Skip to content

Commit 620f676

Browse files
committed
Add Float8ActInt4WeightQATQuantizer
**Summary:** This commit adds a QAT quantizer that performs float8 dynamic activation + int4 symmetric per channel weight fake quantization. Note that there is no corresponding config for float8 QAT yet. This will be added in a future PR. **Test Plan:** python test/quantization/test_qat.py -k test_float8_fake_quantize python test/quantization/test_qat.py -k test_qat_fp8a4w_quantizer
1 parent f0f1f6c commit 620f676

File tree

5 files changed

+240
-9
lines changed

5 files changed

+240
-9
lines changed

test/quantization/test_qat.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
1818

1919
from torchao import quantize_
20+
from torchao.float8.config import ScalingGranularity
21+
from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic
22+
from torchao.float8.float8_tensor import LinearMMConfig
2023
from torchao.quantization.granularity import (
2124
PerAxis,
2225
PerGroup,
@@ -40,15 +43,18 @@
4043
)
4144
from torchao.quantization.qat.fake_quantizer import (
4245
FakeQuantizer,
46+
_Float8ActivationFakeQuantizer,
4347
)
4448
from torchao.quantization.qat.linear import (
4549
FakeQuantizedLinear,
50+
Float8ActInt4WeightQATQuantizer,
4651
Int4WeightOnlyQATLinear,
4752
Int8DynActInt4WeightQATLinear,
4853
)
4954
from torchao.quantization.qat.utils import (
5055
_fake_quantize_per_channel_group,
5156
_fake_quantize_per_token,
57+
_Float8FakeQuantize,
5258
_GenericFakeQuantize,
5359
_get_qmin_qmax,
5460
)
@@ -69,6 +75,7 @@
6975
)
7076
from torchao.quantization.utils import (
7177
_get_per_token_block_size,
78+
compute_error,
7279
get_group_qparams_symmetric,
7380
get_groupwise_affine_qparams,
7481
groupwise_affine_quantize_tensor,
@@ -1511,7 +1518,6 @@ def test_qat_8da4w_prepare_vs_convert(self, dtype: torch.dtype):
15111518
numerics that match exactly over N trials.
15121519
"""
15131520
from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer
1514-
from torchao.quantization.utils import compute_error
15151521

15161522
num_trials = 1000
15171523
group_size = 16
@@ -1711,6 +1717,66 @@ def test_qat_range_learning(self):
17111717
loss.backward()
17121718
optimizer.step()
17131719

1720+
@parameterized.expand([
1721+
(ScalingGranularity.TENSORWISE,),
1722+
(ScalingGranularity.AXISWISE,),
1723+
])
1724+
def test_float8_fake_quantize(self, scaling_granularity: ScalingGranularity):
1725+
"""
1726+
Test that `_Float8FakeQuantize` is numerically close to `Float8Tensor`.
1727+
"""
1728+
torch.manual_seed(self.SEED)
1729+
dtype = torch.float8_e4m3fn
1730+
x = torch.randn(32, 64)
1731+
if scaling_granularity == ScalingGranularity.AXISWISE:
1732+
axiswise_dim = 0
1733+
else:
1734+
axiswise_dim = None
1735+
out = _Float8FakeQuantize.apply(x, dtype, scaling_granularity, axiswise_dim)
1736+
out_expected = hp_tensor_to_float8_dynamic(
1737+
x,
1738+
dtype,
1739+
LinearMMConfig(),
1740+
scaling_granularity=scaling_granularity,
1741+
axiswise_dim=axiswise_dim,
1742+
).to_original_precision()
1743+
torch.testing.assert_close(out, out_expected, atol=0, rtol=0)
1744+
1745+
@unittest.skipIf(
1746+
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
1747+
)
1748+
def test_qat_fp8a4w_quantizer(self):
1749+
"""
1750+
Test basic model training with Float8ActIntWeightQATQuantizer.
1751+
"""
1752+
torch.manual_seed(self.SEED)
1753+
m = M()
1754+
qat_quantizer = Float8ActInt4WeightQATQuantizer()
1755+
qat_model = qat_quantizer.prepare(m)
1756+
for linear in [m.linear1, m.sub.linear, m.linear2]:
1757+
self.assertIsInstance(linear, FakeQuantizedLinear)
1758+
self.assertIsInstance(linear.activation_fake_quantizer, _Float8ActivationFakeQuantizer)
1759+
self.assertIsInstance(linear.weight_fake_quantizer, FakeQuantizer)
1760+
prev_weight = copy.deepcopy(m.linear1.weight)
1761+
1762+
# Simulate training
1763+
optimizer = torch.optim.SGD(
1764+
m.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5
1765+
)
1766+
loss_fn = torch.nn.CrossEntropyLoss()
1767+
optimizer.zero_grad()
1768+
target = torch.randn(1, 512).float()
1769+
example_inputs = m.example_inputs()
1770+
out = m(*example_inputs)
1771+
loss = loss_fn(out, target)
1772+
loss.backward()
1773+
optimizer.step()
1774+
# Assert that weights have valid gradients and are being updated
1775+
new_weight = m.linear1.weight
1776+
self.assertIsNotNone(new_weight.grad)
1777+
self.assertNotEqual(torch.count_nonzero(new_weight.grad), 0)
1778+
self.assertFalse(torch.equal(new_weight, prev_weight))
1779+
17141780

17151781
if __name__ == "__main__":
17161782
unittest.main()

torchao/quantization/qat/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@
1111
Int4WeightOnlyEmbeddingQATQuantizer,
1212
)
1313
from .linear import (
14+
Float8ActInt4WeightQATQuantizer,
1415
Int4WeightOnlyQATQuantizer,
1516
Int8DynActInt4WeightQATQuantizer,
1617
)
1718

1819
__all__ = [
1920
"ComposableQATQuantizer",
2021
"FakeQuantizeConfig",
22+
"Float8ActInt4WeightQATQuantizer",
2123
"FromIntXQuantizationAwareTrainingConfig",
2224
"Int4WeightOnlyEmbeddingQATQuantizer",
2325
"Int4WeightOnlyQATQuantizer",

torchao/quantization/qat/fake_quantizer.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
import torch
1010

11+
from torchao.float8.config import ScalingGranularity
12+
from torchao.float8.float8_scaling_utils import get_maybe_axiswise_dim
1113
from torchao.quantization.granularity import (
1214
PerAxis,
1315
PerGroup,
@@ -31,6 +33,7 @@
3133
from .utils import (
3234
_fake_quantize_per_channel_group,
3335
_fake_quantize_per_token,
36+
_Float8FakeQuantize,
3437
_Round,
3538
)
3639

@@ -186,3 +189,27 @@ def __repr__(self) -> str:
186189
Return a human readable representation of this `FakeQuantizer` with config details.
187190
"""
188191
return "FakeQuantizer(%s)" % self.config
192+
193+
194+
class _Float8ActivationFakeQuantizer(torch.nn.Module):
195+
"""
196+
Simple fake quantizer for float8 fake quantization, intended for activations only.
197+
"""
198+
199+
FLOAT8_DTYPE = torch.float8_e4m3fn
200+
201+
def __init__(self, scaling_granularity: ScalingGranularity):
202+
super().__init__()
203+
self.enabled = True
204+
self.scaling_granularity = scaling_granularity
205+
206+
def forward(self, x: torch.Tensor) -> torch.Tensor:
207+
if self.enabled:
208+
return _Float8FakeQuantize.apply(
209+
x,
210+
self.FLOAT8_DTYPE,
211+
self.scaling_granularity,
212+
get_maybe_axiswise_dim(-1, self.scaling_granularity),
213+
)
214+
else:
215+
return x

torchao/quantization/qat/linear.py

Lines changed: 110 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch.nn.functional as F
1111

1212
from torchao.dtypes.utils import is_device
13+
from torchao.float8.config import ScalingGranularity
1314
from torchao.quantization.granularity import PerGroup
1415
from torchao.quantization.linear_quant_modules import (
1516
Int8DynActInt4WeightLinear,
@@ -28,7 +29,10 @@
2829
from torchao.utils import TORCH_VERSION_AT_LEAST_2_6
2930

3031
from .api import FakeQuantizeConfig
31-
from .fake_quantizer import FakeQuantizer
32+
from .fake_quantizer import (
33+
FakeQuantizer,
34+
_Float8ActivationFakeQuantizer,
35+
)
3236
from .utils import (
3337
_get_qmin_qmax,
3438
)
@@ -145,6 +149,11 @@ def from_linear(
145149
return new_linear
146150

147151

152+
# ===========================
153+
# | QAT quantizer interface |
154+
# ===========================
155+
156+
148157
class _LegacyQATQuantizer(TwoStepQuantizer):
149158
"""
150159
Base class for sharing common methods across legacy QAT quantizers.
@@ -157,9 +166,30 @@ def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]:
157166
return None
158167

159168

160-
# =========================================================
161-
# | Linear int8 dynamic activations + int4 weight QAT |
162-
# =========================================================
169+
def enable_linear_fake_quant(
170+
mod: torch.nn.Module,
171+
enabled: bool = True,
172+
):
173+
"""
174+
Helper function to enable fake quantization in `FakeQuantizerLinear`.
175+
"""
176+
if isinstance(mod, FakeQuantizedLinear):
177+
if mod.activation_fake_quantizer is not None:
178+
mod.activation_fake_quantizer.enabled = enabled
179+
if mod.weight_fake_quantizer is not None:
180+
mod.weight_fake_quantizer.enabled = enabled
181+
182+
183+
def disable_linear_fake_quant(mod: torch.nn.Module):
184+
"""
185+
Helper function to disable fake quantization in `FakeQuantizerLinear`.
186+
"""
187+
enable_linear_fake_quant(mod, enabled=False)
188+
189+
190+
# ===========================================
191+
# | int8 dynamic activations + int4 weights |
192+
# ===========================================
163193

164194

165195
class Int8DynActInt4WeightQATQuantizer(_LegacyQATQuantizer):
@@ -307,6 +337,7 @@ def disable_fake_quant(self):
307337
self.enable_fake_quant(False)
308338

309339

340+
# TODO: remove these in favor of enable_linear_fake_quant
310341
def enable_8da4w_fake_quant(mod: torch.nn.Module):
311342
"""
312343
Enable fake quantization for `Int8DynActInt4WeightQATLinear`.
@@ -315,6 +346,7 @@ def enable_8da4w_fake_quant(mod: torch.nn.Module):
315346
mod.enable_fake_quant()
316347

317348

349+
# TODO: remove in favor of disable_linear_fake_quant
318350
def disable_8da4w_fake_quant(mod: torch.nn.Module):
319351
"""
320352
Disable fake quantization for `Int8DynActInt4WeightQATLinear`.
@@ -357,9 +389,9 @@ def _get_8da4w_weight_config(
357389
)
358390

359391

360-
# ===================================
361-
# | Linear int4 weight-only QAT |
362-
# ===================================
392+
# ====================
393+
# | int4 weight-only |
394+
# ====================
363395

364396

365397
class Int4WeightOnlyQATQuantizer(_LegacyQATQuantizer):
@@ -501,6 +533,7 @@ def disable_fake_quant(self):
501533
self.enable_fake_quant(False)
502534

503535

536+
# TODO: remove these in favor of enable_linear_fake_quant
504537
def enable_4w_fake_quant(mod: torch.nn.Module):
505538
"""
506539
Enable fake quantization for `Int4WeightOnlyQATLinear`.
@@ -509,6 +542,7 @@ def enable_4w_fake_quant(mod: torch.nn.Module):
509542
mod.enable_fake_quant()
510543

511544

545+
# TODO: remove these in favor of disable_linear_fake_quant
512546
def disable_4w_fake_quant(mod: torch.nn.Module):
513547
"""
514548
Disable fake quantization for `Int4WeightOnlyQATLinear`.
@@ -533,3 +567,72 @@ def _get_4w_weight_config(
533567
zero_point_precision=qparams_precision,
534568
zero_point_domain=ZeroPointDomain.FLOAT,
535569
)
570+
571+
572+
# =====================================
573+
# | float8 activations + int4 weights |
574+
# =====================================
575+
576+
577+
class Float8ActInt4WeightQATQuantizer:
578+
"""
579+
QAT quantizer for applying dynamic float8 activation + int4
580+
per channel, symmetric weight fake quantization to linear
581+
layers in the model.
582+
583+
args:
584+
activation_scaling_granularity (ScalingGranularity): float8 scaling granularity
585+
for activation fake quantization, defaults to AXISWISE (per row).
586+
scale_precision (torch.dtype): precision of weight scales, defaults to torch.bfloat16
587+
"""
588+
589+
def __init__(
590+
self,
591+
activation_scaling_granularity: ScalingGranularity = ScalingGranularity.AXISWISE,
592+
scale_precision: torch.dtype = torch.bfloat16,
593+
):
594+
# symmetric, so zero point precision does not matter
595+
zero_point_precision = torch.float32
596+
self._activation_scaling_granularity = activation_scaling_granularity
597+
self._weight_config = FakeQuantizeConfig(
598+
dtype=torch.int4,
599+
granularity="per_channel",
600+
is_symmetric=True,
601+
is_dynamic=True,
602+
scale_precision=scale_precision,
603+
zero_point_precision=zero_point_precision,
604+
)
605+
606+
def prepare(
607+
self, model: torch.nn.Module, *args: Any, **kwargs: Any
608+
) -> torch.nn.Module:
609+
"""
610+
Swap all `nn.Linear` with `FakeQuantizedLinear` with float8
611+
fake quantizer for activations and int4 fake quantizer for weights.
612+
"""
613+
for name, child in model.named_children():
614+
if isinstance(child, torch.nn.Linear):
615+
# TODO: add a config for float8?
616+
new_linear = FakeQuantizedLinear.from_linear(
617+
child,
618+
weight_config=self._weight_config,
619+
)
620+
new_linear.activation_fake_quantizer = _Float8ActivationFakeQuantizer(
621+
self._activation_scaling_granularity
622+
)
623+
setattr(model, name, new_linear)
624+
else:
625+
self.prepare(child)
626+
return model
627+
628+
# TODO: add convert path
629+
def convert(
630+
self, model: torch.nn.Module, *args: Any, **kwargs: Any
631+
) -> torch.nn.Module:
632+
raise NotImplementedError
633+
634+
def get_activation_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]:
635+
raise NotImplementedError("Float8 FakeQuantizeConfig does not exist yet")
636+
637+
def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]:
638+
return self.weight_config

0 commit comments

Comments
 (0)