diff --git a/.circleci/config.yml b/.circleci/config.yml index 76727084e1..a7d799eb9d 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -521,6 +521,7 @@ commands: - store_artifacts: path: /tmp/testlogs +# =================== FX tests start ======================== # test-fx_core: description: "Test the fx core" steps: @@ -720,6 +721,61 @@ commands: - store_artifacts: path: /tmp/testlogs +# =================== FX tests end ======================== # + +# =================== Dynamo tests start ======================== # + test-dynamo-fx_ts: + description: "Test the Dynamo fx_ts_compat path" + steps: + - run: + name: Run Dynamo fx_ts_compat core tests + command: | + cd py/torch_tensorrt/dynamo/fx_ts_compat/test + pushd core/ + pytest --junitxml=/tmp/artifacts/test_results/dynamo/fx_ts_compat/test_results.xml + popd + + - store_test_results: + path: /tmp/artifacts + - store_artifacts: + path: /tmp/testlogs + + test-dynamo-torch_compile-core: + description: "Test the Dynamo torch_compile path" + steps: + - run: + name: Run Dynamo torch_compile core tests + command: | + cd py/torch_tensorrt/dynamo/torch_compile + pushd test/ + pytest --junitxml=/tmp/artifacts/test_results/dynamo/torch_compile/test_results.xml + popd + + - store_test_results: + path: /tmp/artifacts + - store_artifacts: + path: /tmp/testlogs + + test-dynamo-torch_compile: + description: "Test the Dynamo torch_compile path" + steps: + - run: + name: Run Dynamo torch_compile E2E tests + command: | + cd py/torch_tensorrt/dynamo/ + pushd test/ + pip3 install timm + pip3 install transformers + pytest --junitxml=/tmp/artifacts/test_results/dynamo/torch_compile/test_results.xml --ir torch_compile + popd + + - store_test_results: + path: /tmp/artifacts + - store_artifacts: + path: /tmp/testlogs + +# =================== Dynamo tests end ======================== # + # Define a job to be invoked later in a workflow. # See: https://circleci.com/docs/2.0/configuration-reference/#jobs jobs: @@ -911,6 +967,43 @@ jobs: - dump-test-env - test-fx-no-aten + test-py-dynamo-x86_64-linux: + parameters: + torch-build: + type: string + torch-build-index: + type: string + trt-version-long: + type: string + python-version: + type: string + machine: + image: linux-cuda-11:2023.02.1 + resource_class: gpu.nvidia.large + steps: + - checkout + - setup-py-version: + python-version: << parameters.python-version >> + - attach_workspace: + at: /tmp/dist/ + - install-torch-from-index: + torch-build: << parameters.torch-build >> + torch-build-index: << parameters.torch-build-index >> + - create-py-env: + trt-version-long: << parameters.trt-version-long >> + - install-cudnn + # - run: + # name: "Set LD_LIBRARY_PATH path to include the installed CUDNN" + # command: export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu/:$LD_LIBRARY_PATH + - run: + name: "Install torch-tensorrt" + command: pip3 install --pre /tmp/dist/x86_64-linux/*cp39-cp39*.whl + # We install torch after torch-trt because pip automatically enforces the version constraint otherwise + - dump-test-env + - test-dynamo-torch_compile + - test-dynamo-torch_compile-core + - test-dynamo-fx_ts + package-x86_64-linux: parameters: enabled: @@ -1300,6 +1393,14 @@ workflows: requires: - build-x86_64-linux + - test-py-dynamo-x86_64-linux: + torch-build: << pipeline.parameters.torch-build >> + torch-build-index: << pipeline.parameters.torch-build-index >> + trt-version-long: << pipeline.parameters.trt-version-long >> + python-version: << pipeline.parameters.python-version >> + requires: + - build-x86_64-linux + - build-x86_64-linux: name: build-x86_64-linux-legacy torch-build: << pipeline.parameters.torch-build-legacy >> @@ -1374,6 +1475,14 @@ workflows: requires: - package-x86_64-linux + - test-py-dynamo-x86_64-linux: + torch-build: << pipeline.parameters.torch-build >> + torch-build-index: << pipeline.parameters.torch-build-index >> + trt-version-long: << pipeline.parameters.trt-version-long >> + python-version: << pipeline.parameters.python-version >> + requires: + - package-x86_64-linux + on-push: jobs: - build-x86_64-linux: @@ -1407,6 +1516,14 @@ workflows: requires: - build-x86_64-linux + - test-py-dynamo-x86_64-linux: + torch-build: << pipeline.parameters.torch-build >> + torch-build-index: << pipeline.parameters.torch-build-index >> + trt-version-long: << pipeline.parameters.trt-version-long >> + python-version: << pipeline.parameters.python-version >> + requires: + - build-x86_64-linux + - build-x86_64-linux-cmake: torch-build: << pipeline.parameters.torch-build >> torch-build-index: << pipeline.parameters.torch-build-index >> diff --git a/py/setup.py b/py/setup.py index d88ae4fc04..7e44695cf4 100644 --- a/py/setup.py +++ b/py/setup.py @@ -362,6 +362,10 @@ def run(self): "torch_tensorrt.fx.tools", "torch_tensorrt.fx.tracer.acc_tracer", "torch_tensorrt.fx.tracer.dispatch_tracer", + "torch_tensorrt.dynamo", + "torch_tensorrt.dynamo.fx_ts_compat", + "torch_tensorrt.dynamo.fx_ts_compat.passes", + "torch_tensorrt.dynamo.fx_ts_compat.tools", ] package_dir = { "torch_tensorrt.fx": "torch_tensorrt/fx", @@ -370,11 +374,47 @@ def run(self): "torch_tensorrt.fx.tools": "torch_tensorrt/fx/tools", "torch_tensorrt.fx.tracer.acc_tracer": "torch_tensorrt/fx/tracer/acc_tracer", "torch_tensorrt.fx.tracer.dispatch_tracer": "torch_tensorrt/fx/tracer/dispatch_tracer", + "torch_tensorrt.dynamo": "torch_tensorrt/dynamo", + "torch_tensorrt.dynamo.fx_ts_compat": "torch_tensorrt/dynamo/fx_ts_compat", + "torch_tensorrt.dynamo.fx_ts_compat.passes": "torch_tensorrt/dynamo/fx_ts_compat/passes", + "torch_tensorrt.dynamo.fx_ts_compat.tools": "torch_tensorrt/dynamo/fx_ts_compat/tools", } with open("README.md", "r", encoding="utf-8") as fh: long_description = fh.read() +if FX_ONLY: + package_data_list = [ + "_Input.py", + ] +else: + package_data_list = [ + "lib/*", + "include/torch_tensorrt/*.h", + "include/torch_tensorrt/core/*.h", + "include/torch_tensorrt/core/conversion/*.h", + "include/torch_tensorrt/core/conversion/conversionctx/*.h", + "include/torch_tensorrt/core/conversion/converters/*.h", + "include/torch_tensorrt/core/conversion/evaluators/*.h", + "include/torch_tensorrt/core/conversion/tensorcontainer/*.h", + "include/torch_tensorrt/core/conversion/var/*.h", + "include/torch_tensorrt/core/ir/*.h", + "include/torch_tensorrt/core/lowering/*.h", + "include/torch_tensorrt/core/lowering/passes/*.h", + "include/torch_tensorrt/core/partitioning/*.h", + "include/torch_tensorrt/core/partitioning/segmentedblock/*.h", + "include/torch_tensorrt/core/partitioning/partitioninginfo/*.h", + "include/torch_tensorrt/core/partitioning/partitioningctx/*.h", + "include/torch_tensorrt/core/plugins/*.h", + "include/torch_tensorrt/core/plugins/impl/*.h", + "include/torch_tensorrt/core/runtime/*.h", + "include/torch_tensorrt/core/util/*.h", + "include/torch_tensorrt/core/util/logging/*.h", + "bin/*", + "BUILD", + "WORKSPACE", + ] + setup( name="torch_tensorrt", version=__version__, @@ -418,32 +458,7 @@ def run(self): python_requires=">=3.8", include_package_data=True, package_data={ - "torch_tensorrt": [ - "lib/*", - "include/torch_tensorrt/*.h", - "include/torch_tensorrt/core/*.h", - "include/torch_tensorrt/core/conversion/*.h", - "include/torch_tensorrt/core/conversion/conversionctx/*.h", - "include/torch_tensorrt/core/conversion/converters/*.h", - "include/torch_tensorrt/core/conversion/evaluators/*.h", - "include/torch_tensorrt/core/conversion/tensorcontainer/*.h", - "include/torch_tensorrt/core/conversion/var/*.h", - "include/torch_tensorrt/core/ir/*.h", - "include/torch_tensorrt/core/lowering/*.h", - "include/torch_tensorrt/core/lowering/passes/*.h", - "include/torch_tensorrt/core/partitioning/*.h", - "include/torch_tensorrt/core/partitioning/segmentedblock/*.h", - "include/torch_tensorrt/core/partitioning/partitioninginfo/*.h", - "include/torch_tensorrt/core/partitioning/partitioningctx/*.h", - "include/torch_tensorrt/core/plugins/*.h", - "include/torch_tensorrt/core/plugins/impl/*.h", - "include/torch_tensorrt/core/runtime/*.h", - "include/torch_tensorrt/core/util/*.h", - "include/torch_tensorrt/core/util/logging/*.h", - "bin/*", - "BUILD", - "WORKSPACE", - ], + "torch_tensorrt": package_data_list, }, exclude_package_data={ "": ["*.cpp"], diff --git a/py/torch_tensorrt/_Device.py b/py/torch_tensorrt/_Device.py index 0662e17aa1..3eaa5aad4e 100644 --- a/py/torch_tensorrt/_Device.py +++ b/py/torch_tensorrt/_Device.py @@ -1,11 +1,17 @@ import torch -from torch_tensorrt import _enums +# from torch_tensorrt import _enums +import tensorrt as trt from torch_tensorrt import logging -from torch_tensorrt import _C - import warnings +try: + from torch_tensorrt import _C +except: + warnings.warn( + "Unable to import torchscript frontend core and torch-tensorrt runtime. Some dependent features may be unavailable." + ) + class Device(object): """ @@ -51,7 +57,7 @@ def __init__(self, *args, **kwargs): ) else: (self.device_type, id) = Device._parse_device_str(args[0]) - if self.device_type == _enums.DeviceType.GPU: + if self.device_type == trt.DeviceType.GPU: self.gpu_id = id else: self.dla_core = id @@ -64,7 +70,7 @@ def __init__(self, *args, **kwargs): elif len(args) == 0: if "gpu_id" in kwargs or "dla_core" in kwargs: if "dla_core" in kwargs: - self.device_type = _enums.DeviceType.DLA + self.device_type = trt.DeviceType.DLA self.dla_core = kwargs["dla_core"] if "gpu_id" in kwargs: self.gpu_id = kwargs["gpu_id"] @@ -76,7 +82,7 @@ def __init__(self, *args, **kwargs): ) else: self.gpu_id = kwargs["gpu_id"] - self.device_type = _enums.DeviceType.GPU + self.device_type = trt.DeviceType.GPU else: raise ValueError( "Either gpu_id or dla_core or both must be defined if no string with device specs is provided as an arg" @@ -97,7 +103,7 @@ def __init__(self, *args, **kwargs): def __str__(self) -> str: return ( "Device(type={}, gpu_id={}".format(self.device_type, self.gpu_id) + ")" - if self.device_type == _enums.DeviceType.GPU + if self.device_type == trt.DeviceType.GPU else ", dla_core={}, allow_gpu_fallback={}".format( self.dla_core, self.allow_gpu_fallback ) @@ -105,7 +111,15 @@ def __str__(self) -> str: def _to_internal(self) -> _C.Device: internal_dev = _C.Device() - internal_dev.device_type = self.device_type + if self.device_type == trt.DeviceType.GPU: + internal_dev.device_type = _C.DeviceType.GPU + elif self.device_type == trt.DeviceType.DLA: + internal_dev.device_type = _C.DeviceType.DLA + else: + raise ValueError( + "Invalid DeviceType detected while parsing the Device class" + ) + internal_dev.gpu_id = self.gpu_id internal_dev.dla_core = self.dla_core internal_dev.allow_gpu_fallback = self.allow_gpu_fallback @@ -136,6 +150,6 @@ def _parse_device_str(s): s = s.lower() spec = s.split(":") if spec[0] == "gpu" or spec[0] == "cuda": - return (_enums.DeviceType.GPU, int(spec[1])) + return (trt.DeviceType.GPU, int(spec[1])) elif spec[0] == "dla": - return (_enums.DeviceType.DLA, int(spec[1])) + return (trt.DeviceType.DLA, int(spec[1])) diff --git a/py/torch_tensorrt/_Input.py b/py/torch_tensorrt/_Input.py index 324c385fab..e76817e041 100644 --- a/py/torch_tensorrt/_Input.py +++ b/py/torch_tensorrt/_Input.py @@ -4,7 +4,6 @@ import torch from torch_tensorrt import _enums -from torch_tensorrt import _C class Input(object): @@ -41,6 +40,7 @@ class _ShapeMode(Enum): DOMAIN_OFFSET = 2.0 low_tensor_domain_incl = 0.0 high_tensor_domain_excl = low_tensor_domain_incl + DOMAIN_OFFSET + torch_dtype = torch.float32 def __init__(self, *args, **kwargs): """__init__ Method for torch_tensorrt.Input @@ -138,7 +138,11 @@ def __init__(self, *args, **kwargs): ) if "dtype" in kwargs: + if isinstance(kwargs["dtype"], torch.dtype): + self.torch_dtype = kwargs["dtype"] + self.dtype = Input._parse_dtype(kwargs["dtype"]) + self.torch_dtype = Input._to_torch_dtype(self.dtype) self._explicit_set_dtype = True if "format" in kwargs: @@ -173,59 +177,6 @@ def __str__(self) -> str: else: raise RuntimeError("Unknown input shape mode") - def _to_internal(self) -> _C.Input: - internal_in = _C.Input() - if self.shape_mode == Input._ShapeMode.DYNAMIC: - if not Input._supported_input_size_type(self.shape["min_shape"]): - raise TypeError( - "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " - + str(type(self.shape["min_shape"])) - + " for min_shape" - ) - else: - internal_in.min = self.shape["min_shape"] - - if not Input._supported_input_size_type(self.shape["opt_shape"]): - raise TypeError( - "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " - + str(type(self.shape["opt_shape"])) - + " for opt_shape" - ) - else: - internal_in.opt = self.shape["opt_shape"] - - if not Input._supported_input_size_type(self.shape["max_shape"]): - raise TypeError( - "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " - + str(type(self.shape["max_shape"])) - + " for max_shape" - ) - else: - internal_in.max = self.shape["max_shape"] - internal_in.input_is_dynamic = True - else: - if not Input._supported_input_size_type(self.shape): - raise TypeError( - "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " - + str(type(self.shape)) - + " for shape" - ) - else: - internal_in.opt = self.shape - internal_in.input_is_dynamic = False - - if self.dtype != _enums.dtype.unknown: - self._explicit_set_dtype = True - else: - self._explicit_set_dtype = False - - internal_in.dtype = Input._parse_dtype(self.dtype) - internal_in._explicit_set_dtype = self._explicit_set_dtype - internal_in.format = Input._parse_format(self.format) - - internal_in.tensor_domain = Input._parse_tensor_domain(self.tensor_domain) - return internal_in - @staticmethod def _supported_input_size_type(input_size: Any) -> bool: if isinstance(input_size, torch.Size): @@ -265,6 +216,22 @@ def _parse_dtype(dtype: Any) -> _enums.dtype: + str(type(dtype)) ) + @staticmethod + def _to_torch_dtype(dtype: _enums.dtype) -> torch.dtype: + if dtype == _enums.dtype.long: + return torch.long + elif dtype == _enums.dtype.int32: + return torch.int32 + elif dtype == _enums.dtype.half: + return torch.half + elif dtype == _enums.dtype.float: + return torch.float + elif dtype == _enums.dtype.bool: + return torch.bool + else: + # Default torch_dtype used in FX path + return torch.float32 + def is_trt_dtype(self) -> bool: return self.dtype != _enums.dtype.long @@ -304,6 +271,7 @@ def _parse_tensor_domain(domain: Optional[Tuple[float, float]]) -> Tuple: Input.low_tensor_domain_incl, Input.high_tensor_domain_excl, ) + elif len(domain) == 2: domain_lo, domain_hi = domain @@ -416,8 +384,8 @@ def example_tensor(self, optimization_profile_field: str = None) -> torch.Tensor ) if self.shape_mode == Input._ShapeMode.STATIC: - return torch.randn(self.shape).to(dtype=self.dtype) + return torch.rand(self.shape).to(dtype=self.torch_dtype) else: - return torch.randn(self.shape[optimization_profile_field]).to( - dtype=self.dtype + return torch.rand(self.shape[optimization_profile_field]).to( + dtype=self.torch_dtype ) diff --git a/py/torch_tensorrt/__init__.py b/py/torch_tensorrt/__init__.py index 6447c4d537..b8e4fd0d9d 100644 --- a/py/torch_tensorrt/__init__.py +++ b/py/torch_tensorrt/__init__.py @@ -4,6 +4,7 @@ import sys import platform import warnings +from packaging import version from torch_tensorrt._version import ( __version__, __cuda_version__, @@ -94,6 +95,10 @@ def _find_lib(name, paths): from torch_tensorrt import fx +if version.parse(torch.__version__) >= version.parse("2.dev"): + from torch_tensorrt import dynamo + from torch_tensorrt.dynamo import torch_compile + def _register_with_torch(): trtorch_dir = os.path.dirname(__file__) diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index cbd1b87c5c..e300669fd5 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -1,6 +1,6 @@ from typing import List, Dict, Any -from torch_tensorrt import _enums import torch_tensorrt.ts + from torch_tensorrt import logging import torch import torch.fx @@ -15,6 +15,8 @@ class _IRType(Enum): ts = 0 fx = 1 + fx_ts_compat = 2 + torch_compile = 3 class _ModuleType(Enum): @@ -45,11 +47,17 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType: ir_targets_torchscript = any([ir == opt for opt in ["torchscript", "ts"]]) ir_targets_fx = ir == "fx" + ir_targets_torch_compile = ir == "torch_compile" + ir_targets_fx_ts_compat = ir == "fx_ts_compat" if module_is_tsable and ir_targets_torchscript: return _IRType.ts elif module_is_fxable and ir_targets_fx: return _IRType.fx + elif module_is_fxable and ir_targets_fx_ts_compat: + return _IRType.fx_ts_compat + elif module_is_fxable and ir_targets_torch_compile: + return _IRType.torch_compile else: if ir == "default": # Options are listed in order of preference @@ -74,7 +82,7 @@ def compile( module: Any, ir="default", inputs=[], - enabled_precisions=set([_enums.dtype.float]), + enabled_precisions=set([torch.float]), **kwargs, ): """Compile a PyTorch module for NVIDIA GPUs using TensorRT @@ -148,6 +156,14 @@ def compile( dynamic_batch=False, **kwargs, ) + elif target_ir == _IRType.torch_compile: + return torch_tensorrt.dynamo.torch_compile( + module, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs + ) + elif target_ir == _IRType.fx_ts_compat: + return torch_tensorrt.dynamo.fx_ts_compat.compile( + module, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs + ) else: raise RuntimeError("Module is an unknown format or the ir requested is unknown") @@ -157,7 +173,7 @@ def convert_method_to_trt_engine( method_name: str, ir="default", inputs=[], - enabled_precisions=set([_enums.dtype.float]), + enabled_precisions=set([torch.float]), **kwargs, ): """Convert a TorchScript module method to a serialized TensorRT engine diff --git a/py/torch_tensorrt/_enums.py b/py/torch_tensorrt/_enums.py index bc9ed42df4..63dffceb9d 100644 --- a/py/torch_tensorrt/_enums.py +++ b/py/torch_tensorrt/_enums.py @@ -1 +1,2 @@ -from torch_tensorrt._C import dtype, DeviceType, EngineCapability, TensorFormat +from torch_tensorrt._C import dtype, EngineCapability, TensorFormat +from tensorrt import DeviceType diff --git a/py/torch_tensorrt/dynamo/__init__.py b/py/torch_tensorrt/dynamo/__init__.py new file mode 100644 index 0000000000..26e8b7aa3e --- /dev/null +++ b/py/torch_tensorrt/dynamo/__init__.py @@ -0,0 +1,2 @@ +from torch_tensorrt.dynamo import fx_ts_compat +from .torch_compile import compile as torch_compile diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/README.md b/py/torch_tensorrt/dynamo/fx_ts_compat/README.md new file mode 100644 index 0000000000..d2a9e295a3 --- /dev/null +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/README.md @@ -0,0 +1,13 @@ +The code in this directory is similar to `torch_tensorrrt.fx`. We intend to make changes under `dynamo` namespace to ensure we +have the same top level API as `torch_tensorrt.ts.compile`. Right now, the usage is as follows + +``` +import torch_tensorrt +trt_module = torch_tensorrt.compile( + module, + ir="dynamo" + torchtrt_inputs, + enabled_precisions={torch.float32}, + ) +``` +This will internally call `torch_tensorrt.dynamo.compile` which has the same signature as `torch_tensorrt.ts.compile`. We intend to add features (existing in Torchscript backend for eg: torch_executed_ops, torch_executed_modules and many more) to this dynamo backend in the coming months. diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/__init__.py b/py/torch_tensorrt/dynamo/fx_ts_compat/__init__.py new file mode 100644 index 0000000000..85ce01ef20 --- /dev/null +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/__init__.py @@ -0,0 +1,14 @@ +import logging + +from torch_tensorrt.fx.converter_registry import ( # noqa + CONVERTERS, + NO_EXPLICIT_BATCH_DIM_SUPPORT, + NO_IMPLICIT_BATCH_DIM_SUPPORT, + tensorrt_converter, +) +from .fx2trt import TRTInterpreter, TRTInterpreterResult # noqa +from .input_tensor_spec import InputTensorSpec # noqa +from .lower_setting import LowerSetting # noqa +from .lower import compile # usort: skip #noqa + +logging.basicConfig(level=logging.INFO) diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py b/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py new file mode 100644 index 0000000000..b5165c6f2d --- /dev/null +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py @@ -0,0 +1,385 @@ +import logging +import warnings +from datetime import datetime +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence + +import numpy + +# @manual=//deeplearning/trt/python:py_tensorrt +import tensorrt as trt +import torch +import torch.fx +from torch._ops import OpOverload +from torch.fx.node import _get_qualified_name +from torch.fx.passes.shape_prop import TensorMetadata + +from torch_tensorrt.dynamo.fx_ts_compat import CONVERTERS +from .input_tensor_spec import InputTensorSpec +from torch_tensorrt.fx.observer import Observer +from torch_tensorrt.fx.utils import get_dynamic_dims, LowerPrecision, torch_dtype_to_trt + +_LOGGER: logging.Logger = logging.getLogger(__name__) + +TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[ + Callable[[torch.fx.GraphModule], None] +] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER") + + +class TRTInterpreterResult(NamedTuple): + engine: Any + input_names: Sequence[str] + output_names: Sequence[str] + serialized_cache: bytearray + + +class TRTInterpreter(torch.fx.Interpreter): + def __init__( + self, + module: torch.fx.GraphModule, + input_specs: List[InputTensorSpec], + explicit_batch_dimension: bool = True, + explicit_precision: bool = False, + logger_level=None, + ): + super().__init__(module) + + self.logger = trt.Logger(logger_level or trt.Logger.WARNING) + self.builder = trt.Builder(self.logger) + + flag = 0 + if explicit_batch_dimension: + EXPLICIT_BATCH = 1 << (int)( + trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH + ) + flag |= EXPLICIT_BATCH + + if explicit_precision: + EXPLICIT_PRECISION = 1 << (int)( + trt.NetworkDefinitionCreationFlag.EXPLICIT_PRECISION + ) + flag |= EXPLICIT_PRECISION + self.network = self.builder.create_network(flag) + + missing_ops = self.validate_conversion() + if missing_ops: + warnings.warn( + "Interpretation will fail due to missing operations \n" + + "\n".join(f"{i}" for i in missing_ops) + ) + + self.optimization_profiles: Optional[List] = None + self.input_specs = input_specs + self.input_specs_iter = 0 + self.validate_input_specs() + self._cur_node_name: Optional[str] = None + self._input_names: List[str] = [] + self._output_names: List[str] = [] + self._itensor_to_tensor_meta: Dict[ + trt.tensorrt.ITensor, TensorMetadata + ] = dict() + + def validate_input_specs(self): + for shape, _, _, shape_ranges, has_batch_dim in self.input_specs: + if not self.network.has_implicit_batch_dimension: + assert ( + has_batch_dim + ), "It's required to specify batch dimension when it's explicit in TensorRT network." + + dynamic_dims = get_dynamic_dims(shape) + if len(dynamic_dims): + assert not self.network.has_implicit_batch_dimension, ( + "Can't have dynamic dim when " + f"batch dim is implicit, got {shape}." + ) + assert len( + shape_ranges + ), "shape_ranges must be provided when shape has dynamic dim." + + if self.optimization_profiles: + assert len(shape_ranges) == len(self.optimization_profiles), ( + "Number of optimization " + f"profiles {len(self.optimization_profiles)} doesn't match with the number of shape_range" + f" {len(shape_ranges)} provided." + ) + else: + self.optimization_profiles = [ + self.builder.create_optimization_profile() + for _ in range(len(shape_ranges)) + ] + + for shape_range in shape_ranges: + assert ( + len(shape_range) == 3 + ), f"Expect three elements in shape_range, got {len(shape_range)}" + assert all(len(s) == len(shape) for s in shape_range), ( + "Expect elements in shape_range" + f" {shape_range} have the same number of dimension as the provided shape {len(shape)}" + ) + + for i in range(len(shape)): + if i in dynamic_dims: + assert all( + shape_range[j][i] <= shape_range[j + 1][i] + for j in range(2) + ), ( + "Expect dynamic dim" + f" {i} to have incremental value for shapes in shape_range {shape_range}." + ) + else: + assert all(s[i] == shape[i] for s in shape_range), ( + f"Expect non dynamic dim {i} to be the same" + f" for all shapes in shape_range {shape_range}." + ) + else: + assert ( + len(shape_ranges) == 0 + ), "shape_ranges are provided for input that doesn't have dynamic dim." + + def validate_conversion(self): + missing_converter = set() + + for node in self.module.graph.nodes: + if node.op == "call_function" and not CONVERTERS.get(node.target): + missing_converter.add(f"{node.op} {_get_qualified_name(node.target)}") + elif node.op == "call_method" and not CONVERTERS.get(node.target): + missing_converter.add(f"{node.op} torch.Tensor.{node.target}") + elif node.op == "call_module": + submod = self.fetch_attr(node.target) + submod_type = getattr(submod, "_base_class_origin", type(submod)) + if not CONVERTERS.get(submod_type): + missing_converter.add(f"{node.op} {torch.typename(submod_type)}") + + return missing_converter + + def run( + self, + workspace_size=0, + lower_precision=LowerPrecision.FP16, + sparse_weights=False, + disable_tf32=False, + force_fp32_output=False, + strict_type_constraints=False, + algorithm_selector=None, + timing_cache=None, + profiling_verbosity=None, + tactic_sources=None, + ) -> TRTInterpreterResult: + """ + Build TensorRT engine with some configs. + Args: + workspace_size: Amount of memory used by TensorRT to store intermediate buffers within an operation. + lower_precision: the precision model layers are running on (TensorRT will choose the best perforamnce precision). + sparse_weights: allow the builder to examine weights and use optimized functions when weights have suitable sparsity + force_fp32_output: force output to be fp32 + strict_type_constraints: Usually we should set it to False unless we want to control the precision of certain layer for numeric reasons. + algorithm_selector: set up algorithm selection for certain layer + timing_cache: enable timing cache for TensorRT + profiling_verbosity: TensorRT logging level + Return: + TRTInterpreterResult + """ + TRT_INTERPRETER_CALL_PRE_OBSERVER.observe(self.module) + + # For float outputs, we set their dtype to fp16 only if lower_precision == LowerPrecision.FP16 and + # force_fp32_output=False. + self.output_fp16 = ( + not force_fp32_output and lower_precision == LowerPrecision.FP16 + ) + + if ( + lower_precision == LowerPrecision.INT8 + and not self.builder.platform_has_fast_int8 + ): + raise RuntimeError("Current platform doesn't support fast native int8!") + + if ( + lower_precision == LowerPrecision.FP16 + and not self.builder.platform_has_fast_fp16 + ): + warnings.warn("Current platform doesn't support fast native fp16!") + + self.input_specs_iter = 0 + run_module_start_time = datetime.now() + super().run() + _LOGGER.info( + f"TRT INetwork construction elapsed time: {datetime.now() - run_module_start_time}" + ) + build_engine_start_time = datetime.now() + + builder_config = self.builder.create_builder_config() + + if workspace_size != 0: + builder_config.set_memory_pool_limit( + trt.MemoryPoolType.WORKSPACE, workspace_size + ) + + cache = None + if timing_cache: + cache_file = numpy.array(timing_cache) + cache = builder_config.create_timing_cache(cache_file.tobytes()) + else: + cache = builder_config.create_timing_cache(b"") + builder_config.set_timing_cache(cache, False) + + if trt.__version__ >= "8.2": + builder_config.profiling_verbosity = ( + profiling_verbosity + if profiling_verbosity + else trt.ProfilingVerbosity.LAYER_NAMES_ONLY + ) + if lower_precision == LowerPrecision.FP16: + builder_config.set_flag(trt.BuilderFlag.FP16) + + if lower_precision == LowerPrecision.INT8: + builder_config.set_flag(trt.BuilderFlag.INT8) + + if sparse_weights: + builder_config.set_flag(trt.BuilderFlag.SPARSE_WEIGHTS) + + if disable_tf32: + builder_config.clear_flag(trt.BuilderFlag.TF32) + + if strict_type_constraints: + builder_config.set_flag(trt.BuilderFlag.STRICT_TYPES) + + if self.optimization_profiles: + for optimization_profile in self.optimization_profiles: + builder_config.add_optimization_profile(optimization_profile) + + if algorithm_selector: + builder_config.set_flag(trt.BuilderFlag.DISABLE_TIMING_CACHE) + builder_config.algorithm_selector = algorithm_selector + + if tactic_sources is not None: + builder_config.set_tactic_sources(tactic_sources=tactic_sources) + + engine = self.builder.build_engine(self.network, builder_config) + assert engine + + serialized_cache = ( + bytearray(cache.serialize()) + if builder_config.get_timing_cache() + else bytearray() + ) + _LOGGER.info( + f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}" + ) + + return TRTInterpreterResult( + engine, self._input_names, self._output_names, serialized_cache + ) + + def run_node(self, n): + self._cur_node_name = str(n) + # add "_itensor_to_tensor_meta" + kwargs = dict(n.kwargs) + kwargs["_itensor_to_tensor_meta"] = self._itensor_to_tensor_meta + n.kwargs = kwargs + + # run the node + trt_node = super().run_node(n) + + # remove "_itensor_to_tensor_meta" + kwargs = dict(n.kwargs) + del kwargs["_itensor_to_tensor_meta"] + n.kwargs = kwargs + + if isinstance(trt_node, trt.tensorrt.ITensor): + self._itensor_to_tensor_meta[trt_node] = n.meta.get("tensor_meta") + + return trt_node + + def placeholder(self, target, args, kwargs): + self._input_names.append(target) + shape, dtype, _, shape_ranges, has_batch_dim = self.input_specs[ + self.input_specs_iter + ] + self.input_specs_iter += 1 + + if self.network.has_implicit_batch_dimension: + if has_batch_dim: + shape = shape[1:] + else: + for i, shape_range in enumerate(shape_ranges): + assert self.optimization_profiles + self.optimization_profiles[i].set_shape(target, *shape_range) + + return self.network.add_input( + name=target, shape=tuple(shape), dtype=torch_dtype_to_trt(dtype) + ) + + def call_module(self, target, args, kwargs): + assert isinstance(target, str) + submod = self.fetch_attr(target) + submod_type = getattr(submod, "_base_class_origin", type(submod)) + converter = CONVERTERS.get(submod_type) + + if not converter: + raise RuntimeError( + f"Conversion of module of type {submod_type} not currently supported!" + ) + + assert self._cur_node_name is not None + return converter(self.network, submod, args, kwargs, self._cur_node_name) + + def call_function(self, target, args, kwargs): + converter = CONVERTERS.get(target) + if not converter: + raise RuntimeError( + f"Conversion of function {torch.typename(target)} not currently supported!" + ) + + assert self._cur_node_name is not None + return converter(self.network, target, args, kwargs, self._cur_node_name) + + def call_method(self, target, args, kwargs): + assert isinstance(target, str) + converter = CONVERTERS.get(target) + + if not converter: + raise RuntimeError( + f"Conversion of method {target} not currently supported!" + ) + + assert self._cur_node_name is not None + return converter(self.network, target, args, kwargs, self._cur_node_name) + + def output(self, target, args, kwargs): + assert len(args) == 1 + if isinstance(args[0], tuple): + outputs = args[0] + elif isinstance(args[0], list): + outputs = tuple(args[0]) + else: + outputs = (args[0],) + + if not all(isinstance(output, trt.tensorrt.ITensor) for output in outputs): + raise RuntimeError("TensorRT requires all outputs to be Tensor!") + + for i, output in enumerate(outputs): + if any( + op_name in output.name.split("_") + for op_name in ( + "eq", + "gt", + "lt", + "or", + "xor", + "and", + "not", + "ne", + "isinf", + "any", + ) + ): + output_bool = True + else: + output_bool = False + name = f"output{i}" + output.name = name + self.network.mark_output(output) + if output_bool: + output.dtype = trt.bool + elif self.output_fp16 and output.dtype == trt.float32: + output.dtype = trt.float16 + self._output_names.append(name) diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/input_tensor_spec.py b/py/torch_tensorrt/dynamo/fx_ts_compat/input_tensor_spec.py new file mode 100644 index 0000000000..7f67e8abbf --- /dev/null +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/input_tensor_spec.py @@ -0,0 +1,181 @@ +from typing import Iterable, List, NamedTuple, Optional, Sequence, Tuple + +import torch + +from torch_tensorrt.fx.types import Shape, ShapeRange +from torch_tensorrt.fx.utils import get_dynamic_dims +from torch_tensorrt._Input import Input + + +class InputTensorSpec(NamedTuple): + """ + This class contains the information of a input tensor. + + shape: shape of the tensor. + + dtype: dtyep of the tensor. + + device: device of the tensor. This is only used to generate inputs to the given model + in order to run shape prop. For TensorRT engine, inputs have to be on cuda device. + + shape_ranges: If dynamic shape is needed (shape has dimensions of -1), then this field + has to be provided (default is empty list). Every shape_range is a tuple of three + tuples ((min_input_shape), (optimized_input_shape), (max_input_shape)). Each shape_range + is used to populate a TensorRT optimization profile. + e.g. If the input shape varies from (1, 224) to (100, 224) and we want to optimize + for (25, 224) because it's the most common input shape, then we set shape_ranges to + ((1, 224), (25, 225), (100, 224)). + + has_batch_dim: Whether the shape includes batch dimension. Batch dimension has to be provided + if the engine want to run with dynamic shape. + """ + + shape: Shape + dtype: torch.dtype + device: torch.device = torch.device("cpu") + shape_ranges: List[ShapeRange] = [] + has_batch_dim: bool = True + + @classmethod + def from_tensor(cls, tensor: torch.Tensor) -> "InputTensorSpec": + """ + Produce an InputTenosrSpec named tuple which contains the + information of the given PyTorch tensor. + + Args: + tensor (torch.Tensor): A PyTorch tensor. + + Returns: + An InputTensorSpec named tuple. + """ + return cls(tensor.shape, tensor.dtype, tensor.device) + + @classmethod + def from_tensors(cls, tensors: Sequence[torch.Tensor]) -> List["InputTensorSpec"]: + """ + Produce a list of InputTenosrSpec named tuples which contain + the information of all the given PyTorch tensors. + + Args: + tensors (Iterable[torch.Tensor]): A list of PyTorch tensors. + + Returns: + A list of InputTensorSpec named tuples. + """ + assert isinstance(tensors, (list, tuple)) + return [cls.from_tensor(t) for t in tensors] + + @classmethod + def from_input(cls, input_obj: Input) -> "InputTensorSpec": + """ + Produce a list of InputTenosrSpec named tuples which contain + the information of all the given PyTorch tensors. + + Args: + tensors (Iterable[torch.Tensor]): A list of PyTorch tensors. + + Returns: + A list of InputTensorSpec named tuples. + """ + assert isinstance(input_obj, Input) + input_spec = None + if isinstance(input_obj.shape, dict): + min_shape = input_obj.shape["min_shape"] + opt_shape = input_obj.shape["opt_shape"] + max_shape = input_obj.shape["max_shape"] + dyn_shape = [] + for min, opt, max in zip(min_shape, opt_shape, max_shape): + if min == opt == max: + dyn_shape.append(min) + else: + dyn_shape.append(-1) + dtype = input_obj.torch_dtype + input_spec = cls( + shape=dyn_shape, + dtype=dtype, + shape_ranges=[(min_shape, opt_shape, max_shape)], + ) + else: + shape = input_obj.shape + dtype = input_obj.torch_dtype + input_spec = cls(shape=shape, dtype=dtype) + + return input_spec + + @classmethod + def from_tensors_with_dynamic_batch_size( + cls, + tensors: Sequence[torch.Tensor], + batch_size_range: Tuple[int, int, int], + opt_profile_replica: int = 1, + batch_dims: Optional[List[int]] = None, + ) -> List["InputTensorSpec"]: + """ + Produce a list of InputTenosrSpec named tuples which would contain + the information of all the given PyTorch tensors. The produced input + tensor specs will treat all tensors' first dimension as batch dimension + and mark them as dynmaic. + + Args: + tensors (Sequence[torch.Tensor]): A list of PyTorch tensors. + batch_size_range (Tuple[int, int, int]): The first integer indicates + the smallest batch size allowed. The second integer indiceates + the batch size that we'll optimize for. The third integer indicates + the largest batch size allowed. + opt_profile_replica (int): If dynamic shape is enabled, each execution + context requires a different optimization profile. This arg determines + how many optimization profile replicas we want to produce. + batch_dims (Optional[List[int]]): The batch dim might not be the leading dim + and allow user to specify the batch dims using this arg. Default we treat + dim 0 as the batch dim. + + Returns: + A list of InputTensorSpec named tuples with dynamic ranges. + """ + if batch_dims is None: + batch_dims = [0] * len(tensors) + + input_specs = [] + batch_size = tensors[0].size(batch_dims[0]) + + for i, tensor in enumerate(tensors): + batch_dim = batch_dims[i] + assert batch_size == tensor.size( + batch_dim + ), f"The {i}th tensor (shape: {tensor.shape}) doesn't have the correct batch size: {batch_size}." + shape = list(tensor.shape) + shape[batch_dim] = -1 + shape_ranges: List[ShapeRange] = [tuple(tuple(shape[0:batch_dim] + [bs] + shape[batch_dim + 1 :]) for bs in batch_size_range)] * opt_profile_replica # type: ignore[list-item] + input_specs.append( + cls(tuple(shape), tensor.dtype, tensor.device, shape_ranges) + ) + + return input_specs + + def to_random_tensor(self, id=1): + shape = tuple(self.shape) + if len(get_dynamic_dims(shape)): + # id=0 -> min shape + # id=1 -> optimal shape + # id=2 -> max shape + shape = tuple(self.shape_ranges[0][id]) + elif not self.has_batch_dim: + shape = (1,) + tuple(shape) + + return torch.randn(shape).to(dtype=self.dtype, device=self.device) + + @staticmethod + def create_inputs_from_specs(input_specs: Iterable["InputTensorSpec"]): + inputs = [] + for spec in input_specs: + inputs.append(spec.to_random_tensor()) + + return inputs + + @staticmethod + def create_inputs_from_max_specs(input_specs: Iterable["InputTensorSpec"]): + inputs = [] + for spec in input_specs: + inputs.append(spec.to_random_tensor(2)) + + return inputs diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/lower.py b/py/torch_tensorrt/dynamo/fx_ts_compat/lower.py new file mode 100644 index 0000000000..60ace0f12a --- /dev/null +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/lower.py @@ -0,0 +1,361 @@ +import dataclasses as dc +import logging +from typing import Any, Callable, Optional, Sequence + +# @manual=//deeplearning/trt/python:py_tensorrt +import tensorrt as trt +import torch +import torch.fx as fx +import torch.nn as nn +import torch_tensorrt.fx.tracer.dispatch_tracer.aten_tracer as aten_tracer +from torch.fx.passes.splitter_base import SplitResult + +from .fx2trt import TRTInterpreter, TRTInterpreterResult +from .lower_setting import LowerSetting +from .passes.lower_pass_manager_builder import LowerPassManagerBuilder +from .passes.pass_utils import PassFunc, validate_inference +from torch_tensorrt.fx.tools.timing_cache_utils import TimingCacheManager +from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter, TRTSplitterSetting + +from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer +from torch_tensorrt.fx.trt_module import TRTModule +from torch_tensorrt.fx.utils import LowerPrecision +from torch_tensorrt._Device import Device + +logger = logging.getLogger(__name__) + +Input = Sequence[Any] + + +def compile( + module: nn.Module, + inputs, + device=torch.device(torch.cuda.current_device()), + disable_tf32=False, + sparse_weights=False, + enabled_precisions=set(), + min_block_size: int = 3, + workspace_size=0, + dla_sram_size=1048576, + dla_local_dram_size=1073741824, + dla_global_dram_size=536870912, + calibrator=None, + truncate_long_and_double=False, + require_full_compilation=False, + debug=False, + refit=False, + timing_cache_prefix="", + save_timing_cache=False, + cuda_graph_batch_size=-1, + is_aten=False, + use_experimental_fx_rt=False, + num_avg_timing_iters=1, + torch_executed_ops=[], + torch_executed_modules=[], + **kwargs, +) -> nn.Module: + """ + Takes in original module, input and lowering setting, run lowering workflow to turn module + into lowered module, or so called TRTModule. + + Args: + module: Original module for lowering. + input: Input for module. + min_block_size: Minimal number of nodes for an accelerated submodule + workspace_size: Maximum size of workspace given to TensorRT. + debug: Enable verbose log for TensorRT if set True. + timing_cache_prefix: Timing cache file name for timing cache used by fx2trt. + save_timing_cache: Update timing cache with current timing cache data if set to True. + cuda_graph_batch_size: Cuda graph batch size, default to be -1. + use_experimental_fx_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++). + Returns: + A torch.nn.Module lowered by TensorRT. + """ + if use_experimental_fx_rt and not explicit_batch_dimension: + raise ValueError( + "The experimental unifed runtime only supports explicit batch. Please make sure to set explicit_batch_dimension=True when use_experimental_fx_rt=True" + ) + + logger.warn( + "For ir=fx_ts_compat backend only the " + + "following arguments are supported: " + + "{enabled_precisions, debug, workspace_size, device, disable_tf32, sparse_weights, min_block_size}" + ) + + # Parse precision into LowerPrecision + lower_precision = LowerPrecision.FP32 + if torch.float16 in enabled_precisions: + lower_precision = LowerPrecision.FP16 + elif torch.float32 in enabled_precisions: + lower_precision = LowerPrecision.FP32 + else: + raise ValueError(f"Precision {enabled_precisions} not supported on FX") + + # Parse device + if isinstance(device, Device): + if device.gpu_id != -1: + device = torch.device(device.gpu_id) + else: + raise ValueError("Invalid GPU ID provided for the CUDA device provided") + elif isinstance(device, torch.device): + device = device + elif isinstance(device, dict): + if "device_type" in device and device["device_type"] == trt.DeviceType.GPU: + if "gpu_id" in device: + device = torch.device(device["gpu_id"]) + else: + device = torch.device("cuda:0") + else: + raise ValueError( + "Invalid device provided. Supported options: torch.device | torch_tensorrt.Device" + ) + + lower_setting = LowerSetting( + device=device, + min_block_size=min_block_size, + disable_tf32=disable_tf32, + sparse_weights=sparse_weights, + workspace_size=workspace_size, + lower_precision=lower_precision, + debug=debug, + timing_cache_prefix=timing_cache_prefix, + save_timing_cache=save_timing_cache, + cuda_graph_batch_size=cuda_graph_batch_size, + is_aten=is_aten, + use_experimental_rt=use_experimental_fx_rt, + ) + lowerer = Lowerer.create(lower_setting=lower_setting) + return lowerer(module, inputs) + + +@dc.dataclass +class LowerTrtInterpreter: + lower_setting: LowerSetting + timing_cache_manager: TimingCacheManager + + @classmethod + def create(cls, lower_setting): + timing_cache_manager = TimingCacheManager( + lower_setting.timing_cache_prefix, lower_setting.save_timing_cache + ) + return LowerTrtInterpreter(lower_setting, timing_cache_manager) + + def __call__(self, mod, input, split_name) -> TRTInterpreterResult: + assert self.lower_setting.input_specs, "Can't find input specs for lowering!" + logger.info( + f"split_name={split_name}, input_specs={self.lower_setting.input_specs}" + ) + + # Prepare algorithm selector and timing_cache for TRTInterpreter + algo_selector = None + if self.lower_setting.algo_selector: + algo_selector = self.lower_setting.algo_selector(f"{split_name}.json") + cache_data = None + if self.timing_cache_manager: + try: + cache_data = self.timing_cache_manager.get_timing_cache_trt(split_name) + logger.info("Timing cache is used!") + except Exception as e: + logger.warning(f"Cannot load timing cache for {split_name}: {str(e)}") + cache_data = None + + interpreter = TRTInterpreter( + mod, + input_specs=self.lower_setting.input_specs, + explicit_batch_dimension=self.lower_setting.explicit_batch_dimension, + explicit_precision=self.lower_setting.explicit_precision, + logger_level=trt.Logger.VERBOSE + if self.lower_setting.debug + else trt.Logger.WARNING, + ) + + interp_result: TRTInterpreterResult = interpreter.run( + workspace_size=self.lower_setting.workspace_size, + lower_precision=self.lower_setting.lower_precision, + sparse_weights=self.lower_setting.sparse_weights, + disable_tf32=self.lower_setting.disable_tf32, + strict_type_constraints=self.lower_setting.strict_type_constraints, + algorithm_selector=algo_selector, + timing_cache=cache_data, + profiling_verbosity=trt.ProfilingVerbosity.DETAILED + if self.lower_setting.verbose_profile + else trt.ProfilingVerbosity.LAYER_NAMES_ONLY, + tactic_sources=self.lower_setting.tactic_sources, + ) + + # Update timing cache file if needed + timing_cache = interp_result.serialized_cache + if timing_cache and self.timing_cache_manager: + self.timing_cache_manager.update_timing_cache(split_name, timing_cache) + + return interp_result + + +def default_split_function( + model: fx.GraphModule, inputs: Input, lower_setting: LowerSetting +) -> SplitResult: + splitter_setting = TRTSplitterSetting() + splitter_setting.use_implicit_batch_dim = not lower_setting.explicit_batch_dimension + splitter_setting.min_block_size = lower_setting.min_block_size + splitter_setting.use_experimental_rt = lower_setting.use_experimental_rt + splitter = TRTSplitter(model, inputs, settings=splitter_setting) + splitter.node_support_preview() + return splitter.generate_split_results() + + +def create_lower_trt_interpreter(lower_setting: LowerSetting) -> LowerTrtInterpreter: + return LowerTrtInterpreter.create(lower_setting) + + +def default_lower_pass( + create_trt_interpreter: Callable[[LowerSetting], LowerTrtInterpreter], +) -> PassFunc: + def lower_pass( + mod: nn.Module, input: Input, lower_setting: LowerSetting, module_name: str + ) -> nn.Module: + """ + Create a module transformation pass which lowers an `fx.GraphModule` into a + `TRTModule` + """ + interpreter = create_trt_interpreter(lower_setting) + interp_res: TRTInterpreterResult = interpreter(mod, input, module_name) + if lower_setting.use_experimental_rt: + import io + + from torch_tensorrt._Device import Device + from torch_tensorrt._TRTModuleNext import TRTModuleNext + + with io.BytesIO() as engine_bytes: + engine_bytes.write(interp_res.engine.serialize()) + engine_str = engine_bytes.getvalue() + + trt_module = TRTModuleNext( + engine_str, + name=module_name, + input_binding_names=interp_res.input_names, + output_binding_names=interp_res.output_names, + target_device=Device(f"cuda:{torch.cuda.current_device()}"), + # cuda_graph_batch_size=lower_setting.cuda_graph_batch_size, # NOTE: Not sure what this is supposed to do + ) + return trt_module + + else: + trt_module = TRTModule( + engine=interp_res.engine, + input_names=interp_res.input_names, + output_names=interp_res.output_names, + cuda_graph_batch_size=lower_setting.cuda_graph_batch_size, + ) + return trt_module + + return lower_pass + + +@dc.dataclass(frozen=True) +class Lowerer: + """Lowers a module using fx2trt. + + This is a composable class to facilitate fx2trt. A normal fx2trt process + composes of the following passes to transform an `fx.GraphModule`: + + 1. trace - use torch.fx to trace the module so we can get the graph + representation of the model. + 2. split - the graph module is split into several submodules, + running either via TensorRT, or via regular CUDA. + + For each split that need to run via TRT, the following passes are + invoked: + + 3. `TRTInterpreter` - build the TRT engine for the submodule that + can be supported through `TRTInterpreter`. + 4. Wraps the executable TRT engine into `TRTModule`, which is an `nn.Module`. + 5. The converted submodule is then set back onto the top-level module + + """ + + lower_pass_manager_builder: LowerPassManagerBuilder + + @classmethod + def create( + cls, + lower_setting: LowerSetting, + interpreter_builder: Callable = create_lower_trt_interpreter, + split_func: Callable = default_split_function, + ) -> "Lowerer": + """Instantiate a `Lowerer` instance.""" + if not lower_setting.is_aten: + return cls( + lower_pass_manager_builder=LowerPassManagerBuilder( + lower_setting=lower_setting, + trace_func=lambda module, inputs: acc_tracer.trace( + module, + inputs, # type: ignore[arg-type] + ast_rewriter_allow_list=lower_setting.ast_rewriter_allow_list, + leaf_module_list=lower_setting.leaf_module_list, + ), + split_func=split_func, + lower_func=default_lower_pass(interpreter_builder), + ) + ) + # proxytensor_trace + else: + return cls( + lower_pass_manager_builder=LowerPassManagerBuilder( + lower_setting=lower_setting, + trace_func=lambda module, inputs: aten_tracer.opt_trace( + module, inputs + ), + split_func=split_func, + lower_func=default_lower_pass(interpreter_builder), + ) + ) + + def __call__( + self, + module: nn.Module, + inputs: Input, + additional_inputs: Optional[Input] = None, + fp16_conversion_fn: Optional[Callable[[Input], Input]] = None, + ) -> nn.Module: + lower_setting = self.lower_pass_manager_builder.lower_setting + atol = lower_setting.correctness_atol + rtol = lower_setting.correctness_rtol + device = lower_setting.device + + @validate_inference( + atol=atol, + rtol=rtol, + device=device, + ) + def do_lower(module: nn.Module, inputs: Input) -> nn.Module: + module.eval() + if ( + self.lower_pass_manager_builder.lower_setting.lower_precision + == LowerPrecision.FP16 + ): + module.half() + # A custom conversion function can be passed to the lowerer to + # handle inputs with custom types. By default, just handle + # tensors and NoneType. + if fp16_conversion_fn is None: + conversion_fn = ( + lambda x: x.half() + if x is not None and x.dtype == torch.float32 + else x + ) + else: + conversion_fn = fp16_conversion_fn + + inputs = tuple(conversion_fn(x) for x in inputs) + if lower_setting.is_aten: + pm = self.lower_pass_manager_builder.build_aten2trt_lower_pipeline( + inputs, additional_inputs + ) + else: + pm = self.lower_pass_manager_builder.build_trt_lower_pipeline( + inputs, additional_inputs + ) + lower_result = pm(module) + return lower_result + + return do_lower(module, inputs) diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/lower_setting.py b/py/torch_tensorrt/dynamo/fx_ts_compat/lower_setting.py new file mode 100644 index 0000000000..9008bbe8e9 --- /dev/null +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/lower_setting.py @@ -0,0 +1,98 @@ +import dataclasses as dc +from typing import List, Optional, Set, Type +import torch +from torch import nn +from torch.fx.passes.pass_manager import PassManager + +from .input_tensor_spec import InputTensorSpec +from torch_tensorrt.fx.passes.lower_basic_pass import ( + fuse_permute_linear, + fuse_permute_matmul, +) +from torch_tensorrt.fx.utils import LowerPrecision + + +@dc.dataclass +class LowerSettingBasic: + """ + Basic class for lowering. + lower_precision: lower precision dtype during lowering. + min_block_size(int): The minimum number of contiguous TensorRT convertable nodes in order to run them in TensorRT + ast_rewriter_allow_list (Optional[Set[nn.Module]]): Optional allow list of + modules that need AST rewriting. This is aiming to eliminate input variable involve in + exception checking control flow. + leaf_module_list (Optional[Set[nn.Module]]): Optional leaf module list where + modules will not be traced into. + verbose_profile (bool): verbosity of profiler, default to False. + """ + + lower_precision: LowerPrecision = LowerPrecision.FP32 + device: torch.device = torch.device(torch.cuda.current_device()) + min_block_size: int = 3 + disable_tf32: bool = False + sparse_weights: bool = False + ast_rewriter_allow_list: Optional[Set[Type[nn.Module]]] = None + leaf_module_list: Optional[Set[Type[nn.Module]]] = None + verbose_profile: bool = False + is_aten: bool = False + + +@dc.dataclass +class LowerSetting(LowerSettingBasic): + """ + Basic configuration for lowering stack. + Args: + input_specs: Specs for inputs to engine, can either be a single size or a + range defined by Min, Optimal, Max sizes. + explicit_precision: Use explicit precision during lowering. + workspace_size: The maximum workspace size. The maximum GPU temporary + memory which the TensorRT engine can use at execution time. + strict_type_constraints: Require TensorRT engine to strictly follow data type + setting at execution time. + customized_fuse_pass: List of custmozied pass to apply during lowering process. + lower_basic_fuse_pass: Enable basic pass fuse duirng lowering, i.e. fuse multiple operations + as (a->b->c->d)=>(e). Current basic fuse patterns are: + permute->linear + permute->matmul + debug: Enable TensorRT engine verbose log mode. + algo_selector: Enable TensorRT algorithm selector at execution time. + timing_cache_prefix: TensorRT timing cache file path. TensorRT engine will use timing + cache file at execution time if valid timing cache file is provided. + save_timing_cache: Save updated timing cache data into timing cache file if the timing + cache file is provided. + cuda_graph_batch_size (int): Cuda graph batch size, default to be -1. + preset_lowerer (str): when specified, use a preset logic to build the + instance of Lowerer. + only used by explicit batch dim with dynamic shape mode. In general, we use 2 GPU setting with + 2 stream on each. Set total number to 8 as a safe default value. + tactic_sources: tactic sources for TensorRT kernel selection. Default to None, + meaning all possible tactic sources. + correctness_atol: absolute tolerance for correctness check + correctness_rtol: relative tolerance for correctness check + use_experimental_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++). + """ + + input_specs: List[InputTensorSpec] = dc.field(default_factory=list) + explicit_batch_dimension: bool = True + explicit_precision: bool = False + workspace_size: int = 0 + strict_type_constraints: bool = False + customized_fuse_pass: PassManager = dc.field( + default_factory=lambda: PassManager.build_from_passlist([]) + ) + lower_basic_fuse_pass: PassManager = dc.field( + default_factory=lambda: PassManager.build_from_passlist( + [fuse_permute_matmul, fuse_permute_linear] + ) + ) + debug: bool = False + algo_selector = None + timing_cache_prefix: str = "" + save_timing_cache: bool = False + cuda_graph_batch_size: int = -1 + preset_lowerer: str = "" + opt_profile_replica: int = 8 + tactic_sources: Optional[int] = None + correctness_atol: float = 0.1 + correctness_rtol: float = 0.1 + use_experimental_rt: bool = False diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/passes/__init__.py b/py/torch_tensorrt/dynamo/fx_ts_compat/passes/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/passes/lower_pass_manager_builder.py b/py/torch_tensorrt/dynamo/fx_ts_compat/passes/lower_pass_manager_builder.py new file mode 100644 index 0000000000..0fd3777254 --- /dev/null +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/passes/lower_pass_manager_builder.py @@ -0,0 +1,333 @@ +import datetime +import logging +from functools import partial, wraps +from typing import Any, Callable, Optional, Sequence + +import torch +from torch import nn +from torch.fx.passes.pass_manager import inplace_wrapper, PassManager +from torch.fx.passes.shape_prop import ShapeProp +from torch.fx.passes.splitter_base import generate_inputs_for_submodules, SplitResult +from torch_tensorrt.fx.utils import LowerPrecision +from torch_tensorrt import _Input +from ..input_tensor_spec import InputTensorSpec + +from ..lower_setting import LowerSetting +from torch_tensorrt.fx.observer import Observer +from torch_tensorrt.fx.passes.remove_duplicate_output_args import ( + remove_duplicate_output_args, +) +from torch_tensorrt.fx.passes.graph_opts import common_subexpression_elimination +from .pass_utils import extract_example_tensors_from_input + +from torch_tensorrt.fx.passes.lower_basic_pass import ( # noqa + fix_clamp_numerical_limits_to_fp16, + fix_reshape_batch_dim, + replace_mutable_op, + replace_op_with_indices, + run_const_fold, +) + + +_LOGGER: logging.Logger = logging.getLogger(__name__) + + +Input = Sequence[Any] + + +# ---------------------------------------------------------------------- +# OBSERVERS +# ---------------------------------------------------------------------- +# List of observers. We can subscribe to them by calling its `add(callback)` +# function from anywhere in code: +# +# >>> from torch_tensorrt.fx.lower import FUSE_PASSES_POST_OBSERVER +# >>> with FUSE_PASSES_POST_OBSERVER.add(print_module_and_input): +# >>> # print_module_and_input will be called right after the fuse passes +# >>> lower(module, sample_input) + +# Observer for the model after the fuse passes. +FUSE_PASSES_POST_OBSERVER: Observer[Callable[[nn.Module, Input], None]] = Observer( + "FUSE_PASSES_POST_OBSERVER" +) + +# Observer for the TRT split submodules before lowering +LOWER_SPLIT_PRE_OBSERVER: Observer[Callable[[str, nn.Module, Input], None]] = Observer( + "LOWER_SPLIT_PRE_OBSERVER" +) + +# Observer for the TRT split submodules after lowering +LOWER_SPLIT_POST_OBSERVER: Observer[Callable[[str, nn.Module, Input], None]] = Observer( + "LOWER_SPLIT_POST_OBSERVER" +) +# ---------------------------------------------------------------------- + + +def wrapper(fn: Callable, input) -> Callable: + @wraps(fn) + def wrapped_fn(gm): + if isinstance(gm, torch.fx.GraphModule): + ShapeProp(gm).propagate(*input) + return fn(gm, input) + + return wrapped_fn + + +class LowerPassManagerBuilder: + """ + Build PassManager for lowering. + + Attributes: + lower_setting: Setting that will be used during process of lowering, see lower_setting.py for the details. + _trace_func: fx trace function for TRT conversion. + _split_func: the fx2trt split function. + _lower_func: function to create and run `TRTInterpreter` to convert `fx.GraphModule` + into a TensorRT engine. + + """ + + def __init__( + self, + lower_setting: LowerSetting, + trace_func: Callable, + split_func: Callable, + lower_func: Callable, + ): + self.lower_setting = lower_setting + self._trace_func = trace_func + self._split_func = split_func + self._lower_func = lower_func + + def _const_fold_pass(self) -> PassManager: + passes = [ + wrapper(self._trace_func, self._input), + run_const_fold, + ] + return PassManager.build_from_passlist(passes) + + def graph_optimization_pass(self) -> PassManager: + passes = [ + wrapper(self._trace_func, self._input), + ] + for p in self.lower_setting.customized_fuse_pass.passes: + passes.append(wrapper(p, self._input)) + for p in self.lower_setting.lower_basic_fuse_pass.passes: + passes.append(wrapper(p, self._input)) + if ( + hasattr(self.lower_setting, "lower_precision") + and self.lower_setting.lower_precision is LowerPrecision.FP16 + ) or ( + hasattr(self.lower_setting, "precision") + and self.lower_setting.precision is LowerPrecision.FP16 + ): + passes.append(wrapper(fix_clamp_numerical_limits_to_fp16, self._input)) + + passes.append(inplace_wrapper(common_subexpression_elimination)) + passes.append( + inplace_wrapper(lambda m: FUSE_PASSES_POST_OBSERVER.observe(m, self._input)) + ) + passes.append(fix_reshape_batch_dim) + + return PassManager.build_from_passlist(passes) + + def graph_optimization_pass_aten(self) -> PassManager: + passes = [] + + for p in self.lower_setting.customized_fuse_pass.passes: + passes.append(wrapper(p, self._input)) + for p in self.lower_setting.lower_basic_fuse_pass.passes: + passes.append(wrapper(p, self._input)) + # TODO fix this pass for aten graph + # if ( + # hasattr(self.lower_setting, "lower_precision") + # and self.lower_setting.lower_precision is LowerPrecision.FP16 + # ) or ( + # hasattr(self.lower_setting, "precision") + # and self.lower_setting.precision is LowerPrecision.FP16 + # ): + # passes.append(wrapper(fix_clamp_numerical_limits_to_fp16, self._input)) + + passes.append( + inplace_wrapper(lambda m: FUSE_PASSES_POST_OBSERVER.observe(m, self._input)) + ) + # TODO we most likely do not need it for aten + # passes.append(fix_reshape_batch_dim) + + return PassManager.build_from_passlist(passes) + + def _split_pass(self) -> PassManager: + passes = [ + partial( + self._split_func, inputs=self._input, lower_setting=self.lower_setting + ) + ] + passes.append( + inplace_wrapper( + lambda split_result: remove_duplicate_output_args( + split_result.split_module, split_result.submodule_inputs.keys() + ) + ) + ) + + return PassManager.build_from_passlist(passes) + + def _trt_lower_pass(self) -> PassManager: + def lower_func(split_result: SplitResult) -> nn.Module: + if ( + hasattr(self.lower_setting, "explicit_batch_dimension") + and self.lower_setting.explicit_batch_dimension + and self._additional_input + ): + additional_submodule_inputs = generate_inputs_for_submodules( + split_result.split_module, + self._additional_input, + list(split_result.submodule_inputs.keys()), + ) + else: + additional_submodule_inputs = None + + for submod_name, submod_inputs in split_result.submodule_inputs.items(): + submod = getattr(split_result.split_module, submod_name) + + LOWER_SPLIT_PRE_OBSERVER.observe(submod_name, submod, submod_inputs) + + # Only acc submodules will be lowered. + if not submod_name.startswith(split_result.non_acc_submodule_prefix): + _LOGGER.info(f"Now lowering submodule {submod_name}") + lowering_start_time = datetime.datetime.now() + + self.lower_setting.input_specs = self._trt_input + + lowered_module = self._lower_func( + submod, submod_inputs, self.lower_setting, submod_name + ) + setattr(split_result.split_module, submod_name, lowered_module) + LOWER_SPLIT_POST_OBSERVER.observe( + submod_name, lowered_module, submod_inputs + ) + _LOGGER.info( + f"Lowering submodule {submod_name} elapsed time {datetime.datetime.now() - lowering_start_time}" + ) + + return split_result.split_module + + return PassManager.build_from_passlist([lower_func]) + + def _default_lower_pass(self) -> PassManager: + def lower_func(split_result: SplitResult) -> nn.Module: + if self._additional_input: + additional_submodule_inputs = generate_inputs_for_submodules( + split_result.split_module, + self._additional_input, + list(split_result.submodule_inputs.keys()), + ) + else: + additional_submodule_inputs = None + + for submod_name, submod_inputs in split_result.submodule_inputs.items(): + submod = getattr(split_result.split_module, submod_name) + + LOWER_SPLIT_PRE_OBSERVER.observe(submod_name, submod, submod_inputs) + + # Only acc submodules will be lowered. + if not submod_name.startswith(split_result.non_acc_submodule_prefix): + _LOGGER.info(f"Now lowering submodule {submod_name}") + lowering_start_time = datetime.datetime.now() + + self.lower_setting.additional_inputs = ( + additional_submodule_inputs[submod_name] + if additional_submodule_inputs + else None, + ) + + lowered_module = self._lower_func( + submod, submod_inputs, self.lower_setting, submod_name + ) + setattr(split_result.split_module, submod_name, lowered_module) + LOWER_SPLIT_POST_OBSERVER.observe( + submod_name, lowered_module, submod_inputs + ) + _LOGGER.info( + f"Lowering submodule {submod_name} elapsed time {datetime.datetime.now() - lowering_start_time}" + ) + + return split_result.split_module + + return PassManager.build_from_passlist([lower_func]) + + def _default_replace_mutable_op_pass(self) -> PassManager: + return PassManager.build_from_passlist([replace_mutable_op]) + + def build_trt_lower_pipeline( + self, input: Input, additional_input: Optional[Input] = None + ) -> PassManager: + + self._input = extract_example_tensors_from_input( + input, self.lower_setting.device + ) + self._trt_input = [] + for input_obj in input: + if isinstance(input_obj, _Input.Input): + self._trt_input.append(InputTensorSpec.from_input(input_obj)) + elif isinstance(input_obj, torch.Tensor): + self._trt_input.append(InputTensorSpec.from_tensor(input_obj)) + else: + raise ValueError( + "Invalid input type provided in the FX lowering. Expected type: torch_tensorrt.Input or torch.Tensor" + ) + + self._additional_input = additional_input + passes = [] + + passes.append(self._default_replace_mutable_op_pass()) + passes.append(self._const_fold_pass()) + passes.append(self.graph_optimization_pass()) + passes.append(self._split_pass()) + passes.append(self._trt_lower_pass()) + + pm = PassManager.build_from_passlist(passes) + return pm + + def build_aten2trt_lower_pipeline( + self, input: Input, additional_input: Optional[Input] = None + ) -> PassManager: + + self._input = extract_example_tensors_from_input(input) + self._trt_input = [] + for input_obj in input: + if isinstance(input_obj, _Input.Input): + self._trt_input.append(InputTensorSpec.from_input(input_obj)) + elif isinstance(input_obj, torch.Tensor): + self._trt_input.append(InputTensorSpec.from_tensor(input_obj)) + else: + raise ValueError( + "Invalid input type provided in the FX lowering. Expected type: torch_tensorrt.Input or torch.Tensor" + ) + + self._additional_input = additional_input + passes = [] + passes.append( + wrapper(self._trace_func, self._input), + ) + passes.append(self.graph_optimization_pass_aten()) + passes.append(self._split_pass()) + passes.append(self._trt_lower_pass()) + + pm = PassManager.build_from_passlist(passes) + return pm + + def build_default_lower_pipeline( + self, input: Input, additional_input: Optional[Input] = None + ) -> PassManager: + self._input = input + self._additional_input = additional_input + passes = [] + + passes.append(self._default_replace_mutable_op_pass()) + passes.append(self._const_fold_pass()) + passes.append(self.graph_optimization_pass()) + passes.append(self._split_pass()) + passes.append(self._default_lower_pass()) + + pm = PassManager.build_from_passlist(passes) + return pm diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/passes/pass_utils.py b/py/torch_tensorrt/dynamo/fx_ts_compat/passes/pass_utils.py new file mode 100644 index 0000000000..96fa96cfae --- /dev/null +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/passes/pass_utils.py @@ -0,0 +1,298 @@ +import io +import logging +import tempfile +from datetime import datetime +from functools import wraps +from typing import Any, Callable, List, Optional + +import torch +from torch import fx +from torch.fx.passes.shape_prop import ShapeProp +from torch_tensorrt import _Input + +# Create an alias for module input type to avoid littering pyre-ignore for Any +# throughout the file. +Input = Any +_LOGGER: logging.Logger = logging.getLogger(__name__) + +PassFunc = Callable[[fx.GraphModule, Input], fx.GraphModule] + +RELAX_ACCURACY_FAILURE: bool = False +FINAL_CHECK_ATOL_MULTIPLIER: float = 10 +FINAL_CHECK_RTOL_MULTIPLIER: float = 10 + + +def extract_example_tensors_from_input( + inputs: Any, device: torch.device = torch.device("cuda") +): + input_tensors = [] + for input_obj in inputs: + if isinstance(input_obj, _Input.Input): + if isinstance(input_obj.shape, dict): + input_tensors.append( + input_obj.example_tensor(optimization_profile_field="opt_shape").to( + device + ) + ) + else: + input_tensors.append(input_obj.example_tensor().to(device)) + elif isinstance(input_obj, torch.Tensor): + input_tensors.append(input_obj) + else: + raise ValueError( + "Invalid input type provided in the FX lowering. Expected type: torch_tensorrt.Input or torch.Tensor" + ) + + return input_tensors + + +class RelaxAccuracyCheckMode: + """ + Basically a context manager that controls a global variable that controls + the accuracy check mode. Use it like + with RelaxAccuracyCheckMode(True): + fx2trt() + """ + + def __init__( + self, + mode: bool, + final_atol_multiplier: Optional[float] = None, + final_rtol_multiplier: Optional[float] = None, + ): + """ + Arguments: + mode: whether we relax the immediate accuracy check failure or not. If yes, we will do an extra + accruacy check by raising the tolerance by the multipler times and only raise error if that fails. + This is to avoid catastrophic errors. + final_atol_multiplier [optional]: set FINAL_CHECK_ATOL_MULTIPLIER if specifier. + final_rtol_multiplier [optional]: set FINAL_CHECK_RTOL_MULTIPLIER if specifier. + """ + global RELAX_ACCURACY_FAILURE + global FINAL_CHECK_ATOL_MULTIPLIER + global FINAL_CHECK_RTOL_MULTIPLIER + self._old_mode = ( + RELAX_ACCURACY_FAILURE, + FINAL_CHECK_ATOL_MULTIPLIER, + FINAL_CHECK_RTOL_MULTIPLIER, + ) + RELAX_ACCURACY_FAILURE = mode + FINAL_CHECK_ATOL_MULTIPLIER = ( + final_atol_multiplier + if final_atol_multiplier + else FINAL_CHECK_ATOL_MULTIPLIER + ) + FINAL_CHECK_RTOL_MULTIPLIER = ( + final_rtol_multiplier + if final_rtol_multiplier + else FINAL_CHECK_RTOL_MULTIPLIER + ) + _LOGGER.info( + f"Set new relaxed accuracy check mode: {RELAX_ACCURACY_FAILURE=}, {FINAL_CHECK_ATOL_MULTIPLIER=}, {FINAL_CHECK_RTOL_MULTIPLIER=}" + ) + + def __enter__(self): + pass + + def __exit__(self, type, value, traceback): + global RELAX_ACCURACY_FAILURE + global FINAL_CHECK_ATOL_MULTIPLIER + global FINAL_CHECK_RTOL_MULTIPLIER + ( + RELAX_ACCURACY_FAILURE, + FINAL_CHECK_ATOL_MULTIPLIER, + FINAL_CHECK_RTOL_MULTIPLIER, + ) = self._old_mode + _LOGGER.info( + f"Restored old relaxed accuracy check mode: {RELAX_ACCURACY_FAILURE=}, {FINAL_CHECK_ATOL_MULTIPLIER=}, {FINAL_CHECK_RTOL_MULTIPLIER=}" + ) + + +def chain_passes(*passes: PassFunc) -> PassFunc: + """ + Chains a sequence of pass functions to form a single pass function + """ + + def parent_pass(module: fx.GraphModule, input: Input) -> fx.GraphModule: + for pass_ in passes: + if isinstance(module, torch.fx.GraphModule): + ShapeProp(module).propagate(*input) + module = pass_(module, input) + return module + + return parent_pass + + +# (TODO(shirongwu): Add exception notification for fblearner flow when available, notify oncall +# on pass that failed accuracy check. +def validate_inference( + rtol=None, atol=None, device=torch.device(torch.cuda.current_device()) +): + def _validate_inference(pass_: PassFunc) -> PassFunc: + """ + Wraps a pass function to validate that its inference results before and + after the pass run should be `close`. + """ + + @wraps(pass_) + def pass_with_validation( + module: fx.GraphModule, + input: Input, + *args, + **kwargs, + ) -> fx.GraphModule: + input_tensors = extract_example_tensors_from_input(input, device) + res0 = module(*input_tensors) + processed_module = pass_(module, input, *args, **kwargs) + res1 = processed_module(*input_tensors) + tensor_res_0 = _collect_tensors(res0) + tensor_res_1 = _collect_tensors(res1) + relax_accuracy_check_failure = RELAX_ACCURACY_FAILURE + + for kk, (x, y) in enumerate(zip(tensor_res_0, tensor_res_1)): + kwargs2 = {"equal_nan": True} + if rtol: + kwargs2["rtol"] = rtol + if atol: + kwargs2["atol"] = atol + kwargs2[ + "msg" + ] = ( + lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}" + ) + # If tensors are on different devices, make sure to compare + # their copies that are on the same device. + if x.get_device() != y.get_device(): + x = x.cpu() + y = y.cpu() + try: + torch.testing.assert_close(x, y, **kwargs2) + except Exception as e: + if relax_accuracy_check_failure: + _LOGGER.error(f"{e}") + kwargs2["rtol"] *= FINAL_CHECK_RTOL_MULTIPLIER + kwargs2["atol"] *= FINAL_CHECK_ATOL_MULTIPLIER + new_atol = kwargs2["atol"] + new_rtol = kwargs2["rtol"] + _LOGGER.info( + f"Do a sanity check to see whether things are completely wrong with {new_atol=}, {new_rtol=}" + ) + torch.testing.assert_close(x, y, **kwargs2) + return processed_module + else: + raise e + + return processed_module + + return pass_with_validation + + return _validate_inference + + +Decorator = Callable[[Callable], Callable] + + +def decorate_method(dec_for_function: Decorator) -> Decorator: + def dec_for_method(unbounded_method) -> Callable: + def decorated_unbounded_method(self, *args, **kwargs): + @dec_for_function + def bounded_method(*args, **kwargs): + return unbounded_method(self, *args, **kwargs) + + return bounded_method(*args, **kwargs) + + return decorated_unbounded_method + + return dec_for_method + + +def log_perf_before_after(pass_: PassFunc) -> PassFunc: + """ + Wraps a pass function to log perf of the module before and after the pass + """ + + @wraps(pass_) + def check_perf_with_before_after_log( + module: fx.GraphModule, input: Input + ) -> fx.GraphModule: + def benchmark_torch_function(iters: int, f, *args) -> float: + """Estimates the average time duration for a single inference call in second + + If the input is batched, then the estimation is for the batches inference call. + + Args: + iters: number of inference iterations to run + f: a function to perform a single inference call + + Returns: + estimated average time duration in second for a single inference call + """ + with torch.inference_mode(): + f(*args) + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + # print("== Start benchmark iterations") + with torch.inference_mode(): + start_event.record() + for _ in range(iters): + f(*args) + end_event.record() + torch.cuda.synchronize() + # print("== End benchmark iterations") + return (start_event.elapsed_time(end_event) * 1.0e-3) / iters + + time_before = benchmark_torch_function(100, lambda: module(*input)) + _LOGGER.info(f"[{pass_}] Perf Before(eager mode): {time_before}") + + module = pass_(module, input) + time_after = benchmark_torch_function(100, lambda: module(*input)) + _LOGGER.info(f"[{pass_}] Perf After(eager mode): {time_after}") + return module + + return check_perf_with_before_after_log + + +def log_before_after(pass_: PassFunc) -> PassFunc: + """ + Wraps a pass function to log the module graph before and after the pass + """ + + @wraps(pass_) + def pass_with_before_after_log( + module: fx.GraphModule, input: Input + ) -> fx.GraphModule: + before_io = io.StringIO() + after_io = io.StringIO() + with tempfile.NamedTemporaryFile( + mode="w", + encoding="utf-8", + delete=False, + ) as f: + print(f"[{pass_}] Before:\n{module.graph}", file=f) + print(module.graph, file=before_io) + start_time = datetime.now() + module = pass_(module, input) + t_elapsed = datetime.now() - start_time + print(f"[{pass_}] After:\n{module.graph}", file=f) + print(module.graph, file=after_io) + t = before_io.getvalue() == after_io.getvalue() + _LOGGER.info( + f"== Log pass {pass_} before/after graph to {f.name}, before/after are the same = {t}, time elapsed = {t_elapsed}" + ) + return module + + return pass_with_before_after_log + + +def _collect_tensors(arg: fx.node.Argument) -> List[torch.Tensor]: + """Collects all the tensors found in a nested container object""" + res: List[torch.Tensor] = [] + + def collect(x: fx.node.Argument) -> fx.node.Argument: + if isinstance(x, torch.Tensor): + res.append(x) + return x + + fx.node.map_aggregate(arg, collect) + return res diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_input.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_input.py new file mode 100644 index 0000000000..b7dd8153cb --- /dev/null +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_input.py @@ -0,0 +1,89 @@ +# Owner(s): ["oncall: gpu_enablement"] + +import io +import os + +import torch +import torch_tensorrt +from torch.testing._internal.common_utils import run_tests, TestCase + + +class TestInput(TestCase): + def test_add_model(self): + class TestModule(torch.nn.Module): + def forward(self, x): + return x + x + + inputs = [torch_tensorrt.Input(shape=(1, 3, 3, 4), dtype=torch.float32)] + rand_inputs = [torch.randn((1, 3, 3, 4), dtype=torch.float32).cuda()] + mod = TestModule().cuda().eval() + ref_output = mod(*rand_inputs) + + trt_mod = torch_tensorrt.compile( + mod, + ir="fx_ts_compat", + inputs=inputs, + min_block_size=1, + ) + trt_output = trt_mod(*rand_inputs) + + torch.testing.assert_close(trt_output, ref_output, rtol=1e-04, atol=1e-04) + + def test_conv_model(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 6, 1, 1, 1, 1, 1, True) + + def forward(self, x): + return self.conv(x) + + inputs = [torch_tensorrt.Input(shape=(1, 3, 32, 32), dtype=torch.float32)] + rand_inputs = [torch.randn((1, 3, 32, 32), dtype=torch.float32).cuda()] + mod = TestModule().cuda().eval() + ref_output = mod(*rand_inputs) + + trt_mod = torch_tensorrt.compile( + mod, + ir="fx_ts_compat", + inputs=inputs, + min_block_size=1, + ) + trt_output = trt_mod(*rand_inputs) + + torch.testing.assert_close(trt_output, ref_output, rtol=1e-04, atol=1e-04) + + def test_conv_model_with_dyn_shapes(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 6, 1, 1, 1, 1, 1, True) + + def forward(self, x): + return self.conv(x) + + inputs = [ + torch_tensorrt.Input( + min_shape=(1, 3, 32, 32), + opt_shape=(8, 3, 32, 32), + max_shape=(16, 3, 32, 32), + dtype=torch.float32, + ) + ] + rand_inputs = [torch.randn((4, 3, 32, 32), dtype=torch.float32).cuda()] + mod = TestModule().cuda().eval() + ref_output = mod(*rand_inputs) + + trt_mod = torch_tensorrt.compile( + mod, + ir="fx_ts_compat", + inputs=inputs, + min_block_size=1, + ) + trt_output = trt_mod(*rand_inputs) + + torch.testing.assert_close(trt_output, ref_output, rtol=1e-04, atol=1e-04) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_input_tensor_spec.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_input_tensor_spec.py new file mode 100644 index 0000000000..0761b964f8 --- /dev/null +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_input_tensor_spec.py @@ -0,0 +1,84 @@ +# Owner(s): ["oncall: gpu_enablement"] + +from typing import List, Optional + +import torch +import torch_tensorrt +from torch.testing._internal.common_utils import run_tests, TestCase +from torch_tensorrt.dynamo.fx_ts_compat import InputTensorSpec, LowerSetting + + +class TestTRTModule(TestCase): + def _validate_spec( + self, + spec: InputTensorSpec, + tensor: torch.Tensor, + dynamic_dims: Optional[List[int]] = None, + ): + expected_shape = list(tensor.shape) + if dynamic_dims: + for dim in dynamic_dims: + expected_shape[dim] = -1 + self.assertSequenceEqual(spec.shape, expected_shape) + self.assertEqual(spec.dtype, tensor.dtype) + self.assertEqual(spec.device, tensor.device) + self.assertTrue(spec.has_batch_dim) + + def test_from_tensor(self): + tensor = torch.randn(1, 2, 3) + spec = InputTensorSpec.from_tensor(tensor) + self._validate_spec(spec, tensor) + + def test_from_tensors(self): + tensors = [torch.randn(1, 2, 3), torch.randn(2, 4)] + specs = InputTensorSpec.from_tensors(tensors) + for spec, tensor in zip(specs, tensors): + self._validate_spec(spec, tensor) + + def test_from_tensors_with_dynamic_batch_size(self): + tensors = [torch.randn(1, 2, 3), torch.randn(1, 4)] + batch_size_range = [2, 3, 4] + specs = InputTensorSpec.from_tensors_with_dynamic_batch_size( + tensors, batch_size_range + ) + for spec, tensor in zip(specs, tensors): + self._validate_spec(spec, tensor, dynamic_dims=[0]) + + for batch_size, shape in zip(batch_size_range, spec.shape_ranges[0]): + self.assertEqual(batch_size, shape[0]) + self.assertSequenceEqual(tensor.shape[1:], shape[1:]) + + def test_from_tensors_with_dynamic_batch_size_different_batch_dims(self): + tensors = [torch.randn(1, 2, 3), torch.randn(2, 1, 4)] + batch_size_range = [2, 3, 4] + specs = InputTensorSpec.from_tensors_with_dynamic_batch_size( + tensors, batch_size_range, batch_dims=[0, 1] + ) + for i, spec_and_tensor in enumerate(zip(specs, tensors)): + spec, tensor = spec_and_tensor + self._validate_spec(spec, tensor, dynamic_dims=[i]) + + for batch_size, shape in zip(batch_size_range, spec.shape_ranges[0]): + self.assertEqual(batch_size, shape[i]) + tensor_shape = list(tensor.shape) + tensor_shape[i] = batch_size + self.assertSequenceEqual(tensor_shape, shape) + + def test_from_static_input(self): + tensors = [torch.randn(1, 2, 3), torch.randn(2, 1, 4)] + inputs = torch_tensorrt.Input.from_tensors(tensors) + specs = [InputTensorSpec.from_input(input) for input in inputs] + for spec, tensor in zip(specs, tensors): + self._validate_spec(spec, tensor) + + def test_from_dynamic_input(self): + inputs = torch_tensorrt.Input( + min_shape=(2, 2, 3), opt_shape=(4, 2, 3), max_shape=(8, 2, 3) + ) + example_tensor = inputs.example_tensor(optimization_profile_field="opt_shape") + spec = InputTensorSpec.from_input(inputs) + self._validate_spec(spec, example_tensor, dynamic_dims=[0]) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/tools/__init__.py b/py/torch_tensorrt/dynamo/fx_ts_compat/tools/__init__.py new file mode 100644 index 0000000000..6423aa65ea --- /dev/null +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/tools/__init__.py @@ -0,0 +1 @@ +from .trt_minimizer import * # noqa: F401 F403 diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/tools/common_fx2trt.py b/py/torch_tensorrt/dynamo/fx_ts_compat/tools/common_fx2trt.py new file mode 100644 index 0000000000..334243fef4 --- /dev/null +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/tools/common_fx2trt.py @@ -0,0 +1,446 @@ +import logging +import time +import unittest +from typing import Callable, List, Optional, Set, Tuple + +import torch +import torch.fx + +import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer +import torch_tensorrt.fx.tracer.dispatch_tracer.aten_tracer as aten_tracer +from torch_tensorrt.fx import TRTModule +from torch.fx.experimental.normalize import NormalizeArgs +from torch.fx.passes import shape_prop +from torch.fx.passes.infra.pass_base import PassResult +from torch.testing._internal.common_utils import TestCase +from torch_tensorrt.dynamo.fx_ts_compat import InputTensorSpec, TRTInterpreter +from torch_tensorrt.fx.passes.lower_basic_pass_aten import ( + compose_bmm, + compose_chunk, + compose_getitem_slice, + remove_ops, + replace_aten_op_with_indices, + replace_aten_reshape_alias_with_replace, + replace_builtin_ops, + replace_native_layernorm_with_layernorm, + replace_transpose_mm_op_with_linear, + run_const_fold, +) +from torch_tensorrt.dynamo.fx_ts_compat.passes.pass_utils import chain_passes +from torch_tensorrt.fx.utils import LowerPrecision, proxytensor_trace + +_LOGGER: logging.Logger = logging.getLogger(__name__) + + +def fetch_attr(mod, target): + """ + Fetch an attribute from the ``Module`` hierarchy of ``mod.module``. + + Args: + target (str): The fully-qualfiied name of the attribute to fetch + + Return: + Any: The value of the attribute. + """ + target_atoms = target.split(".") + attr_itr = mod + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError( + f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}" + ) + attr_itr = getattr(attr_itr, atom) + return attr_itr + + +@unittest.skipIf(not torch.cuda.is_available(), "Skip because CUDA is not available") +class TRTTestCase(TestCase): + def setUp(self): + super().setUp() + torch.manual_seed(3) + + def run_test( + self, + mod, + inputs, + expected_ops, + unexpected_ops, + interpreter, + rtol, + atol, + precision=LowerPrecision.FP32, + ): + with torch.no_grad(): + cuda_inputs = [] + for i in inputs: + cuda_inputs.append(i.cuda()) + + mod.eval() + if len(expected_ops): + self.assert_has_op(mod, expected_ops) + if unexpected_ops: + self.assert_unexpected_op(mod, unexpected_ops) + start = time.perf_counter() + interpreter_result = interpreter.run(lower_precision=precision) + sec = time.perf_counter() - start + _LOGGER.info(f"Interpreter run time(s): {sec}") + trt_mod = TRTModule( + interpreter_result.engine, + interpreter_result.input_names, + interpreter_result.output_names, + ) + + ref_outputs = mod(*inputs) + + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + outputs = trt_mod(*cuda_inputs) + end_event.record() + torch.cuda.synchronize() + _LOGGER.info( + f"TRT run time(s)= {(start_event.elapsed_time(end_event) * 1.0e-3)}" + ) + + if type(outputs) not in (list, tuple): + outputs = [outputs] + if type(ref_outputs) not in ( + list, + tuple, + torch.return_types.max, + torch.return_types.min, + ): + ref_outputs = [ref_outputs] + for out, ref in zip(outputs, ref_outputs): + if not isinstance(ref, torch.Tensor): + ref = torch.tensor([ref]) + ref = ref.cpu() # to_dtype test has cases with gpu output + if ref.dtype == torch.int64: + ref = ref.int() # convert torch.max's index output tensor to int32 + torch.testing.assert_close( + out.cpu(), ref, rtol=rtol, atol=atol, equal_nan=True + ) + + def run_test_custom_compare_results( + self, + mod, + inputs, + expected_ops, + interpreter, + comparators: List[Tuple[Callable, List]], + fp16_mode=False, + ): + """ + Runs the test and compares the result using the provided comparators. + The size of comparators must be equal to the number of outputs from 'mod'. + + mod - a model to run. + inputs - a list of the model inputs. + expected ops - a list of ops that should be verified. + interpreter - used for converting the model to TRT. + comparators - a list of (func, args) pairs corresponding to each of + the module outputs. usage: func(x, y, *args) + + """ + with torch.no_grad(): + cuda_inputs = [] + for i in inputs: + cuda_inputs.append(i.cuda()) + + mod.eval() + if len(expected_ops): + self.assert_has_op(mod, expected_ops) + + interpreter_result = interpreter.run( + lower_precision=LowerPrecision.FP16 + if fp16_mode + else LowerPrecision.FP32 + ) + trt_mod = TRTModule( + interpreter_result.engine, + interpreter_result.input_names, + interpreter_result.output_names, + ) + res_trt = trt_mod(*cuda_inputs).cpu() + res_cpu = mod(*inputs) + assert len(res_trt) == len(res_cpu) + assert len(res_cpu) == len(comparators) + for output_trt, output_cpu, comparator in zip( + res_trt, res_cpu, comparators + ): + comp_func = comparator[0] + args = comparator[1] + self.assertTrue(comp_func(output_trt, output_cpu, *args)) + + def run_test_with_error(self, mod, inputs, interpreter, expect_error): + with self.assertRaises(expect_error): + with torch.no_grad(): + cuda_inputs = [] + for i in inputs: + cuda_inputs.append(i.cuda()) + + mod.eval() + interpreter.run(lower_precision=LowerPrecision.FP32) + + def assert_has_op(self, mod, ops): + ops_in_mod = set() + + for node in mod.graph.nodes: + if node.op == "call_module": + ops_in_mod.add(type(fetch_attr(mod, node.target))) + elif node.op in {"call_function", "call_method"}: + ops_in_mod.add(node.target) + + self.assertTrue( + ops_in_mod >= ops, f"expected ops {ops}, actuall ops {ops_in_mod}" + ) + + def assert_unexpected_op(self, mod, ops): + for node in mod.graph.nodes: + if node.op == "call_module": + if type(fetch_attr(mod, node.target)) in ops: + return False + elif node.op in {"call_function", "call_method"}: + if node.target in ops: + return False + return True + + +class VanillaTestCase(TRTTestCase): + def run_test(self, mod, inputs, expected_ops, rtol=1e-03, atol=1e-03): + mod = torch.fx.symbolic_trace(mod) + shape_prop.ShapeProp(mod).propagate(*inputs) + mod = NormalizeArgs(mod).transform() + interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs)) + super().run_test(mod, inputs, expected_ops, None, interp, rtol, atol) + + def run_test_custom_compare_results( + self, + mod, + inputs, + expected_ops, + interpreter, + comparators: List[Tuple[Callable, List]], + fp16_mode=False, + ): + # interpreter is ignored, we do not need this for Vanilla tests + # Note this is different from internal version, we need to fix the test case + # after we refactor the internal callsites to use this file + mod = torch.fx.symbolic_trace(mod) + shape_prop.ShapeProp(mod).propagate(*inputs) + mod = NormalizeArgs(mod).transform() + interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs)) + super().run_test_custom_compare_results( + mod, inputs, expected_ops, interp, comparators, fp16_mode=fp16_mode + ) + + +class AccTestCase(TRTTestCase): + def run_test( + self, + mod, + inputs, + expected_ops, + unexpected_ops=None, + apply_passes=None, + test_explicit_batch_dim=True, + test_implicit_batch_dim=False, + test_explicit_precision=False, + rtol=1e-03, + atol=1e-03, + precision=LowerPrecision.FP32, + ): + mod.eval() + mod = acc_tracer.trace(mod, inputs) + + if apply_passes is not None: + pass_tracer = chain_passes(*apply_passes) + mod = pass_tracer(mod, inputs) + + if test_implicit_batch_dim: + interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs)) + super().run_test( + mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision + ) + + if test_explicit_batch_dim: + interp = TRTInterpreter( + mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True + ) + super().run_test( + mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision + ) + + if test_explicit_precision: + interp = TRTInterpreter( + mod, + InputTensorSpec.from_tensors(inputs), + explicit_precision=test_explicit_precision, + ) + super().run_test( + mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol + ) + + interp = TRTInterpreter( + mod, + InputTensorSpec.from_tensors(inputs), + explicit_batch_dimension=True, + explicit_precision=test_explicit_precision, + ) + super().run_test( + mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision + ) + + def run_test_with_assert_error( + self, + mod, + inputs, + expect_error, + test_explicit_batch_dim=True, + test_implicit_batch_dim=True, + ): + mod.eval() + mod = acc_tracer.trace(mod, inputs) + + if test_implicit_batch_dim: + interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs)) + super().run_test_with_error(mod, inputs, interp, expect_error) + + if test_explicit_batch_dim: + interp = TRTInterpreter( + mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True + ) + super().run_test_with_error(mod, inputs, interp, expect_error) + + def run_test_with_dynamic_shape( + self, + mod, + input_specs, + expected_ops, + unexpected_ops=None, + rtol=1e-03, + atol=1e-03, + ): + mod.eval() + inputs = InputTensorSpec.create_inputs_from_specs(input_specs) + mod = acc_tracer.trace(mod, inputs) + interp = TRTInterpreter(mod, input_specs, explicit_batch_dimension=True) + super().run_test(mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol) + + +class DispatchTestCase(TRTTestCase): + def generate_graph( + self, + mod: torch.nn.Module, + original_inputs: List[torch.Tensor], + expected_ops: Set[Callable], + unexpected_ops: Optional[Set[Callable]] = None, + customized_passes: List[Callable] = None, + ): + # Torchdynamo+aot proxytensor tracer + # Below are common passes + passes_list = [ + compose_bmm, + compose_chunk, + compose_getitem_slice, + replace_aten_reshape_alias_with_replace, + replace_aten_op_with_indices, + replace_transpose_mm_op_with_linear, # after compose_bmm + replace_native_layernorm_with_layernorm, + remove_ops, + replace_builtin_ops, # after replace_native_layernorm_with_layernorm + ] + # Combine with customized passes specific to any model + if customized_passes: + passes_list.extend(customized_passes) + fx_module, _ = aten_tracer.trace(mod, original_inputs) + for passes in passes_list: + pr: PassResult = passes(fx_module) + fx_module = pr.graph_module + fx_module(*original_inputs) + + fx_module = run_const_fold(fx_module) + _LOGGER.info(f"FX graph= {fx_module.graph}") + + if len(expected_ops): + self.assert_has_op(fx_module, expected_ops) + if unexpected_ops: + self.assert_unexpected_op(fx_module, unexpected_ops) + + return fx_module + + def run_test( + self, + mod, + inputs, + expected_ops, + unexpected_ops=None, + apply_passes=None, + test_explicit_batch_dim=True, + test_explicit_precision=False, + rtol=1e-03, + atol=1e-03, + precision=LowerPrecision.FP32, + ): + mod.eval() + mod = self.generate_graph(mod, inputs, expected_ops, unexpected_ops, None) + + if apply_passes is not None: + pass_tracer = chain_passes(*apply_passes) + mod = pass_tracer(mod, inputs) + + if test_explicit_batch_dim: + interp = TRTInterpreter( + mod, + InputTensorSpec.from_tensors(inputs), + explicit_batch_dimension=True, + ) + super().run_test( + mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision + ) + + if test_explicit_precision: + interp = TRTInterpreter( + mod, + InputTensorSpec.from_tensors(inputs), + explicit_precision=test_explicit_precision, + ) + super().run_test( + mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol + ) + + interp = TRTInterpreter( + mod, + InputTensorSpec.from_tensors(inputs), + explicit_batch_dimension=True, + explicit_precision=test_explicit_precision, + ) + super().run_test( + mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision + ) + + def run_test_with_dynamic_shape( + self, + mod, + input_specs, + expected_ops, + unexpected_ops=None, + rtol=1e-03, + atol=1e-03, + ): + mod.eval() + inputs = InputTensorSpec.create_inputs_from_specs(input_specs) + mod = self.generate_graph(mod, inputs, expected_ops, unexpected_ops, None) + + interp = TRTInterpreter( + mod, + input_specs, + explicit_batch_dimension=True, + ) + # Since the lowering is based on optimal shape. We need to test with + # different shape(for ex. max shape) for testing dynamic shape + inputs_max = InputTensorSpec.create_inputs_from_max_specs(input_specs) + super().run_test( + mod, inputs_max, expected_ops, unexpected_ops, interp, rtol, atol + ) diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/tools/trt_minimizer.py b/py/torch_tensorrt/dynamo/fx_ts_compat/tools/trt_minimizer.py new file mode 100644 index 0000000000..f5c15b049b --- /dev/null +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/tools/trt_minimizer.py @@ -0,0 +1,101 @@ +import logging +from typing import Any, Callable, Tuple + +import torch +import torch.fx.passes.net_min_base as net_min_base +from torch.fx.passes.tools_common import Tensors + +from .. import InputTensorSpec, TRTInterpreter + +from torch_tensorrt.fx import TRTModule + +_LOGGER: logging.Logger = logging.getLogger(__name__) + + +def lower_mod_default( + mod: torch.fx.GraphModule, + inputs: Tensors, + use_experimental_rt: bool = False, +) -> TRTModule: + interp = TRTInterpreter( + mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True + ) + interpreter_result = interp.run() + if use_experimental_rt: + import io + + from torch_tensorrt._Device import Device + from torch_tensorrt._TRTModuleNext import TRTModuleNext + + with io.BytesIO() as engine_bytes: + engine_bytes.write(interpreter_result.engine.serialize()) + engine_str = engine_bytes.getvalue() + + res_mod = TRTModuleNext( + engine_str, + name=str(type(mod)), + input_binding_names=interpreter_result.input_names, + output_binding_names=interpreter_result.output_names, + target_device=Device(f"cuda:{torch.cuda.current_device()}"), + # cuda_graph_batch_size=lower_setting.cuda_graph_batch_size, # NOTE: Not sure what this is supposed to do + ) + else: + res_mod = TRTModule( + interpreter_result.engine, + interpreter_result.input_names, + interpreter_result.output_names, + ) + return res_mod + + +class TensorRTMinizerSetting(net_min_base._MinimizerSettingBase): + def __init__( + self, explicit_batch_dimension: Any = True, use_experimental_rt: bool = False + ): + if use_experimental_rt and not explicit_batch_dimension: + raise ValueError( + "The experimental unifed runtime only supports explicit batch. Please make sure to set explicit_batch_dimension=True when use_experimental_rt=True" + ) + + self.explicit_batch_dimension = explicit_batch_dimension + self.use_experimental_rt = use_experimental_rt + super(TensorRTMinizerSetting, self).__init__() + + +class TensorRTMinimizer(net_min_base._MinimizerBase): + def __init__( + self, + module: torch.fx.GraphModule, + sample_input: Tensors, + compare_fn: Callable[[Any, Any, Any], Tuple[float, bool]], + settings: TensorRTMinizerSetting = TensorRTMinizerSetting(), + lower_fn: Callable[ + [torch.fx.GraphModule, Tensors, Any, bool], TRTModule + ] = lower_mod_default, + ): + self.lower_fn = lower_fn + self.use_experiemental_rt = settings.use_experimental_rt + super().__init__(module, sample_input, compare_fn, settings) + + def run_a(self, mod, inputs): + mod.eval() + with torch.no_grad(): + return mod(*inputs) + + def run_b(self, mod, inputs): + mod.eval() + try: + mod = self.lower_fn(mod, inputs, self.use_experiemental_rt) + output = mod(*inputs) + except RuntimeError as e: + raise net_min_base.FxNetMinimizerRunFuncError( + f"Encounter an error when processing \n{mod.graph}\n {e}" + ) + else: + return output + + def get_nodes(self, start=None, end=None, enable_print=False): + nodes = self._collect_nodes(start, end) + if enable_print: + _LOGGER.info(f"Nodes fetched from start {start} to end {end} as: {nodes}") + return nodes diff --git a/py/torch_tensorrt/dynamo/test/conftest.py b/py/torch_tensorrt/dynamo/test/conftest.py new file mode 100644 index 0000000000..98be643435 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/conftest.py @@ -0,0 +1,18 @@ +import pytest + + +def pytest_addoption(parser): + parser.addoption( + "--ir", + metavar="Internal Representation", + nargs=1, + type=str, + required=True, + help="IR to compile with", + choices=["torch_compile", "fx_ts_compat"], + ) + + +@pytest.fixture +def ir(request): + return request.config.getoption("--ir")[0] diff --git a/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py b/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py new file mode 100644 index 0000000000..4852f033bd --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py @@ -0,0 +1,144 @@ +import torch +import timm +import pytest + +import torch_tensorrt as torchtrt +import torchvision.models as models + +from transformers import BertModel + +from utils import COSINE_THRESHOLD, cosine_similarity + + +@pytest.mark.unit +def test_resnet18(ir): + model = models.resnet18(pretrained=True).eval().to("cuda") + input = torch.randn((1, 3, 224, 224)).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + input.shape, dtype=torch.float, format=torch.contiguous_format + ) + ], + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float}, + "ir": ir, + } + + trt_mod = torchtrt.compile(model, **compile_spec) + cos_sim = cosine_similarity(model(input), trt_mod(input)) + assert ( + cos_sim > COSINE_THRESHOLD, + f"Resnet50 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + +@pytest.mark.unit +def test_mobilenet_v2(ir): + model = models.mobilenet_v2(pretrained=True).eval().to("cuda") + input = torch.randn((1, 3, 224, 224)).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + input.shape, dtype=torch.float, format=torch.contiguous_format + ) + ], + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float}, + "ir": ir, + } + + trt_mod = torchtrt.compile(model, **compile_spec) + cos_sim = cosine_similarity(model(input), trt_mod(input)) + assert ( + cos_sim > COSINE_THRESHOLD, + f"Mobilenet v2 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + +@pytest.mark.unit +def test_efficientnet_b0(ir): + model = timm.create_model("efficientnet_b0", pretrained=True).eval().to("cuda") + input = torch.randn((1, 3, 224, 224)).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + input.shape, dtype=torch.float, format=torch.contiguous_format + ) + ], + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float}, + "ir": ir, + } + + trt_mod = torchtrt.compile(model, **compile_spec) + cos_sim = cosine_similarity(model(input), trt_mod(input)) + assert ( + cos_sim > COSINE_THRESHOLD, + f"EfficientNet-B0 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + +@pytest.mark.unit +def test_bert_base_uncased(ir): + model = BertModel.from_pretrained("bert-base-uncased").cuda().eval() + input = torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda") + input2 = torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + input.shape, + dtype=input.dtype, + format=torch.contiguous_format, + ), + torchtrt.Input( + input.shape, + dtype=input.dtype, + format=torch.contiguous_format, + ), + ], + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float}, + "truncate_long_and_double": True, + "debug": True, + "ir": ir, + } + trt_mod = torchtrt.compile(model, **compile_spec) + + model_outputs = model(input, input2) + trt_model_outputs = trt_mod(input, input2) + for key in model_outputs.keys(): + out, trt_out = model_outputs[key], trt_model_outputs[key] + cos_sim = cosine_similarity(out, trt_out) + assert ( + cos_sim > COSINE_THRESHOLD, + f"HF BERT base-uncased TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + +@pytest.mark.unit +def test_resnet18_half(ir): + model = models.resnet18(pretrained=True).eval().to("cuda").half() + input = torch.randn((1, 3, 224, 224)).to("cuda").half() + + compile_spec = { + "inputs": [ + torchtrt.Input( + input.shape, dtype=torch.half, format=torch.contiguous_format + ) + ], + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.half}, + "ir": ir, + } + + trt_mod = torchtrt.compile(model, **compile_spec) + cos_sim = cosine_similarity(model(input), trt_mod(input)) + assert ( + cos_sim > COSINE_THRESHOLD, + f"Resnet50 Half TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) diff --git a/py/torch_tensorrt/dynamo/test/utils.py b/py/torch_tensorrt/dynamo/test/utils.py new file mode 100644 index 0000000000..b1e6632ec3 --- /dev/null +++ b/py/torch_tensorrt/dynamo/test/utils.py @@ -0,0 +1,15 @@ +import torch + +COSINE_THRESHOLD = 0.99 + + +def cosine_similarity(gt_tensor, pred_tensor): + gt_tensor = gt_tensor.flatten().to(torch.float32) + pred_tensor = pred_tensor.flatten().to(torch.float32) + if torch.sum(gt_tensor) == 0.0 or torch.sum(pred_tensor) == 0.0: + if torch.allclose(gt_tensor, pred_tensor, atol=1e-4, rtol=1e-4, equal_nan=True): + return 1.0 + res = torch.nn.functional.cosine_similarity(gt_tensor, pred_tensor, dim=0, eps=1e-6) + res = res.cpu().detach().item() + + return res diff --git a/py/torch_tensorrt/dynamo/torch_compile/__init__.py b/py/torch_tensorrt/dynamo/torch_compile/__init__.py new file mode 100644 index 0000000000..32e5567c51 --- /dev/null +++ b/py/torch_tensorrt/dynamo/torch_compile/__init__.py @@ -0,0 +1,126 @@ +import torch +import logging +import collections.abc +import torch_tensorrt +from functools import partial + +from typing import Any +from torch_tensorrt import EngineCapability, Device +from torch_tensorrt.fx.utils import LowerPrecision + +from torch_tensorrt.dynamo.torch_compile._settings import CompilationSettings +from torch_tensorrt.dynamo.torch_compile.utils import prepare_inputs, prepare_device +from torch_tensorrt.dynamo.torch_compile.backends import tensorrt_backend +from torch_tensorrt.dynamo.torch_compile._defaults import ( + PRECISION, + DEBUG, + MAX_WORKSPACE_SIZE, + MAX_NUM_TRT_ENGINES, +) + + +logger = logging.getLogger(__name__) + + +def compile( + gm: torch.nn.Module, + inputs: Any, + *, + device=Device._current_device(), + disable_tf32=False, + sparse_weights=False, + enabled_precisions=set(), + refit=False, + debug=DEBUG, + capability=EngineCapability.default, + num_avg_timing_iters=1, + workspace_size=MAX_WORKSPACE_SIZE, + dla_sram_size=1048576, + dla_local_dram_size=1073741824, + dla_global_dram_size=536870912, + calibrator=None, + truncate_long_and_double=False, + require_full_compilation=False, + min_block_size=3, + torch_executed_ops=[], + torch_executed_modules=[], + **kwargs, +): + + logger.warn( + "The Dynamo backend is an experimental feature, for which only the " + + "following arguments are supported: " + + "{enabled_precisions, debug, workspace_size, max_num_trt_engines}" + ) + + if not isinstance(inputs, collections.abc.Sequence): + inputs = [inputs] + + inputs = prepare_inputs(inputs, prepare_device(device)) + + if ( + torch.float16 in enabled_precisions + or torch_tensorrt.dtype.half in enabled_precisions + ): + lower_precision = LowerPrecision.FP16 + elif ( + torch.float32 in enabled_precisions + or torch_tensorrt.dtype.float in enabled_precisions + ): + lower_precision = LowerPrecision.FP32 + elif len(enabled_precisions) == 0: + logger.info(f"No precision specified, defaulting to {PRECISION}") + lower_precision = PRECISION + else: + raise ValueError( + f"Precision {enabled_precisions} not supported in the Dynamo Path" + ) + + custom_backend = create_backend( + precision=lower_precision, + debug=debug, + workspace_size=workspace_size, + **kwargs, + ) + + model = torch.compile(gm, backend=custom_backend) + + # Ensure compilation occurs by calling the function with provided inputs + model(*inputs) + + return model + + +from torch_tensorrt.fx.utils import LowerPrecision + +logger = logging.getLogger(__name__) + + +def create_backend( + precision: LowerPrecision = PRECISION, + debug: bool = DEBUG, + workspace_size: int = MAX_WORKSPACE_SIZE, + max_num_trt_engines: int = MAX_NUM_TRT_ENGINES, + **kwargs, +): + """Create torch.compile backend given specified arguments + + Args: + precision: + debug: Whether to print out verbose debugging information + workspace_size: Maximum workspace TRT is allowed to use for the module + precision: Model Layer precision + Returns: + Backend for torch.compile + """ + settings = CompilationSettings( + debug=debug, + precision=precision, + workspace_size=workspace_size, + max_num_trt_engines=max_num_trt_engines, + ) + + return partial( + tensorrt_backend, + settings=settings, + ) diff --git a/py/torch_tensorrt/dynamo/torch_compile/_defaults.py b/py/torch_tensorrt/dynamo/torch_compile/_defaults.py new file mode 100644 index 0000000000..48c9a26f9e --- /dev/null +++ b/py/torch_tensorrt/dynamo/torch_compile/_defaults.py @@ -0,0 +1,7 @@ +from torch_tensorrt.fx.utils import LowerPrecision + + +PRECISION = LowerPrecision.FP32 +DEBUG = False +MAX_WORKSPACE_SIZE = 20 << 30 +MAX_NUM_TRT_ENGINES = 200 diff --git a/py/torch_tensorrt/dynamo/torch_compile/_settings.py b/py/torch_tensorrt/dynamo/torch_compile/_settings.py new file mode 100644 index 0000000000..276b8742ff --- /dev/null +++ b/py/torch_tensorrt/dynamo/torch_compile/_settings.py @@ -0,0 +1,17 @@ +from dataclasses import dataclass + +from torch_tensorrt.fx.utils import LowerPrecision +from torch_tensorrt.dynamo.torch_compile._defaults import ( + PRECISION, + DEBUG, + MAX_WORKSPACE_SIZE, + MAX_NUM_TRT_ENGINES, +) + + +@dataclass(frozen=True) +class CompilationSettings: + precision: LowerPrecision = PRECISION + debug: bool = DEBUG + workspace_size: int = MAX_WORKSPACE_SIZE + max_num_trt_engines: int = MAX_NUM_TRT_ENGINES diff --git a/py/torch_tensorrt/dynamo/torch_compile/backends.py b/py/torch_tensorrt/dynamo/torch_compile/backends.py new file mode 100644 index 0000000000..9ceab947f0 --- /dev/null +++ b/py/torch_tensorrt/dynamo/torch_compile/backends.py @@ -0,0 +1,118 @@ +from typing import Sequence +import torch +import traceback +from functools import partial +import torch._dynamo as td + +from torch_tensorrt.dynamo.torch_compile._settings import CompilationSettings +from torch_tensorrt.dynamo.torch_compile.lowering._decompositions import ( + get_decompositions, +) +from torch_tensorrt.dynamo.torch_compile.lowering._partition import ( + partition, + get_submod_inputs, +) +from torch_tensorrt.dynamo.torch_compile.conversion import convert_module + +from torch._dynamo.backends.common import fake_tensor_unsupported + +from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler + + +@td.register_backend(name="tensorrt") +@fake_tensor_unsupported +def tensorrt_backend( + gm: torch.nn.Module, + sample_inputs: Sequence[torch.Tensor], + settings: CompilationSettings = CompilationSettings(), +): + custom_backend = partial( + fx_dynamo_backend, + settings=settings, + ) + + # Invoke AOTAutograd to translate operators to aten + return aot_module_simplified( + gm, + sample_inputs, + fw_compiler=make_boxed_compiler(custom_backend), + decompositions=get_decompositions(), + ) + + +@td.register_backend(name="fx_tensorrt") +@fake_tensor_unsupported +def fx_dynamo_backend( + gm: torch.fx.GraphModule, + example_inputs: Sequence[torch.Tensor], + settings: CompilationSettings = CompilationSettings(), +): + """Helper function to manage translation of FX module to TRT engines + + Args: + module: FX GraphModule to convert + inputs: Inputs to the module + settings: Compilation settings + Returns: + Compiled FX GraphModule + """ + try: + trt_compiled = compile_module( + gm, + example_inputs, + settings=settings, + ) + return trt_compiled + except: + traceback.print_exc() + print( + "FX2TRT conversion failed on the subgraph. See trace above. " + + "Returning GraphModule forward instead." + ) + return gm.forward + + +def compile_module( + gm: torch.fx.GraphModule, + example_inputs: Sequence[torch.Tensor], + settings: CompilationSettings = CompilationSettings(), +) -> torch.fx.GraphModule: + """Compile an FX module + + Includes: Partitioning + Conversion Phases + + Args: + module: FX GraphModule to convert + inputs: Inputs to the module + settings: Compilation settings + Returns: + Compiled FX GraphModule + """ + # Partition module into components that can be TRT-accelerated + partitioned_module = partition( + gm, verbose=settings.debug, max_num_trt_engines=settings.max_num_trt_engines + ) + + # Iterate over all components that can be accelerated + # Generate the corresponding TRT Module for those + for name, _ in partitioned_module.named_children(): + submodule = getattr(partitioned_module, name) + + # Get submodule inputs + submodule_inputs = get_submod_inputs( + partitioned_module, submodule, example_inputs + ) + + # Create TRT Module from submodule + trt_mod = convert_module( + submodule, + submodule_inputs, + debug=settings.debug, + workspace_size=settings.workspace_size, + precision=settings.precision, + ) + + # Replace FX Module with TRT Module + setattr(partitioned_module, name, trt_mod) + + return partitioned_module diff --git a/py/torch_tensorrt/dynamo/torch_compile/conversion.py b/py/torch_tensorrt/dynamo/torch_compile/conversion.py new file mode 100644 index 0000000000..4f495dad4b --- /dev/null +++ b/py/torch_tensorrt/dynamo/torch_compile/conversion.py @@ -0,0 +1,48 @@ +from typing import Sequence, Union +import torch +from torch_tensorrt.fx.trt_module import TRTModule +from torch_tensorrt import TRTModuleNext +from torch_tensorrt.fx.fx2trt import ( + InputTensorSpec, + TRTInterpreter, +) +from torch_tensorrt.fx.utils import LowerPrecision + +import tensorrt as trt + + +def convert_module( + module: torch.fx.GraphModule, + inputs: Sequence[torch.Tensor], + debug: bool = False, + workspace_size: int = 20 << 30, + precision: LowerPrecision = LowerPrecision.FP32, +) -> Union[TRTModuleNext, TRTModule]: + """Convert an FX module to a TRT module + Args: + module: FX GraphModule to convert + inputs: Sequence of Tensors representing inputs to the module + debug: Whether to print out verbose debugging information + workspace_size: Maximum workspace TRT is allowed to use for the module + precision: Model Layer precision + Returns: + TRTModule or TRTModuleNext + """ + interp = TRTInterpreter( + module, + InputTensorSpec.from_tensors(inputs), + explicit_batch_dimension=True, + logger_level=(trt.Logger.VERBOSE if debug else trt.Logger.WARNING), + ) + + r = interp.run( + max_workspace_size=workspace_size, + lower_precision=precision, + profiling_verbosity=( + trt.ProfilingVerbosity.VERBOSE + if debug + else trt.ProfilingVerbosity.LAYER_NAMES_ONLY + ), + ) + + return TRTModule(*r) diff --git a/py/torch_tensorrt/dynamo/torch_compile/lowering/__init__.py b/py/torch_tensorrt/dynamo/torch_compile/lowering/__init__.py new file mode 100644 index 0000000000..e0a41df755 --- /dev/null +++ b/py/torch_tensorrt/dynamo/torch_compile/lowering/__init__.py @@ -0,0 +1,7 @@ +from torch_tensorrt.dynamo.torch_compile.lowering._decompositions import ( + get_decompositions, +) +from torch_tensorrt.dynamo.torch_compile.lowering._partition import ( + partition, + get_submod_inputs, +) diff --git a/py/torch_tensorrt/dynamo/torch_compile/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/torch_compile/lowering/_decompositions.py new file mode 100644 index 0000000000..7aff1a79d1 --- /dev/null +++ b/py/torch_tensorrt/dynamo/torch_compile/lowering/_decompositions.py @@ -0,0 +1,45 @@ +import torch +from torch._decomp import register_decomposition, core_aten_decompositions + + +DECOMPOSITIONS = {**core_aten_decompositions()} + +aten = torch.ops.aten + + +def replace_inplace_op(aten_op, outplace_op): + """Replace inplace operation with functional equivalent + Adapted from: + https://github.com/pytorch/pytorch/blob/3344d79e3f732dadd5c85b99a7aa1a022f187929/torch/_decomp/decompositions.py#L3355-L3361 + """ + + @register_decomposition(aten_op, registry=DECOMPOSITIONS) + def inplace_op(*args, **kwargs): + out = outplace_op(*args, **kwargs) + return args[0].copy_(out) + + return inplace_op + + +replace_inplace_op(aten.add_, aten.add) +replace_inplace_op(aten.addbmm_, aten.addbmm) +replace_inplace_op(aten.addmm_, aten.addmm) +replace_inplace_op(aten.addmv_, aten.addmv) +replace_inplace_op(aten.baddbmm_, aten.baddbmm) +replace_inplace_op(aten.cumprod_, aten.cumprod) +replace_inplace_op(aten.fill_, aten.fill) +replace_inplace_op(aten.gelu_, aten.gelu) +replace_inplace_op(aten.hardsigmoid_, aten.hardsigmoid) +replace_inplace_op(aten.index_put_, aten.index_put) +replace_inplace_op(aten.index_reduce_, aten.index_reduce) +replace_inplace_op(aten.logit_, aten.logit) +replace_inplace_op(aten.relu_, aten.relu) +replace_inplace_op(aten.renorm_, aten.renorm) +replace_inplace_op(aten.round_, aten.round) +replace_inplace_op(aten.scatter_, aten.scatter) +replace_inplace_op(aten.scatter_add_, aten.scatter_add) +replace_inplace_op(aten.scatter_reduce_, aten.scatter_reduce) + + +def get_decompositions(): + return DECOMPOSITIONS diff --git a/py/torch_tensorrt/dynamo/torch_compile/lowering/_partition.py b/py/torch_tensorrt/dynamo/torch_compile/lowering/_partition.py new file mode 100644 index 0000000000..1dd38e0bd9 --- /dev/null +++ b/py/torch_tensorrt/dynamo/torch_compile/lowering/_partition.py @@ -0,0 +1,117 @@ +from typing import Dict, Optional, Sequence + +import torch + +from torch_tensorrt.dynamo.torch_compile._defaults import MAX_NUM_TRT_ENGINES +from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner +from torch.fx.passes.operator_support import OperatorSupport + +from torch_tensorrt.fx.converter_registry import CONVERTERS + + +class TorchTensorRTOperatorSupport(OperatorSupport): + """Class to determine whether operators within a module are supported""" + + def __init__(self, support_dict=None): + super().__init__(support_dict) + + # Initialize sets of supported/unsupported operators + self.supported_operators = set() + self.unsupported_operators = set() + + def is_node_supported( + self, submodules: Dict[str, torch.nn.Module], node: torch.fx.Node + ) -> bool: + if node.target in CONVERTERS.keys(): + # If node is a proper computational node, store the operator + if not node.is_impure(): + node_name = node._pretty_print_target(node.target) + self.supported_operators.add(node_name) + + return True + else: + if not node.is_impure(): + node_name = node._pretty_print_target(node.target) + self.unsupported_operators.add(node_name) + + return False + + def print_support_overview(self, num_trt_blocks: Optional[int] = None): + if num_trt_blocks is not None: + print(f"\nNumber of TensorRT-Accelerated Subgraphs: {num_trt_blocks}") + + print("\nSupported Nodes:") + for node_name in self.supported_operators: + print("-", node_name) + + if len(self.unsupported_operators) != 0: + print("\nUnsupported Nodes:") + for node_name in self.unsupported_operators: + print("-", node_name) + print("\n") + else: + print("\nAll Nodes Supported\n") + + +def partition( + gm: torch.fx.GraphModule, + verbose: bool = True, + max_num_trt_engines: int = MAX_NUM_TRT_ENGINES, +) -> torch.fx.GraphModule: + """Partition an FX GraphModule with aten ops into TRT engines + Partitioning is based on converter operator support + + Args: + gm: FX GraphModule to partition + verbose: Bool representing whether to print operator support + max_num_trt_engines: Maximum number of allowed TRT engines in partitioning + Returns: + torch.fx.GraphModule + """ + supported_ops = TorchTensorRTOperatorSupport() + partitioner = CapabilityBasedPartitioner(gm, supported_ops) + + # Determine partitions, and raise error if the degree of partitioning + # exceeds a specified threshold + partitions = partitioner.propose_partitions() + num_blocks = len(partitions) + if num_blocks > max_num_trt_engines: + raise AssertionError( + f"The graph module has {num_blocks} TRT Engines which is larger than the " + + f"threshold={max_num_trt_engines}. Falling back to non-TRT module." + ) + + # Fuse partitions and display overview of supported/unsupported operators + fused_graph = partitioner.fuse_partitions(partitions) + num_blocks = len(partitions) + + if verbose: + supported_ops.print_support_overview(num_blocks) + + return fused_graph + + +def get_submod_inputs( + mod: torch.fx.GraphModule, + submod: torch.fx.GraphModule, + inputs: Sequence[torch.Tensor], +) -> Sequence[torch.Tensor]: + """Helper function to get inputs to a Torch submodule + + Args: + mod: Parent FX GraphModule + submod: Child FX GraphModule + inputs: Sample inputs to parent module + Returns: + Sequence of Tensors representing inputs to child module + """ + acc_inputs = None + + def get_input(self, inputs): + nonlocal acc_inputs + acc_inputs = inputs + + handle = submod.register_forward_pre_hook(get_input) + mod(*inputs) + handle.remove() + return acc_inputs diff --git a/py/torch_tensorrt/dynamo/torch_compile/test/test_compiler_utils.py b/py/torch_tensorrt/dynamo/torch_compile/test/test_compiler_utils.py new file mode 100644 index 0000000000..da7157c3e5 --- /dev/null +++ b/py/torch_tensorrt/dynamo/torch_compile/test/test_compiler_utils.py @@ -0,0 +1,57 @@ +from torch_tensorrt.dynamo.torch_compile.utils import prepare_device, prepare_inputs +from utils import same_output_format +import torch_tensorrt +import unittest +import torch + + +class TestPrepareDevice(unittest.TestCase): + def test_prepare_cuda_device(self): + gpu_id = 0 + device = torch.device(f"cuda:{gpu_id}") + prepared_device = prepare_device(device) + self.assertTrue(isinstance(prepared_device, torch.device)) + self.assertTrue(prepared_device.index == gpu_id) + + def test_prepare_trt_device(self): + gpu_id = 4 + device = torch_tensorrt.Device(gpu_id=gpu_id) + prepared_device = prepare_device(device) + self.assertTrue(isinstance(prepared_device, torch.device)) + self.assertTrue(prepared_device.index == gpu_id) + + +class TestPrepareInputs(unittest.TestCase): + def test_prepare_single_tensor_input(self): + inputs = [torch.ones((4, 4))] + prepared_inputs = prepare_inputs(inputs) + self.assertTrue( + same_output_format(inputs, prepared_inputs, enforce_tensor_type=False) + ) + + def test_prepare_trt_input(self): + inputs = [torch_tensorrt.Input(shape=(4, 3), dtype=torch.float)] + prepared_inputs = prepare_inputs(inputs) + self.assertTrue( + same_output_format(inputs, prepared_inputs, enforce_tensor_type=False) + ) + + def test_prepare_mixed_type_compound_tensor_input(self): + inputs = { + "first": [ + torch.ones((4, 4)), + torch_tensorrt.Input(shape=(4, 3), dtype=torch.float), + ], + "second": ( + torch.rand((5, 1)), + (torch.rand((5, 1)), torch_tensorrt.Input(shape=(2, 3))), + ), + } + prepared_inputs = prepare_inputs(inputs) + self.assertTrue( + same_output_format(inputs, prepared_inputs, enforce_tensor_type=False) + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/py/torch_tensorrt/dynamo/torch_compile/test/test_lowering.py b/py/torch_tensorrt/dynamo/torch_compile/test/test_lowering.py new file mode 100644 index 0000000000..d14acb815b --- /dev/null +++ b/py/torch_tensorrt/dynamo/torch_compile/test/test_lowering.py @@ -0,0 +1,54 @@ +from functools import partial +from utils import fx_dynamo_testing_backend +from torch.testing._internal.common_utils import run_tests, TestCase +import torch + + +class TestLowering(TestCase): + def test_lowering_inplace_op(self): + class FullySupported(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, x, y): + x = torch.ops.aten.add_.Tensor(x, y) + x = torch.ops.aten.relu_.default(x) + return x + + # Operations expected to be included in the traced graph after decompositions + expected_ops = {torch.ops.aten.add.Tensor, torch.ops.aten.relu.default} + + # Trace module and set up custom backend to track intermediate graphs + fx_graph = torch.fx.symbolic_trace(FullySupported()) + partitioned_graphs = [] + custom_backend = partial( + fx_dynamo_testing_backend, + store_intermediate_graphs=partitioned_graphs, + ) + + # Invoke compilation + compiled_graph = torch.compile(fx_graph, backend=custom_backend) + compiled_graph( + torch.rand( + 5, + ).cuda(), + torch.rand( + 5, + ).cuda(), + ) + + # Iterate over intermediate graphs, attempt to match nodes + for fx_module in partitioned_graphs: + for _, submodule in fx_module.named_children(): + for node in submodule.graph.nodes: + + if node.op == "call_function" and node.target in expected_ops: + expected_ops.remove(node.target) + + self.assertEqual( + len(expected_ops), 0, "All operators should have been decomposed" + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/torch_compile/test/test_partitioning.py b/py/torch_tensorrt/dynamo/torch_compile/test/test_partitioning.py new file mode 100644 index 0000000000..b068f9c413 --- /dev/null +++ b/py/torch_tensorrt/dynamo/torch_compile/test/test_partitioning.py @@ -0,0 +1,68 @@ +from torch_tensorrt.dynamo.torch_compile.lowering import partition +from torch.testing._internal.common_utils import run_tests, TestCase +import torch +from copy import deepcopy +import numpy as np + + +class TestPartitioning(TestCase): + def test_partition_fully_supported_one_op(self): + class FullySupportedOneOp(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, x, y): + return torch.ops.aten.add.Tensor(x, y) + + fx_graph = torch.fx.symbolic_trace(FullySupportedOneOp()) + partitioned_graph = partition(deepcopy(fx_graph)) + self.assertEqual( + len(list(partitioned_graph.named_children())), + 0, + "Single operators should not be segmented", + ) + + def test_partition_fully_supported_multi_op(self): + class FullySupportedMultiOp(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, x, y): + sum_ = torch.ops.aten.sub.Tensor(x, y) + concat_ = torch.ops.aten.cat.default(x, sum_) + relu_ = torch.ops.aten.relu.default(concat_) + pow_ = torch.ops.aten.pow.Tensor_Scalar(relu_, 2) + return pow_ + + fx_graph = torch.fx.symbolic_trace(FullySupportedMultiOp()) + partitioned_graph = partition(deepcopy(fx_graph)) + self.assertEqual( + len(list(partitioned_graph.named_children())), + 1, + "All operators are supported, there should be one segment", + ) + + def test_partition_partially_supported_multi_op(self): + class PartiallySupportedMultiOp(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, x, y): + sum_1 = torch.ops.aten.add.Tensor(x, y) + sum_2 = torch.ops.aten.add.Tensor(x, sum_1) + sum_ = np.sum(sum_1) + np.sum(sum_2) + relu_ = torch.ops.aten.relu.default(sum_) + pow_ = torch.ops.aten.pow.Tensor_Scalar(relu_, 2) + return pow_ + + fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp()) + partitioned_graph = partition(deepcopy(fx_graph)) + self.assertEqual( + len(list(partitioned_graph.named_children())), + 2, + "Unsupported operators interleave supported ones, expected 2 segments", + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/torch_compile/test/utils.py b/py/torch_tensorrt/dynamo/torch_compile/test/utils.py new file mode 100644 index 0000000000..bdcbbfcc4a --- /dev/null +++ b/py/torch_tensorrt/dynamo/torch_compile/test/utils.py @@ -0,0 +1,94 @@ +from copy import deepcopy +from functools import partial +from typing import List, Sequence +import torch +from torch_tensorrt.dynamo.torch_compile.lowering._decompositions import ( + get_decompositions, +) +from torch_tensorrt.dynamo.torch_compile.lowering._partition import ( + partition, +) + +from torch._dynamo.backends.common import fake_tensor_unsupported + +from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler + + +@fake_tensor_unsupported +def fx_dynamo_testing_backend( + gm: torch.fx.GraphModule, + sample_inputs: Sequence[torch.Tensor], + *, + store_intermediate_graphs: List, +): + """Helper Dynamo backend exclusively for testing""" + custom_backend = partial( + compile_module_testing, + store_intermediate_graphs=store_intermediate_graphs, + ) + + # Invoke AOTAutograd to translate operators to aten + return aot_module_simplified( + gm, + sample_inputs, + fw_compiler=make_boxed_compiler(custom_backend), + decompositions=get_decompositions(), + ) + + +def compile_module_testing( + gm: torch.fx.GraphModule, + example_inputs: Sequence[torch.Tensor], + *, + store_intermediate_graphs: List, +) -> torch.fx.GraphModule: + """Helper compiler exclusively for testing""" + partitioned_module = partition(gm) + + # Store intermediate graph from partitioned module + store_intermediate_graphs.append(deepcopy(partitioned_module)) + + return partitioned_module + + +def same_output_format(trt_output, torch_output, enforce_tensor_type=True): + # For each encountered collection type, ensure the torch and trt outputs agree + # on type and size, checking recursively through all member elements. + if isinstance(trt_output, tuple): + return ( + isinstance(torch_output, tuple) + and (len(trt_output) == len(torch_output)) + and all( + same_output_format(trt_entry, torch_entry, enforce_tensor_type) + for trt_entry, torch_entry in zip(trt_output, torch_output) + ) + ) + elif isinstance(trt_output, list): + return ( + isinstance(torch_output, list) + and (len(trt_output) == len(torch_output)) + and all( + same_output_format(trt_entry, torch_entry, enforce_tensor_type) + for trt_entry, torch_entry in zip(trt_output, torch_output) + ) + ) + elif isinstance(trt_output, dict): + return ( + isinstance(torch_output, dict) + and (len(trt_output) == len(torch_output)) + and (trt_output.keys() == torch_output.keys()) + and all( + same_output_format( + trt_output[key], torch_output[key], enforce_tensor_type + ) + for key in trt_output.keys() + ) + ) + elif isinstance(trt_output, set) or isinstance(trt_output, frozenset): + raise AssertionError( + "Unsupported output type 'set' encountered in output format check." + ) + elif enforce_tensor_type: + return type(trt_output) is type(torch_output) + else: + return True diff --git a/py/torch_tensorrt/dynamo/torch_compile/utils.py b/py/torch_tensorrt/dynamo/torch_compile/utils.py new file mode 100644 index 0000000000..ba76536338 --- /dev/null +++ b/py/torch_tensorrt/dynamo/torch_compile/utils.py @@ -0,0 +1,68 @@ +import torch + +from typing import Any, Union, Sequence, Dict +from torch_tensorrt import _Input, Device + + +def prepare_inputs( + inputs: Union[_Input.Input, torch.Tensor, Sequence, Dict], + device: torch.device = torch.device("cuda"), +) -> Any: + if isinstance(inputs, _Input.Input): + if isinstance(inputs.shape, dict): + return inputs.example_tensor(optimization_profile_field="opt_shape").to( + device + ) + else: + return inputs.example_tensor().to(device) + + elif isinstance(inputs, torch.Tensor): + return inputs + + elif isinstance(inputs, list): + prepared_input = list() + + for input_obj in inputs: + prepared_input.append(prepare_inputs(input_obj)) + + return prepared_input + + elif isinstance(inputs, tuple): + prepared_input = list() + + for input_obj in inputs: + prepared_input.append(prepare_inputs(input_obj)) + + return tuple(prepared_input) + + elif isinstance(inputs, dict): + prepared_input = dict() + + for key, input_obj in inputs.items(): + prepared_input[key] = prepare_inputs(input_obj) + + return prepared_input + + else: + raise ValueError( + f"Invalid input type {type(inputs)} encountered in the torch_compile input parsing. " + + "Allowed input types: {torch_tensorrt.Input, torch.Tensor, list, tuple, dict}" + ) + + +def prepare_device(device: Union[Device, torch.device]) -> torch.device: + if isinstance(device, Device): + if device.gpu_id != -1: + device = torch.device(device.gpu_id) + else: + raise ValueError("Invalid GPU ID provided for the CUDA device provided") + + elif isinstance(device, torch.device): + device = device + + else: + raise ValueError( + "Invalid device provided. Supported options: torch.device | torch_tensorrt.Device" + ) + + return device diff --git a/py/torch_tensorrt/ts/__init__.py b/py/torch_tensorrt/ts/__init__.py index ddee197aef..47ef249e55 100644 --- a/py/torch_tensorrt/ts/__init__.py +++ b/py/torch_tensorrt/ts/__init__.py @@ -1,2 +1,3 @@ from torch_tensorrt.ts._compiler import * from torch_tensorrt.ts._compile_spec import TensorRTCompileSpec +from torch_tensorrt.ts.ts_input import TSInput diff --git a/py/torch_tensorrt/ts/_compile_spec.py b/py/torch_tensorrt/ts/_compile_spec.py index 8f06e2ef71..f17a9fa5bf 100644 --- a/py/torch_tensorrt/ts/_compile_spec.py +++ b/py/torch_tensorrt/ts/_compile_spec.py @@ -9,6 +9,8 @@ from typing import Tuple, List, Dict import warnings from copy import deepcopy +from torch_tensorrt.ts.ts_input import TSInput +import tensorrt as trt def _internal_input_to_torch_class_input(i: _C.Input) -> torch.classes.tensorrt._Input: @@ -38,46 +40,6 @@ def _supported_input_size_type(input_size: Any) -> bool: ) -def _parse_input_ranges(input_sizes: List) -> List: - - if any( - not isinstance(i, dict) and not _supported_input_size_type(i) - for i in input_sizes - ): - raise KeyError( - "An input size must either be a static size or a range of three sizes (min, opt, max) as Dict" - ) - - parsed_input_sizes = [] - for i in input_sizes: - if isinstance(i, dict): - if all(k in i for k in ["min", "opt", "min"]): - parsed_input_sizes.append( - Input( - min_shape=i["min"], opt_shape=i["opt"], max_shape=i["max"] - )._to_internal() - ) - - elif "opt" in i: - parsed_input_sizes.append(Input(shape=i["opt"])._to_internal()) - - else: - raise KeyError( - "An input size must either be a static size or a range of three sizes (min, opt, max) as Dict" - ) - - elif isinstance(i, list): - parsed_input_sizes.append(Input(shape=i)._to_internal()) - - elif isinstance(i, tuple): - parsed_input_sizes.append(Input(shape=i)._to_internal()) - - elif isinstance(i, torch.Size): - parsed_input_sizes.append(Input(shape=i)._to_internal()) - - return parsed_input_sizes - - def _parse_op_precision(precision: Any) -> _enums.dtype: if isinstance(precision, torch.dtype): if precision == torch.int8: @@ -115,20 +77,24 @@ def _parse_enabled_precisions(precisions: Any) -> Set: def _parse_device_type(device: Any) -> _enums.DeviceType: if isinstance(device, torch.device): if device.type == "cuda": - return _enums.DeviceType.gpu + return _C.DeviceType.gpu else: ValueError( "Got a device type other than GPU or DLA (type: " + str(device.type) + ")" ) - elif isinstance(device, _enums.DeviceType): + elif isinstance(device, _C.DeviceType): return device + elif isinstance(device, trt.DeviceType): + if device == trt.DeviceType.DLA: + return _C.DeviceType.DLA + return _C.DeviceType.GPU elif isinstance(device, str): if device == "gpu" or device == "GPU": - return _enums.DeviceType.gpu + return _C.DeviceType.GPU elif device == "dla" or device == "DLA": - return _enums.DeviceType.dla + return _C.DeviceType.DLA else: ValueError( "Got a device type other than GPU or DLA (type: " + str(device) + ")" @@ -146,7 +112,6 @@ def _parse_device(device_info: Any) -> _C.Device: if "device_type" not in device_info: raise KeyError("Device type is required parameter") else: - assert isinstance(device_info["device_type"], _enums.DeviceType) info.device_type = _parse_device_type(device_info["device_type"]) if "gpu_id" in device_info: @@ -228,7 +193,23 @@ def _parse_input_signature(input_signature: Any, depth: int = 0): + "non-TRT types." ) - clone = _internal_input_to_torch_class_input(i._to_internal()) + ts_i = i + if i.shape_mode == Input._ShapeMode.STATIC: + ts_i = TSInput(shape=i.shape, dtype=i.dtype, format=i.format) + elif i.shape_mode == Input._ShapeMode.DYNAMIC: + ts_i = TSInput( + min_shape=i.shape["min_shape"], + opt_shape=i.shape["opt_shape"], + max_shape=i.shape["max_shape"], + dtype=i.dtype, + format=i.format, + ) + else: + raise ValueError( + "Invalid shape mode detected for input while parsing the input_signature" + ) + + clone = _internal_input_to_torch_class_input(ts_i._to_internal()) return clone else: raise KeyError( @@ -260,7 +241,25 @@ def _parse_compile_spec(compile_spec_: Dict[str, Any]) -> _ts_C.CompileSpec: Input.from_tensor(i) if isinstance(i, torch.Tensor) else i for i in compile_spec["inputs"] ] - info.inputs = [i._to_internal() for i in inputs] + ts_inputs = [] + for i in inputs: + if i.shape_mode == Input._ShapeMode.STATIC: + ts_inputs.append( + TSInput( + shape=i.shape, dtype=i.dtype, format=i.format + )._to_internal() + ) + elif i.shape_mode == Input._ShapeMode.DYNAMIC: + ts_inputs.append( + TSInput( + min_shape=i.shape["min_shape"], + opt_shape=i.shape["opt_shape"], + max_shape=i.shape["max_shape"], + dtype=i.dtype, + format=i.format, + )._to_internal() + ) + info.inputs = ts_inputs elif compile_spec["input_signature"] is not None: log( @@ -268,7 +267,7 @@ def _parse_compile_spec(compile_spec_: Dict[str, Any]) -> _ts_C.CompileSpec: "Input signature parsing is an experimental feature, behavior and APIs may change", ) signature = _parse_input_signature(compile_spec["input_signature"]) - info.input_signature = _C.InputSignature(signature) + info.input_signature = _C.InputSignature(signature) # py_object else: raise KeyError( diff --git a/py/torch_tensorrt/ts/ts_input.py b/py/torch_tensorrt/ts/ts_input.py new file mode 100644 index 0000000000..00055d4f13 --- /dev/null +++ b/py/torch_tensorrt/ts/ts_input.py @@ -0,0 +1,108 @@ +from enum import Enum +from typing import List, Dict, Any, Tuple, Optional + +import torch + +from torch_tensorrt import _C +from torch_tensorrt import _enums +from torch_tensorrt import _Input +from torch_tensorrt._Input import Input + + +class TSInput(Input): + """ + Defines an input to a module in terms of expected shape, data type and tensor format. + + Attributes: + shape_mode (torch_tensorrt.Input._ShapeMode): Is input statically or dynamically shaped + shape (Tuple or Dict): Either a single Tuple or a dict of tuples defining the input shape. + Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form + ``{ + "min_shape": Tuple, + "opt_shape": Tuple, + "max_shape": Tuple + }`` + dtype (torch_tensorrt.dtype): The expected data type of the input tensor (default: torch_tensorrt.dtype.float32) + format (torch_tensorrt.TensorFormat): The expected format of the input tensor (default: torch_tensorrt.TensorFormat.NCHW) + """ + + def __init__(self, *args, **kwargs): + """__init__ Method for torch_tensorrt.Input + + Input accepts one of a few construction patterns + + Args: + shape (Tuple or List, optional): Static shape of input tensor + + Keyword Arguments: + shape (Tuple or List, optional): Static shape of input tensor + min_shape (Tuple or List, optional): Min size of input tensor's shape range + Note: All three of min_shape, opt_shape, max_shape must be provided, there must be no positional arguments, shape must not be defined and implictly this sets Input's shape_mode to DYNAMIC + opt_shape (Tuple or List, optional): Opt size of input tensor's shape range + Note: All three of min_shape, opt_shape, max_shape must be provided, there must be no positional arguments, shape must not be defined and implictly this sets Input's shape_mode to DYNAMIC + max_shape (Tuple or List, optional): Max size of input tensor's shape range + Note: All three of min_shape, opt_shape, max_shape must be provided, there must be no positional arguments, shape must not be defined and implictly this sets Input's shape_mode to DYNAMIC + dtype (torch.dtype or torch_tensorrt.dtype): Expected data type for input tensor (default: torch_tensorrt.dtype.float32) + format (torch.memory_format or torch_tensorrt.TensorFormat): The expected format of the input tensor (default: torch_tensorrt.TensorFormat.NCHW) + tensor_domain (Tuple(float, float), optional): The domain of allowed values for the tensor, as interval notation: [tensor_domain[0], tensor_domain[1]). + Note: Entering "None" (or not specifying) will set the bound to [0, 2) + + Examples: + - Input([1,3,32,32], dtype=torch.float32, format=torch.channel_last) + - Input(shape=(1,3,32,32), dtype=torch_tensorrt.dtype.int32, format=torch_tensorrt.TensorFormat.NCHW) + - Input(min_shape=(1,3,32,32), opt_shape=[2,3,32,32], max_shape=(3,3,32,32)) #Implicitly dtype=torch_tensorrt.dtype.float32, format=torch_tensorrt.TensorFormat.NCHW + """ + super(TSInput, self).__init__(*args, **kwargs) + + def _to_internal(self) -> _C.Input: + internal_in = _C.Input() + if self.shape_mode == Input._ShapeMode.DYNAMIC: + if not Input._supported_input_size_type(self.shape["min_shape"]): + raise TypeError( + "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " + + str(type(self.shape["min_shape"])) + + " for min_shape" + ) + else: + internal_in.min = self.shape["min_shape"] + + if not Input._supported_input_size_type(self.shape["opt_shape"]): + raise TypeError( + "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " + + str(type(self.shape["opt_shape"])) + + " for opt_shape" + ) + else: + internal_in.opt = self.shape["opt_shape"] + + if not Input._supported_input_size_type(self.shape["max_shape"]): + raise TypeError( + "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " + + str(type(self.shape["max_shape"])) + + " for max_shape" + ) + else: + internal_in.max = self.shape["max_shape"] + internal_in.input_is_dynamic = True + else: + if not Input._supported_input_size_type(self.shape): + raise TypeError( + "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " + + str(type(self.shape)) + + " for shape" + ) + else: + internal_in.opt = self.shape + internal_in.input_is_dynamic = False + + if self.dtype != _enums.dtype.unknown: + self._explicit_set_dtype = True + else: + self._explicit_set_dtype = False + + internal_in.dtype = Input._parse_dtype(self.dtype) + internal_in._explicit_set_dtype = self._explicit_set_dtype + internal_in.format = Input._parse_format(self.format) + + internal_in.tensor_domain = Input._parse_tensor_domain(self.tensor_domain) + return internal_in diff --git a/tests/py/api/test_classes.py b/tests/py/api/test_classes.py index 861efd84a7..b9729b9d4d 100644 --- a/tests/py/api/test_classes.py +++ b/tests/py/api/test_classes.py @@ -103,7 +103,8 @@ def test_infer_from_example_tensor(self): example_tensor = torch.randn(shape).half() i = torchtrt.Input.from_tensor(example_tensor) - self.assertTrue(self._verify_correctness(i, target)) + ts_i = torchtrt.ts.TSInput(shape=i.shape, dtype=i.dtype, format=i.format) + self.assertTrue(self._verify_correctness(ts_i, target)) def test_static_shape(self): shape = [1, 3, 255, 255] @@ -118,22 +119,28 @@ def test_static_shape(self): } i = torchtrt.Input(shape) - self.assertTrue(self._verify_correctness(i, target)) + ts_i = torchtrt.ts.TSInput(shape=i.shape, dtype=i.dtype, format=i.format) + self.assertTrue(self._verify_correctness(ts_i, target)) i = torchtrt.Input(tuple(shape)) - self.assertTrue(self._verify_correctness(i, target)) + ts_i = torchtrt.ts.TSInput(shape=i.shape, dtype=i.dtype, format=i.format) + self.assertTrue(self._verify_correctness(ts_i, target)) i = torchtrt.Input(torch.randn(shape).shape) - self.assertTrue(self._verify_correctness(i, target)) + ts_i = torchtrt.ts.TSInput(shape=i.shape, dtype=i.dtype, format=i.format) + self.assertTrue(self._verify_correctness(ts_i, target)) i = torchtrt.Input(shape=shape) - self.assertTrue(self._verify_correctness(i, target)) + ts_i = torchtrt.ts.TSInput(shape=i.shape, dtype=i.dtype, format=i.format) + self.assertTrue(self._verify_correctness(ts_i, target)) i = torchtrt.Input(shape=tuple(shape)) - self.assertTrue(self._verify_correctness(i, target)) + ts_i = torchtrt.ts.TSInput(shape=i.shape, dtype=i.dtype, format=i.format) + self.assertTrue(self._verify_correctness(ts_i, target)) i = torchtrt.Input(shape=torch.randn(shape).shape) - self.assertTrue(self._verify_correctness(i, target)) + ts_i = torchtrt.ts.TSInput(shape=i.shape, dtype=i.dtype, format=i.format) + self.assertTrue(self._verify_correctness(ts_i, target)) def test_data_type(self): shape = [1, 3, 255, 255] @@ -148,10 +155,12 @@ def test_data_type(self): } i = torchtrt.Input(shape, dtype=torchtrt.dtype.half) - self.assertTrue(self._verify_correctness(i, target)) + ts_i = torchtrt.ts.TSInput(shape=i.shape, dtype=i.dtype, format=i.format) + self.assertTrue(self._verify_correctness(ts_i, target)) i = torchtrt.Input(shape, dtype=torch.half) - self.assertTrue(self._verify_correctness(i, target)) + ts_i = torchtrt.ts.TSInput(shape=i.shape, dtype=i.dtype, format=i.format) + self.assertTrue(self._verify_correctness(ts_i, target)) def test_tensor_format(self): shape = [1, 3, 255, 255] @@ -166,10 +175,12 @@ def test_tensor_format(self): } i = torchtrt.Input(shape, format=torchtrt.TensorFormat.channels_last) - self.assertTrue(self._verify_correctness(i, target)) + ts_i = torchtrt.ts.TSInput(shape=i.shape, dtype=i.dtype, format=i.format) + self.assertTrue(self._verify_correctness(ts_i, target)) i = torchtrt.Input(shape, format=torch.channels_last) - self.assertTrue(self._verify_correctness(i, target)) + ts_i = torchtrt.ts.TSInput(shape=i.shape, dtype=i.dtype, format=i.format) + self.assertTrue(self._verify_correctness(ts_i, target)) def test_dynamic_shape(self): min_shape = [1, 3, 128, 128] @@ -188,14 +199,28 @@ def test_dynamic_shape(self): i = torchtrt.Input( min_shape=min_shape, opt_shape=opt_shape, max_shape=max_shape ) - self.assertTrue(self._verify_correctness(i, target)) + ts_i = torchtrt.ts.TSInput( + min_shape=i.shape["min_shape"], + opt_shape=i.shape["opt_shape"], + max_shape=i.shape["max_shape"], + dtype=i.dtype, + format=i.format, + ) + self.assertTrue(self._verify_correctness(ts_i, target)) i = torchtrt.Input( min_shape=tuple(min_shape), opt_shape=tuple(opt_shape), max_shape=tuple(max_shape), ) - self.assertTrue(self._verify_correctness(i, target)) + ts_i = torchtrt.ts.TSInput( + min_shape=i.shape["min_shape"], + opt_shape=i.shape["opt_shape"], + max_shape=i.shape["max_shape"], + dtype=i.dtype, + format=i.format, + ) + self.assertTrue(self._verify_correctness(ts_i, target)) tensor_shape = lambda shape: torch.randn(shape).shape i = torchtrt.Input( @@ -203,7 +228,14 @@ def test_dynamic_shape(self): opt_shape=tensor_shape(opt_shape), max_shape=tensor_shape(max_shape), ) - self.assertTrue(self._verify_correctness(i, target)) + ts_i = torchtrt.ts.TSInput( + min_shape=i.shape["min_shape"], + opt_shape=i.shape["opt_shape"], + max_shape=i.shape["max_shape"], + dtype=i.dtype, + format=i.format, + ) + self.assertTrue(self._verify_correctness(ts_i, target)) class TestTRTModuleNext(unittest.TestCase):