Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 12 additions & 32 deletions tensorrt_llm/_torch/models/modeling_qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
from ..modules.linear import TensorParallelMode
from ..modules.multi_stream_utils import maybe_execute_in_parallel
from ..modules.rms_norm import RMSNorm
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
register_auto_model)
from ..speculative import SpecMetadata
from .modeling_speculative import SpecDecOneEngineForCausalLM
from .modeling_utils import DecoderModel, register_auto_model


class Qwen3Attention(Attention):
Expand Down Expand Up @@ -148,6 +149,7 @@ def forward(
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
mrope_config: Optional[Tuple[torch.Tensor, int]] = None,
spec_metadata: Optional[SpecMetadata] = None,
**kwargs,
) -> torch.Tensor:
if residual is None:
Expand All @@ -171,6 +173,10 @@ def forward(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)

if spec_metadata is not None:
spec_metadata.maybe_capture_hidden_states(self.layer_idx,
hidden_states, residual)

return hidden_states, residual


Expand Down Expand Up @@ -207,6 +213,7 @@ def forward(
position_ids: Optional[torch.IntTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
mrope_config: Optional[Tuple[torch.Tensor, int]] = None,
spec_metadata: Optional[SpecMetadata] = None,
**kwargs,
) -> torch.Tensor:
if (input_ids is None) ^ (inputs_embeds is not None):
Expand All @@ -227,48 +234,21 @@ def forward(
attn_metadata=attn_metadata,
residual=residual,
mrope_config=mrope_config,
spec_metadata=spec_metadata,
)

hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states


@register_auto_model("Qwen3ForCausalLM")
class Qwen3ForCausalLM(DecoderModelForCausalLM[Qwen3Model, Qwen3Config]):
class Qwen3ForCausalLM(SpecDecOneEngineForCausalLM[Qwen3Model, Qwen3Config]):

def __init__(
self,
model_config: ModelConfig[Qwen3Config],
):
super().__init__(
Qwen3Model(model_config),
config=model_config,
hidden_size=model_config.pretrained_config.hidden_size,
vocab_size=model_config.pretrained_config.vocab_size,
)

# NOTE: Qwen2-VL needs special mrope_config so adding separate forward() function to accept 'mrope_config'.
def forward(
self,
attn_metadata: AttentionMetadata,
input_ids: torch.IntTensor = None,
position_ids: Optional[torch.IntTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
return_context_logits: bool = False,
mrope_config: Optional[dict] = None,
**kwargs,
) -> torch.Tensor:
output = self.model(
input_ids=input_ids,
attn_metadata=attn_metadata,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
mrope_config=mrope_config,
)

return self.logits_processor.forward(
output,
self.lm_head,
attn_metadata,
return_context_logits,
model_config,
)
2 changes: 2 additions & 0 deletions tests/integration/defs/accuracy/references/mmlu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ Qwen3/Qwen3-8B:
- quant_algo: FP8_BLOCK_SCALES
accuracy: 76.12
- accuracy: 76.12
- spec_dec_algo: Eagle
accuracy: 76.12
Qwen3/Qwen3-30B-A3B:
- quant_algo: FP8_BLOCK_SCALES
accuracy: 79.53
Expand Down
24 changes: 24 additions & 0 deletions tests/integration/defs/accuracy/test_llm_api_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1607,6 +1607,30 @@ def test_bf16(self, tp_size, pp_size, ep_size, attention_dp, cuda_graph,
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)

def test_eagle3(self):
pytorch_config = dict(
disable_overlap_scheduler=True,
cuda_graph_config=CudaGraphConfig(batch_sizes=[1]),
)
kv_cache_config = KvCacheConfig(enable_block_reuse=False)

eagle_model_dir = f"{llm_models_root()}/Qwen3/qwen3_8b_eagle3"
target_model_dir = f"{llm_models_root()}/Qwen3/Qwen3-8B"

draft_len = 4
spec_config = EagleDecodingConfig(max_draft_len=draft_len,
speculative_model_dir=eagle_model_dir)

llm = LLM(model=target_model_dir,
**pytorch_config,
kv_cache_config=kv_cache_config,
speculative_config=spec_config,
build_config=None)

with llm:
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)


class TestQwen3_30B_A3B(LlmapiAccuracyTestHarness):
MODEL_NAME = "Qwen3/Qwen3-30B-A3B"
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_lists/test-db/l0_h100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ l0_h100:
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_no_kv_cache_reuse[quant_dtype=fp8-mtp_nextn=2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True]
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency]
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8[latency]
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_eagle3
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_cuda_graph_padding[mtp_nextn=0]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_cuda_graph_padding[mtp_nextn=2]
- test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-False-False]
Expand Down