Skip to content

Commit 876ef22

Browse files
committed
Fix tests, PR feedback
Signed-off-by: Luka Govedič <[email protected]>
1 parent 6253d5b commit 876ef22

File tree

4 files changed

+25
-16
lines changed

4 files changed

+25
-16
lines changed

tests/compile/test_fusion.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -169,24 +169,29 @@ def test_fusion_rmsnorm_quant(
169169
cleanup_pass = PostCleanupPass(vllm_config)
170170

171171
backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
172+
backend2 = TestBackend(noop_pass, cleanup_pass)
172173
model = TestModel(hidden_size, eps, static, cuda_force_torch)
173174

174175
# First dimension dynamic
175176
x = torch.rand(num_tokens, hidden_size)
176177
torch._dynamo.mark_dynamic(x, 0)
177178

178-
result = model(x)
179+
model_fused = torch.compile(model, backend=backend)
180+
result_fused = model_fused(x)
179181

180-
model2 = torch.compile(model, backend=backend)
181-
result2 = model2(x)
182+
model_unfused = torch.compile(model, backend=backend2)
183+
result_unfused = model_unfused(x)
182184

183-
# Higher tol for dynamic bfloat16
184-
if dtype == torch.float16 or static:
185+
if enable_rms_norm_custom_op and static:
186+
ATOL, RTOL = (1e-5, 1e-5) # up to 1e-8 close
187+
elif dtype == torch.float16:
185188
ATOL, RTOL = (2e-3, 2e-3)
189+
elif static:
190+
ATOL, RTOL = (5e-3, 5e-3)
186191
else:
187192
ATOL, RTOL = (1e-2, 1e-2)
188193

189-
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)
194+
torch.testing.assert_close(result_fused, result_unfused, atol=ATOL, rtol=RTOL)
190195

191196
assert fusion_pass.matched_count == 3
192197
backend.check_before_ops(model.ops_in_model_before())

tests/compile/test_sequence_parallelism.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
ModelConfig,
1919
PassConfig,
2020
VllmConfig,
21+
get_current_vllm_config,
2122
set_current_vllm_config,
2223
)
2324
from vllm.distributed import tensor_model_parallel_all_reduce
@@ -94,13 +95,11 @@ def ops_in_model(self):
9495

9596

9697
class TestQuantModel(torch.nn.Module):
97-
def __init__(
98-
self, hidden_size=16, intermediate_size=32, vllm_config: VllmConfig = None
99-
):
98+
def __init__(self, hidden_size=16, intermediate_size=32):
10099
super().__init__()
101100
self.hidden_size = hidden_size
102101
self.intermediate_size = intermediate_size
103-
self.vllm_config = vllm_config
102+
self.vllm_config = get_current_vllm_config()
104103
self.gate_proj = torch.nn.Parameter(
105104
torch.empty((intermediate_size, hidden_size)), requires_grad=False
106105
)

vllm/compilation/fusion.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434

3535
def empty_bf16(*args, **kwargs):
36-
return torch.empty(*args, **kwargs, dtype=torch.float16, device="cuda")
36+
return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")
3737

3838

3939
def empty_fp32(*args, **kwargs):
@@ -144,7 +144,7 @@ def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
144144
inputs = [
145145
# input, weight
146146
*self.rmsnorm_matcher.inputs(),
147-
empty_fp32(1, 1), # scale
147+
self.quant_matcher.inputs()[1], # scale
148148
]
149149
pattern(*inputs)
150150

@@ -200,7 +200,7 @@ def replacement(
200200
inputs = [
201201
# input, weight, residual
202202
*self.rmsnorm_matcher.inputs(),
203-
empty_fp32(1, 1), # scale
203+
self.quant_matcher.inputs()[1], # scale
204204
]
205205

206206
pm.register_replacement(

vllm/compilation/matcher_utils.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,7 @@ def __init__(self, epsilon: float, enabled: bool | None = None):
112112

113113
def inputs(self):
114114
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
115-
weight = self.empty(
116-
16,
117-
)
115+
weight = self.empty(16)
118116
residual = self.empty(5, 16)
119117
return [input, weight, residual]
120118

@@ -203,3 +201,10 @@ def make_scale(self, input: torch.Tensor):
203201
)
204202

205203
return torch.empty(scale_shape, device=input.device, dtype=torch.float32)
204+
205+
def inputs(self) -> list[torch.Tensor]:
206+
input = self.empty(5, 16)
207+
if self.quant_key.scale.static:
208+
return [input, self.empty_f32(1, 1)]
209+
210+
return [input]

0 commit comments

Comments
 (0)