Skip to content

Commit f3b4cf1

Browse files
committed
TEMP Mostly working
Signed-off-by: Luka Govedič <[email protected]>
1 parent 21d7d67 commit f3b4cf1

File tree

5 files changed

+204
-104
lines changed

5 files changed

+204
-104
lines changed

tests/compile/test_fusion.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from vllm.compilation.fusion import (
99
FUSED_OPS,
1010
QUANT_OPS,
11+
RMS_OP,
1112
FusedRMSQuantKey,
1213
RMSNormQuantFusionPass,
1314
)
@@ -65,6 +66,9 @@ def __init__(
6566
act_quant_group_shape=group_shape,
6667
)
6768

69+
self.enable_rms_norm = self.norm[0].enabled()
70+
self.enable_quant_fp8 = self.fp8_linear.quant_fp8.enabled()
71+
6872
def forward(self, x):
6973
resid = torch.sqrt(x)
7074
y = self.norm[0](x)
@@ -82,7 +86,18 @@ def forward(self, x):
8286
return y3
8387

8488
def ops_in_model_before(self):
85-
return [QUANT_OPS[self.key]]
89+
ops = []
90+
if self.enable_rms_norm:
91+
ops += [RMS_OP]
92+
else:
93+
ops += [torch.ops.aten.rsqrt.default]
94+
95+
if self.enable_quant_fp8:
96+
ops += [QUANT_OPS[self.key]]
97+
else:
98+
ops += [torch.ops.aten.reciprocal.default]
99+
100+
return ops
86101

87102
def ops_in_model_after(self):
88103
return [
@@ -91,11 +106,13 @@ def ops_in_model_after(self):
91106
]
92107

93108

94-
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
109+
@pytest.mark.parametrize("dtype", [torch.float16]) # , torch.bfloat16])
95110
@pytest.mark.parametrize("hidden_size", [64])
96111
@pytest.mark.parametrize("num_tokens", [257])
97112
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
98113
@pytest.mark.parametrize("static", [True, False])
114+
@pytest.mark.parametrize("enable_rms_norm", [True]) # , False])
115+
@pytest.mark.parametrize("enable_quant_fp8", [True]) # , False])
99116
# cuda_force_torch used to test torch code path on platforms that
100117
# cutlass_fp8_supported() == True.
101118
@pytest.mark.parametrize(
@@ -105,17 +122,29 @@ def ops_in_model_after(self):
105122
not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm"
106123
)
107124
def test_fusion_rmsnorm_quant(
108-
dtype, hidden_size, num_tokens, eps, static, cuda_force_torch
125+
dtype,
126+
hidden_size,
127+
num_tokens,
128+
eps,
129+
static,
130+
enable_rms_norm,
131+
enable_quant_fp8,
132+
cuda_force_torch,
109133
):
110134
torch.set_default_device("cuda")
111135
torch.set_default_dtype(dtype)
112136
torch.manual_seed(1)
113137
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths
114138

139+
custom_ops = []
140+
if enable_rms_norm:
141+
custom_ops.append("+rms_norm")
142+
if enable_quant_fp8:
143+
custom_ops.append("+quant_fp8")
115144
vllm_config = VllmConfig(
116145
compilation_config=CompilationConfig(
117146
level=CompilationLevel.PIECEWISE,
118-
custom_ops=["+rms_norm", "+quant_fp8"],
147+
custom_ops=custom_ops,
119148
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
120149
)
121150
)

vllm/_custom_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1507,7 +1507,7 @@ def scaled_fp8_quant(
15071507
output, input, scale, scale_ub
15081508
)
15091509
else:
1510-
scale = torch.empty(1, device=input.device, dtype=torch.float32)
1510+
scale = torch.empty((1, 1), device=input.device, dtype=torch.float32)
15111511
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
15121512
else:
15131513
assert scale.numel() == 1, f"{scale.shape}"

vllm/compilation/fusion.py

Lines changed: 20 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from vllm.platforms import current_platform
2525

2626
from .inductor_pass import enable_fake_mode
27+
from .matcher_utils import MatcherQuant, MatcherRMSNorm
2728
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
2829

2930
logger = init_logger(__name__)
@@ -99,6 +100,9 @@ def __init__(self, epsilon: float, key: FusedRMSQuantKey):
99100
assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
100101
self.FUSED_OP = FUSED_OPS[key]
101102

103+
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
104+
self.quant_matcher = MatcherQuant(key.quant)
105+
102106

