Skip to content

Commit 1e8226e

Browse files
committed
feat: Improve logging throughout the Dynamo path
- Add clear logging at the beginning and end of each phase of compilation - Reword logging in certain locations for clarity
1 parent d9b2840 commit 1e8226e

File tree

4 files changed

+32
-4
lines changed

4 files changed

+32
-4
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -307,9 +307,11 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
307307
# Partition module into components that can be TRT-accelerated
308308
fast_partitioner_failed = False
309309

310+
logger.info("Beginning TensorRT operator Partitioning Phase")
310311
# If specified, try using the fast partitioner and fall back to the global one on failure
311312
if settings.use_fast_partitioner:
312313
try:
314+
logger.info("Partitioning the graph via the fast partitioner")
313315
partitioned_module, supported_ops = partitioning.fast_partition(
314316
gm,
315317
verbose=settings.debug,
@@ -319,14 +321,15 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
319321
except torch.fx.passes.splitter_base.FxNetSplitterInternalError:
320322
logger.error(
321323
"Partitioning failed on the subgraph with fast partition. See trace above. "
322-
+ "Retrying with global partition.",
324+
"Retrying with global partition.",
323325
exc_info=True,
324326
)
325327

326328
fast_partitioner_failed = True
327329
settings.use_fast_partitioner = False
328330

329331
if not settings.use_fast_partitioner:
332+
logger.info("Partitioning the graph via the global partitioner")
330333
partitioned_module, supported_ops = partitioning.global_partition(
331334
gm,
332335
verbose=settings.debug,
@@ -340,6 +343,11 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
340343
if not settings.use_fast_partitioner:
341344
dryrun_tracker.to_run_in_torch.extend(parse_non_trt_nodes(partitioned_module))
342345

346+
logger.info(
347+
"Successfully completed graph partitioning phase. "
348+
"Beginning the conversion phase."
349+
)
350+
343351
# Store TRT replicas of Torch subgraphs
344352
trt_modules = {}
345353
# Iterate over all components that can be accelerated
@@ -364,14 +372,15 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
364372
# Get the submodule inputs for min, opt, max shapes of the graph inputs
365373
submodule_inputs = partitioning.construct_submodule_inputs(submodule)
366374

375+
assert submodule_inputs is not None
376+
367377
logger.debug(
368-
"Submodule name: %s\n Input shapes: %s\n %s",
378+
"Converting submodule: %s\n Input shapes: %s\n %s",
369379
str(name),
370380
[input.shape for input in submodule_inputs],
371381
str(submodule.graph),
372382
)
373383

374-
assert submodule_inputs is not None
375384
# Handle long/double inputs if requested by the user
376385
if settings.truncate_double:
377386
submodule_inputs = repair_double_inputs(

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,10 @@ def run_node(self, n: torch.fx.Node) -> torch.fx.Node:
346346
n.kwargs = kwargs
347347

348348
# run the node
349+
_LOGGER.debug(
350+
f"Running node {self._cur_node_name}, a {self._cur_node.op} node "
351+
f"with target {self._cur_node.target} in the TensorRT Interpreter"
352+
)
349353
trt_node: torch.fx.Node = super().run_node(n)
350354

351355
# remove "_itensor_to_tensor_meta"

py/torch_tensorrt/dynamo/conversion/truncate_double.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import logging
34
from typing import Optional, Sequence, Set
45

56
import torch
@@ -8,6 +9,8 @@
89
from torch_tensorrt._Input import Input
910
from torch_tensorrt.dynamo.utils import get_torch_inputs
1011

12+
logger = logging.getLogger(__name__)
13+
1114

1215
def _extract_downstream_get_nodes(
1316
module_node: torch.fx.Node, output_indices: Set[int]
@@ -62,6 +65,10 @@ def _repair_64bit_input(
6265
torch.float64,
6366
), f"dtype argument must be torch.float64, got {dtype}"
6467

68+
logger.info(
69+
f"Downcasting a 64-bit input at position {position} of submodule {submodule_name}"
70+
)
71+
6572
# Determine target data type in 32 and 64 bit forms
6673
dtype_64bit = dtype
6774
dtype_32bit = torch.float32

py/torch_tensorrt/dynamo/utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Any, Callable, Dict, Optional, Sequence, Union
66

77
import torch
8+
import torch_tensorrt
89
from torch_tensorrt._Device import Device
910
from torch_tensorrt._enums import dtype
1011
from torch_tensorrt._Input import Input
@@ -189,9 +190,16 @@ def parse_complex_tensor_structs(
189190

190191
return torchtrt_inputs_dict
191192

193+
elif isinstance(inputs, (torch.SymBool, torch.SymFloat, torch.SymInt)):
194+
raise ValueError(
195+
f"Detected Torch symbolic input type {type(inputs)} during input parsing. "
196+
"Symbolic inputs are not currently allowed; please specify dynamic=False "
197+
"if using torch.compile with the Torch-TensorRT backend."
198+
)
199+
192200
else:
193201
raise ValueError(
194-
f"Invalid input type {type(inputs)} encountered in parse_complex_tensor_structs parsing. "
202+
f"Invalid input type {type(inputs)} encountered during Dynamo input parsing. "
195203
+ "Allowed input types: {torch_tensorrt.Input, torch.Tensor, list, tuple, dict}"
196204
)
197205

0 commit comments

Comments
 (0)