@@ -52,10 +52,31 @@ def after_run(self, run_context, run_values):
5252 current_step , self .num_steps , duration * 1000 ,
5353 self .batch_size / self .iter_times [- 1 ]))
5454
55+ class DurationHook (tf .train .SessionRunHook ):
56+ """Limits run duration"""
57+ def __init__ (self , target_duration ):
58+ self .target_duration = target_duration
59+ self .start_time = None
60+
61+ def after_run (self , run_context , run_values ):
62+ if not self .target_duration :
63+ return
64+
65+ if not self .start_time :
66+ self .start_time = time .time ()
67+ print (" running for target duration from %d" % self .start_time )
68+ return
69+
70+ current_time = time .time ()
71+ if (current_time - self .start_time ) > self .target_duration :
72+ print (" target duration %d reached at %d, requesting stop" % (self .target_duration , current_time ))
73+ run_context .request_stop ()
74+
5575def run (frozen_graph , model , data_files , batch_size ,
56- num_iterations , num_warmup_iterations , use_synthetic = False , display_every = 100 , run_calibration = False ):
76+ num_iterations , num_warmup_iterations , use_synthetic , display_every = 100 , run_calibration = False ,
77+ target_duration = None ):
5778 """Evaluates a frozen graph
58-
79+
5980 This function evaluates a graph on the ImageNet validation set.
6081 tf.estimator.Estimator is used to evaluate the accuracy of the model
6182 and a few other metrics. The results are returned as a dict.
@@ -125,8 +146,9 @@ def eval_input_fn():
125146 model_fn = model_fn ,
126147 config = tf .estimator .RunConfig (session_config = tf_config ),
127148 model_dir = 'model_dir' )
128- results = estimator .evaluate (eval_input_fn , steps = num_iterations , hooks = [logger ])
129-
149+ duration_hook = DurationHook (target_duration )
150+ results = estimator .evaluate (eval_input_fn , steps = num_iterations , hooks = [logger , duration_hook ])
151+
130152 # Gather additional results
131153 iter_times = np .array (logger .iter_times [num_warmup_iterations :])
132154 results ['total_time' ] = np .sum (iter_times )
@@ -137,7 +159,7 @@ def eval_input_fn():
137159
138160class NetDef (object ):
139161 """Contains definition of a model
140-
162+
141163 name: Name of model
142164 url: (optional) Where to download archive containing checkpoint
143165 model_dir_in_archive: (optional) Subdirectory in archive containing
@@ -375,7 +397,7 @@ def get_checkpoint(model, model_dir=None, default_models_dir='.'):
375397 if get_netdef (model ).url :
376398 download_checkpoint (model , model_dir )
377399 return find_checkpoint_in_dir (model_dir )
378-
400+
379401 print ('No model_dir was provided and the model does not define a download' \
380402 ' URL.' )
381403 exit (1 )
@@ -533,7 +555,7 @@ def get_frozen_graph(
533555 parser .add_argument ('--model_dir' , type = str , default = None ,
534556 help = 'Directory containing model checkpoint. If not provided, a ' \
535557 'checkpoint may be downloaded automatically and stored in ' \
536- '"{--default_models_dir}/{--model}" for future use.' )
558+ '"{--default_models_dir}/{--model}" for future use.' )
537559 parser .add_argument ('--default_models_dir' , type = str , default = './data' ,
538560 help = 'Directory where downloaded model checkpoints will be stored and ' \
539561 'loaded from if --model_dir is not provided.' )
@@ -562,6 +584,8 @@ def get_frozen_graph(
562584 help = 'workspace size in bytes' )
563585 parser .add_argument ('--cache' , action = 'store_true' ,
564586 help = 'If set, graphs will be saved to disk after conversion. If a converted graph is present on disk, it will be loaded instead of building the graph again.' )
587+ parser .add_argument ('--target_duration' , type = int , default = None ,
588+ help = 'If set, script will run for specified number of seconds.' )
565589 args = parser .parse_args ()
566590
567591 if args .precision != 'fp32' and not args .use_trt :
@@ -625,7 +649,8 @@ def print_dict(input_dict, str='', scale=None):
625649 num_iterations = args .num_iterations ,
626650 num_warmup_iterations = args .num_warmup_iterations ,
627651 use_synthetic = args .use_synthetic ,
628- display_every = args .display_every )
652+ display_every = args .display_every ,
653+ target_duration = args .target_duration )
629654
630655 # Display results
631656 print ('results of {}:' .format (args .model ))
0 commit comments