diff --git a/tensorrt_llm/_torch/models/modeling_qwen3.py b/tensorrt_llm/_torch/models/modeling_qwen3.py index 26353acdb04..8635e510f42 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen3.py +++ b/tensorrt_llm/_torch/models/modeling_qwen3.py @@ -16,8 +16,9 @@ from ..modules.linear import TensorParallelMode from ..modules.multi_stream_utils import maybe_execute_in_parallel from ..modules.rms_norm import RMSNorm -from .modeling_utils import (DecoderModel, DecoderModelForCausalLM, - register_auto_model) +from ..speculative import SpecMetadata +from .modeling_speculative import SpecDecOneEngineForCausalLM +from .modeling_utils import DecoderModel, register_auto_model class Qwen3Attention(Attention): @@ -148,6 +149,7 @@ def forward( attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], mrope_config: Optional[Tuple[torch.Tensor, int]] = None, + spec_metadata: Optional[SpecMetadata] = None, **kwargs, ) -> torch.Tensor: if residual is None: @@ -171,6 +173,10 @@ def forward( hidden_states, residual) hidden_states = self.mlp(hidden_states) + if spec_metadata is not None: + spec_metadata.maybe_capture_hidden_states(self.layer_idx, + hidden_states, residual) + return hidden_states, residual @@ -207,6 +213,7 @@ def forward( position_ids: Optional[torch.IntTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, mrope_config: Optional[Tuple[torch.Tensor, int]] = None, + spec_metadata: Optional[SpecMetadata] = None, **kwargs, ) -> torch.Tensor: if (input_ids is None) ^ (inputs_embeds is not None): @@ -227,6 +234,7 @@ def forward( attn_metadata=attn_metadata, residual=residual, mrope_config=mrope_config, + spec_metadata=spec_metadata, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -234,7 +242,7 @@ def forward( @register_auto_model("Qwen3ForCausalLM") -class Qwen3ForCausalLM(DecoderModelForCausalLM[Qwen3Model, Qwen3Config]): +class Qwen3ForCausalLM(SpecDecOneEngineForCausalLM[Qwen3Model, Qwen3Config]): def __init__( self, @@ -242,33 +250,5 @@ def __init__( ): super().__init__( Qwen3Model(model_config), - config=model_config, - hidden_size=model_config.pretrained_config.hidden_size, - vocab_size=model_config.pretrained_config.vocab_size, - ) - - # NOTE: Qwen2-VL needs special mrope_config so adding separate forward() function to accept 'mrope_config'. - def forward( - self, - attn_metadata: AttentionMetadata, - input_ids: torch.IntTensor = None, - position_ids: Optional[torch.IntTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - return_context_logits: bool = False, - mrope_config: Optional[dict] = None, - **kwargs, - ) -> torch.Tensor: - output = self.model( - input_ids=input_ids, - attn_metadata=attn_metadata, - position_ids=position_ids, - inputs_embeds=inputs_embeds, - mrope_config=mrope_config, - ) - - return self.logits_processor.forward( - output, - self.lm_head, - attn_metadata, - return_context_logits, + model_config, ) diff --git a/tests/integration/defs/accuracy/references/mmlu.yaml b/tests/integration/defs/accuracy/references/mmlu.yaml index 7beba282671..25111ebec23 100644 --- a/tests/integration/defs/accuracy/references/mmlu.yaml +++ b/tests/integration/defs/accuracy/references/mmlu.yaml @@ -145,6 +145,8 @@ Qwen3/Qwen3-8B: - quant_algo: FP8_BLOCK_SCALES accuracy: 76.12 - accuracy: 76.12 + - spec_dec_algo: Eagle + accuracy: 76.12 Qwen3/Qwen3-30B-A3B: - quant_algo: FP8_BLOCK_SCALES accuracy: 79.53 diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 701461b19d1..6777386349b 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -1607,6 +1607,30 @@ def test_bf16(self, tp_size, pp_size, ep_size, attention_dp, cuda_graph, task = MMLU(self.MODEL_NAME) task.evaluate(llm) + def test_eagle3(self): + pytorch_config = dict( + disable_overlap_scheduler=True, + cuda_graph_config=CudaGraphConfig(batch_sizes=[1]), + ) + kv_cache_config = KvCacheConfig(enable_block_reuse=False) + + eagle_model_dir = f"{llm_models_root()}/Qwen3/qwen3_8b_eagle3" + target_model_dir = f"{llm_models_root()}/Qwen3/Qwen3-8B" + + draft_len = 4 + spec_config = EagleDecodingConfig(max_draft_len=draft_len, + speculative_model_dir=eagle_model_dir) + + llm = LLM(model=target_model_dir, + **pytorch_config, + kv_cache_config=kv_cache_config, + speculative_config=spec_config, + build_config=None) + + with llm: + task = MMLU(self.MODEL_NAME) + task.evaluate(llm) + class TestQwen3_30B_A3B(LlmapiAccuracyTestHarness): MODEL_NAME = "Qwen3/Qwen3-30B-A3B" diff --git a/tests/integration/test_lists/test-db/l0_h100.yml b/tests/integration/test_lists/test-db/l0_h100.yml index 3e9f0d3995b..46ab1bb4d6e 100644 --- a/tests/integration/test_lists/test-db/l0_h100.yml +++ b/tests/integration/test_lists/test-db/l0_h100.yml @@ -38,6 +38,7 @@ l0_h100: - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_no_kv_cache_reuse[quant_dtype=fp8-mtp_nextn=2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True] - accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency] - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8[latency] + - accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_eagle3 - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_cuda_graph_padding[mtp_nextn=0] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_cuda_graph_padding[mtp_nextn=2] - test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-False-False]