diff --git a/.github/scripts/torchao_model_releases/README.md b/.github/scripts/torchao_model_releases/README.md index 4ff1f96b14..be8c32be46 100644 --- a/.github/scripts/torchao_model_releases/README.md +++ b/.github/scripts/torchao_model_releases/README.md @@ -51,8 +51,7 @@ By default, we release FP8, INT4, INT8-INT4 checkpoints, with model card pre-fil Examples: ``` -# Note: first login with `huggingface-cli login`, the quantized model will be uploaded to -# the logged in user +# Note: first login with `hf auth login`, the quantized model will be uploaded to the logged in user # release with default quant options (FP8, INT4, INT8-INT4) ./release.sh --model_id Qwen/Qwen3-8B --push_to_hub @@ -63,8 +62,17 @@ Examples: Note: for initial release, please include `--populate_model_card_template` to populate model card template. +### SmoothQuant-INT8-INT8 +[SmoothQuant](https://arxiv.org/abs/2211.10438) smooths activation outliers by migrating quantization difficulty from activations to weights through a mathematically equivalent per-channel scaling transformation. That means SmoothQuant observes activation distribution before applying quantization. + +Examples: +``` +# release SmoothQuant-INT8-INT8 model, calibrated with a specific task +python quantize_and_upload.py --model_id Qwen/Qwen3-8B --quant SmoothQuant-INT8-INT8 --push_to_hub --task bbh --populate_model_card_template +``` + ### AWQ-INT4 -[AWQ](https://arxiv.org/abs/2306.00978) is a technique to improve accuracy for weight only quantization. It improves accuracy by preserving "salient" weight channels that has high impact on the accuracy of output, through multiplying the weight channel by a scale, and do the reverse for the correspnoding activation, since activation is not quantized, there is no additional loss from activation, while the quantization loss from weight can be reduced. +Similar to SmoothQuant, [AWQ](https://arxiv.org/abs/2306.00978) improves accuracy by preserving "salient" weight channels that has high impact on the accuracy of output. The notable point is that AWQ uses activation distribution to find salient weights, not weight distribution, multiplying the weight channel by a scale, and doing the reverse for the corresponding activation. Since activation is not quantized, there is no additional loss from activation, while the quantization loss from weight can be reduced. After eval for INT4 checkpoint is done, we might find some task have a large accuracy drop compared to high precision baseline, in that case we can do a calibration for that task, with a few samples, tasks are selected from [lm-eval](https://github.com/EleutherAI/lm-eval\uation-harness/blob/main/lm_eval/tasks/README.md). You can follow [new task guide](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/new_task_guide.md) to add new tasks to lm-eval. diff --git a/.github/scripts/torchao_model_releases/quantize_and_upload.py b/.github/scripts/torchao_model_releases/quantize_and_upload.py index 208b6f29a1..d58c8fae8a 100644 --- a/.github/scripts/torchao_model_releases/quantize_and_upload.py +++ b/.github/scripts/torchao_model_releases/quantize_and_upload.py @@ -15,9 +15,11 @@ from torchao.prototype.awq import ( AWQConfig, ) +from torchao.prototype.smoothquant import SmoothQuantConfig from torchao.quantization import ( Float8DynamicActivationFloat8WeightConfig, Int4WeightOnlyConfig, + Int8DynamicActivationInt8WeightConfig, Int8DynamicActivationIntxWeightConfig, IntxWeightOnlyConfig, ModuleFqnToConfig, @@ -265,6 +267,42 @@ def _untie_weights_and_save_locally(model_id): tokenizer = AutoTokenizer.from_pretrained(model_id) """ + +_smoothquant_int8_int8_quant_code = """ +from torchao.quantization import Int8DynamicActivationInt8WeightConfig, quantize_ +from torchao.prototype.smoothquant import SmoothQuantConfig + +from torchao._models._eval import TransformerEvalWrapper +model = AutoModelForCausalLM.from_pretrained( + model_to_quantize, + device_map="auto", + torch_dtype=torch.bfloat16, +) +tokenizer = AutoTokenizer.from_pretrained(model_id) + +base_config = Int8DynamicActivationInt8WeightConfig() +quant_config = SmoothQuantConfig(base_config, step="prepare") +quantize_( + model, + quant_config, +) +TransformerEvalWrapper( + model=model, + tokenizer=tokenizer, + max_seq_length=max_seq_length, +).run_eval( + tasks=tasks, + limit=calibration_limit, +) +quant_config = SmoothQuantConfig(base_config, step="convert") +quantize_(model, quant_config) + +quantized_model = model +quant_config = SmoothQuantConfig(base_config, step="prepare_for_loading") +quantized_model.config.quantization_config = TorchAoConfig(quant_config) +""" + + _awq_int4_quant_code = """ from torchao.quantization import Int4WeightOnlyConfig, quantize_ from torchao.prototype.awq import ( @@ -647,6 +685,7 @@ def quantize_and_upload( ), } ), + "SmoothQuant-INT8-INT8": Int8DynamicActivationInt8WeightConfig(), } quant_to_quant_code = { @@ -655,6 +694,7 @@ def quantize_and_upload( "INT8-INT4": _int8_int4_quant_code, "INT8-INT4-HQQ": _int8_int4_hqq_quant_code, "AWQ-INT4": _awq_int4_quant_code, + "SmoothQuant-INT8-INT8": _smoothquant_int8_int8_quant_code, } # preparation @@ -698,6 +738,35 @@ def quantize_and_upload( quantized_model = model quant_config = AWQConfig(base_config, step="prepare_for_loading") quantized_model.config.quantization_config = TorchAoConfig(quant_config) + elif quant == "SmoothQuant-INT8-INT8": + model = AutoModelForCausalLM.from_pretrained( + model_to_quantize, + device_map="auto", + torch_dtype=torch.bfloat16, + ) + tokenizer = AutoTokenizer.from_pretrained(model_id) + + base_config = Int8DynamicActivationInt8WeightConfig() + quant_config = SmoothQuantConfig(base_config, step="prepare") + quantize_( + model, + quant_config, + ) + TransformerEvalWrapper( + model=model, + tokenizer=tokenizer, + max_seq_length=max_seq_length, + ).run_eval( + tasks=tasks, + limit=calibration_limit, + ) + quant_config = SmoothQuantConfig(base_config, step="convert") + quantize_(model, quant_config) + + quantized_model = model + + load_config = SmoothQuantConfig(base_config, step="prepare_for_loading") + quantized_model.config.quantization_config = TorchAoConfig(load_config) else: # other quantization are integrated with `from_pretrained` in huggingface transformers assert quant in quant_to_config, f"Unsupported quant option: {quant}" @@ -812,7 +881,7 @@ def quantize_and_upload( parser.add_argument( "--quant", type=str, - help="Quantization method. Options are FP8, INT4, INT8-INT4, INT8-INT4-HQQ, AWQ-INT4", + help="Quantization method. Options are FP8, INT4, INT8-INT4, INT8-INT4-HQQ, AWQ-INT4, SmoothQuant-INT8-INT8", ) parser.add_argument( "--tasks", @@ -824,8 +893,8 @@ def quantize_and_upload( parser.add_argument( "--calibration_limit", type=int, - default=10, - help="Number of samples to use for calibration. Default is 10.", + default=128, + help="Number of samples to use for calibration. Default is 128.", ) parser.add_argument( "--max_seq_length",