From 36e7cc80101b51448645b6d8193d00f57a4c2cf6 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Thu, 15 Aug 2024 14:14:55 -0700 Subject: [PATCH 01/11] Fixed device issue when model is on CPU --- py/torch_tensorrt/dynamo/_compiler.py | 6 +++--- py/torch_tensorrt/dynamo/_refit.py | 10 +++++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index c97c3a6229..0e5e09de8a 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -56,9 +56,9 @@ def compile( disable_tf32: bool = _defaults.DISABLE_TF32, assume_dynamic_shape_support: bool = _defaults.ASSUME_DYNAMIC_SHAPE_SUPPORT, sparse_weights: bool = _defaults.SPARSE_WEIGHTS, - enabled_precisions: ( - Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype] - ) = _defaults.ENABLED_PRECISIONS, + enabled_precisions: Union[ + Set[Union[torch.dtype, dtype]], Tuple[Union[torch.dtype, dtype]] + ] = _defaults.ENABLED_PRECISIONS, engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, make_refitable: bool = _defaults.MAKE_REFITABLE, debug: bool = _defaults.DEBUG, diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 660cb8a875..f03009c3d8 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -157,13 +157,17 @@ def _refit_single_trt_engine_with_gm( """ refitted = set() - + torch_device = list(new_gm.state_dict().values())[0].device.type refitter = trt.Refitter(old_engine, TRT_LOGGER) weight_list = refitter.get_all_weights() if weight_name_map: # Get the refitting mapping - trt_wt_location = trt.TensorLocation.DEVICE + trt_wt_location = ( + trt.TensorLocation.DEVICE + if torch_device == "cuda" + else trt.TensorLocation.HOST + ) mapping = construct_refit_mapping_from_weight_name_map( weight_name_map, new_gm.state_dict() ) @@ -235,7 +239,7 @@ def refit_module_weights( compiled_module = copy.deepcopy(compiled_module) elif inline_module: raise AssertionError( - "Exported program does not support modifying in place. Please set inplace to false and use the returned graph module." + "Exported program does not support modifying in place. Please set in_place to false and use the returned graph module." ) # Get the settings and check the setting to be uniform From 2998c9158fe90f403fad0e230c3cbb7b92e711c1 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Fri, 16 Aug 2024 16:39:32 -0700 Subject: [PATCH 02/11] Fixed prolong time of weight name mapping construction --- .../dynamo/conversion/_TRTInterpreter.py | 42 +++++++++++++++---- 1 file changed, 34 insertions(+), 8 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 17437ceb6e..6e06a2c479 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -3,7 +3,18 @@ import os import warnings from datetime import datetime -from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Tuple +from typing import ( + Any, + Callable, + Dict, + List, + NamedTuple, + Optional, + Sequence, + Set, + Tuple, + Union, +) import numpy as np import torch @@ -26,9 +37,10 @@ get_node_name, get_trt_tensor, ) -from torch_tensorrt.dynamo.utils import DYNAMIC_DIM +from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, to_torch_device from torch_tensorrt.fx.observer import Observer from torch_tensorrt.logging import TRT_LOGGER +from tqdm import tqdm import tensorrt as trt from packaging import version @@ -339,18 +351,22 @@ def _save_weight_mapping(self) -> None: def find_weight( weight_name: str, np_map: dict[str, Any], sd: dict[str, Any] ) -> str: - network_weight = np_map[weight_name] + network_weight = torch.from_numpy(np_map[weight_name]).cuda() for sd_w_name, sd_weight in sd.items(): if check_weight_equal(sd_weight, network_weight): + del sd[sd_w_name] return sd_w_name return "" def check_weight_equal( - sd_weight: torch.tensor, network_weight: np.ndarray + sd_weight: torch.tensor, network_weight: Union[np.ndarray, torch.Tensor] ) -> Any: - sd_weight = sd_weight.reshape(-1).cpu().numpy() - return sd_weight.size == network_weight.size and np.allclose( - sd_weight, network_weight, 1e-1, 1e-1 + sd_weight = sd_weight.reshape(-1) + if not isinstance(network_weight, torch.Tensor): + network_weight = torch.from_numpy(network_weight).cuda() + return ( + sd_weight.shape == network_weight.shape + and torch.all(torch.abs(sd_weight - network_weight) < 0.1).cpu() ) MODULE_MAP = { @@ -398,8 +414,14 @@ def check_weight_equal( ) } """ + _LOGGER.info("building weight name mapping...") # Stage 1: Name mapping sd = self.module.state_dict() + gm_is_on_cuda = list(sd.values())[0].device.type == "cuda" + # If the model original position is on CPU, move it GPU + if not gm_is_on_cuda: + self.module.to(to_torch_device(self.compilation_settings.device)) + sd = self.module.state_dict() weight_name_map: dict[str, Any] = {} np_map = {} net = self.ctx.net @@ -444,7 +466,7 @@ def check_weight_equal( np_map[engine_weight_name] = weight # Stage 2: Value mapping - for engine_weight_name, sd_weight_name in weight_name_map.items(): + for engine_weight_name, sd_weight_name in tqdm(weight_name_map.items()): if "SCALE" in engine_weight_name: # There is no direct connection in batch_norm layer. So skip it pass @@ -461,6 +483,10 @@ def check_weight_equal( ] self.weight_name_map = weight_name_map + # If the model original position is on CPU, set it back to CPU and save GPU memory + if not gm_is_on_cuda: + self.module.to("cpu") + torch.cuda.empty_cache() def run( self, From e68a7e3c27b2783548d1573d41eb39b93d82d226 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Sat, 17 Aug 2024 12:13:25 -0700 Subject: [PATCH 03/11] Handled the memory issue --- .../dynamo/conversion/_TRTInterpreter.py | 6 +++--- py/torch_tensorrt/dynamo/conversion/_conversion.py | 11 +++++++---- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 6e06a2c479..0baa2b9e83 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -364,9 +364,8 @@ def check_weight_equal( sd_weight = sd_weight.reshape(-1) if not isinstance(network_weight, torch.Tensor): network_weight = torch.from_numpy(network_weight).cuda() - return ( - sd_weight.shape == network_weight.shape - and torch.all(torch.abs(sd_weight - network_weight) < 0.1).cpu() + return sd_weight.shape == network_weight.shape and torch.all( + torch.abs(sd_weight - network_weight) < 0.1 ) MODULE_MAP = { @@ -486,6 +485,7 @@ def check_weight_equal( # If the model original position is on CPU, set it back to CPU and save GPU memory if not gm_is_on_cuda: self.module.to("cpu") + del np_map, sd torch.cuda.empty_cache() def run( diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index e03c6cf832..6c96347239 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -130,13 +130,13 @@ def convert_module( from torch_tensorrt.dynamo._refit import _refit_single_trt_engine_with_gm from torch_tensorrt.logging import TRT_LOGGER - runtime = trt.Runtime(TRT_LOGGER) - refit_test_engine = runtime.deserialize_cuda_engine( - interpreter_result.serialized_engine - ) weight_name_map: Any = None # Do the test refit with cached map if make_refitable is enabled if settings.make_refitable: + runtime = trt.Runtime(TRT_LOGGER) + refit_test_engine = runtime.deserialize_cuda_engine( + interpreter_result.serialized_engine + ) weight_name_map = interpreter_result.weight_name_map try: _refit_single_trt_engine_with_gm( @@ -149,6 +149,9 @@ def convert_module( except AssertionError: logger.warning("Fast refit test failed. Removing the weight map caching.") + del refit_test_engine + torch.cuda.empty_cache() + rt_cls = PythonTorchTensorRTModule if ENABLED_FEATURES.torch_tensorrt_runtime and not settings.use_python_runtime: From e8dc166627bc72f1c92ab423b98cda43d64c9ee4 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Sat, 17 Aug 2024 15:48:03 -0700 Subject: [PATCH 04/11] Seriously optimized weight name map construction --- .../dynamo/conversion/_TRTInterpreter.py | 64 +++++++++++-------- 1 file changed, 36 insertions(+), 28 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 0baa2b9e83..780a866a38 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -339,6 +339,31 @@ def _construct_trt_network_def(self) -> None: f"TRT INetwork construction elapsed time: {datetime.now() - run_module_start_time}" ) + @staticmethod + def find_weight( + weight_name: str, np_map: dict[str, Any], sd: dict[str, Any] + ) -> str: + network_weight = np_map[weight_name] + network_weight = torch.from_numpy(np_map[weight_name]).cuda() + for sd_w_name, sd_weight in sd.items(): + if TRTInterpreter.check_weight_equal(sd_weight, network_weight): + del sd[sd_w_name] + return sd_w_name + return "" + + @staticmethod + def check_weight_equal( + sd_weight: torch.tensor, network_weight: Union[torch.Tensor, np.ndarray] + ) -> Any: + if not isinstance(network_weight, torch.Tensor): + network_weight = torch.from_numpy(network_weight).cuda() + try: + return sd_weight.shape == network_weight.shape and torch.all( + torch.abs(sd_weight - network_weight) < 0.1 + ) + except Exception: + return torch.all(sd_weight == network_weight) + def _save_weight_mapping(self) -> None: """ Construct the weight name mapping from engine weight name to state_dict weight name. @@ -348,26 +373,6 @@ def _save_weight_mapping(self) -> None: 2. Value mapping that, for each weight in INetworkDefinition search for identical weight in state_dict """ - def find_weight( - weight_name: str, np_map: dict[str, Any], sd: dict[str, Any] - ) -> str: - network_weight = torch.from_numpy(np_map[weight_name]).cuda() - for sd_w_name, sd_weight in sd.items(): - if check_weight_equal(sd_weight, network_weight): - del sd[sd_w_name] - return sd_w_name - return "" - - def check_weight_equal( - sd_weight: torch.tensor, network_weight: Union[np.ndarray, torch.Tensor] - ) -> Any: - sd_weight = sd_weight.reshape(-1) - if not isinstance(network_weight, torch.Tensor): - network_weight = torch.from_numpy(network_weight).cuda() - return sd_weight.shape == network_weight.shape and torch.all( - torch.abs(sd_weight - network_weight) < 0.1 - ) - MODULE_MAP = { "SCALE": ( trt.IScaleLayer, @@ -416,11 +421,16 @@ def check_weight_equal( _LOGGER.info("building weight name mapping...") # Stage 1: Name mapping sd = self.module.state_dict() + torch_device = to_torch_device(self.compilation_settings.device) gm_is_on_cuda = list(sd.values())[0].device.type == "cuda" - # If the model original position is on CPU, move it GPU if not gm_is_on_cuda: - self.module.to(to_torch_device(self.compilation_settings.device)) - sd = self.module.state_dict() + # If the model original position is on CPU, move it GPU + sd = { + k: v.reshape(-1).to(torch_device) + for k, v in self.module.state_dict().items() + } + else: + sd = {k: v.reshape(-1) for k, v in self.module.state_dict().items()} weight_name_map: dict[str, Any] = {} np_map = {} net = self.ctx.net @@ -469,10 +479,10 @@ def check_weight_equal( if "SCALE" in engine_weight_name: # There is no direct connection in batch_norm layer. So skip it pass - elif sd_weight_name not in sd or not check_weight_equal( + elif sd_weight_name not in sd or not TRTInterpreter.check_weight_equal( sd[sd_weight_name], np_map[engine_weight_name] ): - weight_name_map[engine_weight_name] = find_weight( + weight_name_map[engine_weight_name] = TRTInterpreter.find_weight( engine_weight_name, np_map, sd ) @@ -482,9 +492,7 @@ def check_weight_equal( ] self.weight_name_map = weight_name_map - # If the model original position is on CPU, set it back to CPU and save GPU memory - if not gm_is_on_cuda: - self.module.to("cpu") + del np_map, sd torch.cuda.empty_cache() From 9b801e7bb6a61f3947ff18a6de561f9275e2219b Mon Sep 17 00:00:00 2001 From: cehongwang Date: Sun, 18 Aug 2024 14:18:30 -0700 Subject: [PATCH 05/11] Fixed accuracy issue of fast refit --- py/torch_tensorrt/dynamo/_refit.py | 17 ++++++++++++++--- tests/py/dynamo/models/test_model_refit.py | 4 ++-- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index f03009c3d8..039426248c 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -117,10 +117,16 @@ def construct_refit_mapping_from_weight_name_map( torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype) if engine_weight_name.split(" ")[-1] in ["SCALE", "SHIFT"]: # Batch Norm Layer - params = {} + params = { + "weight": 1.0, + "bias": 0.0, + "running_mean": 0.0, + "running_var": 1.0, + } for w in sd_weight_name: - params[w.split(".")[-1]] = state_dict[w] - scale = params["weight"] / torch.sqrt(params["running_var"] + 1e-7) + if w in state_dict: + params[w.split(".")[-1]] = state_dict[w] + scale = params["weight"] / torch.sqrt(params["running_var"] + 1e-5) shift = params["bias"] - params["running_mean"] * scale # Set scale to scale or shift to shift engine_weight_map[engine_weight_name] = eval( @@ -171,6 +177,11 @@ def _refit_single_trt_engine_with_gm( mapping = construct_refit_mapping_from_weight_name_map( weight_name_map, new_gm.state_dict() ) + + # Debug Use + # correct = construct_refit_mapping(new_gm, input_list, settings) + # {k: np.allclose(correct[k][0], mapping[k][0].cpu().numpy(), 1e-2, 1e-2) for k in mapping if k in correct} + for layer_name in weight_list: if layer_name not in mapping: logger.warning(f"{layer_name} is not found in weight mapping.") diff --git a/tests/py/dynamo/models/test_model_refit.py b/tests/py/dynamo/models/test_model_refit.py index c642ae0675..a3518e818a 100644 --- a/tests/py/dynamo/models/test_model_refit.py +++ b/tests/py/dynamo/models/test_model_refit.py @@ -91,8 +91,8 @@ def test_mapping(): @pytest.mark.unit def test_refit_one_engine_with_weightmap(): - model = models.resnet18(pretrained=False).eval().to("cuda") - model2 = models.resnet18(pretrained=True).eval().to("cuda") + model = models.resnet152(pretrained=False).eval().to("cuda") + model2 = models.resnet152(pretrained=True).eval().to("cuda") inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] enabled_precisions = {torch.float} debug = False From 6bfd978feaee73f37bf4df05f90c6aac988b4432 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Mon, 19 Aug 2024 10:35:45 -0700 Subject: [PATCH 06/11] Made revision according to comments --- .../dynamo/conversion/_TRTInterpreter.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 780a866a38..f99beca1a2 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -40,7 +40,6 @@ from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, to_torch_device from torch_tensorrt.fx.observer import Observer from torch_tensorrt.logging import TRT_LOGGER -from tqdm import tqdm import tensorrt as trt from packaging import version @@ -341,13 +340,21 @@ def _construct_trt_network_def(self) -> None: @staticmethod def find_weight( - weight_name: str, np_map: dict[str, Any], sd: dict[str, Any] + weight_name: str, np_map: dict[str, Any], state_dict: dict[str, Any] ) -> str: + """ + We need to build map from engine weight name to state_dict weight name. + The purpose of this function is to find the corresponding weight name in module state_dict. + + weight_name: the target weight name we want to search for + np_map: the map from weight name to np values in INetworkDefinition + state_dict: state of the graph module + """ network_weight = np_map[weight_name] network_weight = torch.from_numpy(np_map[weight_name]).cuda() - for sd_w_name, sd_weight in sd.items(): + for sd_w_name, sd_weight in state_dict.items(): if TRTInterpreter.check_weight_equal(sd_weight, network_weight): - del sd[sd_w_name] + del state_dict[sd_w_name] return sd_w_name return "" @@ -475,7 +482,7 @@ def _save_weight_mapping(self) -> None: np_map[engine_weight_name] = weight # Stage 2: Value mapping - for engine_weight_name, sd_weight_name in tqdm(weight_name_map.items()): + for engine_weight_name, sd_weight_name in weight_name_map.items(): if "SCALE" in engine_weight_name: # There is no direct connection in batch_norm layer. So skip it pass From df95f002d80317c82d587c98c54f1570a953363f Mon Sep 17 00:00:00 2001 From: cehongwang Date: Mon, 19 Aug 2024 10:39:01 -0700 Subject: [PATCH 07/11] Fixed style issue --- py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index f99beca1a2..9fef61961b 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -1,3 +1,4 @@ +import gc import io import logging import os @@ -366,7 +367,7 @@ def check_weight_equal( network_weight = torch.from_numpy(network_weight).cuda() try: return sd_weight.shape == network_weight.shape and torch.all( - torch.abs(sd_weight - network_weight) < 0.1 + torch.abs(sd_weight - network_weight) < 0.01 ) except Exception: return torch.all(sd_weight == network_weight) @@ -425,7 +426,7 @@ def _save_weight_mapping(self) -> None: ) } """ - _LOGGER.info("building weight name mapping...") + _LOGGER.info("Building weight name mapping...") # Stage 1: Name mapping sd = self.module.state_dict() torch_device = to_torch_device(self.compilation_settings.device) @@ -501,6 +502,7 @@ def _save_weight_mapping(self) -> None: self.weight_name_map = weight_name_map del np_map, sd + gc.collect() torch.cuda.empty_cache() def run( From df121b9872f4172e6e0aa5520373f0234b46f81a Mon Sep 17 00:00:00 2001 From: cehongwang Date: Thu, 22 Aug 2024 17:26:56 -0700 Subject: [PATCH 08/11] Fixed batchnorm converter change --- .../dynamo/mutable_torchtrt_module_example.py | 4 +-- examples/dynamo/refit_engine_example.py | 4 +-- py/torch_tensorrt/dynamo/_refit.py | 21 ++--------- .../dynamo/conversion/_conversion.py | 2 +- tests/py/dynamo/models/test_model_refit.py | 36 +++++++++---------- 5 files changed, 25 insertions(+), 42 deletions(-) diff --git a/examples/dynamo/mutable_torchtrt_module_example.py b/examples/dynamo/mutable_torchtrt_module_example.py index 84122e074b..a10c0e17ae 100644 --- a/examples/dynamo/mutable_torchtrt_module_example.py +++ b/examples/dynamo/mutable_torchtrt_module_example.py @@ -34,7 +34,7 @@ "make_refitable": True, } -model = models.resnet18(pretrained=False).eval().to("cuda") +model = models.resnet18(pretrained=True).eval().to("cuda") mutable_module = torch_trt.MutableTorchTensorRTModule(model, **settings) # You can use the mutable module just like the original pytorch module. The compilation happens while you first call the mutable module. mutable_module(*inputs) @@ -45,7 +45,7 @@ # %% # Making changes to mutable module can trigger refit or re-compilation. For example, loading a different state_dict and setting new weight values will trigger refit, and adding a module to the model will trigger re-compilation. -model2 = models.resnet18(pretrained=True).eval().to("cuda") +model2 = models.resnet18(pretrained=False).eval().to("cuda") mutable_module.load_state_dict(model2.state_dict()) diff --git a/examples/dynamo/refit_engine_example.py b/examples/dynamo/refit_engine_example.py index c47ed19ebb..c8cd5590d3 100644 --- a/examples/dynamo/refit_engine_example.py +++ b/examples/dynamo/refit_engine_example.py @@ -39,7 +39,7 @@ # Compile the module for the first time and save it. # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -model = models.resnet18(pretrained=False).eval().to("cuda") +model = models.resnet18(pretrained=True).eval().to("cuda") exp_program = torch.export.export(model, tuple(inputs)) enabled_precisions = {torch.float} debug = False @@ -68,7 +68,7 @@ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # Create and compile the updated model -model2 = models.resnet18(pretrained=True).eval().to("cuda") +model2 = models.resnet18(pretrained=False).eval().to("cuda") exp_program2 = torch.export.export(model2, tuple(inputs)) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 039426248c..2f7f22becc 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -115,25 +115,8 @@ def construct_refit_mapping_from_weight_name_map( for engine_weight_name, (sd_weight_name, np_weight_type) in weight_name_map.items(): trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType) torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype) - if engine_weight_name.split(" ")[-1] in ["SCALE", "SHIFT"]: - # Batch Norm Layer - params = { - "weight": 1.0, - "bias": 0.0, - "running_mean": 0.0, - "running_var": 1.0, - } - for w in sd_weight_name: - if w in state_dict: - params[w.split(".")[-1]] = state_dict[w] - scale = params["weight"] / torch.sqrt(params["running_var"] + 1e-5) - shift = params["bias"] - params["running_mean"] * scale - # Set scale to scale or shift to shift - engine_weight_map[engine_weight_name] = eval( - engine_weight_name.split(" ")[-1].lower() - ) - elif sd_weight_name not in state_dict: + if sd_weight_name not in state_dict: # If weights is not in sd, we can leave it unchanged continue else: @@ -180,7 +163,7 @@ def _refit_single_trt_engine_with_gm( # Debug Use # correct = construct_refit_mapping(new_gm, input_list, settings) - # {k: np.allclose(correct[k][0], mapping[k][0].cpu().numpy(), 1e-2, 1e-2) for k in mapping if k in correct} + # comparison = {k: (np.allclose(correct[k][0], mapping[k][0].cpu().numpy(), 1e-2, 1e-2), correct[k][0], mapping[k][0]) for k in mapping if k in correct} for layer_name in weight_list: if layer_name not in mapping: diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index 6c96347239..4cedcb80cb 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -137,7 +137,6 @@ def convert_module( refit_test_engine = runtime.deserialize_cuda_engine( interpreter_result.serialized_engine ) - weight_name_map = interpreter_result.weight_name_map try: _refit_single_trt_engine_with_gm( new_gm=module, @@ -146,6 +145,7 @@ def convert_module( settings=settings, weight_name_map=interpreter_result.weight_name_map, ) + weight_name_map = interpreter_result.weight_name_map except AssertionError: logger.warning("Fast refit test failed. Removing the weight map caching.") diff --git a/tests/py/dynamo/models/test_model_refit.py b/tests/py/dynamo/models/test_model_refit.py index a3518e818a..a63476adf5 100644 --- a/tests/py/dynamo/models/test_model_refit.py +++ b/tests/py/dynamo/models/test_model_refit.py @@ -35,8 +35,8 @@ @pytest.mark.unit def test_mapping(): - model = models.resnet18(pretrained=False).eval().to("cuda") - model2 = models.resnet18(pretrained=True).eval().to("cuda") + model = models.resnet18(pretrained=True).eval().to("cuda") + model2 = models.resnet18(pretrained=False).eval().to("cuda") inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] trt_input = [ torchtrt.Input(i.shape, dtype=torch.float, format=torch.contiguous_format) @@ -91,8 +91,8 @@ def test_mapping(): @pytest.mark.unit def test_refit_one_engine_with_weightmap(): - model = models.resnet152(pretrained=False).eval().to("cuda") - model2 = models.resnet152(pretrained=True).eval().to("cuda") + model = models.resnet18(pretrained=True).eval().to("cuda") + model2 = models.resnet18(pretrained=False).eval().to("cuda") inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] enabled_precisions = {torch.float} debug = False @@ -140,8 +140,8 @@ def test_refit_one_engine_with_weightmap(): @pytest.mark.unit def test_refit_one_engine_no_map_with_weightmap(): - model = models.resnet18(pretrained=False).eval().to("cuda") - model2 = models.resnet18(pretrained=True).eval().to("cuda") + model = models.resnet18(pretrained=True).eval().to("cuda") + model2 = models.resnet18(pretrained=False).eval().to("cuda") inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] enabled_precisions = {torch.float} debug = False @@ -191,8 +191,8 @@ def test_refit_one_engine_no_map_with_weightmap(): @pytest.mark.unit def test_refit_one_engine_with_wrong_weightmap(): - model = models.resnet18(pretrained=False).eval().to("cuda") - model2 = models.resnet18(pretrained=True).eval().to("cuda") + model = models.resnet18(pretrained=True).eval().to("cuda") + model2 = models.resnet18(pretrained=False).eval().to("cuda") inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] enabled_precisions = {torch.float} debug = False @@ -301,8 +301,8 @@ def test_refit_one_engine_bert_with_weightmap(): @pytest.mark.unit def test_refit_one_engine_inline_runtime__with_weightmap(): trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep") - model = models.resnet18(pretrained=False).eval().to("cuda") - model2 = models.resnet18(pretrained=True).eval().to("cuda") + model = models.resnet18(pretrained=True).eval().to("cuda") + model2 = models.resnet18(pretrained=False).eval().to("cuda") inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] enabled_precisions = {torch.float} debug = False @@ -347,8 +347,8 @@ def test_refit_one_engine_inline_runtime__with_weightmap(): @pytest.mark.unit def test_refit_one_engine_python_runtime_with_weightmap(): - model = models.resnet18(pretrained=False).eval().to("cuda") - model2 = models.resnet18(pretrained=True).eval().to("cuda") + model = models.resnet18(pretrained=True).eval().to("cuda") + model2 = models.resnet18(pretrained=False).eval().to("cuda") inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] enabled_precisions = {torch.float} debug = False @@ -467,8 +467,8 @@ def forward(self, x): @pytest.mark.unit def test_refit_one_engine_without_weightmap(): - model = models.resnet18(pretrained=False).eval().to("cuda") - model2 = models.resnet18(pretrained=True).eval().to("cuda") + model = models.resnet18(pretrained=True).eval().to("cuda") + model2 = models.resnet18(pretrained=False).eval().to("cuda") inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] enabled_precisions = {torch.float} debug = False @@ -571,8 +571,8 @@ def test_refit_one_engine_bert_without_weightmap(): @pytest.mark.unit def test_refit_one_engine_inline_runtime_without_weightmap(): trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep") - model = models.resnet18(pretrained=False).eval().to("cuda") - model2 = models.resnet18(pretrained=True).eval().to("cuda") + model = models.resnet18(pretrained=True).eval().to("cuda") + model2 = models.resnet18(pretrained=False).eval().to("cuda") inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] enabled_precisions = {torch.float} debug = False @@ -617,8 +617,8 @@ def test_refit_one_engine_inline_runtime_without_weightmap(): @pytest.mark.unit def test_refit_one_engine_python_runtime_without_weightmap(): - model = models.resnet18(pretrained=False).eval().to("cuda") - model2 = models.resnet18(pretrained=True).eval().to("cuda") + model = models.resnet18(pretrained=True).eval().to("cuda") + model2 = models.resnet18(pretrained=False).eval().to("cuda") inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] enabled_precisions = {torch.float} debug = False From da4542c7d68a456de503c4e0ae883637a8af7b40 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Fri, 23 Aug 2024 13:35:03 -0700 Subject: [PATCH 09/11] Moved output check to the util. Redo refit if fast refit yields incorrect result when varify_output is set to True --- py/torch_tensorrt/dynamo/_refit.py | 19 ++++- .../runtime/_MutableTorchTensorRTModule.py | 55 +++----------- py/torch_tensorrt/dynamo/utils.py | 48 ++++++++++-- tests/py/dynamo/models/test_model_refit.py | 5 +- .../runtime/test_mutable_torchtrt_module.py | 75 ++++++++----------- 5 files changed, 105 insertions(+), 97 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 2f7f22becc..4ce7d0b150 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -34,7 +34,7 @@ TorchTensorRTModule, ) from torch_tensorrt.dynamo.utils import ( - check_output, + check_module_output, get_torch_inputs, set_log_level, to_torch_device, @@ -281,6 +281,7 @@ def refit_module_weights( arg_inputs = [arg_inputs] torch_inputs = get_torch_inputs(arg_inputs, device) + torch_kwarg_inputs: Any = {} if kwarg_inputs: torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device) runtime = trt.Runtime(TRT_LOGGER) @@ -434,6 +435,7 @@ def refit_module_weights( settings=settings, weight_name_map=weight_name_map, ) + except AssertionError as e: # If fast_refit is used and failed, we fall back to regular refit logger.warning(e) @@ -461,7 +463,7 @@ def refit_module_weights( setattr(compiled_module, f"{name}_engine", refitted_engine) if verify_output and arg_inputs is not None: - if check_output( + if check_module_output( new_module=new_gm, refitted_module=compiled_module, arg_inputs=torch_inputs, @@ -469,6 +471,19 @@ def refit_module_weights( ): logger.info("Refitting Succeed!") else: + if weight_name_map: + logger.warning( + "Refitting with weight_name_map yielded incorrect result! The outputs do not match." + ) + return refit_module_weights( + compiled_module, + new_weight_module, + arg_inputs, + kwarg_inputs, + verify_output, + use_weight_map_cache=False, + in_place=in_place, + ) logger.error("Refitting Failed! The outputs do not match.") else: logger.info("Refitting Completed! Output verification skipped.") diff --git a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py index a437057d04..672a7e267d 100644 --- a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py @@ -12,7 +12,11 @@ from torch_tensorrt.dynamo._compiler import compile as dynamo_compile from torch_tensorrt.dynamo._refit import refit_module_weights from torch_tensorrt.dynamo._settings import CompilationSettings -from torch_tensorrt.dynamo.utils import to_torch_device, to_torch_tensorrt_device +from torch_tensorrt.dynamo.utils import ( + check_output_equal, + to_torch_device, + to_torch_tensorrt_device, +) logger = logging.getLogger(__name__) @@ -228,7 +232,7 @@ def update_refit_condition(self) -> None: new_result = self.original_model(*args, **kwargs) self.original_model.cpu() torch.cuda.empty_cache() - if MutableTorchTensorRTModule.check_output_equal(result, new_result): + if check_output_equal(result, new_result): self.refit_state.set_state(RefitFlag.LIVE) return @@ -268,7 +272,12 @@ def refit_gm(self) -> None: ) ) self.gm = refit_module_weights( - self.gm, self.exp_program, use_weight_map_cache=True, in_place=True + self.gm, + self.exp_program, + self.arg_inputs, + self.kwarg_inputs, + use_weight_map_cache=True, + in_place=True, ) self.original_model.cpu() @@ -426,46 +435,6 @@ def __setattr__(self, name: str, value: Any) -> None: else: object.__setattr__(self, name, value) - @staticmethod - def check_output_equal( - output1: Any, - output2: Any, - ) -> bool: - # TODO: Move this to utils when all PRs are merged. This can be used by other functions. - if type(output1) != type(output2): - logger.warning( - "This module does not support using output verification to skip refit. Refit will be performed \ - whenever the state is UNKNOWN" - ) - return False - - if isinstance(output1, torch.Tensor): - if output1.shape != output2.shape: - return False - return torch.allclose(output1, output2, 1e-2, 1e-2) # type: ignore - - elif isinstance(output1, (tuple, list)): - if len(output1) != len(output2): - return False - for a, b in zip(output1, output2): - if not MutableTorchTensorRTModule.check_output_equal(a, b): - return False - return True - - elif isinstance(output1, dict): - if output1.keys() != output2.keys(): - return False - for a, b in zip(output1.values(), output2.values()): - if not MutableTorchTensorRTModule.check_output_equal(a, b): - return False - return True - - logger.warning( - "This module does not support using output verification to skip refit. Refit will be performed \ - whenever the state is UNKNOWN" - ) - return False - @staticmethod def check_inputs_equal( input1: Any, diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 1d7785717b..31b8c40852 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -394,7 +394,7 @@ def function_wrapper(*args: Any, **kwargs: Any) -> Any: return nested_decorator -def check_output( +def check_module_output( new_module: torch.fx.GraphModule, refitted_module: torch.fx.GraphModule, arg_inputs: Any, @@ -403,14 +403,48 @@ def check_output( old_outputs, new_outputs = refitted_module(*arg_inputs), new_module( *arg_inputs, **kwarg_inputs ) - for old_output, new_output in zip(old_outputs, new_outputs): - if isinstance(old_output, torch.Tensor) and isinstance( - new_outputs, torch.Tensor - ): - if not torch.allclose(old_output, new_output, 1e-2, 1e-2): + if type(old_outputs) != type(new_outputs): + logger.warning("The output types are different. Output check is skipped.") + return True + return check_output_equal(old_outputs, new_outputs) + + +def check_output_equal( + output1: Any, + output2: Any, +) -> bool: + + if type(output1) != type(output2): + logger.warning( + "The output types are different. Check_output_equal will always return false." + ) + return False + + if isinstance(output1, torch.Tensor): + if output1.shape != output2.shape: + return False + return torch.allclose(output1, output2, 1e-2, 1e-2) # type: ignore + + elif isinstance(output1, (tuple, list)): + if len(output1) != len(output2): + return False + for a, b in zip(output1, output2): + if not check_output_equal(a, b): + return False + return True + + elif isinstance(output1, dict): + if output1.keys() != output2.keys(): + return False + for a, b in zip(output1.values(), output2.values()): + if not check_output_equal(a, b): return False + return True - return True + logger.warning( + "The output type is not supported to be checked. Check_output_equal will always return false." + ) + return False def get_flat_args_with_check( diff --git a/tests/py/dynamo/models/test_model_refit.py b/tests/py/dynamo/models/test_model_refit.py index a63476adf5..9782cd829c 100644 --- a/tests/py/dynamo/models/test_model_refit.py +++ b/tests/py/dynamo/models/test_model_refit.py @@ -91,8 +91,8 @@ def test_mapping(): @pytest.mark.unit def test_refit_one_engine_with_weightmap(): - model = models.resnet18(pretrained=True).eval().to("cuda") - model2 = models.resnet18(pretrained=False).eval().to("cuda") + model = models.resnet18(pretrained=False).eval().to("cuda") + model2 = models.resnet18(pretrained=True).eval().to("cuda") inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] enabled_precisions = {torch.float} debug = False @@ -117,6 +117,7 @@ def test_refit_one_engine_with_weightmap(): new_weight_module=exp_program2, arg_inputs=inputs, use_weight_map_cache=True, + verify_output=True, ) # Check the output diff --git a/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py b/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py index 593a4322b7..86e7678a66 100644 --- a/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py +++ b/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py @@ -11,10 +11,31 @@ import torchvision.models as models from torch import nn from torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModule import RefitFlag +from torch_tensorrt.dynamo.utils import check_output_equal assertions = unittest.TestCase() +@pytest.mark.unit +def test_check_output_equal(): + torch.manual_seed(0) + a = { + "a": torch.rand(10, 30), + "b": [torch.rand(10, 30), torch.rand(5, 5)], + "c": {"a": torch.rand(10, 30), "b": [torch.rand(10, 30), torch.rand(5, 5)]}, + } + torch.manual_seed(0) + b = { + "a": torch.rand(10, 30), + "b": [torch.rand(10, 30), torch.rand(5, 5)], + "c": {"a": torch.rand(10, 30), "b": [torch.rand(10, 30), torch.rand(5, 5)]}, + } + assertions.assertTrue( + check_output_equal(a, b), + msg=f"test_check_output_equal is not correct.", + ) + + @unittest.skipIf( not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", @@ -31,8 +52,8 @@ def test_resnet18(): "make_refitable": True, } - model = models.resnet18(pretrained=False).eval().to("cuda") - model2 = models.resnet18(pretrained=True).eval().to("cuda") + model = models.resnet18(pretrained=True).eval().to("cuda") + model2 = models.resnet18(pretrained=False).eval().to("cuda") mutable_module = torch_trt.MutableTorchTensorRTModule(model, **compile_spec) mutable_module(*inputs) @@ -44,9 +65,7 @@ def test_resnet18(): # Check the output expected_outputs, refitted_outputs = model2(*inputs), mutable_module(*inputs) assertions.assertTrue( - torch_trt.MutableTorchTensorRTModule.check_output_equal( - expected_outputs, refitted_outputs - ), + check_output_equal(expected_outputs, refitted_outputs), msg=f"The output of saved and reloaded Mutable Module is not correct.", ) @@ -73,7 +92,7 @@ def test_save(): "make_refitable": True, } - model = models.resnet18(pretrained=False).eval().to("cuda") + model = models.resnet18(pretrained=True).eval().to("cuda") mutable_module = torch_trt.MutableTorchTensorRTModule(model, **compile_spec) mutable_module(*inputs) @@ -83,9 +102,7 @@ def test_save(): loaded_outputs, trt_gm_outputs = reload(*inputs), mutable_module(*inputs) assertions.assertTrue( - torch_trt.MutableTorchTensorRTModule.check_output_equal( - loaded_outputs, trt_gm_outputs - ), + check_output_equal(loaded_outputs, trt_gm_outputs), msg=f"The output of saved and reloaded Mutable Module is not correct.", ) @@ -109,7 +126,7 @@ def test_resnet18_modify_attribute(): "make_refitable": True, } - model = models.resnet18(pretrained=False).eval().to("cuda") + model = models.resnet18(pretrained=True).eval().to("cuda") mutable_module = torch_trt.MutableTorchTensorRTModule(model, **compile_spec) mutable_module(*inputs) @@ -150,7 +167,7 @@ def test_resnet18_modify_attribute_no_refit(): "make_refitable": True, } - model = models.resnet18(pretrained=False).eval().to("cuda") + model = models.resnet18(pretrained=True).eval().to("cuda") mutable_module = torch_trt.MutableTorchTensorRTModule(model, **compile_spec) mutable_module(*inputs) @@ -241,9 +258,7 @@ def forward(self, x, b=5, c=None, d=None): *args, **kwargs ) assertions.assertTrue( - torch_trt.MutableTorchTensorRTModule.check_output_equal( - expected_outputs, refitted_outputs - ), + check_output_equal(expected_outputs, refitted_outputs), msg=f"The output of saved and reloaded Mutable Module is not correct.", ) @@ -306,9 +321,7 @@ def set_weights(self): model.cuda() expected_outputs, refitted_outputs = model(*args), mutable_module(*args) assertions.assertTrue( - torch_trt.MutableTorchTensorRTModule.check_output_equal( - expected_outputs, refitted_outputs - ), + check_output_equal(expected_outputs, refitted_outputs), msg=f"The output of saved and reloaded Mutable Module is not correct.", ) @@ -371,9 +384,7 @@ def set_layer(self): model.cuda() # move offloaded model from cpu to cuda expected_outputs, refitted_outputs = model(*args), mutable_module(*args) assertions.assertTrue( - torch_trt.MutableTorchTensorRTModule.check_output_equal( - expected_outputs, refitted_outputs - ), + check_output_equal(expected_outputs, refitted_outputs), msg=f"The output of saved and reloaded Mutable Module is not correct.", ) @@ -443,31 +454,9 @@ def forward(self, x, b=5, c=None, d=None): *args, **kwargs ) assertions.assertTrue( - torch_trt.MutableTorchTensorRTModule.check_output_equal( - expected_outputs, refitted_outputs - ), + check_output_equal(expected_outputs, refitted_outputs), msg=f"The output of saved and reloaded Mutable Module is not correct.", ) # Clean up model env torch._dynamo.reset() - - -@pytest.mark.unit -def test_check_output_equal(): - torch.manual_seed(0) - a = { - "a": torch.rand(10, 30), - "b": [torch.rand(10, 30), torch.rand(5, 5)], - "c": {"a": torch.rand(10, 30), "b": [torch.rand(10, 30), torch.rand(5, 5)]}, - } - torch.manual_seed(0) - b = { - "a": torch.rand(10, 30), - "b": [torch.rand(10, 30), torch.rand(5, 5)], - "c": {"a": torch.rand(10, 30), "b": [torch.rand(10, 30), torch.rand(5, 5)]}, - } - assertions.assertTrue( - torch_trt.MutableTorchTensorRTModule.check_output_equal(a, b), - msg=f"test_check_output_equal is not correct.", - ) From d585a13661556eea5a92f9e40748300325b16023 Mon Sep 17 00:00:00 2001 From: Adrian Wang <123616592+cehongwang@users.noreply.github.com> Date: Wed, 28 Aug 2024 00:02:34 -0400 Subject: [PATCH 10/11] Changed the precision tolerance from 1e-2 to 5e-3 --- py/torch_tensorrt/dynamo/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 31b8c40852..6a99262421 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -423,7 +423,7 @@ def check_output_equal( if isinstance(output1, torch.Tensor): if output1.shape != output2.shape: return False - return torch.allclose(output1, output2, 1e-2, 1e-2) # type: ignore + return torch.allclose(output1, output2, 5e-3, 5e-3) # type: ignore elif isinstance(output1, (tuple, list)): if len(output1) != len(output2): From 762d965110b5deb154cce6981e326c19023a55a7 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Wed, 28 Aug 2024 00:16:46 -0700 Subject: [PATCH 11/11] chore: move threshold values to variables and rebase with main --- py/torch_tensorrt/dynamo/utils.py | 6 +++++- tests/py/dynamo/conversion/harness.py | 14 +++++++------- .../py/dynamo/conversion/test_bitwise_and_aten.py | 5 +++-- .../dynamo/conversion/test_embedding_bag_aten.py | 5 +++-- .../py/dynamo/conversion/test_index_select_aten.py | 5 +++-- 5 files changed, 21 insertions(+), 14 deletions(-) diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 6a99262421..6d74ab61bf 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -22,6 +22,8 @@ COSINE_THRESHOLD = 0.99 DYNAMIC_DIM = -1 +RTOL = 5e-3 +ATOL = 5e-3 class Frameworks(Enum): @@ -412,6 +414,8 @@ def check_module_output( def check_output_equal( output1: Any, output2: Any, + rtol: float = RTOL, + atol: float = ATOL, ) -> bool: if type(output1) != type(output2): @@ -423,7 +427,7 @@ def check_output_equal( if isinstance(output1, torch.Tensor): if output1.shape != output2.shape: return False - return torch.allclose(output1, output2, 5e-3, 5e-3) # type: ignore + return torch.allclose(output1, output2, rtol, atol) # type: ignore elif isinstance(output1, (tuple, list)): if len(output1) != len(output2): diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index 6cdee663e6..df1e4ee934 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -23,7 +23,7 @@ pre_export_lowering, ) from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule -from torch_tensorrt.dynamo.utils import get_torch_inputs +from torch_tensorrt.dynamo.utils import ATOL, RTOL, get_torch_inputs _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -60,8 +60,8 @@ def run_test( mod, inputs, interpreter, - rtol, - atol, + rtol=RTOL, + atol=ATOL, check_dtype=True, pyt_inputs=None, rt_cls=PythonTorchTensorRTModule, @@ -254,8 +254,8 @@ def run_test( self, mod, inputs, - rtol=5e-3, - atol=5e-3, + rtol=RTOL, + atol=ATOL, precision=dtype.f32, check_dtype=True, use_dynamo_tracer=False, @@ -374,8 +374,8 @@ def run_test_with_dynamic_shape( self, mod, input_specs, - rtol=5e-3, - atol=5e-3, + rtol=RTOL, + atol=ATOL, output_dtypes=None, use_dynamo_tracer=False, enable_passes=False, diff --git a/tests/py/dynamo/conversion/test_bitwise_and_aten.py b/tests/py/dynamo/conversion/test_bitwise_and_aten.py index 9cb63f4fdc..a29a8061db 100644 --- a/tests/py/dynamo/conversion/test_bitwise_and_aten.py +++ b/tests/py/dynamo/conversion/test_bitwise_and_aten.py @@ -5,6 +5,7 @@ from torch.export import Dim from torch.testing._internal.common_utils import run_tests from torch_tensorrt import Input +from torch_tensorrt.dynamo.utils import ATOL, RTOL from .harness import DispatchTestCase @@ -152,8 +153,8 @@ def forward(self, lhs_val, rhs_val): torch.testing.assert_close( out, ref, - rtol=5e-3, - atol=5e-3, + rtol=RTOL, + atol=ATOL, equal_nan=True, check_dtype=True, ) diff --git a/tests/py/dynamo/conversion/test_embedding_bag_aten.py b/tests/py/dynamo/conversion/test_embedding_bag_aten.py index 03bae9b68b..d935134ff2 100644 --- a/tests/py/dynamo/conversion/test_embedding_bag_aten.py +++ b/tests/py/dynamo/conversion/test_embedding_bag_aten.py @@ -3,6 +3,7 @@ from parameterized import param, parameterized from torch.testing._internal.common_utils import run_tests from torch_tensorrt import Input +from torch_tensorrt.dynamo.utils import ATOL, RTOL from .harness import DispatchTestCase @@ -501,8 +502,8 @@ def forward(self, weights, indices, offsets, per_sample_weights=None): torch.testing.assert_close( out, ref, - rtol=5e-3, - atol=5e-3, + rtol=RTOL, + atol=ATOL, equal_nan=True, check_dtype=True, ) diff --git a/tests/py/dynamo/conversion/test_index_select_aten.py b/tests/py/dynamo/conversion/test_index_select_aten.py index 839474a0dd..3d0b41b791 100644 --- a/tests/py/dynamo/conversion/test_index_select_aten.py +++ b/tests/py/dynamo/conversion/test_index_select_aten.py @@ -4,6 +4,7 @@ from parameterized import param, parameterized from torch.testing._internal.common_utils import run_tests from torch_tensorrt import Input +from torch_tensorrt.dynamo.utils import ATOL, RTOL from .harness import DispatchTestCase @@ -122,8 +123,8 @@ def forward(self, source_tensor, indice_tensor): torch.testing.assert_close( out, ref, - rtol=5e-3, - atol=5e-3, + rtol=RTOL, + atol=ATOL, equal_nan=True, check_dtype=True, )