-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
[V1][Spec Decode] Eagle Model loading #16035
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
19b8b73
b566516
bc1b7d0
e5ad748
41e4d03
10b107d
59ee450
2ce6084
b0388cf
a6f46cf
5bb90c4
be3a01b
ba1671e
560eaee
32bbd98
302b591
bcf2388
29e7637
d511a0e
1eb4fb9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from typing import Iterable, Set, Tuple | ||
|
||
import torch | ||
import torch.nn as nn | ||
from transformers import LlamaConfig | ||
|
||
from vllm.config import ModelConfig | ||
from vllm.logger import init_logger | ||
from vllm.model_executor.layers.logits_processor import LogitsProcessor | ||
from vllm.model_executor.layers.vocab_parallel_embedding import ( | ||
VocabParallelEmbedding) | ||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader | ||
from vllm.model_executor.models.llama import (LlamaDecoderLayer, | ||
LlamaForCausalLM) | ||
|
||
from .utils import AutoWeightsLoader, maybe_prefix | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
||
class LlamaDecoderLayer(LlamaDecoderLayer): | ||
|
||
def __init__( | ||
self, | ||
config: LlamaConfig, | ||
disable_input_layernorm: bool, | ||
prefix: str = "", | ||
) -> None: | ||
super().__init__(config, prefix=prefix) | ||
|
||
# Skip the input_layernorm | ||
# https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427 | ||
if disable_input_layernorm: | ||
del self.input_layernorm | ||
self.input_layernorm = nn.Identity() | ||
|
||
|
||
class LlamaModel(nn.Module): | ||
|
||
def __init__( | ||
self, | ||
*, | ||
model_config: ModelConfig, | ||
start_layer_id: int = 0, | ||
prefix: str = "", | ||
) -> None: | ||
super().__init__() | ||
self.config = model_config.hf_config | ||
self.vocab_size = self.config.vocab_size | ||
self.embed_tokens = VocabParallelEmbedding( | ||
self.config.vocab_size, | ||
self.config.hidden_size, | ||
prefix=maybe_prefix(prefix, "embed_tokens"), | ||
) | ||
self.layers = nn.ModuleList([ | ||
LlamaDecoderLayer( | ||
self.config, | ||
i == 0, | ||
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), | ||
) for i in range(self.config.num_hidden_layers) | ||
]) | ||
self.fc = torch.nn.Linear(self.config.hidden_size * 2, | ||
self.config.hidden_size, | ||
bias=False) | ||
|
||
def forward( | ||
self, | ||
input_ids: torch.Tensor, | ||
positions: torch.Tensor, | ||
hidden_states: torch.Tensor, | ||
) -> torch.Tensor: | ||
input_embeds = self.embed_tokens(input_ids) | ||
hidden_states = self.fc( | ||
torch.cat((input_embeds, hidden_states), dim=-1)) | ||
residual = None | ||
for i in range(len(self.layers)): | ||
layer = self.layers[i] | ||
hidden_states, residual = layer( | ||
positions, | ||
hidden_states, | ||
residual, | ||
) | ||
return hidden_states + residual | ||
|
||
def load_weights(self, weights: Iterable[Tuple[str, | ||
torch.Tensor]]) -> Set[str]: | ||
stacked_params_mapping = [ | ||
# (param_name, shard_name, shard_id) | ||
(".qkv_proj", ".q_proj", "q"), | ||
(".qkv_proj", ".k_proj", "k"), | ||
(".qkv_proj", ".v_proj", "v"), | ||
(".gate_up_proj", ".gate_proj", 0), | ||
(".gate_up_proj", ".up_proj", 1), | ||
] | ||
params_dict = dict(self.named_parameters()) | ||
loaded_params: Set[str] = set() | ||
for name, loaded_weight in weights: | ||
for param_name, weight_name, shard_id in stacked_params_mapping: | ||
if weight_name not in name: | ||
continue | ||
name = name.replace(weight_name, param_name) | ||
param = params_dict[name] | ||
weight_loader = param.weight_loader | ||
weight_loader(param, loaded_weight, shard_id) | ||
break | ||
else: | ||
param = params_dict[name] | ||
weight_loader = getattr(param, "weight_loader", | ||
default_weight_loader) | ||
weight_loader(param, loaded_weight) | ||
loaded_params.add(name) | ||
return loaded_params | ||
|
||
|
||
class EagleLlamaForCausalLM(LlamaForCausalLM): | ||
|
||
def __init__(self, *, model_config: ModelConfig, start_layer_id: int = 0): | ||
nn.Module.__init__(self) | ||
self.config = model_config.hf_config | ||
self.model = LlamaModel(model_config=model_config, | ||
start_layer_id=start_layer_id, | ||
prefix="model") | ||
|
||
logit_scale = getattr(self.config, "logit_scale", 1.0) | ||
self.logits_processor = LogitsProcessor(self.config.vocab_size, | ||
scale=logit_scale) | ||
|
||
def forward( | ||
self, | ||
input_ids: torch.Tensor, | ||
positions: torch.Tensor, | ||
hidden_states: torch.Tensor, | ||
) -> torch.Tensor: | ||
return self.model(input_ids, positions, hidden_states) | ||
|
||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): | ||
loader = AutoWeightsLoader( | ||
self, | ||
skip_prefixes=(["lm_head."] | ||
if self.config.tie_word_embeddings else None), | ||
) | ||
|
||
model_weights = {} | ||
for name, loaded_weight in weights: | ||
if "lm_head" not in name: | ||
name = "model." + name | ||
model_weights[name] = loaded_weight | ||
|
||
loader.load_weights(model_weights.items()) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,8 +4,11 @@ | |
import triton | ||
import triton.language as tl | ||
|
||
from vllm.config import VllmConfig | ||
from vllm.config import VllmConfig, set_current_vllm_config | ||
from vllm.forward_context import set_forward_context | ||
from vllm.model_executor.model_loader.loader import get_model_loader | ||
from vllm.model_executor.model_loader.utils import set_default_torch_dtype | ||
from vllm.model_executor.models.llama_eagle import EagleLlamaForCausalLM | ||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata | ||
from vllm.v1.sample.metadata import SamplingMetadata | ||
|
||
|
@@ -21,8 +24,12 @@ def __init__( | |
self.num_speculative_tokens = ( | ||
vllm_config.speculative_config.num_speculative_tokens) | ||
self.block_size = vllm_config.cache_config.block_size | ||
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs, | ||
device=device) | ||
# We need +1 here because the arange is used to set query_start_loc, | ||
# which has one more element than batch_size. | ||
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs + | ||
1, | ||
device=device, | ||
dtype=torch.int32) | ||
|
||
def propose( | ||
self, | ||
|
@@ -54,7 +61,9 @@ def propose( | |
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] | ||
input_ids[last_token_indices] = next_token_ids | ||
|
||
seq_lens = target_positions[last_token_indices] + 1 | ||
# FA requires seq_len to have dtype int32. | ||
seq_lens = (target_positions[last_token_indices] + 1).int() | ||
|
||
# FIXME(woosuk): The below two ops cause synchronization. Optimize. | ||
max_seq_len = seq_lens.max().item() | ||
max_num_tokens = (cu_num_tokens[1:] - cu_num_tokens[:-1]).max().item() | ||
|
@@ -98,7 +107,7 @@ def propose( | |
hidden_states = sample_hidden_states | ||
attn_metadata.num_actual_tokens = batch_size | ||
attn_metadata.max_query_len = 1 | ||
attn_metadata.query_start_loc = self.arange[:batch_size] | ||
attn_metadata.query_start_loc = self.arange[:batch_size + 1] | ||
for _ in range(self.num_speculative_tokens - 1): | ||
# Update the inputs. | ||
input_ids = draft_token_ids_list[-1] | ||
|
@@ -176,26 +185,28 @@ def prepare_inputs( | |
return cu_num_tokens, token_indices | ||
|
||
def load_model(self, target_model: nn.Module) -> None: | ||
self.model = DummyEagleModel() | ||
self.model.get_input_embeddings = target_model.get_input_embeddings | ||
self.model.compute_logits = target_model.compute_logits | ||
|
||
|
||
# FIXME(woosuk): This is a dummy model for testing. | ||
# Remove this once we have a real model. | ||
class DummyEagleModel(nn.Module): | ||
|
||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward( | ||
self, | ||
input_ids: torch.Tensor, | ||
hidden_states: torch.Tensor, | ||
positions: torch.Tensor, | ||
) -> torch.Tensor: | ||
input_embeddings = self.get_input_embeddings(input_ids) | ||
return hidden_states + input_embeddings # Dummy return. | ||
loader = get_model_loader(self.vllm_config.load_config) | ||
target_layer_num = self.vllm_config.model_config.get_num_layers( | ||
self.vllm_config.parallel_config) | ||
|
||
draft_model_config = \ | ||
self.vllm_config.speculative_config.draft_model_config | ||
# FIXME(lily): This does not handle with distributed inference. | ||
target_device = self.vllm_config.device_config.device | ||
# We need to set the vllm_config here to register attention | ||
# layers in the forward context. | ||
Comment on lines
+196
to
+197
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you elaborate a bit on which There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The one in this file |
||
with set_default_torch_dtype( | ||
draft_model_config.dtype), set_current_vllm_config( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
My understanding is that, we did not change any
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Which There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The one in spec_decode/eagle.py here: https://github.com/vllm-project/vllm/pull/16035/files/59ee450306d3d719f78ad60c77ba9b739bc5cb11#diff-a4809a837fbf535a8f0999b11087a53ec1c53948b50c0a1fe64396bc86de9461R184 I have broken my above question into 2 parts along with my understanding so that it is easier for you to explain what I am missing. Looking fwd to your response |
||
self.vllm_config): | ||
self.model = EagleLlamaForCausalLM( | ||
model_config=draft_model_config, | ||
start_layer_id=target_layer_num).to(target_device) | ||
|
||
self.model.load_weights( | ||
loader.get_all_weights( | ||
self.vllm_config.speculative_config.draft_model_config, | ||
self.model)) | ||
self.model.lm_head = target_model.lm_head | ||
|
||
|
||
# FIXME(woosuk): The logic here is duplicated with the main sampling code. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
QQ: Doesn't Eagle share the same vocab embedding with the original model?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, it's the same, it's just the eagle head also includes the embed_tokens weights, so we just load it here. We can also set it to the embed_token of target model. I double check the weights of the two, they are the same.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for checking it. I'm ok with having it for now, but maybe worth a comment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we don't set it to the embed_token of target model then we are loading the vocab twice which take more GPU memory for larger vocab?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ekagra-ranjan Yes, for now. We can share the embedding when PP size is 1, but I was thinking that we can do this in a followup PR as an optimization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed by: #17326