|
1 | 1 | import importlib |
| 2 | +import importlib.util |
2 | 3 | import json |
3 | 4 | import os |
4 | 5 |
|
5 | 6 | import custom_models as cm |
6 | 7 | import torch |
7 | 8 |
|
8 | | -if importlib.util.find_spec("torchvision"): |
9 | | - import timm |
10 | | - import torchvision.models as models |
11 | | - |
12 | 9 | torch.hub._validate_not_a_forked_repo = lambda a, b, c: True |
13 | 10 |
|
| 11 | + |
14 | 12 | torch_version = torch.__version__ |
15 | 13 |
|
16 | 14 | # Detect case of no GPU before deserialization of models on GPU |
|
22 | 20 | # Downloads all model files again if manifest file is not present |
23 | 21 | MANIFEST_FILE = "model_manifest.json" |
24 | 22 |
|
25 | | -to_test_models = {} |
| 23 | +to_test_models = { |
| 24 | + "pooling": {"model": cm.Pool(), "path": "trace"}, |
| 25 | + "module_fallback": {"model": cm.ModuleFallbackMain(), "path": "script"}, |
| 26 | + "loop_fallback_eval": {"model": cm.LoopFallbackEval(), "path": "script"}, |
| 27 | + "loop_fallback_no_eval": {"model": cm.LoopFallbackNoEval(), "path": "script"}, |
| 28 | + "conditional": {"model": cm.FallbackIf(), "path": "script"}, |
| 29 | + "inplace_op_if": {"model": cm.FallbackInplaceOPIf(), "path": "script"}, |
| 30 | + "standard_tensor_input": {"model": cm.StandardTensorInput(), "path": "script"}, |
| 31 | + "tuple_input": {"model": cm.TupleInput(), "path": "script"}, |
| 32 | + "list_input": {"model": cm.ListInput(), "path": "script"}, |
| 33 | + "tuple_input_output": {"model": cm.TupleInputOutput(), "path": "script"}, |
| 34 | + "list_input_output": {"model": cm.ListInputOutput(), "path": "script"}, |
| 35 | + "list_input_tuple_output": { |
| 36 | + "model": cm.ListInputTupleOutput(), |
| 37 | + "path": "script", |
| 38 | + }, |
| 39 | + # "bert_base_uncased": {"model": cm.BertModule(), "path": "trace"}, |
| 40 | +} |
| 41 | + |
| 42 | +if importlib.util.find_spec("torchvision"): |
| 43 | + import timm |
| 44 | + import torchvision.models as models |
| 45 | + |
| 46 | + torchvision_models = { |
| 47 | + "alexnet": {"model": models.alexnet(pretrained=True), "path": "both"}, |
| 48 | + "vgg16": {"model": models.vgg16(pretrained=True), "path": "both"}, |
| 49 | + "squeezenet": {"model": models.squeezenet1_0(pretrained=True), "path": "both"}, |
| 50 | + "densenet": {"model": models.densenet161(pretrained=True), "path": "both"}, |
| 51 | + "inception_v3": {"model": models.inception_v3(pretrained=True), "path": "both"}, |
| 52 | + "shufflenet": { |
| 53 | + "model": models.shufflenet_v2_x1_0(pretrained=True), |
| 54 | + "path": "both", |
| 55 | + }, |
| 56 | + "mobilenet_v2": {"model": models.mobilenet_v2(pretrained=True), "path": "both"}, |
| 57 | + "resnext50_32x4d": { |
| 58 | + "model": models.resnext50_32x4d(pretrained=True), |
| 59 | + "path": "both", |
| 60 | + }, |
| 61 | + "wideresnet50_2": { |
| 62 | + "model": models.wide_resnet50_2(pretrained=True), |
| 63 | + "path": "both", |
| 64 | + }, |
| 65 | + "mnasnet": {"model": models.mnasnet1_0(pretrained=True), "path": "both"}, |
| 66 | + "resnet18": { |
| 67 | + "model": torch.hub.load( |
| 68 | + "pytorch/vision:v0.9.0", "resnet18", pretrained=True |
| 69 | + ), |
| 70 | + "path": "both", |
| 71 | + }, |
| 72 | + "resnet50": { |
| 73 | + "model": torch.hub.load( |
| 74 | + "pytorch/vision:v0.9.0", "resnet50", pretrained=True |
| 75 | + ), |
| 76 | + "path": "both", |
| 77 | + }, |
| 78 | + "efficientnet_b0": { |
| 79 | + "model": timm.create_model("efficientnet_b0", pretrained=True), |
| 80 | + "path": "script", |
| 81 | + }, |
| 82 | + "vit": { |
| 83 | + "model": timm.create_model("vit_base_patch16_224", pretrained=True), |
| 84 | + "path": "script", |
| 85 | + }, |
| 86 | + } |
| 87 | + to_test_models.update(torchvision_models) |
26 | 88 |
|
27 | 89 |
|
28 | 90 | def get(n, m, manifest): |
@@ -77,67 +139,6 @@ def download_models(version_matches, manifest): |
77 | 139 |
|
78 | 140 |
|
79 | 141 | def main(): |
80 | | - if not importlib.util.find_spec("torchvision"): |
81 | | - print(f"torchvision is not installed, skip models download") |
82 | | - return |
83 | | - |
84 | | - to_test_models = { |
85 | | - "alexnet": {"model": models.alexnet(pretrained=True), "path": "both"}, |
86 | | - "vgg16": {"model": models.vgg16(pretrained=True), "path": "both"}, |
87 | | - "squeezenet": {"model": models.squeezenet1_0(pretrained=True), "path": "both"}, |
88 | | - "densenet": {"model": models.densenet161(pretrained=True), "path": "both"}, |
89 | | - "inception_v3": {"model": models.inception_v3(pretrained=True), "path": "both"}, |
90 | | - "shufflenet": { |
91 | | - "model": models.shufflenet_v2_x1_0(pretrained=True), |
92 | | - "path": "both", |
93 | | - }, |
94 | | - "mobilenet_v2": {"model": models.mobilenet_v2(pretrained=True), "path": "both"}, |
95 | | - "resnext50_32x4d": { |
96 | | - "model": models.resnext50_32x4d(pretrained=True), |
97 | | - "path": "both", |
98 | | - }, |
99 | | - "wideresnet50_2": { |
100 | | - "model": models.wide_resnet50_2(pretrained=True), |
101 | | - "path": "both", |
102 | | - }, |
103 | | - "mnasnet": {"model": models.mnasnet1_0(pretrained=True), "path": "both"}, |
104 | | - "resnet18": { |
105 | | - "model": torch.hub.load( |
106 | | - "pytorch/vision:v0.9.0", "resnet18", pretrained=True |
107 | | - ), |
108 | | - "path": "both", |
109 | | - }, |
110 | | - "resnet50": { |
111 | | - "model": torch.hub.load( |
112 | | - "pytorch/vision:v0.9.0", "resnet50", pretrained=True |
113 | | - ), |
114 | | - "path": "both", |
115 | | - }, |
116 | | - "efficientnet_b0": { |
117 | | - "model": timm.create_model("efficientnet_b0", pretrained=True), |
118 | | - "path": "script", |
119 | | - }, |
120 | | - "vit": { |
121 | | - "model": timm.create_model("vit_base_patch16_224", pretrained=True), |
122 | | - "path": "script", |
123 | | - }, |
124 | | - "pooling": {"model": cm.Pool(), "path": "trace"}, |
125 | | - "module_fallback": {"model": cm.ModuleFallbackMain(), "path": "script"}, |
126 | | - "loop_fallback_eval": {"model": cm.LoopFallbackEval(), "path": "script"}, |
127 | | - "loop_fallback_no_eval": {"model": cm.LoopFallbackNoEval(), "path": "script"}, |
128 | | - "conditional": {"model": cm.FallbackIf(), "path": "script"}, |
129 | | - "inplace_op_if": {"model": cm.FallbackInplaceOPIf(), "path": "script"}, |
130 | | - "standard_tensor_input": {"model": cm.StandardTensorInput(), "path": "script"}, |
131 | | - "tuple_input": {"model": cm.TupleInput(), "path": "script"}, |
132 | | - "list_input": {"model": cm.ListInput(), "path": "script"}, |
133 | | - "tuple_input_output": {"model": cm.TupleInputOutput(), "path": "script"}, |
134 | | - "list_input_output": {"model": cm.ListInputOutput(), "path": "script"}, |
135 | | - "list_input_tuple_output": { |
136 | | - "model": cm.ListInputTupleOutput(), |
137 | | - "path": "script", |
138 | | - }, |
139 | | - # "bert_base_uncased": {"model": cm.BertModule(), "path": "trace"}, |
140 | | - } |
141 | 142 |
|
142 | 143 | manifest = None |
143 | 144 | version_matches = False |
|
0 commit comments