Skip to content

Commit 6c27c19

Browse files
authored
Support calib_func on TF 3x API (#1934)
Signed-off-by: zehao-intel <[email protected]>
1 parent 53e6ee6 commit 6c27c19

File tree

8 files changed

+79
-36
lines changed

8 files changed

+79
-36
lines changed

docs/3x/TensorFlow.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,15 @@ Intel(R) Neural Compressor provides `quantize_model` and `autotune` as main inte
2323

2424
**quantize_model**
2525

26-
The design philosophy of the `quantize_model` interface is easy-of-use. With minimal parameters requirement, including `model`, `quant_config`, `calib_dataloader` and `calib_iteration`, it offers a straightforward choice of quantizing TF model in one-shot.
26+
The design philosophy of the `quantize_model` interface is easy-of-use. With minimal parameters requirement, including `model`, `quant_config`, `calib_dataloader`, `calib_iteration`, it offers a straightforward choice of quantizing TF model in one-shot.
2727

2828
```python
2929
def quantize_model(
3030
model: Union[str, tf.keras.Model, BaseModel],
3131
quant_config: Union[BaseConfig, list],
3232
calib_dataloader: Callable = None,
3333
calib_iteration: int = 100,
34+
calib_func: Callable = None,
3435
):
3536
```
3637
`model` should be a string of the model's location, the object of Keras model or INC TF model wrapper class.
@@ -41,6 +42,9 @@ def quantize_model(
4142

4243
`calib_iteration` is used to decide how many iterations the calibration process will be run.
4344

45+
`calib_func` is a substitution for `calib_dataloader` when the built-in calibration function of INC does not work for model inference.
46+
47+
4448
Here is a simple example of using `quantize_model` interface with a dummy calibration dataloader and the default `StaticQuantConfig`:
4549
```python
4650
from neural_compressor.tensorflow import StaticQuantConfig, quantize_model
@@ -68,6 +72,7 @@ def autotune(
6872
eval_args: Optional[Tuple[Any]] = None,
6973
calib_dataloader: Callable = None,
7074
calib_iteration: int = 100,
75+
calib_func: Callable = None,
7176
) -> Optional[BaseModel]:
7277
```
7378
`model` should be a string of the model's location, the object of Keras model or INC TF model wrapper class.
@@ -82,6 +87,8 @@ def autotune(
8287

8388
`calib_iteration` is used to decide how many iterations the calibration process will be run.
8489

90+
`calib_func` is a substitution for `calib_dataloader` when the built-in calibration function of INC does not work for model inference.
91+
8592
Here is a simple example of using `autotune` interface with different quantization rules defined by a list of `StaticQuantConfig`:
8693
```python
8794
from neural_compressor.common.base_tuning import TuningConfig

neural_compressor/tensorflow/algorithms/smoother/core.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,19 +37,23 @@ class SmoothQuant:
3737
def __init__(
3838
self,
3939
config: SmoothQuantConfig,
40-
calib_dataloader: Callable,
40+
calib_dataloader: Callable = None,
4141
calib_iteration: int = 1,
42+
calib_func: Callable = None,
4243
):
4344
"""Convert the model by smooth quant.
4445
4546
Args:
46-
config: the SmoothQuantConfig class used to set this class
47-
calibdataloader: the calibration dataloader
48-
calib_iteration: how many steps of iterations on the dataloader to move forward
47+
config: the SmoothQuantConfig class used to set this class.
48+
calibdataloader: the calibration dataloader.
49+
calib_iteration: how many steps of iterations on the dataloader to move forward.
50+
calib_func: the function used for calibration, should be a substitution for calib_dataloader
51+
when the built-in calibration function of INC does not work for model inference.
4952
5053
Returns:
5154
model: A smoothed Tensorflow model
5255
"""
56+
assert calib_func is None, "calibration function is not supported for smooth quant."
5357
self.config = config
5458
self.calib_dataloader = calib_dataloader
5559
self.calib_iteration = calib_iteration

neural_compressor/tensorflow/algorithms/static_quant/keras.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -314,16 +314,18 @@ def fuse_conv_bn(conv_weight, bn_weight, conv_type="Conv2D", eps=1.0e-5):
314314
return bn_fused_model
315315

316316
@dump_elapsed_time("Pass quantize model")
317-
def quantize(self, quant_config, model, dataloader, iteration, q_func=None):
317+
def quantize(self, quant_config, model, dataloader, iteration, calib_func=None):
318318
"""Execute the quantize process on the specified model.
319319
320320
Args:
321-
tune_cfg(dict): The user defined 'StaticQuantConfig' class.
321+
quant_config(dict): The user defined 'StaticQuantConfig' class.
322322
model (object): The model to do quantization.
323323
dataloader(object): The calibration dataloader used to load quantization dataset.
324324
iteration(int): The iteration of calibration.
325-
q_func (optional): training function for quantization aware training mode.
325+
calib_func (optional): the function used for calibration, should be a substitution for calibration
326+
dataloader when the built-in calibration function of INC does not work for model inference.
326327
"""
328+
assert calib_func is None, "The calibration function is not supported on Keras backend yet"
327329
self.query_fw_capability(model)
328330
converter = KerasConfigConverter(quant_config, iteration)
329331
tune_cfg = converter.parse_to_tune_cfg()
@@ -367,15 +369,13 @@ def quantize(self, quant_config, model, dataloader, iteration, q_func=None):
367369

368370
return quantized_model
369371

370-
def _calibrate(self, model, dataloader, calib_interation):
372+
def _calibrate(self, model, dataloader=None, calib_interation=None):
371373
"""Apply calibration.
372374
373375
Args:
374376
model (tf.keras.Model): The model inserted with FakeQuant layers for calibration.
375377
dataloader(object): The calibration dataloader used to load quantization dataset.
376378
iteration(int): The iteration of calibration.
377-
fq_output_layers (dict): A dict mapping from names of FakeQuant layers to
378-
names of their output layers.
379379
"""
380380
# run eagerly to fetch the numpy min/max
381381
results = {}

neural_compressor/tensorflow/algorithms/static_quant/tensorflow.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def quantize(
172172
model: BaseModel,
173173
calib_dataloader: Callable = None,
174174
calib_iteration: int = 100,
175-
q_func=None,
175+
calib_func: Callable = None,
176176
):
177177
"""Execute the quantize process on the specified model.
178178
@@ -181,11 +181,11 @@ def quantize(
181181
model: the fp32 model to be quantized.
182182
calib_dataloader: a data loader for calibration.
183183
calib_iteration: the iteration of calibration.
184-
q_func: training function for quantization aware training mode,
185-
which not enabled for tensorflow yet.
184+
calib_func: the function used for calibration, should be a substitution for calib_dataloader
185+
when the built-in calibration function of INC does not work for model inference.
186186
187187
Returns:
188-
tf.compat.v1.GraphDef: the quantized model
188+
converted_model: the quantized INC model wrapper.
189189
"""
190190
assert (
191191
self.approach != "post_training_dynamic_quant"
@@ -195,7 +195,7 @@ def quantize(
195195
self.approach != "quant_aware_training"
196196
), "Quantize Aware Training is not supported on TensorFlow framework now!"
197197

198-
self.calib_sampling_size = calib_dataloader.batch_size * calib_iteration
198+
self.calib_sampling_size = calib_dataloader.batch_size * calib_iteration if calib_dataloader else 100
199199
tune_cfg = self.parse_quant_config(quant_config, model, calib_iteration)
200200
self._tuning_cfg_to_fw(tune_cfg)
201201
self.bf16_ops.extend(self.smooth_quant_mul_ops)
@@ -228,7 +228,7 @@ def quantize(
228228
fp32_ops=self.fp32_ops,
229229
bf16_ops=self.bf16_ops,
230230
data_loader=calib_dataloader,
231-
calib_func=q_func,
231+
calib_func=calib_func,
232232
qdq_enabled=self.qdq_enabled,
233233
new_api=self.new_api,
234234
performance_only=self.performance_only,
@@ -251,7 +251,7 @@ def quantize(
251251
fp32_ops=self.fp32_ops,
252252
bf16_ops=self.bf16_ops,
253253
data_loader=calib_dataloader,
254-
calib_func=q_func,
254+
calib_func=calib_func,
255255
qdq_enabled=self.qdq_enabled,
256256
new_api=self.new_api,
257257
performance_only=self.performance_only,
@@ -275,7 +275,7 @@ def quantize(
275275
fp32_ops=self.fp32_ops,
276276
bf16_ops=self.bf16_ops,
277277
data_loader=calib_dataloader,
278-
calib_func=q_func,
278+
calib_func=calib_func,
279279
qdq_enabled=self.qdq_enabled,
280280
new_api=self.new_api,
281281
performance_only=self.performance_only,
@@ -750,21 +750,21 @@ def quantize(
750750
model: BaseModel,
751751
calib_dataloader: Callable = None,
752752
calib_iteration: int = 100,
753-
q_func=None,
753+
calib_func: Callable = None,
754754
):
755755
"""Execute the quantize process on the specified model.
756756
757757
Args:
758-
tune_cfg (dict): quantization configuration
759-
model (tf.compat.v1.GraphDef): fp32 model
760-
data_loader (generator): generator the data and labels
761-
q_func (optional): training function for quantization aware training mode,
762-
which not enabled for tensorflow yet.
758+
quant_config: a quantization configuration.
759+
model: the fp32 model to be quantized.
760+
calib_dataloader: a data loader for calibration.
761+
calib_iteration: the iteration of calibration.
762+
calib_func: the function used for calibration, should be a substitution for calib_dataloader
763+
when the built-in calibration function of INC does not work for model inference.
763764
764765
Returns:
765-
tf.compat.v1.GraphDef: the quantized model
766+
converted_model: the quantized INC model wrapper.
766767
"""
767-
assert q_func is None, "quantization aware training mode is not support on tensorflow"
768768
self.calib_sampling_size = calib_dataloader.batch_size * calib_iteration
769769
tune_cfg = self.parse_quant_config(quant_config, model, calib_iteration)
770770
self._tuning_cfg_to_fw(tune_cfg)
@@ -798,7 +798,7 @@ def quantize(
798798
fp32_ops=self.fp32_ops,
799799
bf16_ops=self.bf16_ops,
800800
data_loader=calib_dataloader,
801-
calib_func=q_func,
801+
calib_func=calib_func,
802802
itex_mode=self.itex_mode,
803803
qdq_enabled=self.qdq_enabled,
804804
new_api=self.new_api,
@@ -846,7 +846,7 @@ def quantize(
846846
fp32_ops=self.fp32_ops,
847847
bf16_ops=self.bf16_ops,
848848
data_loader=calib_dataloader,
849-
calib_func=q_func,
849+
calib_func=calib_func,
850850
itex_mode=self.itex_mode,
851851
qdq_enabled=self.qdq_enabled,
852852
new_api=self.new_api,

neural_compressor/tensorflow/quantization/algorithm_entry.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def static_quant_entry(
2828
quant_config: BaseConfig,
2929
calib_dataloader: Callable = None,
3030
calib_iteration: int = 100,
31+
calib_func: Callable = None,
3132
):
3233
"""The main entry to apply static quantization.
3334
@@ -36,6 +37,8 @@ def static_quant_entry(
3637
quant_config: a quantization configuration.
3738
calib_dataloader: a data loader for calibration.
3839
calib_iteration: the iteration of calibration.
40+
calib_func: the function used for calibration, should be a substitution for calib_dataloader
41+
when the built-in calibration function of INC does not work for model inference.
3942
4043
Returns:
4144
q_model: the quantized model.
@@ -49,7 +52,7 @@ def static_quant_entry(
4952
framework = TensorFlowAdaptor
5053

5154
quantizer = framework(TFConfig.global_config)
52-
q_model = quantizer.quantize(quant_config, model, calib_dataloader, calib_iteration)
55+
q_model = quantizer.quantize(quant_config, model, calib_dataloader, calib_iteration, calib_func)
5356
TFConfig.reset_global_config()
5457

5558
return q_model
@@ -61,12 +64,26 @@ def smooth_quant_entry(
6164
smooth_quant_config: SmoothQuantConfig,
6265
calib_dataloader: Callable = None,
6366
calib_iteration: int = 100,
67+
calib_func: Callable = None,
6468
):
69+
"""The main entry to apply smooth quantization.
70+
71+
Args:
72+
model: a fp32 model to be quantized.
73+
quant_config: a quantization configuration.
74+
calib_dataloader: a data loader for calibration.
75+
calib_iteration: the iteration of calibration.
76+
calib_func: the function used for calibration, should be a substitution for calib_dataloader
77+
when the built-in calibration function of INC does not work for model inference.
78+
79+
Returns:
80+
q_model: the quantized model.
81+
"""
6582
assert not isinstance(model, KerasModel), "INC don't support smooth quantization for Keras models now."
6683

6784
from neural_compressor.tensorflow.algorithms import SmoothQuant
6885

69-
converter = SmoothQuant(smooth_quant_config, calib_dataloader, calib_iteration)
86+
converter = SmoothQuant(smooth_quant_config, calib_dataloader, calib_iteration, calib_func)
7087
sq_model = converter(model)
7188

7289
return sq_model

neural_compressor/tensorflow/quantization/autotune.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def autotune(
4444
eval_args: Optional[Tuple[Any]] = None,
4545
calib_dataloader: Callable = None,
4646
calib_iteration: int = 100,
47+
calib_func: Callable = None,
4748
) -> Optional[BaseModel]:
4849
"""The main entry of auto-tune."""
4950
model = Model(model)
@@ -57,7 +58,7 @@ def autotune(
5758
tuning_logger.trial_start(trial_index=trial_index)
5859
tuning_logger.execution_start()
5960
logger.info(quant_config.to_dict())
60-
q_model = quantize_model(model, quant_config, calib_dataloader, calib_iteration)
61+
q_model = quantize_model(model, quant_config, calib_dataloader, calib_iteration, calib_func)
6162
tuning_logger.execution_end()
6263
tuning_logger.evaluation_start()
6364
eval_result: float = eval_func_wrapper.evaluate(q_model)
@@ -71,7 +72,9 @@ def autotune(
7172
logger.info("Re-quantizing with best quantization config...")
7273
del q_model
7374
best_quant_config: BaseConfig = best_trial_record.quant_config
74-
best_quant_model = quantize_model(model, best_quant_config, calib_dataloader, calib_iteration)
75+
best_quant_model = quantize_model(
76+
model, best_quant_config, calib_dataloader, calib_iteration, calib_func
77+
)
7578
else:
7679
best_quant_model = q_model
7780
break

neural_compressor/tensorflow/quantization/quantize.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def quantize_model(
3232
quant_config: Union[BaseConfig, list],
3333
calib_dataloader: Callable = None,
3434
calib_iteration: int = 100,
35+
calib_func: Callable = None,
3536
):
3637
"""The main entry to quantize model.
3738
@@ -40,16 +41,20 @@ def quantize_model(
4041
quant_config: single or lists of quantization configuration.
4142
calib_dataloader: a data loader for calibration.
4243
calib_iteration: the iteration of calibration.
44+
calib_func: the function used for calibration, should be a substitution for calib_dataloader
45+
when the built-in calibration function of INC does not work for model inference.
4346
4447
Returns:
4548
q_model: the quantized model.
4649
"""
4750
q_model = Model(model)
4851
if isinstance(quant_config, list):
4952
for config in quant_config:
50-
q_model = quantize_model_with_single_config(q_model, config, calib_dataloader, calib_iteration)
53+
q_model = quantize_model_with_single_config(q_model, config, calib_dataloader, calib_iteration, calib_func)
5154
else:
52-
q_model = quantize_model_with_single_config(q_model, quant_config, calib_dataloader, calib_iteration)
55+
q_model = quantize_model_with_single_config(
56+
q_model, quant_config, calib_dataloader, calib_iteration, calib_func
57+
)
5358

5459
return q_model
5560

@@ -59,6 +64,7 @@ def quantize_model_with_single_config(
5964
quant_config: BaseConfig,
6065
calib_dataloader: Callable = None,
6166
calib_iteration: int = 100,
67+
calib_func: Callable = None,
6268
):
6369
"""Quantize model using single config.
6470
@@ -67,6 +73,8 @@ def quantize_model_with_single_config(
6773
quant_config: a quantization configuration.
6874
calib_dataloader: a data loader for calibration.
6975
calib_iteration: the iteration of calibration.
76+
calib_func: the function used for calibration, should be a substitution for calib_dataloader
77+
when the built-in calibration function of INC does not work for model inference.
7078
7179
Returns:
7280
q_model: the quantized model.
@@ -89,5 +97,5 @@ def quantize_model_with_single_config(
8997
for algo_name, algo_func in algos_mapping.items():
9098
if need_apply(configs_mapping, algo_name):
9199
logger.info(f"Start to apply {algo_name} on the model.")
92-
q_model = algo_func(q_model, configs_mapping, calib_dataloader, calib_iteration)
100+
q_model = algo_func(q_model, configs_mapping, calib_dataloader, calib_iteration, calib_func)
93101
return q_model

neural_compressor/tensorflow/quantization/utils/graph_converter.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,10 @@ def _inference(self, model):
231231
Args:
232232
model(TensorflowBaseModel): input TensorflowBaseModel
233233
"""
234+
if self.calib_func:
235+
self.calib_func(model)
236+
return
237+
234238
if model.model_type == "llm_saved_model":
235239
self._inference_llm(model)
236240
return

0 commit comments

Comments
 (0)