@@ -134,15 +134,6 @@ def __init__(
134134 self .hf_text_config = get_hf_text_config (self .hf_config )
135135 self .dtype = _get_and_verify_dtype (self .hf_text_config , dtype )
136136
137- if (getattr (self .hf_config , "max_position_embeddings" , 0 ) == 131072
138- and getattr (self .hf_config , "rope_scaling" , None ) is None ):
139- # Note(simon): this is a special case for a model that doesn't
140- # supply rope_scaling. We should remove this once the model is
141- # updated.
142- self .hf_config .update ({"rope_scaling" : {
143- "type" : "extended" ,
144- }})
145-
146137 if (not self .disable_sliding_window
147138 and self .hf_text_config .model_type == "gemma2"
148139 and self .hf_text_config .sliding_window is not None ):
@@ -1245,24 +1236,32 @@ def _get_and_verify_max_len(
12451236 derived_max_model_len = default_max_len
12461237
12471238 rope_scaling = getattr (hf_config , "rope_scaling" , None )
1248- # The correct one should be "longrope", kept "su" here
1249- # to be backward compatible
1250- if rope_scaling is not None and rope_scaling ["type" ] not in {
1251- "su" , "longrope" , "extended"
1252- }:
1253- if disable_sliding_window :
1254- # TODO(robertgshaw): Find a model that supports rope_scaling
1255- # with sliding window to see if this case should be allowed.
1256- raise NotImplementedError (
1257- "Disabling sliding window is not supported for models "
1258- "with rope_scaling. Please raise an issue so we can "
1259- "investigate." )
1260- assert "factor" in rope_scaling
1261- scaling_factor = rope_scaling ["factor" ]
1262- if rope_scaling ["type" ] == "yarn" :
1263- derived_max_model_len = rope_scaling [
1264- "original_max_position_embeddings" ]
1265- derived_max_model_len *= scaling_factor
1239+ if rope_scaling is not None :
1240+ if "type" in rope_scaling :
1241+ rope_type = rope_scaling ["type" ]
1242+ elif "rope_type" in rope_scaling :
1243+ rope_type = rope_scaling ["rope_type" ]
1244+ else :
1245+ raise ValueError (
1246+ "rope_scaling must have a 'type' or 'rope_type' key." )
1247+
1248+ # The correct one should be "longrope", kept "su" here
1249+ # to be backward compatible
1250+ if rope_type not in ("su" , "longrope" , "llama3" ):
1251+ if disable_sliding_window :
1252+ # TODO(robertgshaw): Find a model that supports rope_scaling
1253+ # with sliding window to see if this case should be allowed.
1254+ raise NotImplementedError (
1255+ "Disabling sliding window is not supported for models "
1256+ "with rope_scaling. Please raise an issue so we can "
1257+ "investigate." )
1258+
1259+ assert "factor" in rope_scaling
1260+ scaling_factor = rope_scaling ["factor" ]
1261+ if rope_type == "yarn" :
1262+ derived_max_model_len = rope_scaling [
1263+ "original_max_position_embeddings" ]
1264+ derived_max_model_len *= scaling_factor
12661265
12671266 # If the user specified a max length, make sure it is smaller than the
12681267 # derived length from the HF model config.
0 commit comments