|
34 | 34 | import os |
35 | 35 | import numpy as np |
36 | 36 | import torch |
37 | | -from torch import fx |
| 37 | +from torch.export import export |
38 | 38 | from torchvision.models.resnet import ResNet18_Weights, resnet18 |
39 | 39 |
|
40 | 40 | torch_model = resnet18(weights=ResNet18_Weights.DEFAULT) |
|
63 | 63 | # Convert the model to IRModule |
64 | 64 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
65 | 65 | # Next step, we convert the model to an IRModule using the Relax frontend for PyTorch for further |
66 | | -# optimization. Besides the model, we also need to provide the input shape and data type. |
| 66 | +# optimization. |
67 | 67 |
|
68 | 68 | import tvm |
69 | 69 | from tvm import relax |
70 | | -from tvm.relax.frontend.torch import from_fx |
| 70 | +from tvm.relax.frontend.torch import from_exported_program |
71 | 71 |
|
72 | | -torch_model = resnet18(weights=ResNet18_Weights.DEFAULT) |
73 | | - |
74 | | -# Give the input shape and data type |
| 72 | +# Give an example argument to torch.export |
| 73 | +example_args = (torch.randn(1, 3, 224, 224),) |
75 | 74 | input_info = [((1, 3, 224, 224), "float32")] |
76 | 75 |
|
77 | 76 | # Convert the model to IRModule |
78 | 77 | with torch.no_grad(): |
79 | | - torch_fx_model = fx.symbolic_trace(torch_model) |
80 | | - mod = from_fx(torch_fx_model, input_info, keep_params_as_input=True) |
| 78 | + exported_program = export(torch_model, example_args) |
| 79 | + mod = from_exported_program(exported_program, keep_params_as_input=True) |
81 | 80 |
|
82 | 81 | mod, params = relax.frontend.detach_params(mod) |
83 | 82 | mod.show() |
|
0 commit comments