Skip to content

Conversation

cehongwang
Copy link
Collaborator

@cehongwang cehongwang commented Jul 3, 2024

Description

Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.

Fixes # (issue)

Type of change

Please delete options that are not relevant and/or add your own.

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

Refit state transition graph

image

@github-actions github-actions bot added component: api [Python] Issues re: Python API component: runtime component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Jul 3, 2024
@github-actions github-actions bot requested a review from narendasan July 3, 2024 00:33
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/mutable_torchtrt_module_example.py	2024-07-03 00:33:30.637263+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/mutable_torchtrt_module_example.py	2024-07-03 00:35:20.326993+00:00
@@ -36,20 +36,18 @@

# %%
# Compile the module for the first time and save it.
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
kwargs = {
-    'use_python': False,
-    'enabled_precisions': {torch.float16, torch.float32},
-    'mutable': True
+    "use_python": False,
+    "enabled_precisions": {torch.float16, torch.float32},
+    "mutable": True,
}

model = models.resnet18(pretrained=False).eval().to("cuda")
model2 = models.resnet18(pretrained=True).eval().to("cuda")
-mutable_module = torch_trt.compile(inputs=inputs, 
-                                   module=model, 
-                                   **kwargs)
+mutable_module = torch_trt.compile(inputs=inputs, module=model, **kwargs)

# Save the graph module as an exported program
# This is only supported when use_python_runtime = False
mutable_module.load_state_dict(model2.state_dict())

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py	2024-07-03 00:33:30.645263+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py	2024-07-03 00:35:20.875148+00:00
@@ -9,11 +9,13 @@
import torch.fx
from torch_tensorrt._enums import dtype
from torch_tensorrt._features import ENABLED_FEATURES
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo import _defaults
-from torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModule import MutableTorchTensorRTModule
+from torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModule import (
+    MutableTorchTensorRTModule,
+)

from torch_tensorrt.fx import InputTensorSpec
from torch_tensorrt.fx.lower import compile as fx_compile
from torch_tensorrt.fx.utils import LowerPrecision
from typing_extensions import TypeGuard
@@ -238,15 +240,17 @@
            **kwargs,
        )
        return compiled_fx_module
    elif target_ir == _IRType.dynamo:
        if kwargs["mutable"]:
-            mutable_trt_graph_module = MutableTorchTensorRTModule(module, input_list, enabled_precisions_set, **kwargs)
+            mutable_trt_graph_module = MutableTorchTensorRTModule(
+                module, input_list, enabled_precisions_set, **kwargs
+            )
            mutable_trt_graph_module.compile()
            return mutable_trt_graph_module
        else:
-        # Prepare torch and torchtrt inputs
+            # Prepare torch and torchtrt inputs
            from torch_tensorrt.dynamo.utils import prepare_inputs

            if not isinstance(input_list, collections.abc.Sequence):
                input_list = [input_list]

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py	2024-07-03 00:33:30.653263+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py	2024-07-03 00:35:21.749591+00:00
@@ -7,10 +7,11 @@
import torch_tensorrt as torch_trt
import torch
from torch_tensorrt.dynamo.utils import prepare_inputs
from torch_tensorrt.dynamo._tracer import trace as dynamo_trace
from torch_tensorrt.dynamo._compiler import compile as dynamo_compile
+

class RefitFlag:
    def __init__(self):
        self.flag = False

@@ -20,12 +21,15 @@

    def set_off(self):
        self.flag = False
        print("RefitFlag is set to OFF.")

+
class MutableTorchTensorRTModule(object):
-    def __init__(self, pytorch_model, sample_inputs, enabled_precisions_set, **kwargs) -> None:
+    def __init__(
+        self, pytorch_model, sample_inputs, enabled_precisions_set, **kwargs
+    ) -> None:
        self.refit_flag = RefitFlag()
        self.original_inputs = sample_inputs
        if not isinstance(sample_inputs, collections.abc.Sequence):
            sample_inputs = [sample_inputs]
        self.sample_inputs = tuple(sample_inputs)
@@ -40,84 +44,89 @@
        self.refit_flag.set_on()
        self.pytorch_model.load_state_dict(sd)

    def refit_gm(self):
        if self.exp_program is None:
