Skip to content

Commit dcd6ad0

Browse files
WoosukKwongshtras
authored andcommitted
[BugFix] Fix RoPE error in Llama 3.1 (vllm-project#6693)
1 parent be817de commit dcd6ad0

File tree

2 files changed

+30
-30
lines changed

2 files changed

+30
-30
lines changed

vllm/config.py

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -134,15 +134,6 @@ def __init__(
134134
self.hf_text_config = get_hf_text_config(self.hf_config)
135135
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
136136

137-
if (getattr(self.hf_config, "max_position_embeddings", 0) == 131072
138-
and getattr(self.hf_config, "rope_scaling", None) is None):
139-
# Note(simon): this is a special case for a model that doesn't
140-
# supply rope_scaling. We should remove this once the model is
141-
# updated.
142-
self.hf_config.update({"rope_scaling": {
143-
"type": "extended",
144-
}})
145-
146137
if (not self.disable_sliding_window
147138
and self.hf_text_config.model_type == "gemma2"
148139
and self.hf_text_config.sliding_window is not None):
@@ -1245,24 +1236,32 @@ def _get_and_verify_max_len(
12451236
derived_max_model_len = default_max_len
12461237

12471238
rope_scaling = getattr(hf_config, "rope_scaling", None)
1248-
# The correct one should be "longrope", kept "su" here
1249-
# to be backward compatible
1250-
if rope_scaling is not None and rope_scaling["type"] not in {
1251-
"su", "longrope", "extended"
1252-
}:
1253-
if disable_sliding_window:
1254-
# TODO(robertgshaw): Find a model that supports rope_scaling
1255-
# with sliding window to see if this case should be allowed.
1256-
raise NotImplementedError(
1257-
"Disabling sliding window is not supported for models "
1258-
"with rope_scaling. Please raise an issue so we can "
1259-
"investigate.")
1260-
assert "factor" in rope_scaling
1261-
scaling_factor = rope_scaling["factor"]
1262-
if rope_scaling["type"] == "yarn":
1263-
derived_max_model_len = rope_scaling[
1264-
"original_max_position_embeddings"]
1265-
derived_max_model_len *= scaling_factor
1239+
if rope_scaling is not None:
1240+
if "type" in rope_scaling:
1241+
rope_type = rope_scaling["type"]
1242+
elif "rope_type" in rope_scaling:
1243+
rope_type = rope_scaling["rope_type"]
1244+
else:
1245+
raise ValueError(
1246+
"rope_scaling must have a 'type' or 'rope_type' key.")
1247+
1248+
# The correct one should be "longrope", kept "su" here
1249+
# to be backward compatible
1250+
if rope_type not in ("su", "longrope", "llama3"):
1251+
if disable_sliding_window:
1252+
# TODO(robertgshaw): Find a model that supports rope_scaling
1253+
# with sliding window to see if this case should be allowed.
1254+
raise NotImplementedError(
1255+
"Disabling sliding window is not supported for models "
1256+
"with rope_scaling. Please raise an issue so we can "
1257+
"investigate.")
1258+
1259+
assert "factor" in rope_scaling
1260+
scaling_factor = rope_scaling["factor"]
1261+
if rope_type == "yarn":
1262+
derived_max_model_len = rope_scaling[
1263+
"original_max_position_embeddings"]
1264+
derived_max_model_len *= scaling_factor
12661265

12671266
# If the user specified a max length, make sure it is smaller than the
12681267
# derived length from the HF model config.

vllm/model_executor/layers/rotary_embedding.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -687,12 +687,13 @@ def get_rope(
687687
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
688688
is_neox_style, dtype)
689689
else:
690-
scaling_type = rope_scaling["type"]
690+
scaling_type = rope_scaling[
691+
"type"] if "type" in rope_scaling else rope_scaling["rope_type"]
691692
# The correct one should be "longrope" but keep "su" here
692693
# for backward compatible
693-
if scaling_type not in {"su", "longrope", "extended"}:
694+
if scaling_type not in {"su", "longrope", "llama3"}:
694695
scaling_factor = rope_scaling["factor"]
695-
if scaling_type == "extended":
696+
if scaling_type == "llama3":
696697
rotary_emb = ExtendedRotaryEmbedding(head_size, rotary_dim,
697698
max_position, base,
698699
is_neox_style, dtype)

0 commit comments

Comments
 (0)