Skip to content

Commit d3aba66

Browse files
committed
Update
[ghstack-poisoned]
2 parents b91f59b + 53d2486 commit d3aba66

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

test/integration/test_integration.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -885,7 +885,6 @@ def _test_lin_weight_subclass_api_impl(
885885

886886

887887
@parameterized.expand(COMMON_DEVICE_DTYPE)
888-
@unittest.skipIf(TORCH_VERSION_AT_LEAST_2_4, "skip because there is some bug in inductor codegen")
889888
def test_int8_dynamic_quant_subclass_api(self, device, dtype):
890889
self._test_lin_weight_subclass_api_impl(
891890
_int8da_int8w_api, device, 35, test_dtype=dtype

torchao/dtypes/uintx/semi_sparse_layout.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,15 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_impl(
4444
# must pad
4545
row, col = tmp.shape
4646
from torch.sparse import SparseSemiStructuredTensorCUSPARSELT
47+
4748
tmp_padded = SparseSemiStructuredTensorCUSPARSELT._pad_dense_input(tmp)
4849
# we fuse one of the scalar matrix multiplications (w_scales) into the sparse mm
4950
y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm(
5051
w_vals_int8,
5152
tmp_padded.t(),
5253
alpha=w_scales.to(torch.float32),
5354
out_dtype=torch.bfloat16,
54-
).t()[:row, :]
55+
).t()[:row, :]
5556
y = (y_dot_bf16_w_scales_fused * x_scales.reshape(-1, 1)).reshape(
5657
*x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1]
5758
)

0 commit comments

Comments
 (0)