-            self.exp_program = torch.export.export(self.pytorch_model, self.sample_inputs)
+            self.exp_program = torch.export.export(
+                self.pytorch_model, self.sample_inputs
+            )
        # TODO: Check refit condition and fallback to recompile
-        self.exp_program._state_dict = MutableTorchTensorRTModule._transform_state_dict(self.pytorch_model.state_dict())
+        self.exp_program._state_dict = MutableTorchTensorRTModule._transform_state_dict(
+            self.pytorch_model.state_dict()
+        )
        self.gm = refit_module_weights(self.gm, self.exp_program, self.sample_inputs)
-        

    def compile(self):
-        
+
        # Export the module
-        self.exp_program = dynamo_trace(self.original_model, self.torchtrt_inputs, **self.kwargs)
+        self.exp_program = dynamo_trace(
+            self.original_model, self.torchtrt_inputs, **self.kwargs
+        )
        self.gm = dynamo_compile(
            self.exp_program,
            inputs=self.torchtrt_inputs,
            enabled_precisions=self.enabled_precisions_set,
            make_refitable=True,
            **self.kwargs,
        )
-        
+
    def _transform_state_dict(sd):
        return {k: torch.nn.Parameter(v, requires_grad=False) for k, v in sd.items()}
-        
+
    def __getattr__(self, name):
-    
+
        if name in self.__dict__:
            # this object has it
            return getattr(self, name)

        return getattr(self.pytorch_model, name)
-        
+
        # raise AttributeError(f"'{type(self.pytorch_model)}' object has no attribute '{name}'")

    def __call__(self, *args, **kwargs):
        # We can update this once the kwarg pull request got merged
        return self.forward(*args, **kwargs)
-    
+
    def forward(self, *args, **kwargs):
        # TODO: Check the inputs is the same as the sample input
        if self.refit_flag.flag:
            print("Model weight change detected. Refitting the module...")
            self.refit_flag.set_off()
            self.refit_gm()

        return self.gm(*args, **kwargs)
-    
-    


def _make_refit_change_trigger(obj: Any, refit_flag: RefitFlag) -> Any:

    class ChangeTriggerWrapper(obj.__class__):
        def __init__(self, obj: Any):
-            object.__setattr__(self, 'instance', obj)
+            object.__setattr__(self, "instance", obj)

        def __getattr__(self, name: str):
            # This will cause infinte loop if there is a cycle
            obj = getattr(self.instance, name)
-            if not hasattr(obj, '__dict__'):
-                return obj 
+            if not hasattr(obj, "__dict__"):
+                return obj
            else:
                return _make_refit_change_trigger(obj, refit_flag)
-            
-        def __setattr__(self, name:str, value: Any):
+
+        def __setattr__(self, name: str, value: Any):
            self._on_change()
            setattr(self.instance, name, value)

        def __delattr__(self, name: str):
            self._on_change()
-            delattr(self.instance, name, )
+            delattr(
+                self.instance,
+                name,
+            )

        def _on_change(self):
            refit_flag.set_on()
            print("Change!")

        def __call__(self, *args: Any, **kwargs: Any) -> Any:
            print("Warning: uncatched change in function!")
            self.instance(*args, **kwargs)

    return ChangeTriggerWrapper(obj)
-

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py	2024-07-05 18:47:36.308343+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py	2024-07-05 18:49:28.137706+00:00
@@ -9,11 +9,13 @@
import torch.fx
from torch_tensorrt._enums import dtype
from torch_tensorrt._features import ENABLED_FEATURES
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo import _defaults
-from torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModule import MutableTorchTensorRTModule
+from torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModule import (
+    MutableTorchTensorRTModule,
+)

from torch_tensorrt.fx import InputTensorSpec
from torch_tensorrt.fx.lower import compile as fx_compile
from torch_tensorrt.fx.utils import LowerPrecision
from typing_extensions import TypeGuard
@@ -238,15 +240,17 @@
            **kwargs,
        )
        return compiled_fx_module
    elif target_ir == _IRType.dynamo:
        if kwargs["mutable"]:
-            mutable_trt_graph_module = MutableTorchTensorRTModule(module, input_list, enabled_precisions_set, **kwargs)
+            mutable_trt_graph_module = MutableTorchTensorRTModule(
+                module, input_list, enabled_precisions_set, **kwargs
+            )
            mutable_trt_graph_module.compile()
            return mutable_trt_graph_module
        else:
