Skip to content

Commit e7a9112

Browse files
djns99dominicshanshan
authored andcommitted
feat: Add support for per expert activation scaling factors (NVIDIA#5013)
Signed-off-by: Daniel Stokes <[email protected]>
1 parent 09eeeb8 commit e7a9112

File tree

10 files changed

+197
-85
lines changed

10 files changed

+197
-85
lines changed

cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkFixture.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,7 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
581581

582582
auto func = NVFP4 ? QuantParams::FP4 : QuantParams::FP8MXFP4;
583583
mQuantParams = func(mExpertFP4ActScale1, mExpertFP4WeightSf1, mExpertFP4GlobalScale1, mExpertFP4ActScale2,
584-
mExpertFP4WeightSf2, mExpertFP4GlobalScale2);
584+
mExpertFP4WeightSf2, mExpertFP4GlobalScale2, false, false);
585585
}
586586

587587
mSelectedExperts = allocBuffer<int>(mTotalTokens * mK);

cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,9 @@ struct QuantParams
210210
// FP8 quantization params
211211
struct
212212
{
213+
bool fc2_use_per_expert_act_scale = false;
213214
float const* dequant_fc1 = nullptr; // (num_experts_per_node, )
214-
float const* quant_fc2 = nullptr; // (1, )
215+
float const* quant_fc2 = nullptr; // (1, ) or (num_experts_per_node, ) based on fc2_use_per_expert_act_scale
215216
float const* dequant_fc2 = nullptr; // (num_experts_per_node, )
216217
float const* quant_final = nullptr; // (1, )
217218
float const* dequant_input = nullptr; // (1, )
@@ -223,10 +224,12 @@ struct QuantParams
223224
{
224225
struct GemmInputs
225226
{
226-
float const* act_global_scale = nullptr; // (1, )
227+
bool use_per_expert_act_scale = false;
228+
float const* act_global_scale
229+
= nullptr; // (1, ) or (num_experts_per_node, ) based on use_per_expert_act_scale
227230
TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF const* weight_block_scale
228-
= nullptr; // (experts, n, k / 32)
229-
float const* global_scale = nullptr; // (num_experts_per_node, )
231+
= nullptr; // (experts, n, k / 32)
232+
float const* global_scale = nullptr; // (num_experts_per_node, )
230233
};
231234

232235
GemmInputs fc1;
@@ -238,10 +241,13 @@ struct QuantParams
238241
{
239242
struct GemmInputs
240243
{
241-
float const* act_global_scale = nullptr; // (1, )
244+
bool use_per_expert_act_scale = false;
245+
246+
float const* act_global_scale
247+
= nullptr; // (1, ) or (num_experts_per_node, ) based on use_per_expert_act_scale
242248
TmaWarpSpecializedGroupedGemmInput::NVFP4ElementSF const* weight_block_scale
243-
= nullptr; // (experts, n, k / 16)
244-
float const* global_scale = nullptr; // (num_experts_per_node, )
249+
= nullptr; // (experts, n, k / 16)
250+
float const* global_scale = nullptr; // (num_experts_per_node, )
245251
};
246252

247253
GemmInputs fc1;
@@ -287,10 +293,11 @@ struct QuantParams
287293
}
288294

289295
static QuantParams FP8(float const* dequant_fc1, float const* quant_fc2, float const* dequant_fc2,
290-
float const* quant_final = nullptr, float const* dequant_input = nullptr)
296+
float const* quant_final = nullptr, float const* dequant_input = nullptr,
297+
bool fc2_use_per_expert_act_scale = false)
291298
{
292299
QuantParams qp;
293-
qp.fp8 = {dequant_fc1, quant_fc2, dequant_fc2, quant_final, dequant_input};
300+
qp.fp8 = {fc2_use_per_expert_act_scale, dequant_fc1, quant_fc2, dequant_fc2, quant_final, dequant_input};
294301
return qp;
295302
}
296303

@@ -299,12 +306,14 @@ struct QuantParams
299306
float const* fc1_global_scale, //
300307
float const* fc2_act_global_scale,
301308
TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF const* fc2_weight_block_scale,
302-
float const* fc2_global_scale //
303-
)
309+
float const* fc2_global_scale, //
310+
bool fc1_use_per_expert_act_scale = false, bool fc2_use_per_expert_act_scale = false)
304311
{
305312
QuantParams qp;
306-
qp.fp8_mxfp4.fc1 = {fc1_act_global_scale, fc1_weight_block_scale, fc1_global_scale};
307-
qp.fp8_mxfp4.fc2 = {fc2_act_global_scale, fc2_weight_block_scale, fc2_global_scale};
313+
qp.fp8_mxfp4.fc1
314+
= {fc1_use_per_expert_act_scale, fc1_act_global_scale, fc1_weight_block_scale, fc1_global_scale};
315+
qp.fp8_mxfp4.fc2
316+
= {fc2_use_per_expert_act_scale, fc2_act_global_scale, fc2_weight_block_scale, fc2_global_scale};
308317
return qp;
309318
}
310319

