diff --git a/Artifacts.toml b/Artifacts.toml index edd81b16..e4eb2ec8 100644 --- a/Artifacts.toml +++ b/Artifacts.toml @@ -1,3 +1,27 @@ +[mobilenet_v2] +git-tree-sha1 = "1552eb3a913b8ea964482bd68a8fd6e623f57b45" +lazy = true + + [[mobilenet_v2.download]] + sha256 = "24d8afb8897cb82825c51e65c7a02f4b2117581294de25ac7fd5222bbc15ae6e" + url = "https://huggingface.co/FluxML/mobilenet/resolve/ccf1ffc5bbb4f2d3b9b7fb7e1285e52de337774e/mobilenet_v2-IMAGENET1K_V2.tar.gz" + +[mobilenet_v3_small] +git-tree-sha1 = "9b1f0550051c731056cda7739fcdd115c36c04ad" +lazy = true + + [[mobilenet_v3_small.download]] + sha256 = "35b7f6d733bfbd1621349a0ec27391c2714add7ca80006fb4032d8bc66629c97" + url = "https://huggingface.co/FluxML/mobilenet/resolve/ccf1ffc5bbb4f2d3b9b7fb7e1285e52de337774e/mobilenet_v3_small-IMAGENET1K_V1.tar.gz" + +[mobilenet_v3_large] +git-tree-sha1 = "49971ff8327bc591885e78ff94140ee472f77329" +lazy = true + + [[mobilenet_v3_large.download]] + sha256 = "555fcb5f4f6574d77b603b2fc6672ab437ef60a1a20bc2f951122c91aaaf2f69" + url = "https://huggingface.co/FluxML/mobilenet/resolve/ccf1ffc5bbb4f2d3b9b7fb7e1285e52de337774e/mobilenet_v3_large-IMAGENET1K_V2.tar.gz" + [resnet101] git-tree-sha1 = "68d563526ab34d3e5aa66b7d96278d2acde212f9" lazy = true diff --git a/README.md b/README.md index c25e01a9..f3f5e123 100644 --- a/README.md +++ b/README.md @@ -37,8 +37,8 @@ To contribute new models, see our [contributing docs](https://fluxml.ai/Metalhea | [InceptionResNet-v2](https://arxiv.org/abs/1602.07261) | [`InceptionResNetv2`](https://fluxml.ai/Metalhead.jl/dev/api/inception/#Metalhead.InceptionResNetv2) | N | | [MLPMixer](https://arxiv.org/pdf/2105.01601) | [`MLPMixer`](https://fluxml.ai/Metalhead.jl/dev/api/mixers/#Metalhead.MLPMixer) | N | | [MobileNetv1](https://arxiv.org/abs/1704.04861) | [`MobileNetv1`](https://fluxml.ai/Metalhead.jl/dev/api/mobilenet/#Metalhead.MobileNetv1) | N | -| [MobileNetv2](https://arxiv.org/abs/1801.04381) | [`MobileNetv2`](https://fluxml.ai/Metalhead.jl/dev/api/mobilenet/#Metalhead.MobileNetv2) | N | -| [MobileNetv3](https://arxiv.org/abs/1905.02244) | [`MobileNetv3`](https://fluxml.ai/Metalhead.jl/dev/api/mobilenet/#Metalhead.MobileNetv3) | N | +| [MobileNetv2](https://arxiv.org/abs/1801.04381) | [`MobileNetv2`](https://fluxml.ai/Metalhead.jl/dev/api/mobilenet/#Metalhead.MobileNetv2) | Y | +| [MobileNetv3](https://arxiv.org/abs/1905.02244) | [`MobileNetv3`](https://fluxml.ai/Metalhead.jl/dev/api/mobilenet/#Metalhead.MobileNetv3) | Y | | [MNASNet](https://arxiv.org/abs/1807.11626) | [`MNASNet`](https://fluxml.ai/Metalhead.jl/dev/api/efficientnet/#Metalhead.MNASNet) | N | | [ResMLP](https://arxiv.org/abs/2105.03404) | [`ResMLP`](https://fluxml.ai/Metalhead.jl/dev/api/mixers/#Metalhead.ResMLP) | N | | [ResNet](https://arxiv.org/abs/1512.03385) | [`ResNet`](https://fluxml.ai/Metalhead.jl/dev/api/resnet/#Metalhead.ResNet) | Y | diff --git a/scripts/Project.toml b/scripts/Project.toml index ba44ab17..18e0525b 100644 --- a/scripts/Project.toml +++ b/scripts/Project.toml @@ -3,6 +3,7 @@ ArtifactUtils = "8b73e784-e7d8-4ea5-973d-377fed4e3bce" BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" HuggingFaceApi = "3cc741c3-0c9d-4fbe-84fa-cdec264173de" Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0" JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" @@ -10,3 +11,4 @@ MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc" PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +TestImages = "5e47fb64-e119-507b-a336-dd2b206d9990" diff --git a/scripts/port_torchvision.jl b/scripts/port_torchvision.jl index 76e26b7d..17b7e2fa 100644 --- a/scripts/port_torchvision.jl +++ b/scripts/port_torchvision.jl @@ -33,6 +33,9 @@ model_list = [ # ("vit_b_32", "IMAGENET1K_V1", () -> ViT(:base, patch_size=(32,32)), weights -> tvmodels.vit_b_32(; weights)), # ("vit_l_16", "IMAGENET1K_V1", () -> ViT(:large), weights -> tvmodels.vit_l_16(; weights)), # ("vit_l_32", "IMAGENET1K_V1", () -> ViT(:large, patch_size=(32,32)), weights -> tvmodels.vit_l_32(; weights)), + # ("mobilenet_v2", "IMAGENET1K_V2", () -> MobileNetv2(), weights -> tvmodels.mobilenet_v2(; weights)), + # ("mobilenet_v3_small", "IMAGENET1K_V1", () -> MobileNetv3(:small), weights -> tvmodels.mobilenet_v3_small(; weights)), + # ("mobilenet_v3_large", "IMAGENET1K_V2", () -> MobileNetv3(:large), weights -> tvmodels.mobilenet_v3_large(; weights)), ## NOT WORKING: ("densenet121", "IMAGENET1K_V1", () -> DenseNet(121), weights -> tvmodels.densenet121(; weights)), # ("squeezenet1_0", "IMAGENET1K_V1", () -> SqueezeNet(), weights -> tvmodels.squeezenet1_0(; weights)), diff --git a/scripts/pytorch2flux.jl b/scripts/pytorch2flux.jl index 92f68647..d2b5e394 100644 --- a/scripts/pytorch2flux.jl +++ b/scripts/pytorch2flux.jl @@ -10,19 +10,18 @@ using BSON using PythonCall using Images using Test +using TestImages include("utils.jl") const torch = pyimport("torch") const torchvision = pyimport("torchvision") -# test image -const GUITAR_PATH = download("https://cdn.pixabay.com/photo/2015/05/07/11/02/guitar-756326_960_720.jpg") const IMAGENET_LABELS = readlines(download("https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt")) function compare_pytorch(jlmodel, pymodel; rtol = 1e-4) sz = (224, 224) - img = Images.load(GUITAR_PATH); + img = testimage("monarch_color_256") img = imresize(img, sz); # CHW -> WHC data = permutedims(convert(Array{Float32}, channelview(img)), (3,2,1)) diff --git a/src/convnets/mobilenets/mobilenetv2.jl b/src/convnets/mobilenets/mobilenetv2.jl index 702aebb3..ff3f5d7a 100644 --- a/src/convnets/mobilenets/mobilenetv2.jl +++ b/src/convnets/mobilenets/mobilenetv2.jl @@ -57,10 +57,6 @@ Create a MobileNetv2 model with the specified configuration. - `inchannels`: The number of input channels. - `nclasses`: The number of output classes -!!! warning - - `MobileNetv2` does not currently support pretrained weights. - See also [`Metalhead.mobilenetv2`](@ref). """ struct MobileNetv2 @@ -73,6 +69,9 @@ function MobileNetv2(width_mult::Real = 1; pretrain::Bool = false, layers = mobilenetv2(width_mult; inchannels, nclasses) model = MobileNetv2(layers) if pretrain + if width_mult != 1.0 + throw(ArgumentError("No pre-trained weights available for width_mult=$width_mult.")) + end loadpretrain!(model, "mobilenet_v2") end return model diff --git a/src/convnets/mobilenets/mobilenetv3.jl b/src/convnets/mobilenets/mobilenetv3.jl index 49f56dcb..84fe6041 100644 --- a/src/convnets/mobilenets/mobilenetv3.jl +++ b/src/convnets/mobilenets/mobilenetv3.jl @@ -16,7 +16,7 @@ const MOBILENETV3_CONFIGS = Dict(:small => (1024, (mbconv, 5, 40, 4, 2, 1, 4, hardswish), (mbconv, 5, 40, 6, 1, 2, 4, hardswish), (mbconv, 5, 48, 3, 1, 2, 4, hardswish), - (mbconv, 5, 96, 6, 1, 3, 4, hardswish), + (mbconv, 5, 96, 6, 2, 3, 4, hardswish), ]), :large => (1280, [ @@ -31,7 +31,7 @@ const MOBILENETV3_CONFIGS = Dict(:small => (1024, (mbconv, 3, 80, 2.3, 1, 2, nothing, hardswish), (mbconv, 3, 112, 6, 1, 2, 4, hardswish), - (mbconv, 5, 160, 6, 1, 3, 4, hardswish), + (mbconv, 5, 160, 6, 2, 3, 4, hardswish), ])) """ @@ -54,9 +54,10 @@ function mobilenetv3(config::Symbol; width_mult::Real = 1, dropout_prob = 0.2, inchannels::Integer = 3, nclasses::Integer = 1000) _checkconfig(config, [:small, :large]) max_width, block_configs = MOBILENETV3_CONFIGS[config] + norm_layer = (args...; kwargs...) -> BatchNorm(args...; momentum=1.0f-2, eps=1.0f-3, kwargs...) return build_invresmodel(width_mult, block_configs; inplanes = 16, - headplanes = max_width, activation = relu, - se_from_explanes = true, se_round_fn = _round_channels, + headplanes = max_width, activation = hardswish, norm_layer, + se_activation = relu, se_from_explanes = true, se_round_fn = _round_channels, expanded_classifier = true, dropout_prob, inchannels, nclasses) end @@ -77,10 +78,6 @@ Set `pretrain = true` to load the model with pre-trained weights for ImageNet. - `inchannels`: number of input channels - `nclasses`: the number of output classes -!!! warning - - `MobileNetv3` does not currently support pretrained weights. - See also [`Metalhead.mobilenetv3`](@ref). """ struct MobileNetv3 @@ -93,7 +90,10 @@ function MobileNetv3(config::Symbol; width_mult::Real = 1, pretrain::Bool = fals layers = mobilenetv3(config; width_mult, inchannels, nclasses) model = MobileNetv3(layers) if pretrain - loadpretrain!(model, "mobilenet_v3") + if width_mult != 1.0 + throw(ArgumentError("No pre-trained weights available for width_mult=$width_mult.")) + end + loadpretrain!(model, string("mobilenet_v3_", config)) end return model end diff --git a/src/layers/mbconv.jl b/src/layers/mbconv.jl index 6300880e..79fb4925 100644 --- a/src/layers/mbconv.jl +++ b/src/layers/mbconv.jl @@ -81,6 +81,7 @@ First introduced in the MobileNetv2 paper. function mbconv(kernel_size::Dims{2}, inplanes::Integer, explanes::Integer, outplanes::Integer, activation = relu; stride::Integer, reduction::Union{Nothing, Real} = nothing, + se_activation = activation, se_round_fn = x -> round(Int, x), norm_layer = BatchNorm) @assert stride in [1, 2] "`stride` has to be 1 or 2 for `mbconv`" layers = [] @@ -97,10 +98,10 @@ function mbconv(kernel_size::Dims{2}, inplanes::Integer, explanes::Integer, if !isnothing(reduction) push!(layers, squeeze_excite(explanes; round_fn = se_round_fn, reduction, - activation, gate_activation = hardσ)) + activation=se_activation, gate_activation = hardσ)) end # project - append!(layers, conv_norm((1, 1), explanes, outplanes, identity)) + append!(layers, conv_norm((1, 1), explanes, outplanes, identity; norm_layer)) return Chain(layers...) end diff --git a/test/convnet_tests.jl b/test/convnet_tests.jl index 630de19e..7526cdf7 100644 --- a/test/convnet_tests.jl +++ b/test/convnet_tests.jl @@ -307,9 +307,9 @@ end m = MobileNetv2(width_mult) |> gpu @test size(m(x_224)) == (1000, 1) if (MobileNetv2, width_mult) in PRETRAINED_MODELS - @test acctest(MobileNetv2(; pretrain = true)) + @test acctest(MobileNetv2(width_mult; pretrain = true)) else - @test_throws ArgumentError MobileNetv2(pretrain = true) + @test_throws ArgumentError MobileNetv2(width_mult; pretrain = true) end @test gradtest(m, x_224) end @@ -321,9 +321,9 @@ end m = MobileNetv3(config; width_mult) |> gpu @test size(m(x_224)) == (1000, 1) if (MobileNetv3, config, width_mult) in PRETRAINED_MODELS - @test acctest(MobileNetv3(config; pretrain = true)) + @test acctest(MobileNetv3(config; width_mult, pretrain = true)) else - @test_throws ArgumentError MobileNetv3(config; pretrain = true) + @test_throws ArgumentError MobileNetv3(config; width_mult, pretrain = true) end @test gradtest(m, x_224) _gc() diff --git a/test/model_tests.jl b/test/model_tests.jl index 7df7a7a4..51413ea8 100644 --- a/test/model_tests.jl +++ b/test/model_tests.jl @@ -24,6 +24,9 @@ const PRETRAINED_MODELS = [ # (DenseNet, 161), # (DenseNet, 169), # (DenseNet, 201), + (MobileNetv2, 1.0), + (MobileNetv3, :small, 1.0), + (MobileNetv3, :large, 1.0), (ResNet, 18), (ResNet, 34), (ResNet, 50), @@ -85,4 +88,4 @@ end const x_224 = rand(Float32, 224, 224, 3, 1) |> gpu const x_256 = rand(Float32, 256, 256, 3, 1) |> gpu -end \ No newline at end of file +end