Skip to content

Commit 34bb5b0

Browse files
authored
Merge pull request #4 from DhruvaBansal00/gptq-marlin-refactor
Refactoring for maintainability
2 parents 8f4648c + 315e3b6 commit 34bb5b0

File tree

13 files changed

+814
-606
lines changed

13 files changed

+814
-606
lines changed

tests/kernels/test_moe.py

Lines changed: 41 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
1111

1212
from vllm.model_executor.layers.activation import SiluAndMul
13-
from vllm.model_executor.layers.fused_moe import (fused_marlin_moe, fused_moe,
14-
single_marlin_moe)
13+
from vllm.model_executor.layers.fused_moe import fused_moe
14+
from vllm.model_executor.layers.fused_moe.fused_moe_marlin import (
15+
fused_moe_marlin, single_marlin_moe)
1516
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
1617
marlin_quantize)
1718
from vllm.model_executor.models.mixtral import MixtralMoE
@@ -63,11 +64,11 @@ def test_fused_moe(
6364
topk: int,
6465
dtype: torch.dtype,
6566
):
66-
a = torch.randn((m, k), device='cuda', dtype=dtype) / 10
67-
w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10
68-
w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10
67+
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
68+
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
69+
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
6970

70-
score = torch.randn((m, e), device='cuda', dtype=dtype)
71+
score = torch.randn((m, e), device="cuda", dtype=dtype)
7172
triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
7273
torch_output = torch_moe(a, w1, w2, score, topk)
7374
torch.testing.assert_close(triton_output, torch_output, atol=1e-2, rtol=0)
@@ -166,11 +167,11 @@ def test_fused_marlin_moe(
166167

167168
quant_type = scalar_types.uint4b8
168169
dtype = torch.float16
169-
a = torch.randn((m, k), device='cuda', dtype=dtype) / 10
170-
w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10
171-
w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10
170+
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
171+
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
172+
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
172173
for i in range(w2.shape[0]):
173-
w2[0] = torch.eye(k, n, device='cuda', dtype=dtype)
174+
w2[0] = torch.eye(k, n, device="cuda", dtype=dtype)
174175

175176
w_ref1_l = []
176177
qweight1_l = []
@@ -218,27 +219,32 @@ def test_fused_marlin_moe(
218219
g_idx2 = stack_and_dev(g_idx2_l)
219220
sort_indices2 = stack_and_dev(sort_indices2_l)
220221

221-
score = torch.randn((m, e), device='cuda', dtype=dtype)
222-
triton_output = fused_moe(a,
223-
w_ref1.transpose(1, 2).contiguous(),
224-
w_ref2.transpose(1, 2).contiguous(),
225-
score,
226-
topk,
227-
renormalize=False)
228-
marlin_output = fused_marlin_moe(a,
229-
qweight1,
230-
qweight2,
231-
score,
232-
g_idx1,
233-
g_idx2,
234-
sort_indices1,
235-
sort_indices2,
236-
topk,
237-
renormalize=False,
238-
w1_scale=scales1,
239-
w2_scale=scales2)
240-
241-
assert (compute_max_diff(marlin_output, triton_output) < 4e-2)
222+
score = torch.randn((m, e), device="cuda", dtype=dtype)
223+
triton_output = fused_moe(
224+
a,
225+
w_ref1.transpose(1, 2).contiguous(),
226+
w_ref2.transpose(1, 2).contiguous(),
227+
score,
228+
topk,
229+
renormalize=False,
230+
)
231+
marlin_output = fused_moe_marlin(
232+
a,
233+
qweight1,
234+
qweight2,
235+
score,
236+
g_idx1,
237+
g_idx2,
238+
sort_indices1,
239+
sort_indices2,
240+
topk,
241+
renormalize=False,
242+
w1_scale=scales1,
243+
w2_scale=scales2,
244+
num_bits=4,
245+
)
246+
247+
assert compute_max_diff(marlin_output, triton_output) < 4e-2
242248

243249

244250
# TODO: make sure this test works
@@ -275,8 +281,8 @@ def test_single_marlin_moe(
275281

276282
quant_type = scalar_types.uint4b8
277283
dtype = torch.float16
278-
a = torch.randn((m, k), device='cuda', dtype=dtype) / 10
279-
w = torch.randn((e, n, k), device='cuda', dtype=dtype) / 10
284+
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
285+
w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
280286

281287
w_ref_l = []
282288
qweights_l = []
@@ -300,7 +306,7 @@ def test_single_marlin_moe(
300306
g_idx = stack_and_dev(g_idx_l)
301307
sort_indices = stack_and_dev(sort_indices_l)
302308

303-
score = torch.randn((m, e), device='cuda', dtype=dtype)
309+
score = torch.randn((m, e), device="cuda", dtype=dtype)
304310
marlin_output = single_marlin_moe(a,
305311
qweight,
306312
scales,
@@ -311,4 +317,4 @@ def test_single_marlin_moe(
311317
renormalize=False)
312318
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)
313319

314-
assert (compute_max_diff(marlin_output, torch_output) < 1e-2)
320+
assert compute_max_diff(marlin_output, torch_output) < 1e-2

vllm/_custom_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
304304
size_k: int, size_n: int,
305305
num_bits: int) -> torch.Tensor:
306306
num_experts = b_q_weight.shape[0]
307-
output = torch.empty((num_experts, size_k // 16, size_n * 2),
307+
output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)),
308308
device=b_q_weight.device,
309309
dtype=b_q_weight.dtype)
310310
for e in range(num_experts):

vllm/model_executor/layers/fused_moe/__init__.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
1-
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_marlin_moe,
2-
single_marlin_moe)
1+
from vllm.model_executor.layers.fused_moe.fused_moe_marlin import (
2+
fused_moe_marlin, single_marlin_moe)
33
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
44
FusedMoEMethodBase)
55
from vllm.triton_utils import HAS_TRITON
66

