Skip to content

Commit 8602808

Browse files
guomingzmengniwang95
authored andcommitted
Enable QuantizedConv2DWithBiasAndKLeakyRelu fusion for yolo_v3 model … (#148)
* Enable QuantizedConv2DWithBiasAndKLeakyRelu fusion for yolo_v3 model and add 1.15.up3 config into tensorflow.yaml. Signed-off-by: Zhang, Guoming <[email protected]> Co-authored-by: mengniwa <[email protected]>
1 parent 625e69f commit 8602808

File tree

15 files changed

+798
-77
lines changed

15 files changed

+798
-77
lines changed
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
This document describes the step-by-step to reproduce Yolo-v3 tuning result with LPOT.
2+
3+
## Prerequisite
4+
5+
6+
### 1. Installation
7+
Recommend python 3.6 or higher version.
8+
9+
```shell
10+
# Install Intel® Low Precision Optimization Tool
11+
pip install lpot
12+
```
13+
### 2. Install Intel Tensorflow
14+
```shell
15+
pip install intel-tensorflow==1.15.0up3
16+
```
17+
> Note: Supported Tensorflow versions please refer to LPOT readme file.
18+
19+
### 3. Installation Dependency packages
20+
```shell
21+
cd examples/tensorflow/object_detection
22+
pip install -r requirements.txt
23+
```
24+
25+
### 4. Downloaded Yolo-v3 model
26+
```shell
27+
git clone https://github.com/mystic123/tensorflow-yolo-v3.git
28+
cd tensorflow-yolo-v3
29+
```
30+
31+
### 5. Download COCO Class Names File
32+
```shell
33+
wget https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names
34+
```
35+
36+
### 6. Download Model Weights (Full):
37+
```shell
38+
wget https://pjreddie.com/media/files/yolov3.weights
39+
```
40+
41+
### 7. Generate PB:
42+
```shell
43+
python convert_weights_pb.py --class_names coco.names --weights_file yolov3.weights --data_format NHWC --size 416 --output_graph yolov3.pb
44+
```
45+
46+
### 8. Prepare Dataset
47+
48+
#### Automatic dataset download
49+
50+
> **_Note: `prepare_dataset.sh` script works with TF version 1.x._**
51+
52+
Run the `prepare_dataset.sh` script located in `examples/tensorflow/object_detection`.
53+
54+
Usage:
55+
```shell
56+
cd examples/tensorflow/object_detection
57+
. prepare_dataset.sh
58+
```
59+
60+
This script will download the *train*, *validation* and *test* COCO datasets. Furthermore it will convert them to
61+
tensorflow records using the `https://github.com/tensorflow/models.git` dedicated script.
62+
63+
#### Manual dataset download
64+
Download CoCo Dataset from [Official Website](https://cocodataset.org/#download).
65+
66+
## Get Quantized Yolo-v3 model with LPOT
67+
68+
### 1.Config the yolo_v3.yaml with the valid cocoraw data path.
69+
70+
### 2.Run below command one by one.
71+
Usage
72+
```shell
73+
cd examples/tensorflow/object_detection/yolo_v3
74+
```
75+
```python
76+
python infer_detections.py --input_graph /path/to/yolov3_fp32.pb --config ./yolo_v3.yaml --output_graph /path/to/save/yolov3_tuned3.pb
77+
```
78+
79+
Finally, the LPOT will generate the quantized Yolo-v3 model with relative 1% loss.
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
COCO_NUM_VAL_IMAGES = 4952
2+
LABEL_MAP = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20,
3+
21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
4+
42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60,
5+
61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82,
6+
84, 85, 86, 87, 88, 89, 90]
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
import time
2+
import numpy as np
3+
import tensorflow as tf
4+
5+
from absl import app, flags
6+
7+
from tensorflow.python.client import timeline
8+
from coco_constants import LABEL_MAP
9+
from utils import read_graph, non_max_suppression
10+
11+
flags.DEFINE_integer('batch_size', 1, "batch size")
12+
13+
flags.DEFINE_string("ground_truth", None, "ground truth file")
14+
15+
flags.DEFINE_string("input_graph", None, "input graph")
16+
17+
flags.DEFINE_string("output_graph", None, "input graph")
18+
19+
flags.DEFINE_string("config", None, "LPOT config file")
20+
21+
flags.DEFINE_float("conf_threshold", 0.5, "confidence threshold")
22+
23+
flags.DEFINE_float("iou_threshold", 0.4, "IoU threshold")
24+
25+
flags.DEFINE_integer("num_intra_threads", 0, "number of intra threads")
26+
27+
flags.DEFINE_integer("num_inter_threads", 1, "number of inter threads")
28+
29+
flags.DEFINE_boolean("benchmark", False, "benchmark mode")
30+
31+
flags.DEFINE_boolean("profiling", False, "Signal of profiling")
32+
33+
FLAGS = flags.FLAGS
34+
35+
36+
class NMS():
37+
def __init__(self, conf_threshold, iou_threshold):
38+
self.conf_threshold = conf_threshold
39+
self.iou_threshold = iou_threshold
40+
41+
def __call__(self, sample):
42+
preds, labels = sample
43+
if not isinstance(preds, np.ndarray):
44+
preds = np.array(preds)
45+
filtered_boxes = non_max_suppression(preds,
46+
self.conf_threshold,
47+
self.iou_threshold)
48+
49+
det_boxes = []
50+
det_scores = []
51+
det_classes = []
52+
for cls, bboxs in filtered_boxes.items():
53+
det_classes.extend([LABEL_MAP[cls + 1]] * len(bboxs))
54+
for box, score in bboxs:
55+
rect_pos = box.tolist()
56+
y_min, x_min = rect_pos[1], rect_pos[0]
57+
y_max, x_max = rect_pos[3], rect_pos[2]
58+
height, width = 416, 416
59+
det_boxes.append(
60+
[y_min / height, x_min / width, y_max / height, x_max / width])
61+
det_scores.append(score)
62+
63+
if len(det_boxes) == 0:
64+
det_boxes = np.zeros((0, 4))
65+
det_scores = np.zeros((0, ))
66+
det_classes = np.zeros((0, ))
67+
68+
return [np.array([det_boxes]), np.array([det_scores]), np.array([det_classes])], labels
69+
70+
71+
def create_tf_config():
72+
config = tf.compat.v1.ConfigProto()
73+
config.intra_op_parallelism_threads = FLAGS.num_intra_threads
74+
config.inter_op_parallelism_threads = FLAGS.num_inter_threads
75+
return config
76+
77+
78+
def run_benchmark():
79+
config = create_tf_config()
80+
81+
graph_def = read_graph(FLAGS.input_graph)
82+
83+
tf.import_graph_def(graph_def, name='')
84+
85+
input_tensor = tf.compat.v1.get_default_graph().get_tensor_by_name('inputs:0')
86+
output_tensor = tf.compat.v1.get_default_graph().get_tensor_by_name('output_boxes:0')
87+
88+
dummy_data_shape = list(input_tensor.shape)
89+
dummy_data_shape[0] = FLAGS.batch_size
90+
dummy_data = np.random.random(dummy_data_shape).astype(np.float32)
91+
92+
if FLAGS.profiling != True:
93+
num_warmup = 200
94+
total_iter = 1000
95+
else:
96+
num_warmup = 20
97+
total_iter = 100
98+
99+
total_time = 0.0
100+
101+
with tf.compat.v1.Session(config=config) as sess:
102+
print("Running warm-up")
103+
for i in range(num_warmup):
104+
sess.run(output_tensor, {input_tensor: dummy_data})
105+
print("Warm-up complete")
106+
107+
for i in range(1, total_iter + 1):
108+
start_time = time.time()
109+
sess.run(output_tensor, {input_tensor: dummy_data})
110+
end_time = time.time()
111+
112+
if i % 10 == 0:
113+
print(
114+
"Steps = {0}, {1:10.6f} samples/sec".format(i, FLAGS.batch_size / duration))
115+
116+
duration = end_time - start_time
117+
total_time += duration
118+
119+
if FLAGS.profiling:
120+
options = tf.compat.v1.RunOptions(
121+
trace_level=tf.compat.v1.RunOptions.FULL_TRACE)
122+
run_metadata = tf.compat.v1.RunMetadata()
123+
124+
sess.run(output_tensor, {input_tensor: dummy_data},
125+
options=options, run_metadata=run_metadata)
126+
127+
fetched_timeline = timeline.Timeline(run_metadata.step_stats)
128+
chrome_trace = fetched_timeline.generate_chrome_trace_format()
129+
with open("timeline_%s.json" % (time.time()), 'w') as f:
130+
f.write(chrome_trace)
131+
132+
print("Average Thoughput: %f samples/sec" %
133+
(total_iter * FLAGS.batch_size / total_time))
134+
135+
136+
def main(_):
137+
if FLAGS.benchmark:
138+
run_benchmark()
139+
else:
140+
FLAGS.batch_size = 1
141+
from lpot.experimental import Quantization, common
142+
quantizer = Quantization(FLAGS.config)
143+
quantizer.model = common.Model(FLAGS.input_graph)
144+
kwargs = {'conf_threshold': FLAGS.conf_threshold,
145+
'iou_threshold': FLAGS.iou_threshold}
146+
quantizer.postprocess = common.Postprocess(NMS, 'NMS', **kwargs)
147+
q_model = quantizer()
148+
q_model.save(FLAGS.output_graph)
149+
150+
151+
if __name__ == '__main__':
152+
app.run(main)

0 commit comments

Comments
 (0)