Skip to content

Commit 2335ece

Browse files
committed
Rebased to main after refit acceleration is merged.
1 parent bd685ae commit 2335ece

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

examples/dynamo/mutable_torchtrt_module_example.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
import torch_tensorrt as torch_trt
2222
import torchvision.models as models
2323

24-
np.random.seed(0)
25-
torch.manual_seed(0)
24+
np.random.seed(5)
25+
torch.manual_seed(5)
2626
inputs = [torch.rand((1, 3, 224, 224)).to("cuda")]
2727

2828
# %%
@@ -76,7 +76,7 @@
7676
from diffusers import DiffusionPipeline
7777

7878
with torch.no_grad():
79-
kwargs = {
79+
settings = {
8080
"use_python_runtime": True,
8181
"enabled_precisions": {torch.float16},
8282
"debug": True,
@@ -95,7 +95,7 @@
9595
pipe.to(device)
9696

9797
# The only extra line you need
98-
pipe.unet = torch_trt.MutableTorchTensorRTModule(pipe.unet, **kwargs)
98+
pipe.unet = torch_trt.MutableTorchTensorRTModule(pipe.unet, **settings)
9999

100100
image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0]
101101
image.save("./without_LoRA_mutable.jpg")

py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,9 @@ def refit_gm(self) -> None:
267267
self.original_model.state_dict()
268268
)
269269
)
270-
self.gm = refit_module_weights(self.gm, self.exp_program)
270+
self.gm = refit_module_weights(
271+
self.gm, self.exp_program, use_weight_map_cache=True, in_place=True
272+
)
271273

272274
self.original_model.cpu()
273275
torch.cuda.empty_cache()

0 commit comments

Comments
 (0)