Skip to content

Commit 0c1a4d8

Browse files
yongwwwwweic
authored andcommitted
[Relay][Frontend] Support TF Gather (apache#2935)
* [Relay][Frontend] Support TF Gather * fix comments
1 parent 1ce512a commit 0c1a4d8

File tree

2 files changed

+52
-9
lines changed

2 files changed

+52
-9
lines changed

python/tvm/relay/frontend/tensorflow.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -673,10 +673,13 @@ def _impl(inputs, attr, params):
673673
return _op.multiply(inputs[0], inputs[0])
674674
return _impl
675675

676-
def _gather_v2():
677-
"Tensorflow now support only gatherv2"
676+
def _gather():
677+
"GatherV2, Gather"
678678
def _impl(inputs, attr, params):
679-
axis = params[inputs.pop(2).name_hint].asnumpy()[0]
679+
680+
axis = 0
681+
if len(inputs) > 2:
682+
axis = params[inputs.pop(2).name_hint].asnumpy()[0]
680683
new_input = []
681684
new_input.append(inputs.pop(0))
682685
new_input.append(inputs.pop(0))
@@ -1013,7 +1016,8 @@ def _impl(inputs, attr, params):
10131016
'Shape' : _shape(),
10141017
'Sigmoid' : AttrCvt('sigmoid'),
10151018
'Fill' : _fill(),
1016-
'GatherV2' : _gather_v2(),
1019+
'GatherV2' : _gather(),
1020+
'Gather' : _gather(),
10171021
'StridedSlice' : _stridedSlice(),
10181022
'LRN' : _lrn(),
10191023
'Pad' : _pad('Pad'),

tests/python/frontend/tensorflow/test_forward.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from tensorflow.python.ops import variable_scope
2020
from tensorflow.python.ops import variables
2121
from tensorflow.python.ops import init_ops
22-
from tensorflow.core.framework import graph_pb2
2322

23+
from distutils.version import LooseVersion
2424
import tvm.relay.testing.tf as tf_testing
2525

2626
#######################################################################
@@ -473,11 +473,11 @@ def test_forward_stridedslice():
473473

474474

475475
#######################################################################
476-
# Gather
477-
# ------
476+
# Gather, GatherV2
477+
# ----------------
478478

479479
def _test_gather(ip_shape, indice_shape, indice_value, axis, dtype):
480-
""" One iteration of a Gather """
480+
""" One iteration of a GatherV2 """
481481

482482
tf.reset_default_graph()
483483
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
@@ -497,7 +497,7 @@ def _fill_indices(indice_value):
497497
compare_tf_with_tvm([np_data, np_indices], ['in_data:0', 'indices:0'], 'GatherV2:0')
498498

499499
def test_forward_gather():
500-
'''test gather layer'''
500+
'''test GatherV2 layer'''
501501
_test_gather((4,), (1,), 1, 0, 'int32')
502502
_test_gather((4,), (1,), 1, 0, 'float32')
503503
_test_gather((1,4), (1,), [0], 0, 'int32')
@@ -509,6 +509,44 @@ def test_forward_gather():
509509
_test_gather((3,3,3), (1,1,2), [[[1,0]]], 2, 'int32')
510510
_test_gather((4,3,5,6), (1,4), [[2,1,0,0]], 0, 'float32')
511511

512+
513+
def _test_gather_v1(ip_shape, indice_shape, indice_value, dtype):
514+
""" One iteration of a Gather"""
515+
tf.reset_default_graph()
516+
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
517+
indices = tf.placeholder("int32", indice_shape, name="indices")
518+
tf.gather(in_data, indices)
519+
np_data = np.random.uniform(size=ip_shape).astype(dtype)
520+
521+
def _fill_indices(indice_value):
522+
indices = np.array(ip_shape, dtype=dtype)
523+
if isinstance(indice_value, int):
524+
indices = np.array([indice_value], dtype='int32')
525+
else:
526+
indices = np.asarray(indice_value, dtype='int32')
527+
return indices
528+
np_indices = _fill_indices(indice_value)
529+
530+
compare_tf_with_tvm([np_data, np_indices], ['in_data:0', 'indices:0'], 'Gather:0')
531+
532+
533+
def test_forward_gather_v1():
534+
'''test gather layer'''
535+
536+
if tf.__version__ < LooseVersion('1.7'):
537+
_test_gather_v1((4,), (1, 2, 2), [[[1, 0], [0, 1]]], 'float32')
538+
_test_gather_v1((4,), (1,), 1, 'int32')
539+
_test_gather_v1((4,), (1,), 1, 'float32')
540+
_test_gather_v1((1, 4), (1,), [0], 'int32')
541+
_test_gather_v1((4,), (1, 2, 2), [[[1, 0], [0, 1]]], 'float32')
542+
_test_gather_v1((2, 2), (1, 2, 2), [[[1, 0], [0, 1]]], 'int32')
543+
_test_gather_v1((2, 2), (1, 2, 2), [[[1, 0], [0, 1]]], 'int32')
544+
_test_gather_v1((2, 2), (1, 2, 2), [[[1, 0], [0, 1]]], 'float32')
545+
_test_gather_v1((3, 3, 3), (1, 1, 2), [[[1, 0]]], 'int32')
546+
_test_gather_v1((3, 3, 3), (1, 1, 2), [[[1, 0]]], 'int32')
547+
_test_gather_v1((4, 3, 5, 6), (1, 4), [[2, 1, 0, 0]], 'float32')
548+
549+
512550
#######################################################################
513551
# Split
514552
# -----
@@ -1213,6 +1251,7 @@ def test_forward_rel_ops():
12131251
test_forward_crop()
12141252
test_forward_pad()
12151253
test_forward_gather()
1254+
test_forward_gather_v1()
12161255
test_forward_stridedslice()
12171256
test_forward_split()
12181257
test_forward_unstack()

0 commit comments

Comments
 (0)