103107
class RMSNormStaticQuantPattern(RMSNormQuantPattern):
104108
def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True):
@@ -113,25 +117,8 @@ def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True):
113117
def register(self, pm_pass: PatternMatcherPass):
114118
# Cannot use methods, as the self argument affects tracing
115119
def pattern(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
116-
result_rms = torch.empty_like(input)
117-
# TODO: why does empty_like produce a permute but
118-
# empty via shape doesn't?
119-
result = torch.empty(
120-
input.shape, device=input.device, dtype=self.quant_dtype
121-
)
122-
at1 = auto_functionalized(
123-
RMS_OP,
124-
result=result_rms,
125-
input=input,
126-
weight=weight,
127-
epsilon=self.epsilon,
128-
)
129-
at2 = auto_functionalized(
130-
self.QUANT_OP, result=result, input=at1[1], scale=scale
131-
)
132-
133-
# result
134-
return at2[1]
120+
result_rms = self.rmsnorm_matcher(input, weight)
121+
return self.quant_matcher(result_rms, scale)
135122

136123
def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
137124
result = torch.empty_like(input, dtype=self.quant_dtype)
@@ -173,22 +160,10 @@ def pattern(
173160
weight: torch.Tensor,
174161
scale: torch.Tensor,
175162
):
176-
result = torch.empty(
177-
input.shape, device=input.device, dtype=self.quant_dtype
178-
)
179-
at = auto_functionalized(
180-
RMS_ADD_OP,
181-
input=input,
182-
residual=residual,
183-
weight=weight,
184-
epsilon=self.epsilon,
185-
)
186-
at1 = auto_functionalized(
187-
self.QUANT_OP, result=result, input=at[1], scale=scale
188-
)
163+
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
164+
result = self.quant_matcher(result_rms, scale)
189165

190-
# result, residual
191-
return at1[1], at[2]
166+
return result, residual
192167

193168
def replacement(
194169
input: torch.Tensor,
@@ -242,27 +217,14 @@ def __init__(
242217
super().__init__(epsilon, key)
243218

244219
def register(self, pm_pass: PatternMatcherPass):
245-
def pattern(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
246-
result_rms = torch.empty_like(input)
247-
result = torch.empty(
248-
input.shape, device=input.device, dtype=self.quant_dtype
249-
)
250-
at1 = auto_functionalized(
251-
RMS_OP,
252-
result=result_rms,
253-
input=input,
254-
weight=weight,
255-
epsilon=self.epsilon,
256-
)
257-
at2 = auto_functionalized(
258-
self.QUANT_OP, result=result, input=at1[1], scale=scale, scale_ub=None
259-
)
260-
220+
def pattern(input: torch.Tensor, weight: torch.Tensor):
221+
result_rms = self.rmsnorm_matcher(input, weight)
261222
# result, scale
262-
return at2[1], at2[2]
223+
return self.quant_matcher(result_rms)
263224

264-
def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
225+
def replacement(input: torch.Tensor, weight: torch.Tensor):
265226
result = torch.empty_like(input, dtype=self.quant_dtype)
227+
scale = self.quant_matcher.make_scale(input)
266228
at = auto_functionalized(
267229
self.FUSED_OP,
268230
result=result,
@@ -280,7 +242,6 @@ def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
280242
inputs = [
281243
empty_bf16(5, 4), # input
282244
empty_bf16(1, 5), # weight
283-
empty_fp32(1, 1), # scale
284245
]
285246

286247
pm.register_replacement(
@@ -308,36 +269,17 @@ def __init__(
308269
super().__init__(epsilon, key)
309270

310271
def register(self, pm_pass: PatternMatcherPass):
311-
def pattern(
312-
input: torch.Tensor,
313-
residual: torch.Tensor,
314-
weight: torch.Tensor,
315-
scale: torch.Tensor,
316-
):
317-
result = torch.empty(
318-
input.shape, device=input.device, dtype=self.quant_dtype
319-
)
320-
at = auto_functionalized(
321-
RMS_ADD_OP,
322-
input=input,
323-
residual=residual,
324-
weight=weight,
325-
epsilon=self.epsilon,
326-
)
327-
at1 = auto_functionalized(
328-
self.QUANT_OP, result=result, input=at[1], scale=scale, scale_ub=None
329-
)
272+
def pattern(input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor):
273+
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
274+
result, scale = self.quant_matcher(result_rms)
330275

331-
# result, residual, scale
332-
return at1[1], at[2], at1[2]
276+
return result, residual, scale
333277

334278
def replacement(
335-
input: torch.Tensor,
336-
residual: torch.Tensor,
337-
weight: torch.Tensor,
338-
scale: torch.Tensor,
279+
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor
339280
):
340281
result = torch.empty_like(input, dtype=self.quant_dtype)
282+
scale = self.quant_matcher.make_scale(input)
341283
at = auto_functionalized(
342284
self.FUSED_OP,
343285
result=result,
@@ -356,7 +298,6 @@ def replacement(
356298
empty_bf16(5, 4), # input
357299
empty_bf16(5, 4), # residual
358300
empty_bf16(1, 5), # weight
359-
empty_fp32(1, 1), # scale
360301
]
361302

362303
pm.register_replacement(

vllm/compilation/matcher_utils.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from typing import Optional, Union
4+
5+
import torch
6+
from torch._higher_order_ops import auto_functionalized
7+
from torch._ops import OpOverload
8+
9+
from vllm.model_executor.layers.quantization.utils.quant_utils import (
10+
QuantKey,
11+
_normalize_quant_group_shape,
12+
kFp8DynamicTensorSym,
13+
kFp8DynamicTokenSym,
14+
kFp8StaticTensorSym,
15+
)
16+
17+
RMS_OP = torch.ops._C.rms_norm.default
18+
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
19+
20+
QUANT_OPS: dict[QuantKey, OpOverload] = {
21+
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
22+
kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
23+
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
24+
}
25+
26+
# TODO
27+
# if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
28+
# QUANT_OPS[
29+
# kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501
30+
31+
32+
class MatcherRMSNorm:
33+
def __init__(self, epsilon: float):
34+
self.epsilon = epsilon
35+
36+
def forward(
37+
self,
38+
input: torch.Tensor,
39+
weight: torch.Tensor,
40+
residual: Optional[torch.Tensor] = None,
41+
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
42+
if residual is None:
43+
result = torch.empty_like(input)
44+
_, result = auto_functionalized(
45+
RMS_OP,
46+
result=result,
47+
input=input,
48+
weight=weight,
49+
epsilon=self.epsilon,
50+
)
51+
52+
return result
53+
else:
54+
_, result, residual = auto_functionalized(
55+
RMS_ADD_OP,
56+
input=input,
57+
residual=residual,
58+
weight=weight,
59+
epsilon=self.epsilon,
60+
)
61+
62+
return result, residual
63+
64+
def __call__(
65+
self,
66+
input: torch.Tensor,
67+
weight: torch.Tensor,
68+
residual: Optional[torch.Tensor] = None,
69+
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
70+
return self.forward(input, weight, residual)
71+
72+
73+
class MatcherQuant:
74+
def __init__(self, quant_key: QuantKey):
75+
self.quant_key = quant_key
76+
assert quant_key in QUANT_OPS, f"unsupported quantization scheme {quant_key}"
77+
self.QUANT_OP = QUANT_OPS[quant_key]
78+
79+
def forward(
80+
self, input: torch.Tensor, scale: Optional[torch.Tensor] = None
81+
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
82+
# TODO: why does empty_like produce a permute but
83+
# empty via shape doesn't?
84+
result = torch.empty(
85+
input.shape, device=input.device, dtype=self.quant_key.dtype
86+
)
87+
88+
if self.quant_key.scale.static:
89+
assert scale is not None
90+
_, result = auto_functionalized(
91+
self.QUANT_OP, result=result, input=input, scale=scale
92+
)
93+
return result
94+
else:
95+
assert scale is None
96+
scale = self.make_scale(input)
97+
_, result, scale = auto_functionalized(
98+
self.QUANT_OP, result=result, input=input, scale=scale, scale_ub=None
99+
)
100+
return result, scale
101+
102+
def make_scale(self, input: torch.Tensor):
103+
normalized_group_shape = _normalize_quant_group_shape(
104+
input, self.quant_key.scale.group_shape
105+
)
106+
scale_shape = (
107+
input.shape[0] // normalized_group_shape[0],
108+
input.shape[1] // normalized_group_shape[1],
109+
)
110+
111+
return torch.empty(scale_shape, device=input.device, dtype=torch.float32)
112+
113+
def __call__(
114+
self, input: torch.Tensor, scale: Optional[torch.Tensor] = None
115+
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
116+
return self.forward(input, scale)

0 commit comments

Comments
 (0)