@@ -313,12 +322,12 @@ struct QuantParams
313322
float const* fc1_global_scale, //
314323
float const* fc2_act_global_scale,
315324
TmaWarpSpecializedGroupedGemmInput::NVFP4ElementSF const* fc2_weight_block_scale,
316-
float const* fc2_global_scale //
317-
)
325+
float const* fc2_global_scale, //
326+
bool fc1_use_per_expert_act_scale = false, bool fc2_use_per_expert_act_scale = false)
318327
{
319328
QuantParams qp;
320-
qp.fp4.fc1 = {fc1_act_global_scale, fc1_weight_block_scale, fc1_global_scale};
321-
qp.fp4.fc2 = {fc2_act_global_scale, fc2_weight_block_scale, fc2_global_scale};
329+
qp.fp4.fc1 = {fc1_use_per_expert_act_scale, fc1_act_global_scale, fc1_weight_block_scale, fc1_global_scale};
330+
qp.fp4.fc2 = {fc2_use_per_expert_act_scale, fc2_act_global_scale, fc2_weight_block_scale, fc2_global_scale};
322331
return qp;
323332
}
324333

cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1237,7 +1237,8 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp
12371237
ExpandedActivationsType* permuted_output, float const* unpermuted_scales, float* permuted_scales,
12381238
int const* expanded_dest_row_to_expanded_source_row, int* expanded_source_row_to_expanded_dest_row,
12391239
int64_t const num_rows, int64_t const cols, int64_t const k, float const* fc1_act_global_scale,
1240-
int64_t const* expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat,
1240+
bool use_per_expert_act_scale, int64_t const* expert_first_token_offset,
1241+
TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat,
12411242
TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, int64_t const num_experts_per_node)
12421243
{
12431244
#ifdef ENABLE_FP4
@@ -1300,7 +1301,8 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp
13001301
int64_t expert = findTotalEltsLessThanTarget(
13011302
expert_first_token_offset, num_experts_per_node, (int64_t) expanded_dest_row + 1)
13021303
- 1;
1303-
float global_scale_val = fc1_act_global_scale ? *fc1_act_global_scale : 1.0f;
1304+
size_t act_scale_idx = use_per_expert_act_scale ? expert : 0;
1305+
float global_scale_val = fc1_act_global_scale ? fc1_act_global_scale[act_scale_idx] : 1.0f;
13041306
int64_t num_tokens_before_expert = expert_first_token_offset[expert];
13051307

13061308
for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride)
@@ -1315,6 +1317,7 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp
13151317
}
13161318
else
13171319
{
1320+
assert(act_scale_idx == 0 && "Cannot use per-expert act scale for pre-quantized activations");
13181321
writeSF(num_tokens_before_expert, expert, source_row, expanded_dest_row, elem_index, cols, num_rows,
13191322
fc1_act_sf_flat, input_sf);
13201323
dest_row_ptr[elem_index] = in_vec;
@@ -1345,7 +1348,7 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input,
13451348
ExpandedActivationsType* permuted_output, float const* unpermuted_scales, float* permuted_scales,
13461349
int const* expanded_dest_row_to_expanded_source_row, int* expanded_source_row_to_expanded_dest_row,
13471350
int64_t const num_rows, int64_t const cols, int const k, int const num_experts_per_node,
1348-
float const* fc1_act_global_scale, int64_t* expert_first_token_offset,
1351+
float const* fc1_act_global_scale, bool use_per_expert_act_scale, int64_t* expert_first_token_offset,
13491352
TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat,
13501353
TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, cudaStream_t stream)
13511354
{
@@ -1360,6 +1363,11 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input,
13601363
check_cuda_error(cudaMemsetAsync(
13611364
fc1_act_sf_flat, 0x0, num_elems * sizeof(TmaWarpSpecializedGroupedGemmInput::NVFP4ElementSF), stream));
13621365
}
1366+
else
1367+
{
1368+
TLLM_CHECK_WITH_INFO(
1369+
!use_per_expert_act_scale, "Per-expert act scale for FC1 is only supported for FP4 activations");
1370+
}
13631371
#endif
13641372

