@@ -176,16 +176,14 @@ def expected(dshape):
176176 f0 = relay .Function ([x ], pooled )
177177
178178 p0 = relay .var ("p0" , shape = (dshape [0 ], dshape [1 ], dshape [2 ]// 2 , dshape [3 ]// 2 ))
179- p1 = relay .var ("p1" , shape = (dshape [0 ], dshape [1 ], dshape [2 ], dshape [3 ]))
180- p1_copy = relay .copy (p1 )
181179 upsampled = relay .nn .upsampling (p0 , scale = 2 , layout = "NCHW" )
182- out = relay .Tuple ((upsampled , p1_copy ))
183- f1 = relay .Function ([p0 , p1 ], out )
180+ f1 = relay .Function ([p0 ], upsampled )
184181
185182 x = relay .var ("x" , shape = dshape )
186183 y = relay .Call (f0 , [x ])
187- z = relay .Call (f1 , [y , x ])
188- return relay .Function ([x ], z )
184+ z = relay .Call (f1 , [y ])
185+ tup = relay .Tuple ((z , x ))
186+ return relay .Function ([x ], tup )
189187
190188 dshape = (1 , 16 , 64 , 64 )
191189 z = before (dshape )
@@ -199,41 +197,6 @@ def expected(dshape):
199197 assert relay .ir_pass .alpha_equal (zz , after )
200198
201199
202- def test_tuple_strided_slice ():
203- """
204- Test fusion case where the number of fields of tuple and
205- the number of parameters to the function containing the tuple are different
206- """
207-
208- def before (dshape ):
209- x = relay .var ("x" , shape = dshape )
210- slice1 = relay .strided_slice (x , begin = [0 , 0 ], end = [dshape [1 ]// 2 , dshape [1 ]], strides = [1 ,1 ])
211- slice2 = relay .strided_slice (x , begin = [dshape [1 ]// 2 , 0 ], end = [dshape [0 ], dshape [1 ]], strides = [1 ,1 ])
212- out = relay .Tuple ((slice1 , slice2 ))
213- return relay .Function ([x ], out )
214-
215- def expected (dshape ):
216- x = relay .var ("x" , shape = dshape )
217- slice1 = relay .strided_slice (x , begin = [0 , 0 ], end = [dshape [1 ]// 2 , dshape [1 ]], strides = [1 ,1 ])
218- slice2 = relay .strided_slice (x , begin = [dshape [1 ]// 2 , 0 ], end = [dshape [0 ], dshape [1 ]], strides = [1 ,1 ])
219- out = relay .Tuple ((slice1 , slice2 ))
220- f0 = relay .Function ([x ], out )
221-
222- x = relay .var ("x" , shape = dshape )
223- y = relay .Call (f0 , [x ])
224- return relay .Function ([x ], y )
225-
226- dshape = (64 , 64 )
227- z = before (dshape )
228- z = relay .ir_pass .infer_type (z )
229- zz = relay .ir_pass .fuse_ops (z , opt_level = 0 )
230- assert not relay .ir_pass .free_vars (zz )
231- zz = relay .ir_pass .fuse_ops (z , opt_level = 2 )
232- zz = relay .ir_pass .infer_type (zz )
233- assert not relay .ir_pass .free_vars (zz )
234- after = relay .ir_pass .infer_type (expected (dshape ))
235- assert relay .ir_pass .alpha_equal (zz , after )
236-
237200
238201def test_stop_fusion ():
239202 def before (dshape ):
@@ -377,13 +340,178 @@ def expected(dim):
377340 assert relay .ir_pass .alpha_equal (zz , after )
378341
379342
343+ def test_tuple_intermediate ():
344+ def before (x ):
345+ inj = relay .squeeze (x )
346+ y1 = relay .add (inj , relay .const (1 , "float32" ))
347+ tmp = relay .squeeze (inj )
348+ tmp = relay .add (tmp , relay .const (1 , "float32" ))
349+ y2 = relay .add (tmp , relay .const (1 , "float32" ))
350+ y3 = relay .add (inj , relay .const (1 , "float32" ))
351+ concat = relay .concatenate ((y1 , y2 , y3 ), axis = 1 )
352+ out_inj = relay .squeeze (concat )
353+ out = relay .add (out_inj , relay .const (1 , "float32" ))
354+ return relay .Function (relay .ir_pass .free_vars (out ), out )
355+
356+ def expected (p0 ):
357+ f0 = before (p0 )
358+ x = relay .var ("x" , shape = dshape )
359+ y = relay .Call (f0 , [x ])
360+ return relay .Function ([x ], y )
361+
362+ dshape = (1 , 16 , 64 , 64 )
363+ x = relay .var ("x" , shape = dshape )
364+ z = before (x )
365+ z = relay .ir_pass .infer_type (z )
366+ zz = relay .ir_pass .fuse_ops (z , opt_level = 0 )
367+ assert not relay .ir_pass .free_vars (zz )
368+ zz = relay .ir_pass .fuse_ops (z , opt_level = 2 )
369+ relay .build (zz , 'llvm' )
370+ zz = relay .ir_pass .infer_type (zz )
371+ assert not relay .ir_pass .free_vars (zz )
372+ after = relay .ir_pass .infer_type (expected (x ))
373+ assert relay .ir_pass .alpha_equal (zz , after )
374+
375+
376+ def test_tuple_consecutive ():
377+ def gen_intermediate_tuple (x ):
378+ y1 = relay .add (x , relay .const (1 , "float32" ))
379+ y2 = relay .add (x , relay .const (1 , "float32" ))
380+ y3 = relay .add (x , relay .const (1 , "float32" ))
381+ concat = relay .concatenate ((y1 , y2 , y3 ), axis = 1 )
382+ out = relay .add (concat , relay .const (1 , "float32" ))
383+ return out
384+
385+ def gen_consecutive_tuple (x ):
386+ y1 = gen_intermediate_tuple (x )
387+ y2 = gen_intermediate_tuple (x )
388+ y3 = gen_intermediate_tuple (x )
389+ concat = relay .concatenate ((y1 , y2 , y3 ), axis = 1 )
390+ return concat
391+
392+ def before (x ):
393+ concat = gen_consecutive_tuple (x )
394+ pooled = relay .nn .max_pool2d (concat , pool_size = (2 , 2 ), strides = (2 , 2 ), padding = (0 , 0 ))
395+ out = relay .add (pooled , relay .const (1 , "float32" ))
396+ out2 = relay .add (out , relay .const (1 , "float32" ))
397+ out_tup = relay .Tuple ((out , out2 ))
398+ return relay .Function (relay .ir_pass .free_vars (out_tup ), out_tup )
399+
400+ def expected (dshape ):
401+ p0 = relay .var ("p0" , shape = dshape )
402+ concat = gen_consecutive_tuple (p0 )
403+ f0 = relay .Function ([p0 ], concat )
404+
405+ p01 = relay .var ("p01" , shape = (1 , dshape [1 ]* 9 , dshape [2 ], dshape [3 ]))
406+ pooled = relay .nn .max_pool2d (p01 , pool_size = (2 , 2 ), strides = (2 , 2 ), padding = (0 , 0 ))
407+ out = relay .add (pooled , relay .const (1 , "float32" ))
408+ f1 = relay .Function ([p01 ], out )
409+
410+ p02 = relay .var ("p02" , shape = (1 , dshape [1 ]* 9 , dshape [2 ]// 2 , dshape [3 ]// 2 ))
411+ out = relay .add (p02 , relay .const (1 , "float32" ))
412+ f2 = relay .Function ([p02 ], out )
413+
414+ x = relay .var ("x" , shape = dshape )
415+ y = relay .Call (f0 , [x ])
416+ z = relay .Call (f1 , [y ])
417+ z2 = relay .Call (f2 , [z ])
418+
419+ return relay .Function ([x ], relay .Tuple ((z , z2 )))
420+
421+ dshape = (1 , 16 , 64 , 64 )
422+ x = relay .var ("x" , shape = dshape )
423+ z = before (x )
424+ z = relay .ir_pass .infer_type (z )
425+ zz = relay .ir_pass .fuse_ops (z , opt_level = 0 )
426+ assert not relay .ir_pass .free_vars (zz )
427+ zz = relay .ir_pass .fuse_ops (z , opt_level = 2 )
428+ relay .build (zz , 'llvm' )
429+ zz = relay .ir_pass .infer_type (zz )
430+ assert not relay .ir_pass .free_vars (zz )
431+ after = relay .ir_pass .infer_type (expected (dshape ))
432+ assert relay .ir_pass .alpha_equal (zz , after )
433+
434+
435+ def test_inception_like ():
436+ def conv (data ):
437+ y = relay .nn .conv2d (data , relay .var ("w" ),
438+ kernel_size = (3 , 3 ),
439+ padding = (1 , 1 ),
440+ channels = 16 )
441+ return relay .nn .relu (data = y )
442+
443+ def inception_like (data ):
444+ c0 = conv (data )
445+ c1 = conv (data )
446+ return relay .concatenate ((c0 , c1 ), axis = 1 )
447+
448+ def before (dshape ):
449+ x = relay .var ("x" , shape = dshape )
450+ in1 = inception_like (x )
451+ in2 = inception_like (in1 )
452+ return relay .Function (relay .ir_pass .free_vars (in2 ), in2 )
453+
454+ def expected (dshape ):
455+ p0 = relay .var ("p0" , shape = dshape )
456+ c = conv (p0 )
457+ f0 = relay .Function (relay .ir_pass .free_vars (c ), c )
458+
459+ p01 = relay .var ("p01" , shape = dshape )
460+ c = conv (p01 )
461+ f1 = relay .Function (relay .ir_pass .free_vars (c ), c )
462+
463+ p02 = relay .var ("p02" , shape = dshape )
464+ p12 = relay .var ("p12" , shape = dshape )
465+ concat1 = relay .concatenate ((p02 , p12 ), axis = 1 )
466+ f_concat1 = relay .Function ([p02 , p12 ], concat1 )
467+
468+ dshape2 = (dshape [0 ], dshape [1 ]* 2 , dshape [2 ], dshape [3 ])
469+
470+ p03 = relay .var ("p03" , shape = dshape2 )
471+ c = conv (p03 )
472+ f2 = relay .Function (relay .ir_pass .free_vars (c ), c )
473+
474+ p04 = relay .var ("p04" , shape = dshape2 )
475+ c = conv (p04 )
476+ f3 = relay .Function (relay .ir_pass .free_vars (c ), c )
477+
478+ p05 = relay .var ("p05" , shape = dshape )
479+ p15 = relay .var ("p15" , shape = dshape )
480+ concat2 = relay .concatenate ((p05 , p15 ), axis = 1 )
481+ f_concat2 = relay .Function ([p05 , p15 ], concat2 )
482+
483+ x = relay .var ("x" , shape = dshape )
484+ c1 = relay .Call (f0 , [x , relay .var ("w1" )])
485+ c2 = relay .Call (f1 , [x , relay .var ("w2" )])
486+ concat = relay .Call (f_concat1 , [c1 , c2 ])
487+ c3 = relay .Call (f2 , [concat , relay .var ("w3" )])
488+ c4 = relay .Call (f3 , [concat , relay .var ("w4" )])
489+ out = relay .Call (f_concat2 , [c3 , c4 ])
490+
491+ return relay .Function (relay .ir_pass .free_vars (out ), out )
492+
493+ dshape = (1 , 16 , 64 , 64 )
494+ z = before (dshape )
495+ z = relay .ir_pass .infer_type (z )
496+ zz = relay .ir_pass .fuse_ops (z , opt_level = 0 )
497+ assert not relay .ir_pass .free_vars (zz )
498+ zz = relay .ir_pass .fuse_ops (z , opt_level = 2 )
499+ relay .build (zz , 'llvm' )
500+ zz = relay .ir_pass .infer_type (zz )
501+ assert not relay .ir_pass .free_vars (zz )
502+ after = relay .ir_pass .infer_type (expected (dshape ))
503+ assert relay .ir_pass .alpha_equal (zz , after )
504+
505+
380506if __name__ == "__main__" :
381507 test_fuse_simple ()
382508 test_conv2d_fuse ()
383509 test_concatenate ()
384510 test_tuple_root ()
385- test_tuple_strided_slice ()
386511 test_stop_fusion ()
387512 test_fuse_myia_regression ()
388513 test_fuse_tuple_get_elemwise ()
389514 test_tuple_get_root ()
515+ test_tuple_intermediate ()
516+ test_tuple_consecutive ()
517+ test_inception_like ()
0 commit comments