66
77import gzip
88import json
9+ import logging
910import time
1011import warnings
1112from collections .abc import Iterable
1718from matplotlib import colors as mcolors
1819from pytorch3d .implicitron .tools .vis_utils import get_visdom_connection
1920
21+ logger = logging .getLogger (__name__ )
22+
2023
2124class AverageMeter (object ):
2225 """Computes and stores the average and current value"""
@@ -91,7 +94,9 @@ class Stats(object):
9194 # stats.update() automatically parses the 'objective' and 'top1e' from
9295 # the "output" dict and stores this into the db
9396 stats.update(output)
94- stats.print() # prints the averages over given epoch
97+ # prints the metric averages over given epoch
98+ std_out = stats.get_status_string()
99+ logger.info(str_out)
95100 # stores the training plots into '/tmp/epoch_stats.pdf'
96101 # and plots into a visdom server running at localhost (if running)
97102 stats.plot_stats(plot_file='/tmp/epoch_stats.pdf')
@@ -101,7 +106,6 @@ class Stats(object):
101106 def __init__ (
102107 self ,
103108 log_vars ,
104- verbose = False ,
105109 epoch = - 1 ,
106110 visdom_env = "main" ,
107111 do_plot = True ,
@@ -110,7 +114,6 @@ def __init__(
110114 visdom_port = 8097 ,
111115 ):
112116
113- self .verbose = verbose
114117 self .log_vars = log_vars
115118 self .visdom_env = visdom_env
116119 self .visdom_server = visdom_server
@@ -156,32 +159,29 @@ def __exit__(self, type, value, traceback):
156159 iserr = type is not None and issubclass (type , Exception )
157160 iserr = iserr or (type is KeyboardInterrupt )
158161 if iserr :
159- print ("error inside 'with' block" )
162+ logger . error ("error inside 'with' block" )
160163 return
161164 if self .do_plot :
162165 self .plot_stats (self .visdom_env )
163166
164167 def reset (self ): # to be called after each epoch
165168 stat_sets = list (self .stats .keys ())
166- if self .verbose :
167- print ("stats: epoch %d - reset" % self .epoch )
169+ logger .debug (f"stats: epoch { self .epoch } - reset" )
168170 self .it = {k : - 1 for k in stat_sets }
169171 for stat_set in stat_sets :
170172 for stat in self .stats [stat_set ]:
171173 self .stats [stat_set ][stat ].reset ()
172174
173175 def hard_reset (self , epoch = - 1 ): # to be called during object __init__
174176 self .epoch = epoch
175- if self .verbose :
176- print ("stats: epoch %d - hard reset" % self .epoch )
177+ logger .debug (f"stats: epoch { self .epoch } - hard reset" )
177178 self .stats = {}
178179
179180 # reset
180181 self .reset ()
181182
182183 def new_epoch (self ):
183- if self .verbose :
184- print ("stats: new epoch %d" % (self .epoch + 1 ))
184+ logger .debug (f"stats: new epoch { (self .epoch + 1 )} " )
185185 self .epoch += 1
186186 self .reset () # zero the stats + increase epoch counter
187187
@@ -193,18 +193,17 @@ def gather_value(self, val):
193193 val = float (val .sum ())
194194 return val
195195
196- def add_log_vars (self , added_log_vars , verbose = True ):
196+ def add_log_vars (self , added_log_vars ):
197197 for add_log_var in added_log_vars :
198198 if add_log_var not in self .stats :
199- if verbose :
200- print (f"Adding { add_log_var } " )
199+ logger .debug (f"Adding { add_log_var } " )
201200 self .log_vars .append (add_log_var )
202201
203202 def update (self , preds , time_start = None , freeze_iter = False , stat_set = "train" ):
204203
205204 if self .epoch == - 1 : # uninitialized
206- print (
207- "warning: epoch==-1 means uninitialized stats structure -> new_epoch() called"
205+ logger . warning (
206+ "epoch==-1 means uninitialized stats structure -> new_epoch() called"
208207 )
209208 self .new_epoch ()
210209
@@ -284,6 +283,12 @@ def print(
284283 skip_nan = False ,
285284 stat_format = lambda s : s .replace ("loss_" , "" ).replace ("prev_stage_" , "ps_" ),
286285 ):
286+ """
287+ stats.print() is deprecated. Please use get_status_string() instead.
288+ example:
289+ std_out = stats.get_status_string()
290+ logger.info(str_out)
291+ """
287292
288293 epoch = self .epoch
289294 stats = self .stats
@@ -311,8 +316,30 @@ def print(
311316 if get_str :
312317 return str_out
313318 else :
319+ warnings .warn (
320+ "get_str=False is deprecated."
321+ "Please enable this flag to get receive the output string." ,
322+ DeprecationWarning ,
323+ )
314324 print (str_out )
315325
326+ def get_status_string (
327+ self ,
328+ max_it = None ,
329+ stat_set = "train" ,
330+ vars_print = None ,
331+ skip_nan = False ,
332+ stat_format = lambda s : s .replace ("loss_" , "" ).replace ("prev_stage_" , "ps_" ),
333+ ):
334+ return self .print (
335+ max_it = max_it ,
336+ stat_set = stat_set ,
337+ vars_print = vars_print ,
338+ get_str = True ,
339+ skip_nan = skip_nan ,
340+ stat_format = stat_format ,
341+ )
342+
316343 def plot_stats (
317344 self , visdom_env = None , plot_file = None , visdom_server = None , visdom_port = None
318345 ):
@@ -329,16 +356,15 @@ def plot_stats(
329356
330357 stat_sets = list (self .stats .keys ())
331358
332- print (
333- "printing charts to visdom env '%s' (%s:%d)"
334- % (visdom_env , visdom_server , visdom_port )
359+ logger .debug (
360+ f"printing charts to visdom env '{ visdom_env } ' ({ visdom_server } :{ visdom_port } )"
335361 )
336362
337363 novisdom = False
338364
339365 viz = get_visdom_connection (server = visdom_server , port = visdom_port )
340366 if viz is None or not viz .check_connection ():
341- print ("no visdom server! -> skipping visdom plots" )
367+ logger . info ("no visdom server! -> skipping visdom plots" )
342368 novisdom = True
343369
344370 lines = []
@@ -385,7 +411,7 @@ def plot_stats(
385411 )
386412
387413 if plot_file :
388- print ( "exporting stats to %s" % plot_file )
414+ logger . info ( f"plotting stats to { plot_file } " )
389415 ncol = 3
390416 nrow = int (np .ceil (float (len (lines )) / ncol ))
391417 matplotlib .rcParams .update ({"font.size" : 5 })
@@ -423,15 +449,15 @@ def plot_stats(
423449 except PermissionError :
424450 warnings .warn ("Cant dump stats due to insufficient permissions!" )
425451
426- def synchronize_logged_vars (self , log_vars , default_val = float ("NaN" ), verbose = True ):
452+ def synchronize_logged_vars (self , log_vars , default_val = float ("NaN" )):
427453
428454 stat_sets = list (self .stats .keys ())
429455
430456 # remove the additional log_vars
431457 for stat_set in stat_sets :
432458 for stat in self .stats [stat_set ].keys ():
433459 if stat not in log_vars :
434- print ( "additional stat %s:%s -> removing" % ( stat_set , stat ) )
460+ logger . warning ( f "additional stat { stat_set } : { stat } -> removing" )
435461
436462 self .stats [stat_set ] = {
437463 stat : v for stat , v in self .stats [stat_set ].items () if stat in log_vars
@@ -442,21 +468,19 @@ def synchronize_logged_vars(self, log_vars, default_val=float("NaN"), verbose=Tr
442468 for stat_set in stat_sets :
443469 for stat in log_vars :
444470 if stat not in self .stats [stat_set ]:
445- if verbose :
446- print (
447- "missing stat %s:%s -> filling with default values (%1.2f)"
448- % (stat_set , stat , default_val )
449- )
471+ logger .info (
472+ "missing stat %s:%s -> filling with default values (%1.2f)"
473+ % (stat_set , stat , default_val )
474+ )
450475 elif len (self .stats [stat_set ][stat ].history ) != self .epoch + 1 :
451476 h = self .stats [stat_set ][stat ].history
452477 if len (h ) == 0 : # just never updated stat ... skip
453478 continue
454479 else :
455- if verbose :
456- print (
457- "incomplete stat %s:%s -> reseting with default values (%1.2f)"
458- % (stat_set , stat , default_val )
459- )
480+ logger .info (
481+ "incomplete stat %s:%s -> reseting with default values (%1.2f)"
482+ % (stat_set , stat , default_val )
483+ )
460484 else :
461485 continue
462486
0 commit comments