Skip to content

Commit d593e74

Browse files
committed
Refactor torchao.prototype.parq.quant.quant_api
1 parent ef8f26c commit d593e74

File tree

7 files changed

+182
-129
lines changed

7 files changed

+182
-129
lines changed

test/prototype/test_dynamic_activation_lut.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
import torch.nn as nn
1515

1616
from torchao.core.config import AOBaseConfig
17-
from torchao.prototype.parq.quant import StretchedUnifTorchaoQuantizer
18-
from torchao.prototype.parq.quant.quant_api import StretchedIntxWeightOnlyConfig
17+
from torchao.prototype.parq.quant import (
18+
StretchedIntxWeightOnlyConfig,
19+
StretchedUnifTorchaoQuantizer,
20+
)
1921
from torchao.prototype.quantization.dynamic_activation_lut import (
2022
StretchedAffineQuantizedTensor_to_Int8DynamicActivationLutTensorConfig,
2123
)

test/prototype/test_parq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@
2121
Int4UnifTorchaoQuantizer,
2222
LSBQuantizer,
2323
Quantizer,
24+
StretchedIntxWeightOnlyConfig,
2425
StretchedUnifTorchaoQuantizer,
2526
TernaryUnifQuantizer,
2627
UnifQuantizer,
2728
UnifTorchaoQuantizer,
2829
)
29-
from torchao.prototype.parq.quant.quant_api import StretchedIntxWeightOnlyConfig
3030
from torchao.prototype.parq.quant.uniform_torchao import _BIT_WIDTH_TO_DTYPE
3131
from torchao.quantization.granularity import PerGroup
3232
from torchao.quantization.qat import IntxFakeQuantizeConfig, QATConfig

torchao/prototype/parq/optim/quantopt.py

Lines changed: 28 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -13,25 +13,9 @@
1313
from torch import Tensor, nn
1414
from torch.optim import Optimizer
1515

16-
from torchao.quantization import (
17-
Int4WeightOnlyConfig,
18-
Int8DynamicActivationIntxWeightConfig,
19-
IntxWeightOnlyConfig,
20-
MappingType,
21-
PerGroup,
22-
PerRow,
23-
quantize_,
24-
)
25-
from torchao.quantization.quantize_.common import PackingFormat
26-
27-
from ..quant import Quantizer
28-
from ..quant.quant_api import StretchedIntxWeightOnlyConfig
29-
from ..quant.uniform_torchao import (
30-
_BIT_WIDTH_TO_DTYPE,
31-
Int4UnifTorchaoQuantizer,
32-
StretchedUnifTorchaoQuantizer,
33-
UnifTorchaoQuantizer,
34-
)
16+
from torchao.quantization import quantize_
17+
18+
from ..quant import Quantizer, UnifTorchaoQuantizer, get_config_from_quantizer
3519
from ..utils import HAS_DTENSOR, is_dtensor
3620
from .proxmap import ProxMap
3721

@@ -109,6 +93,23 @@ def __repr__(self) -> str:
10993
def state(self) -> defaultdict[Tensor, Any]: # pyre-ignore[3]
11094
return self._state if hasattr(self, "_state") else self.base_optimizer.state
11195

