Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 20 additions & 21 deletions nnvm/python/nnvm/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def _impl(inputs, attr, params):
attr['data_format'] = attr['data_format'].decode("utf-8")
flip_layout = False

input_shape = attr['_input_shapes'][inputs[0]][0]
input_shape = attr['_input_shapes'][inputs[0]]

if attr['data_format'] == 'NHWC':
attr['kernel_shape'] = (attr['ksize'][1], attr['ksize'][2])
Expand All @@ -132,7 +132,7 @@ def _impl(inputs, attr, params):
raise TypeError("Unsupported data_format type : {}".format(attr['data_format']))

if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC":
tmp_shape = attr['_input_shapes'][inputs[0]][0]
tmp_shape = attr['_input_shapes'][inputs[0]]
input_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)]
inputs[0] = _sym.transpose(inputs[0], axes=(0, 3, 1, 2))
attr['data_format'] = "NCHW"
Expand Down Expand Up @@ -185,13 +185,13 @@ def _impl(inputs, attr, params):

# NCHW Layout require weights transpose
if attr['data_format'] == 'NCHW':
tmp_shape = attr['_input_shapes'][inputs[1]][0]
tmp_shape = attr['_input_shapes'][inputs[1]]
tmp_shape = [tmp_shape[ii] for ii in (3, 2, 0, 1)]
inputs[1] = _sym.transpose(inputs[1], axes=(3, 2, 0, 1))
attr['_input_shapes'][inputs[1]] = [tmp_shape]
attr['_input_shapes'][inputs[1]] = tmp_shape

input_shape = attr['_input_shapes'][inputs[0]][0]
weights_shape = attr['_input_shapes'][inputs[1]][0]
input_shape = attr['_input_shapes'][inputs[0]]
weights_shape = attr['_input_shapes'][inputs[1]]

if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC":
input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)]
Expand Down Expand Up @@ -484,7 +484,7 @@ def _impl(inputs, attr, params):

def _shape():
def _impl(inputs, attr, params):
return np.array(attr['_input_shapes'][inputs[0]][0], dtype='int32')
return np.array(attr['_input_shapes'][inputs[0]], dtype='int32')
return _impl

def _fill():
Expand Down Expand Up @@ -565,7 +565,7 @@ def _impl(inputs, attr, params):
new_axis_mask = int(attr.get('new_axis_mask', 0))
shrink_axis_mask = int(attr.get('shrink_axis_mask', 0))
data_shape = attr['_input_shapes'][inputs[0]]
data_dim = len(data_shape[0])
data_dim = len(data_shape)
stride_dim = len(stride)

def _transform_mask(stride_dim, ellipsis_mask):
Expand Down Expand Up @@ -596,7 +596,7 @@ def _transform_mask(stride_dim, ellipsis_mask):
+ new_axes_after_ellipsis), data_dim)
for i in range(final_index, to_index):
m_begin[final_index] = 0
m_end[final_index] = data_shape[0][final_index]
m_end[final_index] = data_shape[final_index]
m_stride[final_index] = 1
fshape_indices.append(final_index)
final_index += 1
Expand All @@ -606,19 +606,19 @@ def _transform_mask(stride_dim, ellipsis_mask):
if final_index == len(m_begin):
break
if mask & begin_mask:
m_begin[final_index] = data_shape[0][final_index] \
m_begin[final_index] = data_shape[final_index] \
if stride[index] < 0 else 0
elif begin[index]:
m_begin[final_index] = begin[index]
if mask & end_mask:
m_end[final_index] = 0 if stride[index] < 0 \
else data_shape[0][final_index]
else data_shape[final_index]
elif end[index]:
m_end[final_index] = end[index]
m_stride[final_index] = stride[index]
if mask & shrink_axis_mask:
#Tensorflow make axis with shrink_axis_mask as dimension 1
m_begin[final_index] = data_shape[0][final_index] + begin[index] \
m_begin[final_index] = data_shape[final_index] + begin[index] \
if begin[index] < 0 else begin[index]
m_end[final_index] = begin[index] + 1
m_stride[final_index] = 1
Expand Down Expand Up @@ -684,8 +684,8 @@ def _impl(inputs, in_state_c, in_state_h, attr, params):
forget_bias = attr.pop('forget_bias')
input_shape = attr['_input_shapes'][inputs[0]]
weight_shape = attr['_input_shapes'][inputs[3]]
batch_size, input_size = input_shape[0][0], input_shape[0][1]
num_hidden_layers = weight_shape[0][1]
batch_size, input_size = input_shape[0], input_shape[1]
num_hidden_layers = weight_shape[1]
num_hidden = num_hidden_layers // 4

in_data = _sym.reshape(in_data,
Expand Down Expand Up @@ -741,11 +741,10 @@ def _impl(inputs, attr, params):

def _rank():
def _impl(inputs, attr, params):
input_shapes = attr['_input_shapes'][inputs[0]]
assert len(inputs) == 1
input_shape = attr['_input_shapes'][inputs[0]]

name = attr["_node_name"]
params[name] = tvm.nd.array([len(input_shapes[0])])
params[name] = tvm.nd.array([len(input_shape)])
return _sym.Variable(name=name, shape=params[name].shape)
return _impl

Expand Down Expand Up @@ -829,7 +828,7 @@ def _unpack():
def _impl(inputs, attr, params):
input_node = inputs[0]
axis = attr['axis']
input_shape = attr['_input_shapes'][input_node][0]
input_shape = attr['_input_shapes'][input_node]
axis_length = input_shape[axis]
if axis_length < 0:
raise TypeError("Unstack with unknown axis length")
Expand Down Expand Up @@ -1018,8 +1017,8 @@ def _LSTMBlockCellWrapper(inputs, attr, params,
"""LSTM cell warapper to prepare the inputs"""
input_shape = attr['_input_shapes'][inputs[0]]
weight_shape = attr['_input_shapes'][inputs[3]]
batch_size = input_shape[0][0]
num_hidden = weight_shape[0][1] // 4
batch_size = input_shape[0]
num_hidden = weight_shape[1] // 4

if layer == 0:
#Create initial states placeholder in case of first layer
Expand Down Expand Up @@ -1240,7 +1239,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
tensor_slot = 0
input_shape = self._output_shapes[node_name][0]
inputs.append(in_sym)
input_shapes[in_sym] = [input_shape]
input_shapes[in_sym] = input_shape
# This means the node is 1d in NNVM and 0d in TF.
# See `_expand_dims_0d_aware`.
if self._outputs_are_0d[node_name][tensor_slot] and input_shape:
Expand Down