Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 31 additions & 60 deletions cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,32 +87,6 @@ struct LoraParams

namespace cutlass_kernels
{
static inline size_t pad_to_multiple_of_16(size_t const& input)
{
static constexpr int ALIGNMENT = 16;
return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT);
}

class CubKeyValueSorter
{
public:
CubKeyValueSorter();

CubKeyValueSorter(int const num_experts_per_node);

void updateNumExperts(int const num_experts_per_node);

static size_t getWorkspaceSize(size_t const num_key_value_pairs, int const num_experts_per_node);

void run(void* workspace, size_t const workspace_size, int const* keys_in, int* keys_out, int const* values_in,
int* values_out, size_t const num_key_value_pairs, cudaStream_t stream);

private:
static int expertsToBits(int experts);
int num_experts_;
int num_bits_;
};

/**
* \brief Describes what parallelism mode the MoE is using
*
Expand Down Expand Up @@ -397,9 +371,9 @@ class CutlassMoeFCRunnerInterface
ActivationType fc1_activation_type, void const* fc2_expert_weights, void const* fc2_expert_biases,
QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
int const num_experts, int const experts_per_token, char* workspace_ptr, void* final_output,
int* expanded_source_row_to_expanded_dest_row, MOEParallelismConfig parallelism_config,
bool const enable_alltoall, bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale,
bool min_latency_mode, MoeMinLatencyParams& min_latency_params, cudaStream_t stream)
int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool const enable_alltoall,
bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode,
MoeMinLatencyParams& min_latency_params, cudaStream_t stream)
= 0;

// Aliases for profiling the gemms
Expand All @@ -413,22 +387,22 @@ class CutlassMoeFCRunnerInterface
int const num_experts_per_node, ActivationType fc1_activation_type, float const** alpha_scale_ptr_array,
bool bias_is_broadcast, bool use_deepseek_fp8_block_scale, cudaStream_t stream,
cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per,
int* active_expert_global_ids, int start_expert)
int* active_expert_global_ids)
= 0;

virtual void gemm2(void const* const input, void* const gemm_output, void* const final_output,
int64_t const* const expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput const tma_ws_input_template,
void const* const fc2_expert_weights, void const* const fc2_expert_biases, void const* const fc2_int_scales,
float const* const fc2_fp8_dequant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc2_fp4_act_flat,
QuantParams quant_params, float const* const token_topk_unpermuted_scales,
float const* const token_topk_permuted_scales, int const* const expanded_source_row_to_expanded_dest_row,
int const* expanded_dest_row_to_expanded_source_row, int const* const expert_for_source_row,
float const* const token_topk_permuted_scales, int const* const unpermuted_row_to_permuted_row,
int const* permuted_row_to_unpermuted_row, int const* const token_selected_experts,
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
int64_t const experts_per_token, float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora,
bool use_deepseek_fp8_block_scale, cudaStream_t stream, MOEParallelismConfig parallelism_config,
bool const enable_alltoall, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode,
int* num_active_experts_per, int* active_expert_global_ids, int start_expert)
int* num_active_experts_per, int* active_expert_global_ids)
= 0;

virtual std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput>
Expand Down Expand Up @@ -544,9 +518,9 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
ActivationType fc1_activation_type, void const* fc2_expert_weights, void const* fc2_expert_biases,
QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
int const num_experts, int const experts_per_token, char* workspace_ptr, void* final_output,
int* expanded_source_row_to_expanded_dest_row, MOEParallelismConfig parallelism_config,
bool const enable_alltoall, bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale,
bool min_latency_mode, MoeMinLatencyParams& min_latency_params, cudaStream_t stream) override;
int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool const enable_alltoall,
bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode,
MoeMinLatencyParams& min_latency_params, cudaStream_t stream) override;

// We make these GEMM1 & GEMM2 static because they need to be stateless for the profiler to work
static void gemm1(MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>& gemm_runner,
Expand All @@ -565,7 +539,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size,
int const num_experts_per_node, ActivationType fc1_activation_type, float const** alpha_scale_ptr_array,
bool bias_is_broadcast, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config,
bool min_latency_mode, int* num_active_experts_per, int* active_expert_global_ids, int start_expert);
bool min_latency_mode, int* num_active_experts_per, int* active_expert_global_ids);

static void gemm2(MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>& gemm_runner,
DeepSeekBlockScaleGemmRunner* fp8_blockscale_gemm_runner, T const* const input, void* const gemm_output,
Expand All @@ -574,14 +548,14 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
ScaleBiasType const* const fc2_expert_biases, ScaleBiasType const* const fc2_int_scales,
float const* const fc2_fp8_dequant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc2_fp4_act_flat,
QuantParams quant_params, float const* const token_topk_unpermuted_scales,
float const* const token_topk_permuted_scales, int const* const expanded_source_row_to_expanded_dest_row,
int const* expanded_dest_row_to_expanded_source_row, int const* const expert_for_source_row,
float const* const token_topk_permuted_scales, int const* const unpermuted_row_to_permuted_row,
int const* permuted_row_to_unpermuted_row, int const* const token_selected_experts,
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
int64_t const experts_per_token, float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora,
cudaStream_t stream, MOEParallelismConfig parallelism_config, bool const enable_alltoall,
cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per,
int* active_expert_global_ids, int start_expert);
int* active_expert_global_ids);

// Overrides to allow us to forward on to the internal functions with the pointers using the correct type
void gemm1(void const* const input, void* const output, void* const intermediate_result,
Expand All @@ -594,7 +568,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
int const num_experts_per_node, ActivationType fc1_activation_type, float const** alpha_scale_ptr_array,
bool bias_is_broadcast, bool use_deepseek_fp8_block_scale, cudaStream_t stream,
cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per,
int* active_expert_global_ids, int start_expert) override
int* active_expert_global_ids) override
{
auto* block_scale_gemm_runner = use_deepseek_fp8_block_scale ? getDeepSeekBlockScaleGemmRunner() : nullptr;
return Self::gemm1(moe_gemm_runner_, block_scale_gemm_runner, static_cast<T const*>(input),
Expand All @@ -603,33 +577,33 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
num_valid_tokens_ptr, static_cast<ScaleBiasType const*>(fc1_int_scales), fc1_fp8_dequant, fc2_fp8_quant,
fc1_fp4_act_flat, fc2_fp4_act_flat, quant_params, num_rows, expanded_num_rows, hidden_size, inter_size,
num_experts_per_node, fc1_activation_type, alpha_scale_ptr_array, bias_is_broadcast, stream, config,
min_latency_mode, num_active_experts_per, active_expert_global_ids, start_expert);
min_latency_mode, num_active_experts_per, active_expert_global_ids);
}

void gemm2(void const* const input, void* const gemm_output, void* const final_output,
int64_t const* const expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput const tma_ws_input_template,
void const* const fc2_expert_weights, void const* const fc2_expert_biases, void const* const fc2_int_scales,
float const* const fc2_fp8_dequant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc2_fp4_act_flat,
QuantParams quant_params, float const* const token_topk_unpermuted_scales,
float const* const token_topk_permuted_scales, int const* const expanded_source_row_to_expanded_dest_row,
int const* expanded_dest_row_to_expanded_source_row, int const* const expert_for_source_row,
float const* const token_topk_permuted_scales, int const* const unpermuted_row_to_permuted_row,
int const* permuted_row_to_unpermuted_row, int const* const token_selected_experts,
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
int64_t const experts_per_token, float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora,
bool use_deepseek_fp8_block_scale, cudaStream_t stream, MOEParallelismConfig parallelism_config,
bool const enable_alltoall, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode,
int* num_active_experts_per, int* active_expert_global_ids, int start_expert) override
int* num_active_experts_per, int* active_expert_global_ids) override
{
auto* block_scale_gemm_runner = use_deepseek_fp8_block_scale ? getDeepSeekBlockScaleGemmRunner() : nullptr;
return Self::gemm2(moe_gemm_runner_, block_scale_gemm_runner, static_cast<T const*>(input), gemm_output,
static_cast<OutputType*>(final_output), expert_first_token_offset, tma_ws_input_template,
static_cast<WeightType const*>(fc2_expert_weights), static_cast<ScaleBiasType const*>(fc2_expert_biases),
static_cast<ScaleBiasType const*>(fc2_int_scales), fc2_fp8_dequant, fc2_fp4_act_flat, quant_params,
token_topk_unpermuted_scales, token_topk_permuted_scales, expanded_source_row_to_expanded_dest_row,
expanded_dest_row_to_expanded_source_row, expert_for_source_row, num_valid_tokens_ptr, num_rows,
expanded_num_rows, hidden_size, inter_size, num_experts_per_node, experts_per_token, alpha_scale_ptr_array,
use_lora, fc2_lora, stream, parallelism_config, enable_alltoall, config, min_latency_mode,
num_active_experts_per, active_expert_global_ids, start_expert);
token_topk_unpermuted_scales, token_topk_permuted_scales, unpermuted_row_to_permuted_row,
permuted_row_to_unpermuted_row, token_selected_experts, num_valid_tokens_ptr, num_rows, expanded_num_rows,
hidden_size, inter_size, num_experts_per_node, experts_per_token, alpha_scale_ptr_array, use_lora, fc2_lora,
stream, parallelism_config, enable_alltoall, config, min_latency_mode, num_active_experts_per,
active_expert_global_ids);
}

virtual size_t getGemmWorkspaceSize(int num_experts_per_node) const override
Expand Down Expand Up @@ -763,30 +737,29 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
static void BlockScaleFC2(DeepSeekBlockScaleGemmRunner& gemm_runner, T const* const input, void* const gemm_output,
OutputType* const final_output, int64_t const* const expert_first_token_offset,
WeightType const* const fc2_expert_weights, ScaleBiasType const* const fc2_expert_biases,
float const* const token_topk_unpermuted_scales, int const* const expanded_source_row_to_expanded_dest_row,
int const* const expanded_dest_row_to_expanded_source_row, int const* const expert_for_source_row,
float const* const token_topk_unpermuted_scales, int const* const unpermuted_row_to_permuted_row,
int const* const permuted_row_to_unpermuted_row, int const* const token_selected_experts,
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, int64_t const k,
int64_t const hidden_size, int64_t const inter_size, int64_t const num_experts_per_node, int64_t const k,
MOEParallelismConfig parallelism_config, bool const enable_alltoall, QuantParams& quant_params,
cudaStream_t stream);

T const* applyPrequantScale(void* smoothed_act, void const* permuted_data, void const* prequant_scales,
int64_t const* num_valid_tokens_ptr, int64_t const expanded_num_rows, int64_t const seq_len, bool const use_awq,
cudaStream_t stream);

CubKeyValueSorter sorter_;
MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType> moe_gemm_runner_;
std::unique_ptr<DeepSeekBlockScaleGemmRunner> blockscale_gemm_runner_;

std::optional<cutlass_extensions::CutlassGemmConfig> gemm1_config_;
std::optional<cutlass_extensions::CutlassGemmConfig> gemm2_config_;

// Pointers
int* unpermuted_token_selected_experts_{};
int* unpermuted_source_token_ids_{};
int* permuted_source_token_ids_{};
int* permuted_row_to_unpermuted_row_{};
int* permuted_token_selected_experts_{};
char* sorter_ws_{};
int* blocked_expert_counts_{};
int* blocked_expert_counts_cumsum_{};
int* blocked_row_to_unpermuted_row_{};
T* permuted_data_{};
float* permuted_token_final_scales_{};

Expand Down Expand Up @@ -859,7 +832,6 @@ struct GemmProfilerBackend
mParallelismConfig = parallelism_config;
mEnableAlltoall = enable_alltoall;
mSM = common::getSMVersion();
mSorter.updateNumExperts(mNumExpertsPerNode);

mScalingType = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE;
if (dtype == nvinfer1::DataType::kFP8
Expand All @@ -883,7 +855,6 @@ struct GemmProfilerBackend
cudaStream_t const& stream);

CutlassMoeFCRunnerInterface* mInterface;
CubKeyValueSorter mSorter;

GemmToProfile mGemmToProfile = GemmToProfile::Undefined;
std::vector<Config> mAllTacticsSaved;
Expand Down
Loading