Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tensorrt_llm/_mnnvl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def __init__(self, mapping: Mapping, size: int):

def __del__(self):
if not sys.is_finalizing():
MnnvlMemory.close_mnnvl_memory(self.ptr)
if hasattr(self, "ptr"):
MnnvlMemory.close_mnnvl_memory(self.ptr)

def as_torch_strided_tensor(self, dtype):
num_segments = MnnvlMemory.comm.Get_size()
Expand Down
146 changes: 47 additions & 99 deletions tensorrt_llm/_torch/models/modeling_deepseekv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
from ..modules.decoder_layer import DecoderLayer
from ..modules.embedding import Embedding
from ..modules.fused_moe import (CutlassFusedMoE, DeepSeekV3MoeRoutingMethod,
create_moe)
WideEPMoE, create_moe)
from ..modules.gated_mlp import GatedMLP
from ..modules.linear import Linear, TensorParallelMode, WeightsLoadingConfig
from ..modules.multi_stream_utils import maybe_execute_in_parallel
Expand Down Expand Up @@ -511,7 +511,7 @@ def compute_routed_output(self, hidden_states, hidden_states_fp4,
self.mapping,
dim=0,
sizes=all_rank_num_tokens)
elif not isinstance(self.experts, CutlassFusedMoE) or (
elif not isinstance(self.experts, (CutlassFusedMoE, WideEPMoE)) or (
not self.experts.has_fp8_qdq and self.experts.has_nvfp4):
# Use padding when not using the cutlass path or when x_sf in self.experts is not None
use_dp_padding = True
Expand Down Expand Up @@ -721,12 +721,6 @@ def _compute_mlp_tp_size(self, intermediate_size: int,
) if tp > self.mapping.gpus_per_node else tp # Avoid costly inter-node TP
return mlp_tp_size

def _enable_min_latency_mode(self, num_tokens: int):
return (num_tokens <= 128 and self.fusion_config.POST_MOE_FUSION
and self.is_nvfp4 and self.model_config.moe_backend == 'CUTLASS'
and not self.mapping.is_multi_node()
and self.allreduce.mnnvl_allreduce is None)

def forward(
self,
position_ids: torch.IntTensor,
Expand Down Expand Up @@ -779,116 +773,70 @@ def _run_MoE(hidden_states, hidden_states_fp4, do_finalize):
do_finalize=do_finalize,
)

cutlass_min_latency_mode = self._enable_min_latency_mode(
hidden_states.shape[0])

if cutlass_min_latency_mode:
assert self.fusion_config.PRE_MOE_FUSION and self.fusion_config.POST_MOE_FUSION
assert self.model_config.moe_backend == 'CUTLASS'

hidden_states, hidden_states_act, hidden_states_sf, residual = self.allreduce(
if self.fusion_config.PRE_MOE_FUSION:
# moe_backend can be either CUTLASS or TRTLLM here
# TODO: unify the two min-latency MoE backends by enabling quant fusion
hidden_states, residual = self.allreduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=AllReduceFusionOp.
RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4,
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
residual=residual,
norm_weight=self.post_attention_layernorm.weight,
scale=self.mlp.experts.fc31_input_scale,
eps=self.post_attention_layernorm.variance_epsilon,
trigger_completion_at_end=False,
))
hidden_states_fp4 = Fp4QuantizedTensor(hidden_states_act,
hidden_states_sf)

hidden_states = _run_MoE(hidden_states,
hidden_states_fp4,
do_finalize=False)
else:
# No fusion
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)

shared_output = hidden_states[0]
hidden_states_activated_experts = hidden_states[1]
num_activated_experts_per_node = hidden_states[2]
experts_to_token_score = hidden_states[3]
# Note: this fusion pattern is only supported for TRTLLM-nvfp4 backend now
do_finalize = not (hidden_states.shape[0]
<= self.moe_allreduce.max_token
and self.fusion_config.POST_MOE_FUSION
and self.model_config.moe_backend == 'TRTLLM'
and self.mlp.experts.has_nvfp4)

moe_all_reduce_params = MoEAllReduceParams(
residual=residual,
norm_weight=self.next_layer_layernorm.weight,
device_num_experts=num_activated_experts_per_node,
expert_scale_factor=experts_to_token_score,
shared_expert_output=shared_output,
eps=self.next_layer_layernorm.variance_epsilon,
is_cutlass_min_latency=True,
)
hidden_states = _run_MoE(hidden_states,
hidden_states_fp4=None,
do_finalize=do_finalize)

