Skip to content

Commit ccd883b

Browse files
authored
Call narrow only for TensorCoreTiledLayout (#1207)
* Call narrow only for TensorCoreTiledLayout only Summary: att, previously in #914 we added narrow op for all layout, the introduced narrow op breaks the pattern for int8 dynamic activation int4 weight quant for executorch, this PR guarded narrow op for tensor core tiled layout only If similar things coming up in the future we can factor this into a proper API for Layout or TensorImpl Test Plan: python test/test_integration.py -k test_export Reviewers: Subscribers: Tasks: Tags: * enable test * version * skip aoti * version update * skip aoti
1 parent 2ba1a61 commit ccd883b

File tree

2 files changed

+38
-25
lines changed

2 files changed

+38
-25
lines changed

test/integration/test_integration.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
int4_weight_only,
2525
int8_weight_only,
2626
int8_dynamic_activation_int8_weight,
27+
int8_dynamic_activation_int4_weight,
2728
quantize_,
2829
_replace_with_custom_fn_if_matches_filter,
2930
)
@@ -137,6 +138,12 @@ def _int4wo_api(mod):
137138
else:
138139
change_linear_weights_to_int4_woqtensors(mod)
139140

141+
def _int8da_int4w_api(mod):
142+
quantize_(mod, int8_dynamic_activation_int4_weight(), set_inductor_config=False)
143+
if not TORCH_VERSION_AT_LEAST_2_5:
144+
unwrap_tensor_subclass(mod)
145+
146+
140147
# TODO: use this to reduce the number of tests
141148
TENSOR_SUBCLASS_APIS = [
142149
_int8wo_api,
@@ -781,7 +788,7 @@ def test_aq_float8_dynamic_quant_rowwise_scaling_subclass(self, device, dtype):
781788
self._test_lin_weight_subclass_impl(
782789
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight.from_float, device, 25, test_dtype=dtype
783790
)
784-
791+
785792
@parameterized.expand(COMMON_DEVICE_DTYPE)
786793
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch")
787794
@unittest.skipIf(not is_H100, "Need H100 to run")
@@ -973,11 +980,11 @@ def test_weight_only_groupwise_embedding_quant(self):
973980
group_size = 64
974981
m = nn.Embedding(4096, 128)
975982
input = torch.randint(0, 4096, (1, 6))
976-
983+
977984
quantize_(m, int8_weight_only(group_size=group_size), filter_fn=lambda x, *args: isinstance(x, nn.Embedding))
978985
y_q = m(input)
979986
y_ref = m.weight.dequantize()[input]
980-
987+
981988
sqnr = compute_error(y_ref, y_q)
982989

983990
self.assertGreater(sqnr, 45.0)
@@ -1486,22 +1493,22 @@ def forward(self, x):
14861493

14871494

14881495

1496+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "requires 2.5+.")
1497+
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
1498+
@unittest.skip("AOTI tests are failing right now")
14891499
class TestAOTI(unittest.TestCase):
14901500
@parameterized.expand(
14911501
list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)),
14921502
)
1493-
@run_supported_device_dtype
14941503
def test_aoti(self, api, test_device, test_dtype):
1495-
if not TORCH_VERSION_AT_LEAST_2_4:
1496-
self.skipTest("aoti compatibility requires 2.4+.")
1497-
1498-
print(f"TestAOTI: {api}, {test_device}, {test_dtype}")
1499-
logger.info(f"TestAOTI: {api}, {test_device}, {test_dtype}")
15001504
if api is change_linear_weights_to_int8_dqtensors and test_device == "cuda":
15011505
self.skipTest(f"{api} in {test_device} is not support for aoti compilation yet")
15021506

1503-
if test_dtype != torch.bfloat16:
1504-
self.skipTest(f"{api} in {test_dtype} is not support for aoti compilation yet")
1507+
if test_device == "cuda" and torch.cuda.is_available() and test_dtype == torch.bfloat16 and torch.cuda.get_device_capability() < (8, 0):
1508+
self.skipTest("Need CUDA and SM80+ available.")
1509+
1510+
1511+
logger.info(f"TestAOTI: {api}, {test_device}, {test_dtype}")
15051512

