@@ -594,6 +594,8 @@ class GotOcr2PreTrainedModel(PreTrainedModel):
594594    _supports_cache_class  =  True 
595595    _supports_flash_attn_2  =  True 
596596    _supports_sdpa  =  True 
597+     _supports_quantized_cache  =  True 
598+     _supports_static_cache  =  True 
597599
598600    def  _init_weights (self , module ):
599601        # important: this ported version of GotOcr2 isn't meant for training from scratch - only 
@@ -748,89 +750,6 @@ def get_image_features(
748750        image_outputs  =  self .vision_tower (pixel_values ).last_hidden_state 
749751        return  self .multi_modal_projector (image_outputs )
750752
751-     def  _merge_input_ids_with_image_features (self , image_features , inputs_embeds , input_ids , attention_mask , labels ):
752-         num_images , num_image_patches , embed_dim  =  image_features .shape 
753-         batch_size , sequence_length  =  input_ids .shape 
754-         left_padding  =  not  torch .sum (input_ids [:, - 1 ] ==  torch .tensor (self .pad_token_id ))
755-         # 1. Create a mask to know where special image tokens are 
756-         special_image_token_mask  =  input_ids  ==  self .config .image_token_index 
757-         num_special_image_tokens  =  torch .sum (special_image_token_mask , dim = - 1 )
758-         # Compute the maximum embed dimension 
759-         max_embed_dim  =  (num_special_image_tokens .max () *  (num_image_patches  -  1 )) +  sequence_length 
760-         batch_indices , non_image_indices  =  torch .where (input_ids  !=  self .config .image_token_index )
761- 
762-         # 2. Compute the positions where text should be written 
763-         # Calculate new positions for text tokens in merged image-text sequence. 
764-         # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens. 
765-         # `torch.cumsum` computes how each image token shifts subsequent text token positions. 
766-         # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. 
767-         new_token_positions  =  torch .cumsum ((special_image_token_mask  *  (num_image_patches  -  1 ) +  1 ), - 1 ) -  1 
768-         nb_image_pad  =  max_embed_dim  -  1  -  new_token_positions [:, - 1 ]
769-         if  left_padding :
770-             new_token_positions  +=  nb_image_pad [:, None ]  # offset for left padding 
771-         text_to_overwrite  =  new_token_positions [batch_indices , non_image_indices ]
772- 
773-         # 3. Create the full embedding, already padded to the maximum position 
774-         final_embedding  =  torch .zeros (
775-             batch_size , max_embed_dim , embed_dim , dtype = inputs_embeds .dtype , device = inputs_embeds .device 
776-         )
777-         final_attention_mask  =  torch .zeros (
778-             batch_size , max_embed_dim , dtype = attention_mask .dtype , device = inputs_embeds .device 
779-         )
780-         if  labels  is  not   None :
781-             final_labels  =  torch .full (
782-                 (batch_size , max_embed_dim ), self .config .ignore_index , dtype = input_ids .dtype , device = input_ids .device 
783-             )
784-         # In case the Vision model or the Language model has been offloaded to CPU, we need to manually 
785-         # set the corresponding tensors into their correct target device. 
786-         target_device  =  inputs_embeds .device 
787-         batch_indices , non_image_indices , text_to_overwrite  =  (
788-             batch_indices .to (target_device ),
789-             non_image_indices .to (target_device ),
790-             text_to_overwrite .to (target_device ),
791-         )
792-         attention_mask  =  attention_mask .to (target_device )
793- 
794-         # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"] 
795-         # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features 
796-         final_embedding [batch_indices , text_to_overwrite ] =  inputs_embeds [batch_indices , non_image_indices ]
797-         final_attention_mask [batch_indices , text_to_overwrite ] =  attention_mask [batch_indices , non_image_indices ]
798-         if  labels  is  not   None :
799-             final_labels [batch_indices , text_to_overwrite ] =  labels [batch_indices , non_image_indices ]
800- 
801-         # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835) 
802-         image_to_overwrite  =  torch .full (
803-             (batch_size , max_embed_dim ), True , dtype = torch .bool , device = inputs_embeds .device 
804-         )
805-         image_to_overwrite [batch_indices , text_to_overwrite ] =  False 
806-         if  left_padding :
807-             image_to_overwrite  &=  image_to_overwrite .cumsum (- 1 ) -  1  >=  nb_image_pad [:, None ].to (target_device )
808-         else :
809-             mask  =  torch .ones_like (image_to_overwrite , dtype = torch .bool ).cumsum (- 1 ) -  1 
810-             padding_mask  =  mask  <=  new_token_positions [:, - 1 :].to (target_device )
811-             image_to_overwrite  &=  padding_mask 
812- 
813-         if  image_to_overwrite .sum () !=  image_features .shape [:- 1 ].numel ():
814-             raise  ValueError (
815-                 f"The input provided to the model are wrong. The number of image tokens is { torch .sum (special_image_token_mask )}   while" 
816-                 f" the number of image given to the model is { num_images }  . This prevents correct indexing and breaks batch generation." 
817-             )
818- 
819-         final_embedding [image_to_overwrite ] =  image_features .contiguous ().reshape (- 1 , embed_dim ).to (target_device )
820-         final_attention_mask  |=  image_to_overwrite 
821-         position_ids  =  (final_attention_mask .cumsum (- 1 ) -  1 ).masked_fill_ ((final_attention_mask  ==  0 ), 1 )
822- 
823-         # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. 
824-         batch_indices , pad_indices  =  torch .where (input_ids  ==  self .pad_token_id )
825-         indices_to_mask  =  new_token_positions [batch_indices , pad_indices ]
826- 
827-         final_embedding [batch_indices , indices_to_mask ] =  0 
828- 
829-         if  labels  is  None :
830-             final_labels  =  None 
831- 
832-         return  final_embedding , final_attention_mask , final_labels , position_ids 
833- 
834753    @add_start_docstrings_to_model_forward (GOT_OCR2_INPUTS_DOCSTRING ) 
835754    @replace_return_docstrings (output_type = GotOcr2CausalLMOutputWithPast , config_class = _CONFIG_FOR_DOC ) 
836755    def  forward (
0 commit comments