# MoE_finalize is fused into allreduce
hidden_states, residual = self.moe_allreduce(
hidden_states_activated_experts,
all_reduce_params=moe_all_reduce_params,
)
else:
if self.fusion_config.PRE_MOE_FUSION:
# moe_backend can be either CUTLASS or TRTLLM here
# TODO: unify the two min-latency MoE backends by enabling quant fusion
if self.fusion_config.POST_MOE_FUSION:
if do_finalize:
hidden_states, residual = self.allreduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
residual=residual,
norm_weight=self.post_attention_layernorm.weight,
eps=self.post_attention_layernorm.variance_epsilon,
norm_weight=self.next_layer_layernorm.weight,
eps=self.next_layer_layernorm.variance_epsilon,
trigger_completion_at_end=False,
))
else:
# No fusion
hidden_states, residual = self.post_attention_layernorm(
assert len(
hidden_states) == 4, "hidden_states must have 4 elements"

shared_output = hidden_states[0]
fc2_output = hidden_states[1]
expert_scale_factor = hidden_states[2]
expanded_idx_to_permuted_idx = hidden_states[3]

moe_all_reduce_params = MoEAllReduceParams(
expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx,
expert_scale_factor=expert_scale_factor,
shared_expert_output=shared_output,
residual=residual,
norm_weight=self.next_layer_layernorm.weight,
eps=self.next_layer_layernorm.variance_epsilon,
is_cutlass_min_latency=False,
)
hidden_states, residual = self.moe_allreduce(
fc2_output, all_reduce_params=moe_all_reduce_params)
else:
if self.next_layer_layernorm is not None:
hidden_states, residual = self.next_layer_layernorm(
hidden_states, residual)

# Note: this fusion pattern is only supported for TRTLLM-nvfp4 backend now
do_finalize = not (hidden_states.shape[0]
<= self.moe_allreduce.max_token
and self.fusion_config.POST_MOE_FUSION
and self.model_config.moe_backend == 'TRTLLM'
and self.mlp.experts.has_nvfp4)

hidden_states = _run_MoE(hidden_states,
hidden_states_fp4=None,
do_finalize=do_finalize)

if self.fusion_config.POST_MOE_FUSION:
if do_finalize:
hidden_states, residual = self.allreduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
residual=residual,
norm_weight=self.next_layer_layernorm.weight,
eps=self.next_layer_layernorm.variance_epsilon,
trigger_completion_at_end=False,
))
else:
assert len(hidden_states
) == 4, "hidden_states must have 4 elements"

shared_output = hidden_states[0]
fc2_output = hidden_states[1]
expert_scale_factor = hidden_states[2]
expanded_idx_to_permuted_idx = hidden_states[3]

moe_all_reduce_params = MoEAllReduceParams(
expanded_idx_to_permuted_idx=
expanded_idx_to_permuted_idx,
expert_scale_factor=expert_scale_factor,
shared_expert_output=shared_output,
residual=residual,
norm_weight=self.next_layer_layernorm.weight,
eps=self.next_layer_layernorm.variance_epsilon,
is_cutlass_min_latency=False,
)
hidden_states, residual = self.moe_allreduce(
fc2_output, all_reduce_params=moe_all_reduce_params)
else:
if self.next_layer_layernorm is not None:
hidden_states, residual = self.next_layer_layernorm(
hidden_states, residual)

return hidden_states, residual

def forward_mlp(
Expand Down
4 changes: 2 additions & 2 deletions tensorrt_llm/_torch/models/modeling_qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ..modules.fused_moe import (BaseMoeRoutingMethod, CutlassFusedMoE, MoE,
RenormalizeMoeRoutingMethod,
RenormalizeNaiveMoeRoutingMethod,
RoutingMethodType, create_moe)
RoutingMethodType, WideEPMoE, create_moe)
from ..modules.linear import TensorParallelMode
from ..modules.rms_norm import RMSNorm
from ..speculative import SpecMetadata
Expand Down Expand Up @@ -138,7 +138,7 @@ def forward(
self.mapping,
dim=0,
sizes=all_rank_num_tokens)
elif not isinstance(self.experts, CutlassFusedMoE) or (
elif not isinstance(self.experts, (CutlassFusedMoE, WideEPMoE)) or (
not self.experts.has_fp8_qdq and self.experts.has_nvfp4):
# Use padding when not using the cutlass path or when x_sf in self.experts is not None
use_dp_padding = True
Expand Down
30 changes: 16 additions & 14 deletions tensorrt_llm/_torch/modules/fused_moe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .fused_moe_cutlass import CutlassFusedMoE
from .fused_moe_trtllm_gen import TRTLLMGenFusedMoE
from .fused_moe_vanilla import VanillaMoE
from .fused_moe_wide_ep import WideEPMoE
from .interface import MoE, MoEWeightLoadingMode
from .moe_load_balancer import MoeLoadBalancer
from .quantization import FusedMoEQuantScalesFP8
Expand All @@ -13,23 +14,24 @@
SparseMixerMoeRoutingMethod, StaticMoeRoutingMethod)

