Skip to content

Commit 317f246

Browse files
INC new API tensorflow examples update (#174)
1 parent 9b15ad1 commit 317f246

File tree

24 files changed

+451
-602
lines changed

24 files changed

+451
-602
lines changed

.azure-pipelines/scripts/codeScan/pyspelling/inc_dict.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2397,6 +2397,8 @@ grappler
23972397
amsgrad
23982398
qoperator
23992399
apis
2400+
PostTrainingQuantConfig
2401+
dgpu
24002402
CPz
24012403
PostTrainingQuantConfig
24022404
dgpu

examples/.config/model_params_tensorflow.json

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,8 @@
162162
"model_src_dir": "image_recognition/keras_models/inception_resnet_v2/quantization/ptq",
163163
"dataset_location": "/tf_dataset/dataset/imagenet",
164164
"input_model": "/tf_dataset2/models/tensorflow/inception_resnet_v2_keras/saved_model/",
165-
"yaml": "inception_resnet_v2.yaml",
166-
"strategy": "basic",
167-
"batch_size": 1,
168-
"new_benchmark": true
165+
"main_script": "main.py",
166+
"batch_size": 32
169167
},
170168
"vgg16": {
171169
"model_src_dir": "image_recognition/tensorflow_models/quantization/ptq",
@@ -288,10 +286,8 @@
288286
"model_src_dir": "image_recognition/keras_models/resnetv2_50/quantization/ptq",
289287
"dataset_location": "/tf_dataset/dataset/imagenet",
290288
"input_model": "/tf_dataset2/models/tensorflow/resnetv2_50_keras/saved_model",
291-
"yaml": "resnetv2_50.yaml",
292-
"strategy": "basic",
293-
"batch_size": 1,
294-
"new_benchmark": true
289+
"main_script": "main.py",
290+
"batch_size": 32
295291
},
296292
"resnetv2_101": {
297293
"model_src_dir": "image_recognition/tensorflow_models/quantization/ptq",
@@ -2486,10 +2482,8 @@
24862482
"model_src_dir": "image_recognition/keras_models/xception/quantization/ptq",
24872483
"dataset_location": "/tf_dataset/dataset/imagenet",
24882484
"input_model": "/tf_dataset2/models/tensorflow/xception_keras/saved_model/",
2489-
"yaml": "xception.yaml",
2490-
"strategy": "basic",
2491-
"batch_size": 1,
2492-
"new_benchmark": true
2485+
"main_script": "main.py",
2486+
"batch_size": 32
24932487
}
24942488
}
24952489
}

