Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,15 +508,12 @@ def __init__(
):
super().__init__()

if rotary_dim != head_size:
raise ValueError(
f"`Phi3LongRoPEScaledRotaryEmbedding` does not support \
rotary_dim != head_size ({rotary_dim}!={head_size}).")
if is_neox_style is False:
raise ValueError(
"`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style."
)

self.rotary_dim = rotary_dim
self.head_size = head_size
self.max_position_embeddings = max_position_embeddings
self.original_max_position_embeddings = original_max_position_embeddings
Expand Down Expand Up @@ -556,7 +553,7 @@ def __init__(
def _compute_inv_freq(self, rescale_factors: List[float]) -> torch.Tensor:
rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32)
inv_freq = 1.0 / (rescale_factors * (self.base**(torch.arange(
0, self.head_size, 2, dtype=torch.float) / self.head_size)))
0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim)))
return inv_freq

def _compute_cos_sin_cache(
Expand Down Expand Up @@ -595,8 +592,15 @@ def forward(
cos = cos.repeat(1, 2).unsqueeze(-2)
sin = sin.repeat(1, 2).unsqueeze(-2)

query = query * cos + _rotate_neox(query) * sin
key = key * cos + _rotate_neox(key) * sin
query_rot = query[..., :self.rotary_dim]
query_pass = query[..., self.rotary_dim:]
query_rot = query_rot * cos + _rotate_neox(query_rot) * sin
query = torch.cat((query_rot, query_pass), dim=-1)

key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim:]
key_rot = key_rot * cos + _rotate_neox(key_rot) * sin
key = torch.cat((key_rot, key_pass), dim=-1)

return query.flatten(-2), key.flatten(-2)

Expand Down
5 changes: 4 additions & 1 deletion vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ def __init__(self,
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
self.head_dim = getattr(config, "head_dim",
self.hidden_size // self.total_num_heads)
# Phi models introduced a partial_rotary_factor parameter in the config
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
self.rotary_dim = int(partial_rotary_factor * self.head_dim)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
Expand Down Expand Up @@ -159,7 +162,7 @@ def __init__(self,

self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
rotary_dim=self.rotary_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
Expand Down