Skip to content

Commit 734e071

Browse files
hlu1dominicshanshan
authored andcommitted
Refactor CutlassFusedMoE (NVIDIA#5344)
Signed-off-by: Hao Lu <[email protected]>
1 parent 04a6de8 commit 734e071

File tree

8 files changed

+1115
-749
lines changed

8 files changed

+1115
-749
lines changed

tensorrt_llm/_mnnvl_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ def __init__(self, mapping: Mapping, size: int):
7171

7272
def __del__(self):
7373
if not sys.is_finalizing():
74-
MnnvlMemory.close_mnnvl_memory(self.ptr)
74+
if hasattr(self, "ptr"):
75+
MnnvlMemory.close_mnnvl_memory(self.ptr)
7576

7677
def as_torch_strided_tensor(self, dtype):
7778
num_segments = MnnvlMemory.comm.Get_size()

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 47 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
from ..modules.decoder_layer import DecoderLayer
5555
from ..modules.embedding import Embedding
5656
from ..modules.fused_moe import (CutlassFusedMoE, DeepSeekV3MoeRoutingMethod,
57-
create_moe)
57+
WideEPMoE, create_moe)
5858
from ..modules.gated_mlp import GatedMLP
5959
from ..modules.linear import Linear, TensorParallelMode, WeightsLoadingConfig
6060
from ..modules.multi_stream_utils import maybe_execute_in_parallel
@@ -513,7 +513,7 @@ def compute_routed_output(self, hidden_states, hidden_states_fp4,
513513
self.mapping,
514514
dim=0,
515515
sizes=all_rank_num_tokens)
516-
elif not isinstance(self.experts, CutlassFusedMoE) or (
516+
elif not isinstance(self.experts, (CutlassFusedMoE, WideEPMoE)) or (
517517
not self.experts.has_fp8_qdq and self.experts.has_nvfp4):
518518
# Use padding when not using the cutlass path or when x_sf in self.experts is not None
519519
use_dp_padding = True
@@ -780,116 +780,70 @@ def _run_MoE(hidden_states, hidden_states_fp4, do_finalize):
780780
do_finalize=do_finalize,
781781
)
782782