examples/tensorflow/image_recognition/keras_models/inception_resnet_v2/quantization/ptq/README.md

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ Intel Extension for Tensorflow is mandatory to be installed for quantizing the m
2626
```shell
2727
pip install --upgrade intel-extension-for-tensorflow[gpu]
2828
```
29-
For any more details, please follow the procedure in [install-gpu-drivers](https://github.com/intel-innersource/frameworks.ai.infrastructure.intel-extension-for-tensorflow.intel-extension-for-tensorflow/blob/master/docs/install/install_for_gpu.md#install-gpu-drivers)
29+
Please refer to the [Installation Guides](https://dgpu-docs.intel.com/installation-guides/ubuntu/ubuntu-focal-dc.html) for latest Intel GPU driver installation.
30+
For any more details, please follow the procedure in [install-gpu-drivers](https://github.com/intel-innersource/frameworks.ai.infrastructure.intel-extension-for-tensorflow.intel-extension-for-tensorflow/blob/master/docs/install/install_for_gpu.md#install-gpu-drivers).
3031

3132
#### Quantizing the model on Intel CPU(Experimental)
3233
Intel Extension for Tensorflow for Intel CPUs is experimental currently. It's not mandatory for quantizing the model on Intel CPUs.
@@ -43,12 +44,20 @@ python prepare_model.py --output_model=/path/to/model
4344
```
4445
`--output_model ` the model should be saved as SavedModel format or H5 format.
4546

46-
## Write Yaml config file
47-
In examples directory, there is a inception_resnet_v2.yaml for tuning the model on Intel CPUs. The 'framework' in the yaml is set to 'tensorflow'. If running this example on Intel GPUs, the 'framework' should be set to 'tensorflow_itex' and the device in yaml file should be set to 'gpu'. The inception_resnet_v2_itex.yaml is prepared for the GPU case. We could remove most of items and only keep mandatory item for tuning. We also implement a calibration dataloader and have evaluation field for creation of evaluation function at internal neural_compressor.
47+
## Quantization Config
48+
The Quantization Config class has default parameters setting for running on Intel CPUs. If running this example on Intel GPUs, the 'backend' parameter should be set to 'itex' and the 'device' parameter should be set to 'gpu'.
49+
50+
```
51+
config = PostTrainingQuantConfig(
52+
device="gpu",
53+
backend="itex",
54+
...
55+
)
56+
```
4857

4958
## Run Command
5059
```shell
51-
bash run_tuning.sh --config=inception_resnet_v2.yaml --input_model=./path/to/model --output_model=./result --eval_data=/path/to/evaluation/dataset --calib_data=/path/to/calibration/dataset
52-
bash run_benchmark.sh --config=inception_resnet_v2.yaml --input_model=./path/to/model --mode=performance --eval_data=/path/to/evaluation/dataset
60+
bash run_tuning.sh --input_model=./path/to/model --output_model=./result --dataset_location=/path/to/evaluation/dataset --batch_size=32
61+
bash run_benchmark.sh --input_model=./path/to/model --mode=performance --dataset_location=/path/to/evaluation/dataset --batch_size=1
5362
```
5463

examples/tensorflow/image_recognition/keras_models/inception_resnet_v2/quantization/ptq/inception_resnet_v2.yaml

Lines changed: 0 additions & 44 deletions
This file was deleted.

examples/tensorflow/image_recognition/keras_models/inception_resnet_v2/quantization/ptq/inception_resnet_v2_itex.yaml

Lines changed: 0 additions & 44 deletions
This file was deleted.
Lines changed: 95 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#
22
# -*- coding: utf-8 -*-
33
#
4-
# Copyright (c) 2018 Intel Corporation
4+
# Copyright (c) 2022 Intel Corporation
55
#
66
# Licensed under the Apache License, Version 2.0 (the "License");
77
# you may not use this file except in compliance with the License.
@@ -16,11 +16,9 @@
1616
# limitations under the License.
1717
#
1818
import time
19-
import shutil
2019
import numpy as np
21-
from argparse import ArgumentParser
22-
from neural_compressor import data
2320
import tensorflow as tf
21+
from neural_compressor.utils import logger
2422
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
2523

2624
flags = tf.compat.v1.flags
@@ -42,87 +40,111 @@
4240
flags.DEFINE_bool(
4341
'benchmark', False, 'whether to benchmark the model')
4442

45-
flags.DEFINE_string(
46-
'config', 'bert.yaml', 'yaml configuration of the model')
47-
4843
flags.DEFINE_string(
4944
'calib_data', None, 'location of calibration dataset')
5045

5146
flags.DEFINE_string(
5247
'eval_data', None, 'location of evaluate dataset')
5348

54-
from neural_compressor.experimental.metric.metric import TensorflowTopK
55-
from neural_compressor.experimental.data.transforms.transform import ComposeTransform
56-
from neural_compressor.experimental.data.datasets.dataset import TensorflowImageRecord
57-
from neural_compressor.experimental.data.transforms.imagenet_transform import LabelShift
58-
from neural_compressor.experimental.data.dataloaders.default_dataloader import DefaultDataLoader
49+
flags.DEFINE_integer('batch_size', 32, 'batch_size')
50+
51+
flags.DEFINE_integer(
52+
'iters', 100, 'maximum iteration when evaluating performance')
53+
54+
from neural_compressor.metric.metric import TensorflowTopK
55+
from neural_compressor.data.transforms.transform import ComposeTransform
56+
from neural_compressor.data.datasets.dataset import TensorflowImageRecord
57+
from neural_compressor.data.transforms.imagenet_transform import LabelShift
58+
from neural_compressor.data.dataloaders.default_dataloader import DefaultDataLoader
5959
from neural_compressor.data.transforms.imagenet_transform import BilinearImagenetTransform
6060

6161
eval_dataset = TensorflowImageRecord(root=FLAGS.eval_data, transform=ComposeTransform(transform_list= \
62-
[BilinearImagenetTransform(height=299, width=299)]))
63-
if FLAGS.benchmark and FLAGS.mode == 'performance':
64-
eval_dataloader = DefaultDataLoader(dataset=eval_dataset, batch_size=1)
65-
else:
66-
eval_dataloader = DefaultDataLoader(dataset=eval_dataset, batch_size=32)
62+
[BilinearImagenetTransform(height=299, width=299)]))
63+
64+
eval_dataloader = DefaultDataLoader(dataset=eval_dataset, batch_size=FLAGS.batch_size)
65+
6766
if FLAGS.calib_data:
68-
calib_dataset = TensorflowImageRecord(root=FLAGS.calib_data, transform=ComposeTransform(transform_list= \
69-
[BilinearImagenetTransform(height=299, width=299)]))
70-
calib_dataloader = DefaultDataLoader(dataset=calib_dataset, batch_size=10)
71-
72-
def evaluate(model, measurer=None):
73-
"""
74-
Custom evaluate function to inference the model for specified metric on validation dataset.
75-
76-
Args:
77-
model (tf.saved_model.load): The input model will be the class of tf.saved_model.load(quantized_model_path).
78-
measurer (object, optional): for benchmark measurement of duration.
79-
80-
Returns:
81-
accuracy (float): evaluation result, the larger is better.
82-
"""
83-
infer = model.signatures["serving_default"]
84-
output_dict_keys = infer.structured_outputs.keys()
85-
output_name = list(output_dict_keys )[0]
86-
postprocess = LabelShift(label_shift=1)
87-
metric = TensorflowTopK(k=1)
88-
89-
def eval_func(dataloader, metric):
90-
results = []
91-
for idx, (inputs, labels) in enumerate(dataloader):
92-
inputs = np.array(inputs)
93-
input_tensor = tf.constant(inputs)
94-
if measurer:
95-
measurer.start()
96-
predictions = infer(input_tensor)[output_name]
97-
if measurer:
98-
measurer.end()
99-
predictions = predictions.numpy()
100-
predictions, labels = postprocess((predictions, labels))
101-
metric.update(predictions, labels)
102-
return results
103-
104-
results = eval_func(eval_dataloader, metric)
105-
acc = metric.result()
106-
return acc
67+
calib_dataset = TensorflowImageRecord(root=FLAGS.calib_data, transform= \
68+
ComposeTransform(transform_list= [BilinearImagenetTransform(height=299, width=299)]))
69+
calib_dataloader = DefaultDataLoader(dataset=calib_dataset, batch_size=10)
70+
71+
def evaluate(model):
72+
"""
73+
Custom evaluate function to inference the model for specified metric on validation dataset.
74+
75+
Args:
76+
model (tf.saved_model.load): The input model will be the class of tf.saved_model.load(quantized_model_path).
77+
measurer (object, optional): for benchmark measurement of duration.
78+
79+
Returns:
80+
accuracy (float): evaluation result, the larger is better.
81+
"""
82+
infer = model.signatures["serving_default"]
83+
output_dict_keys = infer.structured_outputs.keys()
84+
output_name = list(output_dict_keys )[0]
85+
postprocess = LabelShift(label_shift=1)
86+
metric = TensorflowTopK(k=1)
87+
latency_list = []
88+
89+
def eval_func(dataloader, metric):
90+
warmup = 5
91+
iteration = None
92+
93+
if FLAGS.benchmark and FLAGS.mode == 'performance':
94+
iteration = FLAGS.iters
95+
for idx, (inputs, labels) in enumerate(dataloader):
96+
inputs = np.array(inputs)
97+
input_tensor = tf.constant(inputs)
98+
start = time.time()
99+
predictions = infer(input_tensor)[output_name]
100+
end = time.time()
101+
latency_list.append(end - start)
102+
predictions = predictions.numpy()
103+
predictions, labels = postprocess((predictions, labels))
104+
metric.update(predictions, labels)
105+
if iteration and idx >= iteration:
106+
break
107+
latency = np.array(latency_list[warmup:]).mean() / eval_dataloader.batch_size
108+
return latency
109+
110+
latency = eval_func(eval_dataloader, metric)
111+
if FLAGS.benchmark:
112+
logger.info("\n{} mode benchmark result:".format(FLAGS.mode))
113+
for i, res in enumerate(latency_list):
114+
logger.debug("Iteration {} result {}:".format(i, res))
115+
if FLAGS.benchmark and FLAGS.mode == 'performance':
116+
logger.info("Batch size = {}".format(eval_dataloader.batch_size))
117+
logger.info("Latency: {:.3f} ms".format(latency * 1000))
118+
logger.info("Throughput: {:.3f} images/sec".format(1. / latency))
119+
acc = metric.result()
120+
return acc
107121

108122
def main(_):
109-
if FLAGS.tune:
110-
from neural_compressor.experimental import Quantization, common
111-
quantizer = Quantization(FLAGS.config)
112-
quantizer.model = common.Model(FLAGS.input_model)
113-
quantizer.eval_func = evaluate
114-
quantizer.calib_dataloader = calib_dataloader
115-
q_model = quantizer.fit()
116-
q_model.save(FLAGS.output_model)
117-
118-
119-
if FLAGS.benchmark:
120-
from neural_compressor.experimental import Benchmark, common
121-
evaluator = Benchmark(FLAGS.config)
122-
evaluator.model = common.Model(FLAGS.input_model)
123-
evaluator.b_func = evaluate
124-
evaluator.b_dataloader = eval_dataloader
125-
evaluator(FLAGS.mode)
123+
if FLAGS.tune:
124+
from neural_compressor.quantization import fit
125+
from neural_compressor.config import PostTrainingQuantConfig
126+
from neural_compressor.utils.utility import set_random_seed
127+
set_random_seed(9527)
128+
config = PostTrainingQuantConfig(calibration_sampling_size=[50, 100])
129+
q_model = fit(
130+
model=FLAGS.input_model,
131+
conf=config,
132+
calib_dataloader=calib_dataloader,
133+
eval_dataloader=eval_dataloader,
134+
eval_func=evaluate)
135+
q_model.save(FLAGS.output_model)
136+
137+
if FLAGS.benchmark:
138+
from neural_compressor.benchmark import fit
139+
from neural_compressor.config import BenchmarkConfig
140+
if FLAGS.mode == 'performance':
141+
conf = BenchmarkConfig(iteration=100, cores_per_instance=4, num_of_instance=7)
142+
fit(FLAGS.input_model, conf, b_func=evaluate)
143+
else:
144+
from neural_compressor.model.model import Model
145+
accuracy = evaluate(Model(FLAGS.input_model).model)
146+
logger.info('Batch size = %d' % FLAGS.batch_size)
147+
logger.info("Accuracy: %.5f" % accuracy)
126148

127149
if __name__ == "__main__":
128150
tf.compat.v1.app.run()

examples/tensorflow/image_recognition/keras_models/inception_resnet_v2/quantization/ptq/prepare_model.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,21 @@
1+
#
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright (c) 2022 Intel Corporation
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#
18+
119
import argparse
220
import tensorflow as tf
321
def get_inception_resnet_v2_model(saved_path):

0 commit comments

Comments
 (0)