@@ -309,5 +309,68 @@ def pooling_decompose_3(
309309 check_decompose_padding (sum_pool_2d , sch .mod ["main" ], pooling_decompose_3 , check_run = True )
310310
311311
312+ def test_decompose_wrt_single_child_subtree ():
313+ """Test the case when the decompose position is under the single child subtree"""
314+
315+ @T .prim_func
316+ def pad_op (
317+ x : T .Buffer [(1 , 16 , 225 , 225 ), "int8" ], y : T .Buffer ([1 , 16 , 231 , 231 ], dtype = "int8" )
318+ ):
319+ for i0 , i1 , i2 , i3 in T .grid (1 , 16 , 231 , 231 ):
320+ with T .block ("pad_temp" ):
321+ ax0 , ax1 , ax2 , ax3 = T .axis .remap ("SSSS" , [i0 , i1 , i2 , i3 ])
322+ y [ax0 , ax1 , ax2 , ax3 ] = T .if_then_else (
323+ 3 <= ax2 and ax2 < 228 and 3 <= ax3 and ax3 < 228 ,
324+ x [ax0 , ax1 , ax2 - 3 , ax3 - 3 ],
325+ T .int8 (0 ),
326+ dtype = "int8" ,
327+ )
328+
329+ @T .prim_func
330+ def pad_op_after (
331+ x : T .Buffer [(1 , 16 , 225 , 225 ), "int8" ], y : T .Buffer [(1 , 16 , 231 , 231 ), "int8" ]
332+ ):
333+ for i0 , i1 in T .grid (1 , 16 ):
334+ for i2 , i3 in T .grid (231 , 231 ):
335+ with T .block ("pad_temp_pad_const" ):
336+ ax0 = T .axis .spatial (1 , 0 )
337+ ax1 , ax2 , ax3 = T .axis .remap ("SSS" , [i1 , i2 , i3 ])
338+ y [ax0 , ax1 , ax2 , ax3 ] = T .int8 (0 )
339+ for i2 , i3 in T .grid (225 , 225 ):
340+ with T .block ("pad_temp" ):
341+ ax0 = T .axis .spatial (1 , 0 )
342+ ax1 , ax2 , ax3 = T .axis .remap ("SSS" , [i1 , i2 , i3 ])
343+ y [ax0 , ax1 , ax2 + 3 , ax3 + 3 ] = x [ax0 , ax1 , ax2 , ax3 ]
344+
345+ sch = tir .Schedule (pad_op , debug_mask = "all" )
346+ pad = sch .get_block ("pad_temp" )
347+ _ , _ , h , _ = sch .get_loops (pad )
348+ sch .decompose_padding (pad , h )
349+ check_decompose_padding (pad_op , sch .mod ["main" ], pad_op_after , check_run = True )
350+
351+
352+ def test_not_to_decompose_trivial_predicate ():
353+ """Test the case when the padding condition is trivial"""
354+
355+ @T .prim_func
356+ def trivial_pad (
357+ x : T .Buffer [(1 , 16 , 225 , 225 ), "int8" ], y : T .Buffer ([1 , 16 , 225 , 225 ], dtype = "int8" )
358+ ):
359+ for i0 , i1 , i2 , i3 in T .grid (1 , 16 , 225 , 225 ):
360+ with T .block ("pad_temp" ):
361+ ax0 , ax1 , ax2 , ax3 = T .axis .remap ("SSSS" , [i0 , i1 , i2 , i3 ])
362+ y [ax0 , ax1 , ax2 , ax3 ] = T .if_then_else (
363+ 0 <= ax2 and ax2 < 225 and 0 <= ax3 and ax3 < 225 ,
364+ x [ax0 , ax1 , ax2 , ax3 ],
365+ T .int8 (0 ),
366+ dtype = "int8" ,
367+ )
368+
369+ sch = tir .Schedule (trivial_pad , debug_mask = "all" )
370+ pad = sch .get_block ("pad_temp" )
371+ _ , _ , h , _ = sch .get_loops (pad )
372+ assert not sch .can_decompose_padding (pad , h )
373+
374+
312375if __name__ == "__main__" :
313376 tvm .testing .main ()
0 commit comments