Skip to content

Commit 01c6e0d

Browse files
committed
Add pretrained weights for MobileNet v2 and v3
* Adds pre-trained weights for MobileNet v3 small and large models * Adds pre-trained weights for MobileNet v2 * Adds, commented, lines to the model_list in the port_torchvision.jl script for generating these weights from torchvision models * Replaces the test image (of a guitar) in the pytorch2flux.jl script, the link is currently unavailable
1 parent 75a7159 commit 01c6e0d

File tree

6 files changed

+32
-12
lines changed

6 files changed

+32
-12
lines changed

Artifacts.toml

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,27 @@
1+
[mobilenet_v2]
2+
git-tree-sha1 = "1552eb3a913b8ea964482bd68a8fd6e623f57b45"
3+
lazy = true
4+
5+
[[mobilenet_v2.download]]
6+
sha256 = "24d8afb8897cb82825c51e65c7a02f4b2117581294de25ac7fd5222bbc15ae6e"
7+
url = ""
8+
9+
[mobilenet_v3_small]
10+
git-tree-sha1 = "9b1f0550051c731056cda7739fcdd115c36c04ad"
11+
lazy = true
12+
13+
[[mobilenet_v3_small.download]]
14+
sha256 = "35b7f6d733bfbd1621349a0ec27391c2714add7ca80006fb4032d8bc66629c97"
15+
url = ""
16+
17+
[mobilenet_v3_large]
18+
git-tree-sha1 = "49971ff8327bc591885e78ff94140ee472f77329"
19+
lazy = true
20+
21+
[[mobilenet_v3_large.download]]
22+
sha256 = "555fcb5f4f6574d77b603b2fc6672ab437ef60a1a20bc2f951122c91aaaf2f69"
23+
url = ""
24+
125
[resnet101]
226
git-tree-sha1 = "68d563526ab34d3e5aa66b7d96278d2acde212f9"
327
lazy = true

scripts/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@ ArtifactUtils = "8b73e784-e7d8-4ea5-973d-377fed4e3bce"
33
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
44
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
55
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
6+
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
67
HuggingFaceApi = "3cc741c3-0c9d-4fbe-84fa-cdec264173de"
78
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
89
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
910
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1011
Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
1112
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
1213
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
14+
TestImages = "5e47fb64-e119-507b-a336-dd2b206d9990"

scripts/port_torchvision.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ model_list = [
3333
# ("vit_b_32", "IMAGENET1K_V1", () -> ViT(:base, patch_size=(32,32)), weights -> tvmodels.vit_b_32(; weights)),
3434
# ("vit_l_16", "IMAGENET1K_V1", () -> ViT(:large), weights -> tvmodels.vit_l_16(; weights)),
3535
# ("vit_l_32", "IMAGENET1K_V1", () -> ViT(:large, patch_size=(32,32)), weights -> tvmodels.vit_l_32(; weights)),
36+
# ("mobilenet_v2", "IMAGENET1K_V2", () -> MobileNetv2(), weights -> tvmodels.mobilenet_v2(; weights)),
37+
# ("mobilenet_v3_small", "IMAGENET1K_V1", () -> MobileNetv3(:small), weights -> tvmodels.mobilenet_v3_small(; weights)),
38+
# ("mobilenet_v3_large", "IMAGENET1K_V2", () -> MobileNetv3(:large), weights -> tvmodels.mobilenet_v3_large(; weights)),
3639
## NOT WORKING:
3740
("densenet121", "IMAGENET1K_V1", () -> DenseNet(121), weights -> tvmodels.densenet121(; weights)),
3841
# ("squeezenet1_0", "IMAGENET1K_V1", () -> SqueezeNet(), weights -> tvmodels.squeezenet1_0(; weights)),

scripts/pytorch2flux.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,18 @@ using BSON
1010
using PythonCall
1111
using Images
1212
using Test
13+
using TestImages
1314

1415
include("utils.jl")
1516

1617
const torch = pyimport("torch")
1718
const torchvision = pyimport("torchvision")
1819

19-
# test image
20-
const GUITAR_PATH = download("https://cdn.pixabay.com/photo/2015/05/07/11/02/guitar-756326_960_720.jpg")
2120
const IMAGENET_LABELS = readlines(download("https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"))
2221

2322
function compare_pytorch(jlmodel, pymodel; rtol = 1e-4)
2423
sz = (224, 224)
25-
img = Images.load(GUITAR_PATH);
24+
img = testimage("monarch_color_256")
2625
img = imresize(img, sz);
2726
# CHW -> WHC
2827
data = permutedims(convert(Array{Float32}, channelview(img)), (3,2,1))

src/convnets/mobilenets/mobilenetv2.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,6 @@ Create a MobileNetv2 model with the specified configuration.
5757
- `inchannels`: The number of input channels.
5858
- `nclasses`: The number of output classes
5959
60-
!!! warning
61-
62-
`MobileNetv2` does not currently support pretrained weights.
63-
6460
See also [`Metalhead.mobilenetv2`](@ref).
6561
"""
6662
struct MobileNetv2

src/convnets/mobilenets/mobilenetv3.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,6 @@ Set `pretrain = true` to load the model with pre-trained weights for ImageNet.
7878
- `inchannels`: number of input channels
7979
- `nclasses`: the number of output classes
8080
81-
!!! warning
82-
83-
`MobileNetv3` does not currently support pretrained weights.
84-
8581
See also [`Metalhead.mobilenetv3`](@ref).
8682
"""
8783
struct MobileNetv3
@@ -94,7 +90,7 @@ function MobileNetv3(config::Symbol; width_mult::Real = 1, pretrain::Bool = fals
9490
layers = mobilenetv3(config; width_mult, inchannels, nclasses)
9591
model = MobileNetv3(layers)
9692
if pretrain
97-
loadpretrain!(model, "mobilenet_v3")
93+
loadpretrain!(model, string("mobilenet_v3_", config))
9894
end
9995
return model
10096
end

0 commit comments

Comments
 (0)