diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index afb8b9f426a2..87989354db6a 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -3,6 +3,8 @@ Run `pytest tests/kernels/test_moe.py`. """ +import unittest.mock as mock + import pytest import torch from torch.nn import Parameter @@ -40,6 +42,7 @@ @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("ep_size", EP_SIZE) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("padding", [True, False]) def test_fused_moe( m: int, n: int, @@ -48,20 +51,20 @@ def test_fused_moe( topk: int, ep_size: int, dtype: torch.dtype, + padding: bool, ): + if padding: + padding_size = 128 + envs.VLLM_MOE_PADDING = True + else: + padding_size = 0 + envs.VLLM_MOE_PADDING = False + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 score = torch.randn((m, e), device="cuda", dtype=dtype) - - # Pad the input if use padding - if envs.VLLM_MOE_PADDING: - w1 = F.pad(w1, (0, 128), "constant", 0) - torch.cuda.empty_cache() - w2 = F.pad(w2, (0, 128), "constant", 0) - torch.cuda.empty_cache() - if ep_size > 1: local_e = e // ep_size e_ids = torch.randint(0, @@ -75,16 +78,7 @@ def test_fused_moe( else: e_map = None - triton_output = fused_moe(a, - w1, - w2, - score, - topk, - global_num_experts=e, - expert_map=e_map, - renormalize=False) torch_output = torch_moe(a, w1, w2, score, topk, e_map) - torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) iterative_output = iterative_moe(a, w1, w2, @@ -93,6 +87,26 @@ def test_fused_moe( global_num_experts=e, expert_map=e_map, renormalize=False) + # Pad the input if use padding + if envs.VLLM_MOE_PADDING: + w1 = F.pad(w1, (0, 128), "constant", 0) + torch.cuda.empty_cache() + w2 = F.pad(w2, (0, 128), "constant", 0) + torch.cuda.empty_cache() + + with mock.patch( + 'vllm.model_executor.layers.fused_moe.fused_moe.padding_size', + padding_size): + triton_output = fused_moe(a, + w1, + w2, + score, + topk, + global_num_experts=e, + expert_map=e_map, + renormalize=False) + + torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) torch.testing.assert_close(iterative_output, torch_output, atol=1e-2, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 37be0fd9227f..7e3d05509bb5 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -719,6 +719,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, block_shape is not None and block_shape[1] > 0: assert B_scale is not None and B_scale.ndim == 3 assert B_zp is None or B_zp.ndim == 3 + assert padding_size == 0, "MoE padding is not supported " \ + "with GPTQ/AWQ quantization" fused_moe_kernel_gptq_awq[grid]( A, @@ -770,7 +772,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, expert_ids, num_tokens_post_padded, B.shape[1], - A.shape[1] - padding_size, + B.shape[2] - padding_size, EM, topk_ids.numel(), A.stride(0),