@@ -87,6 +87,32 @@ struct LoraParams
8787
8888namespace cutlass_kernels
8989{
90+ static inline size_t pad_to_multiple_of_16 (size_t const & input)
91+ {
92+ static constexpr int ALIGNMENT = 16 ;
93+ return ALIGNMENT * ((input + ALIGNMENT - 1 ) / ALIGNMENT);
94+ }
95+
96+ class CubKeyValueSorter
97+ {
98+ public:
99+ CubKeyValueSorter ();
100+
101+ CubKeyValueSorter (int const num_experts_per_node);
102+
103+ void updateNumExperts (int const num_experts_per_node);
104+
105+ static size_t getWorkspaceSize (size_t const num_key_value_pairs, int const num_experts_per_node);
106+
107+ void run (void * workspace, size_t const workspace_size, int const * keys_in, int * keys_out, int const * values_in,
108+ int * values_out, size_t const num_key_value_pairs, cudaStream_t stream);
109+
110+ private:
111+ static int expertsToBits (int experts);
112+ int num_experts_;
113+ int num_bits_;
114+ };
115+
90116/* *
91117 * \brief Describes what parallelism mode the MoE is using
92118 *
@@ -371,9 +397,9 @@ class CutlassMoeFCRunnerInterface
371397 ActivationType fc1_activation_type, void const * fc2_expert_weights, void const * fc2_expert_biases,
372398 QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
373399 int const num_experts, int const experts_per_token, char * workspace_ptr, void * final_output,
374- int * unpermuted_row_to_permuted_row , MOEParallelismConfig parallelism_config, bool const enable_alltoall ,
375- bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode ,
376- MoeMinLatencyParams& min_latency_params, cudaStream_t stream)
400+ int * expanded_source_row_to_expanded_dest_row , MOEParallelismConfig parallelism_config,
401+ bool const enable_alltoall, bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale,
402+ bool min_latency_mode, MoeMinLatencyParams& min_latency_params, cudaStream_t stream)
377403 = 0;
378404
379405 // Aliases for profiling the gemms
@@ -387,22 +413,22 @@ class CutlassMoeFCRunnerInterface
387413 int const num_experts_per_node, ActivationType fc1_activation_type, float const ** alpha_scale_ptr_array,
388414 bool bias_is_broadcast, bool use_deepseek_fp8_block_scale, cudaStream_t stream,
389415 cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int * num_active_experts_per,
390- int * active_expert_global_ids)
416+ int * active_expert_global_ids, int start_expert )
391417 = 0;
392418
393419 virtual void gemm2 (void const * const input, void * const gemm_output, void * const final_output,
394420 int64_t const * const expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput const tma_ws_input_template,
395421 void const * const fc2_expert_weights, void const * const fc2_expert_biases, void const * const fc2_int_scales,
396422 float const * const fc2_fp8_dequant, TmaWarpSpecializedGroupedGemmInput::ElementSF const * fc2_fp4_act_flat,
397423 QuantParams quant_params, float const * const token_topk_unpermuted_scales,
398- float const * const token_topk_permuted_scales, int const * const unpermuted_row_to_permuted_row ,
399- int const * permuted_row_to_unpermuted_row , int const * const token_selected_experts ,
424+ float const * const token_topk_permuted_scales, int const * const expanded_source_row_to_expanded_dest_row ,
425+ int const * expanded_dest_row_to_expanded_source_row , int const * const expert_for_source_row ,
400426 int64_t const * const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
401427 int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
402428 int64_t const experts_per_token, float const ** alpha_scale_ptr_array, bool use_lora, void * fc2_lora,
403429 bool use_deepseek_fp8_block_scale, cudaStream_t stream, MOEParallelismConfig parallelism_config,
404430 bool const enable_alltoall, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode,
405- int * num_active_experts_per, int * active_expert_global_ids)
431+ int * num_active_experts_per, int * active_expert_global_ids, int start_expert )
406432 = 0;
407433
408434 virtual std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput>
@@ -518,9 +544,9 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
518544 ActivationType fc1_activation_type, void const * fc2_expert_weights, void const * fc2_expert_biases,
519545 QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
520546 int const num_experts, int const experts_per_token, char * workspace_ptr, void * final_output,
521- int * unpermuted_row_to_permuted_row , MOEParallelismConfig parallelism_config, bool const enable_alltoall ,
522- bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode ,
523- MoeMinLatencyParams& min_latency_params, cudaStream_t stream) override ;
547+ int * expanded_source_row_to_expanded_dest_row , MOEParallelismConfig parallelism_config,
548+ bool const enable_alltoall, bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale,
549+ bool min_latency_mode, MoeMinLatencyParams& min_latency_params, cudaStream_t stream) override ;
524550
525551 // We make these GEMM1 & GEMM2 static because they need to be stateless for the profiler to work
526552 static void gemm1 (MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>& gemm_runner,
@@ -539,7 +565,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
539565 int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size,
540566 int const num_experts_per_node, ActivationType fc1_activation_type, float const ** alpha_scale_ptr_array,
541567 bool bias_is_broadcast, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config,
542- bool min_latency_mode, int * num_active_experts_per, int * active_expert_global_ids);
568+ bool min_latency_mode, int * num_active_experts_per, int * active_expert_global_ids, int start_expert );
543569
544570 static void gemm2 (MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>& gemm_runner,
545571 DeepSeekBlockScaleGemmRunner* fp8_blockscale_gemm_runner, T const * const input, void * const gemm_output,
@@ -548,14 +574,14 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
548574 ScaleBiasType const * const fc2_expert_biases, ScaleBiasType const * const fc2_int_scales,
549575 float const * const fc2_fp8_dequant, TmaWarpSpecializedGroupedGemmInput::ElementSF const * fc2_fp4_act_flat,
550576 QuantParams quant_params, float const * const token_topk_unpermuted_scales,
551- float const * const token_topk_permuted_scales, int const * const unpermuted_row_to_permuted_row ,
552- int const * permuted_row_to_unpermuted_row , int const * const token_selected_experts ,
577+ float const * const token_topk_permuted_scales, int const * const expanded_source_row_to_expanded_dest_row ,
578+ int const * expanded_dest_row_to_expanded_source_row , int const * const expert_for_source_row ,
553579 int64_t const * const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
554580 int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
555581 int64_t const experts_per_token, float const ** alpha_scale_ptr_array, bool use_lora, void * fc2_lora,
556582 cudaStream_t stream, MOEParallelismConfig parallelism_config, bool const enable_alltoall,
557583 cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int * num_active_experts_per,
558- int * active_expert_global_ids);
584+ int * active_expert_global_ids, int start_expert );
559585
560586 // Overrides to allow us to forward on to the internal functions with the pointers using the correct type
561587 void gemm1 (void const * const input, void * const output, void * const intermediate_result,
@@ -568,7 +594,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
568594 int const num_experts_per_node, ActivationType fc1_activation_type, float const ** alpha_scale_ptr_array,
569595 bool bias_is_broadcast, bool use_deepseek_fp8_block_scale, cudaStream_t stream,
570596 cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int * num_active_experts_per,
571- int * active_expert_global_ids) override
597+ int * active_expert_global_ids, int start_expert ) override
572598 {
573599 auto * block_scale_gemm_runner = use_deepseek_fp8_block_scale ? getDeepSeekBlockScaleGemmRunner () : nullptr ;
574600 return Self::gemm1 (moe_gemm_runner_, block_scale_gemm_runner, static_cast <T const *>(input),
@@ -577,33 +603,33 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
577603 num_valid_tokens_ptr, static_cast <ScaleBiasType const *>(fc1_int_scales), fc1_fp8_dequant, fc2_fp8_quant,
578604 fc1_fp4_act_flat, fc2_fp4_act_flat, quant_params, num_rows, expanded_num_rows, hidden_size, inter_size,
579605 num_experts_per_node, fc1_activation_type, alpha_scale_ptr_array, bias_is_broadcast, stream, config,
580- min_latency_mode, num_active_experts_per, active_expert_global_ids);
606+ min_latency_mode, num_active_experts_per, active_expert_global_ids, start_expert );
581607 }
582608
583609 void gemm2 (void const * const input, void * const gemm_output, void * const final_output,
584610 int64_t const * const expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput const tma_ws_input_template,
585611 void const * const fc2_expert_weights, void const * const fc2_expert_biases, void const * const fc2_int_scales,
586612 float const * const fc2_fp8_dequant, TmaWarpSpecializedGroupedGemmInput::ElementSF const * fc2_fp4_act_flat,
587613 QuantParams quant_params, float const * const token_topk_unpermuted_scales,
588- float const * const token_topk_permuted_scales, int const * const unpermuted_row_to_permuted_row ,
589- int const * permuted_row_to_unpermuted_row , int const * const token_selected_experts ,
614+ float const * const token_topk_permuted_scales, int const * const expanded_source_row_to_expanded_dest_row ,
615+ int const * expanded_dest_row_to_expanded_source_row , int const * const expert_for_source_row ,
590616 int64_t const * const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
591617 int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
592618 int64_t const experts_per_token, float const ** alpha_scale_ptr_array, bool use_lora, void * fc2_lora,
593619 bool use_deepseek_fp8_block_scale, cudaStream_t stream, MOEParallelismConfig parallelism_config,
594620 bool const enable_alltoall, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode,
595- int * num_active_experts_per, int * active_expert_global_ids) override
621+ int * num_active_experts_per, int * active_expert_global_ids, int start_expert ) override
596622 {
597623 auto * block_scale_gemm_runner = use_deepseek_fp8_block_scale ? getDeepSeekBlockScaleGemmRunner () : nullptr ;
598624 return Self::gemm2 (moe_gemm_runner_, block_scale_gemm_runner, static_cast <T const *>(input), gemm_output,
599625 static_cast <OutputType*>(final_output), expert_first_token_offset, tma_ws_input_template,
600626 static_cast <WeightType const *>(fc2_expert_weights), static_cast <ScaleBiasType const *>(fc2_expert_biases),
601627 static_cast <ScaleBiasType const *>(fc2_int_scales), fc2_fp8_dequant, fc2_fp4_act_flat, quant_params,
602- token_topk_unpermuted_scales, token_topk_permuted_scales, unpermuted_row_to_permuted_row ,
603- permuted_row_to_unpermuted_row, token_selected_experts , num_valid_tokens_ptr, num_rows, expanded_num_rows ,
604- hidden_size, inter_size, num_experts_per_node, experts_per_token, alpha_scale_ptr_array, use_lora, fc2_lora ,
605- stream, parallelism_config, enable_alltoall, config, min_latency_mode, num_active_experts_per ,
606- active_expert_global_ids);
628+ token_topk_unpermuted_scales, token_topk_permuted_scales, expanded_source_row_to_expanded_dest_row ,
629+ expanded_dest_row_to_expanded_source_row, expert_for_source_row , num_valid_tokens_ptr, num_rows,
630+ expanded_num_rows, hidden_size, inter_size, num_experts_per_node, experts_per_token, alpha_scale_ptr_array,
631+ use_lora, fc2_lora, stream, parallelism_config, enable_alltoall, config, min_latency_mode,
632+ num_active_experts_per, active_expert_global_ids, start_expert );
607633 }
608634
609635 virtual size_t getGemmWorkspaceSize (int num_experts_per_node) const override
@@ -737,29 +763,30 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
737763 static void BlockScaleFC2 (DeepSeekBlockScaleGemmRunner& gemm_runner, T const * const input, void * const gemm_output,
738764 OutputType* const final_output, int64_t const * const expert_first_token_offset,
739765 WeightType const * const fc2_expert_weights, ScaleBiasType const * const fc2_expert_biases,
740- float const * const token_topk_unpermuted_scales, int const * const unpermuted_row_to_permuted_row ,
741- int const * const permuted_row_to_unpermuted_row , int const * const token_selected_experts ,
766+ float const * const token_topk_unpermuted_scales, int const * const expanded_source_row_to_expanded_dest_row ,
767+ int const * const expanded_dest_row_to_expanded_source_row , int const * const expert_for_source_row ,
742768 int64_t const * const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
743- int64_t const hidden_size, int64_t const inter_size, int64_t const num_experts_per_node, int64_t const k,
769+ int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, int64_t const k,
744770 MOEParallelismConfig parallelism_config, bool const enable_alltoall, QuantParams& quant_params,
745771 cudaStream_t stream);
746772
747773 T const * applyPrequantScale (void * smoothed_act, void const * permuted_data, void const * prequant_scales,
748774 int64_t const * num_valid_tokens_ptr, int64_t const expanded_num_rows, int64_t const seq_len, bool const use_awq,
749775 cudaStream_t stream);
750776
777+ CubKeyValueSorter sorter_;
751778 MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType> moe_gemm_runner_;
752779 std::unique_ptr<DeepSeekBlockScaleGemmRunner> blockscale_gemm_runner_;
753780
754781 std::optional<cutlass_extensions::CutlassGemmConfig> gemm1_config_;
755782 std::optional<cutlass_extensions::CutlassGemmConfig> gemm2_config_;
756783
757784 // Pointers
758- int * permuted_row_to_unpermuted_row_{};
785+ int * unpermuted_token_selected_experts_{};
786+ int * unpermuted_source_token_ids_{};
787+ int * permuted_source_token_ids_{};
759788 int * permuted_token_selected_experts_{};
760- int * blocked_expert_counts_{};
761- int * blocked_expert_counts_cumsum_{};
762- int * blocked_row_to_unpermuted_row_{};
789+ char * sorter_ws_{};
763790 T* permuted_data_{};
764791 float * permuted_token_final_scales_{};
765792
@@ -832,6 +859,7 @@ struct GemmProfilerBackend
832859 mParallelismConfig = parallelism_config;
833860 mEnableAlltoall = enable_alltoall;
834861 mSM = common::getSMVersion ();
862+ mSorter .updateNumExperts (mNumExpertsPerNode );
835863
836864 mScalingType = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE;
837865 if (dtype == nvinfer1::DataType::kFP8
@@ -855,6 +883,7 @@ struct GemmProfilerBackend
855883 cudaStream_t const & stream);
856884
857885 CutlassMoeFCRunnerInterface* mInterface ;
886+ CubKeyValueSorter mSorter ;
858887
859888 GemmToProfile mGemmToProfile = GemmToProfile::Undefined;
860889 std::vector<Config> mAllTacticsSaved ;
0 commit comments