Skip to content

Commit a44cc6e

Browse files
junrushaospectrometerHBHMasterJH5574jinhongyiivinx13
authored
[TensorIR][M2a] Compute-At (#8943)
This PR is part of the TensorIR upstreaming effort (#7527), which adds the following schedule primitives: * `compute-at` * `reverse-compute-at` Co-authored-by: Bohan Hou <[email protected]> Co-authored-by: Ruihang Lai <[email protected]> Co-authored-by: Hongyi Jin <[email protected]> Co-authored-by: Wuwei Lin <[email protected]> Co-authored-by: Siyuan Feng <[email protected]>
1 parent 2232399 commit a44cc6e

30 files changed

+2526
-343
lines changed

include/tvm/arith/int_set.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,17 +121,24 @@ class IntSet : public ObjectRef {
121121
* \return The result set containing the indices in the vector.
122122
*/
123123
static IntSet Vector(PrimExpr vec);
124+
/*!
125+
* \brief Construct a set representing a range [min, min + extent).
126+
* \param min The minimum of the range range
127+
* \param extent The extent of the range.
128+
* \return The constructed set.
129+
*/
130+
static IntSet FromMinExtent(PrimExpr min, PrimExpr extent);
124131
/*!
125132
* \brief Construct a set representing a range.
126133
* \param r The range
127-
* \return constructed set.
134+
* \return The constructed set.
128135
*/
129136
static IntSet FromRange(tvm::Range r);
130137
/*!
131138
* \brief Construct a set representing a interval.
132139
* \param min The minimum value of the interval.
133140
* \param max The maximum value of the interval.
134-
* \return constructed set.
141+
* \return The constructed set.
135142
*/
136143
static IntSet Interval(PrimExpr min, PrimExpr max);
137144

include/tvm/tir/schedule/schedule.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,41 @@ class ScheduleNode : public runtime::Object {
305305
virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
306306
const String& storage_scope) = 0;
307307
/******** Schedule: Compute location ********/
308+
/*!
309+
* \brief Move a producer block under the specific loop, and regenerate the
310+
* loops induced by the block so that the buffer region produced by the producer block could
311+
* cover those regions consumed by its consumer blocks under the given loop. It requires:
312+
* 1) `block` and `loop` are under the same scope, `loop` is not the ancestor of `block`
313+
* 2) The scope block has stage-pipeline property
314+
* 3) The subtree of the scope block, where the given block is in, satisfies the compact dataflow
315+
* condition. i.e. all the blocks in the scope block's subtree must be either complete block or
316+
* reduction block
317+
* 4) The block is not an output block with regard to the scope block, i.e. the buffers written by
318+
* the block are allocated under the scope block
319+
* 5) All the consumers of the block are under the given loop
320+
* \param block_rv The block to be moved
321+
* \param loop_rv The loop where the block to be moved under
322+
* \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1
323+
*/
324+
virtual void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
325+
bool preserve_unit_loops) = 0;
326+
/*!
327+
* \brief Move a consumer block under the specific loop, and regenerate the
328+
* loops induced by the block so that the buffer region consumed by the consumer block could
329+
* cover those regions produced by its producer blocks under the given loop. It requires:
330+
* 1) `block` and `loop` are under the same scope, `loop` is not the ancestor of `block`
331+
* 2) The scope block has stage-pipeline property
332+
* 3) The subtree of the scope block, where the given block is in, satisfies the compact dataflow
333+
* condition. i.e. all the blocks in the scope block's subtree must be either complete block or
334+
* reduction block
335+
* 4) All the producers of the block are under the given loop
336+
*
337+
* \param block_rv The block to be moved
338+
* \param loop_rv The loop where the block to be moved under
339+
* \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1
340+
*/
341+
virtual void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
342+
bool preserve_unit_loops) = 0;
308343
/*!
309344
* \brief Inline a block into its consumer(s). It requires:
310345
* 1) The block is a complete non-root block, which only produces one buffer

include/tvm/tir/schedule/state.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,11 +128,6 @@ class ScheduleStateNode : public Object {
128128
*/
129129
TVM_DLL void Replace(const tir::StmtSRef& src_sref, const Stmt& tgt_stmt,
130130
const Map<Block, Block>& block_sref_reuse);
131-
/*!
132-
* \brief Recalculate the `affine_binding` flag of the scope block info.
133-
* \param scope_sref The sref to the interested scope block.
134-
*/
135-
TVM_DLL void UpdateAffineFlag(const StmtSRef& scope_sref);
136131
/*!
137132
* \brief Trigger the verification according to the `debug_mask` bitmask.
138133
* 1) If the bitmask `kVerifySRefTree` is on, verify the correctness of the sref tree.

python/tvm/tir/schedule/schedule.py

Lines changed: 184 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -927,6 +927,183 @@ def after_cache_write(a: ty.handle, b: ty.handle) -> None:
927927

928928
########## Schedule: Compute location ##########
929929

930+
def compute_at(
931+
self,
932+
block: BlockRV,
933+
loop: LoopRV,
934+
preserve_unit_loops: bool = False,
935+
) -> None:
936+
"""Compute-At. Move a producer block under the specific loop, and regenerate the
937+
loops induced by the block so that the buffer region produced by the producer block could
938+
cover those regions consumed by its consumer blocks under the given loop. It requires:
939+
940+
1) `block` and `loop` are under the same scope, `loop` is not the ancestor of `block`
941+
942+
2) The scope block has stage-pipeline property
943+
944+
3) The subtree of the scope block, where the given block is in, satisfies the compact
945+
dataflow condition. i.e. all the blocks in the scope block's subtree must be either
946+
complete block or reduction block
947+
948+
4) The block is not an output block with regard to the scope block, i.e. the buffers written
949+
by the block are allocated under the scope block
950+
951+
5) All the consumers of the block are under the given loop
952+
953+
Parameters
954+
----------
955+
block : BlockRV
956+
The block to be moved
957+
958+
loop: LoopRV
959+
The loop where the block to be moved under
960+
961+
preserve_unit_loops: bool
962+
Whether to keep the trivial loops whose extents are 1
963+
964+
Examples
965+
--------
966+
967+
Before compute-at, in TensorIR, the IR is:
968+
969+
.. code-block:: python
970+
971+
@tvm.script.tir
972+
def before_compute_at(a: ty.handle, c: ty.handle) -> None:
973+
A = tir.match_buffer(a, (128, 128), "float32")
974+
B = tir.alloc_buffer((128, 128), "float32")
975+
C = tir.match_buffer(c, (128, 128), "float32")
976+
with tir.block([128, 128], "B") as [vi, vj]:
977+
B[vi, vj] = A[vi, vj] * 2.0
978+
with tir.block([128, 128], "C") as [vi, vj]:
979+
C[vi, vj] = B[vi, vj] + 1.0
980+
981+
Create the schedule and do compute-at:
982+
983+
.. code-block:: python
984+
985+
sch = tir.Schedule(before_compute_at)
986+
block = sch.get_block("B")
987+
loop, _ = sch.get_loops(sch.get_block("C"))
988+
sch.compute_at(block, loop, preserve_unit_loops=False)
989+
print(tvm.script.asscript(sch.mod["main"]))
990+
991+
After applying compute-at, the IR becomes:
992+
993+
.. code-block:: python
994+
995+
@tvm.script.tir
996+
def after_compute_at(a: ty.handle, c: ty.handle) -> None:
997+
A = tir.match_buffer(a, (128, 128), "float32")
998+
B = tir.alloc_buffer((128, 128), "float32")
999+
C = tir.match_buffer(c, (128, 128), "float32")
1000+
for i in tir.serial(0, 128):
1001+
for j in tir.serial(0, 128):
1002+
with tir.block([128, 128], "B") as [vi, vj]:
1003+
tir.bind(vi, i)
1004+
tir.bind(vj, j)
1005+
B[vi, vj] = A[vi, vj] * 2.0
1006+
for j in tir.serial(0, 128):
1007+
with tir.block([128, 128], "C") as [vi, vj]:
1008+
tir.bind(vi, i)
1009+
tir.bind(vj, j)
1010+
C[vi, vj] = B[vi, vj] + 1.0
1011+
1012+
"""
1013+
_ffi_api.ScheduleComputeAt( # type: ignore # pylint: disable=no-member
1014+
self,
1015+
block,
1016+
loop,
1017+
preserve_unit_loops,
1018+
)
1019+
1020+
def reverse_compute_at(
1021+
self,
1022+
block: BlockRV,
1023+
loop: LoopRV,
1024+
preserve_unit_loops: bool = False,
1025+
) -> None:
1026+
"""Reverse-Compute-At. Move a consumer block under the specific loop, and regenerate the
1027+
loops induced by the block so that the buffer region consumed by the consumer block could
1028+
cover those regions produced by its producer blocks under the given loop. It requires:
1029+
1030+
1) `block` and `loop` are under the same scope, `loop` is not the ancestor of `block`
1031+
1032+
2) The scope block has stage-pipeline property
1033+
1034+
3) The subtree of the scope block, where the given block is in, satisfies the compact
1035+
dataflow condition. i.e. all the blocks in the scope block's subtree must be either
1036+
complete block or reduction block
1037+
1038+
4) All the producers of the block are under the given loop
1039+
1040+
Parameters
1041+
----------
1042+
block : BlockRV
1043+
The block to be moved
1044+
1045+
loop: LoopRV
1046+
The loop where the block to be moved under
1047+
1048+
preserve_unit_loops: bool
1049+
Whether to keep the trivial loops whose extents are 1
1050+
1051+
Examples
1052+
--------
1053+
1054+
Before reverse-compute-at, in TensorIR, the IR is:
1055+
1056+
.. code-block:: python
1057+
1058+
@tvm.script.tir
1059+
def before_reverse_compute_at(a: ty.handle, c: ty.handle) -> None:
1060+
A = tir.match_buffer(a, (128, 128), "float32")
1061+
B = tir.alloc_buffer((128, 128), "float32")
1062+
C = tir.match_buffer(c, (128, 128), "float32")
1063+
with tir.block([128, 128], "B") as [vi, vj]:
1064+
B[vi, vj] = A[vi, vj] * 2.0
1065+
with tir.block([128, 128], "C") as [vi, vj]:
1066+
C[vi, vj] = B[vi, vj] + 1.0
1067+
1068+
Create the schedule and do reverse-compute-at:
1069+
1070+
.. code-block:: python
1071+
1072+
sch = tir.Schedule(before_reverse_compute_at)
1073+
block = sch.get_block("C")
1074+
loop, _ = sch.get_loops(sch.get_block("B"))
1075+
sch.reverse_compute_at(block, loop, preserve_unit_loops=False)
1076+
print(tvm.script.asscript(sch.mod["main"]))
1077+
1078+
After applying reverse-compute-at, the IR becomes:
1079+
1080+
.. code-block:: python
1081+
1082+
@tvm.script.tir
1083+
def after_reverse_compute_at(a: ty.handle, c: ty.handle) -> None:
1084+
A = tir.match_buffer(a, (128, 128), "float32")
1085+
B = tir.alloc_buffer((128, 128), "float32")
1086+
C = tir.match_buffer(c, (128, 128), "float32")
1087+
for i in tir.serial(0, 128):
1088+
for j in tir.serial(0, 128):
1089+
with tir.block([128, 128], "B") as [vi, vj]:
1090+
tir.bind(vi, i)
1091+
tir.bind(vj, j)
1092+
B[vi, vj] = A[vi, vj] * 2.0
1093+
for j in tir.serial(0, 128):
1094+
with tir.block([128, 128], "C") as [vi, vj]:
1095+
tir.bind(vi, i)
1096+
tir.bind(vj, j)
1097+
C[vi, vj] = B[vi, vj] + 1.0
1098+
1099+
"""
1100+
_ffi_api.ScheduleReverseComputeAt( # type: ignore # pylint: disable=no-member
1101+
self,
1102+
block,
1103+
loop,
1104+
preserve_unit_loops,
1105+
)
1106+
9301107
def compute_inline(self, block: BlockRV) -> None:
9311108
"""Inline a block into its consumer(s). It requires:
9321109
@@ -1189,10 +1366,15 @@ def after_rfactor(a: ty.handle, b: ty.handle) -> None:
11891366
"""
11901367
return _ffi_api.ScheduleRFactor(self, loop, factor_axis) # type: ignore # pylint: disable=no-member
11911368

1192-
######## Schedule: Block annotatoin ########
1369+
######## Schedule: Block annotation ########
11931370

11941371
def storage_align( # pylint: disable=too-many-arguments
1195-
self, block: BlockRV, buffer_index: int, axis: int, factor: int, offset: int
1372+
self,
1373+
block: BlockRV,
1374+
buffer_index: int,
1375+
axis: int,
1376+
factor: int,
1377+
offset: int,
11961378
) -> None:
11971379
"""Set alignment requirement for specific dimension such that
11981380
stride[axis] == k * factor + offset for some k. This is useful to set memory layout for more

src/arith/int_set.cc

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,13 @@ inline bool ProveEqual(Analyzer* analyzer, PrimExpr lhs, PrimExpr rhs) {
607607
return is_zero(analyzer->Simplify(lhs - rhs));
608608
}
609609

610+
IntSet IntSet::FromMinExtent(PrimExpr min, PrimExpr extent) {
611+
if (is_one(extent)) {
612+
return IntSet::SinglePoint(min);
613+
}
614+
return IntervalSet(min, extent + min - 1);
615+
}
616+
610617
IntSet IntSet::FromRange(Range r) {
611618
// must make sure it can be matched back by MatchRange.
612619
if (is_one(r->extent)) {
@@ -815,46 +822,45 @@ IntSet EvalSet(Range r, const Map<IterVar, IntSet>& dom_map) {
815822
return EvalSet(r, ConvertDomMap(dom_map));
816823
}
817824

818-
Optional<Array<arith::IntSet>> EstimateRegionLowerBound(const Array<Range>& region,
819-
const Map<Var, Range>& var_dom,
820-
const PrimExpr& predicate,
821-
arith::Analyzer* analyzer) {
825+
Optional<Array<IntSet>> EstimateRegionLowerBound(const Array<Range>& region,
826+
const Map<Var, Range>& var_dom,
827+
const PrimExpr& predicate, Analyzer* analyzer) {
822828
int ndim = region.size();
823-
Array<arith::IterSumExpr> iter_sum_exprs{nullptr};
829+
Array<IterSumExpr> iter_sum_exprs{nullptr};
824830
{
825831
Array<PrimExpr> affine_indices;
826832
affine_indices.reserve(ndim);
827833
for (const Range& range : region) {
828834
affine_indices.push_back(range->min);
829835
}
830-
iter_sum_exprs = arith::DetectIterMap(
836+
iter_sum_exprs = DetectIterMap(
831837
/*indices=*/affine_indices, /*input_iters=*/var_dom,
832838
/*predicate=*/predicate, /*require_bijective=*/false, analyzer);
833839
}
834840
if (iter_sum_exprs.empty()) {
835841
return NullOpt;
836842
}
837843
ICHECK_EQ(iter_sum_exprs.size(), ndim);
838-
Array<arith::IntSet> result;
844+
Array<IntSet> result;
839845
result.reserve(ndim);
840846
for (int i = 0; i < ndim; ++i) {
841-
const arith::IterSumExpr& sum_expr = iter_sum_exprs[i];
847+
const IterSumExpr& sum_expr = iter_sum_exprs[i];
842848
const Range& range = region[i];
843849
if (sum_expr->args.empty()) {
844-
result.push_back(arith::IntSet::Interval(sum_expr->base, sum_expr->base + range->extent));
850+
result.push_back(IntSet::FromMinExtent(sum_expr->base, range->extent));
845851
continue;
846852
}
847853
ICHECK_EQ(sum_expr->args.size(), 1);
848-
const arith::IterSplitExpr& split = sum_expr->args[0];
854+
const IterSplitExpr& split = sum_expr->args[0];
849855
if (!analyzer->CanProve(range->extent >= split->scale)) {
850856
return NullOpt;
851857
}
852858
const PrimExpr& base = sum_expr->base;
853859
// IterSplitExpr: (source // lower_factor) % extent * scale
854860
// where `(source // lower_factor) % extent` is within [0, extent - 1]
855861
// Therefore, the range of `region[i]->min` is `base + [0, (extent - 1) * scale]`
856-
result.push_back(arith::IntSet::Interval(
857-
base, split->extent * split->scale + base + (range->extent - split->scale) - 1));
862+
result.push_back(
863+
IntSet::FromMinExtent(base, split->extent * split->scale + (range->extent - split->scale)));
858864
}
859865
return result;
860866
}

