diff --git a/src/transformers/models/apertus/modeling_apertus.py b/src/transformers/models/apertus/modeling_apertus.py index 8ca279ded178..7a3db9c4c168 100644 --- a/src/transformers/models/apertus/modeling_apertus.py +++ b/src/transformers/models/apertus/modeling_apertus.py @@ -89,7 +89,10 @@ def __init__(self, config: ApertusConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/arcee/modeling_arcee.py b/src/transformers/models/arcee/modeling_arcee.py index 7d9afa092def..c752c67eee72 100644 --- a/src/transformers/models/arcee/modeling_arcee.py +++ b/src/transformers/models/arcee/modeling_arcee.py @@ -96,7 +96,10 @@ def __init__(self, config: ArceeConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 789bea15ef54..f6e37ac5cb18 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -629,7 +629,10 @@ def __init__(self, config: AriaTextConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index 97e911351c80..0573c5fb6871 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -207,7 +207,10 @@ def __init__(self, config: BambaConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/bitnet/modeling_bitnet.py b/src/transformers/models/bitnet/modeling_bitnet.py index 0af7c794b155..4a31e6010111 100644 --- a/src/transformers/models/bitnet/modeling_bitnet.py +++ b/src/transformers/models/bitnet/modeling_bitnet.py @@ -284,7 +284,10 @@ def __init__(self, config: BitNetConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index 1e677dda4a98..54447c3da9a0 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -99,7 +99,10 @@ def __init__(self, config: BltConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 1dfa0ce0be33..a9a5d359632b 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -79,7 +79,10 @@ def __init__(self, config: CohereConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index bab804aab67e..b5186a79fb9b 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -54,7 +54,10 @@ def __init__(self, config: Cohere2Config, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/csm/modeling_csm.py b/src/transformers/models/csm/modeling_csm.py index 80157e2aa93a..ecc4e9f94c60 100644 --- a/src/transformers/models/csm/modeling_csm.py +++ b/src/transformers/models/csm/modeling_csm.py @@ -130,7 +130,10 @@ def __init__(self, config: CsmConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 5043f5898255..f755cf12fa59 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -53,7 +53,10 @@ def __init__(self, config: DbrxConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index 629d1cf2ccad..0025c7eb5c06 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -67,7 +67,10 @@ def __init__(self, config: DeepseekV3Config, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/dia/modeling_dia.py b/src/transformers/models/dia/modeling_dia.py index 3025c8de4faa..78c6f8cff858 100644 --- a/src/transformers/models/dia/modeling_dia.py +++ b/src/transformers/models/dia/modeling_dia.py @@ -154,7 +154,10 @@ def __init__(self, config: DiaConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index 094cc375057f..bba34280d87b 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -568,7 +568,10 @@ def __init__(self, config: DiffLlamaConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/doge/modeling_doge.py b/src/transformers/models/doge/modeling_doge.py index f92371fdaba6..078a4490d131 100644 --- a/src/transformers/models/doge/modeling_doge.py +++ b/src/transformers/models/doge/modeling_doge.py @@ -84,7 +84,10 @@ def __init__(self, config: DogeConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/dots1/modeling_dots1.py b/src/transformers/models/dots1/modeling_dots1.py index ea6698b470de..eb2549fc6e2b 100644 --- a/src/transformers/models/dots1/modeling_dots1.py +++ b/src/transformers/models/dots1/modeling_dots1.py @@ -76,7 +76,10 @@ def __init__(self, config: Dots1Config, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index 5e791d1042f6..d64953e9580e 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -1121,7 +1121,10 @@ def __init__(self, config: Emu3Config, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/ernie4_5/modeling_ernie4_5.py b/src/transformers/models/ernie4_5/modeling_ernie4_5.py index 13ec6fb3a3b6..922d83318143 100644 --- a/src/transformers/models/ernie4_5/modeling_ernie4_5.py +++ b/src/transformers/models/ernie4_5/modeling_ernie4_5.py @@ -53,7 +53,10 @@ def __init__(self, config: Ernie4_5Config, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py b/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py index 66d6b1935096..949c379d6d77 100644 --- a/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +++ b/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py @@ -92,7 +92,10 @@ def __init__(self, config: Ernie4_5_MoeConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/evolla/modeling_evolla.py b/src/transformers/models/evolla/modeling_evolla.py index 75db8a22a022..2aac4be9dc52 100644 --- a/src/transformers/models/evolla/modeling_evolla.py +++ b/src/transformers/models/evolla/modeling_evolla.py @@ -1093,7 +1093,10 @@ def __init__(self, config: EvollaConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/exaone4/modeling_exaone4.py b/src/transformers/models/exaone4/modeling_exaone4.py index 2693a80c79fd..ee5d02832e4f 100644 --- a/src/transformers/models/exaone4/modeling_exaone4.py +++ b/src/transformers/models/exaone4/modeling_exaone4.py @@ -82,7 +82,10 @@ def __init__(self, config: Exaone4Config, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index b32981c51353..21c26ae4a90b 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -98,7 +98,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed -# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Falcon +# Adapted from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Falcon class FalconRotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` diff --git a/src/transformers/models/falcon_h1/modeling_falcon_h1.py b/src/transformers/models/falcon_h1/modeling_falcon_h1.py index 62797f2ecc63..18a4e20222dd 100644 --- a/src/transformers/models/falcon_h1/modeling_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modeling_falcon_h1.py @@ -236,7 +236,10 @@ def __init__(self, config: FalconH1Config, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/flex_olmo/modeling_flex_olmo.py b/src/transformers/models/flex_olmo/modeling_flex_olmo.py index 1777b8f31237..53981b4e7d9d 100644 --- a/src/transformers/models/flex_olmo/modeling_flex_olmo.py +++ b/src/transformers/models/flex_olmo/modeling_flex_olmo.py @@ -75,7 +75,10 @@ def __init__(self, config: FlexOlmoConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index ef0a688d4608..228288980402 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -93,7 +93,10 @@ def __init__(self, config: GemmaConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 2a218338384a..c6221cd2126d 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -97,7 +97,10 @@ def __init__(self, config: Gemma2Config, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index ed9c83180059..fbc23e09507a 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -157,7 +157,10 @@ def __init__(self, config: Gemma3TextConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/gemma3n/modeling_gemma3n.py b/src/transformers/models/gemma3n/modeling_gemma3n.py index 68595ead4371..667537f102fa 100644 --- a/src/transformers/models/gemma3n/modeling_gemma3n.py +++ b/src/transformers/models/gemma3n/modeling_gemma3n.py @@ -1158,7 +1158,10 @@ def __init__(self, config: Gemma3nTextConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 59c9f39da527..1032c8539f0a 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -253,7 +253,10 @@ def __init__(self, config: GlmConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/glm4/modeling_glm4.py b/src/transformers/models/glm4/modeling_glm4.py index dafab297f566..3df0ff43433a 100644 --- a/src/transformers/models/glm4/modeling_glm4.py +++ b/src/transformers/models/glm4/modeling_glm4.py @@ -303,7 +303,10 @@ def __init__(self, config: Glm4Config, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/glm4_moe/modeling_glm4_moe.py b/src/transformers/models/glm4_moe/modeling_glm4_moe.py index fea2b97685d2..cbbb948e06dc 100644 --- a/src/transformers/models/glm4_moe/modeling_glm4_moe.py +++ b/src/transformers/models/glm4_moe/modeling_glm4_moe.py @@ -439,7 +439,10 @@ def __init__(self, config: Glm4MoeConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/glm4v/modeling_glm4v.py b/src/transformers/models/glm4v/modeling_glm4v.py index fbfeae9130e1..11e1b08f608c 100644 --- a/src/transformers/models/glm4v/modeling_glm4v.py +++ b/src/transformers/models/glm4v/modeling_glm4v.py @@ -399,7 +399,10 @@ def __init__(self, config: Glm4vTextConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py index 0b65072b2404..063d2ab1b2f0 100644 --- a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py +++ b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py @@ -748,7 +748,10 @@ def __init__(self, config: Glm4vMoeTextConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 07072c077089..39f55a4eb5d1 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -259,7 +259,10 @@ def __init__(self, config: GPTNeoXConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index a930070bfba7..ff55b35ea76d 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -227,7 +227,10 @@ def __init__(self, config: GPTNeoXJapaneseConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/gpt_oss/modeling_gpt_oss.py b/src/transformers/models/gpt_oss/modeling_gpt_oss.py index d55aa3f31d4f..a0c7a14f67a6 100644 --- a/src/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -185,7 +185,10 @@ def __init__(self, config: GptOssConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 846865c55508..a0662ea500a8 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -334,7 +334,10 @@ def __init__(self, config: GraniteConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index 8dca068b5915..ebd687b22192 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -76,7 +76,10 @@ def __init__(self, config: GraniteMoeConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index cc623b0f1d28..50befb799512 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -1161,7 +1161,10 @@ def __init__(self, config: GraniteMoeHybridConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py index e844e23305ca..62f3beb32c96 100644 --- a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py +++ b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py @@ -490,7 +490,10 @@ def __init__(self, config: GraniteMoeSharedConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index 9f4a2e73affd..efecafea6e2a 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -75,7 +75,10 @@ def __init__(self, config: HeliumConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index 7b1af7433e8d..4e12d8f3a7e5 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -79,7 +79,10 @@ def __init__(self, config: JetMoeConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py b/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py index 77c636570d58..fdb2aebcc98c 100644 --- a/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +++ b/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py @@ -267,6 +267,7 @@ def forward(self, x, layer_idx=None): return self.linear(x) +# Adapted from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->KyutaiSpeechToText class KyutaiSpeechToTextRotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` @@ -281,7 +282,10 @@ def __init__(self, config: KyutaiSpeechToTextConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/lfm2/modeling_lfm2.py b/src/transformers/models/lfm2/modeling_lfm2.py index 5ea4314968e2..c53080566b1c 100644 --- a/src/transformers/models/lfm2/modeling_lfm2.py +++ b/src/transformers/models/lfm2/modeling_lfm2.py @@ -80,7 +80,10 @@ def __init__(self, config: Lfm2Config, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index f43a5fc9b523..ed34cdd3399a 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -84,7 +84,10 @@ def __init__(self, config: LlamaConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/longcat_flash/modeling_longcat_flash.py b/src/transformers/models/longcat_flash/modeling_longcat_flash.py index e13474815ac7..7cdbb2e88652 100644 --- a/src/transformers/models/longcat_flash/modeling_longcat_flash.py +++ b/src/transformers/models/longcat_flash/modeling_longcat_flash.py @@ -78,7 +78,10 @@ def __init__(self, config: LongcatFlashConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index f22cad968247..0e25ed561dda 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -501,7 +501,7 @@ def forward(self, x: torch.Tensor): return self.scale * x -# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mimi +# adapted from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mimi class MimiRotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` @@ -516,7 +516,10 @@ def __init__(self, config: MimiConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/minimax/modeling_minimax.py b/src/transformers/models/minimax/modeling_minimax.py index f12f3b817e3d..444890e2f54f 100644 --- a/src/transformers/models/minimax/modeling_minimax.py +++ b/src/transformers/models/minimax/modeling_minimax.py @@ -568,7 +568,10 @@ def __init__(self, config: MiniMaxConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/ministral/modeling_ministral.py b/src/transformers/models/ministral/modeling_ministral.py index 91ea520c6167..2136387056c9 100644 --- a/src/transformers/models/ministral/modeling_ministral.py +++ b/src/transformers/models/ministral/modeling_ministral.py @@ -285,7 +285,10 @@ def __init__(self, config: MinistralConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 5b7c7b2c1790..b8339fef2eb7 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -281,7 +281,10 @@ def __init__(self, config: MistralConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 533cab15f647..85022c3d7cf8 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -343,7 +343,10 @@ def __init__(self, config: MixtralConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 00fbe19c3a63..7a6fd3d8e984 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -256,7 +256,10 @@ def __init__(self, config: ModernBertConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py b/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py index 20b9db63950c..13abd9e05081 100644 --- a/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +++ b/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py @@ -109,7 +109,10 @@ def __init__(self, config: ModernBertDecoderConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py index 42b66aa185c8..c18de027354a 100644 --- a/src/transformers/models/moonshine/modeling_moonshine.py +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -309,7 +309,10 @@ def __init__(self, config: MoonshineConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index 7f685050930c..e9f6efa95e91 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -269,7 +269,7 @@ def forward(self, x, layer_idx=None): return self.linear(x) -# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Moshi +# Adapted from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Moshi class MoshiRotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` @@ -284,7 +284,10 @@ def __init__(self, config: MoshiConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 2d6e6e7092cb..2bfe89424cec 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -267,7 +267,10 @@ def __init__(self, config: OlmoConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 3fe4cfaf91de..967134fa4378 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -272,7 +272,10 @@ def __init__(self, config: Olmo2Config, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index fac2604bbf7e..7414f373c445 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -72,7 +72,10 @@ def __init__(self, config: OlmoeConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index c963bb53852a..07fd410c45b6 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -55,7 +55,7 @@ logger = logging.get_logger(__name__) -# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Persimmon +# Adapted from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Persimmon class PersimmonRotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 165a2b887423..e975bc135cae 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -270,7 +270,10 @@ def __init__(self, config: PhiConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 23820075a020..f69e6d83935e 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -313,7 +313,10 @@ def __init__(self, config: Phi3Config, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py index bb495642c710..660eea59c2a7 100644 --- a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py @@ -1491,7 +1491,10 @@ def __init__(self, config: Phi4MultimodalConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 2fcb44372fe4..c78454730b41 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -284,7 +284,10 @@ def __init__(self, config: Qwen2Config, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index c8b99164730e..01faeb560d64 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -1241,7 +1241,10 @@ def __init__(self, config: Qwen2_5OmniThinkerConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index a98574551922..5a2912fe682d 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -502,7 +502,10 @@ def __init__(self, config: Qwen2_5_VLTextConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 58d316e5a587..aa5adef80911 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -86,7 +86,10 @@ def __init__(self, config: Qwen2MoeConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 269f37492ad6..c9883d53d386 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -121,7 +121,10 @@ def __init__(self, config: Qwen2VLTextConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index 81b16c4ee6b6..65b13acbd14a 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -310,7 +310,10 @@ def __init__(self, config: Qwen3Config, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index d9568865be7e..2341bd10a685 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -372,7 +372,10 @@ def __init__(self, config: Qwen3MoeConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/qwen3_next/modeling_qwen3_next.py b/src/transformers/models/qwen3_next/modeling_qwen3_next.py index f10d24c96c46..1e57e1873a6b 100644 --- a/src/transformers/models/qwen3_next/modeling_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modeling_qwen3_next.py @@ -190,7 +190,10 @@ def __init__(self, config: Qwen3NextConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index a008c68d4fdc..8020446493d2 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -2398,7 +2398,10 @@ def __init__(self, config: Qwen3OmniMoeConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) @@ -3247,7 +3250,10 @@ def __init__(self, config: Qwen3OmniMoeConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/seed_oss/modeling_seed_oss.py b/src/transformers/models/seed_oss/modeling_seed_oss.py index f0be87883d94..7b7e9d1c57e5 100644 --- a/src/transformers/models/seed_oss/modeling_seed_oss.py +++ b/src/transformers/models/seed_oss/modeling_seed_oss.py @@ -309,7 +309,10 @@ def __init__(self, config: SeedOssConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/smollm3/modeling_smollm3.py b/src/transformers/models/smollm3/modeling_smollm3.py index 1e08e288193b..ae516c990d31 100644 --- a/src/transformers/models/smollm3/modeling_smollm3.py +++ b/src/transformers/models/smollm3/modeling_smollm3.py @@ -314,7 +314,10 @@ def __init__(self, config: SmolLM3Config, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 6b31565a1b1d..aef0cb385fe2 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -59,7 +59,7 @@ logger = logging.get_logger(__name__) -# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->StableLm +# Adapted from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->StableLm class StableLmRotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index dfdfec22ca99..1a970398ef7c 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -265,7 +265,10 @@ def __init__(self, config: Starcoder2Config, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/t5gemma/modeling_t5gemma.py b/src/transformers/models/t5gemma/modeling_t5gemma.py index 336e67ce42b6..5539284184dc 100644 --- a/src/transformers/models/t5gemma/modeling_t5gemma.py +++ b/src/transformers/models/t5gemma/modeling_t5gemma.py @@ -103,7 +103,10 @@ def __init__(self, config, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/vaultgemma/modeling_vaultgemma.py b/src/transformers/models/vaultgemma/modeling_vaultgemma.py index eaad6c5335a4..fda15a18de12 100644 --- a/src/transformers/models/vaultgemma/modeling_vaultgemma.py +++ b/src/transformers/models/vaultgemma/modeling_vaultgemma.py @@ -302,7 +302,10 @@ def __init__(self, config: VaultGemmaConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index dcef3c5237d7..1a9737fee7d3 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -222,7 +222,10 @@ def __init__(self, config: Zamba2Config, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"] + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False)