|
1 | 1 | import json |
2 | 2 | import os |
3 | 3 |
|
4 | | -import custom_models as cm |
5 | | -import timm |
6 | 4 | import torch |
7 | | -import torch.nn as nn |
8 | | -import torch.nn.functional as F |
9 | | -import torchvision.models as models |
10 | | -from transformers import BertConfig, BertModel, BertTokenizer |
11 | 5 |
|
12 | 6 | torch.hub._validate_not_a_forked_repo = lambda a, b, c: True |
13 | 7 |
|
|
26 | 20 | VALID_PATHS = ("script", "trace", "torchscript", "pytorch", "all") |
27 | 21 |
|
28 | 22 | # Key models selected for benchmarking with their respective paths |
29 | | -BENCHMARK_MODELS = { |
30 | | - "vgg16": { |
31 | | - "model": models.vgg16(weights=models.VGG16_Weights.DEFAULT), |
32 | | - "path": ["script", "pytorch"], |
33 | | - }, |
34 | | - "resnet50": { |
35 | | - "model": models.resnet50(weights=None), |
36 | | - "path": ["script", "pytorch"], |
37 | | - }, |
38 | | - "efficientnet_b0": { |
39 | | - "model": timm.create_model("efficientnet_b0", pretrained=True), |
40 | | - "path": ["script", "pytorch"], |
41 | | - }, |
42 | | - "vit": { |
43 | | - "model": timm.create_model("vit_base_patch16_224", pretrained=True), |
44 | | - "path": ["script", "pytorch"], |
45 | | - }, |
46 | | - "bert_base_uncased": {"model": cm.BertModule(), "path": "trace"}, |
47 | | -} |
| 23 | +from utils import BENCHMARK_MODELS |
48 | 24 |
|
49 | 25 |
|
50 | 26 | def get(n, m, manifest): |
51 | 27 | print("Downloading {}".format(n)) |
52 | 28 | traced_filename = "models/" + n + "_traced.jit.pt" |
53 | 29 | script_filename = "models/" + n + "_scripted.jit.pt" |
54 | 30 | pytorch_filename = "models/" + n + "_pytorch.pt" |
55 | | - x = torch.ones((1, 3, 300, 300)).cuda() |
56 | | - if n == "bert_base_uncased": |
57 | | - traced_model = m["model"] |
58 | | - torch.jit.save(traced_model, traced_filename) |
| 31 | + |
| 32 | + m["model"] = m["model"].eval().cuda() |
| 33 | + |
| 34 | + # Get all desired model save specifications as list |
| 35 | + paths = [m["path"]] if isinstance(m["path"], str) else m["path"] |
| 36 | + |
| 37 | + # Depending on specified model save specifications, save desired model formats |
| 38 | + if any(path in ("all", "torchscript", "trace") for path in paths): |
| 39 | + # (TorchScript) Traced model |
| 40 | + trace_model = torch.jit.trace(m["model"], [inp.cuda() for inp in m["inputs"]]) |
| 41 | + torch.jit.save(trace_model, traced_filename) |
59 | 42 | manifest.update({n: [traced_filename]}) |
60 | | - else: |
61 | | - m["model"] = m["model"].eval().cuda() |
62 | | - |
63 | | - # Get all desired model save specifications as list |
64 | | - paths = [m["path"]] if isinstance(m["path"], str) else m["path"] |
65 | | - |
66 | | - # Depending on specified model save specifications, save desired model formats |
67 | | - if any(path in ("all", "torchscript", "trace") for path in paths): |
68 | | - # (TorchScript) Traced model |
69 | | - trace_model = torch.jit.trace(m["model"], [x]) |
70 | | - torch.jit.save(trace_model, traced_filename) |
71 | | - manifest.update({n: [traced_filename]}) |
72 | | - if any(path in ("all", "torchscript", "script") for path in paths): |
73 | | - # (TorchScript) Scripted model |
74 | | - script_model = torch.jit.script(m["model"]) |
75 | | - torch.jit.save(script_model, script_filename) |
76 | | - if n in manifest.keys(): |
77 | | - files = list(manifest[n]) if type(manifest[n]) != list else manifest[n] |
78 | | - files.append(script_filename) |
79 | | - manifest.update({n: files}) |
80 | | - else: |
81 | | - manifest.update({n: [script_filename]}) |
82 | | - if any(path in ("all", "pytorch") for path in paths): |
83 | | - # (PyTorch Module) model |
84 | | - torch.save(m["model"], pytorch_filename) |
85 | | - if n in manifest.keys(): |
86 | | - files = list(manifest[n]) if type(manifest[n]) != list else manifest[n] |
87 | | - files.append(script_filename) |
88 | | - manifest.update({n: files}) |
89 | | - else: |
90 | | - manifest.update({n: [script_filename]}) |
| 43 | + if any(path in ("all", "torchscript", "script") for path in paths): |
| 44 | + # (TorchScript) Scripted model |
| 45 | + script_model = torch.jit.script(m["model"]) |
| 46 | + torch.jit.save(script_model, script_filename) |
| 47 | + if n in manifest.keys(): |
| 48 | + files = list(manifest[n]) if type(manifest[n]) != list else manifest[n] |
| 49 | + files.append(script_filename) |
| 50 | + manifest.update({n: files}) |
| 51 | + else: |
| 52 | + manifest.update({n: [script_filename]}) |
| 53 | + if any(path in ("all", "pytorch") for path in paths): |
| 54 | + # (PyTorch Module) model |
| 55 | + torch.save(m["model"], pytorch_filename) |
| 56 | + if n in manifest.keys(): |
| 57 | + files = list(manifest[n]) if type(manifest[n]) != list else manifest[n] |
| 58 | + files.append(script_filename) |
| 59 | + manifest.update({n: files}) |
| 60 | + else: |
| 61 | + manifest.update({n: [script_filename]}) |
| 62 | + |
91 | 63 | return manifest |
92 | 64 |
|
93 | 65 |
|
|
0 commit comments