From d115b6a5c379e5cf4a9ef23fd28deac7b30ce56a Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 5 Nov 2024 19:54:42 +0100 Subject: [PATCH 01/11] update --- .../models/cohere/modeling_cohere.py | 44 ++---------- .../models/falcon/modeling_falcon.py | 55 ++------------- src/transformers/models/glm/modeling_glm.py | 8 +-- .../models/gpt_neox/modeling_gpt_neox.py | 44 ++---------- .../modeling_gpt_neox_japanese.py | 15 +--- .../models/granite/modeling_granite.py | 8 +-- .../models/granitemoe/modeling_granitemoe.py | 8 +-- .../models/idefics2/modeling_idefics2.py | 1 - .../models/llama/modeling_llama.py | 69 ++----------------- .../models/mistral/modeling_mistral.py | 37 +++++----- .../models/mllama/modeling_mllama.py | 2 +- .../models/nemotron/modeling_nemotron.py | 2 +- .../models/persimmon/modeling_persimmon.py | 42 +---------- src/transformers/models/phi/modeling_phi.py | 68 ++---------------- .../models/qwen2/modeling_qwen2.py | 41 ++--------- .../models/qwen2_moe/modeling_qwen2_moe.py | 42 ++--------- .../models/qwen2_vl/modeling_qwen2_vl.py | 42 ++--------- .../models/stablelm/modeling_stablelm.py | 68 ++---------------- .../models/starcoder2/modeling_starcoder2.py | 41 ++--------- 19 files changed, 106 insertions(+), 531 deletions(-) diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index d481d87e7ab8..25260fa5fd60 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -283,9 +283,6 @@ def __init__(self, config: CohereConfig, layer_idx: Optional[int] = None): self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) - # TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers) - self.rotary_emb = CohereRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -295,7 +292,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -314,16 +311,7 @@ def forward( key_states = key_states.transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -389,7 +377,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if isinstance(past_key_value, StaticCache): @@ -415,16 +403,7 @@ def forward( key_states = key_states.transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -502,7 +481,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. @@ -536,16 +515,7 @@ def forward( key_states = key_states.transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -615,7 +585,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index faea670ecbf4..bd90c637c717 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -197,33 +197,6 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Falcon -class FalconLinearScalingRotaryEmbedding(FalconRotaryEmbedding): - """FalconRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - - def __init__(self, *args, **kwargs): - logger.warning_once( - "`FalconLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use " - "`FalconRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)." - ) - kwargs["rope_type"] = "linear" - super().__init__(*args, **kwargs) - - -# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Falcon -class FalconDynamicNTKScalingRotaryEmbedding(FalconRotaryEmbedding): - """FalconRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - - def __init__(self, *args, **kwargs): - logger.warning_once( - "`FalconDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use " - "`FalconRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to " - "__init__)." - ) - kwargs["rope_type"] = "dynamic" - super().__init__(*args, **kwargs) - - def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor: batch_size, seq_length = attention_mask.shape closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) @@ -388,7 +361,7 @@ def forward( use_cache: bool = False, output_attentions: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ): fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads @@ -402,16 +375,7 @@ def forward( value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) if alibi is None: - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_layer, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin) if layer_past is not None: @@ -548,7 +512,7 @@ def forward( use_cache: bool = False, output_attentions: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ): fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads @@ -562,16 +526,7 @@ def forward( value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) if alibi is None: - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_layer, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin) if layer_past is not None: @@ -695,7 +650,7 @@ def forward( use_cache: bool = False, output_attentions: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ): residual = hidden_states diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 16a724f69464..dc24d6982aec 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -227,7 +227,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -303,7 +303,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: output_attentions = False @@ -402,7 +402,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: @@ -503,7 +503,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 3fdb814ebab5..fd11dc4ec0d0 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -311,7 +311,7 @@ def forward( output_attentions: Optional[bool] = False, padding_mask: Optional[torch.Tensor] = None, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ): bsz, seq_len, _ = hidden_states.shape @@ -404,7 +404,7 @@ def _attn_projections_and_rope( layer_past: Optional[Tuple[torch.Tensor]] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ): # Compute QKV # Attention heads [batch, seq_len, hidden_size] @@ -427,16 +427,7 @@ def _attn_projections_and_rope( key_rot = key[..., : self.rotary_ndims] key_pass = key[..., self.rotary_ndims :] - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) query = torch.cat((query, query_pass), dim=-1) key = torch.cat((key, key_pass), dim=-1) @@ -583,33 +574,6 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->GPTNeoX -class GPTNeoXLinearScalingRotaryEmbedding(GPTNeoXRotaryEmbedding): - """GPTNeoXRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - - def __init__(self, *args, **kwargs): - logger.warning_once( - "`GPTNeoXLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use " - "`GPTNeoXRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)." - ) - kwargs["rope_type"] = "linear" - super().__init__(*args, **kwargs) - - -# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->GPTNeoX -class GPTNeoXDynamicNTKScalingRotaryEmbedding(GPTNeoXRotaryEmbedding): - """GPTNeoXRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - - def __init__(self, *args, **kwargs): - logger.warning_once( - "`GPTNeoXDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use " - "`GPTNeoXRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to " - "__init__)." - ) - kwargs["rope_type"] = "dynamic" - super().__init__(*args, **kwargs) - - def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -688,7 +652,7 @@ def forward( layer_past: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ): attention_layer_outputs = self.attention( self.input_layernorm(hidden_states), diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index 6c3f3313f57f..2a315998c3ee 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -105,7 +105,7 @@ def forward( use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ): # Compute QKV # Attention heads [batch, seq_len, hidden_size] @@ -128,16 +128,7 @@ def forward( key_rot = key[..., : self.rotary_ndims] key_pass = key[..., self.rotary_ndims :] - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) query = torch.cat((query, query_pass), dim=-1) key = torch.cat((key, key_pass), dim=-1) @@ -415,7 +406,7 @@ def forward( layer_past: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ): residual = hidden_states ln_out = self.input_layernorm(hidden_states) diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 9cabd48a5102..c2661e6ebcc4 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -242,7 +242,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -318,7 +318,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: output_attentions = False @@ -417,7 +417,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: @@ -520,7 +520,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index 4871fc3584fa..ff4a229dd236 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -458,7 +458,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -535,7 +535,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: output_attentions = False @@ -635,7 +635,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: @@ -739,7 +739,7 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, output_router_logits: Optional[bool] = False, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index 3d46c3bd82e7..b095abf0007d 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -867,7 +867,6 @@ class Idefics2PerceiverFlashAttention2(Idefics2PerceiverAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 0408bb73c7f2..2b17f0af74fd 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -168,31 +168,6 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): - """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - - def __init__(self, *args, **kwargs): - logger.warning_once( - "`LlamaLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use " - "`LlamaRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)." - ) - kwargs["rope_type"] = "linear" - super().__init__(*args, **kwargs) - - -class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): - """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - - def __init__(self, *args, **kwargs): - logger.warning_once( - "`LlamaDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use " - "`LlamaRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to " - "__init__)." - ) - kwargs["rope_type"] = "dynamic" - super().__init__(*args, **kwargs) - - def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -284,9 +259,6 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) - # TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers) - self.rotary_emb = LlamaRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -296,7 +268,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -310,16 +282,7 @@ def forward( key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -382,7 +345,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if isinstance(past_key_value, StaticCache): @@ -406,16 +369,7 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -497,7 +451,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: @@ -528,16 +482,7 @@ def forward( key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -607,7 +552,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index e94281b29fe1..a228fa659269 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -205,12 +205,6 @@ def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None): self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - self.rotary_emb = MistralRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) - def forward( self, hidden_states: torch.Tensor, @@ -220,6 +214,8 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -231,7 +227,7 @@ def forward( key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids) + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -270,6 +266,7 @@ def forward( return attn_output, attn_weights, past_key_value +# copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Mistral class MistralFlashAttention2(MistralAttention): """ Mistral flash attention module. This module inherits from `MistralAttention` as the weights of the module stays @@ -277,7 +274,6 @@ class MistralFlashAttention2(MistralAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -295,6 +291,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ): if isinstance(past_key_value, StaticCache): raise ValueError( @@ -314,11 +311,7 @@ def forward( key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += cache_position[0] - - cos, sin = self.rotary_emb(value_states, position_ids) + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -381,7 +374,6 @@ def forward( # copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral -# TODO(joao): add me back asap :) class MistralSdpaAttention(MistralAttention): """ Mistral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -399,6 +391,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: @@ -427,7 +420,7 @@ def forward( key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids) + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -478,7 +471,6 @@ def forward( # copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Mistral, LLAMA->MISTRAL -# TODO(joao): add me back asap :) class MistralDecoderLayer(nn.Module): def __init__(self, config: MistralConfig, layer_idx: int): super().__init__() @@ -499,6 +491,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -533,6 +526,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states @@ -697,6 +691,12 @@ def __init__(self, config: MistralConfig): self._attn_implementation = config._attn_implementation self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = MistralRotaryEmbedding( + config.head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + ) + self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @@ -771,6 +771,9 @@ def forward( hidden_states = inputs_embeds + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None @@ -790,6 +793,7 @@ def forward( output_attentions, use_cache, cache_position, + position_embeddings, ) else: layer_outputs = decoder_layer( @@ -800,6 +804,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index 763ad97b1e72..a9e96e8dbc3f 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -858,7 +858,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index a56b5c68085c..e787a5e4f4e5 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -536,7 +536,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index cd580ab0dc0f..058c213e844f 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -143,33 +143,6 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Persimmon -class PersimmonLinearScalingRotaryEmbedding(PersimmonRotaryEmbedding): - """PersimmonRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - - def __init__(self, *args, **kwargs): - logger.warning_once( - "`PersimmonLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use " - "`PersimmonRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)." - ) - kwargs["rope_type"] = "linear" - super().__init__(*args, **kwargs) - - -# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Persimmon -class PersimmonDynamicNTKScalingRotaryEmbedding(PersimmonRotaryEmbedding): - """PersimmonRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - - def __init__(self, *args, **kwargs): - logger.warning_once( - "`PersimmonDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use " - "`PersimmonRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to " - "__init__)." - ) - kwargs["rope_type"] = "dynamic" - super().__init__(*args, **kwargs) - - # Copied from transformers.models.llama.modeling_llama.rotate_half def rotate_half(x): """Rotates half the hidden dims of the input.""" @@ -286,7 +259,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -305,16 +278,7 @@ def forward( value_states = value_states.transpose(1, 2) key_states = key_states.transpose(1, 2) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings # Partial rotary embedding query_rot, query_pass = ( @@ -390,7 +354,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 4613672ff274..72f1138a9294 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -147,33 +147,6 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Phi -class PhiLinearScalingRotaryEmbedding(PhiRotaryEmbedding): - """PhiRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - - def __init__(self, *args, **kwargs): - logger.warning_once( - "`PhiLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use " - "`PhiRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)." - ) - kwargs["rope_type"] = "linear" - super().__init__(*args, **kwargs) - - -# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Phi -class PhiDynamicNTKScalingRotaryEmbedding(PhiRotaryEmbedding): - """PhiRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - - def __init__(self, *args, **kwargs): - logger.warning_once( - "`PhiDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use " - "`PhiRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to " - "__init__)." - ) - kwargs["rope_type"] = "dynamic" - super().__init__(*args, **kwargs) - - # Copied from transformers.models.llama.modeling_llama.rotate_half def rotate_half(x): """Rotates half the hidden dims of the input.""" @@ -294,7 +267,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -310,16 +283,7 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings # Partial rotary embedding query_rot, query_pass = ( @@ -406,7 +370,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: # PhiFlashAttention2 attention does not support output_attentions @@ -430,16 +394,7 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings # Partial rotary embedding query_rot, query_pass = ( @@ -542,7 +497,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. @@ -575,16 +530,7 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings # Partial rotary embedding query_rot, query_pass = ( @@ -671,7 +617,7 @@ def forward( use_cache: Optional[bool] = False, past_key_value: Optional[Tuple[torch.Tensor]] = None, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 36c5271c5c5e..04c7904289bb 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -284,7 +284,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -296,16 +296,7 @@ def forward( key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -370,7 +361,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ): bsz, q_len, _ = hidden_states.size() @@ -382,16 +373,7 @@ def forward( key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -479,7 +461,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. @@ -506,16 +488,7 @@ def forward( key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -590,7 +563,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 6c5cbec2220e..a1a31986b5c5 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -368,7 +368,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -380,17 +380,7 @@ def forward( key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings - + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -457,7 +447,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ): bsz, q_len, _ = hidden_states.size() @@ -469,16 +459,7 @@ def forward( key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -567,7 +548,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. @@ -594,16 +575,7 @@ def forward( key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -742,7 +714,7 @@ def forward( output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index f7648f4a53d1..34c1122aaa14 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -537,7 +537,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -549,16 +549,7 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings query_states, key_states = apply_multimodal_rotary_pos_emb( query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] ) @@ -630,7 +621,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ): bsz, q_len, _ = hidden_states.size() @@ -643,17 +634,7 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) # Because the input can be padded, the absolute sequence length depends on the max position id. - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings - + cos, sin = position_embeddings query_states, key_states = apply_multimodal_rotary_pos_emb( query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] ) @@ -742,7 +723,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. @@ -770,16 +751,7 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings query_states, key_states = apply_multimodal_rotary_pos_emb( query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] ) @@ -856,7 +828,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 004e4ff3f6c0..55ca0944c35a 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -149,33 +149,6 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->StableLm -class StableLmLinearScalingRotaryEmbedding(StableLmRotaryEmbedding): - """StableLmRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - - def __init__(self, *args, **kwargs): - logger.warning_once( - "`StableLmLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use " - "`StableLmRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)." - ) - kwargs["rope_type"] = "linear" - super().__init__(*args, **kwargs) - - -# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->StableLm -class StableLmDynamicNTKScalingRotaryEmbedding(StableLmRotaryEmbedding): - """StableLmRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - - def __init__(self, *args, **kwargs): - logger.warning_once( - "`StableLmDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use " - "`StableLmRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to " - "__init__)." - ) - kwargs["rope_type"] = "dynamic" - super().__init__(*args, **kwargs) - - # Copied from transformers.models.llama.modeling_llama.rotate_half def rotate_half(x): """Rotates half the hidden dims of the input.""" @@ -307,7 +280,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -323,16 +296,7 @@ def forward( query_states = self.q_layernorm(query_states) key_states = self.k_layernorm(key_states) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings # Partial rotary embedding query_rot, query_pass = ( @@ -403,7 +367,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. @@ -434,16 +398,7 @@ def forward( query_states = self.q_layernorm(query_states) key_states = self.k_layernorm(key_states) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings # Partial rotary embedding query_rot, query_pass = ( @@ -533,7 +488,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: # StableLmFlashAttention2 attention does not support output_attentions @@ -557,16 +512,7 @@ def forward( query_states = self.q_layernorm(query_states) key_states = self.k_layernorm(key_states) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings # Partial rotary embedding query_rot, query_pass = ( @@ -650,7 +596,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 5ecffc8719be..552ea5d6154c 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -262,7 +262,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -274,16 +274,7 @@ def forward( key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -346,7 +337,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ): bsz, q_len, _ = hidden_states.size() @@ -358,16 +349,7 @@ def forward( key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -446,7 +428,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. @@ -473,16 +455,7 @@ def forward( key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -555,7 +528,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ From 74a886f3eae2f0956eeb912d5f17daaf1e8e1164 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 21 Nov 2024 15:21:36 +0100 Subject: [PATCH 02/11] style --- .../modular-transformers/modeling_dummy.py | 44 +++---------------- .../modular-transformers/modeling_super.py | 44 +++---------------- .../models/starcoder2/modeling_starcoder2.py | 2 - .../models/starcoder2/modular_starcoder2.py | 41 +++-------------- 4 files changed, 20 insertions(+), 111 deletions(-) diff --git a/examples/modular-transformers/modeling_dummy.py b/examples/modular-transformers/modeling_dummy.py index 0b373d4e6eab..33ac61201916 100644 --- a/examples/modular-transformers/modeling_dummy.py +++ b/examples/modular-transformers/modeling_dummy.py @@ -228,9 +228,6 @@ def __init__(self, config: DummyConfig, layer_idx: Optional[int] = None): self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) - # TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers) - self.rotary_emb = DummyRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -240,7 +237,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -254,16 +251,7 @@ def forward( key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -326,7 +314,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if isinstance(past_key_value, StaticCache): @@ -350,16 +338,7 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -441,7 +420,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: @@ -472,16 +451,7 @@ def forward( key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -551,7 +521,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ diff --git a/examples/modular-transformers/modeling_super.py b/examples/modular-transformers/modeling_super.py index 7ad606280dcc..8b1d4d71e8d7 100644 --- a/examples/modular-transformers/modeling_super.py +++ b/examples/modular-transformers/modeling_super.py @@ -228,9 +228,6 @@ def __init__(self, config: SuperConfig, layer_idx: Optional[int] = None): self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) - # TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers) - self.rotary_emb = SuperRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -240,7 +237,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -254,16 +251,7 @@ def forward( key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -326,7 +314,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if isinstance(past_key_value, StaticCache): @@ -350,16 +338,7 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -441,7 +420,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: @@ -472,16 +451,7 @@ def forward( key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -551,7 +521,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 552ea5d6154c..bafce167003f 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -251,8 +251,6 @@ def __init__(self, config: Starcoder2Config, layer_idx: Optional[int] = None): self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.use_bias) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=self.use_bias) - self.rotary_emb = Starcoder2RotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, diff --git a/src/transformers/models/starcoder2/modular_starcoder2.py b/src/transformers/models/starcoder2/modular_starcoder2.py index b5d74bf7feb3..9d515d123967 100644 --- a/src/transformers/models/starcoder2/modular_starcoder2.py +++ b/src/transformers/models/starcoder2/modular_starcoder2.py @@ -117,8 +117,6 @@ def __init__(self, config: Starcoder2Config, layer_idx: Optional[int] = None): self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.use_bias) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=self.use_bias) - self.rotary_emb = Starcoder2RotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -128,7 +126,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -140,16 +138,7 @@ def forward( key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -212,7 +201,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ): bsz, q_len, _ = hidden_states.size() @@ -224,16 +213,7 @@ def forward( key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -312,7 +292,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. @@ -339,16 +319,7 @@ def forward( key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: From ea676309c0e2c59debd77a1acc171b3269edfc96 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 21 Nov 2024 16:09:05 +0100 Subject: [PATCH 03/11] fix missing args --- src/transformers/models/cohere/modeling_cohere.py | 1 + src/transformers/models/mistral/modeling_mistral.py | 1 + src/transformers/models/phi/modeling_phi.py | 2 ++ src/transformers/models/qwen2/modeling_qwen2.py | 2 ++ src/transformers/models/qwen2_moe/modeling_qwen2_moe.py | 2 ++ src/transformers/models/qwen2_vl/modeling_qwen2_vl.py | 1 + src/transformers/models/stablelm/modeling_stablelm.py | 2 ++ src/transformers/models/starcoder2/modeling_starcoder2.py | 2 ++ src/transformers/models/starcoder2/modular_starcoder2.py | 2 ++ 9 files changed, 15 insertions(+) diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 25260fa5fd60..c3936561e190 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -497,6 +497,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, ) bsz, q_len, _ = hidden_states.size() diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index a228fa659269..9ffdf84579f8 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -408,6 +408,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, ) bsz, q_len, _ = hidden_states.size() diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 72f1138a9294..0062317d19bb 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -514,6 +514,8 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, ) bsz, q_len, _ = hidden_states.size() diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 04c7904289bb..efbe504f910f 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -476,6 +476,8 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, ) bsz, q_len, _ = hidden_states.size() diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index a1a31986b5c5..73ed4a652e30 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -563,6 +563,8 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, ) bsz, q_len, _ = hidden_states.size() diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 34c1122aaa14..6c79bbe2ea0c 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -739,6 +739,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, ) bsz, q_len, _ = hidden_states.size() diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 55ca0944c35a..66e60d899bb7 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -382,6 +382,8 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, ) bsz, q_len, _ = hidden_states.size() diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index bafce167003f..e6d22b833857 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -441,6 +441,8 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, ) bsz, q_len, _ = hidden_states.size() diff --git a/src/transformers/models/starcoder2/modular_starcoder2.py b/src/transformers/models/starcoder2/modular_starcoder2.py index 9d515d123967..a5e90354d17a 100644 --- a/src/transformers/models/starcoder2/modular_starcoder2.py +++ b/src/transformers/models/starcoder2/modular_starcoder2.py @@ -307,6 +307,8 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, ) bsz, q_len, _ = hidden_states.size() From 117697482400db47c257c9170ca2ffdbeb76d932 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 21 Nov 2024 16:22:54 +0100 Subject: [PATCH 04/11] remove last trace of old rope classes --- tests/models/falcon/test_modeling_falcon.py | 8 ++-- .../models/gpt_neox/test_modeling_gpt_neox.py | 12 +++--- tests/models/llama/test_modeling_llama.py | 39 +------------------ 3 files changed, 10 insertions(+), 49 deletions(-) diff --git a/tests/models/falcon/test_modeling_falcon.py b/tests/models/falcon/test_modeling_falcon.py index ce04fae94ea9..129bd346a10d 100644 --- a/tests/models/falcon/test_modeling_falcon.py +++ b/tests/models/falcon/test_modeling_falcon.py @@ -50,8 +50,6 @@ FalconModel, ) from transformers.models.falcon.modeling_falcon import ( - FalconDynamicNTKScalingRotaryEmbedding, - FalconLinearScalingRotaryEmbedding, FalconRotaryEmbedding, ) @@ -484,11 +482,12 @@ def test_model_rope_scaling(self): # Sanity check linear RoPE scaling # New position "x" should match original position with index "x/scaling_factor" - linear_scaling_rope = FalconLinearScalingRotaryEmbedding( + linear_scaling_rope = FalconRotaryEmbedding( head_dim, max_position_embeddings=config.max_position_embeddings, base=config.rope_theta, scaling_factor=scaling_factor, + rope_type="linear", ).to(torch_device) linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short) linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long) @@ -502,11 +501,12 @@ def test_model_rope_scaling(self): # Sanity check Dynamic NTK RoPE scaling # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase # with scaling_factor (or that `inv_freq` decreases) - ntk_scaling_rope = FalconDynamicNTKScalingRotaryEmbedding( + ntk_scaling_rope = FalconRotaryEmbedding( head_dim, max_position_embeddings=config.max_position_embeddings, base=config.rope_theta, scaling_factor=scaling_factor, + rope_type="dynamic", ).to(torch_device) ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short) ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long) diff --git a/tests/models/gpt_neox/test_modeling_gpt_neox.py b/tests/models/gpt_neox/test_modeling_gpt_neox.py index 435133e93860..ca9fbb225c6d 100644 --- a/tests/models/gpt_neox/test_modeling_gpt_neox.py +++ b/tests/models/gpt_neox/test_modeling_gpt_neox.py @@ -37,11 +37,7 @@ GPTNeoXForTokenClassification, GPTNeoXModel, ) - from transformers.models.gpt_neox.modeling_gpt_neox import ( - GPTNeoXDynamicNTKScalingRotaryEmbedding, - GPTNeoXLinearScalingRotaryEmbedding, - GPTNeoXRotaryEmbedding, - ) + from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXRotaryEmbedding class GPTNeoXModelTester: @@ -400,11 +396,12 @@ def test_model_rope_scaling(self): # Sanity check linear RoPE scaling # New position "x" should match original position with index "x/scaling_factor" - linear_scaling_rope = GPTNeoXLinearScalingRotaryEmbedding( + linear_scaling_rope = GPTNeoXRotaryEmbedding( head_dim, max_position_embeddings=config.max_position_embeddings, base=config.rotary_emb_base, scaling_factor=scaling_factor, + rope_type="linear", ).to(torch_device) linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short) linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long) @@ -418,11 +415,12 @@ def test_model_rope_scaling(self): # Sanity check Dynamic NTK RoPE scaling # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase # with scaling_factor (or that `inv_freq` decreases) - ntk_scaling_rope = GPTNeoXDynamicNTKScalingRotaryEmbedding( + ntk_scaling_rope = GPTNeoXRotaryEmbedding( head_dim, max_position_embeddings=config.max_position_embeddings, base=config.rotary_emb_base, scaling_factor=scaling_factor, + rope_type="dynamic", ).to(torch_device) ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short) ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long) diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 9e67f4f7381e..0790de4e133b 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -51,7 +51,7 @@ LlamaModel, LlamaTokenizer, ) - from transformers.models.llama.modeling_llama import LlamaLinearScalingRotaryEmbedding, LlamaRotaryEmbedding + from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding class LlamaModelTester: @@ -489,43 +489,6 @@ def test_model_rope_scaling(self): with self.assertRaises(AssertionError): torch.testing.assert_close(yarn_sin_long, original_sin_long) - def test_rope_class_retrocompatibility(self): - # Delete me when we remove compatibility for the old API :) - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - scaling_factor = 10 - short_input_length = 10 - long_input_length = int(config.max_position_embeddings * 1.5) - config.rope_scaling = {"type": "linear", "factor": 10} - - # Inputs - x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device - position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device) - position_ids_short = position_ids_short.unsqueeze(0) - position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device) - position_ids_long = position_ids_long.unsqueeze(0) - - # Old API -- under the hood, "type": "linear" is set and `LlamaRotaryEmbedding` is called - old_api_rope = LlamaLinearScalingRotaryEmbedding( - config.hidden_size // config.num_attention_heads, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - scaling_factor=scaling_factor, - ).to(torch_device) - old_cos_short, old_sin_short = old_api_rope(x, position_ids_short) - old_cos_long, old_sin_long = old_api_rope(x, position_ids_long) - - # New API - config.rope_scaling = {"type": "linear", "factor": scaling_factor} - new_api_rope = LlamaRotaryEmbedding(config=config).to(torch_device) - new_cos_short, new_sin_short = new_api_rope(x, position_ids_short) - new_cos_long, new_sin_long = new_api_rope(x, position_ids_long) - - # The results should match - torch.testing.assert_close(old_cos_short, new_cos_short) - torch.testing.assert_close(old_sin_short, new_sin_short) - torch.testing.assert_close(old_cos_long, new_cos_long) - torch.testing.assert_close(old_sin_long, new_sin_long) - def test_model_loading_old_rope_configs(self): def _reinitialize_config(base_config, new_kwargs): # Reinitialize the config with the new kwargs, forcing the config to go through its __init__ validation From 99dc1628a28227302879381d79df88778bd41324 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 21 Nov 2024 16:24:54 +0100 Subject: [PATCH 05/11] remove deprecated copied from --- .../models/chameleon/modeling_chameleon.py | 4 ---- src/transformers/models/olmo/modeling_olmo.py | 4 ---- src/transformers/models/olmo2/modeling_olmo2.py | 12 ++++-------- 3 files changed, 4 insertions(+), 16 deletions(-) diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index 3255b6f44c05..dfac0ef1cc6f 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -115,8 +115,6 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -# copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Chameleon -# TODO(joao): add me back asap :) class ChameleonLinearScalingRotaryEmbedding(ChameleonRotaryEmbedding): """ChameleonRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" @@ -127,8 +125,6 @@ def forward(self, x, position_ids): return cos, sin -# copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Chameleon -# TODO(joao): add me back asap :) class ChameleonDynamicNTKScalingRotaryEmbedding(ChameleonRotaryEmbedding): """ChameleonRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index d865c51e5057..8b40c41e34dc 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -105,8 +105,6 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -# copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Olmo -# TODO(joao): add me back asap :) class OlmoLinearScalingRotaryEmbedding(OlmoRotaryEmbedding): """OlmoRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" @@ -117,8 +115,6 @@ def forward(self, x, position_ids): return cos, sin -# copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Olmo -# TODO(joao): add me back asap :) class OlmoDynamicNTKScalingRotaryEmbedding(OlmoRotaryEmbedding): """OlmoRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index c042669e1ed5..e34a1e68f178 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -87,10 +87,8 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -# copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Olmo2 -# TODO(joao): add me back asap :) -class Olmo2LinearScalingRotaryEmbedding(Olmo2RotaryEmbedding): - """Olmo2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" +class Olmo1124LinearScalingRotaryEmbedding(Olmo1124RotaryEmbedding): + """Olmo1124RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" def forward(self, x, position_ids): # difference to the original RoPE: a scaling factor is aplied to the position ids @@ -99,10 +97,8 @@ def forward(self, x, position_ids): return cos, sin -# copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Olmo2 -# TODO(joao): add me back asap :) -class Olmo2DynamicNTKScalingRotaryEmbedding(Olmo2RotaryEmbedding): - """Olmo2RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" +class Olmo1124DynamicNTKScalingRotaryEmbedding(Olmo1124RotaryEmbedding): + """Olmo1124RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" def forward(self, x, position_ids): # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length From 91be558da31ab88945c97900eceb4ca2fa2cfaa5 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 21 Nov 2024 16:33:32 +0100 Subject: [PATCH 06/11] fix copies --- tests/models/persimmon/test_modeling_persimmon.py | 12 +++++------- tests/models/phi/test_modeling_phi.py | 12 +++++------- tests/models/stablelm/test_modeling_stablelm.py | 12 +++++------- 3 files changed, 15 insertions(+), 21 deletions(-) diff --git a/tests/models/persimmon/test_modeling_persimmon.py b/tests/models/persimmon/test_modeling_persimmon.py index 99d84f9b5b5b..54ee49b65343 100644 --- a/tests/models/persimmon/test_modeling_persimmon.py +++ b/tests/models/persimmon/test_modeling_persimmon.py @@ -46,11 +46,7 @@ PersimmonForTokenClassification, PersimmonModel, ) - from transformers.models.persimmon.modeling_persimmon import ( - PersimmonDynamicNTKScalingRotaryEmbedding, - PersimmonLinearScalingRotaryEmbedding, - PersimmonRotaryEmbedding, - ) + from transformers.models.persimmon.modeling_persimmon import PersimmonRotaryEmbedding # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester with Llama->Persimmon @@ -451,11 +447,12 @@ def test_model_rope_scaling(self): # Sanity check linear RoPE scaling # New position "x" should match original position with index "x/scaling_factor" - linear_scaling_rope = PersimmonLinearScalingRotaryEmbedding( + linear_scaling_rope = PersimmonRotaryEmbedding( head_dim, max_position_embeddings=config.max_position_embeddings, base=config.rope_theta, scaling_factor=scaling_factor, + rope_type="linear", ).to(torch_device) linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short) linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long) @@ -469,11 +466,12 @@ def test_model_rope_scaling(self): # Sanity check Dynamic NTK RoPE scaling # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase # with scaling_factor (or that `inv_freq` decreases) - ntk_scaling_rope = PersimmonDynamicNTKScalingRotaryEmbedding( + ntk_scaling_rope = PersimmonRotaryEmbedding( head_dim, max_position_embeddings=config.max_position_embeddings, base=config.rope_theta, scaling_factor=scaling_factor, + rope_type="dynamic", ).to(torch_device) ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short) ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long) diff --git a/tests/models/phi/test_modeling_phi.py b/tests/models/phi/test_modeling_phi.py index eae6789bef25..df5278cb34e3 100644 --- a/tests/models/phi/test_modeling_phi.py +++ b/tests/models/phi/test_modeling_phi.py @@ -42,11 +42,7 @@ PhiForTokenClassification, PhiModel, ) - from transformers.models.phi.modeling_phi import ( - PhiDynamicNTKScalingRotaryEmbedding, - PhiLinearScalingRotaryEmbedding, - PhiRotaryEmbedding, - ) + from transformers.models.phi.modeling_phi import PhiRotaryEmbedding class PhiModelTester: @@ -430,11 +426,12 @@ def test_model_rope_scaling(self): # Sanity check linear RoPE scaling # New position "x" should match original position with index "x/scaling_factor" - linear_scaling_rope = PhiLinearScalingRotaryEmbedding( + linear_scaling_rope = PhiRotaryEmbedding( head_dim, max_position_embeddings=config.max_position_embeddings, base=config.rope_theta, scaling_factor=scaling_factor, + rope_type="linear", ).to(torch_device) linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short) linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long) @@ -448,11 +445,12 @@ def test_model_rope_scaling(self): # Sanity check Dynamic NTK RoPE scaling # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase # with scaling_factor (or that `inv_freq` decreases) - ntk_scaling_rope = PhiDynamicNTKScalingRotaryEmbedding( + ntk_scaling_rope = PhiRotaryEmbedding( head_dim, max_position_embeddings=config.max_position_embeddings, base=config.rope_theta, scaling_factor=scaling_factor, + rope_type="dynamic", ).to(torch_device) ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short) ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long) diff --git a/tests/models/stablelm/test_modeling_stablelm.py b/tests/models/stablelm/test_modeling_stablelm.py index 91044a4eb750..bfab01578229 100644 --- a/tests/models/stablelm/test_modeling_stablelm.py +++ b/tests/models/stablelm/test_modeling_stablelm.py @@ -44,11 +44,7 @@ StableLmForTokenClassification, StableLmModel, ) - from transformers.models.stablelm.modeling_stablelm import ( - StableLmDynamicNTKScalingRotaryEmbedding, - StableLmLinearScalingRotaryEmbedding, - StableLmRotaryEmbedding, - ) + from transformers.models.stablelm.modeling_stablelm import StableLmRotaryEmbedding # Copied from transformers.tests.models.persimmon.test_modeling_persimmon.PersimmonModelTester with Persimmon -> StableLm @@ -436,11 +432,12 @@ def test_model_rope_scaling(self): # Sanity check linear RoPE scaling # New position "x" should match original position with index "x/scaling_factor" - linear_scaling_rope = StableLmLinearScalingRotaryEmbedding( + linear_scaling_rope = StableLmRotaryEmbedding( head_dim, max_position_embeddings=config.max_position_embeddings, base=config.rope_theta, scaling_factor=scaling_factor, + rope_type="linear", ).to(torch_device) linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short) linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long) @@ -454,11 +451,12 @@ def test_model_rope_scaling(self): # Sanity check Dynamic NTK RoPE scaling # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase # with scaling_factor (or that `inv_freq` decreases) - ntk_scaling_rope = StableLmDynamicNTKScalingRotaryEmbedding( + ntk_scaling_rope = StableLmRotaryEmbedding( head_dim, max_position_embeddings=config.max_position_embeddings, base=config.rope_theta, scaling_factor=scaling_factor, + rope_type="dynamic", ).to(torch_device) ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short) ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long) From fcd21eee708cb45b5b4d9f3b7fea1765c6e866c2 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 21 Nov 2024 16:38:56 +0100 Subject: [PATCH 07/11] trigger CIs From ba4ab337039b67c4b27bdeda729a94c952384084 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 4 Dec 2024 16:13:10 +0100 Subject: [PATCH 08/11] post rebase clean-up --- src/transformers/models/olmo2/modeling_olmo2.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index e34a1e68f178..6c35587f1f14 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -87,8 +87,8 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -class Olmo1124LinearScalingRotaryEmbedding(Olmo1124RotaryEmbedding): - """Olmo1124RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" +class Olmo2LinearScalingRotaryEmbedding(Olmo2RotaryEmbedding): + """Olmo2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" def forward(self, x, position_ids): # difference to the original RoPE: a scaling factor is aplied to the position ids @@ -97,8 +97,8 @@ def forward(self, x, position_ids): return cos, sin -class Olmo1124DynamicNTKScalingRotaryEmbedding(Olmo1124RotaryEmbedding): - """Olmo1124RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" +class Olmo2DynamicNTKScalingRotaryEmbedding(Olmo2RotaryEmbedding): + """Olmo2RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" def forward(self, x, position_ids): # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length From e8e14a3c5b8cfabefca360a2fd4bcfd48915e882 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 4 Dec 2024 16:16:55 +0100 Subject: [PATCH 09/11] reverse mistral --- .../models/mistral/modeling_mistral.py | 34 +++++++------------ 1 file changed, 12 insertions(+), 22 deletions(-) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 9ffdf84579f8..6ed8178ed982 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -205,6 +205,12 @@ def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None): self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self.rotary_emb = MistralRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + def forward( self, hidden_states: torch.Tensor, @@ -214,8 +220,6 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -227,7 +231,7 @@ def forward( key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - cos, sin = position_embeddings + cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -266,7 +270,6 @@ def forward( return attn_output, attn_weights, past_key_value -# copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Mistral class MistralFlashAttention2(MistralAttention): """ Mistral flash attention module. This module inherits from `MistralAttention` as the weights of the module stays @@ -274,6 +277,7 @@ class MistralFlashAttention2(MistralAttention): flash attention and deal with padding tokens in case the input contains any of them. """ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -291,7 +295,6 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ): if isinstance(past_key_value, StaticCache): raise ValueError( @@ -311,7 +314,7 @@ def forward( key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - cos, sin = position_embeddings + cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -374,6 +377,7 @@ def forward( # copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral +# TODO(joao): add me back asap :) class MistralSdpaAttention(MistralAttention): """ Mistral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -391,7 +395,6 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: @@ -408,7 +411,6 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, ) bsz, q_len, _ = hidden_states.size() @@ -421,7 +423,7 @@ def forward( key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - cos, sin = position_embeddings + cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -472,6 +474,7 @@ def forward( # copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Mistral, LLAMA->MISTRAL +# TODO(joao): add me back asap :) class MistralDecoderLayer(nn.Module): def __init__(self, config: MistralConfig, layer_idx: int): super().__init__() @@ -492,7 +495,6 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -527,7 +529,6 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states @@ -692,12 +693,6 @@ def __init__(self, config: MistralConfig): self._attn_implementation = config._attn_implementation self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb = MistralRotaryEmbedding( - config.head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - ) - self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @@ -772,9 +767,6 @@ def forward( hidden_states = inputs_embeds - # create position embeddings to be shared across the decoder layers - position_embeddings = self.rotary_emb(hidden_states, position_ids) - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None @@ -794,7 +786,6 @@ def forward( output_attentions, use_cache, cache_position, - position_embeddings, ) else: layer_outputs = decoder_layer( @@ -805,7 +796,6 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] From 5a2c962392d61ff9d4cbeba164d763fd296f6504 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 5 Dec 2024 17:04:59 +0100 Subject: [PATCH 10/11] cleanup after dropping commits --- .../modeling_multimodal1.py | 44 +++---------------- .../models/idefics2/modeling_idefics2.py | 1 + 2 files changed, 8 insertions(+), 37 deletions(-) diff --git a/examples/modular-transformers/modeling_multimodal1.py b/examples/modular-transformers/modeling_multimodal1.py index 02876996e0e5..8b7e03ef4552 100644 --- a/examples/modular-transformers/modeling_multimodal1.py +++ b/examples/modular-transformers/modeling_multimodal1.py @@ -228,9 +228,6 @@ def __init__(self, config: Multimodal1TextConfig, layer_idx: Optional[int] = Non self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) - # TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers) - self.rotary_emb = Multimodal1TextRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -240,7 +237,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -254,16 +251,7 @@ def forward( key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -326,7 +314,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if isinstance(past_key_value, StaticCache): @@ -350,16 +338,7 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -441,7 +420,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: @@ -472,16 +451,7 @@ def forward( key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -553,7 +523,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index b095abf0007d..3d46c3bd82e7 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -867,6 +867,7 @@ class Idefics2PerceiverFlashAttention2(Idefics2PerceiverAttention): flash attention and deal with padding tokens in case the input contains any of them. """ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) From c99274450f665e999c171ad07bfa0f895d2f6e2e Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 5 Dec 2024 17:26:22 +0100 Subject: [PATCH 11/11] Add comment --- examples/modular-transformers/modeling_dummy.py | 8 ++++---- examples/modular-transformers/modeling_multimodal1.py | 8 ++++---- examples/modular-transformers/modeling_super.py | 8 ++++---- src/transformers/models/cohere/modeling_cohere.py | 8 ++++---- src/transformers/models/falcon/modeling_falcon.py | 6 +++--- src/transformers/models/glm/modeling_glm.py | 8 ++++---- src/transformers/models/gpt_neox/modeling_gpt_neox.py | 6 +++--- .../gpt_neox_japanese/modeling_gpt_neox_japanese.py | 4 ++-- src/transformers/models/granite/modeling_granite.py | 8 ++++---- src/transformers/models/granitemoe/modeling_granitemoe.py | 8 ++++---- src/transformers/models/llama/modeling_llama.py | 8 ++++---- src/transformers/models/mllama/modeling_mllama.py | 2 +- src/transformers/models/nemotron/modeling_nemotron.py | 2 +- src/transformers/models/persimmon/modeling_persimmon.py | 4 ++-- src/transformers/models/phi/modeling_phi.py | 8 ++++---- src/transformers/models/qwen2/modeling_qwen2.py | 8 ++++---- src/transformers/models/qwen2_moe/modeling_qwen2_moe.py | 8 ++++---- src/transformers/models/qwen2_vl/modeling_qwen2_vl.py | 8 ++++---- src/transformers/models/stablelm/modeling_stablelm.py | 8 ++++---- src/transformers/models/starcoder2/modeling_starcoder2.py | 8 ++++---- src/transformers/models/starcoder2/modular_starcoder2.py | 6 +++--- 21 files changed, 71 insertions(+), 71 deletions(-) diff --git a/examples/modular-transformers/modeling_dummy.py b/examples/modular-transformers/modeling_dummy.py index 33ac61201916..6172c9acfd21 100644 --- a/examples/modular-transformers/modeling_dummy.py +++ b/examples/modular-transformers/modeling_dummy.py @@ -237,7 +237,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -314,7 +314,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if isinstance(past_key_value, StaticCache): @@ -420,7 +420,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: @@ -521,7 +521,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ diff --git a/examples/modular-transformers/modeling_multimodal1.py b/examples/modular-transformers/modeling_multimodal1.py index 8b7e03ef4552..562e7dcab2b9 100644 --- a/examples/modular-transformers/modeling_multimodal1.py +++ b/examples/modular-transformers/modeling_multimodal1.py @@ -237,7 +237,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -314,7 +314,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if isinstance(past_key_value, StaticCache): @@ -420,7 +420,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: @@ -523,7 +523,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ diff --git a/examples/modular-transformers/modeling_super.py b/examples/modular-transformers/modeling_super.py index 8b1d4d71e8d7..79e5ab15a5ed 100644 --- a/examples/modular-transformers/modeling_super.py +++ b/examples/modular-transformers/modeling_super.py @@ -237,7 +237,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -314,7 +314,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if isinstance(past_key_value, StaticCache): @@ -420,7 +420,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: @@ -521,7 +521,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index c3936561e190..b9a235ed500c 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -292,7 +292,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -377,7 +377,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if isinstance(past_key_value, StaticCache): @@ -481,7 +481,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. @@ -586,7 +586,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index bd90c637c717..51d9ff39d48f 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -361,7 +361,7 @@ def forward( use_cache: bool = False, output_attentions: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ): fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads @@ -512,7 +512,7 @@ def forward( use_cache: bool = False, output_attentions: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ): fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads @@ -650,7 +650,7 @@ def forward( use_cache: bool = False, output_attentions: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ): residual = hidden_states diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index dc24d6982aec..b4a292d69de9 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -227,7 +227,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -303,7 +303,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: output_attentions = False @@ -402,7 +402,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: @@ -503,7 +503,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index fd11dc4ec0d0..70ff07ed7f6d 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -311,7 +311,7 @@ def forward( output_attentions: Optional[bool] = False, padding_mask: Optional[torch.Tensor] = None, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ): bsz, seq_len, _ = hidden_states.shape @@ -404,7 +404,7 @@ def _attn_projections_and_rope( layer_past: Optional[Tuple[torch.Tensor]] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ): # Compute QKV # Attention heads [batch, seq_len, hidden_size] @@ -652,7 +652,7 @@ def forward( layer_past: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ): attention_layer_outputs = self.attention( self.input_layernorm(hidden_states), diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index 2a315998c3ee..c9e1b2d72135 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -105,7 +105,7 @@ def forward( use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ): # Compute QKV # Attention heads [batch, seq_len, hidden_size] @@ -406,7 +406,7 @@ def forward( layer_past: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ): residual = hidden_states ln_out = self.input_layernorm(hidden_states) diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index c2661e6ebcc4..8cd24265d9ed 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -242,7 +242,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -318,7 +318,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: output_attentions = False @@ -417,7 +417,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: @@ -520,7 +520,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index ff4a229dd236..9f5fdeea07d4 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -458,7 +458,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -535,7 +535,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: output_attentions = False @@ -635,7 +635,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: @@ -739,7 +739,7 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, output_router_logits: Optional[bool] = False, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 2b17f0af74fd..8e06098b04c6 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -268,7 +268,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -345,7 +345,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if isinstance(past_key_value, StaticCache): @@ -451,7 +451,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: @@ -552,7 +552,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index a9e96e8dbc3f..d53a80dd8929 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -858,7 +858,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index e787a5e4f4e5..78dace1a53ce 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -536,7 +536,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 058c213e844f..884ee4d86aaf 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -259,7 +259,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -354,7 +354,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 0062317d19bb..8e60798e857f 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -267,7 +267,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -370,7 +370,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: # PhiFlashAttention2 attention does not support output_attentions @@ -497,7 +497,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. @@ -619,7 +619,7 @@ def forward( use_cache: Optional[bool] = False, past_key_value: Optional[Tuple[torch.Tensor]] = None, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index efbe504f910f..f4875386253c 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -284,7 +284,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -361,7 +361,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ): bsz, q_len, _ = hidden_states.size() @@ -461,7 +461,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. @@ -565,7 +565,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 73ed4a652e30..a1e36b8ad7bc 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -368,7 +368,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -447,7 +447,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ): bsz, q_len, _ = hidden_states.size() @@ -548,7 +548,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. @@ -716,7 +716,7 @@ def forward( output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 6c79bbe2ea0c..dce0702b0819 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -537,7 +537,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -621,7 +621,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ): bsz, q_len, _ = hidden_states.size() @@ -723,7 +723,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. @@ -829,7 +829,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 66e60d899bb7..0ce550697e79 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -280,7 +280,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -367,7 +367,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. @@ -490,7 +490,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: # StableLmFlashAttention2 attention does not support output_attentions @@ -598,7 +598,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index e6d22b833857..eb218accdb8c 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -260,7 +260,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -335,7 +335,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ): bsz, q_len, _ = hidden_states.size() @@ -426,7 +426,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. @@ -528,7 +528,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ diff --git a/src/transformers/models/starcoder2/modular_starcoder2.py b/src/transformers/models/starcoder2/modular_starcoder2.py index a5e90354d17a..a1cec871baca 100644 --- a/src/transformers/models/starcoder2/modular_starcoder2.py +++ b/src/transformers/models/starcoder2/modular_starcoder2.py @@ -126,7 +126,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -201,7 +201,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ): bsz, q_len, _ = hidden_states.size() @@ -292,7 +292,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.