diff --git a/examples/offline_inference/eagle.py b/examples/offline_inference/eagle.py index c7b4368c9b13..474b745a6106 100644 --- a/examples/offline_inference/eagle.py +++ b/examples/offline_inference/eagle.py @@ -52,8 +52,8 @@ def main(): args = parse_args() - model_dir = "meta-llama/Meta-Llama-3-8B-Instruct" - eagle_dir = "abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm" + model_dir = "meta-llama/Llama-3.1-8B-Instruct" + eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" max_model_len = 2048 @@ -81,7 +81,7 @@ def main(): max_num_seqs=args.max_num_seqs, gpu_memory_utilization=0.8, speculative_config={ - "method": "eagle", + "method": "eagle3" if "eagle3" in eagle_dir.lower() else "eagle", "model": eagle_dir, "num_speculative_tokens": args.num_spec_tokens, "draft_tensor_parallel_size": args.draft_tp, @@ -95,6 +95,9 @@ def main(): outputs = llm.generate(prompt_token_ids=prompt_ids, sampling_params=sampling_params) + if not hasattr(outputs, "metrics") or outputs.metrics is None: + return + # calculate the average number of accepted tokens per forward pass, +1 is # to account for the token from the target model that's always going to be # accepted @@ -109,6 +112,11 @@ def main(): {sum(acceptance_counts) / acceptance_counts[0]:.2f}") print("-" * 50) + # print acceptance at each token position + for i in range(len(acceptance_counts)): + print(f"acceptance at token {i}:" + f"{acceptance_counts[i] / (acceptance_counts[0]):.2f}") + if __name__ == "__main__": main() diff --git a/tests/models/registry.py b/tests/models/registry.py index c15ae3619844..2f7ddfd907ed 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -392,6 +392,10 @@ def check_available_online( trust_remote_code=True, speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B", tokenizer="meta-llama/Meta-Llama-3-8B-Instruct"), # noqa: E501 + "Eagle3LlamaForCausalLM": _HfExamplesInfo("yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", # noqa: E501 + trust_remote_code=True, + speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", + tokenizer="meta-llama/Llama-3.1-8B-Instruct"), } _TRANSFORMERS_MODELS = { diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 673714980592..485b011acc00 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -50,12 +50,15 @@ def sampling_config(): @pytest.fixture def model_name(): - return "meta-llama/Meta-Llama-3-8B-Instruct" + return "meta-llama/Llama-3.1-8B-Instruct" -@pytest.fixture def eagle_model_name(): - return "yuhuili/EAGLE-LLaMA3-Instruct-8B" + return "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" + + +def eagle3_model_name(): + return "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" def test_ngram_correctness( @@ -102,12 +105,13 @@ def test_ngram_correctness( del spec_llm +@pytest.mark.parametrize("use_eagle3", [False, True], ids=["eagle", "eagle3"]) def test_eagle_correctness( monkeypatch: pytest.MonkeyPatch, test_prompts: list[list[dict[str, Any]]], sampling_config: SamplingParams, model_name: str, - eagle_model_name: str, + use_eagle3: bool, ): ''' Compare the outputs of a original LLM and a speculative LLM @@ -116,18 +120,22 @@ def test_eagle_correctness( with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") - ref_llm = LLM(model=model_name, max_model_len=1024) + ref_llm = LLM(model=model_name, max_model_len=2048) ref_outputs = ref_llm.chat(test_prompts, sampling_config) del ref_llm + spec_model_name = eagle3_model_name( + ) if use_eagle3 else eagle_model_name() spec_llm = LLM( model=model_name, + trust_remote_code=True, speculative_config={ - "method": "eagle", - "model": eagle_model_name, + "method": "eagle3" if use_eagle3 else "eagle", + "model": spec_model_name, "num_speculative_tokens": 3, + "max_model_len": 2048, }, - max_model_len=1024, + max_model_len=2048, ) spec_outputs = spec_llm.chat(test_prompts, sampling_config) matches = 0 diff --git a/vllm/config.py b/vllm/config.py index 3e5a17802f0f..cffaa2942b8b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2338,9 +2338,10 @@ def __post_init__(self): ) # Automatically detect the method - if self.method == 'eagle': + if self.method in ('eagle', 'eagle3'): pass - elif "eagle-" in self.draft_model_config.model.lower(): + elif "eagle-" in self.draft_model_config.model.lower() or \ + "eagle3-" in self.draft_model_config.model.lower(): self.method = "eagle" elif self.draft_model_config.hf_config.model_type == "medusa": self.method = "medusa" @@ -2351,7 +2352,7 @@ def __post_init__(self): self.method = "draft_model" # Replace hf_config for EAGLE draft_model - if self.method == "eagle": + if self.method in ("eagle", "eagle3"): if self.enable_chunked_prefill and not envs.VLLM_USE_V1: raise ValueError( "Chunked prefill and EAGLE are not compatible " @@ -2548,6 +2549,12 @@ def _verify_args(self) -> None: "speculative decoding is > 1, but got " f"{self.disable_by_batch_size=}") + if self.method == "eagle3" and self.target_model_config and \ + "llama" not in self.target_model_config.hf_text_config.model_type: + raise ValueError( + "Eagle3 is only supported for Llama models. " + f"Got {self.target_model_config.hf_text_config.model_type=}") + @property def num_lookahead_slots(self) -> int: """The number of additional slots the scheduler should allocate per diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 6d6b5ac02b14..5b110d7c1ecc 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1456,7 +1456,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: if speculative_method: if speculative_method in ("ngram", "[ngram]"): is_ngram_enabled = True - elif speculative_method == "eagle": + elif speculative_method in ("eagle", "eagle3"): is_eagle_enabled = True else: speculative_model = self.speculative_config.get("model") diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 31ffa4e1e63c..17d080fa5a28 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -330,6 +330,8 @@ def __init__(self, else: self.norm = PPMissingLayer() + self.aux_hidden_state_layers: tuple[int] = tuple() + self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) @@ -355,7 +357,11 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + aux_hidden_states = [] + for idx, layer in enumerate( + self.layers[self.start_layer:self.end_layer]): + if idx in self.aux_hidden_state_layers: + aux_hidden_states.append(hidden_states + residual) hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: @@ -365,6 +371,9 @@ def forward( }) hidden_states, _ = self.norm(hidden_states, residual) + + if len(aux_hidden_states) > 0: + return hidden_states, aux_hidden_states return hidden_states def load_weights(self, weights: Iterable[Tuple[str, @@ -517,6 +526,13 @@ def __init__(self, self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None: + self.model.aux_hidden_state_layers = layers + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]: + num_layers = len(self.model.layers) + return (2, num_layers // 2, num_layers - 3) + def _init_model(self, vllm_config: VllmConfig, prefix: str = "", diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index 28ad6128c4f1..06f7cb08a7c8 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -82,7 +82,8 @@ def forward( hidden_states, residual, ) - return hidden_states + residual + hidden_states = hidden_states + residual + return hidden_states, hidden_states def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py new file mode 100644 index 000000000000..ffbb9d75a06b --- /dev/null +++ b/vllm/model_executor/models/llama_eagle3.py @@ -0,0 +1,232 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Iterable, Optional, Set, Tuple + +import torch +import torch.nn as nn +from transformers import LlamaConfig + +from vllm.config import ModelConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import QKVParallelLinear +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.llama import (LlamaDecoderLayer, + LlamaForCausalLM) +from vllm.v1.sample.metadata import SamplingMetadata + +from .utils import AutoWeightsLoader, maybe_prefix + +logger = init_logger(__name__) + + +class LlamaDecoderLayer(LlamaDecoderLayer): + + def __init__( + self, + config: LlamaConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(config, quant_config=quant_config, prefix=prefix) + + # override qkv + self.self_attn.qkv_proj = QKVParallelLinear( + 2 * self.hidden_size, + self.self_attn.head_dim, + self.self_attn.total_num_heads, + self.self_attn.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "qkv_proj"), + ) + + self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + embeds: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + + residual = hidden_states + embeds = self.input_layernorm(embeds) + hidden_states = self.hidden_norm(hidden_states) + + hidden_states = torch.cat([embeds, hidden_states], dim=-1) + # Self Attention + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + + # Fully Connected + hidden_states = self.mlp(hidden_states) + + return hidden_states, residual + + +class LlamaModel(nn.Module): + + def __init__( + self, + *, + model_config: ModelConfig, + start_layer_id: int = 0, + prefix: str = "", + ) -> None: + super().__init__() + self.config = model_config.hf_config + self.vocab_size = self.config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + self.config.vocab_size, + self.config.hidden_size, + prefix=maybe_prefix(prefix, "embed_tokens"), + ) + self.layers = nn.ModuleList([ + LlamaDecoderLayer( + self.config, + prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"), + ) + ]) + if hasattr(self.config, "target_hidden_size"): + self.fc = torch.nn.Linear(self.config.target_hidden_size * 3, + self.config.hidden_size, + bias=False) + else: + self.fc = torch.nn.Linear(self.config.hidden_size * 3, + self.config.hidden_size, + bias=False) + self.norm = RMSNorm( + self.config.hidden_size, + eps=self.config.rms_norm_eps, + ) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + input_embeds = self.embed_tokens(input_ids) + if (hidden_states.shape[-1] != input_embeds.shape[-1]): + hidden_states = self.fc(hidden_states) + + residual = None + hidden_states, residual = self.layers[0]( + positions, + input_embeds, + hidden_states, + residual, + ) + + hidden_states, hidden_prenorm = self.norm(hidden_states, residual) + return hidden_states, hidden_prenorm + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if 'midlayer.' in name: + name = name.replace('midlayer.', 'layers.0.') + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Eagle3LlamaForCausalLM(LlamaForCausalLM): + + def __init__(self, *, model_config: ModelConfig, start_layer_id: int = 0): + nn.Module.__init__(self) + self.config = model_config.hf_config + self.model = LlamaModel(model_config=model_config, + start_layer_id=start_layer_id, + prefix="model") + + logit_scale = getattr(self.config, "logit_scale", 1.0) + self.lm_head = ParallelLMHead( + self.config.draft_vocab_size, + self.config.hidden_size, + org_num_embeddings=self.config.draft_vocab_size, + padding_size=(DEFAULT_VOCAB_PADDING_SIZE), + prefix="") + self.logits_processor = LogitsProcessor(self.config.draft_vocab_size, + scale=logit_scale) + self.draft_id_to_target_id = nn.Parameter( + torch.zeros((self.config.draft_vocab_size), + dtype=torch.long).type(torch.LongTensor), + requires_grad=False, + ) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + return self.model(input_ids, positions, hidden_states) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + base = torch.arange(self.config.draft_vocab_size, device=logits.device) + targets = base + self.draft_id_to_target_id + logits_new = logits.new_full(( + logits.shape[0], + self.config.vocab_size, + ), float('-inf')) + logits_new[:, targets] = logits + return logits_new + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader( + self, + skip_prefixes=None, + ) + + model_weights = {} + for name, loaded_weight in weights: + if "t2d" in name: + continue + if "d2t" in name: + name = name.replace("d2t", "draft_id_to_target_id") + elif "lm_head" not in name: + name = "model." + name + model_weights[name] = loaded_weight + + return loader.load_weights(model_weights.items()) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 621b9d69faa5..11e663e32d45 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -214,6 +214,7 @@ _SPECULATIVE_DECODING_MODELS = { "EAGLEModel": ("eagle", "EAGLE"), "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"), + "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), "MedusaModel": ("medusa", "Medusa"), "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 44dd9b026c2d..adec4462963c 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -126,7 +126,7 @@ def __init__( self.num_spec_tokens = self.num_lookahead_tokens = 0 if speculative_config: self.num_spec_tokens = speculative_config.num_speculative_tokens - if speculative_config.method == "eagle": + if speculative_config.method in ("eagle", "eagle3"): self.num_lookahead_tokens = self.num_spec_tokens def schedule(self) -> SchedulerOutput: diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 95f0c067d406..1de14584d396 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -6,12 +6,16 @@ from vllm.config import VllmConfig, set_current_vllm_config from vllm.forward_context import set_forward_context +from vllm.logger import init_logger from vllm.model_executor.model_loader.loader import get_model_loader from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.models.llama_eagle import EagleLlamaForCausalLM +from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.sample.metadata import SamplingMetadata +logger = init_logger(__name__) + PADDING_SLOT_ID = -1 @@ -87,12 +91,12 @@ def propose( ) with set_forward_context(attn_metadata, self.vllm_config): - hidden_states = self.model( + hidden_states_logits, hidden_states_fwd = self.model( input_ids=input_ids, hidden_states=target_hidden_states, positions=target_positions, ) - sample_hidden_states = hidden_states[last_token_indices] + sample_hidden_states = hidden_states_logits[last_token_indices] logits = self.model.compute_logits(sample_hidden_states, None) draft_token_ids = logits.argmax(dim=-1) @@ -105,7 +109,7 @@ def propose( draft_token_ids_list = [draft_token_ids] positions = target_positions[last_token_indices] - hidden_states = sample_hidden_states + hidden_states = hidden_states_fwd[last_token_indices] attn_metadata.num_actual_tokens = batch_size attn_metadata.max_query_len = 1 attn_metadata.query_start_loc = self.arange[:batch_size + 1] @@ -151,12 +155,12 @@ def propose( # Run the model. with set_forward_context(attn_metadata, self.vllm_config): - hidden_states = self.model( + hidden_states_logits, hidden_states = self.model( input_ids=input_ids, hidden_states=hidden_states, positions=clamped_positions, ) - logits = self.model.compute_logits(hidden_states, None) + logits = self.model.compute_logits(hidden_states_logits, None) draft_token_ids = logits.argmax(dim=-1) draft_token_ids_list.append(draft_token_ids) @@ -221,15 +225,28 @@ def load_model(self, target_model: nn.Module) -> None: with set_default_torch_dtype( draft_model_config.dtype), set_current_vllm_config( self.vllm_config): - self.model = EagleLlamaForCausalLM( - model_config=draft_model_config, - start_layer_id=target_layer_num).to(target_device) - - self.model.load_weights( + if self.vllm_config.speculative_config.method == "eagle": + self.model = EagleLlamaForCausalLM( + model_config=draft_model_config, + start_layer_id=target_layer_num).to(target_device) + else: + assert self.vllm_config.speculative_config.method == "eagle3" + self.model = Eagle3LlamaForCausalLM( + model_config=draft_model_config, + start_layer_id=target_layer_num).to(target_device) + + loaded_weights = self.model.load_weights( loader.get_all_weights( self.vllm_config.speculative_config.draft_model_config, self.model)) - self.model.lm_head = target_model.lm_head + if self.vllm_config.speculative_config.method == "eagle3": + if "model.embed_tokens.weight" not in loaded_weights: + logger.info( + "Loading EAGLE embedding weights from the target model.") + self.model.model.embed_tokens = target_model.model.embed_tokens + else: + logger.info("Loading EAGLE LM head weights from the target model.") + self.model.lm_head = target_model.lm_head # NOTE(woosuk): Currently, the below code is not used and we always use argmax diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 86f6a301fbb6..7910481762ef 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -165,14 +165,18 @@ def __init__( # Set up speculative decoding. self.use_spec_decode = False + self.use_aux_hidden_state_outputs = False if self.speculative_config: self.use_spec_decode = True if get_pp_group().is_last_rank: if self.speculative_config.method == "ngram": self.drafter = NgramProposer(self.vllm_config) - elif self.speculative_config.method == "eagle": + elif self.speculative_config.method == "eagle" or \ + self.speculative_config.method == "eagle3": self.drafter = EagleProposer(self.vllm_config, self.device) # type: ignore + if self.speculative_config.method == "eagle3": + self.use_aux_hidden_state_outputs = True else: raise ValueError("Unknown speculative decoding method: " f"{self.speculative_config.method}") @@ -1079,12 +1083,18 @@ def execute_model( # Run the decoder. # Use persistent buffers for CUDA graphs. with set_forward_context(attn_metadata, self.vllm_config): - hidden_states = self.model( + output = self.model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) + + if self.use_aux_hidden_state_outputs: + hidden_states, aux_hidden_states = output + else: + hidden_states = output + if not get_pp_group().is_last_rank: # For mid-pipeline stages, return the hidden states. return hidden_states @@ -1182,7 +1192,8 @@ def execute_model( assert isinstance(self.drafter, NgramProposer) spec_token_ids = self.generate_draft_token_ids( valid_sampled_token_ids, sampling_metadata) - elif self.speculative_config.method == "eagle": + elif self.speculative_config.method == "eagle" or \ + self.speculative_config.method == "eagle3": assert isinstance(self.drafter, EagleProposer) # TODO(woosuk): Refactor the loop. next_token_ids: list[int] = [] @@ -1210,7 +1221,12 @@ def execute_model( # not include padding. target_token_ids = self.input_ids[:num_scheduled_tokens] target_positions = positions[:num_scheduled_tokens] - target_hidden_states = hidden_states[:num_scheduled_tokens] + if self.use_aux_hidden_state_outputs: + target_hidden_states = [ + h[:num_scheduled_tokens] for h in aux_hidden_states + ] + else: + target_hidden_states = hidden_states[:num_scheduled_tokens] target_slot_mapping = attn_metadata.slot_mapping cu_num_tokens = attn_metadata.query_start_loc else: @@ -1231,9 +1247,16 @@ def execute_model( ) target_token_ids = self.input_ids[token_indices] target_positions = positions[token_indices] - target_hidden_states = hidden_states[token_indices] + if self.use_aux_hidden_state_outputs: + target_hidden_states = [ + h[token_indices] for h in aux_hidden_states + ] + else: + target_hidden_states = hidden_states[token_indices] target_slot_mapping = attn_metadata.slot_mapping[token_indices] + if self.use_aux_hidden_state_outputs: + target_hidden_states = torch.cat(target_hidden_states, dim=-1) draft_token_ids = self.drafter.propose( target_token_ids=target_token_ids, target_positions=target_positions, @@ -1311,6 +1334,9 @@ def load_model(self) -> None: if hasattr(self, "drafter"): logger.info("Loading drafter model...") self.drafter.load_model(self.model) + if self.use_aux_hidden_state_outputs: + self.model.set_aux_hidden_state_layers( + self.model.get_eagle3_aux_hidden_state_layers()) time_after_load = time.perf_counter() self.model_memory_usage = m.consumed_memory logger.info("Model loading took %.4f GiB and %.6f seconds", @@ -1463,12 +1489,16 @@ def _dummy_run( with set_forward_context(None, self.vllm_config, num_tokens=num_tokens): - hidden_states = model( + outputs = model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) + if self.use_aux_hidden_state_outputs: + hidden_states, _ = outputs + else: + hidden_states = outputs logit_indices = np.cumsum(num_scheduled_tokens) - 1 return hidden_states[logit_indices]