Skip to content
This repository was archived by the owner on Feb 3, 2025. It is now read-only.

Commit 90d62f6

Browse files
otstrelPooya Davoodi
authored andcommitted
Addin target_duration argument (#23)
1 parent 34344a9 commit 90d62f6

File tree

1 file changed

+33
-8
lines changed

1 file changed

+33
-8
lines changed

tftrt/examples/image-classification/image_classification.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,31 @@ def after_run(self, run_context, run_values):
5252
current_step, self.num_steps, duration * 1000,
5353
self.batch_size / self.iter_times[-1]))
5454

55+
class DurationHook(tf.train.SessionRunHook):
56+
"""Limits run duration"""
57+
def __init__(self, target_duration):
58+
self.target_duration = target_duration
59+
self.start_time = None
60+
61+
def after_run(self, run_context, run_values):
62+
if not self.target_duration:
63+
return
64+
65+
if not self.start_time:
66+
self.start_time = time.time()
67+
print(" running for target duration from %d" % self.start_time)
68+
return
69+
70+
current_time = time.time()
71+
if (current_time - self.start_time) > self.target_duration:
72+
print(" target duration %d reached at %d, requesting stop" % (self.target_duration, current_time))
73+
run_context.request_stop()
74+
5575
def run(frozen_graph, model, data_files, batch_size,
56-
num_iterations, num_warmup_iterations, use_synthetic=False, display_every=100, run_calibration=False):
76+
num_iterations, num_warmup_iterations, use_synthetic, display_every=100, run_calibration=False,
77+
target_duration=None):
5778
"""Evaluates a frozen graph
58-
79+
5980
This function evaluates a graph on the ImageNet validation set.
6081
tf.estimator.Estimator is used to evaluate the accuracy of the model
6182
and a few other metrics. The results are returned as a dict.
@@ -125,8 +146,9 @@ def eval_input_fn():
125146
model_fn=model_fn,
126147
config=tf.estimator.RunConfig(session_config=tf_config),
127148
model_dir='model_dir')
128-
results = estimator.evaluate(eval_input_fn, steps=num_iterations, hooks=[logger])
129-
149+
duration_hook = DurationHook(target_duration)
150+
results = estimator.evaluate(eval_input_fn, steps=num_iterations, hooks=[logger, duration_hook])
151+
130152
# Gather additional results
131153
iter_times = np.array(logger.iter_times[num_warmup_iterations:])
132154
results['total_time'] = np.sum(iter_times)
@@ -137,7 +159,7 @@ def eval_input_fn():
137159

138160
class NetDef(object):
139161
"""Contains definition of a model
140-
162+
141163
name: Name of model
142164
url: (optional) Where to download archive containing checkpoint
143165
model_dir_in_archive: (optional) Subdirectory in archive containing
@@ -375,7 +397,7 @@ def get_checkpoint(model, model_dir=None, default_models_dir='.'):
375397
if get_netdef(model).url:
376398
download_checkpoint(model, model_dir)
377399
return find_checkpoint_in_dir(model_dir)
378-
400+
379401
print('No model_dir was provided and the model does not define a download' \
380402
' URL.')
381403
exit(1)
@@ -533,7 +555,7 @@ def get_frozen_graph(
533555
parser.add_argument('--model_dir', type=str, default=None,
534556
help='Directory containing model checkpoint. If not provided, a ' \
535557
'checkpoint may be downloaded automatically and stored in ' \
536-
'"{--default_models_dir}/{--model}" for future use.')
558+
'"{--default_models_dir}/{--model}" for future use.')
537559
parser.add_argument('--default_models_dir', type=str, default='./data',
538560
help='Directory where downloaded model checkpoints will be stored and ' \
539561
'loaded from if --model_dir is not provided.')
@@ -562,6 +584,8 @@ def get_frozen_graph(
562584
help='workspace size in bytes')
563585
parser.add_argument('--cache', action='store_true',
564586
help='If set, graphs will be saved to disk after conversion. If a converted graph is present on disk, it will be loaded instead of building the graph again.')
587+
parser.add_argument('--target_duration', type=int, default=None,
588+
help='If set, script will run for specified number of seconds.')
565589
args = parser.parse_args()
566590

567591
if args.precision != 'fp32' and not args.use_trt:
@@ -625,7 +649,8 @@ def print_dict(input_dict, str='', scale=None):
625649
num_iterations=args.num_iterations,
626650
num_warmup_iterations=args.num_warmup_iterations,
627651
use_synthetic=args.use_synthetic,
628-
display_every=args.display_every)
652+
display_every=args.display_every,
653+
target_duration=args.target_duration)
629654

630655
# Display results
631656
print('results of {}:'.format(args.model))

0 commit comments

Comments
 (0)