From 888baef636c5b3bf76c80be69df5e80cd9e83dd4 Mon Sep 17 00:00:00 2001 From: Caleb_Du Date: Fri, 9 May 2025 05:16:23 -0700 Subject: [PATCH 1/6] [1] refactor permute/unpermute kernel - remove token_expert_indices dependence - remove unused parameter - align to triton kernel [2] integrate permute/unpermute kernel into deepgemm moe Signed-off-by: Caleb_Du --- .../benchmark_moe_permute_unpermute.py | 58 +++----- csrc/moe/moe_permute_unpermute_op.cu | 68 +++++---- .../moe_permute_unpermute_kernel.h | 20 +-- .../moe_permute_unpermute_kernel.inl | 37 +++-- csrc/moe/torch_bindings.cpp | 13 +- .../kernels/moe/test_moe_permute_unpermute.py | 129 +++++++++++------- .../layers/fused_moe/deep_gemm_moe.py | 36 +++++ .../layers/fused_moe/moe_permute_unpermute.py | 73 ++++++---- 8 files changed, 242 insertions(+), 192 deletions(-) diff --git a/benchmarks/kernels/benchmark_moe_permute_unpermute.py b/benchmarks/kernels/benchmark_moe_permute_unpermute.py index 4ed690090144..ed50e407d5f4 100644 --- a/benchmarks/kernels/benchmark_moe_permute_unpermute.py +++ b/benchmarks/kernels/benchmark_moe_permute_unpermute.py @@ -63,19 +63,14 @@ def prepare(i: int): def run(): if use_customized_permute: - (permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = ( - moe_permute( - qhidden_states, - topk_weights=topk_weights, - topk_ids=topk_ids, - token_expert_indices=token_expert_indices, - topk=topk, - n_expert=num_experts, - n_local_expert=num_experts, - expert_map=None, - align_block_size=align_block_size, - ) - ) + (permuted_hidden_states, first_token_off, inv_perm_idx, + permuted_idx, m_indices) = moe_permute( + qhidden_states, + topk_ids=topk_ids, + topk=topk, + n_expert=num_experts, + expert_map=None, + align_block_size=align_block_size) else: ( permuted_hidden_states, @@ -150,24 +145,20 @@ def benchmark_unpermute( def prepare(): if use_customized_permute: - (permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = ( - moe_permute( - qhidden_states, - topk_weights=topk_weights, - topk_ids=topk_ids, - token_expert_indices=token_expert_indices, - topk=topk, - n_expert=num_experts, - n_local_expert=num_experts, - expert_map=None, - align_block_size=align_block_size, - ) - ) + (permuted_hidden_states, first_token_off, inv_perm_idx, + permuted_idx, m_indices) = moe_permute( + qhidden_states, + topk_ids=topk_ids, + topk=topk, + n_expert=num_experts, + expert_map=None, + align_block_size=align_block_size) # convert to fp16/bf16 as gemm output return ( permuted_hidden_states.to(dtype), first_token_off, inv_perm_idx, + permuted_idx, m_indices, ) else: @@ -191,17 +182,10 @@ def prepare(): def run(input: tuple): if use_customized_permute: - (permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = input - moe_unpermute( - permuted_hidden_states, - topk_weights, - topk_ids, - inv_perm_idx, - first_token_off, - topk, - num_experts, - num_experts, - ) + (permuted_hidden_states, first_token_off, inv_perm_idx, + permuted_idx, m_indices) = input + moe_unpermute(permuted_hidden_states, topk_weights, inv_perm_idx, + first_token_off, topk) else: ( permuted_hidden_states, diff --git a/csrc/moe/moe_permute_unpermute_op.cu b/csrc/moe/moe_permute_unpermute_op.cu index a77471a7f207..e57e90f64aba 100644 --- a/csrc/moe/moe_permute_unpermute_op.cu +++ b/csrc/moe/moe_permute_unpermute_op.cu @@ -9,33 +9,32 @@ #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000) void moe_permute( - const torch::Tensor& input, // [n_token, hidden] - const torch::Tensor& topk_weights, //[n_token, topk] - torch::Tensor& topk_ids, // [n_token, topk] - const torch::Tensor& token_expert_indices, // [n_token, topk] + const torch::Tensor& input, // [n_token, hidden] + // const torch::Tensor& topk_weights, //[n_token, topk] + const torch::Tensor& topk_ids, // [n_token, topk] + const torch::Tensor& token_expert_indicies, // [n_token, topk] const std::optional& expert_map, // [n_expert] int64_t n_expert, int64_t n_local_expert, int64_t topk, const std::optional& align_block_size, - torch::Tensor& - permuted_input, // [topk * n_token/align_block_size_m, hidden] + torch::Tensor& permuted_input, // [permuted_size, hidden] torch::Tensor& expert_first_token_offset, // [n_local_expert + 1] - torch::Tensor& src_row_id2dst_row_id_map, // [n_token, topk] + torch::Tensor& inv_permuted_idx, // [n_token, topk] + torch::Tensor& permuted_idx, // [permute_size] torch::Tensor& m_indices) { // [align_expand_m] - TORCH_CHECK(topk_weights.scalar_type() == at::ScalarType::Float, - "topk_weights must be float32"); + // TORCH_CHECK(topk_weights.scalar_type() == at::ScalarType::Float, + // "topk_weights must be float32"); TORCH_CHECK(expert_first_token_offset.scalar_type() == at::ScalarType::Long, "expert_first_token_offset must be int64"); TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int, "topk_ids must be int32"); - TORCH_CHECK(token_expert_indices.scalar_type() == at::ScalarType::Int, - "token_expert_indices must be int32"); - TORCH_CHECK(src_row_id2dst_row_id_map.scalar_type() == at::ScalarType::Int, - "src_row_id2dst_row_id_map must be int32"); + TORCH_CHECK(token_expert_indicies.scalar_type() == at::ScalarType::Int, + "token_expert_indicies must be int32"); + TORCH_CHECK(inv_permuted_idx.scalar_type() == at::ScalarType::Int, + "inv_permuted_idx must be int32"); TORCH_CHECK(expert_first_token_offset.size(0) == n_local_expert + 1, "expert_first_token_offset shape != n_local_expert+1") - TORCH_CHECK( - src_row_id2dst_row_id_map.sizes() == token_expert_indices.sizes(), - "token_expert_indices shape must be same as src_row_id2dst_row_id_map"); + TORCH_CHECK(inv_permuted_idx.sizes() == token_expert_indicies.sizes(), + "token_expert_indicies shape must be same as inv_permuted_idx"); auto n_token = input.sizes()[0]; auto n_hidden = input.sizes()[1]; auto align_block_size_value = @@ -46,8 +45,9 @@ void moe_permute( auto sort_workspace = torch::empty( {sorter_size}, torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false)); + auto copy_topk_ids = topk_ids.clone(); // copy topk_ids for preprocess auto permuted_experts_id = torch::empty_like(topk_ids); - auto dst_row_id2src_row_id_map = torch::empty_like(src_row_id2dst_row_id_map); + auto sorted_row_idx = torch::empty_like(inv_permuted_idx); auto align_expert_first_token_offset = torch::zeros_like(expert_first_token_offset); @@ -67,24 +67,22 @@ void moe_permute( const int* expert_map_ptr = get_ptr(expert_map.value()); valid_num_ptr = get_ptr(expert_first_token_offset) + n_local_expert; - preprocessTopkIdLauncher(get_ptr(topk_ids), n_token * topk, + preprocessTopkIdLauncher(get_ptr(copy_topk_ids), n_token * topk, expert_map_ptr, n_expert, stream); } // expert sort topk expert id and scan expert id get expert_first_token_offset - sortAndScanExpert(get_ptr(topk_ids), get_ptr(token_expert_indices), - get_ptr(permuted_experts_id), - get_ptr(dst_row_id2src_row_id_map), - get_ptr(expert_first_token_offset), n_token, - n_expert, n_local_expert, topk, sorter, - get_ptr(sort_workspace), stream); + sortAndScanExpert( + get_ptr(copy_topk_ids), get_ptr(token_expert_indicies), + get_ptr(permuted_experts_id), get_ptr(sorted_row_idx), + get_ptr(expert_first_token_offset), n_token, n_expert, + n_local_expert, topk, sorter, get_ptr(sort_workspace), stream); // dispatch expandInputRowsKernelLauncher MOE_DISPATCH(input.scalar_type(), [&] { expandInputRowsKernelLauncher( get_ptr(input), get_ptr(permuted_input), - get_ptr(topk_weights), get_ptr(permuted_experts_id), - get_ptr(dst_row_id2src_row_id_map), - get_ptr(src_row_id2dst_row_id_map), + get_ptr(permuted_experts_id), get_ptr(sorted_row_idx), + get_ptr(inv_permuted_idx), get_ptr(permuted_idx), get_ptr(expert_first_token_offset), n_token, valid_num_ptr, n_hidden, topk, n_local_expert, align_block_size_value, stream); }); @@ -103,30 +101,26 @@ void moe_permute( void moe_unpermute( const torch::Tensor& permuted_hidden_states, // [n_token * topk, hidden] const torch::Tensor& topk_weights, //[n_token, topk] - const torch::Tensor& topk_ids, // [n_token, topk] - const torch::Tensor& src_row_id2dst_row_id_map, // [n_token, topk] + const torch::Tensor& inv_permuted_idx, // [n_token, topk] const torch::Tensor& expert_first_token_offset, // [n_local_expert+1] - int64_t n_expert, int64_t n_local_expert, int64_t topk, + int64_t topk, torch::Tensor& hidden_states // [n_token, hidden] ) { - TORCH_CHECK(src_row_id2dst_row_id_map.sizes() == topk_ids.sizes(), - "topk_ids shape must be same as src_row_id2dst_row_id_map"); - TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int, - "topk_ids must be int32"); TORCH_CHECK( permuted_hidden_states.scalar_type() == hidden_states.scalar_type(), - "topk_ids dtype must be same as src_row_id2dst_row_id_map"); + "permuted_hidden_states dtype must be same as hidden_states"); auto n_token = hidden_states.size(0); auto n_hidden = hidden_states.size(1); auto stream = at::cuda::getCurrentCUDAStream().stream(); + int n_local_expert = expert_first_token_offset.size(0) - 1; const int64_t* valid_ptr = get_ptr(expert_first_token_offset) + n_local_expert; MOE_DISPATCH(hidden_states.scalar_type(), [&] { finalizeMoeRoutingKernelLauncher( get_ptr(permuted_hidden_states), get_ptr(hidden_states), get_ptr(topk_weights), - get_ptr(src_row_id2dst_row_id_map), get_ptr(topk_ids), - n_token, n_hidden, topk, valid_ptr, stream); + get_ptr(inv_permuted_idx), n_token, n_hidden, topk, valid_ptr, + stream); }); } diff --git a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h index 43c29721cd16..108091efbefa 100644 --- a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h +++ b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h @@ -57,31 +57,19 @@ void sortAndScanExpert(int* expert_for_source_row, const int* source_rows, template void expandInputRowsKernelLauncher( - T const* unpermuted_input, T* permuted_output, - const float* unpermuted_scales, int* sorted_experts, + T const* unpermuted_input, T* permuted_output, int* sorted_experts, int const* expanded_dest_row_to_expanded_source_row, - int* expanded_source_row_to_expanded_dest_row, + int* expanded_source_row_to_expanded_dest_row, int* permuted_idx, int64_t* expert_first_token_offset, int64_t const num_rows, int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k, int num_local_experts, const int& align_block_size, cudaStream_t stream); -// Final kernel to unpermute and scale -// This kernel unpermutes the original data, does the k-way reduction and -// performs the final skip connection. -template -__global__ void finalizeMoeRoutingKernel( - T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output, - float const* scales, int const* expanded_source_row_to_expanded_dest_row, - int const* expert_for_source_row, int64_t const orig_cols, int64_t const k, - int64_t const* num_valid_ptr); - template void finalizeMoeRoutingKernelLauncher( T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output, float const* scales, int const* expanded_source_row_to_expanded_dest_row, - int const* expert_for_source_row, int64_t const num_rows, - int64_t const cols, int64_t const k, int64_t const* num_valid_ptr, - cudaStream_t stream); + int64_t const num_rows, int64_t const cols, int64_t const k, + int64_t const* num_valid_ptr, cudaStream_t stream); void preprocessTopkIdLauncher(int* topk_id_ptr, int size, const int* expert_map_ptr, int num_experts, diff --git a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl index ad0d390665a0..02c420f8217b 100644 --- a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl +++ b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl @@ -2,10 +2,9 @@ template __global__ void expandInputRowsKernel( - T const* unpermuted_input, T* permuted_output, - const float* unpermuted_scales, int* sorted_experts, + T const* unpermuted_input, T* permuted_output, int* sorted_experts, int const* expanded_dest_row_to_expanded_source_row, - int* expanded_source_row_to_expanded_dest_row, + int* expanded_source_row_to_expanded_dest_row, int* permuted_idx, int64_t* expert_first_token_offset, int64_t const num_rows, int64_t const* num_dest_rows, int64_t const cols, int64_t k, int num_local_experts, int align_block_size) { @@ -54,6 +53,10 @@ __global__ void expandInputRowsKernel( assert(expanded_dest_row <= INT32_MAX); expanded_source_row_to_expanded_dest_row[expanded_source_row] = static_cast(expanded_dest_row); + // skip non local expert token + if (!CHECK_SKIPPED || blockIdx.x < *num_dest_rows) { + permuted_idx[expanded_dest_row] = expanded_source_row; + } } if (!CHECK_SKIPPED || blockIdx.x < *num_dest_rows) { @@ -62,7 +65,7 @@ __global__ void expandInputRowsKernel( using DataElem = cutlass::Array; // Duplicate and permute rows - int64_t const source_row = expanded_source_row % num_rows; + int64_t const source_row = expanded_source_row / k; auto const* source_row_ptr = reinterpret_cast(unpermuted_input + source_row * cols); @@ -82,10 +85,9 @@ __global__ void expandInputRowsKernel( template void expandInputRowsKernelLauncher( - T const* unpermuted_input, T* permuted_output, - const float* unpermuted_scales, int* sorted_experts, + T const* unpermuted_input, T* permuted_output, int* sorted_experts, int const* expanded_dest_row_to_expanded_source_row, - int* expanded_source_row_to_expanded_dest_row, + int* expanded_source_row_to_expanded_dest_row, int* permuted_idx, int64_t* expert_first_token_offset, int64_t const num_rows, int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k, int num_local_experts, const int& align_block_size, cudaStream_t stream) { @@ -105,11 +107,11 @@ void expandInputRowsKernelLauncher( int64_t smem_size = sizeof(int64_t) * (num_local_experts + 1); func<<>>( - unpermuted_input, permuted_output, unpermuted_scales, sorted_experts, + unpermuted_input, permuted_output, sorted_experts, expanded_dest_row_to_expanded_source_row, - expanded_source_row_to_expanded_dest_row, expert_first_token_offset, - num_rows, num_valid_tokens_ptr, cols, k, num_local_experts, - align_block_size); + expanded_source_row_to_expanded_dest_row, permuted_idx, + expert_first_token_offset, num_rows, num_valid_tokens_ptr, cols, k, + num_local_experts, align_block_size); } template @@ -128,8 +130,7 @@ template __global__ void finalizeMoeRoutingKernel( T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output, float const* scales, int const* expanded_source_row_to_expanded_dest_row, - int const* expert_for_source_row, int64_t const orig_cols, int64_t const k, - int64_t const* num_valid_ptr) { + int64_t const orig_cols, int64_t const k, int64_t const* num_valid_ptr) { assert(orig_cols % 4 == 0); int64_t const original_row = blockIdx.x; int64_t const num_rows = gridDim.x; @@ -159,7 +160,7 @@ __global__ void finalizeMoeRoutingKernel( ComputeElem thread_output; thread_output.fill(0); for (int k_idx = 0; k_idx < k; ++k_idx) { - int64_t const expanded_original_row = original_row + k_idx * num_rows; + int64_t const expanded_original_row = original_row * k + k_idx; int64_t const expanded_permuted_row = expanded_source_row_to_expanded_dest_row[expanded_original_row]; @@ -189,9 +190,8 @@ template void finalizeMoeRoutingKernelLauncher( T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output, float const* scales, int const* expanded_source_row_to_expanded_dest_row, - int const* expert_for_source_row, int64_t const num_rows, - int64_t const cols, int64_t const k, int64_t const* num_valid_ptr, - cudaStream_t stream) { + int64_t const num_rows, int64_t const cols, int64_t const k, + int64_t const* num_valid_ptr, cudaStream_t stream) { int64_t const blocks = num_rows; int64_t const threads = 256; bool const check_finished = num_valid_ptr != nullptr; @@ -201,6 +201,5 @@ void finalizeMoeRoutingKernelLauncher( auto* const kernel = func_map[check_finished]; kernel<<>>( expanded_permuted_rows, reduced_unpermuted_output, scales, - expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k, - num_valid_ptr); + expanded_source_row_to_expanded_dest_row, cols, k, num_valid_ptr); } diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 97df311d0440..9102a774ee3d 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -56,18 +56,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { " -> Tensor"); m.def( - "moe_permute(Tensor input, Tensor topk_weight, Tensor! topk_ids," - "Tensor token_expert_indices, Tensor? expert_map, int n_expert," + "moe_permute(Tensor input, Tensor topk_ids," + "Tensor token_expert_indicies, Tensor? expert_map, int n_expert," "int n_local_expert," "int topk, int? align_block_size,Tensor! permuted_input, Tensor! " - "expert_first_token_offset, Tensor! src_row_id2dst_row_id_map, Tensor! " - "m_indices)->()"); + "expert_first_token_offset, Tensor! inv_permuted_idx, Tensor! " + "permuted_idx, Tensor! m_indices)->()"); m.def( "moe_unpermute(Tensor permuted_hidden_states, Tensor topk_weights," - "Tensor topk_ids,Tensor src_row_id2dst_row_id_map, Tensor " - "expert_first_token_offset, int n_expert, int n_local_expert,int " - "topk, Tensor! hidden_states)->()"); + "Tensor inv_permuted_idx, Tensor expert_first_token_offset, " + "int topk, Tensor! hidden_states)->()"); m.def("moe_permute_unpermute_supported() -> bool"); m.impl("moe_permute_unpermute_supported", &moe_permute_unpermute_supported); diff --git a/tests/kernels/moe/test_moe_permute_unpermute.py b/tests/kernels/moe/test_moe_permute_unpermute.py index 7cc83b512c8b..48d066ac745d 100644 --- a/tests/kernels/moe/test_moe_permute_unpermute.py +++ b/tests/kernels/moe/test_moe_permute_unpermute.py @@ -23,22 +23,28 @@ current_platform.seed_everything(0) -def torch_permute(hidden_states: torch.Tensor, - topk_ids: torch.Tensor, - token_expert_indices: torch.Tensor, - topk: int, - n_expert: int, - n_local_expert: int, - start_expert: int, - expert_map: Optional[torch.Tensor] = None, - align_block_size: Optional[int] = None, - fill_invalid_expert: int = -1) -> list[torch.Tensor]: +def torch_permute( + hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + # token_expert_indices: torch.Tensor, + topk: int, + n_expert: int, + n_local_expert: int, + start_expert: int, + expert_map: Optional[torch.Tensor] = None, + align_block_size: Optional[int] = None, + fill_invalid_expert: int = -1) -> list[torch.Tensor]: n_token, n_hidden = hidden_states.shape[0], hidden_states.shape[1] if expert_map is not None: is_local_expert = (expert_map[topk_ids] != -1) not_local_expert = (expert_map[topk_ids] == -1) topk_ids = is_local_expert * ( topk_ids - start_expert) + not_local_expert * (topk_ids + n_expert) + token_expert_indices = torch.arange(0, + n_token * topk, + dtype=torch.int32, + device=hidden_states.device).reshape( + (n_token, topk)) sorted_topk_ids, sorted_indices = torch.sort(topk_ids.flatten(), stable=True) @@ -59,8 +65,8 @@ def torch_permute(hidden_states: torch.Tensor, valid_row_idx = [] if align_block_size is None: - permuted_hidden_states = hidden_states[dst_row_id2src_row_id_map % - n_token, ...] + permuted_hidden_states = hidden_states[dst_row_id2src_row_id_map // + topk, ...] permuted_row_size = permuted_hidden_states.shape[0] m_indices = torch.empty(permuted_row_size, device="cuda", @@ -73,14 +79,21 @@ def torch_permute(hidden_states: torch.Tensor, 0, n_token * topk, device="cuda", dtype=torch.int32)[src2dst_idx].reshape((n_token, topk)) valid_row_idx += [i for i in range(expert_first_token_offset[-1])] + dst_row_id2src_row_id_map[ + expert_first_token_offset[-1]:] = n_token * topk return [ permuted_hidden_states, expert_first_token_offset, - src_row_id2dst_row_id_map, m_indices, valid_row_idx + src_row_id2dst_row_id_map, dst_row_id2src_row_id_map, m_indices, + valid_row_idx ] else: permuted_row_size = (topk * n_token + n_expert * (align_block_size - 1) + align_block_size - 1) // align_block_size * align_block_size + permuted_idx = torch.full((permuted_row_size, ), + n_token * topk, + dtype=torch.int32, + device=hidden_states.device) permuted_hidden_states = torch.empty((permuted_row_size, n_hidden), device="cuda", dtype=hidden_states.dtype) @@ -105,13 +118,16 @@ def torch_permute(hidden_states: torch.Tensor, align_first_token_offset = align_expert_first_token_offset[i - 1] align_last_token_offset = align_expert_first_token_offset[i] dst_row_id2src_row_id_in_expert = dst_row_id2src_row_id_map[ - first_token_offset:first_token_offset + - n_token_in_expert] % n_token + first_token_offset:first_token_offset + n_token_in_expert] # store token in current expert with align_first_token_offset permuted_hidden_states[align_first_token_offset:\ align_first_token_offset+n_token_in_expert,\ ...] = hidden_states[\ - dst_row_id2src_row_id_in_expert, ...] + dst_row_id2src_row_id_in_expert // topk,\ + ...] + permuted_idx[align_first_token_offset:\ + align_first_token_offset+\ + n_token_in_expert] = dst_row_id2src_row_id_in_expert # set current expert m_indices m_indices[align_first_token_offset:align_last_token_offset] = i - 1 valid_row_idx += [ @@ -135,7 +151,7 @@ def torch_permute(hidden_states: torch.Tensor, src2dst_idx].reshape((n_token, topk)) return [ permuted_hidden_states, align_expert_first_token_offset, - align_src_row_id2dst_row_id, m_indices, valid_row_idx + align_src_row_id2dst_row_id, permuted_idx, m_indices, valid_row_idx ] @@ -146,15 +162,22 @@ def torch_unpermute(permuted_hidden_states: torch.Tensor, valid_row_idx: torch.Tensor, topk: int, n_expert: int) -> torch.Tensor: # ignore invalid row + n_hidden = permuted_hidden_states.shape[1] mask = torch.zeros(permuted_hidden_states.shape[0], dtype=bool, device="cuda") mask[valid_row_idx] = True permuted_hidden_states[~mask] = 0 - idx = src_row_id2dst_row_id_map.flatten()[ - token_expert_indices.flatten()].reshape(token_expert_indices.shape) - output = permuted_hidden_states[idx, ...] * topk_weights[..., None] - output = output.sum(dim=1).to(permuted_hidden_states.dtype) + # idx = src_row_id2dst_row_id_map.flatten()[ + # token_expert_indices.flatten()].reshape(token_expert_indices.shape) + # output = permuted_hidden_states[idx, ...] * topk_weights[..., None] + # output = output.sum(dim=1).to(permuted_hidden_states.dtype) + + permuted_hidden_states = permuted_hidden_states[ + src_row_id2dst_row_id_map.flatten(), ...] + permuted_hidden_states = permuted_hidden_states.view(-1, topk, n_hidden) + output = (permuted_hidden_states * topk_weights.unsqueeze(2)).sum(1).to( + permuted_hidden_states.dtype) return output @@ -184,43 +207,55 @@ def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int, gating_output = torch.randn((n_token, n_expert), device="cuda").to(dtype) topk_weights, topk_ids, token_expert_indices = fused_topk( hidden_states, gating_output, topk, False) - gold0, gold1, gold2, gold3, valid_row_idx = torch_permute( - hidden_states, - topk_ids, - token_expert_indices, - topk, - n_expert, - n_local_expert, - start_expert, - expert_map=expert_map, - align_block_size=align_block_size, - fill_invalid_expert=fill_invalid_expert) - - result0, result1, result2, result3 = moe_permute( - hidden_states, topk_weights, topk_ids, token_expert_indices, topk, - n_expert, n_local_expert, expert_map, align_block_size, - fill_invalid_expert) + (gold_permuted_hidden_states, gold_expert_first_token_offset, + gold_inv_permuted_idx, gold_permuted_idx, gold_m_indices, + valid_row_idx) = torch_permute( + hidden_states, + topk_ids, + # token_expert_indices, + topk, + n_expert, + n_local_expert, + start_expert, + expert_map=expert_map, + align_block_size=align_block_size, + fill_invalid_expert=fill_invalid_expert) + + (permuted_hidden_states, expert_first_token_offset, inv_permuted_idx, + permuted_idx, m_indices) = moe_permute(hidden_states, topk_ids, topk, + n_expert, expert_map, + align_block_size, + fill_invalid_expert) # check expert_first_token_offset - torch.testing.assert_close(gold1, result1, atol=0, rtol=0) + torch.testing.assert_close(gold_expert_first_token_offset, + expert_first_token_offset, + atol=0, + rtol=0) # check src_row_id2dst_row_id_map - torch.testing.assert_close(gold2, result2, atol=0, rtol=0) + torch.testing.assert_close(gold_inv_permuted_idx, + inv_permuted_idx, + atol=0, + rtol=0) # check mindice - torch.testing.assert_close(gold3, result3, atol=0, rtol=0) + torch.testing.assert_close(gold_m_indices, m_indices, atol=0, rtol=0) # check permuted_hidden_states, only valid token - torch.testing.assert_close(gold0[valid_row_idx], - result0[valid_row_idx], + torch.testing.assert_close(gold_permuted_hidden_states[valid_row_idx], + permuted_hidden_states[valid_row_idx], atol=0, rtol=0) + # check permuted_idx + torch.testing.assert_close(gold_permuted_idx, permuted_idx, atol=0, rtol=0) # add a random tensor to simulate group gemm - result0 = 0.5 * result0 + torch.randn_like(result0) + result0 = 0.5 * permuted_hidden_states + torch.randn_like( + permuted_hidden_states) - result4 = moe_unpermute(result0, topk_weights, topk_ids, result2, result1, - topk, n_expert, n_local_expert) + result4 = moe_unpermute(result0, topk_weights, inv_permuted_idx, + expert_first_token_offset, topk) gold4 = torch_unpermute(result0, topk_weights, topk_ids, - token_expert_indices, result2, valid_row_idx, topk, - n_local_expert) + token_expert_indices, inv_permuted_idx, + valid_row_idx, topk, n_local_expert) # check unpermuted hidden torch.testing.assert_close(result4, gold4, atol=2e-2, rtol=0) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index b89e5ac6f093..ffb785202bb5 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -215,6 +215,42 @@ def apply( output=output) +def _customized_moe_permute( + curr_hidden_states: torch.Tensor, + a1q_scale: Optional[torch.Tensor], + curr_topk_ids: torch.Tensor, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + block_m: int, +): + fill_invalid_expert = 0 + topk = curr_topk_ids.shape[1] + tokens_in_chunk, _ = curr_hidden_states.shape + num_tokens = topk * tokens_in_chunk + (permuted_hidden_states, expert_first_token_offset, inv_permuted_idx, + permuted_idx, m_indices) = moe_permute(curr_hidden_states, curr_topk_ids, + topk, global_num_experts, + expert_map, block_m, + fill_invalid_expert) + permuted_idx = permuted_idx.clamp(max=num_tokens - 1) + if a1q_scale is not None: + a1q_scale = a1q_scale[permuted_idx // topk] + return (permuted_hidden_states, a1q_scale, permuted_idx, m_indices, + inv_permuted_idx, expert_first_token_offset) + + +def _customized_moe_unpermute_and_reduce( + curr_hidden: torch.Tensor, + inv_perm: Optional[torch.Tensor], + topk_weight: torch.Tensor, + first_token_offset: torch.Tensor, +) -> torch.Tensor: + M, topk = topk_weight.shape + output = moe_unpermute(curr_hidden, topk_weight, inv_perm, + first_token_offset, topk) + return output + + def deep_gemm_moe_fp8( hidden_states: torch.Tensor, w1: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py index 20ee0d9f780a..9eed81c03d2f 100644 --- a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py +++ b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py @@ -76,29 +76,30 @@ def _moe_unpermute_and_reduce( def moe_permute( hidden_states: torch.Tensor, - topk_weights: torch.Tensor, + # topk_weights: torch.Tensor, topk_ids: torch.Tensor, - token_expert_indices: torch.Tensor, + # token_expert_indices: torch.Tensor, topk: int, n_expert: int, - n_local_expert: int, + # n_local_expert: int, expert_map: Optional[torch.Tensor] = None, align_block_size: Optional[int] = None, fill_invalid_expert: int = -1 -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, + torch.Tensor]: """ This function expands and permutes activation to gather uncontinuous tokens for each expert. Parameters: - hidden_states (torch.Tensor): The input tensor to the MoE layer. - - topk_weights (torch.Tensor): topk expert route weight for each token. + # - topk_weights (torch.Tensor): topk expert route weight for each token. - topk_ids (torch.Tensor): topk expert route id for each token. - - token_expert_indices (torch.Tensor): indice for expanded hidden. + # - token_expert_indices (torch.Tensor): indice for expanded hidden. - topk (int): The number of top-k experts to select. - n_expert (int): The number of expert. - - n_local_expert (int): The number of expert in current EP rank. + # - n_local_expert (int): The number of expert in current EP rank. - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices - from the global expert space to the local expert space of the expert + from the global expert space to the local expert space of the expert parallel shard. - align_block_size (Optional[int]): align group gemm block size for deepgemm - fill_invalid_expert(int): fill expert id in m_indices for invalid expert @@ -108,7 +109,8 @@ def moe_permute( - expert_first_token_offset (torch.Tensor): offset of the first token of each expert for standard grouped gemm. if enable 'align_block_size' expert_first_token_offset will align up to 'align_block_size'. - - src_row_id2dst_row_id_map (torch.Tensor): idx map for moe_unpermute. + - inv_permuted_idx (torch.Tensor): idx map for moe_unpermute. + - permuted_idx (torch.Tensor): idx map from hidden to permuted_hidden. - m_indices: m_indices for grouped gemm in deepgemm,`m_indices[i]` records the group which the j-th row of the LHS belong to.` """ @@ -120,12 +122,21 @@ def moe_permute( permuted_row_size = (permuted_row_size + n_expert * (align_block_size - 1) + align_block_size - 1) // align_block_size * align_block_size + n_local_expert = n_expert + if expert_map is not None: + n_local_expert = torch.sum(expert_map != -1).item() permuted_hidden_states = torch.empty( (permuted_row_size, n_hidden), dtype=hidden_states.dtype, device=hidden_states.device, ) + token_expert_indices = torch.arange(0, + n_token * topk, + dtype=torch.int32, + device=hidden_states.device).reshape( + (n_token, topk)) + m_indices = torch.full((permuted_row_size, ), fill_invalid_expert, dtype=torch.int32, @@ -133,28 +144,32 @@ def moe_permute( expert_first_token_offset = torch.empty(n_local_expert + 1, dtype=torch.int64, device=hidden_states.device) - src_row_id2dst_row_id_map = torch.empty((n_token, topk), - dtype=torch.int32, - device=hidden_states.device) - torch.ops._moe_C.moe_permute(hidden_states, topk_weights, topk_ids, - token_expert_indices, expert_map, n_expert, - n_local_expert, topk, align_block_size, - permuted_hidden_states, - expert_first_token_offset, - src_row_id2dst_row_id_map, m_indices) + # todo clamp (0, n_token * topk - 1) to avoid out of bound ? + permuted_idx = torch.full((permuted_row_size, ), + n_token * topk, + dtype=torch.int32, + device=hidden_states.device) + inv_permuted_idx = torch.empty((n_token, topk), + dtype=torch.int32, + device=hidden_states.device) + torch.ops._moe_C.moe_permute(hidden_states, topk_ids, token_expert_indices, + expert_map, n_expert, n_local_expert, topk, + align_block_size, permuted_hidden_states, + expert_first_token_offset, inv_permuted_idx, + permuted_idx, m_indices) return (permuted_hidden_states, expert_first_token_offset, - src_row_id2dst_row_id_map, m_indices) + inv_permuted_idx, permuted_idx, m_indices) def moe_unpermute( permuted_hidden_states: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - src_row_id2dst_row_id_map: torch.Tensor, + # topk_ids: torch.Tensor, + inv_permuted_idx: torch.Tensor, expert_first_token_offset: torch.Tensor, topk: int, - n_expert: int, - n_local_expert: int, + # n_expert: int, + # n_local_expert: int, ) -> torch.Tensor: """ This function expands and permutes activation to gathering uncontinuous @@ -162,12 +177,13 @@ def moe_unpermute( Parameters: - permuted_hidden_states (torch.Tensor): permuted activation. - topk_weights (torch.Tensor): topk expert route weight for each token. - - topk_ids (torch.Tensor): topk expert route id for each token. + # - topk_ids (torch.Tensor): topk expert route id for each token. + - inv_permuted_idx (torch.Tensor): row idx map for moe_unpermute. - expert_first_token_offset (torch.Tensor): offset of the first token of each expert for grouped gemm. - topk (int): The number of top-k experts to select. - - n_expert (int): The number of expert. - - n_local_expert (int): The number of expert in current EP rank. + # - n_expert (int): The number of expert. + # - n_local_expert (int): The number of expert in current EP rank. Returns: - hidden_states (torch.Tensor): The reduced and unpermuted activation tensor. @@ -180,9 +196,8 @@ def moe_unpermute( device=permuted_hidden_states.device) torch.ops._moe_C.moe_unpermute(permuted_hidden_states, topk_weights, - topk_ids, src_row_id2dst_row_id_map, - expert_first_token_offset, n_expert, - n_local_expert, topk, hidden_states) + inv_permuted_idx, expert_first_token_offset, + topk, hidden_states) return hidden_states From 4482899a27b8fd7df9b5958cc75e4c1871082ae1 Mon Sep 17 00:00:00 2001 From: Caleb_Du Date: Fri, 9 May 2025 05:24:40 -0700 Subject: [PATCH 2/6] remove useless code and comment Signed-off-by: Caleb_Du --- .../benchmark_moe_permute_unpermute.py | 28 +++++++++---------- csrc/moe/moe_permute_unpermute_op.cu | 5 +--- .../moe_permute_unpermute_kernel.cu | 2 +- .../moe_permute_unpermute_kernel.inl | 2 -- .../kernels/moe/test_moe_permute_unpermute.py | 6 +--- .../layers/fused_moe/moe_permute_unpermute.py | 12 -------- 6 files changed, 17 insertions(+), 38 deletions(-) diff --git a/benchmarks/kernels/benchmark_moe_permute_unpermute.py b/benchmarks/kernels/benchmark_moe_permute_unpermute.py index ed50e407d5f4..b7e86f13db2c 100644 --- a/benchmarks/kernels/benchmark_moe_permute_unpermute.py +++ b/benchmarks/kernels/benchmark_moe_permute_unpermute.py @@ -64,13 +64,13 @@ def prepare(i: int): def run(): if use_customized_permute: (permuted_hidden_states, first_token_off, inv_perm_idx, - permuted_idx, m_indices) = moe_permute( - qhidden_states, - topk_ids=topk_ids, - topk=topk, - n_expert=num_experts, - expert_map=None, - align_block_size=align_block_size) + permuted_idx, + m_indices) = moe_permute(qhidden_states, + topk_ids=topk_ids, + topk=topk, + n_expert=num_experts, + expert_map=None, + align_block_size=align_block_size) else: ( permuted_hidden_states, @@ -146,13 +146,13 @@ def benchmark_unpermute( def prepare(): if use_customized_permute: (permuted_hidden_states, first_token_off, inv_perm_idx, - permuted_idx, m_indices) = moe_permute( - qhidden_states, - topk_ids=topk_ids, - topk=topk, - n_expert=num_experts, - expert_map=None, - align_block_size=align_block_size) + permuted_idx, + m_indices) = moe_permute(qhidden_states, + topk_ids=topk_ids, + topk=topk, + n_expert=num_experts, + expert_map=None, + align_block_size=align_block_size) # convert to fp16/bf16 as gemm output return ( permuted_hidden_states.to(dtype), diff --git a/csrc/moe/moe_permute_unpermute_op.cu b/csrc/moe/moe_permute_unpermute_op.cu index e57e90f64aba..e81d3e337d4e 100644 --- a/csrc/moe/moe_permute_unpermute_op.cu +++ b/csrc/moe/moe_permute_unpermute_op.cu @@ -9,8 +9,7 @@ #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000) void moe_permute( - const torch::Tensor& input, // [n_token, hidden] - // const torch::Tensor& topk_weights, //[n_token, topk] + const torch::Tensor& input, // [n_token, hidden] const torch::Tensor& topk_ids, // [n_token, topk] const torch::Tensor& token_expert_indicies, // [n_token, topk] const std::optional& expert_map, // [n_expert] @@ -21,8 +20,6 @@ void moe_permute( torch::Tensor& inv_permuted_idx, // [n_token, topk] torch::Tensor& permuted_idx, // [permute_size] torch::Tensor& m_indices) { // [align_expand_m] - // TORCH_CHECK(topk_weights.scalar_type() == at::ScalarType::Float, - // "topk_weights must be float32"); TORCH_CHECK(expert_first_token_offset.scalar_type() == at::ScalarType::Long, "expert_first_token_offset must be int64"); TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int, diff --git a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu index de2c153882d9..2271c1bc75b1 100644 --- a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu +++ b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu @@ -177,7 +177,7 @@ __global__ void getMIndicesKernel(int64_t* expert_first_token_offset, int tidx = threadIdx.x; extern __shared__ int64_t smem_expert_first_token_offset[]; for (int i = tidx; i <= num_local_expert; i += blockDim.x) { - smem_expert_first_token_offset[tidx] = __ldg(expert_first_token_offset + i); + smem_expert_first_token_offset[i] = __ldg(expert_first_token_offset + i); } __syncthreads(); auto last_token_offset = smem_expert_first_token_offset[eidx + 1]; diff --git a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl index 02c420f8217b..449243b92a28 100644 --- a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl +++ b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl @@ -133,7 +133,6 @@ __global__ void finalizeMoeRoutingKernel( int64_t const orig_cols, int64_t const k, int64_t const* num_valid_ptr) { assert(orig_cols % 4 == 0); int64_t const original_row = blockIdx.x; - int64_t const num_rows = gridDim.x; auto const offset = original_row * orig_cols; OutputType* reduced_row_ptr = reduced_unpermuted_output + offset; int64_t const num_valid = *num_valid_ptr; @@ -167,7 +166,6 @@ __global__ void finalizeMoeRoutingKernel( int64_t const k_offset = original_row * k + k_idx; float const row_scale = scales[k_offset]; - // Check after row_rescale has accumulated if (CHECK_SKIPPED && expanded_permuted_row >= num_valid) { continue; } diff --git a/tests/kernels/moe/test_moe_permute_unpermute.py b/tests/kernels/moe/test_moe_permute_unpermute.py index 48d066ac745d..c80db4c16608 100644 --- a/tests/kernels/moe/test_moe_permute_unpermute.py +++ b/tests/kernels/moe/test_moe_permute_unpermute.py @@ -17,7 +17,7 @@ moe_permute, moe_permute_unpermute_supported, moe_unpermute) from vllm.platforms import current_platform -NUM_EXPERTS = [16, 64] +NUM_EXPERTS = [16, 64, 256] TOP_KS = [2, 4, 6, 8] EP_SIZE = [1, 4, 16] current_platform.seed_everything(0) @@ -168,10 +168,6 @@ def torch_unpermute(permuted_hidden_states: torch.Tensor, device="cuda") mask[valid_row_idx] = True permuted_hidden_states[~mask] = 0 - # idx = src_row_id2dst_row_id_map.flatten()[ - # token_expert_indices.flatten()].reshape(token_expert_indices.shape) - # output = permuted_hidden_states[idx, ...] * topk_weights[..., None] - # output = output.sum(dim=1).to(permuted_hidden_states.dtype) permuted_hidden_states = permuted_hidden_states[ src_row_id2dst_row_id_map.flatten(), ...] diff --git a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py index 9eed81c03d2f..344cb4644008 100644 --- a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py +++ b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py @@ -76,12 +76,9 @@ def _moe_unpermute_and_reduce( def moe_permute( hidden_states: torch.Tensor, - # topk_weights: torch.Tensor, topk_ids: torch.Tensor, - # token_expert_indices: torch.Tensor, topk: int, n_expert: int, - # n_local_expert: int, expert_map: Optional[torch.Tensor] = None, align_block_size: Optional[int] = None, fill_invalid_expert: int = -1 @@ -92,12 +89,9 @@ def moe_permute( for each expert. Parameters: - hidden_states (torch.Tensor): The input tensor to the MoE layer. - # - topk_weights (torch.Tensor): topk expert route weight for each token. - topk_ids (torch.Tensor): topk expert route id for each token. - # - token_expert_indices (torch.Tensor): indice for expanded hidden. - topk (int): The number of top-k experts to select. - n_expert (int): The number of expert. - # - n_local_expert (int): The number of expert in current EP rank. - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices from the global expert space to the local expert space of the expert parallel shard. @@ -164,12 +158,9 @@ def moe_permute( def moe_unpermute( permuted_hidden_states: torch.Tensor, topk_weights: torch.Tensor, - # topk_ids: torch.Tensor, inv_permuted_idx: torch.Tensor, expert_first_token_offset: torch.Tensor, topk: int, - # n_expert: int, - # n_local_expert: int, ) -> torch.Tensor: """ This function expands and permutes activation to gathering uncontinuous @@ -177,13 +168,10 @@ def moe_unpermute( Parameters: - permuted_hidden_states (torch.Tensor): permuted activation. - topk_weights (torch.Tensor): topk expert route weight for each token. - # - topk_ids (torch.Tensor): topk expert route id for each token. - inv_permuted_idx (torch.Tensor): row idx map for moe_unpermute. - expert_first_token_offset (torch.Tensor): offset of the first token of each expert for grouped gemm. - topk (int): The number of top-k experts to select. - # - n_expert (int): The number of expert. - # - n_local_expert (int): The number of expert in current EP rank. Returns: - hidden_states (torch.Tensor): The reduced and unpermuted activation tensor. From 4795c8bea9c04d8be5b030e032d2324b9207d6f8 Mon Sep 17 00:00:00 2001 From: Caleb_Du Date: Fri, 16 May 2025 14:50:02 +0000 Subject: [PATCH 3/6] refactor moe unpermute 1. set output as parameter 2. support expert_first_token_offset is optional Signed-off-by: Caleb_Du --- .../benchmark_moe_permute_unpermute.py | 77 +++++++++++++------ csrc/moe/moe_permute_unpermute_op.cu | 20 +++-- csrc/moe/torch_bindings.cpp | 2 +- .../kernels/moe/test_moe_permute_unpermute.py | 6 +- .../layers/fused_moe/deep_gemm_moe.py | 36 --------- .../layers/fused_moe/moe_permute_unpermute.py | 55 ++++++++++--- 6 files changed, 113 insertions(+), 83 deletions(-) diff --git a/benchmarks/kernels/benchmark_moe_permute_unpermute.py b/benchmarks/kernels/benchmark_moe_permute_unpermute.py index b7e86f13db2c..90f8ca7fa2e6 100644 --- a/benchmarks/kernels/benchmark_moe_permute_unpermute.py +++ b/benchmarks/kernels/benchmark_moe_permute_unpermute.py @@ -8,12 +8,13 @@ import torch from transformers import AutoConfig -from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( +from vllm.model_executor.layers.fused_moe.fused_moe import * +from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( _moe_permute, _moe_unpermute_and_reduce, + moe_permute, + moe_unpermute, ) -from vllm.model_executor.layers.fused_moe.fused_moe import * -from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import * from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize from vllm.platforms import current_platform from vllm.utils import FlexibleArgumentParser @@ -63,14 +64,20 @@ def prepare(i: int): def run(): if use_customized_permute: - (permuted_hidden_states, first_token_off, inv_perm_idx, - permuted_idx, - m_indices) = moe_permute(qhidden_states, - topk_ids=topk_ids, - topk=topk, - n_expert=num_experts, - expert_map=None, - align_block_size=align_block_size) + ( + permuted_hidden_states, + first_token_off, + inv_perm_idx, + permuted_idx, + m_indices, + ) = moe_permute( + qhidden_states, + topk_ids=topk_ids, + topk=topk, + n_expert=num_experts, + expert_map=None, + align_block_size=align_block_size, + ) else: ( permuted_hidden_states, @@ -145,14 +152,20 @@ def benchmark_unpermute( def prepare(): if use_customized_permute: - (permuted_hidden_states, first_token_off, inv_perm_idx, - permuted_idx, - m_indices) = moe_permute(qhidden_states, - topk_ids=topk_ids, - topk=topk, - n_expert=num_experts, - expert_map=None, - align_block_size=align_block_size) + ( + permuted_hidden_states, + first_token_off, + inv_perm_idx, + permuted_idx, + m_indices, + ) = moe_permute( + qhidden_states, + topk_ids=topk_ids, + topk=topk, + n_expert=num_experts, + expert_map=None, + align_block_size=align_block_size, + ) # convert to fp16/bf16 as gemm output return ( permuted_hidden_states.to(dtype), @@ -182,10 +195,22 @@ def prepare(): def run(input: tuple): if use_customized_permute: - (permuted_hidden_states, first_token_off, inv_perm_idx, - permuted_idx, m_indices) = input - moe_unpermute(permuted_hidden_states, topk_weights, inv_perm_idx, - first_token_off, topk) + ( + permuted_hidden_states, + first_token_off, + inv_perm_idx, + permuted_idx, + m_indices, + ) = input + output = torch.empty_like(hidden_states) + moe_unpermute( + output, + permuted_hidden_states, + topk_weights, + inv_perm_idx, + topk, + first_token_off, + ) else: ( permuted_hidden_states, @@ -195,7 +220,11 @@ def run(input: tuple): inv_perm, ) = input _moe_unpermute_and_reduce( - output_hidden_states, permuted_hidden_states, inv_perm, topk_weights + output_hidden_states, + permuted_hidden_states, + inv_perm, + topk_weights, + True, ) # JIT compilation & warmup diff --git a/csrc/moe/moe_permute_unpermute_op.cu b/csrc/moe/moe_permute_unpermute_op.cu index e81d3e337d4e..14499bd7c65e 100644 --- a/csrc/moe/moe_permute_unpermute_op.cu +++ b/csrc/moe/moe_permute_unpermute_op.cu @@ -96,10 +96,11 @@ void moe_permute( } void moe_unpermute( - const torch::Tensor& permuted_hidden_states, // [n_token * topk, hidden] - const torch::Tensor& topk_weights, //[n_token, topk] - const torch::Tensor& inv_permuted_idx, // [n_token, topk] - const torch::Tensor& expert_first_token_offset, // [n_local_expert+1] + const torch::Tensor& permuted_hidden_states, // [n_token * topk, hidden] + const torch::Tensor& topk_weights, // [n_token, topk] + const torch::Tensor& inv_permuted_idx, // [n_token, topk] + const std::optional& + expert_first_token_offset, // [n_local_expert+1] int64_t topk, torch::Tensor& hidden_states // [n_token, hidden] ) { @@ -109,9 +110,14 @@ void moe_unpermute( auto n_token = hidden_states.size(0); auto n_hidden = hidden_states.size(1); auto stream = at::cuda::getCurrentCUDAStream().stream(); - int n_local_expert = expert_first_token_offset.size(0) - 1; - const int64_t* valid_ptr = - get_ptr(expert_first_token_offset) + n_local_expert; + + int64_t const* valid_ptr = nullptr; + if (expert_first_token_offset.has_value()) { + int n_local_expert = expert_first_token_offset.value().size(0) - 1; + valid_ptr = + get_ptr(expert_first_token_offset.value()) + n_local_expert; + } + MOE_DISPATCH(hidden_states.scalar_type(), [&] { finalizeMoeRoutingKernelLauncher( get_ptr(permuted_hidden_states), diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 9102a774ee3d..fb0a7f8b101d 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -65,7 +65,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { m.def( "moe_unpermute(Tensor permuted_hidden_states, Tensor topk_weights," - "Tensor inv_permuted_idx, Tensor expert_first_token_offset, " + "Tensor inv_permuted_idx, Tensor? expert_first_token_offset, " "int topk, Tensor! hidden_states)->()"); m.def("moe_permute_unpermute_supported() -> bool"); diff --git a/tests/kernels/moe/test_moe_permute_unpermute.py b/tests/kernels/moe/test_moe_permute_unpermute.py index c80db4c16608..6f4149f34180 100644 --- a/tests/kernels/moe/test_moe_permute_unpermute.py +++ b/tests/kernels/moe/test_moe_permute_unpermute.py @@ -246,12 +246,12 @@ def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int, # add a random tensor to simulate group gemm result0 = 0.5 * permuted_hidden_states + torch.randn_like( permuted_hidden_states) + result4 = torch.empty_like(hidden_states) + moe_unpermute(result4, result0, topk_weights, inv_permuted_idx, topk, + expert_first_token_offset) - result4 = moe_unpermute(result0, topk_weights, inv_permuted_idx, - expert_first_token_offset, topk) gold4 = torch_unpermute(result0, topk_weights, topk_ids, token_expert_indices, inv_permuted_idx, valid_row_idx, topk, n_local_expert) - # check unpermuted hidden torch.testing.assert_close(result4, gold4, atol=2e-2, rtol=0) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index ffb785202bb5..b89e5ac6f093 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -215,42 +215,6 @@ def apply( output=output) -def _customized_moe_permute( - curr_hidden_states: torch.Tensor, - a1q_scale: Optional[torch.Tensor], - curr_topk_ids: torch.Tensor, - global_num_experts: int, - expert_map: Optional[torch.Tensor], - block_m: int, -): - fill_invalid_expert = 0 - topk = curr_topk_ids.shape[1] - tokens_in_chunk, _ = curr_hidden_states.shape - num_tokens = topk * tokens_in_chunk - (permuted_hidden_states, expert_first_token_offset, inv_permuted_idx, - permuted_idx, m_indices) = moe_permute(curr_hidden_states, curr_topk_ids, - topk, global_num_experts, - expert_map, block_m, - fill_invalid_expert) - permuted_idx = permuted_idx.clamp(max=num_tokens - 1) - if a1q_scale is not None: - a1q_scale = a1q_scale[permuted_idx // topk] - return (permuted_hidden_states, a1q_scale, permuted_idx, m_indices, - inv_permuted_idx, expert_first_token_offset) - - -def _customized_moe_unpermute_and_reduce( - curr_hidden: torch.Tensor, - inv_perm: Optional[torch.Tensor], - topk_weight: torch.Tensor, - first_token_offset: torch.Tensor, -) -> torch.Tensor: - M, topk = topk_weight.shape - output = moe_unpermute(curr_hidden, topk_weight, inv_perm, - first_token_offset, topk) - return output - - def deep_gemm_moe_fp8( hidden_states: torch.Tensor, w1: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py index 344cb4644008..f1b6883df578 100644 --- a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py +++ b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py @@ -82,7 +82,7 @@ def moe_permute( expert_map: Optional[torch.Tensor] = None, align_block_size: Optional[int] = None, fill_invalid_expert: int = -1 -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ This function expands and permutes activation to gather uncontinuous tokens @@ -138,7 +138,6 @@ def moe_permute( expert_first_token_offset = torch.empty(n_local_expert + 1, dtype=torch.int64, device=hidden_states.device) - # todo clamp (0, n_token * topk - 1) to avoid out of bound ? permuted_idx = torch.full((permuted_row_size, ), n_token * topk, dtype=torch.int32, @@ -156,38 +155,70 @@ def moe_permute( def moe_unpermute( + out: torch.Tensor, permuted_hidden_states: torch.Tensor, topk_weights: torch.Tensor, inv_permuted_idx: torch.Tensor, - expert_first_token_offset: torch.Tensor, topk: int, + expert_first_token_offset: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ This function expands and permutes activation to gathering uncontinuous tokens for each expert. Parameters: + - out (torch.Tensor): output tensor - permuted_hidden_states (torch.Tensor): permuted activation. - topk_weights (torch.Tensor): topk expert route weight for each token. - inv_permuted_idx (torch.Tensor): row idx map for moe_unpermute. - - expert_first_token_offset (torch.Tensor): offset of the first token - of each expert for grouped gemm. - topk (int): The number of top-k experts to select. + - expert_first_token_offset (Optional[torch.Tensor]): offset of the first + token of each expert for grouped gemm. Returns: - hidden_states (torch.Tensor): The reduced and unpermuted activation tensor. """ - n_token, n_hidden = topk_weights.size(0), permuted_hidden_states.size(-1) + n_hidden = permuted_hidden_states.size(-1) assert (n_hidden * permuted_hidden_states.element_size() ) % 16 == 0, "unpermue kernel need hidden dim align to 16B" - hidden_states = torch.empty((n_token, n_hidden), - dtype=permuted_hidden_states.dtype, - device=permuted_hidden_states.device) - torch.ops._moe_C.moe_unpermute(permuted_hidden_states, topk_weights, inv_permuted_idx, expert_first_token_offset, - topk, hidden_states) - return hidden_states + topk, out) +def _customized_moe_permute( + curr_hidden_states: torch.Tensor, + a1q_scale: Optional[torch.Tensor], + curr_topk_ids: torch.Tensor, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + block_m: int, +): + fill_invalid_expert = -1 + topk = curr_topk_ids.shape[1] + tokens_in_chunk, _ = curr_hidden_states.shape + num_tokens = topk * tokens_in_chunk + (permuted_hidden_states, expert_first_token_offset, inv_permuted_idx, + permuted_idx, m_indices) = moe_permute(curr_hidden_states, curr_topk_ids, + topk, global_num_experts, + expert_map, block_m, + fill_invalid_expert) + permuted_idx = permuted_idx.clamp(max=num_tokens - 1) + if a1q_scale is not None: + a1q_scale = a1q_scale[permuted_idx // topk] + return (permuted_hidden_states, a1q_scale, permuted_idx, m_indices, + inv_permuted_idx, expert_first_token_offset) + + +def _customized_moe_unpermute_and_reduce( + curr_hidden: torch.Tensor, + inv_perm: Optional[torch.Tensor], + topk_weight: torch.Tensor, + first_token_offset: torch.Tensor, +) -> torch.Tensor: + M, topk = topk_weight.shape + output = moe_unpermute(curr_hidden, topk_weight, inv_perm, + first_token_offset, topk) + return output + def moe_permute_unpermute_supported(): return torch.ops._moe_C.moe_permute_unpermute_supported() From 1dc61102e1dae779a6606581f157237ea2d86d48 Mon Sep 17 00:00:00 2001 From: Caleb_Du Date: Sun, 18 May 2025 13:48:15 +0000 Subject: [PATCH 4/6] 1. integrate moe_permute firstly after refactored moe Signed-off-by: Caleb_Du --- .../kernels/benchmark_moe_permute_unpermute.py | 4 ++++ tests/kernels/moe/test_moe_permute_unpermute.py | 15 ++++++++++----- .../layers/fused_moe/moe_permute_unpermute.py | 13 +++++++++---- 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/benchmarks/kernels/benchmark_moe_permute_unpermute.py b/benchmarks/kernels/benchmark_moe_permute_unpermute.py index 90f8ca7fa2e6..3fb52a3de78d 100644 --- a/benchmarks/kernels/benchmark_moe_permute_unpermute.py +++ b/benchmarks/kernels/benchmark_moe_permute_unpermute.py @@ -66,12 +66,14 @@ def run(): if use_customized_permute: ( permuted_hidden_states, + a1q_scale, first_token_off, inv_perm_idx, permuted_idx, m_indices, ) = moe_permute( qhidden_states, + a1q_scale=None, topk_ids=topk_ids, topk=topk, n_expert=num_experts, @@ -154,12 +156,14 @@ def prepare(): if use_customized_permute: ( permuted_hidden_states, + a1q_scale, first_token_off, inv_perm_idx, permuted_idx, m_indices, ) = moe_permute( qhidden_states, + a1q_scale=None, topk_ids=topk_ids, topk=topk, n_expert=num_experts, diff --git a/tests/kernels/moe/test_moe_permute_unpermute.py b/tests/kernels/moe/test_moe_permute_unpermute.py index 6f4149f34180..1c63990a43c4 100644 --- a/tests/kernels/moe/test_moe_permute_unpermute.py +++ b/tests/kernels/moe/test_moe_permute_unpermute.py @@ -217,11 +217,16 @@ def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int, align_block_size=align_block_size, fill_invalid_expert=fill_invalid_expert) - (permuted_hidden_states, expert_first_token_offset, inv_permuted_idx, - permuted_idx, m_indices) = moe_permute(hidden_states, topk_ids, topk, - n_expert, expert_map, - align_block_size, - fill_invalid_expert) + (permuted_hidden_states, _, expert_first_token_offset, inv_permuted_idx, + permuted_idx, + m_indices) = moe_permute(hidden_states=hidden_states, + a1q_scale=None, + topk_ids=topk_ids, + topk=topk, + n_expert=n_expert, + expert_map=expert_map, + align_block_size=align_block_size, + fill_invalid_expert=fill_invalid_expert) # check expert_first_token_offset torch.testing.assert_close(gold_expert_first_token_offset, diff --git a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py index f1b6883df578..47123154af70 100644 --- a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py +++ b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py @@ -76,6 +76,7 @@ def _moe_unpermute_and_reduce( def moe_permute( hidden_states: torch.Tensor, + a1q_scale: Optional[torch.Tensor], topk_ids: torch.Tensor, topk: int, n_expert: int, @@ -83,12 +84,13 @@ def moe_permute( align_block_size: Optional[int] = None, fill_invalid_expert: int = -1 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, - torch.Tensor]: + torch.Tensor, torch.Tensor]: """ This function expands and permutes activation to gather uncontinuous tokens for each expert. Parameters: - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - a1q_scale (Optional[torch.Tensor]): quant scale for hidden_states - topk_ids (torch.Tensor): topk expert route id for each token. - topk (int): The number of top-k experts to select. - n_expert (int): The number of expert. @@ -100,6 +102,7 @@ def moe_permute( to workaround DeepGemm unsupported -1 in m_indices Returns: - permuted_hidden_states (torch.Tensor): permuted activation. + - a1q_scale (Optional[torch.Tensor]): quant scale for hidden_states - expert_first_token_offset (torch.Tensor): offset of the first token of each expert for standard grouped gemm. if enable 'align_block_size' expert_first_token_offset will align up to 'align_block_size'. @@ -150,7 +153,10 @@ def moe_permute( align_block_size, permuted_hidden_states, expert_first_token_offset, inv_permuted_idx, permuted_idx, m_indices) - return (permuted_hidden_states, expert_first_token_offset, + if a1q_scale is not None: + a1q_scale = a1q_scale[permuted_idx.clamp(max=n_token * topk - 1) // + topk] + return (permuted_hidden_states, a1q_scale, expert_first_token_offset, inv_permuted_idx, permuted_idx, m_indices) @@ -161,7 +167,7 @@ def moe_unpermute( inv_permuted_idx: torch.Tensor, topk: int, expert_first_token_offset: Optional[torch.Tensor] = None, -) -> torch.Tensor: +) -> None: """ This function expands and permutes activation to gathering uncontinuous tokens for each expert. @@ -184,7 +190,6 @@ def moe_unpermute( inv_permuted_idx, expert_first_token_offset, topk, out) - def _customized_moe_permute( curr_hidden_states: torch.Tensor, a1q_scale: Optional[torch.Tensor], From 58fe99ee5152a20a2e9110a47e191ef7bd938a3e Mon Sep 17 00:00:00 2001 From: Caleb_Du Date: Thu, 29 May 2025 12:52:16 +0000 Subject: [PATCH 5/6] update with bnell's review Signed-off-by: Caleb_Du --- .../benchmark_moe_permute_unpermute.py | 7 --- csrc/moe/torch_bindings.cpp | 2 +- .../kernels/moe/test_moe_permute_unpermute.py | 7 +-- .../layers/fused_moe/moe_permute_unpermute.py | 46 ++----------------- 4 files changed, 7 insertions(+), 55 deletions(-) diff --git a/benchmarks/kernels/benchmark_moe_permute_unpermute.py b/benchmarks/kernels/benchmark_moe_permute_unpermute.py index 3fb52a3de78d..04d2205aa372 100644 --- a/benchmarks/kernels/benchmark_moe_permute_unpermute.py +++ b/benchmarks/kernels/benchmark_moe_permute_unpermute.py @@ -69,13 +69,11 @@ def run(): a1q_scale, first_token_off, inv_perm_idx, - permuted_idx, m_indices, ) = moe_permute( qhidden_states, a1q_scale=None, topk_ids=topk_ids, - topk=topk, n_expert=num_experts, expert_map=None, align_block_size=align_block_size, @@ -159,13 +157,11 @@ def prepare(): a1q_scale, first_token_off, inv_perm_idx, - permuted_idx, m_indices, ) = moe_permute( qhidden_states, a1q_scale=None, topk_ids=topk_ids, - topk=topk, n_expert=num_experts, expert_map=None, align_block_size=align_block_size, @@ -175,7 +171,6 @@ def prepare(): permuted_hidden_states.to(dtype), first_token_off, inv_perm_idx, - permuted_idx, m_indices, ) else: @@ -203,7 +198,6 @@ def run(input: tuple): permuted_hidden_states, first_token_off, inv_perm_idx, - permuted_idx, m_indices, ) = input output = torch.empty_like(hidden_states) @@ -212,7 +206,6 @@ def run(input: tuple): permuted_hidden_states, topk_weights, inv_perm_idx, - topk, first_token_off, ) else: diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index fb0a7f8b101d..d96e082f6ef1 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -57,7 +57,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { m.def( "moe_permute(Tensor input, Tensor topk_ids," - "Tensor token_expert_indicies, Tensor? expert_map, int n_expert," + "Tensor token_expert_indices, Tensor? expert_map, int n_expert," "int n_local_expert," "int topk, int? align_block_size,Tensor! permuted_input, Tensor! " "expert_first_token_offset, Tensor! inv_permuted_idx, Tensor! " diff --git a/tests/kernels/moe/test_moe_permute_unpermute.py b/tests/kernels/moe/test_moe_permute_unpermute.py index 1c63990a43c4..192256b6ba3d 100644 --- a/tests/kernels/moe/test_moe_permute_unpermute.py +++ b/tests/kernels/moe/test_moe_permute_unpermute.py @@ -218,11 +218,9 @@ def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int, fill_invalid_expert=fill_invalid_expert) (permuted_hidden_states, _, expert_first_token_offset, inv_permuted_idx, - permuted_idx, m_indices) = moe_permute(hidden_states=hidden_states, a1q_scale=None, topk_ids=topk_ids, - topk=topk, n_expert=n_expert, expert_map=expert_map, align_block_size=align_block_size, @@ -245,14 +243,11 @@ def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int, permuted_hidden_states[valid_row_idx], atol=0, rtol=0) - # check permuted_idx - torch.testing.assert_close(gold_permuted_idx, permuted_idx, atol=0, rtol=0) - # add a random tensor to simulate group gemm result0 = 0.5 * permuted_hidden_states + torch.randn_like( permuted_hidden_states) result4 = torch.empty_like(hidden_states) - moe_unpermute(result4, result0, topk_weights, inv_permuted_idx, topk, + moe_unpermute(result4, result0, topk_weights, inv_permuted_idx, expert_first_token_offset) gold4 = torch_unpermute(result0, topk_weights, topk_ids, diff --git a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py index 47123154af70..26fe7182b091 100644 --- a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py +++ b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py @@ -78,13 +78,12 @@ def moe_permute( hidden_states: torch.Tensor, a1q_scale: Optional[torch.Tensor], topk_ids: torch.Tensor, - topk: int, n_expert: int, expert_map: Optional[torch.Tensor] = None, align_block_size: Optional[int] = None, fill_invalid_expert: int = -1 -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, - torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, + torch.Tensor]: """ This function expands and permutes activation to gather uncontinuous tokens for each expert. @@ -92,7 +91,6 @@ def moe_permute( - hidden_states (torch.Tensor): The input tensor to the MoE layer. - a1q_scale (Optional[torch.Tensor]): quant scale for hidden_states - topk_ids (torch.Tensor): topk expert route id for each token. - - topk (int): The number of top-k experts to select. - n_expert (int): The number of expert. - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices from the global expert space to the local expert space of the expert @@ -112,6 +110,7 @@ def moe_permute( the group which the j-th row of the LHS belong to.` """ n_token, n_hidden = hidden_states.size() + topk = topk_ids.size(1) assert (n_hidden * hidden_states.element_size() ) % 16 == 0, "permue kernel need hidden dim align to 16B" permuted_row_size = n_token * topk @@ -157,7 +156,7 @@ def moe_permute( a1q_scale = a1q_scale[permuted_idx.clamp(max=n_token * topk - 1) // topk] return (permuted_hidden_states, a1q_scale, expert_first_token_offset, - inv_permuted_idx, permuted_idx, m_indices) + inv_permuted_idx.flatten(), m_indices) def moe_unpermute( @@ -165,7 +164,6 @@ def moe_unpermute( permuted_hidden_states: torch.Tensor, topk_weights: torch.Tensor, inv_permuted_idx: torch.Tensor, - topk: int, expert_first_token_offset: Optional[torch.Tensor] = None, ) -> None: """ @@ -176,13 +174,13 @@ def moe_unpermute( - permuted_hidden_states (torch.Tensor): permuted activation. - topk_weights (torch.Tensor): topk expert route weight for each token. - inv_permuted_idx (torch.Tensor): row idx map for moe_unpermute. - - topk (int): The number of top-k experts to select. - expert_first_token_offset (Optional[torch.Tensor]): offset of the first token of each expert for grouped gemm. Returns: - hidden_states (torch.Tensor): The reduced and unpermuted activation tensor. """ + topk = topk_weights.size(1) n_hidden = permuted_hidden_states.size(-1) assert (n_hidden * permuted_hidden_states.element_size() ) % 16 == 0, "unpermue kernel need hidden dim align to 16B" @@ -190,40 +188,6 @@ def moe_unpermute( inv_permuted_idx, expert_first_token_offset, topk, out) -def _customized_moe_permute( - curr_hidden_states: torch.Tensor, - a1q_scale: Optional[torch.Tensor], - curr_topk_ids: torch.Tensor, - global_num_experts: int, - expert_map: Optional[torch.Tensor], - block_m: int, -): - fill_invalid_expert = -1 - topk = curr_topk_ids.shape[1] - tokens_in_chunk, _ = curr_hidden_states.shape - num_tokens = topk * tokens_in_chunk - (permuted_hidden_states, expert_first_token_offset, inv_permuted_idx, - permuted_idx, m_indices) = moe_permute(curr_hidden_states, curr_topk_ids, - topk, global_num_experts, - expert_map, block_m, - fill_invalid_expert) - permuted_idx = permuted_idx.clamp(max=num_tokens - 1) - if a1q_scale is not None: - a1q_scale = a1q_scale[permuted_idx // topk] - return (permuted_hidden_states, a1q_scale, permuted_idx, m_indices, - inv_permuted_idx, expert_first_token_offset) - - -def _customized_moe_unpermute_and_reduce( - curr_hidden: torch.Tensor, - inv_perm: Optional[torch.Tensor], - topk_weight: torch.Tensor, - first_token_offset: torch.Tensor, -) -> torch.Tensor: - M, topk = topk_weight.shape - output = moe_unpermute(curr_hidden, topk_weight, inv_perm, - first_token_offset, topk) - return output def moe_permute_unpermute_supported(): return torch.ops._moe_C.moe_permute_unpermute_supported() From 8e17c5130ec93cc4b802e982f2a8c55a7be7ea1e Mon Sep 17 00:00:00 2001 From: Caleb_Du Date: Thu, 3 Jul 2025 06:20:21 -0700 Subject: [PATCH 6/6] add n_local_expert for moe_permute Signed-off-by: Caleb_Du --- csrc/moe/moe_permute_unpermute_op.cu | 12 ++++++------ tests/kernels/moe/test_moe_permute_unpermute.py | 3 ++- .../layers/fused_moe/moe_permute_unpermute.py | 9 +++++---- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/csrc/moe/moe_permute_unpermute_op.cu b/csrc/moe/moe_permute_unpermute_op.cu index 14499bd7c65e..2922352a3f7c 100644 --- a/csrc/moe/moe_permute_unpermute_op.cu +++ b/csrc/moe/moe_permute_unpermute_op.cu @@ -11,7 +11,7 @@ void moe_permute( const torch::Tensor& input, // [n_token, hidden] const torch::Tensor& topk_ids, // [n_token, topk] - const torch::Tensor& token_expert_indicies, // [n_token, topk] + const torch::Tensor& token_expert_indices, // [n_token, topk] const std::optional& expert_map, // [n_expert] int64_t n_expert, int64_t n_local_expert, int64_t topk, const std::optional& align_block_size, @@ -24,14 +24,14 @@ void moe_permute( "expert_first_token_offset must be int64"); TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int, "topk_ids must be int32"); - TORCH_CHECK(token_expert_indicies.scalar_type() == at::ScalarType::Int, - "token_expert_indicies must be int32"); + TORCH_CHECK(token_expert_indices.scalar_type() == at::ScalarType::Int, + "token_expert_indices must be int32"); TORCH_CHECK(inv_permuted_idx.scalar_type() == at::ScalarType::Int, "inv_permuted_idx must be int32"); TORCH_CHECK(expert_first_token_offset.size(0) == n_local_expert + 1, "expert_first_token_offset shape != n_local_expert+1") - TORCH_CHECK(inv_permuted_idx.sizes() == token_expert_indicies.sizes(), - "token_expert_indicies shape must be same as inv_permuted_idx"); + TORCH_CHECK(inv_permuted_idx.sizes() == token_expert_indices.sizes(), + "token_expert_indices shape must be same as inv_permuted_idx"); auto n_token = input.sizes()[0]; auto n_hidden = input.sizes()[1]; auto align_block_size_value = @@ -69,7 +69,7 @@ void moe_permute( } // expert sort topk expert id and scan expert id get expert_first_token_offset sortAndScanExpert( - get_ptr(copy_topk_ids), get_ptr(token_expert_indicies), + get_ptr(copy_topk_ids), get_ptr(token_expert_indices), get_ptr(permuted_experts_id), get_ptr(sorted_row_idx), get_ptr(expert_first_token_offset), n_token, n_expert, n_local_expert, topk, sorter, get_ptr(sort_workspace), stream); diff --git a/tests/kernels/moe/test_moe_permute_unpermute.py b/tests/kernels/moe/test_moe_permute_unpermute.py index 192256b6ba3d..8d215a0cbeed 100644 --- a/tests/kernels/moe/test_moe_permute_unpermute.py +++ b/tests/kernels/moe/test_moe_permute_unpermute.py @@ -222,6 +222,7 @@ def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int, a1q_scale=None, topk_ids=topk_ids, n_expert=n_expert, + n_local_expert=n_local_expert, expert_map=expert_map, align_block_size=align_block_size, fill_invalid_expert=fill_invalid_expert) @@ -232,7 +233,7 @@ def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int, atol=0, rtol=0) # check src_row_id2dst_row_id_map - torch.testing.assert_close(gold_inv_permuted_idx, + torch.testing.assert_close(gold_inv_permuted_idx.flatten(), inv_permuted_idx, atol=0, rtol=0) diff --git a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py index 26fe7182b091..d9059f50b445 100644 --- a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py +++ b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py @@ -79,6 +79,7 @@ def moe_permute( a1q_scale: Optional[torch.Tensor], topk_ids: torch.Tensor, n_expert: int, + n_local_expert: int = -1, expert_map: Optional[torch.Tensor] = None, align_block_size: Optional[int] = None, fill_invalid_expert: int = -1 @@ -92,6 +93,7 @@ def moe_permute( - a1q_scale (Optional[torch.Tensor]): quant scale for hidden_states - topk_ids (torch.Tensor): topk expert route id for each token. - n_expert (int): The number of expert. + - n_local_expert (int): The number of expert in current EP rank. - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices from the global expert space to the local expert space of the expert parallel shard. @@ -118,10 +120,8 @@ def moe_permute( permuted_row_size = (permuted_row_size + n_expert * (align_block_size - 1) + align_block_size - 1) // align_block_size * align_block_size - n_local_expert = n_expert - if expert_map is not None: - n_local_expert = torch.sum(expert_map != -1).item() - + if n_local_expert == -1: + n_local_expert = n_expert permuted_hidden_states = torch.empty( (permuted_row_size, n_hidden), dtype=hidden_states.dtype, @@ -147,6 +147,7 @@ def moe_permute( inv_permuted_idx = torch.empty((n_token, topk), dtype=torch.int32, device=hidden_states.device) + topk_ids = topk_ids.to(torch.int32) torch.ops._moe_C.moe_permute(hidden_states, topk_ids, token_expert_indices, expert_map, n_expert, n_local_expert, topk, align_block_size, permuted_hidden_states,