diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index cbe279c12e..323802757d 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -17,6 +17,9 @@ from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 from torchao import quantize_ +from torchao.float8.config import ScalingGranularity +from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic +from torchao.float8.float8_tensor import LinearMMConfig from torchao.quantization.granularity import ( PerAxis, PerGroup, @@ -40,15 +43,18 @@ ) from torchao.quantization.qat.fake_quantizer import ( FakeQuantizer, + _Float8RowwiseActivationFakeQuantizer, ) from torchao.quantization.qat.linear import ( FakeQuantizedLinear, + Float8ActInt4WeightQATQuantizer, Int4WeightOnlyQATLinear, Int8DynActInt4WeightQATLinear, ) from torchao.quantization.qat.utils import ( _fake_quantize_per_channel_group, _fake_quantize_per_token, + _Float8RowwiseFakeQuantize, _get_qmin_qmax, ) from torchao.quantization.quant_api import ( @@ -68,6 +74,7 @@ ) from torchao.quantization.utils import ( _get_per_token_block_size, + compute_error, get_group_qparams_symmetric, get_groupwise_affine_qparams, groupwise_affine_quantize_tensor, @@ -1474,7 +1481,6 @@ def test_qat_8da4w_prepare_vs_convert(self, dtype: torch.dtype): numerics that match exactly over N trials. """ from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer - from torchao.quantization.utils import compute_error num_trials = 1000 group_size = 16 @@ -1688,6 +1694,61 @@ def test_qat_range_learning(self): self.assertNotEqual(torch.count_nonzero(new_weight.grad), 0) self.assertFalse(torch.equal(new_weight, prev_weight)) + def test_float8_rowwise_fake_quantize(self): + """ + Test that `_Float8RowwiseFakeQuantize` is numerically close to `Float8Tensor`. + """ + torch.manual_seed(self.SEED) + dtype = torch.float8_e4m3fn + x = torch.randn(32, 64) + axiswise_dim = 0 + out = _Float8RowwiseFakeQuantize.apply(x, dtype, axiswise_dim) + out_expected = hp_tensor_to_float8_dynamic( + x, + dtype, + LinearMMConfig(), + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=axiswise_dim, + ).to_original_precision() + torch.testing.assert_close(out, out_expected, atol=0, rtol=0) + + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_6, "skipping when torch version is 2.6 or lower" + ) + def test_qat_fp8a4w_quantizer(self): + """ + Test basic model training with `Float8ActInt4WeightQATQuantizer`. + """ + torch.manual_seed(self.SEED) + m = M() + qat_quantizer = Float8ActInt4WeightQATQuantizer() + qat_model = qat_quantizer.prepare(m) + for linear in [m.linear1, m.sub.linear, m.linear2]: + self.assertIsInstance(linear, FakeQuantizedLinear) + self.assertIsInstance( + linear.activation_fake_quantizer, _Float8RowwiseActivationFakeQuantizer + ) + self.assertIsInstance(linear.weight_fake_quantizer, FakeQuantizer) + prev_weight = copy.deepcopy(m.linear1.weight) + + # Simulate training + optimizer = torch.optim.SGD( + m.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5 + ) + loss_fn = torch.nn.CrossEntropyLoss() + optimizer.zero_grad() + target = torch.randn(1, 512).float() + example_inputs = m.example_inputs() + out = qat_model(*example_inputs) + loss = loss_fn(out, target) + loss.backward() + optimizer.step() + # Assert that weights have valid gradients and are being updated + new_weight = m.linear1.weight + self.assertIsNotNone(new_weight.grad) + self.assertNotEqual(torch.count_nonzero(new_weight.grad), 0) + self.assertFalse(torch.equal(new_weight, prev_weight)) + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/qat/__init__.py b/torchao/quantization/qat/__init__.py index 010ccfc8cc..4a4359e682 100644 --- a/torchao/quantization/qat/__init__.py +++ b/torchao/quantization/qat/__init__.py @@ -11,6 +11,7 @@ Int4WeightOnlyEmbeddingQATQuantizer, ) from .linear import ( + Float8ActInt4WeightQATQuantizer, Int4WeightOnlyQATQuantizer, Int8DynActInt4WeightQATQuantizer, ) @@ -18,6 +19,7 @@ __all__ = [ "ComposableQATQuantizer", "FakeQuantizeConfig", + "Float8ActInt4WeightQATQuantizer", "FromIntXQuantizationAwareTrainingConfig", "Int4WeightOnlyEmbeddingQATQuantizer", "Int4WeightOnlyQATQuantizer", diff --git a/torchao/quantization/qat/fake_quantizer.py b/torchao/quantization/qat/fake_quantizer.py index aca0c032bb..b7ad792dc1 100644 --- a/torchao/quantization/qat/fake_quantizer.py +++ b/torchao/quantization/qat/fake_quantizer.py @@ -32,6 +32,7 @@ from .utils import ( _fake_quantize_per_channel_group, _fake_quantize_per_token, + _Float8RowwiseFakeQuantize, ) @@ -186,3 +187,23 @@ def __repr__(self) -> str: Return a human readable representation of this `FakeQuantizer` with config details. """ return "FakeQuantizer(%s)" % self.config + + +class _Float8RowwiseActivationFakeQuantizer(torch.nn.Module): + """ + Simple fake quantizer for float8 rowwise fake quantization, intended for activations only. + """ + + def __init__(self): + super().__init__() + self.enabled = True + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.enabled: + return _Float8RowwiseFakeQuantize.apply( + x, + torch.float8_e4m3fn, + -1, + ) + else: + return x diff --git a/torchao/quantization/qat/linear.py b/torchao/quantization/qat/linear.py index bffd5dc31f..567b87f342 100644 --- a/torchao/quantization/qat/linear.py +++ b/torchao/quantization/qat/linear.py @@ -28,7 +28,10 @@ from torchao.utils import TORCH_VERSION_AT_LEAST_2_6 from .api import FakeQuantizeConfig -from .fake_quantizer import FakeQuantizer +from .fake_quantizer import ( + FakeQuantizer, + _Float8RowwiseActivationFakeQuantizer, +) from .utils import ( _get_qmin_qmax, ) @@ -145,6 +148,11 @@ def from_linear( return new_linear +# =========================== +# | QAT quantizer interface | +# =========================== + + class _LegacyQATQuantizer(TwoStepQuantizer): """ Base class for sharing common methods across legacy QAT quantizers. @@ -157,9 +165,30 @@ def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]: return None -# ========================================================= -# | Linear int8 dynamic activations + int4 weight QAT | -# ========================================================= +def enable_linear_fake_quant( + mod: torch.nn.Module, + enabled: bool = True, +): + """ + Helper function to enable fake quantization in `FakeQuantizerLinear`. + """ + if isinstance(mod, FakeQuantizedLinear): + if mod.activation_fake_quantizer is not None: + mod.activation_fake_quantizer.enabled = enabled + if mod.weight_fake_quantizer is not None: + mod.weight_fake_quantizer.enabled = enabled + + +def disable_linear_fake_quant(mod: torch.nn.Module): + """ + Helper function to disable fake quantization in `FakeQuantizerLinear`. + """ + enable_linear_fake_quant(mod, enabled=False) + + +# =========================================== +# | int8 dynamic activations + int4 weights | +# =========================================== class Int8DynActInt4WeightQATQuantizer(_LegacyQATQuantizer): @@ -307,6 +336,7 @@ def disable_fake_quant(self): self.enable_fake_quant(False) +# TODO: remove these in favor of enable_linear_fake_quant def enable_8da4w_fake_quant(mod: torch.nn.Module): """ Enable fake quantization for `Int8DynActInt4WeightQATLinear`. @@ -315,6 +345,7 @@ def enable_8da4w_fake_quant(mod: torch.nn.Module): mod.enable_fake_quant() +# TODO: remove in favor of disable_linear_fake_quant def disable_8da4w_fake_quant(mod: torch.nn.Module): """ Disable fake quantization for `Int8DynActInt4WeightQATLinear`. @@ -357,9 +388,9 @@ def _get_8da4w_weight_config( ) -# =================================== -# | Linear int4 weight-only QAT | -# =================================== +# ==================== +# | int4 weight-only | +# ==================== class Int4WeightOnlyQATQuantizer(_LegacyQATQuantizer): @@ -501,6 +532,7 @@ def disable_fake_quant(self): self.enable_fake_quant(False) +# TODO: remove these in favor of enable_linear_fake_quant def enable_4w_fake_quant(mod: torch.nn.Module): """ Enable fake quantization for `Int4WeightOnlyQATLinear`. @@ -509,6 +541,7 @@ def enable_4w_fake_quant(mod: torch.nn.Module): mod.enable_fake_quant() +# TODO: remove these in favor of disable_linear_fake_quant def disable_4w_fake_quant(mod: torch.nn.Module): """ Disable fake quantization for `Int4WeightOnlyQATLinear`. @@ -533,3 +566,74 @@ def _get_4w_weight_config( zero_point_precision=qparams_precision, zero_point_domain=ZeroPointDomain.FLOAT, ) + + +# ============================================= +# | float8 rowwise activations + int4 weights | +# ============================================= + + +class Float8ActInt4WeightQATQuantizer(_LegacyQATQuantizer): + """ + QAT quantizer for applying dynamic rowwise float8 activation + int4 + per group/channel symmetric weight fake quantization to linear layers + in the model. Currently only supports rowwise granularity for float8 + activations. + + args: + group_size (Optional[int]): the number of elements in each quantized + group for weights, defaults to 64. Use None for per channel. + scale_precision: precision of weight scales, defaults to torch.bfloat16. + """ + + def __init__( + self, + group_size: Optional[int] = 64, + scale_precision: torch.dtype = torch.bfloat16, + ): + if group_size is not None: + weight_granularity = "per_group" + else: + weight_granularity = "per_channel" + self._weight_config = FakeQuantizeConfig( + dtype=torch.int4, + granularity=weight_granularity, + group_size=group_size, + is_symmetric=True, + is_dynamic=True, + scale_precision=scale_precision, + ) + + def prepare( + self, model: torch.nn.Module, *args: Any, **kwargs: Any + ) -> torch.nn.Module: + """ + Swap all `nn.Linear` with `FakeQuantizedLinear` with float8 + fake quantizer for activations and int4 fake quantizer for weights. + """ + for name, child in model.named_children(): + if isinstance(child, torch.nn.Linear): + # TODO: add a config for float8? + new_linear = FakeQuantizedLinear.from_linear( + child, + weight_config=self._weight_config, + ) + new_linear.activation_fake_quantizer = ( + _Float8RowwiseActivationFakeQuantizer() + ) + setattr(model, name, new_linear) + else: + self.prepare(child) + return model + + # TODO: add convert path + def convert( + self, model: torch.nn.Module, *args: Any, **kwargs: Any + ) -> torch.nn.Module: + raise NotImplementedError + + def get_activation_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]: + raise NotImplementedError("Float8 FakeQuantizeConfig does not exist yet") + + def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]: + return self.weight_config diff --git a/torchao/quantization/qat/utils.py b/torchao/quantization/qat/utils.py index 01818ef2b2..132020499c 100644 --- a/torchao/quantization/qat/utils.py +++ b/torchao/quantization/qat/utils.py @@ -16,6 +16,38 @@ ) +class _Float8RowwiseFakeQuantize(torch.autograd.Function): + """ + Implementation of float8 rowwise fake quantize with backward STE. + """ + + @staticmethod + def forward( + ctx: torch.autograd.function.FunctionCtx, + x: torch.Tensor, + float8_dtype: torch.dtype, + axiswise_dim: int, + ): + # compute rowwise scale based on `torchao.float8.float8_utils.tensor_to_scale` + eps = 1e-12 + amax = torch.amax(torch.abs(x), dim=axiswise_dim, keepdim=True) + amax = amax.to(torch.float64) + scale = torch.finfo(float8_dtype).max / torch.clamp(amax, min=eps) + scale = scale.to(torch.float32) + + # fake quantize + max_value = torch.finfo(float8_dtype).max + x_fq = x.to(torch.float32) * scale + x_fq = x_fq.clamp(min=-max_value, max=max_value) + x_fq = x_fq.to(float8_dtype).to(x.dtype) + x_fq = x_fq / scale + return x_fq.to(x.dtype) + + @staticmethod + def backward(ctx, gy): + return gy, None, None + + # TODO: delete? class _UnwrapAffineFakeQuantizedTensor(torch.autograd.Function): """