@@ -503,6 +503,159 @@ def forward(
503503 return query .flatten (- 2 ), key .flatten (- 2 )
504504
505505
506+ def yarn_get_mscale (scale : float = 1 , mscale : float = 1 ) -> float :
507+ if scale <= 1 :
508+ return 1.0
509+ return 0.1 * mscale * math .log (scale ) + 1.0
510+
511+
512+ class DeepseekScalingRotaryEmbedding (RotaryEmbedding ):
513+ """RotaryEmbedding extended with YaRN method.
514+
515+ Credits to Peng et al. github.com/jquesnelle/yarn
516+ """
517+
518+ def __init__ (
519+ self ,
520+ head_size : int ,
521+ rotary_dim : int ,
522+ max_position_embeddings : int ,
523+ base : int ,
524+ is_neox_style : bool ,
525+ scaling_factor : float ,
526+ dtype : torch .dtype ,
527+ * ,
528+ extrapolation_factor : float = 1 ,
529+ attn_factor : float = 1 ,
530+ beta_fast : int = 32 ,
531+ beta_slow : int = 1 ,
532+ mscale : float = 1 ,
533+ mscale_all_dim : float = 0 ,
534+ ) -> None :
535+ self .scaling_factor = scaling_factor
536+ self .extrapolation_factor = extrapolation_factor
537+ self .attn_factor = attn_factor
538+ self .beta_fast = beta_fast
539+ self .beta_slow = beta_slow
540+ # Get n-d magnitude scaling corrected for interpolation.
541+ self .mscale = float (
542+ yarn_get_mscale (self .scaling_factor , float (mscale )) /
543+ yarn_get_mscale (self .scaling_factor , float (mscale_all_dim )) *
544+ attn_factor )
545+ super ().__init__ (head_size , rotary_dim , max_position_embeddings , base ,
546+ is_neox_style , dtype )
547+
548+ def _compute_inv_freq (self , scaling_factor : float ) -> torch .Tensor :
549+ pos_freqs = self .base ** (torch .arange (
550+ 0 , self .rotary_dim , 2 , dtype = torch .float , device = "cuda" ) /
551+ self .rotary_dim )
552+ inv_freq_extrapolation = 1.0 / pos_freqs
553+ inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs )
554+
555+ low , high = _yarn_find_correction_range (self .beta_fast , self .beta_slow ,
556+ self .rotary_dim , self .base ,
557+ self .max_position_embeddings )
558+ # Get n-d rotational scaling corrected for extrapolation
559+ inv_freq_mask = (1 - _yarn_linear_ramp_mask (
560+ low , high , self .rotary_dim // 2 ,
561+ dtype = torch .float )) * self .extrapolation_factor
562+ inv_freq = inv_freq_interpolation * (
563+ 1 - inv_freq_mask ) + inv_freq_extrapolation * inv_freq_mask
564+ return inv_freq
565+
566+ def _compute_cos_sin_cache (self ) -> torch .Tensor :
567+ inv_freq = self ._compute_inv_freq (self .scaling_factor )
568+ t = torch .arange (self .max_position_embeddings * self .scaling_factor ,
569+ device = "cuda" ,
570+ dtype = torch .float32 )
571+ freqs = torch .einsum ("i,j -> ij" , t , inv_freq )
572+ cos = (freqs .cos () * self .mscale )
573+ sin = (freqs .sin () * self .mscale )
574+ cache = torch .cat ((cos , sin ), dim = - 1 )
575+ print ("Cache shape" , cache .shape )
576+ return cache
577+
578+ def forward (
579+ self ,
580+ positions : torch .Tensor ,
581+ query : torch .Tensor ,
582+ key : torch .Tensor ,
583+ offsets : Optional [torch .Tensor ] = None ,
584+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
585+ """PyTorch-native implementation equivalent to forward()."""
586+ query_rot = query [..., :self .rotary_dim ]
587+ key_rot = key [..., :self .rotary_dim ]
588+ if self .rotary_dim < self .head_size :
589+ query_pass = query [..., self .rotary_dim :]
590+ key_pass = key [..., self .rotary_dim :]
591+
592+ self .cos_sin_cache : torch .Tensor = self .cos_sin_cache .to (
593+ positions .device )
594+ cos_sin = self .cos_sin_cache [torch .add (positions , offsets )
595+ if offsets is not None else positions ]
596+ cos , sin = cos_sin .chunk (2 , dim = - 1 )
597+ if self .is_neox_style :
598+ # NOTE(woosuk): Here we assume that the positions tensor has the
599+ # shape [batch_size, seq_len].
600+ cos = cos .repeat (1 , 1 , 2 ).unsqueeze (- 2 )
601+ sin = sin .repeat (1 , 1 , 2 ).unsqueeze (- 2 )
602+ else :
603+ cos = cos .repeat_interleave (2 , dim = - 1 ).unsqueeze (- 2 )
604+ sin = sin .repeat_interleave (2 , dim = - 1 ).unsqueeze (- 2 )
605+
606+ rotate_fn = _rotate_neox if self .is_neox_style else _rotate_gptj
607+ query_rot = query_rot * cos + rotate_fn (query_rot ) * sin
608+ key_rot = key_rot * cos + rotate_fn (key_rot ) * sin
609+
610+ if self .rotary_dim < self .head_size :
611+ query = torch .cat ((query_rot , query_pass ), dim = - 1 )
612+ key = torch .cat ((key_rot , key_pass ), dim = - 1 )
613+ else :
614+ query = query_rot
615+ key = key_rot
616+ return query , key
617+
618+
619+ class GemmaRotaryEmbedding (RotaryEmbedding ):
620+
621+ def _compute_inv_freq (self , base : Union [int , float ]) -> torch .Tensor :
622+ # https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107
623+ inv_freq = 1.0 / (base ** (
624+ torch .arange (0 , self .rotary_dim , 2 , dtype = torch .int64 ).float () /
625+ self .rotary_dim ))
626+ return inv_freq
627+
628+
629+ class ExtendedRotaryEmbedding (RotaryEmbedding ):
630+
631+ def _compute_inv_freq (self , base : Union [int , float ]) -> torch .Tensor :
632+ inv_freqs = super ()._compute_inv_freq (base )
633+ return self .apply_scaling (inv_freqs )
634+
635+ def apply_scaling (self , freqs : torch .Tensor ):
636+ scale_factor = 8
637+ low_freq_factor = 1
638+ high_freq_factor = 4
639+ old_context_len = 8192
640+
641+ low_freq_wavelen = old_context_len / low_freq_factor
642+ high_freq_wavelen = old_context_len / high_freq_factor
643+ new_freqs = []
644+ for freq in freqs :
645+ wavelen = 2 * math .pi / freq
646+ if wavelen < high_freq_wavelen :
647+ new_freqs .append (freq )
648+ elif wavelen > low_freq_wavelen :
649+ new_freqs .append (freq / scale_factor )
650+ else :
651+ assert low_freq_wavelen != high_freq_wavelen
652+ smooth = (old_context_len / wavelen - low_freq_factor ) / (
653+ high_freq_factor - low_freq_factor )
654+ new_freqs .append ((1 - smooth ) * freq / scale_factor +
655+ smooth * freq )
656+ return torch .tensor (new_freqs , dtype = freqs .dtype , device = freqs .device )
657+
658+
506659_ROPE_DICT : Dict [Tuple , RotaryEmbedding ] = {}
507660
508661
@@ -534,10 +687,17 @@ def get_rope(
534687 rotary_emb = RotaryEmbedding (head_size , rotary_dim , max_position , base ,
535688 is_neox_style , dtype )
536689 else :
537- scaling_type = rope_scaling ["type" ]
538- if scaling_type != "su" :
690+ scaling_type = rope_scaling [
691+ "type" ] if "type" in rope_scaling else rope_scaling ["rope_type" ]
692+ # The correct one should be "longrope" but keep "su" here
693+ # for backward compatible
694+ if scaling_type not in {"su" , "longrope" , "llama3" }:
539695 scaling_factor = rope_scaling ["factor" ]
540- if scaling_type == "linear" :
696+ if scaling_type == "llama3" :
697+ rotary_emb = ExtendedRotaryEmbedding (head_size , rotary_dim ,
698+ max_position , base ,
699+ is_neox_style , dtype )
700+ elif scaling_type == "linear" :
541701 rotary_emb = LinearScalingRotaryEmbedding (head_size , rotary_dim ,
542702 max_position , base ,
543703 is_neox_style ,
0 commit comments