Skip to content

Commit 14fdc8b

Browse files
committed
quant with fix for pure torch, broke others
Signed-off-by: Luka Govedič <[email protected]>
1 parent e151e6d commit 14fdc8b

File tree

3 files changed

+13
-11
lines changed

3 files changed

+13
-11
lines changed

tests/compile/test_fusion.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,10 +147,8 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
147147
model2 = torch.compile(model, backend=backend)
148148
result2 = model2(x)
149149

150-
# Higher tol for dynamic, even higher for bfloat16
151-
if static:
152-
ATOL, RTOL = (1e-3, 1e-3)
153-
elif dtype == torch.float16:
150+
# Higher tol for dynamic bfloat16
151+
if dtype == torch.float16 or static:
154152
ATOL, RTOL = (2e-3, 2e-3)
155153
else:
156154
ATOL, RTOL = (1e-2, 1e-2)

vllm/compilation/fusion.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727

2828
def empty_bf16(*args, **kwargs):
29-
return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")
29+
return torch.empty(*args, **kwargs, dtype=torch.float16, device="cuda")
3030

3131

3232
def empty_fp32(*args, **kwargs):
@@ -133,7 +133,7 @@ def replacement(input: torch.Tensor, weight: torch.Tensor,
133133
return at[1]
134134

135135
inputs = [
136-
empty_bf16(5, 4), # input
136+
empty_fp32(5, 4), # input # TODO: rms_input
137137
empty_bf16(4, ), # weight
138138
empty_fp32(1, 1) # scale
139139
]
@@ -185,8 +185,8 @@ def replacement(input: torch.Tensor, residual: torch.Tensor,
185185
return at[1], at[2]
186186

187187
inputs = [
188-
# TODO: maybe 32bit for torch impl?
189-
# TODO dtype doesn't seem to matter?
188+
# TODO: maybe 32bit for torch impl? yes to resolve bug
189+
# TODO dtype doesn't seem to matter? it does matter for what cvts get traced
190190
empty_bf16(5, 4), # input
191191
empty_bf16(5, 4), # residual
192192
empty_bf16(4, ), # weight

vllm/compilation/matcher_utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ def __init__(self, epsilon: float, enabled: Optional[bool] = None):
4343

4444
self.forward = self.forward_custom if enabled else self.forward_native
4545
self.model_dtype = get_current_vllm_config().model_config.dtype
46+
print(self.model_dtype)
47+
48+
def inputs(self):
49+
return
4650

4751
def forward_custom(
4852
self,
@@ -76,10 +80,10 @@ def forward_native(
7680
weight: torch.Tensor,
7781
residual: Optional[torch.Tensor] = None,
7882
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
79-
x = input # .to(torch.float32)
83+
x = input.to(torch.float32)
8084
if residual is not None:
81-
x = x + residual.to(torch.float32)
82-
residual = x # conversion to 16-bit is eliminated in full graph
85+
x = x + residual
86+
residual = x.to(self.model_dtype)
8387

8488
variance = x.pow(2).mean(dim=-1, keepdim=True)
8589

0 commit comments

Comments
 (0)