Skip to content

Commit 91f2cc3

Browse files
committed
Revert "[TRTLLM-5965] perf: Optimize MoE sort kernels for large-scale EP (NVIDIA#5435)"
This reverts commit b4dab23.
1 parent 98a7c24 commit 91f2cc3

File tree

6 files changed

+414
-629
lines changed

6 files changed

+414
-629
lines changed

cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h

Lines changed: 60 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,32 @@ struct LoraParams
8787

8888
namespace cutlass_kernels
8989
{
90+
static inline size_t pad_to_multiple_of_16(size_t const& input)
91+
{
92+
static constexpr int ALIGNMENT = 16;
93+
return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT);
94+
}
95+
96+
class CubKeyValueSorter
97+
{
98+
public:
99+
CubKeyValueSorter();
100+
101+
CubKeyValueSorter(int const num_experts_per_node);
102+
103+
void updateNumExperts(int const num_experts_per_node);
104+
105+
static size_t getWorkspaceSize(size_t const num_key_value_pairs, int const num_experts_per_node);
106+
107+
void run(void* workspace, size_t const workspace_size, int const* keys_in, int* keys_out, int const* values_in,
108+
int* values_out, size_t const num_key_value_pairs, cudaStream_t stream);
109+
110+
private:
111+
static int expertsToBits(int experts);
112+
int num_experts_;
113+
int num_bits_;
114+
};
115+
90116
/**
91117
* \brief Describes what parallelism mode the MoE is using
92118
*
@@ -371,9 +397,9 @@ class CutlassMoeFCRunnerInterface
371397
ActivationType fc1_activation_type, void const* fc2_expert_weights, void const* fc2_expert_biases,
372398
QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
373399
int const num_experts, int const experts_per_token, char* workspace_ptr, void* final_output,
374-
int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool const enable_alltoall,
375-
bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode,
376-
MoeMinLatencyParams& min_latency_params, cudaStream_t stream)
400+
int* expanded_source_row_to_expanded_dest_row, MOEParallelismConfig parallelism_config,
401+
bool const enable_alltoall, bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale,
402+
bool min_latency_mode, MoeMinLatencyParams& min_latency_params, cudaStream_t stream)
377403
= 0;
378404

379405
// Aliases for profiling the gemms
@@ -387,22 +413,22 @@ class CutlassMoeFCRunnerInterface
387413
int const num_experts_per_node, ActivationType fc1_activation_type, float const** alpha_scale_ptr_array,
388414
bool bias_is_broadcast, bool use_deepseek_fp8_block_scale, cudaStream_t stream,
389415
cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per,
390-
int* active_expert_global_ids)
416+
int* active_expert_global_ids, int start_expert)
391417
= 0;
392418

393419
virtual void gemm2(void const* const input, void* const gemm_output, void* const final_output,
394420
int64_t const* const expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput const tma_ws_input_template,
395421
void const* const fc2_expert_weights, void const* const fc2_expert_biases, void const* const fc2_int_scales,
396422
float const* const fc2_fp8_dequant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc2_fp4_act_flat,
397423
QuantParams quant_params, float const* const token_topk_unpermuted_scales,
398-
float const* const token_topk_permuted_scales, int const* const unpermuted_row_to_permuted_row,
399-
int const* permuted_row_to_unpermuted_row, int const* const token_selected_experts,
424+
float const* const token_topk_permuted_scales, int const* const expanded_source_row_to_expanded_dest_row,
425+
int const* expanded_dest_row_to_expanded_source_row, int const* const expert_for_source_row,
400426
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
401427
int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
402428
int64_t const experts_per_token, float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora,
403429
bool use_deepseek_fp8_block_scale, cudaStream_t stream, MOEParallelismConfig parallelism_config,
404430
bool const enable_alltoall, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode,
405-
int* num_active_experts_per, int* active_expert_global_ids)
431+
int* num_active_experts_per, int* active_expert_global_ids, int start_expert)
406432
= 0;
407433

408434
virtual std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput>
@@ -518,9 +544,9 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
518544
ActivationType fc1_activation_type, void const* fc2_expert_weights, void const* fc2_expert_biases,
519545
QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
520546
int const num_experts, int const experts_per_token, char* workspace_ptr, void* final_output,
521-
int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool const enable_alltoall,
522-
bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode,
523-
MoeMinLatencyParams& min_latency_params, cudaStream_t stream) override;
547+
int* expanded_source_row_to_expanded_dest_row, MOEParallelismConfig parallelism_config,
548+
bool const enable_alltoall, bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale,
549+
bool min_latency_mode, MoeMinLatencyParams& min_latency_params, cudaStream_t stream) override;
524550

525551
// We make these GEMM1 & GEMM2 static because they need to be stateless for the profiler to work
526552
static void gemm1(MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>& gemm_runner,
@@ -539,7 +565,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
539565
int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size,
540566
int const num_experts_per_node, ActivationType fc1_activation_type, float const** alpha_scale_ptr_array,
541567
bool bias_is_broadcast, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config,
542-
bool min_latency_mode, int* num_active_experts_per, int* active_expert_global_ids);
568+
bool min_latency_mode, int* num_active_experts_per, int* active_expert_global_ids, int start_expert);
543569

544570
static void gemm2(MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>& gemm_runner,
545571
DeepSeekBlockScaleGemmRunner* fp8_blockscale_gemm_runner, T const* const input, void* const gemm_output,
@@ -548,14 +574,14 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
548574
ScaleBiasType const* const fc2_expert_biases, ScaleBiasType const* const fc2_int_scales,
549575
float const* const fc2_fp8_dequant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc2_fp4_act_flat,
550576
QuantParams quant_params, float const* const token_topk_unpermuted_scales,
551-
float const* const token_topk_permuted_scales, int const* const unpermuted_row_to_permuted_row,
552-
int const* permuted_row_to_unpermuted_row, int const* const token_selected_experts,
577+
float const* const token_topk_permuted_scales, int const* const expanded_source_row_to_expanded_dest_row,
578+
int const* expanded_dest_row_to_expanded_source_row, int const* const expert_for_source_row,
553579
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
554580
int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
555581
int64_t const experts_per_token, float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora,
556582
cudaStream_t stream, MOEParallelismConfig parallelism_config, bool const enable_alltoall,
557583
cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per,
558-
int* active_expert_global_ids);
584+
int* active_expert_global_ids, int start_expert);
559585

560586
// Overrides to allow us to forward on to the internal functions with the pointers using the correct type
561587
void gemm1(void const* const input, void* const output, void* const intermediate_result,
@@ -568,7 +594,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
568594
int const num_experts_per_node, ActivationType fc1_activation_type, float const** alpha_scale_ptr_array,
569595
bool bias_is_broadcast, bool use_deepseek_fp8_block_scale, cudaStream_t stream,
570596
cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per,
571-
int* active_expert_global_ids) override
597+
int* active_expert_global_ids, int start_expert) override
572598
{
573599
auto* block_scale_gemm_runner = use_deepseek_fp8_block_scale ? getDeepSeekBlockScaleGemmRunner() : nullptr;
574600
return Self::gemm1(moe_gemm_runner_, block_scale_gemm_runner, static_cast<T const*>(input),
@@ -577,33 +603,33 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
577603
num_valid_tokens_ptr, static_cast<ScaleBiasType const*>(fc1_int_scales), fc1_fp8_dequant, fc2_fp8_quant,
578604
fc1_fp4_act_flat, fc2_fp4_act_flat, quant_params, num_rows, expanded_num_rows, hidden_size, inter_size,
579605
num_experts_per_node, fc1_activation_type, alpha_scale_ptr_array, bias_is_broadcast, stream, config,
580-
min_latency_mode, num_active_experts_per, active_expert_global_ids);
606+
min_latency_mode, num_active_experts_per, active_expert_global_ids, start_expert);
581607
}
582608

583609
void gemm2(void const* const input, void* const gemm_output, void* const final_output,
584610
int64_t const* const expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput const tma_ws_input_template,
585611
void const* const fc2_expert_weights, void const* const fc2_expert_biases, void const* const fc2_int_scales,
586612
float const* const fc2_fp8_dequant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc2_fp4_act_flat,
587613
QuantParams quant_params, float const* const token_topk_unpermuted_scales,
588-
float const* const token_topk_permuted_scales, int const* const unpermuted_row_to_permuted_row,
589-
int const* permuted_row_to_unpermuted_row, int const* const token_selected_experts,
614+
float const* const token_topk_permuted_scales, int const* const expanded_source_row_to_expanded_dest_row,
615+
int const* expanded_dest_row_to_expanded_source_row, int const* const expert_for_source_row,
590616
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
591617
int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
592618
int64_t const experts_per_token, float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora,
593619
bool use_deepseek_fp8_block_scale, cudaStream_t stream, MOEParallelismConfig parallelism_config,
594620
bool const enable_alltoall, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode,
595-
int* num_active_experts_per, int* active_expert_global_ids) override
621+
int* num_active_experts_per, int* active_expert_global_ids, int start_expert) override
596622
{
597623
auto* block_scale_gemm_runner = use_deepseek_fp8_block_scale ? getDeepSeekBlockScaleGemmRunner() : nullptr;
598624
return Self::gemm2(moe_gemm_runner_, block_scale_gemm_runner, static_cast<T const*>(input), gemm_output,
599625
static_cast<OutputType*>(final_output), expert_first_token_offset, tma_ws_input_template,
600626
static_cast<WeightType const*>(fc2_expert_weights), static_cast<ScaleBiasType const*>(fc2_expert_biases),
601627
static_cast<ScaleBiasType const*>(fc2_int_scales), fc2_fp8_dequant, fc2_fp4_act_flat, quant_params,
602-
token_topk_unpermuted_scales, token_topk_permuted_scales, unpermuted_row_to_permuted_row,
603-
permuted_row_to_unpermuted_row, token_selected_experts, num_valid_tokens_ptr, num_rows, expanded_num_rows,
604-
hidden_size, inter_size, num_experts_per_node, experts_per_token, alpha_scale_ptr_array, use_lora, fc2_lora,
605-
stream, parallelism_config, enable_alltoall, config, min_latency_mode, num_active_experts_per,
606-
active_expert_global_ids);
628+
token_topk_unpermuted_scales, token_topk_permuted_scales, expanded_source_row_to_expanded_dest_row,
629+
expanded_dest_row_to_expanded_source_row, expert_for_source_row, num_valid_tokens_ptr, num_rows,
630+
expanded_num_rows, hidden_size, inter_size, num_experts_per_node, experts_per_token, alpha_scale_ptr_array,
631+
use_lora, fc2_lora, stream, parallelism_config, enable_alltoall, config, min_latency_mode,
632+
num_active_experts_per, active_expert_global_ids, start_expert);
607633
}
608634

609635
virtual size_t getGemmWorkspaceSize(int num_experts_per_node) const override
@@ -737,29 +763,30 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
737763
static void BlockScaleFC2(DeepSeekBlockScaleGemmRunner& gemm_runner, T const* const input, void* const gemm_output,
738764
OutputType* const final_output, int64_t const* const expert_first_token_offset,
739765
WeightType const* const fc2_expert_weights, ScaleBiasType const* const fc2_expert_biases,
740-
float const* const token_topk_unpermuted_scales, int const* const unpermuted_row_to_permuted_row,
741-
int const* const permuted_row_to_unpermuted_row, int const* const token_selected_experts,
766+
float const* const token_topk_unpermuted_scales, int const* const expanded_source_row_to_expanded_dest_row,
767+
int const* const expanded_dest_row_to_expanded_source_row, int const* const expert_for_source_row,
742768
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
743-
int64_t const hidden_size, int64_t const inter_size, int64_t const num_experts_per_node, int64_t const k,
769+
int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, int64_t const k,
744770
MOEParallelismConfig parallelism_config, bool const enable_alltoall, QuantParams& quant_params,
745771
cudaStream_t stream);
746772

747773
T const* applyPrequantScale(void* smoothed_act, void const* permuted_data, void const* prequant_scales,
748774
int64_t const* num_valid_tokens_ptr, int64_t const expanded_num_rows, int64_t const seq_len, bool const use_awq,
749775
cudaStream_t stream);
750776

777+
CubKeyValueSorter sorter_;
751778
MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType> moe_gemm_runner_;
752779
std::unique_ptr<DeepSeekBlockScaleGemmRunner> blockscale_gemm_runner_;
753780

754781
std::optional<cutlass_extensions::CutlassGemmConfig> gemm1_config_;
755782
std::optional<cutlass_extensions::CutlassGemmConfig> gemm2_config_;
756783

757784
// Pointers
758-
int* permuted_row_to_unpermuted_row_{};
785+
int* unpermuted_token_selected_experts_{};
786+
int* unpermuted_source_token_ids_{};
787+
int* permuted_source_token_ids_{};
759788
int* permuted_token_selected_experts_{};
760-
int* blocked_expert_counts_{};
761-
int* blocked_expert_counts_cumsum_{};
762-
int* blocked_row_to_unpermuted_row_{};
789+
char* sorter_ws_{};
763790
T* permuted_data_{};
764791
float* permuted_token_final_scales_{};
765792

@@ -832,6 +859,7 @@ struct GemmProfilerBackend
832859
mParallelismConfig = parallelism_config;
833860
mEnableAlltoall = enable_alltoall;
834861
mSM = common::getSMVersion();
862+
mSorter.updateNumExperts(mNumExpertsPerNode);
835863

836864
mScalingType = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE;
837865
if (dtype == nvinfer1::DataType::kFP8
@@ -855,6 +883,7 @@ struct GemmProfilerBackend
855883
cudaStream_t const& stream);
856884

857885
CutlassMoeFCRunnerInterface* mInterface;
886+
CubKeyValueSorter mSorter;
858887

859888
GemmToProfile mGemmToProfile = GemmToProfile::Undefined;
860889
std::vector<Config> mAllTacticsSaved;

0 commit comments

Comments
 (0)