diff --git a/tftrt/examples/image-classification/README.md b/tftrt/examples/image-classification/README.md new file mode 100644 index 000000000..d4b8d662f --- /dev/null +++ b/tftrt/examples/image-classification/README.md @@ -0,0 +1,70 @@ +# TensorFlow-TensorRT Examples + +This script will run inference using a few popular image classification models +on the ImageNet validation set. + +You can turn on TensorFlow-TensorRT integration with the flag `--use_trt`. This +will apply TensorRT inference optimization to speed up execution for portions of +the model's graph where supported, and will fall back to native TensorFlow for +layers and operations which are not supported. See +https://devblogs.nvidia.com/tensorrt-integration-speeds-tensorflow-inference/ +for more information. + +When using TF-TRT, you can also control the precision with `--precision`. +float32 is the default (`--precision fp32`) with float16 (`--precision fp16`) or +int8 (`--precision int8`) allowing further performance improvements, at the cost +of some accuracy. int8 mode requires a calibration step which is done +automatically. + +## Models + +This test supports the following models for image classification: +* MobileNet v1 +* MobileNet v2 +* NASNet - Large +* NASNet - Mobile +* ResNet50 v1 +* ResNet50 v2 +* VGG16 +* VGG19 +* Inception v3 +* Inception v4 + +## Setup +``` +# Clone [tensorflow/models](https://github.com/tensorflow/models) +git clone https://github.com/tensorflow/models.git + +# Add the models directory to PYTHONPATH to install tensorflow/models. +cd models +export PYTHONPATH="$PYTHONPATH:$PWD" + +# Run the TF Slim setup. +cd research/slim +python setup.py install + +# You may also need to install the requests package +pip install requests +``` +Note: the PYTHONPATH environment variable will be not be saved between different +shells. You can either repeat that step each time you work in a new shell, or +add `export PYTHONPATH="$PYTHONPATH:/path/to/tensorflow_models"` to your .bashrc +file (replacing /path/to/tensorflow_models with the path to your +tensorflow/models repository). + +### Data + +The script supports only TFRecord format for data. The script +assumes that validation TFRecords are named according to the pattern: +`validation-*-of-00128`. + +You can download and process Imagenet using [this script provided by TF +Slim](https://github.com/tensorflow/models/blob/master/research/slim/datasets/download_imagenet.sh). +Please note that this script downloads both the training and validation sets, +and this example only requires the validation set. + +## Usage + +`python inference.py --data_dir /imagenet_validation_data --model vgg_16 [--use_trt]` + +Run with `--help` to see all available options. diff --git a/tftrt/examples/image-classification/image_classification.py b/tftrt/examples/image-classification/image_classification.py new file mode 100644 index 000000000..0090b802a --- /dev/null +++ b/tftrt/examples/image-classification/image_classification.py @@ -0,0 +1,606 @@ +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +import argparse +import os +import tensorflow as tf +import tensorflow.contrib.tensorrt as trt +import time +import numpy as np +import sys +import glob +import shutil +import subprocess +import nets.nets_factory +import tensorflow.contrib.slim as slim +import official.resnet.imagenet_main +from preprocessing import inception_preprocessing, vgg_preprocessing + +class LoggerHook(tf.train.SessionRunHook): + """Logs runtime of each iteration""" + def __init__(self, batch_size, num_records, display_every): + self.iter_times = [] + self.display_every = display_every + self.num_steps = (num_records + batch_size - 1) / batch_size + self.batch_size = batch_size + + def begin(self): + self.start_time = time.time() + + def after_run(self, run_context, run_values): + current_time = time.time() + duration = current_time - self.start_time + self.start_time = current_time + self.iter_times.append(duration) + current_step = len(self.iter_times) + if current_step % self.display_every == 0: + print(" step %d/%d, iter_time(ms)=%.4f, images/sec=%d" % ( + current_step, self.num_steps, duration * 1000, + self.batch_size / self.iter_times[-1])) + +def run(frozen_graph, model, data_dir, batch_size, + num_iterations, num_warmup_iterations, use_synthetic, display_every=100): + """Evaluates a frozen graph + + This function evaluates a graph on the ImageNet validation set. + tf.estimator.Estimator is used to evaluate the accuracy of the model + and a few other metrics. The results are returned as a dict. + + frozen_graph: GraphDef, a graph containing input node 'input' and outputs 'logits' and 'classes' + model: string, the model name (see NETS table in graph.py) + data_dir: str, directory containing ImageNet validation TFRecord files + batch_size: int, batch size for TensorRT optimizations + num_iterations: int, number of iterations(batches) to run for + """ + # Define model function for tf.estimator.Estimator + def model_fn(features, labels, mode): + logits_out, classes_out = tf.import_graph_def(frozen_graph, + input_map={'input': features}, + return_elements=['logits:0', 'classes:0'], + name='') + loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits_out) + accuracy = tf.metrics.accuracy(labels=labels, predictions=classes_out, name='acc_op') + if mode == tf.estimator.ModeKeys.EVAL: + return tf.estimator.EstimatorSpec( + mode, + loss=loss, + eval_metric_ops={'accuracy': accuracy}) + + # Create the dataset + preprocess_fn = get_preprocess_fn(model) + validation_files = tf.gfile.Glob(os.path.join(data_dir, 'validation*')) + + def get_tfrecords_count(files): + num_records = 0 + for fn in files: + for record in tf.python_io.tf_record_iterator(fn): + num_records += 1 + return num_records + + # Define the dataset input function for tf.estimator.Estimator + def eval_input_fn(): + if use_synthetic: + input_width, input_height = get_netdef(model).get_input_dims() + features = np.random.normal( + loc=112, scale=70, + size=(batch_size, input_height, input_width, 3)).astype(np.float32) + features = np.clip(features, 0.0, 255.0) + features = tf.identity(tf.constant(features)) + labels = np.random.randint( + low=0, + high=get_netdef(model).get_num_classes(), + size=(batch_size), + dtype=np.int32) + labels = tf.identity(tf.constant(labels)) + else: + dataset = tf.data.TFRecordDataset(validation_files) + dataset = dataset.apply(tf.contrib.data.map_and_batch(map_func=preprocess_fn, batch_size=batch_size, num_parallel_calls=8)) + dataset = dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE) + dataset = dataset.repeat(count=1) + iterator = dataset.make_one_shot_iterator() + features, labels = iterator.get_next() + return features, labels + + # Evaluate model + logger = LoggerHook( + display_every=display_every, + batch_size=batch_size, + num_records=get_tfrecords_count(validation_files)) + tf_config = tf.ConfigProto() + tf_config.gpu_options.allow_growth = True + estimator = tf.estimator.Estimator( + model_fn=model_fn, + config=tf.estimator.RunConfig(session_config=tf_config), + model_dir='model_dir') + results = estimator.evaluate(eval_input_fn, steps=num_iterations, hooks=[logger]) + + # Gather additional results + iter_times = np.array(logger.iter_times[num_warmup_iterations:]) + results['total_time'] = np.sum(iter_times) + results['images_per_sec'] = np.mean(batch_size / iter_times) + results['99th_percentile'] = np.percentile(iter_times, q=99, interpolation='lower') * 1000 + results['latency_mean'] = np.mean(iter_times) * 1000 + return results + +class NetDef(object): + """Contains definition of a model + + name: Name of model + url: (optional) Where to download archive containing checkpoint + model_dir_in_archive: (optional) Subdirectory in archive containing + checkpoint files. + preprocess: Which preprocessing method to use for inputs. + input_size: Input dimensions. + slim: If True, use tensorflow/research/slim/nets to build graph. Else, use + model_fn to build graph. + postprocess: Postprocessing function on predictions. + model_fn: Function to build graph if slim=False + num_classes: Number of output classes in model. Background class will be + automatically adjusted for if num_classes is 1001. + """ + def __init__(self, name, url=None, model_dir_in_archive=None, + checkpoint_name=None, preprocess='inception', + input_size=224, slim=True, postprocess=tf.nn.softmax, model_fn=None, num_classes=1001): + self.name = name + self.url = url + self.model_dir_in_archive = model_dir_in_archive + self.checkpoint_name = checkpoint_name + if preprocess == 'inception': + self.preprocess = inception_preprocessing.preprocess_image + elif preprocess == 'vgg': + self.preprocess = vgg_preprocessing.preprocess_image + self.input_width = input_size + self.input_height = input_size + self.slim = slim + self.postprocess = postprocess + self.model_fn = model_fn + self.num_classes = num_classes + + def get_input_dims(self): + return self.input_width, self.input_height + + def get_num_classes(self): + return self.num_classes + +def get_netdef(model): + """ + Creates the dictionary NETS with model names as keys and NetDef as values. + Returns the NetDef corresponding to the model specified in the parameter. + + model: string, the model name (see NETS table) + """ + NETS = { + 'mobilenet_v1': NetDef( + name='mobilenet_v1', + url='http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz'), + + 'mobilenet_v2': NetDef( + name='mobilenet_v2_140', + url='https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz'), + + 'nasnet_mobile': NetDef( + name='nasnet_mobile', + url='https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_mobile_04_10_2017.tar.gz'), + + 'nasnet_large': NetDef( + name='nasnet_large', + url='https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_large_04_10_2017.tar.gz', + input_size=331), + + 'resnet_v1_50': NetDef( + name='resnet_v1_50', + url='http://download.tensorflow.org/models/official/20180601_resnet_v1_imagenet_checkpoint.tar.gz', + model_dir_in_archive='20180601_resnet_v1_imagenet_checkpoint', + slim=False, + preprocess='vgg', + model_fn=official.resnet.imagenet_main.ImagenetModel(resnet_size=50, resnet_version=1)), + + 'resnet_v2_50': NetDef( + name='resnet_v2_50', + url='http://download.tensorflow.org/models/official/20180601_resnet_v2_imagenet_checkpoint.tar.gz', + model_dir_in_archive='20180601_resnet_v2_imagenet_checkpoint', + slim=False, + preprocess='vgg', + model_fn=official.resnet.imagenet_main.ImagenetModel(resnet_size=50, resnet_version=2)), + + 'resnet_v2_152': NetDef( + name='resnet_v2_152', + slim=False, + preprocess='vgg', + model_fn=official.resnet.imagenet_main.ImagenetModel(resnet_size=152, resnet_version=2)), + + 'vgg_16': NetDef( + name='vgg_16', + url='http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz', + preprocess='vgg', + num_classes=1000), + + 'vgg_19': NetDef( + name='vgg_19', + url='http://download.tensorflow.org/models/vgg_19_2016_08_28.tar.gz', + preprocess='vgg', + num_classes=1000), + + 'inception_v3': NetDef( + name='inception_v3', + url='http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz', + input_size=299), + + 'inception_v4': NetDef( + name='inception_v4', + url='http://download.tensorflow.org/models/inception_v4_2016_09_09.tar.gz', + input_size=299), + } + return NETS[model] + +def deserialize_image_record(record): + feature_map = { + 'image/encoded': tf.FixedLenFeature([ ], tf.string, ''), + 'image/class/label': tf.FixedLenFeature([1], tf.int64, -1), + 'image/class/text': tf.FixedLenFeature([ ], tf.string, ''), + 'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32), + 'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32), + 'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32), + 'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32) + } + with tf.name_scope('deserialize_image_record'): + obj = tf.parse_single_example(record, feature_map) + imgdata = obj['image/encoded'] + label = tf.cast(obj['image/class/label'], tf.int32) + bbox = tf.stack([obj['image/object/bbox/%s'%x].values + for x in ['ymin', 'xmin', 'ymax', 'xmax']]) + bbox = tf.transpose(tf.expand_dims(bbox, 0), [0,2,1]) + text = obj['image/class/text'] + return imgdata, label, bbox, text + +def get_preprocess_fn(model, mode='classification'): + """Creates a function to parse and process a TFRecord using the model's parameters + + model: string, the model name (see NETS table) + mode: string, whether the model is for classification or detection + returns: function, the preprocessing function for a record + """ + def process(record): + # Parse TFRecord + imgdata, label, bbox, text = deserialize_image_record(record) + label -= 1 # Change to 0-based (don't use background class) + try: image = tf.image.decode_jpeg(imgdata, channels=3, fancy_upscaling=False, dct_method='INTEGER_FAST') + except: image = tf.image.decode_png(imgdata, channels=3) + # Use model's preprocessing function + netdef = get_netdef(model) + image = netdef.preprocess(image, netdef.input_height, netdef.input_width, is_training=False) + return image, label + + return process + +def build_classification_graph(model, model_dir=None, default_models_dir='./data'): + """Builds an image classification model by name + + This function builds an image classification model given a model + name, parameter checkpoint file path, and number of classes. This + function performs some graph processing to produce a graph that is + well optimized by the TensorRT package in TensorFlow 1.7+. + + model: string, the model name (see NETS table) + model_dir: string, optional user provided checkpoint location + default_models_dir: string, directory to store downloaded model checkpoints + returns: tensorflow.GraphDef, the TensorRT compatible frozen graph + """ + netdef = get_netdef(model) + tf_config = tf.ConfigProto() + tf_config.gpu_options.allow_growth = True + + with tf.Graph().as_default() as tf_graph: + with tf.Session(config=tf_config) as tf_sess: + tf_input = tf.placeholder(tf.float32, [None, netdef.input_height, netdef.input_width, 3], name='input') + if netdef.slim: + # TF Slim Model: get model function from nets_factory + network_fn = nets.nets_factory.get_network_fn(netdef.name, netdef.num_classes, + is_training=False) + tf_net, tf_end_points = network_fn(tf_input) + else: + # TF Official Model: get model function from NETS + tf_net = netdef.model_fn(tf_input, training=False) + + tf_output = tf.identity(tf_net, name='logits') + num_classes = tf_output.get_shape().as_list()[1] + if num_classes == 1001: + # Shift class down by 1 if background class was included + tf_output_classes = tf.add(tf.argmax(tf_output, axis=1), -1, name='classes') + else: + tf_output_classes = tf.argmax(tf_output, axis=1, name='classes') + + # Get checkpoint. + checkpoint_path = get_checkpoint(model, model_dir, default_models_dir) + print('Using checkpoint:', checkpoint_path) + # load checkpoint + tf_saver = tf.train.Saver() + tf_saver.restore(save_path=checkpoint_path, sess=tf_sess) + + # freeze graph + frozen_graph = tf.graph_util.convert_variables_to_constants( + tf_sess, + tf_sess.graph_def, + output_node_names=['logits', 'classes'] + ) + + return frozen_graph + +def get_checkpoint(model, model_dir=None, default_models_dir='.'): + """Get the checkpoint. User may provide their own checkpoint via model_dir. + If model_dir is None, attempts to download the checkpoint using url property + from model definition (see get_netdef()). default_models_dir/model is first + checked to see if the checkpoint was already downloaded. If not, the + checkpoint will be downloaded from the url. + + model: string, the model name (see NETS table) + model_dir: string, optional user provided checkpoint location + default_models_dir: string, the directory where files are downloaded to + returns: string, path to the checkpoint file containing trained model params + """ + # User has provided a checkpoint + if model_dir: + checkpoint_path = find_checkpoint_in_dir(model_dir) + if not checkpoint_path: + print('No checkpoint was found in', model_dir) + exit(1) + return checkpoint_path + + # User has not provided a checkpoint. We need to download one. First check + # if checkpoint was already downloaded and stored in default_models_dir. + model_dir = os.path.join(default_models_dir, model) + checkpoint_path = find_checkpoint_in_dir(model_dir) + if checkpoint_path: + return checkpoint_path + + # Checkpoint has not yet been downloaded. Download checkpoint if model has + # defined a URL. + if get_netdef(model).url: + download_checkpoint(model, model_dir) + return find_checkpoint_in_dir(model_dir) + + print('No model_dir was provided and the model does not define a download' \ + ' URL.') + exit(1) + +def find_checkpoint_in_dir(model_dir): + # tf.train.latest_checkpoint will find checkpoints if a 'checkpoint' file is + # present in the directory. + checkpoint_path = tf.train.latest_checkpoint(model_dir) + if checkpoint_path: + return checkpoint_path + + # tf.train.latest_checkpoint did not find anything. Find .ckpt file + # manually. + files = glob.glob(os.path.join(model_dir, '*.ckpt*')) + if len(files) == 0: + return None + # Use last file for consistency if more than one (may not actually be + # "latest"). + checkpoint_path = sorted(files)[-1] + # Trim after .ckpt-* segment. For example: + # model.ckpt-257706.data-00000-of-00002 -> model.ckpt-257706 + parts = checkpoint_path.split('.') + ckpt_index = [i for i in range(len(parts)) if 'ckpt' in parts[i]][0] + checkpoint_path = '.'.join(parts[:ckpt_index+1]) + return checkpoint_path + +def download_checkpoint(model, destination_path): + # Make directories if they don't exist. + if not os.path.exists(destination_path): + os.makedirs(destination_path) + # Download archive. + archive_path = os.path.join(destination_path, + os.path.basename(get_netdef(model).url)) + if not os.path.isfile(archive_path): + subprocess.call(['wget', '--no-check-certificate', + get_netdef(model).url, '-O', archive_path]) + # Extract. + subprocess.call(['tar', '-xzf', archive_path, '-C', destination_path]) + # Move checkpoints out of archive sub directories into destination_path + if get_netdef(model).model_dir_in_archive: + source_files = os.path.join(destination_path, + get_netdef(model).model_dir_in_archive, + '*') + for f in glob.glob(source_files): + shutil.copy2(f, destination_path) + +def get_frozen_graph( + model, + model_dir=None, + use_trt=False, + use_dynamic_op=False, + precision='fp32', + batch_size=8, + minimum_segment_size=2, + calib_data_dir=None, + num_calib_inputs=None, + use_synthetic=False, + cache=False, + default_models_dir='./data'): + """Retreives a frozen GraphDef from model definitions in classification.py and applies TF-TRT + + model: str, the model name (see NETS table in classification.py) + use_trt: bool, if true, use TensorRT + precision: str, floating point precision (fp32, fp16, or int8) + batch_size: int, batch size for TensorRT optimizations + returns: tensorflow.GraphDef, the TensorRT compatible frozen graph + """ + num_nodes = {} + times = {} + + # Load from pb file if frozen graph was already created and cached + if cache: + # Graph must match the model, TRT mode, precision, and batch size + prebuilt_graph_path = "graphs/frozen_graph_%s_%d_%s_%d.pb" % (model, int(use_trt), precision, batch_size) + if os.path.isfile(prebuilt_graph_path): + print('Loading cached frozen graph from \'%s\'' % prebuilt_graph_path) + start_time = time.time() + with tf.gfile.GFile(prebuilt_graph_path, "rb") as f: + frozen_graph = tf.GraphDef() + frozen_graph.ParseFromString(f.read()) + times['loading_frozen_graph'] = time.time() - start_time + num_nodes['loaded_frozen_graph'] = len(frozen_graph.node) + num_nodes['trt_only'] = len([1 for n in frozen_graph.node if str(n.op)=='TRTEngineOp']) + return frozen_graph, num_nodes, times + + # Build graph and load weights + frozen_graph = build_classification_graph(model, model_dir, default_models_dir) + num_nodes['native_tf'] = len(frozen_graph.node) + + # Convert to TensorRT graph + if use_trt: + start_time = time.time() + frozen_graph = trt.create_inference_graph( + input_graph_def=frozen_graph, + outputs=['logits', 'classes'], + max_batch_size=batch_size, + max_workspace_size_bytes=(4096<<20)-1000, + precision_mode=precision, + minimum_segment_size=minimum_segment_size, + is_dynamic_op=use_dynamic_op + ) + times['trt_conversion'] = time.time() - start_time + num_nodes['tftrt_total'] = len(frozen_graph.node) + num_nodes['trt_only'] = len([1 for n in frozen_graph.node if str(n.op)=='TRTEngineOp']) + + if precision == 'int8': + calib_graph = frozen_graph + # INT8 calibration step + print('Calibrating INT8...') + start_time = time.time() + run(calib_graph, model, calib_data_dir, batch_size, + num_calib_inputs // batch_size, 0, False) + times['trt_calibration'] = time.time() - start_time + + start_time = time.time() + frozen_graph = trt.calib_graph_to_infer_graph(calib_graph) + times['trt_int8_conversion'] = time.time() - start_time + + del calib_graph + print('INT8 graph created.') + + # Cache graph to avoid long conversions each time + if cache: + if not os.path.exists(os.path.dirname(prebuilt_graph_path)): + try: + os.makedirs(os.path.dirname(prebuilt_graph_path)) + except Exception as e: + raise e + start_time = time.time() + with tf.gfile.GFile(prebuilt_graph_path, "wb") as f: + f.write(frozen_graph.SerializeToString()) + times['saving_frozen_graph'] = time.time() - start_time + + return frozen_graph, num_nodes, times + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Evaluate model') + parser.add_argument('--model', type=str, default='inception_v4', + choices=['mobilenet_v1', 'mobilenet_v2', 'nasnet_mobile', 'nasnet_large', + 'resnet_v1_50', 'resnet_v2_50', 'resnet_v2_152', 'vgg_16', 'vgg_19', + 'inception_v3', 'inception_v4'], + help='Which model to use.') + parser.add_argument('--data_dir', type=str, required=True, + help='Directory containing validation set TFRecord files.') + parser.add_argument('--calib_data_dir', type=str, + help='Directory containing TFRecord files for calibrating int8.') + parser.add_argument('--model_dir', type=str, default=None, + help='Directory containing model checkpoint. If not provided, a ' \ + 'checkpoint may be downloaded automatically and stored in ' \ + '"{--default_models_dir}/{--model}" for future use.') + parser.add_argument('--default_models_dir', type=str, default='./data', + help='Directory where downloaded model checkpoints will be stored and ' \ + 'loaded from if --model_dir is not provided.') + parser.add_argument('--use_trt', action='store_true', + help='If set, the graph will be converted to a TensorRT graph.') + parser.add_argument('--use_trt_dynamic_op', action='store_true', + help='If set, TRT conversion will be done using dynamic op instead of statically.') + parser.add_argument('--precision', type=str, choices=['fp32', 'fp16', 'int8'], default='fp32', + help='Precision mode to use. FP16 and INT8 only work in conjunction with --use_trt') + parser.add_argument('--batch_size', type=int, default=8, + help='Number of images per batch.') + parser.add_argument('--minimum_segment_size', type=int, default=2, + help='Minimum number of TF ops in a TRT engine.') + parser.add_argument('--num_iterations', type=int, default=None, + help='How many iterations(batches) to evaluate. If not supplied, the whole set will be evaluated.') + parser.add_argument('--display_every', type=int, default=100, + help='Number of iterations executed between two consecutive display of metrics') + parser.add_argument('--use_synthetic', action='store_true', + help='If set, one batch of random data is generated and used at every iteration.') + parser.add_argument('--num_warmup_iterations', type=int, default=50, + help='Number of initial iterations skipped from timing') + parser.add_argument('--num_calib_inputs', type=int, default=500, + help='Number of inputs (e.g. images) used for calibration ' + '(last batch is skipped in case it is not full)') + parser.add_argument('--cache', action='store_true', + help='If set, graphs will be saved to disk after conversion. If a converted graph is present on disk, it will be loaded instead of building the graph again.') + args = parser.parse_args() + + if args.precision != 'fp32' and not args.use_trt: + raise ValueError('TensorRT must be enabled for fp16 or int8 modes (--use_trt).') + if args.precision == 'int8' and not args.calib_data_dir: + raise ValueError('--calib_data_dir is required for int8 mode') + if args.num_iterations is not None and args.num_iterations <= args.num_warmup_iterations: + raise ValueError('--num_iterations must be larger than --num_warmup_iterations ' + '({} <= {})'.format(args.num_iterations, args.num_warmup_iterations)) + if args.num_calib_inputs < args.batch_size: + raise ValueError('--num_calib_inputs must not be smaller than --batch_size' + '({} <= {})'.format(args.num_calib_inputs, args.batch_size)) + + # Retreive graph using NETS table in graph.py + frozen_graph, num_nodes, times = get_frozen_graph( + model=args.model, + model_dir=args.model_dir, + use_trt=args.use_trt, + use_dynamic_op=args.use_trt_dynamic_op, + precision=args.precision, + batch_size=args.batch_size, + minimum_segment_size=args.minimum_segment_size, + calib_data_dir=args.calib_data_dir, + num_calib_inputs=args.num_calib_inputs, + use_synthetic=args.use_synthetic, + cache=args.cache, + default_models_dir=args.default_models_dir) + + def print_dict(input_dict, str=''): + for k, v in sorted(input_dict.items()): + headline = '{}({}): '.format(str, k) if str else '{}: '.format(k) + print('{}{}'.format(headline, '%.1f'%v if type(v)==float else v)) + print_dict(vars(args)) + print_dict(num_nodes, str='num_nodes') + print_dict(times, str='time(s)') + + # Evaluate model + print('running inference...') + results = run( + frozen_graph, + model=args.model, + data_dir=args.data_dir, + batch_size=args.batch_size, + num_iterations=args.num_iterations, + num_warmup_iterations=args.num_warmup_iterations, + use_synthetic=args.use_synthetic, + display_every=args.display_every) + + # Display results + print('results of {}:'.format(args.model)) + print(' accuracy: %.2f' % (results['accuracy'] * 100)) + print(' images/sec: %d' % results['images_per_sec']) + print(' 99th_percentile(ms): %.1f' % results['99th_percentile']) + print(' total_time(s): %.1f' % results['total_time']) + print(' latency_mean(ms): %.1f' % results['latency_mean'])