783-
cutlass_min_latency_mode = self._enable_min_latency_mode(
784-
hidden_states.shape[0])
785-
786-
if cutlass_min_latency_mode:
787-
assert self.fusion_config.PRE_MOE_FUSION and self.fusion_config.POST_MOE_FUSION
788-
assert self.model_config.moe_backend == 'CUTLASS'
789-
790-
hidden_states, hidden_states_act, hidden_states_sf, residual = self.allreduce(
783+
if self.fusion_config.PRE_MOE_FUSION:
784+
# moe_backend can be either CUTLASS or TRTLLM here
785+
# TODO: unify the two min-latency MoE backends by enabling quant fusion
786+
hidden_states, residual = self.allreduce(
791787
hidden_states,
792788
all_reduce_params=AllReduceParams(
793-
fusion_op=AllReduceFusionOp.
794-
RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4,
789+
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
795790
residual=residual,
796791
norm_weight=self.post_attention_layernorm.weight,
797-
scale=self.mlp.experts.fc31_input_scale,
798792
eps=self.post_attention_layernorm.variance_epsilon,
793+
trigger_completion_at_end=False,
799794
))
800-
hidden_states_fp4 = Fp4QuantizedTensor(hidden_states_act,
801-
hidden_states_sf)
802-
803-
hidden_states = _run_MoE(hidden_states,
804-
hidden_states_fp4,
805-
do_finalize=False)
795+
else:
796+
# No fusion
797+
hidden_states, residual = self.post_attention_layernorm(
798+
hidden_states, residual)
806799

807-
shared_output = hidden_states[0]
808-
hidden_states_activated_experts = hidden_states[1]
809-
num_activated_experts_per_node = hidden_states[2]
810-
experts_to_token_score = hidden_states[3]
800+
# Note: this fusion pattern is only supported for TRTLLM-nvfp4 backend now
801+
do_finalize = not (hidden_states.shape[0]
802+
<= self.moe_allreduce.max_token
803+
and self.fusion_config.POST_MOE_FUSION
804+
and self.model_config.moe_backend == 'TRTLLM'
805+
and self.mlp.experts.has_nvfp4)
811806

812-
moe_all_reduce_params = MoEAllReduceParams(
813-
residual=residual,
814-
norm_weight=self.next_layer_layernorm.weight,
815-
device_num_experts=num_activated_experts_per_node,
816-
expert_scale_factor=experts_to_token_score,
817-
shared_expert_output=shared_output,
818-
eps=self.next_layer_layernorm.variance_epsilon,
819-
is_cutlass_min_latency=True,
820-
)
807+
hidden_states = _run_MoE(hidden_states,
808+
hidden_states_fp4=None,
809+
do_finalize=do_finalize)
821810

822-
# MoE_finalize is fused into allreduce
823-
hidden_states, residual = self.moe_allreduce(
824-
hidden_states_activated_experts,
825-
all_reduce_params=moe_all_reduce_params,
826-
)
827-
else:
828-
if self.fusion_config.PRE_MOE_FUSION:
829-
# moe_backend can be either CUTLASS or TRTLLM here
830-
# TODO: unify the two min-latency MoE backends by enabling quant fusion
811+
if self.fusion_config.POST_MOE_FUSION:
812+
if do_finalize:
831813
hidden_states, residual = self.allreduce(
832814
hidden_states,
833815
all_reduce_params=AllReduceParams(
834816
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
835817
residual=residual,
836-
norm_weight=self.post_attention_layernorm.weight,
837-
eps=self.post_attention_layernorm.variance_epsilon,
818+
norm_weight=self.next_layer_layernorm.weight,
819+
eps=self.next_layer_layernorm.variance_epsilon,
838820
trigger_completion_at_end=False,
839821
))
840822
else:
841-
# No fusion
842-
hidden_states, residual = self.post_attention_layernorm(
823+
assert len(
824+
hidden_states) == 4, "hidden_states must have 4 elements"
825+
826+
shared_output = hidden_states[0]
827+
fc2_output = hidden_states[1]
828+
expert_scale_factor = hidden_states[2]
829+
expanded_idx_to_permuted_idx = hidden_states[3]
830+
831+
moe_all_reduce_params = MoEAllReduceParams(
832+
expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx,
833+
expert_scale_factor=expert_scale_factor,
834+
shared_expert_output=shared_output,
835+
residual=residual,
836+
norm_weight=self.next_layer_layernorm.weight,
837+
eps=self.next_layer_layernorm.variance_epsilon,
838+
is_cutlass_min_latency=False,
839+
)
840+
hidden_states, residual = self.moe_allreduce(
841+
fc2_output, all_reduce_params=moe_all_reduce_params)
842+
else:
843+
if self.next_layer_layernorm is not None:
844+
hidden_states, residual = self.next_layer_layernorm(
843845
hidden_states, residual)
844846

845-
# Note: this fusion pattern is only supported for TRTLLM-nvfp4 backend now
846-
do_finalize = not (hidden_states.shape[0]
847-
<= self.moe_allreduce.max_token
848-
and self.fusion_config.POST_MOE_FUSION
849-
and self.model_config.moe_backend == 'TRTLLM'
850-
and self.mlp.experts.has_nvfp4)
851-
852-
hidden_states = _run_MoE(hidden_states,
853-
hidden_states_fp4=None,
854-
do_finalize=do_finalize)
855-
856-
if self.fusion_config.POST_MOE_FUSION:
857-
if do_finalize:
858-
hidden_states, residual = self.allreduce(
859-
hidden_states,
860-
all_reduce_params=AllReduceParams(
861-
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
862-
residual=residual,
863-
norm_weight=self.next_layer_layernorm.weight,
864-
eps=self.next_layer_layernorm.variance_epsilon,
865-
trigger_completion_at_end=False,
866-
))
867-
else:
868-
assert len(hidden_states
869-
) == 4, "hidden_states must have 4 elements"
870-
871-
shared_output = hidden_states[0]
872-
fc2_output = hidden_states[1]
873-
expert_scale_factor = hidden_states[2]
874-
expanded_idx_to_permuted_idx = hidden_states[3]
875-
876-
moe_all_reduce_params = MoEAllReduceParams(
877-
expanded_idx_to_permuted_idx=
878-
expanded_idx_to_permuted_idx,
879-
expert_scale_factor=expert_scale_factor,
880-
shared_expert_output=shared_output,
881-
residual=residual,
882-
norm_weight=self.next_layer_layernorm.weight,
883-
eps=self.next_layer_layernorm.variance_epsilon,
884-
is_cutlass_min_latency=False,
885-
)
886-
hidden_states, residual = self.moe_allreduce(
887-
fc2_output, all_reduce_params=moe_all_reduce_params)
888-
else:
889-
if self.next_layer_layernorm is not None:
890-
hidden_states, residual = self.next_layer_layernorm(
891-
hidden_states, residual)
892-
893847
return hidden_states, residual
894848

895849
def forward_mlp(

tensorrt_llm/_torch/models/modeling_qwen3_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from ..modules.fused_moe import (BaseMoeRoutingMethod, CutlassFusedMoE, MoE,
1616
RenormalizeMoeRoutingMethod,
1717
RenormalizeNaiveMoeRoutingMethod,
18-
RoutingMethodType, create_moe)
18+
RoutingMethodType, WideEPMoE, create_moe)
1919
from ..modules.linear import TensorParallelMode
2020
from ..modules.rms_norm import RMSNorm
2121
from ..speculative import SpecMetadata
@@ -138,7 +138,7 @@ def forward(
138138
self.mapping,
139139
dim=0,
140140
sizes=all_rank_num_tokens)
141-
elif not isinstance(self.experts, CutlassFusedMoE) or (
141+
elif not isinstance(self.experts, (CutlassFusedMoE, WideEPMoE)) or (
142142
not self.experts.has_fp8_qdq and self.experts.has_nvfp4):
143143
# Use padding when not using the cutlass path or when x_sf in self.experts is not None
144144
use_dp_padding = True

tensorrt_llm/_torch/modules/fused_moe/__init__.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .fused_moe_cutlass import CutlassFusedMoE
33
from .fused_moe_trtllm_gen import TRTLLMGenFusedMoE
44
from .fused_moe_vanilla import VanillaMoE
5+
from .fused_moe_wide_ep import WideEPMoE
56
from .interface import MoE, MoEWeightLoadingMode
67
from .moe_load_balancer import MoeLoadBalancer
78
from .quantization import FusedMoEQuantScalesFP8
@@ -13,23 +14,24 @@
1314
SparseMixerMoeRoutingMethod, StaticMoeRoutingMethod)
1415

1516
__all__ = [
16-
"VanillaMoE",
17-
"CutlassFusedMoE",
18-
"TRTLLMGenFusedMoE",
1917
"BaseMoeRoutingMethod",
20-
"MoeLoadBalancer",
21-
"RenormalizeNaiveMoeRoutingMethod",
18+
"create_moe",
19+
"CutlassFusedMoE",
20+
"DeepSeekV3MoeRoutingMethod",
21+
"DefaultMoeRoutingMethod",
22+
"FusedMoEQuantScalesFP8",
23+
"get_moe_cls",
2224
"Llama4RenormalizeMoeRoutingMethod",
23-
"SparseMixerMoeRoutingMethod",
2425
"LoadBalancedMoeRoutingMethod",
25-
"StaticMoeRoutingMethod",
26-
"DefaultMoeRoutingMethod",
27-
"DeepSeekV3MoeRoutingMethod",
28-
"RoutingMethodType",
29-
"RenormalizeMoeRoutingMethod",
3026
"MoE",
27+
"MoeLoadBalancer",
3128
"MoEWeightLoadingMode",
32-
"get_moe_cls",
33-
"create_moe",
34-
"FusedMoEQuantScalesFP8",
29+
"RenormalizeMoeRoutingMethod",
30+
"RenormalizeNaiveMoeRoutingMethod",
31+
"RoutingMethodType",
32+
"SparseMixerMoeRoutingMethod",
33+
"StaticMoeRoutingMethod",
34+
"TRTLLMGenFusedMoE",
35+
"VanillaMoE",
36+
"WideEPMoE",
3537
]

tensorrt_llm/_torch/modules/fused_moe/create_moe.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,16 @@
99
from .fused_moe_cutlass import CutlassFusedMoE
1010
from .fused_moe_trtllm_gen import TRTLLMGenFusedMoE
1111
from .fused_moe_vanilla import VanillaMoE
12+
from .fused_moe_wide_ep import WideEPMoE
1213
from .interface import MoE, MoEWeightLoadingMode
1314
from .moe_load_balancer import get_moe_load_balancer
1415
from .routing import BaseMoeRoutingMethod
1516

1617

1718
def get_moe_cls(
1819
model_config: ModelConfig,
20+
routing_method: BaseMoeRoutingMethod,
21+
dtype: Optional[torch.dtype] = None,
1922
override_quant_config: Optional[QuantConfig] = None) -> Type[MoE]:
2023
moe_backend = model_config.moe_backend
2124
quant_config = model_config.quant_config
@@ -36,6 +39,8 @@ def get_moe_cls(
3639
f"Check out details in quant_config: {quant_config}"
3740
"Using CutlassFusedMoE instead.")
3841
return CutlassFusedMoE
42+
elif moe_backend.upper() == "WIDEEP":
43+
return WideEPMoE
3944
else:
4045
raise ValueError(f"Unsupported moe backend: {moe_backend}")
4146

@@ -54,7 +59,8 @@ def create_moe(
5459
apply_router_weight_on_input: bool = False,
5560
layer_idx: Optional[int] = None,
5661
) -> MoE:
57-
moe_cls = get_moe_cls(model_config, override_quant_config)
62+
moe_cls = get_moe_cls(model_config, routing_method, dtype,
63+
override_quant_config)
5864

5965
moe_load_balancer = get_moe_load_balancer()
6066
if moe_load_balancer is not None:
@@ -88,6 +94,20 @@ def create_moe(
8894
apply_router_weight_on_input=apply_router_weight_on_input,
8995
layer_idx=layer_idx,
9096
)
97+
elif moe_cls == WideEPMoE:
98+
return moe_cls(
99+
routing_method=routing_method,
100+
num_experts=num_experts,
101+
hidden_size=hidden_size,
102+
intermediate_size=intermediate_size,
103+
dtype=dtype,
104+
reduce_results=reduce_results,
105+
model_config=model_config,
106+
aux_stream=aux_stream,
107+
weight_loading_mode=weight_loading_mode,
108+
apply_router_weight_on_input=apply_router_weight_on_input,
109+
layer_idx=layer_idx,
110+
)
91111
elif moe_cls == VanillaMoE:
92112
assert not apply_router_weight_on_input, "apply_router_weight_on_input is not supported in VanillaMoE."
93113

0 commit comments

Comments
 (0)