Skip to content

Commit e7f8c09

Browse files
committed
chore: additional fixes
1 parent 29d6c23 commit e7f8c09

File tree

4 files changed

+87
-17
lines changed

4 files changed

+87
-17
lines changed

.github/workflows/build-test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ jobs:
142142
${CONDA_RUN} python -m pip install --pre pytest timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
143143
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_fe_test_results.xml --ir dynamo models/test_models_export.py
144144
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_models_export.xml --ir dynamo models/test_dyn_models.py
145+
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/output_format.xml --ir dynamo models/test_output_format.py
145146
popd
146147
147148
tests-py-dynamo-serde:

docsrc/user_guide/saving_models.rst

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,10 @@ The `output_format` can take the following options
2222
* `torchscript` (or) `ts` : This returns a TorchScript module
2323
* `graph_module` (or) `fx` : This returns a torch.fx.GraphModule which can be traced into Torchscript to save to disk.
2424

25-
a) Converting to Torchscript
25+
a) Torchscript
2626
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2727

28-
`torch.fx.GraphModule` objects cannot be serialized directly. Hence we use `torch.jit.trace` to convert this into a `ScriptModule` object which can be saved to disk.
29-
The following code illustrates this approach.
28+
If you set the `output_format="torchscript"`, this will return a `ScriptModule` which can be serialized via torch.jit.save
3029

3130
.. code-block:: python
3231
@@ -35,6 +34,7 @@ The following code illustrates this approach.
3534
3635
model = MyModel().eval().cuda()
3736
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
37+
# trt_ts is a torch.jit.ScriptModule object
3838
trt_ts = torch_tensorrt.compile(model, ir="dynamo", inputs, output_format="torchscript")
3939
torch.jit.save(trt_ts, "trt_model.ts")
4040
@@ -45,8 +45,7 @@ The following code illustrates this approach.
4545
b) ExportedProgram
4646
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
4747

48-
`torch.export.ExportedProgram` is a new format introduced in Pytorch 2.1. After we compile a Pytorch module using Torch-TensorRT, the resultant
49-
`torch.fx.GraphModule` along with additional metadata can be used to create `ExportedProgram` which can be saved and loaded from disk.
48+
`torch.export.ExportedProgram`, a new format introduced in Pytorch 2.X is the default return type of Torch-TensorRT compilation.
5049

5150
.. code-block:: python
5251
@@ -55,24 +54,36 @@ b) ExportedProgram
5554
5655
model = MyModel().eval().cuda()
5756
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
58-
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) # Output is a torch.fx.GraphModule
59-
torch.export.save(trt_exp_program, "trt_model.ep")
57+
# trt_ep is a torch.export.ExportedProgram object
58+
trt_ep = torch_tensorrt.compile(model, ir="dynamo", inputs)
59+
torch.export.save(trt_ep, "trt_model.ep")
6060
6161
# Later, you can load it and run inference
6262
model = torch.export.load("trt_model.ep")
6363
model(*inputs)
6464
65-
`torch_tensorrt.dynamo.export` inlines the submodules within a GraphModule to their corresponding nodes and stiches all the nodes together.
66-
This is needed as `torch._export` serialization cannot handle serializing and deserializing of submodules (`call_module` nodes).
65+
c) GraphModule
66+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
67+
68+
We can also return a `torch.fx.GraphModule` object as the output of Torch-TensorRT compilation by setting `output_format="graph_module"`.
69+
Internally, partitioning, lowering, conversion phases operate using GraphModule objects. These can be either traced into a Torchscript modules or
70+
exported into `ExportedProgram` objects
71+
72+
.. code-block:: python
6773
68-
.. note:: This way of saving the models using `ExportedProgram` is experimental. Here is a known issue : https://github.com/pytorch/TensorRT/issues/2341
74+
import torch
75+
import torch_tensorrt
6976
77+
model = MyModel().eval().cuda()
78+
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
79+
# trt_gm is a torch.fx.GraphModule object
80+
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs, output_format="graph_module")
7081
7182
Torchscript IR
7283
-------------
7384

7485
In Torch-TensorRT 1.X versions, the primary way to compile and run inference with Torch-TensorRT is using Torchscript IR.
75-
This behavior stays the same in 2.X versions as well.
86+
For `ir=ts`, this behavior stays the same in 2.X versions as well.
7687

7788
.. code-block:: python
7889

py/torch_tensorrt/dynamo/_exporter.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import copy
21
import operator
32
from typing import Any, Dict, Sequence, Tuple, cast
43

