Skip to content

Commit f6c3f99

Browse files
alexeyrsrkreddy1238
authored andcommitted
[FRONTEND][TENSORFLOW] Use input shapes directly instead of 1-element lists (#2242)
1 parent 6d1f4c0 commit f6c3f99

File tree

1 file changed

+20
-21
lines changed

1 file changed

+20
-21
lines changed

nnvm/python/nnvm/frontend/tensorflow.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def _impl(inputs, attr, params):
120120
attr['data_format'] = attr['data_format'].decode("utf-8")
121121
flip_layout = False
122122

123-
input_shape = attr['_input_shapes'][inputs[0]][0]
123+
input_shape = attr['_input_shapes'][inputs[0]]
124124

125125
if attr['data_format'] == 'NHWC':
126126
attr['kernel_shape'] = (attr['ksize'][1], attr['ksize'][2])
@@ -132,7 +132,7 @@ def _impl(inputs, attr, params):
132132
raise TypeError("Unsupported data_format type : {}".format(attr['data_format']))
133133

134134
if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC":
135-
tmp_shape = attr['_input_shapes'][inputs[0]][0]
135+
tmp_shape = attr['_input_shapes'][inputs[0]]
136136
input_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)]
137137
inputs[0] = _sym.transpose(inputs[0], axes=(0, 3, 1, 2))
138138
attr['data_format'] = "NCHW"
@@ -185,13 +185,13 @@ def _impl(inputs, attr, params):
185185

186186
# NCHW Layout require weights transpose
187187
if attr['data_format'] == 'NCHW':
188-
tmp_shape = attr['_input_shapes'][inputs[1]][0]
188+
tmp_shape = attr['_input_shapes'][inputs[1]]
189189
tmp_shape = [tmp_shape[ii] for ii in (3, 2, 0, 1)]
190190
inputs[1] = _sym.transpose(inputs[1], axes=(3, 2, 0, 1))
191-
attr['_input_shapes'][inputs[1]] = [tmp_shape]
191+
attr['_input_shapes'][inputs[1]] = tmp_shape
192192

193-
input_shape = attr['_input_shapes'][inputs[0]][0]
194-
weights_shape = attr['_input_shapes'][inputs[1]][0]
193+
input_shape = attr['_input_shapes'][inputs[0]]
194+
weights_shape = attr['_input_shapes'][inputs[1]]
195195

196196
if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC":
197197
input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)]
@@ -484,7 +484,7 @@ def _impl(inputs, attr, params):
484484

485485
def _shape():
486486
def _impl(inputs, attr, params):
487-
return np.array(attr['_input_shapes'][inputs[0]][0], dtype='int32')
487+
return np.array(attr['_input_shapes'][inputs[0]], dtype='int32')
488488
return _impl
489489

490490
def _fill():
@@ -565,7 +565,7 @@ def _impl(inputs, attr, params):
565565
new_axis_mask = int(attr.get('new_axis_mask', 0))
566566
shrink_axis_mask = int(attr.get('shrink_axis_mask', 0))
567567
data_shape = attr['_input_shapes'][inputs[0]]
568-
data_dim = len(data_shape[0])
568+
data_dim = len(data_shape)
569569
stride_dim = len(stride)
570570

