diff --git a/py/torch_tensorrt/dynamo/backend/__init__.py b/py/torch_tensorrt/dynamo/backend/__init__.py index 037294965c..4fa76b3dfc 100644 --- a/py/torch_tensorrt/dynamo/backend/__init__.py +++ b/py/torch_tensorrt/dynamo/backend/__init__.py @@ -8,7 +8,6 @@ from torch_tensorrt import EngineCapability, Device from torch_tensorrt.fx.utils import LowerPrecision -from torch_tensorrt.dynamo.backend._settings import CompilationSettings from torch_tensorrt.dynamo.backend.utils import prepare_inputs, prepare_device from torch_tensorrt.dynamo.backend.backends import torch_tensorrt_backend from torch_tensorrt.dynamo.backend._defaults import ( @@ -62,6 +61,10 @@ def compile( inputs = prepare_inputs(inputs, prepare_device(device)) + if not isinstance(enabled_precisions, collections.abc.Collection): + enabled_precisions = [enabled_precisions] + + # Parse user-specified enabled precisions if ( torch.float16 in enabled_precisions or torch_tensorrt.dtype.half in enabled_precisions @@ -123,10 +126,8 @@ def create_backend( Returns: Backend for torch.compile """ - if debug: - logger.setLevel(logging.DEBUG) - - settings = CompilationSettings( + return partial( + torch_tensorrt_backend, debug=debug, precision=precision, workspace_size=workspace_size, @@ -134,8 +135,3 @@ def create_backend( torch_executed_ops=torch_executed_ops, pass_through_build_failures=pass_through_build_failures, ) - - return partial( - torch_tensorrt_backend, - settings=settings, - ) diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index cf869562b6..04de9c8ce9 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -77,13 +77,12 @@ def _pretraced_backend( ) return trt_compiled except: - logger.error( - "FX2TRT conversion failed on the subgraph. See trace above. " - + "Returning GraphModule forward instead.", - exc_info=True, - ) - if not settings.pass_through_build_failures: + logger.warning( + "TRT conversion failed on the subgraph. See trace above. " + + "Returning GraphModule forward instead.", + exc_info=True, + ) return gm.forward else: raise AssertionError(