Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 6 additions & 29 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def __call__(self, inputs, attrs, *args):
self._ignores.append('_node_name')
self._ignores.append('is_training')
self._ignores.append('_target_layout')
self._ignores.append('_input_0d_mismatch')

# apply custom check
if self._custom_check:
Expand Down Expand Up @@ -458,9 +457,9 @@ def _impl(inputs, attr, params):
def _expand_dims():
def _impl(inputs, attr, params):
dim_input = inputs.pop(1)
axis = params[dim_input.name_hint]
params.pop(dim_input.name_hint)
return _expand_dims_0d_aware(inputs[0], attr, axis=axis.asnumpy()[0])
axis = params.pop(_get_name_hint(dim_input)).asnumpy()[0]
return AttrCvt(op_name="expand_dims", ignores=['Tdim', 'N'],
extras={'axis': int(axis), 'num_newaxis': 1})(inputs, attr)
return _impl

def _resize_bilinear():
Expand Down Expand Up @@ -528,7 +527,7 @@ def _impl(inputs, attr, params):
def _pack():
def _impl(inputs, attr, params):
axis = int(attr["axis"])
inputs_reshaped = [_expand_dims_0d_aware(i, attr, axis=axis, num_newaxis=1) for i in inputs]
inputs_reshaped = [_op.expand_dims(i, axis=axis, num_newaxis=1) for i in inputs]
return _op.concatenate(inputs_reshaped, axis)
return _impl

Expand Down Expand Up @@ -820,9 +819,9 @@ def _transform_mask(stride_dim, ellipsis_mask):
pass
else:
final_output.append(out_shape[gather_index])
# Prevent 0-dim tensors which are not accepted by Relay

if not final_output:
final_output.append(1)
return out
return _op.reshape(out, newshape=tuple(final_output))
return _impl

Expand Down Expand Up @@ -984,16 +983,6 @@ def _impl(inputs, attr, params):
for split_item in splitted]), len(splitted))
return _impl

def _expand_dims_0d_aware(data, attr, axis, num_newaxis=1):
if data in attr['_input_0d_mismatch']:
return data if num_newaxis == 1 else \
AttrCvt(op_name="expand_dims", ignores=['Tdim', 'N'],
extras={'axis': int(axis), 'num_newaxis': int(num_newaxis-1)})([data], attr)

return AttrCvt(op_name="expand_dims", ignores=['Tdim', 'N'],
extras={'axis': int(axis), 'num_newaxis': int(num_newaxis)})([data], attr)


def _softmax():
def _impl(inputs, attr, params):
return AttrCvt(op_name='softmax',
Expand Down Expand Up @@ -1647,7 +1636,6 @@ def __init__(self):
self._output_shapes = {}
self._num_param = 0
self._num_rnn_layer = False
self._outputs_are_0d = {}
self._input_shapes = {}
self._loops = {}
self._branches = {}
Expand Down Expand Up @@ -1737,7 +1725,6 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
# Operator name 'Const' is treated as a parameter to build params dict.

input_shapes = {}
input_0d_mismatch = set()
attr = self._parse_attr(node.attr)

# Variable converted to Const will not have only value attr
Expand All @@ -1753,10 +1740,6 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
# Will infer shapes if the graph is not frozen with add_shapes=True
self._output_shapes[node.name] = [None]

self._outputs_are_0d[node.name] = [ \
not shape if isinstance(tshape, list) else False \
for tshape in self._output_shapes[node.name]]

if node.op == "Const":
# All Const nodes are Param nodes, lets parse
self._num_param += 1
Expand Down Expand Up @@ -1810,14 +1793,8 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
input_shape = self._output_shapes[node_name][0]
inputs.append(in_sym[0])
input_shapes[in_sym[0]] = input_shape
# This means the node is 1d in Relay and 0d in TF.
# See `_expand_dims_0d_aware`.
if node_name in self._outputs_are_0d \
and self._outputs_are_0d[node_name][tensor_slot] and input_shape:
input_0d_mismatch.add(in_sym[0])

attr['_input_shapes'] = input_shapes
attr['_input_0d_mismatch'] = input_0d_mismatch

if node.op in _control_flow_nodes:
op = self._convert_control_flow_operator(node, inputs,
Expand Down
17 changes: 17 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,7 @@ def _test_stridedslice(ip_shape, begin, end, stride, dtype,
def test_forward_stridedslice():
'''test StridedSlice'''

_test_stridedslice((2), [1], [1], [1], 'float32', shrink_axis_mask=1)
_test_stridedslice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], 'float32')
_test_stridedslice((3, 4, 3), [1, 0], [4, 3], [2, 1], 'float32', ellipsis_mask=8)
_test_stridedslice((3, 4, 3), [1, 0], [4, 2], [2, 1], 'float32', ellipsis_mask=2)
Expand Down Expand Up @@ -1475,6 +1476,21 @@ def test_forward_rel_ops():
_test_forward_rel_op([t1, t2], math_ops.equal)
_test_forward_rel_op([t1, t2], math_ops.not_equal)

#######################################################################
# ExpandDims
# ----------
def _test_forward_expand_dims(data, axis):
in1 = tf.placeholder(shape=data.shape, dtype=data.dtype, name='in1')
out = tf.expand_dims(in1, axis)
compare_tf_with_tvm([data], [in1.name], out.name)

def test_forward_expand_dims():
_test_forward_expand_dims(np.int32(1), 0)
_test_forward_expand_dims(np.array([1]), 0)
_test_forward_expand_dims(np.array([1]), -1)
_test_forward_expand_dims(np.array([[1], [2]]), 0)
_test_forward_expand_dims(np.array([[1], [2]]), 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if "axis = -1" is supported here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's added

_test_forward_expand_dims(np.array([[1], [2]]), -1)

#######################################################################
# Main
Expand Down Expand Up @@ -1509,6 +1525,7 @@ def test_forward_rel_ops():
test_forward_reverse_v2()
test_forward_pow_exp()
test_forward_sign()
test_forward_expand_dims()

# Reductions
test_forward_argminmax()
Expand Down