-        # Prepare torch and torchtrt inputs
+            # Prepare torch and torchtrt inputs
            from torch_tensorrt.dynamo.utils import prepare_inputs

            if not isinstance(input_list, collections.abc.Sequence):
                input_list = [input_list]

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py	2024-07-05 19:03:35.588909+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py	2024-07-05 19:05:25.168514+00:00
@@ -9,11 +9,13 @@
import torch.fx
from torch_tensorrt._enums import dtype
from torch_tensorrt._features import ENABLED_FEATURES
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo import _defaults
-from torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModule import MutableTorchTensorRTModule
+from torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModule import (
+    MutableTorchTensorRTModule,
+)

from torch_tensorrt.fx import InputTensorSpec
from torch_tensorrt.fx.lower import compile as fx_compile
from torch_tensorrt.fx.utils import LowerPrecision
from typing_extensions import TypeGuard
@@ -238,15 +240,17 @@
            **kwargs,
        )
        return compiled_fx_module
    elif target_ir == _IRType.dynamo:
        if kwargs["mutable"]:
-            mutable_trt_graph_module = MutableTorchTensorRTModule(module, input_list, enabled_precisions_set, **kwargs)
+            mutable_trt_graph_module = MutableTorchTensorRTModule(
+                module, input_list, enabled_precisions_set, **kwargs
+            )
            mutable_trt_graph_module.compile()
            return mutable_trt_graph_module
        else:
-        # Prepare torch and torchtrt inputs
+            # Prepare torch and torchtrt inputs
            from torch_tensorrt.dynamo.utils import prepare_inputs

            if not isinstance(input_list, collections.abc.Sequence):
                input_list = [input_list]

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py	2024-07-05 20:13:17.684296+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py	2024-07-05 20:15:07.402520+00:00
@@ -9,11 +9,13 @@
import torch.fx
from torch_tensorrt._enums import dtype
from torch_tensorrt._features import ENABLED_FEATURES
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo import _defaults
-from torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModule import MutableTorchTensorRTModule
+from torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModule import (
+    MutableTorchTensorRTModule,
+)

from torch_tensorrt.fx import InputTensorSpec
from torch_tensorrt.fx.lower import compile as fx_compile
from torch_tensorrt.fx.utils import LowerPrecision
from typing_extensions import TypeGuard
@@ -238,15 +240,17 @@
            **kwargs,
        )
        return compiled_fx_module
    elif target_ir == _IRType.dynamo:
        if kwargs["mutable"]:
-            mutable_trt_graph_module = MutableTorchTensorRTModule(module, input_list, enabled_precisions_set, **kwargs)
+            mutable_trt_graph_module = MutableTorchTensorRTModule(
+                module, input_list, enabled_precisions_set, **kwargs
+            )
            mutable_trt_graph_module.compile()
            return mutable_trt_graph_module
        else:
-        # Prepare torch and torchtrt inputs
+            # Prepare torch and torchtrt inputs
            from torch_tensorrt.dynamo.utils import prepare_inputs

            if not isinstance(input_list, collections.abc.Sequence):
                input_list = [input_list]

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py	2024-07-05 21:05:16.080984+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py	2024-07-05 21:07:04.224364+00:00
@@ -9,11 +9,13 @@
import torch.fx
from torch_tensorrt._enums import dtype
from torch_tensorrt._features import ENABLED_FEATURES
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo import _defaults
-from torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModule import MutableTorchTensorRTModule
+from torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModule import (
+    MutableTorchTensorRTModule,
+)

from torch_tensorrt.fx import InputTensorSpec
from torch_tensorrt.fx.lower import compile as fx_compile
from torch_tensorrt.fx.utils import LowerPrecision
from typing_extensions import TypeGuard
@@ -238,15 +240,17 @@
            **kwargs,
        )
        return compiled_fx_module
    elif target_ir == _IRType.dynamo:
        if kwargs["mutable"]:
-            mutable_trt_graph_module = MutableTorchTensorRTModule(module, input_list, enabled_precisions_set, **kwargs)
+            mutable_trt_graph_module = MutableTorchTensorRTModule(
+                module, input_list, enabled_precisions_set, **kwargs
+            )
            mutable_trt_graph_module.compile()
            return mutable_trt_graph_module
        else:
