|
| 1 | +import importlib |
1 | 2 | import json
|
2 | 3 | import os
|
3 | 4 |
|
4 | 5 | import custom_models as cm
|
5 |
| -import timm |
6 | 6 | import torch
|
7 |
| -import torchvision.models as models |
| 7 | + |
| 8 | +if importlib.util.find_spec("torchvision"): |
| 9 | + import timm |
| 10 | + import torchvision.models as models |
8 | 11 |
|
9 | 12 | torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
|
10 | 13 |
|
|
19 | 22 | # Downloads all model files again if manifest file is not present
|
20 | 23 | MANIFEST_FILE = "model_manifest.json"
|
21 | 24 |
|
22 |
| -models = { |
23 |
| - "alexnet": {"model": models.alexnet(pretrained=True), "path": "both"}, |
24 |
| - "vgg16": {"model": models.vgg16(pretrained=True), "path": "both"}, |
25 |
| - "squeezenet": {"model": models.squeezenet1_0(pretrained=True), "path": "both"}, |
26 |
| - "densenet": {"model": models.densenet161(pretrained=True), "path": "both"}, |
27 |
| - "inception_v3": {"model": models.inception_v3(pretrained=True), "path": "both"}, |
28 |
| - "shufflenet": {"model": models.shufflenet_v2_x1_0(pretrained=True), "path": "both"}, |
29 |
| - "mobilenet_v2": {"model": models.mobilenet_v2(pretrained=True), "path": "both"}, |
30 |
| - "resnext50_32x4d": { |
31 |
| - "model": models.resnext50_32x4d(pretrained=True), |
32 |
| - "path": "both", |
33 |
| - }, |
34 |
| - "wideresnet50_2": { |
35 |
| - "model": models.wide_resnet50_2(pretrained=True), |
36 |
| - "path": "both", |
37 |
| - }, |
38 |
| - "mnasnet": {"model": models.mnasnet1_0(pretrained=True), "path": "both"}, |
39 |
| - "resnet18": { |
40 |
| - "model": torch.hub.load("pytorch/vision:v0.9.0", "resnet18", pretrained=True), |
41 |
| - "path": "both", |
42 |
| - }, |
43 |
| - "resnet50": { |
44 |
| - "model": torch.hub.load("pytorch/vision:v0.9.0", "resnet50", pretrained=True), |
45 |
| - "path": "both", |
46 |
| - }, |
47 |
| - "efficientnet_b0": { |
48 |
| - "model": timm.create_model("efficientnet_b0", pretrained=True), |
49 |
| - "path": "script", |
50 |
| - }, |
51 |
| - "vit": { |
52 |
| - "model": timm.create_model("vit_base_patch16_224", pretrained=True), |
53 |
| - "path": "script", |
54 |
| - }, |
55 |
| - "pooling": {"model": cm.Pool(), "path": "trace"}, |
56 |
| - "module_fallback": {"model": cm.ModuleFallbackMain(), "path": "script"}, |
57 |
| - "loop_fallback_eval": {"model": cm.LoopFallbackEval(), "path": "script"}, |
58 |
| - "loop_fallback_no_eval": {"model": cm.LoopFallbackNoEval(), "path": "script"}, |
59 |
| - "conditional": {"model": cm.FallbackIf(), "path": "script"}, |
60 |
| - "inplace_op_if": {"model": cm.FallbackInplaceOPIf(), "path": "script"}, |
61 |
| - "standard_tensor_input": {"model": cm.StandardTensorInput(), "path": "script"}, |
62 |
| - "tuple_input": {"model": cm.TupleInput(), "path": "script"}, |
63 |
| - "list_input": {"model": cm.ListInput(), "path": "script"}, |
64 |
| - "tuple_input_output": {"model": cm.TupleInputOutput(), "path": "script"}, |
65 |
| - "list_input_output": {"model": cm.ListInputOutput(), "path": "script"}, |
66 |
| - "list_input_tuple_output": {"model": cm.ListInputTupleOutput(), "path": "script"}, |
67 |
| - # "bert_base_uncased": {"model": cm.BertModule(), "path": "trace"}, |
68 |
| -} |
| 25 | +models = {} |
69 | 26 |
|
70 | 27 |
|
71 | 28 | def get(n, m, manifest):
|
@@ -120,6 +77,68 @@ def download_models(version_matches, manifest):
|
120 | 77 |
|
121 | 78 |
|
122 | 79 | def main():
|
| 80 | + if not importlib.util.find_spec("torchvision"): |
| 81 | + print(f"torchvision is not installed, skip models download") |
| 82 | + return |
| 83 | + |
| 84 | + models = { |
| 85 | + "alexnet": {"model": models.alexnet(pretrained=True), "path": "both"}, |
| 86 | + "vgg16": {"model": models.vgg16(pretrained=True), "path": "both"}, |
| 87 | + "squeezenet": {"model": models.squeezenet1_0(pretrained=True), "path": "both"}, |
| 88 | + "densenet": {"model": models.densenet161(pretrained=True), "path": "both"}, |
| 89 | + "inception_v3": {"model": models.inception_v3(pretrained=True), "path": "both"}, |
| 90 | + "shufflenet": { |
| 91 | + "model": models.shufflenet_v2_x1_0(pretrained=True), |
| 92 | + "path": "both", |
| 93 | + }, |
| 94 | + "mobilenet_v2": {"model": models.mobilenet_v2(pretrained=True), "path": "both"}, |
| 95 | + "resnext50_32x4d": { |
| 96 | + "model": models.resnext50_32x4d(pretrained=True), |
| 97 | + "path": "both", |
| 98 | + }, |
| 99 | + "wideresnet50_2": { |
| 100 | + "model": models.wide_resnet50_2(pretrained=True), |
| 101 | + "path": "both", |
| 102 | + }, |
| 103 | + "mnasnet": {"model": models.mnasnet1_0(pretrained=True), "path": "both"}, |
| 104 | + "resnet18": { |
| 105 | + "model": torch.hub.load( |
| 106 | + "pytorch/vision:v0.9.0", "resnet18", pretrained=True |
| 107 | + ), |
| 108 | + "path": "both", |
| 109 | + }, |
| 110 | + "resnet50": { |
| 111 | + "model": torch.hub.load( |
| 112 | + "pytorch/vision:v0.9.0", "resnet50", pretrained=True |
| 113 | + ), |
| 114 | + "path": "both", |
| 115 | + }, |
| 116 | + "efficientnet_b0": { |
| 117 | + "model": timm.create_model("efficientnet_b0", pretrained=True), |
| 118 | + "path": "script", |
| 119 | + }, |
| 120 | + "vit": { |
| 121 | + "model": timm.create_model("vit_base_patch16_224", pretrained=True), |
| 122 | + "path": "script", |
| 123 | + }, |
| 124 | + "pooling": {"model": cm.Pool(), "path": "trace"}, |
| 125 | + "module_fallback": {"model": cm.ModuleFallbackMain(), "path": "script"}, |
| 126 | + "loop_fallback_eval": {"model": cm.LoopFallbackEval(), "path": "script"}, |
| 127 | + "loop_fallback_no_eval": {"model": cm.LoopFallbackNoEval(), "path": "script"}, |
| 128 | + "conditional": {"model": cm.FallbackIf(), "path": "script"}, |
| 129 | + "inplace_op_if": {"model": cm.FallbackInplaceOPIf(), "path": "script"}, |
| 130 | + "standard_tensor_input": {"model": cm.StandardTensorInput(), "path": "script"}, |
| 131 | + "tuple_input": {"model": cm.TupleInput(), "path": "script"}, |
| 132 | + "list_input": {"model": cm.ListInput(), "path": "script"}, |
| 133 | + "tuple_input_output": {"model": cm.TupleInputOutput(), "path": "script"}, |
| 134 | + "list_input_output": {"model": cm.ListInputOutput(), "path": "script"}, |
| 135 | + "list_input_tuple_output": { |
| 136 | + "model": cm.ListInputTupleOutput(), |
| 137 | + "path": "script", |
| 138 | + }, |
| 139 | + # "bert_base_uncased": {"model": cm.BertModule(), "path": "trace"}, |
| 140 | + } |
| 141 | + |
123 | 142 | manifest = None
|
124 | 143 | version_matches = False
|
125 | 144 | manifest_exists = False
|
|
0 commit comments