diff --git a/tftrt/examples/image-classification/image_classification.py b/tftrt/examples/image-classification/image_classification.py index 21427e421..6d8e24a8a 100644 --- a/tftrt/examples/image-classification/image_classification.py +++ b/tftrt/examples/image-classification/image_classification.py @@ -52,10 +52,31 @@ def after_run(self, run_context, run_values): current_step, self.num_steps, duration * 1000, self.batch_size / self.iter_times[-1])) +class DurationHook(tf.train.SessionRunHook): + """Limits run duration""" + def __init__(self, target_duration): + self.target_duration = target_duration + self.start_time = None + + def after_run(self, run_context, run_values): + if not self.target_duration: + return + + if not self.start_time: + self.start_time = time.time() + print(" running for target duration from %d" % self.start_time) + return + + current_time = time.time() + if (current_time - self.start_time) > self.target_duration: + print(" target duration %d reached at %d, requesting stop" % (self.target_duration, current_time)) + run_context.request_stop() + def run(frozen_graph, model, data_files, batch_size, - num_iterations, num_warmup_iterations, use_synthetic=False, display_every=100, run_calibration=False): + num_iterations, num_warmup_iterations, use_synthetic, display_every=100, run_calibration=False, + target_duration=None): """Evaluates a frozen graph - + This function evaluates a graph on the ImageNet validation set. tf.estimator.Estimator is used to evaluate the accuracy of the model and a few other metrics. The results are returned as a dict. @@ -125,8 +146,9 @@ def eval_input_fn(): model_fn=model_fn, config=tf.estimator.RunConfig(session_config=tf_config), model_dir='model_dir') - results = estimator.evaluate(eval_input_fn, steps=num_iterations, hooks=[logger]) - + duration_hook = DurationHook(target_duration) + results = estimator.evaluate(eval_input_fn, steps=num_iterations, hooks=[logger, duration_hook]) + # Gather additional results iter_times = np.array(logger.iter_times[num_warmup_iterations:]) results['total_time'] = np.sum(iter_times) @@ -137,7 +159,7 @@ def eval_input_fn(): class NetDef(object): """Contains definition of a model - + name: Name of model url: (optional) Where to download archive containing checkpoint model_dir_in_archive: (optional) Subdirectory in archive containing @@ -375,7 +397,7 @@ def get_checkpoint(model, model_dir=None, default_models_dir='.'): if get_netdef(model).url: download_checkpoint(model, model_dir) return find_checkpoint_in_dir(model_dir) - + print('No model_dir was provided and the model does not define a download' \ ' URL.') exit(1) @@ -533,7 +555,7 @@ def get_frozen_graph( parser.add_argument('--model_dir', type=str, default=None, help='Directory containing model checkpoint. If not provided, a ' \ 'checkpoint may be downloaded automatically and stored in ' \ - '"{--default_models_dir}/{--model}" for future use.') + '"{--default_models_dir}/{--model}" for future use.') parser.add_argument('--default_models_dir', type=str, default='./data', help='Directory where downloaded model checkpoints will be stored and ' \ 'loaded from if --model_dir is not provided.') @@ -562,6 +584,8 @@ def get_frozen_graph( help='workspace size in bytes') parser.add_argument('--cache', action='store_true', help='If set, graphs will be saved to disk after conversion. If a converted graph is present on disk, it will be loaded instead of building the graph again.') + parser.add_argument('--target_duration', type=int, default=None, + help='If set, script will run for specified number of seconds.') args = parser.parse_args() if args.precision != 'fp32' and not args.use_trt: @@ -625,7 +649,8 @@ def print_dict(input_dict, str='', scale=None): num_iterations=args.num_iterations, num_warmup_iterations=args.num_warmup_iterations, use_synthetic=args.use_synthetic, - display_every=args.display_every) + display_every=args.display_every, + target_duration=args.target_duration) # Display results print('results of {}:'.format(args.model))