Skip to content

Commit 79f323d

Browse files
author
Li Xiaoquan
committed
[Relay][TensorFlow] Remove 'input_0d_mismatch' special handling
1 parent 3f835bd commit 79f323d

File tree

1 file changed

+6
-29
lines changed

1 file changed

+6
-29
lines changed

python/tvm/relay/frontend/tensorflow.py

Lines changed: 6 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@ def __call__(self, inputs, attrs, *args):
9999
self._ignores.append('_node_name')
100100
self._ignores.append('is_training')
101101
self._ignores.append('_target_layout')
102-
self._ignores.append('_input_0d_mismatch')
103102

104103
# apply custom check
105104
if self._custom_check:
@@ -458,9 +457,9 @@ def _impl(inputs, attr, params):
458457
def _expand_dims():
459458
def _impl(inputs, attr, params):
460459
dim_input = inputs.pop(1)
461-
axis = params[dim_input.name_hint]
462-
params.pop(dim_input.name_hint)
463-
return _expand_dims_0d_aware(inputs[0], attr, axis=axis.asnumpy()[0])
460+
axis = params.pop(_get_name_hint(dim_input)).asnumpy()[0]
461+
return AttrCvt(op_name="expand_dims", ignores=['Tdim', 'N'],
462+
extras={'axis': axis, 'num_newaxis': 1})(inputs, attr)
464463
return _impl
465464

466465
def _resize_bilinear():
@@ -528,7 +527,7 @@ def _impl(inputs, attr, params):
528527
def _pack():
529528
def _impl(inputs, attr, params):
530529
axis = int(attr["axis"])
531-
inputs_reshaped = [_expand_dims_0d_aware(i, attr, axis=axis, num_newaxis=1) for i in inputs]
530+
inputs_reshaped = [_op.expand_dims(i, axis=axis, num_newaxis=1) for i in inputs]
532531
return _op.concatenate(inputs_reshaped, axis)
533532
return _impl
534533

@@ -820,9 +819,9 @@ def _transform_mask(stride_dim, ellipsis_mask):
820819
pass
821820
else:
822821
final_output.append(out_shape[gather_index])
823-
# Prevent 0-dim tensors which are not accepted by Relay
822+
824823
if not final_output:
825-
final_output.append(1)
824+
return out
826825
return _op.reshape(out, newshape=tuple(final_output))
827826
return _impl
828827

@@ -984,16 +983,6 @@ def _impl(inputs, attr, params):
984983
for split_item in splitted]), len(splitted))
985984
return _impl
986985

987-
def _expand_dims_0d_aware(data, attr, axis, num_newaxis=1):
988-
if data in attr['_input_0d_mismatch']:
989-
return data if num_newaxis == 1 else \
990-
AttrCvt(op_name="expand_dims", ignores=['Tdim', 'N'],
991-
extras={'axis': int(axis), 'num_newaxis': int(num_newaxis-1)})([data], attr)
992-
993-
return AttrCvt(op_name="expand_dims", ignores=['Tdim', 'N'],
994-
extras={'axis': int(axis), 'num_newaxis': int(num_newaxis)})([data], attr)
995-
996-
997986
def _softmax():
998987
def _impl(inputs, attr, params):
999988
return AttrCvt(op_name='softmax',
@@ -1647,7 +1636,6 @@ def __init__(self):
16471636
self._output_shapes = {}
16481637
self._num_param = 0
16491638
self._num_rnn_layer = False
1650-
self._outputs_are_0d = {}
16511639
self._input_shapes = {}
16521640
self._loops = {}
16531641
self._branches = {}
@@ -1737,7 +1725,6 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
17371725
# Operator name 'Const' is treated as a parameter to build params dict.
17381726

17391727
input_shapes = {}
1740-
input_0d_mismatch = set()
17411728
attr = self._parse_attr(node.attr)
17421729

17431730
# Variable converted to Const will not have only value attr
@@ -1753,10 +1740,6 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
17531740
# Will infer shapes if the graph is not frozen with add_shapes=True
17541741
self._output_shapes[node.name] = [None]
17551742

1756-
self._outputs_are_0d[node.name] = [ \
1757-
not shape if isinstance(tshape, list) else False \
1758-
for tshape in self._output_shapes[node.name]]
1759-
17601743
if node.op == "Const":
17611744
# All Const nodes are Param nodes, lets parse
17621745
self._num_param += 1
@@ -1810,14 +1793,8 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
18101793
input_shape = self._output_shapes[node_name][0]
18111794
inputs.append(in_sym[0])
18121795
input_shapes[in_sym[0]] = input_shape
1813-
# This means the node is 1d in Relay and 0d in TF.
1814-
# See `_expand_dims_0d_aware`.
1815-
if node_name in self._outputs_are_0d \
1816-
and self._outputs_are_0d[node_name][tensor_slot] and input_shape:
1817-
input_0d_mismatch.add(in_sym[0])
18181796

18191797
attr['_input_shapes'] = input_shapes
1820-
attr['_input_0d_mismatch'] = input_0d_mismatch
18211798

18221799
if node.op in _control_flow_nodes:
18231800
op = self._convert_control_flow_operator(node, inputs,

0 commit comments

Comments
 (0)