|
72 | 72 | AQInt8WeightOnlyQuantizedLinearWeight2, |
73 | 73 | AQInt8WeightOnlyQuantizedLinearWeight3, |
74 | 74 | AutoQuantizableLinearWeight, |
75 | | - |
| 75 | + AQFloat8WeightOnlyQuantizedLinearWeight, |
76 | 76 | ) |
77 | 77 | from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx |
78 | 78 | import os |
|
98 | 98 | COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] |
99 | 99 |
|
100 | 100 | COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy() |
| 101 | +is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) |
101 | 102 |
|
102 | 103 | def _int8wo_api(mod): |
103 | 104 | if TORCH_VERSION_AT_LEAST_2_4: |
@@ -744,6 +745,14 @@ def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype): |
744 | 745 | AQInt8WeightOnlyQuantizedLinearWeight3.from_float, device, 35, test_dtype=dtype |
745 | 746 | ) |
746 | 747 |
|
| 748 | + @parameterized.expand(COMMON_DEVICE_DTYPE) |
| 749 | + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch") |
| 750 | + @unittest.skipIf(not is_H100, "Need H100 to run") |
| 751 | + def test_aq_float8_weight_only_quant_subclass(self, device, dtype): |
| 752 | + self._test_lin_weight_subclass_impl( |
| 753 | + AQFloat8WeightOnlyQuantizedLinearWeight.from_float, device, 30, test_dtype=dtype |
| 754 | + ) |
| 755 | + |
747 | 756 | @parameterized.expand(COMMON_DEVICE_DTYPE) |
748 | 757 | @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") |
749 | 758 | # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") |
|
0 commit comments