@@ -134,7 +134,7 @@ def quantize(
134134 from torchao .quantization .quant_api import Int8DynActInt4WeightQuantizer
135135
136136 model = Int8DynActInt4WeightQuantizer (
137- precision = torch_dtype , group_size = group_size
137+ precision = torch_dtype , groupsize = group_size
138138 ).quantize (model )
139139 if verbose_export ():
140140 print ("quantized model:" , model )
@@ -153,6 +153,7 @@ def quantize(
153153 if calibration_tasks is None :
154154 calibration_tasks = ["wikitext" ]
155155
156+ from torchao .quantization .GPTQ import InputRecorder
156157 from torchao .quantization .quant_api import Int8DynActInt4WeightGPTQQuantizer
157158
158159 if tokenizer_path is None :
@@ -161,17 +162,28 @@ def quantize(
161162 tokenizer = SentencePieceProcessor ( # pyre-ignore[28]
162163 model_file = str (tokenizer_path )
163164 )
165+
166+ inputs = (
167+ InputRecorder (
168+ tokenizer ,
169+ calibration_seq_length ,
170+ None , # input_prep_func
171+ pad_calibration_inputs ,
172+ model .vocab_size ,
173+ )
174+ .record_inputs (
175+ calibration_tasks ,
176+ calibration_limit ,
177+ )
178+ .get_inputs ()
179+ )
180+
164181 gptq_quantizer = Int8DynActInt4WeightGPTQQuantizer (
165- tokenizer ,
166182 blocksize ,
167183 percdamp ,
168184 group_size ,
169- calibration_tasks ,
170- calibration_limit ,
171- calibration_seq_length ,
172- pad_calibration_inputs ,
173185 )
174- model = gptq_quantizer .quantize (model )
186+ model = gptq_quantizer .quantize (model , inputs )
175187 return model
176188 else :
177189 raise Exception (f"Unrecognized quantize mode: { qmode } " )
0 commit comments