Skip to content

Commit b039aef

Browse files
paulnovoCarloLucibello
authored andcommitted
Match torchvision MobileNet V3 implementation
* Set epsilon to 0.001 and momentum to 0.01 in all BatchNorm layers * MobileNet V3 uses a ReLu activation in all squeeze excitation layers * Fix stride in final mbconv block of both small and large models, this should be set to 2 instead of 1
1 parent c533d51 commit b039aef

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

src/convnets/mobilenets/mobilenetv3.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ const MOBILENETV3_CONFIGS = Dict(:small => (1024,
1616
(mbconv, 5, 40, 4, 2, 1, 4, hardswish),
1717
(mbconv, 5, 40, 6, 1, 2, 4, hardswish),
1818
(mbconv, 5, 48, 3, 1, 2, 4, hardswish),
19-
(mbconv, 5, 96, 6, 1, 3, 4, hardswish),
19+
(mbconv, 5, 96, 6, 2, 3, 4, hardswish),
2020
]),
2121
:large => (1280,
2222
[
@@ -31,7 +31,7 @@ const MOBILENETV3_CONFIGS = Dict(:small => (1024,
3131
(mbconv, 3, 80, 2.3, 1, 2, nothing,
3232
hardswish),
3333
(mbconv, 3, 112, 6, 1, 2, 4, hardswish),
34-
(mbconv, 5, 160, 6, 1, 3, 4, hardswish),
34+
(mbconv, 5, 160, 6, 2, 3, 4, hardswish),
3535
]))
3636

3737
"""
@@ -54,9 +54,10 @@ function mobilenetv3(config::Symbol; width_mult::Real = 1, dropout_prob = 0.2,
5454
inchannels::Integer = 3, nclasses::Integer = 1000)
5555
_checkconfig(config, [:small, :large])
5656
max_width, block_configs = MOBILENETV3_CONFIGS[config]
57+
norm_layer = (args...; kwargs...) -> BatchNorm(args...; momentum=1.0f-2, eps=1.0f-3, kwargs...)
5758
return build_invresmodel(width_mult, block_configs; inplanes = 16,
58-
headplanes = max_width, activation = relu,
59-
se_from_explanes = true, se_round_fn = _round_channels,
59+
headplanes = max_width, activation = hardswish, norm_layer,
60+
se_activation = relu, se_from_explanes = true, se_round_fn = _round_channels,
6061
expanded_classifier = true, dropout_prob, inchannels, nclasses)
6162
end
6263

src/layers/mbconv.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ First introduced in the MobileNetv2 paper.
8181
function mbconv(kernel_size::Dims{2}, inplanes::Integer, explanes::Integer,
8282
outplanes::Integer, activation = relu; stride::Integer,
8383
reduction::Union{Nothing, Real} = nothing,
84+
se_activation = activation,
8485
se_round_fn = x -> round(Int, x), norm_layer = BatchNorm)
8586
@assert stride in [1, 2] "`stride` has to be 1 or 2 for `mbconv`"
8687
layers = []
@@ -97,10 +98,10 @@ function mbconv(kernel_size::Dims{2}, inplanes::Integer, explanes::Integer,
9798
if !isnothing(reduction)
9899
push!(layers,
99100
squeeze_excite(explanes; round_fn = se_round_fn, reduction,
100-
activation, gate_activation = hardσ))
101+
activation=se_activation, gate_activation = hardσ))
101102
end
102103
# project
103-
append!(layers, conv_norm((1, 1), explanes, outplanes, identity))
104+
append!(layers, conv_norm((1, 1), explanes, outplanes, identity; norm_layer))
104105
return Chain(layers...)
105106
end
106107

0 commit comments

Comments
 (0)