Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture

auto func = NVFP4 ? QuantParams::FP4 : QuantParams::FP8MXFP4;
mQuantParams = func(mExpertFP4ActScale1, mExpertFP4WeightSf1, mExpertFP4GlobalScale1, mExpertFP4ActScale2,
mExpertFP4WeightSf2, mExpertFP4GlobalScale2);
mExpertFP4WeightSf2, mExpertFP4GlobalScale2, false, false);
}

mSelectedExperts = allocBuffer<int>(mTotalTokens * mK);
Expand Down
43 changes: 26 additions & 17 deletions cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,9 @@ struct QuantParams
// FP8 quantization params
struct
{
bool fc2_use_per_expert_act_scale = false;
float const* dequant_fc1 = nullptr; // (num_experts_per_node, )
float const* quant_fc2 = nullptr; // (1, )
float const* quant_fc2 = nullptr; // (1, ) or (num_experts_per_node, ) based on fc2_use_per_expert_act_scale
float const* dequant_fc2 = nullptr; // (num_experts_per_node, )
float const* quant_final = nullptr; // (1, )
float const* dequant_input = nullptr; // (1, )
Expand All @@ -223,10 +224,12 @@ struct QuantParams
{
struct GemmInputs
{
float const* act_global_scale = nullptr; // (1, )
bool use_per_expert_act_scale = false;
float const* act_global_scale
= nullptr; // (1, ) or (num_experts_per_node, ) based on use_per_expert_act_scale
TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF const* weight_block_scale
= nullptr; // (experts, n, k / 32)
float const* global_scale = nullptr; // (num_experts_per_node, )
= nullptr; // (experts, n, k / 32)
float const* global_scale = nullptr; // (num_experts_per_node, )
};

GemmInputs fc1;
Expand All @@ -238,10 +241,13 @@ struct QuantParams
{
struct GemmInputs
{
float const* act_global_scale = nullptr; // (1, )
bool use_per_expert_act_scale = false;

float const* act_global_scale
= nullptr; // (1, ) or (num_experts_per_node, ) based on use_per_expert_act_scale
TmaWarpSpecializedGroupedGemmInput::NVFP4ElementSF const* weight_block_scale
= nullptr; // (experts, n, k / 16)
float const* global_scale = nullptr; // (num_experts_per_node, )
= nullptr; // (experts, n, k / 16)
float const* global_scale = nullptr; // (num_experts_per_node, )
};

GemmInputs fc1;
Expand Down Expand Up @@ -287,10 +293,11 @@ struct QuantParams
}

static QuantParams FP8(float const* dequant_fc1, float const* quant_fc2, float const* dequant_fc2,
float const* quant_final = nullptr, float const* dequant_input = nullptr)
float const* quant_final = nullptr, float const* dequant_input = nullptr,
bool fc2_use_per_expert_act_scale = false)
{
QuantParams qp;
qp.fp8 = {dequant_fc1, quant_fc2, dequant_fc2, quant_final, dequant_input};
qp.fp8 = {fc2_use_per_expert_act_scale, dequant_fc1, quant_fc2, dequant_fc2, quant_final, dequant_input};
return qp;
}

Expand All @@ -299,12 +306,14 @@ struct QuantParams
float const* fc1_global_scale, //
float const* fc2_act_global_scale,
TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF const* fc2_weight_block_scale,
float const* fc2_global_scale //
)
float const* fc2_global_scale, //
bool fc1_use_per_expert_act_scale = false, bool fc2_use_per_expert_act_scale = false)
{
QuantParams qp;
qp.fp8_mxfp4.fc1 = {fc1_act_global_scale, fc1_weight_block_scale, fc1_global_scale};
qp.fp8_mxfp4.fc2 = {fc2_act_global_scale, fc2_weight_block_scale, fc2_global_scale};
qp.fp8_mxfp4.fc1
= {fc1_use_per_expert_act_scale, fc1_act_global_scale, fc1_weight_block_scale, fc1_global_scale};
qp.fp8_mxfp4.fc2
= {fc2_use_per_expert_act_scale, fc2_act_global_scale, fc2_weight_block_scale, fc2_global_scale};
return qp;
}

