diff --git a/smdebug/core/utils.py b/smdebug/core/utils.py index 0557072b4..99f229391 100644 --- a/smdebug/core/utils.py +++ b/smdebug/core/utils.py @@ -64,14 +64,20 @@ class FRAMEWORK(Enum): except (ImportError, ModuleNotFoundError): _torch_dist_imported = None - +logger = get_logger() try: import horovod.torch as hvd # This redundant import is necessary because horovod does not raise an ImportError if the library is not present import torch # noqa + #make sure the library is correctly imported + hvd.init() _hvd_imported = hvd +except AttributeError: + _hvd_imported = None + logger.error("horovod.torch is not correctly imported.") + raise except (ModuleNotFoundError, ImportError): try: import horovod.tensorflow as hvd @@ -79,9 +85,8 @@ class FRAMEWORK(Enum): _hvd_imported = hvd except (ModuleNotFoundError, ImportError): _hvd_imported = None + raise - -logger = get_logger() error_handling_agent = ( ErrorHandlingAgent.get_error_handling_agent() ) # set up error handler to wrap smdebug functions