13651373
static int const smCount = tensorrt_llm::common::getMultiProcessorCount();
@@ -1380,7 +1388,8 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input,
13801388
config.attrs = attrs;
13811389
cudaLaunchKernelEx(&config, func, unpermuted_input, permuted_output, unpermuted_scales, permuted_scales,
13821390
expanded_dest_row_to_expanded_source_row, expanded_source_row_to_expanded_dest_row, num_rows, cols, k,
1383-
fc1_act_global_scale, expert_first_token_offset, fc1_act_sf_flat, input_sf, num_experts_per_node);
1391+
fc1_act_global_scale, use_per_expert_act_scale, expert_first_token_offset, fc1_act_sf_flat, input_sf,
1392+
num_experts_per_node);
13841393
}
13851394

13861395
enum class ScaleMode : int
@@ -1681,7 +1690,8 @@ template <class T, class GemmOutputType, class ScaleBiasType, template <class> c
16811690
__global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, float const* fp8_quant,
16821691
ScaleBiasType const* bias_ptr, bool bias_is_broadcast, int64_t const* expert_first_token_offset,
16831692
int num_experts_per_node, int64_t inter_size, int64_t max_tokens_per_expert, bool gated,
1684-
float const* fc2_act_global_scale, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_act_sf_flat)
1693+
float const* fc2_act_global_scale, bool use_per_expert_act_scale,
1694+
TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_act_sf_flat)
16851695
{
16861696
#ifdef ENABLE_FP4
16871697
constexpr bool IsFP4 = std::is_same_v<T, __nv_fp4_e2m1>;
@@ -1705,16 +1715,17 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result,
17051715
size_t output_offset = token * inter_size;
17061716

17071717
int64_t expert = 0;
1708-
if (bias_ptr || IsFP4)
1718+
if (bias_ptr || IsFP4 || use_per_expert_act_scale)
17091719
{
17101720
// TODO this is almost certainly faster as a linear scan
17111721
expert = findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node, token + 1) - 1;
17121722
}
17131723

1714-
float const quant_scale = fp8_quant ? *fp8_quant : 1.f;
1724+
size_t act_scale_idx = use_per_expert_act_scale ? expert : 0;
1725+
float const quant_scale = fp8_quant ? fp8_quant[act_scale_idx] : 1.f;
17151726

17161727
// Some globals for FP4
1717-
float global_scale_val = fc2_act_global_scale ? *fc2_act_global_scale : 1.0f;
1728+
float global_scale_val = fc2_act_global_scale ? fc2_act_global_scale[act_scale_idx] : 1.0f;
17181729
int64_t num_tokens_before_expert = IsFP4 ? expert_first_token_offset[expert] : 0;
17191730

