Skip to content

Commit 7dc9c7e

Browse files
yongwwwwweic
authored andcommitted
[Relay][Keras] Permute, Softmax support (apache#3618)
1 parent 671fd77 commit 7dc9c7e

File tree

2 files changed

+47
-30
lines changed

2 files changed

+47
-30
lines changed

python/tvm/relay/frontend/keras.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,9 @@ def _convert_activation(inexpr, keras_layer, _):
115115

116116
def _convert_advanced_activation(inexpr, keras_layer, etab):
117117
act_type = type(keras_layer).__name__
118+
119+
if act_type == 'Softmax':
120+
return _op.nn.softmax(inexpr, axis=1)
118121
if act_type == 'ReLU':
119122
if keras_layer.max_value:
120123
return _op.clip(inexpr, a_min=0., a_max=float(keras_layer.max_value))
@@ -160,6 +163,8 @@ def _convert_merge(inexpr, keras_layer, _):
160163
'Operator {} is not supported in frontend Keras.'.format(merge_type))
161164
return ret
162165

166+
def _convert_permute(inexpr, keras_layer, _):
167+
return _op.transpose(inexpr, axes=(0,) + keras_layer.dims)
163168

164169
def _convert_dense(inexpr, keras_layer, etab):
165170
weightList = keras_layer.get_weights()
@@ -574,6 +579,7 @@ def _default_skip(inexpr, keras_layer, _): # pylint: disable=unused-argument
574579
_convert_map = {
575580
'Dense' : _convert_dense,
576581
'Activation' : _convert_activation,
582+
'Softmax' : _convert_advanced_activation,
577583
'ReLU' : _convert_advanced_activation,
578584
'LeakyReLU' : _convert_advanced_activation,
579585
'PReLU' : _convert_advanced_activation,
@@ -620,7 +626,7 @@ def _default_skip(inexpr, keras_layer, _): # pylint: disable=unused-argument
620626
'Average' : _convert_merge,
621627
'Maximum' : _convert_merge,
622628
# 'Dot' : _convert_merge,
623-
# 'Permute' : _convert_permute,
629+
'Permute' : _convert_permute,
624630
# 'Embedding' : _convert_embedding,
625631
# 'RepeatVector' : _convert_repeat_vector,
626632

@@ -632,11 +638,15 @@ def _default_skip(inexpr, keras_layer, _): # pylint: disable=unused-argument
632638

633639

634640
def _check_unsupported_layers(model):
641+
missing_ops = set()
635642
for layer in model.layers:
636643
op_name = type(layer).__name__
637644
if op_name not in _convert_map:
638-
raise tvm.error.OpNotImplemented(
639-
'Operator {} is not supported in frontend Keras.'.format(op_name))
645+
missing_ops.add(op_name)
646+
647+
if missing_ops:
648+
raise NotImplementedError( \
649+
"The following operators are not implemented: {}".format(missing_ops))
640650

641651

642652
def keras_op_to_relay(inexpr, keras_layer, outname, etab):

tests/python/frontend/keras/test_forward.py

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def to_channels_last(arr):
7373

7474

7575
def test_forward_merge():
76-
data = keras.layers.Input(shape=(32,32,3))
76+
data = keras.layers.Input(shape=(32, 32, 3))
7777
x = keras.layers.Conv2D(8, (3, 3), padding="same")(data)
7878
y = keras.layers.Conv2D(8, (3, 3), padding="same")(x)
7979
z = keras.layers.Conv2D(8, (3, 3), padding="same")(y)
@@ -93,7 +93,7 @@ def test_forward_merge():
9393

9494

9595
def test_forward_activations():
96-
data = keras.layers.Input(shape=(32,32,3))
96+
data = keras.layers.Input(shape=(32, 32, 3))
9797
act_funcs = [keras.layers.Activation('softmax'),
9898
keras.layers.Activation('softplus'),
9999
keras.layers.Activation('relu'),
@@ -103,6 +103,7 @@ def test_forward_activations():
103103
keras.layers.Activation('tanh'),
104104
keras.layers.Activation('linear'),
105105
keras.layers.Activation('selu'),
106+
keras.layers.Softmax(),
106107
keras.layers.ReLU(),
107108
keras.layers.ReLU(max_value=6.),
108109
keras.layers.LeakyReLU(alpha=0.3),
@@ -116,13 +117,18 @@ def test_forward_activations():
116117

117118

118119
def test_forward_dense():
119-
data = keras.layers.Input(shape=(32,32,1))
120+
data = keras.layers.Input(shape=(32, 32, 1))
120121
x = keras.layers.Flatten()(data)
121122
x = keras.layers.Dropout(0.5)(x)
122123
x = keras.layers.Dense(10, activation='relu', kernel_initializer='uniform')(x)
123124
keras_model = keras.models.Model(data, x)
124125
verify_keras_frontend(keras_model)
125126

127+
def test_forward_permute():
128+
data = keras.layers.Input(shape=(2, 3, 4))
129+
x = keras.layers.Permute([2, 3, 1])(data)
130+
keras_model = keras.models.Model(data, x)
131+
verify_keras_frontend(keras_model, need_transpose=False)
126132

127133
def test_forward_sequential():
128134
keras_model = keras.models.Sequential([
@@ -136,7 +142,7 @@ def test_forward_sequential():
136142

137143

138144
def test_forward_pool():
139-
data = keras.layers.Input(shape=(32,32,1))
145+
data = keras.layers.Input(shape=(32, 32, 1))
140146
# maxpool
141147
x = keras.layers.MaxPooling2D((3, 3), strides=(1, 1), padding='same')(data)
142148
keras_model = keras.models.Model(data, x)
@@ -148,36 +154,36 @@ def test_forward_pool():
148154

149155

150156
def test_forward_conv():
151-
data = keras.layers.Input(shape=(32,32,3))
152-
conv_funcs = [keras.layers.Conv2D(filters=10, kernel_size=(3,3),
153-
strides=(2,2), padding='same'),
154-
keras.layers.Conv2D(filters=10, kernel_size=(3,3),
155-
dilation_rate=(2,2), padding='same'),
156-
keras.layers.DepthwiseConv2D(kernel_size=(3,3), padding='same'),
157-
keras.layers.Conv2DTranspose(filters=10, kernel_size=(3,3), padding='valid'),
158-
keras.layers.SeparableConv2D(filters=10, kernel_size=(3,3), padding='same')]
157+
data = keras.layers.Input(shape=(32, 32, 3))
158+
conv_funcs = [keras.layers.Conv2D(filters=10, kernel_size=(3, 3),
159+
strides=(2, 2), padding='same'),
160+
keras.layers.Conv2D(filters=10, kernel_size=(3, 3),
161+
dilation_rate=(2, 2), padding='same'),
162+
keras.layers.DepthwiseConv2D(kernel_size=(3, 3), padding='same'),
163+
keras.layers.Conv2DTranspose(filters=10, kernel_size=(3, 3), padding='valid'),
164+
keras.layers.SeparableConv2D(filters=10, kernel_size=(3, 3), padding='same')]
159165
for conv_func in conv_funcs:
160166
x = conv_func(data)
161167
keras_model = keras.models.Model(data, x)
162168
verify_keras_frontend(keras_model)
163169

164170

165171
def test_forward_upsample(interpolation='nearest'):
166-
data = keras.layers.Input(shape=(32,32,3))
167-
x = keras.layers.UpSampling2D(size=(3,3), interpolation=interpolation)(data)
172+
data = keras.layers.Input(shape=(32, 32, 3))
173+
x = keras.layers.UpSampling2D(size=(3, 3), interpolation=interpolation)(data)
168174
keras_model = keras.models.Model(data, x)
169175
verify_keras_frontend(keras_model)
170176

171177

172178
def test_forward_reshape():
173-
data = keras.layers.Input(shape=(32,32,3))
174-
x = keras.layers.Reshape(target_shape=(32,32,3))(data)
179+
data = keras.layers.Input(shape=(32, 32, 3))
180+
x = keras.layers.Reshape(target_shape=(32, 32, 3))(data)
175181
keras_model = keras.models.Model(data, x)
176182
verify_keras_frontend(keras_model)
177183

178184

179185
def test_forward_crop():
180-
data = keras.layers.Input(shape=(32,32,3))
186+
data = keras.layers.Input(shape=(32, 32, 3))
181187
x = keras.layers.Cropping2D(cropping=((1, 1), (1, 1)))(data)
182188
x = keras.layers.Cropping2D(cropping=(1, 1))(x)
183189
x = keras.layers.Cropping2D(cropping=1)(x)
@@ -190,8 +196,8 @@ def test_forward_crop():
190196

191197

192198
def test_forward_multi_inputs():
193-
data1 = keras.layers.Input(shape=(32,32,3))
194-
data2 = keras.layers.Input(shape=(32,32,3))
199+
data1 = keras.layers.Input(shape=(32, 32, 3))
200+
data2 = keras.layers.Input(shape=(32, 32, 3))
195201
x = keras.layers.Conv2D(8, (3, 3), padding="same")(data1)
196202
y = keras.layers.Conv2D(8, (3, 3), padding="same")(data2)
197203
z = keras.layers.Average()([x, y])
@@ -201,7 +207,7 @@ def test_forward_multi_inputs():
201207

202208

203209
def test_forward_multi_outputs():
204-
data = keras.layers.Input(shape=(32,32,3))
210+
data = keras.layers.Input(shape=(32, 32, 3))
205211
x = keras.layers.Conv2D(8, (3, 3), padding="same")(data)
206212
x = keras.layers.GlobalAveragePooling2D()(x)
207213
y = keras.layers.Conv2D(8, (3, 3), padding="same")(data)
@@ -212,7 +218,7 @@ def test_forward_multi_outputs():
212218

213219
def test_forward_reuse_layers():
214220
# reuse conv2d
215-
data = keras.layers.Input(shape=(32,32,3))
221+
data = keras.layers.Input(shape=(32, 32, 3))
216222
conv2d = keras.layers.Conv2D(8, (3, 3), padding="same")
217223
x = conv2d(data)
218224
y = conv2d(data)
@@ -221,7 +227,7 @@ def test_forward_reuse_layers():
221227
keras_model = keras.models.Model(data, z)
222228
verify_keras_frontend(keras_model)
223229
# reuse add
224-
data = keras.layers.Input(shape=(32,32,3))
230+
data = keras.layers.Input(shape=(32, 32, 3))
225231
x = keras.layers.Conv2D(8, (3, 3), padding="same")(data)
226232
add = keras.layers.Add()
227233
x = add([x, x])
@@ -232,7 +238,7 @@ def test_forward_reuse_layers():
232238

233239

234240
def test_forward_rnn():
235-
data = keras.layers.Input(shape=(1,32))
241+
data = keras.layers.Input(shape=(1, 32))
236242
rnn_funcs = [keras.layers.LSTM(units=16, return_state=False,
237243
recurrent_activation='sigmoid', activation='tanh'),
238244
keras.layers.SimpleRNN(units=16, return_state=False,
@@ -247,32 +253,33 @@ def test_forward_rnn():
247253

248254
def test_forward_vgg16():
249255
keras_model = keras.applications.VGG16(include_top=True, weights='imagenet',
250-
input_shape=(224,224,3), classes=1000)
256+
input_shape=(224, 224, 3), classes=1000)
251257
verify_keras_frontend(keras_model)
252258

253259

254260
def test_forward_xception():
255261
keras_model = keras.applications.Xception(include_top=True, weights='imagenet',
256-
input_shape=(299,299,3), classes=1000)
262+
input_shape=(299, 299, 3), classes=1000)
257263
verify_keras_frontend(keras_model)
258264

259265

260266
def test_forward_resnet50():
261267
keras_model = keras.applications.ResNet50(include_top=True, weights='imagenet',
262-
input_shape=(224,224,3), classes=1000)
268+
input_shape=(224, 224, 3), classes=1000)
263269
verify_keras_frontend(keras_model)
264270

265271

266272
def test_forward_mobilenet():
267273
keras_model = keras.applications.MobileNet(include_top=True, weights='imagenet',
268-
input_shape=(224,224,3), classes=1000)
274+
input_shape=(224, 224, 3), classes=1000)
269275
verify_keras_frontend(keras_model)
270276

271277

272278
if __name__ == '__main__':
273279
test_forward_merge()
274280
test_forward_activations()
275281
test_forward_dense()
282+
test_forward_permute()
276283
test_forward_sequential()
277284
test_forward_pool()
278285
test_forward_conv()

0 commit comments

Comments
 (0)