Skip to content

Commit b172747

Browse files
committed
Functionalize attn+quant patterns
Signed-off-by: Luka Govedič <[email protected]>
1 parent d96913a commit b172747

File tree

5 files changed

+281
-214
lines changed

5 files changed

+281
-214
lines changed

tests/compile/backend.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,10 @@ def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph], None]]):
5656
self.inductor_config["post_grad_custom_post_pass"] = self.post_pass
5757

5858
if compile_config.debug_dump_path:
59-
self.debug_dump_path = (Path(compile_config.debug_dump_path) /
60-
f"rank_{vllm_config.parallel_config.rank}")
59+
self.debug_dump_path = (
60+
Path(compile_config.debug_dump_path)
61+
/ f"rank_{vllm_config.parallel_config.rank}"
62+
)
6163
self.ctx = depyf.prepare_debug(str(self.debug_dump_path))
6264
self.ctx.__enter__()
6365
else:

tests/compile/test_fusion.py

Lines changed: 50 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,24 @@
88
from vllm.compilation.fusion import RMSNormQuantFusionPass
99
from vllm.compilation.noop_elimination import NoOpEliminationPass
1010
from vllm.compilation.post_cleanup import PostCleanupPass
11-
from vllm.config import (CompilationConfig, CompilationLevel, ModelConfig,
12-
PassConfig, VllmConfig)
11+
from vllm.config import (
12+
CompilationConfig,
13+
CompilationLevel,
14+
ModelConfig,
15+
PassConfig,
16+
VllmConfig,
17+
)
1318
from vllm.model_executor.layers.layernorm import RMSNorm
1419
from vllm.model_executor.layers.quantization.utils.quant_utils import (
15-
GroupShape, QuantKey, ScaleDesc)
20+
GroupShape,
21+
QuantKey,
22+
ScaleDesc,
23+
)
1624
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
17-
Fp8LinearOp, cutlass_fp8_supported, maybe_create_device_identity)
25+
Fp8LinearOp,
26+
cutlass_fp8_supported,
27+
maybe_create_device_identity,
28+
)
1829
from vllm.platforms import current_platform
1930

2031
from ..utils import override_cutlass_fp8_supported
@@ -24,9 +35,15 @@
2435

2536

2637
class TestModel(torch.nn.Module):
27-
28-
def __init__(self, hidden_size: int, eps: float, static: bool,
29-
cuda_force_torch: bool, *args, **kwargs):
38+
def __init__(
39+
self,
40+
hidden_size: int,
41+
eps: float,
42+
static: bool,
43+
cuda_force_torch: bool,
44+
*args,
45+
**kwargs,
46+
):
3047
super().__init__(*args, **kwargs)
3148
self.cuda_force_torch = cuda_force_torch
3249
self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)]
@@ -57,30 +74,27 @@ def forward(self, x):
5774
x = resid = torch.relu(x)
5875
y = self.norm[0](x)
5976

60-
x2 = self.fp8_linear.apply(y,
61-
self.w[0],
62-
self.wscale[0],
63-
input_scale=self.scale[0])
77+
x2 = self.fp8_linear.apply(
78+
y, self.w[0], self.wscale[0], input_scale=self.scale[0]
79+
)
6480
# make sure resid is used for replacement to work
6581
y2, resid = self.norm[1](x2, resid)
6682

67-
x3 = self.fp8_linear.apply(y2,
68-
self.w[1],
69-
self.wscale[1],
70-
input_scale=self.scale[1])
83+
x3 = self.fp8_linear.apply(
84+
y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
85+
)
7186

7287
y3, resid = self.norm[2](x3, resid) # use resid here
7388

74-
x4 = self.fp8_linear.apply(y3,
75-
self.w[2],
76-
self.wscale[2],
77-
input_scale=self.scale[2])
89+
x4 = self.fp8_linear.apply(
90+
y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
91+
)
7892

7993
y4, resid = self.norm[3](x4, resid) # use resid here
8094
return y4
8195

8296

83-
@pytest.mark.parametrize("dtype", [torch.float16]) #, torch.bfloat16])
97+
@pytest.mark.parametrize("dtype", [torch.float16]) # , torch.bfloat16])
8498
@pytest.mark.parametrize("hidden_size", [64])
8599
@pytest.mark.parametrize("num_tokens", [257])
86100
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
@@ -89,13 +103,22 @@ def forward(self, x):
89103
@pytest.mark.parametrize("enable_quant_fp8", [True, False])
90104
# cuda_force_torch used to test torch code path on platforms that
91105
# cutlass_fp8_supported() == True.
92-
@pytest.mark.parametrize("cuda_force_torch",
93-
[True, False] if cutlass_fp8_supported() else [True])
94-
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
95-
reason="Only test on CUDA and ROCm")
96-
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
97-
enable_rms_norm, enable_quant_fp8,
98-
cuda_force_torch):
106+
@pytest.mark.parametrize(
107+
"cuda_force_torch", [True, False] if cutlass_fp8_supported() else [True]
108+
)
109+
@pytest.mark.skipif(
110+
not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm"
111+
)
112+
def test_fusion_rmsnorm_quant(
113+
dtype,
114+
hidden_size,
115+
num_tokens,
116+
eps,
117+
static,
118+
enable_rms_norm,
119+
enable_quant_fp8,
120+
cuda_force_torch,
121+
):
99122
torch.set_default_device("cuda")
100123
torch.set_default_dtype(dtype)
101124
torch.manual_seed(1)

0 commit comments

Comments
 (0)