From 681b5bf533a3c48f583dc9ca4c35bf7744945e0d Mon Sep 17 00:00:00 2001 From: xxi Date: Fri, 29 Aug 2025 06:08:15 +0000 Subject: [PATCH 1/2] feat: wide_ep support block-wise FP8 on blackwell Signed-off-by: xxi modified: tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py new file: tensorrt_llm/_torch/modules/fused_moe/moe_backend.py modified: tests/unittest/_torch/modules/test_fused_moe.py modified: tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py new file: tensorrt_llm/_torch/modules/fused_moe/moe_backend.py modified: tests/unittest/_torch/modules/test_fused_moe.py --- .../modules/fused_moe/fused_moe_wide_ep.py | 51 +- .../_torch/modules/fused_moe/moe_backend.py | 791 ++++++++++++++++++ .../unittest/_torch/modules/test_fused_moe.py | 155 ++++ 3 files changed, 978 insertions(+), 19 deletions(-) create mode 100644 tensorrt_llm/_torch/modules/fused_moe/moe_backend.py diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py index a9fb53a3b85..738485901e5 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py @@ -5,8 +5,9 @@ import torch from tensorrt_llm._mnnvl_utils import MnnvlMemory, MnnvlMoe, MoEAlltoallInfo -from tensorrt_llm._utils import logger +from tensorrt_llm._utils import get_sm_version from tensorrt_llm.functional import AllReduceStrategy +from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping from ...distributed import AllReduce, allgather, reducescatter @@ -15,8 +16,10 @@ from ...utils import AuxStreamType, EventType, Fp4QuantizedTensor from .deep_ep_utils import buffer_pool, deep_ep_installed from .interface import MoE +from .moe_backend import MoEBackend, MoEBackendSelection from .moe_load_balancer import get_moe_load_balancer from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethod, + DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm, FP8QDQFusedMoEMethod, MoEWeightLoadingMode, NVFP4CutlassFusedMoEMethod, UnquantizedFusedMoEMethod, WInt4AFP8FusedMoEMethod) @@ -90,6 +93,9 @@ def __init__( # If True, the router weight will be multiplied on the input rather than at the end of FC2 self.apply_router_weight_on_input = apply_router_weight_on_input + # Store original hidden size before any potential padding + self.unpadded_hidden_size = self.hidden_size + moe_load_balancer = get_moe_load_balancer() self.layer_load_balancer = None self.repeat_idx = 0 @@ -227,6 +233,9 @@ def __init__( self.enable_dummy_allreduce = os.environ.get( "TRTLLM_ENABLE_DUMMY_ALLREDUCE", "0") == "1" + # MoE backend will be lazily initialized when first accessed (see moe_backend property) + self._moe_backend_impl = None + def _check_configs(self): assert self._weights_created @@ -316,7 +325,10 @@ def _get_quant_method(self): if self.quant_config.layer_quant_mode.has_fp8_qdq(): return FP8QDQFusedMoEMethod() elif self.quant_config.layer_quant_mode.has_fp8_block_scales(): - return DeepSeekFP8BlockScalesFusedMoEMethod() + if get_sm_version() == 100: + return DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm() + else: + return DeepSeekFP8BlockScalesFusedMoEMethod() elif self.quant_config.layer_quant_mode.has_nvfp4(): return NVFP4CutlassFusedMoEMethod() elif self.quant_config.layer_quant_mode.is_int4_weight_only_per_group( @@ -339,6 +351,19 @@ def create_weights(self): self._weights_created = True self._check_configs() + @property + def moe_backend_impl(self) -> MoEBackend: + """ + Lazily initialize and return the MoE backend. + + The backend is selected based on hardware capabilities and quantization + configuration, which are only available after weights are created. + """ + if self._moe_backend_impl is None: + assert self._weights_created, "Weights must be created before accessing moe_backend" + self._moe_backend_impl = MoEBackendSelection.select_backend(self) + return self._moe_backend_impl + def dummy_allreduce(self): """ Debug function for eliminating imbalance during performance analysis. @@ -391,6 +416,7 @@ def forward_chunk( use_deepseek_fp8_block_scale = False use_w4_group_scaling = False + weight_dtype = self.w3_w1_weight.dtype token_selected_experts, token_final_scales = self.routing_method.apply( @@ -544,9 +570,8 @@ def forward_chunk( x_sf = x_sf.view((x_row, -1)) elif self.has_deepseek_fp8_block_scales: - use_deepseek_fp8_block_scale = True + pass elif self.has_w4afp8: - use_w4_group_scaling = True weight_dtype = torch.quint4x2 else: raise ValueError( @@ -569,12 +594,8 @@ def forward_chunk( sizes=None if use_dp_padding else all_rank_num_tokens) x_row = x.shape[0] - ep_size = self.ep_size - ep_rank = self.ep_rank w3_w1_weight = self.w3_w1_weight w2_weight = self.w2_weight - cluster_size = self.cluster_size - cluster_rank = self.cluster_rank quant_scales = self.quant_scales if self.alltoall_method_type == AlltoallMethodType.MNNVL: @@ -640,7 +661,8 @@ def forward_chunk( f"Not available alltoall method type: {self.alltoall_method_type!r}" ) - final_hidden_states = torch.ops.trtllm.fused_moe( + final_hidden_states = self.moe_backend_impl.run_moe( + self, x, token_selected_slots, token_final_scales, @@ -652,17 +674,8 @@ def forward_chunk( quant_scales=quant_scales, input_sf=x_sf, swizzled_input_sf=False, - tp_size=self.tp_size, - tp_rank=self.tp_rank, - ep_size=ep_size, - ep_rank=ep_rank, - cluster_size=cluster_size, - cluster_rank=cluster_rank, - enable_alltoall=use_all_to_all, - use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale, - use_w4_group_scaling=use_w4_group_scaling, min_latency_mode=False, - tune_max_num_tokens=self.tune_max_num_tokens, + use_fused_finalize=True, tuner_num_tokens=tuner_num_tokens, tuner_top_k=tuner_top_k, ) diff --git a/tensorrt_llm/_torch/modules/fused_moe/moe_backend.py b/tensorrt_llm/_torch/modules/fused_moe/moe_backend.py new file mode 100644 index 00000000000..ae99c3d85bb --- /dev/null +++ b/tensorrt_llm/_torch/modules/fused_moe/moe_backend.py @@ -0,0 +1,791 @@ +""" +MoE Backend abstraction for supporting different MoE computation implementations. +This module provides a unified interface for different MoE backends (Cutlass, DeepGemm, etc.) +""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Dict, List, Optional + +import torch + +from tensorrt_llm._utils import get_sm_version + +if TYPE_CHECKING: + from .interface import MoE + + +class MoEBackend(ABC): + """Abstract base class for MoE computation backends. + + This class provides a strategy pattern for different MoE computation implementations. + It is used by MoE modules (like WideEPMoE) to delegate the actual computation. + + Note: MoEBackend is NOT a MoE module itself, but a computation strategy. + The actual MoE module (e.g., WideEPMoE) inherits from MoE and uses MoEBackend + for the computation implementation. + """ + + # Backend-specific abstract methods + @abstractmethod + def finalize_tactic( + self, + module: 'MoE', + tuner_input: torch.Tensor, + output_dtype: torch.dtype, + min_latency_mode: bool = False, + use_fused_finalize: bool = True, + tuner_top_k: Optional[int] = None, + ) -> None: + """ + Finalize tactics for the MoE computation. + For Cutlass backend, this includes profiling and tactic selection. + For DeepGemm backend, this can be a no-op. + + Args: + module: The MoE module containing MoE configurations + tuner_input: Real input used for tuning (same shape/layout as non-alltoall) + output_dtype: Output dtype for tuner run + min_latency_mode: Whether to profile for min-latency path + use_fused_finalize: Whether to use fused finalize + tuner_top_k: Top-k value for tuning (Cutlass specific) + """ + + @abstractmethod + def compute_moe( + self, + module: 'MoE', + # Input tensors + x: torch.Tensor, + token_selected_slots: torch.Tensor, + token_final_scales: Optional[torch.Tensor], + # Weight tensors + w3_w1_weight: torch.Tensor, + w3_w1_bias: Optional[torch.Tensor], + w2_weight: torch.Tensor, + w2_bias: Optional[torch.Tensor], + # Output configuration + output_dtype: torch.dtype, + # Quantization parameters + quant_scales: List[torch.Tensor], + input_sf: Optional[torch.Tensor] = None, + swizzled_input_sf: bool = True, + # Performance tuning (only runtime-variable parameters) + min_latency_mode: bool = False, + use_fused_finalize: bool = True, + tuner_num_tokens: Optional[int] = None, + tuner_top_k: Optional[int] = None, + **kwargs) -> torch.Tensor: + """ + Perform the actual MoE computation. + + Configuration parameters (tp_size, ep_size, swiglu params, etc.) are + automatically extracted from the module parameter. + + Args: + module: MoE module containing configuration and parameters. + The following will be extracted: + - tp_size, tp_rank, ep_size, ep_rank, cluster_size, cluster_rank + - enable_alltoall, tune_max_num_tokens + - swiglu_alpha, swiglu_beta, swiglu_limit + - Quantization flags based on module properties + x: Input tensor + token_selected_slots: Selected expert slots + token_final_scales: Scaling factors + w3_w1_weight: Fused gate and up projection weights + w3_w1_bias: Optional bias + w2_weight: Down projection weights + w2_bias: Optional bias + output_dtype: Output data type + quant_scales: Quantization scales + input_sf: Input scaling factor + swizzled_input_sf: Whether input_sf is swizzled + min_latency_mode: Use minimum latency optimizations + use_fused_finalize: Use fused finalization + tuner_num_tokens: Number of tokens for tuning + tuner_top_k: Top-k value for tuning + + Returns: + Computed MoE output tensor + """ + + def run_moe( + self, + module: 'MoE', + # Input tensors + input: torch.Tensor, + token_selected_slots: torch.Tensor, + token_final_scales: torch.Tensor, + w3_w1_weight: torch.Tensor, + w3_w1_bias: Optional[torch.Tensor], + w2_weight: torch.Tensor, + w2_bias: Optional[torch.Tensor], + output_dtype: torch.dtype, + # Quantization parameters + quant_scales: List[torch.Tensor], + input_sf: Optional[torch.Tensor] = None, + swizzled_input_sf: bool = True, + # Performance tuning (only runtime-variable parameters) + min_latency_mode: bool = False, + use_fused_finalize: bool = True, + tuner_num_tokens: Optional[int] = None, + tuner_top_k: Optional[int] = None, + **kwargs) -> torch.Tensor: + """ + Run the complete MoE computation pipeline. + + Configuration parameters are automatically extracted from the module. + + Args: + module: MoE module containing configuration + input: Input tensor to the MoE layer + token_selected_slots: Selected expert slots for each token + token_final_scales: Final scaling factors for each token + w3_w1_weight: Concatenated weights for w3 and w1 projections + w3_w1_bias: Optional bias for w3/w1 projections + w2_weight: Weight for w2 projection + w2_bias: Optional bias for w2 projection + output_dtype: Desired output data type + quant_scales: Quantization scales for weights + input_sf: Optional input scale factors for quantization + swizzled_input_sf: Whether input scale factors are swizzled + min_latency_mode: Use minimum latency optimizations + use_fused_finalize: Use fused finalization + tuner_num_tokens: Number of tokens for tuner input + tuner_top_k: Top-k value for tuning + + Returns: + Computed MoE output tensor + """ + self.finalize_tactic(module, input, output_dtype, min_latency_mode, + use_fused_finalize, tuner_top_k) + + # Call compute_moe with module + return self.compute_moe(module=module, + x=input, + token_selected_slots=token_selected_slots, + token_final_scales=token_final_scales, + w3_w1_weight=w3_w1_weight, + w3_w1_bias=w3_w1_bias, + w2_weight=w2_weight, + w2_bias=w2_bias, + output_dtype=output_dtype, + quant_scales=quant_scales, + input_sf=input_sf, + swizzled_input_sf=swizzled_input_sf, + min_latency_mode=min_latency_mode, + use_fused_finalize=use_fused_finalize, + tuner_num_tokens=tuner_num_tokens, + tuner_top_k=tuner_top_k, + **kwargs) + + +class MoECutlassBackend(MoEBackend): + """Cutlass-based MoE backend using torch.ops.trtllm.fused_moe.""" + + def __init__(self): + """Initialize the Cutlass backend.""" + super().__init__() + self.moe_runner = None + self.gemm_tactics = None + + def finalize_tactic( + self, + module: 'MoE', + tuner_input: torch.Tensor, + output_dtype: torch.dtype, + min_latency_mode: bool = False, + use_fused_finalize: bool = True, + tuner_top_k: Optional[int] = None, + ) -> None: + """ + Finalize tactics for Cutlass MoE by profiling and selecting optimal GEMM tactics. + """ + + # Import necessary modules for profiling + from ...custom_ops.torch_custom_ops import AutoTuner, MoERunner + + # Use real tuner_input rather than dummy input + assert tuner_input is not None, "tuner_input must be provided to finalize_tactic" + if tuner_top_k is None: + tuner_top_k = getattr(module.routing_method, 'experts_per_token', 1) + + # Determine view dtype for weights to match runtime quantization layout + weight_view_dtype = module.w3_w1_weight.dtype + if getattr(module, 'has_w4afp8', False): + weight_view_dtype = torch.quint4x2 + elif getattr(module, 'has_w4a16_mxfp4', False): + weight_view_dtype = torch.uint8 + + # Create MoERunner for profiling + if self.moe_runner is None: + self.moe_runner = MoERunner( + x_dtype=tuner_input.dtype, + weight_dtype=module.w3_w1_weight.dtype, + output_dtype=output_dtype, + top_k=tuner_top_k, + tp_size=module.tp_size, + tp_rank=module.tp_rank, + ep_size=module.ep_size, + ep_rank=module.ep_rank, + cluster_size=module.cluster_size, + cluster_rank=module.cluster_rank, + use_deepseek_fp8_block_scale=getattr( + module, 'has_deepseek_fp8_block_scales', False), + use_w4_group_scaling=getattr(module, 'has_w4afp8', False), + use_int8_woq_per_channel=getattr(module, + 'has_int8_woq_per_channel', + False), + use_mxfp8_act_scaling=getattr(module, 'has_mxfp8_act_scaling', + False), + min_latency_mode=min_latency_mode, + use_fused_finalize=use_fused_finalize, + ) + + # Set tuning configuration + MoERunner.tuning_config.tune_max_num_tokens = getattr( + module, 'tune_max_num_tokens', 8192) + + # Get AutoTuner for tactic selection + tuner = AutoTuner.get() + + # Profile and select tactics (GEMM1) + _, gemm_tactic_1 = tuner.choose_one( + "trtllm::fused_moe::gemm1", + [self.moe_runner], + MoERunner.tuning_config, + [ + tuner_input, + module.w3_w1_weight.view(weight_view_dtype), + getattr(module, 'w3_w1_bias', None), + module.w2_weight.view(weight_view_dtype), + getattr(module, 'w2_bias', None), + ], + gemm_idx=1, + ) + + # Profile and select tactics (GEMM2) + _, gemm_tactic_2 = tuner.choose_one( + "trtllm::fused_moe::gemm2", + [self.moe_runner], + MoERunner.tuning_config, + [ + tuner_input, + module.w3_w1_weight.view(weight_view_dtype), + getattr(module, 'w3_w1_bias', None), + module.w2_weight.view(weight_view_dtype), + getattr(module, 'w2_bias', None), + ], + gemm_idx=2, + ) + + # Store selected tactics + self.gemm_tactics = [gemm_tactic_1, gemm_tactic_2] + + def compute_moe( + self, + module: 'MoE', # Now required as first parameter + # Input tensors + x: torch.Tensor, + token_selected_slots: torch.Tensor, + token_final_scales: Optional[torch.Tensor], + # Weight tensors + w3_w1_weight: torch.Tensor, + w3_w1_bias: Optional[torch.Tensor], + w2_weight: torch.Tensor, + w2_bias: Optional[torch.Tensor], + # Output configuration + output_dtype: torch.dtype, + # Quantization parameters + quant_scales: List[torch.Tensor], + input_sf: Optional[torch.Tensor] = None, + swizzled_input_sf: bool = True, + # Performance tuning (only runtime-variable parameters) + min_latency_mode: bool = False, + use_fused_finalize: bool = True, + tuner_num_tokens: Optional[int] = None, + tuner_top_k: Optional[int] = None, + **kwargs) -> torch.Tensor: + """ + Compute MoE using Cutlass backend with MoERunner. + """ + # Extract parameters from module + tp_size = module.tp_size + tp_rank = module.tp_rank + ep_size = module.ep_size + ep_rank = module.ep_rank + cluster_size = module.cluster_size + cluster_rank = module.cluster_rank + enable_alltoall = module.enable_alltoall + getattr(module, 'tune_max_num_tokens', 8192) + swiglu_alpha = module.swiglu_alpha + swiglu_beta = module.swiglu_beta + swiglu_limit = module.swiglu_limit + use_w4_group_scaling = getattr(module, 'has_w4afp8', False) + + # Determine weight dtype for view operation if needed + weight_dtype = w3_w1_weight.dtype + if use_w4_group_scaling and weight_dtype != torch.quint4x2: + weight_dtype = torch.quint4x2 + + # Validate that tactics have been finalized + if self.gemm_tactics is None or len(self.gemm_tactics) == 0: + raise RuntimeError( + "GEMM tactics have not been finalized. " + "Call finalize_tactic() before compute_moe() or use run_moe() instead." + ) + + if self.moe_runner is None: + raise RuntimeError( + "MoERunner has not been initialized. " + "Call finalize_tactic() before compute_moe() or use run_moe() instead." + ) + + # Select the appropriate run method based on latency mode + run_moe = self.moe_runner.fused_moe_runner.run_moe_min_latency if min_latency_mode else self.moe_runner.fused_moe_runner.run_moe + + # Run the actual MoE computation + output = run_moe( + x, + token_selected_slots, + token_final_scales, + w3_w1_weight.view(weight_dtype), + w3_w1_bias, + w2_weight.view(weight_dtype), + w2_bias, + quant_scales, + input_sf, + swizzled_input_sf, + swiglu_alpha, + swiglu_beta, + swiglu_limit, + tp_size, + tp_rank, + ep_size, + ep_rank, + cluster_size, + cluster_rank, + enable_alltoall, + min_latency_mode, + self.gemm_tactics, + ) + + # Return output based on latency mode + return output if min_latency_mode else [output] + + def run_moe( + self, + module: 'MoE', + # Input tensors + input: torch.Tensor, + token_selected_slots: torch.Tensor, + token_final_scales: torch.Tensor, + w3_w1_weight: torch.Tensor, + w3_w1_bias: Optional[torch.Tensor], + w2_weight: torch.Tensor, + w2_bias: Optional[torch.Tensor], + output_dtype: torch.dtype, + # Quantization parameters + quant_scales: List[torch.Tensor], + input_sf: Optional[torch.Tensor] = None, + swizzled_input_sf: bool = True, + # Performance tuning (only runtime-variable parameters) + min_latency_mode: bool = False, + use_fused_finalize: bool = True, + tuner_num_tokens: Optional[int] = None, + tuner_top_k: Optional[int] = None, + **kwargs) -> torch.Tensor: + """ + Run the complete MoE computation pipeline for Cutlass backend. + + This override handles the specific tuner_input logic needed for Cutlass. + + Args: + module: MoE module containing configuration + input: Input tensor to the MoE layer + token_selected_slots: Selected expert slots for each token + token_final_scales: Final scaling factors for each token + w3_w1_weight: Concatenated weights for w3 and w1 projections + w3_w1_bias: Optional bias for w3/w1 projections + w2_weight: Weight for w2 projection + w2_bias: Optional bias for w2 projection + output_dtype: Desired output data type + quant_scales: Quantization scales for weights + input_sf: Optional input scale factors for quantization + swizzled_input_sf: Whether input scale factors are swizzled + min_latency_mode: Use minimum latency optimizations + use_fused_finalize: Use fused finalization + tuner_num_tokens: Number of tokens for tuner input + tuner_top_k: Top-k value for tuning + + Returns: + Computed MoE output tensor + """ + # Extract enable_alltoall from module to determine tuner_input logic + enable_alltoall = module.enable_alltoall + + # Compute tuner_input per fused_moe logic + if enable_alltoall: + assert tuner_num_tokens is not None + assert tuner_top_k is not None + tuner_input = input[:tuner_num_tokens] + else: + assert tuner_num_tokens is None + assert tuner_top_k is None + tuner_input = input + tuner_top_k = token_selected_slots.size(1) + + self.finalize_tactic(module, tuner_input, output_dtype, + min_latency_mode, use_fused_finalize, tuner_top_k) + + # Call compute_moe with module + return self.compute_moe(module=module, + x=input, + token_selected_slots=token_selected_slots, + token_final_scales=token_final_scales, + w3_w1_weight=w3_w1_weight, + w3_w1_bias=w3_w1_bias, + w2_weight=w2_weight, + w2_bias=w2_bias, + output_dtype=output_dtype, + quant_scales=quant_scales, + input_sf=input_sf, + swizzled_input_sf=swizzled_input_sf, + min_latency_mode=min_latency_mode, + use_fused_finalize=use_fused_finalize, + tuner_num_tokens=tuner_num_tokens, + tuner_top_k=tuner_top_k, + **kwargs) + + +class MoEDeepGemmBackend(MoEBackend): + """DeepGemm-based MoE backend for GB200 block FP8.""" + + def __init__(self): + """Initialize DeepGemm backend.""" + super().__init__() + import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils + self.fp8_utils = fp8_utils + + from .fused_moe_deepgemm import deepgemm_fp8_group_blockwise_gemm + self.deepgemm_fp8_group_blockwise_gemm = deepgemm_fp8_group_blockwise_gemm + + def finalize_tactic( + self, + module: 'MoE', + tuner_input: torch.Tensor, + output_dtype: torch.dtype, + min_latency_mode: bool = False, + use_fused_finalize: bool = True, + tuner_top_k: Optional[int] = None, + ) -> None: + """ + No-op for DeepGemm backend as it doesn't require tactic profiling. + + Args: + module: The MoE module + tuner_input: Input tensor for tuning + output_dtype: Output dtype + min_latency_mode: Whether to use min-latency mode + use_fused_finalize: Whether to use fused finalize + tuner_top_k: Top-k value for tuning + """ + + def _get_deepgemm_workspace(self, module: 'MoE', m_max: int, + group_size: int) -> Dict[str, torch.Tensor]: + """ + Get workspace for DeepGemm backend operations. + + Args: + module: The MoE module containing configuration + m_max: Maximum number of tokens (aligned) + group_size: Group size for quantization + + Returns: + Dictionary containing workspace tensors + """ + import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils + + # Get dimensions from module + hidden_size = module.hidden_size + intermediate_size = module.intermediate_size + expert_size_per_partition = module.expert_size_per_partition + + # Calculate aligned dimensions + m_padded = fp8_utils.align(m_max, 4) + fp8_dim = max(hidden_size, intermediate_size) + scale_k = fp8_utils.ceil_div(fp8_dim, group_size) + scale_k_padded = fp8_utils.align(scale_k, 4) + + # Allocate workspace tensors + workspace = {} + + # Workspace for FP8 activations + workspace["workspace_0"] = torch.empty( + (expert_size_per_partition * m_max * fp8_dim), + dtype=torch.float8_e4m3fn, + device='cuda') + + # Workspace for intermediate results + workspace["workspace_1"] = torch.empty( + (expert_size_per_partition * m_max * + max(intermediate_size * 2, hidden_size)), + dtype=torch.bfloat16, + device='cuda') + + # Workspace for scaling factors + workspace["workspace_sf"] = torch.empty( + expert_size_per_partition * (scale_k_padded // 4) * m_padded, + dtype=torch.int32, + device='cuda') + + return workspace + + def compute_moe( + self, + module: 'MoE', + # Input tensors + x: torch.Tensor, + token_selected_slots: torch.Tensor, + token_final_scales: Optional[torch.Tensor], + # Weight tensors + w3_w1_weight: torch.Tensor, + w3_w1_bias: Optional[torch.Tensor], + w2_weight: torch.Tensor, + w2_bias: Optional[torch.Tensor], + # Output configuration + output_dtype: torch.dtype, + # Quantization parameters + quant_scales: List[torch.Tensor], + input_sf: Optional[torch.Tensor] = None, + swizzled_input_sf: bool = True, + # Performance tuning (only runtime-variable parameters) + min_latency_mode: bool = False, + use_fused_finalize: bool = True, + tuner_num_tokens: Optional[int] = None, + tuner_top_k: Optional[int] = None, + **kwargs) -> torch.Tensor: + """ + Compute MoE using DeepGemm backend with block FP8 quantization. + + Note: This assumes the data has already been gathered/alltoall'd + by the WideEP forward_chunk method. + """ + + # Import necessary functions for DeepGemm + from .fused_moe_deepgemm import (masked_index_copy_group_quant_fp8, + preprocess_after_permute, set_strides, + triton_masked_index_gather) + + # Extract parameters from module + tp_size = module.tp_size + tp_rank = module.tp_rank + ep_size = module.ep_size + ep_rank = module.ep_rank + cluster_size = module.cluster_size + cluster_rank = module.cluster_rank + enable_alltoall = module.enable_alltoall + getattr(module, 'tune_max_num_tokens', 8192) + module.swiglu_alpha + module.swiglu_beta + module.swiglu_limit + + # Not supported: min_latency_mode. Raise error if enabled. + if min_latency_mode: + raise NotImplementedError( + "DeepGemm backend does not support min_latency_mode=True") + + # Get expert configuration from module + expert_size_per_partition = module.expert_size_per_partition + intermediate_size = module.intermediate_size + hidden_size = x.shape[1] + + # Permute the data for expert-parallel processing + ( + permuted_row_to_unpermuted_row_tensor, + permuted_token_selected_experts_tensor, + permuted_data_tensor, + expert_first_token_offset_tensor, + permuted_token_final_scales_tensor, + unpermuted_row_to_permuted_row_tensor, + ) = torch.ops.trtllm.moe_permute_op( + x, + token_selected_slots, + token_final_scales, + None, # w3_w1_weight + None, # w2_weight + None, # quant_scales + input_sf=input_sf, + num_experts_on_rank=expert_size_per_partition, + tp_size=tp_size, + tp_rank=tp_rank, + ep_size=ep_size, + ep_rank=ep_rank, + cluster_size=cluster_size, + cluster_rank=cluster_rank, + min_latency_mode=min_latency_mode, + use_fp8_block_scaling=True, # Always use block scaling for DeepGemm + ) + + if permuted_data_tensor.numel() == 0: + return torch.zeros_like(x) + + # Preprocess for masked operations + masked_m, token_to_expert_map = preprocess_after_permute( + expert_first_token_offset_tensor, permuted_data_tensor) + + expected_m = (token_selected_slots.numel() + expert_size_per_partition - + 1) // expert_size_per_partition + + # Get workspace for DeepGemm operations + m_max = self.fp8_utils.align(x.shape[0], 128) + workspace = self._get_deepgemm_workspace(module, m_max, 128) + + # Padding and quantization for first GEMM input + m_padded = self.fp8_utils.align(m_max, 4) + scale_k = self.fp8_utils.ceil_div(hidden_size, 128) + scale_k_padded = self.fp8_utils.align(scale_k, 4) + + act_input_fp8 = set_strides(workspace["workspace_0"], + expert_size_per_partition, m_max, + hidden_size) + act_input_sf = set_strides(workspace["workspace_sf"], + expert_size_per_partition, + scale_k_padded // 4, m_padded) + + # Quantize and copy input with masking + act_input_sf = masked_index_copy_group_quant_fp8( + act_input_fp8, + act_input_sf, + permuted_data_tensor, + expert_first_token_offset_tensor, + token_to_expert_map, + group_size=128) + + # First grouped GEMM (w3 and w1) + h1 = set_strides(workspace["workspace_1"], expert_size_per_partition, + m_max, intermediate_size * 2) + + self.deepgemm_fp8_group_blockwise_gemm( + d=h1, + a=act_input_fp8, + b=w3_w1_weight, + sfa=act_input_sf, + sfb=quant_scales[0] if quant_scales else None, + masked_m=masked_m, + expected_m=expected_m, + ) + + # SiLU activation and quantization for second GEMM + act_input_fp8 = set_strides(workspace["workspace_0"], + expert_size_per_partition, m_max, + intermediate_size) + + scale_k = self.fp8_utils.ceil_div(intermediate_size, 128) + scale_k_padded = self.fp8_utils.align(scale_k, 4) + act_input_sf = set_strides(workspace["workspace_sf"], + expert_size_per_partition, + scale_k_padded // 4, m_padded) + + act_input_sf = self.fp8_utils.silu_and_mul_masked_post_quant_fwd( + output=act_input_fp8, + output_scale=act_input_sf, + input=h1, + quant_group_size=128, + masked_m=masked_m, + scale_ue8m0=True) + + # Second grouped GEMM (w2) + h3 = set_strides(workspace["workspace_1"], expert_size_per_partition, + m_max, hidden_size) + + self.deepgemm_fp8_group_blockwise_gemm( + d=h3, + a=act_input_fp8, + b=w2_weight, + sfa=act_input_sf, + sfb=quant_scales[1] if quant_scales else None, + masked_m=masked_m, + expected_m=expected_m, + ) + + # Gather results back to original token order + triton_masked_index_gather(permuted_data_tensor, h3, + expert_first_token_offset_tensor, + token_to_expert_map) + + # Finalize and scale the output + # Get unpadded_hidden_size from module if available, otherwise use hidden_size + # For now it is the user's responsibility to set unpadded_hidden_size. + # DeepGemmFusedMoE and WideEPMoE both have unpadded_hidden_size. + unpadded_hidden_size = getattr(module, 'unpadded_hidden_size', + x.shape[1]) + + final_hidden_states = torch.ops.trtllm.moe_finalize_scale_op( + permuted_data_tensor, + None, # biases (w2_bias could be added here if needed) + token_final_scales, + unpermuted_row_to_permuted_row_tensor, + permuted_row_to_unpermuted_row_tensor, + token_selected_slots, + expert_first_token_offset_tensor, + enable_alltoall, + x.shape[0], # num_rows + x.shape[1], # hidden_size + unpadded_hidden_size, # unpadded_hidden_size (may be different from hidden_size if padding was applied) + module.routing_method.top_k if module else 1, # experts_per_token + expert_size_per_partition, # num_experts_per_node + tp_size, + tp_rank, + ep_size, + ep_rank, + ) + + return final_hidden_states if min_latency_mode else [ + final_hidden_states + ] + + +class MoEBackendSelection: + """ + Utility class for selecting the appropriate MoE backend based on + hardware capabilities and quantization configuration. + + This class implements the strategy pattern for backend selection, + choosing between Cutlass and DeepGemm implementations based on: + - Hardware capabilities (SM version) + - Quantization configuration (block FP8 support) + """ + + @staticmethod + def select_backend(module: 'MoE') -> MoEBackend: + """ + Select the appropriate MoE backend based on module configuration. + + Selection criteria: + - Blackwell (SM100) with block FP8 quantization -> DeepGemm backend + - All other configurations -> Cutlass backend + + Args: + module: The MoE module containing configuration information + Expected attributes: + - has_deepseek_fp8_block_scales: Whether block FP8 is enabled + + Returns: + MoEBackend: Selected backend instance (MoECutlassBackend or MoEDeepGemmBackend) + + Example: + >>> backend = MoEBackendSelection.select_backend(moe_module) + >>> output = backend.run_moe(input, ...) + """ + # Check if we should use DeepGemm backend + # Blackwell has SM version 100 + is_blackwell = get_sm_version() == 100 + has_block_fp8 = (hasattr(module, 'has_deepseek_fp8_block_scales') + and module.has_deepseek_fp8_block_scales) + + if is_blackwell and has_block_fp8: + # Use DeepGemm backend for Blackwell with block FP8 + return MoEDeepGemmBackend() + else: + # Use Cutlass backend for all other cases + return MoECutlassBackend() diff --git a/tests/unittest/_torch/modules/test_fused_moe.py b/tests/unittest/_torch/modules/test_fused_moe.py index 2158e5fef30..2461f85984d 100644 --- a/tests/unittest/_torch/modules/test_fused_moe.py +++ b/tests/unittest/_torch/modules/test_fused_moe.py @@ -638,6 +638,161 @@ def set_tensor_value_4(x, num_row, num_cols): x.copy_(repeated) +@skip_pre_blackwell +@pytest.mark.skipif(torch.cuda.device_count() < 4, + reason="needs 4 GPUs to run this test") +@pytest.mark.parametrize( + "alltoall_method_type", + [AlltoallMethodType.MNNVL, AlltoallMethodType.NotEnabled], + ids=lambda s: s.name) +def test_fused_moe_fp8_blockwise_wide_ep(alltoall_method_type): + """Test WideEPMoE with FP8 block-wise quantization using DeepGemmFusedMoE as reference.""" + + world_size = 4 + dtype = torch.bfloat16 + # Reduce model size to avoid MPI int32 overflow + HIDDEN_SIZE = 768 + INTERMEDIATE_SIZE = 512 + NUM_EXPERTS = 16 + TOP_K = 2 + MAX_NUM_TOKENS = 256 + + # The MPI can not support FP8, so create weights on each rank + def per_rank_test_fused_moe_alltoall_fp8_blockwise(job_id): + routing_method = DefaultMoeRoutingMethod(top_k=TOP_K) + mapping = Mapping(world_size=world_size, + rank=mpi_rank(), + tp_size=world_size, + moe_ep_size=world_size, + moe_tp_size=1, + enable_attention_dp=True) + torch.cuda.set_device(mapping.rank) + # Use same seed for all ranks to ensure consistency + torch.manual_seed(0) + torch.cuda.manual_seed(0) + + # Generate test data locally on each rank + x_list = [] + m = MAX_NUM_TOKENS + while m >= 1: + x = torch.randn((m, HIDDEN_SIZE), dtype=dtype, device="cuda") + set_tensor_value_2(x, m, HIDDEN_SIZE) + x_list.append(x) + m //= 2 + + # Generate weights locally on each rank (same weights due to same seed) + weights = {} + for expert_id in range(NUM_EXPERTS): + w1_weight = torch.randn( + (INTERMEDIATE_SIZE, HIDDEN_SIZE), dtype=dtype, + device="cuda") / HIDDEN_SIZE + w2_weight = torch.randn((HIDDEN_SIZE, INTERMEDIATE_SIZE), + dtype=dtype, + device="cuda") + w3_weight = torch.randn( + (INTERMEDIATE_SIZE, HIDDEN_SIZE), dtype=dtype, + device="cuda") / HIDDEN_SIZE + + set_tensor_value_3(w1_weight, INTERMEDIATE_SIZE, HIDDEN_SIZE) + set_tensor_value_4(w2_weight, HIDDEN_SIZE, INTERMEDIATE_SIZE) + set_tensor_value_3(w3_weight, INTERMEDIATE_SIZE, HIDDEN_SIZE) + + # FP8 block-wise quantization + w1_weight_fp8, w1_weight_scale = per_block_cast_to_fp8_e8m0( + w1_weight) + w1_weight_fp8 = w1_weight_fp8.view(torch.float8_e4m3fn).cuda() + + w2_weight_fp8, w2_weight_scale = per_block_cast_to_fp8_e8m0( + w2_weight) + w2_weight_fp8 = w2_weight_fp8.view(torch.float8_e4m3fn).cuda() + + w3_weight_fp8, w3_weight_scale = per_block_cast_to_fp8_e8m0( + w3_weight) + w3_weight_fp8 = w3_weight_fp8.view(torch.float8_e4m3fn).cuda() + + weights[f"{expert_id}.w1.weight"] = w1_weight_fp8 + weights[f"{expert_id}.w2.weight"] = w2_weight_fp8 + weights[f"{expert_id}.w3.weight"] = w3_weight_fp8 + weights[f"{expert_id}.w1.weight_scale_inv"] = w1_weight_scale + weights[f"{expert_id}.w2.weight_scale_inv"] = w2_weight_scale + weights[f"{expert_id}.w3.weight_scale_inv"] = w3_weight_scale + weights[f"{expert_id}.w1.weight_scale"] = w1_weight_scale + weights[f"{expert_id}.w2.weight_scale"] = w2_weight_scale + weights[f"{expert_id}.w3.weight_scale"] = w3_weight_scale + + quant_config = QuantConfig(quant_algo=QuantAlgo.FP8_BLOCK_SCALES) + + # Test WideEPMoE with alltoall method + with mock.patch.object(WideEPMoE, + "select_alltoall_method_type", + return_value=alltoall_method_type): + alltoall_model = WideEPMoE( + num_experts=NUM_EXPERTS, + routing_method=routing_method, + hidden_size=HIDDEN_SIZE, + intermediate_size=INTERMEDIATE_SIZE, + dtype=dtype, + reduce_results=True, + model_config=ModelConfig(mapping=mapping, + max_num_tokens=MAX_NUM_TOKENS, + quant_config=quant_config), + ) + alltoall_model.to("cuda") + alltoall_model.load_weights([weights]) + + # Use DeepGemmFusedMoE as reference + ref_model = DeepGemmFusedMoE( + num_experts=NUM_EXPERTS, + routing_method=routing_method, + hidden_size=HIDDEN_SIZE, + intermediate_size=INTERMEDIATE_SIZE, + dtype=dtype, + reduce_results=True, + model_config=ModelConfig(mapping=mapping, + max_num_tokens=MAX_NUM_TOKENS, + quant_config=quant_config), + ) + ref_model.to("cuda") + ref_model.load_weights([weights]) + + # Evaluate the outputs on variant sequence lengths + m = MAX_NUM_TOKENS + i = 0 + while m >= 1: + x = x_list[i] + i += 1 + router_logits = torch.randn((m, NUM_EXPERTS), + dtype=dtype, + device="cuda") + all_rank_num_tokens = [m] * mapping.world_size + with torch.inference_mode(): + output = alltoall_model.forward( + x, + router_logits, + all_rank_num_tokens=all_rank_num_tokens, + all_rank_max_num_tokens=m, + use_dp_padding=False) + ref_output = ref_model.forward( + x, + router_logits, + all_rank_num_tokens=all_rank_num_tokens, + all_rank_max_num_tokens=m, + use_dp_padding=False) + + # Evaluate outputs with relaxed tolerance for FP8 + # If WideEPMoE output has TOP_K dimension, reduce it to match DeepGemmFusedMoE + if output.dim() == 3 and output.shape[1] == TOP_K: + output = output.sum(dim=1) + torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1) + m //= 2 + + with MPIPoolExecutor(max_workers=world_size) as executor: + results = executor.map(per_rank_test_fused_moe_alltoall_fp8_blockwise, + range(world_size)) + for r in results: + assert r is None + + @skip_pre_blackwell @pytest.mark.parametrize( "dtype, num_experts, seq_len, hidden_size, RoutingMethodCls", From 94a52a3de21fb7404fad0a2db08be71f58d30e54 Mon Sep 17 00:00:00 2001 From: xxi Date: Wed, 3 Sep 2025 03:11:44 +0000 Subject: [PATCH 2/2] rename by comment Signed-off-by: xxi modified: tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py deleted: tensorrt_llm/_torch/modules/fused_moe/moe_backend.py new file: tensorrt_llm/_torch/modules/fused_moe/ops/__init__.py new file: tensorrt_llm/_torch/modules/fused_moe/ops/moe_op.py new file: tensorrt_llm/_torch/modules/fused_moe/ops/moe_op_cutlass.py new file: tensorrt_llm/_torch/modules/fused_moe/ops/moe_op_deepgemm.py modified: docs/source/deployment-guide/quick-start-recipe-for-deepseek-r1-on-trtllm.md modified: tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py deleted: tensorrt_llm/_torch/modules/fused_moe/moe_backend.py new file: tensorrt_llm/_torch/modules/fused_moe/ops/__init__.py new file: tensorrt_llm/_torch/modules/fused_moe/ops/moe_op.py new file: tensorrt_llm/_torch/modules/fused_moe/ops/moe_op_cutlass.py new file: tensorrt_llm/_torch/modules/fused_moe/ops/moe_op_deepgemm.py --- ...-start-recipe-for-deepseek-r1-on-trtllm.md | 2 +- .../modules/fused_moe/fused_moe_wide_ep.py | 26 +- .../_torch/modules/fused_moe/moe_backend.py | 791 ------------------ .../_torch/modules/fused_moe/ops/__init__.py | 17 + .../_torch/modules/fused_moe/ops/moe_op.py | 230 +++++ .../modules/fused_moe/ops/moe_op_cutlass.py | 306 +++++++ .../modules/fused_moe/ops/moe_op_deepgemm.py | 306 +++++++ 7 files changed, 873 insertions(+), 805 deletions(-) delete mode 100644 tensorrt_llm/_torch/modules/fused_moe/moe_backend.py create mode 100644 tensorrt_llm/_torch/modules/fused_moe/ops/__init__.py create mode 100644 tensorrt_llm/_torch/modules/fused_moe/ops/moe_op.py create mode 100644 tensorrt_llm/_torch/modules/fused_moe/ops/moe_op_cutlass.py create mode 100644 tensorrt_llm/_torch/modules/fused_moe/ops/moe_op_deepgemm.py diff --git a/docs/source/deployment-guide/quick-start-recipe-for-deepseek-r1-on-trtllm.md b/docs/source/deployment-guide/quick-start-recipe-for-deepseek-r1-on-trtllm.md index cd42786b8fa..8d811c65471 100644 --- a/docs/source/deployment-guide/quick-start-recipe-for-deepseek-r1-on-trtllm.md +++ b/docs/source/deployment-guide/quick-start-recipe-for-deepseek-r1-on-trtllm.md @@ -30,7 +30,7 @@ There are multiple MOE backends inside TRT-LLM, not all of them supporting every | B200/GB200 EP<=8 | NVFP4 | CUTLASS, TRTLLM | | B200/GB200 EP<=8 | FP8 | DEEPGEMM | | GB200 NVL72 EP>8 | NVFP4 | WIDEEP | -| GB200 NVL72 EP>8 | FP8 | N/A (WIP) | +| GB200 NVL72 EP>8 | FP8 | WIDEEP without EPLB | The default moe backend is `CUTLASS`, so for the combination which is not supported by `CUTLASS`, one must set the `moe_config.backend` explicitly to run the model. diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py index 738485901e5..86724a326e6 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py @@ -16,8 +16,8 @@ from ...utils import AuxStreamType, EventType, Fp4QuantizedTensor from .deep_ep_utils import buffer_pool, deep_ep_installed from .interface import MoE -from .moe_backend import MoEBackend, MoEBackendSelection from .moe_load_balancer import get_moe_load_balancer +from .ops import MoEOp, MoEOpSelector from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethod, DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm, FP8QDQFusedMoEMethod, MoEWeightLoadingMode, @@ -233,8 +233,8 @@ def __init__( self.enable_dummy_allreduce = os.environ.get( "TRTLLM_ENABLE_DUMMY_ALLREDUCE", "0") == "1" - # MoE backend will be lazily initialized when first accessed (see moe_backend property) - self._moe_backend_impl = None + # MoE op will be lazily initialized when first accessed (see moe_op_impl property) + self._moe_op_impl = None def _check_configs(self): assert self._weights_created @@ -352,17 +352,17 @@ def create_weights(self): self._check_configs() @property - def moe_backend_impl(self) -> MoEBackend: + def moe_op_impl(self) -> MoEOp: """ - Lazily initialize and return the MoE backend. + Lazily initialize and return the MoE op. - The backend is selected based on hardware capabilities and quantization + The op is selected based on hardware capabilities and quantization configuration, which are only available after weights are created. """ - if self._moe_backend_impl is None: - assert self._weights_created, "Weights must be created before accessing moe_backend" - self._moe_backend_impl = MoEBackendSelection.select_backend(self) - return self._moe_backend_impl + if self._moe_op_impl is None: + assert self._weights_created, "Weights must be created before accessing moe_op" + self._moe_op_impl = MoEOpSelector.select_op(self) + return self._moe_op_impl def dummy_allreduce(self): """ @@ -414,8 +414,8 @@ def forward_chunk( if self.layer_load_balancer and is_first_call: self.layer_load_balancer.start_wait_gpu_stage() - use_deepseek_fp8_block_scale = False - use_w4_group_scaling = False + if not use_all_to_all or self.alltoall_method_type != AlltoallMethodType.MNNVL: + pass weight_dtype = self.w3_w1_weight.dtype @@ -661,7 +661,7 @@ def forward_chunk( f"Not available alltoall method type: {self.alltoall_method_type!r}" ) - final_hidden_states = self.moe_backend_impl.run_moe( + final_hidden_states = self.moe_op_impl.run_moe( self, x, token_selected_slots, diff --git a/tensorrt_llm/_torch/modules/fused_moe/moe_backend.py b/tensorrt_llm/_torch/modules/fused_moe/moe_backend.py deleted file mode 100644 index ae99c3d85bb..00000000000 --- a/tensorrt_llm/_torch/modules/fused_moe/moe_backend.py +++ /dev/null @@ -1,791 +0,0 @@ -""" -MoE Backend abstraction for supporting different MoE computation implementations. -This module provides a unified interface for different MoE backends (Cutlass, DeepGemm, etc.) -""" - -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Dict, List, Optional - -import torch - -from tensorrt_llm._utils import get_sm_version - -if TYPE_CHECKING: - from .interface import MoE - - -class MoEBackend(ABC): - """Abstract base class for MoE computation backends. - - This class provides a strategy pattern for different MoE computation implementations. - It is used by MoE modules (like WideEPMoE) to delegate the actual computation. - - Note: MoEBackend is NOT a MoE module itself, but a computation strategy. - The actual MoE module (e.g., WideEPMoE) inherits from MoE and uses MoEBackend - for the computation implementation. - """ - - # Backend-specific abstract methods - @abstractmethod - def finalize_tactic( - self, - module: 'MoE', - tuner_input: torch.Tensor, - output_dtype: torch.dtype, - min_latency_mode: bool = False, - use_fused_finalize: bool = True, - tuner_top_k: Optional[int] = None, - ) -> None: - """ - Finalize tactics for the MoE computation. - For Cutlass backend, this includes profiling and tactic selection. - For DeepGemm backend, this can be a no-op. - - Args: - module: The MoE module containing MoE configurations - tuner_input: Real input used for tuning (same shape/layout as non-alltoall) - output_dtype: Output dtype for tuner run - min_latency_mode: Whether to profile for min-latency path - use_fused_finalize: Whether to use fused finalize - tuner_top_k: Top-k value for tuning (Cutlass specific) - """ - - @abstractmethod - def compute_moe( - self, - module: 'MoE', - # Input tensors - x: torch.Tensor, - token_selected_slots: torch.Tensor, - token_final_scales: Optional[torch.Tensor], - # Weight tensors - w3_w1_weight: torch.Tensor, - w3_w1_bias: Optional[torch.Tensor], - w2_weight: torch.Tensor, - w2_bias: Optional[torch.Tensor], - # Output configuration - output_dtype: torch.dtype, - # Quantization parameters - quant_scales: List[torch.Tensor], - input_sf: Optional[torch.Tensor] = None, - swizzled_input_sf: bool = True, - # Performance tuning (only runtime-variable parameters) - min_latency_mode: bool = False, - use_fused_finalize: bool = True, - tuner_num_tokens: Optional[int] = None, - tuner_top_k: Optional[int] = None, - **kwargs) -> torch.Tensor: - """ - Perform the actual MoE computation. - - Configuration parameters (tp_size, ep_size, swiglu params, etc.) are - automatically extracted from the module parameter. - - Args: - module: MoE module containing configuration and parameters. - The following will be extracted: - - tp_size, tp_rank, ep_size, ep_rank, cluster_size, cluster_rank - - enable_alltoall, tune_max_num_tokens - - swiglu_alpha, swiglu_beta, swiglu_limit - - Quantization flags based on module properties - x: Input tensor - token_selected_slots: Selected expert slots - token_final_scales: Scaling factors - w3_w1_weight: Fused gate and up projection weights - w3_w1_bias: Optional bias - w2_weight: Down projection weights - w2_bias: Optional bias - output_dtype: Output data type - quant_scales: Quantization scales - input_sf: Input scaling factor - swizzled_input_sf: Whether input_sf is swizzled - min_latency_mode: Use minimum latency optimizations - use_fused_finalize: Use fused finalization - tuner_num_tokens: Number of tokens for tuning - tuner_top_k: Top-k value for tuning - - Returns: - Computed MoE output tensor - """ - - def run_moe( - self, - module: 'MoE', - # Input tensors - input: torch.Tensor, - token_selected_slots: torch.Tensor, - token_final_scales: torch.Tensor, - w3_w1_weight: torch.Tensor, - w3_w1_bias: Optional[torch.Tensor], - w2_weight: torch.Tensor, - w2_bias: Optional[torch.Tensor], - output_dtype: torch.dtype, - # Quantization parameters - quant_scales: List[torch.Tensor], - input_sf: Optional[torch.Tensor] = None, - swizzled_input_sf: bool = True, - # Performance tuning (only runtime-variable parameters) - min_latency_mode: bool = False, - use_fused_finalize: bool = True, - tuner_num_tokens: Optional[int] = None, - tuner_top_k: Optional[int] = None, - **kwargs) -> torch.Tensor: - """ - Run the complete MoE computation pipeline. - - Configuration parameters are automatically extracted from the module. - - Args: - module: MoE module containing configuration - input: Input tensor to the MoE layer - token_selected_slots: Selected expert slots for each token - token_final_scales: Final scaling factors for each token - w3_w1_weight: Concatenated weights for w3 and w1 projections - w3_w1_bias: Optional bias for w3/w1 projections - w2_weight: Weight for w2 projection - w2_bias: Optional bias for w2 projection - output_dtype: Desired output data type - quant_scales: Quantization scales for weights - input_sf: Optional input scale factors for quantization - swizzled_input_sf: Whether input scale factors are swizzled - min_latency_mode: Use minimum latency optimizations - use_fused_finalize: Use fused finalization - tuner_num_tokens: Number of tokens for tuner input - tuner_top_k: Top-k value for tuning - - Returns: - Computed MoE output tensor - """ - self.finalize_tactic(module, input, output_dtype, min_latency_mode, - use_fused_finalize, tuner_top_k) - - # Call compute_moe with module - return self.compute_moe(module=module, - x=input, - token_selected_slots=token_selected_slots, - token_final_scales=token_final_scales, - w3_w1_weight=w3_w1_weight, - w3_w1_bias=w3_w1_bias, - w2_weight=w2_weight, - w2_bias=w2_bias, - output_dtype=output_dtype, - quant_scales=quant_scales, - input_sf=input_sf, - swizzled_input_sf=swizzled_input_sf, - min_latency_mode=min_latency_mode, - use_fused_finalize=use_fused_finalize, - tuner_num_tokens=tuner_num_tokens, - tuner_top_k=tuner_top_k, - **kwargs) - - -class MoECutlassBackend(MoEBackend): - """Cutlass-based MoE backend using torch.ops.trtllm.fused_moe.""" - - def __init__(self): - """Initialize the Cutlass backend.""" - super().__init__() - self.moe_runner = None - self.gemm_tactics = None - - def finalize_tactic( - self, - module: 'MoE', - tuner_input: torch.Tensor, - output_dtype: torch.dtype, - min_latency_mode: bool = False, - use_fused_finalize: bool = True, - tuner_top_k: Optional[int] = None, - ) -> None: - """ - Finalize tactics for Cutlass MoE by profiling and selecting optimal GEMM tactics. - """ - - # Import necessary modules for profiling - from ...custom_ops.torch_custom_ops import AutoTuner, MoERunner - - # Use real tuner_input rather than dummy input - assert tuner_input is not None, "tuner_input must be provided to finalize_tactic" - if tuner_top_k is None: - tuner_top_k = getattr(module.routing_method, 'experts_per_token', 1) - - # Determine view dtype for weights to match runtime quantization layout - weight_view_dtype = module.w3_w1_weight.dtype - if getattr(module, 'has_w4afp8', False): - weight_view_dtype = torch.quint4x2 - elif getattr(module, 'has_w4a16_mxfp4', False): - weight_view_dtype = torch.uint8 - - # Create MoERunner for profiling - if self.moe_runner is None: - self.moe_runner = MoERunner( - x_dtype=tuner_input.dtype, - weight_dtype=module.w3_w1_weight.dtype, - output_dtype=output_dtype, - top_k=tuner_top_k, - tp_size=module.tp_size, - tp_rank=module.tp_rank, - ep_size=module.ep_size, - ep_rank=module.ep_rank, - cluster_size=module.cluster_size, - cluster_rank=module.cluster_rank, - use_deepseek_fp8_block_scale=getattr( - module, 'has_deepseek_fp8_block_scales', False), - use_w4_group_scaling=getattr(module, 'has_w4afp8', False), - use_int8_woq_per_channel=getattr(module, - 'has_int8_woq_per_channel', - False), - use_mxfp8_act_scaling=getattr(module, 'has_mxfp8_act_scaling', - False), - min_latency_mode=min_latency_mode, - use_fused_finalize=use_fused_finalize, - ) - - # Set tuning configuration - MoERunner.tuning_config.tune_max_num_tokens = getattr( - module, 'tune_max_num_tokens', 8192) - - # Get AutoTuner for tactic selection - tuner = AutoTuner.get() - - # Profile and select tactics (GEMM1) - _, gemm_tactic_1 = tuner.choose_one( - "trtllm::fused_moe::gemm1", - [self.moe_runner], - MoERunner.tuning_config, - [ - tuner_input, - module.w3_w1_weight.view(weight_view_dtype), - getattr(module, 'w3_w1_bias', None), - module.w2_weight.view(weight_view_dtype), - getattr(module, 'w2_bias', None), - ], - gemm_idx=1, - ) - - # Profile and select tactics (GEMM2) - _, gemm_tactic_2 = tuner.choose_one( - "trtllm::fused_moe::gemm2", - [self.moe_runner], - MoERunner.tuning_config, - [ - tuner_input, - module.w3_w1_weight.view(weight_view_dtype), - getattr(module, 'w3_w1_bias', None), - module.w2_weight.view(weight_view_dtype), - getattr(module, 'w2_bias', None), - ], - gemm_idx=2, - ) - - # Store selected tactics - self.gemm_tactics = [gemm_tactic_1, gemm_tactic_2] - - def compute_moe( - self, - module: 'MoE', # Now required as first parameter - # Input tensors - x: torch.Tensor, - token_selected_slots: torch.Tensor, - token_final_scales: Optional[torch.Tensor], - # Weight tensors - w3_w1_weight: torch.Tensor, - w3_w1_bias: Optional[torch.Tensor], - w2_weight: torch.Tensor, - w2_bias: Optional[torch.Tensor], - # Output configuration - output_dtype: torch.dtype, - # Quantization parameters - quant_scales: List[torch.Tensor], - input_sf: Optional[torch.Tensor] = None, - swizzled_input_sf: bool = True, - # Performance tuning (only runtime-variable parameters) - min_latency_mode: bool = False, - use_fused_finalize: bool = True, - tuner_num_tokens: Optional[int] = None, - tuner_top_k: Optional[int] = None, - **kwargs) -> torch.Tensor: - """ - Compute MoE using Cutlass backend with MoERunner. - """ - # Extract parameters from module - tp_size = module.tp_size - tp_rank = module.tp_rank - ep_size = module.ep_size - ep_rank = module.ep_rank - cluster_size = module.cluster_size - cluster_rank = module.cluster_rank - enable_alltoall = module.enable_alltoall - getattr(module, 'tune_max_num_tokens', 8192) - swiglu_alpha = module.swiglu_alpha - swiglu_beta = module.swiglu_beta - swiglu_limit = module.swiglu_limit - use_w4_group_scaling = getattr(module, 'has_w4afp8', False) - - # Determine weight dtype for view operation if needed - weight_dtype = w3_w1_weight.dtype - if use_w4_group_scaling and weight_dtype != torch.quint4x2: - weight_dtype = torch.quint4x2 - - # Validate that tactics have been finalized - if self.gemm_tactics is None or len(self.gemm_tactics) == 0: - raise RuntimeError( - "GEMM tactics have not been finalized. " - "Call finalize_tactic() before compute_moe() or use run_moe() instead." - ) - - if self.moe_runner is None: - raise RuntimeError( - "MoERunner has not been initialized. " - "Call finalize_tactic() before compute_moe() or use run_moe() instead." - ) - - # Select the appropriate run method based on latency mode - run_moe = self.moe_runner.fused_moe_runner.run_moe_min_latency if min_latency_mode else self.moe_runner.fused_moe_runner.run_moe - - # Run the actual MoE computation - output = run_moe( - x, - token_selected_slots, - token_final_scales, - w3_w1_weight.view(weight_dtype), - w3_w1_bias, - w2_weight.view(weight_dtype), - w2_bias, - quant_scales, - input_sf, - swizzled_input_sf, - swiglu_alpha, - swiglu_beta, - swiglu_limit, - tp_size, - tp_rank, - ep_size, - ep_rank, - cluster_size, - cluster_rank, - enable_alltoall, - min_latency_mode, - self.gemm_tactics, - ) - - # Return output based on latency mode - return output if min_latency_mode else [output] - - def run_moe( - self, - module: 'MoE', - # Input tensors - input: torch.Tensor, - token_selected_slots: torch.Tensor, - token_final_scales: torch.Tensor, - w3_w1_weight: torch.Tensor, - w3_w1_bias: Optional[torch.Tensor], - w2_weight: torch.Tensor, - w2_bias: Optional[torch.Tensor], - output_dtype: torch.dtype, - # Quantization parameters - quant_scales: List[torch.Tensor], - input_sf: Optional[torch.Tensor] = None, - swizzled_input_sf: bool = True, - # Performance tuning (only runtime-variable parameters) - min_latency_mode: bool = False, - use_fused_finalize: bool = True, - tuner_num_tokens: Optional[int] = None, - tuner_top_k: Optional[int] = None, - **kwargs) -> torch.Tensor: - """ - Run the complete MoE computation pipeline for Cutlass backend. - - This override handles the specific tuner_input logic needed for Cutlass. - - Args: - module: MoE module containing configuration - input: Input tensor to the MoE layer - token_selected_slots: Selected expert slots for each token - token_final_scales: Final scaling factors for each token - w3_w1_weight: Concatenated weights for w3 and w1 projections - w3_w1_bias: Optional bias for w3/w1 projections - w2_weight: Weight for w2 projection - w2_bias: Optional bias for w2 projection - output_dtype: Desired output data type - quant_scales: Quantization scales for weights - input_sf: Optional input scale factors for quantization - swizzled_input_sf: Whether input scale factors are swizzled - min_latency_mode: Use minimum latency optimizations - use_fused_finalize: Use fused finalization - tuner_num_tokens: Number of tokens for tuner input - tuner_top_k: Top-k value for tuning - - Returns: - Computed MoE output tensor - """ - # Extract enable_alltoall from module to determine tuner_input logic - enable_alltoall = module.enable_alltoall - - # Compute tuner_input per fused_moe logic - if enable_alltoall: - assert tuner_num_tokens is not None - assert tuner_top_k is not None - tuner_input = input[:tuner_num_tokens] - else: - assert tuner_num_tokens is None - assert tuner_top_k is None - tuner_input = input - tuner_top_k = token_selected_slots.size(1) - - self.finalize_tactic(module, tuner_input, output_dtype, - min_latency_mode, use_fused_finalize, tuner_top_k) - - # Call compute_moe with module - return self.compute_moe(module=module, - x=input, - token_selected_slots=token_selected_slots, - token_final_scales=token_final_scales, - w3_w1_weight=w3_w1_weight, - w3_w1_bias=w3_w1_bias, - w2_weight=w2_weight, - w2_bias=w2_bias, - output_dtype=output_dtype, - quant_scales=quant_scales, - input_sf=input_sf, - swizzled_input_sf=swizzled_input_sf, - min_latency_mode=min_latency_mode, - use_fused_finalize=use_fused_finalize, - tuner_num_tokens=tuner_num_tokens, - tuner_top_k=tuner_top_k, - **kwargs) - - -class MoEDeepGemmBackend(MoEBackend): - """DeepGemm-based MoE backend for GB200 block FP8.""" - - def __init__(self): - """Initialize DeepGemm backend.""" - super().__init__() - import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils - self.fp8_utils = fp8_utils - - from .fused_moe_deepgemm import deepgemm_fp8_group_blockwise_gemm - self.deepgemm_fp8_group_blockwise_gemm = deepgemm_fp8_group_blockwise_gemm - - def finalize_tactic( - self, - module: 'MoE', - tuner_input: torch.Tensor, - output_dtype: torch.dtype, - min_latency_mode: bool = False, - use_fused_finalize: bool = True, - tuner_top_k: Optional[int] = None, - ) -> None: - """ - No-op for DeepGemm backend as it doesn't require tactic profiling. - - Args: - module: The MoE module - tuner_input: Input tensor for tuning - output_dtype: Output dtype - min_latency_mode: Whether to use min-latency mode - use_fused_finalize: Whether to use fused finalize - tuner_top_k: Top-k value for tuning - """ - - def _get_deepgemm_workspace(self, module: 'MoE', m_max: int, - group_size: int) -> Dict[str, torch.Tensor]: - """ - Get workspace for DeepGemm backend operations. - - Args: - module: The MoE module containing configuration - m_max: Maximum number of tokens (aligned) - group_size: Group size for quantization - - Returns: - Dictionary containing workspace tensors - """ - import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils - - # Get dimensions from module - hidden_size = module.hidden_size - intermediate_size = module.intermediate_size - expert_size_per_partition = module.expert_size_per_partition - - # Calculate aligned dimensions - m_padded = fp8_utils.align(m_max, 4) - fp8_dim = max(hidden_size, intermediate_size) - scale_k = fp8_utils.ceil_div(fp8_dim, group_size) - scale_k_padded = fp8_utils.align(scale_k, 4) - - # Allocate workspace tensors - workspace = {} - - # Workspace for FP8 activations - workspace["workspace_0"] = torch.empty( - (expert_size_per_partition * m_max * fp8_dim), - dtype=torch.float8_e4m3fn, - device='cuda') - - # Workspace for intermediate results - workspace["workspace_1"] = torch.empty( - (expert_size_per_partition * m_max * - max(intermediate_size * 2, hidden_size)), - dtype=torch.bfloat16, - device='cuda') - - # Workspace for scaling factors - workspace["workspace_sf"] = torch.empty( - expert_size_per_partition * (scale_k_padded // 4) * m_padded, - dtype=torch.int32, - device='cuda') - - return workspace - - def compute_moe( - self, - module: 'MoE', - # Input tensors - x: torch.Tensor, - token_selected_slots: torch.Tensor, - token_final_scales: Optional[torch.Tensor], - # Weight tensors - w3_w1_weight: torch.Tensor, - w3_w1_bias: Optional[torch.Tensor], - w2_weight: torch.Tensor, - w2_bias: Optional[torch.Tensor], - # Output configuration - output_dtype: torch.dtype, - # Quantization parameters - quant_scales: List[torch.Tensor], - input_sf: Optional[torch.Tensor] = None, - swizzled_input_sf: bool = True, - # Performance tuning (only runtime-variable parameters) - min_latency_mode: bool = False, - use_fused_finalize: bool = True, - tuner_num_tokens: Optional[int] = None, - tuner_top_k: Optional[int] = None, - **kwargs) -> torch.Tensor: - """ - Compute MoE using DeepGemm backend with block FP8 quantization. - - Note: This assumes the data has already been gathered/alltoall'd - by the WideEP forward_chunk method. - """ - - # Import necessary functions for DeepGemm - from .fused_moe_deepgemm import (masked_index_copy_group_quant_fp8, - preprocess_after_permute, set_strides, - triton_masked_index_gather) - - # Extract parameters from module - tp_size = module.tp_size - tp_rank = module.tp_rank - ep_size = module.ep_size - ep_rank = module.ep_rank - cluster_size = module.cluster_size - cluster_rank = module.cluster_rank - enable_alltoall = module.enable_alltoall - getattr(module, 'tune_max_num_tokens', 8192) - module.swiglu_alpha - module.swiglu_beta - module.swiglu_limit - - # Not supported: min_latency_mode. Raise error if enabled. - if min_latency_mode: - raise NotImplementedError( - "DeepGemm backend does not support min_latency_mode=True") - - # Get expert configuration from module - expert_size_per_partition = module.expert_size_per_partition - intermediate_size = module.intermediate_size - hidden_size = x.shape[1] - - # Permute the data for expert-parallel processing - ( - permuted_row_to_unpermuted_row_tensor, - permuted_token_selected_experts_tensor, - permuted_data_tensor, - expert_first_token_offset_tensor, - permuted_token_final_scales_tensor, - unpermuted_row_to_permuted_row_tensor, - ) = torch.ops.trtllm.moe_permute_op( - x, - token_selected_slots, - token_final_scales, - None, # w3_w1_weight - None, # w2_weight - None, # quant_scales - input_sf=input_sf, - num_experts_on_rank=expert_size_per_partition, - tp_size=tp_size, - tp_rank=tp_rank, - ep_size=ep_size, - ep_rank=ep_rank, - cluster_size=cluster_size, - cluster_rank=cluster_rank, - min_latency_mode=min_latency_mode, - use_fp8_block_scaling=True, # Always use block scaling for DeepGemm - ) - - if permuted_data_tensor.numel() == 0: - return torch.zeros_like(x) - - # Preprocess for masked operations - masked_m, token_to_expert_map = preprocess_after_permute( - expert_first_token_offset_tensor, permuted_data_tensor) - - expected_m = (token_selected_slots.numel() + expert_size_per_partition - - 1) // expert_size_per_partition - - # Get workspace for DeepGemm operations - m_max = self.fp8_utils.align(x.shape[0], 128) - workspace = self._get_deepgemm_workspace(module, m_max, 128) - - # Padding and quantization for first GEMM input - m_padded = self.fp8_utils.align(m_max, 4) - scale_k = self.fp8_utils.ceil_div(hidden_size, 128) - scale_k_padded = self.fp8_utils.align(scale_k, 4) - - act_input_fp8 = set_strides(workspace["workspace_0"], - expert_size_per_partition, m_max, - hidden_size) - act_input_sf = set_strides(workspace["workspace_sf"], - expert_size_per_partition, - scale_k_padded // 4, m_padded) - - # Quantize and copy input with masking - act_input_sf = masked_index_copy_group_quant_fp8( - act_input_fp8, - act_input_sf, - permuted_data_tensor, - expert_first_token_offset_tensor, - token_to_expert_map, - group_size=128) - - # First grouped GEMM (w3 and w1) - h1 = set_strides(workspace["workspace_1"], expert_size_per_partition, - m_max, intermediate_size * 2) - - self.deepgemm_fp8_group_blockwise_gemm( - d=h1, - a=act_input_fp8, - b=w3_w1_weight, - sfa=act_input_sf, - sfb=quant_scales[0] if quant_scales else None, - masked_m=masked_m, - expected_m=expected_m, - ) - - # SiLU activation and quantization for second GEMM - act_input_fp8 = set_strides(workspace["workspace_0"], - expert_size_per_partition, m_max, - intermediate_size) - - scale_k = self.fp8_utils.ceil_div(intermediate_size, 128) - scale_k_padded = self.fp8_utils.align(scale_k, 4) - act_input_sf = set_strides(workspace["workspace_sf"], - expert_size_per_partition, - scale_k_padded // 4, m_padded) - - act_input_sf = self.fp8_utils.silu_and_mul_masked_post_quant_fwd( - output=act_input_fp8, - output_scale=act_input_sf, - input=h1, - quant_group_size=128, - masked_m=masked_m, - scale_ue8m0=True) - - # Second grouped GEMM (w2) - h3 = set_strides(workspace["workspace_1"], expert_size_per_partition, - m_max, hidden_size) - - self.deepgemm_fp8_group_blockwise_gemm( - d=h3, - a=act_input_fp8, - b=w2_weight, - sfa=act_input_sf, - sfb=quant_scales[1] if quant_scales else None, - masked_m=masked_m, - expected_m=expected_m, - ) - - # Gather results back to original token order - triton_masked_index_gather(permuted_data_tensor, h3, - expert_first_token_offset_tensor, - token_to_expert_map) - - # Finalize and scale the output - # Get unpadded_hidden_size from module if available, otherwise use hidden_size - # For now it is the user's responsibility to set unpadded_hidden_size. - # DeepGemmFusedMoE and WideEPMoE both have unpadded_hidden_size. - unpadded_hidden_size = getattr(module, 'unpadded_hidden_size', - x.shape[1]) - - final_hidden_states = torch.ops.trtllm.moe_finalize_scale_op( - permuted_data_tensor, - None, # biases (w2_bias could be added here if needed) - token_final_scales, - unpermuted_row_to_permuted_row_tensor, - permuted_row_to_unpermuted_row_tensor, - token_selected_slots, - expert_first_token_offset_tensor, - enable_alltoall, - x.shape[0], # num_rows - x.shape[1], # hidden_size - unpadded_hidden_size, # unpadded_hidden_size (may be different from hidden_size if padding was applied) - module.routing_method.top_k if module else 1, # experts_per_token - expert_size_per_partition, # num_experts_per_node - tp_size, - tp_rank, - ep_size, - ep_rank, - ) - - return final_hidden_states if min_latency_mode else [ - final_hidden_states - ] - - -class MoEBackendSelection: - """ - Utility class for selecting the appropriate MoE backend based on - hardware capabilities and quantization configuration. - - This class implements the strategy pattern for backend selection, - choosing between Cutlass and DeepGemm implementations based on: - - Hardware capabilities (SM version) - - Quantization configuration (block FP8 support) - """ - - @staticmethod - def select_backend(module: 'MoE') -> MoEBackend: - """ - Select the appropriate MoE backend based on module configuration. - - Selection criteria: - - Blackwell (SM100) with block FP8 quantization -> DeepGemm backend - - All other configurations -> Cutlass backend - - Args: - module: The MoE module containing configuration information - Expected attributes: - - has_deepseek_fp8_block_scales: Whether block FP8 is enabled - - Returns: - MoEBackend: Selected backend instance (MoECutlassBackend or MoEDeepGemmBackend) - - Example: - >>> backend = MoEBackendSelection.select_backend(moe_module) - >>> output = backend.run_moe(input, ...) - """ - # Check if we should use DeepGemm backend - # Blackwell has SM version 100 - is_blackwell = get_sm_version() == 100 - has_block_fp8 = (hasattr(module, 'has_deepseek_fp8_block_scales') - and module.has_deepseek_fp8_block_scales) - - if is_blackwell and has_block_fp8: - # Use DeepGemm backend for Blackwell with block FP8 - return MoEDeepGemmBackend() - else: - # Use Cutlass backend for all other cases - return MoECutlassBackend() diff --git a/tensorrt_llm/_torch/modules/fused_moe/ops/__init__.py b/tensorrt_llm/_torch/modules/fused_moe/ops/__init__.py new file mode 100644 index 00000000000..3c6f3bc3a79 --- /dev/null +++ b/tensorrt_llm/_torch/modules/fused_moe/ops/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MoE ops module for different computation implementations.""" + +from .moe_op import MoEOp, MoEOpSelector +from .moe_op_cutlass import CutlassMoEOp +from .moe_op_deepgemm import DeepGemmMoEOp + +__all__ = ['MoEOp', 'MoEOpSelector', 'CutlassMoEOp', 'DeepGemmMoEOp'] diff --git a/tensorrt_llm/_torch/modules/fused_moe/ops/moe_op.py b/tensorrt_llm/_torch/modules/fused_moe/ops/moe_op.py new file mode 100644 index 00000000000..629f9bf7952 --- /dev/null +++ b/tensorrt_llm/_torch/modules/fused_moe/ops/moe_op.py @@ -0,0 +1,230 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +MoE Op abstraction for supporting different MoE computation implementations. +This module provides a unified interface for different MoE ops (Cutlass, DeepGemm, etc.) +""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, List, Optional + +import torch + +from tensorrt_llm._utils import get_sm_version + +if TYPE_CHECKING: + from ..interface import MoE + + +class MoEOp(ABC): + """Abstract base class for MoE computation ops. + + This class provides a strategy pattern for different MoE computation implementations. + It is used by MoE modules (like WideEPMoE) to delegate the actual computation. + + Note: MoEOp is NOT a MoE module itself, but a computation strategy. + The actual MoE module (e.g., WideEPMoE) inherits from MoE and uses MoEOp + for the computation implementation. + """ + + # Op-specific abstract methods + @abstractmethod + def finalize_tactic( + self, + module: 'MoE', + tuner_input: torch.Tensor, + output_dtype: torch.dtype, + min_latency_mode: bool = False, + use_fused_finalize: bool = True, + tuner_top_k: Optional[int] = None, + ) -> None: + """ + Finalize tactics for the MoE computation. + For Cutlass op, this includes profiling and tactic selection. + For DeepGemm op, this can be a no-op. + + Args: + module: The MoE module containing MoE configurations + tuner_input: Real input used for tuning (same shape/layout as non-alltoall) + output_dtype: Output dtype for tuner run + min_latency_mode: Whether to profile for min-latency path + use_fused_finalize: Whether to use fused finalize + tuner_top_k: Top-k value for tuning (Cutlass specific) + """ + + @abstractmethod + def compute_moe( + self, + module: 'MoE', + # Input tensors + x: torch.Tensor, + token_selected_slots: torch.Tensor, + token_final_scales: Optional[torch.Tensor], + # Weight tensors + w3_w1_weight: torch.Tensor, + w3_w1_bias: Optional[torch.Tensor], + w2_weight: torch.Tensor, + w2_bias: Optional[torch.Tensor], + # Output configuration + output_dtype: torch.dtype, + # Quantization parameters + quant_scales: List[torch.Tensor], + input_sf: Optional[torch.Tensor] = None, + swizzled_input_sf: bool = True, + # Performance tuning (only runtime-variable parameters) + min_latency_mode: bool = False, + use_fused_finalize: bool = True, + tuner_num_tokens: Optional[int] = None, + tuner_top_k: Optional[int] = None, + **kwargs) -> torch.Tensor: + """ + Perform the actual MoE computation. + + Configuration parameters (tp_size, ep_size, swiglu params, etc.) are + automatically extracted from the module parameter. + + Args: + module: MoE module containing configuration and parameters. + x: Input tensor + token_selected_slots: Selected expert slots + token_final_scales: Scaling factors + w3_w1_weight: Fused gate and up projection weights + w3_w1_bias: Optional bias + w2_weight: Down projection weights + w2_bias: Optional bias + output_dtype: Output data type + quant_scales: Quantization scales + input_sf: Input scaling factor + swizzled_input_sf: Whether input_sf is swizzled + min_latency_mode: Use minimum latency optimizations + use_fused_finalize: Use fused finalization + tuner_num_tokens: Number of tokens for tuning + tuner_top_k: Top-k value for tuning + + Returns: + Computed MoE output tensor + """ + + def run_moe( + self, + module: 'MoE', + # Input tensors + input: torch.Tensor, + token_selected_slots: torch.Tensor, + token_final_scales: torch.Tensor, + w3_w1_weight: torch.Tensor, + w3_w1_bias: Optional[torch.Tensor], + w2_weight: torch.Tensor, + w2_bias: Optional[torch.Tensor], + output_dtype: torch.dtype, + # Quantization parameters + quant_scales: List[torch.Tensor], + input_sf: Optional[torch.Tensor] = None, + swizzled_input_sf: bool = True, + # Performance tuning (only runtime-variable parameters) + min_latency_mode: bool = False, + use_fused_finalize: bool = True, + tuner_num_tokens: Optional[int] = None, + tuner_top_k: Optional[int] = None, + **kwargs) -> torch.Tensor: + """ + Run the complete MoE computation pipeline. + + Configuration parameters are automatically extracted from the module. + + Args: + module: MoE module containing configuration + input: Input tensor to the MoE layer + token_selected_slots: Selected expert slots for each token + token_final_scales: Final scaling factors for each token + w3_w1_weight: Concatenated weights for w3 and w1 projections + w3_w1_bias: Optional bias for w3/w1 projections + w2_weight: Weight for w2 projection + w2_bias: Optional bias for w2 projection + output_dtype: Desired output data type + quant_scales: Quantization scales for weights + input_sf: Optional input scale factors for quantization + swizzled_input_sf: Whether input scale factors are swizzled + min_latency_mode: Use minimum latency optimizations + use_fused_finalize: Use fused finalization + tuner_num_tokens: Number of tokens for tuner input + tuner_top_k: Top-k value for tuning + + Returns: + Computed MoE output tensor + """ + self.finalize_tactic(module, input, output_dtype, min_latency_mode, + use_fused_finalize, tuner_top_k) + + # Call compute_moe with module + return self.compute_moe(module=module, + x=input, + token_selected_slots=token_selected_slots, + token_final_scales=token_final_scales, + w3_w1_weight=w3_w1_weight, + w3_w1_bias=w3_w1_bias, + w2_weight=w2_weight, + w2_bias=w2_bias, + output_dtype=output_dtype, + quant_scales=quant_scales, + input_sf=input_sf, + swizzled_input_sf=swizzled_input_sf, + min_latency_mode=min_latency_mode, + use_fused_finalize=use_fused_finalize, + tuner_num_tokens=tuner_num_tokens, + tuner_top_k=tuner_top_k, + **kwargs) + + +class MoEOpSelector: + """ + Utility class for selecting the appropriate MoE op based on + hardware capabilities and quantization configuration. + + This class implements the strategy pattern for op selection, + choosing between Cutlass and DeepGemm implementations based on: + - Hardware capabilities (SM version) + - Quantization configuration (block FP8 support) + """ + + @staticmethod + def select_op(module: 'MoE') -> MoEOp: + """ + Select the appropriate MoE op based on module configuration. + + Selection criteria: + - Blackwell (SM100) with block FP8 quantization -> DeepGemm op + - All other configurations -> Cutlass op + + Args: + module: The MoE module containing configuration information + + Returns: + MoEOp: Selected op instance (CutlassMoEOp or DeepGemmMoEOp) + + Example: + >>> op = MoEOpSelector.select_op(moe_module) + >>> output = op.run_moe(input, ...) + """ + from .moe_op_cutlass import CutlassMoEOp + from .moe_op_deepgemm import DeepGemmMoEOp + + # Check if we should use DeepGemm op + # Blackwell has SM version 100 + is_blackwell = get_sm_version() == 100 + has_block_fp8 = module.has_deepseek_fp8_block_scales + + if is_blackwell and has_block_fp8: + # Use DeepGemm op for Blackwell with block FP8 + return DeepGemmMoEOp() + else: + # Use Cutlass op for all other cases + return CutlassMoEOp() diff --git a/tensorrt_llm/_torch/modules/fused_moe/ops/moe_op_cutlass.py b/tensorrt_llm/_torch/modules/fused_moe/ops/moe_op_cutlass.py new file mode 100644 index 00000000000..45f27dc8235 --- /dev/null +++ b/tensorrt_llm/_torch/modules/fused_moe/ops/moe_op_cutlass.py @@ -0,0 +1,306 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Cutlass-based MoE op implementation. +""" + +from typing import TYPE_CHECKING, List, Optional + +import torch + +from .moe_op import MoEOp + +if TYPE_CHECKING: + from ..interface import MoE + + +class CutlassMoEOp(MoEOp): + """Cutlass-based MoE op using torch.ops.trtllm.fused_moe.""" + + def __init__(self): + """Initialize the Cutlass op.""" + super().__init__() + self.moe_runner = None + self.gemm_tactics = None + + def finalize_tactic( + self, + module: 'MoE', + tuner_input: torch.Tensor, + output_dtype: torch.dtype, + min_latency_mode: bool = False, + use_fused_finalize: bool = True, + tuner_top_k: Optional[int] = None, + ) -> None: + """ + Finalize tactics for Cutlass MoE by profiling and selecting optimal GEMM tactics. + """ + + # Import necessary modules for profiling + from ....custom_ops.torch_custom_ops import AutoTuner, MoERunner + + # Use real tuner_input rather than dummy input + assert tuner_input is not None, "tuner_input must be provided to finalize_tactic" + if tuner_top_k is None: + tuner_top_k = getattr(module.routing_method, 'experts_per_token', 1) + + # Determine view dtype for weights to match runtime quantization layout + weight_view_dtype = module.w3_w1_weight.dtype + if getattr(module, 'has_w4afp8', False): + weight_view_dtype = torch.quint4x2 + elif module.has_w4a16_mxfp4: + weight_view_dtype = torch.uint8 + + # Create MoERunner for profiling + if self.moe_runner is None: + self.moe_runner = MoERunner( + x_dtype=tuner_input.dtype, + weight_dtype=weight_view_dtype, + output_dtype=output_dtype, + top_k=tuner_top_k, + tp_size=module.tp_size, + tp_rank=module.tp_rank, + ep_size=module.ep_size, + ep_rank=module.ep_rank, + cluster_size=module.cluster_size, + cluster_rank=module.cluster_rank, + use_deepseek_fp8_block_scale=module. + has_deepseek_fp8_block_scales, + use_w4_group_scaling=getattr(module, 'has_w4afp8', False), + use_int8_woq_per_channel=getattr(module, + 'has_int8_woq_per_channel', + False), + use_mxfp8_act_scaling=getattr(module, 'has_mxfp8_act_scaling', + False), + min_latency_mode=min_latency_mode, + use_fused_finalize=use_fused_finalize, + ) + + # Set tuning configuration + MoERunner.tuning_config.tune_max_num_tokens = getattr( + module, 'tune_max_num_tokens', 8192) + + # Get AutoTuner for tactic selection + tuner = AutoTuner.get() + + # Profile and select tactics (GEMM1) + _, gemm_tactic_1 = tuner.choose_one( + "trtllm::fused_moe::gemm1", + [self.moe_runner], + MoERunner.tuning_config, + [ + tuner_input, + module.w3_w1_weight.view(weight_view_dtype), + getattr(module, 'w3_w1_bias', None), + module.w2_weight.view(weight_view_dtype), + getattr(module, 'w2_bias', None), + ], + gemm_idx=1, + ) + + # Profile and select tactics (GEMM2) + _, gemm_tactic_2 = tuner.choose_one( + "trtllm::fused_moe::gemm2", + [self.moe_runner], + MoERunner.tuning_config, + [ + tuner_input, + module.w3_w1_weight.view(weight_view_dtype), + getattr(module, 'w3_w1_bias', None), + module.w2_weight.view(weight_view_dtype), + getattr(module, 'w2_bias', None), + ], + gemm_idx=2, + ) + + # Store selected tactics + self.gemm_tactics = [gemm_tactic_1, gemm_tactic_2] + + def compute_moe( + self, + module: 'MoE', # Now required as first parameter + # Input tensors + x: torch.Tensor, + token_selected_slots: torch.Tensor, + token_final_scales: Optional[torch.Tensor], + # Weight tensors + w3_w1_weight: torch.Tensor, + w3_w1_bias: Optional[torch.Tensor], + w2_weight: torch.Tensor, + w2_bias: Optional[torch.Tensor], + # Output configuration + output_dtype: torch.dtype, + # Quantization parameters + quant_scales: List[torch.Tensor], + input_sf: Optional[torch.Tensor] = None, + swizzled_input_sf: bool = True, + # Performance tuning (only runtime-variable parameters) + min_latency_mode: bool = False, + use_fused_finalize: bool = True, + tuner_num_tokens: Optional[int] = None, + tuner_top_k: Optional[int] = None, + **kwargs) -> torch.Tensor: + """ + Compute MoE using Cutlass op with MoERunner. + """ + # Extract parameters from module + tp_size = module.tp_size + tp_rank = module.tp_rank + ep_size = module.ep_size + ep_rank = module.ep_rank + cluster_size = module.cluster_size + cluster_rank = module.cluster_rank + enable_alltoall = module.enable_alltoall + swiglu_alpha = module.swiglu_alpha + swiglu_beta = module.swiglu_beta + swiglu_limit = module.swiglu_limit + use_w4_group_scaling = getattr(module, 'has_w4afp8', False) + + # Determine weight dtype for view operation if needed + weight_dtype = w3_w1_weight.dtype + if use_w4_group_scaling and weight_dtype != torch.quint4x2: + weight_dtype = torch.quint4x2 + + # Validate that tactics have been finalized + if self.gemm_tactics is None or len(self.gemm_tactics) == 0: + raise RuntimeError( + "GEMM tactics have not been finalized. " + "Call finalize_tactic() before compute_moe() or use run_moe() instead." + ) + + if self.moe_runner is None: + raise RuntimeError( + "MoERunner has not been initialized. " + "Call finalize_tactic() before compute_moe() or use run_moe() instead." + ) + + # Select the appropriate run method based on latency mode + run_moe = self.moe_runner.fused_moe_runner.run_moe_min_latency if min_latency_mode else self.moe_runner.fused_moe_runner.run_moe + + # Get unpadded_hidden_size from module if available, otherwise use hidden_size + # For now it is the user's responsibility to set unpadded_hidden_size. + # DeepGemmFusedMoE and WideEPMoE both have unpadded_hidden_size. + unpadded_hidden_size = getattr(module, 'unpadded_hidden_size', + x.shape[1]) + + # Run the actual MoE computation + output = run_moe( + x, + token_selected_slots, + token_final_scales, + w3_w1_weight.view(weight_dtype), + w3_w1_bias, + w2_weight.view(weight_dtype), + w2_bias, + quant_scales, + input_sf, + swizzled_input_sf, + swiglu_alpha, + swiglu_beta, + swiglu_limit, + tp_size, + tp_rank, + ep_size, + ep_rank, + cluster_size, + cluster_rank, + enable_alltoall, + min_latency_mode, + self.gemm_tactics, + unpadded_hidden_size, + ) + + # Return output based on latency mode + return output if min_latency_mode else [output] + + def run_moe( + self, + module: 'MoE', + # Input tensors + input: torch.Tensor, + token_selected_slots: torch.Tensor, + token_final_scales: torch.Tensor, + w3_w1_weight: torch.Tensor, + w3_w1_bias: Optional[torch.Tensor], + w2_weight: torch.Tensor, + w2_bias: Optional[torch.Tensor], + output_dtype: torch.dtype, + # Quantization parameters + quant_scales: List[torch.Tensor], + input_sf: Optional[torch.Tensor] = None, + swizzled_input_sf: bool = True, + # Performance tuning (only runtime-variable parameters) + min_latency_mode: bool = False, + use_fused_finalize: bool = True, + tuner_num_tokens: Optional[int] = None, + tuner_top_k: Optional[int] = None, + **kwargs) -> torch.Tensor: + """ + Run the complete MoE computation pipeline for Cutlass op. + + This override handles the specific tuner_input logic needed for Cutlass. + + Args: + module: MoE module containing configuration + input: Input tensor to the MoE layer + token_selected_slots: Selected expert slots for each token + token_final_scales: Final scaling factors for each token + w3_w1_weight: Concatenated weights for w3 and w1 projections + w3_w1_bias: Optional bias for w3/w1 projections + w2_weight: Weight for w2 projection + w2_bias: Optional bias for w2 projection + output_dtype: Desired output data type + quant_scales: Quantization scales for weights + input_sf: Optional input scale factors for quantization + swizzled_input_sf: Whether input scale factors are swizzled + min_latency_mode: Use minimum latency optimizations + use_fused_finalize: Use fused finalization + tuner_num_tokens: Number of tokens for tuner input + tuner_top_k: Top-k value for tuning + + Returns: + Computed MoE output tensor + """ + # Extract enable_alltoall from module to determine tuner_input logic + enable_alltoall = module.enable_alltoall + + # Compute tuner_input per fused_moe logic + if enable_alltoall: + assert tuner_num_tokens is not None + assert tuner_top_k is not None + tuner_input = input[:tuner_num_tokens] + else: + assert tuner_num_tokens is None + assert tuner_top_k is None + tuner_input = input + tuner_top_k = token_selected_slots.size(1) + + self.finalize_tactic(module, tuner_input, output_dtype, + min_latency_mode, use_fused_finalize, tuner_top_k) + + # Call compute_moe with module + return self.compute_moe(module=module, + x=input, + token_selected_slots=token_selected_slots, + token_final_scales=token_final_scales, + w3_w1_weight=w3_w1_weight, + w3_w1_bias=w3_w1_bias, + w2_weight=w2_weight, + w2_bias=w2_bias, + output_dtype=output_dtype, + quant_scales=quant_scales, + input_sf=input_sf, + swizzled_input_sf=swizzled_input_sf, + min_latency_mode=min_latency_mode, + use_fused_finalize=use_fused_finalize, + tuner_num_tokens=tuner_num_tokens, + tuner_top_k=tuner_top_k, + **kwargs) diff --git a/tensorrt_llm/_torch/modules/fused_moe/ops/moe_op_deepgemm.py b/tensorrt_llm/_torch/modules/fused_moe/ops/moe_op_deepgemm.py new file mode 100644 index 00000000000..c191f74cbdf --- /dev/null +++ b/tensorrt_llm/_torch/modules/fused_moe/ops/moe_op_deepgemm.py @@ -0,0 +1,306 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +DeepGemm-based MoE op implementation for GB200 block FP8. +""" + +from typing import TYPE_CHECKING, Dict, List, Optional + +import torch + +from .moe_op import MoEOp + +if TYPE_CHECKING: + from ..interface import MoE + + +class DeepGemmMoEOp(MoEOp): + """DeepGemm-based MoE op for GB200 block FP8.""" + + def __init__(self): + """Initialize DeepGemm op.""" + super().__init__() + import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils + self.fp8_utils = fp8_utils + + from ..fused_moe_deepgemm import deepgemm_fp8_group_blockwise_gemm + self.deepgemm_fp8_group_blockwise_gemm = deepgemm_fp8_group_blockwise_gemm + + def finalize_tactic( + self, + module: 'MoE', + tuner_input: torch.Tensor, + output_dtype: torch.dtype, + min_latency_mode: bool = False, + use_fused_finalize: bool = True, + tuner_top_k: Optional[int] = None, + ) -> None: + """ + No-op for DeepGemm op as it doesn't require tactic profiling. + + Args: + module: The MoE module + tuner_input: Input tensor for tuning + output_dtype: Output dtype + min_latency_mode: Whether to use min-latency mode + use_fused_finalize: Whether to use fused finalize + tuner_top_k: Top-k value for tuning + """ + + def _get_deepgemm_workspace(self, module: 'MoE', m_max: int, + group_size: int) -> Dict[str, torch.Tensor]: + """ + Get workspace for DeepGemm op operations. + + Args: + module: The MoE module containing configuration + m_max: Maximum number of tokens (aligned) + group_size: Group size for quantization + + Returns: + Dictionary containing workspace tensors + """ + import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils + + # Get dimensions from module + hidden_size = module.hidden_size + intermediate_size = module.intermediate_size + expert_size_per_partition = module.expert_size_per_partition + + # Calculate aligned dimensions + m_padded = fp8_utils.align(m_max, 4) + fp8_dim = max(hidden_size, intermediate_size) + scale_k = fp8_utils.ceil_div(fp8_dim, group_size) + scale_k_padded = fp8_utils.align(scale_k, 4) + + # Allocate workspace tensors + workspace = {} + + # Workspace for FP8 activations + workspace["workspace_0"] = torch.empty( + (expert_size_per_partition * m_max * fp8_dim), + dtype=torch.float8_e4m3fn, + device='cuda') + + # Workspace for intermediate results + workspace["workspace_1"] = torch.empty( + (expert_size_per_partition * m_max * + max(intermediate_size * 2, hidden_size)), + dtype=torch.bfloat16, + device='cuda') + + # Workspace for scaling factors + workspace["workspace_sf"] = torch.empty( + expert_size_per_partition * (scale_k_padded // 4) * m_padded, + dtype=torch.int32, + device='cuda') + + return workspace + + def compute_moe( + self, + module: 'MoE', + # Input tensors + x: torch.Tensor, + token_selected_slots: torch.Tensor, + token_final_scales: Optional[torch.Tensor], + # Weight tensors + w3_w1_weight: torch.Tensor, + w3_w1_bias: Optional[torch.Tensor], + w2_weight: torch.Tensor, + w2_bias: Optional[torch.Tensor], + # Output configuration + output_dtype: torch.dtype, + # Quantization parameters + quant_scales: List[torch.Tensor], + input_sf: Optional[torch.Tensor] = None, + swizzled_input_sf: bool = True, + # Performance tuning (only runtime-variable parameters) + min_latency_mode: bool = False, + use_fused_finalize: bool = True, + tuner_num_tokens: Optional[int] = None, + tuner_top_k: Optional[int] = None, + **kwargs) -> torch.Tensor: + """ + Compute MoE using DeepGemm op with block FP8 quantization. + + Note: This assumes the data has already been gathered/alltoall'd + by the WideEP forward_chunk method. + """ + + # Import necessary functions for DeepGemm + from ..fused_moe_deepgemm import (masked_index_copy_group_quant_fp8, + preprocess_after_permute, set_strides, + triton_masked_index_gather) + + # Extract parameters from module + tp_size = module.tp_size + tp_rank = module.tp_rank + ep_size = module.ep_size + ep_rank = module.ep_rank + cluster_size = module.cluster_size + cluster_rank = module.cluster_rank + enable_alltoall = module.enable_alltoall + + # Not supported: min_latency_mode. Raise error if enabled. + if min_latency_mode: + raise NotImplementedError( + "DeepGemm op does not support min_latency_mode=True") + + # Get expert configuration from module + expert_size_per_partition = module.expert_size_per_partition + intermediate_size = module.intermediate_size + hidden_size = x.shape[1] + + # Permute the data for expert-parallel processing + ( + permuted_row_to_unpermuted_row_tensor, + permuted_token_selected_experts_tensor, + permuted_data_tensor, + expert_first_token_offset_tensor, + permuted_token_final_scales_tensor, + unpermuted_row_to_permuted_row_tensor, + ) = torch.ops.trtllm.moe_permute_op( + x, + token_selected_slots, + token_final_scales, + None, # w3_w1_weight + None, # w2_weight + None, # quant_scales + input_sf=input_sf, + num_experts_on_rank=expert_size_per_partition, + tp_size=tp_size, + tp_rank=tp_rank, + ep_size=ep_size, + ep_rank=ep_rank, + cluster_size=cluster_size, + cluster_rank=cluster_rank, + min_latency_mode=min_latency_mode, + use_fp8_block_scaling=True, # Always use block scaling for DeepGemm + ) + + if permuted_data_tensor.numel() == 0: + return torch.zeros_like(x) + + # Preprocess for masked operations + masked_m, token_to_expert_map = preprocess_after_permute( + expert_first_token_offset_tensor, permuted_data_tensor) + + expected_m = (token_selected_slots.numel() + expert_size_per_partition - + 1) // expert_size_per_partition + + # Get workspace for DeepGemm operations + m_max = self.fp8_utils.align(x.shape[0], 128) + workspace = self._get_deepgemm_workspace(module, m_max, 128) + + # Padding and quantization for first GEMM input + m_padded = self.fp8_utils.align(m_max, 4) + scale_k = self.fp8_utils.ceil_div(hidden_size, 128) + scale_k_padded = self.fp8_utils.align(scale_k, 4) + + act_input_fp8 = set_strides(workspace["workspace_0"], + expert_size_per_partition, m_max, + hidden_size) + act_input_sf = set_strides(workspace["workspace_sf"], + expert_size_per_partition, + scale_k_padded // 4, m_padded) + + # Quantize and copy input with masking + act_input_sf = masked_index_copy_group_quant_fp8( + act_input_fp8, + act_input_sf, + permuted_data_tensor, + expert_first_token_offset_tensor, + token_to_expert_map, + group_size=128) + + # First grouped GEMM (w3 and w1) + h1 = set_strides(workspace["workspace_1"], expert_size_per_partition, + m_max, intermediate_size * 2) + + self.deepgemm_fp8_group_blockwise_gemm( + d=h1, + a=act_input_fp8, + b=w3_w1_weight, + sfa=act_input_sf, + sfb=quant_scales[0] if quant_scales else None, + masked_m=masked_m, + expected_m=expected_m, + ) + + # SiLU activation and quantization for second GEMM + act_input_fp8 = set_strides(workspace["workspace_0"], + expert_size_per_partition, m_max, + intermediate_size) + + scale_k = self.fp8_utils.ceil_div(intermediate_size, 128) + scale_k_padded = self.fp8_utils.align(scale_k, 4) + act_input_sf = set_strides(workspace["workspace_sf"], + expert_size_per_partition, + scale_k_padded // 4, m_padded) + + act_input_sf = self.fp8_utils.silu_and_mul_masked_post_quant_fwd( + output=act_input_fp8, + output_scale=act_input_sf, + input=h1, + quant_group_size=128, + masked_m=masked_m, + scale_ue8m0=True) + + # Second grouped GEMM (w2) + h3 = set_strides(workspace["workspace_1"], expert_size_per_partition, + m_max, hidden_size) + + self.deepgemm_fp8_group_blockwise_gemm( + d=h3, + a=act_input_fp8, + b=w2_weight, + sfa=act_input_sf, + sfb=quant_scales[1] if quant_scales else None, + masked_m=masked_m, + expected_m=expected_m, + ) + + # Gather results back to original token order + triton_masked_index_gather(permuted_data_tensor, h3, + expert_first_token_offset_tensor, + token_to_expert_map) + + # Finalize and scale the output + # Get unpadded_hidden_size from module if available, otherwise use hidden_size + # For now it is the user's responsibility to set unpadded_hidden_size. + # DeepGemmFusedMoE and WideEPMoE both have unpadded_hidden_size. + unpadded_hidden_size = getattr(module, 'unpadded_hidden_size', + x.shape[1]) + + final_hidden_states = torch.ops.trtllm.moe_finalize_scale_op( + permuted_data_tensor, + None, # biases (w2_bias could be added here if needed) + token_final_scales, + unpermuted_row_to_permuted_row_tensor, + permuted_row_to_unpermuted_row_tensor, + token_selected_slots, + expert_first_token_offset_tensor, + enable_alltoall, + x.shape[0], # num_rows + x.shape[1], # hidden_size + unpadded_hidden_size, # unpadded_hidden_size (may be different from hidden_size if padding was applied) + module.routing_method.top_k if module else 1, # experts_per_token + expert_size_per_partition, # num_experts_per_node + tp_size, + tp_rank, + ep_size, + ep_rank, + ) + + return final_hidden_states if min_latency_mode else [ + final_hidden_states + ]