77
__all__ = [
88
"FusedMoE",
99
"FusedMoEMethodBase",
10-
"fused_marlin_moe",
10+
"fused_moe_marlin",
1111
"single_marlin_moe",
1212
]
1313

1414
if HAS_TRITON:
15-
1615
from vllm.model_executor.layers.fused_moe.fused_moe import (
1716
fused_experts, fused_moe, fused_topk, get_config_file_name,
1817
grouped_topk)

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 0 additions & 178 deletions
Original file line numberDiff line numberDiff line change
@@ -666,181 +666,3 @@ def fused_moe(
666666
w2_scale=w2_scale,
667667
a1_scale=a1_scale,
668668
a2_scale=a2_scale)
669-
670-
671-
def single_marlin_moe(
672-
hidden_states: torch.Tensor,
673-
w: torch.Tensor,
674-
scales: torch.Tensor,
675-
gating_output: torch.Tensor,
676-
g_idx: torch.Tensor,
677-
rand_perm: torch.Tensor,
678-
topk: int,
679-
renormalize: bool,
680-
override_config: Optional[Dict[str, Any]] = None,
681-
use_fp8: bool = False,
682-
) -> torch.Tensor:
683-
"""
684-
This function computes a Marlin MoE MMM using weights w
685-
and top-k gating mechanism. It is meant for testing and debugging.
686-
687-
Parameters:
688-
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
689-
- w (torch.Tensor): The first set of expert weights.
690-
- gating_output (torch.Tensor): The output of the gating operation
691-
(before softmax).
692-
- topk (int): The number of top-k experts to select.
693-
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
694-
- inplace (bool): If True, perform the operation in-place.
695-
Defaults to False.
696-
- override_config (Optional[Dict[str, Any]]): Optional override
697-
for the kernel configuration.
698-
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
699-
products for w and w2. Defaults to False.
700-
701-
Returns:
702-
- torch.Tensor: The output tensor after applying the MoE layer.
703-
"""
704-
# Check constraints.
705-
assert hidden_states.shape[0] == gating_output.shape[0], (
706-
"Number of tokens mismatch")
707-
assert hidden_states.shape[1] == w.shape[1] * 16, "Hidden size mismatch"
708-
assert gating_output.shape[1] == w.shape[0], "Number of experts mismatch"
709-
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
710-
assert w.is_contiguous(), "Expert weights must be contiguous"
711-
assert hidden_states.dtype in [
712-
torch.float32, torch.float16, torch.bfloat16
713-
]
714-
M, K = hidden_states.shape
715-
E = w.shape[0]
716-
N = w.shape[2] // 2
717-
718-
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
719-
renormalize)
720-
721-
# This might not be an optimal config for a single MMM
722-
get_config_func = functools.partial(try_get_optimal_moe_config,
723-
w.shape,
724-
w.shape,
725-
topk_ids.shape[1],
726-
"float8" if use_fp8 else None,
727-
override_config=override_config,
728-
is_marlin=True)
729-
config = get_config_func(M)
730-
731-
block_size_m = config['BLOCK_SIZE_M']
732-
733-
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)
734-
735-
max_workspace_size = (N // 64) * 16
736-
workspace = torch.zeros(max_workspace_size,
737-
dtype=torch.int,
738-
device="cuda",
739-
requires_grad=False)
740-
741-
intermediate_cache = torch.ops._moe_C.marlin_gemm_moe(
742-
hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales,
743-
g_idx, rand_perm, workspace, M, N, K, True, E, topk, block_size_m,
744-
True, False)
745-
746-
return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
747-
748-
749-
def fused_marlin_moe(hidden_states: torch.Tensor,
750-
w1: torch.Tensor,
751-
w2: torch.Tensor,
752-
gating_output: torch.Tensor,
753-
g_idx1: torch.Tensor,
754-
g_idx2: torch.Tensor,
755-
rand_perm1: torch.Tensor,
756-
rand_perm2: torch.Tensor,
757-
topk: int,
758-
renormalize: bool,
759-
override_config: Optional[Dict[str, Any]] = None,
760-
use_fp8: bool = False,
761-
w1_scale: Optional[torch.Tensor] = None,
762-
w2_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
763-
"""
764-
This function computes a Mixture of Experts (MoE) layer using two sets of
765-
weights, w1 and w2, and top-k gating mechanism.
766-
767-
Parameters:
768-
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
769-
- w1 (torch.Tensor): The first set of expert weights.
770-
- w2 (torch.Tensor): The second set of expert weights.
771-
- gating_output (torch.Tensor): The output of the gating operation
772-
(before softmax).
773-
- topk (int): The number of top-k experts to select.
774-
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
775-
- inplace (bool): If True, perform the operation in-place.
776-
Defaults to False.
777-
- override_config (Optional[Dict[str, Any]]): Optional override
778-
for the kernel configuration.
779-
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
780-
products for w1 and w2. Defaults to False.
781-
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
782-
w1.
783-
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
784-
w2.
785-
786-
Returns:
787-
- torch.Tensor: The output tensor after applying the MoE layer.
788-
"""
789-
# Check constraints.
790-
assert hidden_states.shape[0] == gating_output.shape[0], (
791-
"Number of tokens mismatch")
792-
assert hidden_states.shape[
793-
1] == w1.shape[1] * 16, "Hidden size mismatch w1"
794-
assert hidden_states.shape[
795-
1] == w2.shape[2] // 2, "Hidden size mismatch w2"
796-
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
797-
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
798-
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
799-
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
800-
assert hidden_states.dtype in [
801-
torch.float32, torch.float16, torch.bfloat16
802-
]
803-
M, K = hidden_states.shape
804-
E = w1.shape[0]
805-
N = w2.shape[1] * 16
806-
807-
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
808-
renormalize)
809-
810-
get_config_func = functools.partial(try_get_optimal_moe_config,
811-
w1.shape,
812-
w2.shape,
813-
topk_ids.shape[1],
814-
"float8" if use_fp8 else None,
815-
override_config=override_config,
816-
is_marlin=True)
817-
config = get_config_func(M)
818-
819-
block_size_m = config['BLOCK_SIZE_M']
820-
821-
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)
822-
823-
max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16
824-
workspace = torch.zeros(max_workspace_size,
825-
dtype=torch.int,
826-
device="cuda",
827-
requires_grad=False)
828-
829-
intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N),
830-
device=hidden_states.device,
831-
dtype=hidden_states.dtype)
832-
833-
intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe(
834-
hidden_states, w1, sorted_token_ids, topk_weights, topk_ids, w1_scale,
835-
g_idx1, rand_perm1, workspace, M, 2 * N, K, True, E, topk,
836-
block_size_m, True, False)
837-
838-
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N))
839-
840-
intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe(
841-
intermediate_cache2, w2, sorted_token_ids, topk_weights, topk_ids,
842-
w2_scale, g_idx2, rand_perm2, workspace, M, K, N, True, E, topk,
843-
block_size_m, False, True)
844-
845-
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
846-
dim=1)

0 commit comments

Comments
 (0)