17201731
size_t bias_offset = 0;
@@ -1790,7 +1801,7 @@ template <class T, class GemmOutputType, class ScaleBiasType>
17901801
void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8_quant, ScaleBiasType const* bias,
17911802
bool bias_is_broadcast, int64_t const* expert_first_token_offset, int num_experts_per_node, int64_t inter_size,
17921803
int64_t num_tokens, int64_t expanded_num_tokens, ActivationType activation_type, float const* fc2_act_global_scale,
1793-
TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_act_sf_flat, cudaStream_t stream)
1804+
bool use_per_expert_act_scale, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_act_sf_flat, cudaStream_t stream)
17941805
{
17951806
static int const smCount = tensorrt_llm::common::getMultiProcessorCount();
17961807
// Note: Launching 8 blocks per SM can fully leverage the memory bandwidth (tested on B200).
@@ -1819,7 +1830,7 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8
18191830
config.attrs = attrs;
18201831
cudaLaunchKernelEx(&config, fn, output, gemm_result, fp8_quant, bias, bias_is_broadcast, expert_first_token_offset,
18211832
num_experts_per_node, inter_size, num_tokens, isGatedActivation(activation_type), fc2_act_global_scale,
1822-
fc2_act_sf_flat);
1833+
use_per_expert_act_scale, fc2_act_sf_flat);
18231834
}
18241835

18251836
// ============================== Lora Add Bias =================================
@@ -2346,9 +2357,11 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, ScaleBiasType, Ena
23462357

23472358
sync_check_cuda_error(stream);
23482359
constexpr bool bias_is_broadcast = true;
2360+
constexpr bool use_per_expert_act_scale = false;
23492361
doActivation<T, UnfusedGemmOutputType>(output, static_cast<UnfusedGemmOutputType const*>(gemm_output),
23502362
fc2_fp8_quant, fc1_expert_biases, bias_is_broadcast, expert_first_token_offset, num_experts_per_node,
2351-
inter_size, num_rows, expanded_num_rows, fc1_activation_type, nullptr, nullptr, stream);
2363+
inter_size, num_rows, expanded_num_rows, fc1_activation_type, nullptr, use_per_expert_act_scale, nullptr,
2364+
stream);
23522365

23532366
sync_check_cuda_error(stream);
23542367
}
@@ -2498,10 +2511,16 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
24982511

24992512
// TODO: when bias_is_broadcast is false, fuse bias to gemm
25002513
using GatedActOutputType = std::conditional_t<use_w4afp8, BackBoneType, T>;
2514+
bool use_per_expert_act_scale = use_fp4 ? quant_params.fp4.fc2.use_per_expert_act_scale
2515+
: use_wfp4afp8 ? quant_params.fp8_mxfp4.fc2.use_per_expert_act_scale
2516+
: use_fp8 ? quant_params.fp8.fc2_use_per_expert_act_scale
2517+
: false;
2518+
25012519
doActivation<GatedActOutputType, UnfusedGemmOutputType>(reinterpret_cast<GatedActOutputType*>(output),
25022520
static_cast<UnfusedGemmOutputType const*>(gemm_output), fc2_fp8_quant, fc1_expert_biases, bias_is_broadcast,
25032521
expert_first_token_offset, num_experts_per_node, inter_size, num_rows, expanded_num_rows,
2504-
fc1_activation_type, quant_params.fp4.fc2.act_global_scale, fc2_fp4_act_flat, stream);
2522+
fc1_activation_type, quant_params.fp4.fc2.act_global_scale, use_per_expert_act_scale, fc2_fp4_act_flat,
2523+
stream);
25052524

25062525
sync_check_cuda_error(stream);
25072526
}
@@ -2522,9 +2541,11 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
25222541
/*use_fused_moe*/ false, stream, config};
25232542
gemm_runner.moeGemm(universal_input, TmaWarpSpecializedGroupedGemmInput{});
25242543

2544+
bool use_per_expert_act_scale = use_fp8 ? quant_params.fp8.fc2_use_per_expert_act_scale : false;
25252545
doActivation<T, UnfusedGemmOutputType>(output, static_cast<UnfusedGemmOutputType const*>(intermediate_result),
25262546
fc2_fp8_quant, fc1_expert_biases, bias_is_broadcast, expert_first_token_offset, num_experts_per_node,
2527-
inter_size, num_rows, expanded_num_rows, fc1_activation_type, nullptr, nullptr, stream);
2547+
inter_size, num_rows, expanded_num_rows, fc1_activation_type, nullptr, use_per_expert_act_scale, nullptr,
2548+
stream);
25282549

