@@ -1237,7 +1237,8 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp
12371237 ExpandedActivationsType* permuted_output, float const * unpermuted_scales, float * permuted_scales,
12381238 int const * expanded_dest_row_to_expanded_source_row, int * expanded_source_row_to_expanded_dest_row,
12391239 int64_t const num_rows, int64_t const cols, int64_t const k, float const * fc1_act_global_scale,
1240- int64_t const * expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat,
1240+ bool use_per_expert_act_scale, int64_t const * expert_first_token_offset,
1241+ TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat,
12411242 TmaWarpSpecializedGroupedGemmInput::ElementSF const * input_sf, int64_t const num_experts_per_node)
12421243{
12431244#ifdef ENABLE_FP4
@@ -1300,7 +1301,8 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp
13001301 int64_t expert = findTotalEltsLessThanTarget (
13011302 expert_first_token_offset, num_experts_per_node, (int64_t ) expanded_dest_row + 1 )
13021303 - 1 ;
1303- float global_scale_val = fc1_act_global_scale ? *fc1_act_global_scale : 1 .0f ;
1304+ size_t act_scale_idx = use_per_expert_act_scale ? expert : 0 ;
1305+ float global_scale_val = fc1_act_global_scale ? fc1_act_global_scale[act_scale_idx] : 1 .0f ;
13041306 int64_t num_tokens_before_expert = expert_first_token_offset[expert];
13051307
13061308 for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride)
@@ -1315,6 +1317,7 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp
13151317 }
13161318 else
13171319 {
1320+ assert (act_scale_idx == 0 && " Cannot use per-expert act scale for pre-quantized activations" );
13181321 writeSF (num_tokens_before_expert, expert, source_row, expanded_dest_row, elem_index, cols, num_rows,
13191322 fc1_act_sf_flat, input_sf);
13201323 dest_row_ptr[elem_index] = in_vec;
@@ -1345,7 +1348,7 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input,
13451348 ExpandedActivationsType* permuted_output, float const * unpermuted_scales, float * permuted_scales,
13461349 int const * expanded_dest_row_to_expanded_source_row, int * expanded_source_row_to_expanded_dest_row,
13471350 int64_t const num_rows, int64_t const cols, int const k, int const num_experts_per_node,
1348- float const * fc1_act_global_scale, int64_t * expert_first_token_offset,
1351+ float const * fc1_act_global_scale, bool use_per_expert_act_scale, int64_t * expert_first_token_offset,
13491352 TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat,
13501353 TmaWarpSpecializedGroupedGemmInput::ElementSF const * input_sf, cudaStream_t stream)
13511354{
@@ -1360,6 +1363,11 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input,
13601363 check_cuda_error (cudaMemsetAsync (
13611364 fc1_act_sf_flat, 0x0 , num_elems * sizeof (TmaWarpSpecializedGroupedGemmInput::NVFP4ElementSF), stream));
13621365 }
1366+ else
1367+ {
1368+ TLLM_CHECK_WITH_INFO (
1369+ !use_per_expert_act_scale, " Per-expert act scale for FC1 is only supported for FP4 activations" );
1370+ }
13631371#endif
13641372
13651373 static int const smCount = tensorrt_llm::common::getMultiProcessorCount ();
@@ -1380,7 +1388,8 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input,
13801388 config.attrs = attrs;
13811389 cudaLaunchKernelEx (&config, func, unpermuted_input, permuted_output, unpermuted_scales, permuted_scales,
13821390 expanded_dest_row_to_expanded_source_row, expanded_source_row_to_expanded_dest_row, num_rows, cols, k,
1383- fc1_act_global_scale, expert_first_token_offset, fc1_act_sf_flat, input_sf, num_experts_per_node);
1391+ fc1_act_global_scale, use_per_expert_act_scale, expert_first_token_offset, fc1_act_sf_flat, input_sf,
1392+ num_experts_per_node);
13841393}
13851394
13861395enum class ScaleMode : int
@@ -1681,7 +1690,8 @@ template <class T, class GemmOutputType, class ScaleBiasType, template <class> c
16811690__global__ void doActivationKernel (T* output, GemmOutputType const * gemm_result, float const * fp8_quant,
16821691 ScaleBiasType const * bias_ptr, bool bias_is_broadcast, int64_t const * expert_first_token_offset,
16831692 int num_experts_per_node, int64_t inter_size, int64_t max_tokens_per_expert, bool gated,
1684- float const * fc2_act_global_scale, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_act_sf_flat)
1693+ float const * fc2_act_global_scale, bool use_per_expert_act_scale,
1694+ TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_act_sf_flat)
16851695{
16861696#ifdef ENABLE_FP4
16871697 constexpr bool IsFP4 = std::is_same_v<T, __nv_fp4_e2m1>;
@@ -1705,16 +1715,17 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result,
17051715 size_t output_offset = token * inter_size;
17061716
17071717 int64_t expert = 0 ;
1708- if (bias_ptr || IsFP4)
1718+ if (bias_ptr || IsFP4 || use_per_expert_act_scale )
17091719 {
17101720 // TODO this is almost certainly faster as a linear scan
17111721 expert = findTotalEltsLessThanTarget (expert_first_token_offset, num_experts_per_node, token + 1 ) - 1 ;
17121722 }
17131723
1714- float const quant_scale = fp8_quant ? *fp8_quant : 1 .f ;
1724+ size_t act_scale_idx = use_per_expert_act_scale ? expert : 0 ;
1725+ float const quant_scale = fp8_quant ? fp8_quant[act_scale_idx] : 1 .f ;
17151726
17161727 // Some globals for FP4
1717- float global_scale_val = fc2_act_global_scale ? * fc2_act_global_scale : 1 .0f ;
1728+ float global_scale_val = fc2_act_global_scale ? fc2_act_global_scale[act_scale_idx] : 1 .0f ;
17181729 int64_t num_tokens_before_expert = IsFP4 ? expert_first_token_offset[expert] : 0 ;
17191730
17201731 size_t bias_offset = 0 ;
@@ -1790,7 +1801,7 @@ template <class T, class GemmOutputType, class ScaleBiasType>
17901801void doActivation (T* output, GemmOutputType const * gemm_result, float const * fp8_quant, ScaleBiasType const * bias,
17911802 bool bias_is_broadcast, int64_t const * expert_first_token_offset, int num_experts_per_node, int64_t inter_size,
17921803 int64_t num_tokens, int64_t expanded_num_tokens, ActivationType activation_type, float const * fc2_act_global_scale,
1793- TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_act_sf_flat, cudaStream_t stream)
1804+ bool use_per_expert_act_scale, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_act_sf_flat, cudaStream_t stream)
17941805{
17951806 static int const smCount = tensorrt_llm::common::getMultiProcessorCount ();
17961807 // Note: Launching 8 blocks per SM can fully leverage the memory bandwidth (tested on B200).
@@ -1819,7 +1830,7 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8
18191830 config.attrs = attrs;
18201831 cudaLaunchKernelEx (&config, fn, output, gemm_result, fp8_quant, bias, bias_is_broadcast, expert_first_token_offset,
18211832 num_experts_per_node, inter_size, num_tokens, isGatedActivation (activation_type), fc2_act_global_scale,
1822- fc2_act_sf_flat);
1833+ use_per_expert_act_scale, fc2_act_sf_flat);
18231834}
18241835
18251836// ============================== Lora Add Bias =================================
@@ -2346,9 +2357,11 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, ScaleBiasType, Ena
23462357
23472358 sync_check_cuda_error (stream);
23482359 constexpr bool bias_is_broadcast = true ;
2360+ constexpr bool use_per_expert_act_scale = false ;
23492361 doActivation<T, UnfusedGemmOutputType>(output, static_cast <UnfusedGemmOutputType const *>(gemm_output),
23502362 fc2_fp8_quant, fc1_expert_biases, bias_is_broadcast, expert_first_token_offset, num_experts_per_node,
2351- inter_size, num_rows, expanded_num_rows, fc1_activation_type, nullptr , nullptr , stream);
2363+ inter_size, num_rows, expanded_num_rows, fc1_activation_type, nullptr , use_per_expert_act_scale, nullptr ,
2364+ stream);
23522365
23532366 sync_check_cuda_error (stream);
23542367}
@@ -2498,10 +2511,16 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
24982511
24992512 // TODO: when bias_is_broadcast is false, fuse bias to gemm
25002513 using GatedActOutputType = std::conditional_t <use_w4afp8, BackBoneType, T>;
2514+ bool use_per_expert_act_scale = use_fp4 ? quant_params.fp4 .fc2 .use_per_expert_act_scale
2515+ : use_wfp4afp8 ? quant_params.fp8_mxfp4 .fc2 .use_per_expert_act_scale
2516+ : use_fp8 ? quant_params.fp8 .fc2_use_per_expert_act_scale
2517+ : false ;
2518+
25012519 doActivation<GatedActOutputType, UnfusedGemmOutputType>(reinterpret_cast <GatedActOutputType*>(output),
25022520 static_cast <UnfusedGemmOutputType const *>(gemm_output), fc2_fp8_quant, fc1_expert_biases, bias_is_broadcast,
25032521 expert_first_token_offset, num_experts_per_node, inter_size, num_rows, expanded_num_rows,
2504- fc1_activation_type, quant_params.fp4 .fc2 .act_global_scale , fc2_fp4_act_flat, stream);
2522+ fc1_activation_type, quant_params.fp4 .fc2 .act_global_scale , use_per_expert_act_scale, fc2_fp4_act_flat,
2523+ stream);
25052524
25062525 sync_check_cuda_error (stream);
25072526 }
@@ -2522,9 +2541,11 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
25222541 /* use_fused_moe*/ false , stream, config};
25232542 gemm_runner.moeGemm (universal_input, TmaWarpSpecializedGroupedGemmInput{});
25242543
2544+ bool use_per_expert_act_scale = use_fp8 ? quant_params.fp8 .fc2_use_per_expert_act_scale : false ;
25252545 doActivation<T, UnfusedGemmOutputType>(output, static_cast <UnfusedGemmOutputType const *>(intermediate_result),
25262546 fc2_fp8_quant, fc1_expert_biases, bias_is_broadcast, expert_first_token_offset, num_experts_per_node,
2527- inter_size, num_rows, expanded_num_rows, fc1_activation_type, nullptr , nullptr , stream);
2547+ inter_size, num_rows, expanded_num_rows, fc1_activation_type, nullptr , use_per_expert_act_scale, nullptr ,
2548+ stream);
25282549
25292550 sync_check_cuda_error (stream);
25302551 }
@@ -2687,7 +2708,7 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
26872708 loraBiasApplyFunc (static_cast <UnfusedGemmOutputType*>(gemm_output),
26882709 static_cast <UnfusedGemmOutputType const *>(gemm_output), nullptr ,
26892710 static_cast <ScaleBiasType const *>(fc2_lora), false , expert_first_token_offset, num_experts_per_node,
2690- hidden_size, num_rows, expanded_num_rows, ActivationType::Identity, nullptr , nullptr , stream);
2711+ hidden_size, num_rows, expanded_num_rows, ActivationType::Identity, nullptr , false , nullptr , stream);
26912712 sync_check_cuda_error (stream);
26922713 }
26932714
@@ -3129,10 +3150,13 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
31293150 }
31303151
31313152 using ExpandedActivationsType = std::conditional_t <use_w4afp8, BackBoneType, T>;
3153+ // Only NVFP4xNVFP4 supports FC1 per-expert act scale
3154+ bool use_per_expert_act_scale = use_fp4 ? quant_params.fp4 .fc1 .use_per_expert_act_scale : false ;
31323155 expandInputRowsKernelLauncher (input_activations, reinterpret_cast <ExpandedActivationsType*>(permuted_data_),
31333156 token_topk_unpermuted_scales, permuted_token_final_scales_, permuted_source_token_ids_,
31343157 expanded_source_row_to_expanded_dest_row, num_rows, hidden_size, experts_per_token, num_experts_per_node,
3135- quant_params.fp4 .fc1 .act_global_scale , expert_first_token_offset_, fc1_fp4_act_scale_, input_sf, stream);
3158+ quant_params.fp4 .fc1 .act_global_scale , use_per_expert_act_scale, expert_first_token_offset_,
3159+ fc1_fp4_act_scale_, input_sf, stream);
31363160
31373161 sync_check_cuda_error (stream);
31383162
@@ -3211,10 +3235,12 @@ CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enable>::
32113235
32123236 auto alpha_scale_flat1 = use_fp4 ? quant_params.fp4 .fc1 .global_scale
32133237 : use_wfp4afp8 ? quant_params.fp8_mxfp4 .fc1 .global_scale
3214- : fp8_dequant1;
3238+ : use_fp8 ? fp8_dequant1
3239+ : nullptr ;
32153240 auto alpha_scale_flat2 = use_fp4 ? quant_params.fp4 .fc2 .global_scale
32163241 : use_wfp4afp8 ? quant_params.fp8_mxfp4 .fc2 .global_scale
3217- : fp8_dequant2;
3242+ : use_fp8 ? fp8_dequant2
3243+ : nullptr ;
32183244 if (!alpha_scale_flat1 && !alpha_scale_flat2)
32193245 {
32203246 layout_info1.alpha_scale_ptr_array = nullptr ;
@@ -3380,6 +3406,7 @@ CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enable>::
33803406 // fp8_mxfp4 memsets the scaling factors to 1.0f
33813407 if (quant_params.fp8_mxfp4 .fc1 .weight_block_scale )
33823408 {
3409+ // We are in FP8 x MXFP4 mode
33833410 TLLM_CHECK (quant_params.fp8_mxfp4 .fc2 .weight_block_scale );
33843411 TLLM_CHECK (fc1_fp4_act_scale_ != nullptr );
33853412 TLLM_CHECK_WITH_INFO (fc1_fp4_act_scale_ == fc2_fp4_act_scale_,
0 commit comments