Skip to content

Commit 9e6a581

Browse files
committed
get_name
1 parent be862bd commit 9e6a581

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

smdebug/pytorch/hook.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,14 @@ def register_hook(self, module):
197197
# for compatibility with ZCC patches which call this
198198
self.register_module(module)
199199

200+
@staticmethod
201+
def _add_module_name(module, module_name):
202+
if isinstance(module, torch.nn.parallel.data_parallel.DataParallel):
203+
module.module._module_name = module_name
204+
else:
205+
module._module_name = module_name
206+
return module
207+
200208
def register_module(self, module):
201209
"""
202210
This function registers the forward hook. If user wants to register the hook
@@ -215,9 +223,9 @@ def register_module(self, module):
215223

216224
for name, submodule in module.named_modules():
217225
assert submodule not in self.module_set, f"Don't register module={module} twice"
218-
submodule._module_name = name
226+
Hook._add_module_name(submodule, name)
219227
self.module_set.add(submodule)
220-
module._module_name = module._get_name()
228+
Hook._add_module_name(module, module._get_name())
221229
self.module_set.add(module)
222230

223231
# Use `forward_pre_hook` for the entire net

0 commit comments

Comments
 (0)