Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 6 additions & 11 deletions smdebug/core/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def __init__(
self.reduction_config = reduction_config
self.include_regex = include_regex
self.collection_manager = collection_manager
self.collection_manager.set_num_workers(self.get_num_workers())
self.collection_manager.set_num_workers(self._get_num_workers())
self.init_step = init_step

self.logger = logger
Expand Down Expand Up @@ -228,11 +228,11 @@ def __repr__(self):
)

@abstractmethod
def get_worker_name(self):
def _get_worker_name(self):
pass

@abstractmethod
def get_num_workers(self):
def _get_num_workers(self):
pass

#### Save Manager methods ####
Expand Down Expand Up @@ -361,7 +361,7 @@ def _initialize_writers(self) -> None:
return
self.writer = FileWriter(trial_dir=self.out_dir, step=self.step, worker=self.worker)

def get_writers(self, tensor_name, tensor_ref=None) -> List[FileWriter]:
def _get_writers(self, tensor_name, tensor_ref=None) -> List[FileWriter]:
"""
:param tensor_name:
:param tensor_ref: used by TF
Expand Down Expand Up @@ -451,7 +451,7 @@ def set_mode(self, mode):
self._collections_to_save_for_step = None

def export_collections(self):
num_workers = self.get_num_workers()
num_workers = self._get_num_workers()
if self.save_all_workers is False:
if self.chief_worker != self.worker:
return
Expand Down Expand Up @@ -604,7 +604,7 @@ def _write_raw_tensor_simple(self, tensor_name, tensor_value, tensor_ref=None):
numpy_tensor_value = self._make_numpy_array(tensor_value)
this_size, this_shape = size_and_shape(numpy_tensor_value)
if self.dry_run is False and this_size > 0:
writers = self.get_writers(tensor_name, tensor_ref=tensor_ref)
writers = self._get_writers(tensor_name, tensor_ref=tensor_ref)
for writer in writers:
writer.write_tensor(
tdata=numpy_tensor_value,
Expand Down Expand Up @@ -709,11 +709,6 @@ def _make_numpy_array(tensor_value):
:return: numpy ndarray
"""

def _set_collection_manager(self, coll_manager):
# used when creating hook from json config
# using this elsewhere may have unintended consequences
self.collection_manager = coll_manager

def add_to_collection(self, collection_name, variable):
self.collection_manager.get(collection_name).add(variable)

