From 54dbf26fb41c91afb900abd7492718bfcbcf16c2 Mon Sep 17 00:00:00 2001 From: Yukun He <23156053+hyukn@users.noreply.github.com> Date: Sat, 21 Jun 2025 13:19:23 +0000 Subject: [PATCH] [TRTLLM-6019] feat: Remove cutlass min latency code from AutoTuner. Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> --- .../_torch/custom_ops/torch_custom_ops.py | 46 +++++-------------- 1 file changed, 12 insertions(+), 34 deletions(-) diff --git a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py index 9b8b2f059b2..e94ee6df448 100644 --- a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py @@ -22,12 +22,9 @@ def bmm_out(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor) -> None: class MoERunner(TunableRunner): # avoid overhead of creating a new runner in forward pass runner_dict = dict() - # TODO: only profile for min_latency_mode = False due to the error in the moe_kernels tuning_config = TuningConfig(dynamic_tensor_specs=( DynamicTensorSpec(0, 0, get_last_power_of_2_num_tokens_buckets(8192), - lambda x: min(last_positive_power_of_2(x), 8192)), - DynamicTensorSpec(3, 0, (0, ), lambda x: x), - )) + lambda x: min(last_positive_power_of_2(x), 8192)), )) def __init__( self, @@ -44,6 +41,7 @@ def __init__( enable_alltoall: bool, use_deepseek_fp8_block_scale: bool, use_w4a8_group_scaling: bool, + min_latency_mode: bool, ): self.x_dtype = x_dtype self.weight_dtype = weight_dtype @@ -58,7 +56,7 @@ def __init__( self.enable_alltoall = enable_alltoall self.use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale self.use_w4a8_group_scaling = use_w4a8_group_scaling - + self.min_latency_mode = min_latency_mode instance_key = (x_dtype, weight_dtype, output_dtype, use_deepseek_fp8_block_scale, use_w4a8_group_scaling) @@ -74,22 +72,7 @@ def get_valid_tactics( inputs: List[torch.Tensor], profile: OptimizationProfile, ) -> List[int]: - x, _, _, min_latency_mode_tensor = inputs - min_latency_mode = min_latency_mode_tensor.size(0) == 1 - m = x.shape[0] - - # Only profile m <= 128 for min latency mode = True - # Profile all valid buckets for min latency mode = False - # TODO: min_latency_mode = True will cause the following error: - # Cannot profile configuration 4: Cutlass GEMM Tactic - # [TensorRT-LLM][ERROR] Assertion failed: Failed to initialize cutlass TMA WS grouped gemm. - # Should be fixed in the moe_kernels in the future. - invalid = (m > 128 and - min_latency_mode) or (m <= 128 and min_latency_mode and - (not self.weight_dtype == torch.int64)) - - return [] if invalid else list( - range(self.fused_moe_runner.get_tactic_num())) + return range(self.fused_moe_runner.get_tactic_num()) def forward( self, @@ -98,8 +81,7 @@ def forward( tactic: int = -1, do_preparation: bool = False, ): - x, fc1_expert_weights, fc2_expert_weights, min_latency_mode_tensor = inputs - min_latency_mode = min_latency_mode_tensor.size(0) == 1 + x, fc1_expert_weights, fc2_expert_weights = inputs # determine if we should use min latency mode according to the profiled seq len self.fused_moe_runner.run_gemm_profile( x, @@ -113,7 +95,7 @@ def forward( self.cluster_size, self.cluster_rank, self.enable_alltoall, - min_latency_mode, + self.min_latency_mode, gemm_idx, tactic, do_preparation, @@ -122,13 +104,11 @@ def forward( @classmethod @lru_cache(maxsize=None) def refine_tuning_config(cls, tune_max_num_tokens: int): - cls.tuning_config = TuningConfig(dynamic_tensor_specs=( - DynamicTensorSpec( + cls.tuning_config = TuningConfig( + dynamic_tensor_specs=(DynamicTensorSpec( 0, 0, get_last_power_of_2_num_tokens_buckets( tune_max_num_tokens), lambda x: min( - last_positive_power_of_2(x), tune_max_num_tokens)), - DynamicTensorSpec(3, 0, (0, ), lambda x: x), - )) + last_positive_power_of_2(x), tune_max_num_tokens)), )) @torch.library.custom_op("trtllm::fused_moe", mutates_args=()) @@ -157,9 +137,6 @@ def fused_moe( tuner = AutoTuner.get() MoERunner.refine_tuning_config(tune_max_num_tokens) - # TODO: set min_latency_mode always to False due to the error in the moe_kernels - min_latency_tensor = torch.empty(0) - # allocate workspace for profiling moe_runner = MoERunner( x_dtype=input.dtype, @@ -175,13 +152,14 @@ def fused_moe( enable_alltoall=enable_alltoall, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale, use_w4a8_group_scaling=use_w4a8_group_scaling, + min_latency_mode=min_latency_mode, ) _, gemm_tactic_1 = tuner.choose_one( "trtllm::fused_moe::gemm1", [moe_runner], MoERunner.tuning_config, - [input, fc1_expert_weights, fc2_expert_weights, min_latency_tensor], + [input, fc1_expert_weights, fc2_expert_weights], gemm_idx=1, ) @@ -189,7 +167,7 @@ def fused_moe( "trtllm::fused_moe::gemm2", [moe_runner], MoERunner.tuning_config, - [input, fc1_expert_weights, fc2_expert_weights, min_latency_tensor], + [input, fc1_expert_weights, fc2_expert_weights], gemm_idx=2, )