Skip to content

Commit e5eec9d

Browse files
committed
feat: Improve logging throughout the Dynamo path
- Add clear logging at the beginning and end of each phase of compilation - Reword logging in certain locations for clarity
1 parent 1ff10a6 commit e5eec9d

File tree

5 files changed

+34
-7
lines changed

5 files changed

+34
-7
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import Any, List, Optional, Sequence, Set, Tuple, Union
66

77
import torch
8-
import torch_tensorrt
98
from torch.export import ExportedProgram
109
from torch_tensorrt._Device import Device
1110
from torch_tensorrt._enums import ( # TODO: Should probabably be the TRT EngineCapability Enum
@@ -43,6 +42,8 @@
4342
to_torch_tensorrt_device,
4443
)
4544

45+
import torch_tensorrt
46+
4647
logger = logging.getLogger(__name__)
4748

4849

@@ -244,9 +245,11 @@ def compile_module(
244245
# Partition module into components that can be TRT-accelerated
245246
fast_partitioner_failed = False
246247

248+
logger.info("Beginning TensorRT operator Partitioning Phase")
247249
# If specified, try using the fast partitioner and fall back to the global one on failure
248250
if settings.use_fast_partitioner:
249251
try:
252+
logger.info("Partitioning the graph via the fast partitioner")
250253
partitioned_module = partitioning.fast_partition(
251254
gm,
252255
verbose=settings.debug,
@@ -256,21 +259,27 @@ def compile_module(
256259
except torch.fx.passes.splitter_base.FxNetSplitterInternalError:
257260
logger.error(
258261
"Partitioning failed on the subgraph with fast partition. See trace above. "
259-
+ "Retrying with global partition.",
262+
"Retrying with global partition.",
260263
exc_info=True,
261264
)
262265

263266
fast_partitioner_failed = True
264267
settings.use_fast_partitioner = False
265268

266269
if not settings.use_fast_partitioner:
270+
logger.info("Partitioning the graph via the global partitioner")
267271
partitioned_module = partitioning.global_partition(
268272
gm,
269273
verbose=settings.debug,
270274
min_block_size=settings.min_block_size,
271275
torch_executed_ops=settings.torch_executed_ops,
272276
)
273277

278+
logger.info(
279+
"Successfully completed graph partitioning phase. "
280+
"Beginning the conversion phase."
281+
)
282+
274283
# Store TRT replicas of Torch subgraphs
275284
trt_modules = {}
276285
# Iterate over all components that can be accelerated
@@ -289,14 +298,15 @@ def compile_module(
289298
to_torch_device(settings.device),
290299
)
291300

301+
assert submodule_inputs is not None
302+
292303
logger.debug(
293-
"Submodule name: %s\n Input shapes: %s\n %s",
304+
"Converting submodule: %s\n Input shapes: %s\n %s",
294305
str(name),
295306
[input.shape for input in submodule_inputs],
296307
str(submodule.graph),
297308
)
298309

299-
assert submodule_inputs is not None
300310
# Handle long/double inputs if requested by the user
301311
if settings.truncate_long_and_double:
302312
submodule_inputs = repair_long_or_double_inputs(

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,10 @@ def run_node(self, n: torch.fx.Node) -> torch.fx.Node:
256256
n.kwargs = kwargs
257257

258258
# run the node
259+
_LOGGER.debug(
260+
f"Running node {self._cur_node_name}, a {self._cur_node.op} node "
261+
f"with target {self._cur_node.target} in the TensorRT Interpreter"
262+
)
259263
trt_node: torch.fx.Node = super().run_node(n)
260264

261265
# remove "_itensor_to_tensor_meta"

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ def get_node_name(node: torch.fx.Node) -> str:
4444
# like the node.meta['source_fn'] attr
4545
pass
4646

47-
_LOGGER.debug(f"Node meta name {node_name}")
4847
return node_name
4948

5049

py/torch_tensorrt/dynamo/conversion/truncate_long_and_double.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
from __future__ import annotations
22

3+
import logging
34
from typing import Optional, Sequence, Set
45

56
import torch
67
from torch.fx.node import _get_qualified_name
78
from torch_tensorrt._Input import Input
89
from torch_tensorrt.dynamo.utils import get_torch_inputs
910

11+
logger = logging.getLogger(__name__)
12+
1013

1114
def _extract_downstream_get_nodes(
1215
module_node: torch.fx.Node, output_indices: Set[int]
@@ -62,6 +65,10 @@ def _repair_64bit_input(
6265
torch.float64,
6366
), f"dtype argument must be torch.int64 or torch.float64, got {dtype}"
6467

68+
logger.info(
69+
f"Downcasting a 64-bit input at position {position} of submodule {submodule_name}"
70+
)
71+
6572
# Determine target data type in 32 and 64 bit forms
6673
dtype_64bit = dtype
6774
dtype_32bit = torch.int32 if (dtype == torch.int64) else torch.float32

py/torch_tensorrt/dynamo/utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
from typing import Any, Callable, Dict, Optional, Sequence, Union
66

77
import torch
8-
import torch_tensorrt
98
from torch_tensorrt._Device import Device
109
from torch_tensorrt._Input import Input
1110
from torch_tensorrt.dynamo._defaults import PRECISION
1211
from torch_tensorrt.dynamo._settings import CompilationSettings
1312

13+
import torch_tensorrt
1414
from packaging import version
1515

1616
logger = logging.getLogger(__name__)
@@ -145,9 +145,16 @@ def prepare_inputs(
145145

146146
return torchtrt_inputs_dict
147147

148+
elif isinstance(inputs, (torch.SymBool, torch.SymFloat, torch.SymInt)):
149+
raise ValueError(
150+
f"Detected Torch symbolic input type {type(inputs)} during input parsing. "
151+
"Symbolic inputs are not currently allowed; please specify dynamic=False "
152+
"if using torch.compile with the Torch-TensorRT backend."
153+
)
154+
148155
else:
149156
raise ValueError(
150-
f"Invalid input type {type(inputs)} encountered in the dynamo_compile input parsing. "
157+
f"Invalid input type {type(inputs)} encountered during Dynamo input parsing. "
151158
+ "Allowed input types: {torch_tensorrt.Input, torch.Tensor, list, tuple, dict}"
152159
)
153160

0 commit comments

Comments
 (0)