1919from tensorflow .python .ops import variable_scope
2020from tensorflow .python .ops import variables
2121from tensorflow .python .ops import init_ops
22- from tensorflow .core .framework import graph_pb2
2322
23+ from distutils .version import LooseVersion
2424import tvm .relay .testing .tf as tf_testing
2525
2626#######################################################################
@@ -473,11 +473,11 @@ def test_forward_stridedslice():
473473
474474
475475#######################################################################
476- # Gather
477- # ------
476+ # Gather, GatherV2
477+ # ----------------
478478
479479def _test_gather (ip_shape , indice_shape , indice_value , axis , dtype ):
480- """ One iteration of a Gather """
480+ """ One iteration of a GatherV2 """
481481
482482 tf .reset_default_graph ()
483483 in_data = tf .placeholder (dtype , ip_shape , name = "in_data" )
@@ -497,7 +497,7 @@ def _fill_indices(indice_value):
497497 compare_tf_with_tvm ([np_data , np_indices ], ['in_data:0' , 'indices:0' ], 'GatherV2:0' )
498498
499499def test_forward_gather ():
500- '''test gather layer'''
500+ '''test GatherV2 layer'''
501501 _test_gather ((4 ,), (1 ,), 1 , 0 , 'int32' )
502502 _test_gather ((4 ,), (1 ,), 1 , 0 , 'float32' )
503503 _test_gather ((1 ,4 ), (1 ,), [0 ], 0 , 'int32' )
@@ -509,6 +509,44 @@ def test_forward_gather():
509509 _test_gather ((3 ,3 ,3 ), (1 ,1 ,2 ), [[[1 ,0 ]]], 2 , 'int32' )
510510 _test_gather ((4 ,3 ,5 ,6 ), (1 ,4 ), [[2 ,1 ,0 ,0 ]], 0 , 'float32' )
511511
512+
513+ def _test_gather_v1 (ip_shape , indice_shape , indice_value , dtype ):
514+ """ One iteration of a Gather"""
515+ tf .reset_default_graph ()
516+ in_data = tf .placeholder (dtype , ip_shape , name = "in_data" )
517+ indices = tf .placeholder ("int32" , indice_shape , name = "indices" )
518+ tf .gather (in_data , indices )
519+ np_data = np .random .uniform (size = ip_shape ).astype (dtype )
520+
521+ def _fill_indices (indice_value ):
522+ indices = np .array (ip_shape , dtype = dtype )
523+ if isinstance (indice_value , int ):
524+ indices = np .array ([indice_value ], dtype = 'int32' )
525+ else :
526+ indices = np .asarray (indice_value , dtype = 'int32' )
527+ return indices
528+ np_indices = _fill_indices (indice_value )
529+
530+ compare_tf_with_tvm ([np_data , np_indices ], ['in_data:0' , 'indices:0' ], 'Gather:0' )
531+
532+
533+ def test_forward_gather_v1 ():
534+ '''test gather layer'''
535+
536+ if tf .__version__ < LooseVersion ('1.7' ):
537+ _test_gather_v1 ((4 ,), (1 , 2 , 2 ), [[[1 , 0 ], [0 , 1 ]]], 'float32' )
538+ _test_gather_v1 ((4 ,), (1 ,), 1 , 'int32' )
539+ _test_gather_v1 ((4 ,), (1 ,), 1 , 'float32' )
540+ _test_gather_v1 ((1 , 4 ), (1 ,), [0 ], 'int32' )
541+ _test_gather_v1 ((4 ,), (1 , 2 , 2 ), [[[1 , 0 ], [0 , 1 ]]], 'float32' )
542+ _test_gather_v1 ((2 , 2 ), (1 , 2 , 2 ), [[[1 , 0 ], [0 , 1 ]]], 'int32' )
543+ _test_gather_v1 ((2 , 2 ), (1 , 2 , 2 ), [[[1 , 0 ], [0 , 1 ]]], 'int32' )
544+ _test_gather_v1 ((2 , 2 ), (1 , 2 , 2 ), [[[1 , 0 ], [0 , 1 ]]], 'float32' )
545+ _test_gather_v1 ((3 , 3 , 3 ), (1 , 1 , 2 ), [[[1 , 0 ]]], 'int32' )
546+ _test_gather_v1 ((3 , 3 , 3 ), (1 , 1 , 2 ), [[[1 , 0 ]]], 'int32' )
547+ _test_gather_v1 ((4 , 3 , 5 , 6 ), (1 , 4 ), [[2 , 1 , 0 , 0 ]], 'float32' )
548+
549+
512550#######################################################################
513551# Split
514552# -----
@@ -1213,6 +1251,7 @@ def test_forward_rel_ops():
12131251 test_forward_crop ()
12141252 test_forward_pad ()
12151253 test_forward_gather ()
1254+ test_forward_gather_v1 ()
12161255 test_forward_stridedslice ()
12171256 test_forward_split ()
12181257 test_forward_unstack ()
0 commit comments