Skip to content

Commit 239e57a

Browse files
authored
Build SmoothQuant release pipeline (#3010)
* Summary: Adds SMOOTHQUANT-W8A8 quantization method to the TorchAO model release pipeline. - Adjusted defaults: Increased calibration samples from 10 to 128 to ensure consistency, reduced max sequence length (SeqLen) from 2048 to 1024 - Updated HF CLI command: `huggingface-cli login` to `hf auth login` Test plan: ```bash python quantize_and_upload.py --model_id Qwen/Qwen3-8B --quant SMOOTHQUANT-W8A8 --push_to_hub --task bbh ``` * add SmoothQuant uploader * separate docs for AWQ & SmoothQuant * rename SMOOTHQUANT-W8A8 to SMOOTHQUANT-INT8-INT8 * add SmoothQuant release example * update example in docs * rename SMOOTHQUANT-INT8-INT8 to SmoothQuant-INT8-INT8 * rename SMOOTHQUANT to SmoothQuant * revert max_seq_length default to 2048
1 parent 2d31ac3 commit 239e57a

File tree

2 files changed

+83
-6
lines changed

2 files changed

+83
-6
lines changed

.github/scripts/torchao_model_releases/README.md

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,7 @@ By default, we release FP8, INT4, INT8-INT4 checkpoints, with model card pre-fil
5151

5252
Examples:
5353
```
54-
# Note: first login with `huggingface-cli login`, the quantized model will be uploaded to
55-
# the logged in user
54+
# Note: first login with `hf auth login`, the quantized model will be uploaded to the logged in user
5655
5756
# release with default quant options (FP8, INT4, INT8-INT4)
5857
./release.sh --model_id Qwen/Qwen3-8B --push_to_hub
@@ -63,8 +62,17 @@ Examples:
6362

6463
Note: for initial release, please include `--populate_model_card_template` to populate model card template.
6564

65+
### SmoothQuant-INT8-INT8
66+
[SmoothQuant](https://arxiv.org/abs/2211.10438) smooths activation outliers by migrating quantization difficulty from activations to weights through a mathematically equivalent per-channel scaling transformation. That means SmoothQuant observes activation distribution before applying quantization.
67+
68+
Examples:
69+
```
70+
# release SmoothQuant-INT8-INT8 model, calibrated with a specific task
71+
python quantize_and_upload.py --model_id Qwen/Qwen3-8B --quant SmoothQuant-INT8-INT8 --push_to_hub --task bbh --populate_model_card_template
72+
```
73+
6674
### AWQ-INT4
67-
[AWQ](https://arxiv.org/abs/2306.00978) is a technique to improve accuracy for weight only quantization. It improves accuracy by preserving "salient" weight channels that has high impact on the accuracy of output, through multiplying the weight channel by a scale, and do the reverse for the correspnoding activation, since activation is not quantized, there is no additional loss from activation, while the quantization loss from weight can be reduced.
75+
Similar to SmoothQuant, [AWQ](https://arxiv.org/abs/2306.00978) improves accuracy by preserving "salient" weight channels that has high impact on the accuracy of output. The notable point is that AWQ uses activation distribution to find salient weights, not weight distribution, multiplying the weight channel by a scale, and doing the reverse for the corresponding activation. Since activation is not quantized, there is no additional loss from activation, while the quantization loss from weight can be reduced.
6876

6977
After eval for INT4 checkpoint is done, we might find some task have a large accuracy drop compared to high precision baseline, in that case we can do a calibration for that task, with a few samples, tasks are selected from [lm-eval](https://github.com/EleutherAI/lm-eval\uation-harness/blob/main/lm_eval/tasks/README.md). You can follow [new task guide](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/new_task_guide.md) to add new tasks to lm-eval.
7078

.github/scripts/torchao_model_releases/quantize_and_upload.py

Lines changed: 72 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
from torchao.prototype.awq import (
1616
AWQConfig,
1717
)
18+
from torchao.prototype.smoothquant import SmoothQuantConfig
1819
from torchao.quantization import (
1920
Float8DynamicActivationFloat8WeightConfig,
2021
Int4WeightOnlyConfig,
22+
Int8DynamicActivationInt8WeightConfig,
2123
Int8DynamicActivationIntxWeightConfig,
2224
IntxWeightOnlyConfig,
2325
ModuleFqnToConfig,
@@ -265,6 +267,42 @@ def _untie_weights_and_save_locally(model_id):
265267
tokenizer = AutoTokenizer.from_pretrained(model_id)
266268
"""
267269

270+
271+
_smoothquant_int8_int8_quant_code = """
272+
from torchao.quantization import Int8DynamicActivationInt8WeightConfig, quantize_
273+
from torchao.prototype.smoothquant import SmoothQuantConfig
274+
275+
from torchao._models._eval import TransformerEvalWrapper
276+
model = AutoModelForCausalLM.from_pretrained(
277+
model_to_quantize,
278+
device_map="auto",
279+
torch_dtype=torch.bfloat16,
280+
)
281+
tokenizer = AutoTokenizer.from_pretrained(model_id)
282+
283+
base_config = Int8DynamicActivationInt8WeightConfig()
284+
quant_config = SmoothQuantConfig(base_config, step="prepare")
285+
quantize_(
286+
model,
287+
quant_config,
288+
)
289+
TransformerEvalWrapper(
290+
model=model,
291+
tokenizer=tokenizer,
292+
max_seq_length=max_seq_length,
293+
).run_eval(
294+
tasks=tasks,
295+
limit=calibration_limit,
296+
)
297+
quant_config = SmoothQuantConfig(base_config, step="convert")
298+
quantize_(model, quant_config)
299+
300+
quantized_model = model
301+
quant_config = SmoothQuantConfig(base_config, step="prepare_for_loading")
302+
quantized_model.config.quantization_config = TorchAoConfig(quant_config)
303+
"""
304+
305+
268306
_awq_int4_quant_code = """
269307
from torchao.quantization import Int4WeightOnlyConfig, quantize_
270308
from torchao.prototype.awq import (
@@ -647,6 +685,7 @@ def quantize_and_upload(
647685
),
648686
}
649687
),
688+
"SmoothQuant-INT8-INT8": Int8DynamicActivationInt8WeightConfig(),
650689
}
651690

652691
quant_to_quant_code = {
@@ -655,6 +694,7 @@ def quantize_and_upload(
655694
"INT8-INT4": _int8_int4_quant_code,
656695
"INT8-INT4-HQQ": _int8_int4_hqq_quant_code,
657696
"AWQ-INT4": _awq_int4_quant_code,
697+
"SmoothQuant-INT8-INT8": _smoothquant_int8_int8_quant_code,
658698
}
659699

660700
# preparation
@@ -698,6 +738,35 @@ def quantize_and_upload(
698738
quantized_model = model
699739
quant_config = AWQConfig(base_config, step="prepare_for_loading")
700740
quantized_model.config.quantization_config = TorchAoConfig(quant_config)
741+
elif quant == "SmoothQuant-INT8-INT8":
742+
model = AutoModelForCausalLM.from_pretrained(
743+
model_to_quantize,
744+
device_map="auto",
745+
torch_dtype=torch.bfloat16,
746+
)
747+
tokenizer = AutoTokenizer.from_pretrained(model_id)
748+
749+
base_config = Int8DynamicActivationInt8WeightConfig()
750+
quant_config = SmoothQuantConfig(base_config, step="prepare")
751+
quantize_(
752+
model,
753+
quant_config,
754+
)
755+
TransformerEvalWrapper(
756+
model=model,
757+
tokenizer=tokenizer,
758+
max_seq_length=max_seq_length,
759+
).run_eval(
760+
tasks=tasks,
761+
limit=calibration_limit,
762+
)
763+
quant_config = SmoothQuantConfig(base_config, step="convert")
764+
quantize_(model, quant_config)
765+
766+
quantized_model = model
767+
768+
load_config = SmoothQuantConfig(base_config, step="prepare_for_loading")
769+
quantized_model.config.quantization_config = TorchAoConfig(load_config)
701770
else:
702771
# other quantization are integrated with `from_pretrained` in huggingface transformers
703772
assert quant in quant_to_config, f"Unsupported quant option: {quant}"
@@ -812,7 +881,7 @@ def quantize_and_upload(
812881
parser.add_argument(
813882
"--quant",
814883
type=str,
815-
help="Quantization method. Options are FP8, INT4, INT8-INT4, INT8-INT4-HQQ, AWQ-INT4",
884+
help="Quantization method. Options are FP8, INT4, INT8-INT4, INT8-INT4-HQQ, AWQ-INT4, SmoothQuant-INT8-INT8",
816885
)
817886
parser.add_argument(
818887
"--tasks",
@@ -824,8 +893,8 @@ def quantize_and_upload(
824893
parser.add_argument(
825894
"--calibration_limit",
826895
type=int,
827-
default=10,
828-
help="Number of samples to use for calibration. Default is 10.",
896+
default=128,
897+
help="Number of samples to use for calibration. Default is 128.",
829898
)
830899
parser.add_argument(
831900
"--max_seq_length",

0 commit comments

Comments
 (0)