Skip to content

Commit fa6283f

Browse files
authored
Support tf 2 3 tests (#313)
1 parent ef7f671 commit fa6283f

File tree

2 files changed

+40
-7
lines changed

2 files changed

+40
-7
lines changed

tests/tensorflow2/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,9 @@ def is_tf_2_2():
1515
if version.parse(tf.__version__) >= version.parse("2.2.0"):
1616
return True
1717
return False
18+
19+
20+
def is_tf_2_3():
21+
if version.parse(tf.__version__) == version.parse("2.3.0"):
22+
return True
23+
return False

tests/zero_code_change/test_tensorflow2_integration.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
# Third Party
2121
import pytest
2222
import tensorflow.compat.v2 as tf
23+
from tests.tensorflow2.utils import is_tf_2_3
2324
from tests.utils import SagemakerSimulator
2425

2526
# First Party
@@ -50,28 +51,40 @@ def helper_test_keras_v2(script_mode: bool = False, eager_mode: bool = True):
5051
""" Test the default ZCC behavior of saving losses and metrics in eager and non-eager modes."""
5152
smd.del_hook()
5253
tf.keras.backend.clear_session()
53-
if not eager_mode:
54+
if not eager_mode and is_tf_2_3() is False:
55+
# v1 training APIs are currently not supported
56+
# in ZCC mode with smdebug 0.9 and AWS TF 2.3.0
5457
tf.compat.v1.disable_eager_execution()
5558
enable_tb = False if tf.__version__ == "2.0.2" else True
5659
with SagemakerSimulator(enable_tb=enable_tb) as sim:
5760
model = get_keras_model_v2()
5861
(x_train, y_train), (x_test, y_test) = get_keras_data()
5962
x_train, x_test = x_train / 255, x_test / 255
63+
run_eagerly = None
64+
if is_tf_2_3():
65+
# Test eager and non eager mode for v2
66+
run_eagerly = eager_mode
6067

6168
opt = tf.keras.optimizers.RMSprop()
6269
if script_mode:
6370
hook = smd.KerasHook(out_dir=sim.out_dir, export_tensorboard=True)
6471
opt = hook.wrap_optimizer(opt)
6572
model.compile(
66-
loss="sparse_categorical_crossentropy", optimizer=opt, metrics=["accuracy"]
73+
loss="sparse_categorical_crossentropy",
74+
optimizer=opt,
75+
metrics=["accuracy"],
76+
run_eagerly=run_eagerly,
6777
)
6878
history = model.fit(
6979
x_train, y_train, batch_size=64, epochs=2, validation_split=0.2, callbacks=[hook]
7080
)
7181
test_scores = model.evaluate(x_test, y_test, verbose=2, callbacks=[hook])
7282
else:
7383
model.compile(
74-
loss="sparse_categorical_crossentropy", optimizer=opt, metrics=["accuracy"]
84+
loss="sparse_categorical_crossentropy",
85+
optimizer=opt,
86+
metrics=["accuracy"],
87+
run_eagerly=run_eagerly,
7588
)
7689
history = model.fit(x_train, y_train, batch_size=64, epochs=2, validation_split=0.2)
7790
test_scores = model.evaluate(x_test, y_test, verbose=2)
@@ -101,7 +114,9 @@ def helper_test_keras_v2_json_config(
101114
""" Tests ZCC with custom hook configs """
102115
smd.del_hook()
103116
tf.keras.backend.clear_session()
104-
if not eager_mode:
117+
if not eager_mode and is_tf_2_3() is False:
118+
# v1 training APIs are currently not supported
119+
# in ZCC mode with smdebug 0.9 and AWS TF 2.3.0
105120
tf.compat.v1.disable_eager_execution()
106121
enable_tb = False if tf.__version__ == "2.0.2" else True
107122
with SagemakerSimulator(json_file_contents=json_file_contents, enable_tb=enable_tb) as sim:
@@ -110,19 +125,29 @@ def helper_test_keras_v2_json_config(
110125
x_train, x_test = x_train / 255, x_test / 255
111126

112127
opt = tf.keras.optimizers.RMSprop()
128+
run_eagerly = None
129+
if is_tf_2_3():
130+
# Test eager and non eager mode for v2
131+
run_eagerly = eager_mode
113132
if script_mode:
114133
hook = smd.KerasHook.create_from_json_file()
115134
opt = hook.wrap_optimizer(opt)
116135
model.compile(
117-
loss="sparse_categorical_crossentropy", optimizer=opt, metrics=["accuracy"]
136+
loss="sparse_categorical_crossentropy",
137+
optimizer=opt,
138+
metrics=["accuracy"],
139+
run_eagerly=run_eagerly,
118140
)
119141
history = model.fit(
120142
x_train, y_train, batch_size=64, epochs=2, validation_split=0.2, callbacks=[hook]
121143
)
122144
test_scores = model.evaluate(x_test, y_test, verbose=2, callbacks=[hook])
123145
else:
124146
model.compile(
125-
loss="sparse_categorical_crossentropy", optimizer=opt, metrics=["accuracy"]
147+
loss="sparse_categorical_crossentropy",
148+
optimizer=opt,
149+
metrics=["accuracy"],
150+
run_eagerly=run_eagerly,
126151
)
127152
history = model.fit(x_train, y_train, epochs=2, batch_size=64, validation_split=0.2)
128153
test_scores = model.evaluate(x_test, y_test, verbose=2)
@@ -134,7 +159,9 @@ def helper_test_keras_v2_json_config(
134159
trial = smd.create_trial(path=sim.out_dir)
135160
assert len(trial.steps()) > 0, "Nothing saved at any step."
136161
assert len(trial.tensor_names()) > 0, "Tensors were not saved."
137-
if not eager_mode:
162+
if not eager_mode and is_tf_2_3() is False:
163+
# Gradients are currently not saved in ZCC mode with AWS TF 2.3.0
164+
# and smdebug 0.9
138165
assert len(trial.tensor_names(collection="gradients")) > 0
139166
assert len(trial.tensor_names(collection="weights")) > 0
140167
assert len(trial.tensor_names(collection="losses")) > 0

0 commit comments

Comments
 (0)