Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ There are multiple MOE backends inside TRT-LLM, not all of them supporting every
| B200/GB200 EP<=8 | NVFP4 | CUTLASS, TRTLLM |
| B200/GB200 EP<=8 | FP8 | DEEPGEMM |
| GB200 NVL72 EP>8 | NVFP4 | WIDEEP |
| GB200 NVL72 EP>8 | FP8 | N/A (WIP) |
| GB200 NVL72 EP>8 | FP8 | WIDEEP without EPLB |

The default moe backend is `CUTLASS`, so for the combination which is not supported by `CUTLASS`, one must set the `moe_config.backend` explicitly to run the model.

Expand Down
55 changes: 34 additions & 21 deletions tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import torch

from tensorrt_llm._mnnvl_utils import MnnvlMemory, MnnvlMoe, MoEAlltoallInfo
from tensorrt_llm._utils import logger
from tensorrt_llm._utils import get_sm_version
from tensorrt_llm.functional import AllReduceStrategy
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping

from ...distributed import AllReduce, allgather, reducescatter
Expand All @@ -16,7 +17,9 @@
from .deep_ep_utils import buffer_pool, deep_ep_installed
from .interface import MoE
from .moe_load_balancer import get_moe_load_balancer
from .ops import MoEOp, MoEOpSelector
from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethod,
DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm,
FP8QDQFusedMoEMethod, MoEWeightLoadingMode,
NVFP4CutlassFusedMoEMethod,
UnquantizedFusedMoEMethod, WInt4AFP8FusedMoEMethod)
Expand Down Expand Up @@ -90,6 +93,9 @@ def __init__(
# If True, the router weight will be multiplied on the input rather than at the end of FC2
self.apply_router_weight_on_input = apply_router_weight_on_input

# Store original hidden size before any potential padding
self.unpadded_hidden_size = self.hidden_size

moe_load_balancer = get_moe_load_balancer()
self.layer_load_balancer = None
self.repeat_idx = 0
Expand Down Expand Up @@ -227,6 +233,9 @@ def __init__(
self.enable_dummy_allreduce = os.environ.get(
"TRTLLM_ENABLE_DUMMY_ALLREDUCE", "0") == "1"

# MoE op will be lazily initialized when first accessed (see moe_op_impl property)
self._moe_op_impl = None

def _check_configs(self):
assert self._weights_created

Expand Down Expand Up @@ -316,7 +325,10 @@ def _get_quant_method(self):
if self.quant_config.layer_quant_mode.has_fp8_qdq():
return FP8QDQFusedMoEMethod()
elif self.quant_config.layer_quant_mode.has_fp8_block_scales():
return DeepSeekFP8BlockScalesFusedMoEMethod()
if get_sm_version() == 100:
return DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm()
else:
return DeepSeekFP8BlockScalesFusedMoEMethod()
elif self.quant_config.layer_quant_mode.has_nvfp4():
return NVFP4CutlassFusedMoEMethod()
elif self.quant_config.layer_quant_mode.is_int4_weight_only_per_group(
Expand All @@ -339,6 +351,19 @@ def create_weights(self):
self._weights_created = True
self._check_configs()

@property
def moe_op_impl(self) -> MoEOp:
"""
Lazily initialize and return the MoE op.

The op is selected based on hardware capabilities and quantization
configuration, which are only available after weights are created.
"""
if self._moe_op_impl is None:
assert self._weights_created, "Weights must be created before accessing moe_op"
self._moe_op_impl = MoEOpSelector.select_op(self)
return self._moe_op_impl

def dummy_allreduce(self):
"""
Debug function for eliminating imbalance during performance analysis.
Expand Down Expand Up @@ -389,8 +414,9 @@ def forward_chunk(
if self.layer_load_balancer and is_first_call:
self.layer_load_balancer.start_wait_gpu_stage()

use_deepseek_fp8_block_scale = False
use_w4_group_scaling = False
if not use_all_to_all or self.alltoall_method_type != AlltoallMethodType.MNNVL:
pass

weight_dtype = self.w3_w1_weight.dtype

token_selected_experts, token_final_scales = self.routing_method.apply(
Expand Down Expand Up @@ -544,9 +570,8 @@ def forward_chunk(
x_sf = x_sf.view((x_row, -1))

elif self.has_deepseek_fp8_block_scales:
use_deepseek_fp8_block_scale = True
pass
elif self.has_w4afp8:
use_w4_group_scaling = True
weight_dtype = torch.quint4x2
else:
raise ValueError(
Expand All @@ -569,12 +594,8 @@ def forward_chunk(
sizes=None if use_dp_padding else all_rank_num_tokens)
x_row = x.shape[0]

ep_size = self.ep_size
ep_rank = self.ep_rank
w3_w1_weight = self.w3_w1_weight
w2_weight = self.w2_weight
cluster_size = self.cluster_size
cluster_rank = self.cluster_rank
quant_scales = self.quant_scales

if self.alltoall_method_type == AlltoallMethodType.MNNVL:
Expand Down Expand Up @@ -640,7 +661,8 @@ def forward_chunk(
f"Not available alltoall method type: {self.alltoall_method_type!r}"
)

final_hidden_states = torch.ops.trtllm.fused_moe(
final_hidden_states = self.moe_op_impl.run_moe(
self,
x,
token_selected_slots,
token_final_scales,
Expand All @@ -652,17 +674,8 @@ def forward_chunk(
quant_scales=quant_scales,
input_sf=x_sf,
swizzled_input_sf=False,
tp_size=self.tp_size,
tp_rank=self.tp_rank,
ep_size=ep_size,
ep_rank=ep_rank,
cluster_size=cluster_size,
cluster_rank=cluster_rank,
enable_alltoall=use_all_to_all,
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
use_w4_group_scaling=use_w4_group_scaling,
min_latency_mode=False,
tune_max_num_tokens=self.tune_max_num_tokens,
use_fused_finalize=True,
tuner_num_tokens=tuner_num_tokens,
tuner_top_k=tuner_top_k,
)
Expand Down
17 changes: 17 additions & 0 deletions tensorrt_llm/_torch/modules/fused_moe/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MoE ops module for different computation implementations."""

from .moe_op import MoEOp, MoEOpSelector
from .moe_op_cutlass import CutlassMoEOp
from .moe_op_deepgemm import DeepGemmMoEOp

__all__ = ['MoEOp', 'MoEOpSelector', 'CutlassMoEOp', 'DeepGemmMoEOp']
230 changes: 230 additions & 0 deletions tensorrt_llm/_torch/modules/fused_moe/ops/moe_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
MoE Op abstraction for supporting different MoE computation implementations.
This module provides a unified interface for different MoE ops (Cutlass, DeepGemm, etc.)
"""

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List, Optional

import torch

from tensorrt_llm._utils import get_sm_version

if TYPE_CHECKING:
from ..interface import MoE


class MoEOp(ABC):
"""Abstract base class for MoE computation ops.

This class provides a strategy pattern for different MoE computation implementations.
It is used by MoE modules (like WideEPMoE) to delegate the actual computation.

Note: MoEOp is NOT a MoE module itself, but a computation strategy.
The actual MoE module (e.g., WideEPMoE) inherits from MoE and uses MoEOp
for the computation implementation.
"""

# Op-specific abstract methods
@abstractmethod
def finalize_tactic(
self,
module: 'MoE',
tuner_input: torch.Tensor,
output_dtype: torch.dtype,
min_latency_mode: bool = False,
use_fused_finalize: bool = True,
tuner_top_k: Optional[int] = None,
) -> None:
"""
Finalize tactics for the MoE computation.
For Cutlass op, this includes profiling and tactic selection.
For DeepGemm op, this can be a no-op.

Args:
module: The MoE module containing MoE configurations
tuner_input: Real input used for tuning (same shape/layout as non-alltoall)
output_dtype: Output dtype for tuner run
min_latency_mode: Whether to profile for min-latency path
use_fused_finalize: Whether to use fused finalize
tuner_top_k: Top-k value for tuning (Cutlass specific)
"""

@abstractmethod
def compute_moe(
self,
module: 'MoE',
# Input tensors
x: torch.Tensor,
token_selected_slots: torch.Tensor,
token_final_scales: Optional[torch.Tensor],
# Weight tensors
w3_w1_weight: torch.Tensor,
w3_w1_bias: Optional[torch.Tensor],
w2_weight: torch.Tensor,
w2_bias: Optional[torch.Tensor],
# Output configuration
output_dtype: torch.dtype,
# Quantization parameters
quant_scales: List[torch.Tensor],
input_sf: Optional[torch.Tensor] = None,
swizzled_input_sf: bool = True,
# Performance tuning (only runtime-variable parameters)
min_latency_mode: bool = False,
use_fused_finalize: bool = True,
tuner_num_tokens: Optional[int] = None,
tuner_top_k: Optional[int] = None,
**kwargs) -> torch.Tensor:
"""
Perform the actual MoE computation.

Configuration parameters (tp_size, ep_size, swiglu params, etc.) are
automatically extracted from the module parameter.

Args:
module: MoE module containing configuration and parameters.
x: Input tensor
token_selected_slots: Selected expert slots
token_final_scales: Scaling factors
w3_w1_weight: Fused gate and up projection weights
w3_w1_bias: Optional bias
w2_weight: Down projection weights
w2_bias: Optional bias
output_dtype: Output data type
quant_scales: Quantization scales
input_sf: Input scaling factor
swizzled_input_sf: Whether input_sf is swizzled
min_latency_mode: Use minimum latency optimizations
use_fused_finalize: Use fused finalization
tuner_num_tokens: Number of tokens for tuning
tuner_top_k: Top-k value for tuning

Returns:
Computed MoE output tensor
"""

def run_moe(
self,
module: 'MoE',
# Input tensors
input: torch.Tensor,
token_selected_slots: torch.Tensor,
token_final_scales: torch.Tensor,
w3_w1_weight: torch.Tensor,
w3_w1_bias: Optional[torch.Tensor],
w2_weight: torch.Tensor,
w2_bias: Optional[torch.Tensor],
output_dtype: torch.dtype,
# Quantization parameters
quant_scales: List[torch.Tensor],
input_sf: Optional[torch.Tensor] = None,
swizzled_input_sf: bool = True,
# Performance tuning (only runtime-variable parameters)
min_latency_mode: bool = False,
use_fused_finalize: bool = True,
tuner_num_tokens: Optional[int] = None,
tuner_top_k: Optional[int] = None,
**kwargs) -> torch.Tensor:
"""
Run the complete MoE computation pipeline.

Configuration parameters are automatically extracted from the module.

Args:
module: MoE module containing configuration
input: Input tensor to the MoE layer
token_selected_slots: Selected expert slots for each token
token_final_scales: Final scaling factors for each token
w3_w1_weight: Concatenated weights for w3 and w1 projections
w3_w1_bias: Optional bias for w3/w1 projections
w2_weight: Weight for w2 projection
w2_bias: Optional bias for w2 projection
output_dtype: Desired output data type
quant_scales: Quantization scales for weights
input_sf: Optional input scale factors for quantization
swizzled_input_sf: Whether input scale factors are swizzled
min_latency_mode: Use minimum latency optimizations
use_fused_finalize: Use fused finalization
tuner_num_tokens: Number of tokens for tuner input
tuner_top_k: Top-k value for tuning

Returns:
Computed MoE output tensor
"""
self.finalize_tactic(module, input, output_dtype, min_latency_mode,
use_fused_finalize, tuner_top_k)

# Call compute_moe with module
return self.compute_moe(module=module,
x=input,
token_selected_slots=token_selected_slots,
token_final_scales=token_final_scales,
w3_w1_weight=w3_w1_weight,
w3_w1_bias=w3_w1_bias,
w2_weight=w2_weight,
w2_bias=w2_bias,
output_dtype=output_dtype,
quant_scales=quant_scales,
input_sf=input_sf,
swizzled_input_sf=swizzled_input_sf,
min_latency_mode=min_latency_mode,
use_fused_finalize=use_fused_finalize,
tuner_num_tokens=tuner_num_tokens,
tuner_top_k=tuner_top_k,
**kwargs)


class MoEOpSelector:
"""
Utility class for selecting the appropriate MoE op based on
hardware capabilities and quantization configuration.

This class implements the strategy pattern for op selection,
choosing between Cutlass and DeepGemm implementations based on:
- Hardware capabilities (SM version)
- Quantization configuration (block FP8 support)
"""

@staticmethod
def select_op(module: 'MoE') -> MoEOp:
"""
Select the appropriate MoE op based on module configuration.

Selection criteria:
- Blackwell (SM100) with block FP8 quantization -> DeepGemm op
- All other configurations -> Cutlass op

Args:
module: The MoE module containing configuration information

Returns:
MoEOp: Selected op instance (CutlassMoEOp or DeepGemmMoEOp)

Example:
>>> op = MoEOpSelector.select_op(moe_module)
>>> output = op.run_moe(input, ...)
"""
from .moe_op_cutlass import CutlassMoEOp
from .moe_op_deepgemm import DeepGemmMoEOp

# Check if we should use DeepGemm op
# Blackwell has SM version 100
is_blackwell = get_sm_version() == 100
has_block_fp8 = module.has_deepseek_fp8_block_scales

if is_blackwell and has_block_fp8:
# Use DeepGemm op for Blackwell with block FP8
return DeepGemmMoEOp()
else:
# Use Cutlass op for all other cases
return CutlassMoEOp()
Loading