@@ -120,7 +120,7 @@ def _impl(inputs, attr, params):
120120 attr ['data_format' ] = attr ['data_format' ].decode ("utf-8" )
121121 flip_layout = False
122122
123- input_shape = attr ['_input_shapes' ][inputs [0 ]][ 0 ]
123+ input_shape = attr ['_input_shapes' ][inputs [0 ]]
124124
125125 if attr ['data_format' ] == 'NHWC' :
126126 attr ['kernel_shape' ] = (attr ['ksize' ][1 ], attr ['ksize' ][2 ])
@@ -132,7 +132,7 @@ def _impl(inputs, attr, params):
132132 raise TypeError ("Unsupported data_format type : {}" .format (attr ['data_format' ]))
133133
134134 if attr ['_target_layout' ] == "NCHW" and attr ['data_format' ] == "NHWC" :
135- tmp_shape = attr ['_input_shapes' ][inputs [0 ]][ 0 ]
135+ tmp_shape = attr ['_input_shapes' ][inputs [0 ]]
136136 input_shape = [tmp_shape [ii ] for ii in (0 , 3 , 1 , 2 )]
137137 inputs [0 ] = _sym .transpose (inputs [0 ], axes = (0 , 3 , 1 , 2 ))
138138 attr ['data_format' ] = "NCHW"
@@ -185,13 +185,13 @@ def _impl(inputs, attr, params):
185185
186186 # NCHW Layout require weights transpose
187187 if attr ['data_format' ] == 'NCHW' :
188- tmp_shape = attr ['_input_shapes' ][inputs [1 ]][ 0 ]
188+ tmp_shape = attr ['_input_shapes' ][inputs [1 ]]
189189 tmp_shape = [tmp_shape [ii ] for ii in (3 , 2 , 0 , 1 )]
190190 inputs [1 ] = _sym .transpose (inputs [1 ], axes = (3 , 2 , 0 , 1 ))
191- attr ['_input_shapes' ][inputs [1 ]] = [ tmp_shape ]
191+ attr ['_input_shapes' ][inputs [1 ]] = tmp_shape
192192
193- input_shape = attr ['_input_shapes' ][inputs [0 ]][ 0 ]
194- weights_shape = attr ['_input_shapes' ][inputs [1 ]][ 0 ]
193+ input_shape = attr ['_input_shapes' ][inputs [0 ]]
194+ weights_shape = attr ['_input_shapes' ][inputs [1 ]]
195195
196196 if attr ['_target_layout' ] == "NCHW" and attr ['data_format' ] == "NHWC" :
197197 input_shape = [input_shape [ii ] for ii in (0 , 3 , 1 , 2 )]
@@ -484,7 +484,7 @@ def _impl(inputs, attr, params):
484484
485485def _shape ():
486486 def _impl (inputs , attr , params ):
487- return np .array (attr ['_input_shapes' ][inputs [0 ]][ 0 ] , dtype = 'int32' )
487+ return np .array (attr ['_input_shapes' ][inputs [0 ]], dtype = 'int32' )
488488 return _impl
489489
490490def _fill ():
@@ -565,7 +565,7 @@ def _impl(inputs, attr, params):
565565 new_axis_mask = int (attr .get ('new_axis_mask' , 0 ))
566566 shrink_axis_mask = int (attr .get ('shrink_axis_mask' , 0 ))
567567 data_shape = attr ['_input_shapes' ][inputs [0 ]]
568- data_dim = len (data_shape [ 0 ] )
568+ data_dim = len (data_shape )
569569 stride_dim = len (stride )
570570
571571 def _transform_mask (stride_dim , ellipsis_mask ):
@@ -596,7 +596,7 @@ def _transform_mask(stride_dim, ellipsis_mask):
596596 + new_axes_after_ellipsis ), data_dim )
597597 for i in range (final_index , to_index ):
598598 m_begin [final_index ] = 0
599- m_end [final_index ] = data_shape [0 ][ final_index ]
599+ m_end [final_index ] = data_shape [final_index ]
600600 m_stride [final_index ] = 1
601601 fshape_indices .append (final_index )
602602 final_index += 1
@@ -606,19 +606,19 @@ def _transform_mask(stride_dim, ellipsis_mask):
606606 if final_index == len (m_begin ):
607607 break
608608 if mask & begin_mask :
609- m_begin [final_index ] = data_shape [0 ][ final_index ] \
609+ m_begin [final_index ] = data_shape [final_index ] \
610610 if stride [index ] < 0 else 0
611611 elif begin [index ]:
612612 m_begin [final_index ] = begin [index ]
613613 if mask & end_mask :
614614 m_end [final_index ] = 0 if stride [index ] < 0 \
615- else data_shape [0 ][ final_index ]
615+ else data_shape [final_index ]
616616 elif end [index ]:
617617 m_end [final_index ] = end [index ]
618618 m_stride [final_index ] = stride [index ]
619619 if mask & shrink_axis_mask :
620620 #Tensorflow make axis with shrink_axis_mask as dimension 1
621- m_begin [final_index ] = data_shape [0 ][ final_index ] + begin [index ] \
621+ m_begin [final_index ] = data_shape [final_index ] + begin [index ] \
622622 if begin [index ] < 0 else begin [index ]
623623 m_end [final_index ] = begin [index ] + 1
624624 m_stride [final_index ] = 1
@@ -684,8 +684,8 @@ def _impl(inputs, in_state_c, in_state_h, attr, params):
684684 forget_bias = attr .pop ('forget_bias' )
685685 input_shape = attr ['_input_shapes' ][inputs [0 ]]
686686 weight_shape = attr ['_input_shapes' ][inputs [3 ]]
687- batch_size , input_size = input_shape [0 ][ 0 ] , input_shape [ 0 ] [1 ]
688- num_hidden_layers = weight_shape [0 ][ 1 ]
687+ batch_size , input_size = input_shape [0 ], input_shape [1 ]
688+ num_hidden_layers = weight_shape [1 ]
689689 num_hidden = num_hidden_layers // 4
690690
691691 in_data = _sym .reshape (in_data ,
@@ -741,11 +741,10 @@ def _impl(inputs, attr, params):
741741
742742def _rank ():
743743 def _impl (inputs , attr , params ):
744- input_shapes = attr ['_input_shapes' ][inputs [0 ]]
745- assert len (inputs ) == 1
744+ input_shape = attr ['_input_shapes' ][inputs [0 ]]
746745
747746 name = attr ["_node_name" ]
748- params [name ] = tvm .nd .array ([len (input_shapes [ 0 ] )])
747+ params [name ] = tvm .nd .array ([len (input_shape )])
749748 return _sym .Variable (name = name , shape = params [name ].shape )
750749 return _impl
751750
@@ -829,7 +828,7 @@ def _unpack():
829828 def _impl (inputs , attr , params ):
830829 input_node = inputs [0 ]
831830 axis = attr ['axis' ]
832- input_shape = attr ['_input_shapes' ][input_node ][ 0 ]
831+ input_shape = attr ['_input_shapes' ][input_node ]
833832 axis_length = input_shape [axis ]
834833 if axis_length < 0 :
835834 raise TypeError ("Unstack with unknown axis length" )
@@ -1018,8 +1017,8 @@ def _LSTMBlockCellWrapper(inputs, attr, params,
10181017 """LSTM cell warapper to prepare the inputs"""
10191018 input_shape = attr ['_input_shapes' ][inputs [0 ]]
10201019 weight_shape = attr ['_input_shapes' ][inputs [3 ]]
1021- batch_size = input_shape [0 ][ 0 ]
1022- num_hidden = weight_shape [0 ][ 1 ] // 4
1020+ batch_size = input_shape [0 ]
1021+ num_hidden = weight_shape [1 ] // 4
10231022
10241023 if layer == 0 :
10251024 #Create initial states placeholder in case of first layer
@@ -1240,7 +1239,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
12401239 tensor_slot = 0
12411240 input_shape = self ._output_shapes [node_name ][0 ]
12421241 inputs .append (in_sym )
1243- input_shapes [in_sym ] = [ input_shape ]
1242+ input_shapes [in_sym ] = input_shape
12441243 # This means the node is 1d in NNVM and 0d in TF.
12451244 # See `_expand_dims_0d_aware`.
12461245 if self ._outputs_are_0d [node_name ][tensor_slot ] and input_shape :
0 commit comments