Skip to content

Commit a8b43af

Browse files
authored
Merge pull request #155 from theabhirath/revert-conv_bn
2 parents 792076f + f066d5a commit a8b43af

File tree

8 files changed

+80
-81
lines changed

8 files changed

+80
-81
lines changed

src/convnets/convmixer.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@ Creates a ConvMixer model.
1717
function convmixer(planes, depth; inchannels = 3, kernel_size = (9, 9),
1818
patch_size::Dims{2} = (7, 7), activation = gelu, nclasses = 1000)
1919
stem = conv_bn(patch_size, inchannels, planes, activation; preact = true, stride = patch_size[1])
20-
blocks = [Chain(SkipConnection(conv_bn(kernel_size, planes, planes, activation;
21-
preact = true, groups = planes, pad = SamePad()), +),
22-
conv_bn((1, 1), planes, planes, activation; preact = true)) for _ in 1:depth]
20+
blocks = [Chain(SkipConnection(Chain(conv_bn(kernel_size, planes, planes, activation;
21+
preact = true, groups = planes, pad = SamePad())), +),
22+
conv_bn((1, 1), planes, planes, activation; preact = true)...) for _ in 1:depth]
2323
head = Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, Dense(planes, nclasses))
24-
return Chain(Chain(stem, Chain(blocks)), head)
24+
return Chain(Chain(stem..., Chain(blocks)), head)
2525
end
2626

