@@ -104,6 +104,53 @@ __global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids,
104104 }
105105}
106106
107+ namespace {
108+ inline void launch_compute_problem_sizes (const torch::Tensor& topk_ids,
109+ torch::Tensor& problem_sizes1,
110+ torch::Tensor& problem_sizes2,
111+ torch::Tensor& atomic_buffer,
112+ int64_t num_experts, int64_t n,
113+ int64_t k, cudaStream_t stream,
114+ const bool swap_ab) {
115+ int num_threads = min (THREADS_PER_EXPERT, topk_ids.numel ());
116+
117+ const int32_t * topk_ptr = static_cast <const int32_t *>(topk_ids.data_ptr ());
118+ int32_t * ps1_ptr = static_cast <int32_t *>(problem_sizes1.data_ptr ());
119+ int32_t * ps2_ptr = static_cast <int32_t *>(problem_sizes2.data_ptr ());
120+ int32_t * atomic_ptr = static_cast <int32_t *>(atomic_buffer.data_ptr ());
121+
122+ if (swap_ab) {
123+ compute_problem_sizes<true ><<<num_experts, num_threads, 0 , stream>>> (
124+ topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr,
125+ static_cast <int >(topk_ids.numel ()), static_cast <int >(n),
126+ static_cast <int >(k));
127+ } else {
128+ compute_problem_sizes<false ><<<num_experts, num_threads, 0 , stream>>> (
129+ topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr,
130+ static_cast <int >(topk_ids.numel ()), static_cast <int >(n),
131+ static_cast <int >(k));
132+ }
133+ }
134+ } // namespace
135+
136+ void get_cutlass_moe_mm_problem_sizes_caller (
137+ const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
138+ torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
139+ const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets) {
140+ auto stream = at::cuda::getCurrentCUDAStream (topk_ids.device ().index ());
141+ auto options_int32 =
142+ torch::TensorOptions ().dtype (torch::kInt32 ).device (topk_ids.device ());
143+ torch::Tensor atomic_buffer = torch::zeros (num_experts, options_int32);
144+
145+ // Swap-AB should be disabled for FP4 path
146+ bool may_swap_ab = (!blockscale_offsets.has_value ()) &&
147+ (topk_ids.numel () <= SWAP_AB_THRESHOLD);
148+
149+ launch_compute_problem_sizes (topk_ids, problem_sizes1, problem_sizes2,
150+ atomic_buffer, num_experts, n, k, stream,
151+ may_swap_ab);
152+ }
153+
107154void get_cutlass_moe_mm_data_caller (
108155 const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
109156 torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
@@ -121,21 +168,9 @@ void get_cutlass_moe_mm_data_caller(
121168 bool may_swap_ab = (!blockscale_offsets.has_value ()) &&
122169 (topk_ids.numel () <= SWAP_AB_THRESHOLD);
123170
124- if (may_swap_ab) {
125- compute_problem_sizes<true ><<<num_experts, num_threads, 0 , stream>>> (
126- static_cast <const int32_t *>(topk_ids.data_ptr ()),
127- static_cast <int32_t *>(problem_sizes1.data_ptr ()),
128- static_cast <int32_t *>(problem_sizes2.data_ptr ()),
129- static_cast <int32_t *>(atomic_buffer.data_ptr ()), topk_ids.numel (), n,
130- k);
131- } else {
132- compute_problem_sizes<false ><<<num_experts, num_threads, 0 , stream>>> (
133- static_cast <const int32_t *>(topk_ids.data_ptr ()),
134- static_cast <int32_t *>(problem_sizes1.data_ptr ()),
135- static_cast <int32_t *>(problem_sizes2.data_ptr ()),
136- static_cast <int32_t *>(atomic_buffer.data_ptr ()), topk_ids.numel (), n,
137- k);
138- }
171+ launch_compute_problem_sizes (topk_ids, problem_sizes1, problem_sizes2,
172+ atomic_buffer, num_experts, n, k, stream,
173+ may_swap_ab);
139174
140175 if (blockscale_offsets.has_value ()) {
141176 // fp4 path
0 commit comments