diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 2442c2c26f..7a1549aad9 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -229,3 +229,30 @@ jobs: ${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_partitioning_test_results.xml partitioning/ ${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_lowering_test_results.xml lowering/ popd + + tests-py-core: + name: Test core [Python] + needs: [generate-matrix, build] + strategy: + fail-fast: false + matrix: + include: + - repository: pytorch/tensorrt + package-name: torch_tensorrt + pre-script: packaging/pre_build_script.sh + uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main + with: + job-name: tests-py-core + repository: "pytorch/tensorrt" + ref: "" + test-infra-repository: pytorch/test-infra + test-infra-ref: main + build-matrix: ${{ needs.generate-matrix.outputs.matrix }} + pre-script: ${{ matrix.pre-script }} + script: | + export USE_HOST_DEPS=1 + pushd . + cd tests/py/core + ${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver + ${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_core_test_results.xml . + popd diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f4ac2ab9e2..61d97503a2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ exclude: ^.github/actions/assigner/dist repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.5.0 hooks: - id: check-yaml - id: trailing-whitespace @@ -16,38 +16,38 @@ repos: - --fix=lf exclude: ^docs - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v16.0.6 + rev: v18.1.1 hooks: - id: clang-format types_or: [c++, c, cuda] - repo: https://github.com/keith/pre-commit-buildifier - rev: 6.1.0.2 + rev: 6.4.0 hooks: - id: buildifier args: - --warnings=all - id: buildifier-lint - repo: https://github.com/abravalheri/validate-pyproject - rev: v0.13 + rev: v0.16 hooks: - id: validate-pyproject - repo: https://github.com/pycqa/isort - rev: 5.12.0 + rev: 5.13.2 hooks: - id: isort name: isort (python) - repo: https://github.com/pre-commit/mirrors-mypy - rev: 'v1.4.1' + rev: 'v1.9.0' hooks: - id: mypy exclude: "^py/torch_tensorrt/fx|^examples|^tests|^py/torch_tensorrt/dynamo/_experimental|^tools|^docs|noxfile.py|setup.py|versions.py" - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.0.278 + rev: v0.3.3 hooks: - id: ruff - repo: https://github.com/psf/black - rev: 24.1.1 + rev: 24.3.0 hooks: - id: black exclude: ^examples/custom_converters/elu_converter/setup.py|^docs diff --git a/BUILD b/BUILD index c40d52e0f9..3138a5d021 100644 --- a/BUILD +++ b/BUILD @@ -33,6 +33,14 @@ pkg_tar( ], ) +pkg_tar( + name = "include_rt", + package_dir = "include/torch_tensorrt", + deps = [ + "//core/runtime:include", + ], +) + pkg_tar( name = "include", srcs = [ @@ -55,6 +63,18 @@ pkg_tar( package_dir = "lib/", ) +pkg_tar( + name = "lib_rt", + srcs = select({ + ":windows": ["//cpp/lib:torch_tensorrt_runtime.dll"], + "//conditions:default": [ + "//cpp/lib:libtorchtrt_runtime.so", + ], + }), + mode = "0755", + package_dir = "lib/", +) + pkg_tar( name = "bin", srcs = [ @@ -82,3 +102,18 @@ pkg_tar( "//conditions:default": [":bin"], }), ) + +pkg_tar( + name = "libtorchtrt_runtime", + srcs = [ + "//:LICENSE", + "//bzl_def:BUILD", + "//bzl_def:WORKSPACE", + ], + extension = "tar.gz", + package_dir = "torch_tensorrt_runtime", + deps = [ + ":include_rt", + ":lib_rt", + ], +) diff --git a/core/runtime/RTDevice.cpp b/core/runtime/RTDevice.cpp index 34ecc22e97..f78ce306ad 100644 --- a/core/runtime/RTDevice.cpp +++ b/core/runtime/RTDevice.cpp @@ -7,8 +7,6 @@ namespace torch_tensorrt { namespace core { namespace runtime { -const std::string DEVICE_INFO_DELIM = "%"; - typedef enum { ID_IDX = 0, SM_MAJOR_IDX, SM_MINOR_IDX, DEVICE_TYPE_IDX, DEVICE_NAME_IDX } SerializedDeviceInfoIndex; RTDevice::RTDevice() : id{-1}, major{-1}, minor{-1}, device_type{nvinfer1::DeviceType::kGPU} {} diff --git a/core/runtime/RTDevice.h b/core/runtime/RTDevice.h index bd1484d4b0..60963c36e1 100644 --- a/core/runtime/RTDevice.h +++ b/core/runtime/RTDevice.h @@ -6,6 +6,8 @@ namespace torch_tensorrt { namespace core { namespace runtime { +const std::string DEVICE_INFO_DELIM = "%"; + struct RTDevice { int64_t id; // CUDA device id int64_t major; // CUDA compute major version diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index 4ae4f92337..c2a344a307 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -116,6 +116,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion = TORCH_LIBRARY(tensorrt, m) { m.def("execute_engine", execute_engine); m.def("SERIALIZED_ENGINE_BINDING_DELIM", []() -> std::string { return std::string(1, TRTEngine::BINDING_DELIM); }); + m.def("SERIALIZED_RT_DEVICE_DELIM", []() -> std::string { return DEVICE_INFO_DELIM; }); m.def("ABI_VERSION", []() -> std::string { return ABI_VERSION; }); m.def("get_multi_device_safe_mode", []() -> bool { return MULTI_DEVICE_SAFE_MODE; }); m.def("set_multi_device_safe_mode", [](bool multi_device_safe_mode) -> void { diff --git a/docsrc/py_api/torch_tensorrt.rst b/docsrc/py_api/torch_tensorrt.rst index 22fda13ba2..eb8285e103 100644 --- a/docsrc/py_api/torch_tensorrt.rst +++ b/docsrc/py_api/torch_tensorrt.rst @@ -37,10 +37,6 @@ Classes :members: :special-members: __init__ -.. autoclass:: TRTModuleNext - :members: - :special-members: __init__ - Enums ------- @@ -50,7 +46,7 @@ Enums .. autoclass:: EngineCapability -.. autoclass:: TensorFormat +.. autoclass:: memory_format Submodules ---------- diff --git a/examples/dynamo/torch_compile_stable_diffusion.py b/examples/dynamo/torch_compile_stable_diffusion.py index 0511e5a363..a0b725572b 100644 --- a/examples/dynamo/torch_compile_stable_diffusion.py +++ b/examples/dynamo/torch_compile_stable_diffusion.py @@ -18,9 +18,8 @@ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ import torch -from diffusers import DiffusionPipeline - import torch_tensorrt +from diffusers import DiffusionPipeline model_id = "CompVis/stable-diffusion-v1-4" device = "cuda:0" @@ -39,7 +38,7 @@ backend=backend, options={ "truncate_long_and_double": True, - "precision": torch.float16, + "enabled_precisions": {torch.float32, torch.float16}, }, dynamic=False, ) diff --git a/py/torch_tensorrt/_Device.py b/py/torch_tensorrt/_Device.py index 6f20b6c84c..9171cc2c27 100644 --- a/py/torch_tensorrt/_Device.py +++ b/py/torch_tensorrt/_Device.py @@ -1,3 +1,6 @@ +from __future__ import annotations + +import logging import sys from typing import Any, Optional, Tuple @@ -6,19 +9,10 @@ else: from typing_extensions import Self -import warnings - -# from torch_tensorrt import _enums import tensorrt as trt import torch -from torch_tensorrt import logging - -try: - from torch_tensorrt import _C -except ImportError: - warnings.warn( - "Unable to import torchscript frontend core and torch-tensorrt runtime. Some dependent features may be unavailable." - ) +from torch_tensorrt._enums import DeviceType +from torch_tensorrt._features import ENABLED_FEATURES class Device(object): @@ -32,9 +26,9 @@ class Device(object): allow_gpu_fallback (bool): Whether falling back to GPU if DLA cannot support an op should be allowed """ - device_type: Optional[trt.DeviceType] = ( - None #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified. - ) + device_type: DeviceType = ( + DeviceType.UNKNOWN + ) #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified. gpu_id: int = -1 #: Device ID for target GPU dla_core: int = -1 #: Core ID for target DLA core allow_gpu_fallback: bool = ( @@ -69,32 +63,31 @@ def __init__(self, *args: Any, **kwargs: Any): ) else: (self.device_type, id) = Device._parse_device_str(args[0]) - if self.device_type == trt.DeviceType.GPU: - self.gpu_id = id - else: + if self.device_type == DeviceType.DLA: self.dla_core = id self.gpu_id = 0 - logging.log( - logging.Level.Warning, - "Setting GPU id to 0 for device because device 0 manages DLA on Xavier", + logging.warning( + "Setting GPU id to 0 for device because device 0 manages DLA on AGX Devices", ) + else: + self.gpu_id = id elif len(args) == 0: if "gpu_id" in kwargs or "dla_core" in kwargs: if "dla_core" in kwargs: - self.device_type = trt.DeviceType.DLA self.dla_core = kwargs["dla_core"] - if "gpu_id" in kwargs: - self.gpu_id = kwargs["gpu_id"] - else: + if "gpu_id" in kwargs: + self.gpu_id = kwargs["gpu_id"] + + if self.dla_core >= 0: + self.device_type = DeviceType.DLA + if self.gpu_id != 0: self.gpu_id = 0 - logging.log( - logging.Level.Warning, - "Setting GPU id to 0 for device because device 0 manages DLA on Xavier", + logging.warning( + "Setting GPU id to 0 for device because device 0 manages DLA on AGX Platforms", ) else: - self.gpu_id = kwargs["gpu_id"] - self.device_type = trt.DeviceType.GPU + self.device_type = DeviceType.GPU else: raise ValueError( "Either gpu_id or dla_core or both must be defined if no string with device specs is provided as an arg" @@ -102,9 +95,7 @@ def __init__(self, *args: Any, **kwargs: Any): else: raise ValueError( - "Unexpected number of positional arguments for class Device \n Found {} arguments, expected either zero or a single positional arguments".format( - len(args) - ) + f"Unexpected number of positional arguments for class Device \n Found {len(args)} arguments, expected either zero or a single positional arguments" ) if "allow_gpu_fallback" in kwargs: @@ -112,58 +103,85 @@ def __init__(self, *args: Any, **kwargs: Any): raise TypeError("allow_gpu_fallback must be a bool") self.allow_gpu_fallback = kwargs["allow_gpu_fallback"] + if "device_type" in kwargs: + if isinstance(kwargs["device_type"], trt.DeviceType): + self.device_type = DeviceType._from(kwargs["device_type"]) + def __str__(self) -> str: - return ( - "Device(type={}, gpu_id={}".format(self.device_type, self.gpu_id) + ")" - if self.device_type == trt.DeviceType.GPU - else ", dla_core={}, allow_gpu_fallback={}".format( - self.dla_core, self.allow_gpu_fallback - ) + suffix = ( + ")" + if self.device_type == DeviceType.GPU + else f", dla_core={self.dla_core}, allow_gpu_fallback={self.allow_gpu_fallback})" ) + dev_str: str = f"Device(type={self.device_type}, gpu_id={self.gpu_id}{suffix}" + return dev_str def __repr__(self) -> str: return self.__str__() - def _to_internal(self) -> _C.Device: - internal_dev = _C.Device() - if self.device_type == trt.DeviceType.GPU: - internal_dev.device_type = _C.DeviceType.GPU - elif self.device_type == trt.DeviceType.DLA: - internal_dev.device_type = _C.DeviceType.DLA - else: - raise ValueError( - "Invalid DeviceType detected while parsing the Device class" - ) + @classmethod + def _from(cls, d: Optional[Self | torch.device | str]) -> Device: + """Cast a device-type to torch_tensorrt.Device - internal_dev.gpu_id = self.gpu_id - internal_dev.dla_core = self.dla_core - internal_dev.allow_gpu_fallback = self.allow_gpu_fallback - return internal_dev + Returns the corresponding torch_tensorrt.Device + """ + if isinstance(d, Device): + return d - def _to_serialized_rt_device(self) -> str: - internal_dev = self._to_internal() - serialized_rt_device: str = internal_dev._to_serialized_rt_device() - return serialized_rt_device + elif isinstance(d, torch.device): + if d.type != "cuda": + raise ValueError('Torch Device specs must have type "cuda"') + return cls(gpu_id=d.index) + + elif d is None: + return cls(gpu_id=torch.cuda.current_device()) + + else: + return cls(d) @classmethod - def _from_torch_device(cls, torch_dev: torch.device) -> Self: - if torch_dev.type != "cuda": - raise ValueError('Torch Device specs must have type "cuda"') - gpu_id = torch_dev.index - return cls(gpu_id=gpu_id) + def _from_torch_device(cls, torch_dev: torch.device) -> Device: + return cls._from(torch_dev) @classmethod - def _current_device(cls) -> Self: - dev = _C._get_current_device() - return cls(gpu_id=dev.gpu_id) + def _current_device(cls) -> Device: + dev_id = torch.cuda.current_device() + return cls(gpu_id=dev_id) @staticmethod def _parse_device_str(s: str) -> Tuple[trt.DeviceType, int]: s = s.lower() spec = s.split(":") if spec[0] == "gpu" or spec[0] == "cuda": - return (trt.DeviceType.GPU, int(spec[1])) + return (DeviceType.GPU, int(spec[1])) elif spec[0] == "dla": - return (trt.DeviceType.DLA, int(spec[1])) + return (DeviceType.DLA, int(spec[1])) else: raise ValueError(f"Unknown device type {spec[0]}") + + def to(self, t: type) -> torch.device: + if t == torch.device: + if self.gpu_id != -1: + return torch.device(self.gpu_id) + else: + raise ValueError("Invalid GPU ID provided for the CUDA device provided") + else: + raise TypeError("Unsupported target type for device conversion") + + def _to_serialized_rt_device(self) -> str: + if not ENABLED_FEATURES.torch_tensorrt_runtime: + raise NotImplementedError("Torch-TensorRT runtime is not available") + + delim = torch.ops.tensorrt.SERIALIZED_RT_DEVICE_DELIM()[0] + dev_info = torch.cuda.get_device_properties(self.gpu_id) + rt_info = [ + self.gpu_id, + dev_info.major, + dev_info.minor, + int(self.device_type.to(trt.DeviceType)), # type: ignore[arg-type] + dev_info.name, + ] + rt_info = [str(i) for i in rt_info] + packed_rt_info: str = delim.join(rt_info) + logging.debug(f"Serialized Device Info: {packed_rt_info}") + return packed_rt_info diff --git a/py/torch_tensorrt/_Input.py b/py/torch_tensorrt/_Input.py index db36678d17..32f19ce1f0 100644 --- a/py/torch_tensorrt/_Input.py +++ b/py/torch_tensorrt/_Input.py @@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple import torch -from torch_tensorrt import _enums +from torch_tensorrt._enums import dtype, memory_format class Input(object): @@ -34,18 +34,17 @@ class _ShapeMode(Enum): shape: Optional[Tuple[int, ...] | Dict[str, Tuple[int, ...]]] = ( None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }`` ) - dtype: _enums.dtype = ( - _enums.dtype.unknown + dtype: dtype = ( + dtype.unknown ) #: The expected data type of the input tensor (default: torch_tensorrt.dtype.float32) _explicit_set_dtype: bool = False - format: _enums.TensorFormat = ( - _enums.TensorFormat.contiguous - ) #: The expected format of the input tensor (default: torch_tensorrt.TensorFormat.NCHW) + format: memory_format = ( + memory_format.linear + ) #: The expected format of the input tensor (default: torch_tensorrt.memory_format.linear) DOMAIN_OFFSET: float = 2.0 low_tensor_domain_incl: float = 0.0 high_tensor_domain_excl: float = low_tensor_domain_incl + DOMAIN_OFFSET - torch_dtype: torch.dtype = torch.float32 torch_tensor: torch.Tensor = None name: str = "" @@ -151,21 +150,19 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: else: raise ValueError( - "Unexpected number of positional arguments for class Input \n Found {} arguments, expected either zero or a single positional arguments".format( - len(args) - ) + f"Unexpected number of positional arguments for class Input \n Found {len(args)} arguments, expected either zero or a single positional arguments" ) if "dtype" in kwargs: - if isinstance(kwargs["dtype"], torch.dtype): - self.torch_dtype = kwargs["dtype"] + self.dtype = dtype._from(kwargs["dtype"]) - self.dtype = Input._parse_dtype(kwargs["dtype"]) - self.torch_dtype = Input._to_torch_dtype(self.dtype) + if self.dtype != dtype.unknown: self._explicit_set_dtype = True + else: + self._explicit_set_dtype = False if "format" in kwargs: - self.format = Input._parse_format(kwargs["format"]) + self.format = memory_format._from(kwargs["format"]) if "tensor_domain" in kwargs: domain = kwargs["tensor_domain"] @@ -212,6 +209,9 @@ def __str__(self) -> str: else: raise RuntimeError("Unknown input shape mode") + def __repr__(self) -> str: + return self.__str__() + @staticmethod def _supported_input_size_type(input_size: Any) -> bool: if isinstance(input_size, torch.Size): @@ -223,77 +223,6 @@ def _supported_input_size_type(input_size: Any) -> bool: else: return False - @staticmethod - def _parse_dtype(dtype: Any) -> _enums.dtype: - if isinstance(dtype, torch.dtype): - if dtype == torch.long: - return _enums.dtype.long - elif dtype == torch.int32: - return _enums.dtype.int32 - elif dtype == torch.half: - return _enums.dtype.half - elif dtype == torch.float: - return _enums.dtype.float - elif dtype == torch.float64: - return _enums.dtype.double - elif dtype == torch.bool: - return _enums.dtype.bool - else: - raise TypeError( - "Provided an unsupported data type as an input data type (support: bool, int32, long, half, float), got: " - + str(dtype) - ) - - elif isinstance(dtype, _enums.dtype): - return dtype - - else: - raise TypeError( - "Input data type needs to be specified with a torch.dtype or a torch_tensorrt.dtype, got: " - + str(type(dtype)) - ) - - @staticmethod - def _to_torch_dtype(dtype: _enums.dtype) -> torch.dtype: - if dtype == _enums.dtype.long: - return torch.long - elif dtype == _enums.dtype.int32: - return torch.int32 - elif dtype == _enums.dtype.half: - return torch.half - elif dtype == _enums.dtype.float: - return torch.float - elif dtype == _enums.dtype.bool: - return torch.bool - elif dtype == _enums.dtype.double: - return torch.float64 - else: - # Default torch_dtype used in FX path - return torch.float32 - - def is_trt_dtype(self) -> bool: - return bool(self.dtype != _enums.dtype.long) - - @staticmethod - def _parse_format(format: Any) -> _enums.TensorFormat: - if isinstance(format, torch.memory_format): - if format == torch.contiguous_format: - return _enums.TensorFormat.contiguous - elif format == torch.channels_last: - return _enums.TensorFormat.channels_last - else: - raise ValueError( - "Provided an unsupported tensor format (support: NCHW/contiguous_format, NHWC/channel_last)" - ) - - elif isinstance(format, _enums.TensorFormat): - return format - - else: - raise TypeError( - "Tensor format needs to be specified with either torch.memory_format or torch_tensorrt.TensorFormat" - ) - @staticmethod def _parse_tensor_domain( domain: Optional[Tuple[float, float]] @@ -415,7 +344,9 @@ def example_tensor( ) else: if isinstance(self.shape, tuple): - return torch.rand(self.shape).to(dtype=self.torch_dtype) + return torch.rand(self.shape).to( + dtype=self.dtype.to(torch.dtype, use_default=True) + ) else: RuntimeError( f"Input shape is dynamic but shapes are not provided as sequence (found: {self.shape})" @@ -434,7 +365,7 @@ def example_tensor( if isinstance(self.shape, dict): return torch.rand(self.shape[optimization_profile_field]).to( - dtype=self.torch_dtype + dtype=self.dtype.to(torch.dtype, use_default=True) ) else: raise RuntimeError( diff --git a/py/torch_tensorrt/__init__.py b/py/torch_tensorrt/__init__.py index b9d2af39c5..f95f33bc74 100644 --- a/py/torch_tensorrt/__init__.py +++ b/py/torch_tensorrt/__init__.py @@ -80,26 +80,46 @@ def _find_lib(name: str, paths: List[str]) -> str: for lib in LINUX_LIBS: ctypes.CDLL(_find_lib(lib, LINUX_PATHS)) -import torch -from torch_tensorrt._compile import * # noqa: F403 -from torch_tensorrt._Device import Device # noqa: F401 -from torch_tensorrt._enums import * # noqa: F403 -from torch_tensorrt._Input import Input # noqa: F401 -from torch_tensorrt._utils import * # noqa: F403 -from torch_tensorrt._utils import sanitized_torch_version -from torch_tensorrt.logging import * -from torch_tensorrt.ptq import * -from torch_tensorrt.runtime import * # noqa: F403 +import logging -if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"): - from torch_tensorrt.dynamo import backend # noqa: F401 +import torch +from torch_tensorrt._features import ENABLED_FEATURES, _enabled_features_str - from torch_tensorrt import dynamo # noqa: F401 +_LOGGER = logging.getLogger(__name__) +_LOGGER.debug(_enabled_features_str()) def _register_with_torch() -> None: trtorch_dir = os.path.dirname(__file__) - torch.ops.load_library(trtorch_dir + "/lib/libtorchtrt.so") + if os.path.isfile(trtorch_dir + "/lib/libtorchtrt.so"): + assert ENABLED_FEATURES.torchscript_frontend + assert ENABLED_FEATURES.torch_tensorrt_runtime + torch.ops.load_library(trtorch_dir + "/lib/libtorchtrt.so") + elif os.path.isfile(trtorch_dir + "/lib/libtorchtrt_runtime.so"): + assert ENABLED_FEATURES.torch_tensorrt_runtime + torch.ops.load_library(trtorch_dir + "/lib/libtorchtrt_runtime.so") _register_with_torch() + +from torch_tensorrt._Device import Device # noqa: F401 +from torch_tensorrt._enums import ( # noqa: F401 + DeviceType, + EngineCapability, + dtype, + memory_format, +) +from torch_tensorrt._Input import Input # noqa: F401 +from torch_tensorrt.runtime import * # noqa: F403 + +if ENABLED_FEATURES.torchscript_frontend: + from torch_tensorrt import ts + +if ENABLED_FEATURES.fx_frontend: + from torch_tensorrt import fx + +if ENABLED_FEATURES.dynamo_frontend: + from torch_tensorrt.dynamo import backend # noqa: F401 + from torch_tensorrt import dynamo # noqa: F401 + +from torch_tensorrt._compile import * # noqa: F403 diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index 7acf83124a..01692006a6 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -6,24 +6,29 @@ import torch import torch.fx -import torch_tensorrt.dynamo -import torch_tensorrt.ts from torch_tensorrt._enums import dtype +from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt._Input import Input -from torch_tensorrt._utils import sanitized_torch_version +from torch_tensorrt.dynamo import _defaults from torch_tensorrt.fx import InputTensorSpec from torch_tensorrt.fx.lower import compile as fx_compile from torch_tensorrt.fx.utils import LowerPrecision -from torch_tensorrt.ts._compiler import compile as torchscript_compile from typing_extensions import TypeGuard -from packaging import version - -DYNAMO_ENABLED = version.parse(sanitized_torch_version()) >= version.parse("2.1.dev") +if ENABLED_FEATURES.torchscript_frontend: + import torch_tensorrt.ts + from torch_tensorrt.ts._compiler import compile as torchscript_compile + from torch_tensorrt.ts._compiler import ( + convert_method_to_trt_engine as ts_convert_method_to_trt_engine, + ) -if DYNAMO_ENABLED: +if ENABLED_FEATURES.dynamo_frontend: from torch._export import ExportedProgram from torch_tensorrt.dynamo._compiler import compile as dynamo_compile + from torch_tensorrt.dynamo._compiler import ( + convert_module_to_trt_engine as dynamo_convert_module_to_trt_engine, + ) + from torch_tensorrt.dynamo._tracer import trace as dynamo_trace logger = logging.getLogger(__name__) @@ -69,7 +74,7 @@ def _parse_module_type(module: Any) -> _ModuleType: return _ModuleType.ts elif isinstance(module, torch.fx.GraphModule): return _ModuleType.fx - elif DYNAMO_ENABLED and isinstance(module, ExportedProgram): + elif isinstance(module, ExportedProgram): return _ModuleType.ep elif isinstance(module, torch.nn.Module): return _ModuleType.nn @@ -77,7 +82,7 @@ def _parse_module_type(module: Any) -> _ModuleType: raise RuntimeError("Module is an unknown format") -def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType: +def _get_target_fe(module_type: _ModuleType, ir: str) -> _IRType: module_is_tsable = any(module_type == t for t in [_ModuleType.nn, _ModuleType.ts]) module_is_fxable = any(module_type == t for t in [_ModuleType.nn, _ModuleType.fx]) module_is_exportable = module_type == _ModuleType.ep @@ -88,35 +93,52 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType: ir_targets_torch_compile = ir == "torch_compile" if module_is_tsable and ir_targets_torchscript: - return _IRType.ts + if ENABLED_FEATURES.torchscript_frontend: + return _IRType.ts + else: + raise ValueError( + "Requested using the TS frontend but the TS frontend is not available in this build of Torch-TensorRT" + ) elif module_is_fxable and ir_targets_fx: - return _IRType.fx - elif module_is_fxable and ir_targets_dynamo: - return _IRType.dynamo + if ENABLED_FEATURES.fx_frontend: + return _IRType.fx + else: + raise ValueError( + "Requested using the FX frontend but the FX frontend is not available in this build of Torch-TensorRT" + ) + elif (module_is_fxable or module_is_exportable) and ir_targets_dynamo: + if ENABLED_FEATURES.dynamo_frontend: + return _IRType.dynamo + else: + raise ValueError( + "Requested using the Dynamo frontend but the Dynamo frontend is not available in this build of Torch-TensorRT" + ) elif module_is_fxable and ir_targets_torch_compile: - return _IRType.torch_compile + if ENABLED_FEATURES.dynamo_frontend: + return _IRType.torch_compile + else: + raise ValueError( + "Requested using the Torch-TensorRT torch.compile backend but the Torch-TensorRT torch.compile backend is not available in this build of Torch-TensorRT" + ) else: if ir == "default": # Options are listed in order of preference - if DYNAMO_ENABLED and module_is_fxable: - logger.info("ir was set to default, using dynamo as ir") + if ENABLED_FEATURES.dynamo_frontend and module_is_fxable: + logger.info("ir was set to default, using dynamo frontend") return _IRType.dynamo - elif module_is_tsable: - if DYNAMO_ENABLED: + elif ENABLED_FEATURES.torchscript_frontend and module_is_tsable: + if ENABLED_FEATURES.dynamo_frontend: logger.warning( - "Input graph is a Torchscript module but the ir provided is default (dynamo). Please set ir=torchscript to suppress the warning. Compiling the module with ir=torchscript" + "Input is a torchscript module but the ir was not specified (default=dynamo), please set ir=torchscript to suppress the warning." ) return _IRType.ts - elif module_is_exportable: + elif ENABLED_FEATURES.dynamo_frontend and module_is_exportable: + logger.info("ir was set to default, using dynamo frontend") + return _IRType.dynamo + else: raise ValueError( - "Input graph is an ExportedProgram which is not currently supported. Please provide torch.nn.Module or torch.fx.GraphModule as input." + f"Module was provided in an unsupported format\nInstalled frontends:\n\tDynamo - {ENABLED_FEATURES.dynamo_frontend}\n\tTorchScript - {ENABLED_FEATURES.torchscript_frontend}\n\tFX - {ENABLED_FEATURES.fx_frontend})" ) - else: - raise ValueError("Module was provided in an unsupported format") - elif ir == "exported_program": - raise ValueError( - "ir=exported_program is not currently supported. Supported ir options : ts|fx|dynamo" - ) else: raise ValueError("Unknown ir was requested") @@ -166,12 +188,14 @@ def compile( torch.nn.Module: Compiled Module, when run it will execute via TensorRT """ input_list = inputs if inputs is not None else [] - enabled_precisions_set = ( - enabled_precisions if enabled_precisions is not None else {torch.float} + enabled_precisions_set: Set[dtype | torch.dtype] = ( + enabled_precisions + if enabled_precisions is not None + else _defaults.ENABLED_PRECISIONS ) module_type = _parse_module_type(module) - target_ir = _get_target_ir(module_type, ir) + target_ir = _get_target_fe(module_type, ir) if target_ir == _IRType.ts: ts_mod = module if module_type == _ModuleType.nn: @@ -222,7 +246,7 @@ def compile( # Export the module torchtrt_inputs = prepare_inputs(input_list) - exp_program = torch_tensorrt.dynamo.trace(module, torchtrt_inputs, **kwargs) + exp_program = dynamo_trace(module, torchtrt_inputs, **kwargs) trt_graph_module = dynamo_compile( exp_program, inputs=torchtrt_inputs, @@ -297,7 +321,7 @@ def convert_method_to_trt_engine( ) module_type = _parse_module_type(module) - target_ir = _get_target_ir(module_type, ir) + target_ir = _get_target_fe(module_type, ir) if target_ir == _IRType.ts: ts_mod = module if module_type == _ModuleType.nn: @@ -305,19 +329,20 @@ def convert_method_to_trt_engine( "Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript" ) ts_mod = torch.jit.script(module) - return torch_tensorrt.ts.convert_method_to_trt_engine( # type: ignore[no-any-return] + serialized_engine: bytes = ts_convert_method_to_trt_engine( ts_mod, inputs=inputs, method_name=method_name, enabled_precisions=enabled_precisions_set, **kwargs, ) + return serialized_engine elif target_ir == _IRType.fx: raise RuntimeError( "convert_method_to_trt_engine call is not supported for ir=fx" ) elif target_ir == _IRType.dynamo: - return torch_tensorrt.dynamo.convert_module_to_trt_engine( # type: ignore[no-any-return] + return dynamo_convert_module_to_trt_engine( # type: ignore[no-any-return] module, inputs=inputs, method_name=method_name, diff --git a/py/torch_tensorrt/_enums.py b/py/torch_tensorrt/_enums.py index 44cb772dc3..5c16cd03cd 100644 --- a/py/torch_tensorrt/_enums.py +++ b/py/torch_tensorrt/_enums.py @@ -1,3 +1,728 @@ -from torch_tensorrt._C import EngineCapability, TensorFormat, dtype # noqa: F401 +from __future__ import annotations -from tensorrt import DeviceType # noqa: F401 +import logging +from enum import Enum, auto +from typing import Any, Optional, Type, Union + +import numpy as np +import tensorrt as trt +import torch +from torch_tensorrt._features import ENABLED_FEATURES + + +class dtype(Enum): + """Enum to set supported dtypes in the compiler""" + + # Supported types in Torch-TensorRT + unknown = auto() + u8 = auto() + i8 = auto() + i32 = auto() + i64 = auto() + f16 = auto() + f32 = auto() + f64 = auto() + b = auto() + # TODO: Enable FP8 and BF16 + # f8 = auto() + # bf16 = auto() + + uint8 = u8 + int8 = i8 + + int32 = i32 + + long = i64 + int64 = i64 + + half = f16 + fp16 = f16 + float16 = f16 + + float = f32 + fp32 = f32 + float32 = f32 + + double = f64 + fp64 = f64 + float64 = f64 + + # TODO: Enable when FP8 is enabled + # float8 = f8 + # fp8 = f8 + + # TODO: Enable when BF16 is enabled + # bfloat16 = bf16 + + @staticmethod + def _is_np_obj(t: Any) -> bool: + if isinstance(t, np.dtype): + return True + elif isinstance(t, type): + if issubclass(t, np.generic): + return True + return False + + @classmethod + def _from( + cls, + t: Union[torch.dtype, trt.DataType, np.dtype, dtype, type], + use_default: bool = False, + ) -> dtype: + # TODO: Ideally implemented with match statement but need to wait for Py39 EoL + if isinstance(t, torch.dtype): + if t == torch.uint8: + return dtype.u8 + elif t == torch.int8: + return dtype.i8 + elif t == torch.long: + return dtype.i64 + elif t == torch.int32: + return dtype.i32 + elif t == torch.half: + return dtype.f16 + elif t == torch.float: + return dtype.f32 + elif t == torch.float64: + return dtype.f64 + elif t == torch.bool: + return dtype.b + elif use_default: + logging.warning( + f"Given dtype that does not have direct mapping to Torch-TensorRT supported types ({t}), defaulting to torch_tensorrt.dtype.float" + ) + return dtype.float + else: + raise TypeError( + f"Provided an unsupported data type as an input data type (support: bool, int32, long, half, float), got: {t}" + ) + elif isinstance(t, trt.DataType): + if t == trt.uint8: + return dtype.u8 + elif t == trt.int8: + return dtype.i8 + elif t == trt.int32: + return dtype.i32 + elif t == trt.float16: + return dtype.f16 + elif t == trt.float32: + return dtype.f32 + elif trt.__version__ >= "7.0" and t == trt.bool: + return dtype.b + else: + raise TypeError( + f"Provided an unsupported data type as an input data type (support: bool, int32, half, float), got: {t}" + ) + + elif dtype._is_np_obj(t): + if t == np.uint8: + return dtype.u8 + elif t == np.int8: + return dtype.i8 + elif t == np.int32: + return dtype.i32 + elif t == np.int64: + return dtype.i64 + elif t == np.float16: + return dtype.f16 + elif t == np.float32: + return dtype.f32 + elif t == np.float64: + return dtype.f64 + elif t == np.bool: + return dtype.b + elif use_default: + logging.warning( + f"Given dtype that does not have direct mapping to Torch-TensorRT supported types ({t}), defaulting to torch_tensorrt.dtype.float" + ) + return dtype.float + else: + raise TypeError( + "Provided an unsupported data type as an input data type (support: bool, int32, long, half, float), got: " + + str(t) + ) + + elif isinstance(t, dtype): + return t + + elif ENABLED_FEATURES.torchscript_frontend: + from torch_tensorrt import _C + + if isinstance(t, _C.dtype): + if t == _C.dtype.long: + return dtype.i64 + elif t == _C.dtype.int32: + return dtype.i32 + elif t == _C.dtype.int8: + return dtype.i8 + elif t == _C.dtype.half: + return dtype.f16 + elif t == _C.dtype.float: + return dtype.f32 + elif t == _C.dtype.double: + return dtype.f64 + elif t == _C.dtype.bool: + return dtype.b + elif t == _C.dtype.unknown: + return dtype.unknown + else: + raise TypeError( + f"Provided an unsupported data type as an input data type (support: bool, int32, long, half, float), got: {t}" + ) + # else: # commented out for mypy + raise TypeError( + f"Provided unsupported source type for dtype conversion (got: {t})" + ) + + @classmethod + def try_from( + cls, + t: Union[torch.dtype, trt.DataType, np.dtype, dtype], + use_default: bool = False, + ) -> Optional[dtype]: + try: + casted_format = dtype._from(t, use_default=use_default) + return casted_format + except (ValueError, TypeError) as e: + logging.debug( + f"Conversion from {t} to torch_tensorrt.dtype failed", exc_info=True + ) + return None + + def to( + self, + t: Union[Type[torch.dtype], Type[trt.DataType], Type[np.dtype], Type[dtype]], + use_default: bool = False, + ) -> Union[torch.dtype, trt.DataType, np.dtype, dtype]: + # TODO: Ideally implemented with match statement but need to wait for Py39 EoL + if t == torch.dtype: + if self == dtype.u8: + return torch.uint8 + elif self == dtype.i8: + return torch.int8 + elif self == dtype.i32: + return torch.int + elif self == dtype.i64: + return torch.long + elif self == dtype.f16: + return torch.half + elif self == dtype.f32: + return torch.float + elif self == dtype.f64: + return torch.double + elif self == dtype.b: + return torch.bool + elif use_default: + logging.warning( + f"Given dtype that does not have direct mapping to torch ({self}), defaulting to torch.float" + ) + return torch.float + else: + raise TypeError(f"Unsupported torch dtype (had: {self})") + + elif t == trt.DataType: + if self == dtype.u8: + return trt.DataType.UINT8 + if self == dtype.i8: + return trt.DataType.INT8 + elif self == dtype.i32: + return trt.DataType.INT32 + elif self == dtype.f16: + return trt.DataType.HALF + elif self == dtype.f32: + return trt.DataType.FLOAT + elif self == dtype.b: + return trt.DataType.BOOL + elif use_default: + return trt.DataType.FLOAT + else: + raise TypeError("Unsupported tensorrt dtype") + + elif t == np.dtype: + if self == dtype.u8: + return np.uint8 + elif self == dtype.i8: + return np.int8 + elif self == dtype.i32: + return np.int32 + elif self == dtype.i64: + return np.int64 + elif self == dtype.f16: + return np.float16 + elif self == dtype.f32: + return np.float32 + elif self == dtype.f64: + return np.float64 + elif self == dtype.b: + return np.bool_ + elif use_default: + return np.float32 + else: + raise TypeError("Unspported numpy dtype") + + elif t == dtype: + return self + + elif ENABLED_FEATURES.torchscript_frontend: + from torch_tensorrt import _C + + if t == _C.dtype: + if self == dtype.i64: + return _C.dtype.long + elif self == dtype.i8: + return _C.dtype.int8 + elif self == dtype.i32: + return _C.dtype.int32 + elif self == dtype.f16: + return _C.dtype.half + elif self == dtype.f32: + return _C.dtype.float + elif self == dtype.f64: + return _C.dtype.double + elif self == dtype.b: + return _C.dtype.bool + elif self == dtype.unknown: + return _C.dtype.unknown + else: + raise TypeError( + f"Provided an unsupported data type as an input data type (support: bool, int32, long, half, float), got: {self}" + ) + # else: # commented out for mypy + raise TypeError( + f"Provided unsupported destination type for dtype conversion {t}" + ) + + def try_to( + self, + t: Union[Type[torch.dtype], Type[trt.DataType], Type[np.dtype], Type[dtype]], + use_default: bool, + ) -> Optional[Union[torch.dtype, trt.DataType, np.dtype, dtype]]: + try: + casted_format = self.to(t, use_default) + return casted_format + except (ValueError, TypeError) as e: + logging.debug( + f"torch_tensorrt.dtype conversion to target type {t} failed", + exc_info=True, + ) + return None + + def __eq__(self, other: Union[torch.dtype, trt.DataType, np.dtype, dtype]) -> bool: + other_ = dtype._from(other) + return bool(self.value == other_.value) + + def __hash__(self) -> int: + return hash(self.value) + + # Putting aliases here that mess with mypy + bool = b + int = i32 + + +class memory_format(Enum): + + # TensorRT supported memory layouts + linear = auto() + chw2 = auto() + hwc8 = auto() + chw4 = auto() + chw16 = auto() + chw32 = auto() + dhwc8 = auto() + cdhw32 = auto() + hwc = auto() + dla_linear = auto() + dla_hwc4 = auto() + hwc16 = auto() + dhwc = auto() + + # PyTorch aliases for TRT layouts + contiguous = linear + channels_last = hwc + channels_last_3d = dhwc + + @classmethod + def _from( + cls, f: Union[torch.memory_format, trt.TensorFormat, memory_format] + ) -> memory_format: + # TODO: Ideally implemented with match statement but need to wait for Py39 EoL + if isinstance(f, torch.memory_format): + if f == torch.contiguous_format: + return memory_format.contiguous + elif f == torch.channels_last: + return memory_format.channels_last + elif f == torch.channels_last_3d: + return memory_format.channels_last_3d + else: + raise TypeError( + f"Provided an unsupported memory format for tensor, got: {dtype}" + ) + + elif isinstance(f, trt.DataType): + if f == trt.TensorFormat.LINEAR: + return memory_format.linear + elif f == trt.TensorFormat.CHW2: + return memory_format.chw2 + elif f == trt.TensorFormat.HWC8: + return memory_format.hwc8 + elif f == trt.TensorFormat.CHW4: + return memory_format.chw4 + elif f == trt.TensorFormat.CHW16: + return memory_format.chw16 + elif f == trt.TensorFormat.CHW32: + return memory_format.chw32 + elif f == trt.TensorFormat.DHWC8: + return memory_format.dhwc8 + elif f == trt.TensorFormat.CDHW32: + return memory_format.cdhw32 + elif f == trt.TensorFormat.HWC: + return memory_format.hwc + elif f == trt.TensorFormat.DLA_LINEAR: + return memory_format.dla_linear + elif f == trt.TensorFormat.DLA_HWC4: + return memory_format.dla_hwc4 + elif f == trt.TensorFormat.HWC16: + return memory_format.hwc16 + elif f == trt.TensorFormat.DHWC: + return memory_format.dhwc + else: + raise TypeError( + f"Provided an unsupported tensor format for tensor, got: {dtype}" + ) + + elif isinstance(f, memory_format): + return f + + elif ENABLED_FEATURES.torchscript_frontend: + from torch_tensorrt import _C + + if isinstance(f, _C.TensorFormat): + if f == _C.TensorFormat.contiguous: + return memory_format.contiguous + elif f == _C.TensorFormat.channels_last: + return memory_format.channels_last + else: + raise ValueError( + "Provided an unsupported tensor format (support: NCHW/contiguous_format, NHWC/channel_last)" + ) + # else: # commented out for mypy + raise TypeError("Provided unsupported source type for memory_format conversion") + + @classmethod + def try_from( + cls, f: Union[torch.memory_format, trt.TensorFormat, memory_format] + ) -> Optional[memory_format]: + try: + casted_format = memory_format._from(f) + return casted_format + except (ValueError, TypeError) as e: + logging.debug( + f"Conversion from {f} to torch_tensorrt.memory_format failed", + exc_info=True, + ) + return None + + def to( + self, + t: Union[ + Type[torch.memory_format], Type[trt.TensorFormat], Type[memory_format] + ], + ) -> Union[torch.memory_format, trt.TensorFormat, memory_format]: + if t == torch.memory_format: + if self == memory_format.contiguous: + return torch.contiguous_format + elif self == memory_format.channels_last: + return torch.channels_last + elif self == memory_format.channels_last_3d: + return torch.channels_last_3d + else: + raise TypeError("Unsupported torch dtype") + + elif t == trt.TensorFormat: + if self == memory_format.linear: + return trt.TensorFormat.LINEAR + elif self == memory_format.chw2: + return trt.TensorFormat.CHW2 + elif self == memory_format.hwc8: + return trt.TensorFormat.HWC8 + elif self == memory_format.chw4: + return trt.TensorFormat.CHW4 + elif self == memory_format.chw16: + return trt.TensorFormat.CHW16 + elif self == memory_format.chw32: + return trt.TensorFormat.CHW32 + elif self == memory_format.dhwc8: + return trt.TensorFormat.DHWC8 + elif self == memory_format.cdhw32: + return trt.TensorFormat.CDHW32 + elif self == memory_format.hwc: + return trt.TensorFormat.HWC + elif self == memory_format.dla_linear: + return trt.TensorFormat.DLA_LINEAR + elif self == memory_format.dla_hwc4: + return trt.TensorFormat.DLA_HWC4 + elif self == memory_format.hwc16: + return trt.TensorFormat.HWC16 + elif self == memory_format.dhwc: + return trt.TensorFormat.DHWC + else: + raise TypeError("Unsupported tensorrt memory format") + + elif t == memory_format: + return self + + elif ENABLED_FEATURES.torchscript_frontend: + from torch_tensorrt import _C + + if t == _C.TensorFormat: + if self == memory_format.contiguous: + return _C.TensorFormat.contiguous + elif self == memory_format.channels_last: + return _C.TensorFormat.channels_last + else: + raise ValueError( + "Provided an unsupported tensor format (support: NCHW/contiguous_format, NHWC/channel_last)" + ) + # else: # commented out for mypy + raise TypeError( + "Provided unsupported destination type for memory format conversion" + ) + + def try_to( + self, + t: Union[ + Type[torch.memory_format], Type[trt.TensorFormat], Type[memory_format] + ], + ) -> Optional[Union[torch.memory_format, trt.TensorFormat, memory_format]]: + try: + casted_format = self.to(t) + return casted_format + except (ValueError, TypeError) as e: + logging.debug( + f"torch_tensorrt.memory_format conversion to target type {t} failed", + exc_info=True, + ) + return None + + def __eq__( + self, other: Union[torch.memory_format, trt.TensorFormat, memory_format] + ) -> bool: + other_ = memory_format._from(other) + return self.value == other_.value + + def __hash__(self) -> int: + return hash(self.value) + + +class DeviceType(Enum): + UNKNOWN = auto() + GPU = auto() + DLA = auto() + + @classmethod + def _from(cls, d: Union[trt.DeviceType, DeviceType]) -> DeviceType: + if isinstance(d, trt.DeviceType): + if d == trt.DeviceType.GPU: + return DeviceType.GPU + elif d == trt.DeviceType.DLA: + return DeviceType.DLA + else: + raise ValueError( + "Provided an unsupported device type (support: GPU/DLA)" + ) + + elif isinstance(d, DeviceType): + return d + + elif ENABLED_FEATURES.torchscript_frontend: + from torch_tensorrt import _C + + if isinstance(d, _C.DeviceType): + if d == _C.DeviceType.GPU: + return DeviceType.GPU + elif d == _C.DeviceType.DLA: + return DeviceType.DLA + else: + raise ValueError( + "Provided an unsupported device type (support: GPU/DLA)" + ) + # else: # commented out for mypy + raise TypeError("Provided unsupported source type for DeviceType conversion") + + @classmethod + def try_from(cls, d: Union[trt.DeviceType, DeviceType]) -> Optional[DeviceType]: + try: + casted_format = DeviceType._from(d) + return casted_format + except (ValueError, TypeError) as e: + logging.debug( + f"Conversion from {d} to torch_tensorrt.DeviceType failed", + exc_info=True, + ) + return None + + def to( + self, + t: Union[Type[trt.DeviceType], Type[DeviceType]], + use_default: bool = False, + ) -> Union[trt.DeviceType, DeviceType]: + if t == trt.DeviceType: + if self == DeviceType.GPU: + return trt.DeviceType.GPU + elif self == DeviceType.DLA: + return trt.DeviceType.DLA + elif use_default: + return trt.DeviceType.GPU + else: + raise ValueError( + "Provided an unsupported device type (support: GPU/DLA)" + ) + + elif t == DeviceType: + return self + + elif ENABLED_FEATURES.torchscript_frontend: + from torch_tensorrt import _C + + if t == _C.DeviceType: + if self == DeviceType.GPU: + return _C.DeviceType.GPU + elif self == DeviceType.DLA: + return _C.DeviceType.DLA + else: + raise ValueError( + "Provided an unsupported device type (support: GPU/DLA)" + ) + # else: # commented out for mypy + raise TypeError( + "Provided unsupported destination type for device type conversion" + ) + + def try_to( + self, + t: Union[Type[trt.DeviceType], Type[DeviceType]], + use_default: bool = False, + ) -> Optional[Union[trt.DeviceType, DeviceType]]: + try: + casted_format = self.to(t, use_default=use_default) + return casted_format + except (ValueError, TypeError) as e: + logging.debug( + f"torch_tensorrt.DeviceType conversion to target type {t} failed", + exc_info=True, + ) + return None + + def __eq__(self, other: Union[trt.DeviceType, DeviceType]) -> bool: + other_ = DeviceType._from(other) + return bool(self.value == other_.value) + + def __hash__(self) -> int: + return hash(self.value) + + +class EngineCapability(Enum): + STANDARD = auto() + SAFETY = auto() + DLA_STANDALONE = auto() + + @classmethod + def _from( + cls, c: Union[trt.EngineCapability, EngineCapability] + ) -> EngineCapability: + if isinstance(c, trt.EngineCapability): + if c == trt.EngineCapability.STANDARD: + return EngineCapability.STANDARD + elif c == trt.EngineCapability.SAFETY: + return EngineCapability.SAFETY + elif c == trt.EngineCapability.DLA_STANDALONE: + return EngineCapability.DLA_STANDALONE + else: + raise ValueError("Provided an unsupported engine capability") + + elif isinstance(c, EngineCapability): + return c + + elif ENABLED_FEATURES.torchscript_frontend: + from torch_tensorrt import _C + + if isinstance(c, _C.EngineCapability): + if c == _C.EngineCapability.STANDARD: + return EngineCapability.STANDARD + elif c == _C.EngineCapability.SAFETY: + return EngineCapability.SAFETY + elif c == _C.EngineCapability.DLA_STANDALONE: + return EngineCapability.DLA_STANDALONE + else: + raise ValueError("Provided an unsupported engine capability") + # else: # commented out for mypy + raise TypeError( + "Provided unsupported source type for EngineCapability conversion" + ) + + @classmethod + def try_from( + c: Union[trt.EngineCapability, EngineCapability] + ) -> Optional[EngineCapability]: + try: + casted_format = EngineCapability._from(c) + return casted_format + except (ValueError, TypeError) as e: + logging.debug( + f"Conversion from {c} to torch_tensorrt.EngineCapablity failed", + exc_info=True, + ) + return None + + def to( + self, t: Union[Type[trt.EngineCapability], Type[EngineCapability]] + ) -> Union[trt.EngineCapability, EngineCapability]: + if t == trt.EngineCapability: + if self == EngineCapability.STANDARD: + return trt.EngineCapability.STANDARD + elif self == EngineCapability.SAFETY: + return trt.EngineCapability.SAFETY + elif self == EngineCapability.DLA_STANDALONE: + return trt.EngineCapability.DLA_STANDALONE + else: + raise ValueError("Provided an unsupported engine capability") + + elif t == EngineCapability: + return self + + elif ENABLED_FEATURES.torchscript_frontend: + from torch_tensorrt import _C + + if t == _C.EngineCapability: + if self == EngineCapability.STANDARD: + return _C.EngineCapability.STANDARD + elif self == EngineCapability.SAFETY: + return _C.EngineCapability.SAFETY + elif self == EngineCapability.DLA_STANDALONE: + return _C.EngineCapability.DLA_STANDALONE + else: + raise ValueError("Provided an unsupported engine capability") + # else: # commented out for mypy + raise TypeError( + "Provided unsupported destination type for engine capablity type conversion" + ) + + def try_to( + self, t: Union[Type[trt.EngineCapability], Type[EngineCapability]] + ) -> Optional[Union[trt.EngineCapability, EngineCapability]]: + try: + casted_format = self.to(t) + return casted_format + except (ValueError, TypeError) as e: + logging.debug( + f"torch_tensorrt.EngineCapablity conversion to target type {t} failed", + exc_info=True, + ) + return None + + def __eq__(self, other: Union[trt.EngineCapability, EngineCapability]) -> bool: + other_ = EngineCapability._from(other) + return bool(self.value == other_.value) + + def __hash__(self) -> int: + return hash(self.value) diff --git a/py/torch_tensorrt/_features.py b/py/torch_tensorrt/_features.py new file mode 100644 index 0000000000..dde99cbaf6 --- /dev/null +++ b/py/torch_tensorrt/_features.py @@ -0,0 +1,35 @@ +import os +from collections import namedtuple + +from torch_tensorrt._utils import sanitized_torch_version + +from packaging import version + +FeatureSet = namedtuple( + "FeatureSet", + [ + "torchscript_frontend", + "torch_tensorrt_runtime", + "dynamo_frontend", + "fx_frontend", + ], +) + +_TS_FE_AVAIL = os.path.isfile(os.path.dirname(__file__) + "/lib/libtorchtrt.so") +_TORCHTRT_RT_AVAIL = _TS_FE_AVAIL or os.path.isfile( + os.path.dirname(__file__) + "/lib/libtorchtrt_runtime.so" +) +_DYNAMO_FE_AVAIL = version.parse(sanitized_torch_version()) >= version.parse("2.1.dev") +_FX_FE_AVAIL = True + +ENABLED_FEATURES = FeatureSet( + _TS_FE_AVAIL, _TORCHTRT_RT_AVAIL, _DYNAMO_FE_AVAIL, _FX_FE_AVAIL +) + + +def _enabled_features_str() -> str: + enabled = lambda x: "ENABLED" if x else "DISABLED" + out_str: str = ( + f"Enabled Features:\n - Dynamo Frontend: {enabled(_DYNAMO_FE_AVAIL)}\n - Torch-TensorRT Runtime: {enabled(_TORCHTRT_RT_AVAIL)}\n - FX Frontend: {enabled(_FX_FE_AVAIL)}\n - TorchScript Frontend: {enabled(_TS_FE_AVAIL)}\n" # type: ignore[no-untyped-call] + ) + return out_str diff --git a/py/torch_tensorrt/_utils.py b/py/torch_tensorrt/_utils.py index b21696427b..3d5f98b5e5 100644 --- a/py/torch_tensorrt/_utils.py +++ b/py/torch_tensorrt/_utils.py @@ -1,36 +1,6 @@ from typing import Any import torch -from torch_tensorrt import _C -from torch_tensorrt._version import __version__ - - -def dump_build_info() -> None: - """Prints build information about the torch_tensorrt distribution to stdout""" - print(get_build_info()) - - -def get_build_info() -> str: - """Returns a string containing the build information of torch_tensorrt distribution - - Returns: - str: String containing the build information for torch_tensorrt distribution - """ - core_build_info = _C.get_build_info() - build_info = str( - "Torch-TensorRT Version: " - + str(__version__) - + "\n" - + "Using PyTorch Version: " - + str(torch.__version__) - + "\n" - + core_build_info - ) - return build_info - - -def set_device(gpu_id: int) -> None: - _C.set_device(gpu_id) def sanitized_torch_version() -> Any: diff --git a/py/torch_tensorrt/csrc/tensorrt_classes.cpp b/py/torch_tensorrt/csrc/tensorrt_classes.cpp index 4794a679eb..bd3aa6b305 100644 --- a/py/torch_tensorrt/csrc/tensorrt_classes.cpp +++ b/py/torch_tensorrt/csrc/tensorrt_classes.cpp @@ -235,23 +235,23 @@ std::string Device::to_str() { std::string to_str(EngineCapability value) { switch (value) { - case EngineCapability::kSAFE_GPU: - return "Safe GPU"; - case EngineCapability::kSAFE_DLA: - return "Safe DLA"; - case EngineCapability::kDEFAULT: + case EngineCapability::kDLA_STANDALONE: + return "DLA Standalone"; + case EngineCapability::kSAFETY: + return "Safety"; + case EngineCapability::kSTANDARD: default: - return "Default"; + return "Standard"; } } nvinfer1::EngineCapability toTRTEngineCapability(EngineCapability value) { switch (value) { - case EngineCapability::kSAFE_DLA: + case EngineCapability::kDLA_STANDALONE: return TRT_ENGINE_CAPABILITY_DLA_STANDALONE; - case EngineCapability::kSAFE_GPU: + case EngineCapability::kSAFETY: return TRT_ENGINE_CAPABILITY_SAFETY; - case EngineCapability::kDEFAULT: + case EngineCapability::kSTANDARD: default: return TRT_ENGINE_CAPABILITY_STANDARD; } diff --git a/py/torch_tensorrt/csrc/tensorrt_classes.h b/py/torch_tensorrt/csrc/tensorrt_classes.h index 9bdd00b7e0..89c5c8661e 100644 --- a/py/torch_tensorrt/csrc/tensorrt_classes.h +++ b/py/torch_tensorrt/csrc/tensorrt_classes.h @@ -114,9 +114,9 @@ struct TorchFallback : torch::CustomClassHolder { }; enum class EngineCapability : int8_t { - kDEFAULT, - kSAFE_GPU, - kSAFE_DLA, + kSTANDARD, + kSAFETY, + kDLA_STANDALONE, }; std::string to_str(EngineCapability value); @@ -160,7 +160,7 @@ struct CompileSpec : torch::CustomClassHolder { ADD_FIELD_GET_SET(sparse_weights, bool); ADD_FIELD_GET_SET(refit, bool); ADD_FIELD_GET_SET(debug, bool); - ADD_ENUM_GET_SET(capability, EngineCapability, static_cast(EngineCapability::kSAFE_DLA)); + ADD_ENUM_GET_SET(capability, EngineCapability, static_cast(EngineCapability::kSTANDARD)); ADD_FIELD_GET_SET(num_avg_timing_iters, int64_t); ADD_FIELD_GET_SET(workspace_size, int64_t); ADD_FIELD_GET_SET(dla_sram_size, int64_t); @@ -184,7 +184,7 @@ struct CompileSpec : torch::CustomClassHolder { bool allow_shape_tensors = false; Device device; TorchFallback torch_fallback; - EngineCapability capability = EngineCapability::kDEFAULT; + EngineCapability capability = EngineCapability::kSTANDARD; int64_t num_avg_timing_iters = 1; int64_t workspace_size = 0; int64_t dla_sram_size = 1048576; diff --git a/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp b/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp index 33c7e27398..e4d88088e4 100644 --- a/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp +++ b/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp @@ -261,9 +261,9 @@ PYBIND11_MODULE(_C, m) { m, "EngineCapability", "Enum to specify engine capability settings (selections of kernels to meet safety requirements)") - .value("safe_gpu", EngineCapability::kSAFE_GPU, "Use safety GPU kernels only") - .value("safe_dla", EngineCapability::kSAFE_DLA, "Use safety DLA kernels only") - .value("default", EngineCapability::kDEFAULT, "Use default behavior"); + .value("SAFETY", EngineCapability::kSAFETY, "Use safe kernels only") + .value("DLA_STANDALONE", EngineCapability::kDLA_STANDALONE, "Use DLA kernels only") + .value("STANDARD", EngineCapability::kSTANDARD, "Use default behavior"); py::enum_(m, "TensorFormat", "Enum to specifiy the memory layout of tensors") .value("contiguous", TensorFormat::kContiguous, "Contiguous memory layout (NCHW / Linear)") diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 884f99f5c0..09543a5d64 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -5,42 +5,12 @@ from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union import torch -import torch_tensorrt from torch.export import ExportedProgram from torch.fx.node import Target -from torch_tensorrt import _enums from torch_tensorrt._Device import Device -from torch_tensorrt._enums import ( # TODO: Should probabably be the TRT EngineCapability Enum - EngineCapability, -) +from torch_tensorrt._enums import EngineCapability, dtype from torch_tensorrt._Input import Input -from torch_tensorrt.dynamo import partitioning -from torch_tensorrt.dynamo._defaults import ( - DEBUG, - DEVICE, - DISABLE_TF32, - DLA_GLOBAL_DRAM_SIZE, - DLA_LOCAL_DRAM_SIZE, - DLA_SRAM_SIZE, - DRYRUN, - ENABLE_EXPERIMENTAL_DECOMPOSITIONS, - ENGINE_CAPABILITY, - HARDWARE_COMPATIBLE, - MAX_AUX_STREAMS, - MIN_BLOCK_SIZE, - NUM_AVG_TIMING_ITERS, - OPTIMIZATION_LEVEL, - PASS_THROUGH_BUILD_FAILURES, - PRECISION, - REFIT, - REQUIRE_FULL_COMPILATION, - SPARSE_WEIGHTS, - TRUNCATE_LONG_AND_DOUBLE, - USE_FAST_PARTITIONER, - USE_PYTHON_RUNTIME, - VERSION_COMPATIBLE, - WORKSPACE_SIZE, -) +from torch_tensorrt.dynamo import _defaults, partitioning from torch_tensorrt.dynamo._DryRunTracker import ( DryRunTracker, PerSubgraphData, @@ -74,32 +44,34 @@ def compile( exported_program: ExportedProgram, inputs: Tuple[Any, ...], *, - device: Optional[Union[Device, torch.device, str]] = DEVICE, - disable_tf32: bool = DISABLE_TF32, - sparse_weights: bool = SPARSE_WEIGHTS, - enabled_precisions: Set[torch.dtype] | Tuple[torch.dtype] = (torch.float32,), - engine_capability: EngineCapability = ENGINE_CAPABILITY, - refit: bool = REFIT, - debug: bool = DEBUG, - num_avg_timing_iters: int = NUM_AVG_TIMING_ITERS, - workspace_size: int = WORKSPACE_SIZE, - dla_sram_size: int = DLA_SRAM_SIZE, - dla_local_dram_size: int = DLA_LOCAL_DRAM_SIZE, - dla_global_dram_size: int = DLA_GLOBAL_DRAM_SIZE, - truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE, - require_full_compilation: bool = REQUIRE_FULL_COMPILATION, - min_block_size: int = MIN_BLOCK_SIZE, + device: Optional[Union[Device, torch.device, str]] = _defaults.DEVICE, + disable_tf32: bool = _defaults.DISABLE_TF32, + sparse_weights: bool = _defaults.SPARSE_WEIGHTS, + enabled_precisions: ( + Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype] + ) = _defaults.ENABLED_PRECISIONS, + engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, + refit: bool = _defaults.REFIT, + debug: bool = _defaults.DEBUG, + num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, + workspace_size: int = _defaults.WORKSPACE_SIZE, + dla_sram_size: int = _defaults.DLA_SRAM_SIZE, + dla_local_dram_size: int = _defaults.DLA_LOCAL_DRAM_SIZE, + dla_global_dram_size: int = _defaults.DLA_GLOBAL_DRAM_SIZE, + truncate_long_and_double: bool = _defaults.TRUNCATE_LONG_AND_DOUBLE, + require_full_compilation: bool = _defaults.REQUIRE_FULL_COMPILATION, + min_block_size: int = _defaults.MIN_BLOCK_SIZE, torch_executed_ops: Optional[Collection[Target]] = None, torch_executed_modules: Optional[List[str]] = None, - pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES, - max_aux_streams: Optional[int] = MAX_AUX_STREAMS, - version_compatible: bool = VERSION_COMPATIBLE, - optimization_level: Optional[int] = OPTIMIZATION_LEVEL, - use_python_runtime: bool = USE_PYTHON_RUNTIME, - use_fast_partitioner: bool = USE_FAST_PARTITIONER, - enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS, - dryrun: bool = DRYRUN, - hardware_compatible: bool = HARDWARE_COMPATIBLE, + pass_through_build_failures: bool = _defaults.PASS_THROUGH_BUILD_FAILURES, + max_aux_streams: Optional[int] = _defaults.MAX_AUX_STREAMS, + version_compatible: bool = _defaults.VERSION_COMPATIBLE, + optimization_level: Optional[int] = _defaults.OPTIMIZATION_LEVEL, + use_python_runtime: bool = _defaults.USE_PYTHON_RUNTIME, + use_fast_partitioner: bool = _defaults.USE_FAST_PARTITIONER, + enable_experimental_decompositions: bool = _defaults.ENABLE_EXPERIMENTAL_DECOMPOSITIONS, + dryrun: bool = _defaults.DRYRUN, + hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile a TorchScript module for NVIDIA GPUs using TensorRT @@ -158,7 +130,6 @@ def compile( enable_experimental_decompositions (bool): Use the full set of operator decompositions. These decompositions may not be tested but serve to make the grap easier to covert to TensorRT, potentially increasing the amount of graphs run in TensorRT. dryrun (bool): Toggle for "Dryrun" mode, running everything except conversion to TRT and logging outputs hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer) - output_format (str): Output format of the result of TRT compilation. Options include "exported_program" (or) "ep" | "torchscript" (or) "ts" | "graph_module" (or) "fx". Default is "exported_program" **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -167,6 +138,8 @@ def compile( if debug: set_log_level(logger.parent, logging.DEBUG) + engine_capability = EngineCapability._from(engine_capability) + if torch_executed_modules is not None and torch_executed_modules: logger.warning( f"Detected torch_executed_modules was non-empty: {torch_executed_modules}" @@ -179,6 +152,7 @@ def compile( # Prepare torch_trt inputs inputs = prepare_inputs(inputs) device = to_torch_tensorrt_device(device) + enabled_precisions = {dtype._from(p) for p in enabled_precisions} if not isinstance(exported_program, ExportedProgram): raise AssertionError( @@ -195,28 +169,10 @@ def compile( logger.debug("Lowered Input graph: " + str(gm.graph)) - enabled_precisions = set(enabled_precisions) - - if ( - torch.float16 in enabled_precisions - or torch_tensorrt.dtype.half in enabled_precisions - ): - precision = torch.float16 - elif ( - torch.float32 in enabled_precisions - or torch_tensorrt.dtype.float in enabled_precisions - ): - precision = torch.float32 - elif len(enabled_precisions) == 0: - logger.info(f"No precision specified, defaulting to {PRECISION}") - precision = PRECISION - else: - raise ValueError( - f"Precision {enabled_precisions} not supported in the Dynamo Path" - ) - compilation_options = { - "precision": precision, + "enabled_precisions": ( + enabled_precisions if enabled_precisions else _defaults.ENABLED_PRECISIONS + ), "debug": debug, "device": device, "workspace_size": workspace_size, @@ -283,7 +239,7 @@ def compile_module( sample_inputs, "shape", lambda x: dict(x) if isinstance(x, dict) else tuple(x) ) dryrun_tracker.graph_input_dtypes = parse_complex_tensor_structs( - sample_inputs, "torch_dtype" + sample_inputs, "dtype", lambda t: t.to(torch.dtype, use_default=True) ) dryrun_tracker.compilation_settings = settings @@ -408,7 +364,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: lambda x: dict(x) if isinstance(x, dict) else tuple(x), ) subgraph_data.subgraph_input_dtypes = parse_complex_tensor_structs( - submodule_inputs, "torch_dtype" + submodule_inputs, "dtype", lambda t: t.to(torch.dtype) ) submodule_outputs = submodule( @@ -469,29 +425,31 @@ def convert_module_to_trt_engine( module: torch.fx.GraphModule, method_name: str = "forward", inputs: Optional[Sequence[Input | torch.Tensor]] = None, - enabled_precisions: Optional[Set[torch.dtype | _enums.dtype]] = None, - debug: bool = DEBUG, - workspace_size: int = WORKSPACE_SIZE, - min_block_size: int = MIN_BLOCK_SIZE, - torch_executed_ops: Set[str] = set(), - pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES, - max_aux_streams: Optional[int] = MAX_AUX_STREAMS, - version_compatible: bool = VERSION_COMPATIBLE, - optimization_level: Optional[int] = OPTIMIZATION_LEVEL, - use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME, - truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE, - use_fast_partitioner: bool = USE_FAST_PARTITIONER, - enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS, + enabled_precisions: ( + Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype] + ) = _defaults.ENABLED_PRECISIONS, + debug: bool = _defaults.DEBUG, + workspace_size: int = _defaults.WORKSPACE_SIZE, + min_block_size: int = _defaults.MIN_BLOCK_SIZE, + torch_executed_ops: Optional[Set[str]] = None, + pass_through_build_failures: bool = _defaults.PASS_THROUGH_BUILD_FAILURES, + max_aux_streams: Optional[int] = _defaults.MAX_AUX_STREAMS, + version_compatible: bool = _defaults.VERSION_COMPATIBLE, + optimization_level: Optional[int] = _defaults.OPTIMIZATION_LEVEL, + use_python_runtime: Optional[bool] = _defaults.USE_PYTHON_RUNTIME, + truncate_long_and_double: bool = _defaults.TRUNCATE_LONG_AND_DOUBLE, + use_fast_partitioner: bool = _defaults.USE_FAST_PARTITIONER, + enable_experimental_decompositions: bool = _defaults.ENABLE_EXPERIMENTAL_DECOMPOSITIONS, device: Device = Device._current_device(), - require_full_compilation: bool = REQUIRE_FULL_COMPILATION, - disable_tf32: bool = DISABLE_TF32, - sparse_weights: bool = SPARSE_WEIGHTS, - refit: bool = REFIT, - engine_capability: EngineCapability = ENGINE_CAPABILITY, - num_avg_timing_iters: int = NUM_AVG_TIMING_ITERS, - dla_sram_size: int = DLA_SRAM_SIZE, - dla_local_dram_size: int = DLA_LOCAL_DRAM_SIZE, - dla_global_dram_size: int = DLA_GLOBAL_DRAM_SIZE, + require_full_compilation: bool = _defaults.REQUIRE_FULL_COMPILATION, + disable_tf32: bool = _defaults.DISABLE_TF32, + sparse_weights: bool = _defaults.SPARSE_WEIGHTS, + refit: bool = _defaults.REFIT, + engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, + num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, + dla_sram_size: int = _defaults.DLA_SRAM_SIZE, + dla_local_dram_size: int = _defaults.DLA_LOCAL_DRAM_SIZE, + dla_global_dram_size: int = _defaults.DLA_GLOBAL_DRAM_SIZE, calibrator: object = None, allow_shape_tensors: bool = False, ) -> bytes: @@ -575,34 +533,15 @@ def convert_module_to_trt_engine( set_log_level(logger.parent, logging.DEBUG) input_list = list(inputs) if inputs is not None else [] + torch_executed_ops = torch_executed_ops if torch_executed_ops is not None else set() # Prepare torch_trt inputs input_list = prepare_inputs(input_list) device = to_torch_tensorrt_device(device) - enabled_precisions = ( - enabled_precisions if enabled_precisions is not None else {torch.float} - ) - - if ( - torch.float16 in enabled_precisions - or torch_tensorrt.dtype.half in enabled_precisions - ): - precision = torch.float16 - elif ( - torch.float32 in enabled_precisions - or torch_tensorrt.dtype.float in enabled_precisions - ): - precision = torch.float32 - elif len(enabled_precisions) == 0: - logger.info(f"No precision specified, defaulting to {PRECISION}") - precision = PRECISION - else: - raise ValueError( - f"Precision {enabled_precisions} not supported in the Dynamo Path" - ) + enabled_precisions = {dtype._from(e) for e in enabled_precisions} compilation_options = { - "precision": precision, + "enabled_precisions": enabled_precisions, "debug": debug, "workspace_size": workspace_size, "min_block_size": min_block_size, diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 3d48ab3def..27db215466 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -1,8 +1,8 @@ import torch -from tensorrt import EngineCapability from torch_tensorrt._Device import Device +from torch_tensorrt._enums import EngineCapability, dtype -PRECISION = torch.float32 +ENABLED_PRECISIONS = {dtype.f32} DEBUG = False DEVICE = None DISABLE_TF32 = False @@ -26,6 +26,7 @@ REQUIRE_FULL_COMPILATION = False DRYRUN = False HARDWARE_COMPATIBLE = False +SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.i8} def default_device() -> Device: diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 2420a227d8..3aee629812 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -1,10 +1,9 @@ from dataclasses import dataclass, field from typing import Collection, Optional, Union -import torch -from tensorrt import EngineCapability from torch.fx.node import Target from torch_tensorrt._Device import Device +from torch_tensorrt._enums import EngineCapability, dtype from torch_tensorrt.dynamo._defaults import ( DEBUG, DISABLE_TF32, @@ -13,6 +12,7 @@ DLA_SRAM_SIZE, DRYRUN, ENABLE_EXPERIMENTAL_DECOMPOSITIONS, + ENABLED_PRECISIONS, ENGINE_CAPABILITY, HARDWARE_COMPATIBLE, MAX_AUX_STREAMS, @@ -20,7 +20,6 @@ NUM_AVG_TIMING_ITERS, OPTIMIZATION_LEVEL, PASS_THROUGH_BUILD_FAILURES, - PRECISION, REFIT, REQUIRE_FULL_COMPILATION, SPARSE_WEIGHTS, @@ -38,7 +37,7 @@ class CompilationSettings: """Compilation settings for Torch-TensorRT Dynamo Paths Args: - precision (torch.dtype): Model Layer precision + enabled_precisions (Set[dtype]): Available kernel dtype precisions debug (bool): Whether to print out verbose debugging information workspace_size (int): Workspace TRT is allowed to use for the module (0 is default) min_block_size (int): Minimum number of operators per TRT-Engine Block @@ -72,7 +71,7 @@ class CompilationSettings: hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer) """ - precision: torch.dtype = PRECISION + enabled_precisions: dtype = field(default_factory=lambda: ENABLED_PRECISIONS) debug: bool = DEBUG workspace_size: int = WORKSPACE_SIZE min_block_size: int = MIN_BLOCK_SIZE @@ -90,7 +89,9 @@ class CompilationSettings: disable_tf32: bool = DISABLE_TF32 sparse_weights: bool = SPARSE_WEIGHTS refit: bool = REFIT - engine_capability: EngineCapability = ENGINE_CAPABILITY + engine_capability: EngineCapability = field( + default_factory=lambda: ENGINE_CAPABILITY + ) num_avg_timing_iters: int = NUM_AVG_TIMING_ITERS dla_sram_size: int = DLA_SRAM_SIZE dla_local_dram_size: int = DLA_LOCAL_DRAM_SIZE diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 06ae596ed0..1cebc8679d 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -10,7 +10,9 @@ from torch.fx.node import _get_qualified_name from torch.fx.passes.shape_prop import TensorMetadata from torch.utils._python_dispatch import _disable_current_modes +from torch_tensorrt._enums import dtype from torch_tensorrt._Input import Input +from torch_tensorrt.dynamo import _defaults from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( @@ -22,7 +24,7 @@ get_trt_tensor, ) from torch_tensorrt.fx.observer import Observer -from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter +from torch_tensorrt.logging import TRT_LOGGER from packaging import version @@ -50,13 +52,12 @@ def __init__( module: torch.fx.GraphModule, input_specs: Sequence[Input], logger_level: trt.ILogger.Severity = trt.ILogger.Severity.WARNING, - output_dtypes: Optional[Sequence[torch.dtype]] = None, + output_dtypes: Optional[Sequence[dtype]] = None, compilation_settings: CompilationSettings = CompilationSettings(), ): super().__init__(module) - # TODO: @narendasan replace with Torch-TensorRT Logger - self.logger = trt.Logger(logger_level) + self.logger = TRT_LOGGER self.builder = trt.Builder(self.logger) flag = 0 @@ -69,9 +70,11 @@ def __init__( self.builder.create_network(flag), compilation_settings ) + assert TRTInterpreter._all_precisions_supported( + compilation_settings.enabled_precisions + ), f"Attempted to enable kernel precisions that are not supported (got: {compilation_settings.enabled_precisions}, support: {_defaults.SUPPORTED_KERNEL_PRECISIONS})" missing_ops = self.validate_conversion() if missing_ops: - # TODO: @narendasan make sure to set logging.captureWarnings(True) warnings.warn( "Interpretation will fail due to missing operations \n" + "\n".join(f"{i}" for i in missing_ops) @@ -98,7 +101,11 @@ def __init__( self.compilation_settings = compilation_settings # Data types for TRT Module output Tensors - self.output_dtypes = output_dtypes + self.output_dtypes = ( + [dtype._from(o) for o in output_dtypes] if output_dtypes else None + ) + + _LOGGER.debug(f"Graph to be compiled to TensorRT: {self.module.graph}") def validate_conversion(self) -> Set[str]: missing_converters: Set[str] = set() @@ -116,60 +123,58 @@ def validate_conversion(self) -> Set[str]: return missing_converters - def run( - self, - force_fp32_output: bool = False, - strict_type_constraints: bool = False, - algorithm_selector: Optional[trt.IAlgorithmSelector] = None, - timing_cache: Optional[trt.ITimingCache] = None, - tactic_sources: Optional[int] = None, - ) -> TRTInterpreterResult: - """ - Build TensorRT engine with some configs. - Args: - force_fp32_output: force output to be fp32 - strict_type_constraints: Usually we should set it to False unless we want to control the precision of certain layer for numeric reasons. - algorithm_selector: set up algorithm selection for certain layer - timing_cache: enable timing cache for TensorRT - Return: - TRTInterpreterResult - """ - TRT_INTERPRETER_CALL_PRE_OBSERVER.observe(self.module) + @staticmethod + def _args_str(args: List[Any]) -> str: + def clean_repr(x: Any, depth: int = 0) -> Any: + if isinstance(x, trt.ITensor): + return f"{x.name} " + elif isinstance(x, torch.Tensor): + return f"" + elif isinstance(x, np.ndarray): + return ( + f"" + ) + elif isinstance(x, Sequence) and not isinstance(x, str): + if depth < 3: + return type(x)([clean_repr(i, depth=depth + 1) for i in x]) # type: ignore[call-arg] + else: + return "(...)" + else: + return x + + str_args = [clean_repr(a) for a in args] + return repr(tuple(str_args)) - precision = self.compilation_settings.precision - # For float outputs, we set their dtype to fp16 only if precision == torch.float16 and - # force_fp32_output=False. Overriden by specifying output_dtypes - self.output_fp16 = not force_fp32_output and precision == torch.float16 + @staticmethod + def _all_precisions_supported(enabled_precisions: Set[dtype]) -> bool: + return enabled_precisions.issubset(_defaults.SUPPORTED_KERNEL_PRECISIONS) - if precision == torch.int8 and not self.builder.platform_has_fast_int8: + def validate_compile_settings(self) -> None: + if ( + dtype.i8 in self.compilation_settings.enabled_precisions + and not self.builder.platform_has_fast_int8 + ): raise RuntimeError("Current platform doesn't support fast native int8!") - if precision == torch.float16 and not self.builder.platform_has_fast_fp16: + if ( + dtype.f16 in self.compilation_settings.enabled_precisions + and not self.builder.platform_has_fast_fp16 + ): warnings.warn("Current platform doesn't support fast native fp16!") - self.input_specs_iter = 0 - run_module_start_time = datetime.now() - super().run() - _LOGGER.info( - f"TRT INetwork construction elapsed time: {datetime.now() - run_module_start_time}" - ) - build_engine_start_time = datetime.now() + def _populate_trt_builder_config( + self, + strict_type_constraints: bool = False, + algorithm_selector: Optional[trt.IAlgorithmSelector] = None, + tactic_sources: Optional[int] = None, + ) -> trt.IBuilderConfig: builder_config = self.builder.create_builder_config() - if self.compilation_settings.workspace_size != 0: builder_config.set_memory_pool_limit( trt.MemoryPoolType.WORKSPACE, self.compilation_settings.workspace_size ) - cache = None - if timing_cache: - cache_file = np.array(timing_cache) - cache = builder_config.create_timing_cache(cache_file.tobytes()) - else: - cache = builder_config.create_timing_cache(b"") - builder_config.set_timing_cache(cache, False) - if version.parse(trt.__version__) >= version.parse("8.2"): builder_config.profiling_verbosity = ( trt.ProfilingVerbosity.VERBOSE @@ -201,12 +206,20 @@ def run( self.compilation_settings.optimization_level ) - builder_config.engine_capability = self.compilation_settings.engine_capability + builder_config.engine_capability = ( + self.compilation_settings.engine_capability.to(trt.EngineCapability) + ) builder_config.avg_timing_iterations = ( self.compilation_settings.num_avg_timing_iters ) if self.compilation_settings.device.device_type == trt.DeviceType.DLA: + device_info = torch.cuda.get_device_properties( + self.compilation_settings.device.gpu_id + ) + assert (device_info.major == 8 and device_info.minor == 7) or ( + device_info.major == 7 and device_info.minor == 2 + ), "DLA is not available on non AGX systems" builder_config.DLA_core = self.compilation_settings.device.dla_core _LOGGER.info(f"Using DLA core {self.compilation_settings.device.dla_core}") builder_config.set_memory_pool_limit( @@ -222,10 +235,10 @@ def run( self.compilation_settings.dla_global_dram_size, ) - if precision == torch.float16: + if dtype.float16 in self.compilation_settings.enabled_precisions: builder_config.set_flag(trt.BuilderFlag.FP16) - if precision == torch.int8: + if dtype.int8 in self.compilation_settings.enabled_precisions: builder_config.set_flag(trt.BuilderFlag.INT8) if self.compilation_settings.sparse_weights: @@ -252,11 +265,58 @@ def run( if tactic_sources is not None: builder_config.set_tactic_sources(tactic_sources=tactic_sources) + return builder_config + + def _create_timing_cache( + self, + builder_config: trt.IBuilderConfig, + existing_cache: Optional[trt.ITimingCache] = None, + ) -> trt.ITimingCache: + cache = None + if existing_cache: + cache_file = np.array(existing_cache) + cache = builder_config.create_timing_cache(cache_file.tobytes()) + else: + cache = builder_config.create_timing_cache(b"") + builder_config.set_timing_cache(cache, False) + return cache + + def run( + self, + strict_type_constraints: bool = False, + algorithm_selector: Optional[trt.IAlgorithmSelector] = None, + existing_cache: Optional[trt.ITimingCache] = None, + tactic_sources: Optional[int] = None, + ) -> TRTInterpreterResult: + """ + Build TensorRT engine with some configs. + Args: + strict_type_constraints: Usually we should set it to False unless we want to control the precision of certain layer for numeric reasons. + algorithm_selector: set up algorithm selection for certain layer + existing_cache: enable timing cache for TensorRT + Return: + TRTInterpreterResult + """ + TRT_INTERPRETER_CALL_PRE_OBSERVER.observe(self.module) + + self.input_specs_iter = 0 + run_module_start_time = datetime.now() + super().run() + _LOGGER.info( + f"TRT INetwork construction elapsed time: {datetime.now() - run_module_start_time}" + ) + build_engine_start_time = datetime.now() + + builder_config = self._populate_trt_builder_config( + strict_type_constraints, algorithm_selector, tactic_sources + ) + timing_cache = self._create_timing_cache(builder_config, existing_cache) + engine = self.builder.build_engine(self.ctx.net, builder_config) assert engine serialized_cache = ( - bytearray(cache.serialize()) + bytearray(timing_cache.serialize()) if builder_config.get_timing_cache() else bytearray() ) @@ -285,7 +345,7 @@ def run_node(self, n: torch.fx.Node) -> torch.fx.Node: del kwargs["_itensor_to_tensor_meta"] n.kwargs = kwargs - if isinstance(trt_node, trt.tensorrt.ITensor): + if isinstance(trt_node, trt.ITensor): self._itensor_to_tensor_meta[trt_node] = n.meta.get("tensor_meta") return trt_node @@ -323,10 +383,14 @@ def placeholder(self, target: str, args: Any, kwargs: Any) -> trt.ITensor: f"Unable to access shape spec for input: {target} (got: {current_input})" ) + trt_input_dtype = current_input.dtype.to(trt.DataType, use_default=True) + _LOGGER.debug( + f"Adding input to in-progress INetwork: {target} [shape={shape}, dtype={trt_input_dtype}]" + ) return self.ctx.net.add_input( name=target, shape=tuple(shape), - dtype=unified_dtype_converter(current_input.torch_dtype, Frameworks.TRT), + dtype=trt_input_dtype, ) def call_module( @@ -345,6 +409,9 @@ def call_module( converter, calling_convention = converter_packet assert self._cur_node_name is not None + _LOGGER.debug( + f"Converting node {self._cur_node_name} (kind: {target}, args: {TRTInterpreter._args_str(args)})" + ) if calling_convention is CallingConvention.LEGACY: return converter(self.ctx.net, submod, args, kwargs, self._cur_node_name) else: @@ -361,6 +428,9 @@ def call_function(self, target: str, args: Any, kwargs: Any) -> Any: converter, calling_convention = converter_packet assert self._cur_node_name is not None + _LOGGER.debug( + f"Converting node {self._cur_node_name} (kind: {target}, args: {TRTInterpreter._args_str(args)})" + ) if calling_convention is CallingConvention.LEGACY: return converter(self.ctx.net, target, args, kwargs, self._cur_node_name) else: @@ -392,6 +462,9 @@ def call_method(self, target: str, args: Any, kwargs: Any) -> Any: converter, calling_convention = converter_packet assert self._cur_node_name is not None + _LOGGER.debug( + f"Converting node {self._cur_node_name} (kind: {target}, args: {TRTInterpreter._args_str(args)})" + ) if calling_convention is CallingConvention.LEGACY: return converter(self.ctx.net, target, args, kwargs, self._cur_node_name) else: @@ -409,13 +482,13 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]: for output_idx in range(len(outputs)): output = outputs[output_idx] - if not isinstance(output, trt.tensorrt.ITensor): + if not isinstance(output, trt.ITensor): new_output = get_trt_tensor(self.ctx, output, target) outputs = ( outputs[:output_idx] + (new_output,) + outputs[output_idx + 1 :] ) - if not all(isinstance(output, trt.tensorrt.ITensor) for output in outputs): + if not all(isinstance(output, trt.ITensor) for output in outputs): raise RuntimeError("TensorRT requires all outputs to be Tensor!") if self.output_dtypes is not None and len(self.output_dtypes) != len(outputs): @@ -436,6 +509,7 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]: "not", "ne", "isinf", + "isnan", "any", ) ): @@ -446,13 +520,13 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]: output.name = name self.ctx.net.mark_output(output) if output_bool: - output.dtype = trt.bool + output.dtype = trt.DataType.BOOL elif self.output_dtypes is not None: - output.dtype = unified_dtype_converter( - self.output_dtypes[i], Frameworks.TRT - ) - elif self.output_fp16 and output.dtype == trt.float32: - output.dtype = trt.float16 + output.dtype = self.output_dtypes[i].to(trt.DataType) + self._output_names.append(name) + _LOGGER.debug( + f"Marking output {name} [shape={output.shape}, dtype={output.dtype}]" + ) return list(outputs) diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index 844cb6789a..ec7fdf4126 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -1,10 +1,14 @@ from __future__ import annotations import io -from typing import Sequence +import logging +from typing import List, Sequence import tensorrt as trt import torch +from torch_tensorrt._Device import Device +from torch_tensorrt._enums import dtype +from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt._Input import Input from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.conversion._TRTInterpreter import ( @@ -12,7 +16,38 @@ TRTInterpreterResult, ) from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule -from torch_tensorrt.dynamo.utils import get_torch_inputs, to_torch_device +from torch_tensorrt.dynamo.utils import get_torch_inputs + +logger = logging.getLogger(__name__) + + +def infer_module_output_dtypes( + module: torch.fx.GraphModule, + inputs: Sequence[Input], + device: Device, + truncate_long_and_double: bool = False, +) -> List[dtype]: + torch_inputs = get_torch_inputs(inputs, device) + module = module.to(device.to(torch.device)) + module_outputs = module(*torch_inputs) + + if not isinstance(module_outputs, (list, tuple)): + module_outputs = [module_outputs] + + # Int64 outputs can sometimes be generated from within other operators + # such as aten.sum - such outputs can be truncated + output_dtypes = [] + for output in module_outputs: + if not isinstance(output, torch.Tensor): + output = torch.tensor(output) + if truncate_long_and_double and output.dtype == dtype.float64: + output_dtypes.append(dtype.float32) + elif truncate_long_and_double and output.dtype == dtype.int64: + output_dtypes.append(dtype.int32) + else: + output_dtypes.append(dtype._from(output.dtype)) + + return output_dtypes def interpret_module_to_result( @@ -28,23 +63,12 @@ def interpret_module_to_result( Returns: TRTInterpreterResult """ - torch_inputs = get_torch_inputs(inputs, settings.device) - module.to(to_torch_device(settings.device)) - module_outputs = module(*torch_inputs) - - if not isinstance(module_outputs, (list, tuple)): - module_outputs = [module_outputs] - - # Int64 outputs can sometimes be generated from within other operators - # such as aten.sum - such outputs can be truncated - output_dtypes = [] - for output in module_outputs: - if settings.truncate_long_and_double and output.dtype == torch.float64: - output_dtypes.append(torch.float32) - elif settings.truncate_long_and_double and output.dtype == torch.int64: - output_dtypes.append(torch.int32) - else: - output_dtypes.append(output.dtype) + output_dtypes = infer_module_output_dtypes( + module, + inputs, + settings.device, + truncate_long_and_double=settings.truncate_long_and_double, + ) interpreter = TRTInterpreter( module, @@ -74,7 +98,11 @@ def convert_module( """ interpreter_result = interpret_module_to_result(module, inputs, settings) - if settings.use_python_runtime: + if settings.use_python_runtime or not ENABLED_FEATURES.torch_tensorrt_runtime: + if not settings.use_python_runtime: + logger.info( + "Since Torch-TensorRT runtime is not available, using Python Runtime, some features may not be available" + ) return PythonTorchTensorRTModule( engine=interpreter_result.engine, input_names=list(interpreter_result.input_names), diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index f9d14917f1..7e55110459 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -8,17 +8,14 @@ import torch from torch import SymBool, SymFloat, SymInt from torch.fx.node import Argument, Target +from torch_tensorrt import _enums from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( ConverterRegistry, DynamoConverterImplSignature, ) -from torch_tensorrt.fx.converters.converter_utils import ( - Frameworks, - get_axes_for_reduce_op, - unified_dtype_converter, -) +from torch_tensorrt.fx.converters.converter_utils import get_axes_for_reduce_op from torch_tensorrt.fx.types import TRTDataType, TRTTensor _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -44,7 +41,6 @@ def get_node_name(node: torch.fx.Node) -> str: # like the node.meta['source_fn'] attr pass - _LOGGER.debug(f"Node meta name {node_name}") return node_name @@ -121,7 +117,7 @@ def _is_subnode_dynamic(subnode: torch.fx.Node) -> bool: def cast_trt_tensor( ctx: ConversionContext, input_val: TRTTensor, - dtype: TRTDataType, + dtype: Union[TRTDataType, torch.dtype, np.dtype, _enums.dtype], name: str, target: Target = "", source_ir: Optional[SourceIR] = None, @@ -142,7 +138,7 @@ def cast_trt_tensor( Returns: A TensorRT ITensor which has been casted to the specified dtype """ - trt_dtype = unified_dtype_converter(dtype, Frameworks.TRT) + trt_dtype = _enums.dtype._from(dtype).to(trt.DataType) if input_val.dtype != trt_dtype: source_ir = source_ir if source_ir is not None else SourceIR.UNKNOWN @@ -253,7 +249,7 @@ def create_constant( ctx: ConversionContext, value: Union[int, float, bool, np.ndarray, torch.Tensor], name: str, - dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType]], + dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType, _enums.dtype]], ) -> TRTTensor: """ Add a TensorRT constant layer whose value is `value` to `ctx.net`. @@ -268,7 +264,9 @@ def create_constant( Returns: A TensorRT ITensor that represents the given value. """ - numpy_value = to_numpy(value, dtype) + numpy_value = to_numpy( + value, _enums.dtype._from(dtype).to(np.dtype) if dtype is not None else None + ) constant = ctx.net.add_constant( (1,) if isinstance(value, (int, float, bool)) else value.shape, numpy_value.copy() if isinstance(numpy_value, np.ndarray) else numpy_value, @@ -281,7 +279,7 @@ def get_trt_tensor( ctx: ConversionContext, input_val: Any, name: str, - dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType]] = None, + dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType, _enums.dtype]] = None, ) -> TRTTensor: """ Given a value of random type, we try to convert it to a TensorRT ITensor. @@ -466,7 +464,7 @@ def convert_with_type_enforcement( def to_numpy( value: Optional[Union[torch.Tensor, np.ndarray, int, float, bool]], - dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType]] = None, + dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType, _enums.dtype]] = None, ) -> Optional[np.ndarray]: """ Convert a PyTorch Tensor, Numpy array, or scalar to a Numpy Array. If the tensor is @@ -503,7 +501,7 @@ def to_numpy( return ( output if (dtype is None or output is None) - else output.astype(unified_dtype_converter(dtype, Frameworks.NUMPY)) + else output.astype(_enums.dtype._from(dtype).to(np.dtype)) ) else: raise AssertionError( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/cast.py b/py/torch_tensorrt/dynamo/conversion/impl/cast.py index bc6af1a32d..817cecd2c7 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/cast.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/cast.py @@ -1,15 +1,15 @@ import logging -from typing import Optional +from typing import Optional, Union +import numpy as np +import tensorrt as trt +import torch from torch.fx.node import Target +from torch_tensorrt import _enums from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion._ConverterRegistry import ConverterRegistry from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor -from torch_tensorrt.fx.converters.converter_utils import ( - Frameworks, - unified_dtype_converter, -) from torch_tensorrt.fx.types import TRTDataType, TRTTensor LOGGER: logging.Logger = logging.getLogger(__name__) @@ -21,7 +21,7 @@ def to_copy( source_ir: Optional[SourceIR], name: str, input: TRTTensor, - dtype: TRTDataType, + dtype: Union[TRTDataType, torch.dtype, np.dtype, _enums.dtype], force_layer: bool = False, ) -> TRTTensor: if not isinstance(input, TRTTensor): @@ -32,7 +32,7 @@ def to_copy( # If cast is forced, insert identity layer regardless of whether the dtype # doesn't change if force_layer: - trt_dtype = unified_dtype_converter(dtype, Frameworks.TRT) + trt_dtype = _enums.dtype._from(dtype).to(trt.DataType) source_ir = source_ir if source_ir is not None else SourceIR.UNKNOWN target_str = ConverterRegistry.qualified_name_or_str(target) target_name = f"{source_ir}_ops{('.' + target_str) if target_str else ''}" diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py index 8282ee8698..90a1c07229 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py @@ -6,6 +6,7 @@ import tensorrt as trt import torch from torch.fx.node import Target +from torch_tensorrt import _enums from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( @@ -14,7 +15,6 @@ ) from torch_tensorrt.fx.converters.converter_utils import broadcast, set_layer_name from torch_tensorrt.fx.types import TRTElementWiseOp, TRTTensor -from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter def get_python_op_from_trt_elementwise_op( @@ -121,22 +121,20 @@ def convert_binary_elementwise( # dtype but we don't have a way to detect whether it makes sense for the # scalar to be float or half. Hence we go with the lhs dtype. if is_lhs_trt_tensor and isinstance(rhs_val, (float, int, bool)): - rhs_val = np.array( - [rhs_val], dtype=unified_dtype_converter(lhs_dtype, Frameworks.NUMPY) - ) + rhs_val = np.array([rhs_val], dtype=_enums.dtype._from(lhs_dtype).to(np.dtype)) if is_rhs_trt_tensor and isinstance(lhs_val, (float, int, bool)): - lhs_val = np.array( - [lhs_val], dtype=unified_dtype_converter(rhs_dtype, Frameworks.NUMPY) - ) + lhs_val = np.array([lhs_val], dtype=_enums.dtype._from(rhs_dtype).to(np.dtype)) lhs_val = get_trt_tensor(ctx, lhs_val, f"{name}_lhs", lhs_dtype) rhs_val = get_trt_tensor(ctx, rhs_val, f"{name}_rhs", rhs_dtype) - promoted_type = torch.promote_types( - unified_dtype_converter(lhs_val.dtype, Frameworks.TORCH), - unified_dtype_converter(rhs_val.dtype, Frameworks.TORCH), + promoted_type = _enums.dtype._from( + torch.promote_types( + _enums.dtype._from(lhs_val.dtype).to(torch.dtype), + _enums.dtype._from(rhs_val.dtype).to(torch.dtype), + ) ) - trt_promoted_type = unified_dtype_converter(promoted_type, Frameworks.TRT) + trt_promoted_type = promoted_type.to(trt.DataType) if trt_promoted_type != lhs_val.dtype: lhs_val = cast_trt_tensor( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py index d0f6d29482..81c3a3e867 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py @@ -4,6 +4,7 @@ import torch import torch_tensorrt.dynamo.conversion.impl as impl from torch.fx.node import Target +from torch_tensorrt import _enums from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( @@ -17,7 +18,6 @@ from torch_tensorrt.dynamo.conversion.impl.unary import sign from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary from torch_tensorrt.fx.types import TRTTensor -from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter def trunc_div( @@ -70,7 +70,7 @@ def trunc_div( ctx, other, f"{name}_other", - dtype=unified_dtype_converter(input.dtype, Frameworks.TORCH), + dtype=_enums.dtype._from(input.dtype).to(torch.dtype), ) abs_input_output = convert_unary( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/matmul.py b/py/torch_tensorrt/dynamo/conversion/impl/matmul.py index a50ec3c434..fe48d5f42e 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/matmul.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/matmul.py @@ -1,13 +1,14 @@ from typing import Optional import tensorrt as trt +import torch from torch.fx.node import Target +from torch_tensorrt import _enums from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor from torch_tensorrt.fx.converters.converter_utils import broadcast, set_layer_name from torch_tensorrt.fx.types import TRTTensor -from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter def matrix_multiply( @@ -20,14 +21,14 @@ def matrix_multiply( input_matrix_op: trt.MatrixOperation = trt.MatrixOperation.NONE, other_matrix_op: trt.MatrixOperation = trt.MatrixOperation.NONE, ) -> TRTTensor: - if not isinstance(input, trt.tensorrt.ITensor): + if not isinstance(input, trt.ITensor): input = get_trt_tensor(ctx, input, f"{name}_input") - if not isinstance(other, trt.tensorrt.ITensor): + if not isinstance(other, trt.ITensor): other = get_trt_tensor( ctx, other, f"{name}_other", - dtype=unified_dtype_converter(input.dtype, Frameworks.TORCH), + dtype=_enums.dtype._from(input.dtype).to(torch.dtype), ) preset_diff = 0 diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index dc33129d24..bf63e4300f 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -81,7 +81,7 @@ def index( source_ir: Optional[SourceIR], name: str, input: TRTTensor, - index: Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]], + indices: Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]], ) -> TRTTensor: adv_indx_indices = [] tensor_indices = [] @@ -93,12 +93,14 @@ def index( "Determining whether aten.index constant-index optimization can be invoked" ) is_numpy = all( - isinstance(ind, (torch.Tensor, np.ndarray)) for ind in index if ind is not None + isinstance(ind, (torch.Tensor, np.ndarray)) + for ind in indices + if ind is not None ) # here we need to check if all the index are broadcastable # if no, then we need to broadcast last_index = None - for i, ind in enumerate(index): + for i, ind in enumerate(indices): if ind is not None: _LOGGER.debug(f"Shape of {i} index is {ind.shape}") adv_indx_indices.append(i) @@ -369,4 +371,22 @@ def index( ) reshape_output = reshape_layer.get_output(0) - return reshape_output + return reshape_output + + +def index_select( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + dim: int, + index: TRTTensor, +) -> TRTTensor: + # The axis parameter specifies the dimension along which to index. + dim = get_positive_dim(dim, len(input.shape)) + gather_layer = ctx.net.add_gather(input, index, axis=dim) + + set_layer_name(gather_layer, target, f"{name}_gather", source_ir) + + return gather_layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/truncate_long_and_double.py b/py/torch_tensorrt/dynamo/conversion/truncate_long_and_double.py index 9390bc3bde..d5670be1db 100644 --- a/py/torch_tensorrt/dynamo/conversion/truncate_long_and_double.py +++ b/py/torch_tensorrt/dynamo/conversion/truncate_long_and_double.py @@ -4,6 +4,7 @@ import torch from torch.fx.node import _get_qualified_name +from torch_tensorrt._enums import dtype from torch_tensorrt._Input import Input from torch_tensorrt.dynamo.utils import get_torch_inputs @@ -217,6 +218,8 @@ def repair_long_or_double_inputs( # Set the 32bit inputs and their types to the submodule Inputs for idx in range(len(submodule_inputs)): submodule_inputs[idx].torch_tensor = submodule_torch_inputs[idx] - submodule_inputs[idx].torch_dtype = submodule_torch_inputs[idx].dtype + submodule_inputs[idx].dtype = dtype._from( + submodule_torch_inputs[idx].dtype + ) return submodule_inputs diff --git a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py index e263b51bb2..c00d92577c 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py @@ -110,6 +110,7 @@ def __init__( allowed_single_node_partition_ops: Optional[Collection[str]] = None, min_block_size: int = MIN_BLOCK_SIZE, require_full_compilation: bool = REQUIRE_FULL_COMPILATION, + return_tuple: bool = False, ): """ Preprocesses graph before splitting: @@ -149,6 +150,7 @@ def __init__( self.num_trt_accelerated_subgraphs: Optional[int] = None self.allowed_single_node_partition_ops = allowed_single_node_partition_ops self.require_full_compilation = require_full_compilation + self._return_tuple = return_tuple def remove_small_acc_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subgraph]: """ diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 3a66ed3716..ac1329e8f8 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -9,12 +9,12 @@ import torch_tensorrt from torch.nn import Module from torch_tensorrt._Device import Device +from torch_tensorrt._enums import dtype from torch_tensorrt.dynamo.runtime.tools import ( _is_switch_required, _select_rt_device, multi_gpu_device_check, ) -from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter logger = logging.getLogger(__name__) @@ -84,9 +84,7 @@ def _initialize(self) -> None: ) self.input_dtypes = [ - unified_dtype_converter( - self.engine.get_binding_dtype(idx), Frameworks.TORCH - ) + dtype._from(self.engine.get_binding_dtype(idx)) for idx in self.input_binding_indices_in_order ] self.input_shapes: Sequence[Sequence[int]] = [ @@ -94,9 +92,7 @@ def _initialize(self) -> None: for idx in self.input_binding_indices_in_order ] self.output_dtypes = [ - unified_dtype_converter( - self.engine.get_binding_dtype(idx), Frameworks.TORCH - ) + dtype._from(self.engine.get_binding_dtype(idx)) for idx in self.output_binding_indices_in_order ] self.output_shapes = [ @@ -108,9 +104,7 @@ def _initialize(self) -> None: for idx in self.output_binding_indices_in_order ] self.hidden_output_dtypes = [ - unified_dtype_converter( - self.engine.get_binding_dtype(idx), Frameworks.TORCH - ) + dtype._from(self.engine.get_binding_dtype(idx)) for idx in self.hidden_output_binding_indices_in_order ] self.hidden_output_shapes = [ @@ -263,7 +257,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . output = torch.empty( size=shape, - dtype=self.output_dtypes[i], + dtype=self.output_dtypes[i].to(torch.dtype), device=torch.cuda.current_device(), ) outputs.append(output) @@ -274,7 +268,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . output = torch.empty( size=shape, - dtype=self.hidden_output_dtypes[i], + dtype=self.hidden_output_dtypes[i].to(torch.dtype), device=torch.cuda.current_device(), ) bindings[idx] = output.data_ptr() diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 549636b3c7..6ea9503b84 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -6,11 +6,11 @@ import torch from torch_tensorrt._Device import Device +from torch_tensorrt._enums import dtype from torch_tensorrt._Input import Input -from torch_tensorrt.dynamo._defaults import PRECISION +from torch_tensorrt.dynamo import _defaults from torch_tensorrt.dynamo._settings import CompilationSettings -import torch_tensorrt from packaging import version logger = logging.getLogger(__name__) @@ -197,10 +197,7 @@ def to_torch_device(device: Optional[Union[Device, torch.device, str]]) -> torch Returns the corresponding torch.device """ if isinstance(device, Device): - if device.gpu_id != -1: - return torch.device(device.gpu_id) - else: - raise ValueError("Invalid GPU ID provided for the CUDA device provided") + return device.to(torch.device) elif isinstance(device, torch.device): return device @@ -219,17 +216,7 @@ def to_torch_tensorrt_device( Returns the corresponding torch_tensorrt.Device """ - if isinstance(device, Device): - return device - - elif isinstance(device, torch.device): - return Device(gpu_id=device.index) - - elif device is None: - return Device(gpu_id=torch.cuda.current_device()) - - else: - return Device(device) + return Device._from(device) def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings: @@ -258,25 +245,15 @@ def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings: # TODO: Remove once Dynamo precisions refactoring is complete if "enabled_precisions" in kwargs: - enabled_precisions = kwargs["enabled_precisions"] - - if ( - torch.float16 in enabled_precisions - or torch_tensorrt.dtype.half in enabled_precisions - ): - settings.precision = torch.float16 - elif ( - torch.float32 in enabled_precisions - or torch_tensorrt.dtype.float in enabled_precisions - ): - settings.precision = torch.float32 - elif len(enabled_precisions) == 0: - logger.info(f"No precision specified, defaulting to {PRECISION}") - settings.precision = PRECISION - else: - raise ValueError( - f"Precision {enabled_precisions} not supported in the Dynamo Path" + enabled_precisions = {dtype._from(e) for e in kwargs["enabled_precisions"]} + + if len(enabled_precisions) == 0: + logger.info( + f"No precision specified, defaulting to {_defaults.ENABLED_PRECISION}" ) + enabled_precisions = _defaults.ENABLED_PRECISIONS + + settings.enabled_precisions = enabled_precisions # Parse input runtime specification settings.use_python_runtime = use_python_runtime_parser(settings.use_python_runtime) diff --git a/py/torch_tensorrt/logging.py b/py/torch_tensorrt/logging.py index e48e3c6317..a8d047e3ef 100644 --- a/py/torch_tensorrt/logging.py +++ b/py/torch_tensorrt/logging.py @@ -1,118 +1,34 @@ -from enum import Enum +import logging from typing import Any -from torch_tensorrt._C import ( - LogLevel, - _get_is_colored_output_on, - _get_logging_prefix, - _get_reportable_log_level, - _log, - _set_is_colored_output_on, - _set_logging_prefix, - _set_reportable_log_level, -) - - -class Level(Enum): - """Enum to set the minimum required logging level to print a message to stdout""" - - InternalError = LogLevel.INTERNAL_ERROR - Error = LogLevel.ERROR - Warning = LogLevel.WARNING - Info = LogLevel.INFO - Debug = LogLevel.DEBUG - Graph = LogLevel.GRAPH - - @staticmethod - def _to_internal_level(external: "Level") -> LogLevel: - if external == Level.InternalError: - return LogLevel.INTERNAL_ERROR - elif external == Level.Error: - return LogLevel.ERROR - elif external == Level.Warning: - return LogLevel.WARNING - elif external == Level.Info: - return LogLevel.INFO - elif external == Level.Debug: - return LogLevel.DEBUG - elif external == Level.Graph: - return LogLevel.GRAPH - else: - raise ValueError("Unknown log severity") - - -def get_logging_prefix() -> str: - """Get the prefix set for logging messages - - Returns: - str: Prefix used for logger - """ - return str(_get_logging_prefix()) - - -def set_logging_prefix(prefix: str) -> None: - """Set the prefix used when logging messages - - Args: - prefix (str): Prefix to use for logging messages - """ - _set_logging_prefix(prefix) - - -def get_reportable_log_level() -> Level: - """Get the level required for a message to be printed in the log - - Returns: - torch_tensorrt.logging.Level: The enum representing the level required to print - """ - return Level(_get_reportable_log_level()) - - -def set_reportable_log_level(level: Level) -> None: - """Set the level required for a message to be printed to the log +import tensorrt as trt +from torch_tensorrt._features import ENABLED_FEATURES - Args: - level (torch_tensorrt.logging.Level): The enum representing the level required to print - """ - _set_reportable_log_level(Level._to_internal_level(level)) - - -def get_is_colored_output_on() -> bool: - """Get if colored output is enabled for logging +logging.captureWarnings(True) +_LOGGER = logging.getLogger("torch_tensorrt [TensorRT Conversion Context]") - Returns: - bool: If colored output is one - """ - return bool(_get_is_colored_output_on()) +class _TRTLogger(trt.ILogger): # type: ignore[misc] -def set_is_colored_output_on(colored_output_on: bool) -> None: - """Enable or disable color in the log output + def __init__(self) -> None: + trt.ILogger.__init__(self) - Args: - colored_output_on (bool): If colored output should be enabled or not - """ - _set_is_colored_output_on(colored_output_on) + def log(self, severity: trt.ILogger.Severity, msg: str) -> None: + # TODO: Move to match once py39 reaches EoL + if severity == trt.ILogger.Severity.INTERNAL_ERROR: + _LOGGER.critical(msg) + raise RuntimeError(msg) + elif severity == trt.ILogger.Severity.ERROR: + _LOGGER.error(msg) + elif severity == trt.ILogger.Severity.WARNING: + _LOGGER.warning(msg) + elif severity == trt.ILogger.Severity.INFO: + _LOGGER.info(msg) + elif severity == trt.ILogger.Severity.VERBOSE: + _LOGGER.debug(msg) -def log(level: Level, msg: str) -> None: - """Add a new message to the log - - Adds a new message to the log at a specified level. The message - will only get printed out if Level > reportable_log_level - - Args: - level (torch_tensorrt.logging.Level): Severity of the message - msg (str): Actual message text - """ - _log(Level._to_internal_level(level), msg) - - InternalError = LogLevel.INTERNAL_ERROR - Error = LogLevel.ERROR - Warning = LogLevel.WARNING - Info = LogLevel.INFO - Debug = LogLevel.DEBUG - Graph = LogLevel.GRAPH +TRT_LOGGER = _TRTLogger() class internal_errors: @@ -125,11 +41,22 @@ class internal_errors: """ def __enter__(self) -> None: - self.external_lvl = get_reportable_log_level() - set_reportable_log_level(Level.InternalError) + self.external_lvl = _LOGGER.getEffectiveLevel() + _LOGGER.setLevel(logging.CRITICAL) + + if ENABLED_FEATURES.torchscript_frontend: + from torch_tensorrt.ts import logging as ts_logging + + self.ts_level = ts_logging.get_reportable_log_level() + ts_logging.set_reportable_log_level(ts_logging.Level.InternalError) def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: - set_reportable_log_level(self.external_lvl) + _LOGGER.setLevel(self.external_lvl) + + if ENABLED_FEATURES.torchscript_frontend: + from torch_tensorrt.ts import logging as ts_logging + + ts_logging.set_reportable_log_level(self.ts_level) class errors: @@ -142,11 +69,22 @@ class errors: """ def __enter__(self) -> None: - self.external_lvl = get_reportable_log_level() - set_reportable_log_level(Level.Error) + self.external_lvl = _LOGGER.getEffectiveLevel() + _LOGGER.setLevel(logging.ERROR) + + if ENABLED_FEATURES.torchscript_frontend: + from torch_tensorrt.ts import logging as ts_logging + + self.ts_level = ts_logging.get_reportable_log_level() + ts_logging.set_reportable_log_level(ts_logging.Level.Error) def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: - set_reportable_log_level(self.external_lvl) + _LOGGER.setLevel(self.external_lvl) + + if ENABLED_FEATURES.torchscript_frontend: + from torch_tensorrt.ts import logging as ts_logging + + ts_logging.set_reportable_log_level(self.ts_level) class warnings: @@ -159,11 +97,22 @@ class warnings: """ def __enter__(self) -> None: - self.external_lvl = get_reportable_log_level() - set_reportable_log_level(Level.Warning) + self.external_lvl = _LOGGER.getEffectiveLevel() + _LOGGER.setLevel(logging.WARNING) + + if ENABLED_FEATURES.torchscript_frontend: + from torch_tensorrt.ts import logging as ts_logging + + self.ts_level = ts_logging.get_reportable_log_level() + ts_logging.set_reportable_log_level(ts_logging.Level.Warning) def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: - set_reportable_log_level(self.external_lvl) + _LOGGER.setLevel(self.external_lvl) + + if ENABLED_FEATURES.torchscript_frontend: + from torch_tensorrt.ts import logging as ts_logging + + ts_logging.set_reportable_log_level(self.ts_level) class info: @@ -176,11 +125,22 @@ class info: """ def __enter__(self) -> None: - self.external_lvl = get_reportable_log_level() - set_reportable_log_level(Level.Info) + self.external_lvl = _LOGGER.getEffectiveLevel() + _LOGGER.setLevel(logging.INFO) + + if ENABLED_FEATURES.torchscript_frontend: + from torch_tensorrt.ts import logging as ts_logging + + self.ts_level = ts_logging.get_reportable_log_level() + ts_logging.set_reportable_log_level(ts_logging.Level.Info) def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: - set_reportable_log_level(self.external_lvl) + _LOGGER.setLevel(self.external_lvl) + + if ENABLED_FEATURES.torchscript_frontend: + from torch_tensorrt.ts import logging as ts_logging + + ts_logging.set_reportable_log_level(self.ts_level) class debug: @@ -193,11 +153,22 @@ class debug: """ def __enter__(self) -> None: - self.external_lvl = get_reportable_log_level() - set_reportable_log_level(Level.Debug) + self.external_lvl = _LOGGER.getEffectiveLevel() + _LOGGER.setLevel(logging.DEBUG) + + if ENABLED_FEATURES.torchscript_frontend: + from torch_tensorrt.ts import logging as ts_logging + + self.ts_level = ts_logging.get_reportable_log_level() + ts_logging.set_reportable_log_level(ts_logging.Level.Debug) def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: - set_reportable_log_level(self.external_lvl) + _LOGGER.setLevel(self.external_lvl) + + if ENABLED_FEATURES.torchscript_frontend: + from torch_tensorrt.ts import logging as ts_logging + + ts_logging.set_reportable_log_level(self.ts_level) class graphs: @@ -211,8 +182,19 @@ class graphs: """ def __enter__(self) -> None: - self.external_lvl = get_reportable_log_level() - set_reportable_log_level(Level.Graph) + self.external_lvl = _LOGGER.getEffectiveLevel() + _LOGGER.setLevel(logging.NOTSET) + + if ENABLED_FEATURES.torchscript_frontend: + from torch_tensorrt.ts import logging as ts_logging + + self.ts_level = ts_logging.get_reportable_log_level() + ts_logging.set_reportable_log_level(ts_logging.Level.Graph) def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: - set_reportable_log_level(self.external_lvl) + _LOGGER.setLevel(self.external_lvl) + + if ENABLED_FEATURES.torchscript_frontend: + from torch_tensorrt.ts import logging as ts_logging + + ts_logging.set_reportable_log_level(self.ts_level) diff --git a/py/torch_tensorrt/ts/_Device.py b/py/torch_tensorrt/ts/_Device.py new file mode 100644 index 0000000000..3ae10a9c4d --- /dev/null +++ b/py/torch_tensorrt/ts/_Device.py @@ -0,0 +1,69 @@ +import sys +from typing import Any + +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + +import warnings + +from torch_tensorrt._Device import Device + +try: + from torch_tensorrt import _C +except ImportError: + warnings.warn( + "Unable to import torchscript frontend core and torch-tensorrt runtime. Some dependent features may be unavailable." + ) + + +class TorchScriptDevice(Device): + """ + Defines a device that can be used to specify target devices for engines + + Attributes: + device_type (torch_tensorrt.DeviceType): Target device type (GPU or DLA). Set implicitly based on if dla_core is specified. + gpu_id (int): Device ID for target GPU + dla_core (int): Core ID for target DLA core + allow_gpu_fallback (bool): Whether falling back to GPU if DLA cannot support an op should be allowed + """ + + def __init__(self, *args: Any, **kwargs: Any): + """__init__ Method for torch_tensorrt.Device + + Device accepts one of a few construction patterns + + Args: + spec (str): String with device spec e.g. "dla:0" for dla, core_id 0 + + Keyword Arguments: + gpu_id (int): ID of target GPU (will get overrided if dla_core is specified to the GPU managing DLA). If specified, no positional arguments should be provided + dla_core (int): ID of target DLA core. If specified, no positional arguments should be provided. + allow_gpu_fallback (bool): Allow TensorRT to schedule operations on GPU if they are not supported on DLA (ignored if device type is not DLA) + + Examples: + - Device("gpu:1") + - Device("cuda:1") + - Device("dla:0", allow_gpu_fallback=True) + - Device(gpu_id=0, dla_core=0, allow_gpu_fallback=True) + - Device(dla_core=0, allow_gpu_fallback=True) + - Device(gpu_id=1) + """ + super().__init__(*args, **kwargs) + + def _to_internal(self) -> _C.Device: + internal_dev = _C.Device() + internal_dev.device_type = self.device_type.to(_C.DeviceType) + internal_dev.gpu_id = self.gpu_id + internal_dev.dla_core = self.dla_core + internal_dev.allow_gpu_fallback = self.allow_gpu_fallback + return internal_dev + + @classmethod + def _from(cls, d: object) -> Self: + return cls( + gpu_id=d.gpu_id, + dla_core=d.dla_core, + allow_gpu_fallback=d.allow_gpu_fallback, + ) diff --git a/py/torch_tensorrt/ts/_Input.py b/py/torch_tensorrt/ts/_Input.py index f9cbf2c333..6099efbcd2 100644 --- a/py/torch_tensorrt/ts/_Input.py +++ b/py/torch_tensorrt/ts/_Input.py @@ -1,6 +1,7 @@ from typing import Any -from torch_tensorrt import _C, _enums +from torch_tensorrt import _C +from torch_tensorrt._enums import dtype from torch_tensorrt._Input import Input @@ -49,11 +50,16 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: """ super().__init__(*args, **kwargs) + def is_trt_dtype(self) -> bool: + return bool(self.dtype != dtype.long) + def _to_internal(self) -> _C.Input: internal_in = _C.Input() if self.shape_mode == Input._ShapeMode.DYNAMIC: if isinstance(self.shape, dict): - if not Input._supported_input_size_type(self.shape["min_shape"]): + if not TorchScriptInput._supported_input_size_type( + self.shape["min_shape"] + ): raise TypeError( "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " + str(type(self.shape["min_shape"])) @@ -62,7 +68,9 @@ def _to_internal(self) -> _C.Input: else: internal_in.min = self.shape["min_shape"] - if not Input._supported_input_size_type(self.shape["opt_shape"]): + if not TorchScriptInput._supported_input_size_type( + self.shape["opt_shape"] + ): raise TypeError( "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " + str(type(self.shape["opt_shape"])) @@ -71,7 +79,9 @@ def _to_internal(self) -> _C.Input: else: internal_in.opt = self.shape["opt_shape"] - if not Input._supported_input_size_type(self.shape["max_shape"]): + if not TorchScriptInput._supported_input_size_type( + self.shape["max_shape"] + ): raise TypeError( "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " + str(type(self.shape["max_shape"])) @@ -81,7 +91,7 @@ def _to_internal(self) -> _C.Input: internal_in.max = self.shape["max_shape"] internal_in.input_is_dynamic = True else: - if not Input._supported_input_size_type(self.shape): + if not TorchScriptInput._supported_input_size_type(self.shape): raise TypeError( "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " + str(type(self.shape)) @@ -91,14 +101,11 @@ def _to_internal(self) -> _C.Input: internal_in.opt = self.shape internal_in.input_is_dynamic = False - if self.dtype != _enums.dtype.unknown: - self._explicit_set_dtype = True - else: - self._explicit_set_dtype = False - - internal_in.dtype = Input._parse_dtype(self.dtype) + internal_in.dtype = self.dtype.to(_C.dtype) internal_in._explicit_set_dtype = self._explicit_set_dtype - internal_in.format = Input._parse_format(self.format) + internal_in.format = self.format.to(_C.TensorFormat) - internal_in.tensor_domain = Input._parse_tensor_domain(self.tensor_domain) + internal_in.tensor_domain = TorchScriptInput._parse_tensor_domain( + self.tensor_domain + ) return internal_in diff --git a/py/torch_tensorrt/ts/__init__.py b/py/torch_tensorrt/ts/__init__.py index 5cb45cba5c..d11db42c68 100644 --- a/py/torch_tensorrt/ts/__init__.py +++ b/py/torch_tensorrt/ts/__init__.py @@ -1,3 +1,5 @@ from torch_tensorrt.ts._compile_spec import TensorRTCompileSpec # noqa: F401 from torch_tensorrt.ts._compiler import * # noqa: F403 +from torch_tensorrt.ts._Device import TorchScriptDevice # noqa: F401 +from torch_tensorrt.ts._enums import * # noqa: F403 from torch_tensorrt.ts._Input import TorchScriptInput # noqa: F401 diff --git a/py/torch_tensorrt/ts/_compile_spec.py b/py/torch_tensorrt/ts/_compile_spec.py index 37f5fb79e3..d2b493c5ca 100644 --- a/py/torch_tensorrt/ts/_compile_spec.py +++ b/py/torch_tensorrt/ts/_compile_spec.py @@ -6,11 +6,13 @@ import tensorrt as trt import torch import torch_tensorrt._C.ts as _ts_C -from torch_tensorrt import _C, _enums +from torch_tensorrt import _C from torch_tensorrt._Device import Device +from torch_tensorrt._enums import DeviceType, EngineCapability, dtype from torch_tensorrt._Input import Input -from torch_tensorrt.logging import Level, log +from torch_tensorrt.ts._Device import TorchScriptDevice from torch_tensorrt.ts._Input import TorchScriptInput +from torch_tensorrt.ts.logging import Level, log def _internal_input_to_torch_class_input(i: _C.Input) -> torch.classes.tensorrt._Input: @@ -40,31 +42,11 @@ def _supported_input_size_type(input_size: Any) -> bool: ) -def _parse_op_precision(precision: Any) -> _enums.dtype: - if isinstance(precision, torch.dtype): - if precision == torch.int8: - return _enums.dtype.int8 - elif precision == torch.half: - return _enums.dtype.half - elif precision == torch.float: - return _enums.dtype.float - else: - raise TypeError( - "Provided an unsupported dtype as operating precision (support: int8, half, float), got: " - + str(precision) - ) +def _parse_op_precision(precision: Any) -> _C.dtype: + return dtype._from(precision).to(_C.dtype) - elif isinstance(precision, _enums.dtype): - return precision - else: - raise TypeError( - "Op precision type needs to be specified with a torch.dtype or a torch_tensorrt.dtype, got: " - + str(type(precision)) - ) - - -def _parse_enabled_precisions(precisions: Any) -> Set[_enums.dtype]: +def _parse_enabled_precisions(precisions: Any) -> Set[_C.dtype]: parsed_precisions = set() if any(isinstance(precisions, type) for type in [list, tuple, set]): for p in precisions: @@ -74,36 +56,8 @@ def _parse_enabled_precisions(precisions: Any) -> Set[_enums.dtype]: return parsed_precisions -def _parse_device_type(device: Any) -> _enums.DeviceType: - if isinstance(device, torch.device): - if device.type == "cuda": - return _C.DeviceType.gpu - else: - ValueError( - "Got a device type other than GPU or DLA (type: " - + str(device.type) - + ")" - ) - elif isinstance(device, _C.DeviceType): - return device - elif isinstance(device, trt.DeviceType): - if device == trt.DeviceType.DLA: - return _C.DeviceType.DLA - return _C.DeviceType.GPU - elif isinstance(device, str): - if device == "gpu" or device == "GPU": - return _C.DeviceType.GPU - elif device == "dla" or device == "DLA": - return _C.DeviceType.DLA - else: - ValueError( - "Got a device type other than GPU or DLA (type: " + str(device) + ")" - ) - else: - raise TypeError( - "Device specification must be of type torch.device, string or torch_tensorrt.DeviceType, but got: " - + str(type(device)) - ) +def _parse_device_type(device: Any) -> _C.DeviceType: + return DeviceType._from(device).to(_C.DeviceType) def _parse_device(device_info: Any) -> _C.Device: @@ -128,9 +82,11 @@ def _parse_device(device_info: Any) -> _C.Device: return info elif isinstance(device_info, Device): + return TorchScriptDevice._from(device_info)._to_internal() + elif isinstance(device_info, TorchScriptDevice): return device_info._to_internal() elif isinstance(device_info, torch.device): - return (Device._from_torch_device(device_info))._to_internal() + return TorchScriptDevice._from(device_info)._to_internal() else: raise ValueError( "Unsupported data for device specification. Expected either a dict, torch_tensorrt.Device or torch.Device" @@ -184,9 +140,11 @@ def _parse_input_signature(input_signature: Any, depth: int = 0) -> Any: else input_signature ) - if not i.is_trt_dtype(): + if not i.dtype.try_to(trt.DataType, use_default=True): raise TypeError( - "Using non-TRT input types with input_signature is not currently " + "Using non-TRT input types ({}) with input_signature is not currently ".format( + i.dtype + ) + "supported. Please specify inputs individually to use " + "non-TRT types." ) @@ -246,7 +204,9 @@ def _parse_compile_spec(compile_spec_: Dict[str, Any]) -> _ts_C.CompileSpec: if i.shape_mode == Input._ShapeMode.STATIC: ts_inputs.append( TorchScriptInput( - shape=i.shape, dtype=i.dtype, format=i.format + shape=i.shape, + dtype=i.dtype.to(_C.dtype), + format=i.format.to(_C.TensorFormat), )._to_internal() ) elif i.shape_mode == Input._ShapeMode.DYNAMIC: @@ -255,8 +215,8 @@ def _parse_compile_spec(compile_spec_: Dict[str, Any]) -> _ts_C.CompileSpec: min_shape=i.shape["min_shape"], opt_shape=i.shape["opt_shape"], max_shape=i.shape["max_shape"], - dtype=i.dtype, - format=i.format, + dtype=i.dtype.to(_C.dtype), + format=i.format.to(_C.TensorFormat), )._to_internal() ) info.inputs = ts_inputs @@ -306,8 +266,10 @@ def _parse_compile_spec(compile_spec_: Dict[str, Any]) -> _ts_C.CompileSpec: info.device = _parse_device(compile_spec["device"]) if "capability" in compile_spec: - assert isinstance(compile_spec["capability"], _enums.EngineCapability) - info.capability = compile_spec["capability"] + capability = EngineCapability._from(compile_spec["capability"]).to( + _C.EngineCapability + ) + info.capability = capability if "num_avg_timing_iters" in compile_spec: assert type(compile_spec["num_avg_timing_iters"]) is int @@ -347,10 +309,10 @@ def TensorRTCompileSpec( device: torch.device | Device = Device._current_device(), disable_tf32: bool = False, sparse_weights: bool = False, - enabled_precisions: Optional[Set[torch.dtype | _enums.dtype]] = None, + enabled_precisions: Optional[Set[torch.dtype | dtype]] = None, refit: bool = False, debug: bool = False, - capability: _enums.EngineCapability = _enums.EngineCapability.default, + capability: EngineCapability = EngineCapability.STANDARD, num_avg_timing_iters: int = 1, workspace_size: int = 0, dla_sram_size: int = 1048576, diff --git a/py/torch_tensorrt/ts/_compiler.py b/py/torch_tensorrt/ts/_compiler.py index 4a9bb53dc0..802976cb27 100644 --- a/py/torch_tensorrt/ts/_compiler.py +++ b/py/torch_tensorrt/ts/_compiler.py @@ -4,8 +4,8 @@ import torch import torch_tensorrt._C.ts as _C -from torch_tensorrt import _enums from torch_tensorrt._Device import Device +from torch_tensorrt._enums import EngineCapability, dtype from torch_tensorrt._Input import Input from torch_tensorrt.ts._compile_spec import _parse_compile_spec, _parse_device @@ -17,10 +17,10 @@ def compile( device: Device = Device._current_device(), disable_tf32: bool = False, sparse_weights: bool = False, - enabled_precisions: Optional[Set[torch.dtype | _enums.dtype]] = None, + enabled_precisions: Optional[Set[torch.dtype | dtype]] = None, refit: bool = False, debug: bool = False, - capability: _enums.EngineCapability = _enums.EngineCapability.default, + capability: EngineCapability = EngineCapability.STANDARD, num_avg_timing_iters: int = 1, workspace_size: int = 0, dla_sram_size: int = 1048576, @@ -162,10 +162,10 @@ def convert_method_to_trt_engine( device: Device = Device._current_device(), disable_tf32: bool = False, sparse_weights: bool = False, - enabled_precisions: Optional[Set[torch.dtype | _enums.dtype]] = None, + enabled_precisions: Optional[Set[torch.dtype | dtype]] = None, refit: bool = False, debug: bool = False, - capability: _enums.EngineCapability = _enums.EngineCapability.default, + capability: EngineCapability = EngineCapability.STANDARD, num_avg_timing_iters: int = 1, workspace_size: int = 0, dla_sram_size: int = 1048576, diff --git a/py/torch_tensorrt/ts/_enums.py b/py/torch_tensorrt/ts/_enums.py new file mode 100644 index 0000000000..4fc9236848 --- /dev/null +++ b/py/torch_tensorrt/ts/_enums.py @@ -0,0 +1,2 @@ +from tensorrt import DeviceType # noqa: F401 +from torch_tensorrt._C import EngineCapability, TensorFormat, dtype # noqa: F401 diff --git a/py/torch_tensorrt/ts/_utils.py b/py/torch_tensorrt/ts/_utils.py new file mode 100644 index 0000000000..89625e5e86 --- /dev/null +++ b/py/torch_tensorrt/ts/_utils.py @@ -0,0 +1,31 @@ +import torch +from torch_tensorrt import _C +from torch_tensorrt._version import __version__ + + +def dump_build_info() -> None: + """Prints build information about the torch_tensorrt distribution to stdout""" + print(get_build_info()) + + +def get_build_info() -> str: + """Returns a string containing the build information of torch_tensorrt distribution + + Returns: + str: String containing the build information for torch_tensorrt distribution + """ + core_build_info = _C.get_build_info() + build_info = str( + "Torch-TensorRT Version: " + + str(__version__) + + "\n" + + "Using PyTorch Version: " + + str(torch.__version__) + + "\n" + + core_build_info + ) + return build_info + + +def set_device(gpu_id: int) -> None: + _C.set_device(gpu_id) diff --git a/py/torch_tensorrt/ts/logging.py b/py/torch_tensorrt/ts/logging.py new file mode 100644 index 0000000000..4220df7f19 --- /dev/null +++ b/py/torch_tensorrt/ts/logging.py @@ -0,0 +1,115 @@ +from enum import Enum + +from torch_tensorrt._C import ( + LogLevel, + _get_is_colored_output_on, + _get_logging_prefix, + _get_reportable_log_level, + _log, + _set_is_colored_output_on, + _set_logging_prefix, + _set_reportable_log_level, +) + + +class Level(Enum): + """Enum to set the minimum required logging level to print a message to stdout""" + + InternalError = LogLevel.INTERNAL_ERROR + Error = LogLevel.ERROR + Warning = LogLevel.WARNING + Info = LogLevel.INFO + Debug = LogLevel.DEBUG + Graph = LogLevel.GRAPH + + @staticmethod + def _to_internal_level(external: "Level") -> LogLevel: + if external == Level.InternalError: + return LogLevel.INTERNAL_ERROR + elif external == Level.Error: + return LogLevel.ERROR + elif external == Level.Warning: + return LogLevel.WARNING + elif external == Level.Info: + return LogLevel.INFO + elif external == Level.Debug: + return LogLevel.DEBUG + elif external == Level.Graph: + return LogLevel.GRAPH + else: + print(external) + raise ValueError("Unknown log severity") + + +def get_logging_prefix() -> str: + """Get the prefix set for logging messages + + Returns: + str: Prefix used for logger + """ + return str(_get_logging_prefix()) + + +def set_logging_prefix(prefix: str) -> None: + """Set the prefix used when logging messages + + Args: + prefix (str): Prefix to use for logging messages + """ + _set_logging_prefix(prefix) + + +def get_reportable_log_level() -> Level: + """Get the level required for a message to be printed in the log + + Returns: + torch_tensorrt.logging.Level: The enum representing the level required to print + """ + return Level(_get_reportable_log_level()) + + +def set_reportable_log_level(level: Level) -> None: + """Set the level required for a message to be printed to the log + + Args: + level (torch_tensorrt.logging.Level): The enum representing the level required to print + """ + _set_reportable_log_level(Level._to_internal_level(level)) + + +def get_is_colored_output_on() -> bool: + """Get if colored output is enabled for logging + + Returns: + bool: If colored output is one + """ + return bool(_get_is_colored_output_on()) + + +def set_is_colored_output_on(colored_output_on: bool) -> None: + """Enable or disable color in the log output + + Args: + colored_output_on (bool): If colored output should be enabled or not + """ + _set_is_colored_output_on(colored_output_on) + + +def log(level: Level, msg: str) -> None: + """Add a new message to the log + + Adds a new message to the log at a specified level. The message + will only get printed out if Level > reportable_log_level + + Args: + level (torch_tensorrt.logging.Level): Severity of the message + msg (str): Actual message text + """ + _log(Level._to_internal_level(level), msg) + + InternalError = LogLevel.INTERNAL_ERROR + Error = LogLevel.ERROR + Warning = LogLevel.WARNING + Info = LogLevel.INFO + Debug = LogLevel.DEBUG + Graph = LogLevel.GRAPH diff --git a/py/torch_tensorrt/ptq.py b/py/torch_tensorrt/ts/ptq.py similarity index 99% rename from py/torch_tensorrt/ptq.py rename to py/torch_tensorrt/ts/ptq.py index 5d13ab9108..aad53d5f8e 100644 --- a/py/torch_tensorrt/ptq.py +++ b/py/torch_tensorrt/ts/ptq.py @@ -11,7 +11,8 @@ import torch from torch_tensorrt import _C -from torch_tensorrt.logging import Level, log + +from py.torch_tensorrt.ts.logging import Level, log class CalibrationAlgo(Enum): diff --git a/pyproject.toml b/pyproject.toml index 1e29a43565..2496491cf8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,7 +64,7 @@ include-package-data = false [tool.ruff] # NOTE: Synchoronize the ignores with .flake8 -ignore = [ +lint.ignore = [ # these ignores are from flake8-bugbear; please fix! "B007", "B008", "B017", "B018", # Useless expression @@ -97,7 +97,7 @@ ignore = [ "SIM118", ] #line-length = 120 -select = [ +lint.select = [ "B", "C4", "G", @@ -112,11 +112,11 @@ select = [ ] # Allow unused variables when underscore-prefixed. -dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" +lint.dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" target-version = "py311" # Allow autofix for all enabled rules (when `--fix`) is provided. -fixable = [ +lint.fixable = [ "A","B","C","D","E","F","G", "I","N","Q","S","T","W", "ANN", "ARG", "BLE", "COM", "DJ", @@ -125,7 +125,7 @@ fixable = [ "PGH", "PIE", "PL", "PT", "PTH", "PYI", "RET", "RSE", "RUF", "SIM", "SLF", "TCH", "TID", "TRY", "UP", "YTT"] -unfixable = [] +lint.unfixable = [] # Exclude a variety of commonly ignored directories. exclude = [ @@ -164,7 +164,7 @@ exclude = [ "__init__.py" ] -[tool.ruff.mccabe] +[tool.ruff.lint.mccabe] # Unlike Flake8, default to a complexity level of 10. max-complexity = 10 diff --git a/requirements-dev.txt b/requirements-dev.txt index 052751afec..7f97a8e276 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -9,3 +9,4 @@ transformers timm parameterized expecttest==0.1.6 +pyyaml diff --git a/setup.py b/setup.py index 38d2121461..494eaa7ee1 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,5 @@ +# type: ignore + import glob import os import platform @@ -80,15 +82,24 @@ def load_dep_info(): CXX11_ABI = False JETPACK_VERSION = None -FX_ONLY = False +PY_ONLY = False +NO_TS = False LEGACY = False RELEASE = False CI_BUILD = False if "--fx-only" in sys.argv: - FX_ONLY = True + PY_ONLY = True sys.argv.remove("--fx-only") +if "--py-only" in sys.argv: + PY_ONLY = True + sys.argv.remove("--py-only") + +if "--no-ts" in sys.argv: + NO_TS = True + sys.argv.remove("--no-ts") + if "--legacy" in sys.argv: LEGACY = True sys.argv.remove("--legacy") @@ -97,6 +108,14 @@ def load_dep_info(): RELEASE = True sys.argv.remove("--release") +if (no_ts_env_var := os.environ.get("NO_TORCHSCRIPT")) is not None: + if no_ts_env_var == "1": + NO_TS = True + +if (py_only_env_var := os.environ.get("PYTHON_ONLY")) is not None: + if py_only_env_var == "1": + PY_ONLY = True + if (release_env_var := os.environ.get("RELEASE")) is not None: if release_env_var == "1": RELEASE = True @@ -168,7 +187,7 @@ def is_exe(fpath): BAZEL_EXE = None -if not FX_ONLY: +if not PY_ONLY: BAZEL_EXE = which("bazelisk") if BAZEL_EXE is None: @@ -177,9 +196,15 @@ def is_exe(fpath): sys.exit("Could not find bazel in PATH") -def build_libtorchtrt_pre_cxx11_abi(develop=True, use_dist_dir=True, cxx11_abi=False): +def build_libtorchtrt_pre_cxx11_abi( + develop=True, use_dist_dir=True, cxx11_abi=False, rt_only=False +): cmd = [BAZEL_EXE, "build"] - cmd.append("//:libtorchtrt") + if rt_only: + cmd.append("//:libtorchtrt_runtime") + else: + cmd.append("//:libtorchtrt") + if develop: cmd.append("--compilation_mode=dbg") else: @@ -224,7 +249,7 @@ def gen_version_file(): f.write('__tensorrt_version__ = "' + __tensorrt_version__ + '"\n') -def copy_libtorchtrt(multilinux=False): +def copy_libtorchtrt(multilinux=False, rt_only=False): if not os.path.exists(dir_path + "/torch_tensorrt/lib"): os.makedirs(dir_path + "/torch_tensorrt/lib") @@ -234,6 +259,14 @@ def copy_libtorchtrt(multilinux=False): dir_path + "/build/libtrtorch_build/libtrtorch.so", dir_path + "/trtorch/lib/libtrtorch.so", ) + elif rt_only: + os.system( + "tar -xzf " + + dir_path + + "/../bazel-bin/libtorchtrt_runtime.tar.gz --strip-components=1 -C " + + dir_path + + "/torch_tensorrt" + ) else: os.system( "tar -xzf " @@ -252,17 +285,20 @@ def initialize_options(self): def finalize_options(self): develop.finalize_options(self) + if NO_TS or PY_ONLY: + self.root_is_pure = False def run(self): - if FX_ONLY: - gen_version_file() - develop.run(self) - else: + + if not PY_ONLY: global CXX11_ABI - build_libtorchtrt_pre_cxx11_abi(develop=True, cxx11_abi=CXX11_ABI) - gen_version_file() - copy_libtorchtrt() - develop.run(self) + build_libtorchtrt_pre_cxx11_abi( + develop=True, cxx11_abi=CXX11_ABI, rt_only=NO_TS + ) + copy_libtorchtrt(rt_only=NO_TS) + + gen_version_file() + develop.run(self) class InstallCommand(install): @@ -273,17 +309,20 @@ def initialize_options(self): def finalize_options(self): install.finalize_options(self) + if NO_TS or PY_ONLY: + self.root_is_pure = False def run(self): - if FX_ONLY: - gen_version_file() - install.run(self) - else: + + if not PY_ONLY: global CXX11_ABI - build_libtorchtrt_pre_cxx11_abi(develop=False, cxx11_abi=CXX11_ABI) - gen_version_file() - copy_libtorchtrt() - install.run(self) + build_libtorchtrt_pre_cxx11_abi( + develop=False, cxx11_abi=CXX11_ABI, rt_only=NO_TS + ) + copy_libtorchtrt(rt_only=NO_TS) + + gen_version_file() + install.run(self) class BdistCommand(bdist_wheel): @@ -294,12 +333,18 @@ def initialize_options(self): def finalize_options(self): bdist_wheel.finalize_options(self) + if NO_TS or PY_ONLY: + self.root_is_pure = False def run(self): - global CXX11_ABI - build_libtorchtrt_pre_cxx11_abi(develop=False, cxx11_abi=CXX11_ABI) + if not PY_ONLY: + global CXX11_ABI + build_libtorchtrt_pre_cxx11_abi( + develop=False, cxx11_abi=CXX11_ABI, rt_only=NO_TS + ) + copy_libtorchtrt(rt_only=NO_TS) + gen_version_file() - copy_libtorchtrt() bdist_wheel.run(self) @@ -311,16 +356,20 @@ def initialize_options(self): def finalize_options(self): editable_wheel.finalize_options(self) + if NO_TS or PY_ONLY: + self.root_is_pure = False def run(self): - if FX_ONLY: + if PY_ONLY: gen_version_file() editable_wheel.run(self) else: global CXX11_ABI - build_libtorchtrt_pre_cxx11_abi(develop=True, cxx11_abi=CXX11_ABI) + build_libtorchtrt_pre_cxx11_abi( + develop=True, cxx11_abi=CXX11_ABI, rt_only=NO_TS + ) gen_version_file() - copy_libtorchtrt() + copy_libtorchtrt(rt_only=NO_TS) editable_wheel.run(self) @@ -436,7 +485,7 @@ def run(self): package_data = {} -if not FX_ONLY: +if not (PY_ONLY or NO_TS): ext_modules += [ cpp_extension.CUDAExtension( "torch_tensorrt._C", @@ -536,6 +585,19 @@ def run(self): ] } ) +elif NO_TS: + package_data.update( + { + "torch_tensorrt": [ + "BUILD", + "WORKSPACE", + "include/torch_tensorrt/*.h", + "include/torch_tensorrt/core/*.h", + "include/torch_tensorrt/core/runtime/*.h", + "lib/*", + ] + } + ) with open("README.md", "r", encoding="utf-8") as fh: long_description = fh.read() diff --git a/tests/py/core/test_classes.py b/tests/py/core/test_classes.py new file mode 100644 index 0000000000..a8922f2924 --- /dev/null +++ b/tests/py/core/test_classes.py @@ -0,0 +1,57 @@ +import copy +import unittest +from typing import Dict + +import tensorrt as trt +import torch +import torch_tensorrt as torchtrt +import torchvision.models as models +from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import TorchTensorRTModule + + +class TestDevice(unittest.TestCase): + def test_from_string_constructor(self): + device = torchtrt.Device("cuda:0") + self.assertEqual(device.device_type, torchtrt.DeviceType.GPU) + self.assertEqual(device.gpu_id, 0) + + device = torchtrt.Device("gpu:1") + self.assertEqual(device.device_type, torchtrt.DeviceType.GPU) + self.assertEqual(device.gpu_id, 1) + + def test_from_string_constructor_dla(self): + device = torchtrt.Device("dla:0") + self.assertEqual(device.device_type, torchtrt.DeviceType.DLA) + self.assertEqual(device.gpu_id, 0) + self.assertEqual(device.dla_core, 0) + + device = torchtrt.Device("dla:1", allow_gpu_fallback=True) + self.assertEqual(device.device_type, torchtrt.DeviceType.DLA) + self.assertEqual(device.gpu_id, 0) + self.assertEqual(device.dla_core, 1) + self.assertEqual(device.allow_gpu_fallback, True) + + def test_kwargs_gpu(self): + device = torchtrt.Device(gpu_id=0) + self.assertEqual(device.device_type, torchtrt.DeviceType.GPU) + self.assertEqual(device.gpu_id, 0) + + def test_kwargs_dla_and_settings(self): + device = torchtrt.Device(dla_core=1, allow_gpu_fallback=False) + self.assertEqual(device.device_type, torchtrt.DeviceType.DLA) + self.assertEqual(device.gpu_id, 0) + self.assertEqual(device.dla_core, 1) + self.assertEqual(device.allow_gpu_fallback, False) + + device = torchtrt.Device(gpu_id=1, dla_core=0, allow_gpu_fallback=True) + self.assertEqual(device.device_type, torchtrt.DeviceType.DLA) + self.assertEqual( + device.gpu_id, 0 + ) # Override since AGX platforms use iGPU to manage DLA + self.assertEqual(device.dla_core, 0) + self.assertEqual(device.allow_gpu_fallback, True) + + def test_from_torch(self): + device = torchtrt.Device._from_torch_device(torch.device("cuda:0")) + self.assertEqual(device.device_type, torchtrt.DeviceType.GPU) + self.assertEqual(device.gpu_id, 0) diff --git a/tests/py/dynamo/backend/test_backend_compiler.py b/tests/py/dynamo/backend/test_backend_compiler.py index a958d03120..506c9a1959 100644 --- a/tests/py/dynamo/backend/test_backend_compiler.py +++ b/tests/py/dynamo/backend/test_backend_compiler.py @@ -1,11 +1,11 @@ +# type: ignore from copy import deepcopy import torch +import torch_tensorrt from torch.testing._internal.common_utils import TestCase, run_tests from torch_tensorrt.dynamo.partitioning import fast_partition -import torch_tensorrt - from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing diff --git a/tests/py/dynamo/backend/test_specialized_models.py b/tests/py/dynamo/backend/test_specialized_models.py index 3885627b5f..613fc167bb 100644 --- a/tests/py/dynamo/backend/test_specialized_models.py +++ b/tests/py/dynamo/backend/test_specialized_models.py @@ -1,3 +1,4 @@ +# type: ignore import torch import torch_tensorrt from torch.testing._internal.common_utils import TestCase, run_tests diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index 404f50a187..ef034c914f 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -1,3 +1,5 @@ +# type: ignore + import logging import time import unittest @@ -6,10 +8,12 @@ import torch from torch.testing._internal.common_utils import TestCase from torch_tensorrt import Input +from torch_tensorrt._enums import dtype from torch_tensorrt.dynamo._settings import CompilationSettings # Use interpreter, input spec, and test case from fx_ts_compat to test Dynamo Converter Registry from torch_tensorrt.dynamo.conversion import TRTInterpreter +from torch_tensorrt.dynamo.conversion._conversion import infer_module_output_dtypes from torch_tensorrt.dynamo.lowering import apply_lowering_passes from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule @@ -68,7 +72,8 @@ def run_test( interpreter_result.output_names, ) - ref_outputs = mod(*inputs) + mod = mod.cuda() + ref_outputs = mod(*cuda_inputs) torch.cuda.synchronize() start_event = torch.cuda.Event(enable_timing=True) @@ -144,7 +149,7 @@ def run_test_custom_compare_results( interpreter_result.output_names, ) res_trt = trt_mod(*cuda_inputs).cpu() - res_cpu = mod(*inputs) + res_cpu = mod(*cuda_inputs).cpu() assert len(res_trt) == len(res_cpu) assert len(res_cpu) == len(comparators) for output_trt, output_cpu, comparator in zip( @@ -208,7 +213,6 @@ def generate_graph( fx_module = torch.fx.symbolic_trace(mod) if enable_passes: fx_module = apply_lowering_passes(fx_module, original_inputs) - _LOGGER.info(f"FX graph= {fx_module.graph}") return fx_module def run_test( @@ -217,9 +221,8 @@ def run_test( inputs, rtol=1e-03, atol=1e-03, - precision=torch.float, + precision=dtype.f32, check_dtype=True, - output_dtypes=None, use_dynamo_tracer=False, enable_passes=False, ): @@ -234,12 +237,25 @@ def run_test( # Previous instance of the interpreter auto-casted 64-bit inputs # We replicate this behavior here compilation_settings = CompilationSettings( - precision=precision, truncate_long_and_double=True + enabled_precisions={dtype._from(precision)}, + truncate_long_and_double=True, + debug=True, ) + input_specs = [Input.from_tensor(i) for i in inputs] + + output_dtypes = None + if check_dtype: + output_dtypes = infer_module_output_dtypes( + mod, + input_specs, + compilation_settings.device, + truncate_long_and_double=compilation_settings.truncate_long_and_double, + ) + interp = TRTInterpreter( mod, - Input.from_tensors(inputs), + input_specs, output_dtypes=output_dtypes, compilation_settings=compilation_settings, ) diff --git a/tests/py/dynamo/conversion/test_abs_aten.py b/tests/py/dynamo/conversion/test_abs_aten.py index 13beeb3bfa..5778110106 100644 --- a/tests/py/dynamo/conversion/test_abs_aten.py +++ b/tests/py/dynamo/conversion/test_abs_aten.py @@ -42,7 +42,6 @@ def forward(self, input): self.run_test( abs(), inputs, - output_dtypes=[torch.int], ) diff --git a/tests/py/dynamo/conversion/test_any.py b/tests/py/dynamo/conversion/test_any.py index f82e2465be..29522145da 100644 --- a/tests/py/dynamo/conversion/test_any.py +++ b/tests/py/dynamo/conversion/test_any.py @@ -26,7 +26,7 @@ def forward(self, x): return torch.ops.aten.any.default(x) inputs = [torch.randn(*input_shape)] - self.run_test(Any(), inputs, output_dtypes=[torch.bool]) + self.run_test(Any(), inputs) @parameterized.expand( [ @@ -43,7 +43,7 @@ def forward(self, x): return torch.ops.aten.any.dim(x, dim, keep_dims) inputs = [torch.randn(*input_shape)] - self.run_test(AnyDim(), inputs, output_dtypes=[torch.bool]) + self.run_test(AnyDim(), inputs) @parameterized.expand( [ @@ -59,7 +59,7 @@ def forward(self, x): return torch.ops.aten.any.dims(x, dims, keep_dims) inputs = [torch.randn(*input_shape)] - self.run_test(AnyDims(), inputs, output_dtypes=[torch.bool]) + self.run_test(AnyDims(), inputs) @parameterized.expand( [ @@ -79,7 +79,6 @@ def forward(self, x): self.run_test( Any(), inputs, - output_dtypes=[torch.bool], ) @parameterized.expand( @@ -100,7 +99,6 @@ def forward(self, x): self.run_test( AnyDim(), inputs, - output_dtypes=[torch.bool], ) @parameterized.expand( @@ -123,7 +121,6 @@ def forward(self, x): self.run_test( AnyDims(), inputs, - output_dtypes=[torch.bool], ) @parameterized.expand( @@ -142,7 +139,6 @@ def forward(self, x): self.run_test( Any(), inputs, - output_dtypes=[torch.bool], ) @parameterized.expand( @@ -163,7 +159,6 @@ def forward(self, x): self.run_test( AnyDim(), inputs, - output_dtypes=[torch.bool], ) @parameterized.expand( @@ -186,7 +181,6 @@ def forward(self, x): self.run_test( AnyDims(), inputs, - output_dtypes=[torch.bool], ) diff --git a/tests/py/dynamo/conversion/test_bitwise_not_aten.py b/tests/py/dynamo/conversion/test_bitwise_not_aten.py index 6dd512ef16..b811f1e51a 100644 --- a/tests/py/dynamo/conversion/test_bitwise_not_aten.py +++ b/tests/py/dynamo/conversion/test_bitwise_not_aten.py @@ -25,7 +25,6 @@ def forward(self, val): bitwise_not(), inputs, enable_passes=True, - output_dtypes=[torch.bool], ) diff --git a/tests/py/dynamo/conversion/test_casts.py b/tests/py/dynamo/conversion/test_casts.py index c067a0b9ad..84234db857 100644 --- a/tests/py/dynamo/conversion/test_casts.py +++ b/tests/py/dynamo/conversion/test_casts.py @@ -1,6 +1,9 @@ +# type: ignore + import torch import torch.nn as nn from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import dtype from torch_tensorrt.dynamo.conversion import UnsupportedOperatorException from .harness import DispatchTestCase diff --git a/tests/py/dynamo/conversion/test_eq_aten.py b/tests/py/dynamo/conversion/test_eq_aten.py index 17a372182c..3adc6774d6 100644 --- a/tests/py/dynamo/conversion/test_eq_aten.py +++ b/tests/py/dynamo/conversion/test_eq_aten.py @@ -25,7 +25,6 @@ def forward(self, lhs_val, rhs_val): self.run_test( eq(), inputs, - output_dtypes=[torch.bool], ) @parameterized.expand( @@ -43,7 +42,6 @@ def forward(self, lhs_val): self.run_test( eq(), inputs, - output_dtypes=[torch.bool], ) @parameterized.expand( @@ -61,7 +59,6 @@ def forward(self, lhs_val): self.run_test( eq(), inputs, - output_dtypes=[torch.bool], ) diff --git a/tests/py/dynamo/conversion/test_ge_aten.py b/tests/py/dynamo/conversion/test_ge_aten.py index 6b1ee6d440..bacfedafc8 100644 --- a/tests/py/dynamo/conversion/test_ge_aten.py +++ b/tests/py/dynamo/conversion/test_ge_aten.py @@ -25,7 +25,6 @@ def forward(self, lhs_val, rhs_val): self.run_test( ge(), inputs, - output_dtypes=[torch.bool], ) @parameterized.expand( @@ -43,7 +42,6 @@ def forward(self, lhs_val): self.run_test( ge(), inputs, - output_dtypes=[torch.bool], ) @parameterized.expand( @@ -61,7 +59,6 @@ def forward(self, lhs_val): self.run_test( ge(), inputs, - output_dtypes=[torch.bool], ) diff --git a/tests/py/dynamo/conversion/test_gt_aten.py b/tests/py/dynamo/conversion/test_gt_aten.py index 8d9ae24f80..0eab7c84ff 100644 --- a/tests/py/dynamo/conversion/test_gt_aten.py +++ b/tests/py/dynamo/conversion/test_gt_aten.py @@ -22,7 +22,6 @@ def forward(self, lhs_val, rhs_val): self.run_test( gt(), inputs, - output_dtypes=[torch.bool], ) @parameterized.expand( @@ -40,7 +39,6 @@ def forward(self, lhs_val): self.run_test( gt(), inputs, - output_dtypes=[torch.bool], ) @parameterized.expand( @@ -58,7 +56,6 @@ def forward(self, lhs_val): self.run_test( gt(), inputs, - output_dtypes=[torch.bool], ) diff --git a/tests/py/dynamo/conversion/test_isinf_aten.py b/tests/py/dynamo/conversion/test_isinf_aten.py index 78695dbe21..d0dce59a60 100644 --- a/tests/py/dynamo/conversion/test_isinf_aten.py +++ b/tests/py/dynamo/conversion/test_isinf_aten.py @@ -35,7 +35,6 @@ def forward(self, input): self.run_test( isinf(), inputs, - output_dtypes=[torch.bool], ) @parameterized.expand( @@ -54,7 +53,6 @@ def forward(self, input): self.run_test( isinf(), inputs, - output_dtypes=[torch.bool], ) diff --git a/tests/py/dynamo/conversion/test_le_aten.py b/tests/py/dynamo/conversion/test_le_aten.py index 373384c6f9..5b725213a3 100644 --- a/tests/py/dynamo/conversion/test_le_aten.py +++ b/tests/py/dynamo/conversion/test_le_aten.py @@ -25,7 +25,6 @@ def forward(self, lhs_val, rhs_val): self.run_test( le(), inputs, - output_dtypes=[torch.bool], ) @parameterized.expand( @@ -43,7 +42,6 @@ def forward(self, lhs_val): self.run_test( le(), inputs, - output_dtypes=[torch.bool], ) @parameterized.expand( @@ -61,7 +59,6 @@ def forward(self, lhs_val): self.run_test( le(), inputs, - output_dtypes=[torch.bool], ) diff --git a/tests/py/dynamo/conversion/test_logical_not_aten.py b/tests/py/dynamo/conversion/test_logical_not_aten.py index a36a8dbf72..b03fbc777e 100644 --- a/tests/py/dynamo/conversion/test_logical_not_aten.py +++ b/tests/py/dynamo/conversion/test_logical_not_aten.py @@ -22,7 +22,6 @@ def forward(self, input): self.run_test( logical_not(), inputs, - output_dtypes=[torch.bool], ) @parameterized.expand( @@ -41,7 +40,6 @@ def forward(self, input): self.run_test( logical_not(), inputs, - output_dtypes=[torch.bool], ) @parameterized.expand( @@ -60,7 +58,6 @@ def forward(self, input): self.run_test( logical_not(), inputs, - output_dtypes=[torch.bool], ) diff --git a/tests/py/dynamo/conversion/test_lt_aten.py b/tests/py/dynamo/conversion/test_lt_aten.py index 89cb7f42c5..bd4b8f1b21 100644 --- a/tests/py/dynamo/conversion/test_lt_aten.py +++ b/tests/py/dynamo/conversion/test_lt_aten.py @@ -22,7 +22,6 @@ def forward(self, lhs_val, rhs_val): self.run_test( lt(), inputs, - output_dtypes=[torch.bool], ) @parameterized.expand( @@ -40,7 +39,6 @@ def forward(self, lhs_val): self.run_test( lt(), inputs, - output_dtypes=[torch.bool], ) @parameterized.expand( @@ -58,7 +56,6 @@ def forward(self, lhs_val): self.run_test( lt(), inputs, - output_dtypes=[torch.bool], ) diff --git a/tests/py/dynamo/conversion/test_ne_aten.py b/tests/py/dynamo/conversion/test_ne_aten.py index 2450ac0945..d2f7421848 100644 --- a/tests/py/dynamo/conversion/test_ne_aten.py +++ b/tests/py/dynamo/conversion/test_ne_aten.py @@ -25,7 +25,6 @@ def forward(self, lhs_val, rhs_val): self.run_test( ne(), inputs, - output_dtypes=[torch.bool], ) @parameterized.expand( @@ -43,7 +42,6 @@ def forward(self, lhs_val): self.run_test( ne(), inputs, - output_dtypes=[torch.bool], ) @parameterized.expand( @@ -61,7 +59,6 @@ def forward(self, lhs_val): self.run_test( ne(), inputs, - output_dtypes=[torch.bool], ) diff --git a/tests/py/dynamo/conversion/test_pad_aten.py b/tests/py/dynamo/conversion/test_pad_aten.py index 2803736ad0..ca0b01b5d2 100644 --- a/tests/py/dynamo/conversion/test_pad_aten.py +++ b/tests/py/dynamo/conversion/test_pad_aten.py @@ -1,3 +1,4 @@ +# type: ignore import torch from parameterized import parameterized from torch.testing._internal.common_utils import run_tests diff --git a/tests/py/dynamo/conversion/test_scalar_tensor_aten.py b/tests/py/dynamo/conversion/test_scalar_tensor_aten.py index 28c3d7f481..d0146ed720 100644 --- a/tests/py/dynamo/conversion/test_scalar_tensor_aten.py +++ b/tests/py/dynamo/conversion/test_scalar_tensor_aten.py @@ -87,7 +87,6 @@ def forward(self): self.run_test( ScalarTensor(), inputs, - output_dtypes=None if dtype is None else [dtype], ) diff --git a/tests/py/dynamo/conversion/test_sum_aten.py b/tests/py/dynamo/conversion/test_sum_aten.py index 999b8b4997..bac8c7edf1 100644 --- a/tests/py/dynamo/conversion/test_sum_aten.py +++ b/tests/py/dynamo/conversion/test_sum_aten.py @@ -85,7 +85,6 @@ def forward(self, x): self.run_test( Sum(), inputs, - output_dtypes=[torch.int32], ) @parameterized.expand( @@ -108,7 +107,6 @@ def forward(self, x): self.run_test( Sum(), inputs, - output_dtypes=[torch.int32], ) diff --git a/tests/py/dynamo/lowering/test_aten_lowering_passes.py b/tests/py/dynamo/lowering/test_aten_lowering_passes.py index b7c895ec11..2d7a4731f5 100644 --- a/tests/py/dynamo/lowering/test_aten_lowering_passes.py +++ b/tests/py/dynamo/lowering/test_aten_lowering_passes.py @@ -1,3 +1,5 @@ +import unittest + import torch import torch_tensorrt from torch.testing._internal.common_utils import TestCase, run_tests @@ -267,6 +269,10 @@ def forward(self, q, k, v): torch._dynamo.reset() +@unittest.skipIf( + torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8, + "GPU compute capability is too low to run flash attention, need Ampere (8.0) or greater", +) class TestLowerFlashAttention(TestCase): def test_lower_flash_attention(self): class FlashAttention(torch.nn.Module): diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index 84f6bf7a36..38889a3df8 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -1,3 +1,4 @@ +# type: ignore import unittest import pytest diff --git a/tests/py/dynamo/runtime/test_hw_compat.py b/tests/py/dynamo/runtime/test_hw_compat.py index 4218cc7de0..29bd17cfde 100644 --- a/tests/py/dynamo/runtime/test_hw_compat.py +++ b/tests/py/dynamo/runtime/test_hw_compat.py @@ -7,6 +7,14 @@ class TestHardwareCompatibility(TestCase): + @unittest.skipIf( + not torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime, + "Torch-TensorRT Runtime is not available", + ) + @unittest.skipIf( + not torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8, + "HW Compatibility is not supported on cards older than Ampere", + ) def test_hw_compat_enabled(self): class SampleModel(torch.nn.Module): def forward(self, x): @@ -58,6 +66,14 @@ def forward(self, x): torch.ops.tensorrt.ABI_VERSION() != "5", "Detected incorrect ABI version, please update this test case", ) + @unittest.skipIf( + not torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime, + "Torch-TensorRT runtime is not available", + ) + @unittest.skipIf( + not torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8, + "HW Compatibility is not supported on cards older than Ampere", + ) def test_hw_compat_3080_build(self): inputs = [torch.randn(5, 7).cuda()] diff --git a/tests/py/dynamo/runtime/test_safe_mode.py b/tests/py/dynamo/runtime/test_safe_mode.py index bd196b12f0..5842b3ddc5 100644 --- a/tests/py/dynamo/runtime/test_safe_mode.py +++ b/tests/py/dynamo/runtime/test_safe_mode.py @@ -1,11 +1,16 @@ -import torch -from torch.testing._internal.common_utils import TestCase, run_tests +import unittest +import torch import torch_tensorrt +from torch.testing._internal.common_utils import TestCase, run_tests from ..testing_utilities import DECIMALS_OF_AGREEMENT +@unittest.skipIf( + not torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime, + "Torch-TensorRT runtime is not available", +) class TestSafeMode(TestCase): def test_multi_device_safe_mode_on(self): torch_tensorrt.runtime.set_multi_device_safe_mode(True) diff --git a/tests/py/ts/api/test_classes.py b/tests/py/ts/api/test_classes.py index 01c805d9a1..2a152cdec7 100644 --- a/tests/py/ts/api/test_classes.py +++ b/tests/py/ts/api/test_classes.py @@ -1,58 +1,17 @@ -import unittest -import torch_tensorrt as torchtrt -from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import TorchTensorRTModule -import torch -import torchvision.models as models import copy +import unittest from typing import Dict - -class TestDevice(unittest.TestCase): - def test_from_string_constructor(self): - device = torchtrt.Device("cuda:0") - self.assertEqual(device.device_type, torchtrt.DeviceType.GPU) - self.assertEqual(device.gpu_id, 0) - - device = torchtrt.Device("gpu:1") - self.assertEqual(device.device_type, torchtrt.DeviceType.GPU) - self.assertEqual(device.gpu_id, 1) - - def test_from_string_constructor_dla(self): - device = torchtrt.Device("dla:0") - self.assertEqual(device.device_type, torchtrt.DeviceType.DLA) - self.assertEqual(device.gpu_id, 0) - self.assertEqual(device.dla_core, 0) - - device = torchtrt.Device("dla:1", allow_gpu_fallback=True) - self.assertEqual(device.device_type, torchtrt.DeviceType.DLA) - self.assertEqual(device.gpu_id, 0) - self.assertEqual(device.dla_core, 1) - self.assertEqual(device.allow_gpu_fallback, True) - - def test_kwargs_gpu(self): - device = torchtrt.Device(gpu_id=0) - self.assertEqual(device.device_type, torchtrt.DeviceType.GPU) - self.assertEqual(device.gpu_id, 0) - - def test_kwargs_dla_and_settings(self): - device = torchtrt.Device(dla_core=1, allow_gpu_fallback=False) - self.assertEqual(device.device_type, torchtrt.DeviceType.DLA) - self.assertEqual(device.gpu_id, 0) - self.assertEqual(device.dla_core, 1) - self.assertEqual(device.allow_gpu_fallback, False) - - device = torchtrt.Device(gpu_id=1, dla_core=0, allow_gpu_fallback=True) - self.assertEqual(device.device_type, torchtrt.DeviceType.DLA) - self.assertEqual(device.gpu_id, 1) - self.assertEqual(device.dla_core, 0) - self.assertEqual(device.allow_gpu_fallback, True) - - def test_from_torch(self): - device = torchtrt.Device._from_torch_device(torch.device("cuda:0")) - self.assertEqual(device.device_type, torchtrt.DeviceType.GPU) - self.assertEqual(device.gpu_id, 0) +import torch +import torch_tensorrt as torchtrt +import torchvision.models as models +from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import TorchTensorRTModule +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) class TestInput(unittest.TestCase): def _verify_correctness(self, struct: torchtrt.Input, target: Dict) -> bool: internal = struct._to_internal() @@ -80,10 +39,16 @@ def field_is_correct(field, equal_fn, a1, a2): target["explicit_set_dtype"], ) dtype_ = field_is_correct( - "dtype", eq, int(internal.dtype), int(target["dtype"]) + "dtype", + eq, + torchtrt.dtype._from(internal.dtype), + torchtrt.dtype._from(target["dtype"]), ) format_ = field_is_correct( - "format", eq, int(internal.format), int(target["format"]) + "format", + eq, + torchtrt.memory_format._from(internal.format), + torchtrt.memory_format._from(target["format"]), ) return all( @@ -98,7 +63,7 @@ def test_infer_from_example_tensor(self): "max": shape, "input_is_dynamic": False, "dtype": torchtrt.dtype.half, - "format": torchtrt.TensorFormat.contiguous, + "format": torchtrt.memory_format.contiguous, "explicit_set_dtype": True, } @@ -117,7 +82,7 @@ def test_static_shape(self): "max": shape, "input_is_dynamic": False, "dtype": torchtrt.dtype.unknown, - "format": torchtrt.TensorFormat.contiguous, + "format": torchtrt.memory_format.contiguous, "explicit_set_dtype": False, } @@ -165,7 +130,7 @@ def test_data_type(self): "max": shape, "input_is_dynamic": False, "dtype": torchtrt.dtype.half, - "format": torchtrt.TensorFormat.contiguous, + "format": torchtrt.memory_format.contiguous, "explicit_set_dtype": True, } @@ -189,11 +154,11 @@ def test_tensor_format(self): "max": shape, "input_is_dynamic": False, "dtype": torchtrt.dtype.unknown, - "format": torchtrt.TensorFormat.channels_last, + "format": torchtrt.memory_format.channels_last, "explicit_set_dtype": False, } - i = torchtrt.Input(shape, format=torchtrt.TensorFormat.channels_last) + i = torchtrt.Input(shape, format=torchtrt.memory_format.channels_last) ts_i = torchtrt.ts.TorchScriptInput( shape=i.shape, dtype=i.dtype, format=i.format ) @@ -215,7 +180,7 @@ def test_dynamic_shape(self): "max": max_shape, "input_is_dynamic": True, "dtype": torchtrt.dtype.unknown, - "format": torchtrt.TensorFormat.contiguous, + "format": torchtrt.memory_format.contiguous, "explicit_set_dtype": False, } @@ -261,6 +226,10 @@ def test_dynamic_shape(self): self.assertTrue(self._verify_correctness(ts_i, target)) +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) class TestTorchTensorRTModule(unittest.TestCase): @staticmethod def _get_trt_mod(): diff --git a/tests/py/ts/api/test_collections.py b/tests/py/ts/api/test_collections.py index eab67679ed..7dc79b09b4 100644 --- a/tests/py/ts/api/test_collections.py +++ b/tests/py/ts/api/test_collections.py @@ -1,9 +1,12 @@ +# type: ignore + +import os import unittest -import torch_tensorrt as torchtrt + import torch +import torch_tensorrt as torchtrt import torchvision.models as models -import os -from utils import cosine_similarity, COSINE_THRESHOLD +from utils import COSINE_THRESHOLD, cosine_similarity def find_repo_root(max_depth=10): @@ -21,6 +24,10 @@ def find_repo_root(max_depth=10): MODULE_DIR = find_repo_root() + "/tests/modules" +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) class TestStandardTensorInput(unittest.TestCase): def test_compile(self): self.input = torch.randn((1, 3, 224, 224)).to("cuda") @@ -49,9 +56,14 @@ def test_compile(self): ) +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) +@unittest.skip("TODO: @bowang007, Invalid test case, needs fixing") class TestStandardTensorInputLong(unittest.TestCase): def test_compile(self): - self.input = torch.randn((1, 3, 224, 224)).to("cuda") + self.input = torch.randn((1, 3, 224, 224)).to("cuda").to(torch.int32) self.model = ( torch.jit.load(MODULE_DIR + "/standard_tensor_input_scripted.jit.pt") .eval() @@ -66,6 +78,7 @@ def test_compile(self): "device": torchtrt.Device("gpu:0"), "enabled_precisions": {torch.float}, "truncate_long_and_double": True, + "require_full_compilation": True, } trt_mod = torchtrt.ts.compile(self.model, **compile_spec) @@ -78,6 +91,10 @@ def test_compile(self): ) +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) class TestStandardTensorInputDomain(unittest.TestCase): def test_compile(self): self.input = torch.randn((1, 3, 224, 224)).to("cuda") @@ -106,6 +123,10 @@ def test_compile(self): ) +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) class TestTupleInput(unittest.TestCase): def test_compile(self): self.input = torch.randn((1, 3, 224, 224)).to("cuda") @@ -134,6 +155,10 @@ def test_compile(self): ) +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) class TestListInput(unittest.TestCase): def test_compile(self): self.input = torch.randn((1, 3, 224, 224)).to("cuda") @@ -160,6 +185,10 @@ def test_compile(self): ) +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) class TestTupleInputOutput(unittest.TestCase): def test_compile(self): self.input = torch.randn((1, 3, 224, 224)).to("cuda") @@ -217,6 +246,10 @@ def test_compile_full_compilation(self): ) +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) class TestListInputOutput(unittest.TestCase): def test_compile(self): self.input = torch.randn((1, 3, 224, 224)).to("cuda") @@ -276,6 +309,10 @@ def test_compile_full_compilation(self): ) +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) class TestListInputTupleOutput(unittest.TestCase): def test_compile(self): self.input = torch.randn((1, 3, 224, 224)).to("cuda") diff --git a/tests/py/ts/api/test_e2e_behavior.py b/tests/py/ts/api/test_e2e_behavior.py index 499106e9ca..7e1f3dd538 100644 --- a/tests/py/ts/api/test_e2e_behavior.py +++ b/tests/py/ts/api/test_e2e_behavior.py @@ -1,9 +1,10 @@ +import copy import unittest -import torch_tensorrt as torchtrt +from typing import Dict + import torch +import torch_tensorrt as torchtrt import torchvision.models as models -import copy -from typing import Dict from utils import same_output_format @@ -39,7 +40,7 @@ def test_input_respect_user_setting_fp32_weights_fp16_in_non_constructor(self): ts_model = torch.jit.script(self.model) input_spec = torchtrt.Input(self.input.shape) - input_spec.dtype = torch.half + input_spec.dtype = torchtrt.dtype.half trt_mod = torchtrt.ts.compile( ts_model, @@ -100,7 +101,7 @@ def test_input_respect_user_setting_fp16_weights_fp32_in_non_constuctor(self): half_mod.half() input_spec = torchtrt.Input(self.input.shape) - input_spec.dtype = torch.float + input_spec.dtype = torchtrt.dtype.float trt_mod = torchtrt.ts.compile( half_mod, diff --git a/tests/py/ts/api/test_logging.py b/tests/py/ts/api/test_logging.py index cc10fa9cc9..b07be3bff5 100644 --- a/tests/py/ts/api/test_logging.py +++ b/tests/py/ts/api/test_logging.py @@ -1,71 +1,76 @@ +import copy import unittest -import torch_tensorrt as torchtrt +from typing import Dict + import torch +import torch_tensorrt as torchtrt import torchvision.models as models -import copy -from typing import Dict +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) class TestLoggingAPIs(unittest.TestCase): def test_logging_prefix(self): new_prefix = "Python API Test: " - torchtrt.logging.set_logging_prefix(new_prefix) - logging_prefix = torchtrt.logging.get_logging_prefix() + torchtrt.ts.logging.set_logging_prefix(new_prefix) + logging_prefix = torchtrt.ts.logging.get_logging_prefix() self.assertEqual(new_prefix, logging_prefix) def test_reportable_log_level(self): - new_level = torchtrt.logging.Level.Error - torchtrt.logging.set_reportable_log_level(new_level) - level = torchtrt.logging.get_reportable_log_level() + new_level = torchtrt.ts.logging.Level.Error + torchtrt.ts.logging.set_reportable_log_level(new_level) + level = torchtrt.ts.logging.get_reportable_log_level() self.assertEqual(new_level, level) def test_is_colored_output_on(self): - torchtrt.logging.set_is_colored_output_on(True) - color = torchtrt.logging.get_is_colored_output_on() + torchtrt.ts.logging.set_is_colored_output_on(True) + color = torchtrt.ts.logging.get_is_colored_output_on() self.assertTrue(color) def test_context_managers(self): - base_lvl = torchtrt.logging.get_reportable_log_level() + base_lvl = torchtrt.ts.logging.get_reportable_log_level() with torchtrt.logging.internal_errors(): - lvl = torchtrt.logging.get_reportable_log_level() - self.assertEqual(torchtrt.logging.Level.InternalError, lvl) + lvl = torchtrt.ts.logging.get_reportable_log_level() + self.assertEqual(torchtrt.ts.logging.Level.InternalError, lvl) - lvl = torchtrt.logging.get_reportable_log_level() + lvl = torchtrt.ts.logging.get_reportable_log_level() self.assertEqual(base_lvl, lvl) with torchtrt.logging.errors(): - lvl = torchtrt.logging.get_reportable_log_level() - self.assertEqual(torchtrt.logging.Level.Error, lvl) + lvl = torchtrt.ts.logging.get_reportable_log_level() + self.assertEqual(torchtrt.ts.logging.Level.Error, lvl) - lvl = torchtrt.logging.get_reportable_log_level() + lvl = torchtrt.ts.logging.get_reportable_log_level() self.assertEqual(base_lvl, lvl) with torchtrt.logging.warnings(): - lvl = torchtrt.logging.get_reportable_log_level() - self.assertEqual(torchtrt.logging.Level.Warning, lvl) + lvl = torchtrt.ts.logging.get_reportable_log_level() + self.assertEqual(torchtrt.ts.logging.Level.Warning, lvl) - lvl = torchtrt.logging.get_reportable_log_level() + lvl = torchtrt.ts.logging.get_reportable_log_level() self.assertEqual(base_lvl, lvl) with torchtrt.logging.info(): - lvl = torchtrt.logging.get_reportable_log_level() - self.assertEqual(torchtrt.logging.Level.Info, lvl) + lvl = torchtrt.ts.logging.get_reportable_log_level() + self.assertEqual(torchtrt.ts.logging.Level.Info, lvl) - lvl = torchtrt.logging.get_reportable_log_level() + lvl = torchtrt.ts.logging.get_reportable_log_level() self.assertEqual(base_lvl, lvl) with torchtrt.logging.debug(): - lvl = torchtrt.logging.get_reportable_log_level() - self.assertEqual(torchtrt.logging.Level.Debug, lvl) + lvl = torchtrt.ts.logging.get_reportable_log_level() + self.assertEqual(torchtrt.ts.logging.Level.Debug, lvl) - lvl = torchtrt.logging.get_reportable_log_level() + lvl = torchtrt.ts.logging.get_reportable_log_level() self.assertEqual(base_lvl, lvl) with torchtrt.logging.graphs(): - lvl = torchtrt.logging.get_reportable_log_level() - self.assertEqual(torchtrt.logging.Level.Graph, lvl) + lvl = torchtrt.ts.logging.get_reportable_log_level() + self.assertEqual(torchtrt.ts.logging.Level.Graph, lvl) - lvl = torchtrt.logging.get_reportable_log_level() + lvl = torchtrt.ts.logging.get_reportable_log_level() self.assertEqual(base_lvl, lvl) diff --git a/tests/py/ts/hw/test_api_dla.py b/tests/py/ts/hw/test_api_dla.py index 5328b92233..0bf3b74010 100644 --- a/tests/py/ts/hw/test_api_dla.py +++ b/tests/py/ts/hw/test_api_dla.py @@ -1,10 +1,15 @@ import unittest -import torch_tensorrt as torchtrt + import torch +import torch_tensorrt as torchtrt import torchvision.models as models -from utils import cosine_similarity, COSINE_THRESHOLD +from utils import COSINE_THRESHOLD, cosine_similarity +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) class ModelTestCaseOnDLA(unittest.TestCase): def __init__(self, methodName="runTest", model=None): super(ModelTestCaseOnDLA, self).__init__(methodName) @@ -21,6 +26,10 @@ def parametrize(testcase_class, model=None): return suite +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) class TestCompile(ModelTestCaseOnDLA): def setUp(self): self.input = torch.randn((1, 3, 224, 224)).to("cuda").half() diff --git a/tests/py/ts/hw/test_multi_gpu.py b/tests/py/ts/hw/test_multi_gpu.py index b6fa3f220b..4fc2bd9223 100644 --- a/tests/py/ts/hw/test_multi_gpu.py +++ b/tests/py/ts/hw/test_multi_gpu.py @@ -1,11 +1,15 @@ import unittest -import torch_tensorrt as torchtrt + import torch +import torch_tensorrt as torchtrt import torchvision.models as models - from model_test_case import ModelTestCase +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) class TestMultiGpuSwitching(ModelTestCase): def setUp(self): if torch.cuda.device_count() < 2: @@ -65,6 +69,10 @@ def test_compile_script(self): ) +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) class TestMultiGpuSerializeDeserializeSwitching(ModelTestCase): def setUp(self): if torch.cuda.device_count() < 2: diff --git a/tests/py/ts/integrations/test_to_backend_api.py b/tests/py/ts/integrations/test_to_backend_api.py index 0f74a3af15..6e974ba2c8 100644 --- a/tests/py/ts/integrations/test_to_backend_api.py +++ b/tests/py/ts/integrations/test_to_backend_api.py @@ -1,10 +1,16 @@ +# type: ignore import unittest -import torch_tensorrt as torchtrt + import torch +import torch_tensorrt as torchtrt import torchvision.models as models -from utils import cosine_similarity, COSINE_THRESHOLD +from utils import COSINE_THRESHOLD, cosine_similarity +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) class TestToBackendLowering(unittest.TestCase): def setUp(self): self.input = torch.randn((1, 3, 300, 300)).to("cuda") @@ -23,7 +29,9 @@ def setUp(self): "dla_core": 0, "allow_gpu_fallback": True, }, - "capability": torchtrt.EngineCapability.default, + "capability": torchtrt.EngineCapability.STANDARD.to( + torchtrt._C.EngineCapability + ), "num_avg_timing_iters": 1, "disable_tf32": False, } diff --git a/tests/py/ts/integrations/test_trt_intercompatibility.py b/tests/py/ts/integrations/test_trt_intercompatibility.py index b938e4a1ac..6afe9d0428 100644 --- a/tests/py/ts/integrations/test_trt_intercompatibility.py +++ b/tests/py/ts/integrations/test_trt_intercompatibility.py @@ -1,11 +1,16 @@ import unittest -import torch_tensorrt as torchtrt + +import tensorrt as trt import torch +import torch_tensorrt as torchtrt import torchvision.models as models -import tensorrt as trt -from utils import cosine_similarity, COSINE_THRESHOLD +from utils import COSINE_THRESHOLD, cosine_similarity +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) class TestPyTorchToTRTEngine(unittest.TestCase): def test_pt_to_trt(self): self.model = models.resnet18(pretrained=True).eval().to("cuda:0") diff --git a/tests/py/ts/models/test_models.py b/tests/py/ts/models/test_models.py index 5678e8f648..1d5c3bae3b 100644 --- a/tests/py/ts/models/test_models.py +++ b/tests/py/ts/models/test_models.py @@ -10,6 +10,10 @@ from utils import COSINE_THRESHOLD, cosine_similarity +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) class TestModels(unittest.TestCase): def test_resnet18(self): self.model = models.resnet18(pretrained=True).eval().to("cuda") diff --git a/tests/py/ts/models/test_multiple_registered_engines.py b/tests/py/ts/models/test_multiple_registered_engines.py index e8c1f95433..407502f04a 100644 --- a/tests/py/ts/models/test_multiple_registered_engines.py +++ b/tests/py/ts/models/test_multiple_registered_engines.py @@ -1,14 +1,19 @@ +import copy import unittest -import torch_tensorrt as torchtrt +from typing import Dict + +import custom_models as cm +import timm import torch +import torch_tensorrt as torchtrt import torchvision.models as models -import copy -import timm -import custom_models as cm -from typing import Dict -from utils import cosine_similarity, COSINE_THRESHOLD +from utils import COSINE_THRESHOLD, cosine_similarity +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) class TestModelToEngineToModel(unittest.TestCase): def test_multiple_engines(self): self.resnet18 = models.resnet18(pretrained=True).eval().to("cuda") diff --git a/tests/py/ts/ptq/test_ptq_dataloader_calibrator.py b/tests/py/ts/ptq/test_ptq_dataloader_calibrator.py index c5a84f301d..2fac02f542 100644 --- a/tests/py/ts/ptq/test_ptq_dataloader_calibrator.py +++ b/tests/py/ts/ptq/test_ptq_dataloader_calibrator.py @@ -1,13 +1,13 @@ +import os import unittest -import torch_tensorrt as torchtrt -from torch_tensorrt.logging import * + import torch import torch.nn as nn -from torch.nn import functional as F +import torch_tensorrt as torchtrt import torchvision import torchvision.transforms as transforms - -import os +from torch.nn import functional as F +from torch_tensorrt.ts.logging import * def find_repo_root(max_depth=10): @@ -49,6 +49,10 @@ def compute_accuracy(testing_dataloader, model): return correct / total +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) class TestAccuracy(unittest.TestCase): def test_compile_script(self): self.model = ( diff --git a/tests/py/ts/ptq/test_ptq_to_backend.py b/tests/py/ts/ptq/test_ptq_to_backend.py index 3a0a5bf336..d016dedb15 100644 --- a/tests/py/ts/ptq/test_ptq_to_backend.py +++ b/tests/py/ts/ptq/test_ptq_to_backend.py @@ -1,12 +1,13 @@ +import os import unittest -import torch_tensorrt as torchtrt -from torch_tensorrt.logging import * + import torch import torch.nn as nn -from torch.nn import functional as F +import torch_tensorrt as torchtrt import torchvision import torchvision.transforms as transforms -import os +from torch.nn import functional as F +from torch_tensorrt.ts.logging import * def find_repo_root(max_depth=10): @@ -48,6 +49,10 @@ def compute_accuracy(testing_dataloader, model): return correct / total +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) class TestAccuracy(unittest.TestCase): def test_compile_script(self): self.model = ( diff --git a/tests/py/ts/ptq/test_ptq_trt_calibrator.py b/tests/py/ts/ptq/test_ptq_trt_calibrator.py index 93596c895d..8564d14b71 100644 --- a/tests/py/ts/ptq/test_ptq_trt_calibrator.py +++ b/tests/py/ts/ptq/test_ptq_trt_calibrator.py @@ -1,13 +1,14 @@ -import unittest import os -import torch_tensorrt as torchtrt -from torch_tensorrt.logging import * -import torch +import unittest + import tensorrt as trt +import torch import torch.nn as nn -from torch.nn import functional as F +import torch_tensorrt as torchtrt import torchvision import torchvision.transforms as transforms +from torch.nn import functional as F +from torch_tensorrt.ts.logging import * def find_repo_root(max_depth=10): @@ -49,6 +50,10 @@ def compute_accuracy(testing_dataloader, model): return correct / total +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) class TRTEntropyCalibrator(trt.IInt8EntropyCalibrator2): def __init__(self, dataloader, **kwargs): trt.IInt8EntropyCalibrator2.__init__(self) @@ -94,6 +99,10 @@ def write_calibration_cache(self, cache): f.write(cache) +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) class TestAccuracy(unittest.TestCase): def test_compile_script(self): self.model = ( diff --git a/tests/py/ts/qat/test_qat_trt_accuracy.py b/tests/py/ts/qat/test_qat_trt_accuracy.py index ce574c57fe..ade2cfc865 100644 --- a/tests/py/ts/qat/test_qat_trt_accuracy.py +++ b/tests/py/ts/qat/test_qat_trt_accuracy.py @@ -1,13 +1,14 @@ +import os +import sys import unittest -import torch_tensorrt as torchtrt -from torch_tensorrt.logging import * + import torch import torch.nn as nn -from torch.nn import functional as F +import torch_tensorrt as torchtrt import torchvision import torchvision.transforms as transforms -import os -import sys +from torch.nn import functional as F +from torch_tensorrt.ts.logging import * def find_repo_root(max_depth=10): @@ -51,6 +52,10 @@ def compute_accuracy(testing_dataloader, model): return correct / total +@unittest.skipIf( + not torchtrt.ENABLED_FEATURES.torchscript_frontend, + "TorchScript Frontend is not available", +) class TestAccuracy(unittest.TestCase): def test_compile_script(self): self.model = ( diff --git a/toolchains/legacy/pyproject.toml b/toolchains/legacy/pyproject.toml index ce9e6423cb..90b2d4f2ec 100644 --- a/toolchains/legacy/pyproject.toml +++ b/toolchains/legacy/pyproject.toml @@ -64,7 +64,7 @@ include-package-data = false [tool.ruff] # NOTE: Synchoronize the ignores with .flake8 -ignore = [ +lint.ignore = [ # these ignores are from flake8-bugbear; please fix! "B007", "B008", "B017", "B018", # Useless expression @@ -97,7 +97,7 @@ ignore = [ "SIM118", ] #line-length = 120 -select = [ +lint.select = [ "B", "C4", "G", @@ -112,11 +112,11 @@ select = [ ] # Allow unused variables when underscore-prefixed. -dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" +lint.dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" target-version = "py311" # Allow autofix for all enabled rules (when `--fix`) is provided. -fixable = [ +lint.fixable = [ "A","B","C","D","E","F","G", "I","N","Q","S","T","W", "ANN", "ARG", "BLE", "COM", "DJ", @@ -125,10 +125,10 @@ fixable = [ "PGH", "PIE", "PL", "PT", "PTH", "PYI", "RET", "RSE", "RUF", "SIM", "SLF", "TCH", "TID", "TRY", "UP", "YTT"] -unfixable = [] +lint.unfixable = [] # Exclude a variety of commonly ignored directories. -exclude = [ +lint.exclude = [ ".bzr", ".direnv", ".eggs", @@ -164,7 +164,7 @@ exclude = [ "__init__.py" ] -[tool.ruff.mccabe] +[tool.ruff.lint.mccabe] # Unlike Flake8, default to a complexity level of 10. max-complexity = 10