@@ -86,21 +85,23 @@ def lift(gm: torch.fx.GraphModule, graph_signature: Any) -> torch.fx.GraphModule
8685
# Get the state_dict of graph_module. This is different from exported_program.state_dict
8786
# exp_program.state_dict contains parameters and buffers whereas a graph_module's state_dict
8887
# has all parameters registered as torch.tensors.
89-
state_dict = copy.deepcopy(gm.state_dict())
88+
state_dict = gm.state_dict()
9089

9190
fake_mode = detect_fake_mode(
9291
tuple(node.meta["val"] for node in gm.graph.nodes if node.op == "placeholder")
9392
)
9493
assert fake_mode is not None
9594

9695
# Locate the user input to insert new placeholders before them
97-
first_user_input_loc, first_user_input = 0, None
96+
first_user_input = None
9897
for node in gm.graph.nodes:
9998
if node.op == "placeholder" and node.name in graph_signature.user_inputs:
10099
first_user_input = node
101100
break
102-
first_user_input_loc += 1
103101

102+
# At first the user_inputs are only present in the graph_signature.input_specs and hence non_user_input_idx=0
103+
# The input_specs should be of the form [params, buffers, constant_tensors, user_inputs]
104+
non_user_input_idx = 0
104105
for node in gm.graph.nodes:
105106
if node.op == "get_attr":
106107
constant_tensor = getattr(gm, node.target)
@@ -130,14 +131,14 @@ def lift(gm: torch.fx.GraphModule, graph_signature: Any) -> torch.fx.GraphModule
130131
# Add these parameters/buffers/constants to the existing graph signature
131132
# before user inputs. These specs are looked up in the state_dict during ExportedProgram creation.
132133
graph_signature.input_specs.insert(
133-
first_user_input_loc,
134+
non_user_input_idx,
134135
InputSpec(
135136
kind=input_kind,
136137
arg=TensorArgument(name=const_placeholder_node.name),
137138
target=node.target,
138139
),
139140
)
140-
first_user_input_loc += 1
141+
non_user_input_idx += 1
141142

142143
gm.graph.eliminate_dead_code()
143144
gm.graph.lint()
@@ -257,6 +258,7 @@ def create_trt_exp_program(
257258
"""Creates a new Exported Program. This function takes an torch.fx.GraphModule which has TRT engines
258259
and constructs an Exported Program object with the new IO node names and state_dict
259260
"""
261+
260262
input_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"]
261263
output_nodes = [node for node in gm.graph.nodes if node.op == "output"]
262264
assert output_nodes

tests/py/dynamo/models/test_export_serde.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,62 @@ def forward(self, x):
116116
)
117117

118118

119+
@pytest.mark.unit
120+
def test_no_compile(ir):
121+
"""
122+
This tests export serde functionality on a model
123+
which won't convert to TRT because of min_block_size=5 constraint
124+
"""
125+
126+
class MyModule(torch.nn.Module):
127+
def __init__(self):
128+
super().__init__()
129+
self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
130+
self.relu = torch.nn.ReLU()
131+
132+
def forward(self, x):
133+
conv = self.conv(x)
134+
conv = conv * 0.5
135+
relu = self.relu(conv)
136+
return conv, relu
137+
138+
model = MyModule().eval().cuda()
139+
input = torch.randn((1, 3, 224, 224)).to("cuda")
140+
141+
compile_spec = {
142+
"inputs": [
143+
torchtrt.Input(
144+
input.shape, dtype=torch.float, format=torch.contiguous_format
145+
)
146+
],
147+
"ir": ir,
148+
"debug": True,
149+
}
150+
151+
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
152+
trt_exp_program = torchtrt.dynamo.compile(exp_program, **compile_spec)
153+
torch.export.save(trt_exp_program, "/tmp/trt.ep")
154+
deser_trt_exp_program = torch.export.load("/tmp/trt.ep")
155+
# Check Pyt and TRT exported program outputs
156+
outputs_pyt = model(input)
157+
outputs_trt = trt_exp_program(input)
158+
for idx in range(len(outputs_pyt)):
159+
cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt[idx])
160+
assertions.assertTrue(
161+
cos_sim > COSINE_THRESHOLD,
162+
msg=f"test_no_compile TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
163+
)
164+
165+
# Check Pyt and deserialized TRT exported program outputs
166+
outputs_trt_deser = deser_trt_exp_program(input)
167+
for idx in range(len(outputs_pyt)):
168+
cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
169+
assertions.assertTrue(
170+
cos_sim > COSINE_THRESHOLD,
171+
msg=f"test_no_compile deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
172+
)
173+
174+
119175
@pytest.mark.unit
120176
def test_hybrid_relu_fallback(ir):
121177
"""

0 commit comments

Comments
 (0)