|
1 | 1 | import logging |
2 | 2 | from typing import Sequence |
3 | 3 | import torch |
4 | | -import traceback |
5 | 4 | from functools import partial |
6 | 5 | import torch._dynamo as td |
7 | 6 |
|
|
23 | 22 | from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler |
24 | 23 |
|
25 | 24 |
|
26 | | -<<<<<<< HEAD:py/torch_tensorrt/dynamo/backend/backends.py |
27 | | -@td.register_backend(name="torch_tensorrt") |
28 | | -======= |
29 | 25 | logger = logging.getLogger(__name__) |
30 | 26 |
|
31 | 27 |
|
32 | | -@td.register_backend(name="tensorrt") |
33 | | ->>>>>>> 7e0f4405... feat: Prototype Module-Acceleration in Dynamo:py/torch_tensorrt/dynamo/torch_compile/backends.py |
| 28 | +@td.register_backend(name="torch_tensorrt") |
34 | 29 | @fake_tensor_unsupported |
35 | 30 | def torch_tensorrt_backend( |
36 | 31 | gm: torch.fx.GraphModule, |
@@ -85,25 +80,31 @@ def _pretraced_backend( |
85 | 80 | Compiled FX GraphModule |
86 | 81 | """ |
87 | 82 | try: |
88 | | -<<<<<<< HEAD:py/torch_tensorrt/dynamo/backend/backends.py |
89 | | - trt_compiled = _compile_module( |
90 | | -======= |
91 | 83 | logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph)) |
92 | 84 |
|
93 | | - trt_compiled = compile_module( |
94 | | ->>>>>>> 7e0f4405... feat: Prototype Module-Acceleration in Dynamo:py/torch_tensorrt/dynamo/torch_compile/backends.py |
| 85 | + trt_compiled = _compile_module( |
95 | 86 | gm, |
96 | 87 | sample_inputs, |
97 | 88 | settings=settings, |
98 | 89 | ) |
99 | 90 | return trt_compiled |
100 | 91 | except: |
101 | | - traceback.print_exc() |
102 | | - print( |
| 92 | + logger.error( |
103 | 93 | "FX2TRT conversion failed on the subgraph. See trace above. " |
104 | | - + "Returning GraphModule forward instead." |
| 94 | + + "Returning GraphModule forward instead.", |
| 95 | + exc_info=True, |
105 | 96 | ) |
106 | | - return gm.forward |
| 97 | + |
| 98 | + if not settings.pass_through_build_failures: |
| 99 | + return gm.forward |
| 100 | + else: |
| 101 | + raise AssertionError( |
| 102 | + "Halting compilation on build failure since " |
| 103 | + + "pass_through_build_failures was specified as True. " |
| 104 | + + "To return the default Torch implementation and avoid " |
| 105 | + + "halting compilation on engine build failures, " |
| 106 | + + "specify pass_through_build_failures=False." |
| 107 | + ) |
107 | 108 |
|
108 | 109 |
|
109 | 110 | def _compile_module( |
|
0 commit comments