Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 43 additions & 33 deletions benchmarks/kernels/benchmark_moe_permute_unpermute.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
import torch
from transformers import AutoConfig

from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
_moe_permute,
_moe_unpermute_and_reduce,
moe_permute,
moe_unpermute,
)
from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import *
from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize
from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser
Expand Down Expand Up @@ -63,18 +64,19 @@ def prepare(i: int):

def run():
if use_customized_permute:
(permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = (
moe_permute(
qhidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
token_expert_indices=token_expert_indices,
topk=topk,
n_expert=num_experts,
n_local_expert=num_experts,
expert_map=None,
align_block_size=align_block_size,
)
(
permuted_hidden_states,
a1q_scale,
first_token_off,
inv_perm_idx,
m_indices,
) = moe_permute(
qhidden_states,
a1q_scale=None,
topk_ids=topk_ids,
n_expert=num_experts,
expert_map=None,
align_block_size=align_block_size,
)
else:
(
Expand Down Expand Up @@ -150,18 +152,19 @@ def benchmark_unpermute(

def prepare():
if use_customized_permute:
(permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = (
moe_permute(
qhidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
token_expert_indices=token_expert_indices,
topk=topk,
n_expert=num_experts,
n_local_expert=num_experts,
expert_map=None,
align_block_size=align_block_size,
)
(
permuted_hidden_states,
a1q_scale,
first_token_off,
inv_perm_idx,
m_indices,
) = moe_permute(
qhidden_states,
a1q_scale=None,
topk_ids=topk_ids,
n_expert=num_experts,
expert_map=None,
align_block_size=align_block_size,
)
# convert to fp16/bf16 as gemm output
return (
Expand Down Expand Up @@ -191,16 +194,19 @@ def prepare():

def run(input: tuple):
if use_customized_permute:
(permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = input
(
permuted_hidden_states,
first_token_off,
inv_perm_idx,
m_indices,
) = input
output = torch.empty_like(hidden_states)
moe_unpermute(
output,
permuted_hidden_states,
topk_weights,
topk_ids,
inv_perm_idx,
first_token_off,
topk,
num_experts,
num_experts,
)
else:
(
Expand All @@ -211,7 +217,11 @@ def run(input: tuple):
inv_perm,
) = input
_moe_unpermute_and_reduce(
output_hidden_states, permuted_hidden_states, inv_perm, topk_weights
output_hidden_states,
permuted_hidden_states,
inv_perm,
topk_weights,
True,
)

# JIT compilation & warmup
Expand Down
73 changes: 35 additions & 38 deletions csrc/moe/moe_permute_unpermute_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,32 +10,28 @@

void moe_permute(
const torch::Tensor& input, // [n_token, hidden]
const torch::Tensor& topk_weights, //[n_token, topk]
torch::Tensor& topk_ids, // [n_token, topk]
const torch::Tensor& topk_ids, // [n_token, topk]
const torch::Tensor& token_expert_indices, // [n_token, topk]
const std::optional<torch::Tensor>& expert_map, // [n_expert]
int64_t n_expert, int64_t n_local_expert, int64_t topk,
const std::optional<int64_t>& align_block_size,
torch::Tensor&
permuted_input, // [topk * n_token/align_block_size_m, hidden]
torch::Tensor& permuted_input, // [permuted_size, hidden]
torch::Tensor& expert_first_token_offset, // [n_local_expert + 1]
torch::Tensor& src_row_id2dst_row_id_map, // [n_token, topk]
torch::Tensor& inv_permuted_idx, // [n_token, topk]
torch::Tensor& permuted_idx, // [permute_size]
torch::Tensor& m_indices) { // [align_expand_m]
TORCH_CHECK(topk_weights.scalar_type() == at::ScalarType::Float,
"topk_weights must be float32");
TORCH_CHECK(expert_first_token_offset.scalar_type() == at::ScalarType::Long,
"expert_first_token_offset must be int64");
TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int,
"topk_ids must be int32");
TORCH_CHECK(token_expert_indices.scalar_type() == at::ScalarType::Int,
"token_expert_indices must be int32");
TORCH_CHECK(src_row_id2dst_row_id_map.scalar_type() == at::ScalarType::Int,
"src_row_id2dst_row_id_map must be int32");
TORCH_CHECK(inv_permuted_idx.scalar_type() == at::ScalarType::Int,
"inv_permuted_idx must be int32");
TORCH_CHECK(expert_first_token_offset.size(0) == n_local_expert + 1,
"expert_first_token_offset shape != n_local_expert+1")
TORCH_CHECK(
src_row_id2dst_row_id_map.sizes() == token_expert_indices.sizes(),
"token_expert_indices shape must be same as src_row_id2dst_row_id_map");
TORCH_CHECK(inv_permuted_idx.sizes() == token_expert_indices.sizes(),
"token_expert_indices shape must be same as inv_permuted_idx");
auto n_token = input.sizes()[0];
auto n_hidden = input.sizes()[1];
auto align_block_size_value =
Expand All @@ -46,8 +42,9 @@ void moe_permute(
auto sort_workspace = torch::empty(
{sorter_size},
torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false));
auto copy_topk_ids = topk_ids.clone(); // copy topk_ids for preprocess
auto permuted_experts_id = torch::empty_like(topk_ids);
auto dst_row_id2src_row_id_map = torch::empty_like(src_row_id2dst_row_id_map);
auto sorted_row_idx = torch::empty_like(inv_permuted_idx);
auto align_expert_first_token_offset =
torch::zeros_like(expert_first_token_offset);

Expand All @@ -67,24 +64,22 @@ void moe_permute(
const int* expert_map_ptr = get_ptr<int>(expert_map.value());
valid_num_ptr =
get_ptr<int64_t>(expert_first_token_offset) + n_local_expert;
preprocessTopkIdLauncher(get_ptr<int>(topk_ids), n_token * topk,
preprocessTopkIdLauncher(get_ptr<int>(copy_topk_ids), n_token * topk,
expert_map_ptr, n_expert, stream);
}
// expert sort topk expert id and scan expert id get expert_first_token_offset
sortAndScanExpert(get_ptr<int>(topk_ids), get_ptr<int>(token_expert_indices),
get_ptr<int>(permuted_experts_id),
get_ptr<int>(dst_row_id2src_row_id_map),
get_ptr<int64_t>(expert_first_token_offset), n_token,
n_expert, n_local_expert, topk, sorter,
get_ptr<int>(sort_workspace), stream);
sortAndScanExpert(
get_ptr<int>(copy_topk_ids), get_ptr<int>(token_expert_indices),
get_ptr<int>(permuted_experts_id), get_ptr<int>(sorted_row_idx),
get_ptr<int64_t>(expert_first_token_offset), n_token, n_expert,
n_local_expert, topk, sorter, get_ptr<int>(sort_workspace), stream);

// dispatch expandInputRowsKernelLauncher
MOE_DISPATCH(input.scalar_type(), [&] {
expandInputRowsKernelLauncher<scalar_t>(
get_ptr<scalar_t>(input), get_ptr<scalar_t>(permuted_input),
get_ptr<float>(topk_weights), get_ptr<int>(permuted_experts_id),
get_ptr<int>(dst_row_id2src_row_id_map),
get_ptr<int>(src_row_id2dst_row_id_map),
get_ptr<int>(permuted_experts_id), get_ptr<int>(sorted_row_idx),
get_ptr<int>(inv_permuted_idx), get_ptr<int>(permuted_idx),
get_ptr<int64_t>(expert_first_token_offset), n_token, valid_num_ptr,
n_hidden, topk, n_local_expert, align_block_size_value, stream);
});
Expand All @@ -101,32 +96,34 @@ void moe_permute(
}

void moe_unpermute(
const torch::Tensor& permuted_hidden_states, // [n_token * topk, hidden]
const torch::Tensor& topk_weights, //[n_token, topk]
const torch::Tensor& topk_ids, // [n_token, topk]
const torch::Tensor& src_row_id2dst_row_id_map, // [n_token, topk]
const torch::Tensor& expert_first_token_offset, // [n_local_expert+1]
int64_t n_expert, int64_t n_local_expert, int64_t topk,
const torch::Tensor& permuted_hidden_states, // [n_token * topk, hidden]
const torch::Tensor& topk_weights, // [n_token, topk]
const torch::Tensor& inv_permuted_idx, // [n_token, topk]
const std::optional<torch::Tensor>&
expert_first_token_offset, // [n_local_expert+1]
int64_t topk,
torch::Tensor& hidden_states // [n_token, hidden]
) {
TORCH_CHECK(src_row_id2dst_row_id_map.sizes() == topk_ids.sizes(),
"topk_ids shape must be same as src_row_id2dst_row_id_map");
TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int,
"topk_ids must be int32");
TORCH_CHECK(
permuted_hidden_states.scalar_type() == hidden_states.scalar_type(),
"topk_ids dtype must be same as src_row_id2dst_row_id_map");
"permuted_hidden_states dtype must be same as hidden_states");
auto n_token = hidden_states.size(0);
auto n_hidden = hidden_states.size(1);
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int64_t* valid_ptr =
get_ptr<int64_t>(expert_first_token_offset) + n_local_expert;

int64_t const* valid_ptr = nullptr;
if (expert_first_token_offset.has_value()) {
int n_local_expert = expert_first_token_offset.value().size(0) - 1;
valid_ptr =
get_ptr<int64_t>(expert_first_token_offset.value()) + n_local_expert;
}

MOE_DISPATCH(hidden_states.scalar_type(), [&] {
finalizeMoeRoutingKernelLauncher<scalar_t, scalar_t>(
get_ptr<scalar_t>(permuted_hidden_states),
get_ptr<scalar_t>(hidden_states), get_ptr<float>(topk_weights),
get_ptr<int>(src_row_id2dst_row_id_map), get_ptr<int>(topk_ids),
n_token, n_hidden, topk, valid_ptr, stream);
get_ptr<int>(inv_permuted_idx), n_token, n_hidden, topk, valid_ptr,
stream);
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ __global__ void getMIndicesKernel(int64_t* expert_first_token_offset,
int tidx = threadIdx.x;
extern __shared__ int64_t smem_expert_first_token_offset[];
for (int i = tidx; i <= num_local_expert; i += blockDim.x) {
smem_expert_first_token_offset[tidx] = __ldg(expert_first_token_offset + i);
smem_expert_first_token_offset[i] = __ldg(expert_first_token_offset + i);
}
__syncthreads();
auto last_token_offset = smem_expert_first_token_offset[eidx + 1];
Expand Down
20 changes: 4 additions & 16 deletions csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,31 +57,19 @@ void sortAndScanExpert(int* expert_for_source_row, const int* source_rows,

template <typename T>
void expandInputRowsKernelLauncher(
T const* unpermuted_input, T* permuted_output,
const float* unpermuted_scales, int* sorted_experts,
T const* unpermuted_input, T* permuted_output, int* sorted_experts,
int const* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row,
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
int64_t* expert_first_token_offset, int64_t const num_rows,
int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k,
int num_local_experts, const int& align_block_size, cudaStream_t stream);

// Final kernel to unpermute and scale
// This kernel unpermutes the original data, does the k-way reduction and
// performs the final skip connection.
template <typename T, typename OutputType, bool CHECK_SKIPPED>
__global__ void finalizeMoeRoutingKernel(
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
int const* expert_for_source_row, int64_t const orig_cols, int64_t const k,
int64_t const* num_valid_ptr);

template <class T, class OutputType>
void finalizeMoeRoutingKernelLauncher(
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
int const* expert_for_source_row, int64_t const num_rows,
int64_t const cols, int64_t const k, int64_t const* num_valid_ptr,
cudaStream_t stream);
int64_t const num_rows, int64_t const cols, int64_t const k,
int64_t const* num_valid_ptr, cudaStream_t stream);

void preprocessTopkIdLauncher(int* topk_id_ptr, int size,
const int* expert_map_ptr, int num_experts,
Expand Down
Loading