diff --git a/docs/source/deployment-guide/quick-start-recipe-for-deepseek-r1-on-trtllm.md b/docs/source/deployment-guide/quick-start-recipe-for-deepseek-r1-on-trtllm.md index cd42786b8fa..8d811c65471 100644 --- a/docs/source/deployment-guide/quick-start-recipe-for-deepseek-r1-on-trtllm.md +++ b/docs/source/deployment-guide/quick-start-recipe-for-deepseek-r1-on-trtllm.md @@ -30,7 +30,7 @@ There are multiple MOE backends inside TRT-LLM, not all of them supporting every | B200/GB200 EP<=8 | NVFP4 | CUTLASS, TRTLLM | | B200/GB200 EP<=8 | FP8 | DEEPGEMM | | GB200 NVL72 EP>8 | NVFP4 | WIDEEP | -| GB200 NVL72 EP>8 | FP8 | N/A (WIP) | +| GB200 NVL72 EP>8 | FP8 | WIDEEP without EPLB | The default moe backend is `CUTLASS`, so for the combination which is not supported by `CUTLASS`, one must set the `moe_config.backend` explicitly to run the model. diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py index a9fb53a3b85..86724a326e6 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py @@ -5,8 +5,9 @@ import torch from tensorrt_llm._mnnvl_utils import MnnvlMemory, MnnvlMoe, MoEAlltoallInfo -from tensorrt_llm._utils import logger +from tensorrt_llm._utils import get_sm_version from tensorrt_llm.functional import AllReduceStrategy +from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping from ...distributed import AllReduce, allgather, reducescatter @@ -16,7 +17,9 @@ from .deep_ep_utils import buffer_pool, deep_ep_installed from .interface import MoE from .moe_load_balancer import get_moe_load_balancer +from .ops import MoEOp, MoEOpSelector from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethod, + DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm, FP8QDQFusedMoEMethod, MoEWeightLoadingMode, NVFP4CutlassFusedMoEMethod, UnquantizedFusedMoEMethod, WInt4AFP8FusedMoEMethod) @@ -90,6 +93,9 @@ def __init__( # If True, the router weight will be multiplied on the input rather than at the end of FC2 self.apply_router_weight_on_input = apply_router_weight_on_input + # Store original hidden size before any potential padding + self.unpadded_hidden_size = self.hidden_size + moe_load_balancer = get_moe_load_balancer() self.layer_load_balancer = None self.repeat_idx = 0 @@ -227,6 +233,9 @@ def __init__( self.enable_dummy_allreduce = os.environ.get( "TRTLLM_ENABLE_DUMMY_ALLREDUCE", "0") == "1" + # MoE op will be lazily initialized when first accessed (see moe_op_impl property) + self._moe_op_impl = None + def _check_configs(self): assert self._weights_created @@ -316,7 +325,10 @@ def _get_quant_method(self): if self.quant_config.layer_quant_mode.has_fp8_qdq(): return FP8QDQFusedMoEMethod() elif self.quant_config.layer_quant_mode.has_fp8_block_scales(): - return DeepSeekFP8BlockScalesFusedMoEMethod() + if get_sm_version() == 100: + return DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm() + else: + return DeepSeekFP8BlockScalesFusedMoEMethod() elif self.quant_config.layer_quant_mode.has_nvfp4(): return NVFP4CutlassFusedMoEMethod() elif self.quant_config.layer_quant_mode.is_int4_weight_only_per_group( @@ -339,6 +351,19 @@ def create_weights(self): self._weights_created = True self._check_configs() + @property + def moe_op_impl(self) -> MoEOp: + """ + Lazily initialize and return the MoE op. + + The op is selected based on hardware capabilities and quantization + configuration, which are only available after weights are created. + """ + if self._moe_op_impl is None: + assert self._weights_created, "Weights must be created before accessing moe_op" + self._moe_op_impl = MoEOpSelector.select_op(self) + return self._moe_op_impl + def dummy_allreduce(self): """ Debug function for eliminating imbalance during performance analysis. @@ -389,8 +414,9 @@ def forward_chunk( if self.layer_load_balancer and is_first_call: self.layer_load_balancer.start_wait_gpu_stage() - use_deepseek_fp8_block_scale = False - use_w4_group_scaling = False + if not use_all_to_all or self.alltoall_method_type != AlltoallMethodType.MNNVL: + pass + weight_dtype = self.w3_w1_weight.dtype token_selected_experts, token_final_scales = self.routing_method.apply( @@ -544,9 +570,8 @@ def forward_chunk( x_sf = x_sf.view((x_row, -1)) elif self.has_deepseek_fp8_block_scales: - use_deepseek_fp8_block_scale = True + pass elif self.has_w4afp8: - use_w4_group_scaling = True weight_dtype = torch.quint4x2 else: raise ValueError( @@ -569,12 +594,8 @@ def forward_chunk( sizes=None if use_dp_padding else all_rank_num_tokens) x_row = x.shape[0] - ep_size = self.ep_size - ep_rank = self.ep_rank w3_w1_weight = self.w3_w1_weight w2_weight = self.w2_weight - cluster_size = self.cluster_size - cluster_rank = self.cluster_rank quant_scales = self.quant_scales if self.alltoall_method_type == AlltoallMethodType.MNNVL: @@ -640,7 +661,8 @@ def forward_chunk( f"Not available alltoall method type: {self.alltoall_method_type!r}" ) - final_hidden_states = torch.ops.trtllm.fused_moe( + final_hidden_states = self.moe_op_impl.run_moe( + self, x, token_selected_slots, token_final_scales, @@ -652,17 +674,8 @@ def forward_chunk( quant_scales=quant_scales, input_sf=x_sf, swizzled_input_sf=False, - tp_size=self.tp_size, - tp_rank=self.tp_rank, - ep_size=ep_size, - ep_rank=ep_rank, - cluster_size=cluster_size, - cluster_rank=cluster_rank, - enable_alltoall=use_all_to_all, - use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale, - use_w4_group_scaling=use_w4_group_scaling, min_latency_mode=False, - tune_max_num_tokens=self.tune_max_num_tokens, + use_fused_finalize=True, tuner_num_tokens=tuner_num_tokens, tuner_top_k=tuner_top_k, ) diff --git a/tensorrt_llm/_torch/modules/fused_moe/ops/__init__.py b/tensorrt_llm/_torch/modules/fused_moe/ops/__init__.py new file mode 100644 index 00000000000..3c6f3bc3a79 --- /dev/null +++ b/tensorrt_llm/_torch/modules/fused_moe/ops/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MoE ops module for different computation implementations.""" + +from .moe_op import MoEOp, MoEOpSelector +from .moe_op_cutlass import CutlassMoEOp +from .moe_op_deepgemm import DeepGemmMoEOp + +__all__ = ['MoEOp', 'MoEOpSelector', 'CutlassMoEOp', 'DeepGemmMoEOp'] diff --git a/tensorrt_llm/_torch/modules/fused_moe/ops/moe_op.py b/tensorrt_llm/_torch/modules/fused_moe/ops/moe_op.py new file mode 100644 index 00000000000..629f9bf7952 --- /dev/null +++ b/tensorrt_llm/_torch/modules/fused_moe/ops/moe_op.py @@ -0,0 +1,230 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +MoE Op abstraction for supporting different MoE computation implementations. +This module provides a unified interface for different MoE ops (Cutlass, DeepGemm, etc.) +""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, List, Optional + +import torch + +from tensorrt_llm._utils import get_sm_version + +if TYPE_CHECKING: + from ..interface import MoE + + +class MoEOp(ABC): + """Abstract base class for MoE computation ops. + + This class provides a strategy pattern for different MoE computation implementations. + It is used by MoE modules (like WideEPMoE) to delegate the actual computation. + + Note: MoEOp is NOT a MoE module itself, but a computation strategy. + The actual MoE module (e.g., WideEPMoE) inherits from MoE and uses MoEOp + for the computation implementation. + """ + + # Op-specific abstract methods + @abstractmethod + def finalize_tactic( + self, + module: 'MoE', + tuner_input: torch.Tensor, + output_dtype: torch.dtype, + min_latency_mode: bool = False, + use_fused_finalize: bool = True, + tuner_top_k: Optional[int] = None, + ) -> None: + """ + Finalize tactics for the MoE computation. + For Cutlass op, this includes profiling and tactic selection. + For DeepGemm op, this can be a no-op. + + Args: + module: The MoE module containing MoE configurations + tuner_input: Real input used for tuning (same shape/layout as non-alltoall) + output_dtype: Output dtype for tuner run + min_latency_mode: Whether to profile for min-latency path + use_fused_finalize: Whether to use fused finalize + tuner_top_k: Top-k value for tuning (Cutlass specific) + """ + + @abstractmethod + def compute_moe( + self, + module: 'MoE', + # Input tensors + x: torch.Tensor, + token_selected_slots: torch.Tensor, + token_final_scales: Optional[torch.Tensor], + # Weight tensors + w3_w1_weight: torch.Tensor, + w3_w1_bias: Optional[torch.Tensor], + w2_weight: torch.Tensor, + w2_bias: Optional[torch.Tensor], + # Output configuration + output_dtype: torch.dtype, + # Quantization parameters + quant_scales: List[torch.Tensor], + input_sf: Optional[torch.Tensor] = None, + swizzled_input_sf: bool = True, + # Performance tuning (only runtime-variable parameters) + min_latency_mode: bool = False, + use_fused_finalize: bool = True, + tuner_num_tokens: Optional[int] = None, + tuner_top_k: Optional[int] = None, + **kwargs) -> torch.Tensor: + """ + Perform the actual MoE computation. + + Configuration parameters (tp_size, ep_size, swiglu params, etc.) are + automatically extracted from the module parameter. + + Args: + module: MoE module containing configuration and parameters. + x: Input tensor + token_selected_slots: Selected expert slots + token_final_scales: Scaling factors + w3_w1_weight: Fused gate and up projection weights + w3_w1_bias: Optional bias + w2_weight: Down projection weights + w2_bias: Optional bias + output_dtype: Output data type + quant_scales: Quantization scales + input_sf: Input scaling factor + swizzled_input_sf: Whether input_sf is swizzled + min_latency_mode: Use minimum latency optimizations + use_fused_finalize: Use fused finalization + tuner_num_tokens: Number of tokens for tuning + tuner_top_k: Top-k value for tuning + + Returns: + Computed MoE output tensor + """ + + def run_moe( + self, + module: 'MoE', + # Input tensors + input: torch.Tensor, + token_selected_slots: torch.Tensor, + token_final_scales: torch.Tensor, + w3_w1_weight: torch.Tensor, + w3_w1_bias: Optional[torch.Tensor], + w2_weight: torch.Tensor, + w2_bias: Optional[torch.Tensor], + output_dtype: torch.dtype, + # Quantization parameters + quant_scales: List[torch.Tensor], + input_sf: Optional[torch.Tensor] = None, + swizzled_input_sf: bool = True, + # Performance tuning (only runtime-variable parameters) + min_latency_mode: bool = False, + use_fused_finalize: bool = True, + tuner_num_tokens: Optional[int] = None, + tuner_top_k: Optional[int] = None, + **kwargs) -> torch.Tensor: + """ + Run the complete MoE computation pipeline. + + Configuration parameters are automatically extracted from the module. + + Args: + module: MoE module containing configuration + input: Input tensor to the MoE layer + token_selected_slots: Selected expert slots for each token + token_final_scales: Final scaling factors for each token + w3_w1_weight: Concatenated weights for w3 and w1 projections + w3_w1_bias: Optional bias for w3/w1 projections + w2_weight: Weight for w2 projection + w2_bias: Optional bias for w2 projection + output_dtype: Desired output data type + quant_scales: Quantization scales for weights + input_sf: Optional input scale factors for quantization + swizzled_input_sf: Whether input scale factors are swizzled + min_latency_mode: Use minimum latency optimizations + use_fused_finalize: Use fused finalization + tuner_num_tokens: Number of tokens for tuner input + tuner_top_k: Top-k value for tuning + + Returns: + Computed MoE output tensor + """ + self.finalize_tactic(module, input, output_dtype, min_latency_mode, + use_fused_finalize, tuner_top_k) + + # Call compute_moe with module + return self.compute_moe(module=module, + x=input, + token_selected_slots=token_selected_slots, + token_final_scales=token_final_scales, + w3_w1_weight=w3_w1_weight, + w3_w1_bias=w3_w1_bias, + w2_weight=w2_weight, + w2_bias=w2_bias, + output_dtype=output_dtype, + quant_scales=quant_scales, + input_sf=input_sf, + swizzled_input_sf=swizzled_input_sf, + min_latency_mode=min_latency_mode, + use_fused_finalize=use_fused_finalize, + tuner_num_tokens=tuner_num_tokens, + tuner_top_k=tuner_top_k, + **kwargs) + + +class MoEOpSelector: + """ + Utility class for selecting the appropriate MoE op based on + hardware capabilities and quantization configuration. + + This class implements the strategy pattern for op selection, + choosing between Cutlass and DeepGemm implementations based on: + - Hardware capabilities (SM version) + - Quantization configuration (block FP8 support) + """ + + @staticmethod + def select_op(module: 'MoE') -> MoEOp: + """ + Select the appropriate MoE op based on module configuration. + + Selection criteria: + - Blackwell (SM100) with block FP8 quantization -> DeepGemm op + - All other configurations -> Cutlass op + + Args: + module: The MoE module containing configuration information + + Returns: + MoEOp: Selected op instance (CutlassMoEOp or DeepGemmMoEOp) + + Example: + >>> op = MoEOpSelector.select_op(moe_module) + >>> output = op.run_moe(input, ...) + """ + from .moe_op_cutlass import CutlassMoEOp + from .moe_op_deepgemm import DeepGemmMoEOp + + # Check if we should use DeepGemm op + # Blackwell has SM version 100 + is_blackwell = get_sm_version() == 100 + has_block_fp8 = module.has_deepseek_fp8_block_scales + + if is_blackwell and has_block_fp8: + # Use DeepGemm op for Blackwell with block FP8 + return DeepGemmMoEOp() + else: + # Use Cutlass op for all other cases + return CutlassMoEOp() diff --git a/tensorrt_llm/_torch/modules/fused_moe/ops/moe_op_cutlass.py b/tensorrt_llm/_torch/modules/fused_moe/ops/moe_op_cutlass.py new file mode 100644 index 00000000000..45f27dc8235 --- /dev/null +++ b/tensorrt_llm/_torch/modules/fused_moe/ops/moe_op_cutlass.py @@ -0,0 +1,306 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Cutlass-based MoE op implementation. +""" + +from typing import TYPE_CHECKING, List, Optional + +import torch + +from .moe_op import MoEOp + +if TYPE_CHECKING: + from ..interface import MoE + + +class CutlassMoEOp(MoEOp): + """Cutlass-based MoE op using torch.ops.trtllm.fused_moe.""" + + def __init__(self): + """Initialize the Cutlass op.""" + super().__init__() + self.moe_runner = None + self.gemm_tactics = None + + def finalize_tactic( + self, + module: 'MoE', + tuner_input: torch.Tensor, + output_dtype: torch.dtype, + min_latency_mode: bool = False, + use_fused_finalize: bool = True, + tuner_top_k: Optional[int] = None, + ) -> None: + """ + Finalize tactics for Cutlass MoE by profiling and selecting optimal GEMM tactics. + """ + + # Import necessary modules for profiling + from ....custom_ops.torch_custom_ops import AutoTuner, MoERunner + + # Use real tuner_input rather than dummy input + assert tuner_input is not None, "tuner_input must be provided to finalize_tactic" + if tuner_top_k is None: + tuner_top_k = getattr(module.routing_method, 'experts_per_token', 1) + + # Determine view dtype for weights to match runtime quantization layout + weight_view_dtype = module.w3_w1_weight.dtype + if getattr(module, 'has_w4afp8', False): + weight_view_dtype = torch.quint4x2 + elif module.has_w4a16_mxfp4: + weight_view_dtype = torch.uint8 + + # Create MoERunner for profiling + if self.moe_runner is None: + self.moe_runner = MoERunner( + x_dtype=tuner_input.dtype, + weight_dtype=weight_view_dtype, + output_dtype=output_dtype, + top_k=tuner_top_k, + tp_size=module.tp_size, + tp_rank=module.tp_rank, + ep_size=module.ep_size, + ep_rank=module.ep_rank, + cluster_size=module.cluster_size, + cluster_rank=module.cluster_rank, + use_deepseek_fp8_block_scale=module. + has_deepseek_fp8_block_scales, + use_w4_group_scaling=getattr(module, 'has_w4afp8', False), + use_int8_woq_per_channel=getattr(module, + 'has_int8_woq_per_channel', + False), + use_mxfp8_act_scaling=getattr(module, 'has_mxfp8_act_scaling', + False), + min_latency_mode=min_latency_mode, + use_fused_finalize=use_fused_finalize, + ) + + # Set tuning configuration + MoERunner.tuning_config.tune_max_num_tokens = getattr( + module, 'tune_max_num_tokens', 8192) + + # Get AutoTuner for tactic selection + tuner = AutoTuner.get() + + # Profile and select tactics (GEMM1) + _, gemm_tactic_1 = tuner.choose_one( + "trtllm::fused_moe::gemm1", + [self.moe_runner], + MoERunner.tuning_config, + [ + tuner_input, + module.w3_w1_weight.view(weight_view_dtype), + getattr(module, 'w3_w1_bias', None), + module.w2_weight.view(weight_view_dtype), + getattr(module, 'w2_bias', None), + ], + gemm_idx=1, + ) + + # Profile and select tactics (GEMM2) + _, gemm_tactic_2 = tuner.choose_one( + "trtllm::fused_moe::gemm2", + [self.moe_runner], + MoERunner.tuning_config, + [ + tuner_input, + module.w3_w1_weight.view(weight_view_dtype), + getattr(module, 'w3_w1_bias', None), + module.w2_weight.view(weight_view_dtype), + getattr(module, 'w2_bias', None), + ], + gemm_idx=2, + ) + + # Store selected tactics + self.gemm_tactics = [gemm_tactic_1, gemm_tactic_2] + + def compute_moe( + self, + module: 'MoE', # Now required as first parameter + # Input tensors + x: torch.Tensor, + token_selected_slots: torch.Tensor, + token_final_scales: Optional[torch.Tensor], + # Weight tensors + w3_w1_weight: torch.Tensor, + w3_w1_bias: Optional[torch.Tensor], + w2_weight: torch.Tensor, + w2_bias: Optional[torch.Tensor], + # Output configuration + output_dtype: torch.dtype, + # Quantization parameters + quant_scales: List[torch.Tensor], + input_sf: Optional[torch.Tensor] = None, + swizzled_input_sf: bool = True, + # Performance tuning (only runtime-variable parameters) + min_latency_mode: bool = False, + use_fused_finalize: bool = True, + tuner_num_tokens: Optional[int] = None, + tuner_top_k: Optional[int] = None, + **kwargs) -> torch.Tensor: + """ + Compute MoE using Cutlass op with MoERunner. + """ + # Extract parameters from module + tp_size = module.tp_size + tp_rank = module.tp_rank + ep_size = module.ep_size + ep_rank = module.ep_rank + cluster_size = module.cluster_size + cluster_rank = module.cluster_rank + enable_alltoall = module.enable_alltoall + swiglu_alpha = module.swiglu_alpha + swiglu_beta = module.swiglu_beta + swiglu_limit = module.swiglu_limit + use_w4_group_scaling = getattr(module, 'has_w4afp8', False) + + # Determine weight dtype for view operation if needed + weight_dtype = w3_w1_weight.dtype + if use_w4_group_scaling and weight_dtype != torch.quint4x2: + weight_dtype = torch.quint4x2 + + # Validate that tactics have been finalized + if self.gemm_tactics is None or len(self.gemm_tactics) == 0: + raise RuntimeError( + "GEMM tactics have not been finalized. " + "Call finalize_tactic() before compute_moe() or use run_moe() instead." + ) + + if self.moe_runner is None: + raise RuntimeError( + "MoERunner has not been initialized. " + "Call finalize_tactic() before compute_moe() or use run_moe() instead." + ) + + # Select the appropriate run method based on latency mode + run_moe = self.moe_runner.fused_moe_runner.run_moe_min_latency if min_latency_mode else self.moe_runner.fused_moe_runner.run_moe + + # Get unpadded_hidden_size from module if available, otherwise use hidden_size + # For now it is the user's responsibility to set unpadded_hidden_size. + # DeepGemmFusedMoE and WideEPMoE both have unpadded_hidden_size. + unpadded_hidden_size = getattr(module, 'unpadded_hidden_size', + x.shape[1]) + + # Run the actual MoE computation + output = run_moe( + x, + token_selected_slots, + token_final_scales, + w3_w1_weight.view(weight_dtype), + w3_w1_bias, + w2_weight.view(weight_dtype), + w2_bias, + quant_scales, + input_sf, + swizzled_input_sf, + swiglu_alpha, + swiglu_beta, + swiglu_limit, + tp_size, + tp_rank, + ep_size, + ep_rank, + cluster_size, + cluster_rank, + enable_alltoall, + min_latency_mode, + self.gemm_tactics, + unpadded_hidden_size, + ) + + # Return output based on latency mode + return output if min_latency_mode else [output] + + def run_moe( + self, + module: 'MoE', + # Input tensors + input: torch.Tensor, + token_selected_slots: torch.Tensor, + token_final_scales: torch.Tensor, + w3_w1_weight: torch.Tensor, + w3_w1_bias: Optional[torch.Tensor], + w2_weight: torch.Tensor, + w2_bias: Optional[torch.Tensor], + output_dtype: torch.dtype, + # Quantization parameters + quant_scales: List[torch.Tensor], + input_sf: Optional[torch.Tensor] = None, + swizzled_input_sf: bool = True, + # Performance tuning (only runtime-variable parameters) + min_latency_mode: bool = False, + use_fused_finalize: bool = True, + tuner_num_tokens: Optional[int] = None, + tuner_top_k: Optional[int] = None, + **kwargs) -> torch.Tensor: + """ + Run the complete MoE computation pipeline for Cutlass op. + + This override handles the specific tuner_input logic needed for Cutlass. + + Args: + module: MoE module containing configuration + input: Input tensor to the MoE layer + token_selected_slots: Selected expert slots for each token + token_final_scales: Final scaling factors for each token + w3_w1_weight: Concatenated weights for w3 and w1 projections + w3_w1_bias: Optional bias for w3/w1 projections + w2_weight: Weight for w2 projection + w2_bias: Optional bias for w2 projection + output_dtype: Desired output data type + quant_scales: Quantization scales for weights + input_sf: Optional input scale factors for quantization + swizzled_input_sf: Whether input scale factors are swizzled + min_latency_mode: Use minimum latency optimizations + use_fused_finalize: Use fused finalization + tuner_num_tokens: Number of tokens for tuner input + tuner_top_k: Top-k value for tuning + + Returns: + Computed MoE output tensor + """ + # Extract enable_alltoall from module to determine tuner_input logic + enable_alltoall = module.enable_alltoall + + # Compute tuner_input per fused_moe logic + if enable_alltoall: + assert tuner_num_tokens is not None + assert tuner_top_k is not None + tuner_input = input[:tuner_num_tokens] + else: + assert tuner_num_tokens is None + assert tuner_top_k is None + tuner_input = input + tuner_top_k = token_selected_slots.size(1) + + self.finalize_tactic(module, tuner_input, output_dtype, + min_latency_mode, use_fused_finalize, tuner_top_k) + + # Call compute_moe with module + return self.compute_moe(module=module, + x=input, + token_selected_slots=token_selected_slots, + token_final_scales=token_final_scales, + w3_w1_weight=w3_w1_weight, + w3_w1_bias=w3_w1_bias, + w2_weight=w2_weight, + w2_bias=w2_bias, + output_dtype=output_dtype, + quant_scales=quant_scales, + input_sf=input_sf, + swizzled_input_sf=swizzled_input_sf, + min_latency_mode=min_latency_mode, + use_fused_finalize=use_fused_finalize, + tuner_num_tokens=tuner_num_tokens, + tuner_top_k=tuner_top_k, + **kwargs) diff --git a/tensorrt_llm/_torch/modules/fused_moe/ops/moe_op_deepgemm.py b/tensorrt_llm/_torch/modules/fused_moe/ops/moe_op_deepgemm.py new file mode 100644 index 00000000000..c191f74cbdf --- /dev/null +++ b/tensorrt_llm/_torch/modules/fused_moe/ops/moe_op_deepgemm.py @@ -0,0 +1,306 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +DeepGemm-based MoE op implementation for GB200 block FP8. +""" + +from typing import TYPE_CHECKING, Dict, List, Optional + +import torch + +from .moe_op import MoEOp + +if TYPE_CHECKING: + from ..interface import MoE + + +class DeepGemmMoEOp(MoEOp): + """DeepGemm-based MoE op for GB200 block FP8.""" + + def __init__(self): + """Initialize DeepGemm op.""" + super().__init__() + import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils + self.fp8_utils = fp8_utils + + from ..fused_moe_deepgemm import deepgemm_fp8_group_blockwise_gemm + self.deepgemm_fp8_group_blockwise_gemm = deepgemm_fp8_group_blockwise_gemm + + def finalize_tactic( + self, + module: 'MoE', + tuner_input: torch.Tensor, + output_dtype: torch.dtype, + min_latency_mode: bool = False, + use_fused_finalize: bool = True, + tuner_top_k: Optional[int] = None, + ) -> None: + """ + No-op for DeepGemm op as it doesn't require tactic profiling. + + Args: + module: The MoE module + tuner_input: Input tensor for tuning + output_dtype: Output dtype + min_latency_mode: Whether to use min-latency mode + use_fused_finalize: Whether to use fused finalize + tuner_top_k: Top-k value for tuning + """ + + def _get_deepgemm_workspace(self, module: 'MoE', m_max: int, + group_size: int) -> Dict[str, torch.Tensor]: + """ + Get workspace for DeepGemm op operations. + + Args: + module: The MoE module containing configuration + m_max: Maximum number of tokens (aligned) + group_size: Group size for quantization + + Returns: + Dictionary containing workspace tensors + """ + import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils + + # Get dimensions from module + hidden_size = module.hidden_size + intermediate_size = module.intermediate_size + expert_size_per_partition = module.expert_size_per_partition + + # Calculate aligned dimensions + m_padded = fp8_utils.align(m_max, 4) + fp8_dim = max(hidden_size, intermediate_size) + scale_k = fp8_utils.ceil_div(fp8_dim, group_size) + scale_k_padded = fp8_utils.align(scale_k, 4) + + # Allocate workspace tensors + workspace = {} + + # Workspace for FP8 activations + workspace["workspace_0"] = torch.empty( + (expert_size_per_partition * m_max * fp8_dim), + dtype=torch.float8_e4m3fn, + device='cuda') + + # Workspace for intermediate results + workspace["workspace_1"] = torch.empty( + (expert_size_per_partition * m_max * + max(intermediate_size * 2, hidden_size)), + dtype=torch.bfloat16, + device='cuda') + + # Workspace for scaling factors + workspace["workspace_sf"] = torch.empty( + expert_size_per_partition * (scale_k_padded // 4) * m_padded, + dtype=torch.int32, + device='cuda') + + return workspace + + def compute_moe( + self, + module: 'MoE', + # Input tensors + x: torch.Tensor, + token_selected_slots: torch.Tensor, + token_final_scales: Optional[torch.Tensor], + # Weight tensors + w3_w1_weight: torch.Tensor, + w3_w1_bias: Optional[torch.Tensor], + w2_weight: torch.Tensor, + w2_bias: Optional[torch.Tensor], + # Output configuration + output_dtype: torch.dtype, + # Quantization parameters + quant_scales: List[torch.Tensor], + input_sf: Optional[torch.Tensor] = None, + swizzled_input_sf: bool = True, + # Performance tuning (only runtime-variable parameters) + min_latency_mode: bool = False, + use_fused_finalize: bool = True, + tuner_num_tokens: Optional[int] = None, + tuner_top_k: Optional[int] = None, + **kwargs) -> torch.Tensor: + """ + Compute MoE using DeepGemm op with block FP8 quantization. + + Note: This assumes the data has already been gathered/alltoall'd + by the WideEP forward_chunk method. + """ + + # Import necessary functions for DeepGemm + from ..fused_moe_deepgemm import (masked_index_copy_group_quant_fp8, + preprocess_after_permute, set_strides, + triton_masked_index_gather) + + # Extract parameters from module + tp_size = module.tp_size + tp_rank = module.tp_rank + ep_size = module.ep_size + ep_rank = module.ep_rank + cluster_size = module.cluster_size + cluster_rank = module.cluster_rank + enable_alltoall = module.enable_alltoall + + # Not supported: min_latency_mode. Raise error if enabled. + if min_latency_mode: + raise NotImplementedError( + "DeepGemm op does not support min_latency_mode=True") + + # Get expert configuration from module + expert_size_per_partition = module.expert_size_per_partition + intermediate_size = module.intermediate_size + hidden_size = x.shape[1] + + # Permute the data for expert-parallel processing + ( + permuted_row_to_unpermuted_row_tensor, + permuted_token_selected_experts_tensor, + permuted_data_tensor, + expert_first_token_offset_tensor, + permuted_token_final_scales_tensor, + unpermuted_row_to_permuted_row_tensor, + ) = torch.ops.trtllm.moe_permute_op( + x, + token_selected_slots, + token_final_scales, + None, # w3_w1_weight + None, # w2_weight + None, # quant_scales + input_sf=input_sf, + num_experts_on_rank=expert_size_per_partition, + tp_size=tp_size, + tp_rank=tp_rank, + ep_size=ep_size, + ep_rank=ep_rank, + cluster_size=cluster_size, + cluster_rank=cluster_rank, + min_latency_mode=min_latency_mode, + use_fp8_block_scaling=True, # Always use block scaling for DeepGemm + ) + + if permuted_data_tensor.numel() == 0: + return torch.zeros_like(x) + + # Preprocess for masked operations + masked_m, token_to_expert_map = preprocess_after_permute( + expert_first_token_offset_tensor, permuted_data_tensor) + + expected_m = (token_selected_slots.numel() + expert_size_per_partition - + 1) // expert_size_per_partition + + # Get workspace for DeepGemm operations + m_max = self.fp8_utils.align(x.shape[0], 128) + workspace = self._get_deepgemm_workspace(module, m_max, 128) + + # Padding and quantization for first GEMM input + m_padded = self.fp8_utils.align(m_max, 4) + scale_k = self.fp8_utils.ceil_div(hidden_size, 128) + scale_k_padded = self.fp8_utils.align(scale_k, 4) + + act_input_fp8 = set_strides(workspace["workspace_0"], + expert_size_per_partition, m_max, + hidden_size) + act_input_sf = set_strides(workspace["workspace_sf"], + expert_size_per_partition, + scale_k_padded // 4, m_padded) + + # Quantize and copy input with masking + act_input_sf = masked_index_copy_group_quant_fp8( + act_input_fp8, + act_input_sf, + permuted_data_tensor, + expert_first_token_offset_tensor, + token_to_expert_map, + group_size=128) + + # First grouped GEMM (w3 and w1) + h1 = set_strides(workspace["workspace_1"], expert_size_per_partition, + m_max, intermediate_size * 2) + + self.deepgemm_fp8_group_blockwise_gemm( + d=h1, + a=act_input_fp8, + b=w3_w1_weight, + sfa=act_input_sf, + sfb=quant_scales[0] if quant_scales else None, + masked_m=masked_m, + expected_m=expected_m, + ) + + # SiLU activation and quantization for second GEMM + act_input_fp8 = set_strides(workspace["workspace_0"], + expert_size_per_partition, m_max, + intermediate_size) + + scale_k = self.fp8_utils.ceil_div(intermediate_size, 128) + scale_k_padded = self.fp8_utils.align(scale_k, 4) + act_input_sf = set_strides(workspace["workspace_sf"], + expert_size_per_partition, + scale_k_padded // 4, m_padded) + + act_input_sf = self.fp8_utils.silu_and_mul_masked_post_quant_fwd( + output=act_input_fp8, + output_scale=act_input_sf, + input=h1, + quant_group_size=128, + masked_m=masked_m, + scale_ue8m0=True) + + # Second grouped GEMM (w2) + h3 = set_strides(workspace["workspace_1"], expert_size_per_partition, + m_max, hidden_size) + + self.deepgemm_fp8_group_blockwise_gemm( + d=h3, + a=act_input_fp8, + b=w2_weight, + sfa=act_input_sf, + sfb=quant_scales[1] if quant_scales else None, + masked_m=masked_m, + expected_m=expected_m, + ) + + # Gather results back to original token order + triton_masked_index_gather(permuted_data_tensor, h3, + expert_first_token_offset_tensor, + token_to_expert_map) + + # Finalize and scale the output + # Get unpadded_hidden_size from module if available, otherwise use hidden_size + # For now it is the user's responsibility to set unpadded_hidden_size. + # DeepGemmFusedMoE and WideEPMoE both have unpadded_hidden_size. + unpadded_hidden_size = getattr(module, 'unpadded_hidden_size', + x.shape[1]) + + final_hidden_states = torch.ops.trtllm.moe_finalize_scale_op( + permuted_data_tensor, + None, # biases (w2_bias could be added here if needed) + token_final_scales, + unpermuted_row_to_permuted_row_tensor, + permuted_row_to_unpermuted_row_tensor, + token_selected_slots, + expert_first_token_offset_tensor, + enable_alltoall, + x.shape[0], # num_rows + x.shape[1], # hidden_size + unpadded_hidden_size, # unpadded_hidden_size (may be different from hidden_size if padding was applied) + module.routing_method.top_k if module else 1, # experts_per_token + expert_size_per_partition, # num_experts_per_node + tp_size, + tp_rank, + ep_size, + ep_rank, + ) + + return final_hidden_states if min_latency_mode else [ + final_hidden_states + ] diff --git a/tests/unittest/_torch/modules/test_fused_moe.py b/tests/unittest/_torch/modules/test_fused_moe.py index 2158e5fef30..2461f85984d 100644 --- a/tests/unittest/_torch/modules/test_fused_moe.py +++ b/tests/unittest/_torch/modules/test_fused_moe.py @@ -638,6 +638,161 @@ def set_tensor_value_4(x, num_row, num_cols): x.copy_(repeated) +@skip_pre_blackwell +@pytest.mark.skipif(torch.cuda.device_count() < 4, + reason="needs 4 GPUs to run this test") +@pytest.mark.parametrize( + "alltoall_method_type", + [AlltoallMethodType.MNNVL, AlltoallMethodType.NotEnabled], + ids=lambda s: s.name) +def test_fused_moe_fp8_blockwise_wide_ep(alltoall_method_type): + """Test WideEPMoE with FP8 block-wise quantization using DeepGemmFusedMoE as reference.""" + + world_size = 4 + dtype = torch.bfloat16 + # Reduce model size to avoid MPI int32 overflow + HIDDEN_SIZE = 768 + INTERMEDIATE_SIZE = 512 + NUM_EXPERTS = 16 + TOP_K = 2 + MAX_NUM_TOKENS = 256 + + # The MPI can not support FP8, so create weights on each rank + def per_rank_test_fused_moe_alltoall_fp8_blockwise(job_id): + routing_method = DefaultMoeRoutingMethod(top_k=TOP_K) + mapping = Mapping(world_size=world_size, + rank=mpi_rank(), + tp_size=world_size, + moe_ep_size=world_size, + moe_tp_size=1, + enable_attention_dp=True) + torch.cuda.set_device(mapping.rank) + # Use same seed for all ranks to ensure consistency + torch.manual_seed(0) + torch.cuda.manual_seed(0) + + # Generate test data locally on each rank + x_list = [] + m = MAX_NUM_TOKENS + while m >= 1: + x = torch.randn((m, HIDDEN_SIZE), dtype=dtype, device="cuda") + set_tensor_value_2(x, m, HIDDEN_SIZE) + x_list.append(x) + m //= 2 + + # Generate weights locally on each rank (same weights due to same seed) + weights = {} + for expert_id in range(NUM_EXPERTS): + w1_weight = torch.randn( + (INTERMEDIATE_SIZE, HIDDEN_SIZE), dtype=dtype, + device="cuda") / HIDDEN_SIZE + w2_weight = torch.randn((HIDDEN_SIZE, INTERMEDIATE_SIZE), + dtype=dtype, + device="cuda") + w3_weight = torch.randn( + (INTERMEDIATE_SIZE, HIDDEN_SIZE), dtype=dtype, + device="cuda") / HIDDEN_SIZE + + set_tensor_value_3(w1_weight, INTERMEDIATE_SIZE, HIDDEN_SIZE) + set_tensor_value_4(w2_weight, HIDDEN_SIZE, INTERMEDIATE_SIZE) + set_tensor_value_3(w3_weight, INTERMEDIATE_SIZE, HIDDEN_SIZE) + + # FP8 block-wise quantization + w1_weight_fp8, w1_weight_scale = per_block_cast_to_fp8_e8m0( + w1_weight) + w1_weight_fp8 = w1_weight_fp8.view(torch.float8_e4m3fn).cuda() + + w2_weight_fp8, w2_weight_scale = per_block_cast_to_fp8_e8m0( + w2_weight) + w2_weight_fp8 = w2_weight_fp8.view(torch.float8_e4m3fn).cuda() + + w3_weight_fp8, w3_weight_scale = per_block_cast_to_fp8_e8m0( + w3_weight) + w3_weight_fp8 = w3_weight_fp8.view(torch.float8_e4m3fn).cuda() + + weights[f"{expert_id}.w1.weight"] = w1_weight_fp8 + weights[f"{expert_id}.w2.weight"] = w2_weight_fp8 + weights[f"{expert_id}.w3.weight"] = w3_weight_fp8 + weights[f"{expert_id}.w1.weight_scale_inv"] = w1_weight_scale + weights[f"{expert_id}.w2.weight_scale_inv"] = w2_weight_scale + weights[f"{expert_id}.w3.weight_scale_inv"] = w3_weight_scale + weights[f"{expert_id}.w1.weight_scale"] = w1_weight_scale + weights[f"{expert_id}.w2.weight_scale"] = w2_weight_scale + weights[f"{expert_id}.w3.weight_scale"] = w3_weight_scale + + quant_config = QuantConfig(quant_algo=QuantAlgo.FP8_BLOCK_SCALES) + + # Test WideEPMoE with alltoall method + with mock.patch.object(WideEPMoE, + "select_alltoall_method_type", + return_value=alltoall_method_type): + alltoall_model = WideEPMoE( + num_experts=NUM_EXPERTS, + routing_method=routing_method, + hidden_size=HIDDEN_SIZE, + intermediate_size=INTERMEDIATE_SIZE, + dtype=dtype, + reduce_results=True, + model_config=ModelConfig(mapping=mapping, + max_num_tokens=MAX_NUM_TOKENS, + quant_config=quant_config), + ) + alltoall_model.to("cuda") + alltoall_model.load_weights([weights]) + + # Use DeepGemmFusedMoE as reference + ref_model = DeepGemmFusedMoE( + num_experts=NUM_EXPERTS, + routing_method=routing_method, + hidden_size=HIDDEN_SIZE, + intermediate_size=INTERMEDIATE_SIZE, + dtype=dtype, + reduce_results=True, + model_config=ModelConfig(mapping=mapping, + max_num_tokens=MAX_NUM_TOKENS, + quant_config=quant_config), + ) + ref_model.to("cuda") + ref_model.load_weights([weights]) + + # Evaluate the outputs on variant sequence lengths + m = MAX_NUM_TOKENS + i = 0 + while m >= 1: + x = x_list[i] + i += 1 + router_logits = torch.randn((m, NUM_EXPERTS), + dtype=dtype, + device="cuda") + all_rank_num_tokens = [m] * mapping.world_size + with torch.inference_mode(): + output = alltoall_model.forward( + x, + router_logits, + all_rank_num_tokens=all_rank_num_tokens, + all_rank_max_num_tokens=m, + use_dp_padding=False) + ref_output = ref_model.forward( + x, + router_logits, + all_rank_num_tokens=all_rank_num_tokens, + all_rank_max_num_tokens=m, + use_dp_padding=False) + + # Evaluate outputs with relaxed tolerance for FP8 + # If WideEPMoE output has TOP_K dimension, reduce it to match DeepGemmFusedMoE + if output.dim() == 3 and output.shape[1] == TOP_K: + output = output.sum(dim=1) + torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1) + m //= 2 + + with MPIPoolExecutor(max_workers=world_size) as executor: + results = executor.map(per_rank_test_fused_moe_alltoall_fp8_blockwise, + range(world_size)) + for r in results: + assert r is None + + @skip_pre_blackwell @pytest.mark.parametrize( "dtype, num_experts, seq_len, hidden_size, RoutingMethodCls",