2727
convmixer_config = Dict(:base => Dict(:planes => 1536, :depth => 20, :kernel_size => (9, 9),

src/convnets/densenet.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ Create a Densenet bottleneck layer
1111
"""
1212
function dense_bottleneck(inplanes, outplanes)
1313
inner_channels = 4 * outplanes
14-
m = Chain(conv_bn((1, 1), inplanes, inner_channels; bias = false, rev = true),
15-
conv_bn((3, 3), inner_channels, outplanes; pad = 1, bias = false, rev = true))
14+
m = Chain(conv_bn((1, 1), inplanes, inner_channels; bias = false, rev = true)...,
15+
conv_bn((3, 3), inner_channels, outplanes; pad = 1, bias = false, rev = true)...)
1616

1717
SkipConnection(m, cat_channels)
1818
end
@@ -28,7 +28,7 @@ Create a DenseNet transition sequence
2828
- `outplanes`: number of output feature maps
2929
"""
3030
transition(inplanes, outplanes) =
31-
Chain(conv_bn((1, 1), inplanes, outplanes; bias = false, rev = true), MeanPool((2, 2)))
31+
Chain(conv_bn((1, 1), inplanes, outplanes; bias = false, rev = true)..., MeanPool((2, 2)))
3232

3333
"""
3434
dense_block(inplanes, growth_rates)
@@ -60,7 +60,7 @@ Create a DenseNet model
6060
"""
6161
function densenet(inplanes, growth_rates; reduction = 0.5, nclasses = 1000)
6262
layers = []
63-
push!(layers, conv_bn((7, 7), 3, inplanes; stride = 2, pad = (3, 3), bias = false))
63+
append!(layers, conv_bn((7, 7), 3, inplanes; stride = 2, pad = (3, 3), bias = false))
6464
push!(layers, MaxPool((3, 3), stride = 2, pad = (1, 1)))
6565

6666
outplanes = 0

src/convnets/inception.jl

Lines changed: 41 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,17 @@ Create an Inception-v3 style-A module
99
- `pool_proj`: the number of output feature maps for the pooling projection
1010
"""
1111
function inception_a(inplanes, pool_proj)
12-
branch1x1 = conv_bn((1, 1), inplanes, 64)
12+
branch1x1 = Chain(conv_bn((1, 1), inplanes, 64))
1313

14-
branch5x5 = Chain(conv_bn((1, 1), inplanes, 48),
15-
conv_bn((5, 5), 48, 64; pad = 2))
14+
branch5x5 = Chain(conv_bn((1, 1), inplanes, 48)...,
15+
conv_bn((5, 5), 48, 64; pad = 2)...)
1616

17-
branch3x3 = Chain(conv_bn((1, 1), inplanes, 64),
18-
conv_bn((3, 3), 64, 96; pad = 1),
19-
conv_bn((3, 3), 96, 96; pad = 1))
17+
branch3x3 = Chain(conv_bn((1, 1), inplanes, 64)...,
18+
conv_bn((3, 3), 64, 96; pad = 1)...,
19+
conv_bn((3, 3), 96, 96; pad = 1)...)
2020

2121
branch_pool = Chain(MeanPool((3, 3), pad = 1, stride = 1),
22-
conv_bn((1, 1), inplanes, pool_proj))
22+
conv_bn((1, 1), inplanes, pool_proj)...)
2323

2424
return Parallel(cat_channels,
2525
branch1x1, branch5x5, branch3x3, branch_pool)
@@ -35,11 +35,11 @@ Create an Inception-v3 style-B module
3535
- `inplanes`: number of input feature maps
3636
"""
3737
function inception_b(inplanes)
38-
branch3x3_1 = conv_bn((3, 3), inplanes, 384; stride = 2)
38+
branch3x3_1 = Chain(conv_bn((3, 3), inplanes, 384; stride = 2))
3939

40-
branch3x3_2 = Chain(conv_bn((1, 1), inplanes, 64),
41-
conv_bn((3, 3), 64, 96; pad = 1),
42-
conv_bn((3, 3), 96, 96; stride = 2))
40+
branch3x3_2 = Chain(conv_bn((1, 1), inplanes, 64)...,
41+
conv_bn((3, 3), 64, 96; pad = 1)...,
42+
conv_bn((3, 3), 96, 96; stride = 2)...)
4343

4444
branch_pool = MaxPool((3, 3), stride = 2)
4545

@@ -59,20 +59,20 @@ Create an Inception-v3 style-C module
5959
- `n`: the "grid size" (kernel size) for the convolution layers
6060
"""
6161
function inception_c(inplanes, inner_planes, n = 7)
62-
branch1x1 = conv_bn((1, 1), inplanes, 192)
62+
branch1x1 = Chain(conv_bn((1, 1), inplanes, 192))
6363

64-
branch7x7_1 = Chain(conv_bn((1, 1), inplanes, inner_planes),
65-
conv_bn((1, n), inner_planes, inner_planes; pad = (0, 3)),
66-
conv_bn((n, 1), inner_planes, 192; pad = (3, 0)))
64+
branch7x7_1 = Chain(conv_bn((1, 1), inplanes, inner_planes)...,
65+
conv_bn((1, n), inner_planes, inner_planes; pad = (0, 3))...,
66+
conv_bn((n, 1), inner_planes, 192; pad = (3, 0))...)
6767

68-
branch7x7_2 = Chain(conv_bn((1, 1), inplanes, inner_planes),
69-
conv_bn((n, 1), inner_planes, inner_planes; pad = (3, 0)),
70-
conv_bn((1, n), inner_planes, inner_planes; pad = (0, 3)),
71-
conv_bn((n, 1), inner_planes, inner_planes; pad = (3, 0)),
72-
conv_bn((1, n), inner_planes, 192; pad = (0, 3)))
68+
branch7x7_2 = Chain(conv_bn((1, 1), inplanes, inner_planes)...,
69+
conv_bn((n, 1), inner_planes, inner_planes; pad = (3, 0))...,
70+
conv_bn((1, n), inner_planes, inner_planes; pad = (0, 3))...,
71+
conv_bn((n, 1), inner_planes, inner_planes; pad = (3, 0))...,
72+
conv_bn((1, n), inner_planes, 192; pad = (0, 3))...)
7373

7474
branch_pool = Chain(MeanPool((3, 3), pad = 1, stride=1),
75-
conv_bn((1, 1), inplanes, 192))
75+
conv_bn((1, 1), inplanes, 192)...)
7676

7777
return Parallel(cat_channels,
7878
branch1x1, branch7x7_1, branch7x7_2, branch_pool)
@@ -88,13 +88,13 @@ Create an Inception-v3 style-D module
8888
- `inplanes`: number of input feature maps
8989
"""
9090
function inception_d(inplanes)
91-
branch3x3 = Chain(conv_bn((1, 1), inplanes, 192),
92-
conv_bn((3, 3), 192, 320; stride = 2))
91+
branch3x3 = Chain(conv_bn((1, 1), inplanes, 192)...,
92+
conv_bn((3, 3), 192, 320; stride = 2)...)
9393

94-
branch7x7x3 = Chain(conv_bn((1, 1), inplanes, 192),
95-
conv_bn((1, 7), 192, 192; pad = (0, 3)),
96-
conv_bn((7, 1), 192, 192; pad = (3, 0)),
97-
conv_bn((3, 3), 192, 192; stride = 2))
94+
branch7x7x3 = Chain(conv_bn((1, 1), inplanes, 192)...,
95+
conv_bn((1, 7), 192, 192; pad = (0, 3))...,
96+
conv_bn((7, 1), 192, 192; pad = (3, 0))...,
97+
conv_bn((3, 3), 192, 192; stride = 2)...)
9898

9999
branch_pool = MaxPool((3, 3), stride=2)
100100

@@ -112,19 +112,19 @@ Create an Inception-v3 style-E module
112112
- `inplanes`: number of input feature maps
113113
"""
114114
function inception_e(inplanes)
115-
branch1x1 = conv_bn((1, 1), inplanes, 320)
115+
branch1x1 = Chain(conv_bn((1, 1), inplanes, 320))
116116

117-
branch3x3_1 = conv_bn((1, 1), inplanes, 384)
118-
branch3x3_1a = conv_bn((1, 3), 384, 384; pad = (0, 1))
119-
branch3x3_1b = conv_bn((3, 1), 384, 384; pad = (1, 0))
117+
branch3x3_1 = Chain(conv_bn((1, 1), inplanes, 384))
118+
branch3x3_1a = Chain(conv_bn((1, 3), 384, 384; pad = (0, 1)))
119+
branch3x3_1b = Chain(conv_bn((3, 1), 384, 384; pad = (1, 0)))
120120

121-
branch3x3_2 = Chain(conv_bn((1, 1), inplanes, 448),
122-
conv_bn((3, 3), 448, 384; pad = 1))
123-
branch3x3_2a = conv_bn((1, 3), 384, 384; pad = (0, 1))
124-
branch3x3_2b = conv_bn((3, 1), 384, 384; pad = (1, 0))
121+
branch3x3_2 = Chain(conv_bn((1, 1), inplanes, 448)...,
122+
conv_bn((3, 3), 448, 384; pad = 1)...)
123+
branch3x3_2a = Chain(conv_bn((1, 3), 384, 384; pad = (0, 1)))
124+
branch3x3_2b = Chain(conv_bn((3, 1), 384, 384; pad = (1, 0)))
125125

126126
branch_pool = Chain(MeanPool((3, 3), pad = 1, stride = 1),
127-
conv_bn((1, 1), inplanes, 192))
127+
conv_bn((1, 1), inplanes, 192)...)
128128

129129
return Parallel(cat_channels,
130130
branch1x1,
@@ -150,12 +150,12 @@ Create an Inception-v3 model ([reference](https://arxiv.org/abs/1512.00567v3)).
150150
`inception3` does not currently support pretrained weights.
151151
"""
152152
function inception3(; nclasses = 1000)
153-
layer = Chain(Chain(conv_bn((3, 3), 3, 32; stride = 2),
154-
conv_bn((3, 3), 32, 32),
155-
conv_bn((3, 3), 32, 64; pad = 1),
153+
layer = Chain(Chain(conv_bn((3, 3), 3, 32; stride = 2)...,
154+
conv_bn((3, 3), 32, 32)...,
155+
conv_bn((3, 3), 32, 64; pad = 1)...,
156156
MaxPool((3, 3), stride = 2),
157-
conv_bn((1, 1), 64, 80),
158-
conv_bn((3, 3), 80, 192),
157+
conv_bn((1, 1), 64, 80)...,
158+
conv_bn((3, 3), 80, 192)...,
159159
MaxPool((3, 3), stride = 2),
160160
inception_a(192, 32),
161161
inception_a(256, 64),

src/convnets/mobilenet.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ function mobilenetv1(width_mult, config;
3434
layer = dw ? depthwise_sep_conv_bn((3, 3), inchannels, outch, activation;
3535
stride = stride, pad = 1) :
3636
conv_bn((3, 3), inchannels, outch, activation; stride = stride, pad = 1)
37-
push!(layers, layer)
37+
append!(layers, layer)
3838
inchannels = outch
3939
end
4040
end
@@ -118,7 +118,7 @@ function mobilenetv2(width_mult, configs; max_width = 1280, nclasses = 1000)
118118
# building first layer
119119
inplanes = _round_channels(32 * width_mult, width_mult == 0.1 ? 4 : 8)
120120
layers = []
121-
push!(layers, conv_bn((3, 3), 3, inplanes, stride = 2))
121+
append!(layers, conv_bn((3, 3), 3, inplanes, stride = 2))
122122

123123
# building inverted residual blocks
124124
for (t, c, n, s, a) in configs
@@ -134,7 +134,7 @@ function mobilenetv2(width_mult, configs; max_width = 1280, nclasses = 1000)
134134
outplanes = (width_mult > 1) ? _round_channels(max_width * width_mult, width_mult == 0.1 ? 4 : 8) :
135135
max_width
136136

137-
return Chain(Chain(Chain(layers), conv_bn((1, 1), inplanes, outplanes, relu6, bias = false)),
137+
return Chain(Chain(Chain(layers), conv_bn((1, 1), inplanes, outplanes, relu6, bias = false)...),
138138
Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, Dense(outplanes, nclasses)))
139139
end
140140

@@ -211,7 +211,7 @@ function mobilenetv3(width_mult, configs; max_width = 1024, nclasses = 1000)
211211
# building first layer
212212
inplanes = _round_channels(16 * width_mult, 8)
213213
layers = []
214-
push!(layers, conv_bn((3, 3), 3, inplanes, hardswish; stride = 2))
214+
append!(layers, conv_bn((3, 3), 3, inplanes, hardswish; stride = 2))
215215
explanes = 0
216216
# building inverted residual blocks
217217
for (k, t, c, r, a, s) in configs
@@ -230,7 +230,7 @@ function mobilenetv3(width_mult, configs; max_width = 1024, nclasses = 1000)
230230
Dropout(0.2),
231231
Dense(output_channel, nclasses))
232232

233-
return Chain(Chain(Chain(layers), conv_bn((1, 1), inplanes, explanes, hardswish, bias = false)),
233+
return Chain(Chain(Chain(layers), conv_bn((1, 1), inplanes, explanes, hardswish, bias = false)...),
234234
Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, classifier))
235235
end
236236

src/convnets/resnet.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ Create a basic residual block
1212
"""
1313
function basicblock(inplanes, outplanes, downsample = false)
1414
stride = downsample ? 2 : 1
15-
Chain(conv_bn((3, 3), inplanes, outplanes[1]; stride = stride, pad = 1, bias = false),
16-
conv_bn((3, 3), outplanes[1], outplanes[2], identity; stride = 1, pad = 1, bias = false))
15+
Chain(conv_bn((3, 3), inplanes, outplanes[1]; stride = stride, pad = 1, bias = false)...,
16+
conv_bn((3, 3), outplanes[1], outplanes[2], identity; stride = 1, pad = 1, bias = false)...)
1717
end
1818

1919
"""
@@ -36,12 +36,11 @@ The original paper uses `stride == [2, 1, 1]` when `downsample == true` instead.
3636
"""
3737
function bottleneck(inplanes, outplanes, downsample = false;
3838
stride = [1, (downsample ? 2 : 1), 1])
39-
Chain(conv_bn((1, 1), inplanes, outplanes[1]; stride = stride[1], bias = false),
40-
conv_bn((3, 3), outplanes[1], outplanes[2]; stride = stride[2], pad = 1, bias = false),
41-
conv_bn((1, 1), outplanes[2], outplanes[3], identity; stride = stride[3], bias = false))
39+
Chain(conv_bn((1, 1), inplanes, outplanes[1]; stride = stride[1], bias = false)...,
40+
conv_bn((3, 3), outplanes[1], outplanes[2]; stride = stride[2], pad = 1, bias = false)...,
41+
conv_bn((1, 1), outplanes[2], outplanes[3], identity; stride = stride[3], bias = false)...)
4242
end
4343

44-
4544
"""
4645
bottleneck_v1(inplanes, outplanes, downsample = false)
4746
@@ -82,7 +81,7 @@ function resnet(block, residuals::AbstractVector{<:NTuple{2, Any}}, connection =
8281
inplanes = 64
8382
baseplanes = 64
8483
layers = []
85-
push!(layers, conv_bn((7, 7), 3, inplanes; stride = 2, pad = 3, bias = false))
84+
append!(layers, conv_bn((7, 7), 3, inplanes; stride = 2, pad = 3, bias = false))
8685
push!(layers, MaxPool((3, 3), stride = (2, 2), pad = (1, 1)))
8786
for (i, nrepeats) in enumerate(block_config)
8887
# output planes within a block

src/convnets/resnext.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ Create a basic residual block as defined in the paper for ResNeXt
1414
function resnextblock(inplanes, outplanes, cardinality, width, downsample = false)
1515
stride = downsample ? 2 : 1
1616
hidden_channels = cardinality * width
17-
return Chain(conv_bn((1, 1), inplanes, hidden_channels; stride = 1, bias = false),
17+
return Chain(conv_bn((1, 1), inplanes, hidden_channels; stride = 1, bias = false)...,
1818
conv_bn((3, 3), hidden_channels, hidden_channels;
19-
stride = stride, pad = 1, bias = false, groups = cardinality),
20-
conv_bn((1, 1), hidden_channels, outplanes; stride = 1, bias = false))
19+
stride = stride, pad = 1, bias = false, groups = cardinality)...,
20+
conv_bn((1, 1), hidden_channels, outplanes; stride = 1, bias = false)...)
2121
end
2222

2323
"""
@@ -40,7 +40,7 @@ function resnext(cardinality, width, widen_factor = 2, connection = (x, y) -> @.
4040
inplanes = 64
4141
baseplanes = 128
4242
layers = []
43-
push!(layers, conv_bn((7, 7), 3, inplanes; stride = 2, pad = (3, 3)))
43+
append!(layers, conv_bn((7, 7), 3, inplanes; stride = 2, pad = (3, 3)))
4444
push!(layers, MaxPool((3, 3), stride = (2, 2), pad = (1, 1)))
4545
for (i, nrepeats) in enumerate(block_config)
4646
# output planes within a block

src/convnets/vgg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ function vgg_block(ifilters, ofilters, depth, batchnorm)
1616
layers = []
1717
for _ in 1:depth
1818
if batchnorm
19-
push!(layers, conv_bn(k, ifilters, ofilters; pad = p, bias = false))
19+
append!(layers, conv_bn(k, ifilters, ofilters; pad = p, bias = false))
2020
else
2121
push!(layers, Conv(k, ifilters => ofilters, relu, pad = p))
2222
end

src/layers/conv.jl

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ function conv_bn(kernelsize, inplanes, outplanes, activation = relu;
4545
push!(layers, BatchNorm(Int(bnplanes), activations.bn;
4646
initβ = initβ, initγ = initγ, ϵ = ϵ, momentum = momentum))
4747

48-
return rev ? Chain(reverse(layers)) : Chain(layers)
48+
return rev ? reverse(layers) : layers
4949
end
5050

5151
"""
@@ -82,13 +82,13 @@ depthwise_sep_conv_bn(kernelsize, inplanes, outplanes, activation = relu;
8282
initβ = Flux.zeros32, initγ = Flux.ones32,
8383
ϵ = 1f-5, momentum = 1f-1,
8484
stride = 1, kwargs...) =
85-
Chain(vcat(conv_bn(kernelsize, inplanes, inplanes, activation;
86-
rev = rev, initβ = initβ, initγ = initγ,
87-
ϵ = ϵ, momentum = momentum,
88-
stride = stride, groups = Int(inplanes), kwargs...),
89-
conv_bn((1, 1), inplanes, outplanes, activation;
90-
rev = rev, initβ = initβ, initγ = initγ,
91-
ϵ = ϵ, momentum = momentum)))
85+
vcat(conv_bn(kernelsize, inplanes, inplanes, activation;
86+
rev = rev, initβ = initβ, initγ = initγ,
87+
ϵ = ϵ, momentum = momentum,
88+
stride = stride, groups = Int(inplanes), kwargs...),
89+
conv_bn((1, 1), inplanes, outplanes, activation;
90+
rev = rev, initβ = initβ, initγ = initγ,
91+
ϵ = ϵ, momentum = momentum))
9292

9393
"""
9494
skip_projection(inplanes, outplanes, downsample = false)
@@ -102,8 +102,8 @@ Create a skip projection
102102
- `downsample`: set to `true` to downsample the input
103103
"""
104104
skip_projection(inplanes, outplanes, downsample = false) = downsample ?
105-
conv_bn((1, 1), inplanes, outplanes, identity; stride = 2, bias = false) :
106-
conv_bn((1, 1), inplanes, outplanes, identity; stride = 1, bias = false)
105+
Chain(conv_bn((1, 1), inplanes, outplanes, identity; stride = 2, bias = false)) :
106+
Chain(conv_bn((1, 1), inplanes, outplanes, identity; stride = 1, bias = false))
107107

108108
# array -> PaddedView(0, array, outplanes) for zero padding arrays
109109
"""
@@ -144,8 +144,8 @@ Squeeze and excitation layer used by MobileNet variants
144144
function squeeze_excite(channels, reduction = 4)
145145
@assert (reduction >= 1) "`reduction` must be >= 1"
146146
SkipConnection(Chain(AdaptiveMeanPool((1, 1)),
147-
conv_bn((1, 1), channels, channels ÷ reduction, relu; bias = false),
148-
conv_bn((1, 1), channels ÷ reduction, channels, hardσ)), .*)
147+
conv_bn((1, 1), channels, channels ÷ reduction, relu; bias = false)...,
148+
conv_bn((1, 1), channels ÷ reduction, channels, hardσ)...), .*)
149149
end
150150

151151
"""
@@ -171,14 +171,14 @@ function invertedresidual(kernel_size, inplanes, hidden_planes, outplanes, activ
171171
@assert stride in [1, 2] "`stride` has to be 1 or 2"
172172

173173
pad = @. (kernel_size - 1) ÷ 2
174-
conv1 = (inplanes == hidden_planes) ? identity : conv_bn((1, 1), inplanes, hidden_planes, activation; bias = false)
174+
conv1 = (inplanes == hidden_planes) ? identity : Chain(conv_bn((1, 1), inplanes, hidden_planes, activation; bias = false))
175175
selayer = isnothing(reduction) ? identity : squeeze_excite(hidden_planes, reduction)
176176

177177
invres = Chain(conv1,
178178
conv_bn(kernel_size, hidden_planes, hidden_planes, activation;
179-
bias = false, stride, pad = pad, groups = hidden_planes),
179+
bias = false, stride, pad = pad, groups = hidden_planes)...,
180180
selayer,
181-
conv_bn((1, 1), hidden_planes, outplanes, identity; bias = false))
181+
conv_bn((1, 1), hidden_planes, outplanes, identity; bias = false)...)
182182

183183
(stride == 1 && inplanes == outplanes) ? SkipConnection(invres, +) : invres
184184
end

0 commit comments

Comments
 (0)