Skip to content

Commit 38d6e4e

Browse files
byshiuenv-guomingz
andauthored
[None][feat] Support Qwen3 next (#7892)
Signed-off-by: mengw <[email protected]> Signed-off-by: bhsueh <[email protected]> Signed-off-by: nv-guomingz <[email protected]> Co-authored-by: nv-guomingz <[email protected]>
1 parent a0d489a commit 38d6e4e

30 files changed

+5286
-39
lines changed

cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,14 @@ void launchFusedQKNormRope(void* qkv, int const num_tokens, int const num_heads_
312312
base, position_ids, num_tokens, factor, low, high, attention_factor);
313313
});
314314
break;
315+
case 256:
316+
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
317+
fusedQKNormRopeKernel<256, INTERLEAVE><<<gridDim, blockDim, 0, stream>>>(
318+
reinterpret_cast<__nv_bfloat16*>(qkv), num_heads_q, num_heads_k, num_heads_v, eps,
319+
reinterpret_cast<__nv_bfloat16 const*>(q_weight), reinterpret_cast<__nv_bfloat16 const*>(k_weight),
320+
base, position_ids, num_tokens, factor, low, high, attention_factor);
321+
});
322+
break;
315323
default: TLLM_THROW("Unsupported head dimension for fusedQKNormRope: %d", head_dim);
316324
}
317325
}

