@@ -217,7 +217,6 @@ def expected(dshape):
217217 assert not relay .ir_pass .free_vars (zz )
218218 after = relay .ir_pass .infer_type (expected (dshape ))
219219 assert relay .ir_pass .alpha_equal (zz , after )
220- print (zz .astext ())
221220
222221
223222def test_stop_fusion ():
@@ -287,6 +286,81 @@ def expected(dshape, dtype):
287286 assert relay .ir_pass .alpha_equal (f , after )
288287
289288
289+ def test_fuse_tuple_get_elemwise ():
290+ def before (dim ):
291+ X = relay .var ("X" , shape = (1 , dim ))
292+ W = relay .var ("W" , shape = (3 * dim , dim ))
293+ matmul = relay .nn .dense (X , W )
294+ splitted = relay .split (matmul , indices_or_sections = 3 , axis = 1 )
295+ out = relay .sigmoid (splitted [0 ]) + relay .tanh (splitted [1 ]) * relay .exp (splitted [2 ])
296+ return relay .Function ([X , W ], out )
297+
298+ def expected (dim ):
299+ p0 = relay .var ("p0" , shape = (1 , dim ))
300+ p1 = relay .var ("p1" , shape = (3 * dim , dim ))
301+ matmul = relay .nn .dense (p0 , p1 )
302+ f0 = relay .Function ([p0 , p1 ], matmul )
303+
304+ p01 = relay .var ("p01" , shape = (1 , 3 * dim ))
305+ splitted = relay .split (p01 , indices_or_sections = 3 , axis = 1 )
306+ out = relay .sigmoid (splitted [0 ]) + relay .tanh (splitted [1 ]) * relay .exp (splitted [2 ])
307+ f1 = relay .Function ([p01 ], out )
308+
309+ X = relay .var ("X" , shape = (1 , dim ))
310+ W = relay .var ("W" , shape = (3 * dim , dim ))
311+ y = relay .Call (f0 , [X , W ])
312+ z = relay .Call (f1 , [y ])
313+ return relay .Function ([X , W ], z )
314+
315+ dim = 10
316+ z = before (dim )
317+ z = relay .ir_pass .infer_type (z )
318+ zz = relay .ir_pass .fuse_ops (z , opt_level = 0 )
319+ assert not relay .ir_pass .free_vars (zz )
320+ zz = relay .ir_pass .fuse_ops (z , opt_level = 2 )
321+ zz = relay .ir_pass .infer_type (zz )
322+ assert not relay .ir_pass .free_vars (zz )
323+ after = relay .ir_pass .infer_type (expected (dim ))
324+ assert relay .ir_pass .alpha_equal (zz , after )
325+
326+
327+ def test_tuple_get_root ():
328+ def before (dim ):
329+ X = relay .var ("X" , shape = (1 , 3 * dim ))
330+ W = relay .var ("W" , shape = (dim , dim ))
331+ splitted = relay .split (X , indices_or_sections = 3 , axis = 1 )
332+ out = relay .nn .dense (splitted [0 ], W )
333+ return relay .Function ([X , W ], out )
334+
335+ def expected (dim ):
336+ p0 = relay .var ("p0" , shape = (1 , 3 * dim ))
337+ splitted = relay .split (p0 , indices_or_sections = 3 , axis = 1 )
338+ out = splitted [0 ]
339+ f0 = relay .Function ([p0 ], out )
340+
341+ p01 = relay .var ("p01" , shape = (1 , dim ))
342+ p1 = relay .var ("p1" , shape = (dim , dim ))
343+ out = relay .nn .dense (p01 , p1 )
344+ f1 = relay .Function ([p01 , p1 ], out )
345+
346+ X = relay .var ("X" , shape = (1 , 3 * dim ))
347+ W = relay .var ("W" , shape = (dim , dim ))
348+ y = relay .Call (f0 , [X ])
349+ z = relay .Call (f1 , [y , W ])
350+ return relay .Function ([X , W ], z )
351+
352+ dim = 10
353+ z = before (dim )
354+ z = relay .ir_pass .infer_type (z )
355+ zz = relay .ir_pass .fuse_ops (z , opt_level = 0 )
356+ assert not relay .ir_pass .free_vars (zz )
357+ zz = relay .ir_pass .fuse_ops (z , opt_level = 2 )
358+ zz = relay .ir_pass .infer_type (zz )
359+ assert not relay .ir_pass .free_vars (zz )
360+ after = relay .ir_pass .infer_type (expected (dim ))
361+ assert relay .ir_pass .alpha_equal (zz , after )
362+
363+
290364if __name__ == "__main__" :
291365 test_fuse_simple ()
292366 test_conv2d_fuse ()
@@ -295,3 +369,5 @@ def expected(dshape, dtype):
295369 test_tuple_strided_slice ()
296370 test_stop_fusion ()
297371 test_fuse_myia_regression ()
372+ test_fuse_tuple_get_elemwise ()
373+ test_tuple_get_root ()
0 commit comments