Skip to content

Commit 8c5c33e

Browse files
authored
fix ruff lint in main branch (#3067)
Summary: #3057 was landed with failing lint, this PR fixes it. Test Plan: ``` ruff format ``` Reviewers: Subscribers: Tasks: Tags:
1 parent 7391a4a commit 8c5c33e

File tree

1 file changed

+35
-24
lines changed

1 file changed

+35
-24
lines changed

torchao/dtypes/floatx/cutlass_semi_sparse_layout.py

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -191,46 +191,57 @@ def _linear_fp8_act_fp8_weight_sparse_cutlass_check(input_tensor, weight_tensor,
191191
from torchao.dtypes.floatx import Float8Layout
192192

193193
base_check = (
194-
isinstance(input_tensor, AffineQuantizedTensor) and
195-
isinstance(input_tensor._layout, Float8Layout) and
196-
input_tensor.dtype in (torch.float16, torch.bfloat16) and
197-
len(input_tensor.shape) >= 2 and
198-
input_tensor.tensor_impl.scale.dtype == torch.float32 and
199-
isinstance(weight_tensor, AffineQuantizedTensor) and
200-
isinstance(weight_tensor._layout, CutlassSemiSparseLayout) and
201-
weight_tensor.dtype == input_tensor.dtype and
202-
len(weight_tensor.shape) == 2 and
203-
weight_tensor.tensor_impl.scale.dtype == torch.float32 and
204-
(bias is None or bias.dtype == input_tensor.dtype) and
205-
(bias is None or len(bias.shape) == 1)
194+
isinstance(input_tensor, AffineQuantizedTensor)
195+
and isinstance(input_tensor._layout, Float8Layout)
196+
and input_tensor.dtype in (torch.float16, torch.bfloat16)
197+
and len(input_tensor.shape) >= 2
198+
and input_tensor.tensor_impl.scale.dtype == torch.float32
199+
and isinstance(weight_tensor, AffineQuantizedTensor)
200+
and isinstance(weight_tensor._layout, CutlassSemiSparseLayout)
201+
and weight_tensor.dtype == input_tensor.dtype
202+
and len(weight_tensor.shape) == 2
203+
and weight_tensor.tensor_impl.scale.dtype == torch.float32
204+
and (bias is None or bias.dtype == input_tensor.dtype)
205+
and (bias is None or len(bias.shape) == 1)
206206
)
207207

208208
if base_check:
209-
210209
# do extra check and reshape if needed
211210
input_tensor_squeezed = False
212-
if len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) and \
213-
len(input_tensor.tensor_impl.scale.shape) > 1 and \
214-
input_tensor.tensor_impl.scale.shape[-1] == 1:
215-
input_tensor.tensor_impl.scale = torch.squeeze(input_tensor.tensor_impl.scale, dim=-1)
211+
if (
212+
len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape)
213+
and len(input_tensor.tensor_impl.scale.shape) > 1
214+
and input_tensor.tensor_impl.scale.shape[-1] == 1
215+
):
216+
input_tensor.tensor_impl.scale = torch.squeeze(
217+
input_tensor.tensor_impl.scale, dim=-1
218+
)
216219
input_tensor_squeezed = True
217-
220+
218221
weight_tensor_squeezed = False
219-
if len(weight_tensor.tensor_impl.scale.shape) == 2 and \
220-
weight_tensor.tensor_impl.scale.shape[-1] == 1:
221-
weight_tensor.tensor_impl.scale = torch.squeeze(weight_tensor.tensor_impl.scale, dim=-1)
222+
if (
223+
len(weight_tensor.tensor_impl.scale.shape) == 2
224+
and weight_tensor.tensor_impl.scale.shape[-1] == 1
225+
):
226+
weight_tensor.tensor_impl.scale = torch.squeeze(
227+
weight_tensor.tensor_impl.scale, dim=-1
228+
)
222229
weight_tensor_squeezed = True
223230

224231
extra_check = (
225232
len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) - 1
226233
and len(weight_tensor.tensor_impl.scale.shape) == 1
227234
)
228235

229-
if not extra_check: # revert if extra check failed
236+
if not extra_check: # revert if extra check failed
230237
if input_tensor_squeezed:
231-
input_tensor.tensor_impl.scale = torch.unsqueeze(input_tensor.tensor_impl.scale, dim=-1)
238+
input_tensor.tensor_impl.scale = torch.unsqueeze(
239+
input_tensor.tensor_impl.scale, dim=-1
240+
)
232241
if weight_tensor_squeezed:
233-
weight_tensor.tensor_impl.scale = torch.unsqueeze(weight_tensor.tensor_impl.scale, dim=-1)
242+
weight_tensor.tensor_impl.scale = torch.unsqueeze(
243+
weight_tensor.tensor_impl.scale, dim=-1
244+
)
234245

235246
return extra_check
236247

0 commit comments

Comments
 (0)