Skip to content

Commit a2cffea

Browse files
CalebDuheyselbi
authored andcommitted
Fix CUDA permute/unpermute for use with DeepGemm Moe (vllm-project#17934)
Signed-off-by: Caleb_Du <[email protected]>
1 parent 0dd9ef1 commit a2cffea

File tree

8 files changed

+236
-209
lines changed

8 files changed

+236
-209
lines changed

benchmarks/kernels/benchmark_moe_permute_unpermute.py

Lines changed: 43 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@
88
import torch
99
from transformers import AutoConfig
1010

11-
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
11+
from vllm.model_executor.layers.fused_moe.fused_moe import *
12+
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
1213
_moe_permute,
1314
_moe_unpermute_and_reduce,
15+
moe_permute,
16+
moe_unpermute,
1417
)
15-
from vllm.model_executor.layers.fused_moe.fused_moe import *
16-
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import *
1718
from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize
1819
from vllm.platforms import current_platform
1920
from vllm.utils import FlexibleArgumentParser
@@ -63,18 +64,19 @@ def prepare(i: int):
6364

6465
def run():
6566
if use_customized_permute:
66-
(permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = (
67-
moe_permute(
68-
qhidden_states,
69-
topk_weights=topk_weights,
70-
topk_ids=topk_ids,
71-
token_expert_indices=token_expert_indices,
72-
topk=topk,
73-
n_expert=num_experts,
74-
n_local_expert=num_experts,
75-
expert_map=None,
76-
align_block_size=align_block_size,
77-
)
67+
(
68+
permuted_hidden_states,
69+
a1q_scale,
70+
first_token_off,
71+
inv_perm_idx,
72+
m_indices,
73+
) = moe_permute(
74+
qhidden_states,
75+
a1q_scale=None,
76+
topk_ids=topk_ids,
77+
n_expert=num_experts,
78+
expert_map=None,
79+
align_block_size=align_block_size,
7880
)
7981
else:
8082
(
@@ -150,18 +152,19 @@ def benchmark_unpermute(
150152

151153
def prepare():
152154
if use_customized_permute:
153-
(permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = (
154-
moe_permute(
155-
qhidden_states,
156-
topk_weights=topk_weights,
157-
topk_ids=topk_ids,
158-
token_expert_indices=token_expert_indices,
159-
topk=topk,
160-
n_expert=num_experts,
161-
n_local_expert=num_experts,
162-
expert_map=None,
163-
align_block_size=align_block_size,
164-
)
155+
(
156+
permuted_hidden_states,
157+
a1q_scale,
158+
first_token_off,
159+
inv_perm_idx,
160+
m_indices,
161+
) = moe_permute(
162+
qhidden_states,
163+
a1q_scale=None,
164+
topk_ids=topk_ids,
165+
n_expert=num_experts,
166+
expert_map=None,
167+
align_block_size=align_block_size,
165168
)
166169
# convert to fp16/bf16 as gemm output
167170
return (
@@ -191,16 +194,19 @@ def prepare():
191194

192195
def run(input: tuple):
193196
if use_customized_permute:
194-
(permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = input
197+
(
198+
permuted_hidden_states,
199+
first_token_off,
200+
inv_perm_idx,
201+
m_indices,
202+
) = input
203+
output = torch.empty_like(hidden_states)
195204
moe_unpermute(
205+
output,
196206
permuted_hidden_states,
197207
topk_weights,
198-
topk_ids,
199208
inv_perm_idx,
200209
first_token_off,
201-
topk,
202-
num_experts,
203-
num_experts,
204210
)
205211
else:
206212
(
@@ -211,7 +217,11 @@ def run(input: tuple):
211217
inv_perm,
212218
) = input
213219
_moe_unpermute_and_reduce(
214-
output_hidden_states, permuted_hidden_states, inv_perm, topk_weights
220+
output_hidden_states,
221+
permuted_hidden_states,
222+
inv_perm,
223+
topk_weights,
224+
True,
215225
)
216226

217227
# JIT compilation & warmup

csrc/moe/moe_permute_unpermute_op.cu

Lines changed: 35 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -10,32 +10,28 @@
1010

1111
void moe_permute(
1212
const torch::Tensor& input, // [n_token, hidden]
13-
const torch::Tensor& topk_weights, //[n_token, topk]
14-
torch::Tensor& topk_ids, // [n_token, topk]
13+
const torch::Tensor& topk_ids, // [n_token, topk]
1514
const torch::Tensor& token_expert_indices, // [n_token, topk]
1615
const std::optional<torch::Tensor>& expert_map, // [n_expert]
1716
int64_t n_expert, int64_t n_local_expert, int64_t topk,
1817
const std::optional<int64_t>& align_block_size,
19-
torch::Tensor&
20-
permuted_input, // [topk * n_token/align_block_size_m, hidden]
18+
torch::Tensor& permuted_input, // [permuted_size, hidden]
2119
torch::Tensor& expert_first_token_offset, // [n_local_expert + 1]
22-
torch::Tensor& src_row_id2dst_row_id_map, // [n_token, topk]
20+
torch::Tensor& inv_permuted_idx, // [n_token, topk]
21+
torch::Tensor& permuted_idx, // [permute_size]
2322
torch::Tensor& m_indices) { // [align_expand_m]
24-
TORCH_CHECK(topk_weights.scalar_type() == at::ScalarType::Float,
25-
"topk_weights must be float32");
2623
TORCH_CHECK(expert_first_token_offset.scalar_type() == at::ScalarType::Long,
2724
"expert_first_token_offset must be int64");
2825
TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int,
2926
"topk_ids must be int32");
3027
TORCH_CHECK(token_expert_indices.scalar_type() == at::ScalarType::Int,
3128
"token_expert_indices must be int32");
32-
TORCH_CHECK(src_row_id2dst_row_id_map.scalar_type() == at::ScalarType::Int,
33-
"src_row_id2dst_row_id_map must be int32");
29+
TORCH_CHECK(inv_permuted_idx.scalar_type() == at::ScalarType::Int,
30+
"inv_permuted_idx must be int32");
3431
TORCH_CHECK(expert_first_token_offset.size(0) == n_local_expert + 1,
3532
"expert_first_token_offset shape != n_local_expert+1")
36-
TORCH_CHECK(
37-
src_row_id2dst_row_id_map.sizes() == token_expert_indices.sizes(),
38-
"token_expert_indices shape must be same as src_row_id2dst_row_id_map");
33+
TORCH_CHECK(inv_permuted_idx.sizes() == token_expert_indices.sizes(),
34+
"token_expert_indices shape must be same as inv_permuted_idx");
3935
auto n_token = input.sizes()[0];
4036
auto n_hidden = input.sizes()[1];
4137
auto align_block_size_value =
@@ -46,8 +42,9 @@ void moe_permute(
4642
auto sort_workspace = torch::empty(
4743
{sorter_size},
4844
torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false));
45+
auto copy_topk_ids = topk_ids.clone(); // copy topk_ids for preprocess
4946
auto permuted_experts_id = torch::empty_like(topk_ids);
50-
auto dst_row_id2src_row_id_map = torch::empty_like(src_row_id2dst_row_id_map);
47+
auto sorted_row_idx = torch::empty_like(inv_permuted_idx);
5148
auto align_expert_first_token_offset =
5249
torch::zeros_like(expert_first_token_offset);
5350

@@ -67,24 +64,22 @@ void moe_permute(
6764
const int* expert_map_ptr = get_ptr<int>(expert_map.value());
6865
valid_num_ptr =
6966
get_ptr<int64_t>(expert_first_token_offset) + n_local_expert;
70-
preprocessTopkIdLauncher(get_ptr<int>(topk_ids), n_token * topk,
67+
preprocessTopkIdLauncher(get_ptr<int>(copy_topk_ids), n_token * topk,
7168
expert_map_ptr, n_expert, stream);
7269
}
7370
// expert sort topk expert id and scan expert id get expert_first_token_offset
74-
sortAndScanExpert(get_ptr<int>(topk_ids), get_ptr<int>(token_expert_indices),
75-
get_ptr<int>(permuted_experts_id),
76-
get_ptr<int>(dst_row_id2src_row_id_map),
77-
get_ptr<int64_t>(expert_first_token_offset), n_token,
78-
n_expert, n_local_expert, topk, sorter,
79-
get_ptr<int>(sort_workspace), stream);
71+
sortAndScanExpert(
72+
get_ptr<int>(copy_topk_ids), get_ptr<int>(token_expert_indices),
73+
get_ptr<int>(permuted_experts_id), get_ptr<int>(sorted_row_idx),
74+
get_ptr<int64_t>(expert_first_token_offset), n_token, n_expert,
75+
n_local_expert, topk, sorter, get_ptr<int>(sort_workspace), stream);
8076

8177
// dispatch expandInputRowsKernelLauncher
8278
MOE_DISPATCH(input.scalar_type(), [&] {
8379
expandInputRowsKernelLauncher<scalar_t>(
8480
get_ptr<scalar_t>(input), get_ptr<scalar_t>(permuted_input),
85-
get_ptr<float>(topk_weights), get_ptr<int>(permuted_experts_id),
86-
get_ptr<int>(dst_row_id2src_row_id_map),
87-
get_ptr<int>(src_row_id2dst_row_id_map),
81+
get_ptr<int>(permuted_experts_id), get_ptr<int>(sorted_row_idx),
82+
get_ptr<int>(inv_permuted_idx), get_ptr<int>(permuted_idx),
8883
get_ptr<int64_t>(expert_first_token_offset), n_token, valid_num_ptr,
8984
n_hidden, topk, n_local_expert, align_block_size_value, stream);
9085
});
@@ -101,32 +96,34 @@ void moe_permute(
10196
}
10297

10398
void moe_unpermute(
104-
const torch::Tensor& permuted_hidden_states, // [n_token * topk, hidden]
105-
const torch::Tensor& topk_weights, //[n_token, topk]
106-
const torch::Tensor& topk_ids, // [n_token, topk]
107-
const torch::Tensor& src_row_id2dst_row_id_map, // [n_token, topk]
108-
const torch::Tensor& expert_first_token_offset, // [n_local_expert+1]
109-
int64_t n_expert, int64_t n_local_expert, int64_t topk,
99+
const torch::Tensor& permuted_hidden_states, // [n_token * topk, hidden]
100+
const torch::Tensor& topk_weights, // [n_token, topk]
101+
const torch::Tensor& inv_permuted_idx, // [n_token, topk]
102+
const std::optional<torch::Tensor>&
103+
expert_first_token_offset, // [n_local_expert+1]
104+
int64_t topk,
110105
torch::Tensor& hidden_states // [n_token, hidden]
111106
) {
112-
TORCH_CHECK(src_row_id2dst_row_id_map.sizes() == topk_ids.sizes(),
113-
"topk_ids shape must be same as src_row_id2dst_row_id_map");
114-
TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int,
115-
"topk_ids must be int32");
116107
TORCH_CHECK(
117108
permuted_hidden_states.scalar_type() == hidden_states.scalar_type(),
118-
"topk_ids dtype must be same as src_row_id2dst_row_id_map");
109+
"permuted_hidden_states dtype must be same as hidden_states");
119110
auto n_token = hidden_states.size(0);
120111
auto n_hidden = hidden_states.size(1);
121112
auto stream = at::cuda::getCurrentCUDAStream().stream();
122-
const int64_t* valid_ptr =
123-
get_ptr<int64_t>(expert_first_token_offset) + n_local_expert;
113+
114+
int64_t const* valid_ptr = nullptr;
115+
if (expert_first_token_offset.has_value()) {
116+
int n_local_expert = expert_first_token_offset.value().size(0) - 1;
117+
valid_ptr =
118+
get_ptr<int64_t>(expert_first_token_offset.value()) + n_local_expert;
119+
}
120+
124121
MOE_DISPATCH(hidden_states.scalar_type(), [&] {
125122
finalizeMoeRoutingKernelLauncher<scalar_t, scalar_t>(
126123
get_ptr<scalar_t>(permuted_hidden_states),
127124
get_ptr<scalar_t>(hidden_states), get_ptr<float>(topk_weights),
128-
get_ptr<int>(src_row_id2dst_row_id_map), get_ptr<int>(topk_ids),
129-
n_token, n_hidden, topk, valid_ptr, stream);
125+
get_ptr<int>(inv_permuted_idx), n_token, n_hidden, topk, valid_ptr,
126+
stream);
130127
});
131128
}
132129

csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ __global__ void getMIndicesKernel(int64_t* expert_first_token_offset,
177177
int tidx = threadIdx.x;
178178
extern __shared__ int64_t smem_expert_first_token_offset[];
179179
for (int i = tidx; i <= num_local_expert; i += blockDim.x) {
180-
smem_expert_first_token_offset[tidx] = __ldg(expert_first_token_offset + i);
180+
smem_expert_first_token_offset[i] = __ldg(expert_first_token_offset + i);
181181
}
182182
__syncthreads();
183183
auto last_token_offset = smem_expert_first_token_offset[eidx + 1];

csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -57,31 +57,19 @@ void sortAndScanExpert(int* expert_for_source_row, const int* source_rows,
5757

5858
template <typename T>
5959
void expandInputRowsKernelLauncher(
60-
T const* unpermuted_input, T* permuted_output,
61-
const float* unpermuted_scales, int* sorted_experts,
60+
T const* unpermuted_input, T* permuted_output, int* sorted_experts,
6261
int const* expanded_dest_row_to_expanded_source_row,
63-
int* expanded_source_row_to_expanded_dest_row,
62+
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
6463
int64_t* expert_first_token_offset, int64_t const num_rows,
6564
int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k,
6665
int num_local_experts, const int& align_block_size, cudaStream_t stream);
6766

68-
// Final kernel to unpermute and scale
69-
// This kernel unpermutes the original data, does the k-way reduction and
70-
// performs the final skip connection.
71-
template <typename T, typename OutputType, bool CHECK_SKIPPED>
72-
__global__ void finalizeMoeRoutingKernel(
73-
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
74-
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
75-
int const* expert_for_source_row, int64_t const orig_cols, int64_t const k,
76-
int64_t const* num_valid_ptr);
77-
7867
template <class T, class OutputType>
7968
void finalizeMoeRoutingKernelLauncher(
8069
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
8170
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
82-
int const* expert_for_source_row, int64_t const num_rows,
83-
int64_t const cols, int64_t const k, int64_t const* num_valid_ptr,
84-
cudaStream_t stream);
71+
int64_t const num_rows, int64_t const cols, int64_t const k,
72+
int64_t const* num_valid_ptr, cudaStream_t stream);
8573

8674
void preprocessTopkIdLauncher(int* topk_id_ptr, int size,
8775
const int* expert_map_ptr, int num_experts,

0 commit comments

Comments
 (0)