@@ -927,6 +927,183 @@ def after_cache_write(a: ty.handle, b: ty.handle) -> None:
927927
928928 ########## Schedule: Compute location ##########
929929
930+ def compute_at (
931+ self ,
932+ block : BlockRV ,
933+ loop : LoopRV ,
934+ preserve_unit_loops : bool = False ,
935+ ) -> None :
936+ """Compute-At. Move a producer block under the specific loop, and regenerate the
937+ loops induced by the block so that the buffer region produced by the producer block could
938+ cover those regions consumed by its consumer blocks under the given loop. It requires:
939+
940+ 1) `block` and `loop` are under the same scope, `loop` is not the ancestor of `block`
941+
942+ 2) The scope block has stage-pipeline property
943+
944+ 3) The subtree of the scope block, where the given block is in, satisfies the compact
945+ dataflow condition. i.e. all the blocks in the scope block's subtree must be either
946+ complete block or reduction block
947+
948+ 4) The block is not an output block with regard to the scope block, i.e. the buffers written
949+ by the block are allocated under the scope block
950+
951+ 5) All the consumers of the block are under the given loop
952+
953+ Parameters
954+ ----------
955+ block : BlockRV
956+ The block to be moved
957+
958+ loop: LoopRV
959+ The loop where the block to be moved under
960+
961+ preserve_unit_loops: bool
962+ Whether to keep the trivial loops whose extents are 1
963+
964+ Examples
965+ --------
966+
967+ Before compute-at, in TensorIR, the IR is:
968+
969+ .. code-block:: python
970+
971+ @tvm.script.tir
972+ def before_compute_at(a: ty.handle, c: ty.handle) -> None:
973+ A = tir.match_buffer(a, (128, 128), "float32")
974+ B = tir.alloc_buffer((128, 128), "float32")
975+ C = tir.match_buffer(c, (128, 128), "float32")
976+ with tir.block([128, 128], "B") as [vi, vj]:
977+ B[vi, vj] = A[vi, vj] * 2.0
978+ with tir.block([128, 128], "C") as [vi, vj]:
979+ C[vi, vj] = B[vi, vj] + 1.0
980+
981+ Create the schedule and do compute-at:
982+
983+ .. code-block:: python
984+
985+ sch = tir.Schedule(before_compute_at)
986+ block = sch.get_block("B")
987+ loop, _ = sch.get_loops(sch.get_block("C"))
988+ sch.compute_at(block, loop, preserve_unit_loops=False)
989+ print(tvm.script.asscript(sch.mod["main"]))
990+
991+ After applying compute-at, the IR becomes:
992+
993+ .. code-block:: python
994+
995+ @tvm.script.tir
996+ def after_compute_at(a: ty.handle, c: ty.handle) -> None:
997+ A = tir.match_buffer(a, (128, 128), "float32")
998+ B = tir.alloc_buffer((128, 128), "float32")
999+ C = tir.match_buffer(c, (128, 128), "float32")
1000+ for i in tir.serial(0, 128):
1001+ for j in tir.serial(0, 128):
1002+ with tir.block([128, 128], "B") as [vi, vj]:
1003+ tir.bind(vi, i)
1004+ tir.bind(vj, j)
1005+ B[vi, vj] = A[vi, vj] * 2.0
1006+ for j in tir.serial(0, 128):
1007+ with tir.block([128, 128], "C") as [vi, vj]:
1008+ tir.bind(vi, i)
1009+ tir.bind(vj, j)
1010+ C[vi, vj] = B[vi, vj] + 1.0
1011+
1012+ """
1013+ _ffi_api .ScheduleComputeAt ( # type: ignore # pylint: disable=no-member
1014+ self ,
1015+ block ,
1016+ loop ,
1017+ preserve_unit_loops ,
1018+ )
1019+
1020+ def reverse_compute_at (
1021+ self ,
1022+ block : BlockRV ,
1023+ loop : LoopRV ,
1024+ preserve_unit_loops : bool = False ,
1025+ ) -> None :
1026+ """Reverse-Compute-At. Move a consumer block under the specific loop, and regenerate the
1027+ loops induced by the block so that the buffer region consumed by the consumer block could
1028+ cover those regions produced by its producer blocks under the given loop. It requires:
1029+
1030+ 1) `block` and `loop` are under the same scope, `loop` is not the ancestor of `block`
1031+
1032+ 2) The scope block has stage-pipeline property
1033+
1034+ 3) The subtree of the scope block, where the given block is in, satisfies the compact
1035+ dataflow condition. i.e. all the blocks in the scope block's subtree must be either
1036+ complete block or reduction block
1037+
1038+ 4) All the producers of the block are under the given loop
1039+
1040+ Parameters
1041+ ----------
1042+ block : BlockRV
1043+ The block to be moved
1044+
1045+ loop: LoopRV
1046+ The loop where the block to be moved under
1047+
1048+ preserve_unit_loops: bool
1049+ Whether to keep the trivial loops whose extents are 1
1050+
1051+ Examples
1052+ --------
1053+
1054+ Before reverse-compute-at, in TensorIR, the IR is:
1055+
1056+ .. code-block:: python
1057+
1058+ @tvm.script.tir
1059+ def before_reverse_compute_at(a: ty.handle, c: ty.handle) -> None:
1060+ A = tir.match_buffer(a, (128, 128), "float32")
1061+ B = tir.alloc_buffer((128, 128), "float32")
1062+ C = tir.match_buffer(c, (128, 128), "float32")
1063+ with tir.block([128, 128], "B") as [vi, vj]:
1064+ B[vi, vj] = A[vi, vj] * 2.0
1065+ with tir.block([128, 128], "C") as [vi, vj]:
1066+ C[vi, vj] = B[vi, vj] + 1.0
1067+
1068+ Create the schedule and do reverse-compute-at:
1069+
1070+ .. code-block:: python
1071+
1072+ sch = tir.Schedule(before_reverse_compute_at)
1073+ block = sch.get_block("C")
1074+ loop, _ = sch.get_loops(sch.get_block("B"))
1075+ sch.reverse_compute_at(block, loop, preserve_unit_loops=False)
1076+ print(tvm.script.asscript(sch.mod["main"]))
1077+
1078+ After applying reverse-compute-at, the IR becomes:
1079+
1080+ .. code-block:: python
1081+
1082+ @tvm.script.tir
1083+ def after_reverse_compute_at(a: ty.handle, c: ty.handle) -> None:
1084+ A = tir.match_buffer(a, (128, 128), "float32")
1085+ B = tir.alloc_buffer((128, 128), "float32")
1086+ C = tir.match_buffer(c, (128, 128), "float32")
1087+ for i in tir.serial(0, 128):
1088+ for j in tir.serial(0, 128):
1089+ with tir.block([128, 128], "B") as [vi, vj]:
1090+ tir.bind(vi, i)
1091+ tir.bind(vj, j)
1092+ B[vi, vj] = A[vi, vj] * 2.0
1093+ for j in tir.serial(0, 128):
1094+ with tir.block([128, 128], "C") as [vi, vj]:
1095+ tir.bind(vi, i)
1096+ tir.bind(vj, j)
1097+ C[vi, vj] = B[vi, vj] + 1.0
1098+
1099+ """
1100+ _ffi_api .ScheduleReverseComputeAt ( # type: ignore # pylint: disable=no-member
1101+ self ,
1102+ block ,
1103+ loop ,
1104+ preserve_unit_loops ,
1105+ )
1106+
9301107 def compute_inline (self , block : BlockRV ) -> None :
9311108 """Inline a block into its consumer(s). It requires:
9321109
@@ -1189,10 +1366,15 @@ def after_rfactor(a: ty.handle, b: ty.handle) -> None:
11891366 """
11901367 return _ffi_api .ScheduleRFactor (self , loop , factor_axis ) # type: ignore # pylint: disable=no-member
11911368
1192- ######## Schedule: Block annotatoin ########
1369+ ######## Schedule: Block annotation ########
11931370
11941371 def storage_align ( # pylint: disable=too-many-arguments
1195- self , block : BlockRV , buffer_index : int , axis : int , factor : int , offset : int
1372+ self ,
1373+ block : BlockRV ,
1374+ buffer_index : int ,
1375+ axis : int ,
1376+ factor : int ,
1377+ offset : int ,
11961378 ) -> None :
11971379 """Set alignment requirement for specific dimension such that
11981380 stride[axis] == k * factor + offset for some k. This is useful to set memory layout for more
0 commit comments