Expand All @@ -313,12 +322,12 @@ struct QuantParams
float const* fc1_global_scale, //
float const* fc2_act_global_scale,
TmaWarpSpecializedGroupedGemmInput::NVFP4ElementSF const* fc2_weight_block_scale,
float const* fc2_global_scale //
)
float const* fc2_global_scale, //
bool fc1_use_per_expert_act_scale = false, bool fc2_use_per_expert_act_scale = false)
{
QuantParams qp;
qp.fp4.fc1 = {fc1_act_global_scale, fc1_weight_block_scale, fc1_global_scale};
qp.fp4.fc2 = {fc2_act_global_scale, fc2_weight_block_scale, fc2_global_scale};
qp.fp4.fc1 = {fc1_use_per_expert_act_scale, fc1_act_global_scale, fc1_weight_block_scale, fc1_global_scale};
qp.fp4.fc2 = {fc2_use_per_expert_act_scale, fc2_act_global_scale, fc2_weight_block_scale, fc2_global_scale};
return qp;
}

Expand Down
61 changes: 44 additions & 17 deletions cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1237,7 +1237,8 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp
ExpandedActivationsType* permuted_output, float const* unpermuted_scales, float* permuted_scales,
int const* expanded_dest_row_to_expanded_source_row, int* expanded_source_row_to_expanded_dest_row,
int64_t const num_rows, int64_t const cols, int64_t const k, float const* fc1_act_global_scale,
int64_t const* expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat,
bool use_per_expert_act_scale, int64_t const* expert_first_token_offset,
TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat,
TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, int64_t const num_experts_per_node)
{
#ifdef ENABLE_FP4
Expand Down Expand Up @@ -1300,7 +1301,8 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp
int64_t expert = findTotalEltsLessThanTarget(
expert_first_token_offset, num_experts_per_node, (int64_t) expanded_dest_row + 1)
- 1;
float global_scale_val = fc1_act_global_scale ? *fc1_act_global_scale : 1.0f;
size_t act_scale_idx = use_per_expert_act_scale ? expert : 0;
float global_scale_val = fc1_act_global_scale ? fc1_act_global_scale[act_scale_idx] : 1.0f;
int64_t num_tokens_before_expert = expert_first_token_offset[expert];

for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride)
Expand All @@ -1315,6 +1317,7 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp
}
else
{
assert(act_scale_idx == 0 && "Cannot use per-expert act scale for pre-quantized activations");
writeSF(num_tokens_before_expert, expert, source_row, expanded_dest_row, elem_index, cols, num_rows,
fc1_act_sf_flat, input_sf);
dest_row_ptr[elem_index] = in_vec;
Expand Down Expand Up @@ -1345,7 +1348,7 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input,
ExpandedActivationsType* permuted_output, float const* unpermuted_scales, float* permuted_scales,
int const* expanded_dest_row_to_expanded_source_row, int* expanded_source_row_to_expanded_dest_row,
int64_t const num_rows, int64_t const cols, int const k, int const num_experts_per_node,
float const* fc1_act_global_scale, int64_t* expert_first_token_offset,
float const* fc1_act_global_scale, bool use_per_expert_act_scale, int64_t* expert_first_token_offset,
TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat,
TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, cudaStream_t stream)
{
Expand All @@ -1360,6 +1363,11 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input,
check_cuda_error(cudaMemsetAsync(
fc1_act_sf_flat, 0x0, num_elems * sizeof(TmaWarpSpecializedGroupedGemmInput::NVFP4ElementSF), stream));
}
else
{
TLLM_CHECK_WITH_INFO(
!use_per_expert_act_scale, "Per-expert act scale for FC1 is only supported for FP4 activations");
}
#endif

static int const smCount = tensorrt_llm::common::getMultiProcessorCount();
Expand All @@ -1380,7 +1388,8 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input,
config.attrs = attrs;
cudaLaunchKernelEx(&config, func, unpermuted_input, permuted_output, unpermuted_scales, permuted_scales,
expanded_dest_row_to_expanded_source_row, expanded_source_row_to_expanded_dest_row, num_rows, cols, k,
fc1_act_global_scale, expert_first_token_offset, fc1_act_sf_flat, input_sf, num_experts_per_node);
fc1_act_global_scale, use_per_expert_act_scale, expert_first_token_offset, fc1_act_sf_flat, input_sf,
num_experts_per_node);
}

