@@ -611,6 +611,34 @@ def elementwise_overcomputed_producer_reverse_inlined(
611611 C [vi , vj ] = A [vi , vj ] * 2.0 + 1.0
612612
613613
614+ @T .prim_func
615+ def elementwise_overcomputed_producer_simplify_predicate (
616+ A : T .Buffer ((128 , 128 ), "float32" ), C : T .Buffer ((127 , 127 ), "float32" )
617+ ) -> None :
618+ B = T .alloc_buffer ((128 , 128 ))
619+ for i in T .grid (16384 ):
620+ with T .block ("B" ):
621+ vi = T .axis .spatial (128 , i // 128 )
622+ vj = T .axis .spatial (128 , i % 128 )
623+ B [vi , vj ] = A [vi , vj ] * 2.0
624+ for i , j in T .grid (127 , 127 ):
625+ with T .block ("C" ):
626+ cvi , cvj = T .axis .remap ("SS" , [i , j ])
627+ C [cvi , cvj ] = B [cvi , cvj ] + 1.0
628+
629+
630+ @T .prim_func
631+ def elementwise_overcomputed_producer_simplify_predicate_reverse_inlined (
632+ A : T .Buffer ((128 , 128 ), "float32" ), C : T .Buffer ((127 , 127 ), "float32" )
633+ ) -> None :
634+ for i in T .grid (16384 ):
635+ with T .block ("B" ):
636+ vi = T .axis .spatial (128 , i // 128 )
637+ vj = T .axis .spatial (128 , i % 128 )
638+ T .where (i < 16255 and i % 128 < 127 )
639+ C [vi , vj ] = A [vi , vj ] * 2.0 + 1.0
640+
641+
614642@T .prim_func
615643def elementwise_producer_not_cover_consumer (
616644 A : T .Buffer ((128 , 128 ), "float32" ), D : T .Buffer ((256 , 128 ), "float32" )
@@ -1025,6 +1053,15 @@ def test_reverse_compute_inline_overcomputed_producer(use_block_name):
10251053 )
10261054
10271055
1056+ def test_reverse_compute_inline_overcomputed_producer_simplify_predicate (use_block_name ):
1057+ """Test reverse compute inline overcomputed producer where the predicate should be simplified"""
1058+ sch = tir .Schedule (elementwise_overcomputed_producer_simplify_predicate , debug_mask = "all" )
1059+ compute = "C" if use_block_name else sch .get_block ("C" )
1060+ sch .reverse_compute_inline (compute )
1061+ tvm .ir .assert_structural_equal (
1062+ elementwise_overcomputed_producer_simplify_predicate_reverse_inlined , sch .mod ["main" ]
1063+ )
1064+
10281065def test_reverse_compute_inline_error_producer_not_cover_consumer (use_block_name ):
10291066 """Test reverse compute inline failure when the inlined block iter domains are not covered by
10301067 its producer
0 commit comments