Skip to content

Commit 7fc9ea4

Browse files
committed
[FP8 SDPA] Enable fp8 sdpa pattern match
1 parent d23ed9e commit 7fc9ea4

File tree

2 files changed

+297
-41
lines changed

2 files changed

+297
-41
lines changed

test/prototype/inductor/test_qsdpa_fusion.py

Lines changed: 252 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,26 +18,184 @@
1818
from torchao.utils import torch_version_at_least
1919

2020

21-
class SelfAttnLikeModule(torch.nn.Module):
21+
def qdq(input, scale):
22+
dtype = input.dtype
23+
q_input = torch.ops.torchao.quantize_affine_float8_non_decomposed.default(
24+
input,
25+
torch.tensor([scale]),
26+
torch.float8_e4m3fn,
27+
)
28+
dq_input = torch.ops.torchao.dequantize_affine_float8_non_decomposed.default(
29+
q_input,
30+
torch.tensor([scale]),
31+
dtype,
32+
)
33+
return dq_input
34+
35+
36+
def fp8_convert_(model):
37+
def generate_model_info(model):
38+
from collections import namedtuple
39+
40+
mod_inst_info = namedtuple("ModInstInfo", ["name", "parent"])
41+
parent_child_mod_dict = {}
42+
43+
def create_mod_info_recursion(parent):
44+
for name, mod in parent.named_children():
45+
parent_child_mod_dict[mod] = mod_inst_info(name=name, parent=parent)
46+
create_mod_info_recursion(mod)
47+
48+
create_mod_info_recursion(model)
49+
return parent_child_mod_dict
50+
51+
parent_child_mod_dict = generate_model_info(model)
52+
for name, mod in model.named_modules():
53+
mod_type_str = mod.__class__.__name__
54+
if mod_type_str not in [
55+
"Linear",
56+
"SDPA",
57+
]:
58+
continue
59+
if mod_type_str == "Linear":
60+
param = mod.weight
61+
xmax = torch.max(param)
62+
weight_scale = xmax / torch.finfo(torch.float8_e4m3fn).max
63+
mod.weight_scale = weight_scale
64+
q_param = torch.clamp(
65+
(param / weight_scale),
66+
torch.finfo(torch.float8_e4m3fn).min,
67+
torch.finfo(torch.float8_e4m3fn).max,
68+
).to(torch.float8_e4m3fn)
69+
mod.weight.data = q_param
70+
patched_mod = FP8QDQLinear(mod.in_features, mod.out_features, False)
71+
patched_mod.bias = mod.bias
72+
patched_mod.weight_scale = weight_scale.item()
73+
patched_mod.weight.data = q_param
74+
else:
75+
patched_mod = FP8QDQSDPA()
76+
patched_mod.__dict__.update(mod.__dict__)
77+
patched_mod.transpose_for_scores = mod.transpose_for_scores
78+
79+
patched_mod.q_out_scale = (
80+
patched_mod.q_out_scale / torch.finfo(torch.float8_e4m3fn).max
81+
)
82+
patched_mod.k_out_scale = (
83+
patched_mod.k_out_scale / torch.finfo(torch.float8_e4m3fn).max
84+
)
85+
patched_mod.attn_weights_scale = (
86+
patched_mod.attn_weights_scale / torch.finfo(torch.float8_e4m3fn).max
87+
)
88+
patched_mod.v_out_scale = (
89+
patched_mod.v_out_scale / torch.finfo(torch.float8_e4m3fn).max
90+
)
91+
patched_mod.qk_out_scale = (
92+
patched_mod.qk_out_scale / torch.finfo(torch.float8_e4m3fn).max
93+
)
94+
patched_mod.attn_out_scale = (
95+
patched_mod.attn_out_scale / torch.finfo(torch.float8_e4m3fn).max
96+
)
97+
98+
parent = parent_child_mod_dict[mod].parent
99+
name = parent_child_mod_dict[mod].name
100+
setattr(parent, name, patched_mod)
101+
model.eval()
102+
return model
103+
104+
105+
class FP8QDQLinear(torch.nn.Module):
106+
def __init__(self, in_features, out_features, has_bias):
107+
super().__init__()
108+
self.qtype = torch.float8_e4m3fn
109+
self.weight = torch.randn((out_features, in_features)).to(self.qtype)
110+
self.weight_scale = 2.0
111+
self.scale = 2.0
112+
self.bias = None
113+
if has_bias:
114+
self.bias = torch.randn((out_features,))
115+
116+
def forward(self, input):
117+
weight = torch.ops.torchao.dequantize_affine_float8_non_decomposed.default(
118+
tensor=self.weight.data,
119+
scale=torch.tensor([self.weight_scale]),
120+
output_dtype=torch.float,
121+
)
122+
123+
q_input = torch.ops.torchao.quantize_affine_float8_non_decomposed.default(
124+
tensor=input,
125+
scale=torch.tensor([self.scale]),
126+
float8_dtype=self.qtype,
127+
)
128+
dq_input = torch.ops.torchao.dequantize_affine_float8_non_decomposed.default(
129+
tensor=q_input,
130+
scale=torch.tensor([self.scale]),
131+
output_dtype=torch.float,
132+
)
133+
134+
out = torch.nn.functional.linear(dq_input, weight, self.bias)
135+
return out
136+
137+
138+
class FP8QDQSDPA(torch.nn.Module):
139+
def __init__(self):
140+
super().__init__()
141+
self.q_out_scale = 1.5
142+
self.k_out_scale = 1.5
143+
self.attn_weights_scale = 1.5
144+
self.v_out_scale = 1.5
145+
self.attn_out_scale = 1.5
146+
self.qk_out_scale = 1.5
147+
148+
def forward(self, q, k, v, mask):
149+
key = self.transpose_for_scores(q)
150+
value = self.transpose_for_scores(k)
151+
query = self.transpose_for_scores(v)
152+
153+
# Take the dot product between "query" and "key" to get the raw attention scores.
154+
query_qdq = qdq(query, self.q_out_scale)
155+
key_qdq = qdq(key.transpose(-1, -2), self.k_out_scale)
156+
attn_weights = torch.matmul(query_qdq, key_qdq) / (self.input_dim**0.5)
157+
158+
# Normalize the attention scores to probabilities.
159+
attn_weights = torch.nn.functional.softmax(
160+
attn_weights, dim=-1, dtype=torch.float32
161+
).to(query.dtype)
162+
163+
# This is actually dropping out entire tokens to attend to, which might
164+
# seem a bit unusual, but is taken from the original Transformer paper.
165+
dropout = 0.0 if not self.training else self.dropout_prob
166+
attn_weights = torch.nn.functional.dropout(
167+
attn_weights, p=dropout, training=self.training
168+
)
169+
170+
# Mask heads if we want to
171+
if mask is not None:
172+
attn_weights = attn_weights + mask
173+
174+
value_qdq = qdq(value, self.v_out_scale)
175+
attn_weights_qdq = qdq(attn_weights, self.attn_weights_scale)
176+
attn_output = torch.matmul(attn_weights_qdq, value_qdq)
177+
attn_output = attn_output.transpose(1, 2).contiguous()
178+
179+
new_context_layer_shape = attn_output.size()[:-2] + (self.all_head_size,)
180+
attn_output = attn_output.reshape(new_context_layer_shape)
181+
182+
return attn_output
183+
184+
185+
class SDPA(torch.nn.Module):
22186
def __init__(
23187
self,
24188
input_dim,
25189
has_mask,
26-
num_attention_heads=None,
27-
attention_head_size=None,
190+
num_attention_heads,
191+
attention_head_size,
28192
) -> None:
29193
super().__init__()
30194
self.input_dim = input_dim
31-
self.q_proj = torch.nn.Linear(input_dim, input_dim, bias=False)
32-
self.k_proj = torch.nn.Linear(input_dim, input_dim, bias=False)
33-
self.v_proj = torch.nn.Linear(input_dim, input_dim, bias=False)
34195
self.softmax = torch.nn.Softmax(dim=-1)
35-
assert num_attention_heads is not None
36-
assert attention_head_size is not None
37196
self.num_attention_heads = num_attention_heads
38197
self.attention_head_size = attention_head_size
39198
self.all_head_size = self.num_attention_heads * self.attention_head_size
40-
self.dense = torch.nn.Linear(self.all_head_size, self.all_head_size)
41199
self.dropout = torch.nn.Dropout(0)
42200
self.has_mask = has_mask
43201

