Skip to content

Commit 0c78ef6

Browse files
🔴 VLM: compile compatibility (#35724)
* llavas * add mroe models * fix `compile_forward` test for all models * fix copies * make style * also doesn't support cache class * fix some tests * not copied from * ci green? * fix tests * fix copies * fix tests * check with `numel` and remove `item` * fix copies * fix copies * Update src/transformers/models/cohere2/modeling_cohere2.py Co-authored-by: Arthur <[email protected]> * opt remove cross attn * gemma2 * fixup * fixup * fix newly added test * maybe fixed? * green please? --------- Co-authored-by: Arthur <[email protected]>
1 parent b45cf0e commit 0c78ef6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+461
-1212
lines changed

src/transformers/models/blip_2/modeling_blip_2.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2016,6 +2016,9 @@ def forward(
20162016
class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
20172017
config_class = Blip2Config
20182018
main_input_name = "pixel_values"
2019+
_supports_cache_class = True
2020+
_supports_static_cache = True
2021+
_supports_quantized_cache = False # not all LM bacbones support (e.g. T5)
20192022

20202023
def __init__(self, config: Blip2Config):
20212024
super().__init__(config)

src/transformers/models/chameleon/modeling_chameleon.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1284,13 +1284,13 @@ def forward(
12841284

12851285
if pixel_values is not None:
12861286
image_tokens = self.get_image_tokens(pixel_values)
1287-
n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum().item()
1288-
n_image_features = image_tokens.shape[0] * image_tokens.shape[1]
1289-
if n_image_tokens_in_text != n_image_features:
1287+
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
1288+
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_tokens.numel():
1289+
n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum()
1290+
n_image_features = image_tokens.shape[0] * image_tokens.shape[1]
12901291
raise ValueError(
12911292
f"Image features and image tokens do not match: tokens: {n_image_tokens_in_text}, features {n_image_features}"
12921293
)
1293-
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
12941294
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
12951295
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
12961296

src/transformers/models/cohere2/modeling_cohere2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import torch.nn as nn
2626

2727
from ...activations import ACT2FN
28-
from ...cache_utils import Cache, HybridCache
28+
from ...cache_utils import Cache, HybridCache, StaticCache
2929
from ...generation import GenerationMixin
3030
from ...modeling_flash_attention_utils import FlashAttentionKwargs
3131
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
@@ -701,7 +701,7 @@ def _update_causal_mask(
701701

702702
dtype, device = input_tensor.dtype, input_tensor.device
703703
sequence_length = input_tensor.shape[1]
704-
if isinstance(past_key_values, HybridCache):
704+
if isinstance(past_key_values, (HybridCache, StaticCache)):
705705
target_length = past_key_values.get_max_cache_shape()
706706
else:
707707
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]

src/transformers/models/gemma2/modeling_gemma2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import torch.nn as nn
2626

2727
from ...activations import ACT2FN
28-
from ...cache_utils import Cache, HybridCache
28+
from ...cache_utils import Cache, HybridCache, StaticCache
2929
from ...generation import GenerationMixin
3030
from ...modeling_flash_attention_utils import FlashAttentionKwargs
3131
from ...modeling_outputs import (
@@ -713,7 +713,7 @@ def _update_causal_mask(
713713

714714
dtype, device = input_tensor.dtype, input_tensor.device
715715
sequence_length = input_tensor.shape[1]
716-
if isinstance(past_key_values, HybridCache):
716+
if isinstance(past_key_values, (HybridCache, StaticCache)):
717717
target_length = past_key_values.get_max_cache_shape()
718718
else:
719719
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]

src/transformers/models/gemma2/modular_gemma2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import torch.utils.checkpoint
2121

2222
from ...activations import ACT2FN
23-
from ...cache_utils import Cache, HybridCache
23+
from ...cache_utils import Cache, HybridCache, StaticCache
2424
from ...configuration_utils import PretrainedConfig
2525
from ...modeling_flash_attention_utils import FlashAttentionKwargs
2626
from ...modeling_outputs import (
@@ -550,7 +550,7 @@ def _update_causal_mask(
550550

551551
dtype, device = input_tensor.dtype, input_tensor.device
552552
sequence_length = input_tensor.shape[1]
553-
if isinstance(past_key_values, HybridCache):
553+
if isinstance(past_key_values, (HybridCache, StaticCache)):
554554
target_length = past_key_values.get_max_cache_shape()
555555
else:
556556
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]

src/transformers/models/got_ocr2/configuration_got_ocr2.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,6 @@ class GotOcr2Config(PretrainedConfig):
132132
The config object or dictionary of the vision backbone.
133133
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
134134
The config object or dictionary of the text backbone.
135-
ignore_index (`int`, *optional*, defaults to -100):
136-
The ignore index for the loss function.
137135
image_token_index (`int`, *optional*, defaults to 151859):
138136
The image token index to encode the image prompt.
139137
image_seq_length (`int`, *optional*, defaults to 576):
@@ -161,13 +159,11 @@ def __init__(
161159
self,
162160
vision_config=None,
163161
text_config=None,
164-
ignore_index=-100,
165162
image_token_index=151859,
166163
image_seq_length=576,
167164
pad_token_id=-1,
168165
**kwargs,
169166
):
170-
self.ignore_index = ignore_index
171167
self.image_token_index = image_token_index
172168
self.image_seq_length = image_seq_length
173169
self.pad_token_id = pad_token_id

src/transformers/models/got_ocr2/modeling_got_ocr2.py

Lines changed: 2 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,8 @@ class GotOcr2PreTrainedModel(PreTrainedModel):
594594
_supports_cache_class = True
595595
_supports_flash_attn_2 = True
596596
_supports_sdpa = True
597+
_supports_quantized_cache = True
598+
_supports_static_cache = True
597599

598600
def _init_weights(self, module):
599601
# important: this ported version of GotOcr2 isn't meant for training from scratch - only
@@ -748,89 +750,6 @@ def get_image_features(
748750
image_outputs = self.vision_tower(pixel_values).last_hidden_state
749751
return self.multi_modal_projector(image_outputs)
750752

751-
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
752-
num_images, num_image_patches, embed_dim = image_features.shape
753-
batch_size, sequence_length = input_ids.shape
754-
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
755-
# 1. Create a mask to know where special image tokens are
756-
special_image_token_mask = input_ids == self.config.image_token_index
757-
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
758-
# Compute the maximum embed dimension
759-
max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
760-
batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)
761-
762-
# 2. Compute the positions where text should be written
763-
# Calculate new positions for text tokens in merged image-text sequence.
764-
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
765-
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
766-
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
767-
new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
768-
nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
769-
if left_padding:
770-
new_token_positions += nb_image_pad[:, None] # offset for left padding
771-
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
772-
773-
# 3. Create the full embedding, already padded to the maximum position
774-
final_embedding = torch.zeros(
775-
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
776-
)
777-
final_attention_mask = torch.zeros(
778-
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
779-
)
780-
if labels is not None:
781-
final_labels = torch.full(
782-
(batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
783-
)
784-
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
785-
# set the corresponding tensors into their correct target device.
786-
target_device = inputs_embeds.device
787-
batch_indices, non_image_indices, text_to_overwrite = (
788-
batch_indices.to(target_device),
789-
non_image_indices.to(target_device),
790-
text_to_overwrite.to(target_device),
791-
)
792-
attention_mask = attention_mask.to(target_device)
793-
794-
# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
795-
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
796-
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
797-
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
798-
if labels is not None:
799-
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
800-
801-
# 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
802-
image_to_overwrite = torch.full(
803-
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
804-
)
805-
image_to_overwrite[batch_indices, text_to_overwrite] = False
806-
if left_padding:
807-
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
808-
else:
809-
mask = torch.ones_like(image_to_overwrite, dtype=torch.bool).cumsum(-1) - 1
810-
padding_mask = mask <= new_token_positions[:, -1:].to(target_device)
811-
image_to_overwrite &= padding_mask
812-
813-
if image_to_overwrite.sum() != image_features.shape[:-1].numel():
814-
raise ValueError(
815-
f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
816-
f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
817-
)
818-
819-
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
820-
final_attention_mask |= image_to_overwrite
821-
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
822-
823-
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
824-
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
825-
indices_to_mask = new_token_positions[batch_indices, pad_indices]
826-
827-
final_embedding[batch_indices, indices_to_mask] = 0
828-
829-
if labels is None:
830-
final_labels = None
831-
832-
return final_embedding, final_attention_mask, final_labels, position_ids
833-
834753
@add_start_docstrings_to_model_forward(GOT_OCR2_INPUTS_DOCSTRING)
835754
@replace_return_docstrings(output_type=GotOcr2CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
836755
def forward(

src/transformers/models/got_ocr2/modular_got_ocr2.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,6 @@ class GotOcr2Config(PretrainedConfig):
170170
The config object or dictionary of the vision backbone.
171171
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
172172
The config object or dictionary of the text backbone.
173-
ignore_index (`int`, *optional*, defaults to -100):
174-
The ignore index for the loss function.
175173
image_token_index (`int`, *optional*, defaults to 151859):
176174
The image token index to encode the image prompt.
177175
image_seq_length (`int`, *optional*, defaults to 576):
@@ -199,13 +197,11 @@ def __init__(
199197
self,
200198
vision_config=None,
201199
text_config=None,
202-
ignore_index=-100,
203200
image_token_index=151859,
204201
image_seq_length=576,
205202
pad_token_id=-1,
206203
**kwargs,
207204
):
208-
self.ignore_index = ignore_index
209205
self.image_token_index = image_token_index
210206
self.image_seq_length = image_seq_length
211207
self.pad_token_id = pad_token_id

src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class GPTNeoXJapanesePreTrainedModel(PreTrainedModel):
5151
_skip_keys_device_placement = "past_key_values"
5252
_supports_cache_class = True
5353
_supports_quantized_cache = True
54-
_supports_static_cache = False # TODO (fix me): compilation fails due to a stide error?
54+
_supports_static_cache = True
5555

5656
def _init_weights(self, module):
5757
"""Initialize the weights"""
@@ -129,8 +129,8 @@ def forward(
129129

130130
cos, sin = position_embeddings
131131
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
132-
query = torch.cat((query, query_pass), dim=-1)
133-
key = torch.cat((key, key_pass), dim=-1)
132+
query = torch.cat((query, query_pass), dim=-1).contiguous()
133+
key = torch.cat((key, key_pass), dim=-1).contiguous()
134134

135135
# Cache QKV values
136136
if layer_past is not None:

src/transformers/models/granitemoe/modeling_granitemoe.py

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1108,6 +1108,7 @@ def forward(
11081108
router_logits=all_router_logits,
11091109
)
11101110

1111+
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
11111112
def _update_causal_mask(
11121113
self,
11131114
attention_mask: torch.Tensor,
@@ -1116,13 +1117,8 @@ def _update_causal_mask(
11161117
past_key_values: Cache,
11171118
output_attentions: bool,
11181119
):
1119-
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
1120-
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
1121-
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
1122-
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
1123-
11241120
if self.config._attn_implementation == "flash_attention_2":
1125-
if attention_mask is not None and 0.0 in attention_mask:
1121+
if attention_mask is not None and (attention_mask == 0.0).any():
11261122
return attention_mask
11271123
return None
11281124

@@ -1143,7 +1139,6 @@ def _update_causal_mask(
11431139
return None
11441140

11451141
dtype, device = input_tensor.dtype, input_tensor.device
1146-
min_dtype = torch.finfo(dtype).min
11471142
sequence_length = input_tensor.shape[1]
11481143
if using_static_cache:
11491144
target_length = past_key_values.get_max_cache_shape()
@@ -1154,25 +1149,17 @@ def _update_causal_mask(
11541149
else past_seen_tokens + sequence_length + 1
11551150
)
11561151

1157-
if attention_mask is not None and attention_mask.dim() == 4:
1158-
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
1159-
causal_mask = attention_mask
1160-
else:
1161-
causal_mask = torch.full(
1162-
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
1163-
)
1164-
if sequence_length != 1:
1165-
causal_mask = torch.triu(causal_mask, diagonal=1)
1166-
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1167-
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
1168-
if attention_mask is not None:
1169-
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
1170-
mask_length = attention_mask.shape[-1]
1171-
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
1172-
padding_mask = padding_mask == 0
1173-
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
1174-
padding_mask, min_dtype
1175-
)
1152+
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
1153+
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
1154+
attention_mask,
1155+
sequence_length=sequence_length,
1156+
target_length=target_length,
1157+
dtype=dtype,
1158+
device=device,
1159+
cache_position=cache_position,
1160+
batch_size=input_tensor.shape[0],
1161+
)
1162+
11761163
if (
11771164
self.config._attn_implementation == "sdpa"
11781165
and attention_mask is not None
@@ -1182,6 +1169,7 @@ def _update_causal_mask(
11821169
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
11831170
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
11841171
# Details: https://github.com/pytorch/pytorch/issues/110213
1172+
min_dtype = torch.finfo(dtype).min
11851173
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
11861174

11871175
return causal_mask

0 commit comments

Comments
 (0)