1818from  torchao .utils  import  torch_version_at_least 
1919
2020
21- class  SelfAttnLikeModule (torch .nn .Module ):
21+ def  qdq (input , scale ):
22+     dtype  =  input .dtype 
23+     q_input  =  torch .ops .torchao .quantize_affine_float8_non_decomposed .default (
24+         input ,
25+         torch .tensor ([scale ]),
26+         torch .float8_e4m3fn ,
27+     )
28+     dq_input  =  torch .ops .torchao .dequantize_affine_float8_non_decomposed .default (
29+         q_input ,
30+         torch .tensor ([scale ]),
31+         dtype ,
32+     )
33+     return  dq_input 
34+ 
35+ 
36+ def  fp8_convert_ (model ):
37+     def  generate_model_info (model ):
38+         from  collections  import  namedtuple 
39+ 
40+         mod_inst_info  =  namedtuple ("ModInstInfo" , ["name" , "parent" ])
41+         parent_child_mod_dict  =  {}
42+ 
43+         def  create_mod_info_recursion (parent ):
44+             for  name , mod  in  parent .named_children ():
45+                 parent_child_mod_dict [mod ] =  mod_inst_info (name = name , parent = parent )
46+                 create_mod_info_recursion (mod )
47+ 
48+         create_mod_info_recursion (model )
49+         return  parent_child_mod_dict 
50+ 
51+     parent_child_mod_dict  =  generate_model_info (model )
52+     for  name , mod  in  model .named_modules ():
53+         mod_type_str  =  mod .__class__ .__name__ 
54+         if  mod_type_str  not  in   [
55+             "Linear" ,
56+             "SDPA" ,
57+         ]:
58+             continue 
59+         if  mod_type_str  ==  "Linear" :
60+             param  =  mod .weight 
61+             xmax  =  torch .max (param )
62+             weight_scale  =  xmax  /  torch .finfo (torch .float8_e4m3fn ).max 
63+             mod .weight_scale  =  weight_scale 
64+             q_param  =  torch .clamp (
65+                 (param  /  weight_scale ),
66+                 torch .finfo (torch .float8_e4m3fn ).min ,
67+                 torch .finfo (torch .float8_e4m3fn ).max ,
68+             ).to (torch .float8_e4m3fn )
69+             mod .weight .data  =  q_param 
70+             patched_mod  =  FP8QDQLinear (mod .in_features , mod .out_features , False )
71+             patched_mod .bias  =  mod .bias 
72+             patched_mod .weight_scale  =  weight_scale .item ()
73+             patched_mod .weight .data  =  q_param 
74+         else :
75+             patched_mod  =  FP8QDQSDPA ()
76+             patched_mod .__dict__ .update (mod .__dict__ )
77+             patched_mod .transpose_for_scores  =  mod .transpose_for_scores 
78+ 
79+             patched_mod .q_out_scale  =  (
80+                 patched_mod .q_out_scale  /  torch .finfo (torch .float8_e4m3fn ).max 
81+             )
82+             patched_mod .k_out_scale  =  (
83+                 patched_mod .k_out_scale  /  torch .finfo (torch .float8_e4m3fn ).max 
84+             )
85+             patched_mod .attn_weights_scale  =  (
86+                 patched_mod .attn_weights_scale  /  torch .finfo (torch .float8_e4m3fn ).max 
87+             )
88+             patched_mod .v_out_scale  =  (
89+                 patched_mod .v_out_scale  /  torch .finfo (torch .float8_e4m3fn ).max 
90+             )
91+             patched_mod .qk_out_scale  =  (
92+                 patched_mod .qk_out_scale  /  torch .finfo (torch .float8_e4m3fn ).max 
93+             )
94+             patched_mod .attn_out_scale  =  (
95+                 patched_mod .attn_out_scale  /  torch .finfo (torch .float8_e4m3fn ).max 
96+             )
97+ 
98+         parent  =  parent_child_mod_dict [mod ].parent 
99+         name  =  parent_child_mod_dict [mod ].name 
100+         setattr (parent , name , patched_mod )
101+     model .eval ()
102+     return  model 
103+ 
104+ 
105+ class  FP8QDQLinear (torch .nn .Module ):
106+     def  __init__ (self , in_features , out_features , has_bias ):
107+         super ().__init__ ()
108+         self .qtype  =  torch .float8_e4m3fn 
109+         self .weight  =  torch .randn ((out_features , in_features )).to (self .qtype )
110+         self .weight_scale  =  2.0 
111+         self .scale  =  2.0 
112+         self .bias  =  None 
113+         if  has_bias :
114+             self .bias  =  torch .randn ((out_features ,))
115+ 
116+     def  forward (self , input ):
117+         weight  =  torch .ops .torchao .dequantize_affine_float8_non_decomposed .default (
118+             tensor = self .weight .data ,
119+             scale = torch .tensor ([self .weight_scale ]),
120+             output_dtype = torch .float ,
121+         )
122+ 
123+         q_input  =  torch .ops .torchao .quantize_affine_float8_non_decomposed .default (
124+             tensor = input ,
125+             scale = torch .tensor ([self .scale ]),
126+             float8_dtype = self .qtype ,
127+         )
128+         dq_input  =  torch .ops .torchao .dequantize_affine_float8_non_decomposed .default (
129+             tensor = q_input ,
130+             scale = torch .tensor ([self .scale ]),
131+             output_dtype = torch .float ,
132+         )
133+ 
134+         out  =  torch .nn .functional .linear (dq_input , weight , self .bias )
135+         return  out 
136+ 
137+ 
138+ class  FP8QDQSDPA (torch .nn .Module ):
139+     def  __init__ (self ):
140+         super ().__init__ ()
141+         self .q_out_scale  =  1.5 
142+         self .k_out_scale  =  1.5 
143+         self .attn_weights_scale  =  1.5 
144+         self .v_out_scale  =  1.5 
145+         self .attn_out_scale  =  1.5 
146+         self .qk_out_scale  =  1.5 
147+ 
148+     def  forward (self , q , k , v , mask ):
149+         key  =  self .transpose_for_scores (q )
150+         value  =  self .transpose_for_scores (k )
151+         query  =  self .transpose_for_scores (v )
152+ 
153+         # Take the dot product between "query" and "key" to get the raw attention scores. 
154+         query_qdq  =  qdq (query , self .q_out_scale )
155+         key_qdq  =  qdq (key .transpose (- 1 , - 2 ), self .k_out_scale )
156+         attn_weights  =  torch .matmul (query_qdq , key_qdq ) /  (self .input_dim ** 0.5 )
157+ 
158+         # Normalize the attention scores to probabilities. 
159+         attn_weights  =  torch .nn .functional .softmax (
160+             attn_weights , dim = - 1 , dtype = torch .float32 
161+         ).to (query .dtype )
162+ 
163+         # This is actually dropping out entire tokens to attend to, which might 
164+         # seem a bit unusual, but is taken from the original Transformer paper. 
165+         dropout  =  0.0  if  not  self .training  else  self .dropout_prob 
166+         attn_weights  =  torch .nn .functional .dropout (
167+             attn_weights , p = dropout , training = self .training 
168+         )
169+ 
170+         # Mask heads if we want to 
171+         if  mask  is  not   None :
172+             attn_weights  =  attn_weights  +  mask 
173+ 
174+         value_qdq  =  qdq (value , self .v_out_scale )
175+         attn_weights_qdq  =  qdq (attn_weights , self .attn_weights_scale )
176+         attn_output  =  torch .matmul (attn_weights_qdq , value_qdq )
177+         attn_output  =  attn_output .transpose (1 , 2 ).contiguous ()
178+ 
179+         new_context_layer_shape  =  attn_output .size ()[:- 2 ] +  (self .all_head_size ,)
180+         attn_output  =  attn_output .reshape (new_context_layer_shape )
181+ 
182+         return  attn_output 
183+ 
184+ 
185+ class  SDPA (torch .nn .Module ):
22186    def  __init__ (
23187        self ,
24188        input_dim ,
25189        has_mask ,
26-         num_attention_heads = None ,
27-         attention_head_size = None ,
190+         num_attention_heads ,
191+         attention_head_size ,
28192    ) ->  None :
29193        super ().__init__ ()
30194        self .input_dim  =  input_dim 
31-         self .q_proj  =  torch .nn .Linear (input_dim , input_dim , bias = False )
32-         self .k_proj  =  torch .nn .Linear (input_dim , input_dim , bias = False )
33-         self .v_proj  =  torch .nn .Linear (input_dim , input_dim , bias = False )
34195        self .softmax  =  torch .nn .Softmax (dim = - 1 )
35-         assert  num_attention_heads  is  not   None 
36-         assert  attention_head_size  is  not   None 
37196        self .num_attention_heads  =  num_attention_heads 
38197        self .attention_head_size  =  attention_head_size 
39198        self .all_head_size  =  self .num_attention_heads  *  self .attention_head_size 
40-         self .dense  =  torch .nn .Linear (self .all_head_size , self .all_head_size )
41199        self .dropout  =  torch .nn .Dropout (0 )
42200        self .has_mask  =  has_mask 
43201
@@ -49,10 +207,7 @@ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
49207        x  =  x .view (new_x_shape )
50208        return  x .permute ([0 , 2 , 1 , 3 ])
51209
52-     def  forward (self , x , mask ):
53-         q  =  self .q_proj (x )
54-         k  =  self .k_proj (x )
55-         v  =  self .v_proj (x )
210+     def  forward (self , q , k , v , mask ):
56211        q  =  self .transpose_for_scores (q )
57212        k  =  self .transpose_for_scores (k )
58213        v  =  self .transpose_for_scores (v )
@@ -63,9 +218,38 @@ def forward(self, x, mask):
63218        attention  =  self .dropout (attention )
64219        context_layer  =  torch .matmul (attention , v )
65220        context_layer  =  context_layer .permute (0 , 2 , 1 , 3 ).contiguous ()
66-         context_layer  =  context_layer .view (
67-             context_layer .size ()[:- 2 ] +  (self .all_head_size ,)
221+         return  context_layer .reshape (context_layer .size ()[:- 2 ] +  (self .all_head_size ,))
222+ 
223+ 
224+ class  MHAModule (torch .nn .Module ):
225+     def  __init__ (
226+         self ,
227+         input_dim ,
228+         has_mask ,
229+         num_attention_heads ,
230+         attention_head_size ,
231+     ) ->  None :
232+         super ().__init__ ()
233+         self .input_dim  =  input_dim 
234+         self .q_proj  =  torch .nn .Linear (input_dim , input_dim , bias = False )
235+         self .k_proj  =  torch .nn .Linear (input_dim , input_dim , bias = False )
236+         self .v_proj  =  torch .nn .Linear (input_dim , input_dim , bias = False )
237+         self .num_attention_heads  =  num_attention_heads 
238+         self .attention_head_size  =  attention_head_size 
239+         self .all_head_size  =  self .num_attention_heads  *  self .attention_head_size 
240+         self .dense  =  torch .nn .Linear (self .all_head_size , self .all_head_size )
241+         self .attn_mod  =  SDPA (
242+             input_dim ,
243+             has_mask ,
244+             num_attention_heads ,
245+             attention_head_size ,
68246        )
247+ 
248+     def  forward (self , x , mask ):
249+         q  =  self .q_proj (x )
250+         k  =  self .k_proj (x )
251+         v  =  self .v_proj (x )
252+         context_layer  =  self .attn_mod (q , k , v , mask )
69253        return  self .dense (context_layer )
70254
71255
@@ -158,7 +342,7 @@ def _check_common(
158342        reason = "cpp kernels not built" , 
159343    ) 
160344    @config .patch ({"freezing" : True }) 
161-     def  _test_qsdpa_rewriter (self ):
345+     def  _test_int8_sdpa_rewriter (self ):
162346        import  torchao .quantization .pt2e .quantizer .x86_inductor_quantizer  as  xiq 
163347        from  torchao .quantization .pt2e .quantize_pt2e  import  convert_pt2e , prepare_pt2e 
164348        from  torchao .quantization .pt2e .quantizer .x86_inductor_quantizer  import  (
@@ -171,7 +355,7 @@ def _test_qsdpa_rewriter(self):
171355            [torch .float32 , torch .bfloat16 ], [True , False ], [56 , 1 ]
172356        ):
173357            seqlen , numhead , headsize  =  197 , 16 , 64 
174-             mod  =  SelfAttnLikeModule (
358+             mod  =  MHAModule (
175359                input_dim = headsize  *  numhead ,
176360                has_mask = has_mask ,
177361                num_attention_heads = numhead ,
@@ -204,6 +388,51 @@ def _test_qsdpa_rewriter(self):
204388                prepare_model (* inputs )
205389                convert_model  =  convert_pt2e (prepare_model )
206390                torchao .quantization .pt2e .move_exported_model_to_eval (convert_model )
391+ 
392+                 self ._check_common (
393+                     convert_model , args1 = inputs , check_train = False , atol = 1.0 
394+                 )
395+ 
396+     @skipIfRocm  
397+     @unittest .skipIf ( 
398+         not  torch_version_at_least ("2.7.0" ), 
399+         reason = "qsdpa requires torch 2.7 or later" , 
400+     ) 
401+     @unittest .skipIf ( 
402+         "CPU"  not  in   torch ._C ._dispatch_dump ("torchao::qscaled_dot_product" ), 
403+         reason = "cpp kernels not built" , 
404+     ) 
405+     @config .patch ({"freezing" : True }) 
406+     def  _test_fp8_sdpa_rewriter (self ):
407+         import  torchao .quantization .pt2e .quantizer .x86_inductor_quantizer  as  xiq 
408+ 
409+         # pattern is different for bs=1 
410+         torch .manual_seed (1234 )
411+         for  dtype , bs  in  itertools .product ([torch .float32 , torch .bfloat16 ], [56 , 1 ]):
412+             seqlen , numhead , headsize  =  197 , 16 , 64 
413+             mod  =  MHAModule (
414+                 input_dim = headsize  *  numhead ,
415+                 has_mask = False ,
416+                 num_attention_heads = numhead ,
417+                 attention_head_size = headsize ,
418+             ).eval ()
419+             inputs  =  (
420+                 torch .randn (
421+                     (bs , seqlen , headsize  *  numhead ), device = self .device , dtype = dtype 
422+                 ),
423+                 None ,
424+             )
425+             enable_autocast  =  dtype  ==  torch .bfloat16 
426+             with  (
427+                 torch .no_grad (),
428+                 torch .amp .autocast (
429+                     self .device , enabled = enable_autocast , dtype = torch .bfloat16 
430+                 ),
431+                 config .patch (post_grad_custom_pre_pass = custom_pass ),
432+             ):
433+                 _qsdpa_init ()
434+                 convert_model  =  fp8_convert_ (mod )
435+ 
207436                self ._check_common (
208437                    convert_model , args1 = inputs , check_train = False , atol = 1.0 
209438                )
@@ -213,7 +442,12 @@ def _test_qsdpa_rewriter(self):
213442
214443    class  SDPAPatternRewriterCpuTests (TestSDPAPatternRewriterTemplate ):
215444        device  =  "cpu" 
216-         test_qsdpa_rewriter_cpu  =  TestSDPAPatternRewriterTemplate ._test_qsdpa_rewriter 
445+         test_int8_sdpa_rewriter_cpu  =  (
446+             TestSDPAPatternRewriterTemplate ._test_int8_sdpa_rewriter 
447+         )
448+         test_fp8_sdpa_rewriter_cpu  =  (
449+             TestSDPAPatternRewriterTemplate ._test_fp8_sdpa_rewriter 
450+         )
217451
218452
219453if  __name__  ==  "__main__" :
0 commit comments