diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py b/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py index a29cee509d..beb0b1f1d3 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py @@ -1,4 +1,5 @@ import logging +import re import warnings from datetime import datetime from packaging import version @@ -291,8 +292,31 @@ def run( engine, self._input_names, self._output_names, serialized_cache ) + def get_node_name(self, node): + # nn_module_stack preserves the call stack of pytorch nn.modules + # The call stack contains a detailed name of the module + # which shows exactly where the module is located in the + # network architecture. + stack_item = node.meta.get("nn_module_stack", None) + # The current node is the last item in the stack + mod_stack = stack_item.popitem() if stack_item else "" + node_name = str(node) + if mod_stack: + mod_name = str(mod_stack[0]).replace("___", "/") + # Clean up the module name + mod_name = re.sub("^.*__self", "", mod_name) + mod_name = re.sub("_(\d+)$", "/\g<1>", mod_name) + node_name = mod_name + "/" + node_name + else: + # Try an alternative way to get the module info + # like the node.meta['source_fn'] attr + pass + + _LOGGER.debug(f"Node meta name {node_name}") + return node_name + def run_node(self, n): - self._cur_node_name = str(n) + self._cur_node_name = self.get_node_name(n) # add "_itensor_to_tensor_meta" kwargs = dict(n.kwargs) kwargs["_itensor_to_tensor_meta"] = self._itensor_to_tensor_meta diff --git a/py/torch_tensorrt/fx/fx2trt.py b/py/torch_tensorrt/fx/fx2trt.py index d7ef976fba..aff46f3290 100644 --- a/py/torch_tensorrt/fx/fx2trt.py +++ b/py/torch_tensorrt/fx/fx2trt.py @@ -1,5 +1,6 @@ import logging import os +import re import warnings from datetime import datetime from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence @@ -270,8 +271,27 @@ def run( engine, self._input_names, self._output_names, serialized_cache ) + def get_node_name(self, node): + # nn_module_stack preserves the call stack of pytorch nn.modules + # The call stack contains a detailed name of the module + # which shows exactly where the module is located in the + # network architecture. + stack_item = node.meta.get("nn_module_stack", None) + # The current node is the last item in the stack + mod_stack = stack_item.popitem() if stack_item else "" + node_name = str(node) + if mod_stack: + mod_name = str(mod_stack[0]).replace("___", "/") + # Clean up the module name + mod_name = re.sub("^.*__self", "", mod_name) + mod_name = re.sub("_(\d+)$", "/\g<1>", mod_name) + node_name = mod_name + "/" + node_name + + _LOGGER.debug(f"Node meta name {node_name}") + return node_name + def run_node(self, n): - self._cur_node_name = str(n) + self._cur_node_name = self.get_node_name(n) # add "_itensor_to_tensor_meta" kwargs = dict(n.kwargs) kwargs["_itensor_to_tensor_meta"] = self._itensor_to_tensor_meta