Skip to content

Commit fe576b5

Browse files
Trevor MorrisLokiiiiii
authored andcommitted
Make keras reshape less restrictive (apache#7446)
1 parent f61325b commit fe576b5

File tree

2 files changed

+18
-23
lines changed

2 files changed

+18
-23
lines changed

python/tvm/relay/frontend/keras.py

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -864,29 +864,14 @@ def _convert_reshape(inexpr, keras_layer, etab):
864864
_check_data_format(keras_layer)
865865
inshape = keras_layer.input_shape # includes batch
866866
tshape = keras_layer.target_shape # no batch
867-
if len(inshape) == 3 and len(tshape) == 1:
868-
# (?, a, b) -> (-1, ab)
869-
shape = (-1, tshape[0])
870-
elif len(inshape) in [2, 3] and len(tshape) == 2:
871-
# (?, cc) -> (-1, c, c)
872-
# (?, a, b) -> (-1, c, c)
873-
assert tshape[0] == tshape[1], "Only supports square target shapes, but got {}".format(
874-
tshape
875-
)
876-
shape = (-1,) + tshape
877-
else:
878-
# (?, h, w, c) -> (-1, c, H, W)
879-
# (?, h, w, c) -> (-1, c, hw)
880-
# (?, hw, c) -> (-1, c, h, w)
881-
ch = inshape[-1]
882-
assert ch == tshape[-1], (
883-
"Only supports last dimension in target shape being equal to "
884-
"the channel number of input tensor."
885-
)
886-
if etab.data_layout == "NCHW":
887-
shape = (-1, ch) + tshape[:-1]
888-
else:
889-
shape = (-1,) + tshape[:-1] + (ch,)
867+
shape = (-1,) + tshape
868+
869+
if etab.data_layout == "NCHW" and (len(inshape) > 3 or len(tshape) > 2):
870+
# Perform reshape in original NHWC format.
871+
inexpr = _op.transpose(inexpr, [0] + list(range(2, len(inshape))) + [1])
872+
inexpr = _op.reshape(inexpr, newshape=shape)
873+
return _op.transpose(inexpr, axes=[0, -1] + list(range(1, len(shape) - 1)))
874+
890875
return _op.reshape(inexpr, newshape=shape)
891876

892877

tests/python/frontend/keras/test_forward.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,16 @@ def test_forward_reshape(self, keras):
350350
x = keras.layers.Reshape(target_shape=(4, 4))(data)
351351
keras_model = keras.models.Model(data, x)
352352
verify_keras_frontend(keras_model, need_transpose=False)
353+
# "non-square" target shape
354+
data = keras.layers.Input(shape=(15,))
355+
x = keras.layers.Reshape(target_shape=(5, 3))(data)
356+
keras_model = keras.models.Model(data, x)
357+
verify_keras_frontend(keras_model, need_transpose=False)
358+
# modify channel dim
359+
data = keras.layers.Input(shape=(3, 2, 4))
360+
x = keras.layers.Reshape(target_shape=(3, 8))(data)
361+
keras_model = keras.models.Model(data, x)
362+
verify_keras_frontend(keras_model)
353363

354364
def test_forward_crop(self, keras):
355365
data = keras.layers.Input(shape=(32, 32, 3))

0 commit comments

Comments
 (0)