Skip to content
This repository was archived by the owner on Feb 3, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 1 addition & 11 deletions tftrt/examples/image-classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

This example includes scripts to run inference using a number of popular image classification models.

You can turn on TensorFlow-TensorRT integration with the flag `--use_trt`. This
You can turn on TF-TRT integration with the flag `--use_trt`. This
will apply TensorRT inference optimization to speed up execution for portions of
the model's graph where supported, and will fall back to native TensorFlow for
layers and operations which are not supported.
Expand Down Expand Up @@ -85,13 +85,3 @@ Run with `--help` to see all available options.
See [General Script Usage
](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html#image-class-usage)
for more information.

### Accuracy tests

There is the script `check_accuracy.py` provided in the example that parses the output log of `inference.py`
to find the reported accuracy, and reports whether that accuracy matches with the
baseline numbers.

See [Checking Accuracy
](https://docs.nvidia.com/deeplearning/dgx/integrate-tf-trt/index.html#image-class-accuracy)
for more information.
35 changes: 21 additions & 14 deletions tftrt/examples/image-classification/image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def after_run(self, run_context, run_values):
current_step, self.num_steps, duration * 1000,
self.batch_size / self.iter_times[-1]))

def run(frozen_graph, model, data_dir, batch_size,
def run(frozen_graph, model, data_files, batch_size,
num_iterations, num_warmup_iterations, use_synthetic, display_every=100, run_calibration=False):
"""Evaluates a frozen graph

Expand All @@ -62,7 +62,7 @@ def run(frozen_graph, model, data_dir, batch_size,

frozen_graph: GraphDef, a graph containing input node 'input' and outputs 'logits' and 'classes'
model: string, the model name (see NETS table in graph.py)
data_dir: str, directory containing ImageNet validation TFRecord files
data_files: List of TFRecord files used for inference
batch_size: int, batch size for TensorRT optimizations
num_iterations: int, number of iterations(batches) to run for
"""
Expand All @@ -80,14 +80,9 @@ def model_fn(features, labels, mode):
loss=loss,
eval_metric_ops={'accuracy': accuracy})

# Create the dataset
# preprocess function for input data
preprocess_fn = get_preprocess_fn(model)

if run_calibration:
validation_files = tf.gfile.Glob(os.path.join(data_dir, 'train*'))
else:
validation_files = tf.gfile.Glob(os.path.join(data_dir, 'validation*'))

def get_tfrecords_count(files):
num_records = 0
for fn in files:
Expand All @@ -111,7 +106,7 @@ def eval_input_fn():
dtype=np.int32)
labels = tf.identity(tf.constant(labels))
else:
dataset = tf.data.TFRecordDataset(validation_files)
dataset = tf.data.TFRecordDataset(data_files)
dataset = dataset.apply(tf.contrib.data.map_and_batch(map_func=preprocess_fn, batch_size=batch_size, num_parallel_calls=8))
dataset = dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)
dataset = dataset.repeat(count=1)
Expand All @@ -123,7 +118,7 @@ def eval_input_fn():
logger = LoggerHook(
display_every=display_every,
batch_size=batch_size,
num_records=get_tfrecords_count(validation_files))
num_records=get_tfrecords_count(data_files))
tf_config = tf.ConfigProto()
tf_config.gpu_options.allow_growth = True
estimator = tf.estimator.Estimator(
Expand Down Expand Up @@ -431,7 +426,7 @@ def get_frozen_graph(
precision='fp32',
batch_size=8,
minimum_segment_size=2,
calib_data_dir=None,
calib_files=None,
num_calib_inputs=None,
use_synthetic=False,
cache=False,
Expand Down Expand Up @@ -488,7 +483,7 @@ def get_frozen_graph(
# INT8 calibration step
print('Calibrating INT8...')
start_time = time.time()
run(calib_graph, model, calib_data_dir, batch_size,
run(calib_graph, model, calib_files, batch_size,
num_calib_inputs // batch_size, 0, False, run_calibration=True)
times['trt_calibration'] = time.time() - start_time

Expand Down Expand Up @@ -569,6 +564,18 @@ def get_frozen_graph(
raise ValueError('--num_calib_inputs must not be smaller than --batch_size'
'({} <= {})'.format(args.num_calib_inputs, args.batch_size))

def get_files(data_dir, filename_pattern):
if data_dir == None:
return []
files = tf.gfile.Glob(os.path.join(data_dir, filename_pattern))
if files == []:
raise ValueError('Can not find any files in {} with pattern "{}"'.format(
data_dir, filename_pattern))
return files

validation_files = get_files(args.data_dir, 'validation*')
calib_files = get_files(args.calib_data_dir, 'train*')

# Retreive graph using NETS table in graph.py
frozen_graph, num_nodes, times = get_frozen_graph(
model=args.model,
Expand All @@ -578,7 +585,7 @@ def get_frozen_graph(
precision=args.precision,
batch_size=args.batch_size,
minimum_segment_size=args.minimum_segment_size,
calib_data_dir=args.calib_data_dir,
calib_files=calib_files,
num_calib_inputs=args.num_calib_inputs,
use_synthetic=args.use_synthetic,
cache=args.cache,
Expand All @@ -602,7 +609,7 @@ def print_dict(input_dict, str=''):
results = run(
frozen_graph,
model=args.model,
data_dir=args.data_dir,
data_files=validation_files,
batch_size=args.batch_size,
num_iterations=args.num_iterations,
num_warmup_iterations=args.num_warmup_iterations,
Expand Down