diff --git a/.github/workflows/build-test-linux.yml b/.github/workflows/build-test-linux.yml index 3f3add0655..1da61a033c 100644 --- a/.github/workflows/build-test-linux.yml +++ b/.github/workflows/build-test-linux.yml @@ -77,6 +77,7 @@ jobs: pre-script: ${{ matrix.pre-script }} script: | export USE_HOST_DEPS=1 + export CI_BUILD=1 export LD_LIBRARY_PATH=/usr/lib64:$LD_LIBRARY_PATH pushd . cd tests/modules @@ -112,6 +113,7 @@ jobs: pre-script: ${{ matrix.pre-script }} script: | export USE_HOST_DEPS=1 + export CI_BUILD=1 pushd . cd tests/py/dynamo python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 8 conversion/ @@ -140,6 +142,7 @@ jobs: pre-script: ${{ matrix.pre-script }} script: | export USE_HOST_DEPS=1 + export CI_BUILD=1 pushd . cd tests/py/dynamo python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_models_export.xml --ir dynamo models/ @@ -168,6 +171,7 @@ jobs: pre-script: ${{ matrix.pre-script }} script: | export USE_HOST_DEPS=1 + export CI_BUILD=1 pushd . cd tests/py/dynamo python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_serde_test_results.xml --ir dynamo models/test_export_serde.py @@ -196,6 +200,7 @@ jobs: pre-script: ${{ matrix.pre-script }} script: | export USE_HOST_DEPS=1 + export CI_BUILD=1 pushd . cd tests/py/dynamo python -m pytest -ra -n 10 --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_compile_be_test_results.xml backend/ @@ -226,6 +231,7 @@ jobs: pre-script: ${{ matrix.pre-script }} script: | export USE_HOST_DEPS=1 + export CI_BUILD=1 pushd . cd tests/py/dynamo python -m pytest -ra -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_runtime_test_results.xml --ignore runtime/test_002_cudagraphs_py.py --ignore runtime/test_002_cudagraphs_cpp.py runtime/ @@ -256,6 +262,7 @@ jobs: pre-script: ${{ matrix.pre-script }} script: | export USE_HOST_DEPS=1 + export CI_BUILD=1 pushd . cd tests/py/dynamo nvidia-smi @@ -286,6 +293,7 @@ jobs: pre-script: ${{ matrix.pre-script }} script: | export USE_HOST_DEPS=1 + export CI_BUILD=1 pushd . cd tests/py/core python -m pytest -ra -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_core_test_results.xml . diff --git a/.github/workflows/build-test-windows.yml b/.github/workflows/build-test-windows.yml index 6123e0b2cc..53db564322 100644 --- a/.github/workflows/build-test-windows.yml +++ b/.github/workflows/build-test-windows.yml @@ -83,6 +83,7 @@ jobs: pre-script: packaging/driver_upgrade.bat script: | export USE_HOST_DEPS=1 + export CI_BUILD=1 pushd . cd tests/modules python hub.py @@ -114,6 +115,7 @@ jobs: pre-script: packaging/driver_upgrade.bat script: | export USE_HOST_DEPS=1 + export CI_BUILD=1 pushd . cd tests/py/dynamo python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 10 conversion/ @@ -139,6 +141,7 @@ jobs: pre-script: packaging/driver_upgrade.bat script: | export USE_HOST_DEPS=1 + export CI_BUILD=1 pushd . cd tests/py/dynamo python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_models_export.xml --ir dynamo models/ @@ -164,6 +167,7 @@ jobs: pre-script: packaging/driver_upgrade.bat script: | export USE_HOST_DEPS=1 + export CI_BUILD=1 pushd . cd tests/py/dynamo python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_serde_test_results.xml --ir dynamo models/test_export_serde.py @@ -189,6 +193,7 @@ jobs: pre-script: packaging/driver_upgrade.bat script: | export USE_HOST_DEPS=1 + export CI_BUILD=1 pushd . cd tests/py/dynamo python -m pytest -ra -n 10 --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_compile_be_test_results.xml backend/ @@ -216,6 +221,7 @@ jobs: pre-script: packaging/driver_upgrade.bat script: | export USE_HOST_DEPS=1 + export CI_BUILD=1 pushd . cd tests/py/dynamo python -m pytest -ra -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_runtime_test_results.xml --ignore runtime/test_002_cudagraphs_py.py --ignore runtime/test_002_cudagraphs_cpp.py runtime/ @@ -246,6 +252,7 @@ jobs: pre-script: ${{ matrix.pre-script }} script: | export USE_HOST_DEPS=1 + export CI_BUILD=1 pushd . cd tests/py/dynamo python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_runtime_cudagraphs_cpp_test_results.xml runtime/test_002_cudagraphs_cpp.py @@ -272,6 +279,7 @@ jobs: pre-script: packaging/driver_upgrade.bat script: | export USE_HOST_DEPS=1 + export CI_BUILD=1 pushd . cd tests/py/core python -m pytest -ra -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_core_test_results.xml . diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTBuilderMonitor.py b/py/torch_tensorrt/dynamo/conversion/_TRTBuilderMonitor.py new file mode 100644 index 0000000000..9a1189e44a --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/_TRTBuilderMonitor.py @@ -0,0 +1,159 @@ +import os +import sys +from typing import Any, Dict, Optional + +import tensorrt as trt + + +class _ASCIIMonitor(trt.IProgressMonitor): # type: ignore + def __init__(self, engine_name: str = "") -> None: + trt.IProgressMonitor.__init__(self) + self._active_phases: Dict[str, Dict[str, Any]] = {} + self._step_result = True + + self._render = True + if (ci_env_var := os.environ.get("CI_BUILD")) is not None: + if ci_env_var == "1": + self._render = False + + def phase_start( + self, phase_name: str, parent_phase: Optional[str], num_steps: int + ) -> None: + try: + if parent_phase is not None: + nbIndents = 1 + self._active_phases[parent_phase]["nbIndents"] + else: + nbIndents = 0 + self._active_phases[phase_name] = { + "title": phase_name, + "steps": 0, + "num_steps": num_steps, + "nbIndents": nbIndents, + } + self._redraw() + except KeyboardInterrupt: + _step_result = False + + def phase_finish(self, phase_name: str) -> None: + try: + del self._active_phases[phase_name] + self._redraw(blank_lines=1) # Clear the removed phase. + except KeyboardInterrupt: + _step_result = False + + def step_complete(self, phase_name: str, step: int) -> bool: + try: + self._active_phases[phase_name]["steps"] = step + self._redraw() + return self._step_result + except KeyboardInterrupt: + return False + + def _redraw(self, *, blank_lines: int = 0) -> None: + if self._render: + + def clear_line() -> None: + print("\x1B[2K", end="") + + def move_to_start_of_line() -> None: + print("\x1B[0G", end="") + + def move_cursor_up(lines: int) -> None: + print("\x1B[{}A".format(lines), end="") + + def progress_bar(steps: int, num_steps: int) -> str: + INNER_WIDTH = 10 + completed_bar_chars = int(INNER_WIDTH * steps / float(num_steps)) + return "[{}{}]".format( + "=" * completed_bar_chars, "-" * (INNER_WIDTH - completed_bar_chars) + ) + + # Set max_cols to a default of 200 if not run in interactive mode. + max_cols = os.get_terminal_size().columns if sys.stdout.isatty() else 200 + + move_to_start_of_line() + for phase in self._active_phases.values(): + phase_prefix = "{indent}{bar} {title}".format( + indent=" " * phase["nbIndents"], + bar=progress_bar(phase["steps"], phase["num_steps"]), + title=phase["title"], + ) + phase_suffix = "{steps}/{num_steps}".format(**phase) + allowable_prefix_chars = max_cols - len(phase_suffix) - 2 + if allowable_prefix_chars < len(phase_prefix): + phase_prefix = phase_prefix[0 : allowable_prefix_chars - 3] + "..." + clear_line() + print(phase_prefix, phase_suffix) + for line in range(blank_lines): + clear_line() + print() + move_cursor_up(len(self._active_phases) + blank_lines) + sys.stdout.flush() + + +try: + from rich.progress import BarColumn, Progress, TaskID, TextColumn, TimeElapsedColumn + + class _RichMonitor(trt.IProgressMonitor): # type: ignore + def __init__(self, engine_name: str = "") -> None: + trt.IProgressMonitor.__init__(self) + self._active_phases: Dict[str, TaskID] = {} + self._step_result = True + + self._progress_monitors = Progress( + TextColumn(" "), + TimeElapsedColumn(), + TextColumn("{task.description}: "), + BarColumn(), + TextColumn(" {task.percentage:.0f}% ({task.completed}/{task.total})"), + ) + + self._render = True + if (ci_env_var := os.environ.get("CI_BUILD")) is not None: + if ci_env_var == "1": + self._render = False + + if self._render: + self._progress_monitors.start() + + def phase_start( + self, phase_name: str, parent_phase: Optional[str], num_steps: int + ) -> None: + try: + self._active_phases[phase_name] = self._progress_monitors.add_task( + phase_name, total=num_steps + ) + self._progress_monitors.refresh() + except KeyboardInterrupt: + # The phase_start callback cannot directly cancel the build, so request the cancellation from within step_complete. + _step_result = False + + def phase_finish(self, phase_name: str) -> None: + try: + self._progress_monitors.update( + self._active_phases[phase_name], visible=False + ) + self._progress_monitors.stop_task(self._active_phases[phase_name]) + self._progress_monitors.remove_task(self._active_phases[phase_name]) + self._progress_monitors.refresh() + except KeyboardInterrupt: + _step_result = False + + def step_complete(self, phase_name: str, step: int) -> bool: + try: + self._progress_monitors.update( + self._active_phases[phase_name], completed=step + ) + self._progress_monitors.refresh() + return self._step_result + except KeyboardInterrupt: + # There is no need to propagate this exception to TensorRT. We can simply cancel the build. + return False + + def __del__(self) -> None: + if self._progress_monitors: + self._progress_monitors.stop() + + TRTBulderMonitor: trt.IProgressMonitor = _RichMonitor +except ImportError: + TRTBulderMonitor: trt.IProgressMonitor = _ASCIIMonitor # type: ignore[no-redef] diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 9a3cace599..17437ceb6e 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -6,7 +6,6 @@ from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Tuple import numpy as np -import tensorrt as trt import torch import torch.fx from torch.fx.node import _get_qualified_name @@ -21,6 +20,7 @@ DYNAMO_CONVERTERS as CONVERTERS, ) from torch_tensorrt.dynamo.conversion._ConverterRegistry import CallingConvention +from torch_tensorrt.dynamo.conversion._TRTBuilderMonitor import TRTBulderMonitor from torch_tensorrt.dynamo.conversion.converter_utils import ( get_node_io, get_node_name, @@ -30,6 +30,7 @@ from torch_tensorrt.fx.observer import Observer from torch_tensorrt.logging import TRT_LOGGER +import tensorrt as trt from packaging import version _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -146,7 +147,7 @@ def clean_repr(x: Any, depth: int = 0) -> Any: else: return "(...)" else: - return x + return f"{x} <{type(x).__name__}>" str_args = [clean_repr(a) for a in args] return repr(tuple(str_args)) @@ -176,6 +177,10 @@ def _populate_trt_builder_config( ) -> trt.IBuilderConfig: builder_config = self.builder.create_builder_config() + + if self.compilation_settings.debug: + builder_config.progress_monitor = TRTBulderMonitor() + if self.compilation_settings.workspace_size != 0: builder_config.set_memory_pool_limit( trt.MemoryPoolType.WORKSPACE, self.compilation_settings.workspace_size @@ -516,18 +521,18 @@ def run_node(self, n: torch.fx.Node) -> torch.fx.Node: kwargs["_itensor_to_tensor_meta"] = self._itensor_to_tensor_meta n.kwargs = kwargs - # run the node - _LOGGER.debug( - f"Running node {self._cur_node_name}, a {self._cur_node.op} node " - f"with target {self._cur_node.target} in the TensorRT Interpreter" - ) + if _LOGGER.isEnabledFor(logging.DEBUG): + _LOGGER.debug( + f"Converting node {self._cur_node_name} (kind: {n.target}, args: {TRTInterpreter._args_str(n.args)})" + ) + trt_node: torch.fx.Node = super().run_node(n) if n.op == "get_attr": self.const_mapping[str(n)] = (tuple(trt_node.shape), str(trt_node.dtype)) - _LOGGER.debug( - f"Ran node {self._cur_node_name} with properties: {get_node_io(n, self.const_mapping)}" + _LOGGER.info( + f"Converted node {self._cur_node_name} [{n.target}] ({get_node_io(n, self.const_mapping)})" ) # remove "_itensor_to_tensor_meta" @@ -611,9 +616,7 @@ def call_module( converter, calling_convention = converter_packet assert self._cur_node_name is not None - _LOGGER.debug( - f"Converting node {self._cur_node_name} (kind: {target}, args: {TRTInterpreter._args_str(args)})" - ) + if calling_convention is CallingConvention.LEGACY: return converter(self.ctx.net, submod, args, kwargs, self._cur_node_name) else: @@ -629,10 +632,6 @@ def call_function(self, target: str, args: Any, kwargs: Any) -> Any: converter, calling_convention = converter_packet - assert self._cur_node_name is not None - _LOGGER.debug( - f"Converting node {self._cur_node_name} (kind: {target}, args: {TRTInterpreter._args_str(args)})" - ) if calling_convention is CallingConvention.LEGACY: return converter(self.ctx.net, target, args, kwargs, self._cur_node_name) else: @@ -663,10 +662,6 @@ def call_method(self, target: str, args: Any, kwargs: Any) -> Any: ) converter, calling_convention = converter_packet - assert self._cur_node_name is not None - _LOGGER.debug( - f"Converting node {self._cur_node_name} (kind: {target}, args: {TRTInterpreter._args_str(args)})" - ) if calling_convention is CallingConvention.LEGACY: return converter(self.ctx.net, target, args, kwargs, self._cur_node_name) else: diff --git a/pyproject.toml b/pyproject.toml index 1d6570db90..a84724f968 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,11 +17,9 @@ build-backend = "setuptools.build_meta" [project] name = "torch_tensorrt" -authors = [ - {name="NVIDIA Corporation", email="narens@nvidia.com"} -] +authors = [{ name = "NVIDIA Corporation", email = "narens@nvidia.com" }] description = "Torch-TensorRT is a package which allows users to automatically compile PyTorch and TorchScript modules to TensorRT while remaining in PyTorch" -license = {file = "LICENSE"} +license = { file = "LICENSE" } classifiers = [ "Development Status :: 5 - Production/Stable", "Environment :: GPU :: NVIDIA CUDA", @@ -37,9 +35,24 @@ classifiers = [ "Topic :: Software Development", "Topic :: Software Development :: Libraries", ] -readme = {file = "README.md", content-type = "text/markdown"} +readme = { file = "README.md", content-type = "text/markdown" } requires-python = ">=3.8" -keywords = ["pytorch", "torch", "tensorrt", "trt", "ai", "artificial intelligence", "ml", "machine learning", "dl", "deep learning", "compiler", "dynamo", "torchscript", "inference"] +keywords = [ + "pytorch", + "torch", + "tensorrt", + "trt", + "ai", + "artificial intelligence", + "ml", + "machine learning", + "dl", + "deep learning", + "compiler", + "dynamo", + "torchscript", + "inference", +] dependencies = [ "torch >=2.5.0.dev,<2.6.0", "tensorrt==10.1.0", @@ -53,6 +66,9 @@ dynamic = ["version"] [project.optional-dependencies] torchvision = ["torchvision >=0.20.dev,<0.21.0"] +quantization = ["nvidia-modelopt[all]>=0.15.1"] +monitoring-tools = ["rich >= 13.7.1"] +jupyter = ["rich[jupyter] >= 13.7.1"] [project.urls] Homepage = "https://pytorch.org/tensorrt" @@ -61,40 +77,53 @@ Repository = "https://github.com/pytorch/tensorrt.git" Changelog = "https://github.com/pytorch/tensorrt/releases" [tool.setuptools] -package-dir = {"" = "py"} +package-dir = { "" = "py" } include-package-data = false [tool.ruff] # NOTE: Synchoronize the ignores with .flake8 lint.ignore = [ # these ignores are from flake8-bugbear; please fix! - "B007", "B008", "B017", - "B018", # Useless expression - "B019", "B020", - "B023", "B024", "B026", - "B028", # No explicit `stacklevel` keyword argument found - "B904", "B905", + "B007", + "B008", + "B017", + "B018", # Useless expression + "B019", + "B020", + "B023", + "B024", + "B026", + "B028", # No explicit `stacklevel` keyword argument found + "B904", + "B905", "E402", - "C408", # C408 ignored because we like the dict keyword argument syntax - "E501", # E501 is not flexible enough, we're using B950 instead + "C408", # C408 ignored because we like the dict keyword argument syntax + "E501", # E501 is not flexible enough, we're using B950 instead "E721", - "E731", # Assign lambda expression + "E731", # Assign lambda expression "E741", "EXE001", "F405", "F821", "F841", # these ignores are from flake8-logging-format; please fix! - "G101", "G201", "G202", "G003", "G004", + "G101", + "G201", + "G202", + "G003", + "G004", # these ignores are from RUFF perf; please fix! - "PERF203", "PERF4", - "SIM102", "SIM103", "SIM112", # flake8-simplify code styles - "SIM105", # these ignores are from flake8-simplify. please fix or ignore with commented reason + "PERF203", + "PERF4", + "SIM102", + "SIM103", + "SIM112", # flake8-simplify code styles + "SIM105", # these ignores are from flake8-simplify. please fix or ignore with commented reason "SIM108", "SIM110", - "SIM114", # Combine `if` branches using logical `or` operator + "SIM114", # Combine `if` branches using logical `or` operator "SIM115", - "SIM116", # Disable Use a dictionary instead of consecutive `if` statements + "SIM116", # Disable Use a dictionary instead of consecutive `if` statements "SIM117", "SIM118", ] @@ -119,14 +148,51 @@ target-version = "py311" # Allow autofix for all enabled rules (when `--fix`) is provided. lint.fixable = [ - "A","B","C","D","E","F","G", - "I","N","Q","S","T","W", - "ANN", "ARG", "BLE", "COM", "DJ", - "DTZ", "EM", "ERA", "EXE", "FBT", - "ICN", "INP", "ISC", "NPY", "PD", - "PGH", "PIE", "PL", "PT", "PTH", - "PYI", "RET", "RSE", "RUF", "SIM", - "SLF", "TCH", "TID", "TRY", "UP", "YTT"] + "A", + "B", + "C", + "D", + "E", + "F", + "G", + "I", + "N", + "Q", + "S", + "T", + "W", + "ANN", + "ARG", + "BLE", + "COM", + "DJ", + "DTZ", + "EM", + "ERA", + "EXE", + "FBT", + "ICN", + "INP", + "ISC", + "NPY", + "PD", + "PGH", + "PIE", + "PL", + "PT", + "PTH", + "PYI", + "RET", + "RSE", + "RUF", + "SIM", + "SLF", + "TCH", + "TID", + "TRY", + "UP", + "YTT", +] lint.unfixable = [] # Exclude a variety of commonly ignored directories. @@ -163,7 +229,7 @@ exclude = [ "tests", "setup.py", "noxfile.py", - "__init__.py" + "__init__.py", ] [tool.ruff.lint.mccabe] @@ -173,9 +239,7 @@ max-complexity = 10 [tool.isort] profile = "black" py_version = 311 -skip = [ - "py/torch_tensorrt/fx", -] +skip = ["py/torch_tensorrt/fx"] [tool.black] #line-length = 120 @@ -200,7 +264,7 @@ exclude = [ "docsrc", "tests", "setup.py", - "noxfile.py" + "noxfile.py", ] python_version = "3.11" @@ -230,14 +294,14 @@ files.extend-exclude = [ "CHANGELOG.md", "*.ipynb", "cpp/", - "py/torch_tensorrt/fx/" + "py/torch_tensorrt/fx/", ] [tool.typos.default] extend-ignore-identifiers-re = [ "^([A-z]|[a-z])*Nd*", "^([A-z]|[a-z])*nd*", - "activ*([A-z]|[a-z]|[0-9])*," + "activ*([A-z]|[a-z]|[0-9])*,", ] [tool.typos.default.extend-words]