@@ -198,26 +198,27 @@ __global__ void moe_align_block_size_global_mem_kernel(
198
198
}
199
199
200
200
// taken from
201
- // https://github.com/sgl-project/sglang/commit/ded9fcd09a43d5e7d5bb31a2bc3e9fc21bf65d2a
201
+ // https://github.com/sgl-project/sglang/commit/cdae77b03dfc6fec3863630550b45bbfc789f957
202
202
template <typename scalar_t >
203
203
__global__ void sgl_moe_align_block_size_kernel (
204
204
scalar_t * __restrict__ topk_ids, int32_t * sorted_token_ids,
205
205
int32_t * expert_ids, int32_t * total_tokens_post_pad, int32_t num_experts,
206
206
int32_t block_size, size_t numel, int32_t * cumsum) {
207
207
__shared__ int32_t shared_counts[32 ][8 ];
208
- __shared__ int32_t local_offsets[256 ];
209
208
210
209
const int warp_id = threadIdx .x / 32 ;
211
- const int lane_id = threadIdx .x % 32 ;
212
210
const int experts_per_warp = 8 ;
213
211
const int my_expert_start = warp_id * experts_per_warp;
214
212
213
+ // Initialize shared_counts for this warp's experts
215
214
for (int i = 0 ; i < experts_per_warp; ++i) {
216
215
if (my_expert_start + i < num_experts) {
217
216
shared_counts[warp_id][i] = 0 ;
218
217
}
219
218
}
220
219
220
+ __syncthreads ();
221
+
221
222
const size_t tokens_per_thread = CEILDIV (numel, blockDim .x );
222
223
const size_t start_idx = threadIdx .x * tokens_per_thread;
223
224
@@ -230,6 +231,7 @@ __global__ void sgl_moe_align_block_size_kernel(
230
231
231
232
__syncthreads ();
232
233
234
+ // Single thread computes cumulative sum and total tokens
233
235
if (threadIdx .x == 0 ) {
234
236
cumsum[0 ] = 0 ;
235
237
for (int i = 1 ; i <= num_experts; ++i) {
@@ -246,19 +248,28 @@ __global__ void sgl_moe_align_block_size_kernel(
246
248
247
249
__syncthreads ();
248
250
251
+ // Assign expert IDs to blocks
249
252
if (threadIdx .x < num_experts) {
250
253
for (int i = cumsum[threadIdx .x ]; i < cumsum[threadIdx .x + 1 ];
251
254
i += block_size) {
252
255
expert_ids[i / block_size] = threadIdx .x ;
253
256
}
254
- local_offsets[threadIdx .x ] = cumsum[threadIdx .x ];
255
257
}
258
+ }
256
259
257
- __syncthreads ();
258
-
259
- for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
260
+ // taken from
261
+ // https://github.com/sgl-project/sglang/commit/cdae77b03dfc6fec3863630550b45bbfc789f957
262
+ template <typename scalar_t >
263
+ __global__ void sgl_moe_token_sort_kernel (scalar_t * __restrict__ topk_ids,
264
+ int32_t * sorted_token_ids,
265
+ int32_t * cumsum_buffer,
266
+ size_t numel) {
267
+ const size_t tid = blockIdx .x * blockDim .x + threadIdx .x ;
268
+ const size_t stride = blockDim .x * gridDim .x ;
269
+
270
+ for (size_t i = tid; i < numel; i += stride) {
260
271
int32_t expert_id = topk_ids[i];
261
- int32_t rank_post_pad = atomicAdd (&local_offsets [expert_id], 1 );
272
+ int32_t rank_post_pad = atomicAdd (&cumsum_buffer [expert_id], 1 );
262
273
sorted_token_ids[rank_post_pad] = i;
263
274
}
264
275
}
@@ -377,23 +388,34 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
377
388
torch::Tensor experts_ids,
378
389
torch::Tensor num_tokens_post_pad) {
379
390
const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
391
+ TORCH_CHECK (num_experts == 256 ,
392
+ " sgl_moe_align_block_size kernel only supports deepseek v3." );
393
+
380
394
VLLM_DISPATCH_INTEGRAL_TYPES (
381
395
topk_ids.scalar_type (), " sgl_moe_align_block_size_kernel" , [&] {
382
- // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
383
- // tensors
396
+ // calc needed amount of shared mem for `cumsum` tensors
384
397
auto options_int =
385
398
torch::TensorOptions ().dtype (torch::kInt ).device (topk_ids.device ());
386
- // torch::Tensor token_cnts_buffer =
387
- // torch::empty({(num_experts + 1) * num_experts}, options_int);
388
399
torch::Tensor cumsum_buffer =
389
- torch::empty ({num_experts + 1 }, options_int);
400
+ torch::zeros ({num_experts + 1 }, options_int);
390
401
391
- auto kernel = vllm::moe::sgl_moe_align_block_size_kernel<scalar_t >;
392
- kernel<<<1 , 1024 , 0 , stream>>> (
402
+ auto align_kernel =
403
+ vllm::moe::sgl_moe_align_block_size_kernel<scalar_t >;
404
+ align_kernel<<<1 , 1024 , 0 , stream>>> (
393
405
topk_ids.data_ptr <scalar_t >(), sorted_token_ids.data_ptr <int32_t >(),
394
406
experts_ids.data_ptr <int32_t >(),
395
407
num_tokens_post_pad.data_ptr <int32_t >(), num_experts, block_size,
396
408
topk_ids.numel (), cumsum_buffer.data_ptr <int32_t >());
409
+
410
+ const int block_threads = 256 ;
411
+ const int num_blocks =
412
+ (topk_ids.numel () + block_threads - 1 ) / block_threads;
413
+ const int max_blocks = 65535 ;
414
+ const int actual_blocks = std::min (num_blocks, max_blocks);
415
+ auto sort_kernel = vllm::moe::sgl_moe_token_sort_kernel<scalar_t >;
416
+ sort_kernel<<<actual_blocks, block_threads, 0 , stream>>> (
417
+ topk_ids.data_ptr <scalar_t >(), sorted_token_ids.data_ptr <int32_t >(),
418
+ cumsum_buffer.data_ptr <int32_t >(), topk_ids.numel ());
397
419
});
398
420
}
399
421
0 commit comments