File tree Expand file tree Collapse file tree 1 file changed +10
-2
lines changed Expand file tree Collapse file tree 1 file changed +10
-2
lines changed Original file line number Diff line number Diff line change @@ -197,6 +197,14 @@ def register_hook(self, module):
197197 # for compatibility with ZCC patches which call this
198198 self .register_module (module )
199199
200+ @staticmethod
201+ def _add_module_name (module , module_name ):
202+ if isinstance (module , torch .nn .parallel .data_parallel .DataParallel ):
203+ module .module ._module_name = module_name
204+ else :
205+ module ._module_name = module_name
206+ return module
207+
200208 def register_module (self , module ):
201209 """
202210 This function registers the forward hook. If user wants to register the hook
@@ -215,9 +223,9 @@ def register_module(self, module):
215223
216224 for name , submodule in module .named_modules ():
217225 assert submodule not in self .module_set , f"Don't register module={ module } twice"
218- submodule . _module_name = name
226+ Hook . _add_module_name ( submodule , name )
219227 self .module_set .add (submodule )
220- module . _module_name = module ._get_name ()
228+ Hook . _add_module_name ( module , module ._get_name () )
221229 self .module_set .add (module )
222230
223231 # Use `forward_pre_hook` for the entire net
You can’t perform that action at this time.
0 commit comments