Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions smdebug/core/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@


class ScalarCache(object):
def __init__(self, scalar_name, scalar_val, sm_metric, write_tb, write_event):
def __init__(self, scalar_name, scalar_val, mode, sm_metric, write_tb, write_event):
self.name = scalar_name
self.value = scalar_val
self.mode = mode
self.sm_metric = sm_metric
self.write_tb = write_tb
self.write_event = write_event
Expand Down Expand Up @@ -442,6 +443,10 @@ def _increment_step(self):

self.step += 1
self.mode_steps[self.mode] += 1

# Increment Global step number irrespective of what mode it is
if self.mode != ModeKeys.GLOBAL:
self.mode_steps[ModeKeys.GLOBAL] = self.step
self._collections_to_save_for_step = None

def _write_state(self):
Expand Down Expand Up @@ -566,12 +571,15 @@ def _write_scalars(self):
for scalar_obj in self.scalar_cache:
scalar_name = scalar_obj.name
scalar_val = scalar_obj.value
scalar_mode = scalar_obj.mode
sm_metric = scalar_obj.sm_metric
write_tb = scalar_obj.write_tb
write_event = scalar_obj.write_event
if self.metrics_writer and sm_metric:
self.metrics_writer.log_metric(
scalar_name, scalar_val, iteration_number=self.mode_steps[self.mode]
scalar_name + "_" + scalar_mode.name,
scalar_val,
iteration_number=self.mode_steps[scalar_mode],
)
if write_tb:
tb_writer = self._maybe_get_tb_writer()
Expand All @@ -598,7 +606,7 @@ def save_scalar(self, name, value, sm_metric=False):
val = self._make_numpy_array(value)
if val.size != 1:
raise TypeError(f"{name} has non scalar value of type: {type(value)}")
scalar_obj = ScalarCache(name, val, sm_metric=True, write_tb=True, write_event=True)
scalar_obj = ScalarCache(name, val, self.mode, sm_metric, write_tb=True, write_event=True)
self.scalar_cache.append(scalar_obj)

def _write_raw_tensor(self, tensor_name, tensor_value, save_collections, tensor_ref=None):
Expand Down Expand Up @@ -659,7 +667,12 @@ def _save_for_tensor(self, tensor_name, tensor_value, check_before_write=True):
# Always log loss to Minerva
tensor_val = np.mean(np_val)
scalar_obj = ScalarCache(
tensor_name, tensor_val, sm_metric=True, write_tb=False, write_event=False
tensor_name,
tensor_val,
self.mode,
sm_metric=True,
write_tb=False,
write_event=False,
)
self.scalar_cache.append(scalar_obj)

Expand Down
2 changes: 1 addition & 1 deletion smdebug/mxnet/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def _register_default_collections(self):
self.get(CollectionKeys.WEIGHTS).include("^(?!gradient).*weight")
self.get(CollectionKeys.BIASES).include("^(?!gradient).*bias")
self.get(CollectionKeys.GRADIENTS).include("^gradient")
self.get(CollectionKeys.LOSSES).include(".*loss")
self.get(CollectionKeys.LOSSES).include(".*loss._(?!input).*output")

def create_collection(self, name):
super().create_collection(name, cls=Collection)
Expand Down
2 changes: 1 addition & 1 deletion smdebug/pytorch/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _register_default_collections(self):
self.get(CollectionKeys.WEIGHTS).include("^(?!gradient).*weight")
self.get(CollectionKeys.BIASES).include("^(?!gradient).*bias")
self.get(CollectionKeys.GRADIENTS).include("^gradient")
self.get(CollectionKeys.LOSSES).include("[Ll]oss")
self.get(CollectionKeys.LOSSES).include("[Ll]oss_(?!input).*output")

def create_collection(self, name):
super().create_collection(name, cls=Collection)
Expand Down
Loading