-        # Prepare torch and torchtrt inputs
+            # Prepare torch and torchtrt inputs
            from torch_tensorrt.dynamo.utils import prepare_inputs

            if not isinstance(input_list, collections.abc.Sequence):
                input_list = [input_list]

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py	2024-07-08 19:50:34.267992+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py	2024-07-08 19:52:23.991166+00:00
@@ -9,11 +9,13 @@
import torch.fx
from torch_tensorrt._enums import dtype
from torch_tensorrt._features import ENABLED_FEATURES
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo import _defaults
-from torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModule import MutableTorchTensorRTModule
+from torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModule import (
+    MutableTorchTensorRTModule,
+)

from torch_tensorrt.fx import InputTensorSpec
from torch_tensorrt.fx.lower import compile as fx_compile
from torch_tensorrt.fx.utils import LowerPrecision
from typing_extensions import TypeGuard
@@ -238,15 +240,17 @@
            **kwargs,
        )
        return compiled_fx_module
    elif target_ir == _IRType.dynamo:
        if kwargs["mutable"]:
-            mutable_trt_graph_module = MutableTorchTensorRTModule(module, input_list, enabled_precisions_set, **kwargs)
+            mutable_trt_graph_module = MutableTorchTensorRTModule(
+                module, input_list, enabled_precisions_set, **kwargs
+            )
            mutable_trt_graph_module.compile()
            return mutable_trt_graph_module
        else:
-        # Prepare torch and torchtrt inputs
+            # Prepare torch and torchtrt inputs
            from torch_tensorrt.dynamo.utils import prepare_inputs

            if not isinstance(input_list, collections.abc.Sequence):
                input_list = [input_list]

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py	2024-07-08 19:59:37.315433+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py	2024-07-08 20:02:01.355058+00:00
@@ -9,11 +9,13 @@
import torch.fx
from torch_tensorrt._enums import dtype
from torch_tensorrt._features import ENABLED_FEATURES
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo import _defaults
-from torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModule import MutableTorchTensorRTModule
+from torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModule import (
+    MutableTorchTensorRTModule,
+)

from torch_tensorrt.fx import InputTensorSpec
from torch_tensorrt.fx.lower import compile as fx_compile
from torch_tensorrt.fx.utils import LowerPrecision
from typing_extensions import TypeGuard
@@ -238,15 +240,17 @@
            **kwargs,
        )
        return compiled_fx_module
    elif target_ir == _IRType.dynamo:
        if kwargs["mutable"]:
-            mutable_trt_graph_module = MutableTorchTensorRTModule(module, input_list, enabled_precisions_set, **kwargs)
+            mutable_trt_graph_module = MutableTorchTensorRTModule(
+                module, input_list, enabled_precisions_set, **kwargs
+            )
            mutable_trt_graph_module.compile()
            return mutable_trt_graph_module
        else:
-        # Prepare torch and torchtrt inputs
+            # Prepare torch and torchtrt inputs
            from torch_tensorrt.dynamo.utils import prepare_inputs

            if not isinstance(input_list, collections.abc.Sequence):
                input_list = [input_list]

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py	2024-07-08 20:32:36.364774+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py	2024-07-08 20:34:26.453620+00:00
@@ -9,11 +9,13 @@
import torch.fx
from torch_tensorrt._enums import dtype
from torch_tensorrt._features import ENABLED_FEATURES
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo import _defaults
-from torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModule import MutableTorchTensorRTModule
+from torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModule import (
+    MutableTorchTensorRTModule,
+)

from torch_tensorrt.fx import InputTensorSpec
from torch_tensorrt.fx.lower import compile as fx_compile
from torch_tensorrt.fx.utils import LowerPrecision
from typing_extensions import TypeGuard
@@ -238,15 +240,17 @@
            **kwargs,
        )
        return compiled_fx_module
    elif target_ir == _IRType.dynamo:
        if kwargs["mutable"]:
-            mutable_trt_graph_module = MutableTorchTensorRTModule(module, input_list, enabled_precisions_set, **kwargs)
+            mutable_trt_graph_module = MutableTorchTensorRTModule(
+                module, input_list, enabled_precisions_set, **kwargs
+            )
            mutable_trt_graph_module.compile()
            return mutable_trt_graph_module
        else:
-        # Prepare torch and torchtrt inputs
+            # Prepare torch and torchtrt inputs
            from torch_tensorrt.dynamo.utils import prepare_inputs

            if not isinstance(input_list, collections.abc.Sequence):
                input_list = [input_list]

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py	2024-07-08 20:35:49.249381+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py	2024-07-08 20:37:40.201769+00:00
@@ -473,6 +473,6 @@

                with enable_torchbind_tracing():
                    exp_program = torch.export.export(
                        module, tuple(inputs), strict=False
                    )
-                    torch.export.save(exp_program, file_path)
\ No newline at end of file
+                    torch.export.save(exp_program, file_path)

@cehongwang cehongwang requested a review from narendasan July 8, 2024 20:39
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py	2024-07-08 20:39:58.446119+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py	2024-07-08 20:41:48.798417+00:00
@@ -473,6 +473,6 @@

                with enable_torchbind_tracing():
                    exp_program = torch.export.export(
                        module, tuple(inputs), strict=False
                    )
-                    torch.export.save(exp_program, file_path)
\ No newline at end of file
+                    torch.export.save(exp_program, file_path)

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/mutable_torchtrt_module_stable_diffusion_example.py	2024-07-11 21:29:42.735735+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/mutable_torchtrt_module_stable_diffusion_example.py	2024-07-11 21:31:37.121983+00:00
@@ -52,23 +52,21 @@

pipe = DiffusionPipeline.from_pretrained(
    model_id, revision="fp16", torch_dtype=torch.float16, safety_checker=None
)
pipe.to(device)
-backend = "torch_tensorrt" 
+backend = "torch_tensorrt"

pipe.unet = torch_trt.MutableTorchTensorRTModule(pipe.unet, **kwargs)
image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0]
-image.save('./without_LoRA.jpg')
+image.save("./without_LoRA.jpg")


pipe.load_lora_weights("/opt/torch_tensorrt/moxin.safetensors", adapter_name="lora1")
pipe.set_adapters(["lora1"], adapter_weights=[1])
-pipe.fuse_lora(['lora1'], 1)
+pipe.fuse_lora(["lora1"], 1)
pipe.unload_lora_weights()


# Check the output
image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0]
-image.save('./with_LoRA.jpg')
-
-
+image.save("./with_LoRA.jpg")

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/mutable_torchtrt_module_stable_diffusion_example.py	2024-07-11 21:32:22.202416+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/mutable_torchtrt_module_stable_diffusion_example.py	2024-07-11 21:34:14.063863+00:00
@@ -52,23 +52,21 @@

pipe = DiffusionPipeline.from_pretrained(
    model_id, revision="fp16", torch_dtype=torch.float16, safety_checker=None
)
pipe.to(device)
-backend = "torch_tensorrt" 
+backend = "torch_tensorrt"

pipe.unet = torch_trt.MutableTorchTensorRTModule(pipe.unet, **kwargs)
image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0]
-image.save('./without_LoRA.jpg')
+image.save("./without_LoRA.jpg")


pipe.load_lora_weights("/opt/torch_tensorrt/moxin.safetensors", adapter_name="lora1")
pipe.set_adapters(["lora1"], adapter_weights=[1])
-pipe.fuse_lora(['lora1'], 1)
+pipe.fuse_lora(["lora1"], 1)
pipe.unload_lora_weights()


# Check the output
image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0]
-image.save('./with_LoRA.jpg')
-
-
+image.save("./with_LoRA.jpg")

Copy link
Collaborator

@narendasan narendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE,
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
**kwargs: Any,
) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a docstr

Copy link
Collaborator

@narendasan narendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests:

  1. Basic E2E workflow (ingest model and run)
  2. Serialization for both rts (C++ and python) - build, save and load
  3. Test refitting behavior (changing weights via attr)
  4. Test recompile ???

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py	2024-08-13 21:55:55.416655+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py	2024-08-13 22:00:22.024407+00:00
@@ -532,6 +532,6 @@

                with enable_torchbind_tracing():
                    exp_program = torch.export.export(
                        module, tuple(arg_inputs), kwargs=kwarg_inputs, strict=False
                    )