15061513
m, k, n = 32, 64, 32
15071514

@@ -1525,29 +1532,30 @@ def forward(self, x):
15251532
ref_f = model(x)
15261533

15271534
api(model)
1535+
unwrap_tensor_subclass(model)
15281536

15291537
# running model
15301538
model(x)
15311539

15321540
# make sure it compiles
1541+
torch._inductor.config.mixed_mm_choice = "triton"
1542+
15331543
example_inputs = (x,)
1534-
torch._export.aot_compile(model, example_inputs)
1544+
torch._inductor.aoti_compile_and_package(torch.export.export(model, example_inputs), example_inputs)
15351545

15361546

1547+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "requires 2.5+.")
1548+
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
15371549
class TestExport(unittest.TestCase):
15381550
@parameterized.expand(
1539-
list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)),
1551+
list(itertools.product(TENSOR_SUBCLASS_APIS + [_int8da_int4w_api], COMMON_DEVICES, COMMON_DTYPES)),
15401552
)
1541-
@run_supported_device_dtype
15421553
def test_export(self, api, test_device, test_dtype):
1543-
if not TORCH_VERSION_AT_LEAST_2_4:
1544-
self.skipTest("aoti compatibility requires 2.4+.")
1554+
if test_device == "cuda" and torch.cuda.is_available() and test_dtype == torch.bfloat16 and torch.cuda.get_device_capability() < (8, 0):
1555+
self.skipTest("Need CUDA and SM80+ available.")
15451556

15461557
logger.info(f"TestExport: {api}, {test_device}, {test_dtype}")
15471558

1548-
if test_dtype != torch.bfloat16:
1549-
self.skipTest(f"{api} in {test_dtype} is not support for aoti compilation yet")
1550-
15511559
m, k, n = 32, 64, 32
15521560

15531561
class test_model(nn.Module):
@@ -1570,6 +1578,7 @@ def forward(self, x):
15701578
ref_f = model(x)
15711579

15721580
api(model)
1581+
unwrap_tensor_subclass(model)
15731582

15741583
# running model
15751584
ref = model(x)
@@ -1585,10 +1594,11 @@ def forward(self, x):
15851594
model = torch._export.capture_pre_autograd_graph(model, example_inputs)
15861595
after_export = model(x)
15871596
self.assertTrue(torch.equal(after_export, ref))
1588-
if api is _int8da_int8w_api:
1597+
if api is _int8da_int4w_api:
15891598
targets = [n.target for n in model.graph.nodes]
15901599
self.assertTrue(torch.ops.quant.choose_qparams_affine.default in targets)
15911600
self.assertTrue(torch.ops.quant.quantize_affine.default in targets)
1601+
self.assertFalse(torch.ops.aten.narrow.default in targets)
15921602

15931603

15941604

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -238,10 +238,13 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
238238
self.zero_point_domain,
239239
output_dtype=output_dtype,
240240
)
241-
# need to return to original shape if tensor was padded
242-
# in preprocessing
243-
for dim, dim_size in enumerate(self.shape):
244-
dq = dq.narrow(dim, 0, dim_size)
241+
if isinstance(self._layout, TensorCoreTiledLayout):
242+
# need to return to original shape if tensor was padded
243+
# in preprocessing
244+
# TODO: we could add an API for this if there are more use cases
245+
# (e.g. dequant_post_process) in TensorImpl or Layout
246+
for dim, dim_size in enumerate(self.shape):
247+
dq = dq.narrow(dim, 0, dim_size)
245248
return dq
246249

247250
@staticmethod
@@ -1698,7 +1701,7 @@ def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias):
16981701
output_dtype = input_tensor.dtype
16991702
y = y.to(output_dtype)
17001703
if bias is not None:
1701-
y += bias
1704+
y = y + bias
17021705
return y
17031706

17041707

0 commit comments

Comments
 (0)