Skip to content

Commit 2448e1e

Browse files
authored
use torch.export
1 parent 5298b12 commit 2448e1e

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

docs/how_to/tutorials/e2e_opt_model.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
import os
3535
import numpy as np
3636
import torch
37-
from torch import fx
37+
from torch.export import export
3838
from torchvision.models.resnet import ResNet18_Weights, resnet18
3939

4040
torch_model = resnet18(weights=ResNet18_Weights.DEFAULT)
@@ -63,21 +63,20 @@
6363
# Convert the model to IRModule
6464
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
6565
# Next step, we convert the model to an IRModule using the Relax frontend for PyTorch for further
66-
# optimization. Besides the model, we also need to provide the input shape and data type.
66+
# optimization.
6767

6868
import tvm
6969
from tvm import relax
70-
from tvm.relax.frontend.torch import from_fx
70+
from tvm.relax.frontend.torch import from_exported_program
7171

72-
torch_model = resnet18(weights=ResNet18_Weights.DEFAULT)
73-
74-
# Give the input shape and data type
72+
# Give an example argument to torch.export
73+
example_args = (torch.randn(1, 3, 224, 224),)
7574
input_info = [((1, 3, 224, 224), "float32")]
7675

7776
# Convert the model to IRModule
7877
with torch.no_grad():
79-
torch_fx_model = fx.symbolic_trace(torch_model)
80-
mod = from_fx(torch_fx_model, input_info, keep_params_as_input=True)
78+
exported_program = export(torch_model, example_args)
79+
mod = from_exported_program(exported_program, keep_params_as_input=True)
8180

8281
mod, params = relax.frontend.detach_params(mod)
8382
mod.show()

0 commit comments

Comments
 (0)