Skip to content

Commit d446acd

Browse files
authored
Add INT8-INT4-HQQ to model release script (#3127)
up
1 parent c96f2dd commit d446acd

File tree

1 file changed

+51
-14
lines changed

1 file changed

+51
-14
lines changed

.github/scripts/torchao_model_releases/quantize_and_upload.py

Lines changed: 51 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,29 @@ def _untie_weights_and_save_locally(model_id):
242242
tokenizer = AutoTokenizer.from_pretrained(model_id)
243243
"""
244244

245+
_int8_int4_hqq_quant_code = """
246+
from torchao.quantization.quant_api import (
247+
IntxWeightOnlyConfig,
248+
Int8DynamicActivationIntxWeightConfig,
249+
ModuleFqnToConfig,
250+
)
251+
from torchao.quantization.granularity import PerGroup, PerAxis
252+
embedding_config = IntxWeightOnlyConfig(
253+
weight_dtype=torch.int8,
254+
granularity=PerAxis(0),
255+
intx_choose_qparams_algorithm="hqq_scale_only",
256+
)
257+
linear_config = Int8DynamicActivationIntxWeightConfig(
258+
weight_dtype=torch.int4,
259+
weight_granularity=PerGroup(32),
260+
intx_choose_qparams_algorithm="hqq_scale_only",
261+
)
262+
quant_config = ModuleFqnToConfig({{"_default": linear_config, "model.embed_tokens": embedding_config}})
263+
quantization_config = TorchAoConfig(quant_type=quant_config, include_input_output_embeddings=True, modules_to_not_convert=[])
264+
quantized_model = AutoModelForCausalLM.from_pretrained(model_to_quantize, device_map="cuda:0", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
265+
tokenizer = AutoTokenizer.from_pretrained(model_id)
266+
"""
267+
245268
_awq_int4_quant_code = """
246269
from torchao.quantization import Int4WeightOnlyConfig, quantize_
247270
from torchao.prototype.awq import (
@@ -589,14 +612,8 @@ def quantize_and_upload(
589612
push_to_user_id: str,
590613
populate_model_card_template: bool,
591614
):
592-
_int8_int4_linear_config = Int8DynamicActivationIntxWeightConfig(
593-
weight_dtype=torch.int4,
594-
weight_granularity=PerGroup(32),
595-
)
596-
_int8_int4_embedding_config = IntxWeightOnlyConfig(
597-
weight_dtype=torch.int8,
598-
granularity=PerAxis(0),
599-
)
615+
is_mobile = quant in ["INT8-INT4", "INT8-INT4-HQQ"]
616+
600617
quant_to_config = {
601618
"FP8": Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()),
602619
"INT4": Int4WeightOnlyConfig(
@@ -606,8 +623,28 @@ def quantize_and_upload(
606623
),
607624
"INT8-INT4": ModuleFqnToConfig(
608625
{
609-
"_default": _int8_int4_linear_config,
610-
"model.embed_tokens": _int8_int4_embedding_config,
626+
"_default": Int8DynamicActivationIntxWeightConfig(
627+
weight_dtype=torch.int4,
628+
weight_granularity=PerGroup(32),
629+
),
630+
"model.embed_tokens": IntxWeightOnlyConfig(
631+
weight_dtype=torch.int8,
632+
granularity=PerAxis(0),
633+
),
634+
}
635+
),
636+
"INT8-INT4-HQQ": ModuleFqnToConfig(
637+
{
638+
"_default": Int8DynamicActivationIntxWeightConfig(
639+
weight_dtype=torch.int4,
640+
weight_granularity=PerGroup(32),
641+
intx_choose_qparams_algorithm="hqq_scale_only",
642+
),
643+
"model.embed_tokens": IntxWeightOnlyConfig(
644+
weight_dtype=torch.int8,
645+
granularity=PerAxis(0),
646+
intx_choose_qparams_algorithm="hqq_scale_only",
647+
),
611648
}
612649
),
613650
}
@@ -616,12 +653,13 @@ def quantize_and_upload(
616653
"FP8": _fp8_quant_code,
617654
"INT4": _int4_quant_code,
618655
"INT8-INT4": _int8_int4_quant_code,
656+
"INT8-INT4-HQQ": _int8_int4_hqq_quant_code,
619657
"AWQ-INT4": _awq_int4_quant_code,
620658
}
621659

622660
# preparation
623661
model_to_quantize = model_id
624-
if quant == "INT8-INT4":
662+
if is_mobile:
625663
model_to_quantize = _untie_weights_and_save_locally(model_to_quantize)
626664

627665
# quantization
@@ -666,7 +704,7 @@ def quantize_and_upload(
666704
quant_config = quant_to_config[quant]
667705

668706
torchao_config_kwargs = {}
669-
if "INT8-INT4" in quant:
707+
if is_mobile:
670708
torchao_config_kwargs["modules_to_not_convert"] = []
671709
torchao_config_kwargs["include_input_output_embeddings"] = True
672710

@@ -688,7 +726,6 @@ def quantize_and_upload(
688726
save_to_user_id = username if push_to_user_id is None else push_to_user_id
689727
save_to = f"{save_to_user_id}/{MODEL_NAME}-{quant}"
690728
untied_model_path = 'f"{{MODEL_NAME}}-untied-weights"'
691-
is_mobile = quant == "INT8-INT4"
692729
quantized_model_id = save_to
693730
# model card
694731
content = MODEL_CARD.format(
@@ -775,7 +812,7 @@ def quantize_and_upload(
775812
parser.add_argument(
776813
"--quant",
777814
type=str,
778-
help="Quantization method. Options are FP8, INT4, INT8-INT4, AWQ-INT4",
815+
help="Quantization method. Options are FP8, INT4, INT8-INT4, INT8-INT4-HQQ, AWQ-INT4",
779816
)
780817
parser.add_argument(
781818
"--tasks",

0 commit comments

Comments
 (0)