571571
def _transform_mask(stride_dim, ellipsis_mask):
@@ -596,7 +596,7 @@ def _transform_mask(stride_dim, ellipsis_mask):
596596
+ new_axes_after_ellipsis), data_dim)
597597
for i in range(final_index, to_index):
598598
m_begin[final_index] = 0
599-
m_end[final_index] = data_shape[0][final_index]
599+
m_end[final_index] = data_shape[final_index]
600600
m_stride[final_index] = 1
601601
fshape_indices.append(final_index)
602602
final_index += 1
@@ -606,19 +606,19 @@ def _transform_mask(stride_dim, ellipsis_mask):
606606
if final_index == len(m_begin):
607607
break
608608
if mask & begin_mask:
609-
m_begin[final_index] = data_shape[0][final_index] \
609+
m_begin[final_index] = data_shape[final_index] \
610610
if stride[index] < 0 else 0
611611
elif begin[index]:
612612
m_begin[final_index] = begin[index]
613613
if mask & end_mask:
614614
m_end[final_index] = 0 if stride[index] < 0 \
615-
else data_shape[0][final_index]
615+
else data_shape[final_index]
616616
elif end[index]:
617617
m_end[final_index] = end[index]
618618
m_stride[final_index] = stride[index]
619619
if mask & shrink_axis_mask:
620620
#Tensorflow make axis with shrink_axis_mask as dimension 1
621-
m_begin[final_index] = data_shape[0][final_index] + begin[index] \
621+
m_begin[final_index] = data_shape[final_index] + begin[index] \
622622
if begin[index] < 0 else begin[index]
623623
m_end[final_index] = begin[index] + 1
624624
m_stride[final_index] = 1
@@ -684,8 +684,8 @@ def _impl(inputs, in_state_c, in_state_h, attr, params):
684684
forget_bias = attr.pop('forget_bias')
685685
input_shape = attr['_input_shapes'][inputs[0]]
686686
weight_shape = attr['_input_shapes'][inputs[3]]
687-
batch_size, input_size = input_shape[0][0], input_shape[0][1]
688-
num_hidden_layers = weight_shape[0][1]
687+
batch_size, input_size = input_shape[0], input_shape[1]
688+
num_hidden_layers = weight_shape[1]
689689
num_hidden = num_hidden_layers // 4
690690

691691
in_data = _sym.reshape(in_data,
@@ -741,11 +741,10 @@ def _impl(inputs, attr, params):
741741

742742
def _rank():
743743
def _impl(inputs, attr, params):
744-
input_shapes = attr['_input_shapes'][inputs[0]]
745-
assert len(inputs) == 1
744+
input_shape = attr['_input_shapes'][inputs[0]]
746745

747746
name = attr["_node_name"]
748-
params[name] = tvm.nd.array([len(input_shapes[0])])
747+
params[name] = tvm.nd.array([len(input_shape)])
749748
return _sym.Variable(name=name, shape=params[name].shape)
750749
return _impl
751750

@@ -829,7 +828,7 @@ def _unpack():
829828
def _impl(inputs, attr, params):
830829
input_node = inputs[0]
831830
axis = attr['axis']
832-
input_shape = attr['_input_shapes'][input_node][0]
831+
input_shape = attr['_input_shapes'][input_node]
833832
axis_length = input_shape[axis]
834833
if axis_length < 0:
835834
raise TypeError("Unstack with unknown axis length")
@@ -1018,8 +1017,8 @@ def _LSTMBlockCellWrapper(inputs, attr, params,
10181017
"""LSTM cell warapper to prepare the inputs"""
10191018
input_shape = attr['_input_shapes'][inputs[0]]
10201019
weight_shape = attr['_input_shapes'][inputs[3]]
1021-
batch_size = input_shape[0][0]
1022-
num_hidden = weight_shape[0][1] // 4
1020+
batch_size = input_shape[0]
1021+
num_hidden = weight_shape[1] // 4
10231022

10241023
if layer == 0:
10251024
#Create initial states placeholder in case of first layer
@@ -1240,7 +1239,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
12401239
tensor_slot = 0
12411240
input_shape = self._output_shapes[node_name][0]
12421241
inputs.append(in_sym)
1243-
input_shapes[in_sym] = [input_shape]
1242+
input_shapes[in_sym] = input_shape
12441243
# This means the node is 1d in NNVM and 0d in TF.
12451244
# See `_expand_dims_0d_aware`.
12461245
if self._outputs_are_0d[node_name][tensor_slot] and input_shape:

0 commit comments

Comments
 (0)