Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 11 additions & 16 deletions tensorrt_llm/_torch/models/modeling_qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,17 +1104,15 @@ def __init__(
aux_stream,
layer_idx=layer_idx)

use_gemma_rms_norm = True
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype,
use_gemma_rms_norm=use_gemma_rms_norm)
use_gemma=True)

self.post_attention_layernorm = RMSNorm(
hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype,
use_gemma_rms_norm=use_gemma_rms_norm)
self.post_attention_layernorm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype,
use_gemma=True)
self.layer_idx = layer_idx

self.allreduce = AllReduce(mapping=model_config.mapping,
Expand Down Expand Up @@ -1266,17 +1264,15 @@ def __init__(self, model_config: ModelConfig[Qwen3NextConfig],
aux_stream,
layer_idx=layer_idx)

use_gemma_rms_norm = True
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype,
use_gemma_rms_norm=use_gemma_rms_norm)
use_gemma=True)

self.post_attention_layernorm = RMSNorm(
hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype,
use_gemma_rms_norm=use_gemma_rms_norm)
self.post_attention_layernorm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype,
use_gemma=True)
self.layer_idx = layer_idx

self.allreduce = AllReduce(mapping=model_config.mapping,
Expand Down Expand Up @@ -1444,12 +1440,11 @@ def __init__(self, model_config: ModelConfig[Qwen3NextConfig]):
) for layer_idx in range(pretrained_config.num_hidden_layers)
])

use_gemma_rms_norm = True
self.norm = RMSNorm(
hidden_size=pretrained_config.hidden_size,
eps=pretrained_config.rms_norm_eps,
dtype=pretrained_config.torch_dtype,
use_gemma_rms_norm=use_gemma_rms_norm,
use_gemma=True,
)

