Skip to content
This repository was archived by the owner on Feb 3, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 33 additions & 8 deletions tftrt/examples/image-classification/image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,31 @@ def after_run(self, run_context, run_values):
current_step, self.num_steps, duration * 1000,
self.batch_size / self.iter_times[-1]))

class DurationHook(tf.train.SessionRunHook):
"""Limits run duration"""
def __init__(self, target_duration):
self.target_duration = target_duration
self.start_time = None

def after_run(self, run_context, run_values):
if not self.target_duration:
return

if not self.start_time:
self.start_time = time.time()
print(" running for target duration from %d" % self.start_time)
return

current_time = time.time()
if (current_time - self.start_time) > self.target_duration:
print(" target duration %d reached at %d, requesting stop" % (self.target_duration, current_time))
run_context.request_stop()

def run(frozen_graph, model, data_files, batch_size,
num_iterations, num_warmup_iterations, use_synthetic=False, display_every=100, run_calibration=False):
num_iterations, num_warmup_iterations, use_synthetic, display_every=100, run_calibration=False,
target_duration=None):
"""Evaluates a frozen graph

This function evaluates a graph on the ImageNet validation set.
tf.estimator.Estimator is used to evaluate the accuracy of the model
and a few other metrics. The results are returned as a dict.
Expand Down Expand Up @@ -125,8 +146,9 @@ def eval_input_fn():
model_fn=model_fn,
config=tf.estimator.RunConfig(session_config=tf_config),
model_dir='model_dir')
results = estimator.evaluate(eval_input_fn, steps=num_iterations, hooks=[logger])

duration_hook = DurationHook(target_duration)
results = estimator.evaluate(eval_input_fn, steps=num_iterations, hooks=[logger, duration_hook])

# Gather additional results
iter_times = np.array(logger.iter_times[num_warmup_iterations:])
results['total_time'] = np.sum(iter_times)
Expand All @@ -137,7 +159,7 @@ def eval_input_fn():

class NetDef(object):
"""Contains definition of a model

name: Name of model
url: (optional) Where to download archive containing checkpoint
model_dir_in_archive: (optional) Subdirectory in archive containing
Expand Down Expand Up @@ -375,7 +397,7 @@ def get_checkpoint(model, model_dir=None, default_models_dir='.'):
if get_netdef(model).url:
download_checkpoint(model, model_dir)
return find_checkpoint_in_dir(model_dir)

print('No model_dir was provided and the model does not define a download' \
' URL.')
exit(1)
Expand Down Expand Up @@ -533,7 +555,7 @@ def get_frozen_graph(
parser.add_argument('--model_dir', type=str, default=None,
help='Directory containing model checkpoint. If not provided, a ' \
'checkpoint may be downloaded automatically and stored in ' \
'"{--default_models_dir}/{--model}" for future use.')
'"{--default_models_dir}/{--model}" for future use.')
parser.add_argument('--default_models_dir', type=str, default='./data',
help='Directory where downloaded model checkpoints will be stored and ' \
'loaded from if --model_dir is not provided.')
Expand Down Expand Up @@ -562,6 +584,8 @@ def get_frozen_graph(
help='workspace size in bytes')
parser.add_argument('--cache', action='store_true',
help='If set, graphs will be saved to disk after conversion. If a converted graph is present on disk, it will be loaded instead of building the graph again.')
parser.add_argument('--target_duration', type=int, default=None,
help='If set, script will run for specified number of seconds.')
args = parser.parse_args()

if args.precision != 'fp32' and not args.use_trt:
Expand Down Expand Up @@ -625,7 +649,8 @@ def print_dict(input_dict, str='', scale=None):
num_iterations=args.num_iterations,
num_warmup_iterations=args.num_warmup_iterations,
use_synthetic=args.use_synthetic,
display_every=args.display_every)
display_every=args.display_every,
target_duration=args.target_duration)

# Display results
print('results of {}:'.format(args.model))
Expand Down