Commit fbd4716
committed
Enable range learning for QAT
**Summary:** This commit adds the option for QAT users to use range
learning during training. Range learning means we train the scale
and zero point instead of recomputing them based on the input
at every iteration.
Example usage:
```
import torch
from torchao.quantization import quantize_
from torchao.quantization.qat import (
FakeQuantizeConfig,
IntXQuantizationAwareTrainingConfig,
initialize_fake_quantizers,
)
config = FakeQuantizeConfig(
torch.int8,
"per_channel",
is_dynamic=False,
range_learning=True,
scale_precision=torch.float32,
zero_point_precision=torch.float32,
)
m = M()
example_inputs = (torch.randn(16, 32),)
quantize_(m, IntXQuantizationAwareTrainingConfig(weight_config=config))
# New required step to turn scales and zero points into trainable
# `nn.Parameters`, must be called before initializing the optimizer
initialize_fake_quantizers(m, example_inputs)
# initialize the optimizer
# do training
```
**Test Plan:**
python test/quantization/test_qat.py -k test_fake_quantize_config_dynamic_and_range_learning
python test/quantization/test_qat.py -k test_fake_quantizer_range_learning
python test/quantization/test_qat.py -k test_qat_range_learning1 parent 5549da8 commit fbd4716
File tree
8 files changed
+217
-19
lines changed- test/quantization
- third_party
- torchao/quantization/qat
8 files changed
+217
-19
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
9 | 9 | | |
10 | 10 | | |
11 | 11 | | |
| 12 | + | |
12 | 13 | | |
13 | 14 | | |
14 | 15 | | |
| |||
26 | 27 | | |
27 | 28 | | |
28 | 29 | | |
| 30 | + | |
29 | 31 | | |
| 32 | + | |
30 | 33 | | |
31 | 34 | | |
32 | 35 | | |
| |||
99 | 102 | | |
100 | 103 | | |
101 | 104 | | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
102 | 115 | | |
103 | 116 | | |
104 | 117 | | |
| |||
996 | 1009 | | |
997 | 1010 | | |
998 | 1011 | | |
| 1012 | + | |
| 1013 | + | |
| 1014 | + | |
| 1015 | + | |
| 1016 | + | |
| 1017 | + | |
| 1018 | + | |
| 1019 | + | |
| 1020 | + | |
| 1021 | + | |
| 1022 | + | |
| 1023 | + | |
| 1024 | + | |
| 1025 | + | |
| 1026 | + | |
999 | 1027 | | |
1000 | 1028 | | |
1001 | 1029 | | |
| |||
1591 | 1619 | | |
1592 | 1620 | | |
1593 | 1621 | | |
| 1622 | + | |
| 1623 | + | |
| 1624 | + | |
| 1625 | + | |
| 1626 | + | |
| 1627 | + | |
| 1628 | + | |
| 1629 | + | |
| 1630 | + | |
| 1631 | + | |
| 1632 | + | |
| 1633 | + | |
| 1634 | + | |
| 1635 | + | |
| 1636 | + | |
| 1637 | + | |
| 1638 | + | |
| 1639 | + | |
| 1640 | + | |
| 1641 | + | |
| 1642 | + | |
| 1643 | + | |
| 1644 | + | |
| 1645 | + | |
| 1646 | + | |
| 1647 | + | |
| 1648 | + | |
| 1649 | + | |
| 1650 | + | |
| 1651 | + | |
| 1652 | + | |
| 1653 | + | |
| 1654 | + | |
| 1655 | + | |
| 1656 | + | |
| 1657 | + | |
| 1658 | + | |
| 1659 | + | |
| 1660 | + | |
| 1661 | + | |
| 1662 | + | |
| 1663 | + | |
| 1664 | + | |
| 1665 | + | |
| 1666 | + | |
| 1667 | + | |
| 1668 | + | |
| 1669 | + | |
| 1670 | + | |
| 1671 | + | |
| 1672 | + | |
| 1673 | + | |
| 1674 | + | |
| 1675 | + | |
| 1676 | + | |
| 1677 | + | |
| 1678 | + | |
| 1679 | + | |
| 1680 | + | |
| 1681 | + | |
| 1682 | + | |
| 1683 | + | |
| 1684 | + | |
| 1685 | + | |
| 1686 | + | |
| 1687 | + | |
| 1688 | + | |
| 1689 | + | |
| 1690 | + | |
| 1691 | + | |
| 1692 | + | |
| 1693 | + | |
| 1694 | + | |
| 1695 | + | |
| 1696 | + | |
| 1697 | + | |
| 1698 | + | |
| 1699 | + | |
| 1700 | + | |
| 1701 | + | |
| 1702 | + | |
| 1703 | + | |
| 1704 | + | |
| 1705 | + | |
| 1706 | + | |
| 1707 | + | |
| 1708 | + | |
| 1709 | + | |
| 1710 | + | |
1594 | 1711 | | |
1595 | 1712 | | |
1596 | 1713 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
4 | 4 | | |
5 | 5 | | |
6 | 6 | | |
| 7 | + | |
7 | 8 | | |
8 | 9 | | |
9 | 10 | | |
| |||
17 | 18 | | |
18 | 19 | | |
19 | 20 | | |
20 | | - | |
| 21 | + | |
21 | 22 | | |
| 23 | + | |
22 | 24 | | |
| 25 | + | |
| 26 | + | |
23 | 27 | | |
24 | 28 | | |
25 | | - | |
26 | | - | |
27 | 29 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
5 | 5 | | |
6 | 6 | | |
7 | 7 | | |
8 | | - | |
| 8 | + | |
9 | 9 | | |
10 | 10 | | |
11 | 11 | | |
| |||
51 | 51 | | |
52 | 52 | | |
53 | 53 | | |
54 | | - | |
| 54 | + | |
| 55 | + | |
55 | 56 | | |
56 | 57 | | |
57 | 58 | | |
| |||
123 | 124 | | |
124 | 125 | | |
125 | 126 | | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
126 | 131 | | |
127 | 132 | | |
128 | 133 | | |
| |||
394 | 399 | | |
395 | 400 | | |
396 | 401 | | |
| 402 | + | |
| 403 | + | |
| 404 | + | |
| 405 | + | |
| 406 | + | |
| 407 | + | |
| 408 | + | |
| 409 | + | |
| 410 | + | |
| 411 | + | |
| 412 | + | |
| 413 | + | |
| 414 | + | |
| 415 | + | |
| 416 | + | |
| 417 | + | |
| 418 | + | |
| 419 | + | |
| 420 | + | |
| 421 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
92 | 92 | | |
93 | 93 | | |
94 | 94 | | |
| 95 | + | |
95 | 96 | | |
96 | 97 | | |
97 | 98 | | |
| |||
116 | 117 | | |
117 | 118 | | |
118 | 119 | | |
| 120 | + | |
119 | 121 | | |
120 | 122 | | |
121 | 123 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
31 | 31 | | |
32 | 32 | | |
33 | 33 | | |
| 34 | + | |
34 | 35 | | |
35 | 36 | | |
36 | 37 | | |
| |||
46 | 47 | | |
47 | 48 | | |
48 | 49 | | |
49 | | - | |
50 | | - | |
51 | | - | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
52 | 54 | | |
53 | | - | |
| 55 | + | |
54 | 56 | | |
55 | 57 | | |
56 | 58 | | |
57 | 59 | | |
58 | 60 | | |
59 | 61 | | |
60 | 62 | | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
61 | 74 | | |
62 | 75 | | |
63 | 76 | | |
64 | 77 | | |
65 | 78 | | |
66 | 79 | | |
67 | 80 | | |
68 | | - | |
| 81 | + | |
69 | 82 | | |
70 | 83 | | |
71 | 84 | | |
72 | 85 | | |
73 | 86 | | |
74 | | - | |
75 | 87 | | |
76 | 88 | | |
77 | 89 | | |
| |||
85 | 97 | | |
86 | 98 | | |
87 | 99 | | |
| 100 | + | |
88 | 101 | | |
89 | 102 | | |
90 | | - | |
| 103 | + | |
91 | 104 | | |
92 | 105 | | |
93 | 106 | | |
| |||
129 | 142 | | |
130 | 143 | | |
131 | 144 | | |
| 145 | + | |
132 | 146 | | |
133 | 147 | | |
134 | 148 | | |
| |||
147 | 161 | | |
148 | 162 | | |
149 | 163 | | |
| 164 | + | |
| 165 | + | |
| 166 | + | |
| 167 | + | |
| 168 | + | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
| 174 | + | |
| 175 | + | |
| 176 | + | |
| 177 | + | |
| 178 | + | |
| 179 | + | |
| 180 | + | |
| 181 | + | |
| 182 | + | |
| 183 | + | |
150 | 184 | | |
151 | 185 | | |
152 | 186 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
18 | 18 | | |
19 | 19 | | |
20 | 20 | | |
| 21 | + | |
21 | 22 | | |
22 | 23 | | |
23 | 24 | | |
| |||
83 | 84 | | |
84 | 85 | | |
85 | 86 | | |
86 | | - | |
87 | | - | |
88 | | - | |
89 | | - | |
90 | | - | |
91 | | - | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
92 | 94 | | |
93 | 95 | | |
94 | 96 | | |
| |||
108 | 110 | | |
109 | 111 | | |
110 | 112 | | |
| 113 | + | |
111 | 114 | | |
112 | 115 | | |
113 | 116 | | |
| |||
131 | 134 | | |
132 | 135 | | |
133 | 136 | | |
| 137 | + | |
134 | 138 | | |
135 | 139 | | |
136 | 140 | | |
| |||
0 commit comments