Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture

auto func = NVFP4 ? QuantParams::FP4 : QuantParams::FP8MXFP4;
mQuantParams = func(mExpertFP4ActScale1, mExpertFP4WeightSf1, mExpertFP4GlobalScale1, mExpertFP4ActScale2,
mExpertFP4WeightSf2, mExpertFP4GlobalScale2);
mExpertFP4WeightSf2, mExpertFP4GlobalScale2, false, false);
}

mSelectedExperts = allocBuffer<int>(mTotalTokens * mK);
Expand Down
134 changes: 57 additions & 77 deletions cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h

Large diffs are not rendered by default.

860 changes: 516 additions & 344 deletions cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu

Large diffs are not rendered by default.

Git LFS file not shown
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
a1180829a0d8fe772ff37934b72573bb41671e7ed76dfa3bd5cd449348b9683a libtensorrt_llm_internal_cutlass_kernels_static.a
commit c767347ff934578193ee4bad58ba3b9398046245
ad34c0f31247c880d60e2c8198093e8373cf0e1d3e8badee0424bfa607d6cd8e libtensorrt_llm_internal_cutlass_kernels_static.a
commit bac309ac608d35d7d0144e594bf3e5fa8cfca796
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,9 @@ struct QuantParams
// FP8 quantization params
struct
{
bool fc2_use_per_expert_act_scale = false;
float const* dequant_fc1 = nullptr; // (num_experts_per_node, )
float const* quant_fc2 = nullptr; // (1, )
float const* quant_fc2 = nullptr; // (1, ) or (num_experts_per_node, ) based on fc2_use_per_expert_act_scale
float const* dequant_fc2 = nullptr; // (num_experts_per_node, )
float const* quant_final = nullptr; // (1, )
float const* dequant_input = nullptr; // (1, )
Expand All @@ -172,10 +173,12 @@ struct QuantParams
{
struct GemmInputs
{
float const* act_global_scale = nullptr; // (1, )
bool use_per_expert_act_scale = false;
float const* act_global_scale
= nullptr; // (1, ) or (num_experts_per_node, ) based on use_per_expert_act_scale
TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF const* weight_block_scale
= nullptr; // (experts, n, k / 32)
float const* global_scale = nullptr; // (num_experts_per_node, )
= nullptr; // (experts, n, k / 32)
float const* global_scale = nullptr; // (num_experts_per_node, )
};

GemmInputs fc1;
Expand All @@ -187,10 +190,12 @@ struct QuantParams
{
struct GemmInputs
{
float const* act_global_scale = nullptr; // (1, )
bool use_per_expert_act_scale = false;
float const* act_global_scale
= nullptr; // (1, ) or (num_experts_per_node, ) based on use_per_expert_act_scale
TmaWarpSpecializedGroupedGemmInput::NVFP4ElementSF const* weight_block_scale
= nullptr; // (experts, n, k / 16)
float const* global_scale = nullptr; // (num_experts_per_node, )
= nullptr; // (experts, n, k / 16)
float const* global_scale = nullptr; // (num_experts_per_node, )
};

GemmInputs fc1;
Expand Down Expand Up @@ -236,10 +241,11 @@ struct QuantParams
}

static QuantParams FP8(float const* dequant_fc1, float const* quant_fc2, float const* dequant_fc2,
float const* quant_final = nullptr, float const* dequant_input = nullptr)
float const* quant_final = nullptr, float const* dequant_input = nullptr,
bool fc2_use_per_expert_act_scale = false)
{
QuantParams qp;
qp.fp8 = {dequant_fc1, quant_fc2, dequant_fc2, quant_final, dequant_input};
qp.fp8 = {fc2_use_per_expert_act_scale, dequant_fc1, quant_fc2, dequant_fc2, quant_final, dequant_input};
return qp;
}

Expand All @@ -248,12 +254,14 @@ struct QuantParams
float const* fc1_global_scale, //
float const* fc2_act_global_scale,
TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF const* fc2_weight_block_scale,
float const* fc2_global_scale //
)
float const* fc2_global_scale, //
bool fc1_use_per_expert_act_scale = false, bool fc2_use_per_expert_act_scale = false)
{
QuantParams qp;
qp.fp8_mxfp4.fc1 = {fc1_act_global_scale, fc1_weight_block_scale, fc1_global_scale};
qp.fp8_mxfp4.fc2 = {fc2_act_global_scale, fc2_weight_block_scale, fc2_global_scale};
qp.fp8_mxfp4.fc1
= {fc1_use_per_expert_act_scale, fc1_act_global_scale, fc1_weight_block_scale, fc1_global_scale};
qp.fp8_mxfp4.fc2
= {fc2_use_per_expert_act_scale, fc2_act_global_scale, fc2_weight_block_scale, fc2_global_scale};
return qp;
}

Expand All @@ -262,12 +270,12 @@ struct QuantParams
float const* fc1_global_scale, //
float const* fc2_act_global_scale,
TmaWarpSpecializedGroupedGemmInput::NVFP4ElementSF const* fc2_weight_block_scale,
float const* fc2_global_scale //
)
float const* fc2_global_scale, //
bool fc1_use_per_expert_act_scale = false, bool fc2_use_per_expert_act_scale = false)
{
QuantParams qp;
qp.fp4.fc1 = {fc1_act_global_scale, fc1_weight_block_scale, fc1_global_scale};
qp.fp4.fc2 = {fc2_act_global_scale, fc2_weight_block_scale, fc2_global_scale};
qp.fp4.fc1 = {fc1_use_per_expert_act_scale, fc1_act_global_scale, fc1_weight_block_scale, fc1_global_scale};
qp.fp4.fc2 = {fc2_use_per_expert_act_scale, fc2_act_global_scale, fc2_weight_block_scale, fc2_global_scale};
return qp;
}

Expand Down Expand Up @@ -760,8 +768,8 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
QuantParams& quant_params, cudaStream_t stream);

T const* applyPrequantScale(void* smoothed_act, void const* permuted_data, void const* prequant_scales,
int const* permuted_token_selected_experts, int64_t const* num_valid_tokens_ptr,
int64_t const expanded_num_rows, int64_t const seq_len, bool const use_awq, cudaStream_t stream);
int64_t const* num_valid_tokens_ptr, int64_t const expanded_num_rows, int64_t const seq_len, bool const use_awq,
cudaStream_t stream);

CubKeyValueSorter sorter_;
MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType> moe_gemm_runner_;
Expand Down
Git LFS file not shown
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
e7130e36217c1df0d281788fc87764945d9c308bef11ad61b3b1a49c7d41c8af libtensorrt_llm_internal_cutlass_kernels_static.a
commit c767347ff934578193ee4bad58ba3b9398046245
21c59ede16aa448b6135327bd0f95e72a6e614f219935b8f67fe635b3cb4b38b libtensorrt_llm_internal_cutlass_kernels_static.a
commit bac309ac608d35d7d0144e594bf3e5fa8cfca796
Loading
Loading