Skip to content

Commit 307eede

Browse files
committed
fix: Add I/O values for nodes
1 parent 1e8226e commit 307eede

File tree

4 files changed

+94
-7
lines changed

4 files changed

+94
-7
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,6 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
307307
# Partition module into components that can be TRT-accelerated
308308
fast_partitioner_failed = False
309309

310-
logger.info("Beginning TensorRT operator Partitioning Phase")
311310
# If specified, try using the fast partitioner and fall back to the global one on failure
312311
if settings.use_fast_partitioner:
313312
try:
@@ -343,11 +342,6 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
343342
if not settings.use_fast_partitioner:
344343
dryrun_tracker.to_run_in_torch.extend(parse_non_trt_nodes(partitioned_module))
345344

346-
logger.info(
347-
"Successfully completed graph partitioning phase. "
348-
"Beginning the conversion phase."
349-
)
350-
351345
# Store TRT replicas of Torch subgraphs
352346
trt_modules = {}
353347
# Iterate over all components that can be accelerated

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
import warnings
33
from datetime import datetime
4-
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set
4+
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Tuple
55

66
import numpy as np
77
import tensorrt as trt
@@ -20,6 +20,7 @@
2020
)
2121
from torch_tensorrt.dynamo.conversion._ConverterRegistry import CallingConvention
2222
from torch_tensorrt.dynamo.conversion.converter_utils import (
23+
get_node_io,
2324
get_node_name,
2425
get_trt_tensor,
2526
)
@@ -106,6 +107,9 @@ def __init__(
106107
[dtype._from(o) for o in output_dtypes] if output_dtypes else None
107108
)
108109

110+
# Mapping of constants to shapes and dtypes
111+
self.const_mapping: Dict[str, Tuple[Sequence[int], str]] = {}
112+
109113
def validate_conversion(self) -> Set[str]:
110114
missing_converters: Set[str] = set()
111115

@@ -352,6 +356,13 @@ def run_node(self, n: torch.fx.Node) -> torch.fx.Node:
352356
)
353357
trt_node: torch.fx.Node = super().run_node(n)
354358

359+
if n.op == "get_attr":
360+
self.const_mapping[str(n)] = (tuple(trt_node.shape), str(trt_node.dtype))
361+
362+
_LOGGER.debug(
363+
f"Ran node {self._cur_node_name} with properties: {get_node_io(n, self.const_mapping)}"
364+
)
365+
355366
# remove "_itensor_to_tensor_meta"
356367
kwargs = dict(n.kwargs)
357368
del kwargs["_itensor_to_tensor_meta"]

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch
99
import torch_tensorrt.dynamo.conversion.impl as impl
1010
from torch.fx.node import Argument, Target
11+
from torch.fx.passes.shape_prop import TensorMetadata
1112
from torch_tensorrt import _enums
1213
from torch_tensorrt.dynamo._SourceIR import SourceIR
1314
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
@@ -44,6 +45,63 @@ def get_node_name(node: torch.fx.Node) -> str:
4445
return node_name
4546

4647

48+
def get_node_io(
49+
node: torch.fx.Node, constant_mapping: Dict[str, Tuple[Sequence[int], str]]
50+
) -> str:
51+
"""Gets a string representing the node inputs and outputs including tensor shapes and dtypes"""
52+
53+
def format_tensor_metadata(
54+
metadata: Union[TensorMetadata, Sequence[TensorMetadata]]
55+
) -> str:
56+
"""Formats the metadata for a single node"""
57+
# If the provided data is a simple TensorMetadata object, parse it
58+
if isinstance(metadata, TensorMetadata):
59+
return f"{tuple(metadata.shape)}@{metadata.dtype}"
60+
# If the provided data is a sequence, recursively parse it
61+
else:
62+
formatted_str = "("
63+
for meta in metadata:
64+
formatted_str += format_tensor_metadata(meta) + ", "
65+
66+
return formatted_str[:-2] + ")"
67+
68+
# Format input tensors
69+
metadata_string = "Inputs: ("
70+
71+
# For each input argument, format it accordingly
72+
for arg in node.args:
73+
if isinstance(arg, torch.fx.Node):
74+
if arg.op == "get_attr":
75+
shape, dtype = constant_mapping[str(arg)]
76+
arg_repr = f"{shape}@{dtype}"
77+
elif arg.meta.get("tensor_meta", False):
78+
arg_repr = format_tensor_metadata(arg.meta["tensor_meta"])
79+
else:
80+
arg_repr = ""
81+
82+
metadata_string += f"{arg}: {arg_repr}, "
83+
else:
84+
metadata_string += f"{arg}, "
85+
86+
metadata_string = (
87+
metadata_string[:-2] if metadata_string[-1] != "(" else metadata_string
88+
) + ")"
89+
90+
# Format output tensors and arguments
91+
metadata_string += " | Outputs: ("
92+
if node.op == "get_attr":
93+
shape, dtype = constant_mapping[str(node)]
94+
node_repr = f"{shape}@{dtype}"
95+
elif node.meta.get("tensor_meta", False):
96+
node_repr = format_tensor_metadata(node.meta["tensor_meta"])
97+
else:
98+
node_repr = ""
99+
metadata_string += f"{node}: {node_repr}, "
100+
metadata_string = metadata_string[:-2] + ")"
101+
102+
return metadata_string
103+
104+
47105
def is_only_operator_on_placeholder(node: torch.fx.Node) -> bool:
48106
"""Detects whether a call_function node is the only operator on a placeholder"""
49107
# Returns true if the node operates on a placeholder and is a direct output
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import logging
2+
from typing import Sequence
3+
4+
import torch
5+
from torch.fx.passes.shape_prop import ShapeProp
6+
7+
logger = logging.getLogger(__name__)
8+
9+
10+
def propagate_shapes(
11+
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
12+
) -> torch.fx.GraphModule:
13+
"""Attempts to propagate shapes through the graph"""
14+
15+
# Propagate shapes through the graph
16+
try:
17+
ShapeProp(gm).propagate(*sample_inputs)
18+
except (RuntimeError, AssertionError):
19+
logger.warning(
20+
"Shape Propagation Failed on Graph, skipping propagate_shapes lowering pass",
21+
exc_info=True,
22+
)
23+
24+
return gm

0 commit comments

Comments
 (0)