self.mamba_metadata: Optional[Mamba2Metadata] = None
Expand Down
17 changes: 8 additions & 9 deletions tensorrt_llm/_torch/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,7 @@ def __init__(
self.attn_output_gate = attn_output_gate

if self.attn_output_gate:
logger.warning_once("using attn output gate!",
key="attn_output_gate")
logger.info_once("using attn output gate!", key="attn_output_gate")

# [Chunked Attention]
# Chunked attention is applied to context requests only. Chunked attention will be
Expand Down Expand Up @@ -224,7 +223,7 @@ def __init__(

self.qkv_proj = Linear(
self.hidden_size,
tp_size * self.q_size * (1 + (1 if self.attn_output_gate else 0)) +
tp_size * self.q_size * (2 if self.attn_output_gate else 1) +
2 * tp_size * self.kv_size,
bias=bias,
dtype=dtype,
Expand Down Expand Up @@ -533,10 +532,11 @@ def forward(
q_gate, k, v = qkv.split(
[self.q_size * 2, self.kv_size, self.kv_size], dim=-1)
orig_shape = q_gate.shape[:-1]
q_gate = q_gate.view(*orig_shape, self.num_heads, -1)
q, gate = torch.chunk(q_gate, 2, dim=-1)
q = q.reshape(*orig_shape, -1)
gate = gate.reshape(*orig_shape, -1)
# Single line: view -> chunk -> reshape both q and gate
q, gate = [
t.reshape(*orig_shape, -1) for t in torch.chunk(
q_gate.view(*orig_shape, self.num_heads, -1), 2, dim=-1)
]
### TODO: avoid the redundant split and concat
qkv = torch.concat([q, k, v], dim=-1)

Expand Down Expand Up @@ -584,8 +584,7 @@ def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor],
"""
# If RoPE is fused into the attention OP, do not apply RoPE here.
if not self.rope_fusion and position_ids is not None:
if k is None and v is None:
q, k, v = self.split_qkv(q, k, v)
q, k, v = self.split_qkv(q, k, v)
q, k = self.rotary_emb(position_ids, [q, k])
return q, k, v

Expand Down
13 changes: 5 additions & 8 deletions tensorrt_llm/_torch/modules/qk_norm_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,9 @@ def __init__(
if use_gemma_rms_norm:
assert fuse_qk_norm_rope is False, "fused_qk_norm_rope is not supported for gemma rms norm."

# If fuse_qk_norm_rope is true, do not apply fused RoPE in attention OP, and self.rotary_emb will be skipped in the overridden apply_rope.
rope_fusion = not self.fuse_qk_norm_rope and not skip_rope
if attn_output_gate and use_gemma_rms_norm:
rope_fusion = False
# If fuse_qk_norm_rope is true, do not apply fused RoPE in attention OP, and self.rotary_emb
# will be skipped in the overridden apply_rope.
rope_fusion = not self.fuse_qk_norm_rope and not skip_rope and not attn_output_gate and not use_gemma_rms_norm
assert not (fuse_qk_norm_rope and skip_rope
), "Fusing qk norm and skipping rope is not supported"

Expand All @@ -180,8 +179,6 @@ def __init__(
max_position_embeddings=max_position_embeddings,
bias=bias,
pos_embd_params=pos_embd_params,
# If fuse_qk_norm_rope is true, do not apply fused RoPE in attention OP,
# and self.rotary_emb will be skipped in the overridden apply_rope.
rope_fusion=rope_fusion,
layer_idx=layer_idx,
dtype=dtype,
Expand All @@ -196,12 +193,12 @@ def __init__(
eps=self.pretrained_config.rms_norm_eps,
dtype=self.pretrained_config.torch_dtype,
has_weights=True,
use_gemma_rms_norm=use_gemma_rms_norm)
use_gemma=use_gemma_rms_norm)
self.k_norm = RMSNorm(hidden_size=self.head_dim,
eps=self.pretrained_config.rms_norm_eps,
dtype=self.pretrained_config.torch_dtype,
has_weights=True,
use_gemma_rms_norm=use_gemma_rms_norm)
use_gemma=use_gemma_rms_norm)
self.aux_stream = torch.cuda.Stream()
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]

Expand Down
30 changes: 17 additions & 13 deletions tensorrt_llm/_torch/modules/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,22 @@ class RMSNorm(nn.Module):
_ArgumentNotSpecifiedSentinelType: TypeAlias = EllipsisType

def __init__(
self,
*,
hidden_size: int,
eps: float,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
has_weights: bool = True,
use_gemma_rms_norm: bool = False, # Assume has_weights = True
self,
*,
hidden_size: int,
eps: float,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
has_weights: bool = True,
use_gemma: bool = False,
):
super().__init__()

if use_gemma and not has_weights:
raise ValueError("has_weights must be True if use_gemma is True")

if has_weights:
if not use_gemma_rms_norm:
if not use_gemma:
self.weight = nn.Parameter(
torch.ones(hidden_size, dtype=dtype, device=device))
else:
Expand All @@ -53,7 +57,7 @@ def __init__(
device=device),
persistent=False)
self.variance_epsilon = eps
self.use_gemma_rms_norm = use_gemma_rms_norm
self.use_gemma = use_gemma

def forward(
self,
Expand All @@ -73,7 +77,7 @@ def forward(
flashinfer_gemma_rmsnorm,
flashinfer_rmsnorm)
if residual is not None:
if not self.use_gemma_rms_norm:
if not self.use_gemma:
flashinfer_fused_add_rmsnorm(hidden_states, residual,
self.weight,
self.variance_epsilon)
Expand All @@ -82,7 +86,7 @@ def forward(
self.weight,
self.variance_epsilon)
else:
if not self.use_gemma_rms_norm:
if not self.use_gemma:
hidden_states = flashinfer_rmsnorm(hidden_states,
self.weight,
self.variance_epsilon)
Expand All @@ -99,7 +103,7 @@ def forward(
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance +
self.variance_epsilon)
if not self.use_gemma_rms_norm:
if not self.use_gemma:
hidden_states = self.weight * hidden_states.to(input_dtype)
else:
hidden_states = (self.weight +
Expand Down
5 changes: 4 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,7 @@ def is_mla(config):


def is_qwen3_next(config):
return getattr(config, 'linear_key_head_dim', 0) > 0
return hasattr(
config, 'architectures'
) and config.architectures is not None and config.architectures[
0] == 'Qwen3NextForCausalLM'