enum class ScaleMode : int
Expand Down Expand Up @@ -1681,7 +1690,8 @@ template <class T, class GemmOutputType, class ScaleBiasType, template <class> c
__global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, float const* fp8_quant,
ScaleBiasType const* bias_ptr, bool bias_is_broadcast, int64_t const* expert_first_token_offset,
int num_experts_per_node, int64_t inter_size, int64_t max_tokens_per_expert, bool gated,
float const* fc2_act_global_scale, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_act_sf_flat)
float const* fc2_act_global_scale, bool use_per_expert_act_scale,
TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_act_sf_flat)
{
#ifdef ENABLE_FP4
constexpr bool IsFP4 = std::is_same_v<T, __nv_fp4_e2m1>;
Expand All @@ -1705,16 +1715,17 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result,
size_t output_offset = token * inter_size;

int64_t expert = 0;
if (bias_ptr || IsFP4)
if (bias_ptr || IsFP4 || use_per_expert_act_scale)
{
// TODO this is almost certainly faster as a linear scan
expert = findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node, token + 1) - 1;
}

float const quant_scale = fp8_quant ? *fp8_quant : 1.f;
size_t act_scale_idx = use_per_expert_act_scale ? expert : 0;
float const quant_scale = fp8_quant ? fp8_quant[act_scale_idx] : 1.f;

// Some globals for FP4
float global_scale_val = fc2_act_global_scale ? *fc2_act_global_scale : 1.0f;
float global_scale_val = fc2_act_global_scale ? fc2_act_global_scale[act_scale_idx] : 1.0f;
int64_t num_tokens_before_expert = IsFP4 ? expert_first_token_offset[expert] : 0;

size_t bias_offset = 0;
Expand Down Expand Up @@ -1790,7 +1801,7 @@ template <class T, class GemmOutputType, class ScaleBiasType>
void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8_quant, ScaleBiasType const* bias,
bool bias_is_broadcast, int64_t const* expert_first_token_offset, int num_experts_per_node, int64_t inter_size,
int64_t num_tokens, int64_t expanded_num_tokens, ActivationType activation_type, float const* fc2_act_global_scale,
TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_act_sf_flat, cudaStream_t stream)
bool use_per_expert_act_scale, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_act_sf_flat, cudaStream_t stream)
{
static int const smCount = tensorrt_llm::common::getMultiProcessorCount();
// Note: Launching 8 blocks per SM can fully leverage the memory bandwidth (tested on B200).
Expand Down Expand Up @@ -1819,7 +1830,7 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8
config.attrs = attrs;
cudaLaunchKernelEx(&config, fn, output, gemm_result, fp8_quant, bias, bias_is_broadcast, expert_first_token_offset,
num_experts_per_node, inter_size, num_tokens, isGatedActivation(activation_type), fc2_act_global_scale,
fc2_act_sf_flat);
use_per_expert_act_scale, fc2_act_sf_flat);
}

// ============================== Lora Add Bias =================================
Expand Down Expand Up @@ -2346,9 +2357,11 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, ScaleBiasType, Ena

sync_check_cuda_error(stream);
constexpr bool bias_is_broadcast = true;
constexpr bool use_per_expert_act_scale = false;
doActivation<T, UnfusedGemmOutputType>(output, static_cast<UnfusedGemmOutputType const*>(gemm_output),
fc2_fp8_quant, fc1_expert_biases, bias_is_broadcast, expert_first_token_offset, num_experts_per_node,
inter_size, num_rows, expanded_num_rows, fc1_activation_type, nullptr, nullptr, stream);
inter_size, num_rows, expanded_num_rows, fc1_activation_type, nullptr, use_per_expert_act_scale, nullptr,
stream);

sync_check_cuda_error(stream);
}
Expand Down Expand Up @@ -2498,10 +2511,16 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab

// TODO: when bias_is_broadcast is false, fuse bias to gemm
using GatedActOutputType = std::conditional_t<use_w4afp8, BackBoneType, T>;
bool use_per_expert_act_scale = use_fp4 ? quant_params.fp4.fc2.use_per_expert_act_scale
: use_wfp4afp8 ? quant_params.fp8_mxfp4.fc2.use_per_expert_act_scale
: use_fp8 ? quant_params.fp8.fc2_use_per_expert_act_scale
: false;

doActivation<GatedActOutputType, UnfusedGemmOutputType>(reinterpret_cast<GatedActOutputType*>(output),
static_cast<UnfusedGemmOutputType const*>(gemm_output), fc2_fp8_quant, fc1_expert_biases, bias_is_broadcast,
expert_first_token_offset, num_experts_per_node, inter_size, num_rows, expanded_num_rows,
fc1_activation_type, quant_params.fp4.fc2.act_global_scale, fc2_fp4_act_flat, stream);
fc1_activation_type, quant_params.fp4.fc2.act_global_scale, use_per_expert_act_scale, fc2_fp4_act_flat,
stream);

