Skip to content

Commit 939ecdd

Browse files
committed
[None][chore] Refine qwen3-next implementation.
Signed-off-by: nv-guomingz <[email protected]>
1 parent 38d6e4e commit 939ecdd

File tree

5 files changed

+44
-47
lines changed

5 files changed

+44
-47
lines changed

tensorrt_llm/_torch/models/modeling_qwen3_next.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,17 +1104,15 @@ def __init__(
11041104
aux_stream,
11051105
layer_idx=layer_idx)
11061106

1107-
use_gemma_rms_norm = True
11081107
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
11091108
eps=config.rms_norm_eps,
11101109
dtype=config.torch_dtype,
1111-
use_gemma_rms_norm=use_gemma_rms_norm)
1110+
use_gemma=True)
11121111

1113-
self.post_attention_layernorm = RMSNorm(
1114-
hidden_size=config.hidden_size,
1115-
eps=config.rms_norm_eps,
1116-
dtype=config.torch_dtype,
1117-
use_gemma_rms_norm=use_gemma_rms_norm)
1112+
self.post_attention_layernorm = RMSNorm(hidden_size=config.hidden_size,
1113+
eps=config.rms_norm_eps,
1114+
dtype=config.torch_dtype,
1115+
use_gemma=True)
11181116
self.layer_idx = layer_idx
11191117

11201118
self.allreduce = AllReduce(mapping=model_config.mapping,
@@ -1266,17 +1264,15 @@ def __init__(self, model_config: ModelConfig[Qwen3NextConfig],
12661264
aux_stream,
12671265
layer_idx=layer_idx)
12681266

1269-
use_gemma_rms_norm = True
12701267
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
12711268
eps=config.rms_norm_eps,
12721269
dtype=config.torch_dtype,
1273-
use_gemma_rms_norm=use_gemma_rms_norm)
1270+
use_gemma=True)
12741271

1275-
self.post_attention_layernorm = RMSNorm(
1276-
hidden_size=config.hidden_size,
1277-
eps=config.rms_norm_eps,
1278-
dtype=config.torch_dtype,
1279-
use_gemma_rms_norm=use_gemma_rms_norm)
1272+
self.post_attention_layernorm = RMSNorm(hidden_size=config.hidden_size,
1273+
eps=config.rms_norm_eps,
1274+
dtype=config.torch_dtype,
1275+
use_gemma=True)
12801276
self.layer_idx = layer_idx
12811277

12821278
self.allreduce = AllReduce(mapping=model_config.mapping,
@@ -1444,12 +1440,11 @@ def __init__(self, model_config: ModelConfig[Qwen3NextConfig]):
14441440
) for layer_idx in range(pretrained_config.num_hidden_layers)
14451441
])
14461442

1447-
use_gemma_rms_norm = True
14481443
self.norm = RMSNorm(
14491444
hidden_size=pretrained_config.hidden_size,
14501445
eps=pretrained_config.rms_norm_eps,
14511446
dtype=pretrained_config.torch_dtype,
1452-
use_gemma_rms_norm=use_gemma_rms_norm,
1447+
use_gemma=True,
14531448
)
14541449

14551450
self.mamba_metadata: Optional[Mamba2Metadata] = None

