Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions smdebug/tensorflow/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
is_profiler_supported_for_tf_version,
is_tf_version_2_3_x,
is_tf_version_2x,
is_tf_version_greater_than_2_4_x,
supported_tf_variables,
)

Expand Down Expand Up @@ -139,9 +140,16 @@ def _is_not_supported(self):
self._hook_supported = False
elif self.distribution_strategy == TFDistributionStrategy.MIRRORED:
try:
from tensorflow.python.keras.distribute.distributed_training_utils import (
get_distributed_model,
)
if is_tf_version_greater_than_2_4_x():
Copy link
Contributor

@ndodda-amazon ndodda-amazon Jan 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we replace this function from utils with the function defined here to reduce code redundancy?

# distributed_training_utils.py renamed to distributed_training_utils_v1 in tf 2.4.0
from tensorflow.python.keras.distribute.distributed_training_utils_v1 import (
get_distributed_model,
)
else:
from tensorflow.python.keras.distribute.distributed_training_utils import (
get_distributed_model,
)

except ImportError:
# for tf1.13 we can't import this, so we can't support mirrored strategy
self.logger.info(
Expand Down
19 changes: 9 additions & 10 deletions smdebug/tensorflow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,14 @@ def does_tf_support_mixed_precision_training():


def supported_tf_variables():
def is_mixed_precision_api_experimental():
"""
tensorflow mixed preicison api is experimental in versions below 2.4.0
:return: bool
"""
return version.parse(tf.__version__) < version.parse("2.4.0")

if does_tf_support_mixed_precision_training():
if is_mixed_precision_api_experimental():
from tensorflow.python.keras.mixed_precision.experimental import autocast_variable
else:
if is_tf_version_greater_than_2_4_x():
# tensorflow mixed preicison api is experimental in versions below 2.4.0
from tensorflow.python.keras.mixed_precision import autocast_variable

else:
from tensorflow.python.keras.mixed_precision.experimental import autocast_variable

return tf_v1.Variable, autocast_variable.AutoCastVariable
else:
return tf_v1.Variable
Expand Down Expand Up @@ -425,5 +420,9 @@ def is_tf_version_2_4_x():
return version.parse("2.4.0") <= version.parse(tf.__version__) < version.parse("2.5.0")


def is_tf_version_greater_than_2_4_x():
return version.parse("2.4.0") <= version.parse(tf.__version__)


def is_profiler_supported_for_tf_version():
return is_tf_version_2_2_x() or is_tf_version_2_3_x()