sync_check_cuda_error(stream);
}
Expand All @@ -2522,9 +2541,11 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
/*use_fused_moe*/ false, stream, config};
gemm_runner.moeGemm(universal_input, TmaWarpSpecializedGroupedGemmInput{});

bool use_per_expert_act_scale = use_fp8 ? quant_params.fp8.fc2_use_per_expert_act_scale : false;
doActivation<T, UnfusedGemmOutputType>(output, static_cast<UnfusedGemmOutputType const*>(intermediate_result),
fc2_fp8_quant, fc1_expert_biases, bias_is_broadcast, expert_first_token_offset, num_experts_per_node,
inter_size, num_rows, expanded_num_rows, fc1_activation_type, nullptr, nullptr, stream);
inter_size, num_rows, expanded_num_rows, fc1_activation_type, nullptr, use_per_expert_act_scale, nullptr,
stream);

sync_check_cuda_error(stream);
}
Expand Down Expand Up @@ -2687,7 +2708,7 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
loraBiasApplyFunc(static_cast<UnfusedGemmOutputType*>(gemm_output),
static_cast<UnfusedGemmOutputType const*>(gemm_output), nullptr,
static_cast<ScaleBiasType const*>(fc2_lora), false, expert_first_token_offset, num_experts_per_node,
hidden_size, num_rows, expanded_num_rows, ActivationType::Identity, nullptr, nullptr, stream);
hidden_size, num_rows, expanded_num_rows, ActivationType::Identity, nullptr, false, nullptr, stream);
sync_check_cuda_error(stream);
}

Expand Down Expand Up @@ -3129,10 +3150,13 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
}

using ExpandedActivationsType = std::conditional_t<use_w4afp8, BackBoneType, T>;
// Only NVFP4xNVFP4 supports FC1 per-expert act scale
bool use_per_expert_act_scale = use_fp4 ? quant_params.fp4.fc1.use_per_expert_act_scale : false;
expandInputRowsKernelLauncher(input_activations, reinterpret_cast<ExpandedActivationsType*>(permuted_data_),
token_topk_unpermuted_scales, permuted_token_final_scales_, permuted_source_token_ids_,
expanded_source_row_to_expanded_dest_row, num_rows, hidden_size, experts_per_token, num_experts_per_node,
quant_params.fp4.fc1.act_global_scale, expert_first_token_offset_, fc1_fp4_act_scale_, input_sf, stream);
quant_params.fp4.fc1.act_global_scale, use_per_expert_act_scale, expert_first_token_offset_,
fc1_fp4_act_scale_, input_sf, stream);

sync_check_cuda_error(stream);

Expand Down Expand Up @@ -3211,10 +3235,12 @@ CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enable>::

auto alpha_scale_flat1 = use_fp4 ? quant_params.fp4.fc1.global_scale
: use_wfp4afp8 ? quant_params.fp8_mxfp4.fc1.global_scale
: fp8_dequant1;
: use_fp8 ? fp8_dequant1
: nullptr;
auto alpha_scale_flat2 = use_fp4 ? quant_params.fp4.fc2.global_scale
: use_wfp4afp8 ? quant_params.fp8_mxfp4.fc2.global_scale
: fp8_dequant2;
: use_fp8 ? fp8_dequant2
: nullptr;
if (!alpha_scale_flat1 && !alpha_scale_flat2)
{
layout_info1.alpha_scale_ptr_array = nullptr;
Expand Down Expand Up @@ -3380,6 +3406,7 @@ CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enable>::
// fp8_mxfp4 memsets the scaling factors to 1.0f
if (quant_params.fp8_mxfp4.fc1.weight_block_scale)
{
// We are in FP8 x MXFP4 mode
TLLM_CHECK(quant_params.fp8_mxfp4.fc2.weight_block_scale);
TLLM_CHECK(fc1_fp4_act_scale_ != nullptr);
TLLM_CHECK_WITH_INFO(fc1_fp4_act_scale_ == fc2_fp4_act_scale_,
Expand Down
Git LFS file not shown
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
a1180829a0d8fe772ff37934b72573bb41671e7ed76dfa3bd5cd449348b9683a libtensorrt_llm_internal_cutlass_kernels_static.a
commit c767347ff934578193ee4bad58ba3b9398046245
ad34c0f31247c880d60e2c8198093e8373cf0e1d3e8badee0424bfa607d6cd8e libtensorrt_llm_internal_cutlass_kernels_static.a
commit bac309ac608d35d7d0144e594bf3e5fa8cfca796
Loading