Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 31 additions & 17 deletions tests/kernels/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

Run `pytest tests/kernels/test_moe.py`.
"""
import unittest.mock as mock

import pytest
import torch
from torch.nn import Parameter
Expand Down Expand Up @@ -40,6 +42,7 @@
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("ep_size", EP_SIZE)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("padding", [True, False])
def test_fused_moe(
m: int,
n: int,
Expand All @@ -48,20 +51,20 @@ def test_fused_moe(
topk: int,
ep_size: int,
dtype: torch.dtype,
padding: bool,
):
if padding:
padding_size = 128
envs.VLLM_MOE_PADDING = True
else:
padding_size = 0
envs.VLLM_MOE_PADDING = False

a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10

score = torch.randn((m, e), device="cuda", dtype=dtype)

# Pad the input if use padding
if envs.VLLM_MOE_PADDING:
w1 = F.pad(w1, (0, 128), "constant", 0)
torch.cuda.empty_cache()
w2 = F.pad(w2, (0, 128), "constant", 0)
torch.cuda.empty_cache()

if ep_size > 1:
local_e = e // ep_size
e_ids = torch.randint(0,
Expand All @@ -75,16 +78,7 @@ def test_fused_moe(
else:
e_map = None

triton_output = fused_moe(a,
w1,
w2,
score,
topk,
global_num_experts=e,
expert_map=e_map,
renormalize=False)
torch_output = torch_moe(a, w1, w2, score, topk, e_map)
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
iterative_output = iterative_moe(a,
w1,
w2,
Expand All @@ -93,6 +87,26 @@ def test_fused_moe(
global_num_experts=e,
expert_map=e_map,
renormalize=False)
# Pad the input if use padding
if envs.VLLM_MOE_PADDING:
w1 = F.pad(w1, (0, 128), "constant", 0)
torch.cuda.empty_cache()
w2 = F.pad(w2, (0, 128), "constant", 0)
torch.cuda.empty_cache()

with mock.patch(
'vllm.model_executor.layers.fused_moe.fused_moe.padding_size',
padding_size):
triton_output = fused_moe(a,
w1,
w2,
score,
topk,
global_num_experts=e,
expert_map=e_map,
renormalize=False)

torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
torch.testing.assert_close(iterative_output,
torch_output,
atol=1e-2,
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
block_shape is not None and block_shape[1] > 0:
assert B_scale is not None and B_scale.ndim == 3
assert B_zp is None or B_zp.ndim == 3
assert padding_size == 0, "MoE padding is not supported " \
"with GPTQ/AWQ quantization"

fused_moe_kernel_gptq_awq[grid](
A,
Expand Down Expand Up @@ -770,7 +772,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
expert_ids,
num_tokens_post_padded,
B.shape[1],
A.shape[1] - padding_size,
B.shape[2] - padding_size,
EM,
topk_ids.numel(),
A.stride(0),
Expand Down