@@ -864,29 +864,14 @@ def _convert_reshape(inexpr, keras_layer, etab):
864864 _check_data_format (keras_layer )
865865 inshape = keras_layer .input_shape # includes batch
866866 tshape = keras_layer .target_shape # no batch
867- if len (inshape ) == 3 and len (tshape ) == 1 :
868- # (?, a, b) -> (-1, ab)
869- shape = (- 1 , tshape [0 ])
870- elif len (inshape ) in [2 , 3 ] and len (tshape ) == 2 :
871- # (?, cc) -> (-1, c, c)
872- # (?, a, b) -> (-1, c, c)
873- assert tshape [0 ] == tshape [1 ], "Only supports square target shapes, but got {}" .format (
874- tshape
875- )
876- shape = (- 1 ,) + tshape
877- else :
878- # (?, h, w, c) -> (-1, c, H, W)
879- # (?, h, w, c) -> (-1, c, hw)
880- # (?, hw, c) -> (-1, c, h, w)
881- ch = inshape [- 1 ]
882- assert ch == tshape [- 1 ], (
883- "Only supports last dimension in target shape being equal to "
884- "the channel number of input tensor."
885- )
886- if etab .data_layout == "NCHW" :
887- shape = (- 1 , ch ) + tshape [:- 1 ]
888- else :
889- shape = (- 1 ,) + tshape [:- 1 ] + (ch ,)
867+ shape = (- 1 ,) + tshape
868+
869+ if etab .data_layout == "NCHW" and (len (inshape ) > 3 or len (tshape ) > 2 ):
870+ # Perform reshape in original NHWC format.
871+ inexpr = _op .transpose (inexpr , [0 ] + list (range (2 , len (inshape ))) + [1 ])
872+ inexpr = _op .reshape (inexpr , newshape = shape )
873+ return _op .transpose (inexpr , axes = [0 , - 1 ] + list (range (1 , len (shape ) - 1 )))
874+
890875 return _op .reshape (inexpr , newshape = shape )
891876
892877
0 commit comments