Skip to content

Commit 2344192

Browse files
authored
Optimize moe_align_block_size for deepseek_v3 (#12850)
Signed-off-by: mgoin <[email protected]>
1 parent bffddd9 commit 2344192

File tree

2 files changed

+39
-16
lines changed

2 files changed

+39
-16
lines changed

csrc/moe/moe_align_sum_kernels.cu

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -198,26 +198,27 @@ __global__ void moe_align_block_size_global_mem_kernel(
198198
}
199199

200200
// taken from
201-
// https://github.com/sgl-project/sglang/commit/ded9fcd09a43d5e7d5bb31a2bc3e9fc21bf65d2a
201+
// https://github.com/sgl-project/sglang/commit/cdae77b03dfc6fec3863630550b45bbfc789f957
202202
template <typename scalar_t>
203203
__global__ void sgl_moe_align_block_size_kernel(
204204
scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids,
205205
int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts,
206206
int32_t block_size, size_t numel, int32_t* cumsum) {
207207
__shared__ int32_t shared_counts[32][8];
208-
__shared__ int32_t local_offsets[256];
209208

210209
const int warp_id = threadIdx.x / 32;
211-
const int lane_id = threadIdx.x % 32;
212210
const int experts_per_warp = 8;
213211
const int my_expert_start = warp_id * experts_per_warp;
214212

213+
// Initialize shared_counts for this warp's experts
215214
for (int i = 0; i < experts_per_warp; ++i) {
216215
if (my_expert_start + i < num_experts) {
217216
shared_counts[warp_id][i] = 0;
218217
}
219218
}
220219

220+
__syncthreads();
221+
221222
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
222223
const size_t start_idx = threadIdx.x * tokens_per_thread;
223224

@@ -230,6 +231,7 @@ __global__ void sgl_moe_align_block_size_kernel(
230231

231232
__syncthreads();
232233

234+
// Single thread computes cumulative sum and total tokens
233235
if (threadIdx.x == 0) {
234236
cumsum[0] = 0;
235237
for (int i = 1; i <= num_experts; ++i) {
@@ -246,19 +248,28 @@ __global__ void sgl_moe_align_block_size_kernel(
246248

247249
__syncthreads();
248250

251+
// Assign expert IDs to blocks
249252
if (threadIdx.x < num_experts) {
250253
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
251254
i += block_size) {
252255
expert_ids[i / block_size] = threadIdx.x;
253256
}
254-
local_offsets[threadIdx.x] = cumsum[threadIdx.x];
255257
}
258+
}
256259

257-
__syncthreads();
258-
259-
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
260+
// taken from
261+
// https://github.com/sgl-project/sglang/commit/cdae77b03dfc6fec3863630550b45bbfc789f957
262+
template <typename scalar_t>
263+
__global__ void sgl_moe_token_sort_kernel(scalar_t* __restrict__ topk_ids,
264+
int32_t* sorted_token_ids,
265+
int32_t* cumsum_buffer,
266+
size_t numel) {
267+
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
268+
const size_t stride = blockDim.x * gridDim.x;
269+
270+
for (size_t i = tid; i < numel; i += stride) {
260271
int32_t expert_id = topk_ids[i];
261-
int32_t rank_post_pad = atomicAdd(&local_offsets[expert_id], 1);
272+
int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1);
262273
sorted_token_ids[rank_post_pad] = i;
263274
}
264275
}
@@ -377,23 +388,34 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
377388
torch::Tensor experts_ids,
378389
torch::Tensor num_tokens_post_pad) {
379390
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
391+
TORCH_CHECK(num_experts == 256,
392+
"sgl_moe_align_block_size kernel only supports deepseek v3.");
393+
380394
VLLM_DISPATCH_INTEGRAL_TYPES(
381395
topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] {
382-
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
383-
// tensors
396+
// calc needed amount of shared mem for `cumsum` tensors
384397
auto options_int =
385398
torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device());
386-
// torch::Tensor token_cnts_buffer =
387-
// torch::empty({(num_experts + 1) * num_experts}, options_int);
388399
torch::Tensor cumsum_buffer =
389-
torch::empty({num_experts + 1}, options_int);
400+
torch::zeros({num_experts + 1}, options_int);
390401

391-
auto kernel = vllm::moe::sgl_moe_align_block_size_kernel<scalar_t>;
392-
kernel<<<1, 1024, 0, stream>>>(
402+
auto align_kernel =
403+
vllm::moe::sgl_moe_align_block_size_kernel<scalar_t>;
404+
align_kernel<<<1, 1024, 0, stream>>>(
393405
topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
394406
experts_ids.data_ptr<int32_t>(),
395407
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
396408
topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>());
409+
410+
const int block_threads = 256;
411+
const int num_blocks =
412+
(topk_ids.numel() + block_threads - 1) / block_threads;
413+
const int max_blocks = 65535;
414+
const int actual_blocks = std::min(num_blocks, max_blocks);
415+
auto sort_kernel = vllm::moe::sgl_moe_token_sort_kernel<scalar_t>;
416+
sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(
417+
topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
418+
cumsum_buffer.data_ptr<int32_t>(), topk_ids.numel());
397419
});
398420
}
399421

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -596,7 +596,7 @@ def moe_align_block_size(
596596
dtype=torch.int32,
597597
device=topk_ids.device)
598598
if num_experts >= 224:
599-
if envs.VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON:
599+
if envs.VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON or num_experts != 256:
600600
moe_align_block_size_triton(
601601
topk_ids,
602602
num_experts,
@@ -606,6 +606,7 @@ def moe_align_block_size(
606606
num_tokens_post_pad,
607607
)
608608
else:
609+
# Currently requires num_experts=256
609610
ops.sgl_moe_align_block_size(
610611
topk_ids,
611612
num_experts,

0 commit comments

Comments
 (0)