Expand Down
20 changes: 12 additions & 8 deletions smdebug/mxnet/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ def __init__(
self.exported_model = False
# Keep the set of blocks to which this hook is registered. The blocks include loss blocks as well.
self.registered_blocks = set()
self.worker = self.get_worker_name()
self.worker = self._get_worker_name()
set_hook(self)

def get_worker_name(self):
def _get_worker_name(self):
try:
import horovod.mxnet as hvd

Expand All @@ -74,7 +74,7 @@ def get_worker_name(self):
pass
return CONFIG_DEFAULT_WORKER_NAME

def get_num_workers(self):
def _get_num_workers(self):
try:
import horovod.mxnet as hvd

Expand All @@ -91,17 +91,17 @@ def hook_from_config(cls, json_config_path=None):
def _cleanup(self):
# Write the gradients of the past step if the writer is still available.
if self.writer is not None and self.last_block is not None:
self.log_params(self.last_block)
self._log_params(self.last_block)
if self.exported_model is False:
self._export_model()
super()._cleanup()

def log_params(self, block):
def _log_params(self, block):
params = block.collect_params().values()
for param in params:
self.log_param(param)
self._log_param(param)

def log_param(self, param):
def _log_param(self, param):
self._save_for_tensor(tensor_name=param.name, tensor_value=param.data(param.list_ctx()[0]))
# If Gradient for this param is available
if param.grad_req != "null":
Expand All @@ -127,7 +127,7 @@ def forward_pre_hook(self, block, inputs):
if self.writer is not None:
# Write the params and gradients of the
# past step if the writer is still available.
self.log_params(block)
self._log_params(block)
self._close_writers()
self._close_tb_writer()

Expand Down Expand Up @@ -206,6 +206,10 @@ def _is_recursive_needed(self):
return len(extra_coll) != 0

def register_hook(self, block):
# for compatibility with ZCC patches which call this
self.register_block(block)

def register_block(self, block):
"""
This function registers the forward hook. If user wants to register the hook
for every child in the given block, then the function calls "apply" API for
Expand Down
18 changes: 11 additions & 7 deletions smdebug/pytorch/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ def __init__(

self.has_registered_module = False
self.has_registered_loss_module = False
self.worker = self.get_worker_name()
self.worker = self._get_worker_name()
set_hook(self)

def get_num_workers(self):
def _get_num_workers(self):
"""Check horovod and torch.distributed."""
# Try torch.distributed
# torch.distributed is empty on Mac on Torch <= 1.2
Expand All @@ -76,7 +76,7 @@ def get_num_workers(self):
# Return default
return 1

def get_worker_name(self):
def _get_worker_name(self):
"""Check horovod and torch.distributed."""
# Try torch.distributed
# torch.distributed is empty on Mac on Torch <= 1.2
Expand Down Expand Up @@ -108,7 +108,7 @@ def hook_from_config(cls, json_config_path=None):
"""
return create_hook_from_json_config(cls, json_config_path=json_config_path)

def log_params(self, module):
def _log_params(self, module):
module_name = module._get_name()
params = module.named_parameters()
for name, param in params:
Expand Down Expand Up @@ -149,7 +149,7 @@ def forward_pre_hook(self, module, inputs):

if self._get_collections_to_save_for_step():
self._initialize_writers()
self.log_params(module)
self._log_params(module)

if self.last_saved_step is not None and not self.exported_collections:
self.export_collections()
Expand Down Expand Up @@ -200,11 +200,15 @@ def _backward_apply(self, module):
pname = module._get_name() + "_" + name
param.register_hook(self.backward_hook(pname))

def closure_for_registering_forward_hook(self, module):
def _closure_for_registering_forward_hook(self, module):
"""Lambda functions don't work here."""
module.register_forward_hook(self.forward_hook)

def register_hook(self, module):
# for compatibility with ZCC patches which call this
self.register_module(module)

def register_module(self, module):
"""
This function registers the forward hook. If user wants to register the hook
for every child in the given block, then the function calls "apply" API for
Expand All @@ -228,7 +232,7 @@ def register_hook(self, module):

# Set `self.forward_hook` as a callback for each submodule/layer.
# `module.apply(fn)` calls fn for each submodule in module.children()
module.apply(self.closure_for_registering_forward_hook)
module.apply(self._closure_for_registering_forward_hook)

# Capture the gradient for each parameter in the net
self._backward_apply(module)
Expand Down
13 changes: 7 additions & 6 deletions smdebug/tensorflow/base_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(
def hook_from_config(cls, json_config_path=None):
return create_hook_from_json_config(cls, json_config_path=json_config_path)

def get_distribution_strategy(self) -> TFDistributionStrategy:
def _get_distribution_strategy(self) -> TFDistributionStrategy:
try:
import horovod.tensorflow as hvd

Expand All @@ -120,7 +120,7 @@ def get_distribution_strategy(self) -> TFDistributionStrategy:

return TFDistributionStrategy.UNSUPPORTED

def get_worker_name(self) -> str:
def _get_worker_name(self) -> str:
"""
This function returns the name of the worker based on
the distribution strategy.
Expand All @@ -146,7 +146,7 @@ def get_worker_name(self) -> str:
return CONFIG_DEFAULT_WORKER_NAME

def export_collections(self):
num_workers = self.get_num_workers()
num_workers = self._get_num_workers()
if self.save_all_workers is False:
num_workers = 1
if (
Expand All @@ -168,7 +168,7 @@ def export_collections(self):
collection_file_name = f"{self.worker}_collections.json"
self.collection_manager.export(self.out_dir, collection_file_name)

def get_num_workers(self):
def _get_num_workers(self):
try:
import horovod.tensorflow as hvd

Expand All @@ -194,7 +194,7 @@ def _add_to_device_map(self, tensor):
if tensor.device and "CPU" not in tensor.device and tensor.device not in self.device_map:
self.device_map[tensor.device] = serialize_tf_device(tensor.device)

def get_writers(self, tensor_name, tensor_ref) -> List[FileWriter]:
def _get_writers(self, tensor_name, tensor_ref) -> List[FileWriter]:
"""
For tensors generated during distributed tf jobs, we map the tensor to a writer
with its device attribute.
Expand Down Expand Up @@ -379,7 +379,8 @@ def save_scalar(self, name, value, searchable=False):
save_scalar() not supported on Tensorflow
"""
self.logger.warning(
"save_scalar not supported on Tensorflow. Add the scalar to searchable_scalars collection instead."
"save_scalar not supported on Tensorflow. "
"Add the scalar to scalars or searchable_scalars collection instead. "
)
return

Expand Down
8 changes: 4 additions & 4 deletions smdebug/tensorflow/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def _is_not_supported(self):
):
self.logger.info("Disabling SMDebug as it does not support eager mode")
self._hook_supported = False
elif self.get_distribution_strategy() == TFDistributionStrategy.MIRRORED_STRATEGY:
elif self._get_distribution_strategy() == TFDistributionStrategy.MIRRORED_STRATEGY:
try:
from tensorflow.python.keras.distribute.distributed_training_utils import (
get_distributed_model,
Expand All @@ -90,7 +90,7 @@ def _is_not_supported(self):
"with TensorFlow version <1.14"
)
self._hook_supported = False
elif self.get_distribution_strategy() == TFDistributionStrategy.UNSUPPORTED:
elif self._get_distribution_strategy() == TFDistributionStrategy.UNSUPPORTED:
self.logger.info(
f"Disabling SMDebug as it does not support " f"{tf.distribute.get_strategy()}"
)
Expand Down Expand Up @@ -430,8 +430,8 @@ def on_epoch_end(self, batch, logs=None):
def _on_any_mode_begin(self, mode):
if self._is_not_supported():
return
self.distribution_strategy = self.get_distribution_strategy()
self.worker = self.get_worker_name()
self.distribution_strategy = self._get_distribution_strategy()
self.worker = self._get_worker_name()
self.graph = tf.get_default_graph()
self.set_mode(mode)

Expand Down
8 changes: 4 additions & 4 deletions smdebug/tensorflow/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def _add_weights_and_biases(self):
def _is_not_supported(self):
if self._hook_supported is None:
self._hook_supported = True
if self.get_distribution_strategy() == TFDistributionStrategy.MIRRORED_STRATEGY:
if self._get_distribution_strategy() == TFDistributionStrategy.MIRRORED_STRATEGY:
from packaging import version

if version.parse(tf.__version__) < version.parse("1.14.0"):
Expand All @@ -207,7 +207,7 @@ def _is_not_supported(self):
"Disabling SMDebug as it does not support mirrored strategy"
"with TensorFlow version <1.14"
)
elif self.get_distribution_strategy() == TFDistributionStrategy.UNSUPPORTED:
elif self._get_distribution_strategy() == TFDistributionStrategy.UNSUPPORTED:
self.logger.info(
f"Disabling SMDebug as it does not support " f"{tf.distribute.get_strategy()}"
)
Expand All @@ -231,8 +231,8 @@ def begin(self):
# todo: use global step from TF instead of tornasole steps

