@@ -242,6 +242,29 @@ def _untie_weights_and_save_locally(model_id):
242
242
tokenizer = AutoTokenizer.from_pretrained(model_id)
243
243
"""
244
244
245
+ _int8_int4_hqq_quant_code = """
246
+ from torchao.quantization.quant_api import (
247
+ IntxWeightOnlyConfig,
248
+ Int8DynamicActivationIntxWeightConfig,
249
+ ModuleFqnToConfig,
250
+ )
251
+ from torchao.quantization.granularity import PerGroup, PerAxis
252
+ embedding_config = IntxWeightOnlyConfig(
253
+ weight_dtype=torch.int8,
254
+ granularity=PerAxis(0),
255
+ intx_choose_qparams_algorithm="hqq_scale_only",
256
+ )
257
+ linear_config = Int8DynamicActivationIntxWeightConfig(
258
+ weight_dtype=torch.int4,
259
+ weight_granularity=PerGroup(32),
260
+ intx_choose_qparams_algorithm="hqq_scale_only",
261
+ )
262
+ quant_config = ModuleFqnToConfig({{"_default": linear_config, "model.embed_tokens": embedding_config}})
263
+ quantization_config = TorchAoConfig(quant_type=quant_config, include_input_output_embeddings=True, modules_to_not_convert=[])
264
+ quantized_model = AutoModelForCausalLM.from_pretrained(model_to_quantize, device_map="cuda:0", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
265
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
266
+ """
267
+
245
268
_awq_int4_quant_code = """
246
269
from torchao.quantization import Int4WeightOnlyConfig, quantize_
247
270
from torchao.prototype.awq import (
@@ -589,14 +612,8 @@ def quantize_and_upload(
589
612
push_to_user_id : str ,
590
613
populate_model_card_template : bool ,
591
614
):
592
- _int8_int4_linear_config = Int8DynamicActivationIntxWeightConfig (
593
- weight_dtype = torch .int4 ,
594
- weight_granularity = PerGroup (32 ),
595
- )
596
- _int8_int4_embedding_config = IntxWeightOnlyConfig (
597
- weight_dtype = torch .int8 ,
598
- granularity = PerAxis (0 ),
599
- )
615
+ is_mobile = quant in ["INT8-INT4" , "INT8-INT4-HQQ" ]
616
+
600
617
quant_to_config = {
601
618
"FP8" : Float8DynamicActivationFloat8WeightConfig (granularity = PerRow ()),
602
619
"INT4" : Int4WeightOnlyConfig (
@@ -606,8 +623,28 @@ def quantize_and_upload(
606
623
),
607
624
"INT8-INT4" : ModuleFqnToConfig (
608
625
{
609
- "_default" : _int8_int4_linear_config ,
610
- "model.embed_tokens" : _int8_int4_embedding_config ,
626
+ "_default" : Int8DynamicActivationIntxWeightConfig (
627
+ weight_dtype = torch .int4 ,
628
+ weight_granularity = PerGroup (32 ),
629
+ ),
630
+ "model.embed_tokens" : IntxWeightOnlyConfig (
631
+ weight_dtype = torch .int8 ,
632
+ granularity = PerAxis (0 ),
633
+ ),
634
+ }
635
+ ),
636
+ "INT8-INT4-HQQ" : ModuleFqnToConfig (
637
+ {
638
+ "_default" : Int8DynamicActivationIntxWeightConfig (
639
+ weight_dtype = torch .int4 ,
640
+ weight_granularity = PerGroup (32 ),
641
+ intx_choose_qparams_algorithm = "hqq_scale_only" ,
642
+ ),
643
+ "model.embed_tokens" : IntxWeightOnlyConfig (
644
+ weight_dtype = torch .int8 ,
645
+ granularity = PerAxis (0 ),
646
+ intx_choose_qparams_algorithm = "hqq_scale_only" ,
647
+ ),
611
648
}
612
649
),
613
650
}
@@ -616,12 +653,13 @@ def quantize_and_upload(
616
653
"FP8" : _fp8_quant_code ,
617
654
"INT4" : _int4_quant_code ,
618
655
"INT8-INT4" : _int8_int4_quant_code ,
656
+ "INT8-INT4-HQQ" : _int8_int4_hqq_quant_code ,
619
657
"AWQ-INT4" : _awq_int4_quant_code ,
620
658
}
621
659
622
660
# preparation
623
661
model_to_quantize = model_id
624
- if quant == "INT8-INT4" :
662
+ if is_mobile :
625
663
model_to_quantize = _untie_weights_and_save_locally (model_to_quantize )
626
664
627
665
# quantization
@@ -666,7 +704,7 @@ def quantize_and_upload(
666
704
quant_config = quant_to_config [quant ]
667
705
668
706
torchao_config_kwargs = {}
669
- if "INT8-INT4" in quant :
707
+ if is_mobile :
670
708
torchao_config_kwargs ["modules_to_not_convert" ] = []
671
709
torchao_config_kwargs ["include_input_output_embeddings" ] = True
672
710
@@ -688,7 +726,6 @@ def quantize_and_upload(
688
726
save_to_user_id = username if push_to_user_id is None else push_to_user_id
689
727
save_to = f"{ save_to_user_id } /{ MODEL_NAME } -{ quant } "
690
728
untied_model_path = 'f"{{MODEL_NAME}}-untied-weights"'
691
- is_mobile = quant == "INT8-INT4"
692
729
quantized_model_id = save_to
693
730
# model card
694
731
content = MODEL_CARD .format (
@@ -775,7 +812,7 @@ def quantize_and_upload(
775
812
parser .add_argument (
776
813
"--quant" ,
777
814
type = str ,
778
- help = "Quantization method. Options are FP8, INT4, INT8-INT4, AWQ-INT4" ,
815
+ help = "Quantization method. Options are FP8, INT4, INT8-INT4, INT8-INT4-HQQ, AWQ-INT4" ,
779
816
)
780
817
parser .add_argument (
781
818
"--tasks" ,
0 commit comments