-                    torch.export.save(exp_program, file_path)
\ No newline at end of file
+                    torch.export.save(exp_program, file_path)

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py	2024-08-13 21:58:09.155212+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py	2024-08-13 22:05:03.968754+00:00
@@ -532,6 +532,6 @@

                with enable_torchbind_tracing():
                    exp_program = torch.export.export(
                        module, tuple(arg_inputs), kwargs=kwarg_inputs, strict=False
                    )
-                    torch.export.save(exp_program, file_path)
\ No newline at end of file
+                    torch.export.save(exp_program, file_path)

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py	2024-08-13 22:22:50.002715+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py	2024-08-13 22:24:35.279571+00:00
@@ -532,6 +532,6 @@

                with enable_torchbind_tracing():
                    exp_program = torch.export.export(
                        module, tuple(arg_inputs), kwargs=kwarg_inputs, strict=False
                    )
-                    torch.export.save(exp_program, file_path)
\ No newline at end of file
+                    torch.export.save(exp_program, file_path)

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py	2024-08-13 22:39:11.275361+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py	2024-08-13 22:41:42.248930+00:00
@@ -532,6 +532,6 @@

                with enable_torchbind_tracing():
                    exp_program = torch.export.export(
                        module, tuple(arg_inputs), kwargs=kwarg_inputs, strict=False
                    )
-                    torch.export.save(exp_program, file_path)
\ No newline at end of file
+                    torch.export.save(exp_program, file_path)

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py	2024-08-14 23:52:00.233996+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py	2024-08-14 23:53:45.828886+00:00
@@ -532,6 +532,6 @@

                with enable_torchbind_tracing():
                    exp_program = torch.export.export(
                        module, tuple(arg_inputs), kwargs=kwarg_inputs, strict=False
                    )
-                    torch.export.save(exp_program, file_path)
\ No newline at end of file
+                    torch.export.save(exp_program, file_path)

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py	2024-08-15 00:10:53.392627+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py	2024-08-15 00:16:07.394087+00:00
@@ -532,6 +532,6 @@

                with enable_torchbind_tracing():
                    exp_program = torch.export.export(
                        module, tuple(arg_inputs), kwargs=kwarg_inputs, strict=False
                    )
-                    torch.export.save(exp_program, file_path)
\ No newline at end of file
+                    torch.export.save(exp_program, file_path)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py	2024-08-15 00:10:53.396627+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py	2024-08-15 00:16:08.728715+00:00
@@ -372,11 +372,11 @@
        return result

    def to(self, device: str):
        logger.warning("Original PyTorch model is moved. CPU offload may failed.")
        self.orignial_model.to(device)
-        
+
    def __deepcopy__(self, memo: Any) -> Any:
        cls = self.__class__
        result = cls.__new__(cls)
        memo[id(self)] = result
        for k, v in self.__dict__.items():

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py	2024-08-15 18:26:39.178114+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py	2024-08-15 18:28:37.121097+00:00
@@ -532,6 +532,6 @@

                with enable_torchbind_tracing():
                    exp_program = torch.export.export(
                        module, tuple(arg_inputs), kwargs=kwarg_inputs, strict=False
                    )
-                    torch.export.save(exp_program, file_path)
\ No newline at end of file
+                    torch.export.save(exp_program, file_path)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py	2024-08-15 18:26:39.186114+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py	2024-08-15 18:28:38.508444+00:00
@@ -372,11 +372,11 @@
        return result

    def to(self, device: str):
        logger.warning("Original PyTorch model is moved. CPU offload may failed.")
        self.orignial_model.to(device)
-        
+
    def __deepcopy__(self, memo: Any) -> Any:
        cls = self.__class__
        result = cls.__new__(cls)
        memo[id(self)] = result
        for k, v in self.__dict__.items():

@cehongwang cehongwang merged commit cee4914 into main Aug 15, 2024
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py	2024-08-15 19:35:17.585819+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py	2024-08-15 19:37:02.052771+00:00
@@ -532,6 +532,6 @@

                with enable_torchbind_tracing():
                    exp_program = torch.export.export(
                        module, tuple(arg_inputs), kwargs=kwarg_inputs, strict=False
                    )
-                    torch.export.save(exp_program, file_path)
\ No newline at end of file
+                    torch.export.save(exp_program, file_path)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: runtime component: tests Issues re: Tests documentation Improvements or additions to documentation
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants