@@ -23,13 +23,15 @@ def before():
2323        x  =  relay .var ("x" , shape = (10 , 20 ))
2424        y  =  relay .add (x , relay .const (1 , "float32" ))
2525        z  =  relay .exp (y )
26-         return  relay .Function ([x ], z )
26+         w  =  relay .squeeze (z )
27+         return  relay .Function ([x ], w )
2728
2829    def  expected ():
2930        x  =  relay .var ("p" , shape = (10 , 20 ))
3031        y  =  relay .add (x , relay .const (1 , "float32" ))
3132        z  =  relay .exp (y )
32-         f1  =  relay .Function ([x ], z )
33+         w  =  relay .squeeze (z )
34+         f1  =  relay .Function ([x ], w )
3335        x  =  relay .var ("x" , shape = (10 , 20 ))
3436        y  =  relay .Call (f1 , [x ])
3537        return  relay .Function ([x ], y )
@@ -503,6 +505,38 @@ def expected(dshape):
503505    assert  relay .ir_pass .alpha_equal (zz , after )
504506
505507
508+ def  test_fuse_parallel_injective ():
509+     """Test fusing parallel injective ops to an elemwise op.""" 
510+     def  before ():
511+         x  =  relay .var ("x" , shape = (10 , 20 ))
512+         y  =  relay .add (x , relay .const (1 , "float32" ))
513+         z  =  relay .squeeze (y )
514+         u  =  relay .transpose (y , axes = [0 , 1 ])
515+         w  =  relay .left_shift (z , u )
516+         return  relay .Function ([x ], w )
517+ 
518+     def  expected ():
519+         x  =  relay .var ("p" , shape = (10 , 20 ))
520+         y  =  relay .add (x , relay .const (1 , "float32" ))
521+         z  =  relay .squeeze (y )
522+         u  =  relay .transpose (y , axes = [0 , 1 ])
523+         w  =  relay .left_shift (z , u )
524+         f1  =  relay .Function ([x ], w )
525+         x  =  relay .var ("x" , shape = (10 , 20 ))
526+         y  =  relay .Call (f1 , [x ])
527+         return  relay .Function ([x ], y )
528+ 
529+     z  =  before ()
530+     z  =  relay .ir_pass .infer_type (z )
531+     zz  =  relay .ir_pass .fuse_ops (z , opt_level = 0 )
532+     assert  not  relay .ir_pass .free_vars (zz )
533+     zz  =  relay .ir_pass .fuse_ops (z , opt_level = 2 )
534+     zz  =  relay .ir_pass .infer_type (zz )
535+     assert  not  relay .ir_pass .free_vars (zz )
536+     after  =  relay .ir_pass .infer_type (expected ())
537+     assert  relay .ir_pass .alpha_equal (zz , after )
538+ 
539+ 
506540if  __name__  ==  "__main__" :
507541    test_fuse_simple ()
508542    test_conv2d_fuse ()
@@ -515,3 +549,4 @@ def expected(dshape):
515549    test_tuple_intermediate ()
516550    test_tuple_consecutive ()
517551    test_inception_like ()
552+     test_fuse_parallel_injective ()
0 commit comments