-
Notifications
You must be signed in to change notification settings - Fork 83
Save Nested Layers For Rubik #377
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
d34e39a
1d1092a
0edcfa4
1aa1cfa
d33823e
8ac2e62
9fb8c45
e34d80d
50e7ac2
ddf30a2
934632a
fb20c63
7f1378d
15b35fa
cd2cbfe
f406153
de07dad
d698ca5
4f1368f
977e9d5
824e731
6540252
613ecf6
5e613c3
3034a12
277e4cd
6992be7
6826d4b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -82,6 +82,7 @@ def __init__( | |
# this flag indicated to the train_batch_begin callback | ||
# the the step was already incremented in the on_train_begin callback | ||
self.step_incremented_in_on_train_begin = False | ||
self.has_wrapped_model_with_input_output_saver = False | ||
|
||
def _is_not_supported(self): | ||
if self.distribution_strategy is None: | ||
|
@@ -548,36 +549,56 @@ def _save_metrics(self, batch, logs, force_save=False): | |
def _save_layer_input_and_outputs(self): | ||
if is_tf_version_2x() is False: | ||
return | ||
layer_collection = ( | ||
{self.get_collection(CollectionKeys.LAYERS)} | ||
if self._is_collection_being_saved_for_step(CollectionKeys.LAYERS) | ||
else set() | ||
) | ||
for layer_name in self.saved_layers: | ||
# Save Input | ||
tensor = self.saved_layers[layer_name].layer_input | ||
export_name = get_export_name_for_keras(layer_name, tensor_type="input", tensor=tensor) | ||
input_collection = ( | ||
{self.get_collection(CollectionKeys.LAYERS)} | ||
if self._is_collection_being_saved_for_step(CollectionKeys.LAYERS) | ||
else set() | ||
) | ||
t = tensor[0] if isinstance(tensor, list) and len(tensor) else tensor | ||
if hasattr(t, "numpy") is False: | ||
self.logger.warning("cannot save layer values during forward pass with tf.function") | ||
continue | ||
else: | ||
self._save_tensor_to_file(export_name, tensor, input_collection) | ||
|
||
layer_inputs = self.saved_layers[layer_name].layer_input | ||
for layer_idx, tensor in enumerate(layer_inputs): | ||
if isinstance(tensor, list): | ||
tensor_list = tensor | ||
else: | ||
tensor_list = [tensor] | ||
if hasattr(tensor_list[0], "numpy") is False: | ||
self.logger.warning( | ||
"cannot save layer values during forward pass with tf.function" | ||
) | ||
continue | ||
else: | ||
for t_idx, t in enumerate(tensor_list): | ||
export_name = get_export_name_for_keras( | ||
layer_name, | ||
tensor_type="input", | ||
tensor=tensor, | ||
layer_idx=layer_idx, | ||
tensor_idx=t_idx, | ||
) | ||
self._save_tensor_to_file(export_name, t, layer_collection) | ||
# Save Output | ||
tensor = self.saved_layers[layer_name].layer_output | ||
export_name = get_export_name_for_keras(layer_name, tensor_type="output", tensor=tensor) | ||
self._is_collection_being_saved_for_step(CollectionKeys.LAYERS) | ||
output_collection = ( | ||
{self.get_collection(CollectionKeys.LAYERS)} | ||
if self._is_collection_being_saved_for_step(CollectionKeys.LAYERS) | ||
else set() | ||
) | ||
t = tensor[0] if isinstance(tensor, list) and len(tensor) else tensor | ||
if hasattr(t, "numpy") is False: | ||
self.logger.warning("cannot save layer values during forward pass with tf.function") | ||
else: | ||
self._save_tensor_to_file(export_name, tensor, output_collection) | ||
layer_outputs = self.saved_layers[layer_name].layer_output | ||
for layer_idx, tensor in enumerate(layer_outputs): | ||
if isinstance(tensor, list): | ||
tensor_list = tensor | ||
else: | ||
tensor_list = [tensor] | ||
if hasattr(tensor_list[0], "numpy") is False: | ||
self.logger.warning( | ||
"cannot save layer values during forward pass with tf.function" | ||
) | ||
continue | ||
else: | ||
for t_idx, t in enumerate(tensor_list): | ||
export_name = get_export_name_for_keras( | ||
layer_name, | ||
tensor_type="output", | ||
tensor=tensor, | ||
layer_idx=layer_idx, | ||
tensor_idx=t_idx, | ||
) | ||
self._save_tensor_to_file(export_name, t, layer_collection) | ||
|
||
def _save_tensors_post_step(self, batch, logs): | ||
# some tensors available as value from within hook are saved here | ||
|
@@ -707,15 +728,31 @@ def on_predict_begin(self, logs=None): | |
self._on_any_mode_begin(ModeKeys.PREDICT) | ||
|
||
def _wrap_model_with_input_output_saver(self): | ||
if self.has_registered_model: | ||
if ( | ||
self.has_wrapped_model_with_input_output_saver | ||
or self.model is None | ||
or self.has_default_hook_configuration() | ||
): | ||
# do not proceed if the model has already been wrapped | ||
# or the model has not been registered with smdebug yet | ||
return | ||
for layer in self.model.layers: | ||
for layer in self.model._flatten_layers(include_self=False, recursive=True): | ||
layer._hooks = [] | ||
layer._old_call = layer.call | ||
layer.call = get_layer_call_fn(layer) | ||
layer.register_hook = lambda hook: layer._hooks.append(hook) | ||
saver = InputOutputSaver() | ||
layer.register_hook(saver) | ||
self.saved_layers[layer.name] = saver | ||
self.has_wrapped_model_with_input_output_saver = True | ||
|
||
def _unwrap_model_with_input_output_saver(self): | ||
if self.has_wrapped_model_with_input_output_saver is False: | ||
return | ||
for layer in self.model._flatten_layers(include_self=False, recursive=True): | ||
layer._hooks = [] | ||
layer.call = layer._old_call | ||
self.has_wrapped_model_with_input_output_saver = False | ||
|
||
def _on_any_batch_begin(self, batch, mode, logs=None): | ||
if self._is_not_supported(): | ||
|
@@ -780,6 +817,7 @@ def _save_layer_values(self, logs): | |
step_collections = self._get_collections_to_save_for_step() | ||
layer_collection = self.get_collection(CollectionKeys.LAYERS) | ||
collections_to_write = {layer_collection} if layer_collection in step_collections else set() | ||
layer_name_dict = dict() | ||
for layer_name, layer_input, layer_output in logs: | ||
# Cast layer_name to str since it can also be of type bytes | ||
# when run with mirrored strategy | ||
|
@@ -794,9 +832,16 @@ def _save_layer_values(self, logs): | |
# Layer Inputs are flattened and passed as a list into | ||
# the next layer. Unpacking it speeds up the _make_numpy fn. | ||
layer_input = layer_input[0] | ||
layer_input_tensor_name = get_export_name_for_keras(str(layer_name), "input") | ||
layer_name = str(layer_name) | ||
idx = layer_name_dict.get(layer_name, 0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will this be 0 always because this is the first time the dict has been accessed after L820? |
||
layer_name_dict[layer_name] = idx + 1 | ||
layer_input_tensor_name = get_export_name_for_keras( | ||
layer_name, "input", layer_idx=idx, tensor_idx=idx | ||
) | ||
self._save_tensor_to_file(layer_input_tensor_name, layer_input, collections_to_write) | ||
layer_output_tensor_name = get_export_name_for_keras(str(layer_name), "output") | ||
layer_output_tensor_name = get_export_name_for_keras( | ||
layer_name, "output", layer_idx=idx, tensor_idx=idx | ||
) | ||
self._save_tensor_to_file(layer_output_tensor_name, layer_output, collections_to_write) | ||
|
||
def _write_optimizer_variables(self): | ||
|
@@ -849,6 +894,9 @@ def _on_any_batch_end(self, batch, mode, logs=None): | |
self._export_model() | ||
self._exported_model[self.mode] = True | ||
|
||
if is_tf_version_2x(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. any restriction for eager/non-eager/tf.function? |
||
self._unwrap_model_with_input_output_saver() | ||
|
||
def on_train_batch_end(self, batch, logs=None): | ||
self._on_any_batch_end(batch, ModeKeys.TRAIN, logs=logs) | ||
|
||
|
@@ -978,6 +1026,7 @@ def run(*args, **kwargs): | |
# this means sometimes collections will be exported after 1 step | ||
self.export_collections() | ||
self._exported_collections = True | ||
self._wrap_model_with_input_output_saver() | ||
|
||
return run | ||
|
||
|
@@ -1054,6 +1103,7 @@ def run(*args, **kwargs): | |
return | ||
|
||
self.last_saved_step = self.step | ||
self._unwrap_model_with_input_output_saver() | ||
|
||
return run | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -293,11 +293,11 @@ def is_keras_optimizer(obj): | |
return False | ||
|
||
|
||
def get_export_name_for_keras(layer, tensor_type, tensor=None): | ||
def get_export_name_for_keras(layer, tensor_type, tensor=None, layer_idx=None, tensor_idx=None): | ||
if tensor_type in ["input", "output", "weight"]: | ||
if isinstance(layer, str): | ||
# Tensor.name is meaningless when eager execution is enabled. | ||
return f"{layer}/{tensor_type}s" | ||
return f"{layer}_{layer_idx}/{tensor_type}_{tensor_idx}" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does this change how layer/tensor names have looked so far? |
||
else: | ||
return f"{layer.name}/{tensor_type}s/{tensor.name}" | ||
else: | ||
|
@@ -341,12 +341,12 @@ def register_hook(self, hook: Callable[[tf.Tensor, tf.Tensor], Optional[tf.Tenso | |
|
||
class InputOutputSaver: | ||
def __init__(self): | ||
self.layer_input = None | ||
self.layer_output = None | ||
self.layer_input = [] | ||
self.layer_output = [] | ||
|
||
def __call__(self, inputs, *args, **kwargs) -> None: | ||
self.layer_input = kwargs["layer_input"] | ||
self.layer_output = kwargs["layer_output"] | ||
self.layer_input.append(kwargs["layer_input"]) | ||
self.layer_output.append(kwargs["layer_output"]) | ||
|
||
|
||
def get_layer_call_fn(layer: tf.keras.layers.Layer) -> Callable[[tf.Tensor], tf.Tensor]: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# Third Party | ||
import numpy as np | ||
from tensorflow.keras.layers import Dense | ||
from tensorflow.python.keras.models import Model | ||
|
||
# First Party | ||
import smdebug.tensorflow as smd | ||
from smdebug.trials import create_trial | ||
|
||
|
||
class CustomModel(Model): | ||
def __init__(self): | ||
super(CustomModel, self).__init__() | ||
self.dense = Dense(10, activation="relu") | ||
|
||
def call(self, x): | ||
x = self.dense(x) | ||
x = self.dense(x) | ||
return self.dense(x) | ||
|
||
|
||
def test_layer_reusability(out_dir): | ||
model = CustomModel() | ||
hook = smd.KerasHook( | ||
out_dir, | ||
save_all=True, | ||
save_config=smd.SaveConfig(save_steps=[0], save_interval=1), | ||
reduction_config=smd.ReductionConfig(save_shape=True, save_raw_tensor=True), | ||
) | ||
|
||
hook.register_model(model) | ||
x_train = np.random.random((1000, 10)) | ||
y_train = np.random.random((1000, 1)) | ||
model.compile(optimizer="Adam", loss="mse", run_eagerly=True) | ||
model.fit(x_train, y_train, epochs=1, steps_per_epoch=1, callbacks=[hook]) | ||
|
||
trial = create_trial(path=out_dir, name="training_run") | ||
tensor_names = trial.tensor_names(collection=smd.CollectionKeys.LAYERS) | ||
""" | ||
[ | ||
'dense_0/input_0', | ||
'dense_0/output_0', | ||
'dense_1/input_0', | ||
'dense_1/output_0', | ||
'dense_2/input_0', | ||
'dense_2/output_0' | ||
] | ||
""" | ||
assert len(tensor_names) == 6 | ||
for name in tensor_names: | ||
shape = trial.tensor(name).shape(step_num=0) | ||
assert shape == (1000, 10) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Third Party | ||
import numpy as np | ||
from tensorflow.keras.layers import Concatenate, Dense, Layer | ||
from tensorflow.python.keras.models import Model | ||
|
||
# First Party | ||
import smdebug.tensorflow as smd | ||
from smdebug.trials import create_trial | ||
|
||
|
||
class CustomLayer(Layer): | ||
def __init__(self): | ||
super(CustomLayer, self).__init__() | ||
self.con = Concatenate() | ||
self.dense = Dense(10, activation="relu") | ||
|
||
def call(self, x): | ||
x = self.con([x, x]) | ||
return self.dense(x) | ||
|
||
|
||
class CustomModel(Model): | ||
def __init__(self): | ||
super(CustomModel, self).__init__() | ||
self.custom_layer = CustomLayer() | ||
|
||
def call(self, x): | ||
return self.custom_layer(x) | ||
|
||
|
||
def test_if_nested_layers_are_recorded(out_dir): | ||
model = CustomModel() | ||
hook = smd.KerasHook( | ||
out_dir, | ||
save_all=True, | ||
save_config=smd.SaveConfig(save_steps=[0], save_interval=1), | ||
reduction_config=smd.ReductionConfig(save_shape=True, save_raw_tensor=True), | ||
) | ||
|
||
hook.register_model(model) | ||
x_train = np.random.random((1000, 20)) | ||
y_train = np.random.random((1000, 1)) | ||
model.compile(optimizer="Adam", loss="mse", run_eagerly=True) | ||
model.fit(x_train, y_train, epochs=1, steps_per_epoch=1, callbacks=[hook]) | ||
trial = create_trial(path=out_dir) | ||
layer_names = trial.tensor_names(collection=smd.CollectionKeys.LAYERS) | ||
""" | ||
[ | ||
'concatenate_0/input_0', | ||
'concatenate_0/input_1', | ||
'concatenate_0/output_0', | ||
'custom_layer_0/input_0', | ||
'custom_layer_0/output_0', | ||
'dense_0/input_0', | ||
'dense_0/output_0' | ||
] | ||
""" | ||
assert len(layer_names) == 7 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this looks similar to what's done for outputs below. possible to make it common?