Skip to content

Commit e33264a

Browse files
committed
Rename to StretchedIntxWeightConfig
1 parent 1ccb298 commit e33264a

File tree

4 files changed

+23
-19
lines changed

4 files changed

+23
-19
lines changed

test/prototype/test_dynamic_activation_lut.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import torch
1313

1414
from torchao.prototype.parq.quant import (
15-
Int8DynamicActivationStretchedIntxWeightConfig,
15+
StretchedIntxWeightConfig,
1616
StretchedUnifTorchaoQuantizer,
1717
)
1818
from torchao.prototype.quantization.dynamic_activation_lut import (
@@ -63,12 +63,13 @@ def run_before_and_after_tests():
6363
def test_parq_conversion(dtype, granularity, bit_width, lead_dim):
6464
torch.manual_seed(0)
6565
quantizer = StretchedUnifTorchaoQuantizer(bit_width)
66-
config = Int8DynamicActivationStretchedIntxWeightConfig(
66+
config = StretchedIntxWeightConfig(
6767
b=bit_width,
6868
quant_min=quantizer.quant_min,
6969
quant_max=quantizer.quant_max,
7070
granularity=granularity,
7171
activation_quantization=None,
72+
version=1,
7273
)
7374

7475
parq_model = ToyLinearModel(128, 256, 128, 1).to(dtype)
@@ -114,12 +115,13 @@ def test_parq_conversion(dtype, granularity, bit_width, lead_dim):
114115
@pytest.mark.skipif(not is_arm64_mac, reason="requires arm64 mac")
115116
def test_export(dtype, granularity, bit_width, lead_dim):
116117
quantizer = StretchedUnifTorchaoQuantizer(bit_width)
117-
config = Int8DynamicActivationStretchedIntxWeightConfig(
118+
config = StretchedIntxWeightConfig(
118119
b=bit_width,
119120
quant_min=quantizer.quant_min,
120121
quant_max=quantizer.quant_max,
121122
granularity=granularity,
122123
activation_quantization=None,
124+
version=1,
123125
)
124126

125127
parq_model = ToyLinearModel(128, 256, 128, 8).to(dtype)

test/prototype/test_parq.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
)
2020
from torchao.prototype.parq.quant import (
2121
Int4UnifTorchaoQuantizer,
22-
Int8DynamicActivationStretchedIntxWeightConfig,
2322
LSBQuantizer,
2423
Quantizer,
24+
StretchedIntxWeightConfig,
2525
StretchedUnifTorchaoQuantizer,
2626
TernaryUnifQuantizer,
2727
UnifQuantizer,
@@ -237,11 +237,14 @@ class TestUnifTorchaoQuantizer(common_utils.TestCase):
237237
def setUp(self):
238238
torch.manual_seed(123)
239239

240-
@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch >= 2.8.0")
241240
@unittest.skipIf(
242-
torch.cuda.is_available()
243-
and (not is_sm_at_least_90() or not _is_fbgemm_genai_gpu_available()),
244-
"Requires sm90+ and fbgemm-gpu-genai >= 1.2.0 if GPU available",
241+
_DEVICE == "cuda"
242+
and (
243+
not torch_version_at_least("2.8.0")
244+
or not is_sm_at_least_90()
245+
or not _is_fbgemm_genai_gpu_available()
246+
),
247+
"Requires pytorch >= 2.8.0, sm90+ and fbgemm-gpu-genai >= 1.2.0 on GPU",
245248
)
246249
@common_utils.parametrize("group_size", [32, 256])
247250
def test_int4_weight_only(self, group_size: int = 32):
@@ -361,7 +364,7 @@ def test_intx_weight_only(self, b: int = 2, group_size: int = 32):
361364
m_ref = copy.deepcopy(model).eval().to(_DEVICE)
362365
quantize_(
363366
m_ref,
364-
Int8DynamicActivationStretchedIntxWeightConfig(
367+
StretchedIntxWeightConfig(
365368
b=b,
366369
quant_min=quantizer.quant_min,
367370
quant_max=quantizer.quant_max,
@@ -381,7 +384,7 @@ def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32):
381384
quantizer = StretchedUnifTorchaoQuantizer(b)
382385

383386
m_ref = copy.deepcopy(model).eval().to(_DEVICE)
384-
config = Int8DynamicActivationStretchedIntxWeightConfig(
387+
config = StretchedIntxWeightConfig(
385388
b=b,
386389
quant_min=quantizer.quant_min,
387390
quant_max=quantizer.quant_max,

torchao/prototype/parq/quant/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from .config_torchao import ( # noqa: F401
8-
Int8DynamicActivationStretchedIntxWeightConfig,
9-
)
7+
from .config_torchao import StretchedIntxWeightConfig # noqa: F401
108
from .lsbq import LSBQuantizer # noqa: F401
119
from .quantizer import Quantizer # noqa: F401
1210
from .uniform import ( # noqa: F401

torchao/prototype/parq/quant/config_torchao.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242

4343

4444
@dataclass
45-
class Int8DynamicActivationStretchedIntxWeightConfig(AOBaseConfig):
45+
class StretchedIntxWeightConfig(AOBaseConfig):
4646
granularity: Granularity = PerAxis(0)
4747
scale_dtype: Optional[torch.dtype] = None
4848
layout: Layout = QDQLayout()
@@ -53,16 +53,16 @@ class Int8DynamicActivationStretchedIntxWeightConfig(AOBaseConfig):
5353
activation_quantization: Optional[str] = "int8_asym_per_token"
5454

5555

56-
@register_quantize_module_handler(Int8DynamicActivationStretchedIntxWeightConfig)
56+
@register_quantize_module_handler(StretchedIntxWeightConfig)
5757
def _int8_dynamic_activation_stretched_intx_transform(
58-
module: nn.Module, config: Int8DynamicActivationStretchedIntxWeightConfig
58+
module: nn.Module, config: StretchedIntxWeightConfig
5959
) -> nn.Module:
6060
weight = module.weight
6161
granularity = config.granularity
6262
mapping_type = MappingType.ASYMMETRIC
6363

6464
assert weight.dim() == 2, (
65-
f"Int8DynamicActivationStretchedIntxWeightConfig only works for 2-d Tensor, got: {weight.dim()}"
65+
f"StretchedIntxWeightConfig only works for 2-d Tensor, got: {weight.dim()}"
6666
)
6767
if isinstance(granularity, PerGroup):
6868
group_size = granularity.group_size
@@ -138,9 +138,8 @@ def _get_config_from_quantizer(
138138
)
139139
if check_cpu_version(device):
140140
config.layout = Int4CPULayout()
141-
config.version = 1
142141
elif isinstance(quantizer, StretchedUnifTorchaoQuantizer):
143-
config = Int8DynamicActivationStretchedIntxWeightConfig(
142+
config = StretchedIntxWeightConfig(
144143
b=b,
145144
quant_min=quantizer.quant_min,
146145
quant_max=quantizer.quant_max,
@@ -164,6 +163,8 @@ def _get_config_from_quantizer(
164163
act_mapping_type=MappingType.ASYMMETRIC,
165164
version=version,
166165
)
166+
if check_cpu_version(device):
167+
config.version = 1
167168
return config
168169

169170

0 commit comments

Comments
 (0)