Skip to content

Commit 5f59cbb

Browse files
committed
Deprecate top level quantization APIs
Summary: This PR deprecates a few quantization APIs and here are the bc-breaking notes: 1. int8 weight only quantization int8 weight only quant module swap API ``` apply_weight_only_int8_quant(model) ``` and int8 weight only tensor subclass API ``` change_linear_weights_to_int8_woqtensors(model) ``` --> unified tensor subclass API ``` quantize(model, get_apply_int8wo_quant())) ``` 2. int8 dynamic quantization ``` apply_dynamic_quant(model) ``` or ``` change_linear_weights_to_int8_dqtensors(model) ``` --> unified tensor subclass API ``` quantize(model, get_apply_int8dyn_quant())) ``` 3. int4 weight only quantization ``` change_linear_weights_to_int4_wotensors(model) ``` --> unified tensor subclass API ``` quantize(model, get_apply_int4wo_quant())) ``` Test Plan: python test/quantization/test_quant_api.py python test/integration/test_integration.py Reviewers: Subscribers: Tasks: Tags:
1 parent 950a893 commit 5f59cbb

File tree

7 files changed

+240
-292
lines changed

7 files changed

+240
-292
lines changed

test/integration/test_integration.py

Lines changed: 38 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,10 @@
2020
DynamicallyPerAxisQuantizedLinear,
2121
)
2222
from torchao.quantization.quant_api import (
23-
apply_dynamic_quant,
24-
apply_weight_only_int8_quant,
25-
change_linear_weights_to_int8_dqtensors,
26-
change_linear_weights_to_int8_woqtensors,
27-
change_linear_weights_to_int4_woqtensors,
23+
get_apply_int4wo_quant,
24+
get_apply_int8wo_quant,
25+
get_apply_int8dyn_quant,
26+
quantize,
2827
_replace_with_custom_fn_if_matches_filter,
2928
)
3029
from torchao.quantization.quant_primitives import (
@@ -73,7 +72,11 @@
7372
from parameterized import parameterized
7473
import itertools
7574
import logging
76-
from torchao.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4
75+
from torchao.utils import (
76+
TORCH_VERSION_AFTER_2_3,
77+
TORCH_VERSION_AFTER_2_4,
78+
unwrap_tensor_subclass,
79+
)
7780

7881
logger = logging.getLogger("INFO")
7982

@@ -82,9 +85,9 @@
8285

8386
# TODO: use this to reduce the number of tests
8487
TENSOR_SUBCLASS_APIS = [
85-
change_linear_weights_to_int8_dqtensors,
86-
change_linear_weights_to_int8_woqtensors,
87-
change_linear_weights_to_int4_woqtensors,
88+
get_apply_int4wo_quant,
89+
get_apply_int8wo_quant,
90+
get_apply_int8dyn_quant,
8891
]
8992

9093
COMMON_DEVICES = ["cpu", "cuda"]
@@ -736,7 +739,8 @@ def _test_lin_weight_subclass_api_impl(
736739
nn.Linear(k, n, device=test_device), nn.ReLU(), nn.Linear(n, n, device=test_device)
737740
).to(test_dtype)
738741
ref_f = mod(x)
739-
api(mod)
742+
quantize(mod, api())
743+
unwrap_tensor_subclass(mod)
740744

741745
test = mod(x)
742746
self.assertGreater(
@@ -756,13 +760,13 @@ def _test_lin_weight_subclass_api_impl(
756760
@unittest.skipIf(TORCH_VERSION_AFTER_2_4, "skip because there is some bug in inductor codegen")
757761
def test_int8_dynamic_quant_subclass_api(self, device, dtype):
758762
self._test_lin_weight_subclass_api_impl(
759-
change_linear_weights_to_int8_dqtensors, device, 35, test_dtype=dtype
763+
get_apply_int8dyn_quant, device, 35, test_dtype=dtype
760764
)
761765

762766
@parameterized.expand(COMMON_DEVICE_DTYPE)
763767
def test_int8_weight_only_quant_subclass_api(self, device, dtype):
764768
self._test_lin_weight_subclass_api_impl(
765-
change_linear_weights_to_int8_woqtensors, device, 40, test_dtype=dtype
769+
get_apply_int8wo_quant, device, 40, test_dtype=dtype
766770
)
767771

768772
@parameterized.expand(COMMON_DEVICE_DTYPE)
@@ -772,7 +776,7 @@ def test_int4_weight_only_quant_subclass_api(self, device, dtype):
772776
self.skipTest(f"Fails for {dtype}")
773777
for test_shape in ([(16, 1024, 16)] + ([(1, 1024, 256)] if device=='cuda' else [])):
774778
self._test_lin_weight_subclass_api_impl(
775-
change_linear_weights_to_int4_woqtensors,
779+
get_apply_int4wo_quant,
776780
device,
777781
15,
778782
test_shape=test_shape,
@@ -789,7 +793,7 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
789793
for inner_k_tiles in [4, 2]:
790794
kwargs = {"groupsize": groupsize, "inner_k_tiles": inner_k_tiles}
791795
self._test_lin_weight_subclass_api_impl(
792-
lambda mod: change_linear_weights_to_int4_woqtensors(mod, **kwargs),
796+
lambda: get_apply_int4wo_quant(**kwargs),
793797
device,
794798
15,
795799
test_shape=test_shape,
@@ -804,7 +808,7 @@ def test_dynamic_quant(self):
804808
m = nn.Sequential(nn.Linear(K, N))
805809

806810
y_ref = m(x)
807-
apply_dynamic_quant(m)
811+
quantize(m, get_apply_int8dyn_quant())
808812
y_test = m(x)
809813

810814
sqnr = compute_error(y_ref, y_test)
@@ -818,7 +822,7 @@ def test_weight_only_quant(self):
818822
x = torch.randn(*x_shape)
819823
m = nn.Sequential(nn.Linear(4, 5))
820824
y_ref = m(x)
821-
apply_weight_only_int8_quant(m)
825+
quantize(m, get_apply_int8wo_quant())
822826
y_wo = m(x)
823827
sqnr = compute_error(y_ref, y_wo)
824828
self.assertGreater(sqnr, 44.0)
@@ -841,7 +845,8 @@ def test_weight_only_quant_force_mixed_mm(self, device, dtype):
841845
x = torch.randn(*x_shape).to(device).to(dtype)
842846
m = nn.Sequential(nn.Linear(4, 5)).to(device).to(dtype)
843847
y_ref = m(x)
844-
apply_weight_only_int8_quant(m)
848+
m = quantize(m, get_apply_int8wo_quant())
849+
m = unwrap_tensor_subclass(m)
845850
m(x)
846851
m_c = torch.compile(m, mode="max-autotune")
847852
y_wo, (code,) = run_and_get_code(m_c, x)
@@ -868,7 +873,8 @@ def test_weight_only_quant_use_mixed_mm(self, device, dtype):
868873
x = torch.randn(*x_shape).to(device).to(dtype)
869874
m = nn.Sequential(nn.Linear(4, 5)).to(device).to(dtype)
870875
y_ref = m(x)
871-
apply_weight_only_int8_quant(m)
876+
m = quantize(m, get_apply_int8wo_quant())
877+
m = unwrap_tensor_subclass(m)
872878
m_c = torch.compile(m, mode="max-autotune")
873879
y_wo, (code,) = run_and_get_code(m_c, x)
874880
sqnr = compute_error(y_ref, y_wo)
@@ -908,7 +914,9 @@ def forward(self, x):
908914
ref_f = model(x)
909915

910916
# save quantized state_dict
911-
api(model)
917+
quantize(model, api())
918+
unwrap_tensor_subclass(model)
919+
912920
torch.save(model.state_dict(), "test.pth")
913921
# get quantized reference
914922
model_qc = torch.compile(model, mode="max-autotune")
@@ -919,11 +927,13 @@ def forward(self, x):
919927
# load model structure
920928
with torch.device('meta'):
921929
model = test_model().to(dtype=test_dtype)
922-
api(model)
930+
quantize(model, api())
931+
unwrap_tensor_subclass(model)
923932

924933
# load quantized state_dict
925934
state_dict = torch.load("test.pth", mmap=True)
926935
os.remove("test.pth")
936+
927937
model.load_state_dict(state_dict, assign=True)
928938
model = model.to(device=test_device, dtype=test_dtype).eval()
929939

@@ -939,20 +949,20 @@ def forward(self, x):
939949
def test_save_load_dqtensors(self, device, dtype):
940950
if device == "cpu":
941951
self.skipTest(f"indcutor failed for cpu right now")
942-
self._test_handle_save_load_meta_impl(change_linear_weights_to_int8_dqtensors, device, test_dtype=dtype)
952+
self._test_handle_save_load_meta_impl(get_apply_int8dyn_quant, device, test_dtype=dtype)
943953

944954
@parameterized.expand(COMMON_DEVICE_DTYPE)
945955
@torch.no_grad()
946956
def test_save_load_int8woqtensors(self, device, dtype):
947-
self._test_handle_save_load_meta_impl(change_linear_weights_to_int8_woqtensors, device, test_dtype=dtype)
957+
self._test_handle_save_load_meta_impl(get_apply_int8wo_quant, device, test_dtype=dtype)
948958

949959
@parameterized.expand(COMMON_DEVICE_DTYPE)
950960
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
951961
@torch.no_grad()
952962
def test_save_load_int4woqtensors(self, device, dtype):
953963
if dtype != torch.bfloat16:
954964
self.skipTest(f"Fails for {dtype}")
955-
self._test_handle_save_load_meta_impl(change_linear_weights_to_int4_woqtensors, device, 20, test_dtype=dtype)
965+
self._test_handle_save_load_meta_impl(get_apply_int4wo_quant, device, 20, test_dtype=dtype)
956966

957967

958968
class TorchCompileUnitTest(unittest.TestCase):
@@ -1271,8 +1281,8 @@ def forward(self, x):
12711281
model = test_model().to(dtype=test_dtype, device=test_device).eval()
12721282
ref_f = model(x)
12731283

1274-
kwargs = {"dtype": test_dtype}
1275-
api(model, **kwargs)
1284+
# kwargs = {"dtype": test_dtype}
1285+
quantize(model, api())
12761286

12771287
# running model
12781288
model(x)
@@ -1317,8 +1327,9 @@ def forward(self, x):
13171327
model = test_model().to(dtype=test_dtype, device=test_device).eval()
13181328
ref_f = model(x)
13191329

1320-
kwargs = {"dtype": test_dtype}
1321-
api(model, **kwargs)
1330+
# kwargs = {"dtype": test_dtype}
1331+
model = quantize(model, api())
1332+
model = unwrap_tensor_subclass(model)
13221333

13231334
# running model
13241335
ref = model(x)

test/quantization/test_quant_api.py

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,6 @@
3535
)
3636
from torchao.quantization.quant_api import (
3737
_replace_with_custom_fn_if_matches_filter,
38-
apply_dynamic_quant,
39-
apply_weight_only_int8_quant,
4038
Quantizer,
4139
TwoStepQuantizer,
4240
quantize,
@@ -53,6 +51,7 @@
5351
from torchao._models.llama.tokenizer import get_tokenizer
5452
from torchao._models.llama.model import Transformer, prepare_inputs_for_model
5553
import copy
54+
import tempfile
5655

5756

5857
def dynamic_quant(model, example_inputs):
@@ -62,20 +61,6 @@ def dynamic_quant(model, example_inputs):
6261
m = convert_pt2e(m)
6362
return m
6463

65-
def _apply_dynamic_quant(model):
66-
"""
67-
Applies dynamic symmetric per-token activation and per-channel weight
68-
quantization to all linear layers in the given model using
69-
module swaps.
70-
"""
71-
_replace_with_custom_fn_if_matches_filter(
72-
model,
73-
lambda linear_mod: dynamic_quant(linear_mod, (torch.randn(1, linear_mod.in_features),)),
74-
lambda mod, fqn: isinstance(mod, torch.nn.Linear),
75-
)
76-
return model
77-
78-
7964
def capture_and_prepare(model, example_inputs):
8065
m = torch.export.export(model, example_inputs)
8166
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_dynamic=True))
@@ -104,7 +89,7 @@ def convert(self, model: torch.nn.Module) -> torch.nn.Module:
10489

10590
class TorchCompileDynamicQuantizer(Quantizer):
10691
def quantize(self, model: torch.nn.Module) -> torch.nn.Module:
107-
apply_dynamic_quant(model)
92+
quantize(model, get_apply_int8dyn_qunat())
10893
return model
10994

11095
class ToyLinearModel(torch.nn.Module):
@@ -127,11 +112,13 @@ def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs
127112
The deprecated implementation for int8 dynamic quant API, used as a reference for
128113
numerics and performance
129114
"""
130-
from torchao.quantization.quant_api import _in_features_greater_than_16
131115
from torchao.quantization.quant_api import _is_linear
132116
from torchao.quantization.quant_api import _get_subclass_inserter
133117
from torchao.quantization.subclass import Int8DynamicallyQuantizedLinearWeight
134118

119+
def _in_features_greater_than_16(mod, *args):
120+
return hasattr(mod, "in_features") and mod.in_features > 16
121+
135122
if filter_fn is None:
136123
filter_fn = lambda *args: _is_linear(*args) and _in_features_greater_than_16(
137124
*args
@@ -167,7 +154,7 @@ class TestQuantFlow(unittest.TestCase):
167154
def test_dynamic_quant_gpu_singleline(self):
168155
m = ToyLinearModel().eval()
169156
example_inputs = m.example_inputs()
170-
m = _apply_dynamic_quant(m)
157+
m = quantize(m, get_apply_int8dyn_quant())
171158
quantized = m(*example_inputs)
172159
# AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64
173160
# While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {})
@@ -205,16 +192,21 @@ def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self):
205192
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
206193
def test_int8_wo_quant_save_load(self):
207194
m = ToyLinearModel().eval().cpu()
208-
apply_weight_only_int8_quant(m)
195+
m = quantize(m, get_apply_int8wo_quant())
196+
197+
from torchao.utils import unwrap_tensor_subclass
198+
unwrap_tensor_subclass(m)
209199
example_inputs = m.example_inputs()
210200
ref = m(*example_inputs)
211-
_TMP_FN = "_test.pt"
212-
torch.save(m.state_dict(), _TMP_FN)
201+
with tempfile.NamedTemporaryFile() as f:
202+
torch.save(m.state_dict(), f)
203+
f.seek(0)
204+
state_dict = torch.load(f)
213205

214-
state_dict = torch.load(_TMP_FN)
215-
os.remove(_TMP_FN)
216206
m2 = ToyLinearModel().eval()
217-
apply_weight_only_int8_quant(m2)
207+
m2 = quantize(m2, get_apply_int8wo_quant())
208+
unwrap_tensor_subclass(m2)
209+
218210
m2.load_state_dict(state_dict)
219211
m2 = m2.to(device="cuda")
220212
example_inputs = map(lambda x: x.cuda(), example_inputs)

torchao/dtypes/aqt.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -386,11 +386,13 @@ def __new__(
386386
quant_min: Optional[int] = None,
387387
quant_max: Optional[int] = None,
388388
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
389+
device=None,
389390
dtype=None,
391+
memory_format=None,
390392
strides=None,
391393
):
392394
kwargs = {}
393-
kwargs["device"] = layout_tensor.device
395+
kwargs["device"] = layout_tensor.device if device is None else device
394396
kwargs["layout"] = (
395397
kwargs.get("layout") if kwargs.get("layout", False) else layout_tensor.layout
396398
)
@@ -500,7 +502,7 @@ def from_float(
500502
)
501503

502504
@property
503-
def layout(self) -> str:
505+
def extended_layout(self) -> str:
504506
return self.layout_tensor.extended_layout
505507

506508
@classmethod
@@ -596,8 +598,8 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
596598
is_cuda and
597599
input_is_int8 and
598600
input_tensor.dtype == weight_qtensor.dtype and
599-
input_tensor.layout == "plain" and
600-
weight_qtensor.layout == "plain"
601+
input_tensor.extended_layout == "plain" and
602+
weight_qtensor.extended_layout == "plain"
601603
):
602604
#
603605
# 1. do the matrix form of dot(X_i, W_j)
@@ -639,7 +641,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
639641
weight_qtensor.dtype == torch.bfloat16 and
640642
len(weight_qtensor.shape) == 2 and
641643
weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT and
642-
weight_qtensor.layout == "tensor_core_tiled"
644+
weight_qtensor.extended_layout == "tensor_core_tiled"
643645
):
644646
assert weight_qtensor.block_size[0] == 1, f"Requires groupwise quantization, got block_size: {block_size}"
645647
assert input_tensor.shape[-1] == weight_qtensor.shape[1], (
@@ -682,7 +684,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
682684
weight_qtensor.block_size[0] == 1 and
683685
weight_qtensor.block_size[1] == weight_qtensor.shape[1] and
684686
weight_qtensor.zero_point_domain == ZeroPointDomain.INT and
685-
weight_qtensor.layout == "plain"
687+
weight_qtensor.extended_layout == "plain"
686688
):
687689
# TODO: enable cpu and mps efficient path
688690
# per channel int8 weight only quantizated mm

0 commit comments

Comments
 (0)