@@ -1529,5 +1529,54 @@ def after(
15291529 assert_structural_equal_ignore_global_symbol (after , sch .mod ["main" ])
15301530
15311531
1532+ def test_inline_with_reduction ():
1533+ @T .prim_func
1534+ def before (
1535+ T_softmax_norm : T .Buffer ((T .int64 (6 ), T .int64 (1 ), T .int64 (1 )), "float32" ),
1536+ T_reshape_2 : T .Buffer ((T .int64 (6 ), T .int64 (1 ), T .int64 (64 )), "float32" ),
1537+ T_transpose : T .Buffer ((T .int64 (1 ), T .int64 (1 ), T .int64 (6 ), T .int64 (64 )), "float32" ),
1538+ ):
1539+ T_batch_matmul_NN = T .alloc_buffer ((T .int64 (6 ), T .int64 (1 ), T .int64 (64 )))
1540+ for ax0 , ax1 in T .grid (T .int64 (6 ), T .int64 (64 )):
1541+ with T .block ("bmm" ):
1542+ v0 , v1 = T .axis .remap ("SS" , [ax0 , ax1 ])
1543+ T .reads (T_softmax_norm [v0 , T .int64 (0 ), T .int64 (0 )], T_reshape_2 [v0 , T .int64 (0 ), v1 ])
1544+ T .writes (T_batch_matmul_NN [v0 , T .int64 (0 ), v1 ])
1545+ with T .init ():
1546+ T_batch_matmul_NN [v0 , T .int64 (0 ), v1 ] = T .float32 (0.0 )
1547+ T_batch_matmul_NN [v0 , T .int64 (0 ), v1 ] = (
1548+ T_batch_matmul_NN [v0 , T .int64 (0 ), v1 ]
1549+ + T_softmax_norm [v0 , T .int64 (0 ), T .int64 (0 )] * T_reshape_2 [v0 , T .int64 (0 ), v1 ]
1550+ )
1551+ for ax0 , ax1 in T .grid (T .int64 (6 ), T .int64 (64 )):
1552+ with T .block ("transpose" ):
1553+ v0 , v1 = T .axis .remap ("SS" , [ax0 , ax1 ])
1554+ T .reads (T_batch_matmul_NN [v0 , T .int64 (0 ), v1 ])
1555+ T .writes (T_transpose [T .int64 (0 ), T .int64 (0 ), v0 , v1 ])
1556+ T_transpose [T .int64 (0 ), T .int64 (0 ), v0 , v1 ] = T_batch_matmul_NN [v0 , T .int64 (0 ), v1 ]
1557+
1558+ @T .prim_func
1559+ def after (
1560+ T_softmax_norm : T .Buffer ((T .int64 (6 ), T .int64 (1 ), T .int64 (1 )), "float32" ),
1561+ T_reshape_2 : T .Buffer ((T .int64 (6 ), T .int64 (1 ), T .int64 (64 )), "float32" ),
1562+ T_transpose : T .Buffer ((T .int64 (1 ), T .int64 (1 ), T .int64 (6 ), T .int64 (64 )), "float32" ),
1563+ ):
1564+ for ax0 , ax1 in T .grid (T .int64 (6 ), T .int64 (64 )):
1565+ with T .block ("bmm" ):
1566+ v0 , v1 = T .axis .remap ("SS" , [ax0 , ax1 ])
1567+ T .reads (T_softmax_norm [v0 , T .int64 (0 ), T .int64 (0 )], T_reshape_2 [v0 , T .int64 (0 ), v1 ])
1568+ T .writes (T_transpose [T .int64 (0 ), T .int64 (0 ), v0 , v1 ])
1569+ with T .init ():
1570+ T_transpose [T .int64 (0 ), T .int64 (0 ), v0 , v1 ] = T .float32 (0.0 )
1571+ T_transpose [T .int64 (0 ), T .int64 (0 ), v0 , v1 ] = (
1572+ T_transpose [T .int64 (0 ), T .int64 (0 ), v0 , v1 ]
1573+ + T_softmax_norm [v0 , T .int64 (0 ), T .int64 (0 )] * T_reshape_2 [v0 , T .int64 (0 ), v1 ]
1574+ )
1575+
1576+ sch = tir .Schedule (before )
1577+ sch .reverse_compute_inline (sch .get_block ("transpose" ))
1578+ assert_structural_equal_ignore_global_symbol (after , sch .mod ["main" ])
1579+
1580+
15321581if __name__ == "__main__" :
15331582 tvm .testing .main ()
0 commit comments