|
| 1 | +import time |
| 2 | +import numpy as np |
| 3 | +import tensorflow as tf |
| 4 | + |
| 5 | +from absl import app, flags |
| 6 | + |
| 7 | +from tensorflow.python.client import timeline |
| 8 | +from coco_constants import LABEL_MAP |
| 9 | +from utils import read_graph, non_max_suppression |
| 10 | + |
| 11 | +flags.DEFINE_integer('batch_size', 1, "batch size") |
| 12 | + |
| 13 | +flags.DEFINE_string("ground_truth", None, "ground truth file") |
| 14 | + |
| 15 | +flags.DEFINE_string("input_graph", None, "input graph") |
| 16 | + |
| 17 | +flags.DEFINE_string("output_graph", None, "input graph") |
| 18 | + |
| 19 | +flags.DEFINE_string("config", None, "LPOT config file") |
| 20 | + |
| 21 | +flags.DEFINE_float("conf_threshold", 0.5, "confidence threshold") |
| 22 | + |
| 23 | +flags.DEFINE_float("iou_threshold", 0.4, "IoU threshold") |
| 24 | + |
| 25 | +flags.DEFINE_integer("num_intra_threads", 0, "number of intra threads") |
| 26 | + |
| 27 | +flags.DEFINE_integer("num_inter_threads", 1, "number of inter threads") |
| 28 | + |
| 29 | +flags.DEFINE_boolean("benchmark", False, "benchmark mode") |
| 30 | + |
| 31 | +flags.DEFINE_boolean("profiling", False, "Signal of profiling") |
| 32 | + |
| 33 | +FLAGS = flags.FLAGS |
| 34 | + |
| 35 | + |
| 36 | +class NMS(): |
| 37 | + def __init__(self, conf_threshold, iou_threshold): |
| 38 | + self.conf_threshold = conf_threshold |
| 39 | + self.iou_threshold = iou_threshold |
| 40 | + |
| 41 | + def __call__(self, sample): |
| 42 | + preds, labels = sample |
| 43 | + if not isinstance(preds, np.ndarray): |
| 44 | + preds = np.array(preds) |
| 45 | + filtered_boxes = non_max_suppression(preds, |
| 46 | + self.conf_threshold, |
| 47 | + self.iou_threshold) |
| 48 | + |
| 49 | + det_boxes = [] |
| 50 | + det_scores = [] |
| 51 | + det_classes = [] |
| 52 | + for cls, bboxs in filtered_boxes.items(): |
| 53 | + det_classes.extend([LABEL_MAP[cls + 1]] * len(bboxs)) |
| 54 | + for box, score in bboxs: |
| 55 | + rect_pos = box.tolist() |
| 56 | + y_min, x_min = rect_pos[1], rect_pos[0] |
| 57 | + y_max, x_max = rect_pos[3], rect_pos[2] |
| 58 | + height, width = 416, 416 |
| 59 | + det_boxes.append( |
| 60 | + [y_min / height, x_min / width, y_max / height, x_max / width]) |
| 61 | + det_scores.append(score) |
| 62 | + |
| 63 | + if len(det_boxes) == 0: |
| 64 | + det_boxes = np.zeros((0, 4)) |
| 65 | + det_scores = np.zeros((0, )) |
| 66 | + det_classes = np.zeros((0, )) |
| 67 | + |
| 68 | + return [np.array([det_boxes]), np.array([det_scores]), np.array([det_classes])], labels |
| 69 | + |
| 70 | + |
| 71 | +def create_tf_config(): |
| 72 | + config = tf.compat.v1.ConfigProto() |
| 73 | + config.intra_op_parallelism_threads = FLAGS.num_intra_threads |
| 74 | + config.inter_op_parallelism_threads = FLAGS.num_inter_threads |
| 75 | + return config |
| 76 | + |
| 77 | + |
| 78 | +def run_benchmark(): |
| 79 | + config = create_tf_config() |
| 80 | + |
| 81 | + graph_def = read_graph(FLAGS.input_graph) |
| 82 | + |
| 83 | + tf.import_graph_def(graph_def, name='') |
| 84 | + |
| 85 | + input_tensor = tf.compat.v1.get_default_graph().get_tensor_by_name('inputs:0') |
| 86 | + output_tensor = tf.compat.v1.get_default_graph().get_tensor_by_name('output_boxes:0') |
| 87 | + |
| 88 | + dummy_data_shape = list(input_tensor.shape) |
| 89 | + dummy_data_shape[0] = FLAGS.batch_size |
| 90 | + dummy_data = np.random.random(dummy_data_shape).astype(np.float32) |
| 91 | + |
| 92 | + if FLAGS.profiling != True: |
| 93 | + num_warmup = 200 |
| 94 | + total_iter = 1000 |
| 95 | + else: |
| 96 | + num_warmup = 20 |
| 97 | + total_iter = 100 |
| 98 | + |
| 99 | + total_time = 0.0 |
| 100 | + |
| 101 | + with tf.compat.v1.Session(config=config) as sess: |
| 102 | + print("Running warm-up") |
| 103 | + for i in range(num_warmup): |
| 104 | + sess.run(output_tensor, {input_tensor: dummy_data}) |
| 105 | + print("Warm-up complete") |
| 106 | + |
| 107 | + for i in range(1, total_iter + 1): |
| 108 | + start_time = time.time() |
| 109 | + sess.run(output_tensor, {input_tensor: dummy_data}) |
| 110 | + end_time = time.time() |
| 111 | + |
| 112 | + if i % 10 == 0: |
| 113 | + print( |
| 114 | + "Steps = {0}, {1:10.6f} samples/sec".format(i, FLAGS.batch_size / duration)) |
| 115 | + |
| 116 | + duration = end_time - start_time |
| 117 | + total_time += duration |
| 118 | + |
| 119 | + if FLAGS.profiling: |
| 120 | + options = tf.compat.v1.RunOptions( |
| 121 | + trace_level=tf.compat.v1.RunOptions.FULL_TRACE) |
| 122 | + run_metadata = tf.compat.v1.RunMetadata() |
| 123 | + |
| 124 | + sess.run(output_tensor, {input_tensor: dummy_data}, |
| 125 | + options=options, run_metadata=run_metadata) |
| 126 | + |
| 127 | + fetched_timeline = timeline.Timeline(run_metadata.step_stats) |
| 128 | + chrome_trace = fetched_timeline.generate_chrome_trace_format() |
| 129 | + with open("timeline_%s.json" % (time.time()), 'w') as f: |
| 130 | + f.write(chrome_trace) |
| 131 | + |
| 132 | + print("Average Thoughput: %f samples/sec" % |
| 133 | + (total_iter * FLAGS.batch_size / total_time)) |
| 134 | + |
| 135 | + |
| 136 | +def main(_): |
| 137 | + if FLAGS.benchmark: |
| 138 | + run_benchmark() |
| 139 | + else: |
| 140 | + FLAGS.batch_size = 1 |
| 141 | + from lpot.experimental import Quantization, common |
| 142 | + quantizer = Quantization(FLAGS.config) |
| 143 | + quantizer.model = common.Model(FLAGS.input_graph) |
| 144 | + kwargs = {'conf_threshold': FLAGS.conf_threshold, |
| 145 | + 'iou_threshold': FLAGS.iou_threshold} |
| 146 | + quantizer.postprocess = common.Postprocess(NMS, 'NMS', **kwargs) |
| 147 | + q_model = quantizer() |
| 148 | + q_model.save(FLAGS.output_graph) |
| 149 | + |
| 150 | + |
| 151 | +if __name__ == '__main__': |
| 152 | + app.run(main) |
0 commit comments