Skip to content

Commit 5fb518e

Browse files
soiferjkevinthesun
authored andcommitted
[Relay][Frontend][TF] Fix Size operator (apache#4175)
* [Relay][Frontend][TF] Fix Size operator * Uncomment tests
1 parent 78af92b commit 5fb518e

File tree

3 files changed

+17
-7
lines changed

3 files changed

+17
-7
lines changed

python/tvm/relay/frontend/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def get_relay_op(op_name):
259259
op = None
260260
else:
261261
# try search op in various modules
262-
for candidate in (_op, _op.nn, _op.image, _op.vision):
262+
for candidate in (_op, _op.nn, _op.image, _op.vision, _op.contrib):
263263
op = getattr(candidate, op_name, None)
264264
if op is not None:
265265
break

python/tvm/relay/frontend/tensorflow.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1305,6 +1305,13 @@ def _impl(inputs, attr, params):
13051305
return _op.multiply(difference, difference)
13061306
return _impl
13071307

1308+
def _size():
1309+
def _impl(inputs, attr, params):
1310+
new_attr = attr
1311+
new_attr['out_type'] = attr['out_type'].name
1312+
return AttrCvt('ndarray_size', transforms={'out_type' : 'dtype'})(inputs, new_attr)
1313+
return _impl
1314+
13081315
# compatible operators that do NOT require any conversion.
13091316
_identity_list = []
13101317

@@ -1410,7 +1417,7 @@ def _impl(inputs, attr, params):
14101417
'Shape' : _shape(),
14111418
'Sigmoid' : AttrCvt('sigmoid'),
14121419
'Sign' : AttrCvt('sign'),
1413-
'Size' : AttrCvt('ndarray_size'),
1420+
'Size' : _size(),
14141421
'Slice' : _slice(),
14151422
'Softmax' : _softmax(),
14161423
'Softplus' : _softplus(),

tests/python/frontend/tensorflow/test_forward.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2184,15 +2184,18 @@ def check_mean(ishape, **kwargs):
21842184
def test_forward_size():
21852185
def check_size(ishape):
21862186
np_input = np.random.uniform(size=ishape).astype(np.float32)
2187+
2188+
# if all dimensions are constant, TF will optimize away size operator into constant
2189+
tf_input_shape = list(np_input.shape)
2190+
tf_input_shape[0] = None
2191+
21872192
with tf.Graph().as_default():
2188-
input = tf.placeholder(shape=np_input.shape, dtype=np_input.dtype, name='input')
2193+
input = tf.placeholder(shape=tf_input_shape, dtype=np_input.dtype, name='input')
21892194
tf.size(input, name='size')
21902195
compare_tf_with_tvm([np_input], ['input:0'], 'size:0')
21912196

2192-
if tf.__version__ < LooseVersion('1.1'):
2193-
check_size((10, 8, 16, 32))
2194-
check_size((10,))
2195-
check_size(())
2197+
check_size((10, 8, 16, 32))
2198+
check_size((10,))
21962199

21972200
#######################################################################
21982201
# All, Max, Min

0 commit comments

Comments
 (0)