diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h index abb4911a8e8..6adf5cbf348 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h @@ -87,32 +87,6 @@ struct LoraParams namespace cutlass_kernels { -static inline size_t pad_to_multiple_of_16(size_t const& input) -{ - static constexpr int ALIGNMENT = 16; - return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT); -} - -class CubKeyValueSorter -{ -public: - CubKeyValueSorter(); - - CubKeyValueSorter(int const num_experts_per_node); - - void updateNumExperts(int const num_experts_per_node); - - static size_t getWorkspaceSize(size_t const num_key_value_pairs, int const num_experts_per_node); - - void run(void* workspace, size_t const workspace_size, int const* keys_in, int* keys_out, int const* values_in, - int* values_out, size_t const num_key_value_pairs, cudaStream_t stream); - -private: - static int expertsToBits(int experts); - int num_experts_; - int num_bits_; -}; - /** * \brief Describes what parallelism mode the MoE is using * @@ -397,9 +371,9 @@ class CutlassMoeFCRunnerInterface ActivationType fc1_activation_type, void const* fc2_expert_weights, void const* fc2_expert_biases, QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size, int const num_experts, int const experts_per_token, char* workspace_ptr, void* final_output, - int* expanded_source_row_to_expanded_dest_row, MOEParallelismConfig parallelism_config, - bool const enable_alltoall, bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, - bool min_latency_mode, MoeMinLatencyParams& min_latency_params, cudaStream_t stream) + int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool const enable_alltoall, + bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode, + MoeMinLatencyParams& min_latency_params, cudaStream_t stream) = 0; // Aliases for profiling the gemms @@ -413,7 +387,7 @@ class CutlassMoeFCRunnerInterface int const num_experts_per_node, ActivationType fc1_activation_type, float const** alpha_scale_ptr_array, bool bias_is_broadcast, bool use_deepseek_fp8_block_scale, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per, - int* active_expert_global_ids, int start_expert) + int* active_expert_global_ids) = 0; virtual void gemm2(void const* const input, void* const gemm_output, void* const final_output, @@ -421,14 +395,14 @@ class CutlassMoeFCRunnerInterface void const* const fc2_expert_weights, void const* const fc2_expert_biases, void const* const fc2_int_scales, float const* const fc2_fp8_dequant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc2_fp4_act_flat, QuantParams quant_params, float const* const token_topk_unpermuted_scales, - float const* const token_topk_permuted_scales, int const* const expanded_source_row_to_expanded_dest_row, - int const* expanded_dest_row_to_expanded_source_row, int const* const expert_for_source_row, + float const* const token_topk_permuted_scales, int const* const unpermuted_row_to_permuted_row, + int const* permuted_row_to_unpermuted_row, int const* const token_selected_experts, int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, int64_t const experts_per_token, float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora, bool use_deepseek_fp8_block_scale, cudaStream_t stream, MOEParallelismConfig parallelism_config, bool const enable_alltoall, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, - int* num_active_experts_per, int* active_expert_global_ids, int start_expert) + int* num_active_experts_per, int* active_expert_global_ids) = 0; virtual std::pair @@ -544,9 +518,9 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface ActivationType fc1_activation_type, void const* fc2_expert_weights, void const* fc2_expert_biases, QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size, int const num_experts, int const experts_per_token, char* workspace_ptr, void* final_output, - int* expanded_source_row_to_expanded_dest_row, MOEParallelismConfig parallelism_config, - bool const enable_alltoall, bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, - bool min_latency_mode, MoeMinLatencyParams& min_latency_params, cudaStream_t stream) override; + int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool const enable_alltoall, + bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode, + MoeMinLatencyParams& min_latency_params, cudaStream_t stream) override; // We make these GEMM1 & GEMM2 static because they need to be stateless for the profiler to work static void gemm1(MoeGemmRunner& gemm_runner, @@ -565,7 +539,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, ActivationType fc1_activation_type, float const** alpha_scale_ptr_array, bool bias_is_broadcast, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config, - bool min_latency_mode, int* num_active_experts_per, int* active_expert_global_ids, int start_expert); + bool min_latency_mode, int* num_active_experts_per, int* active_expert_global_ids); static void gemm2(MoeGemmRunner& gemm_runner, DeepSeekBlockScaleGemmRunner* fp8_blockscale_gemm_runner, T const* const input, void* const gemm_output, @@ -574,14 +548,14 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface ScaleBiasType const* const fc2_expert_biases, ScaleBiasType const* const fc2_int_scales, float const* const fc2_fp8_dequant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc2_fp4_act_flat, QuantParams quant_params, float const* const token_topk_unpermuted_scales, - float const* const token_topk_permuted_scales, int const* const expanded_source_row_to_expanded_dest_row, - int const* expanded_dest_row_to_expanded_source_row, int const* const expert_for_source_row, + float const* const token_topk_permuted_scales, int const* const unpermuted_row_to_permuted_row, + int const* permuted_row_to_unpermuted_row, int const* const token_selected_experts, int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, int64_t const experts_per_token, float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora, cudaStream_t stream, MOEParallelismConfig parallelism_config, bool const enable_alltoall, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per, - int* active_expert_global_ids, int start_expert); + int* active_expert_global_ids); // Overrides to allow us to forward on to the internal functions with the pointers using the correct type void gemm1(void const* const input, void* const output, void* const intermediate_result, @@ -594,7 +568,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface int const num_experts_per_node, ActivationType fc1_activation_type, float const** alpha_scale_ptr_array, bool bias_is_broadcast, bool use_deepseek_fp8_block_scale, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per, - int* active_expert_global_ids, int start_expert) override + int* active_expert_global_ids) override { auto* block_scale_gemm_runner = use_deepseek_fp8_block_scale ? getDeepSeekBlockScaleGemmRunner() : nullptr; return Self::gemm1(moe_gemm_runner_, block_scale_gemm_runner, static_cast(input), @@ -603,7 +577,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface num_valid_tokens_ptr, static_cast(fc1_int_scales), fc1_fp8_dequant, fc2_fp8_quant, fc1_fp4_act_flat, fc2_fp4_act_flat, quant_params, num_rows, expanded_num_rows, hidden_size, inter_size, num_experts_per_node, fc1_activation_type, alpha_scale_ptr_array, bias_is_broadcast, stream, config, - min_latency_mode, num_active_experts_per, active_expert_global_ids, start_expert); + min_latency_mode, num_active_experts_per, active_expert_global_ids); } void gemm2(void const* const input, void* const gemm_output, void* const final_output, @@ -611,25 +585,25 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface void const* const fc2_expert_weights, void const* const fc2_expert_biases, void const* const fc2_int_scales, float const* const fc2_fp8_dequant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc2_fp4_act_flat, QuantParams quant_params, float const* const token_topk_unpermuted_scales, - float const* const token_topk_permuted_scales, int const* const expanded_source_row_to_expanded_dest_row, - int const* expanded_dest_row_to_expanded_source_row, int const* const expert_for_source_row, + float const* const token_topk_permuted_scales, int const* const unpermuted_row_to_permuted_row, + int const* permuted_row_to_unpermuted_row, int const* const token_selected_experts, int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, int64_t const experts_per_token, float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora, bool use_deepseek_fp8_block_scale, cudaStream_t stream, MOEParallelismConfig parallelism_config, bool const enable_alltoall, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, - int* num_active_experts_per, int* active_expert_global_ids, int start_expert) override + int* num_active_experts_per, int* active_expert_global_ids) override { auto* block_scale_gemm_runner = use_deepseek_fp8_block_scale ? getDeepSeekBlockScaleGemmRunner() : nullptr; return Self::gemm2(moe_gemm_runner_, block_scale_gemm_runner, static_cast(input), gemm_output, static_cast(final_output), expert_first_token_offset, tma_ws_input_template, static_cast(fc2_expert_weights), static_cast(fc2_expert_biases), static_cast(fc2_int_scales), fc2_fp8_dequant, fc2_fp4_act_flat, quant_params, - token_topk_unpermuted_scales, token_topk_permuted_scales, expanded_source_row_to_expanded_dest_row, - expanded_dest_row_to_expanded_source_row, expert_for_source_row, num_valid_tokens_ptr, num_rows, - expanded_num_rows, hidden_size, inter_size, num_experts_per_node, experts_per_token, alpha_scale_ptr_array, - use_lora, fc2_lora, stream, parallelism_config, enable_alltoall, config, min_latency_mode, - num_active_experts_per, active_expert_global_ids, start_expert); + token_topk_unpermuted_scales, token_topk_permuted_scales, unpermuted_row_to_permuted_row, + permuted_row_to_unpermuted_row, token_selected_experts, num_valid_tokens_ptr, num_rows, expanded_num_rows, + hidden_size, inter_size, num_experts_per_node, experts_per_token, alpha_scale_ptr_array, use_lora, fc2_lora, + stream, parallelism_config, enable_alltoall, config, min_latency_mode, num_active_experts_per, + active_expert_global_ids); } virtual size_t getGemmWorkspaceSize(int num_experts_per_node) const override @@ -763,10 +737,10 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface static void BlockScaleFC2(DeepSeekBlockScaleGemmRunner& gemm_runner, T const* const input, void* const gemm_output, OutputType* const final_output, int64_t const* const expert_first_token_offset, WeightType const* const fc2_expert_weights, ScaleBiasType const* const fc2_expert_biases, - float const* const token_topk_unpermuted_scales, int const* const expanded_source_row_to_expanded_dest_row, - int const* const expanded_dest_row_to_expanded_source_row, int const* const expert_for_source_row, + float const* const token_topk_unpermuted_scales, int const* const unpermuted_row_to_permuted_row, + int const* const permuted_row_to_unpermuted_row, int const* const token_selected_experts, int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows, - int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, int64_t const k, + int64_t const hidden_size, int64_t const inter_size, int64_t const num_experts_per_node, int64_t const k, MOEParallelismConfig parallelism_config, bool const enable_alltoall, QuantParams& quant_params, cudaStream_t stream); @@ -774,7 +748,6 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface int64_t const* num_valid_tokens_ptr, int64_t const expanded_num_rows, int64_t const seq_len, bool const use_awq, cudaStream_t stream); - CubKeyValueSorter sorter_; MoeGemmRunner moe_gemm_runner_; std::unique_ptr blockscale_gemm_runner_; @@ -782,11 +755,11 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface std::optional gemm2_config_; // Pointers - int* unpermuted_token_selected_experts_{}; - int* unpermuted_source_token_ids_{}; - int* permuted_source_token_ids_{}; + int* permuted_row_to_unpermuted_row_{}; int* permuted_token_selected_experts_{}; - char* sorter_ws_{}; + int* blocked_expert_counts_{}; + int* blocked_expert_counts_cumsum_{}; + int* blocked_row_to_unpermuted_row_{}; T* permuted_data_{}; float* permuted_token_final_scales_{}; @@ -859,7 +832,6 @@ struct GemmProfilerBackend mParallelismConfig = parallelism_config; mEnableAlltoall = enable_alltoall; mSM = common::getSMVersion(); - mSorter.updateNumExperts(mNumExpertsPerNode); mScalingType = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE; if (dtype == nvinfer1::DataType::kFP8 @@ -883,7 +855,6 @@ struct GemmProfilerBackend cudaStream_t const& stream); CutlassMoeFCRunnerInterface* mInterface; - CubKeyValueSorter mSorter; GemmToProfile mGemmToProfile = GemmToProfile::Undefined; std::vector mAllTacticsSaved; diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu index 682b7ec2ec8..963ba2a2910 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu @@ -59,14 +59,10 @@ #error CUDART_VERSION Undefined! #elif (CUDART_VERSION >= 11050) #include -#include -#include #include #include #else #include "3rdparty/cub/cub.cuh" -#include "3rdparty/cub/device/device_radix_sort.cuh" -#include "3rdparty/cub/util_type.cuh" #endif using namespace tensorrt_llm::kernels; @@ -309,7 +305,7 @@ void buildMinLatencyActiveExpertMaps(int* num_active_experts_per_node, float* ex template __global__ void fusedBuildExpertMapsSortFirstTokenKernel(int const* const token_selected_experts, - int* const unpermuted_token_selected_experts, int* const permuted_source_token_ids, + int* const permuted_row_to_unpermuted_row, int* const unpermuted_row_to_permuted_row, int64_t* const expert_first_token_offset, int64_t const num_tokens, int const experts_per_token, int const start_expert, int const end_expert, int const num_experts_per_node) { @@ -381,8 +377,10 @@ __global__ void fusedBuildExpertMapsSortFirstTokenKernel(int const* const token_ #pragma unroll for (int i = 0; i < EXPERTS_PER_TOKEN; i++) { - unpermuted_token_selected_experts[token * EXPERTS_PER_TOKEN + i] = local_token_selected_experts[i]; - permuted_source_token_ids[local_token_permuted_indices[i]] = i * num_tokens + token; + int const unpermuted_row = i * num_tokens + token; + int const permuted_row = local_token_permuted_indices[i]; + permuted_row_to_unpermuted_row[permuted_row] = unpermuted_row; + unpermuted_row_to_permuted_row[unpermuted_row] = permuted_row; } } @@ -398,10 +396,10 @@ __global__ void fusedBuildExpertMapsSortFirstTokenKernel(int const* const token_ } template -bool fusedBuildExpertMapsSortFirstTokenDispatch(int const* token_selected_experts, - int* unpermuted_token_selected_experts, int* permuted_source_token_ids, int64_t* expert_first_token_offset, - int64_t const num_tokens, int const num_experts_per_node, int const experts_per_token, int const start_expert, - int const end_expert, cudaStream_t stream) +bool fusedBuildExpertMapsSortFirstTokenDispatch(int const* token_selected_experts, int* permuted_row_to_unpermuted_row, + int* unpermuted_row_to_permuted_row, int64_t* expert_first_token_offset, int64_t const num_tokens, + int const num_experts_per_node, int const experts_per_token, int const start_expert, int const end_expert, + cudaStream_t stream) { TLLM_CHECK_WITH_INFO(num_experts_per_node == (end_expert - start_expert), "num_experts_per_node must be equal to end_expert - start_expert"); @@ -438,18 +436,18 @@ bool fusedBuildExpertMapsSortFirstTokenDispatch(int const* token_selected_expert } check_cuda_error(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_size)); - check_cuda_error(cudaLaunchKernelEx(&config, kernel, token_selected_experts, unpermuted_token_selected_experts, - permuted_source_token_ids, expert_first_token_offset, num_tokens, experts_per_token, start_expert, end_expert, - num_experts_per_node)); + check_cuda_error(cudaLaunchKernelEx(&config, kernel, token_selected_experts, permuted_row_to_unpermuted_row, + unpermuted_row_to_permuted_row, expert_first_token_offset, num_tokens, experts_per_token, start_expert, + end_expert, num_experts_per_node)); return true; } template -bool fusedBuildExpertMapsSortFirstTokenBlockSize(int const* token_selected_experts, - int* unpermuted_token_selected_experts, int* permuted_source_token_ids, int64_t* expert_first_token_offset, - int64_t const num_tokens, int const num_experts_per_node, int const experts_per_token, int const start_expert, - int const end_expert, cudaStream_t stream) +bool fusedBuildExpertMapsSortFirstTokenBlockSize(int const* token_selected_experts, int* permuted_row_to_unpermuted_row, + int* unpermuted_row_to_permuted_row, int64_t* expert_first_token_offset, int64_t const num_tokens, + int const num_experts_per_node, int const experts_per_token, int const start_expert, int const end_expert, + cudaStream_t stream) { int const block_size = num_tokens; if (num_tokens > 256) @@ -473,16 +471,16 @@ bool fusedBuildExpertMapsSortFirstTokenBlockSize(int const* token_selected_exper func = &fusedBuildExpertMapsSortFirstTokenDispatch<256, EXPERTS_PER_TOKEN, LOG2_NUM_EXPERTS>; } - return func(token_selected_experts, unpermuted_token_selected_experts, permuted_source_token_ids, + return func(token_selected_experts, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_tokens, num_experts_per_node, experts_per_token, start_expert, end_expert, stream); } template -bool fusedBuildExpertMapsSortFirstTokenBlockSize(int const* token_selected_experts, - int* unpermuted_token_selected_experts, int* permuted_source_token_ids, int64_t* expert_first_token_offset, - int64_t const num_tokens, int const num_experts_per_node, int const experts_per_token, int const start_expert, - int const end_expert, cudaStream_t stream) +bool fusedBuildExpertMapsSortFirstTokenBlockSize(int const* token_selected_experts, int* permuted_row_to_unpermuted_row, + int* unpermuted_row_to_permuted_row, int64_t* expert_first_token_offset, int64_t const num_tokens, + int const num_experts_per_node, int const experts_per_token, int const start_expert, int const end_expert, + cudaStream_t stream) { auto func = &fusedBuildExpertMapsSortFirstTokenBlockSize<1, LOG2_NUM_EXPERTS>; switch (experts_per_token) @@ -518,13 +516,13 @@ bool fusedBuildExpertMapsSortFirstTokenBlockSize(int const* token_selected_exper return false; } } - return func(token_selected_experts, unpermuted_token_selected_experts, permuted_source_token_ids, + return func(token_selected_experts, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_tokens, num_experts_per_node, experts_per_token, start_expert, end_expert, stream); } -bool fusedBuildExpertMapsSortFirstToken(int const* token_selected_experts, int* unpermuted_token_selected_experts, - int* permuted_source_token_ids, int64_t* expert_first_token_offset, int64_t const num_tokens, +bool fusedBuildExpertMapsSortFirstToken(int const* token_selected_experts, int* permuted_row_to_unpermuted_row, + int* unpermuted_row_to_permuted_row, int64_t* expert_first_token_offset, int64_t const num_tokens, int const num_experts_per_node, int const experts_per_token, int const start_expert, int const end_expert, cudaStream_t stream) { @@ -539,57 +537,86 @@ bool fusedBuildExpertMapsSortFirstToken(int const* token_selected_experts, int* &fusedBuildExpertMapsSortFirstTokenBlockSize<6>, &fusedBuildExpertMapsSortFirstTokenBlockSize<7>, &fusedBuildExpertMapsSortFirstTokenBlockSize<8>, &fusedBuildExpertMapsSortFirstTokenBlockSize<9>}; - return funcs[expert_log - 1](token_selected_experts, unpermuted_token_selected_experts, - permuted_source_token_ids, expert_first_token_offset, num_tokens, num_experts_per_node, experts_per_token, - start_expert, end_expert, stream); + return funcs[expert_log - 1](token_selected_experts, permuted_row_to_unpermuted_row, + unpermuted_row_to_permuted_row, expert_first_token_offset, num_tokens, num_experts_per_node, + experts_per_token, start_expert, end_expert, stream); } TLLM_LOG_TRACE("Experts per node %d does not have supported fused moe prologues", num_experts_per_node); return false; } -/** - * Takes the input maps and prepares the expanded maps for the sort step - * @param unpermuted_token_selected_experts: Buffer of transformed expert ids masked for the current node, used as the - * keys for the sort - * @param unpermuted_source_token_ids: Buffer of unpermuted token ids that will be used to identify the source row for - * each expanded token, used as the values for the sort - */ -__global__ void buildExpertMapsKernel(int const* token_selected_experts, int* unpermuted_token_selected_experts, - int* unpermuted_source_token_ids, int64_t const num_tokens, int const experts_per_token, int const start_expert, - int const end_expert, int const num_experts_per_node) +int64_t computeNumTokensPerBlock(int64_t const num_tokens, int64_t const num_experts_per_node) { - int const token = blockIdx.x * blockDim.x + threadIdx.x; - if (token >= num_tokens) + for (int64_t num_tokens_per_block = 32; num_tokens_per_block <= 1024; num_tokens_per_block *= 2) { - return; + int64_t const num_blocks_per_seq = tensorrt_llm::common::ceilDiv(num_tokens, num_tokens_per_block); + if (num_blocks_per_seq * num_experts_per_node <= num_tokens_per_block) + { + return num_tokens_per_block; + } } + return 1024; +} + +template +__global__ void blockExpertPrefixSumKernel(int const* token_selected_experts, int* blocked_expert_counts, + int* blocked_row_to_unpermuted_row, int64_t const num_tokens, int64_t const num_experts_per_token, + int const start_expert_id) +{ + using BlockScan = cub::BlockScan; + __shared__ typename BlockScan::TempStorage temp_storage; + + // target_expert_id and expert_id are offset by start_expert_id + int const target_expert_id = blockIdx.x; + int const block_id = blockIdx.y; + int const num_blocks_per_seq = gridDim.y; + int const token_id = block_id * kNumTokensPerBlock + threadIdx.x; + #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.wait;"); #endif - for (int i = 0; i < experts_per_token; i++) + int expanded_token_id = -1; + if (token_id < num_tokens) { - int const expert = token_selected_experts[token * experts_per_token + i]; - // If expert is not in the current node, set it to num_experts_per_node - // If expert is in the current node, subtract start_expert to shift the range to [0, num_experts_per_node) - bool is_valid_expert = expert >= start_expert && expert < end_expert; - unpermuted_token_selected_experts[token * experts_per_token + i] - = is_valid_expert ? (expert - start_expert) : num_experts_per_node; - unpermuted_source_token_ids[token * experts_per_token + i] = i * num_tokens + token; + for (int i = 0; i < num_experts_per_token; i++) + { + // TODO(enweiz): Fix uncoalesced access with shared memory. + int const expert_id = token_selected_experts[token_id * num_experts_per_token + i] - start_expert_id; + if (expert_id == target_expert_id) + { + expanded_token_id = i * num_tokens + token_id; + break; + } + } } + + int const has_matched = expanded_token_id >= 0 ? 1 : 0; + int index; + BlockScan(temp_storage).ExclusiveSum(has_matched, index); + + if (has_matched) + { + blocked_row_to_unpermuted_row[target_expert_id * num_tokens + block_id * kNumTokensPerBlock + index] + = expanded_token_id; + } + if (threadIdx.x == kNumTokensPerBlock - 1) + { + blocked_expert_counts[target_expert_id * num_blocks_per_seq + block_id] = index + has_matched; + } + #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.launch_dependents;"); #endif } -void buildExpertMaps(int const* token_selected_experts, int* unpermuted_token_selected_experts, - int* unpermuted_source_token_ids, int64_t const num_tokens, int const num_experts_per_node, - int const experts_per_token, int const start_expert, int const end_expert, cudaStream_t stream) +void blockExpertPrefixSum(int const* token_selected_experts, int* blocked_expert_counts, + int* blocked_row_to_unpermuted_row, int64_t const num_tokens, int64_t const num_experts_per_node, + int64_t const num_experts_per_token, int64_t const num_tokens_per_block, int64_t const num_blocks_per_seq, + int const start_expert_id, cudaStream_t stream) { - TLLM_CHECK_WITH_INFO(num_experts_per_node == (end_expert - start_expert), - "num_experts_per_node must be equal to end_expert - start_expert"); - int const threads = std::min(int64_t(1024), num_tokens); - int const blocks = (num_tokens + threads - 1) / threads; + dim3 const blocks(num_experts_per_node, num_blocks_per_seq); + dim3 const threads(num_tokens_per_block); cudaLaunchConfig_t config; config.gridDim = blocks; @@ -601,117 +628,208 @@ void buildExpertMaps(int const* token_selected_experts, int* unpermuted_token_se attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL(); config.numAttrs = 1; config.attrs = attrs; - cudaLaunchKernelEx(&config, buildExpertMapsKernel, token_selected_experts, unpermuted_token_selected_experts, - unpermuted_source_token_ids, num_tokens, experts_per_token, start_expert, end_expert, num_experts_per_node); -} -// ========================== CUB Sorting things ==================================== -CubKeyValueSorter::CubKeyValueSorter() - : num_experts_(0) - , num_bits_(sizeof(int) * 8) -{ + auto func = blockExpertPrefixSumKernel<1024>; + if (num_tokens_per_block <= 32) + { + func = blockExpertPrefixSumKernel<32>; + } + else if (num_tokens_per_block <= 64) + { + func = blockExpertPrefixSumKernel<64>; + } + else if (num_tokens_per_block <= 128) + { + func = blockExpertPrefixSumKernel<128>; + } + else if (num_tokens_per_block <= 256) + { + func = blockExpertPrefixSumKernel<256>; + } + else if (num_tokens_per_block <= 512) + { + func = blockExpertPrefixSumKernel<512>; + } + cudaLaunchKernelEx(&config, func, token_selected_experts, blocked_expert_counts, blocked_row_to_unpermuted_row, + num_tokens, num_experts_per_token, start_expert_id); } -int CubKeyValueSorter::expertsToBits(int num_experts) +template +__global__ void globalExpertPrefixSumLargeKernel(int const* blocked_expert_counts, int* blocked_expert_counts_cumsum, + int64_t* expert_first_token_offset, int64_t const num_experts_per_node, int64_t const num_blocks_per_seq, + int64_t const num_elem_per_thread) { - // Max value we represent is V = num_experts + (num_experts - 1) = 2 * num_experts - 1 - // The maximum number of bits is therefore floor(log2(V)) + 1 - return static_cast(log2(2 * num_experts - 1)) + 1; -} + using BlockScan = cub::BlockScan; + __shared__ typename BlockScan::TempStorage temp_storage; -CubKeyValueSorter::CubKeyValueSorter(int const num_experts) - : num_experts_(num_experts) - , num_bits_(expertsToBits(num_experts)) -{ -} + int offset = threadIdx.x * num_elem_per_thread; + int cnt = 0; -void CubKeyValueSorter::updateNumExperts(int const num_experts) -{ - num_experts_ = num_experts; - num_bits_ = expertsToBits(num_experts); -} +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); +#endif -size_t CubKeyValueSorter::getWorkspaceSize(size_t const num_key_value_pairs, int const num_experts) -{ - int num_bits = expertsToBits(num_experts); - size_t required_storage = 0; - int* null_int = nullptr; - cub::DeviceRadixSort::SortPairs( - nullptr, required_storage, null_int, null_int, null_int, null_int, num_key_value_pairs, 0, num_bits); + // Note: Because of limited registers, cannot store thread-level prefix sum or enable #pragma unroll + for (int i = 0; i < num_elem_per_thread; i++) + { + // TODO(enweiz): Fix uncoalesced access with shared memory. + if (offset + i < num_experts_per_node * num_blocks_per_seq) + { + cnt += blocked_expert_counts[offset + i]; + } + } + + int cumsum; + BlockScan(temp_storage).ExclusiveSum(cnt, cumsum); - // TODO: fix DeviceRadixSort - // when num_key_value_pairs, num_experts, num_bits, required_storage = 64, 4, 3, 0 - // The required_storage seems to vary between 0 and 1 for the same inputs - if (required_storage == 0) + for (int i = 0; i < num_elem_per_thread; i++) { - required_storage = 1; + if (offset + i < num_experts_per_node * num_blocks_per_seq) + { + blocked_expert_counts_cumsum[offset + i] = cumsum; + if ((offset + i) % num_blocks_per_seq == 0) + { + expert_first_token_offset[(offset + i) / num_blocks_per_seq] = cumsum; + } + cumsum += blocked_expert_counts[offset + i]; + if ((offset + i) == num_experts_per_node * num_blocks_per_seq - 1) + { + expert_first_token_offset[num_experts_per_node] = cumsum; + } + } } - return required_storage; + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif } -void CubKeyValueSorter::run(void* workspace, size_t const workspace_size, int const* keys_in, int* keys_out, - int const* values_in, int* values_out, size_t const num_key_value_pairs, cudaStream_t stream) +template +__global__ void globalExpertPrefixSumKernel(int const* blocked_expert_counts, int* blocked_expert_counts_cumsum, + int64_t* expert_first_token_offset, int64_t const num_experts_per_node, int64_t const num_blocks_per_seq) { - size_t expected_ws_size = getWorkspaceSize(num_key_value_pairs, num_experts_); - size_t actual_ws_size = workspace_size; + using BlockScan = cub::BlockScan; + __shared__ typename BlockScan::TempStorage temp_storage; - TLLM_CHECK_WITH_INFO(expected_ws_size <= workspace_size, - "[CubKeyValueSorter::run] The allocated workspace is too small to run this problem."); - cub::DeviceRadixSort::SortPairs( - workspace, actual_ws_size, keys_in, keys_out, values_in, values_out, num_key_value_pairs, 0, num_bits_, stream); -} +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); +#endif -// ============================== Infer GEMM sizes ================================= -// TODO Could linear search be better for small # experts -template -__device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices, int64_t const arr_length, T const target) -{ - int64_t low = 0, high = arr_length - 1, target_location = -1; - while (low <= high) - { - int64_t mid = (low + high) / 2; + int const cnt = threadIdx.x < num_experts_per_node * num_blocks_per_seq ? blocked_expert_counts[threadIdx.x] : 0; + int cumsum; + BlockScan(temp_storage).ExclusiveSum(cnt, cumsum); - if (sorted_indices[mid] >= target) + if (threadIdx.x < num_experts_per_node * num_blocks_per_seq) + { + blocked_expert_counts_cumsum[threadIdx.x] = cumsum; + if (threadIdx.x % num_blocks_per_seq == 0) { - high = mid - 1; + expert_first_token_offset[threadIdx.x / num_blocks_per_seq] = cumsum; } - else + if (threadIdx.x == num_experts_per_node * num_blocks_per_seq - 1) { - low = mid + 1; - target_location = mid; + expert_first_token_offset[num_experts_per_node] = cumsum + cnt; } } - return target_location + 1; + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif } -// Calculates the start offset of the tokens for a given expert. The last element is the total number of valid tokens -__global__ void computeExpertFirstTokenOffsetKernel(int const* sorted_experts, int64_t const sorted_experts_len, - int64_t const num_experts_per_node, int64_t* expert_first_token_offset) +void globalExpertPrefixSum(int const* blocked_expert_counts, int* blocked_expert_counts_cumsum, + int64_t* expert_first_token_offset, int64_t const num_experts_per_node, int64_t const num_tokens_per_block, + int64_t const num_blocks_per_seq, cudaStream_t stream) { - // First, compute the global tid. We only need 1 thread per expert. - int const expert = blockIdx.x * blockDim.x + threadIdx.x; + int64_t const num_elements = num_experts_per_node * num_blocks_per_seq; - // Note that expert goes [0, num_experts] (inclusive) because we want a count for the total number of active tokens - // at the end of the scan. - if (expert >= num_experts_per_node + 1) + cudaLaunchConfig_t config; + config.gridDim = 1; + config.blockDim = 1024; + config.dynamicSmemBytes = 0; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL(); + config.numAttrs = 1; + config.attrs = attrs; + + if (num_elements <= 1024) { - return; + auto func = globalExpertPrefixSumKernel<1024>; + if (num_elements <= 32) + { + func = globalExpertPrefixSumKernel<32>; + config.blockDim = 32; + } + else if (num_elements <= 64) + { + func = globalExpertPrefixSumKernel<64>; + config.blockDim = 64; + } + else if (num_elements <= 128) + { + func = globalExpertPrefixSumKernel<128>; + config.blockDim = 128; + } + else if (num_elements <= 256) + { + func = globalExpertPrefixSumKernel<256>; + config.blockDim = 256; + } + else if (num_elements <= 512) + { + func = globalExpertPrefixSumKernel<512>; + config.blockDim = 512; + } + cudaLaunchKernelEx(&config, func, blocked_expert_counts, blocked_expert_counts_cumsum, + expert_first_token_offset, num_experts_per_node, num_blocks_per_seq); + } + else + { + auto func = globalExpertPrefixSumLargeKernel<1024>; + int64_t const num_elem_per_thread = tensorrt_llm::common::ceilDiv(num_elements, 1024); + cudaLaunchKernelEx(&config, func, blocked_expert_counts, blocked_expert_counts_cumsum, + expert_first_token_offset, num_experts_per_node, num_blocks_per_seq, num_elem_per_thread); } +} + +__global__ void mergeExpertPrefixSumKernel(int const* blocked_expert_counts, int const* blocked_expert_counts_cumsum, + int const* blocked_row_to_unpermuted_row, int* permuted_token_selected_experts, int* permuted_row_to_unpermuted_row, + int* unpermuted_row_to_permuted_row, int const num_tokens) +{ + int const target_expert_id = blockIdx.x; + int const block_id = blockIdx.y; + int const num_blocks_per_seq = gridDim.y; + int const token_id = block_id * blockDim.x + threadIdx.x; + #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.wait;"); #endif - expert_first_token_offset[expert] = findTotalEltsLessThanTarget(sorted_experts, sorted_experts_len, expert); + + int const cnt = blocked_expert_counts[target_expert_id * num_blocks_per_seq + block_id]; + int const offset = blocked_expert_counts_cumsum[target_expert_id * num_blocks_per_seq + block_id]; + if (threadIdx.x < cnt) + { + int const unpermuted_row = blocked_row_to_unpermuted_row[target_expert_id * num_tokens + token_id]; + int const permuted_row = offset + threadIdx.x; + permuted_row_to_unpermuted_row[permuted_row] = unpermuted_row; + permuted_token_selected_experts[permuted_row] = target_expert_id; + unpermuted_row_to_permuted_row[unpermuted_row] = permuted_row; + } + #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.launch_dependents;"); #endif } -void computeExpertFirstTokenOffset(int const* sorted_indices, int const total_indices, int const num_experts_per_node, - int64_t* expert_first_token_offset, cudaStream_t stream) +void mergeExpertPrefixSum(int const* blocked_expert_counts, int const* blocked_expert_counts_cumsum, + int const* blocked_row_to_unpermuted_row, int* permuted_token_selected_experts, int* permuted_row_to_unpermuted_row, + int* unpermuted_row_to_permuted_row, int64_t const num_tokens, int64_t const num_experts_per_node, + int64_t const num_tokens_per_block, int64_t const num_blocks_per_seq, cudaStream_t stream) { - int const num_entries = num_experts_per_node + 1; - int const threads = std::min(1024, num_entries); - int const blocks = (num_entries + threads - 1) / threads; + dim3 const blocks(num_experts_per_node, num_blocks_per_seq); + dim3 const threads(num_tokens_per_block); cudaLaunchConfig_t config; config.gridDim = blocks; @@ -723,8 +841,80 @@ void computeExpertFirstTokenOffset(int const* sorted_indices, int const total_in attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL(); config.numAttrs = 1; config.attrs = attrs; - cudaLaunchKernelEx(&config, computeExpertFirstTokenOffsetKernel, sorted_indices, total_indices, - num_experts_per_node, expert_first_token_offset); + + cudaLaunchKernelEx(&config, mergeExpertPrefixSumKernel, blocked_expert_counts, blocked_expert_counts_cumsum, + blocked_row_to_unpermuted_row, permuted_token_selected_experts, permuted_row_to_unpermuted_row, + unpermuted_row_to_permuted_row, num_tokens); +} + +// threeStepBuildExpertMapsSortFirstToken uses three kernels to achieve the sort of token_selected_experts + +// 1. blockExpertPrefixSumKernel launches [num_experts_per_node, num_blocks_per_seq] CTAs; each CTA has +// num_tokens_per_block threads. blocked_row_to_unpermuted_row points to a 2D buffer of size [num_experts_per_node, +// num_tokens], which can be viewed as [num_experts_per_node, num_blocks_per_seq] blocks, and each block has +// num_tokens_per_block tokens. Note that each CTA corresponds to a block in blocked_row_to_unpermuted_row. Within each +// CTA, the threads leverage cub::BlockScan to compute the offsets of tokens that activate the target expert. If a +// thread's token activates the target expert, the thread stores its unpermuted_row to the buffer block with the offset. +// In addition, the kernel also stores the expert counts for each block to another 2D buffer blocked_expert_counts of +// size [num_experts_per_node, num_blocks_per_seq]. + +// 2. globalExpertPrefixSumKernel launches 1 CTA; that CTA has num_experts_per_node * num_blocks_per_seq threads. +// The kernel views blocked_expert_counts as a 1D buffer, and leverages cub::BlockScan to compute the prefix sum of the +// expert counts for each block. The prefix sum is stored to blocked_expert_counts_cumsum. + +// 3. mergeExpertPrefixSumKernel launches [num_experts_per_node, num_blocks_per_seq] CTAs; each CTA has +// num_tokens_per_block threads. Each CTA obtains the block-level offset from blocked_expert_counts_cumsum, and thus +// compacts blocked_row_to_unpermuted_row to permuted_row_to_unpermuted_row. In addition, with the block-level offsets, +// the kernel fills permuted_token_selected_experts. + +// computeNumTokensPerBlock decides num_tokens_per_block. Note that both blockExpertPrefixSumKernel and +// globalExpertPrefixSumKernel leverage cub::BlockScan, and their CTA sizes are num_tokens_per_block and +// num_experts_per_node * num_blocks_per_seq, respectively. computeNumTokensPerBlock tries to find a minimum CTA size +// for both kernels, so that the block-leval cub::BlockScan can be efficient. + +void threeStepBuildExpertMapsSortFirstToken(int const* token_selected_experts, int* permuted_token_selected_experts, + int* permuted_row_to_unpermuted_row, int* unpermuted_row_to_permuted_row, int64_t* expert_first_token_offset, + int* blocked_expert_counts, int* blocked_expert_counts_cumsum, int* blocked_row_to_unpermuted_row, + int64_t const num_tokens, int64_t const num_experts_per_node, int64_t const num_experts_per_token, + int const start_expert_id, cudaStream_t stream) +{ + int64_t const num_tokens_per_block = computeNumTokensPerBlock(num_tokens, num_experts_per_node); + int64_t const num_blocks_per_seq = tensorrt_llm::common::ceilDiv(num_tokens, num_tokens_per_block); + + blockExpertPrefixSum(token_selected_experts, blocked_expert_counts, blocked_row_to_unpermuted_row, num_tokens, + num_experts_per_node, num_experts_per_token, num_tokens_per_block, num_blocks_per_seq, start_expert_id, stream); + sync_check_cuda_error(stream); + + globalExpertPrefixSum(blocked_expert_counts, blocked_expert_counts_cumsum, expert_first_token_offset, + num_experts_per_node, num_tokens_per_block, num_blocks_per_seq, stream); + sync_check_cuda_error(stream); + + mergeExpertPrefixSum(blocked_expert_counts, blocked_expert_counts_cumsum, blocked_row_to_unpermuted_row, + permuted_token_selected_experts, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, num_tokens, + num_experts_per_node, num_tokens_per_block, num_blocks_per_seq, stream); +} + +// ============================== Infer GEMM sizes ================================= +// TODO Could linear search be better for small # experts +template +__device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices, int64_t const arr_length, T const target) +{ + int64_t low = 0, high = arr_length - 1, target_location = -1; + while (low <= high) + { + int64_t mid = (low + high) / 2; + + if (sorted_indices[mid] >= target) + { + high = mid - 1; + } + else + { + low = mid + 1; + target_location = mid; + } + } + return target_location + 1; } template @@ -1225,7 +1415,7 @@ __host__ __device__ constexpr static U arrayConvert(T const& input) // duplicate some rows in the input matrix to match the dimensions. Duplicates will always get routed to separate // experts in the end. -// Note that the expanded_dest_row_to_expanded_source_row map referred to here has indices in the range (0, +// Note that the permuted_row_to_unpermuted_row map referred to here has indices in the range (0, // k*rows_in_input - 1). However, it is set up so that index 0, rows_in_input, 2*rows_in_input ... (k-1)*rows_in_input // all map to row 0 in the original matrix. Thus, to know where to read in the source matrix, we simply take the modulus // of the expanded index. @@ -1235,9 +1425,8 @@ constexpr static int EXPAND_THREADS_PER_BLOCK = 256; template __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_input, ExpandedActivationsType* permuted_output, float const* unpermuted_scales, float* permuted_scales, - int const* expanded_dest_row_to_expanded_source_row, int* expanded_source_row_to_expanded_dest_row, - int64_t const num_rows, int64_t const cols, int64_t const k, float const* fc1_act_global_scale, - bool use_per_expert_act_scale, int64_t const* expert_first_token_offset, + int const* permuted_row_to_unpermuted_row, int64_t const num_rows, int64_t const cols, int64_t const k, + float const* fc1_act_global_scale, bool use_per_expert_act_scale, int64_t const* expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat, TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, int64_t const num_experts_per_node) { @@ -1260,18 +1449,9 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp int64_t const num_valid_tokens = expert_first_token_offset[num_experts_per_node]; - for (int64_t expanded_dest_row = blockIdx.x; expanded_dest_row < num_valid_tokens; expanded_dest_row += gridDim.x) + for (int64_t permuted_row = blockIdx.x; permuted_row < num_valid_tokens; permuted_row += gridDim.x) { - // Reverse permutation map. - // I do this so that later, we can use the source -> dest map to do the k-way reduction and unpermuting. I need - // the reverse map for that reduction to allow each threadblock to do 1 k-way reduce without atomics later in - // MoE. 1 thread block will be responsible for all k summations. - int64_t const expanded_source_row = expanded_dest_row_to_expanded_source_row[expanded_dest_row]; - if (threadIdx.x == 0) - { - assert(expanded_dest_row <= INT32_MAX); - expanded_source_row_to_expanded_dest_row[expanded_source_row] = static_cast(expanded_dest_row); - } + int64_t const unpermuted_row = permuted_row_to_unpermuted_row[permuted_row]; // Load 128-bits per thread constexpr int64_t ELEM_PER_THREAD @@ -1282,14 +1462,13 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp using OutputElem = std::conditional_t; // Duplicate and permute rows - int64_t const source_k_rank = expanded_source_row / num_rows; - int64_t const source_row = expanded_source_row % num_rows; + int64_t const source_k_rank = unpermuted_row / num_rows; + int64_t const source_row = unpermuted_row % num_rows; auto const* source_row_ptr = reinterpret_cast(unpermuted_input + source_row * cols / ELEM_PER_BYTE); // Cast first to handle when this is FP4 - auto* dest_row_ptr - = reinterpret_cast(permuted_output) + expanded_dest_row * cols / ELEM_PER_THREAD; + auto* dest_row_ptr = reinterpret_cast(permuted_output) + permuted_row * cols / ELEM_PER_THREAD; int64_t const start_offset = threadIdx.x; int64_t const stride = EXPAND_THREADS_PER_BLOCK; @@ -1299,7 +1478,7 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp if constexpr (is_fp4) { int64_t expert = findTotalEltsLessThanTarget( - expert_first_token_offset, num_experts_per_node, (int64_t) expanded_dest_row + 1) + expert_first_token_offset, num_experts_per_node, (int64_t) permuted_row + 1) - 1; size_t act_scale_idx = use_per_expert_act_scale ? expert : 0; float global_scale_val = fc1_act_global_scale ? fc1_act_global_scale[act_scale_idx] : 1.0f; @@ -1311,14 +1490,14 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp if constexpr (need_fp4_quant) { auto res = quantizePackedFP4Value(in_vec, global_scale_val, - num_tokens_before_expert, expert, expanded_dest_row, elem_index, cols, num_rows, - fc1_act_sf_flat, TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4); + num_tokens_before_expert, expert, permuted_row, elem_index, cols, num_rows, fc1_act_sf_flat, + TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4); dest_row_ptr[elem_index] = res; } else { assert(act_scale_idx == 0 && "Cannot use per-expert act scale for pre-quantized activations"); - writeSF(num_tokens_before_expert, expert, source_row, expanded_dest_row, elem_index, cols, num_rows, + writeSF(num_tokens_before_expert, expert, source_row, permuted_row, elem_index, cols, num_rows, fc1_act_sf_flat, input_sf); dest_row_ptr[elem_index] = in_vec; } @@ -1335,7 +1514,7 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp if (permuted_scales && threadIdx.x == 0) { int64_t const source_k_idx = source_row * k + source_k_rank; - permuted_scales[expanded_dest_row] = unpermuted_scales ? unpermuted_scales[source_k_idx] : 1.0f; + permuted_scales[permuted_row] = unpermuted_scales ? unpermuted_scales[source_k_idx] : 1.0f; } } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) @@ -1346,10 +1525,9 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp template void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input, ExpandedActivationsType* permuted_output, float const* unpermuted_scales, float* permuted_scales, - int const* expanded_dest_row_to_expanded_source_row, int* expanded_source_row_to_expanded_dest_row, - int64_t const num_rows, int64_t const cols, int const k, int const num_experts_per_node, - float const* fc1_act_global_scale, bool use_per_expert_act_scale, int64_t* expert_first_token_offset, - TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat, + int const* permuted_row_to_unpermuted_row, int64_t const num_rows, int64_t const cols, int const k, + int const num_experts_per_node, float const* fc1_act_global_scale, bool use_per_expert_act_scale, + int64_t* expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat, TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, cudaStream_t stream) { #ifdef ENABLE_FP4 @@ -1387,9 +1565,8 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input, config.numAttrs = 1; config.attrs = attrs; cudaLaunchKernelEx(&config, func, unpermuted_input, permuted_output, unpermuted_scales, permuted_scales, - expanded_dest_row_to_expanded_source_row, expanded_source_row_to_expanded_dest_row, num_rows, cols, k, - fc1_act_global_scale, use_per_expert_act_scale, expert_first_token_offset, fc1_act_sf_flat, input_sf, - num_experts_per_node); + permuted_row_to_unpermuted_row, num_rows, cols, k, fc1_act_global_scale, use_per_expert_act_scale, + expert_first_token_offset, fc1_act_sf_flat, input_sf, num_experts_per_node); } enum class ScaleMode : int @@ -1405,8 +1582,8 @@ constexpr static int FINALIZE_THREADS_PER_BLOCK = 256; template __global__ void finalizeMoeRoutingKernel(GemmOutputType const* expanded_permuted_rows, OutputType* reduced_unpermuted_output, ScaleBiasType const* bias, float const* scales, - int const* expanded_source_row_to_expanded_dest_row, int const* expert_for_source_row, int64_t const orig_cols, - int64_t const experts_per_token, int const num_experts_per_node) + int const* unpermuted_row_to_permuted_row, int const* token_selected_experts, int64_t const orig_cols, + int64_t const experts_per_token, int const num_experts_per_node, int const start_expert_id) { assert(orig_cols % 4 == 0); int64_t const original_row = blockIdx.x; @@ -1442,14 +1619,14 @@ __global__ void finalizeMoeRoutingKernel(GemmOutputType const* expanded_permuted for (int k_idx = 0; k_idx < experts_per_token; ++k_idx) { int64_t const k_offset = original_row * experts_per_token + k_idx; - int64_t const expert_idx = expert_for_source_row[k_offset]; - if (expert_idx >= num_experts_per_node) + int64_t const expert_id = token_selected_experts[k_offset] - start_expert_id; + if (expert_id < 0 || expert_id >= num_experts_per_node) { continue; } int64_t const expanded_original_row = original_row + k_idx * num_rows; - int64_t const expanded_permuted_row = expanded_source_row_to_expanded_dest_row[expanded_original_row]; + int64_t const expanded_permuted_row = unpermuted_row_to_permuted_row[expanded_original_row]; float const row_scale = (SCALE_MODE == ScaleMode::NO_SCALE) ? 1.f : scales[k_offset]; @@ -1461,7 +1638,7 @@ __global__ void finalizeMoeRoutingKernel(GemmOutputType const* expanded_permuted if (bias) { - auto const* bias_ptr = bias_v + expert_idx * num_elems_in_col; + auto const* bias_ptr = bias_v + expert_id * num_elems_in_col; expert_result = expert_result + arrayConvert(bias_ptr[elem_index]); } @@ -1481,9 +1658,9 @@ __global__ void finalizeMoeRoutingKernel(GemmOutputType const* expanded_permuted template __global__ void finalizeMoeRoutingNoFillingKernel(GemmOutputType const* expanded_permuted_rows, OutputType* reduced_unpermuted_output, ScaleBiasType const* bias, float const* scales, - int const* const expanded_source_row_to_expanded_dest_row, int const* expanded_dest_row_to_expanded_source_row, - int const* expert_for_source_row, int64_t const* expert_first_token_offset, int64_t const num_rows, - int64_t const orig_cols, int64_t const experts_per_token, int const num_experts_per_node) + int const* const unpermuted_row_to_permuted_row, int const* permuted_row_to_unpermuted_row, + int const* token_selected_experts, int64_t const* expert_first_token_offset, int64_t const num_rows, + int64_t const orig_cols, int64_t const experts_per_token, int const num_experts_per_node, int const start_expert_id) { assert(orig_cols % 4 == 0); @@ -1495,18 +1672,19 @@ __global__ void finalizeMoeRoutingNoFillingKernel(GemmOutputType const* expanded for (int64_t expanded_permuted_row = blockIdx.x; expanded_permuted_row < num_valid_tokens; expanded_permuted_row += gridDim.x) { - int64_t expanded_source_row = expanded_dest_row_to_expanded_source_row[expanded_permuted_row]; + int64_t unpermuted_row = permuted_row_to_unpermuted_row[expanded_permuted_row]; // Duplicate and permute rows - int64_t const source_k_rank = expanded_source_row / num_rows; - int64_t const source_row = expanded_source_row % num_rows; + int64_t const source_k_rank = unpermuted_row / num_rows; + int64_t const source_row = unpermuted_row % num_rows; // If the expert is the first selected (valid) one of the corresponding token on the current EP rank, do // reduction; otherwise, skip. bool is_first_selected_expert = true; for (int k_idx = 0; k_idx < source_k_rank; ++k_idx) { - if (expert_for_source_row[source_row * experts_per_token + k_idx] < num_experts_per_node) + int const expert_id = token_selected_experts[source_row * experts_per_token + k_idx] - start_expert_id; + if (expert_id >= 0 && expert_id < num_experts_per_node) { is_first_selected_expert = false; break; @@ -1542,14 +1720,14 @@ __global__ void finalizeMoeRoutingNoFillingKernel(GemmOutputType const* expanded for (int k_idx = 0; k_idx < experts_per_token; ++k_idx) { int64_t const k_offset = source_row * experts_per_token + k_idx; - int64_t const expert_idx = expert_for_source_row[k_offset]; - if (expert_idx >= num_experts_per_node) + int64_t const expert_id = token_selected_experts[k_offset] - start_expert_id; + if (expert_id < 0 || expert_id >= num_experts_per_node) { continue; } int64_t const expanded_permuted_row_from_k_idx - = expanded_source_row_to_expanded_dest_row[source_row + k_idx * num_rows]; + = unpermuted_row_to_permuted_row[source_row + k_idx * num_rows]; float const row_scale = (SCALE_MODE == ScaleMode::NO_SCALE) ? 1.f : scales[k_offset]; @@ -1561,7 +1739,7 @@ __global__ void finalizeMoeRoutingNoFillingKernel(GemmOutputType const* expanded if (bias) { - auto const* bias_ptr = bias_v + expert_idx * num_elems_in_col; + auto const* bias_ptr = bias_v + expert_id * num_elems_in_col; expert_result = expert_result + arrayConvert(bias_ptr[elem_index]); } @@ -1579,14 +1757,15 @@ __global__ void finalizeMoeRoutingNoFillingKernel(GemmOutputType const* expanded template void finalizeMoeRoutingKernelLauncher(GemmOutputType const* expanded_permuted_rows, OutputType* reduced_unpermuted_output, ScaleBiasType const* bias, float const* final_scales, - int const* expanded_source_row_to_expanded_dest_row, int const* expanded_dest_row_to_expanded_source_row, - int const* expert_for_source_row, int64_t const* expert_first_token_offset, int64_t const num_rows, - int64_t const cols, int64_t const experts_per_token, int const num_experts_per_node, + int const* unpermuted_row_to_permuted_row, int const* permuted_row_to_unpermuted_row, + int const* token_selected_experts, int64_t const* expert_first_token_offset, int64_t const num_rows, + int64_t const cols, int64_t const experts_per_token, int64_t const num_experts_per_node, MOEParallelismConfig parallelism_config, bool const enable_alltoall, cudaStream_t stream) { // Only add bias on rank 0 for tensor parallelism bool const is_rank_0 = parallelism_config.tp_rank == 0; ScaleBiasType const* bias_ptr = is_rank_0 ? bias : nullptr; + int const start_expert_id = num_experts_per_node * parallelism_config.ep_rank; cudaLaunchConfig_t config; config.dynamicSmemBytes = 0; @@ -1610,8 +1789,8 @@ void finalizeMoeRoutingKernelLauncher(GemmOutputType const* expanded_permuted_ro ? &finalizeMoeRoutingNoFillingKernel : &finalizeMoeRoutingNoFillingKernel; cudaLaunchKernelEx(&config, func, expanded_permuted_rows, reduced_unpermuted_output, bias_ptr, final_scales, - expanded_source_row_to_expanded_dest_row, expanded_dest_row_to_expanded_source_row, expert_for_source_row, - expert_first_token_offset, num_rows, cols, experts_per_token, num_experts_per_node); + unpermuted_row_to_permuted_row, permuted_row_to_unpermuted_row, token_selected_experts, + expert_first_token_offset, num_rows, cols, experts_per_token, num_experts_per_node, start_expert_id); } else { @@ -1624,8 +1803,8 @@ void finalizeMoeRoutingKernelLauncher(GemmOutputType const* expanded_permuted_ro ? &finalizeMoeRoutingKernel : &finalizeMoeRoutingKernel; cudaLaunchKernelEx(&config, func, expanded_permuted_rows, reduced_unpermuted_output, bias_ptr, final_scales, - expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, experts_per_token, - num_experts_per_node); + unpermuted_row_to_permuted_row, token_selected_experts, cols, experts_per_token, num_experts_per_node, + start_expert_id); } } @@ -2051,17 +2230,22 @@ CutlassMoeFCRunner:: constexpr float dtype_size = act_fp4 ? 0.5f : (use_w4afp8 ? 2.0f : sizeof(T)); - size_t const unpermuted_token_selected_experts_size = min_latency_mode ? 0 : num_moe_inputs * sizeof(int); - size_t const unpermuted_source_token_ids_size = min_latency_mode ? 0 : num_moe_inputs * sizeof(int); - size_t const permuted_source_token_ids_size = min_latency_mode ? 0 : num_moe_inputs * sizeof(int); + size_t const permuted_row_to_unpermuted_row_size = min_latency_mode ? 0 : num_moe_inputs * sizeof(int); size_t const permuted_token_selected_experts_size = min_latency_mode ? 0 : num_moe_inputs * sizeof(int); + + int64_t const num_tokens_per_block = computeNumTokensPerBlock(num_rows, num_experts_per_node); + int64_t const num_blocks_per_seq = tensorrt_llm::common::ceilDiv(num_rows, num_tokens_per_block); + size_t const blocked_expert_counts_size + = min_latency_mode ? 0 : num_experts_per_node * num_blocks_per_seq * sizeof(int); + size_t const blocked_expert_counts_cumsum_size = blocked_expert_counts_size; + size_t const blocked_row_to_unpermuted_row_size + = min_latency_mode ? 0 : num_experts_per_node * num_rows * sizeof(int); + size_t const permuted_data_size = permuted_elems * dtype_size; size_t const expert_first_token_offset_size = (num_experts_per_node + 1) * sizeof(int64_t); size_t const permuted_token_final_scales_size = mayHaveFinalizeFused() ? num_moe_inputs * sizeof(float) : 0; size_t const glu_inter_size = glu_inter_elems * gemm_output_dtype; // May be an intermediate type for quantization size_t const fc1_result_size = interbuf_elems * dtype_size; // Activation quantizes so back to dtype_size - size_t const sorter_ws_size - = min_latency_mode ? 0 : CubKeyValueSorter::getWorkspaceSize(num_rows, num_experts_per_node); size_t const fc2_result_size = min_latency_mode ? 0 : num_moe_inputs * hidden_size * gemm_output_dtype; // May be an intermediate type for quantization @@ -2148,13 +2332,13 @@ CutlassMoeFCRunner:: } while (false) #define ADD(name) ADD_NAME(name, name##_size) - ADD(unpermuted_source_token_ids); - ADD(unpermuted_token_selected_experts); - ADD(permuted_source_token_ids); + ADD(permuted_row_to_unpermuted_row); ADD(permuted_token_selected_experts); + ADD(blocked_expert_counts); + ADD(blocked_expert_counts_cumsum); + ADD(blocked_row_to_unpermuted_row); ADD(expert_first_token_offset); ADD(permuted_token_final_scales); - ADD(sorter_ws); ADD(overlapped_gemm1_gemm2_inputs); ADD(overlapped_gemm1_gemm2_outputs); ADD_NAME(alpha_scale_ptr_array_fc1, alpha_scale_ptr_array_size); @@ -2208,10 +2392,11 @@ void CutlassMoeFCRunner kernels::fp8_blockscale_gemm::CutlassFp8BlockScaleGemmRunnerInterface* CutlassMoeFCRunner::getDeepSeekBlockScaleGemmRunner() const @@ -2371,11 +2534,11 @@ void CutlassMoeFCRunner( static_cast(gemm_output), final_output, fc2_expert_biases, - unpermuted_final_scales, expanded_source_row_to_expanded_dest_row, expanded_dest_row_to_expanded_source_row, - expert_for_source_row, expert_first_token_offset, num_rows, hidden_size, k, num_experts_per_node, - parallelism_config, enable_alltoall, stream); + unpermuted_final_scales, unpermuted_row_to_permuted_row, permuted_row_to_unpermuted_row, token_selected_experts, + expert_first_token_offset, num_rows, hidden_size, k, num_experts_per_node, parallelism_config, enable_alltoall, + stream); } template @@ -2435,7 +2598,7 @@ void CutlassMoeFCRunner( static_cast(gemm_output), final_output, fc2_expert_biases, - unpermuted_final_scales, expanded_source_row_to_expanded_dest_row, expanded_dest_row_to_expanded_source_row, - expert_for_source_row, expert_first_token_offset, num_rows, hidden_size, k, num_experts_per_node, + unpermuted_final_scales, unpermuted_row_to_permuted_row, permuted_row_to_unpermuted_row, + token_selected_experts, expert_first_token_offset, num_rows, hidden_size, k, num_experts_per_node, parallelism_config, enable_alltoall, stream); } else if (!using_tma_ws_gemm2) { finalizeMoeRoutingKernelLauncher(static_cast(gemm_output), final_output, - fc2_expert_biases, unpermuted_final_scales, expanded_source_row_to_expanded_dest_row, - expanded_dest_row_to_expanded_source_row, expert_for_source_row, expert_first_token_offset, num_rows, - hidden_size, k, num_experts_per_node, parallelism_config, enable_alltoall, stream); + fc2_expert_biases, unpermuted_final_scales, unpermuted_row_to_permuted_row, permuted_row_to_unpermuted_row, + token_selected_experts, expert_first_token_offset, num_rows, hidden_size, k, num_experts_per_node, + parallelism_config, enable_alltoall, stream); } sync_check_cuda_error(stream); } @@ -2945,7 +3108,7 @@ void CutlassMoeFCRunner::value) == 0, @@ -3092,8 +3255,7 @@ void CutlassMoeFCRunner(sorter_ws_), stream); + threeStepBuildExpertMapsSortFirstToken(token_selected_experts, permuted_token_selected_experts_, + permuted_row_to_unpermuted_row_, unpermuted_row_to_permuted_row, expert_first_token_offset_, + blocked_expert_counts_, blocked_expert_counts_cumsum_, blocked_row_to_unpermuted_row_, num_rows, + num_experts_per_node, experts_per_token, start_expert, stream); } sync_check_cuda_error(stream); @@ -3141,7 +3300,7 @@ void CutlassMoeFCRunner& host_expert_first_token_offset = host_lora_workspace_.host_expert_first_token_offset; host_permuted_rows.resize(expanded_num_rows); TLLM_CUDA_CHECK(tensorrt_llm::common::cudaMemcpyAsyncSanitized(host_permuted_rows.data(), - permuted_source_token_ids_, expanded_num_rows * sizeof(int), cudaMemcpyDeviceToHost, stream)); + permuted_row_to_unpermuted_row_, expanded_num_rows * sizeof(int), cudaMemcpyDeviceToHost, stream)); host_expert_first_token_offset.resize(num_experts_per_node + 1); TLLM_CUDA_CHECK(tensorrt_llm::common::cudaMemcpyAsyncSanitized(host_expert_first_token_offset.data(), expert_first_token_offset_, (num_experts_per_node + 1) * sizeof(int64_t), cudaMemcpyDeviceToHost, @@ -3153,10 +3312,9 @@ void CutlassMoeFCRunner(permuted_data_), - token_topk_unpermuted_scales, permuted_token_final_scales_, permuted_source_token_ids_, - expanded_source_row_to_expanded_dest_row, num_rows, hidden_size, experts_per_token, num_experts_per_node, - quant_params.fp4.fc1.act_global_scale, use_per_expert_act_scale, expert_first_token_offset_, - fc1_fp4_act_scale_, input_sf, stream); + token_topk_unpermuted_scales, permuted_token_final_scales_, permuted_row_to_unpermuted_row_, num_rows, + hidden_size, experts_per_token, num_experts_per_node, quant_params.fp4.fc1.act_global_scale, + use_per_expert_act_scale, expert_first_token_offset_, fc1_fp4_act_scale_, input_sf, stream); sync_check_cuda_error(stream); @@ -3191,7 +3349,7 @@ void CutlassMoeFCRunner:: assert(min_latency_mode == false); gemm2_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE; gemm2_tma_ws_input.setFinalizeFusionParams(final_output, permuted_token_final_scales_, - expert_first_token_offset_, permuted_source_token_ids_, apply_bias ? fc2_expert_biases : nullptr, + expert_first_token_offset_, permuted_row_to_unpermuted_row_, apply_bias ? fc2_expert_biases : nullptr, hidden_size, num_rows); } @@ -3448,8 +3606,8 @@ CutlassMoeFCRunner:: // ==================== Helper for getting load balanced routing for profiling ================================== -__global__ void prepareFakeRouterBuffers(int* unpermuted_source_rows, int* unpermuted_expert_selection, - int64_t num_tokens, int64_t k, int64_t num_experts, int64_t num_experts_per_node) +__global__ void prepareFakeRouterBuffers( + int* token_selected_experts, int64_t num_tokens, int64_t k, int64_t num_experts) { int64_t tid = (int64_t) blockIdx.x * blockDim.x + threadIdx.x; int64_t sample = blockIdx.y; @@ -3459,8 +3617,7 @@ __global__ void prepareFakeRouterBuffers(int* unpermuted_source_rows, int* unper } // Offset the buffers to the start of the sample - unpermuted_source_rows += sample * num_tokens * k; - unpermuted_expert_selection += sample * num_tokens * k; + token_selected_experts += sample * num_tokens * k; // This is not perf sensitive we just init the state here every time prepare is called // This means the first N tokens will always have the same distribution, regardless of num_tokens @@ -3476,7 +3633,7 @@ __global__ void prepareFakeRouterBuffers(int* unpermuted_source_rows, int* unper bool valid = true; for (int prev_k = 0; prev_k < k_idx; prev_k++) { - int prev_expert = unpermuted_expert_selection[k * tid + prev_k]; + int prev_expert = token_selected_experts[k * tid + prev_k]; if (expert == prev_expert) { valid = false; @@ -3486,9 +3643,7 @@ __global__ void prepareFakeRouterBuffers(int* unpermuted_source_rows, int* unper if (valid) { - int64_t const idx = k * tid + k_idx; - unpermuted_expert_selection[idx] = expert < num_experts_per_node ? expert : num_experts_per_node; - unpermuted_source_rows[idx] = k_idx * num_tokens + tid; + token_selected_experts[k * tid + k_idx] = expert; break; } } @@ -3513,18 +3668,6 @@ __global__ void populateRandomBufferKernel(void* buffer_void, size_t size) buffer[tid * elem_per_thread + i] = curand4(&state); } -__global__ void buildReverseMap(int* expanded_source_row_to_expanded_dest_row, - int const* expanded_dest_row_to_expanded_source_row, int64_t expanded_num_tokens) -{ - int tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid < expanded_num_tokens) - { - assert(expanded_dest_row_to_expanded_source_row[tid] >= 0); - assert(expanded_dest_row_to_expanded_source_row[tid] < expanded_num_tokens); - expanded_source_row_to_expanded_dest_row[expanded_dest_row_to_expanded_source_row[tid]] = tid; - } -} - template __global__ void prepareMinLatencyBuffer(int* num_active_experts_per_node, int* active_expert_global_ids, int64_t* expert_first_token_offset, int const num_tokens, int const num_experts_per_token, @@ -3758,9 +3901,15 @@ std::map> GemmProfilerBackend::getProfile size_t map_size = mMinLatencyMode ? 0 : NUM_ROUTING_SAMPLES * num_expanded_tokens * sizeof(int); size_t unpermuted_size = mMinLatencyMode ? 0 : NUM_ROUTING_SAMPLES * num_expanded_tokens * sizeof(int); size_t permuted_size = mMinLatencyMode ? 0 : num_expanded_tokens * sizeof(int); - size_t sorter_ws_size = mMinLatencyMode ? 0 : mSorter.getWorkspaceSize(num_expanded_tokens, mNumExpertsPerNode); size_t token_topk_unpermuted_scales_size = mMinLatencyMode ? 0 : num_expanded_tokens * sizeof(float); + int64_t const num_tokens_per_block = computeNumTokensPerBlock(maxM, num_experts_per_node); + int64_t const num_blocks_per_seq = tensorrt_llm::common::ceilDiv(maxM, num_tokens_per_block); + size_t const blocked_expert_counts_size + = mMinLatencyMode ? 0 : num_experts_per_node * num_blocks_per_seq * sizeof(int); + size_t const blocked_expert_counts_cumsum_size = blocked_expert_counts_size; + size_t const blocked_row_to_unpermuted_row_size = mMinLatencyMode ? 0 : num_experts_per_node * maxM * sizeof(int); + // The follow buffers are used in min_latency_mode size_t num_active_experts_per_node_size = mMinLatencyMode ? sizeof(int) * NUM_ROUTING_SAMPLES : 0; // smaller than or equal to num_experts_per_node @@ -3779,12 +3928,13 @@ std::map> GemmProfilerBackend::getProfile #define ADD(name) ADD_NAME(name, name##_size) ADD(expert_first_token_offset); - ADD_NAME(source_to_dest, map_size); - ADD_NAME(dest_to_source, map_size); - ADD_NAME(unpermuted_selected_experts, unpermuted_size); - ADD_NAME(unpermuted_source_rows, unpermuted_size); + ADD_NAME(unpermuted_row_to_permuted_row, map_size); + ADD_NAME(permuted_row_to_unpermuted_row, map_size); + ADD_NAME(token_selected_experts, unpermuted_size); ADD_NAME(permuted_token_selected_experts, permuted_size); - ADD(sorter_ws); + ADD(blocked_expert_counts); + ADD(blocked_expert_counts_cumsum); + ADD(blocked_row_to_unpermuted_row); ADD(token_topk_unpermuted_scales); ADD(num_active_experts_per_node); ADD(active_expert_global_ids); @@ -3824,12 +3974,13 @@ void GemmProfilerBackend::prepareRouting(int num_tokens, char* workspace_ptr_cha : nullptr) GET_WS_PTR_BASE(int64_t*, expert_first_token_offset); - GET_WS_PTR_BASE(int*, source_to_dest); - GET_WS_PTR_BASE(int*, dest_to_source); - GET_WS_PTR_BASE(int*, unpermuted_selected_experts); - GET_WS_PTR_BASE(int*, unpermuted_source_rows); + GET_WS_PTR_BASE(int*, unpermuted_row_to_permuted_row); + GET_WS_PTR_BASE(int*, permuted_row_to_unpermuted_row); + GET_WS_PTR_BASE(int*, token_selected_experts); GET_WS_PTR(int*, permuted_token_selected_experts); - GET_WS_PTR(int*, sorter_ws); + GET_WS_PTR(int*, blocked_expert_counts); + GET_WS_PTR(int*, blocked_expert_counts_cumsum); + GET_WS_PTR(int*, blocked_row_to_unpermuted_row); GET_WS_PTR(int*, num_active_experts_per_node); GET_WS_PTR(int*, active_expert_global_ids); @@ -3847,29 +3998,27 @@ void GemmProfilerBackend::prepareRouting(int num_tokens, char* workspace_ptr_cha } else { - int64_t num_expanded_tokens = num_tokens * mK; + int64_t const num_expanded_tokens = num_tokens * mK; + int const start_expert_id = mNumExpertsPerNode * mParallelismConfig.ep_rank; + uint32_t num_threads = 256; dim3 grid_dim{(num_tokens + num_threads - 1) / num_threads, NUM_ROUTING_SAMPLES, 1}; - prepareFakeRouterBuffers<<>>(unpermuted_source_rows_base, - unpermuted_selected_experts_base, num_tokens, mK, mNumExperts, mNumExpertsPerNode); + prepareFakeRouterBuffers<<>>( + token_selected_experts_base, num_tokens, mK, mNumExperts); sync_check_cuda_error(stream); for (int64_t i = 0; i < NUM_ROUTING_SAMPLES; i++) { int64_t* expert_first_token_offset = expert_first_token_offset_base + i * (mNumExpertsPerNode + 1); - int* source_to_dest = source_to_dest_base + i * num_expanded_tokens; - int* dest_to_source = dest_to_source_base + i * num_expanded_tokens; - int* unpermuted_expert_selection = unpermuted_selected_experts_base + i * num_expanded_tokens; - int* unpermuted_source_rows = unpermuted_source_rows_base + i * num_expanded_tokens; - - generateTokenPermutation(unpermuted_expert_selection, unpermuted_source_rows, - permuted_token_selected_experts, dest_to_source, expert_first_token_offset, num_tokens, - mNumExpertsPerNode, mK, mSorter, sorter_ws, stream); - + int* unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_base + i * num_expanded_tokens; + int* permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_base + i * num_expanded_tokens; + int* token_selected_experts = token_selected_experts_base + i * num_expanded_tokens; + + threeStepBuildExpertMapsSortFirstToken(token_selected_experts, permuted_token_selected_experts, + permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, + blocked_expert_counts, blocked_expert_counts_cumsum, blocked_row_to_unpermuted_row, num_tokens, + mNumExpertsPerNode, mK, start_expert_id, stream); sync_check_cuda_error(stream); - - int grid_dim = (num_expanded_tokens + num_threads - 1) / num_threads; - buildReverseMap<<>>(source_to_dest, dest_to_source, num_expanded_tokens); } } } @@ -3955,8 +4104,8 @@ void GemmProfilerBackend::prepareTmaWsInputs( GET_WS_PTR(int64_t*, expert_first_token_offset); int64_t* expert_first_token_offset_base = expert_first_token_offset; - GET_WS_PTR(int*, dest_to_source); - int* dest_to_source_base = dest_to_source; + GET_WS_PTR(int*, permuted_row_to_unpermuted_row); + int* permuted_row_to_unpermuted_row_base = permuted_row_to_unpermuted_row; GET_WS_PTR(void*, input); GET_WS_PTR(void*, output); GET_WS_PTR(void*, intermediate); @@ -3989,7 +4138,7 @@ void GemmProfilerBackend::prepareTmaWsInputs( tma_ws_input_workspace += tma_ws_size; int64_t* expert_first_token_offset = expert_first_token_offset_base + i * (mNumExpertsPerNode + 1); - int* expanded_dest_row_to_expanded_source_row = dest_to_source_base + i * num_expanded_tokens; + int* permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_base + i * num_expanded_tokens; auto& gemm1_tma_ws_input = mGemmToProfile == GemmToProfile::GEMM_1 ? mTmaInputCache[i] : dummy_tma_ws_input; auto& gemm2_tma_ws_input = mGemmToProfile == GemmToProfile::GEMM_2 ? mTmaInputCache[i] : dummy_tma_ws_input; @@ -4008,7 +4157,7 @@ void GemmProfilerBackend::prepareTmaWsInputs( assert(!mMinLatencyMode); gemm2_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE; gemm2_tma_ws_input.setFinalizeFusionParams(output, token_topk_unpermuted_scales, - expert_first_token_offset, expanded_dest_row_to_expanded_source_row, apply_bias ? bias : nullptr, + expert_first_token_offset, permuted_row_to_unpermuted_row, apply_bias ? bias : nullptr, mExpertHiddenSize, num_tokens); } @@ -4043,8 +4192,6 @@ void GemmProfilerBackend::prepare( mAllTacticsSaved = mInterface->getTactics(); mSampleIndex = 0; - mSorter.updateNumExperts(mNumExpertsPerNode); - auto workspace_size = getWorkspaceSize(num_tokens); populateRandomBuffer(workspace_ptr_char, workspace_size, stream); @@ -4083,9 +4230,9 @@ void GemmProfilerBackend::runProfiler(int original_num_tokens, Config const& tac : nullptr) GET_WS_PTR_OFFSET(int64_t const*, expert_first_token_offset, (mSampleIndex * (mNumExpertsPerNode + 1))); - GET_WS_PTR_OFFSET(int const*, source_to_dest, (mSampleIndex * expanded_num_tokens)); - GET_WS_PTR_OFFSET(int const*, dest_to_source, (mSampleIndex * expanded_num_tokens)); - GET_WS_PTR_OFFSET(int const*, unpermuted_selected_experts, (mSampleIndex * expanded_num_tokens)); + GET_WS_PTR_OFFSET(int const*, unpermuted_row_to_permuted_row, (mSampleIndex * expanded_num_tokens)); + GET_WS_PTR_OFFSET(int const*, permuted_row_to_unpermuted_row, (mSampleIndex * expanded_num_tokens)); + GET_WS_PTR_OFFSET(int const*, token_selected_experts, (mSampleIndex * expanded_num_tokens)); GET_WS_PTR(float const*, token_topk_unpermuted_scales); auto const* token_topk_permuted_scales = token_topk_unpermuted_scales; @@ -4144,8 +4291,7 @@ void GemmProfilerBackend::runProfiler(int original_num_tokens, Config const& tac tactic, // mMinLatencyMode, // num_active_experts_per_node, // - active_expert_global_ids, // - /*start_expert=*/0); + active_expert_global_ids); // } else { @@ -4163,9 +4309,9 @@ void GemmProfilerBackend::runProfiler(int original_num_tokens, Config const& tac mQuantParams, // token_topk_unpermuted_scales, // token_topk_permuted_scales, // - source_to_dest, // - dest_to_source, // - unpermuted_selected_experts, // + unpermuted_row_to_permuted_row, // + permuted_row_to_unpermuted_row, // + token_selected_experts, // expert_first_token_offset + mNumExpertsPerNode, // original_num_tokens, // expanded_num_tokens, // @@ -4183,8 +4329,7 @@ void GemmProfilerBackend::runProfiler(int original_num_tokens, Config const& tac tactic, // mMinLatencyMode, // num_active_experts_per_node, // - active_expert_global_ids, // - /*start_expert=*/0); + active_expert_global_ids); // } mInterface->is_profiler = false; diff --git a/cpp/tensorrt_llm/kernels/moeUtilOp.cu b/cpp/tensorrt_llm/kernels/moeUtilOp.cu index 87531a5e64d..94f73a7df52 100644 --- a/cpp/tensorrt_llm/kernels/moeUtilOp.cu +++ b/cpp/tensorrt_llm/kernels/moeUtilOp.cu @@ -55,6 +55,62 @@ using namespace tensorrt_llm::common; namespace tensorrt_llm::kernels { +// ========================== CUB Sorting things ==================================== +CubKeyValueSorter::CubKeyValueSorter() + : num_experts_(0) + , num_bits_(sizeof(int) * 8) +{ +} + +int CubKeyValueSorter::expertsToBits(int num_experts) +{ + // Max value we represent is V = num_experts + (num_experts - 1) = 2 * num_experts - 1 + // The maximum number of bits is therefore floor(log2(V)) + 1 + return static_cast(log2(2 * num_experts - 1)) + 1; +} + +CubKeyValueSorter::CubKeyValueSorter(int const num_experts) + : num_experts_(num_experts) + , num_bits_(expertsToBits(num_experts)) +{ +} + +void CubKeyValueSorter::updateNumExperts(int const num_experts) +{ + num_experts_ = num_experts; + num_bits_ = expertsToBits(num_experts); +} + +size_t CubKeyValueSorter::getWorkspaceSize(size_t const num_key_value_pairs, int const num_experts) +{ + int num_bits = expertsToBits(num_experts); + size_t required_storage = 0; + int* null_int = nullptr; + cub::DeviceRadixSort::SortPairs( + nullptr, required_storage, null_int, null_int, null_int, null_int, num_key_value_pairs, 0, num_bits); + + // TODO: fix DeviceRadixSort + // when num_key_value_pairs, num_experts, num_bits, required_storage = 64, 4, 3, 0 + // The required_storage seems to vary between 0 and 1 for the same inputs + if (required_storage == 0) + { + required_storage = 1; + } + return required_storage; +} + +void CubKeyValueSorter::run(void* workspace, size_t const workspace_size, int const* keys_in, int* keys_out, + int const* values_in, int* values_out, size_t const num_key_value_pairs, cudaStream_t stream) +{ + size_t expected_ws_size = getWorkspaceSize(num_key_value_pairs, num_experts_); + size_t actual_ws_size = workspace_size; + + TLLM_CHECK_WITH_INFO(expected_ws_size <= workspace_size, + "[CubKeyValueSorter::run] The allocated workspace is too small to run this problem."); + cub::DeviceRadixSort::SortPairs( + workspace, actual_ws_size, keys_in, keys_out, values_in, values_out, num_key_value_pairs, 0, num_bits_, stream); +} + // TODO: These kernel implementations are duplicated in moe_kernels.cu. They will be refactored later (tracked by // https://jirasw.nvidia.com/browse/TRTLLM-708) template @@ -468,13 +524,13 @@ __device__ void writeSF(int64_t num_tokens_before_expert, int64_t expert_id, int void generateTokenPermutation(int const* unpermuted_token_selected_experts, int const* unpermuted_source_token_ids, int* permuted_token_selected_experts, int* permuted_source_token_ids, int64_t* expert_first_token_offset, - int64_t num_rows, int64_t num_experts_per_node, int64_t k, cutlass_kernels::CubKeyValueSorter& sorter, - void* sorter_ws, cudaStream_t stream) + int64_t num_rows, int64_t num_experts_per_node, int64_t k, CubKeyValueSorter& sorter, void* sorter_ws, + cudaStream_t stream) { int64_t const expanded_num_rows = k * num_rows; sorter.updateNumExperts(num_experts_per_node); size_t const sorter_ws_size_bytes - = cutlass_kernels::pad_to_multiple_of_16(sorter.getWorkspaceSize(expanded_num_rows, num_experts_per_node)); + = pad_to_multiple_of_16(sorter.getWorkspaceSize(expanded_num_rows, num_experts_per_node)); sorter.run((void*) sorter_ws, sorter_ws_size_bytes, unpermuted_token_selected_experts, permuted_token_selected_experts, unpermuted_source_token_ids, permuted_source_token_ids, expanded_num_rows, stream); diff --git a/cpp/tensorrt_llm/kernels/moeUtilOp.h b/cpp/tensorrt_llm/kernels/moeUtilOp.h index 968067d615c..888b1a72ffd 100644 --- a/cpp/tensorrt_llm/kernels/moeUtilOp.h +++ b/cpp/tensorrt_llm/kernels/moeUtilOp.h @@ -23,6 +23,32 @@ namespace tensorrt_llm::kernels { +static inline size_t pad_to_multiple_of_16(size_t const& input) +{ + static constexpr int ALIGNMENT = 16; + return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT); +} + +class CubKeyValueSorter +{ +public: + CubKeyValueSorter(); + + CubKeyValueSorter(int const num_experts_per_node); + + void updateNumExperts(int const num_experts_per_node); + + static size_t getWorkspaceSize(size_t const num_key_value_pairs, int const num_experts_per_node); + + void run(void* workspace, size_t const workspace_size, int const* keys_in, int* keys_out, int const* values_in, + int* values_out, size_t const num_key_value_pairs, cudaStream_t stream); + +private: + static int expertsToBits(int experts); + int num_experts_; + int num_bits_; +}; + bool fusedBuildExpertMapsSortFirstToken(int const* token_selected_experts, int* unpermuted_token_selected_experts, int* permuted_source_token_ids, int64_t* expert_first_token_offset, int64_t const num_tokens, int const num_experts_per_node, int const experts_per_token, int const start_expert, int const end_expert, @@ -34,8 +60,8 @@ void buildExpertMaps(int const* token_selected_experts, int* unpermuted_token_se void generateTokenPermutation(int const* unpermuted_token_selected_experts, int const* unpermuted_source_token_ids, int* permuted_token_selected_experts, int* permuted_source_token_ids, int64_t* expert_first_token_offset, - int64_t num_rows, int64_t num_experts_per_node, int64_t k, cutlass_kernels::CubKeyValueSorter& sorter, - void* sorter_ws, cudaStream_t stream); + int64_t num_rows, int64_t num_experts_per_node, int64_t k, CubKeyValueSorter& sorter, void* sorter_ws, + cudaStream_t stream); template void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input, diff --git a/cpp/tensorrt_llm/thop/moeUtilOp.cpp b/cpp/tensorrt_llm/thop/moeUtilOp.cpp index 319d68f7605..8e50656b32f 100644 --- a/cpp/tensorrt_llm/thop/moeUtilOp.cpp +++ b/cpp/tensorrt_llm/thop/moeUtilOp.cpp @@ -48,7 +48,7 @@ void runPermute(void const* input_activations_void, void const* input_sf_void, i int* unpermuted_token_selected_experts_, int* unpermuted_source_token_ids_, int* permuted_source_token_ids_, int* permuted_token_selected_experts_, T* permuted_data_, char* sorter_ws_, int64_t* expert_first_token_offset_, float* permuted_token_final_scales_, int* expanded_source_row_to_expanded_dest_row, - cutlass_kernels::MOEParallelismConfig parallelism_config, cutlass_kernels::CubKeyValueSorter sorter_, bool use_lora, + cutlass_kernels::MOEParallelismConfig parallelism_config, kernels::CubKeyValueSorter sorter_, bool use_lora, kernels::LoraParams& lora_params, bool use_fp8_block_scaling, bool min_latency_mode, cutlass_kernels::MoeMinLatencyParams& min_latency_params, cudaStream_t stream) { @@ -121,7 +121,7 @@ moe_permute_op(torch::Tensor const& input, torch::Tensor const& token_selected_e int64_t const tp_rank, int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size, int64_t const cluster_rank, bool min_latency_mode, bool use_fp8_block_scaling) { - cutlass_kernels::CubKeyValueSorter sorter_; + kernels::CubKeyValueSorter sorter_; TORCH_CHECK(cluster_size == 1 && cluster_rank == 0, "smart_router is supported in min_latency mode"); TORCH_CHECK(min_latency_mode == false, "min_latency_mode is not supported now"); @@ -179,7 +179,7 @@ moe_permute_op(torch::Tensor const& input, torch::Tensor const& token_selected_e size_t const sorter_size = min_latency_mode ? 0 - : cutlass_kernels::CubKeyValueSorter::getWorkspaceSize(num_rows * experts_per_token, num_experts_per_node); + : kernels::CubKeyValueSorter::getWorkspaceSize(num_rows * experts_per_token, num_experts_per_node); auto sorter_ws_tensor = torch::empty( {static_cast(sorter_size)}, torch::dtype(torch::kChar).device(torch::kCUDA).requires_grad(false)); diff --git a/cpp/tests/unit_tests/kernels/mixtureOfExpertsTest.cu b/cpp/tests/unit_tests/kernels/mixtureOfExpertsTest.cu index b5cc8c5f2e0..a44ca2a4a89 100644 --- a/cpp/tests/unit_tests/kernels/mixtureOfExpertsTest.cu +++ b/cpp/tests/unit_tests/kernels/mixtureOfExpertsTest.cu @@ -2118,7 +2118,7 @@ TEST_F(MixtureOfExpertsProfilerTest, TestGeneratedProfilerDistribution) #ifdef USING_OSS_CUTLASS_MOE_GEMM backend.init(this->mMoERunner, GemmProfilerBackend::GemmToProfile::GEMM_1, nvinfer1::DataType::kHALF, nvinfer1::DataType::kHALF, nvinfer1::DataType::kHALF, num_experts, k, 1024, 4096, mGroupSize, {}, false, - mUseLora, /*min_latency_mode=*/false, /*need_weights=*/true, MOEParallelismConfig{1, 0, ep, ep - 1}, + mUseLora, /*min_latency_mode=*/false, /*need_weights=*/true, MOEParallelismConfig{1, 0, ep, 0}, /*enable_alltoall=*/false); #else backend.init(this->mMoERunner, GemmProfilerBackend::GemmToProfile::GEMM_1, nvinfer1::DataType::kHALF, @@ -2136,34 +2136,47 @@ TEST_F(MixtureOfExpertsProfilerTest, TestGeneratedProfilerDistribution) #define GET_WS_PTR(type, name) auto* name = reinterpret_cast(workspace + workspaces.at(#name).second) GET_WS_PTR(int64_t*, expert_first_token_offset); - GET_WS_PTR(int*, source_to_dest); - GET_WS_PTR(int*, dest_to_source); + GET_WS_PTR(int*, unpermuted_row_to_permuted_row); + GET_WS_PTR(int*, permuted_row_to_unpermuted_row); +#ifdef USING_OSS_CUTLASS_MOE_GEMM + GET_WS_PTR(int*, token_selected_experts); +#else GET_WS_PTR(int*, unpermuted_selected_experts); - +#endif #undef GET_WS_PTR for (int sample = 0; sample < backend.NUM_ROUTING_SAMPLES; sample++) { auto host_expert_first_token_offset_size = getDataFromDevice( expert_first_token_offset + sample * (num_experts_per_node + 1), num_experts_per_node + 1); - auto host_source_to_dest_map - = getDataFromDevice(source_to_dest + sample * expanded_num_tokens, expanded_num_tokens); - auto host_dest_to_source_map - = getDataFromDevice(dest_to_source + sample * expanded_num_tokens, expanded_num_tokens); + auto host_unpermuted_row_to_permuted_row_map = getDataFromDevice( + unpermuted_row_to_permuted_row + sample * expanded_num_tokens, expanded_num_tokens); + auto host_permuted_row_to_unpermuted_row_map = getDataFromDevice( + permuted_row_to_unpermuted_row + sample * expanded_num_tokens, expanded_num_tokens); +#ifdef USING_OSS_CUTLASS_MOE_GEMM + auto host_token_selected_experts + = getDataFromDevice(token_selected_experts + sample * expanded_num_tokens, expanded_num_tokens); +#else auto host_token_selected_experts = getDataFromDevice( unpermuted_selected_experts + sample * expanded_num_tokens, expanded_num_tokens); +#endif std::vector calculated_routing_values(num_experts_per_node + 1, 0); int skipped = 0; for (auto v : host_token_selected_experts) { +#ifndef USING_OSS_CUTLASS_MOE_GEMM ASSERT_TRUE(v < num_experts_per_node || (v == num_experts_per_node && ep > 1)) << "v " << v << " num_experts_per_node " << num_experts_per_node << " ep " << ep; - skipped += (v == num_experts_per_node); +#endif if (v < num_experts_per_node) { calculated_routing_values[v]++; } + else + { + skipped++; + } } if (num_tokens > 1) @@ -2206,14 +2219,18 @@ TEST_F(MixtureOfExpertsProfilerTest, TestGeneratedProfilerDistribution) int64_t idx = token_idx * k + k_idx; int64_t expert_idx = host_token_selected_experts[idx]; +#ifdef USING_OSS_CUTLASS_MOE_GEMM + if (expert_idx < num_experts_per_node) +#else if (expert_idx < num_experts) +#endif { - int64_t source_location = k_idx * num_tokens + token_idx; - int64_t dest_location = host_expert_first_token_offset_size[expert_idx] + int64_t unpermuted_row = k_idx * num_tokens + token_idx; + int64_t permuted_row = host_expert_first_token_offset_size[expert_idx] + calculated_routing_values[expert_idx]; - ASSERT_EQ(host_source_to_dest_map[source_location], dest_location); - ASSERT_EQ(host_dest_to_source_map[dest_location], source_location); + ASSERT_EQ(host_unpermuted_row_to_permuted_row_map[unpermuted_row], permuted_row); + ASSERT_EQ(host_permuted_row_to_unpermuted_row_map[permuted_row], unpermuted_row); calculated_routing_values[expert_idx]++; }