# todo: handle multiple graphs in the model
self.worker = self.get_worker_name()
self.distribution_strategy = self.get_distribution_strategy()
self.worker = self._get_worker_name()
self.distribution_strategy = self._get_distribution_strategy()
self.graph = tf.get_default_graph()

self._add_weights_and_biases()
Expand Down
12 changes: 4 additions & 8 deletions smdebug/trials/local_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,18 @@ def __init__(
self.index_reader = LocalIndexReader(self.path)
self.logger.info(f"Loading trial {name} at path {self.trial_dir}")
self._load_collections()
self.load_tensors()
self._load_tensors()

def get_collection_files(self) -> list:
def _get_collection_files(self) -> list:
return list_files_in_directory(get_path_to_collections(self.path))

def _load_tensors_from_index_tensors(self, index_tensors_dict):
for tname in index_tensors_dict:
for step, itds in index_tensors_dict[tname].items():
for worker in itds:
self.add_tensor(int(step), worker, itds[worker]["tensor_location"])
self._add_tensor(int(step), worker, itds[worker]["tensor_location"])

def read_collections(self, collection_files):
def _read_collections(self, collection_files):
first_collection_file = collection_files[0] # First Collection File
self.collection_manager = CollectionManager.load(first_collection_file)
self.num_workers = self.collection_manager.get_num_workers()

def get_tensors(self, tname_steps_dict, should_regex_match=False):
# now we do not need to do anything since we read the full event file
pass
13 changes: 4 additions & 9 deletions smdebug/trials/s3_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ def __init__(
self.index_reader = S3IndexReader(self.path)
self.s3_handler = S3Handler()
self._load_collections()
self.load_tensors()
self._load_tensors()

def get_collection_files(self) -> list:
def _get_collection_files(self) -> list:
collection_files, _ = list_s3_objects(
self.bucket_name,
get_path_to_collections(self.prefix_name),
Expand All @@ -61,9 +61,9 @@ def _load_tensors_from_index_tensors(self, index_tensors_dict):
for tname in index_tensors_dict:
for step, itds in index_tensors_dict[tname].items():
for worker in itds:
self.add_tensor(int(step), worker, itds[worker]["tensor_location"])
self._add_tensor(int(step), worker, itds[worker]["tensor_location"])

def read_collections(self, collection_files):
def _read_collections(self, collection_files):
first_collection_file = collection_files[0] # First Collection File
key = os.path.join(first_collection_file)
collections_req = ReadObjectRequest(self._get_s3_location(key))
Expand All @@ -72,10 +72,5 @@ def read_collections(self, collection_files):
self.collection_manager = CollectionManager.load_from_string(obj_data)
self.num_workers = self.collection_manager.get_num_workers()

def get_tensors(self, tname_steps_dict, should_regex_match=False):
# to be used when getting selective tensors from S3
# now we do not need to do anything since we read the full event file from S3
pass

def _get_s3_location(self, obj):
return "s3://" + self.bucket_name + "/" + obj
Loading