Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
5d6c5e1
llavas
zucchini-nlp Jan 15, 2025
b500dcf
add mroe models
zucchini-nlp Jan 16, 2025
b56b40c
fix `compile_forward` test for all models
zucchini-nlp Jan 16, 2025
040a83c
fix copies
zucchini-nlp Jan 16, 2025
4c8e6ab
make style
zucchini-nlp Jan 16, 2025
b72d845
also doesn't support cache class
zucchini-nlp Jan 16, 2025
70a0510
fix some tests
zucchini-nlp Jan 16, 2025
17b0c8f
not copied from
zucchini-nlp Jan 16, 2025
8ddee32
ci green?
zucchini-nlp Jan 17, 2025
370c9d2
fix tests
zucchini-nlp Jan 17, 2025
91d268d
Merge remote-tracking branch 'upstream/main' into compile-llava-enable
zucchini-nlp Jan 30, 2025
fcc6454
fix copies
zucchini-nlp Jan 30, 2025
41b50d8
fix tests
zucchini-nlp Jan 30, 2025
2b602ba
check with `numel` and remove `item`
zucchini-nlp Feb 10, 2025
4a3ff89
merge main
zucchini-nlp Feb 10, 2025
4e9cd52
fix copies
zucchini-nlp Feb 10, 2025
1776f0f
fix copies
zucchini-nlp Feb 10, 2025
e906616
Merge remote-tracking branch 'upstream/main' into compile-llava-enable
zucchini-nlp Feb 10, 2025
2232f62
Update src/transformers/models/cohere2/modeling_cohere2.py
zucchini-nlp Feb 13, 2025
f84242e
merge main
zucchini-nlp Feb 13, 2025
e089e34
opt remove cross attn
zucchini-nlp Feb 13, 2025
210bb5f
gemma2
zucchini-nlp Feb 13, 2025
7271490
fixup
zucchini-nlp Feb 13, 2025
496fc05
Merge branch 'main' into compile-llava-enable
zucchini-nlp Feb 13, 2025
45ad329
fixup
zucchini-nlp Feb 13, 2025
7a79bac
Merge branch 'main' into compile-llava-enable
zucchini-nlp Feb 14, 2025
2f219eb
fix newly added test
zucchini-nlp Feb 14, 2025
0cf1cfe
maybe fixed?
zucchini-nlp Feb 14, 2025
eccc5fa
green please?
zucchini-nlp Feb 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/transformers/models/blip_2/modeling_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2016,6 +2016,9 @@ def forward(
class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
config_class = Blip2Config
main_input_name = "pixel_values"
_supports_cache_class = True
_supports_static_cache = True
_supports_quantized_cache = False # not all LM bacbones support (e.g. T5)

def __init__(self, config: Blip2Config):
super().__init__(config)
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/models/chameleon/modeling_chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1284,13 +1284,13 @@ def forward(

if pixel_values is not None:
image_tokens = self.get_image_tokens(pixel_values)
n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum().item()
n_image_features = image_tokens.shape[0] * image_tokens.shape[1]
if n_image_tokens_in_text != n_image_features:
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_tokens.numel():
n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum()
n_image_features = image_tokens.shape[0] * image_tokens.shape[1]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens_in_text}, features {n_image_features}"
)
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/cohere2/modeling_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import torch.nn as nn