96+
@property
97+
def num_steps(self) -> int:
98+
for group in self.regularized_param_groups():
99+
return group.setdefault("num_steps", 0)
100+
101+
@num_steps.setter
102+
def num_steps(self, value: int) -> None:
103+
for group in self.regularized_param_groups():
104+
group["num_steps"] = value
105+
return
106+
107+
@num_steps.deleter
108+
def num_steps(self) -> None:
109+
for group in self.regularized_param_groups():
110+
group.pop("num_steps", None)
111+
return
112+
112113
@staticmethod
113114
def quantize_(
114115
p: Tensor,
@@ -165,40 +166,15 @@ def torchao_convert(self, model: nn.Module) -> None:
165166
if not isinstance(quantizer, UnifTorchaoQuantizer):
166167
continue
167168

168-
weight_dtype = _BIT_WIDTH_TO_DTYPE[group["quant_bits"]]
169-
granularity = (
170-
PerGroup(group["quant_block_size"])
171-
if "quant_block_size" in group
172-
else PerRow()
169+
device = group["params"][0].device
170+
is_embed = all(p.data_ptr() in embed_data_ptrs for p in group["params"])
171+
config = get_config_from_quantizer(
172+
quantizer,
173+
is_embed,
174+
device,
175+
group["quant_bits"],
176+
group.get("quant_block_size"),
173177
)
174-
version = 2
175-
if isinstance(quantizer, Int4UnifTorchaoQuantizer):
176-
config = Int4WeightOnlyConfig(group_size=group["quant_block_size"])
177-
elif isinstance(quantizer, StretchedUnifTorchaoQuantizer):
178-
config = StretchedIntxWeightOnlyConfig(
179-
b=group["quant_bits"],
180-
quant_min=quantizer.quant_min,
181-
quant_max=quantizer.quant_max,
182-
granularity=granularity,
183-
version=version,
184-
)
185-
elif all(p.data_ptr() in embed_data_ptrs for p in group["params"]):
186-
config = IntxWeightOnlyConfig(
187-
weight_dtype=weight_dtype,
188-
granularity=granularity,
189-
mapping_type=quantizer.mapping_type,
190-
packing_format=PackingFormat.UNPACKED_TO_INT8,
191-
version=version,
192-
)
193-
else:
194-
config = Int8DynamicActivationIntxWeightConfig(
195-
weight_dtype=weight_dtype,
196-
weight_granularity=granularity,
197-
weight_mapping_type=quantizer.mapping_type,
198-
act_mapping_type=MappingType.ASYMMETRIC,
199-
packing_format=PackingFormat.UNPACKED_TO_INT8,
200-
version=version,
201-
)
202178
quantize_(model, config, filter_fn=filter_fn)
203179

204180
@torch._disable_dynamo

torchao/prototype/parq/quant/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from .config_torchao import ( # noqa: F401
8+
StretchedIntxWeightOnlyConfig,
9+
get_config_from_quantizer,
10+
)
711
from .lsbq import LSBQuantizer # noqa: F401
812
from .quantizer import Quantizer # noqa: F401
913
from .uniform import ( # noqa: F401
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
from dataclasses import dataclass
2+
from typing import Optional
3+
4+
import torch
5+
6+
from torchao.core.config import AOBaseConfig
7+
from torchao.dtypes import Int4CPULayout
8+
from torchao.quantization import MappingType, PerAxis, PerGroup
9+
from torchao.quantization.quant_api import (
10+
Int4WeightOnlyConfig,
11+
Int8DynamicActivationIntxWeightConfig,
12+
IntxWeightOnlyConfig,
13+
)
14+
from torchao.quantization.quantize_.common.packing_format import PackingFormat
15+
from torchao.quantization.quantize_.workflows import IntxUnpackedToInt8Tensor
16+
from torchao.quantization.transform_module import register_quantize_module_handler
17+
from torchao.utils import check_cpu_version
18+
19+
from .quant_api import (
20+
choose_qparams_stretched_affine,
21+
quantize_stretched_affine,
22+
to_stretched_affine_quantized_intx,
23+
)
24+
from .uniform_torchao import (
25+
_BIT_WIDTH_TO_DTYPE,
26+
Int4UnifTorchaoQuantizer,
27+
StretchedUnifTorchaoQuantizer,
28+
)
29+
30+
31+
@dataclass
32+
class StretchedIntxWeightOnlyConfig(IntxWeightOnlyConfig):
33+
b: Optional[int] = None
34+
quant_min: Optional[int] = None
35+
quant_max: Optional[int] = None
36+
activation_quantization: Optional[str] = "int8_asym_per_token"
37+
38+
39+
@register_quantize_module_handler(StretchedIntxWeightOnlyConfig)
40+
def _stretched_intx_weight_only_transform(
41+
module: torch.nn.Module, config: StretchedIntxWeightOnlyConfig
42+
) -> torch.nn.Module:
43+
weight = module.weight
44+
granularity = config.granularity
45+
mapping_type = MappingType.ASYMMETRIC
46+
47+
assert weight.dim() == 2, (
48+
f"StretchedIntxWeightOnlyConfig only works for 2-d Tensor, got: {weight.dim()}"
49+
)
50+
if isinstance(granularity, PerGroup):
51+
group_size = granularity.group_size
52+
elif isinstance(granularity, PerAxis):
53+
assert granularity.axis == 0, (
54+
f"axis must be 0 with PerAxis, but got {granularity.axis}"
55+
)
56+
group_size = weight.shape[-1]
57+
else:
58+
raise ValueError(f"granularity must be PerGroup or PerAxis, got {granularity}")
59+
60+
block_size = (1, group_size)
61+
target_dtype = torch.int8
62+
q_args = (weight, mapping_type, block_size, target_dtype, config.b)
63+
if config.version == 2:
64+
scale, zero_point = choose_qparams_stretched_affine(
65+
*q_args,
66+
quant_min=config.quant_min,
67+
quant_max=config.quant_max,
68+
)
69+
qdata = quantize_stretched_affine(
70+
weight,
71+
block_size,
72+
scale,
73+
zero_point,
74+
target_dtype,
75+
quant_min=config.quant_min,
76+
quant_max=config.quant_max,
77+
)
78+
n_blocks = [qdata.shape[i] // block_size[i] for i in range(len(block_size))]
79+
scale = scale.reshape(*n_blocks)
80+
zero_point = zero_point.reshape(*n_blocks)
81+
82+
weight = IntxUnpackedToInt8Tensor(
83+
qdata=qdata,
84+
scale=scale,
85+
zero_point=zero_point,
86+
target_dtype=getattr(torch, f"int{config.b}"),
87+
block_size=block_size,
88+
dtype=weight.dtype,
89+
activation_quantization=config.activation_quantization,
90+
)
91+
else:
92+
weight = to_stretched_affine_quantized_intx(
93+
*q_args,
94+
quant_min=config.quant_min,
95+
quant_max=config.quant_max,
96+
scale_dtype=config.scale_dtype,
97+
_layout=config.layout,
98+
)
99+
module.weight = torch.nn.Parameter(weight, requires_grad=False)
100+
return module
101+
102+
103+
def get_config_from_quantizer(
104+
quantizer,
105+
is_embed: bool,
106+
device: torch.device,
107+
b: int,
108+
block_size: Optional[int],
109+
version: int = 2,
110+
) -> AOBaseConfig:
111+
granularity = PerGroup(block_size) if block_size is not None else PerAxis(0)
112+
weight_dtype = _BIT_WIDTH_TO_DTYPE[b]
113+
if isinstance(quantizer, Int4UnifTorchaoQuantizer):
114+
kwargs = {"layout": Int4CPULayout()} if check_cpu_version(device) else {}
115+
config = Int4WeightOnlyConfig(group_size=block_size, **kwargs)
116+
elif isinstance(quantizer, StretchedUnifTorchaoQuantizer):
117+
config = StretchedIntxWeightOnlyConfig(
118+
b=b,
119+
quant_min=quantizer.quant_min,
120+
quant_max=quantizer.quant_max,
121+
granularity=granularity,
122+
version=version,
123+
)
124+
elif is_embed:
125+
config = IntxWeightOnlyConfig(
126+
weight_dtype=weight_dtype,
127+
granularity=granularity,
128+
mapping_type=quantizer.mapping_type,
129+
packing_format=PackingFormat.UNPACKED_TO_INT8,
130+
version=version,
131+
)
132+
else:
133+
config = Int8DynamicActivationIntxWeightConfig(
134+
weight_dtype=weight_dtype,
135+
weight_granularity=granularity,
136+
weight_mapping_type=quantizer.mapping_type,
137+
act_mapping_type=MappingType.ASYMMETRIC,
138+
packing_format=PackingFormat.UNPACKED_TO_INT8,
139+
version=version,
140+
)
141+
return config

torchao/prototype/parq/quant/quant_api.py

Lines changed: 3 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,20 @@
88
from typing import Optional, Tuple, Union
99

1010
import torch
11-
from torch import nn
1211

1312
from torchao.dtypes import AffineQuantizedTensor, Layout, QDQLayout
1413
from torchao.quantization import (
1514
MappingType,
16-
PerAxis,
17-
PerGroup,
1815
ZeroPointDomain,
1916
dequantize_affine,
2017
)
21-
from torchao.quantization.quant_api import IntxWeightOnlyConfig
18+
from torchao.quantization.quant_api import (
19+
IntxWeightOnlyConfig,
20+
)
2221
from torchao.quantization.quant_primitives import (
2322
_SUB_BYTE_UINT_BOUNDS,
2423
_get_reduction_params,
2524
)
26-
from torchao.quantization.quantize_.workflows import IntxUnpackedToInt8Tensor
27-
from torchao.quantization.transform_module import register_quantize_module_handler
2825

2926

3027
def choose_qparams_stretched_affine(
@@ -188,67 +185,3 @@ class StretchedIntxWeightOnlyConfig(IntxWeightOnlyConfig):
188185
quant_min: Optional[int] = None
189186
quant_max: Optional[int] = None
190187
activation_quantization: Optional[str] = "int8_asym_per_token"
191-
192-
193-
@register_quantize_module_handler(StretchedIntxWeightOnlyConfig)
194-
def _stretched_intx_weight_only_transform(
195-
module: nn.Module, config: StretchedIntxWeightOnlyConfig
196-
) -> nn.Module:
197-
weight = module.weight
198-
granularity = config.granularity
199-
mapping_type = MappingType.ASYMMETRIC
200-
201-
assert weight.dim() == 2, (
202-
f"StretchedIntxWeightOnlyConfig only works for 2-d Tensor, got: {weight.dim()}"
203-
)
204-
if isinstance(granularity, PerGroup):
205-
group_size = granularity.group_size
206-
elif isinstance(granularity, PerAxis):
207-
assert granularity.axis == 0, (
208-
f"axis must be 0 with PerAxis, but got {granularity.axis}"
209-
)
210-
group_size = weight.shape[-1]
211-
else:
212-
raise ValueError(f"granularity must be PerGroup or PerAxis, got {granularity}")
213-
214-
block_size = (1, group_size)
215-
target_dtype = torch.int8
216-
q_args = (weight, mapping_type, block_size, target_dtype, config.b)
217-
if config.version == 2:
218-
scale, zero_point = choose_qparams_stretched_affine(
219-
*q_args,
220-
quant_min=config.quant_min,
221-
quant_max=config.quant_max,
222-
)
223-
qdata = quantize_stretched_affine(
224-
weight,
225-
block_size,
226-
scale,
227-
zero_point,
228-
target_dtype,
229-
quant_min=config.quant_min,
230-
quant_max=config.quant_max,
231-
)
232-
n_blocks = [qdata.shape[i] // block_size[i] for i in range(len(block_size))]
233-
scale = scale.reshape(*n_blocks)
234-
zero_point = zero_point.reshape(*n_blocks)
235-
236-
weight = IntxUnpackedToInt8Tensor(
237-
qdata=qdata,
238-
scale=scale,
239-
zero_point=zero_point,
240-
target_dtype=getattr(torch, f"int{config.b}"),
241-
block_size=block_size,
242-
dtype=weight.dtype,
243-
activation_quantization=config.activation_quantization,
244-
)
245-
else:
246-
weight = to_stretched_affine_quantized_intx(
247-
*q_args,
248-
quant_min=config.quant_min,
249-
quant_max=config.quant_max,
250-
scale_dtype=config.scale_dtype,
251-
_layout=config.layout,
252-
)
253-
module.weight = torch.nn.Parameter(weight, requires_grad=False)
254-
return module

torchao/prototype/parq/quant/uniform_torchao.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,7 @@
2727
quantize_affine,
2828
)
2929

30-
from .quant_api import (
31-
choose_qparams_stretched_affine,
32-
quantize_stretched_affine,
33-
)
30+
from .quant_api import choose_qparams_stretched_affine, quantize_stretched_affine
3431
from .quantizer import Quantizer
3532

3633
_BIT_WIDTH_TO_DTYPE = {v: k for k, v in _DTYPE_TO_BIT_WIDTH.items()}

0 commit comments

Comments
 (0)