Skip to content

Commit d382147

Browse files
committed
Deprecate top level quantization APIs
Summary: This PR deprecates a few quantization APIs and here are the bc-breaking notes: 1. int8 weight only quantization int8 weight only quant module swap API ``` apply_weight_only_int8_quant(model) ``` and int8 weight only tensor subclass API ``` change_linear_weights_to_int8_woqtensors(model) ``` --> unified tensor subclass API ``` quantize(model, get_apply_int8wo_quant())) ``` 2. int8 dynamic quantization ``` apply_dynamic_quant(model) ``` or ``` change_linear_weights_to_int8_dqtensors(model) ``` --> unified tensor subclass API ``` quantize(model, get_apply_int8dyn_quant())) ``` 3. int4 weight only quantization ``` change_linear_weights_to_int4_wotensors(model) ``` --> unified tensor subclass API ``` quantize(model, get_apply_int4wo_quant())) ``` Test Plan: python test/quantization/test_quant_api.py python test/integration/test_integration.py Reviewers: Subscribers: Tasks: Tags:
1 parent 79f2c7f commit d382147

File tree

7 files changed

+238
-282
lines changed

7 files changed

+238
-282
lines changed

test/integration/test_integration.py

Lines changed: 38 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,10 @@
2020
DynamicallyPerAxisQuantizedLinear,
2121
)
2222
from torchao.quantization.quant_api import (
23-
apply_dynamic_quant,
24-
apply_weight_only_int8_quant,
25-
change_linear_weights_to_int8_dqtensors,
26-
change_linear_weights_to_int8_woqtensors,
27-
change_linear_weights_to_int4_woqtensors,
23+
get_apply_int4wo_quant,
24+
get_apply_int8wo_quant,
25+
get_apply_int8dyn_quant,
26+
quantize,
2827
_replace_with_custom_fn_if_matches_filter,
2928
)
3029
from torchao.quantization.quant_primitives import (
@@ -73,7 +72,11 @@
7372
from parameterized import parameterized
7473
import itertools
7574
import logging
76-
from torchao.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4
75+
from torchao.utils import (
76+
TORCH_VERSION_AFTER_2_3,
77+
TORCH_VERSION_AFTER_2_4,
78+
unwrap_tensor_subclass,
79+
)
7780

7881
logger = logging.getLogger("INFO")
7982

@@ -82,9 +85,9 @@
8285

8386
# TODO: use this to reduce the number of tests
8487
TENSOR_SUBCLASS_APIS = [
85-
change_linear_weights_to_int8_dqtensors,
86-
change_linear_weights_to_int8_woqtensors,
87-
change_linear_weights_to_int4_woqtensors,
88+
get_apply_int4wo_quant,
89+
get_apply_int8wo_quant,
90+
get_apply_int8dyn_quant,
8891
]
8992

9093
COMMON_DEVICES = ["cpu", "cuda"]
@@ -736,7 +739,8 @@ def _test_lin_weight_subclass_api_impl(
736739
nn.Linear(k, n, device=test_device), nn.ReLU(), nn.Linear(n, n, device=test_device)
737740
).to(test_dtype)
738741
ref_f = mod(x)
739-
api(mod)
742+
quantize(mod, api())
743+
unwrap_tensor_subclass(mod)
740744

741745
test = mod(x)
742746
self.assertGreater(
@@ -756,13 +760,13 @@ def _test_lin_weight_subclass_api_impl(
756760
@unittest.skipIf(TORCH_VERSION_AFTER_2_4, "skip because there is some bug in inductor codegen")
757761
def test_int8_dynamic_quant_subclass_api(self, device, dtype):
758762
self._test_lin_weight_subclass_api_impl(
759-
change_linear_weights_to_int8_dqtensors, device, 35, test_dtype=dtype
763+
get_apply_int8dyn_quant, device, 35, test_dtype=dtype
760764
)
761765

762766
@parameterized.expand(COMMON_DEVICE_DTYPE)
763767
def test_int8_weight_only_quant_subclass_api(self, device, dtype):
764768
self._test_lin_weight_subclass_api_impl(
765-
change_linear_weights_to_int8_woqtensors, device, 40, test_dtype=dtype
769+
get_apply_int8wo_quant, device, 40, test_dtype=dtype
766770
)
767771

768772
@parameterized.expand(COMMON_DEVICE_DTYPE)
@@ -772,7 +776,7 @@ def test_int4_weight_only_quant_subclass_api(self, device, dtype):
772776
self.skipTest(f"Fails for {dtype}")
773777
for test_shape in ([(16, 1024, 16)] + ([(1, 1024, 256)] if device=='cuda' else [])):
774778
self._test_lin_weight_subclass_api_impl(
775-
change_linear_weights_to_int4_woqtensors,
779+
get_apply_int4wo_quant,
776780
device,
777781
15,
778782
test_shape=test_shape,
@@ -789,7 +793,7 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
789793
for inner_k_tiles in [4, 2]:
790794
kwargs = {"groupsize": groupsize, "inner_k_tiles": inner_k_tiles}
791795
self._test_lin_weight_subclass_api_impl(
792-
lambda mod: change_linear_weights_to_int4_woqtensors(mod, **kwargs),
796+
lambda: get_apply_int4wo_quant(**kwargs),
793797
device,
794798
15,
795799
test_shape=test_shape,
@@ -804,7 +808,7 @@ def test_dynamic_quant(self):
804808
m = nn.Sequential(nn.Linear(K, N))
805809

806810
y_ref = m(x)
807-
apply_dynamic_quant(m)
811+
quantize(m, get_apply_int8dyn_quant())
808812
y_test = m(x)
809813

810814
sqnr = compute_error(y_ref, y_test)
@@ -818,7 +822,7 @@ def test_weight_only_quant(self):
818822
x = torch.randn(*x_shape)
819823
m = nn.Sequential(nn.Linear(4, 5))
820824
y_ref = m(x)
821-
apply_weight_only_int8_quant(m)
825+
quantize(m, get_apply_int8wo_quant())
822826
y_wo = m(x)
823827
sqnr = compute_error(y_ref, y_wo)
824828
self.assertGreater(sqnr, 44.0)
@@ -841,7 +845,8 @@ def test_weight_only_quant_force_mixed_mm(self, device, dtype):
841845
x = torch.randn(*x_shape).to(device).to(dtype)
842846
m = nn.Sequential(nn.Linear(4, 5)).to(device).to(dtype)
843847
y_ref = m(x)
844-
apply_weight_only_int8_quant(m)
848+
m = quantize(m, get_apply_int8wo_quant())
849+
m = unwrap_tensor_subclass(m)
845850
m(x)
846851
m_c = torch.compile(m, mode="max-autotune")
847852
y_wo, (code,) = run_and_get_code(m_c, x)
@@ -868,7 +873,8 @@ def test_weight_only_quant_use_mixed_mm(self, device, dtype):
868873
x = torch.randn(*x_shape).to(device).to(dtype)
869874
m = nn.Sequential(nn.Linear(4, 5)).to(device).to(dtype)
870875
y_ref = m(x)
871-
apply_weight_only_int8_quant(m)
876+
m = quantize(m, get_apply_int8wo_quant())
877+
m = unwrap_tensor_subclass(m)
872878
m_c = torch.compile(m, mode="max-autotune")
873879
y_wo, (code,) = run_and_get_code(m_c, x)
874880
sqnr = compute_error(y_ref, y_wo)
@@ -908,7 +914,9 @@ def forward(self, x):
908914
ref_f = model(x)
909915

910916
# save quantized state_dict
911-
api(model)
917+
quantize(model, api())
918+
unwrap_tensor_subclass(model)
919+
912920
torch.save(model.state_dict(), "test.pth")
913921
# get quantized reference
914922
model_qc = torch.compile(model, mode="max-autotune")
@@ -919,11 +927,13 @@ def forward(self, x):
919927
# load model structure
920928
with torch.device('meta'):
921929
model = test_model().to(dtype=test_dtype)
922-
api(model)
930+
quantize(model, api())
931+
unwrap_tensor_subclass(model)
923932

924933
# load quantized state_dict
925934
state_dict = torch.load("test.pth", mmap=True)
926935
os.remove("test.pth")
936+
927937
model.load_state_dict(state_dict, assign=True)
928938
model = model.to(device=test_device, dtype=test_dtype).eval()
929939

@@ -939,20 +949,20 @@ def forward(self, x):
939949
def test_save_load_dqtensors(self, device, dtype):
940950
if device == "cpu":
941951
self.skipTest(f"indcutor failed for cpu right now")
942-
self._test_handle_save_load_meta_impl(change_linear_weights_to_int8_dqtensors, device, test_dtype=dtype)
952+
self._test_handle_save_load_meta_impl(get_apply_int8dyn_quant, device, test_dtype=dtype)
943953

944954
@parameterized.expand(COMMON_DEVICE_DTYPE)
945955
@torch.no_grad()
946956
def test_save_load_int8woqtensors(self, device, dtype):
947-
self._test_handle_save_load_meta_impl(change_linear_weights_to_int8_woqtensors, device, test_dtype=dtype)
957+
self._test_handle_save_load_meta_impl(get_apply_int8wo_quant, device, test_dtype=dtype)
948958

949959
@parameterized.expand(COMMON_DEVICE_DTYPE)
950960
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
951961
@torch.no_grad()
952962
def test_save_load_int4woqtensors(self, device, dtype):
953963
if dtype != torch.bfloat16:
954964
self.skipTest(f"Fails for {dtype}")
955-
self._test_handle_save_load_meta_impl(change_linear_weights_to_int4_woqtensors, device, 20, test_dtype=dtype)
965+
self._test_handle_save_load_meta_impl(get_apply_int4wo_quant, device, 20, test_dtype=dtype)
956966

957967

958968
class TorchCompileUnitTest(unittest.TestCase):
@@ -1271,8 +1281,8 @@ def forward(self, x):
12711281
model = test_model().to(dtype=test_dtype, device=test_device).eval()
12721282
ref_f = model(x)
12731283

1274-
kwargs = {"dtype": test_dtype}
1275-
api(model, **kwargs)
1284+
# kwargs = {"dtype": test_dtype}
1285+
quantize(model, api())
12761286

12771287
# running model
12781288
model(x)
@@ -1317,8 +1327,9 @@ def forward(self, x):
13171327
model = test_model().to(dtype=test_dtype, device=test_device).eval()
13181328
ref_f = model(x)
13191329

1320-
kwargs = {"dtype": test_dtype}
1321-
api(model, **kwargs)
1330+
# kwargs = {"dtype": test_dtype}
1331+
model = quantize(model, api())
1332+
model = unwrap_tensor_subclass(model)
13221333

13231334
# running model
13241335
ref = model(x)

test/quantization/test_quant_api.py

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,6 @@
3434
)
3535
from torchao.quantization.quant_api import (
3636
_replace_with_custom_fn_if_matches_filter,
37-
apply_dynamic_quant,
38-
apply_weight_only_int8_quant,
3937
Quantizer,
4038
TwoStepQuantizer,
4139
quantize,
@@ -52,6 +50,7 @@
5250
from sentencepiece import SentencePieceProcessor
5351
from model import Transformer, prepare_inputs_for_model
5452
import copy
53+
import tempfile
5554

5655

5756
def dynamic_quant(model, example_inputs):
@@ -61,20 +60,6 @@ def dynamic_quant(model, example_inputs):
6160
m = convert_pt2e(m)
6261
return m
6362

64-
def _apply_dynamic_quant(model):
65-
"""
66-
Applies dynamic symmetric per-token activation and per-channel weight
67-
quantization to all linear layers in the given model using
68-
module swaps.
69-
"""
70-
_replace_with_custom_fn_if_matches_filter(
71-
model,
72-
lambda linear_mod: dynamic_quant(linear_mod, (torch.randn(1, linear_mod.in_features),)),
73-
lambda mod, fqn: isinstance(mod, torch.nn.Linear),
74-
)
75-
return model
76-
77-
7863
def capture_and_prepare(model, example_inputs):
7964
m = torch.export.export(model, example_inputs)
8065
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_dynamic=True))
@@ -103,7 +88,7 @@ def convert(self, model: torch.nn.Module) -> torch.nn.Module:
10388

10489
class TorchCompileDynamicQuantizer(Quantizer):
10590
def quantize(self, model: torch.nn.Module) -> torch.nn.Module:
106-
apply_dynamic_quant(model)
91+
quantize(model, get_apply_int8dyn_qunat())
10792
return model
10893

10994
class ToyLinearModel(torch.nn.Module):
@@ -126,11 +111,13 @@ def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs
126111
The deprecated implementation for int8 dynamic quant API, used as a reference for
127112
numerics and performance
128113
"""
129-
from torchao.quantization.quant_api import _in_features_greater_than_16
130114
from torchao.quantization.quant_api import _is_linear
131115
from torchao.quantization.quant_api import _get_subclass_inserter
132116
from torchao.quantization.subclass import Int8DynamicallyQuantizedLinearWeight
133117

118+
def _in_features_greater_than_16(mod, *args):
119+
return hasattr(mod, "in_features") and mod.in_features > 16
120+
134121
if filter_fn is None:
135122
filter_fn = lambda *args: _is_linear(*args) and _in_features_greater_than_16(
136123
*args
@@ -166,7 +153,7 @@ class TestQuantFlow(unittest.TestCase):
166153
def test_dynamic_quant_gpu_singleline(self):
167154
m = ToyLinearModel().eval()
168155
example_inputs = m.example_inputs()
169-
m = _apply_dynamic_quant(m)
156+
m = quantize(m, get_apply_int8dyn_quant())
170157
quantized = m(*example_inputs)
171158
# AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64
172159
# While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {})
@@ -204,16 +191,21 @@ def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self):
204191
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
205192
def test_int8_wo_quant_save_load(self):
206193
m = ToyLinearModel().eval().cpu()
207-
apply_weight_only_int8_quant(m)
194+
m = quantize(m, get_apply_int8wo_quant())
195+
196+
from torchao.utils import unwrap_tensor_subclass
197+
unwrap_tensor_subclass(m)
208198
example_inputs = m.example_inputs()
209199
ref = m(*example_inputs)
210-
_TMP_FN = "_test.pt"
211-
torch.save(m.state_dict(), _TMP_FN)
200+
with tempfile.NamedTemporaryFile() as f:
201+
torch.save(m.state_dict(), f)
202+
f.seek(0)
203+
state_dict = torch.load(f)
212204

213-
state_dict = torch.load(_TMP_FN)
214-
os.remove(_TMP_FN)
215205
m2 = ToyLinearModel().eval()
216-
apply_weight_only_int8_quant(m2)
206+
m2 = quantize(m2, get_apply_int8wo_quant())
207+
unwrap_tensor_subclass(m2)
208+
217209
m2.load_state_dict(state_dict)
218210
m2 = m2.to(device="cuda")
219211
example_inputs = map(lambda x: x.cuda(), example_inputs)

torchao/dtypes/aqt.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -386,11 +386,13 @@ def __new__(
386386
quant_min: Optional[int] = None,
387387
quant_max: Optional[int] = None,
388388
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
389+
device=None,
389390
dtype=None,
391+
memory_format=None,
390392
strides=None,
391393
):
392394
kwargs = {}
393-
kwargs["device"] = layout_tensor.device
395+
kwargs["device"] = layout_tensor.device if device is None else device
394396
kwargs["layout"] = (
395397
kwargs.get("layout") if kwargs.get("layout", False) else layout_tensor.layout
396398
)
@@ -500,7 +502,7 @@ def from_float(
500502
)
501503

502504
@property
503-
def layout(self) -> str:
505+
def extended_layout(self) -> str:
504506
return self.layout_tensor.extended_layout
505507

506508
@classmethod
@@ -596,8 +598,8 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
596598
is_cuda and
597599
input_is_int8 and
598600
input_tensor.dtype == weight_qtensor.dtype and
599-
input_tensor.layout == "plain" and
600-
weight_qtensor.layout == "plain"
601+
input_tensor.extended_layout == "plain" and
602+
weight_qtensor.extended_layout == "plain"
601603
):
602604
#
603605
# 1. do the matrix form of dot(X_i, W_j)
@@ -639,7 +641,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
639641
weight_qtensor.dtype == torch.bfloat16 and
640642
len(weight_qtensor.shape) == 2 and
641643
weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT and
642-
weight_qtensor.layout == "tensor_core_tiled"
644+
weight_qtensor.extended_layout == "tensor_core_tiled"
643645
):
644646
assert weight_qtensor.block_size[0] == 1, f"Requires groupwise quantization, got block_size: {block_size}"
645647
assert input_tensor.shape[-1] == weight_qtensor.shape[1], (
@@ -682,7 +684,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
682684
weight_qtensor.block_size[0] == 1 and
683685
weight_qtensor.block_size[1] == weight_qtensor.shape[1] and
684686
weight_qtensor.zero_point_domain == ZeroPointDomain.INT and
685-
weight_qtensor.layout == "plain"
687+
weight_qtensor.extended_layout == "plain"
686688
):
687689
# TODO: enable cpu and mps efficient path
688690
# per channel int8 weight only quantizated mm

0 commit comments

Comments
 (0)