@@ -49,10 +207,7 @@ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
49207
x = x.view(new_x_shape)
50208
return x.permute([0, 2, 1, 3])
51209

52-
def forward(self, x, mask):
53-
q = self.q_proj(x)
54-
k = self.k_proj(x)
55-
v = self.v_proj(x)
210+
def forward(self, q, k, v, mask):
56211
q = self.transpose_for_scores(q)
57212
k = self.transpose_for_scores(k)
58213
v = self.transpose_for_scores(v)
@@ -63,9 +218,38 @@ def forward(self, x, mask):
63218
attention = self.dropout(attention)
64219
context_layer = torch.matmul(attention, v)
65220
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
66-
context_layer = context_layer.view(
67-
context_layer.size()[:-2] + (self.all_head_size,)
221+
return context_layer.reshape(context_layer.size()[:-2] + (self.all_head_size,))
222+
223+
224+
class MHAModule(torch.nn.Module):
225+
def __init__(
226+
self,
227+
input_dim,
228+
has_mask,
229+
num_attention_heads,
230+
attention_head_size,
231+
) -> None:
232+
super().__init__()
233+
self.input_dim = input_dim
234+
self.q_proj = torch.nn.Linear(input_dim, input_dim, bias=False)
235+
self.k_proj = torch.nn.Linear(input_dim, input_dim, bias=False)
236+
self.v_proj = torch.nn.Linear(input_dim, input_dim, bias=False)
237+
self.num_attention_heads = num_attention_heads
238+
self.attention_head_size = attention_head_size
239+
self.all_head_size = self.num_attention_heads * self.attention_head_size
240+
self.dense = torch.nn.Linear(self.all_head_size, self.all_head_size)
241+
self.attn_mod = SDPA(
242+
input_dim,
243+
has_mask,
244+
num_attention_heads,
245+
attention_head_size,
68246
)
247+
248+
def forward(self, x, mask):
249+
q = self.q_proj(x)
250+
k = self.k_proj(x)
251+
v = self.v_proj(x)
252+
context_layer = self.attn_mod(q, k, v, mask)
69253
return self.dense(context_layer)
70254

71255

@@ -158,7 +342,7 @@ def _check_common(
158342
reason="cpp kernels not built",
159343
)
160344
@config.patch({"freezing": True})
161-
def _test_qsdpa_rewriter(self):
345+
def _test_int8_sdpa_rewriter(self):
162346
import torchao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq
163347
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
164348
from torchao.quantization.pt2e.quantizer.x86_inductor_quantizer import (
@@ -171,7 +355,7 @@ def _test_qsdpa_rewriter(self):
171355
[torch.float32, torch.bfloat16], [True, False], [56, 1]
172356
):
173357
seqlen, numhead, headsize = 197, 16, 64
174-
mod = SelfAttnLikeModule(
358+
mod = MHAModule(
175359
input_dim=headsize * numhead,
176360
has_mask=has_mask,
177361
num_attention_heads=numhead,
@@ -204,6 +388,51 @@ def _test_qsdpa_rewriter(self):
204388
prepare_model(*inputs)
205389
convert_model = convert_pt2e(prepare_model)
206390
torchao.quantization.pt2e.move_exported_model_to_eval(convert_model)
391+
392+
self._check_common(
393+
convert_model, args1=inputs, check_train=False, atol=1.0
394+
)
395+
396+
@skipIfRocm
397+
@unittest.skipIf(
398+
not torch_version_at_least("2.7.0"),
399+
reason="qsdpa requires torch 2.7 or later",
400+
)
401+
@unittest.skipIf(
402+
"CPU" not in torch._C._dispatch_dump("torchao::qscaled_dot_product"),
403+
reason="cpp kernels not built",
404+
)
405+
@config.patch({"freezing": True})
406+
def _test_fp8_sdpa_rewriter(self):
407+
import torchao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq
408+
409+
# pattern is different for bs=1
410+
torch.manual_seed(1234)
411+
for dtype, bs in itertools.product([torch.float32, torch.bfloat16], [56, 1]):
412+
seqlen, numhead, headsize = 197, 16, 64
413+
mod = MHAModule(
414+
input_dim=headsize * numhead,
415+
has_mask=False,
416+
num_attention_heads=numhead,
417+
attention_head_size=headsize,
418+
).eval()
419+
inputs = (
420+
torch.randn(
421+
(bs, seqlen, headsize * numhead), device=self.device, dtype=dtype
422+
),
423+
None,
424+
)
425+
enable_autocast = dtype == torch.bfloat16
426+
with (
427+
torch.no_grad(),
428+
torch.amp.autocast(
429+
self.device, enabled=enable_autocast, dtype=torch.bfloat16
430+
),
431+
config.patch(post_grad_custom_pre_pass=custom_pass),
432+
):
433+
_qsdpa_init()
434+
convert_model = fp8_convert_(mod)
435+
207436
self._check_common(
208437
convert_model, args1=inputs, check_train=False, atol=1.0
209438
)
@@ -213,7 +442,12 @@ def _test_qsdpa_rewriter(self):
213442

214443
class SDPAPatternRewriterCpuTests(TestSDPAPatternRewriterTemplate):
215444
device = "cpu"
216-
test_qsdpa_rewriter_cpu = TestSDPAPatternRewriterTemplate._test_qsdpa_rewriter
445+
test_int8_sdpa_rewriter_cpu = (
446+
TestSDPAPatternRewriterTemplate._test_int8_sdpa_rewriter
447+
)
448+
test_fp8_sdpa_rewriter_cpu = (
449+
TestSDPAPatternRewriterTemplate._test_fp8_sdpa_rewriter
450+
)
217451

218452

219453
if __name__ == "__main__":

0 commit comments

Comments
 (0)