from ...activations import ACT2FN
from ...cache_utils import Cache, HybridCache
from ...cache_utils import Cache, HybridCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
Expand Down Expand Up @@ -701,7 +701,7 @@ def _update_causal_mask(

dtype, device = input_tensor.dtype, input_tensor.device
sequence_length = input_tensor.shape[1]
if isinstance(past_key_values, HybridCache):
if isinstance(past_key_values, (HybridCache, StaticCache)):
target_length = past_key_values.get_max_cache_shape()
else:
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import torch.nn as nn

from ...activations import ACT2FN
from ...cache_utils import Cache, HybridCache
from ...cache_utils import Cache, HybridCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
Expand Down Expand Up @@ -713,7 +713,7 @@ def _update_causal_mask(

dtype, device = input_tensor.dtype, input_tensor.device
sequence_length = input_tensor.shape[1]
if isinstance(past_key_values, HybridCache):
if isinstance(past_key_values, (HybridCache, StaticCache)):
target_length = past_key_values.get_max_cache_shape()
else:
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/gemma2/modular_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch.utils.checkpoint

from ...activations import ACT2FN
from ...cache_utils import Cache, HybridCache
from ...cache_utils import Cache, HybridCache, StaticCache
from ...configuration_utils import PretrainedConfig
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
Expand Down Expand Up @@ -550,7 +550,7 @@ def _update_causal_mask(

dtype, device = input_tensor.dtype, input_tensor.device
sequence_length = input_tensor.shape[1]
if isinstance(past_key_values, HybridCache):
if isinstance(past_key_values, (HybridCache, StaticCache)):
target_length = past_key_values.get_max_cache_shape()
else:
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]
Expand Down
4 changes: 0 additions & 4 deletions src/transformers/models/got_ocr2/configuration_got_ocr2.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,6 @@ class GotOcr2Config(PretrainedConfig):
The config object or dictionary of the vision backbone.
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
The config object or dictionary of the text backbone.
ignore_index (`int`, *optional*, defaults to -100):
The ignore index for the loss function.
image_token_index (`int`, *optional*, defaults to 151859):
The image token index to encode the image prompt.
image_seq_length (`int`, *optional*, defaults to 576):
Expand Down Expand Up @@ -161,13 +159,11 @@ def __init__(
self,
vision_config=None,
text_config=None,
ignore_index=-100,
image_token_index=151859,
image_seq_length=576,
pad_token_id=-1,
**kwargs,
):
self.ignore_index = ignore_index
self.image_token_index = image_token_index
self.image_seq_length = image_seq_length
self.pad_token_id = pad_token_id
Expand Down
85 changes: 2 additions & 83 deletions src/transformers/models/got_ocr2/modeling_got_ocr2.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,8 @@ class GotOcr2PreTrainedModel(PreTrainedModel):
_supports_cache_class = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_quantized_cache = True
_supports_static_cache = True

def _init_weights(self, module):
# important: this ported version of GotOcr2 isn't meant for training from scratch - only
Expand Down Expand Up @@ -748,89 +750,6 @@ def get_image_features(
image_outputs = self.vision_tower(pixel_values).last_hidden_state
return self.multi_modal_projector(image_outputs)

def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
num_images, num_image_patches, embed_dim = image_features.shape
batch_size, sequence_length = input_ids.shape
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
# 1. Create a mask to know where special image tokens are
special_image_token_mask = input_ids == self.config.image_token_index
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
# Compute the maximum embed dimension
max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)

# 2. Compute the positions where text should be written
# Calculate new positions for text tokens in merged image-text sequence.
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
if left_padding:
new_token_positions += nb_image_pad[:, None] # offset for left padding
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]

# 3. Create the full embedding, already padded to the maximum position
final_embedding = torch.zeros(
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
)
final_attention_mask = torch.zeros(
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
)
if labels is not None:
final_labels = torch.full(
(batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
)
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
# set the corresponding tensors into their correct target device.
target_device = inputs_embeds.device
batch_indices, non_image_indices, text_to_overwrite = (
batch_indices.to(target_device),
non_image_indices.to(target_device),
text_to_overwrite.to(target_device),
)
attention_mask = attention_mask.to(target_device)

# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
if labels is not None:
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]

# 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
image_to_overwrite = torch.full(
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
)
image_to_overwrite[batch_indices, text_to_overwrite] = False
if left_padding:
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
else:
mask = torch.ones_like(image_to_overwrite, dtype=torch.bool).cumsum(-1) - 1
padding_mask = mask <= new_token_positions[:, -1:].to(target_device)
image_to_overwrite &= padding_mask

if image_to_overwrite.sum() != image_features.shape[:-1].numel():
raise ValueError(
f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
)

final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
final_attention_mask |= image_to_overwrite
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)

# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
indices_to_mask = new_token_positions[batch_indices, pad_indices]

final_embedding[batch_indices, indices_to_mask] = 0

if labels is None:
final_labels = None

return final_embedding, final_attention_mask, final_labels, position_ids

