Skip to content

Commit 80cc501

Browse files
committed
Update on "Add NVFP4 QAT"
**Summary:** This commit adds a QAT flow for NVFP4, following the numerics in `NVFP4Tensor` closely but without the dtyping casting, swizzling, and the packing/unpacking. Users can call this flow as follows: ``` from torchao.quantization import quantize_ from torchao.quantization.qat import NVFP4FakeQuantizeConfig, QATConfig qat_config = QATConfig( activation_config=NVFP4FakeQuantizeConfig(), weight_config=NVFP4FakeQuantizeConfig(), step="prepare", ) quantize_(model, qat_config) ``` **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 ``` Initial benchmarks on fine-tuning Qwen3-1.7B on alpaca for 3 epochs: ``` # Without QAT | Tasks |Version|Filter|n-shot| Metric | | Value | |Stderr| |--------|------:|------|------|---------------|---|------:|---|------| |wikitext| 2|none |None |bits_per_byte |↓ | 0.8322|± | N/A| | | |none |None |byte_perplexity|↓ | 1.7804|± | N/A| | | |none |None |word_perplexity|↓ |21.8611|± | N/A| # With QAT | Tasks |Version|Filter|n-shot| Metric | | Value | |Stderr| |--------|------:|------|------|---------------|---|------:|---|------| |wikitext| 2|none |None |bits_per_byte |↓ | 0.8271|± | N/A| | | |none |None |byte_perplexity|↓ | 1.7741|± | N/A| | | |none |None |word_perplexity|↓ |21.4467|± | N/A| ``` [ghstack-poisoned]
2 parents 732fb16 + cda3a85 commit 80cc501

File tree

3 files changed

+4
-3
lines changed

3 files changed

+4
-3
lines changed

docs/source/api_ref_qat.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ Custom QAT APIs
2727
FakeQuantizeConfigBase
2828
IntxFakeQuantizeConfig
2929
Float8FakeQuantizeConfig
30-
NVFP4FakeQuantizeConfig
3130
FakeQuantizedLinear
3231
FakeQuantizedEmbedding
3332
FakeQuantizerBase
@@ -63,3 +62,5 @@ Prototype
6362
:nosignatures:
6463

6564
initialize_fake_quantizers
65+
NVFP4FakeQuantizeConfig
66+
NVFP4FakeQuantizer

torchao/quantization/qat/fake_quantize_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def __post_init__(self):
8080
@dataclass
8181
class NVFP4FakeQuantizeConfig(FakeQuantizeConfigBase):
8282
"""
83-
Config for fake quantizing weights or activations to NVIDIA's NVFP4 format
83+
(Prototype) Config for fake quantizing weights or activations to NVIDIA's NVFP4 format
8484
according to https://developer.nvidia.com/blog/introducing-nvfp4-for-efficient-and-accurate-low-precision-inference/.
8585
8686
Fake quantization numerics follow `NVFP4Tensor` closely: https://github.com/pytorch/ao/blob/main/torchao/prototype/mx_formats/nvfp4_tensor.py.

torchao/quantization/qat/fake_quantizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
9797

9898
class NVFP4FakeQuantizer(FakeQuantizerBase):
9999
"""
100-
Generic module for applying NVFP4 fake quantization to a tensor, as specified in the config.
100+
(Prototype) Generic module for applying NVFP4 fake quantization to a tensor, as specified in the config.
101101
"""
102102

103103
def __init__(self, config: NVFP4FakeQuantizeConfig):

0 commit comments

Comments
 (0)