tensorrt_llm/_torch/modules/attention.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,7 @@ def __init__(
177177
self.attn_output_gate = attn_output_gate
178178

179179
if self.attn_output_gate:
180-
logger.warning_once("using attn output gate!",
181-
key="attn_output_gate")
180+
logger.info_once("using attn output gate!", key="attn_output_gate")
182181

183182
# [Chunked Attention]
184183
# Chunked attention is applied to context requests only. Chunked attention will be
@@ -224,7 +223,7 @@ def __init__(
224223

225224
self.qkv_proj = Linear(
226225
self.hidden_size,
227-
tp_size * self.q_size * (1 + (1 if self.attn_output_gate else 0)) +
226+
tp_size * self.q_size * (2 if self.attn_output_gate else 1) +
228227
2 * tp_size * self.kv_size,
229228
bias=bias,
230229
dtype=dtype,
@@ -533,10 +532,11 @@ def forward(
533532
q_gate, k, v = qkv.split(
534533
[self.q_size * 2, self.kv_size, self.kv_size], dim=-1)
535534
orig_shape = q_gate.shape[:-1]
536-
q_gate = q_gate.view(*orig_shape, self.num_heads, -1)
537-
q, gate = torch.chunk(q_gate, 2, dim=-1)
538-
q = q.reshape(*orig_shape, -1)
539-
gate = gate.reshape(*orig_shape, -1)
535+
# Single line: view -> chunk -> reshape both q and gate
536+
q, gate = [
537+
t.reshape(*orig_shape, -1) for t in torch.chunk(
538+
q_gate.view(*orig_shape, self.num_heads, -1), 2, dim=-1)
539+
]
540540
### TODO: avoid the redundant split and concat
541541
qkv = torch.concat([q, k, v], dim=-1)
542542

@@ -584,8 +584,7 @@ def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor],
584584
"""
585585
# If RoPE is fused into the attention OP, do not apply RoPE here.
586586
if not self.rope_fusion and position_ids is not None:
587-
if k is None and v is None:
588-
q, k, v = self.split_qkv(q, k, v)
587+
q, k, v = self.split_qkv(q, k, v)
589588
q, k = self.rotary_emb(position_ids, [q, k])
590589
return q, k, v
591590

tensorrt_llm/_torch/modules/qk_norm_attention.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -166,10 +166,9 @@ def __init__(
166166
if use_gemma_rms_norm:
167167
assert fuse_qk_norm_rope is False, "fused_qk_norm_rope is not supported for gemma rms norm."
168168

169-
# If fuse_qk_norm_rope is true, do not apply fused RoPE in attention OP, and self.rotary_emb will be skipped in the overridden apply_rope.
170-
rope_fusion = not self.fuse_qk_norm_rope and not skip_rope
171-
if attn_output_gate and use_gemma_rms_norm:
172-
rope_fusion = False
169+
# If fuse_qk_norm_rope is true, do not apply fused RoPE in attention OP, and self.rotary_emb
170+
# will be skipped in the overridden apply_rope.
171+
rope_fusion = not self.fuse_qk_norm_rope and not skip_rope and not attn_output_gate and not use_gemma_rms_norm
173172
assert not (fuse_qk_norm_rope and skip_rope
174173
), "Fusing qk norm and skipping rope is not supported"
175174

@@ -180,8 +179,6 @@ def __init__(
180179
max_position_embeddings=max_position_embeddings,
181180
bias=bias,
182181
pos_embd_params=pos_embd_params,
183-
# If fuse_qk_norm_rope is true, do not apply fused RoPE in attention OP,
184-
# and self.rotary_emb will be skipped in the overridden apply_rope.
185182
rope_fusion=rope_fusion,
186183
layer_idx=layer_idx,
187184
dtype=dtype,
@@ -196,12 +193,12 @@ def __init__(
196193
eps=self.pretrained_config.rms_norm_eps,
197194
dtype=self.pretrained_config.torch_dtype,
198195
has_weights=True,
199-
use_gemma_rms_norm=use_gemma_rms_norm)
196+
use_gemma=use_gemma_rms_norm)
200197
self.k_norm = RMSNorm(hidden_size=self.head_dim,
201198
eps=self.pretrained_config.rms_norm_eps,
202199
dtype=self.pretrained_config.torch_dtype,
203200
has_weights=True,
204-
use_gemma_rms_norm=use_gemma_rms_norm)
201+
use_gemma=use_gemma_rms_norm)
205202
self.aux_stream = torch.cuda.Stream()
206203
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
207204

tensorrt_llm/_torch/modules/rms_norm.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,22 @@ class RMSNorm(nn.Module):
2929
_ArgumentNotSpecifiedSentinelType: TypeAlias = EllipsisType
3030

3131
def __init__(
32-
self,
33-
*,
34-
hidden_size: int,
35-
eps: float,
36-
dtype: Optional[torch.dtype] = None,
37-
device: Optional[torch.device] = None,
38-
has_weights: bool = True,
39-
use_gemma_rms_norm: bool = False, # Assume has_weights = True
32+
self,
33+
*,
34+
hidden_size: int,
35+
eps: float,
36+
dtype: Optional[torch.dtype] = None,
37+
device: Optional[torch.device] = None,
38+
has_weights: bool = True,
39+
use_gemma: bool = False,
4040
):
4141
super().__init__()
42+
43+
if use_gemma and not has_weights:
44+
raise ValueError("has_weights must be True if use_gemma is True")
45+
4246
if has_weights:
43-
if not use_gemma_rms_norm:
47+
if not use_gemma:
4448
self.weight = nn.Parameter(
4549
torch.ones(hidden_size, dtype=dtype, device=device))
4650
else:
@@ -53,7 +57,7 @@ def __init__(
5357
device=device),
5458
persistent=False)
5559
self.variance_epsilon = eps
56-
self.use_gemma_rms_norm = use_gemma_rms_norm
60+
self.use_gemma = use_gemma
5761

5862
def forward(
5963
self,
@@ -73,7 +77,7 @@ def forward(
7377
flashinfer_gemma_rmsnorm,
7478
flashinfer_rmsnorm)
7579
if residual is not None:
76-
if not self.use_gemma_rms_norm:
80+
if not self.use_gemma:
7781
flashinfer_fused_add_rmsnorm(hidden_states, residual,
7882
self.weight,
7983
self.variance_epsilon)
@@ -82,7 +86,7 @@ def forward(
8286
self.weight,
8387
self.variance_epsilon)
8488
else:
85-
if not self.use_gemma_rms_norm:
89+
if not self.use_gemma:
8690
hidden_states = flashinfer_rmsnorm(hidden_states,
8791
self.weight,
8892
self.variance_epsilon)
@@ -99,7 +103,7 @@ def forward(
99103
variance = hidden_states.pow(2).mean(-1, keepdim=True)
100104
hidden_states = hidden_states * torch.rsqrt(variance +
101105
self.variance_epsilon)
102-
if not self.use_gemma_rms_norm:
106+
if not self.use_gemma:
103107
hidden_states = self.weight * hidden_states.to(input_dtype)
104108
else:
105109
hidden_states = (self.weight +

tensorrt_llm/_torch/pyexecutor/config_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,6 @@ def is_mla(config):
1414

1515

1616
def is_qwen3_next(config):
17-
return getattr(config, 'linear_key_head_dim', 0) > 0
17+
return hasattr(
18+
config,
19+
'architectures') and config.architectures[0] == 'Qwen3NextForCausalLM'

0 commit comments

Comments
 (0)