diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index af0bf2ae364f..21cbc632d080 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -262,7 +262,7 @@ steps: - pytest -v -s v1/test_metrics_reader.py # TODO: accuracy does not match, whether setting # VLLM_USE_FLASHINFER_SAMPLER or not on H100. - - pytest -v -s v1/e2e + - pytest -v -s v1/e2e --ignore=v1/e2e/test_llama4_eagle.py # Integration test for streaming correctness (requires special branch). - pip install -U git+https://github.com/robertgshaw2-neuralmagic/lm-evaluation-harness.git@streaming-api - pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine diff --git a/tests/models/registry.py b/tests/models/registry.py index c10d375683ee..a5b87e6f27e6 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -32,6 +32,12 @@ class _HfExamplesInfo: for speculative decoding. """ + speculative_method: Optional[str] = None + """ + The default speculative method to use for testing this architecture, which + is only used for speculative decoding. + """ + min_transformers_version: Optional[str] = None """ The minimum version of HF Transformers that is required to run this model. @@ -61,6 +67,9 @@ class _HfExamplesInfo: v0_only: bool = False """The model is only available with the vLLM V0 engine.""" + v1_only: bool = False + """The model is only available with the vLLM V1 engine.""" + hf_overrides: dict[str, Any] = field(default_factory=dict) """The ``hf_overrides`` required to load the model.""" @@ -457,6 +466,13 @@ def check_available_online( trust_remote_code=True, speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B", tokenizer="meta-llama/Meta-Llama-3-8B-Instruct"), # noqa: E501 + "EagleLlama4ForCausalLM": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501 + trust_remote_code=True, + speculative_model="ronaldbxu/EAGLE-Llama-4-Maverick-17B-128E-Instruct", # noqa: E501 + tokenizer="meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501 + speculative_method="eagle", + max_model_len=256, + v1_only=True), "Eagle3LlamaForCausalLM": _HfExamplesInfo("yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", # noqa: E501 trust_remote_code=True, speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index ea6a2cc37ccf..d1a12ec61bff 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -33,7 +33,8 @@ def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch): # FIXME: Possible memory leak in the previous tests? if model_arch in ("Glm4vForConditionalGeneration", "GraniteSpeechForConditionalGeneration", - "KimiVLForConditionalGeneration"): + "KimiVLForConditionalGeneration", + "EagleLlama4ForCausalLM"): pytest.skip("Avoid OOM") # Avoid OOM and reduce initialization time by only using 1 layer @@ -103,6 +104,8 @@ def _initialize_kv_caches_v1(self, vllm_config): _initialize_kv_caches_v1), monkeypatch.context() as m): if model_info.v0_only: m.setenv("VLLM_USE_V1", "0") + if model_info.v1_only: + m.setenv("VLLM_USE_V1", "1") if model_arch == "Phi4FlashForCausalLM": # Phi4FlashForCausalLM only supports DIFFERENTIAL_FLASH_ATTN backend m.setenv("VLLM_ATTENTION_BACKEND", "DIFFERENTIAL_FLASH_ATTN") @@ -112,8 +115,13 @@ def _initialize_kv_caches_v1(self, vllm_config): tokenizer_mode=model_info.tokenizer_mode, revision=model_info.revision, speculative_config={ - "model": model_info.speculative_model, - "num_speculative_tokens": 1, + "method": + model_info.speculative_method + if model_info.speculative_method else None, + "model": + model_info.speculative_model, + "num_speculative_tokens": + 1, } if model_info.speculative_model else None, trust_remote_code=model_info.trust_remote_code, max_model_len=model_info.max_model_len, diff --git a/tests/v1/e2e/test_llama4_eagle.py b/tests/v1/e2e/test_llama4_eagle.py new file mode 100644 index 000000000000..56574ab7af52 --- /dev/null +++ b/tests/v1/e2e/test_llama4_eagle.py @@ -0,0 +1,111 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# To run this file, run +# pytest -vx /tests/v1/e2e/test_llama4_eagle.py + +from __future__ import annotations + +import random +from typing import Any + +import pytest + +from vllm import LLM, SamplingParams + + +@pytest.fixture +def test_prompts(): + prompt_types = ["repeat", "sentence"] + num_prompts = 100 + prompts = [] + + random.seed(0) + random_prompt_type_choices = random.choices(prompt_types, k=num_prompts) + + # Generate a mixed batch of prompts, some of which can be easily + # predicted by n-gram matching and some which likely cannot. + for kind in random_prompt_type_choices: + word_choices = ["test", "temp", "hello", "where"] + word = random.choice(word_choices) + if kind == "repeat": + prompt = f""" + please repeat the word '{word}' 10 times. + give no other output than the word at least ten times in a row, + in lowercase with spaces between each word and without quotes. + """ + elif kind == "sentence": + prompt = f""" + please give a ten-word sentence that + uses the word {word} at least once. + give no other output than that simple sentence without quotes. + """ + else: + raise ValueError(f"Unknown prompt type: {kind}") + prompts.append([{"role": "user", "content": prompt}]) + + return prompts + + +@pytest.fixture +def sampling_config(): + return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False) + + +@pytest.mark.parametrize( + "method_model_and_draft_model", + [("eagle", "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", + "ronaldbxu/EAGLE-Llama-4-Maverick-17B-128E-Instruct")], + ids=[ + "llama4_eagle", + ]) +def test_eagle_correctness( + monkeypatch: pytest.MonkeyPatch, + test_prompts: list[list[dict[str, Any]]], + sampling_config: SamplingParams, + method_model_and_draft_model: tuple[str, str, str], +): + ''' + Compare the outputs of a original LLM and a speculative LLM + should be the same when using eagle speculative decoding. + ''' + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + + method, model_name, spec_model_name = method_model_and_draft_model + + tp = 8 + + ref_llm = LLM(model=model_name, + tensor_parallel_size=tp, + max_model_len=2048) + ref_outputs = ref_llm.chat(test_prompts, sampling_config) + del ref_llm + + spec_llm = LLM( + model=model_name, + trust_remote_code=True, + tensor_parallel_size=tp, + speculative_config={ + "method": method, + "model": spec_model_name, + "num_speculative_tokens": 3, + "max_model_len": 2048, + }, + max_model_len=2048, + ) + spec_outputs = spec_llm.chat(test_prompts, sampling_config) + matches = 0 + misses = 0 + for ref_output, spec_output in zip(ref_outputs, spec_outputs): + if ref_output.outputs[0].text == spec_output.outputs[0].text: + matches += 1 + else: + misses += 1 + print(f"ref_output: {ref_output.outputs[0].text}") + print(f"spec_output: {spec_output.outputs[0].text}") + + # Heuristic: expect at least 66% of the prompts to match exactly + # Upon failure, inspect the outputs to check for inaccuracy. + assert matches > int(0.66 * len(ref_outputs)) + del spec_llm diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 93e7c12f3a09..01cac773bc2f 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -53,14 +53,6 @@ def model_name(): return "meta-llama/Llama-3.1-8B-Instruct" -def eagle_model_name(): - return "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" - - -def eagle3_model_name(): - return "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" - - def test_ngram_correctness( monkeypatch: pytest.MonkeyPatch, test_prompts: list[list[dict[str, Any]]], @@ -105,13 +97,17 @@ def test_ngram_correctness( del spec_llm -@pytest.mark.parametrize("use_eagle3", [False, True], ids=["eagle", "eagle3"]) +@pytest.mark.parametrize("method_model_and_draft_model", + [("eagle", "meta-llama/Llama-3.1-8B-Instruct", + "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"), + ("eagle3", "meta-llama/Llama-3.1-8B-Instruct", + "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B")], + ids=["llama3_eagle", "llama3_eagle3"]) def test_eagle_correctness( monkeypatch: pytest.MonkeyPatch, test_prompts: list[list[dict[str, Any]]], sampling_config: SamplingParams, - model_name: str, - use_eagle3: bool, + method_model_and_draft_model: tuple[str, str, str], ): ''' Compare the outputs of a original LLM and a speculative LLM @@ -120,17 +116,17 @@ def test_eagle_correctness( with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") + method, model_name, spec_model_name = method_model_and_draft_model + ref_llm = LLM(model=model_name, max_model_len=2048) ref_outputs = ref_llm.chat(test_prompts, sampling_config) del ref_llm - spec_model_name = eagle3_model_name( - ) if use_eagle3 else eagle_model_name() spec_llm = LLM( model=model_name, trust_remote_code=True, speculative_config={ - "method": "eagle3" if use_eagle3 else "eagle", + "method": method, "model": spec_model_name, "num_speculative_tokens": 3, "max_model_len": 2048, diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 5efab2c14407..1afd79ce33c0 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -13,12 +13,16 @@ from vllm.platforms import current_platform from vllm.v1.spec_decode.eagle import EagleProposer -model_dir = "meta-llama/Llama-3.1-8B-Instruct" -eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" -eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" +llama3_model_dir = "meta-llama/Llama-3.1-8B-Instruct" +llama3_eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" +llama3_eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" +llama4_model_dir = "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8" +llama4_eagle_dir = "ronaldbxu/EAGLE-Llama-4-Maverick-17B-128E-Instruct" -def _create_proposer(method: str, k: int) -> EagleProposer: + +def _create_proposer(method: str, model_dir: str, draft_model_dir: str, + k: int) -> EagleProposer: model_config = ModelConfig(model=model_dir, task="generate", max_model_len=100, @@ -28,9 +32,6 @@ def _create_proposer(method: str, k: int) -> EagleProposer: seed=None, trust_remote_code=False) - # Choose model directory based on method - draft_model_dir = eagle_dir if method == "eagle" else eagle3_dir - speculative_config = SpeculativeConfig( target_model_config=model_config, target_parallel_config=ParallelConfig(), @@ -118,8 +119,14 @@ def test_prepare_inputs(): @pytest.mark.parametrize("method,proposer_helper", [ - ("eagle", lambda k: _create_proposer("eagle", k)), - ("eagle3", lambda k: _create_proposer("eagle3", k)), + ("eagle", + lambda k: _create_proposer("eagle", llama3_model_dir, llama3_eagle_dir, k) + ), + ("eagle", + lambda k: _create_proposer("eagle", llama4_model_dir, llama4_eagle_dir, k) + ), + ("eagle3", lambda k: _create_proposer("eagle3", llama3_model_dir, + llama3_eagle3_dir, k)), ]) @pytest.mark.parametrize("pp_size", [1, 2]) @pytest.mark.parametrize("use_distinct_embed_tokens", [True, False]) @@ -199,7 +206,12 @@ class _TargetModelStub(LlamaForCausalLM): @pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8]) -def test_propose(num_speculative_tokens): +@pytest.mark.parametrize("model_and_draft_model", + [(llama3_model_dir, llama3_eagle_dir), + (llama4_model_dir, llama4_eagle_dir)]) +def test_propose(num_speculative_tokens, model_and_draft_model): + model_dir = model_and_draft_model[0] + draft_model_dir = model_and_draft_model[1] # Use GPU device device = torch.device(current_platform.device_type) @@ -211,7 +223,8 @@ def test_propose(num_speculative_tokens): vocab_size = 100 # Create proposer first so we can use its actual hidden_size - proposer = _create_proposer("eagle", num_speculative_tokens) + proposer = _create_proposer("eagle", model_dir, draft_model_dir, + num_speculative_tokens) # Get the hidden_size from the proposer to ensure consistency hidden_size = proposer.hidden_size diff --git a/vllm/model_executor/models/llama4_eagle.py b/vllm/model_executor/models/llama4_eagle.py new file mode 100644 index 000000000000..2671f9ad8622 --- /dev/null +++ b/vllm/model_executor/models/llama4_eagle.py @@ -0,0 +1,258 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Iterable + +import torch +import torch.nn as nn + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.llama4 import (Llama4DecoderLayer, + Llama4ForCausalLM) +from vllm.model_executor.models.utils import (AutoWeightsLoader, + extract_layer_index, + is_pp_missing_parameter, + maybe_prefix) + + +@support_torch_compile +class EagleLlama4Model(nn.Module): + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + start_layer_id: int = 0): + + super().__init__() + self.config = ( + vllm_config.speculative_config.draft_model_config.hf_config) + self.vocab_size = self.config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + self.config.vocab_size, + self.config.hidden_size, + prefix=maybe_prefix(prefix, "embed_tokens"), + ) + + if vllm_config.speculative_config.quantization: + self.quant_config = vllm_config.quant_config + else: + self.quant_config = None + + self.layers = nn.ModuleList([ + Llama4DecoderLayer( + config=self.config, + quant_config=self.quant_config, + prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"), + ) + ]) + self.fc = torch.nn.Linear(self.config.hidden_size * 2, + self.config.hidden_size, + bias=False) + self.num_experts = self.config.num_local_experts + + self.norm = RMSNorm( + hidden_size=self.config.hidden_size, + eps=self.config.rms_norm_eps, + ) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + input_embeds = self.embed_tokens(input_ids) + hidden_states = self.fc( + torch.cat((input_embeds, hidden_states), dim=-1)) + residual = None + + for layer in self.layers: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states, hidden_states + + def load_moe_expert_weights( + self, + name: str, + loaded_weight: torch.Tensor, + params_dict: dict[str, nn.Parameter], + loaded_params: set[str], + expert_params_mapping: list[tuple[str, str, int, str]], + fused: bool = True, + ) -> bool: + expert_param_loaded = False + if "experts.gate_up_proj" in name: + loaded_weight = loaded_weight.chunk(2, dim=-1) + for (param_name, weight_name, expert_id, + shard_id) in expert_params_mapping: + new_loaded_weight = loaded_weight + if fused: + e_str, _, proj_str, _ = weight_name.split('.') + weight_name = f"{e_str}.{proj_str}" + param_name = f"{param_name}weight" + if weight_name not in name: + continue + full_param_name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + param = params_dict[full_param_name] + weight_loader = param.weight_loader + if fused: + if "w13" in full_param_name: + shard_idx = 0 if shard_id == "w1" else 1 + new_loaded_weight = new_loaded_weight[shard_idx] + new_loaded_weight = new_loaded_weight.transpose(-1, -2) + layer_idx = extract_layer_index(name) + # EP mapping + expert_map = self.layers[ + layer_idx].feed_forward.experts.expert_map + if expert_map is not None: + local_expert_indices = (expert_map != -1) \ + .nonzero() \ + .flatten() \ + .to(new_loaded_weight.device) + new_loaded_weight = new_loaded_weight[local_expert_indices] + expert_id = local_expert_indices[0].item() + else: + # TODO: add EP support for non fused weights + pass + weight_loader(param, + new_loaded_weight, + full_param_name, + shard_id=shard_id, + expert_id=expert_id) + + loaded_params.add(full_param_name) + expert_param_loaded = True + return expert_param_loaded + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + fused_experts_params = False + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.num_experts) + expert_params_mapping_fused = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_up_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="gate_up_proj", + num_experts=1) + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "experts.gate_up_proj" in name or "experts.down_proj" in name: + fused_experts_params = True + expert_params_mapping = expert_params_mapping_fused + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name or "experts" in name: + continue + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + loaded_params.add(name) + break + else: + moe_loaded = self.load_moe_expert_weights( + name, + loaded_weight, + params_dict, + loaded_params, + expert_params_mapping, + fused=fused_experts_params) + + if not moe_loaded: + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params + + +class EagleLlama4ForCausalLM(Llama4ForCausalLM): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + self.config = ( + vllm_config.speculative_config.draft_model_config.hf_config) + + start_layer_id = vllm_config.model_config.get_num_layers( + vllm_config.parallel_config) + if start_layer_id > 0: + original_no_rope_layers = self.config.no_rope_layers + + # If start_layer_id is 0, we will hit NotImplementedError in + # vllm/v1/utils.py. If we don't pad no_rope_layers, will get + # index out of bounds in constructor of Llama4Attention layer. + self.config.no_rope_layers = [None] * start_layer_id + self.config.no_rope_layers.extend(original_no_rope_layers) + + self.model = EagleLlama4Model(vllm_config=vllm_config, + prefix="model", + start_layer_id=start_layer_id) + + logit_scale = getattr(self.config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.config.vocab_size, + scale=logit_scale) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + return self.model(input_ids, positions, hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader(self) + + model_weights = {} + for name, loaded_weight in weights: + if "lm_head" not in name: + name = "model." + name + model_weights[name] = loaded_weight + loader.load_weights(model_weights.items()) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index e8530a555d28..74cdd0bec97b 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -241,6 +241,7 @@ "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"), "EAGLEModel": ("eagle", "EAGLE"), "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"), + "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"), "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"), "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),