|
5 | 5 | import torch |
6 | 6 |
|
7 | 7 | from torchao.prototype.smoothquant import ( |
| 8 | + SmoothQuantConfig, |
8 | 9 | SmoothQuantObservedLinear, |
9 | 10 | insert_smooth_quant_observer_, |
10 | 11 | load_smooth_quant_recipe, |
11 | 12 | save_smooth_quant_recipe, |
12 | | - smooth_quant, |
13 | 13 | ) |
14 | 14 | from torchao.quantization import quantize_ |
15 | 15 | from torchao.quantization.utils import ( |
@@ -85,7 +85,7 @@ def forward(self, x): |
85 | 85 | m(data) |
86 | 86 | # quantize |
87 | 87 | is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) |
88 | | - quantize_(m, smooth_quant(), is_observed_linear) |
| 88 | + quantize_(m, SmoothQuantConfig(), is_observed_linear) |
89 | 89 | with torch.inference_mode(): |
90 | 90 | if TORCH_VERSION_AT_LEAST_2_5: |
91 | 91 | m = torch.compile(m, fullgraph=True) |
@@ -173,7 +173,7 @@ def test_save_load_recipe(alpha, quant_mode, device, idtype): |
173 | 173 |
|
174 | 174 | # quantize |
175 | 175 | is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) |
176 | | - quantize_(m, smooth_quant(), is_observed_linear) |
| 176 | + quantize_(m, SmoothQuantConfig(), is_observed_linear) |
177 | 177 | if TORCH_VERSION_AT_LEAST_2_5: |
178 | 178 | # earlier versions are not compatible |
179 | 179 | m = torch.compile(m, fullgraph=True) |
|
0 commit comments