@add_start_docstrings_to_model_forward(GOT_OCR2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=GotOcr2CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
Expand Down
4 changes: 0 additions & 4 deletions src/transformers/models/got_ocr2/modular_got_ocr2.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,6 @@ class GotOcr2Config(PretrainedConfig):
The config object or dictionary of the vision backbone.
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
The config object or dictionary of the text backbone.
ignore_index (`int`, *optional*, defaults to -100):
The ignore index for the loss function.
image_token_index (`int`, *optional*, defaults to 151859):
The image token index to encode the image prompt.
image_seq_length (`int`, *optional*, defaults to 576):
Expand Down Expand Up @@ -199,13 +197,11 @@ def __init__(
self,
vision_config=None,
text_config=None,
ignore_index=-100,
image_token_index=151859,
image_seq_length=576,
pad_token_id=-1,
**kwargs,
):
self.ignore_index = ignore_index
self.image_token_index = image_token_index
self.image_seq_length = image_seq_length
self.pad_token_id = pad_token_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class GPTNeoXJapanesePreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = False # TODO (fix me): compilation fails due to a stide error?
_supports_static_cache = True

def _init_weights(self, module):
"""Initialize the weights"""
Expand Down Expand Up @@ -129,8 +129,8 @@ def forward(

cos, sin = position_embeddings
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
query = torch.cat((query, query_pass), dim=-1)
key = torch.cat((key, key_pass), dim=-1)
query = torch.cat((query, query_pass), dim=-1).contiguous()
key = torch.cat((key, key_pass), dim=-1).contiguous()

# Cache QKV values
if layer_past is not None:
Expand Down
40 changes: 14 additions & 26 deletions src/transformers/models/granitemoe/modeling_granitemoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,6 +1108,7 @@ def forward(
router_logits=all_router_logits,
)

# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
Expand All @@ -1116,13 +1117,8 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None

Expand All @@ -1143,7 +1139,6 @@ def _update_causal_mask(
return None

dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
Expand All @@ -1154,25 +1149,17 @@ def _update_causal_mask(
else past_seen_tokens + sequence_length + 1
)

if attention_mask is not None and attention_mask.dim() == 4:
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
causal_mask = attention_mask
else:
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)

if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
Expand All @@ -1182,6 +1169,7 @@ def _update_causal_mask(
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)

return causal_mask
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/instructblip/modeling_instructblip.py
Original file line number Diff line number Diff line change
Expand Up @@ -1290,6 +1290,9 @@ def forward(
class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, GenerationMixin):
config_class = InstructBlipConfig
main_input_name = "pixel_values"
_supports_cache_class = True
_supports_static_cache = True
_supports_quantized_cache = False # not all LM bacbones support (e.g. T5)

def __init__(self, config: InstructBlipConfig):
super().__init__(config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1284,6 +1284,9 @@ def forward(
class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel, GenerationMixin):
config_class = InstructBlipVideoConfig
main_input_name = "pixel_values"
_supports_cache_class = True
_supports_static_cache = True
_supports_quantized_cache = False # not all LM bacbones support (e.g. T5)

def __init__(self, config: InstructBlipVideoConfig):
super().__init__(config)
Expand Down
4 changes: 0 additions & 4 deletions src/transformers/models/llava/configuration_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ class LlavaConfig(PretrainedConfig):
The config object or dictionary of the vision backbone.
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
The config object or dictionary of the text backbone.
ignore_index (`int`, *optional*, defaults to -100):
The ignore index for the loss function.
image_token_index (`int`, *optional*, defaults to 32000):
The image token index to encode the image prompt.
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
Expand Down Expand Up @@ -83,7 +81,6 @@ def __init__(
self,
vision_config=None,
text_config=None,
ignore_index=-100,
image_token_index=32000,
projector_hidden_act="gelu",
vision_feature_select_strategy="default",
Expand All @@ -92,7 +89,6 @@ def __init__(
multimodal_projector_bias=True,
**kwargs,
):
self.ignore_index = ignore_index
self.image_token_index = image_token_index
self.projector_hidden_act = projector_hidden_act
self.image_seq_length = image_seq_length
Expand Down
Loading