Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 49 additions & 8 deletions tensorrt_llm/_torch/models/modeling_qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from ..attention_backend import AttentionMetadata
from ..distributed import (AllReduce, AllReduceFusionOp, AllReduceParams,
allgather)
MoEAllReduce, MoEAllReduceParams, allgather)
from ..model_config import ModelConfig
from ..modules.decoder_layer import DecoderLayer
from ..modules.embedding import Embedding
Expand Down Expand Up @@ -119,13 +119,17 @@ def forward(
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
all_reduce_params: Optional[AllReduceParams] = None,
do_finalize: Optional[bool] = True,
) -> torch.Tensor:
assert hidden_states.shape[-1] == self.hidden_dim
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_dim)
use_dp_padding = False
all_rank_num_tokens = attn_metadata.all_rank_num_tokens

if not do_finalize:
assert not self.enable_attention_dp

if self.enable_attention_dp and self.mapping.tp_size > 1:
# FP4 all_gather moves this bf16 allgather in to after topk and fp4 quantization
# to reduce allreduce BW
Expand All @@ -148,7 +152,12 @@ def forward(
hidden_states,
router_logits,
all_rank_num_tokens=all_rank_num_tokens,
use_dp_padding=use_dp_padding)
use_dp_padding=use_dp_padding,
do_finalize=do_finalize,
)

if not do_finalize:
return final_hidden_states

if not self.enable_attention_dp and self.mapping.tp_size > 1:
final_hidden_states = self.allreduce(
Expand All @@ -162,6 +171,7 @@ class Qwen3MoEDecoderLayer(DecoderLayer):
def __init__(self, model_config: ModelConfig[Qwen3MoeConfig],
layer_idx: int, aux_stream: torch.cuda.Stream):
super().__init__()
self.model_config = model_config
config = model_config.pretrained_config
self.self_attn = Qwen3Attention(
model_config,
Expand Down Expand Up @@ -198,6 +208,7 @@ def __init__(self, model_config: ModelConfig[Qwen3MoeConfig],
self.disable_attn_allreduce = (self.fusion_config.PRE_MOE_FUSION
or self.mapping.tp_size == 1
or self.enable_attention_dp)
self.moe_allreduce = MoEAllReduce(mapping=model_config.mapping)

def forward(
self,
Expand Down Expand Up @@ -236,25 +247,55 @@ def forward(
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)

# Note: this fusion pattern is only supported for TRTLLM-nvfp4 backend now
do_finalize = not (hidden_states.shape[0]
<= self.moe_allreduce.max_token
and self.fusion_config.POST_MOE_FUSION
and self.model_config.moe_backend == 'TRTLLM'
and self.mlp.experts.has_nvfp4)

hidden_states = self.mlp(
hidden_states,
attn_metadata,
all_reduce_params=AllReduceParams(
enable_allreduce=not (self.fusion_config.POST_MOE_FUSION
or self.mapping.tp_size == 1)))
or self.mapping.tp_size == 1)),
do_finalize=do_finalize,
)

if spec_metadata:
spec_metadata.maybe_capture_hidden_states(self.layer_idx,
hidden_states, residual)
if self.fusion_config.POST_MOE_FUSION:
hidden_states, residual = self.allreduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
if do_finalize:
hidden_states, residual = self.allreduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
residual=residual,
norm_weight=self.next_layer_layernorm.weight,
eps=self.next_layer_layernorm.variance_epsilon,
))
else:
assert len(
hidden_states
) == 3, f"hidden_states must have 3 elements, but got {len(hidden_states)}"

fc2_output = hidden_states[0]
expert_scale_factor = hidden_states[1]
expanded_idx_to_permuted_idx = hidden_states[2]

moe_all_reduce_params = MoEAllReduceParams(
expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx,
expert_scale_factor=expert_scale_factor,
shared_expert_output=None,
residual=residual,
norm_weight=self.next_layer_layernorm.weight,
eps=self.next_layer_layernorm.variance_epsilon,
))
is_cutlass_min_latency=False,
)
hidden_states, residual = self.moe_allreduce(
fc2_output, all_reduce_params=moe_all_reduce_params)
else:
if self.next_layer_layernorm is not None:
hidden_states, residual = self.next_layer_layernorm(
Expand Down