Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 12 additions & 34 deletions tensorrt_llm/_torch/custom_ops/torch_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,9 @@ def bmm_out(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor) -> None:
class MoERunner(TunableRunner):
# avoid overhead of creating a new runner in forward pass
runner_dict = dict()
# TODO: only profile for min_latency_mode = False due to the error in the moe_kernels
tuning_config = TuningConfig(dynamic_tensor_specs=(
DynamicTensorSpec(0, 0, get_last_power_of_2_num_tokens_buckets(8192),
lambda x: min(last_positive_power_of_2(x), 8192)),
DynamicTensorSpec(3, 0, (0, ), lambda x: x),
))
lambda x: min(last_positive_power_of_2(x), 8192)), ))

def __init__(
self,
Expand All @@ -44,6 +41,7 @@ def __init__(
enable_alltoall: bool,
use_deepseek_fp8_block_scale: bool,
use_w4a8_group_scaling: bool,
min_latency_mode: bool,
):
self.x_dtype = x_dtype
self.weight_dtype = weight_dtype
Expand All @@ -58,7 +56,7 @@ def __init__(
self.enable_alltoall = enable_alltoall
self.use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale
self.use_w4a8_group_scaling = use_w4a8_group_scaling

self.min_latency_mode = min_latency_mode
instance_key = (x_dtype, weight_dtype, output_dtype,
use_deepseek_fp8_block_scale, use_w4a8_group_scaling)

Expand All @@ -74,22 +72,7 @@ def get_valid_tactics(
inputs: List[torch.Tensor],
profile: OptimizationProfile,
) -> List[int]:
x, _, _, min_latency_mode_tensor = inputs
min_latency_mode = min_latency_mode_tensor.size(0) == 1
m = x.shape[0]

# Only profile m <= 128 for min latency mode = True
# Profile all valid buckets for min latency mode = False
# TODO: min_latency_mode = True will cause the following error:
# Cannot profile configuration 4: Cutlass GEMM Tactic
# [TensorRT-LLM][ERROR] Assertion failed: Failed to initialize cutlass TMA WS grouped gemm.
# Should be fixed in the moe_kernels in the future.
invalid = (m > 128 and
min_latency_mode) or (m <= 128 and min_latency_mode and
(not self.weight_dtype == torch.int64))

return [] if invalid else list(
range(self.fused_moe_runner.get_tactic_num()))
return range(self.fused_moe_runner.get_tactic_num())

def forward(
self,
Expand All @@ -98,8 +81,7 @@ def forward(
tactic: int = -1,
do_preparation: bool = False,
):
x, fc1_expert_weights, fc2_expert_weights, min_latency_mode_tensor = inputs
min_latency_mode = min_latency_mode_tensor.size(0) == 1
x, fc1_expert_weights, fc2_expert_weights = inputs
# determine if we should use min latency mode according to the profiled seq len
self.fused_moe_runner.run_gemm_profile(
x,
Expand All @@ -113,7 +95,7 @@ def forward(
self.cluster_size,
self.cluster_rank,
self.enable_alltoall,
min_latency_mode,
self.min_latency_mode,
gemm_idx,
tactic,
do_preparation,
Expand All @@ -122,13 +104,11 @@ def forward(
@classmethod
@lru_cache(maxsize=None)
def refine_tuning_config(cls, tune_max_num_tokens: int):
cls.tuning_config = TuningConfig(dynamic_tensor_specs=(
DynamicTensorSpec(
cls.tuning_config = TuningConfig(
dynamic_tensor_specs=(DynamicTensorSpec(
0, 0, get_last_power_of_2_num_tokens_buckets(
tune_max_num_tokens), lambda x: min(
last_positive_power_of_2(x), tune_max_num_tokens)),
DynamicTensorSpec(3, 0, (0, ), lambda x: x),
))
last_positive_power_of_2(x), tune_max_num_tokens)), ))


@torch.library.custom_op("trtllm::fused_moe", mutates_args=())
Expand Down Expand Up @@ -157,9 +137,6 @@ def fused_moe(
tuner = AutoTuner.get()
MoERunner.refine_tuning_config(tune_max_num_tokens)

# TODO: set min_latency_mode always to False due to the error in the moe_kernels
min_latency_tensor = torch.empty(0)

# allocate workspace for profiling
moe_runner = MoERunner(
x_dtype=input.dtype,
Expand All @@ -175,21 +152,22 @@ def fused_moe(
enable_alltoall=enable_alltoall,
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
use_w4a8_group_scaling=use_w4a8_group_scaling,
min_latency_mode=min_latency_mode,
)

_, gemm_tactic_1 = tuner.choose_one(
"trtllm::fused_moe::gemm1",
[moe_runner],
MoERunner.tuning_config,
[input, fc1_expert_weights, fc2_expert_weights, min_latency_tensor],
[input, fc1_expert_weights, fc2_expert_weights],
gemm_idx=1,
)

_, gemm_tactic_2 = tuner.choose_one(
"trtllm::fused_moe::gemm2",
[moe_runner],
MoERunner.tuning_config,
[input, fc1_expert_weights, fc2_expert_weights, min_latency_tensor],
[input, fc1_expert_weights, fc2_expert_weights],
gemm_idx=2,
)

Expand Down