1515from keras import layers
1616
1717from keras_nlp .src .api_export import keras_nlp_export
18- from keras_nlp .src .models .backbone import Backbone
18+ from keras_nlp .src .models .feature_pyramid_backbone import FeaturePyramidBackbone
1919
2020
2121@keras_nlp_export ("keras_nlp.models.CSPDarkNetBackbone" )
22- class CSPDarkNetBackbone (Backbone ):
22+ class CSPDarkNetBackbone (FeaturePyramidBackbone ):
2323 """This class represents Keras Backbone of CSPDarkNet model.
2424
2525 This class implements a CSPDarkNet backbone as described in
@@ -65,12 +65,15 @@ def __init__(
6565 self ,
6666 stackwise_num_filters ,
6767 stackwise_depth ,
68- include_rescaling ,
68+ include_rescaling = True ,
6969 block_type = "basic_block" ,
70- image_shape = (224 , 224 , 3 ),
70+ image_shape = (None , None , 3 ),
7171 ** kwargs ,
7272 ):
7373 # === Functional Model ===
74+ channel_axis = (
75+ - 1 if keras .config .image_data_format () == "channels_last" else 1
76+ )
7477 apply_ConvBlock = (
7578 apply_darknet_conv_block_depthwise
7679 if block_type == "depthwise_block"
@@ -83,15 +86,22 @@ def __init__(
8386 if include_rescaling :
8487 x = layers .Rescaling (scale = 1 / 255.0 )(x )
8588
86- x = apply_focus (name = "stem_focus" )(x )
89+ x = apply_focus (channel_axis , name = "stem_focus" )(x )
8790 x = apply_darknet_conv_block (
88- base_channels , kernel_size = 3 , strides = 1 , name = "stem_conv"
91+ base_channels ,
92+ channel_axis ,
93+ kernel_size = 3 ,
94+ strides = 1 ,
95+ name = "stem_conv" ,
8996 )(x )
97+
98+ pyramid_outputs = {}
9099 for index , (channels , depth ) in enumerate (
91100 zip (stackwise_num_filters , stackwise_depth )
92101 ):
93102 x = apply_ConvBlock (
94103 channels ,
104+ channel_axis ,
95105 kernel_size = 3 ,
96106 strides = 2 ,
97107 name = f"dark{ index + 2 } _conv" ,
@@ -100,17 +110,20 @@ def __init__(
100110 if index == len (stackwise_depth ) - 1 :
101111 x = apply_spatial_pyramid_pooling_bottleneck (
102112 channels ,
113+ channel_axis ,
103114 hidden_filters = channels // 2 ,
104115 name = f"dark{ index + 2 } _spp" ,
105116 )(x )
106117
107118 x = apply_cross_stage_partial (
108119 channels ,
120+ channel_axis ,
109121 num_bottlenecks = depth ,
110122 block_type = "basic_block" ,
111123 residual = (index != len (stackwise_depth ) - 1 ),
112124 name = f"dark{ index + 2 } _csp" ,
113125 )(x )
126+ pyramid_outputs [f"P{ index + 2 } " ] = x
114127
115128 super ().__init__ (inputs = image_input , outputs = x , ** kwargs )
116129
@@ -120,6 +133,7 @@ def __init__(
120133 self .include_rescaling = include_rescaling
121134 self .block_type = block_type
122135 self .image_shape = image_shape
136+ self .pyramid_outputs = pyramid_outputs
123137
124138 def get_config (self ):
125139 config = super ().get_config ()
@@ -135,7 +149,7 @@ def get_config(self):
135149 return config
136150
137151
138- def apply_focus (name = None ):
152+ def apply_focus (channel_axis , name = None ):
139153 """A block used in CSPDarknet to focus information into channels of the
140154 image.
141155
@@ -151,7 +165,7 @@ def apply_focus(name=None):
151165 """
152166
153167 def apply (x ):
154- return layers .Concatenate (name = name )(
168+ return layers .Concatenate (axis = channel_axis , name = name )(
155169 [
156170 x [..., ::2 , ::2 , :],
157171 x [..., 1 ::2 , ::2 , :],
@@ -164,7 +178,13 @@ def apply(x):
164178
165179
166180def apply_darknet_conv_block (
167- filters , kernel_size , strides , use_bias = False , activation = "silu" , name = None
181+ filters ,
182+ channel_axis ,
183+ kernel_size ,
184+ strides ,
185+ use_bias = False ,
186+ activation = "silu" ,
187+ name = None ,
168188):
169189 """
170190 The basic conv block used in Darknet. Applies Conv2D followed by a
@@ -193,11 +213,12 @@ def apply(inputs):
193213 kernel_size ,
194214 strides ,
195215 padding = "same" ,
216+ data_format = keras .config .image_data_format (),
196217 use_bias = use_bias ,
197218 name = name + "_conv" ,
198219 )(inputs )
199220
200- x = layers .BatchNormalization (name = name + "_bn" )(x )
221+ x = layers .BatchNormalization (axis = channel_axis , name = name + "_bn" )(x )
201222
202223 if activation == "silu" :
203224 x = layers .Lambda (lambda x : keras .activations .silu (x ))(x )
@@ -212,7 +233,7 @@ def apply(inputs):
212233
213234
214235def apply_darknet_conv_block_depthwise (
215- filters , kernel_size , strides , activation = "silu" , name = None
236+ filters , channel_axis , kernel_size , strides , activation = "silu" , name = None
216237):
217238 """
218239 The depthwise conv block used in CSPDarknet.
@@ -236,9 +257,13 @@ def apply_darknet_conv_block_depthwise(
236257
237258 def apply (inputs ):
238259 x = layers .DepthwiseConv2D (
239- kernel_size , strides , padding = "same" , use_bias = False
260+ kernel_size ,
261+ strides ,
262+ padding = "same" ,
263+ data_format = keras .config .image_data_format (),
264+ use_bias = False ,
240265 )(inputs )
241- x = layers .BatchNormalization ()(x )
266+ x = layers .BatchNormalization (axis = channel_axis )(x )
242267
243268 if activation == "silu" :
244269 x = layers .Lambda (lambda x : keras .activations .swish (x ))(x )
@@ -248,7 +273,11 @@ def apply(inputs):
248273 x = layers .LeakyReLU (0.1 )(x )
249274
250275 x = apply_darknet_conv_block (
251- filters , kernel_size = 1 , strides = 1 , activation = activation
276+ filters ,
277+ channel_axis ,
278+ kernel_size = 1 ,
279+ strides = 1 ,
280+ activation = activation ,
252281 )(x )
253282
254283 return x
@@ -258,6 +287,7 @@ def apply(inputs):
258287
259288def apply_spatial_pyramid_pooling_bottleneck (
260289 filters ,
290+ channel_axis ,
261291 hidden_filters = None ,
262292 kernel_sizes = (5 , 9 , 13 ),
263293 activation = "silu" ,
@@ -291,6 +321,7 @@ def apply_spatial_pyramid_pooling_bottleneck(
291321 def apply (x ):
292322 x = apply_darknet_conv_block (
293323 hidden_filters ,
324+ channel_axis ,
294325 kernel_size = 1 ,
295326 strides = 1 ,
296327 activation = activation ,
@@ -304,13 +335,15 @@ def apply(x):
304335 kernel_size ,
305336 strides = 1 ,
306337 padding = "same" ,
338+ data_format = keras .config .image_data_format (),
307339 name = f"{ name } _maxpool_{ kernel_size } " ,
308340 )(x [0 ])
309341 )
310342
311- x = layers .Concatenate (name = f"{ name } _concat" )(x )
343+ x = layers .Concatenate (axis = channel_axis , name = f"{ name } _concat" )(x )
312344 x = apply_darknet_conv_block (
313345 filters ,
346+ channel_axis ,
314347 kernel_size = 1 ,
315348 strides = 1 ,
316349 activation = activation ,
@@ -324,6 +357,7 @@ def apply(x):
324357
325358def apply_cross_stage_partial (
326359 filters ,
360+ channel_axis ,
327361 num_bottlenecks ,
328362 residual = True ,
329363 block_type = "basic_block" ,
@@ -361,6 +395,7 @@ def apply(inputs):
361395
362396 x1 = apply_darknet_conv_block (
363397 hidden_channels ,
398+ channel_axis ,
364399 kernel_size = 1 ,
365400 strides = 1 ,
366401 activation = activation ,
@@ -369,6 +404,7 @@ def apply(inputs):
369404
370405 x2 = apply_darknet_conv_block (
371406 hidden_channels ,
407+ channel_axis ,
372408 kernel_size = 1 ,
373409 strides = 1 ,
374410 activation = activation ,
@@ -379,13 +415,15 @@ def apply(inputs):
379415 residual_x = x1
380416 x1 = apply_darknet_conv_block (
381417 hidden_channels ,
418+ channel_axis ,
382419 kernel_size = 1 ,
383420 strides = 1 ,
384421 activation = activation ,
385422 name = f"{ name } _bottleneck_{ i } _conv1" ,
386423 )(x1 )
387424 x1 = ConvBlock (
388425 hidden_channels ,
426+ channel_axis ,
389427 kernel_size = 3 ,
390428 strides = 1 ,
391429 activation = activation ,
@@ -399,6 +437,7 @@ def apply(inputs):
399437 x = layers .Concatenate (name = f"{ name } _concat" )([x1 , x2 ])
400438 x = apply_darknet_conv_block (
401439 filters ,
440+ channel_axis ,
402441 kernel_size = 1 ,
403442 strides = 1 ,
404443 activation = activation ,
0 commit comments