|
26 | 26 | from torch.fx.passes.shape_prop import TensorMetadata |
27 | 27 | from torch.utils._python_dispatch import _disable_current_modes |
28 | 28 | from torch_tensorrt._enums import dtype |
| 29 | +from torch_tensorrt._features import needs_refit |
29 | 30 | from torch_tensorrt._Input import Input |
30 | 31 | from torch_tensorrt.dynamo import _defaults |
31 | 32 | from torch_tensorrt.dynamo._engine_cache import BaseEngineCache |
|
44 | 45 | get_trt_tensor, |
45 | 46 | to_torch, |
46 | 47 | ) |
47 | | -from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, get_model_device, to_torch_device |
| 48 | +from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, to_torch_device |
48 | 49 | from torch_tensorrt.fx.observer import Observer |
49 | 50 | from torch_tensorrt.logging import TRT_LOGGER |
50 | 51 |
|
@@ -434,6 +435,7 @@ def check_weight_equal( |
434 | 435 | except Exception: |
435 | 436 | return torch.all(sd_weight == network_weight) |
436 | 437 |
|
| 438 | + @needs_refit |
437 | 439 | def _save_weight_mapping(self) -> None: |
438 | 440 | """ |
439 | 441 | Construct the weight name mapping from engine weight name to state_dict weight name. |
@@ -491,15 +493,10 @@ def _save_weight_mapping(self) -> None: |
491 | 493 | _LOGGER.info("Building weight name mapping...") |
492 | 494 | # Stage 1: Name mapping |
493 | 495 | torch_device = to_torch_device(self.compilation_settings.device) |
494 | | - gm_is_on_cuda = get_model_device(self.module).type == "cuda" |
495 | | - if not gm_is_on_cuda: |
496 | | - # If the model original position is on CPU, move it GPU |
497 | | - sd = { |
498 | | - k: v.reshape(-1).to(torch_device) |
499 | | - for k, v in self.module.state_dict().items() |
500 | | - } |
501 | | - else: |
502 | | - sd = {k: v.reshape(-1) for k, v in self.module.state_dict().items()} |
| 496 | + sd = { |
| 497 | + k: v.reshape(-1).to(torch_device) |
| 498 | + for k, v in self.module.state_dict().items() |
| 499 | + } |
503 | 500 | weight_name_map: dict[str, Any] = {} |
504 | 501 | np_map = {} |
505 | 502 | constant_mapping = {} |
@@ -583,6 +580,7 @@ def _save_weight_mapping(self) -> None: |
583 | 580 | gc.collect() |
584 | 581 | torch.cuda.empty_cache() |
585 | 582 |
|
| 583 | + @needs_refit |
586 | 584 | def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> None: |
587 | 585 | # TODO: @Evan is waiting for TRT's feature to cache the weight-stripped engine |
588 | 586 | # if not self.compilation_settings.strip_engine_weights: |
@@ -610,6 +608,7 @@ def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> No |
610 | 608 | ), |
611 | 609 | ) |
612 | 610 |
|
| 611 | + @needs_refit |
613 | 612 | def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]: |
614 | 613 | # query the cached TRT engine |
615 | 614 | cached_data = self.engine_cache.check(hash_val) # type: ignore[union-attr] |
@@ -720,7 +719,7 @@ def run( |
720 | 719 | if self.compilation_settings.reuse_cached_engines: |
721 | 720 | interpreter_result = self._pull_cached_engine(hash_val) |
722 | 721 | if interpreter_result is not None: # hit the cache |
723 | | - return interpreter_result |
| 722 | + return interpreter_result # type: ignore[no-any-return] |
724 | 723 |
|
725 | 724 | self._construct_trt_network_def() |
726 | 725 |
|
|
0 commit comments