Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/offline_inference/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def main():
max_num_seqs=args.max_num_seqs,
gpu_memory_utilization=0.8,
speculative_config={
"method": "eagle",
"model": eagle_dir,
"num_speculative_tokens": args.num_spec_tokens,
"draft_tensor_parallel_size": args.draft_tp,
Expand Down
4 changes: 4 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,10 @@ def check_available_online(
"DeepSeekMTPModel": _HfExamplesInfo("luccafong/deepseek_mtp_main_random",
speculative_model="luccafong/deepseek_mtp_draft_random", # noqa: E501
trust_remote_code=True),
"EagleLlamaForCausalLM": _HfExamplesInfo("yuhuili/EAGLE-LLaMA3-Instruct-8B",
trust_remote_code=True,
speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B",
tokenizer="meta-llama/Meta-Llama-3-8B-Instruct"), # noqa: E501
}

_TRANSFORMERS_MODELS = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ def model_name():
return "meta-llama/Meta-Llama-3-8B-Instruct"


@pytest.fixture
def eagle_model_name():
return "yuhuili/EAGLE-LLaMA3-Instruct-8B"


def test_ngram_correctness(
monkeypatch: pytest.MonkeyPatch,
test_prompts: list[list[dict[str, Any]]],
Expand Down Expand Up @@ -95,3 +100,47 @@ def test_ngram_correctness(
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.7 * len(ref_outputs))
del spec_llm


def test_eagle_correctness(
monkeypatch: pytest.MonkeyPatch,
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
model_name: str,
eagle_model_name: str,
):
'''
Compare the outputs of a original LLM and a speculative LLM
should be the same when using eagle speculative decoding.
'''
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")

ref_llm = LLM(model=model_name, max_model_len=1024)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm

spec_llm = LLM(
model=model_name,
speculative_config={
"method": "eagle",
"model": eagle_model_name,
"num_speculative_tokens": 3,
},
max_model_len=1024,
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")

# Heuristic: expect at least 70% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.7 * len(ref_outputs))
del spec_llm
4 changes: 2 additions & 2 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ def _hpu_weights_iterator(iterator: Generator):
return ((source.prefix + name, tensor)
for (name, tensor) in weights_iterator)

def _get_all_weights(
def get_all_weights(
self,
model_config: ModelConfig,
model: nn.Module,
Expand Down Expand Up @@ -453,7 +453,7 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:

weights_to_load = {name for name, _ in model.named_parameters()}
loaded_weights = model.load_weights(
self._get_all_weights(model_config, model))
self.get_all_weights(model_config, model))
self.counter_after_loading_weights = time.perf_counter()
logger.info(
"Loading weights took %.2f seconds",
Expand Down
151 changes: 151 additions & 0 deletions vllm/model_executor/models/llama_eagle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# SPDX-License-Identifier: Apache-2.0

from typing import Iterable, Set, Tuple

import torch
import torch.nn as nn
from transformers import LlamaConfig

from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import (LlamaDecoderLayer,
LlamaForCausalLM)

from .utils import AutoWeightsLoader, maybe_prefix

logger = init_logger(__name__)


class LlamaDecoderLayer(LlamaDecoderLayer):

def __init__(
self,
config: LlamaConfig,
disable_input_layernorm: bool,
prefix: str = "",
) -> None:
super().__init__(config, prefix=prefix)

# Skip the input_layernorm
# https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427
if disable_input_layernorm:
del self.input_layernorm
self.input_layernorm = nn.Identity()


class LlamaModel(nn.Module):

def __init__(
self,
*,
model_config: ModelConfig,
start_layer_id: int = 0,
prefix: str = "",
) -> None:
super().__init__()
self.config = model_config.hf_config
self.vocab_size = self.config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: Doesn't Eagle share the same vocab embedding with the original model?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it's the same, it's just the eagle head also includes the embed_tokens weights, so we just load it here. We can also set it to the embed_token of target model. I double check the weights of the two, they are the same.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for checking it. I'm ok with having it for now, but maybe worth a comment?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't set it to the embed_token of target model then we are loading the vocab twice which take more GPU memory for larger vocab?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ekagra-ranjan Yes, for now. We can share the embedding when PP size is 1, but I was thinking that we can do this in a followup PR as an optimization.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed by: #17326

self.config.vocab_size,
self.config.hidden_size,
prefix=maybe_prefix(prefix, "embed_tokens"),
)
self.layers = nn.ModuleList([
LlamaDecoderLayer(
self.config,
i == 0,
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
) for i in range(self.config.num_hidden_layers)
])
self.fc = torch.nn.Linear(self.config.hidden_size * 2,
self.config.hidden_size,
bias=False)

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
input_embeds = self.embed_tokens(input_ids)
hidden_states = self.fc(
torch.cat((input_embeds, hidden_states), dim=-1))
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
residual,
)
return hidden_states + residual

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params


class EagleLlamaForCausalLM(LlamaForCausalLM):

def __init__(self, *, model_config: ModelConfig, start_layer_id: int = 0):
nn.Module.__init__(self)
self.config = model_config.hf_config
self.model = LlamaModel(model_config=model_config,
start_layer_id=start_layer_id,
prefix="model")

logit_scale = getattr(self.config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.config.vocab_size,
scale=logit_scale)

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
return self.model(input_ids, positions, hidden_states)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)

model_weights = {}
for name, loaded_weight in weights:
if "lm_head" not in name:
name = "model." + name
model_weights[name] = loaded_weight

loader.load_weights(model_weights.items())
1 change: 1 addition & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@

_SPECULATIVE_DECODING_MODELS = {
"EAGLEModel": ("eagle", "EAGLE"),
"EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
"MedusaModel": ("medusa", "Medusa"),
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
Expand Down
5 changes: 4 additions & 1 deletion vllm/transformers_utils/configs/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from transformers import AutoConfig, PretrainedConfig

import vllm.envs as envs
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config


Expand Down Expand Up @@ -41,8 +42,10 @@ def __init__(self,
self.truncated_vocab_size = self.model.vocab_size if \
truncated_vocab_size is None else truncated_vocab_size

if "architectures" not in kwargs:
if not envs.VLLM_USE_V1:
kwargs["architectures"] = ["EAGLEModel"]
else:
kwargs["architectures"] = ["EagleLlamaForCausalLM"]

super().__init__(**kwargs)

Expand Down
61 changes: 36 additions & 25 deletions vllm/v1/spec_decode/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
import triton
import triton.language as tl

from vllm.config import VllmConfig
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.forward_context import set_forward_context
from vllm.model_executor.model_loader.loader import get_model_loader
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.models.llama_eagle import EagleLlamaForCausalLM
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.sample.metadata import SamplingMetadata

Expand All @@ -21,8 +24,12 @@ def __init__(
self.num_speculative_tokens = (
vllm_config.speculative_config.num_speculative_tokens)
self.block_size = vllm_config.cache_config.block_size
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs,
device=device)
# We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size.
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs +
1,
device=device,
dtype=torch.int32)

def propose(
self,
Expand Down Expand Up @@ -54,7 +61,9 @@ def propose(
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
input_ids[last_token_indices] = next_token_ids

seq_lens = target_positions[last_token_indices] + 1
# FA requires seq_len to have dtype int32.
seq_lens = (target_positions[last_token_indices] + 1).int()

# FIXME(woosuk): The below two ops cause synchronization. Optimize.
max_seq_len = seq_lens.max().item()
max_num_tokens = (cu_num_tokens[1:] - cu_num_tokens[:-1]).max().item()
Expand Down Expand Up @@ -98,7 +107,7 @@ def propose(
hidden_states = sample_hidden_states
attn_metadata.num_actual_tokens = batch_size
attn_metadata.max_query_len = 1
attn_metadata.query_start_loc = self.arange[:batch_size]
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
for _ in range(self.num_speculative_tokens - 1):
# Update the inputs.
input_ids = draft_token_ids_list[-1]
Expand Down Expand Up @@ -176,26 +185,28 @@ def prepare_inputs(
return cu_num_tokens, token_indices

def load_model(self, target_model: nn.Module) -> None:
self.model = DummyEagleModel()
self.model.get_input_embeddings = target_model.get_input_embeddings
self.model.compute_logits = target_model.compute_logits


# FIXME(woosuk): This is a dummy model for testing.
# Remove this once we have a real model.
class DummyEagleModel(nn.Module):

def __init__(self):
super().__init__()

def forward(
self,
input_ids: torch.Tensor,
hidden_states: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
input_embeddings = self.get_input_embeddings(input_ids)
return hidden_states + input_embeddings # Dummy return.
loader = get_model_loader(self.vllm_config.load_config)
target_layer_num = self.vllm_config.model_config.get_num_layers(
self.vllm_config.parallel_config)

draft_model_config = \
self.vllm_config.speculative_config.draft_model_config
# FIXME(lily): This does not handle with distributed inference.
target_device = self.vllm_config.device_config.device
# We need to set the vllm_config here to register attention
# layers in the forward context.
Comment on lines +196 to +197
Copy link
Contributor

@ekagra-ranjan ekagra-ranjan Apr 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to call load_model() from the __init__() so that this function runs and the attn layer are registered?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate a bit on which load_model are you talking about?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with set_default_torch_dtype(
draft_model_config.dtype), set_current_vllm_config(
Copy link
Contributor

@ekagra-ranjan ekagra-ranjan Apr 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. I didnt get this part which says that setting the current vllm config will lead to registering attn layers. Can you pls share more?

My understanding is that, we did not change any self.vllm_config in this function and the attn are registered when the model is initialized where it saves the attn prefix in static_forward_context which is then used during bind_kv_cache().

  1. So if the load_model() is called in __init__ in this file then would it not register the attn func without the need of set_current_vllm_config(self.vllm_config)?

Copy link
Collaborator Author

@LiuXiaoxuanPKU LiuXiaoxuanPKU Apr 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which load_model are you referring to? This one?

Copy link
Contributor

@ekagra-ranjan ekagra-ranjan Apr 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The one in spec_decode/eagle.py here: https://github.com/vllm-project/vllm/pull/16035/files/59ee450306d3d719f78ad60c77ba9b739bc5cb11#diff-a4809a837fbf535a8f0999b11087a53ec1c53948b50c0a1fe64396bc86de9461R184

I have broken my above question into 2 parts along with my understanding so that it is easier for you to explain what I am missing. Looking fwd to your response

self.vllm_config):
self.model = EagleLlamaForCausalLM(
model_config=draft_model_config,
start_layer_id=target_layer_num).to(target_device)

self.model.load_weights(
loader.get_all_weights(
self.vllm_config.speculative_config.draft_model_config,
self.model))
self.model.lm_head = target_model.lm_head


# FIXME(woosuk): The logic here is duplicated with the main sampling code.
Expand Down
7 changes: 5 additions & 2 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1191,9 +1191,12 @@ def execute_model(

if spec_decode_metadata is None:
# input_ids can be None for multimodal models.
# We need to slice token_ids, positions, and hidden_states
# because the eagle head does not use cuda graph and should
# not include padding.
target_token_ids = self.input_ids[:num_scheduled_tokens]
target_positions = positions
target_hidden_states = hidden_states
target_positions = positions[:num_scheduled_tokens]
target_hidden_states = hidden_states[:num_scheduled_tokens]
target_slot_mapping = attn_metadata.slot_mapping
cu_num_tokens = attn_metadata.query_start_loc
else:
Expand Down