@@ -191,46 +191,57 @@ def _linear_fp8_act_fp8_weight_sparse_cutlass_check(input_tensor, weight_tensor,
191
191
from torchao .dtypes .floatx import Float8Layout
192
192
193
193
base_check = (
194
- isinstance (input_tensor , AffineQuantizedTensor ) and
195
- isinstance (input_tensor ._layout , Float8Layout ) and
196
- input_tensor .dtype in (torch .float16 , torch .bfloat16 ) and
197
- len (input_tensor .shape ) >= 2 and
198
- input_tensor .tensor_impl .scale .dtype == torch .float32 and
199
- isinstance (weight_tensor , AffineQuantizedTensor ) and
200
- isinstance (weight_tensor ._layout , CutlassSemiSparseLayout ) and
201
- weight_tensor .dtype == input_tensor .dtype and
202
- len (weight_tensor .shape ) == 2 and
203
- weight_tensor .tensor_impl .scale .dtype == torch .float32 and
204
- (bias is None or bias .dtype == input_tensor .dtype ) and
205
- (bias is None or len (bias .shape ) == 1 )
194
+ isinstance (input_tensor , AffineQuantizedTensor )
195
+ and isinstance (input_tensor ._layout , Float8Layout )
196
+ and input_tensor .dtype in (torch .float16 , torch .bfloat16 )
197
+ and len (input_tensor .shape ) >= 2
198
+ and input_tensor .tensor_impl .scale .dtype == torch .float32
199
+ and isinstance (weight_tensor , AffineQuantizedTensor )
200
+ and isinstance (weight_tensor ._layout , CutlassSemiSparseLayout )
201
+ and weight_tensor .dtype == input_tensor .dtype
202
+ and len (weight_tensor .shape ) == 2
203
+ and weight_tensor .tensor_impl .scale .dtype == torch .float32
204
+ and (bias is None or bias .dtype == input_tensor .dtype )
205
+ and (bias is None or len (bias .shape ) == 1 )
206
206
)
207
207
208
208
if base_check :
209
-
210
209
# do extra check and reshape if needed
211
210
input_tensor_squeezed = False
212
- if len (input_tensor .tensor_impl .scale .shape ) == len (input_tensor .shape ) and \
213
- len (input_tensor .tensor_impl .scale .shape ) > 1 and \
214
- input_tensor .tensor_impl .scale .shape [- 1 ] == 1 :
215
- input_tensor .tensor_impl .scale = torch .squeeze (input_tensor .tensor_impl .scale , dim = - 1 )
211
+ if (
212
+ len (input_tensor .tensor_impl .scale .shape ) == len (input_tensor .shape )
213
+ and len (input_tensor .tensor_impl .scale .shape ) > 1
214
+ and input_tensor .tensor_impl .scale .shape [- 1 ] == 1
215
+ ):
216
+ input_tensor .tensor_impl .scale = torch .squeeze (
217
+ input_tensor .tensor_impl .scale , dim = - 1
218
+ )
216
219
input_tensor_squeezed = True
217
-
220
+
218
221
weight_tensor_squeezed = False
219
- if len (weight_tensor .tensor_impl .scale .shape ) == 2 and \
220
- weight_tensor .tensor_impl .scale .shape [- 1 ] == 1 :
221
- weight_tensor .tensor_impl .scale = torch .squeeze (weight_tensor .tensor_impl .scale , dim = - 1 )
222
+ if (
223
+ len (weight_tensor .tensor_impl .scale .shape ) == 2
224
+ and weight_tensor .tensor_impl .scale .shape [- 1 ] == 1
225
+ ):
226
+ weight_tensor .tensor_impl .scale = torch .squeeze (
227
+ weight_tensor .tensor_impl .scale , dim = - 1
228
+ )
222
229
weight_tensor_squeezed = True
223
230
224
231
extra_check = (
225
232
len (input_tensor .tensor_impl .scale .shape ) == len (input_tensor .shape ) - 1
226
233
and len (weight_tensor .tensor_impl .scale .shape ) == 1
227
234
)
228
235
229
- if not extra_check : # revert if extra check failed
236
+ if not extra_check : # revert if extra check failed
230
237
if input_tensor_squeezed :
231
- input_tensor .tensor_impl .scale = torch .unsqueeze (input_tensor .tensor_impl .scale , dim = - 1 )
238
+ input_tensor .tensor_impl .scale = torch .unsqueeze (
239
+ input_tensor .tensor_impl .scale , dim = - 1
240
+ )
232
241
if weight_tensor_squeezed :
233
- weight_tensor .tensor_impl .scale = torch .unsqueeze (weight_tensor .tensor_impl .scale , dim = - 1 )
242
+ weight_tensor .tensor_impl .scale = torch .unsqueeze (
243
+ weight_tensor .tensor_impl .scale , dim = - 1
244
+ )
234
245
235
246
return extra_check
236
247
0 commit comments