1414
1515import torch
1616import torch .nn .functional as F
17- from parameterized import parameterized
1817from torch .ao .quantization .fx ._decomposed import quantized_decomposed_lib # noqa: F401
18+ from torch .testing ._internal .common_utils import (
19+ TestCase ,
20+ instantiate_parametrized_tests ,
21+ parametrize ,
22+ )
1923
2024from torchao import quantize_
21- from torchao .float8 .config import ScalingGranularity
22- from torchao .float8 .float8_scaling_utils import hp_tensor_to_float8_dynamic
23- from torchao .float8 .float8_training_tensor import LinearMMConfig
25+ from torchao .core .config import AOBaseConfig
26+ from torchao .quantization import Float8Tensor
2427from torchao .quantization .granularity import (
28+ Granularity ,
2529 PerAxis ,
2630 PerGroup ,
2731 PerRow ,
32+ PerTensor ,
2833 PerToken ,
2934)
3035from torchao .quantization .linear_quant_modules import (
4348 FakeQuantizedEmbedding ,
4449)
4550from torchao .quantization .qat .fake_quantize_config import (
51+ Float8FakeQuantizeConfig ,
4652 IntxFakeQuantizeConfig ,
4753)
4854from torchao .quantization .qat .fake_quantizer import (
55+ Float8FakeQuantizer ,
4956 IntxFakeQuantizer ,
50- _Float8RowwiseActivationFakeQuantizer ,
5157)
5258from torchao .quantization .qat .linear import (
5359 FakeQuantizedLinear ,
5864from torchao .quantization .qat .utils import (
5965 _fake_quantize_per_channel_group ,
6066 _fake_quantize_per_token ,
61- _Float8RowwiseFakeQuantize ,
6267 _get_qmin_qmax ,
6368)
6469from torchao .quantization .quant_api import (
70+ Float8DynamicActivationFloat8WeightConfig ,
71+ Float8DynamicActivationInt4WeightConfig ,
6572 Int8DynamicActivationInt4WeightConfig ,
6673)
6774from torchao .quantization .quant_primitives import (
8390 get_groupwise_affine_qparams ,
8491 groupwise_affine_quantize_tensor ,
8592)
93+ from torchao .utils import (
94+ _is_fbgemm_genai_gpu_available ,
95+ is_sm_at_least_89 ,
96+ )
8697
8798# TODO: put this in a common test utils file
8899_CUDA_IS_AVAILABLE = torch .cuda .is_available ()
@@ -193,7 +204,7 @@ def forward(self, x):
193204 return x
194205
195206
196- class TestQAT (unittest . TestCase ):
207+ class TestQAT (TestCase ):
197208 SEED = 123
198209
199210 def test_fake_quantize_per_channel_group (self ):
@@ -1420,7 +1431,7 @@ def test_qat_linear_bias(self):
14201431 example_inputs = m .example_inputs ()
14211432 m (* example_inputs )
14221433
1423- @parameterized . expand ([( torch .float32 ,), ( torch .bfloat16 ,), ( torch .float16 ,) ])
1434+ @parametrize ( "dtype" , [ torch .float32 , torch .bfloat16 , torch .float16 ])
14241435 def test_fake_quantize_per_token_vs_convert (self , dtype : torch .dtype ):
14251436 """
14261437 Test that the following produce the exact same numerics:
@@ -1437,7 +1448,7 @@ def test_fake_quantize_per_token_vs_convert(self, dtype: torch.dtype):
14371448 baseline_out = per_token_dynamic_quant (x )
14381449 torch .testing .assert_close (fake_quantizer_out , baseline_out , atol = 0 , rtol = 0 )
14391450
1440- @parameterized . expand ([( torch .float32 ,), ( torch .bfloat16 ,), ( torch .float16 ,) ])
1451+ @parametrize ( "dtype" , [ torch .float32 , torch .bfloat16 , torch .float16 ])
14411452 def test_qat_8da4w_prepare_vs_convert (self , dtype : torch .dtype ):
14421453 """
14431454 Test that the prepare and convert steps of Int8DynActInt4QATQuantizer produces
@@ -1548,7 +1559,7 @@ def test_qat_8da4w_eps(self):
15481559 actual_out = converted_model .linear1 (x )
15491560 torch .testing .assert_close (expected_out , actual_out , atol = 0 , rtol = 0 )
15501561
1551- @parameterized . expand ([( True ,), ( False ,) ])
1562+ @parametrize ( "is_symmetric" , [ True , False ])
15521563 def test_fake_quantizer_range_learning (self , is_symmetric ):
15531564 """
15541565 Test that range learning requires `IntxFakeQuantizer`s to be initialized correctly.
@@ -1589,7 +1600,7 @@ def test_fake_quantizer_range_learning(self, is_symmetric):
15891600 self .assertTrue (fake_quantizer .zero_point .requires_grad )
15901601 fake_quantizer (* example_inputs )
15911602
1592- @parameterized . expand ([( True ,), ( False ,) ])
1603+ @parametrize ( "is_symmetric" , [ True , False ])
15931604 def test_qat_range_learning (self , is_symmetric ):
15941605 """
15951606 Test end-to-end QAT flow with range learning.
@@ -1664,24 +1675,6 @@ def test_qat_range_learning(self, is_symmetric):
16641675 self .assertNotEqual (torch .count_nonzero (new_weight .grad ), 0 )
16651676 self .assertFalse (torch .equal (new_weight , prev_weight ))
16661677
1667- def test_float8_rowwise_fake_quantize (self ):
1668- """
1669- Test that `_Float8RowwiseFakeQuantize` is numerically close to `Float8TrainingTensor`.
1670- """
1671- torch .manual_seed (self .SEED )
1672- dtype = torch .float8_e4m3fn
1673- x = torch .randn (32 , 64 )
1674- axiswise_dim = 0
1675- out = _Float8RowwiseFakeQuantize .apply (x , dtype , axiswise_dim )
1676- out_expected = hp_tensor_to_float8_dynamic (
1677- x ,
1678- dtype ,
1679- LinearMMConfig (),
1680- scaling_granularity = ScalingGranularity .AXISWISE ,
1681- axiswise_dim = axiswise_dim ,
1682- ).to_original_precision ()
1683- torch .testing .assert_close (out , out_expected , atol = 0 , rtol = 0 )
1684-
16851678 def test_qat_fp8a4w_quantizer (self ):
16861679 """
16871680 Test basic model training with `Float8ActInt4WeightQATQuantizer`.
@@ -1693,7 +1686,8 @@ def test_qat_fp8a4w_quantizer(self):
16931686 for linear in [m .linear1 , m .sub .linear , m .linear2 ]:
16941687 self .assertIsInstance (linear , FakeQuantizedLinear )
16951688 self .assertIsInstance (
1696- linear .activation_fake_quantizer , _Float8RowwiseActivationFakeQuantizer
1689+ linear .activation_fake_quantizer ,
1690+ Float8FakeQuantizer ,
16971691 )
16981692 self .assertIsInstance (linear .weight_fake_quantizer , IntxFakeQuantizer )
16991693 prev_weight = copy .deepcopy (m .linear1 .weight )
@@ -1833,6 +1827,113 @@ def test_qat_api_convert_no_quantization(self):
18331827 baseline_out = baseline_model (* x2 )
18341828 torch .testing .assert_close (out , baseline_out , atol = 0 , rtol = 0 )
18351829
1830+ def test_float8_fake_quantize_config (self ):
1831+ """
1832+ Test that the correct errors are thrown if `Float8FakeQuantizeConfig` is not instantiated properly.
1833+ """
1834+ # OK
1835+ Float8FakeQuantizeConfig (torch .float8_e4m3fn )
1836+ Float8FakeQuantizeConfig (torch .float8_e4m3fn , PerRow ())
1837+ Float8FakeQuantizeConfig (torch .float8_e4m3fn , PerTensor ())
1838+
1839+ with self .assertRaisesRegex (ValueError , "not a float8 dtype" ):
1840+ Float8FakeQuantizeConfig (torch .int8 )
1841+ with self .assertRaisesRegex (
1842+ ValueError , "Please specify the granularity object instead of the class"
1843+ ):
1844+ Float8FakeQuantizeConfig (granularity = PerRow )
1845+ with self .assertRaisesRegex (
1846+ ValueError , "Expected PerRow or PerTensor granularity"
1847+ ):
1848+ Float8FakeQuantizeConfig (granularity = PerToken ())
1849+
1850+ @parametrize ("granularity" , [PerTensor (), PerRow ()])
1851+ def test_float8_fake_quantize (self , granularity : Granularity ):
1852+ """
1853+ Test that `Float8FakeQuantizer` is numerically close to `Float8Tensor`.
1854+ """
1855+ dtype = torch .float8_e4m3fn
1856+ fq_config = Float8FakeQuantizeConfig (dtype , granularity )
1857+ fake_quantizer = Float8FakeQuantizer (fq_config )
1858+ torch .manual_seed (self .SEED )
1859+ x = torch .randn (32 , 64 )
1860+ out = fake_quantizer (x )
1861+ out_expected = Float8Tensor .to_float8 (x , dtype , granularity ).dequantize ()
1862+ sqnr = compute_error (out , out_expected )
1863+ self .assertGreater (sqnr , 16 )
1864+
1865+ def _test_quantize_api_against_ptq (
1866+ self ,
1867+ base_config : AOBaseConfig ,
1868+ target_prepare_sqnr : float ,
1869+ target_convert_sqnr : float ,
1870+ ):
1871+ """
1872+ Test the following:
1873+
1874+ quantize_(model, QATConfig(base_config, step="prepare"))
1875+ quantize_(model, QATConfig(base_config, step="convert"))
1876+
1877+ and compare model outputs of each step against:
1878+
1879+ quantize_(model, base_config)
1880+ """
1881+ torch .manual_seed (self .SEED )
1882+ m = M ().to (torch .bfloat16 ).cuda ()
1883+ example_inputs = (m .example_inputs ()[0 ].to (torch .bfloat16 ).cuda (),)
1884+
1885+ # baseline
1886+ m_baseline = copy .deepcopy (m )
1887+ quantize_ (m_baseline , base_config )
1888+ out_baseline = m_baseline (* example_inputs )
1889+
1890+ # compare prepare
1891+ quantize_ (m , QATConfig (base_config , step = "prepare" ))
1892+ out_prepared = m (* example_inputs )
1893+ prepare_sqnr = compute_error (out_prepared , out_baseline )
1894+ self .assertGreaterEqual (prepare_sqnr , target_prepare_sqnr )
1895+
1896+ # compare convert
1897+ quantize_ (m , QATConfig (base_config , step = "convert" ))
1898+ out_converted = m (* example_inputs )
1899+ convert_sqnr = compute_error (out_converted , out_baseline )
1900+ self .assertGreaterEqual (convert_sqnr , target_convert_sqnr )
1901+
1902+ @parametrize ("granularity" , [PerTensor (), PerRow ()])
1903+ @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
1904+ @unittest .skipIf (not is_sm_at_least_89 (), "Need sm89+" )
1905+ def test_quantize_api_fp8_fp8 (self , granularity : Granularity ):
1906+ """
1907+ Test the following:
1908+ quantize_(model, QATConfig(Float8DynamicActivationFloat8Weight(), step="prepare"))
1909+ quantize_(model, QATConfig(Float8DynamicActivationFloat8Weight(), step="convert"))
1910+ """
1911+ self ._test_quantize_api_against_ptq (
1912+ Float8DynamicActivationFloat8WeightConfig (granularity = granularity ),
1913+ target_prepare_sqnr = 15 ,
1914+ target_convert_sqnr = float ("inf" ),
1915+ )
1916+
1917+ @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
1918+ @unittest .skipIf (not is_sm_at_least_89 (), "Need sm89+" )
1919+ @unittest .skipIf (
1920+ not _is_fbgemm_genai_gpu_available (), "Requires fbgemm-gpu-genai >= 1.2.0"
1921+ )
1922+ def test_quantize_api_fp8_int4 (self ):
1923+ """
1924+ Test the following:
1925+ quantize_(model, QATConfig(Float8DynamicActivationInt4WeightConfig(), step="prepare"))
1926+ quantize_(model, QATConfig(Float8DynamicActivationInt4WeightConfig(), step="convert"))
1927+ """
1928+ self ._test_quantize_api_against_ptq (
1929+ Float8DynamicActivationInt4WeightConfig (group_size = 128 ),
1930+ target_prepare_sqnr = 15 ,
1931+ target_convert_sqnr = float ("inf" ),
1932+ )
1933+
1934+
1935+ instantiate_parametrized_tests (TestQAT )
1936+
18361937
18371938if __name__ == "__main__" :
18381939 unittest .main ()
0 commit comments