Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ def bench_run(
a, score, topk, renormalize=False
)

ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)

def run_triton_moe(
a: torch.Tensor,
w1: torch.Tensor,
Expand Down Expand Up @@ -111,6 +116,10 @@ def run_cutlass_moe(
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
per_act_token: bool,
Expand All @@ -125,6 +134,10 @@ def run_cutlass_moe(
topk_ids,
w1_scale,
w2_scale,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
per_act_token,
a1_scale=None,
)
Expand All @@ -136,6 +149,10 @@ def run_cutlass_from_graph(
w2_q: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
):
Expand All @@ -150,6 +167,10 @@ def run_cutlass_from_graph(
topk_ids,
w1_scale,
w2_scale,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
per_act_token,
a1_scale=None,
)
Expand Down Expand Up @@ -194,6 +215,10 @@ def replay_graph(graph, num_repeats):
w2_q,
w1_scale,
w2_scale,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
topk_weights,
topk_ids,
)
Expand Down Expand Up @@ -231,6 +256,10 @@ def replay_graph(graph, num_repeats):
"w1_scale": w1_scale,
"w2_scale": w2_scale,
"per_act_token": per_act_token,
"ab_strides1": ab_strides1,
"ab_strides2": ab_strides2,
"c_strides1": c_strides1,
"c_strides2": c_strides2,
# cuda graph params
"cutlass_graph": cutlass_graph,
"triton_graph": triton_graph,
Expand Down Expand Up @@ -289,6 +318,10 @@ def replay_graph(graph, num_repeats):
w2_q,
w1_scale,
w2_scale,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
topk_weights,
topk_ids,
per_act_token,
Expand All @@ -297,7 +330,7 @@ def replay_graph(graph, num_repeats):

results.append(
benchmark.Timer(
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, ab_strides1, ab_strides2, c_strides1, c_strides2, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
Expand Down
53 changes: 42 additions & 11 deletions csrc/moe/moe_permute_unpermute_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,30 @@ __global__ void shuffleInputRowsKernel(const T* input,
}
}

template <typename T>
__global__ void shuffleInputRowsKernelSlow(const T* input,
const int32_t* dst2src_map,
T* output, int64_t num_src_rows,
int64_t num_dst_rows,
int64_t num_cols) {
int64_t dest_row_idx = blockIdx.x;
int64_t const source_row_idx = dst2src_map[dest_row_idx];

if (blockIdx.x < num_dst_rows) {
// Duplicate and permute rows
auto const* source_row_ptr = input + source_row_idx * num_cols;
auto* dest_row_ptr = output + dest_row_idx * num_cols;

int64_t const start_offset = threadIdx.x;
int64_t const stride = blockDim.x;

for (int elem_index = start_offset; elem_index < num_cols;
elem_index += stride) {
dest_row_ptr[elem_index] = source_row_ptr[elem_index];
}
}
}

void shuffle_rows(const torch::Tensor& input_tensor,
const torch::Tensor& dst2src_map,
torch::Tensor& output_tensor) {
Expand All @@ -170,17 +194,24 @@ void shuffle_rows(const torch::Tensor& input_tensor,
int64_t const num_src_rows = input_tensor.size(0);
int64_t const num_cols = input_tensor.size(1);

TORCH_CHECK(!(num_cols % (128 / sizeof(input_tensor.scalar_type()) / 8)),
"num_cols must be divisible by 128 / "
"sizeof(input_tensor.scalar_type()) / 8");

MOE_DISPATCH(input_tensor.scalar_type(), [&] {
shuffleInputRowsKernel<scalar_t><<<blocks, threads, 0, stream>>>(
reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
dst2src_map.data_ptr<int32_t>(),
reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
num_dest_rows, num_cols);
});
if (num_cols % (128 / sizeof(input_tensor.scalar_type()) / 8)) {
// use slow kernel if num_cols can't be aligned to 128 bits
MOE_DISPATCH(input_tensor.scalar_type(), [&] {
shuffleInputRowsKernelSlow<scalar_t><<<blocks, threads, 0, stream>>>(
reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
dst2src_map.data_ptr<int32_t>(),
reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
num_dest_rows, num_cols);
});
} else {
MOE_DISPATCH(input_tensor.scalar_type(), [&] {
shuffleInputRowsKernel<scalar_t><<<blocks, threads, 0, stream>>>(
reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
dst2src_map.data_ptr<int32_t>(),
reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
num_dest_rows, num_cols);
});
}
Comment on lines +197 to +214
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There are two issues in this block:

  1. Critical Bug: The expression sizeof(input_tensor.scalar_type()) is incorrect for getting the size of the tensor's elements. input_tensor.scalar_type() returns a c10::ScalarType enum, and sizeof on it will return the size of the enum type itself (e.g., 1 or 4 bytes), not the size of the data type it represents. This will lead to an incorrect alignment check, which could cause the fast kernel path to be taken for unaligned inputs, leading to memory access errors or incorrect results. The correct way to get the element size is by using c10::elementSize(input_tensor.scalar_type()) or, inside the MOE_DISPATCH macro, sizeof(scalar_t).

  2. Code Duplication: The MOE_DISPATCH call is duplicated in the if and else branches. This can be simplified by moving the if/else logic inside the MOE_DISPATCH lambda, which improves maintainability and reduces code duplication.

Here is a suggested change that addresses both issues:

  MOE_DISPATCH(input_tensor.scalar_type(), [&] {
    if (num_cols % (128 / sizeof(scalar_t) / 8)) {
      // use slow kernel if num_cols can't be aligned to 128 bits
      shuffleInputRowsKernelSlow<scalar_t><<<blocks, threads, 0, stream>>>(
          reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
          dst2src_map.data_ptr<int32_t>(),
          reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
          num_dest_rows, num_cols);
    } else {
      shuffleInputRowsKernel<scalar_t><<<blocks, threads, 0, stream>>>(
          reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
          dst2src_map.data_ptr<int32_t>(),
          reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
          num_dest_rows, num_cols);
    }
  });

}

#else
Expand Down
14 changes: 12 additions & 2 deletions tests/kernels/moe/test_cutlass_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,10 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
'topk_ids': topk_ids,
'w1_scale': moe_tensors.w1_scale,
'w2_scale': moe_tensors.w2_scale,
'ab_strides1': moe_tensors.ab_strides1,
'ab_strides2': moe_tensors.ab_strides2,
'c_strides1': moe_tensors.c_strides1,
'c_strides2': moe_tensors.c_strides2,
'per_act_token': per_act_token,
'a1_scale': None #moe_tensors.a_scale
}
Expand Down Expand Up @@ -440,6 +444,11 @@ def test_run_cutlass_moe_fp8(
expert_map[start:end] = list(range(num_local_experts))
expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda")

ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)

activation = lambda o, i: torch.ops._C.silu_and_mul(o, i)
a1q, a1q_scale = moe_kernel_quantize_input(mt.a, mt.a_scale,
torch.float8_e4m3fn,
Expand All @@ -448,8 +457,9 @@ def test_run_cutlass_moe_fp8(
func = lambda output: run_cutlass_moe_fp8(
output, a1q, mt.w1_q, mt.w2_q, topk_ids, activation,
global_num_experts, expert_map, mt.w1_scale, mt.w2_scale,
a1q_scale, None, workspace13, workspace2, None, mt.a.dtype,
per_act_token, per_out_channel, False)
a1q_scale, None, ab_strides1, ab_strides2, c_strides1, c_strides2,
workspace13, workspace2, None, mt.a.dtype, per_act_token,
per_out_channel, False)

workspace13.random_()
output_random_workspace = torch.empty(output_shape,
Expand Down
22 changes: 22 additions & 0 deletions tests/kernels/moe/test_pplx_cutlass_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def pplx_cutlass_moe(
assert torch.cuda.current_device() == pgi.local_rank

num_tokens, hidden_dim = a.shape
intermediate_dim = w2.shape[2]
num_experts = w1.shape[0]
block_size = hidden_dim # TODO support more cases
device = pgi.device
Expand Down Expand Up @@ -123,10 +124,31 @@ def pplx_cutlass_moe(
num_local_experts=num_local_experts,
num_dispatchers=num_dispatchers)

ab_strides1 = torch.full((num_local_experts, ),
hidden_dim,
device="cuda",
dtype=torch.int64)
ab_strides2 = torch.full((num_local_experts, ),
intermediate_dim,
device="cuda",
dtype=torch.int64)
c_strides1 = torch.full((num_local_experts, ),
2 * intermediate_dim,
device="cuda",
dtype=torch.int64)
c_strides2 = torch.full((num_local_experts, ),
hidden_dim,
device="cuda",
dtype=torch.int64)

experts = CutlassExpertsFp8(num_local_experts,
out_dtype,
per_act_token,
per_out_ch,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
num_dispatchers=num_dispatchers,
use_batched_format=True)

Expand Down
62 changes: 39 additions & 23 deletions vllm/model_executor/layers/fused_moe/cutlass_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm,
_fp8_quantize,
from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize,
_resize_cache,
extract_required_args)
from vllm.scalar_type import scalar_types
Expand All @@ -35,6 +34,10 @@ def run_cutlass_moe_fp8(
w2_scale: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
Expand Down Expand Up @@ -153,27 +156,11 @@ def run_cutlass_moe_fp8(
problem_sizes1, problem_sizes2, a_map,
c_map, global_num_experts, N, K)

a1q = _fp8_perm(a1q, a_map)
a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale
a1q = ops.shuffle_rows(a1q, a_map)
a1q_scale = (ops.shuffle_rows(a1q_scale, a_map)
if per_act_token else a1q_scale)
expert_offsets = expert_offsets[:-1]

ab_strides1 = torch.full((w1.size(0), ),
K,
device=device,
dtype=torch.int64)
c_strides1 = torch.full((w1.size(0), ),
2 * N,
device=device,
dtype=torch.int64)
ab_strides2 = torch.full((w1.size(0), ),
N,
device=device,
dtype=torch.int64)
c_strides2 = torch.full((w1.size(0), ),
K,
device=device,
dtype=torch.int64)

if use_batched_format:
c1 = _resize_cache(workspace13, (local_E * padded_M, N * 2))
c2 = _resize_cache(workspace2, (local_E * padded_M, N))
Expand Down Expand Up @@ -210,7 +197,8 @@ def run_cutlass_moe_fp8(
else:
# We can't do this inplace because output may point to the same tensor
# as c3.
output.copy_(c3[c_map].view(M * topk, K), non_blocking=True)
output.copy_(ops.shuffle_rows(c3, c_map).view(M * topk, K),
non_blocking=True)


# TODO (bnell): split class batched vs. non-batched?
Expand All @@ -223,6 +211,10 @@ def __init__(
out_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
per_out_ch_quant: bool,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
block_shape: Optional[list[int]] = None,
num_dispatchers: Optional[int] = None,
use_batched_format: bool = False,
Expand All @@ -239,6 +231,10 @@ def __init__(
self.max_experts_per_worker = max_experts_per_worker
self.num_dispatchers = num_dispatchers
self.out_dtype = out_dtype
self.ab_strides1 = ab_strides1
self.ab_strides2 = ab_strides2
self.c_strides1 = c_strides1
self.c_strides2 = c_strides2
self.use_batched_format = use_batched_format

@property
Expand Down Expand Up @@ -318,7 +314,8 @@ def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
run_cutlass_moe_fp8(
output, hidden_states, w1, w2, topk_ids, activation_callable,
global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale,
a2_scale, workspace13, workspace2, expert_num_tokens,
a2_scale, self.ab_strides1, self.ab_strides2, self.c_strides1,
self.c_strides2, workspace13, workspace2, expert_num_tokens,
self.out_dtype if self.out_dtype is not None else in_dtype,
self.per_act_token_quant, self.per_out_ch_quant,
self.use_batched_format)
Expand All @@ -332,6 +329,10 @@ def cutlass_moe_fp8(
topk_ids: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
per_act_token: Optional[bool] = None,
activation: str = "silu",
a1_scale: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -359,6 +360,17 @@ def cutlass_moe_fp8(
Shape: [num_experts] or [num_experts, 2N]
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts] or [num_experts, K]
- ab_strides1 (torch.Tensor): The input/weight strides for the first gemm.
Shape: [num_experts]
- ab_strides2 (torch.Tensor): The input/weight strides for the second gemm.
Shape: [num_experts]
- c_strides1 (torch.Tensor): The output strides for the first gemm.
Shape: [num_experts]
- c_strides2 (torch.Tensor): The output strides for the second gemm.
Shape: [num_experts]
- per_act_token (Optional[bool]): Whether the scale is per-token or
per-tensor.
- activation (str): The activation function to use.
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
Shape: scalar or [M]
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
Expand Down Expand Up @@ -391,6 +403,10 @@ def cutlass_moe_fp8(
out_dtype=a.dtype,
per_act_token_quant=per_act_token,
per_out_ch_quant=per_out_ch,
ab_strides1=ab_strides1,
ab_strides2=ab_strides2,
c_strides1=c_strides1,
c_strides2=c_strides2,
use_batched_format=False,
),
)
Expand Down
Loading