@@ -502,6 +502,83 @@ def test_forward_gather():
502502 _test_gather ((4 ,3 ,5 ,6 ), (1 ,4 ), [[2 ,1 ,0 ,0 ]], 0 , 'float32' )
503503
504504
505+ #######################################################################
506+ # Split
507+ # -----
508+
509+ def _test_split (in_shape , axis , num_split , dtype ):
510+ """ One iteration of a Split """
511+
512+ with tf .Graph ().as_default ():
513+ in_data = tf .placeholder (dtype , in_shape , name = "in_data" )
514+ tf .split (in_data , num_split , axis )
515+ np_data = np .random .uniform (size = in_shape ).astype (dtype )
516+ compare_tf_with_tvm (np_data , 'in_data:0' , 'split:0' )
517+
518+ def test_forward_split ():
519+ '''test split layer'''
520+ # rank 1
521+ _test_split ((3 ,), 0 , 1 , 'float32' )
522+ _test_split ((3 ,), 0 , 3 , 'float32' )
523+ _test_split ((6 ,), 0 , 3 , 'float32' )
524+ # rank 2
525+ _test_split ((6 , 2 ), 0 , 3 , 'float32' )
526+ _test_split ((2 , 6 ), 1 , 3 , 'float32' )
527+ # rank 3
528+ _test_split ((6 , 2 , 4 ), 0 , 3 , 'float32' )
529+ _test_split ((2 , 6 , 4 ), 1 , 3 , 'float32' )
530+ _test_split ((2 , 4 , 6 ), 2 , 3 , 'float32' )
531+ # rank 4
532+ _test_split ((6 , 1 , 3 , 5 ), 0 , 3 , 'float32' )
533+ _test_split ((1 , 6 , 3 , 5 ), 1 , 3 , 'float32' )
534+ _test_split ((1 , 3 , 6 , 5 ), 2 , 3 , 'float32' )
535+ _test_split ((1 , 3 , 5 , 6 ), 3 , 3 , 'float32' )
536+ # split along negative axis
537+ _test_split ((6 , 1 , 3 , 5 ), - 4 , 3 , 'float32' )
538+ _test_split ((1 , 6 , 3 , 5 ), - 3 , 3 , 'float32' )
539+ _test_split ((1 , 3 , 6 , 5 ), - 2 , 3 , 'float32' )
540+ _test_split ((1 , 3 , 5 , 6 ), - 1 , 3 , 'float32' )
541+
542+
543+ #######################################################################
544+ # Split followed by concat
545+ # ------------------------
546+
547+ def _test_split_concat (in_shape , axis , num_split , dtype ):
548+ """ One iteration of a split_concat pair"""
549+
550+ with tf .Graph ().as_default ():
551+ in_data = tf .placeholder (dtype , in_shape , name = "in_data" )
552+ splitted = tf .split (in_data , num_split , axis )
553+ tf .concat (splitted , axis )
554+ np_data = np .random .uniform (size = in_shape ).astype (dtype )
555+ compare_tf_with_tvm (np_data , 'in_data:0' , 'concat:0' )
556+
557+ def test_forward_split_concat ():
558+ '''test split followed by concat layers'''
559+ # rank 1
560+ _test_split_concat ((3 ,), 0 , 1 , 'float32' )
561+ _test_split_concat ((3 ,), 0 , 3 , 'float32' )
562+ _test_split_concat ((6 ,), 0 , 3 , 'float32' )
563+ # rank 2
564+ _test_split_concat ((6 , 2 ), 0 , 3 , 'float32' )
565+ _test_split_concat ((2 , 6 ), 1 , 3 , 'float32' )
566+ # rank 3
567+ _test_split_concat ((6 , 2 , 4 ), 0 , 3 , 'float32' )
568+ _test_split_concat ((2 , 6 , 4 ), 1 , 3 , 'float32' )
569+ _test_split_concat ((2 , 4 , 6 ), 2 , 3 , 'float32' )
570+ # rank 4
571+ _test_split ((6 , 1 , 3 , 5 ), 0 , 3 , 'float32' )
572+ _test_split ((1 , 6 , 3 , 5 ), 1 , 3 , 'float32' )
573+ _test_split ((1 , 3 , 6 , 5 ), 2 , 3 , 'float32' )
574+ _test_split ((1 , 3 , 5 , 6 ), 3 , 3 , 'float32' )
575+ # split along negative axis
576+ _test_split ((6 , 1 , 3 , 5 ), - 4 , 3 , 'float32' )
577+ _test_split ((1 , 6 , 3 , 5 ), - 3 , 3 , 'float32' )
578+ _test_split ((1 , 3 , 6 , 5 ), - 2 , 3 , 'float32' )
579+ _test_split ((1 , 3 , 5 , 6 ), - 1 , 3 , 'float32' )
580+
581+
505582#######################################################################
506583# Multi Input to graph
507584# --------------------
@@ -1061,6 +1138,8 @@ def test_forward_rel_ops():
10611138 test_forward_pad ()
10621139 test_forward_gather ()
10631140 test_forward_stridedslice ()
1141+ test_forward_split ()
1142+ test_forward_split_concat ()
10641143
10651144 # Activations
10661145 test_forward_sigmoid ()
0 commit comments