src/relay/transforms/fold_scale_axis.cc

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,9 @@ class ForwardPrep : private MixedModeVisitor {
243243
}
244244
}
245245
// Visitor pattern override.
246-
void VisitExpr_(const LetNode* op) {
246+
void VisitExpr_(const TupleGetItemNode* op) final { MixedModeVisitor::VisitExpr_(op); }
247+
248+
void VisitExpr_(const LetNode* op) final {
247249
ExprVisitor::VisitExpr_(op);
248250
// do pass through condition
249251
// by assigning NullValue<Message>
@@ -256,13 +258,13 @@ class ForwardPrep : private MixedModeVisitor {
256258
flist_.push_back(flazy);
257259
}
258260

259-
void VisitExpr_(const FunctionNode* op) {
261+
void VisitExpr_(const FunctionNode* op) final {
260262
ExprVisitor::VisitExpr_(op);
261263
auto flazy = [this, op] { this->Update(op->body, NullValue<Message>()); };
262264
flist_.push_back(flazy);
263265
}
264266

265-
void VisitExpr_(const CallNode* call) {
267+
void VisitExpr_(const CallNode* call) final {
266268
ExprVisitor::VisitExpr_(call);
267269
// function to be lazily invoked
268270
auto flazy = [this, call]() {
@@ -292,7 +294,7 @@ class ForwardPrep : private MixedModeVisitor {
292294
flist_.push_back(flazy);
293295
}
294296

295-
void VisitExpr_(const TupleNode* op) {
297+
void VisitExpr_(const TupleNode* op) final {
296298
ExprVisitor::VisitExpr_(op);
297299
// do not support pass scale through tuple for now.
298300
auto flazy = [this, op]() {
@@ -303,7 +305,7 @@ class ForwardPrep : private MixedModeVisitor {
303305
flist_.push_back(flazy);
304306
}
305307

306-
void VisitExpr_(const IfNode* op) {
308+
void VisitExpr_(const IfNode* op) final {
307309
ExprVisitor::VisitExpr_(op);
308310
// do pass through condition
309311
// by assigning NullValue<Message>

0 commit comments

Comments
 (0)