From 4c6ca52931c3edb2dca1e6f5c18489e1468d278e Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 24 Jan 2025 16:51:35 +0000 Subject: [PATCH 1/3] Fix moe align block issue for mixtral Signed-off-by: ElizaWszola --- csrc/moe/moe_align_sum_kernels.cu | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/csrc/moe/moe_align_sum_kernels.cu b/csrc/moe/moe_align_sum_kernels.cu index d609ce1697df..e9cdd29b1f5a 100644 --- a/csrc/moe/moe_align_sum_kernels.cu +++ b/csrc/moe/moe_align_sum_kernels.cu @@ -31,9 +31,17 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); const size_t start_idx = threadIdx.x * tokens_per_thread; + // compute aligned shared mem offset to make sure cumsum is aligned + int cnts_byte_offset = + ((blockDim.x + 1) * num_experts) * sizeof(token_cnts_t); + int aligned_offset = + (cnts_byte_offset + sizeof(int32_t) - 1) / sizeof(int32_t); + extern __shared__ int32_t shared_mem[]; - int32_t* cumsum = shared_mem; // 1d tensor with shape (num_experts + 1) - token_cnts_t* tokens_cnts = (token_cnts_t*)(shared_mem + blockDim.x + 1); + token_cnts_t* tokens_cnts = (token_cnts_t*) + shared_mem; // 2d tensor with shape (blockDim.x + 1, num_experts) + int32_t* cumsum = + shared_mem + aligned_offset; // 1d tensor with shape (num_experts + 1) for (int i = 0; i < num_experts; ++i) { tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; From e215ef62fa3669da6b1ade1a3ea4c22b4cf69ae8 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 24 Jan 2025 15:23:04 -0500 Subject: [PATCH 2/3] Update csrc/moe/moe_align_sum_kernels.cu Co-authored-by: Tyler Michael Smith Signed-off-by: ElizaWszola --- csrc/moe/moe_align_sum_kernels.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/moe/moe_align_sum_kernels.cu b/csrc/moe/moe_align_sum_kernels.cu index e9cdd29b1f5a..f9857d1532c0 100644 --- a/csrc/moe/moe_align_sum_kernels.cu +++ b/csrc/moe/moe_align_sum_kernels.cu @@ -38,8 +38,8 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, (cnts_byte_offset + sizeof(int32_t) - 1) / sizeof(int32_t); extern __shared__ int32_t shared_mem[]; - token_cnts_t* tokens_cnts = (token_cnts_t*) - shared_mem; // 2d tensor with shape (blockDim.x + 1, num_experts) + token_cnts_t* tokens_cnts = reinterpret_cast( + shared_mem); // 2d tensor with shape (blockDim.x + 1, num_experts) int32_t* cumsum = shared_mem + aligned_offset; // 1d tensor with shape (num_experts + 1) From 70570d733e283e6f43edf350c6496f76f5f26a23 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 24 Jan 2025 22:16:04 +0000 Subject: [PATCH 3/3] make it simpler Signed-off-by: ElizaWszola --- csrc/moe/moe_align_sum_kernels.cu | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/csrc/moe/moe_align_sum_kernels.cu b/csrc/moe/moe_align_sum_kernels.cu index f9857d1532c0..8b6fe72ad743 100644 --- a/csrc/moe/moe_align_sum_kernels.cu +++ b/csrc/moe/moe_align_sum_kernels.cu @@ -31,17 +31,11 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); const size_t start_idx = threadIdx.x * tokens_per_thread; - // compute aligned shared mem offset to make sure cumsum is aligned - int cnts_byte_offset = - ((blockDim.x + 1) * num_experts) * sizeof(token_cnts_t); - int aligned_offset = - (cnts_byte_offset + sizeof(int32_t) - 1) / sizeof(int32_t); - extern __shared__ int32_t shared_mem[]; - token_cnts_t* tokens_cnts = reinterpret_cast( - shared_mem); // 2d tensor with shape (blockDim.x + 1, num_experts) - int32_t* cumsum = - shared_mem + aligned_offset; // 1d tensor with shape (num_experts + 1) + int32_t* cumsum = shared_mem; // 1d tensor with shape (num_experts + 1) + token_cnts_t* tokens_cnts = + (token_cnts_t*)(shared_mem + num_experts + + 1); // 2d tensor with shape (blockDim.x + 1, num_experts) for (int i = 0; i < num_experts; ++i) { tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;