Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 81 additions & 31 deletions smdebug/tensorflow/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def __init__(
# this flag indicated to the train_batch_begin callback
# the the step was already incremented in the on_train_begin callback
self.step_incremented_in_on_train_begin = False
self.has_wrapped_model_with_input_output_saver = False

def _is_not_supported(self):
if self.distribution_strategy is None:
Expand Down Expand Up @@ -548,36 +549,56 @@ def _save_metrics(self, batch, logs, force_save=False):
def _save_layer_input_and_outputs(self):
if is_tf_version_2x() is False:
return
layer_collection = (
{self.get_collection(CollectionKeys.LAYERS)}
if self._is_collection_being_saved_for_step(CollectionKeys.LAYERS)
else set()
)
for layer_name in self.saved_layers:
# Save Input
tensor = self.saved_layers[layer_name].layer_input
export_name = get_export_name_for_keras(layer_name, tensor_type="input", tensor=tensor)
input_collection = (
{self.get_collection(CollectionKeys.LAYERS)}
if self._is_collection_being_saved_for_step(CollectionKeys.LAYERS)
else set()
)
t = tensor[0] if isinstance(tensor, list) and len(tensor) else tensor
if hasattr(t, "numpy") is False:
self.logger.warning("cannot save layer values during forward pass with tf.function")
continue
else:
self._save_tensor_to_file(export_name, tensor, input_collection)

layer_inputs = self.saved_layers[layer_name].layer_input
for layer_idx, tensor in enumerate(layer_inputs):
if isinstance(tensor, list):
tensor_list = tensor
else:
tensor_list = [tensor]
if hasattr(tensor_list[0], "numpy") is False:
self.logger.warning(
"cannot save layer values during forward pass with tf.function"
)
continue
else:
for t_idx, t in enumerate(tensor_list):
export_name = get_export_name_for_keras(
layer_name,
tensor_type="input",
tensor=tensor,
layer_idx=layer_idx,
tensor_idx=t_idx,
)
self._save_tensor_to_file(export_name, t, layer_collection)
Comment on lines +559 to +579
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks similar to what's done for outputs below. possible to make it common?

# Save Output
tensor = self.saved_layers[layer_name].layer_output
export_name = get_export_name_for_keras(layer_name, tensor_type="output", tensor=tensor)
self._is_collection_being_saved_for_step(CollectionKeys.LAYERS)
output_collection = (
{self.get_collection(CollectionKeys.LAYERS)}
if self._is_collection_being_saved_for_step(CollectionKeys.LAYERS)
else set()
)
t = tensor[0] if isinstance(tensor, list) and len(tensor) else tensor
if hasattr(t, "numpy") is False:
self.logger.warning("cannot save layer values during forward pass with tf.function")
else:
self._save_tensor_to_file(export_name, tensor, output_collection)
layer_outputs = self.saved_layers[layer_name].layer_output
for layer_idx, tensor in enumerate(layer_outputs):
if isinstance(tensor, list):
tensor_list = tensor
else:
tensor_list = [tensor]
if hasattr(tensor_list[0], "numpy") is False:
self.logger.warning(
"cannot save layer values during forward pass with tf.function"
)
continue
else:
for t_idx, t in enumerate(tensor_list):
export_name = get_export_name_for_keras(
layer_name,
tensor_type="output",
tensor=tensor,
layer_idx=layer_idx,
tensor_idx=t_idx,
)
self._save_tensor_to_file(export_name, t, layer_collection)

def _save_tensors_post_step(self, batch, logs):
# some tensors available as value from within hook are saved here
Expand Down Expand Up @@ -707,15 +728,31 @@ def on_predict_begin(self, logs=None):
self._on_any_mode_begin(ModeKeys.PREDICT)

def _wrap_model_with_input_output_saver(self):
if self.has_registered_model:
if (
self.has_wrapped_model_with_input_output_saver
or self.model is None
or self.has_default_hook_configuration()
):
# do not proceed if the model has already been wrapped
# or the model has not been registered with smdebug yet
return
for layer in self.model.layers:
for layer in self.model._flatten_layers(include_self=False, recursive=True):
layer._hooks = []
layer._old_call = layer.call
layer.call = get_layer_call_fn(layer)
layer.register_hook = lambda hook: layer._hooks.append(hook)
saver = InputOutputSaver()
layer.register_hook(saver)
self.saved_layers[layer.name] = saver
self.has_wrapped_model_with_input_output_saver = True

def _unwrap_model_with_input_output_saver(self):
if self.has_wrapped_model_with_input_output_saver is False:
return
for layer in self.model._flatten_layers(include_self=False, recursive=True):
layer._hooks = []
layer.call = layer._old_call
self.has_wrapped_model_with_input_output_saver = False

def _on_any_batch_begin(self, batch, mode, logs=None):
if self._is_not_supported():
Expand Down Expand Up @@ -780,6 +817,7 @@ def _save_layer_values(self, logs):
step_collections = self._get_collections_to_save_for_step()
layer_collection = self.get_collection(CollectionKeys.LAYERS)
collections_to_write = {layer_collection} if layer_collection in step_collections else set()
layer_name_dict = dict()
for layer_name, layer_input, layer_output in logs:
# Cast layer_name to str since it can also be of type bytes
# when run with mirrored strategy
Expand All @@ -794,9 +832,16 @@ def _save_layer_values(self, logs):
# Layer Inputs are flattened and passed as a list into
# the next layer. Unpacking it speeds up the _make_numpy fn.
layer_input = layer_input[0]
layer_input_tensor_name = get_export_name_for_keras(str(layer_name), "input")
layer_name = str(layer_name)
idx = layer_name_dict.get(layer_name, 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this be 0 always because this is the first time the dict has been accessed after L820?

layer_name_dict[layer_name] = idx + 1
layer_input_tensor_name = get_export_name_for_keras(
layer_name, "input", layer_idx=idx, tensor_idx=idx
)
self._save_tensor_to_file(layer_input_tensor_name, layer_input, collections_to_write)
layer_output_tensor_name = get_export_name_for_keras(str(layer_name), "output")
layer_output_tensor_name = get_export_name_for_keras(
layer_name, "output", layer_idx=idx, tensor_idx=idx
)
self._save_tensor_to_file(layer_output_tensor_name, layer_output, collections_to_write)

def _write_optimizer_variables(self):
Expand Down Expand Up @@ -849,6 +894,9 @@ def _on_any_batch_end(self, batch, mode, logs=None):
self._export_model()
self._exported_model[self.mode] = True

if is_tf_version_2x():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any restriction for eager/non-eager/tf.function?

self._unwrap_model_with_input_output_saver()

def on_train_batch_end(self, batch, logs=None):
self._on_any_batch_end(batch, ModeKeys.TRAIN, logs=logs)

Expand Down Expand Up @@ -978,6 +1026,7 @@ def run(*args, **kwargs):
# this means sometimes collections will be exported after 1 step
self.export_collections()
self._exported_collections = True
self._wrap_model_with_input_output_saver()

return run

Expand Down Expand Up @@ -1054,6 +1103,7 @@ def run(*args, **kwargs):
return

self.last_saved_step = self.step
self._unwrap_model_with_input_output_saver()

return run

Expand Down
12 changes: 6 additions & 6 deletions smdebug/tensorflow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,11 +293,11 @@ def is_keras_optimizer(obj):
return False


def get_export_name_for_keras(layer, tensor_type, tensor=None):
def get_export_name_for_keras(layer, tensor_type, tensor=None, layer_idx=None, tensor_idx=None):
if tensor_type in ["input", "output", "weight"]:
if isinstance(layer, str):
# Tensor.name is meaningless when eager execution is enabled.
return f"{layer}/{tensor_type}s"
return f"{layer}_{layer_idx}/{tensor_type}_{tensor_idx}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this change how layer/tensor names have looked so far?

else:
return f"{layer.name}/{tensor_type}s/{tensor.name}"
else:
Expand Down Expand Up @@ -341,12 +341,12 @@ def register_hook(self, hook: Callable[[tf.Tensor, tf.Tensor], Optional[tf.Tenso

class InputOutputSaver:
def __init__(self):
self.layer_input = None
self.layer_output = None
self.layer_input = []
self.layer_output = []

def __call__(self, inputs, *args, **kwargs) -> None:
self.layer_input = kwargs["layer_input"]
self.layer_output = kwargs["layer_output"]
self.layer_input.append(kwargs["layer_input"])
self.layer_output.append(kwargs["layer_output"])


def get_layer_call_fn(layer: tf.keras.layers.Layer) -> Callable[[tf.Tensor], tf.Tensor]:
Expand Down
12 changes: 8 additions & 4 deletions tests/tensorflow2/test_concat_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@ def test_multiple_inputs(out_dir):
my_model.fit(x_train, y_train, epochs=1, steps_per_epoch=1, callbacks=[hook])

trial = create_trial(path=out_dir)
tnames = sorted(trial.tensor_names(collection=smd.CollectionKeys.LAYERS))
assert "concatenate" in tnames[0]
assert len(trial.tensor(tnames[0]).value(0)) == 2
assert trial.tensor(tnames[0]).shape(0) == (2, 1000, 20)
tnames = trial.tensor_names(regex="concatenate")
assert len(tnames) == 3 # two inputs + one output
tnames = trial.tensor_names(regex="concatenate.+/input")
assert len(tnames) == 2 # Concatenate Layer receives two inputs
assert trial.tensor(tnames[0]).shape(0) == (1000, 20)
tnames = trial.tensor_names(regex="concatenate.+/output")
assert len(tnames) == 1 # Concatenate Layer emits a single output
assert trial.tensor(tnames[0]).shape(0) == (1000, 40)
4 changes: 2 additions & 2 deletions tests/tensorflow2/test_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def test_layer_names_gradient_tape(out_dir):

tr = create_trial_fast_refresh(out_dir)
tnames = tr.tensor_names(collection=CollectionKeys.LAYERS)
pattern = r"^(flatten|dense|dropout)(_\d+)?\/(inputs|outputs)"
pattern = r"^(flatten|dense|dropout)(_\d+)+?\/(input|output)_\d+"
for tname in tnames:
assert re.match(pattern=pattern, string=tname) is not None

Expand Down Expand Up @@ -580,7 +580,7 @@ def test_layer_names(out_dir, tf_eager_mode):

tr = create_trial_fast_refresh(out_dir)
tnames = tr.tensor_names(collection=CollectionKeys.LAYERS)
pattern = r"^(flatten|dense|dropout)(_\d+)?\/(inputs|outputs)"
pattern = r"^(flatten|dense|dropout)(_\d+)+?\/(input|output)_\d+"
for tname in tnames:
assert re.match(pattern=pattern, string=tname) is not None

Expand Down
52 changes: 52 additions & 0 deletions tests/tensorflow2/test_model_that_reuses_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Third Party
import numpy as np
from tensorflow.keras.layers import Dense
from tensorflow.python.keras.models import Model

# First Party
import smdebug.tensorflow as smd
from smdebug.trials import create_trial


class CustomModel(Model):
def __init__(self):
super(CustomModel, self).__init__()
self.dense = Dense(10, activation="relu")

def call(self, x):
x = self.dense(x)
x = self.dense(x)
return self.dense(x)


def test_layer_reusability(out_dir):
model = CustomModel()
hook = smd.KerasHook(
out_dir,
save_all=True,
save_config=smd.SaveConfig(save_steps=[0], save_interval=1),
reduction_config=smd.ReductionConfig(save_shape=True, save_raw_tensor=True),
)

hook.register_model(model)
x_train = np.random.random((1000, 10))
y_train = np.random.random((1000, 1))
model.compile(optimizer="Adam", loss="mse", run_eagerly=True)
model.fit(x_train, y_train, epochs=1, steps_per_epoch=1, callbacks=[hook])

trial = create_trial(path=out_dir, name="training_run")
tensor_names = trial.tensor_names(collection=smd.CollectionKeys.LAYERS)
"""
[
'dense_0/input_0',
'dense_0/output_0',
'dense_1/input_0',
'dense_1/output_0',
'dense_2/input_0',
'dense_2/output_0'
]
"""
assert len(tensor_names) == 6
for name in tensor_names:
shape = trial.tensor(name).shape(step_num=0)
assert shape == (1000, 10)
58 changes: 58 additions & 0 deletions tests/tensorflow2/test_nested_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Third Party
import numpy as np
from tensorflow.keras.layers import Concatenate, Dense, Layer
from tensorflow.python.keras.models import Model

# First Party
import smdebug.tensorflow as smd
from smdebug.trials import create_trial


class CustomLayer(Layer):
def __init__(self):
super(CustomLayer, self).__init__()
self.con = Concatenate()
self.dense = Dense(10, activation="relu")

def call(self, x):
x = self.con([x, x])
return self.dense(x)


class CustomModel(Model):
def __init__(self):
super(CustomModel, self).__init__()
self.custom_layer = CustomLayer()

def call(self, x):
return self.custom_layer(x)


def test_if_nested_layers_are_recorded(out_dir):
model = CustomModel()
hook = smd.KerasHook(
out_dir,
save_all=True,
save_config=smd.SaveConfig(save_steps=[0], save_interval=1),
reduction_config=smd.ReductionConfig(save_shape=True, save_raw_tensor=True),
)

hook.register_model(model)
x_train = np.random.random((1000, 20))
y_train = np.random.random((1000, 1))
model.compile(optimizer="Adam", loss="mse", run_eagerly=True)
model.fit(x_train, y_train, epochs=1, steps_per_epoch=1, callbacks=[hook])
trial = create_trial(path=out_dir)
layer_names = trial.tensor_names(collection=smd.CollectionKeys.LAYERS)
"""
[
'concatenate_0/input_0',
'concatenate_0/input_1',
'concatenate_0/output_0',
'custom_layer_0/input_0',
'custom_layer_0/output_0',
'dense_0/input_0',
'dense_0/output_0'
]
"""
assert len(layer_names) == 7