examples/models/core/qwen/README.md

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,14 @@ This document shows how to build and run a [Qwen](https://huggingface.co/Qwen) m
2121
- [Quick start](#quick-start)
2222
- [Run a single inference](#run-a-single-inference)
2323
- [Evaluation](#evaluation)
24-
- [Model Quantization to FP4](#model-quantization-to-fp4)
24+
- [Model Quantization](#model-quantization)
2525
- [Benchmark](#benchmark)
2626
- [Serving](#serving)
2727
- [trtllm-serve](#trtllm-serve)
2828
- [Disaggregated Serving](#disaggregated-serving)
2929
- [Eagle3](#eagle3)
30-
- [Dynamo](#dynamo)
30+
- [Dynamo](#dynamo)
31+
- [Qwen3-Next](#qwen3-next)
3132
- [Notes and Troubleshooting](#notes-and-troubleshooting)
3233
- [Credits](#credits)
3334

@@ -926,6 +927,15 @@ For further details, please refer to [speculative-decoding.md](../../../../docs/
926927
NVIDIA Dynamo is a high-throughput low-latency inference framework designed for serving generative AI and reasoning models in multi-node distributed environments.
927928
Dynamo supports TensorRT LLM as one of its inference engine. For details on how to use TensorRT LLM with Dynamo please refer to [LLM Deployment Examples using TensorRT-LLM](https://github.com/ai-dynamo/dynamo/blob/main/examples/tensorrt_llm/README.md)
928929

930+
## Qwen3-Next
931+
932+
Below is the command to run the Qwen3-Next model.
933+
934+
```bash
935+
mpirun -n 1 --allow-run-as-root --oversubscribe python3 examples/llm-api/quickstart_advanced.py --model_dir /Qwen3-Next-80B-A3B-Thinking --kv_cache_fraction 0.6 --disable_kv_cache_reuse --max_batch_size 1 --tp_size 4
936+
937+
```
938+
929939
## Notes and Troubleshooting
930940

931941
- **Model Directory:** Update `<YOUR_MODEL_DIR>` with the actual path where the model weights reside.

tensorrt_llm/_torch/custom_ops/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,15 @@
2222
if IS_FLASHINFER_AVAILABLE:
2323
from .flashinfer_custom_ops import (
2424
flashinfer_apply_rope_with_cos_sin_cache_inplace,
25-
flashinfer_fused_add_rmsnorm, flashinfer_rmsnorm,
26-
flashinfer_silu_and_mul)
25+
flashinfer_fused_add_rmsnorm, flashinfer_gemma_fused_add_rmsnorm,
26+
flashinfer_gemma_rmsnorm, flashinfer_rmsnorm, flashinfer_silu_and_mul)
2727
__all__ += [
2828
'flashinfer_silu_and_mul',
2929
'flashinfer_rmsnorm',
3030
'flashinfer_fused_add_rmsnorm',
3131
'flashinfer_apply_rope_with_cos_sin_cache_inplace',
32+
'flashinfer_gemma_fused_add_rmsnorm',
33+
'flashinfer_gemma_rmsnorm',
3234
]
3335

3436
if IS_CUTLASS_DSL_AVAILABLE:

tensorrt_llm/_torch/custom_ops/flashinfer_custom_ops.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
if IS_FLASHINFER_AVAILABLE:
66
from flashinfer.activation import silu_and_mul
7-
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
7+
from flashinfer.norm import (fused_add_rmsnorm, gemma_fused_add_rmsnorm,
8+
gemma_rmsnorm, rmsnorm)
89
from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace
910

1011
# Warp this into custom op since flashinfer didn't warp it properly and we want to avoid graph break between mlp layer for user buffer optimization
@@ -27,13 +28,36 @@ def _(input: torch.Tensor, weight: torch.Tensor,
2728
eps: float) -> torch.Tensor:
2829
return torch.empty_like(input)
2930

31+
@torch.library.custom_op("trtllm::flashinfer_gemma_rmsnorm",
32+
mutates_args=())
33+
def flashinfer_gemma_rmsnorm(input: torch.Tensor, weight: torch.Tensor,
34+
eps: float) -> torch.Tensor:
35+
return gemma_rmsnorm(input, weight, eps, enable_pdl=ENABLE_PDL)
36+
37+
@flashinfer_gemma_rmsnorm.register_fake
38+
def _(input: torch.Tensor, weight: torch.Tensor,
39+
eps: float) -> torch.Tensor:
40+
return torch.empty_like(input)
41+
3042
@torch.library.custom_op("trtllm::flashinfer_fused_add_rmsnorm",
3143
mutates_args=("input", "residual"))
3244
def flashinfer_fused_add_rmsnorm(input: torch.Tensor,
3345
residual: torch.Tensor,
3446
weight: torch.Tensor, eps: float) -> None:
3547
fused_add_rmsnorm(input, residual, weight, eps, enable_pdl=ENABLE_PDL)
3648

49+
@torch.library.custom_op("trtllm::flashinfer_gemma_fused_add_rmsnorm",
50+
mutates_args=("input", "residual"))
51+
def flashinfer_gemma_fused_add_rmsnorm(input: torch.Tensor,
52+
residual: torch.Tensor,
53+
weight: torch.Tensor,
54+
eps: float) -> None:
55+
gemma_fused_add_rmsnorm(input,
56+
residual,
57+
weight,
58+
eps,
59+
enable_pdl=ENABLE_PDL)
60+
3761
@torch.library.custom_op(
3862
"trtllm::flashinfer_apply_rope_with_cos_sin_cache_inplace",
3963
mutates_args=("query", "key"))

tensorrt_llm/_torch/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from .modeling_qwen2vl import Qwen2_5_VLModel, Qwen2VLModel
2727
from .modeling_qwen3 import Qwen3ForCausalLM
2828
from .modeling_qwen3_moe import Qwen3MoeForCausalLM
29+
from .modeling_qwen3_next import Qwen3NextForCausalLM
2930
from .modeling_qwen_moe import Qwen2MoeForCausalLM
3031
from .modeling_seedoss import SeedOssForCausalLM
3132
from .modeling_siglip import SiglipVisionModel
@@ -66,6 +67,7 @@
6667
"Qwen2_5_VLModel",
6768
"Qwen3ForCausalLM",
6869
"Qwen3MoeForCausalLM",
70+
"Qwen3NextForCausalLM",
6971
"GptOssForCausalLM",
7072
"SeedOssForCausalLM",
7173
]

tensorrt_llm/_torch/models/checkpoints/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@
88
from .hf.qwen2_moe_weight_mapper import Qwen2MoeHfWeightMapper
99
from .hf.qwen2vl_weight_mapper import Qwen2VLHfWeightMapper
1010
from .hf.qwen3_moe_weight_mapper import Qwen3MoeHfWeightMapper
11+
from .hf.qwen3_next_weight_mapper import Qwen3NextHfWeightMapper
1112
from .hf.weight_loader import HfWeightLoader
1213
from .hf.weight_mapper import HfWeightMapper
1314

1415
__all__ = [
1516
"HfConfigLoader", "HfWeightLoader", "HfWeightMapper",
1617
"BaseCheckpointLoader", "HfCheckpointLoader", "NemotronHHfWeightMapper",
1718
"Gemma3HfWeightMapper", "MixtralHfWeightMapper", "Llama4HfWeightMapper",
18-
"Qwen2MoeHfWeightMapper", "Qwen3MoeHfWeightMapper", "Qwen2VLHfWeightMapper"
19+
"Qwen2MoeHfWeightMapper", "Qwen3MoeHfWeightMapper", "Qwen2VLHfWeightMapper",
20+
"Qwen3NextHfWeightMapper"
1921
]
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
from typing import Union
2+
3+
import torch
4+
from torch import nn
5+
6+
from tensorrt_llm._torch.model_config import ModelConfig
7+
from tensorrt_llm._torch.models.checkpoints.hf.qwen2_moe_weight_mapper import \
8+
Qwen2MoeHfWeightMapper
9+
from tensorrt_llm._torch.models.modeling_nemotron_h import split
10+
from tensorrt_llm._torch.models.modeling_utils import register_mapper
11+
from tensorrt_llm.models.modeling_utils import DecoderModelForCausalLM
12+
13+
14+
@register_mapper("HF", "Qwen3NextForCausalLM")
15+
class Qwen3NextHfWeightMapper(Qwen2MoeHfWeightMapper):
16+
17+
def init_model_and_config(self, model: Union[nn.Module,
18+
DecoderModelForCausalLM],
19+
config: ModelConfig):
20+
super().init_model_and_config(model, config)
21+
self._num_kv_heads = model.config.num_key_value_heads if hasattr(
22+
model.config, 'num_key_value_heads'
23+
) and model.config.num_key_value_heads is not None else model.config.num_attention_heads
24+
25+
def should_skip_module(self, module_name: str) -> bool:
26+
if module_name.startswith("draft_model"):
27+
return True
28+
return super().should_skip_module(module_name)
29+
30+
def _duplicate_kv_weights(self, module: nn.Module, new_name: str,
31+
weights: dict):
32+
tensors_to_duplicate = ["weight", "bias"]
33+
if module.quant_config.quant_mode.has_nvfp4():
34+
tensors_to_duplicate.append("weight_scale")
35+
if module.quant_config.quant_mode.has_fp8_block_scales():
36+
tensors_to_duplicate.append("weight_scale_inv")
37+
38+
if new_name in ['k_proj', 'v_proj']:
39+
num_kv_heads_list = [self._num_kv_heads
40+
] * len(weights) if isinstance(
41+
self._num_kv_heads,
42+
int) else self._num_kv_heads
43+
processed_weights = {
44+
k:
45+
self._duplicate_kv(weight=v[:],
46+
num_kv_heads=num_kv_heads_list[i],
47+
tensor_parallel_size=self._tp_size)
48+
if k in tensors_to_duplicate else v
49+
for i, (k, v) in enumerate(weights.items())
50+
}
51+
return processed_weights
52+
53+
return weights
54+
55+
def preprocess_weights(self, weights: dict) -> dict:
56+
config = self.config.pretrained_config
57+
tp_size = self.config.mapping.tp_size
58+
tp_rank = self.config.mapping.tp_rank
59+
60+
# linear_num_value_heads = config.linear_num_value_heads
61+
# linear_num_key_heads = config.linear_num_key_heads
62+
# linear_key_head_dim = config.linear_key_head_dim
63+
# linear_value_head_dim = config.linear_value_head_dim
64+
linear_key_dim = config.linear_key_head_dim * config.linear_num_key_heads # 16 * 128
65+
linear_value_dim = config.linear_value_head_dim * config.linear_num_value_heads # 32 * 128
66+
67+
new_weights = {}
68+
for name, _ in weights.items():
69+
key = name
70+
71+
if "A_log" in key:
72+
w = split(weights[name], tp_size, tp_rank)
73+
w = w.to(torch.float32)
74+
new_weights[key] = w
75+
elif "dt_bias" in key:
76+
w = split(weights[name], tp_size, tp_rank)
77+
w = w.to(torch.float32)
78+
new_weights[key] = w
79+
elif "in_proj" in key:
80+
# Don't need to split in_proj weight based on the implementation of reference.
81+
# Need to know the reason.
82+
new_weights[key] = weights[name]
83+
elif "conv1d" in key:
84+
w = weights[name]
85+
# removing dim(1) because we are using Linear to store conv1d weights
86+
if "weight" in key:
87+
w = w.squeeze(1)
88+
89+
conv_q, conv_k, conv_v = torch.split(
90+
w, [linear_key_dim, linear_key_dim, linear_value_dim],
91+
dim=0)
92+
93+
w = []
94+
for rank in range(tp_size):
95+
conv_q_rank = split(conv_q, tp_size, rank)
96+
conv_k_rank = split(conv_k, tp_size, rank)
97+
conv_v_rank = split(conv_v, tp_size, rank)
98+
y = torch.concat([conv_q_rank, conv_k_rank, conv_v_rank])
99+
w.append(y)
100+
w = torch.concat(w).contiguous()
101+
new_weights[key] = w
102+
else:
103+
new_weights[key] = weights[name]
104+
105+
return new_weights

tensorrt_llm/_torch/models/modeling_qwen3.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,12 @@ def __init__(
3232
model_config: ModelConfig[Qwen3Config],
3333
layer_idx: Optional[int] = None,
3434
fuse_qk_norm_rope: bool = True,
35+
attn_output_gate: bool = False,
36+
use_gemma_rms_norm: bool = False,
3537
):
3638
config = model_config.pretrained_config
39+
self.pretrained_config = config
40+
self.attn_output_gate = attn_output_gate
3741

3842
if getattr(config, "rope_scaling", None) is not None:
3943
if "type" in config.rope_scaling:
@@ -58,13 +62,15 @@ def __init__(
5862
num_attention_heads=config.num_attention_heads,
5963
num_key_value_heads=config.num_key_value_heads,
6064
max_position_embeddings=config.max_position_embeddings,
61-
bias=config.attention_bias,
65+
bias=getattr(config, "attention_bias", None),
6266
pos_embd_params=pos_embd_params,
6367
fuse_qk_norm_rope=fuse_qk_norm_rope,
6468
layer_idx=layer_idx,
6569
dtype=config.torch_dtype,
66-
dense_bias=config.attention_bias,
70+
dense_bias=getattr(config, "attention_bias", None),
6771
config=model_config,
72+
attn_output_gate=self.attn_output_gate,
73+
use_gemma_rms_norm=use_gemma_rms_norm,
6874
)
6975

7076

0 commit comments

Comments
 (0)