__all__ = [
"VanillaMoE",
"CutlassFusedMoE",
"TRTLLMGenFusedMoE",
"BaseMoeRoutingMethod",
"MoeLoadBalancer",
"RenormalizeNaiveMoeRoutingMethod",
"create_moe",
"CutlassFusedMoE",
"DeepSeekV3MoeRoutingMethod",
"DefaultMoeRoutingMethod",
"FusedMoEQuantScalesFP8",
"get_moe_cls",
"Llama4RenormalizeMoeRoutingMethod",
"SparseMixerMoeRoutingMethod",
"LoadBalancedMoeRoutingMethod",
"StaticMoeRoutingMethod",
"DefaultMoeRoutingMethod",
"DeepSeekV3MoeRoutingMethod",
"RoutingMethodType",
"RenormalizeMoeRoutingMethod",
"MoE",
"MoeLoadBalancer",
"MoEWeightLoadingMode",
"get_moe_cls",
"create_moe",
"FusedMoEQuantScalesFP8",
"RenormalizeMoeRoutingMethod",
"RenormalizeNaiveMoeRoutingMethod",
"RoutingMethodType",
"SparseMixerMoeRoutingMethod",
"StaticMoeRoutingMethod",
"TRTLLMGenFusedMoE",
"VanillaMoE",
"WideEPMoE",
]
22 changes: 21 additions & 1 deletion tensorrt_llm/_torch/modules/fused_moe/create_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@
from .fused_moe_cutlass import CutlassFusedMoE
from .fused_moe_trtllm_gen import TRTLLMGenFusedMoE
from .fused_moe_vanilla import VanillaMoE
from .fused_moe_wide_ep import WideEPMoE
from .interface import MoE, MoEWeightLoadingMode
from .moe_load_balancer import get_moe_load_balancer
from .routing import BaseMoeRoutingMethod


def get_moe_cls(
model_config: ModelConfig,
routing_method: BaseMoeRoutingMethod,
dtype: Optional[torch.dtype] = None,
override_quant_config: Optional[QuantConfig] = None) -> Type[MoE]:
moe_backend = model_config.moe_backend
quant_config = model_config.quant_config
Expand All @@ -36,6 +39,8 @@ def get_moe_cls(
f"Check out details in quant_config: {quant_config}"
"Using CutlassFusedMoE instead.")
return CutlassFusedMoE
elif moe_backend.upper() == "WIDEEP":
return WideEPMoE
else:
raise ValueError(f"Unsupported moe backend: {moe_backend}")

Expand All @@ -54,7 +59,8 @@ def create_moe(
apply_router_weight_on_input: bool = False,
layer_idx: Optional[int] = None,
) -> MoE:
moe_cls = get_moe_cls(model_config, override_quant_config)
moe_cls = get_moe_cls(model_config, routing_method, dtype,
override_quant_config)

moe_load_balancer = get_moe_load_balancer()
if moe_load_balancer is not None:
Expand Down Expand Up @@ -88,6 +94,20 @@ def create_moe(
apply_router_weight_on_input=apply_router_weight_on_input,
layer_idx=layer_idx,
)
elif moe_cls == WideEPMoE:
return moe_cls(
routing_method=routing_method,
num_experts=num_experts,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
dtype=dtype,
reduce_results=reduce_results,
model_config=model_config,
aux_stream=aux_stream,
weight_loading_mode=weight_loading_mode,
apply_router_weight_on_input=apply_router_weight_on_input,
layer_idx=layer_idx,
)
elif moe_cls == VanillaMoE:
assert not apply_router_weight_on_input, "apply_router_weight_on_input is not supported in VanillaMoE."

Expand Down
Loading