Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/api_ref_qat.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@ Custom QAT APIs

FakeQuantizeConfigBase
IntxFakeQuantizeConfig
Float8FakeQuantizeConfig
FakeQuantizedLinear
FakeQuantizedEmbedding
FakeQuantizerBase
IntxFakeQuantizer
Float8FakeQuantizer
linear.enable_linear_fake_quant
linear.disable_linear_fake_quant

Expand Down
161 changes: 131 additions & 30 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,22 @@

import torch
import torch.nn.functional as F
from parameterized import parameterized
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
from torch.testing._internal.common_utils import (
TestCase,
instantiate_parametrized_tests,
parametrize,
)

from torchao import quantize_
from torchao.float8.config import ScalingGranularity
from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic
from torchao.float8.float8_training_tensor import LinearMMConfig
from torchao.core.config import AOBaseConfig
from torchao.quantization import Float8Tensor
from torchao.quantization.granularity import (
Granularity,
PerAxis,
PerGroup,
PerRow,
PerTensor,
PerToken,
)
from torchao.quantization.linear_quant_modules import (
Expand All @@ -43,11 +48,12 @@
FakeQuantizedEmbedding,
)
from torchao.quantization.qat.fake_quantize_config import (
Float8FakeQuantizeConfig,
IntxFakeQuantizeConfig,
)
from torchao.quantization.qat.fake_quantizer import (
Float8FakeQuantizer,
IntxFakeQuantizer,
_Float8RowwiseActivationFakeQuantizer,
)
from torchao.quantization.qat.linear import (
FakeQuantizedLinear,
Expand All @@ -58,10 +64,11 @@
from torchao.quantization.qat.utils import (
_fake_quantize_per_channel_group,
_fake_quantize_per_token,
_Float8RowwiseFakeQuantize,
_get_qmin_qmax,
)
from torchao.quantization.quant_api import (
Float8DynamicActivationFloat8WeightConfig,
Float8DynamicActivationInt4WeightConfig,
Int8DynamicActivationInt4WeightConfig,
)
from torchao.quantization.quant_primitives import (
Expand All @@ -83,6 +90,10 @@
get_groupwise_affine_qparams,
groupwise_affine_quantize_tensor,
)
from torchao.utils import (
_is_fbgemm_genai_gpu_available,
is_sm_at_least_89,
)

# TODO: put this in a common test utils file
_CUDA_IS_AVAILABLE = torch.cuda.is_available()
Expand Down Expand Up @@ -193,7 +204,7 @@ def forward(self, x):
return x


class TestQAT(unittest.TestCase):
class TestQAT(TestCase):
SEED = 123

def test_fake_quantize_per_channel_group(self):
Expand Down Expand Up @@ -1420,7 +1431,7 @@ def test_qat_linear_bias(self):
example_inputs = m.example_inputs()
m(*example_inputs)

@parameterized.expand([(torch.float32,), (torch.bfloat16,), (torch.float16,)])
@parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
def test_fake_quantize_per_token_vs_convert(self, dtype: torch.dtype):
"""
Test that the following produce the exact same numerics:
Expand All @@ -1437,7 +1448,7 @@ def test_fake_quantize_per_token_vs_convert(self, dtype: torch.dtype):
baseline_out = per_token_dynamic_quant(x)
torch.testing.assert_close(fake_quantizer_out, baseline_out, atol=0, rtol=0)

@parameterized.expand([(torch.float32,), (torch.bfloat16,), (torch.float16,)])
@parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
def test_qat_8da4w_prepare_vs_convert(self, dtype: torch.dtype):
"""
Test that the prepare and convert steps of Int8DynActInt4QATQuantizer produces
Expand Down Expand Up @@ -1548,7 +1559,7 @@ def test_qat_8da4w_eps(self):
actual_out = converted_model.linear1(x)
torch.testing.assert_close(expected_out, actual_out, atol=0, rtol=0)

@parameterized.expand([(True,), (False,)])
@parametrize("is_symmetric", [True, False])
def test_fake_quantizer_range_learning(self, is_symmetric):
"""
Test that range learning requires `IntxFakeQuantizer`s to be initialized correctly.
Expand Down Expand Up @@ -1589,7 +1600,7 @@ def test_fake_quantizer_range_learning(self, is_symmetric):
self.assertTrue(fake_quantizer.zero_point.requires_grad)
fake_quantizer(*example_inputs)

@parameterized.expand([(True,), (False,)])
@parametrize("is_symmetric", [True, False])
def test_qat_range_learning(self, is_symmetric):
"""
Test end-to-end QAT flow with range learning.
Expand Down Expand Up @@ -1664,24 +1675,6 @@ def test_qat_range_learning(self, is_symmetric):
self.assertNotEqual(torch.count_nonzero(new_weight.grad), 0)
self.assertFalse(torch.equal(new_weight, prev_weight))

def test_float8_rowwise_fake_quantize(self):
"""
Test that `_Float8RowwiseFakeQuantize` is numerically close to `Float8TrainingTensor`.
"""
torch.manual_seed(self.SEED)
dtype = torch.float8_e4m3fn
x = torch.randn(32, 64)
axiswise_dim = 0
out = _Float8RowwiseFakeQuantize.apply(x, dtype, axiswise_dim)
out_expected = hp_tensor_to_float8_dynamic(
x,
dtype,
LinearMMConfig(),
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=axiswise_dim,
).to_original_precision()
torch.testing.assert_close(out, out_expected, atol=0, rtol=0)

def test_qat_fp8a4w_quantizer(self):
"""
Test basic model training with `Float8ActInt4WeightQATQuantizer`.
Expand All @@ -1693,7 +1686,8 @@ def test_qat_fp8a4w_quantizer(self):
for linear in [m.linear1, m.sub.linear, m.linear2]:
self.assertIsInstance(linear, FakeQuantizedLinear)
self.assertIsInstance(
linear.activation_fake_quantizer, _Float8RowwiseActivationFakeQuantizer
linear.activation_fake_quantizer,
Float8FakeQuantizer,
)
self.assertIsInstance(linear.weight_fake_quantizer, IntxFakeQuantizer)
prev_weight = copy.deepcopy(m.linear1.weight)
Expand Down Expand Up @@ -1833,6 +1827,113 @@ def test_qat_api_convert_no_quantization(self):
baseline_out = baseline_model(*x2)
torch.testing.assert_close(out, baseline_out, atol=0, rtol=0)

def test_float8_fake_quantize_config(self):
"""
Test that the correct errors are thrown if `Float8FakeQuantizeConfig` is not instantiated properly.
"""
# OK
Float8FakeQuantizeConfig(torch.float8_e4m3fn)
Float8FakeQuantizeConfig(torch.float8_e4m3fn, PerRow())
Float8FakeQuantizeConfig(torch.float8_e4m3fn, PerTensor())

with self.assertRaisesRegex(ValueError, "not a float8 dtype"):
Float8FakeQuantizeConfig(torch.int8)
with self.assertRaisesRegex(
ValueError, "Please specify the granularity object instead of the class"
):
Float8FakeQuantizeConfig(granularity=PerRow)
with self.assertRaisesRegex(
ValueError, "Expected PerRow or PerTensor granularity"
):
Float8FakeQuantizeConfig(granularity=PerToken())

@parametrize("granularity", [PerTensor(), PerRow()])
def test_float8_fake_quantize(self, granularity: Granularity):
Copy link
Contributor

@jerryzh168 jerryzh168 Aug 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a same test for fp8_int4?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added some sqnr comparison against PTQ fp8_int4

"""
Test that `Float8FakeQuantizer` is numerically close to `Float8Tensor`.
"""
dtype = torch.float8_e4m3fn
fq_config = Float8FakeQuantizeConfig(dtype, granularity)
fake_quantizer = Float8FakeQuantizer(fq_config)
torch.manual_seed(self.SEED)
x = torch.randn(32, 64)
out = fake_quantizer(x)
out_expected = Float8Tensor.to_float8(x, dtype, granularity).dequantize()
sqnr = compute_error(out, out_expected)
self.assertGreater(sqnr, 16)

def _test_quantize_api_against_ptq(
self,
base_config: AOBaseConfig,
target_prepare_sqnr: float,
target_convert_sqnr: float,
):
"""
Test the following:

quantize_(model, QATConfig(base_config, step="prepare"))
quantize_(model, QATConfig(base_config, step="convert"))

and compare model outputs of each step against:

quantize_(model, base_config)
"""
torch.manual_seed(self.SEED)
m = M().to(torch.bfloat16).cuda()
example_inputs = (m.example_inputs()[0].to(torch.bfloat16).cuda(),)

# baseline
m_baseline = copy.deepcopy(m)
quantize_(m_baseline, base_config)
out_baseline = m_baseline(*example_inputs)

# compare prepare
quantize_(m, QATConfig(base_config, step="prepare"))
out_prepared = m(*example_inputs)
prepare_sqnr = compute_error(out_prepared, out_baseline)
self.assertGreaterEqual(prepare_sqnr, target_prepare_sqnr)

# compare convert
quantize_(m, QATConfig(base_config, step="convert"))
out_converted = m(*example_inputs)
convert_sqnr = compute_error(out_converted, out_baseline)
self.assertGreaterEqual(convert_sqnr, target_convert_sqnr)

@parametrize("granularity", [PerTensor(), PerRow()])
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
@unittest.skipIf(not is_sm_at_least_89(), "Need sm89+")
def test_quantize_api_fp8_fp8(self, granularity: Granularity):
"""
Test the following:
quantize_(model, QATConfig(Float8DynamicActivationFloat8Weight(), step="prepare"))
quantize_(model, QATConfig(Float8DynamicActivationFloat8Weight(), step="convert"))
"""
self._test_quantize_api_against_ptq(
Float8DynamicActivationFloat8WeightConfig(granularity=granularity),
target_prepare_sqnr=15,
target_convert_sqnr=float("inf"),
)

@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
@unittest.skipIf(not is_sm_at_least_89(), "Need sm89+")
@unittest.skipIf(
not _is_fbgemm_genai_gpu_available(), "Requires fbgemm-gpu-genai >= 1.2.0"
)
def test_quantize_api_fp8_int4(self):
"""
Test the following:
quantize_(model, QATConfig(Float8DynamicActivationInt4WeightConfig(), step="prepare"))
quantize_(model, QATConfig(Float8DynamicActivationInt4WeightConfig(), step="convert"))
"""
self._test_quantize_api_against_ptq(
Float8DynamicActivationInt4WeightConfig(group_size=128),
target_prepare_sqnr=15,
target_convert_sqnr=float("inf"),
)


instantiate_parametrized_tests(TestQAT)


if __name__ == "__main__":
unittest.main()
4 changes: 4 additions & 0 deletions torchao/quantization/qat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
from .fake_quantize_config import (
FakeQuantizeConfig,
FakeQuantizeConfigBase,
Float8FakeQuantizeConfig,
IntxFakeQuantizeConfig,
)
from .fake_quantizer import (
FakeQuantizer,
FakeQuantizerBase,
Float8FakeQuantizer,
IntxFakeQuantizer,
)
from .linear import (
Expand All @@ -34,6 +36,8 @@
"QATStep",
"FakeQuantizeConfigBase",
"FakeQuantizerBase",
"Float8FakeQuantizeConfig",
"Float8FakeQuantizer",
"IntxFakeQuantizeConfig",
"IntxFakeQuantizer",
"FakeQuantizedLinear",
Expand Down
Loading
Loading