Skip to content
52 changes: 15 additions & 37 deletions test/prototype/test_dynamic_activation_lut.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,47 +7,25 @@
import platform
import sys
from copy import deepcopy
from dataclasses import dataclass

import pytest
import torch
import torch.nn as nn

from torchao.core.config import AOBaseConfig
from torchao.prototype.parq.quant import StretchedUnifTorchaoQuantizer
from torchao.prototype.parq.quant.quant_api import StretchedIntxWeightOnlyConfig
from torchao.prototype.parq.quant import (
StretchedIntxWeightConfig,
StretchedUnifTorchaoQuantizer,
)
from torchao.prototype.quantization.dynamic_activation_lut import (
StretchedAffineQuantizedTensor_to_Int8DynamicActivationLutTensorConfig,
)
from torchao.quantization import quantize_
from torchao.quantization.granularity import PerAxis, PerGroup
from torchao.quantization.linear_activation_quantized_tensor import (
to_linear_activation_quantized,
)
from torchao.quantization.quant_api import (
_int8_asymm_per_token_quant,
)
from torchao.quantization.transform_module import register_quantize_module_handler
from torchao.quantization.quant_api import _is_linear
from torchao.quantization.utils import compute_error

is_arm64_mac = sys.platform == "darwin" and platform.machine() == "arm64"


@dataclass
class Int8DynamicActivationConfig(AOBaseConfig):
pass


@register_quantize_module_handler(Int8DynamicActivationConfig)
def _int8_dynamic_activation_transform(
module: nn.Module, config: Int8DynamicActivationConfig
) -> nn.Module:
weight = module.weight
weight = to_linear_activation_quantized(weight, _int8_asymm_per_token_quant)
module.weight = torch.nn.Parameter(weight, requires_grad=False)
return module


class ToyLinearModel(torch.nn.Module):
def __init__(self, d1=512, d2=256, d3=128, d4=8):
super().__init__()
Expand Down Expand Up @@ -85,26 +63,24 @@ def run_before_and_after_tests():
def test_parq_conversion(dtype, granularity, bit_width, lead_dim):
torch.manual_seed(0)
quantizer = StretchedUnifTorchaoQuantizer(bit_width)
config = StretchedIntxWeightOnlyConfig(
config = StretchedIntxWeightConfig(
b=bit_width,
quant_min=quantizer.quant_min,
quant_max=quantizer.quant_max,
granularity=granularity,
activation_quantization=None,
version=1,
)

parq_model = ToyLinearModel(128, 256, 128, 1).to(dtype)
activations = parq_model.example_inputs(lead_dim=lead_dim, dtype=dtype)
parq_model_with_dyn_quant = deepcopy(parq_model)
quantize_(parq_model, config)

# Apply dynamic activation to parq model. This will serve as the LUT reference
parq_model_with_dyn_quant = deepcopy(parq_model)
quantize_(
parq_model_with_dyn_quant,
Int8DynamicActivationConfig(),
# We have to explicitly provide filter_fn because the default linear filter
# excludes modules with AffinQUnatizedTensor weights
filter_fn=lambda m, fqn: isinstance(m, torch.nn.Linear),
)
dyn_act_config = deepcopy(config)
dyn_act_config.activation_quantization = "int8_asym_per_token"
quantize_(parq_model_with_dyn_quant, dyn_act_config, filter_fn=_is_linear)

# Convert PARQ model to lowbit LUT model
lut_model = deepcopy(parq_model)
Expand Down Expand Up @@ -139,11 +115,13 @@ def test_parq_conversion(dtype, granularity, bit_width, lead_dim):
@pytest.mark.skipif(not is_arm64_mac, reason="requires arm64 mac")
def test_export(dtype, granularity, bit_width, lead_dim):
quantizer = StretchedUnifTorchaoQuantizer(bit_width)
config = StretchedIntxWeightOnlyConfig(
config = StretchedIntxWeightConfig(
b=bit_width,
quant_min=quantizer.quant_min,
quant_max=quantizer.quant_max,
granularity=granularity,
activation_quantization=None,
version=1,
)

parq_model = ToyLinearModel(128, 256, 128, 8).to(dtype)
Expand Down
Loading
Loading