From faed92ddea639b8b4fd792a20b2f9c9785f57b54 Mon Sep 17 00:00:00 2001 From: Daniel Stokes <40156487+djns99@users.noreply.github.com> Date: Thu, 29 May 2025 18:43:25 -0700 Subject: [PATCH] feat: Add support for per expert activation scaling factors Signed-off-by: Daniel Stokes <40156487+djns99@users.noreply.github.com> --- .../mixtureOfExpertsBackendBenchmarkFixture.h | 2 +- .../cutlass_kernels/include/moe_kernels.h | 43 ++++++---- .../cutlass_kernels/moe_gemm/moe_kernels.cu | 61 ++++++++++---- ...llm_internal_cutlass_kernels_static.tar.xz | 4 +- .../aarch64-linux-gnu/version.txt | 4 +- .../include/moe_kernels.h | 46 ++++++----- ...llm_internal_cutlass_kernels_static.tar.xz | 4 +- .../x86_64-linux-gnu/version.txt | 4 +- cpp/tensorrt_llm/thop/moeOp.cpp | 35 ++++++-- .../kernels/mixtureOfExpertsTest.cu | 79 +++++++++++++++---- 10 files changed, 197 insertions(+), 85 deletions(-) diff --git a/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkFixture.h b/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkFixture.h index 4827d5c4938..816af1f0949 100644 --- a/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkFixture.h +++ b/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkFixture.h @@ -581,7 +581,7 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture auto func = NVFP4 ? QuantParams::FP4 : QuantParams::FP8MXFP4; mQuantParams = func(mExpertFP4ActScale1, mExpertFP4WeightSf1, mExpertFP4GlobalScale1, mExpertFP4ActScale2, - mExpertFP4WeightSf2, mExpertFP4GlobalScale2); + mExpertFP4WeightSf2, mExpertFP4GlobalScale2, false, false); } mSelectedExperts = allocBuffer(mTotalTokens * mK); diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h index b1826eb39bb..abb4911a8e8 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h @@ -210,8 +210,9 @@ struct QuantParams // FP8 quantization params struct { + bool fc2_use_per_expert_act_scale = false; float const* dequant_fc1 = nullptr; // (num_experts_per_node, ) - float const* quant_fc2 = nullptr; // (1, ) + float const* quant_fc2 = nullptr; // (1, ) or (num_experts_per_node, ) based on fc2_use_per_expert_act_scale float const* dequant_fc2 = nullptr; // (num_experts_per_node, ) float const* quant_final = nullptr; // (1, ) float const* dequant_input = nullptr; // (1, ) @@ -223,10 +224,12 @@ struct QuantParams { struct GemmInputs { - float const* act_global_scale = nullptr; // (1, ) + bool use_per_expert_act_scale = false; + float const* act_global_scale + = nullptr; // (1, ) or (num_experts_per_node, ) based on use_per_expert_act_scale TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF const* weight_block_scale - = nullptr; // (experts, n, k / 32) - float const* global_scale = nullptr; // (num_experts_per_node, ) + = nullptr; // (experts, n, k / 32) + float const* global_scale = nullptr; // (num_experts_per_node, ) }; GemmInputs fc1; @@ -238,10 +241,13 @@ struct QuantParams { struct GemmInputs { - float const* act_global_scale = nullptr; // (1, ) + bool use_per_expert_act_scale = false; + + float const* act_global_scale + = nullptr; // (1, ) or (num_experts_per_node, ) based on use_per_expert_act_scale TmaWarpSpecializedGroupedGemmInput::NVFP4ElementSF const* weight_block_scale - = nullptr; // (experts, n, k / 16) - float const* global_scale = nullptr; // (num_experts_per_node, ) + = nullptr; // (experts, n, k / 16) + float const* global_scale = nullptr; // (num_experts_per_node, ) }; GemmInputs fc1; @@ -287,10 +293,11 @@ struct QuantParams } static QuantParams FP8(float const* dequant_fc1, float const* quant_fc2, float const* dequant_fc2, - float const* quant_final = nullptr, float const* dequant_input = nullptr) + float const* quant_final = nullptr, float const* dequant_input = nullptr, + bool fc2_use_per_expert_act_scale = false) { QuantParams qp; - qp.fp8 = {dequant_fc1, quant_fc2, dequant_fc2, quant_final, dequant_input}; + qp.fp8 = {fc2_use_per_expert_act_scale, dequant_fc1, quant_fc2, dequant_fc2, quant_final, dequant_input}; return qp; } @@ -299,12 +306,14 @@ struct QuantParams float const* fc1_global_scale, // float const* fc2_act_global_scale, TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF const* fc2_weight_block_scale, - float const* fc2_global_scale // - ) + float const* fc2_global_scale, // + bool fc1_use_per_expert_act_scale = false, bool fc2_use_per_expert_act_scale = false) { QuantParams qp; - qp.fp8_mxfp4.fc1 = {fc1_act_global_scale, fc1_weight_block_scale, fc1_global_scale}; - qp.fp8_mxfp4.fc2 = {fc2_act_global_scale, fc2_weight_block_scale, fc2_global_scale}; + qp.fp8_mxfp4.fc1 + = {fc1_use_per_expert_act_scale, fc1_act_global_scale, fc1_weight_block_scale, fc1_global_scale}; + qp.fp8_mxfp4.fc2 + = {fc2_use_per_expert_act_scale, fc2_act_global_scale, fc2_weight_block_scale, fc2_global_scale}; return qp; } @@ -313,12 +322,12 @@ struct QuantParams float const* fc1_global_scale, // float const* fc2_act_global_scale, TmaWarpSpecializedGroupedGemmInput::NVFP4ElementSF const* fc2_weight_block_scale, - float const* fc2_global_scale // - ) + float const* fc2_global_scale, // + bool fc1_use_per_expert_act_scale = false, bool fc2_use_per_expert_act_scale = false) { QuantParams qp; - qp.fp4.fc1 = {fc1_act_global_scale, fc1_weight_block_scale, fc1_global_scale}; - qp.fp4.fc2 = {fc2_act_global_scale, fc2_weight_block_scale, fc2_global_scale}; + qp.fp4.fc1 = {fc1_use_per_expert_act_scale, fc1_act_global_scale, fc1_weight_block_scale, fc1_global_scale}; + qp.fp4.fc2 = {fc2_use_per_expert_act_scale, fc2_act_global_scale, fc2_weight_block_scale, fc2_global_scale}; return qp; } diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu index 40a6da8977a..682b7ec2ec8 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu @@ -1237,7 +1237,8 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp ExpandedActivationsType* permuted_output, float const* unpermuted_scales, float* permuted_scales, int const* expanded_dest_row_to_expanded_source_row, int* expanded_source_row_to_expanded_dest_row, int64_t const num_rows, int64_t const cols, int64_t const k, float const* fc1_act_global_scale, - int64_t const* expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat, + bool use_per_expert_act_scale, int64_t const* expert_first_token_offset, + TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat, TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, int64_t const num_experts_per_node) { #ifdef ENABLE_FP4 @@ -1300,7 +1301,8 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp int64_t expert = findTotalEltsLessThanTarget( expert_first_token_offset, num_experts_per_node, (int64_t) expanded_dest_row + 1) - 1; - float global_scale_val = fc1_act_global_scale ? *fc1_act_global_scale : 1.0f; + size_t act_scale_idx = use_per_expert_act_scale ? expert : 0; + float global_scale_val = fc1_act_global_scale ? fc1_act_global_scale[act_scale_idx] : 1.0f; int64_t num_tokens_before_expert = expert_first_token_offset[expert]; for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) @@ -1315,6 +1317,7 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp } else { + assert(act_scale_idx == 0 && "Cannot use per-expert act scale for pre-quantized activations"); writeSF(num_tokens_before_expert, expert, source_row, expanded_dest_row, elem_index, cols, num_rows, fc1_act_sf_flat, input_sf); dest_row_ptr[elem_index] = in_vec; @@ -1345,7 +1348,7 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input, ExpandedActivationsType* permuted_output, float const* unpermuted_scales, float* permuted_scales, int const* expanded_dest_row_to_expanded_source_row, int* expanded_source_row_to_expanded_dest_row, int64_t const num_rows, int64_t const cols, int const k, int const num_experts_per_node, - float const* fc1_act_global_scale, int64_t* expert_first_token_offset, + float const* fc1_act_global_scale, bool use_per_expert_act_scale, int64_t* expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat, TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, cudaStream_t stream) { @@ -1360,6 +1363,11 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input, check_cuda_error(cudaMemsetAsync( fc1_act_sf_flat, 0x0, num_elems * sizeof(TmaWarpSpecializedGroupedGemmInput::NVFP4ElementSF), stream)); } + else + { + TLLM_CHECK_WITH_INFO( + !use_per_expert_act_scale, "Per-expert act scale for FC1 is only supported for FP4 activations"); + } #endif static int const smCount = tensorrt_llm::common::getMultiProcessorCount(); @@ -1380,7 +1388,8 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input, config.attrs = attrs; cudaLaunchKernelEx(&config, func, unpermuted_input, permuted_output, unpermuted_scales, permuted_scales, expanded_dest_row_to_expanded_source_row, expanded_source_row_to_expanded_dest_row, num_rows, cols, k, - fc1_act_global_scale, expert_first_token_offset, fc1_act_sf_flat, input_sf, num_experts_per_node); + fc1_act_global_scale, use_per_expert_act_scale, expert_first_token_offset, fc1_act_sf_flat, input_sf, + num_experts_per_node); } enum class ScaleMode : int @@ -1681,7 +1690,8 @@ template c __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, float const* fp8_quant, ScaleBiasType const* bias_ptr, bool bias_is_broadcast, int64_t const* expert_first_token_offset, int num_experts_per_node, int64_t inter_size, int64_t max_tokens_per_expert, bool gated, - float const* fc2_act_global_scale, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_act_sf_flat) + float const* fc2_act_global_scale, bool use_per_expert_act_scale, + TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_act_sf_flat) { #ifdef ENABLE_FP4 constexpr bool IsFP4 = std::is_same_v; @@ -1705,16 +1715,17 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, size_t output_offset = token * inter_size; int64_t expert = 0; - if (bias_ptr || IsFP4) + if (bias_ptr || IsFP4 || use_per_expert_act_scale) { // TODO this is almost certainly faster as a linear scan expert = findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node, token + 1) - 1; } - float const quant_scale = fp8_quant ? *fp8_quant : 1.f; + size_t act_scale_idx = use_per_expert_act_scale ? expert : 0; + float const quant_scale = fp8_quant ? fp8_quant[act_scale_idx] : 1.f; // Some globals for FP4 - float global_scale_val = fc2_act_global_scale ? *fc2_act_global_scale : 1.0f; + float global_scale_val = fc2_act_global_scale ? fc2_act_global_scale[act_scale_idx] : 1.0f; int64_t num_tokens_before_expert = IsFP4 ? expert_first_token_offset[expert] : 0; size_t bias_offset = 0; @@ -1790,7 +1801,7 @@ template void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8_quant, ScaleBiasType const* bias, bool bias_is_broadcast, int64_t const* expert_first_token_offset, int num_experts_per_node, int64_t inter_size, int64_t num_tokens, int64_t expanded_num_tokens, ActivationType activation_type, float const* fc2_act_global_scale, - TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_act_sf_flat, cudaStream_t stream) + bool use_per_expert_act_scale, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_act_sf_flat, cudaStream_t stream) { static int const smCount = tensorrt_llm::common::getMultiProcessorCount(); // Note: Launching 8 blocks per SM can fully leverage the memory bandwidth (tested on B200). @@ -1819,7 +1830,7 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8 config.attrs = attrs; cudaLaunchKernelEx(&config, fn, output, gemm_result, fp8_quant, bias, bias_is_broadcast, expert_first_token_offset, num_experts_per_node, inter_size, num_tokens, isGatedActivation(activation_type), fc2_act_global_scale, - fc2_act_sf_flat); + use_per_expert_act_scale, fc2_act_sf_flat); } // ============================== Lora Add Bias ================================= @@ -2346,9 +2357,11 @@ void CutlassMoeFCRunner(output, static_cast(gemm_output), fc2_fp8_quant, fc1_expert_biases, bias_is_broadcast, expert_first_token_offset, num_experts_per_node, - inter_size, num_rows, expanded_num_rows, fc1_activation_type, nullptr, nullptr, stream); + inter_size, num_rows, expanded_num_rows, fc1_activation_type, nullptr, use_per_expert_act_scale, nullptr, + stream); sync_check_cuda_error(stream); } @@ -2498,10 +2511,16 @@ void CutlassMoeFCRunner; + bool use_per_expert_act_scale = use_fp4 ? quant_params.fp4.fc2.use_per_expert_act_scale + : use_wfp4afp8 ? quant_params.fp8_mxfp4.fc2.use_per_expert_act_scale + : use_fp8 ? quant_params.fp8.fc2_use_per_expert_act_scale + : false; + doActivation(reinterpret_cast(output), static_cast(gemm_output), fc2_fp8_quant, fc1_expert_biases, bias_is_broadcast, expert_first_token_offset, num_experts_per_node, inter_size, num_rows, expanded_num_rows, - fc1_activation_type, quant_params.fp4.fc2.act_global_scale, fc2_fp4_act_flat, stream); + fc1_activation_type, quant_params.fp4.fc2.act_global_scale, use_per_expert_act_scale, fc2_fp4_act_flat, + stream); sync_check_cuda_error(stream); } @@ -2522,9 +2541,11 @@ void CutlassMoeFCRunner(output, static_cast(intermediate_result), fc2_fp8_quant, fc1_expert_biases, bias_is_broadcast, expert_first_token_offset, num_experts_per_node, - inter_size, num_rows, expanded_num_rows, fc1_activation_type, nullptr, nullptr, stream); + inter_size, num_rows, expanded_num_rows, fc1_activation_type, nullptr, use_per_expert_act_scale, nullptr, + stream); sync_check_cuda_error(stream); } @@ -2687,7 +2708,7 @@ void CutlassMoeFCRunner(gemm_output), static_cast(gemm_output), nullptr, static_cast(fc2_lora), false, expert_first_token_offset, num_experts_per_node, - hidden_size, num_rows, expanded_num_rows, ActivationType::Identity, nullptr, nullptr, stream); + hidden_size, num_rows, expanded_num_rows, ActivationType::Identity, nullptr, false, nullptr, stream); sync_check_cuda_error(stream); } @@ -3129,10 +3150,13 @@ void CutlassMoeFCRunner; + // Only NVFP4xNVFP4 supports FC1 per-expert act scale + bool use_per_expert_act_scale = use_fp4 ? quant_params.fp4.fc1.use_per_expert_act_scale : false; expandInputRowsKernelLauncher(input_activations, reinterpret_cast(permuted_data_), token_topk_unpermuted_scales, permuted_token_final_scales_, permuted_source_token_ids_, expanded_source_row_to_expanded_dest_row, num_rows, hidden_size, experts_per_token, num_experts_per_node, - quant_params.fp4.fc1.act_global_scale, expert_first_token_offset_, fc1_fp4_act_scale_, input_sf, stream); + quant_params.fp4.fc1.act_global_scale, use_per_expert_act_scale, expert_first_token_offset_, + fc1_fp4_act_scale_, input_sf, stream); sync_check_cuda_error(stream); @@ -3211,10 +3235,12 @@ CutlassMoeFCRunner:: auto alpha_scale_flat1 = use_fp4 ? quant_params.fp4.fc1.global_scale : use_wfp4afp8 ? quant_params.fp8_mxfp4.fc1.global_scale - : fp8_dequant1; + : use_fp8 ? fp8_dequant1 + : nullptr; auto alpha_scale_flat2 = use_fp4 ? quant_params.fp4.fc2.global_scale : use_wfp4afp8 ? quant_params.fp8_mxfp4.fc2.global_scale - : fp8_dequant2; + : use_fp8 ? fp8_dequant2 + : nullptr; if (!alpha_scale_flat1 && !alpha_scale_flat2) { layout_info1.alpha_scale_ptr_array = nullptr; @@ -3380,6 +3406,7 @@ CutlassMoeFCRunner:: // fp8_mxfp4 memsets the scaling factors to 1.0f if (quant_params.fp8_mxfp4.fc1.weight_block_scale) { + // We are in FP8 x MXFP4 mode TLLM_CHECK(quant_params.fp8_mxfp4.fc2.weight_block_scale); TLLM_CHECK(fc1_fp4_act_scale_ != nullptr); TLLM_CHECK_WITH_INFO(fc1_fp4_act_scale_ == fc2_fp4_act_scale_, diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz index ce594851068..1a800b30dce 100644 --- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz +++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:c01175cbc8e003e8288e30ad2dc88c2c819147f4d435a5121460533141b04719 -size 64321452 +oid sha256:6d12357919fe6c63749a81e124afd60453153489a3f50cb44b41671d9b55f947 +size 64338696 diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/version.txt b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/version.txt index 240af5e5cfa..62c9a58c081 100644 --- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/version.txt +++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/version.txt @@ -1,2 +1,2 @@ -a1180829a0d8fe772ff37934b72573bb41671e7ed76dfa3bd5cd449348b9683a libtensorrt_llm_internal_cutlass_kernels_static.a -commit c767347ff934578193ee4bad58ba3b9398046245 +ad34c0f31247c880d60e2c8198093e8373cf0e1d3e8badee0424bfa607d6cd8e libtensorrt_llm_internal_cutlass_kernels_static.a +commit bac309ac608d35d7d0144e594bf3e5fa8cfca796 diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/include/moe_kernels.h b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/include/moe_kernels.h index 42821ab4fad..f7d17057099 100644 --- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/include/moe_kernels.h +++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/include/moe_kernels.h @@ -159,8 +159,9 @@ struct QuantParams // FP8 quantization params struct { + bool fc2_use_per_expert_act_scale = false; float const* dequant_fc1 = nullptr; // (num_experts_per_node, ) - float const* quant_fc2 = nullptr; // (1, ) + float const* quant_fc2 = nullptr; // (1, ) or (num_experts_per_node, ) based on fc2_use_per_expert_act_scale float const* dequant_fc2 = nullptr; // (num_experts_per_node, ) float const* quant_final = nullptr; // (1, ) float const* dequant_input = nullptr; // (1, ) @@ -172,10 +173,12 @@ struct QuantParams { struct GemmInputs { - float const* act_global_scale = nullptr; // (1, ) + bool use_per_expert_act_scale = false; + float const* act_global_scale + = nullptr; // (1, ) or (num_experts_per_node, ) based on use_per_expert_act_scale TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF const* weight_block_scale - = nullptr; // (experts, n, k / 32) - float const* global_scale = nullptr; // (num_experts_per_node, ) + = nullptr; // (experts, n, k / 32) + float const* global_scale = nullptr; // (num_experts_per_node, ) }; GemmInputs fc1; @@ -187,10 +190,12 @@ struct QuantParams { struct GemmInputs { - float const* act_global_scale = nullptr; // (1, ) + bool use_per_expert_act_scale = false; + float const* act_global_scale + = nullptr; // (1, ) or (num_experts_per_node, ) based on use_per_expert_act_scale TmaWarpSpecializedGroupedGemmInput::NVFP4ElementSF const* weight_block_scale - = nullptr; // (experts, n, k / 16) - float const* global_scale = nullptr; // (num_experts_per_node, ) + = nullptr; // (experts, n, k / 16) + float const* global_scale = nullptr; // (num_experts_per_node, ) }; GemmInputs fc1; @@ -236,10 +241,11 @@ struct QuantParams } static QuantParams FP8(float const* dequant_fc1, float const* quant_fc2, float const* dequant_fc2, - float const* quant_final = nullptr, float const* dequant_input = nullptr) + float const* quant_final = nullptr, float const* dequant_input = nullptr, + bool fc2_use_per_expert_act_scale = false) { QuantParams qp; - qp.fp8 = {dequant_fc1, quant_fc2, dequant_fc2, quant_final, dequant_input}; + qp.fp8 = {fc2_use_per_expert_act_scale, dequant_fc1, quant_fc2, dequant_fc2, quant_final, dequant_input}; return qp; } @@ -248,12 +254,14 @@ struct QuantParams float const* fc1_global_scale, // float const* fc2_act_global_scale, TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF const* fc2_weight_block_scale, - float const* fc2_global_scale // - ) + float const* fc2_global_scale, // + bool fc1_use_per_expert_act_scale = false, bool fc2_use_per_expert_act_scale = false) { QuantParams qp; - qp.fp8_mxfp4.fc1 = {fc1_act_global_scale, fc1_weight_block_scale, fc1_global_scale}; - qp.fp8_mxfp4.fc2 = {fc2_act_global_scale, fc2_weight_block_scale, fc2_global_scale}; + qp.fp8_mxfp4.fc1 + = {fc1_use_per_expert_act_scale, fc1_act_global_scale, fc1_weight_block_scale, fc1_global_scale}; + qp.fp8_mxfp4.fc2 + = {fc2_use_per_expert_act_scale, fc2_act_global_scale, fc2_weight_block_scale, fc2_global_scale}; return qp; } @@ -262,12 +270,12 @@ struct QuantParams float const* fc1_global_scale, // float const* fc2_act_global_scale, TmaWarpSpecializedGroupedGemmInput::NVFP4ElementSF const* fc2_weight_block_scale, - float const* fc2_global_scale // - ) + float const* fc2_global_scale, // + bool fc1_use_per_expert_act_scale = false, bool fc2_use_per_expert_act_scale = false) { QuantParams qp; - qp.fp4.fc1 = {fc1_act_global_scale, fc1_weight_block_scale, fc1_global_scale}; - qp.fp4.fc2 = {fc2_act_global_scale, fc2_weight_block_scale, fc2_global_scale}; + qp.fp4.fc1 = {fc1_use_per_expert_act_scale, fc1_act_global_scale, fc1_weight_block_scale, fc1_global_scale}; + qp.fp4.fc2 = {fc2_use_per_expert_act_scale, fc2_act_global_scale, fc2_weight_block_scale, fc2_global_scale}; return qp; } @@ -760,8 +768,8 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface QuantParams& quant_params, cudaStream_t stream); T const* applyPrequantScale(void* smoothed_act, void const* permuted_data, void const* prequant_scales, - int const* permuted_token_selected_experts, int64_t const* num_valid_tokens_ptr, - int64_t const expanded_num_rows, int64_t const seq_len, bool const use_awq, cudaStream_t stream); + int64_t const* num_valid_tokens_ptr, int64_t const expanded_num_rows, int64_t const seq_len, bool const use_awq, + cudaStream_t stream); CubKeyValueSorter sorter_; MoeGemmRunner moe_gemm_runner_; diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz index 518b2531ac1..2178f48db9c 100644 --- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz +++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:a139d8316f640da7c2d10cf5461cb0d0d9462d97f00748467ca4202c896a6187 -size 63833516 +oid sha256:53b6f54a21bd547c0da17e3723b7822d4ee16b66b66a545948c0cbee5760bf65 +size 63835444 diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/version.txt b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/version.txt index d48c4297480..721c4d5e522 100644 --- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/version.txt +++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/version.txt @@ -1,2 +1,2 @@ -e7130e36217c1df0d281788fc87764945d9c308bef11ad61b3b1a49c7d41c8af libtensorrt_llm_internal_cutlass_kernels_static.a -commit c767347ff934578193ee4bad58ba3b9398046245 +21c59ede16aa448b6135327bd0f95e72a6e614f219935b8f67fe635b3cb4b38b libtensorrt_llm_internal_cutlass_kernels_static.a +commit bac309ac608d35d7d0144e594bf3e5fa8cfca796 diff --git a/cpp/tensorrt_llm/thop/moeOp.cpp b/cpp/tensorrt_llm/thop/moeOp.cpp index 3d36c53f932..b3f9ef876e5 100644 --- a/cpp/tensorrt_llm/thop/moeOp.cpp +++ b/cpp/tensorrt_llm/thop/moeOp.cpp @@ -632,22 +632,28 @@ class FusedMoeRunner : public torch::CustomClassHolder auto const fc2_dequant = quant_scales.value()[2]; auto const fc1_input_dequant = quant_scales.value()[3]; + // Check types CHECK_INPUT(fc1_dequant, c10::ScalarType::Float); CHECK_INPUT(fc2_quant, c10::ScalarType::Float); CHECK_INPUT(fc2_dequant, c10::ScalarType::Float); CHECK_INPUT(fc1_input_dequant, c10::ScalarType::Float); + // Check ranks TORCH_CHECK(fc1_dequant.dim() == 1, "fc1 dequant must be 1D"); - TORCH_CHECK(fc2_quant.dim() == 0, "fc2 quant must be a scalar tensor"); + TORCH_CHECK(fc2_quant.dim() == 0 || fc2_quant.dim() == 1, "fc2 quant must be a scalar or 1-D tensor"); TORCH_CHECK(fc2_dequant.dim() == 1, "fc2 quant must be 1D"); TORCH_CHECK(fc1_input_dequant.dim() == 0, "fc1 input dequant must be a scalar tensor"); + // Check shapes TORCH_CHECK( fc1_dequant.sizes()[0] == num_experts_on_rank, "fc1 dequant size must be (num_experts_on_rank,)"); + TORCH_CHECK(fc2_quant.dim() == 0 || fc2_quant.sizes()[0] == num_experts_on_rank, + "fc2 quant must be scalar or (num_experts_on_rank,)"); TORCH_CHECK( fc2_dequant.sizes()[0] == num_experts_on_rank, "fc2 dequant size must be (num_experts_on_rank,)"); return kernels::QuantParams::FP8(static_cast(fc1_dequant.data_ptr()), static_cast(fc2_quant.data_ptr()), static_cast(fc2_dequant.data_ptr()), - /* fp8 output quant scale */ nullptr, static_cast(fc1_input_dequant.data_ptr())); + /* fp8 output quant scale */ nullptr, static_cast(fc1_input_dequant.data_ptr()), + fc2_quant.dim() == 1); } else if (isWFp4AFp8Quant()) @@ -663,16 +669,20 @@ class FusedMoeRunner : public torch::CustomClassHolder // The input for scale fc1_weight_block / fc2_weight_block is packed into INT32 constexpr int FP8_PER_INT32 = 4; + // Check types CHECK_INPUT(fc1_weight_block, c10::ScalarType::Int); CHECK_INPUT(fc1_global, c10::ScalarType::Float); CHECK_INPUT(fc2_act_global, c10::ScalarType::Float); CHECK_INPUT(fc2_weight_block, c10::ScalarType::Int); CHECK_INPUT(fc2_global, c10::ScalarType::Float); + // Check ranks TORCH_CHECK(fc1_weight_block.dim() == 3, "fc1 weight block must be #D"); TORCH_CHECK(fc1_global.dim() == 1, "fc1 global must be 1D"); - TORCH_CHECK(fc2_act_global.dim() == 0, "fc2 act global must be a scalar tensor"); + TORCH_CHECK(fc2_act_global.dim() == 0 || fc2_act_global.dim() == 1, + "fc2 act global must be a scalar or 1-D tensor"); TORCH_CHECK(fc2_weight_block.dim() == 3, "fc2 weight block must be 3D"); TORCH_CHECK(fc2_global.dim() == 1, "fc2 global must be 1D"); + // Check shapes TORCH_CHECK(fc1_weight_block.sizes()[0] == num_experts_on_rank && fc1_weight_block.sizes()[1] == inter_size * 2 && fc1_weight_block.sizes()[2] * FP8_PER_INT32 @@ -681,6 +691,8 @@ class FusedMoeRunner : public torch::CustomClassHolder "fc1 weight block size must be (num_experts_on_rank, inter_size * 2, hidden_size // 4 // " "block_scale_vector_size)"); TORCH_CHECK(fc1_global.sizes()[0] == num_experts_on_rank, "fc1 global size must be (num_experts_on_rank,)"); + TORCH_CHECK(fc2_act_global.dim() == 0 || fc2_act_global.sizes()[0] == num_experts_on_rank, + "fc2 act global must be scalar or (num_experts_on_rank,)"); TORCH_CHECK(fc2_weight_block.sizes()[0] == num_experts_on_rank && fc2_weight_block.sizes()[1] == hidden_size && fc2_weight_block.sizes()[2] * FP8_PER_INT32 * TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize @@ -693,7 +705,7 @@ class FusedMoeRunner : public torch::CustomClassHolder static_cast(fc1_weight_block.data_ptr()), static_cast(fc1_global.data_ptr()), static_cast(fc2_act_global.data_ptr()), static_cast(fc2_weight_block.data_ptr()), - static_cast(fc2_global.data_ptr())); + static_cast(fc2_global.data_ptr()), false, fc2_act_global.dim() == 1); } else if (isNvfp4Quant()) { @@ -709,18 +721,25 @@ class FusedMoeRunner : public torch::CustomClassHolder // The input for scale fc1_weight_block / fc2_weight_block is packed into INT32 constexpr int FP8_PER_INT32 = 4; + // Check types CHECK_INPUT(fc1_act_global, c10::ScalarType::Float); CHECK_INPUT(fc1_weight_block, c10::ScalarType::Int); CHECK_INPUT(fc1_global, c10::ScalarType::Float); CHECK_INPUT(fc2_act_global, c10::ScalarType::Float); CHECK_INPUT(fc2_weight_block, c10::ScalarType::Int); CHECK_INPUT(fc2_global, c10::ScalarType::Float); - TORCH_CHECK(fc1_act_global.dim() == 0, "fc1 act global must be a scalar tensor"); + // Check ranks + TORCH_CHECK(fc1_act_global.dim() == 0 || fc1_act_global.dim() == 1, + "fc1 act global must be a scalar or 1-D tensor"); TORCH_CHECK(fc1_weight_block.dim() == 3, "fc1 weight block must be #D"); TORCH_CHECK(fc1_global.dim() == 1, "fc1 global must be 1D"); - TORCH_CHECK(fc2_act_global.dim() == 0, "fc2 act global must be a scalar tensor"); + TORCH_CHECK(fc2_act_global.dim() == 0 || fc2_act_global.dim() == 1, + "fc2 act global must be a scalar or 1-D tensor"); TORCH_CHECK(fc2_weight_block.dim() == 3, "fc2 weight block must be 3D"); TORCH_CHECK(fc2_global.dim() == 1, "fc2 global must be 1D"); + // Check shapes + TORCH_CHECK(fc1_act_global.dim() == 0 || fc1_act_global.sizes()[0] == num_experts_on_rank, + "fc1 act global must be scalar or (num_experts_on_rank,)"); TORCH_CHECK(fc1_weight_block.sizes()[0] == num_experts_on_rank && fc1_weight_block.sizes()[1] == inter_size * 2 && fc1_weight_block.sizes()[2] * FP8_PER_INT32 @@ -729,6 +748,8 @@ class FusedMoeRunner : public torch::CustomClassHolder "fc1 weight block size must be (num_experts_on_rank, inter_size * 2, hidden_size // 4 // " "block_scale_vector_size)"); TORCH_CHECK(fc1_global.sizes()[0] == num_experts_on_rank, "fc1 global size must be (num_experts_on_rank,)"); + TORCH_CHECK(fc2_act_global.dim() == 0 || fc2_act_global.sizes()[0] == num_experts_on_rank, + "fc2 act global must be scalar or (num_experts_on_rank,)"); TORCH_CHECK(fc2_weight_block.sizes()[0] == num_experts_on_rank && fc2_weight_block.sizes()[1] == hidden_size && fc2_weight_block.sizes()[2] * FP8_PER_INT32 * TmaWarpSpecializedGroupedGemmInput::NVFP4BlockScaleVectorSize @@ -741,7 +762,7 @@ class FusedMoeRunner : public torch::CustomClassHolder static_cast(fc1_weight_block.data_ptr()), static_cast(fc1_global.data_ptr()), static_cast(fc2_act_global.data_ptr()), static_cast(fc2_weight_block.data_ptr()), - static_cast(fc2_global.data_ptr())); + static_cast(fc2_global.data_ptr()), fc1_act_global.dim() == 1, fc2_act_global.dim() == 1); } else if (mUseDeepSeekFP8BlockScaling) { diff --git a/cpp/tests/unit_tests/kernels/mixtureOfExpertsTest.cu b/cpp/tests/unit_tests/kernels/mixtureOfExpertsTest.cu index 1cd26b3b910..b5cc8c5f2e0 100644 --- a/cpp/tests/unit_tests/kernels/mixtureOfExpertsTest.cu +++ b/cpp/tests/unit_tests/kernels/mixtureOfExpertsTest.cu @@ -306,6 +306,9 @@ protected: bool mUseLora = false; bool mUsePrequantScale = false; + // Run tests with per-expert act scale + bool mUsePerExpertActScale = true; + bool mIsGated = false; int64_t mGatedMultiplier = 1; int64_t mGroupSize = -1; @@ -480,12 +483,12 @@ protected: { // FP4 uses the same logic as FP8 to generate the global scales mExpertFPXScale1 = allocBuffer(mNumExperts); - mExpertFPXScale2 = allocBuffer(1); + mExpertFPXScale2 = allocBuffer(mNumExperts); // mNumExperts or 1 mExpertFPXScale3 = allocBuffer(mNumExperts); if (ANY_FP4) { - mExpertFP4ActGlobalScale1 = allocBuffer(1); + mExpertFP4ActGlobalScale1 = allocBuffer(mNumExperts); // mNumExperts or 1 mExpertFP4WeightGlobalScale1 = allocBuffer(mNumExperts); mExpertFP4WeightGlobalScale2 = allocBuffer(mNumExperts); } @@ -665,23 +668,37 @@ protected: float scaleAct1 = getFPXActScalar(max_input); float maxFC1Output = calcMLPVal(max_input, maxIndex) / maxW2; - float scaleAct2 = getFPXActScalar(maxFC1Output); + + std::vector scales_1; + std::vector scales_2; + std::vector scales_3; + if (mUsePerExpertActScale) + { + scales_2 = std::vector(mNumExperts); + for (int i = 0; i < mNumExperts; i++) + { + float maxExpertOutput = calcMLPVal(max_input, i) / applyExpertShift(mExpertWDiag2, i); + float scaleAct2 = getFPXActScalar(maxExpertOutput); + scales_2[i] = scaleAct2; + } + } + else + { + float scaleAct2 = getFPXActScalar(maxFC1Output); + scales_2 = std::vector(mNumExperts, scaleAct2); + } ASSERT_NE(mExpertFPXScale1, nullptr); ASSERT_NE(mExpertFPXScale2, nullptr); ASSERT_NE(mExpertFPXScale3, nullptr); - std::vector scales_1; - std::vector scales_2; - std::vector scales_3; if (ANY_FP4) { std::vector scale_global_w1(mNumExperts); std::vector scale_global_w2(mNumExperts); - std::vector scales_0(1, scaleAct1); + std::vector scales_0(mUsePerExpertActScale && NVFP4 ? mNumExperts : 1, scaleAct1); scales_1 = std::vector(mNumExperts); - scales_2 = std::vector(1, scaleAct2); scales_3 = std::vector(mNumExperts); for (int i = 0; i < mNumExperts; i++) @@ -695,7 +712,7 @@ protected: // TODO Per expert scaling factors scales_1[i] = 1.f / (scaleAct1 * scaleW1); - scales_3[i] = 1.f / (scaleAct2 * scaleW2); + scales_3[i] = 1.f / (scales_2[i] * scaleW2); } ASSERT_NE(mExpertFP4ActGlobalScale1, nullptr); @@ -713,8 +730,17 @@ protected: mFP8WeightScalar1 = scaleW1; mFP8WeightScalar2 = scaleW2; scales_1 = std::vector(mNumExperts, 1.f / (scaleW1 * scaleAct1)); - scales_2 = std::vector(1, scaleAct2); - scales_3 = std::vector(mNumExperts, 1.f / (scaleW2 * scaleAct2)); + scales_3 = std::vector(mNumExperts); + + for (int i = 0; i < mNumExperts; i++) + { + scales_3[i] = 1.f / (scaleW2 * scales_2[i]); + } + } + + if (!mUsePerExpertActScale) + { + scales_2.resize(1); } check_cuda_error(cudaMemcpyAsync(mExpertFPXScale1, scales_1.data(), scales_1.size() * sizeof(float), @@ -893,6 +919,10 @@ protected: ep_scale_1 = mExpertFPXScale1 + experts_per_node * parallelism_config.ep_rank; ep_scale_3 = mExpertFPXScale3 + experts_per_node * parallelism_config.ep_rank; } + if (mUsePerExpertActScale) + { + ep_scale_2 = mExpertFPXScale2 + experts_per_node * parallelism_config.ep_rank; + } // Slice weights for TP void* scale_1 = ep_scale_1; @@ -1039,18 +1069,22 @@ protected: else if (FP8) { ASSERT_TRUE(scale1_ptr && scale2_ptr && scale3_ptr); - quant_params = QuantParams::FP8(static_cast(scale1_ptr), - static_cast(scale2_ptr), static_cast(scale3_ptr)); + quant_params + = QuantParams::FP8(static_cast(scale1_ptr), static_cast(scale2_ptr), + static_cast(scale3_ptr), nullptr, nullptr, mUsePerExpertActScale); } else if (ANY_FP4) { ASSERT_TRUE(mExpertFP4ActGlobalScale1); ASSERT_TRUE(mFP4ScalingFactorsW1 && mFP4ScalingFactorsW2); ASSERT_TRUE(scale1_ptr && scale2_ptr && scale3_ptr); + auto fc1_sf_offset = mUsePerExpertActScale && NVFP4 + ? mNumExperts / parallelism_config.ep_size * parallelism_config.ep_rank + : 0; auto constructor = NVFP4 ? &QuantParams::FP4 : &QuantParams::FP8MXFP4; - quant_params - = constructor(mExpertFP4ActGlobalScale1, mFP4ScalingFactorsW1, static_cast(scale1_ptr), - static_cast(scale2_ptr), mFP4ScalingFactorsW2, static_cast(scale3_ptr)); + quant_params = constructor(mExpertFP4ActGlobalScale1 + fc1_sf_offset, mFP4ScalingFactorsW1, + static_cast(scale1_ptr), static_cast(scale2_ptr), mFP4ScalingFactorsW2, + static_cast(scale3_ptr), mUsePerExpertActScale && NVFP4, mUsePerExpertActScale); } if constexpr (WEIGHT_FP4) @@ -1497,6 +1531,19 @@ TYPED_TEST(MixtureOfExpertsTest, PermuteNoBias) this->BasicPermuteTest(3); } +TYPED_TEST(MixtureOfExpertsTest, PermuteSingletonScale) +{ + if (!this->ANY_FPX) + { + GTEST_SKIP() << "Only FPX cares about per-expert act scale"; + return; + } + this->mUsePerExpertActScale = false; + this->BasicPermuteTest(1); + this->BasicPermuteTest(2); + this->BasicPermuteTest(3); +} + TYPED_TEST(MixtureOfExpertsTest, PermuteGelu) { this->mActType = ActivationType::Gelu;