25292550
sync_check_cuda_error(stream);
25302551
}
@@ -2687,7 +2708,7 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
26872708
loraBiasApplyFunc(static_cast<UnfusedGemmOutputType*>(gemm_output),
26882709
static_cast<UnfusedGemmOutputType const*>(gemm_output), nullptr,
26892710
static_cast<ScaleBiasType const*>(fc2_lora), false, expert_first_token_offset, num_experts_per_node,
2690-
hidden_size, num_rows, expanded_num_rows, ActivationType::Identity, nullptr, nullptr, stream);
2711+
hidden_size, num_rows, expanded_num_rows, ActivationType::Identity, nullptr, false, nullptr, stream);
26912712
sync_check_cuda_error(stream);
26922713
}
26932714

@@ -3129,10 +3150,13 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
31293150
}
31303151

31313152
using ExpandedActivationsType = std::conditional_t<use_w4afp8, BackBoneType, T>;
3153+
// Only NVFP4xNVFP4 supports FC1 per-expert act scale
3154+
bool use_per_expert_act_scale = use_fp4 ? quant_params.fp4.fc1.use_per_expert_act_scale : false;
31323155
expandInputRowsKernelLauncher(input_activations, reinterpret_cast<ExpandedActivationsType*>(permuted_data_),
31333156
token_topk_unpermuted_scales, permuted_token_final_scales_, permuted_source_token_ids_,
31343157
expanded_source_row_to_expanded_dest_row, num_rows, hidden_size, experts_per_token, num_experts_per_node,
3135-
quant_params.fp4.fc1.act_global_scale, expert_first_token_offset_, fc1_fp4_act_scale_, input_sf, stream);
3158+
quant_params.fp4.fc1.act_global_scale, use_per_expert_act_scale, expert_first_token_offset_,
3159+
fc1_fp4_act_scale_, input_sf, stream);
31363160

31373161
sync_check_cuda_error(stream);
31383162

@@ -3211,10 +3235,12 @@ CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enable>::
32113235

32123236
auto alpha_scale_flat1 = use_fp4 ? quant_params.fp4.fc1.global_scale
32133237
: use_wfp4afp8 ? quant_params.fp8_mxfp4.fc1.global_scale
3214-
: fp8_dequant1;
3238+
: use_fp8 ? fp8_dequant1
3239+
: nullptr;
32153240
auto alpha_scale_flat2 = use_fp4 ? quant_params.fp4.fc2.global_scale
32163241
: use_wfp4afp8 ? quant_params.fp8_mxfp4.fc2.global_scale
3217-
: fp8_dequant2;
3242+
: use_fp8 ? fp8_dequant2
3243+
: nullptr;
32183244
if (!alpha_scale_flat1 && !alpha_scale_flat2)
32193245
{
32203246
layout_info1.alpha_scale_ptr_array = nullptr;
@@ -3380,6 +3406,7 @@ CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enable>::
33803406
// fp8_mxfp4 memsets the scaling factors to 1.0f
33813407
if (quant_params.fp8_mxfp4.fc1.weight_block_scale)
33823408
{
3409+
// We are in FP8 x MXFP4 mode
33833410
TLLM_CHECK(quant_params.fp8_mxfp4.fc2.weight_block_scale);
33843411
TLLM_CHECK(fc1_fp4_act_scale_ != nullptr);
33853412
TLLM_CHECK_WITH_INFO(fc1_fp4_act_scale_ == fc2_fp4_act_scale_,
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
version https://git-lfs.github.com/spec/v1
2-
oid sha256:c01175cbc8e003e8288e30ad2dc88c2c819147f4d435a5121460533141b04719
3-
size 64321452
2+
oid sha256:6d12357919fe6c63749a81e124afd60453153489a3f50cb44b41671d9b55f947
3+
size 64338696
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
a1180829a0d8fe772ff37934b72573bb41671e7ed76dfa3bd5cd449348b9683a libtensorrt_llm_internal_cutlass_kernels_static.a
2-
commit c767347ff934578193ee4bad58ba3b9398046245
1+
ad34c0f31247c880d60e2c8198093e8373cf0e1d3e8badee0424bfa607d6cd8e libtensorrt_llm_internal_cutlass_kernels_static.a
2+
commit bac309ac608d35d7d0144e594bf3e5fa8cfca796

0 commit comments

Comments
 (0)