@@ -1353,5 +1353,94 @@ def _create_prim_func():
13531353 verify_trace_roundtrip (sch = sch , mod = mod )
13541354
13551355
1356+ def test_compute_at_to_early_stage ():
1357+ @T .prim_func
1358+ def multi_producers_conv (
1359+ data : T .Buffer [(1 , 3 , 224 , 224 ), "int8" ],
1360+ w : T .Buffer [(16 , 3 , 7 , 7 ), "int8" ],
1361+ conv : T .Buffer [(1 , 16 , 112 , 112 ), "int32" ],
1362+ ) -> None :
1363+ pad = T .alloc_buffer ([1 , 3 , 230 , 230 ], dtype = "int8" )
1364+ wbuf = T .alloc_buffer ([16 , 3 , 7 , 7 ], dtype = "int8" )
1365+ for i0 , i1 , i2 , i3 in T .grid (1 , 3 , 230 , 230 ):
1366+ with T .block ("pad" ):
1367+ i0_1 , i1_1 , i2_1 , i3_1 = T .axis .remap ("SSSS" , [i0 , i1 , i2 , i3 ])
1368+ T .reads (data [i0_1 , i1_1 , i2_1 - 3 , i3_1 - 3 ])
1369+ T .writes (pad [i0_1 , i1_1 , i2_1 , i3_1 ])
1370+ pad [i0_1 , i1_1 , i2_1 , i3_1 ] = T .if_then_else (
1371+ 3 <= i2_1 and i2_1 < 227 and 3 <= i3_1 and i3_1 < 227 ,
1372+ data [i0_1 , i1_1 , i2_1 - 3 , i3_1 - 3 ],
1373+ T .int8 (0 ),
1374+ dtype = "int8" ,
1375+ )
1376+ for i0 in T .serial (1 ):
1377+ for ax0 , ax1 , ax2 , ax3 in T .grid (16 , 3 , 7 , 7 ):
1378+ with T .block ("wbuf" ):
1379+ v0 , v1 , v2 , v3 = T .axis .remap ("SSSS" , [ax0 , ax1 , ax2 , ax3 ])
1380+ T .reads (w [v0 , v1 , v2 , v3 ])
1381+ T .writes (wbuf [v0 , v1 , v2 , v3 ])
1382+ wbuf [v0 , v1 , v2 , v3 ] = w [v0 , v1 , v2 , v3 ]
1383+ for i1 , i2 , i3 , i4 , i5 , i6 in T .grid (16 , 112 , 112 , 3 , 7 , 7 ):
1384+ with T .block ("conv" ):
1385+ nn , ff , yy , xx , rc , ry , rx = T .axis .remap (
1386+ "SSSSRRR" , [i0 , i1 , i2 , i3 , i4 , i5 , i6 ]
1387+ )
1388+ T .reads (pad [nn , rc , yy * 2 + ry , xx * 2 + rx ], wbuf [ff , rc , ry , rx ])
1389+ T .writes (conv [nn , ff , yy , xx ])
1390+ with T .init ():
1391+ conv [nn , ff , yy , xx ] = 0
1392+ conv [nn , ff , yy , xx ] = conv [nn , ff , yy , xx ] + T .cast (
1393+ pad [nn , rc , yy * 2 + ry , xx * 2 + rx ], "int32"
1394+ ) * T .cast (wbuf [ff , rc , ry , rx ], "int32" )
1395+
1396+ @T .prim_func
1397+ def multi_producers_after_compute_at (
1398+ data : T .Buffer [(1 , 3 , 224 , 224 ), "int8" ],
1399+ w : T .Buffer [(16 , 3 , 7 , 7 ), "int8" ],
1400+ conv : T .Buffer [(1 , 16 , 112 , 112 ), "int32" ],
1401+ ) -> None :
1402+ pad = T .alloc_buffer ([1 , 3 , 230 , 230 ], dtype = "int8" )
1403+ wbuf = T .alloc_buffer ([16 , 3 , 7 , 7 ], dtype = "int8" )
1404+ for i0 in T .serial (1 ):
1405+ for ax0 , ax1 , ax2 in T .grid (3 , 229 , 229 ):
1406+ with T .block ("pad" ):
1407+ i0_1 = T .axis .spatial (1 , 0 )
1408+ i1_1 = T .axis .spatial (3 , ax0 )
1409+ i2_1 = T .axis .spatial (230 , ax1 )
1410+ i3_1 = T .axis .spatial (230 , ax2 )
1411+ T .reads (data [i0_1 , i1_1 , i2_1 - 3 , i3_1 - 3 ])
1412+ T .writes (pad [i0_1 , i1_1 , i2_1 , i3_1 ])
1413+ pad [i0_1 , i1_1 , i2_1 , i3_1 ] = T .if_then_else (
1414+ 3 <= i2_1 and i2_1 < 227 and 3 <= i3_1 and i3_1 < 227 ,
1415+ data [i0_1 , i1_1 , i2_1 - 3 , i3_1 - 3 ],
1416+ T .int8 (0 ),
1417+ dtype = "int8" ,
1418+ )
1419+ for ax0 , ax1 , ax2 , ax3 in T .grid (16 , 3 , 7 , 7 ):
1420+ with T .block ("wbuf" ):
1421+ v0 , v1 , v2 , v3 = T .axis .remap ("SSSS" , [ax0 , ax1 , ax2 , ax3 ])
1422+ T .reads (w [v0 , v1 , v2 , v3 ])
1423+ T .writes (wbuf [v0 , v1 , v2 , v3 ])
1424+ wbuf [v0 , v1 , v2 , v3 ] = w [v0 , v1 , v2 , v3 ]
1425+ for i1 , i2 , i3 , i4 , i5 , i6 in T .grid (16 , 112 , 112 , 3 , 7 , 7 ):
1426+ with T .block ("conv" ):
1427+ nn , ff , yy , xx , rc , ry , rx = T .axis .remap (
1428+ "SSSSRRR" , [i0 , i1 , i2 , i3 , i4 , i5 , i6 ]
1429+ )
1430+ T .reads (pad [nn , rc , yy * 2 + ry , xx * 2 + rx ], wbuf [ff , rc , ry , rx ])
1431+ T .writes (conv [nn , ff , yy , xx ])
1432+ with T .init ():
1433+ conv [nn , ff , yy , xx ] = 0
1434+ conv [nn , ff , yy , xx ] = conv [nn , ff , yy , xx ] + T .cast (
1435+ pad [nn , rc , yy * 2 + ry , xx * 2 + rx ], "int32"
1436+ ) * T .cast (wbuf [ff , rc , ry , rx ], "int32" )
1437+
1438+ sch = tir .Schedule (multi_producers_conv , debug_mask = "all" )
1439+ block_c = sch .get_block ("pad" )
1440+ axis = sch .get_loops ("conv" )[0 ]
1441+ sch .compute_at (block_c , axis , to_early_stage = True )
1442+ tvm .ir .assert_structural_equal (multi_producers_after_compute_at , sch .mod ["main" ])
1443+
1444+
13561445if __name__ == "__main__" :
13571446 tvm .testing .main ()
0 commit comments