Skip to content

Commit d4dd3c8

Browse files
MasterJH5574junrushaoSiyuan FengspectrometerHBHjinhongyii
authored andcommitted
[TensorIR][M2a] Reduction Factoring (RFactor) (apache#8544)
Co-authored-by: Junru Shao <[email protected]> Co-authored-by: Siyuan Feng <[email protected]> Co-authored-by: Bohan Hou <[email protected]> Co-authored-by: Hongyi Jin <[email protected]> Co-authored-by: Wuwei Lin <[email protected]>
1 parent 4ce3af4 commit d4dd3c8

29 files changed

+2656
-165
lines changed

include/tvm/tir/analysis.h

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -96,22 +96,20 @@ TVM_DLL Array<Var> UndefinedVars(const PrimExpr& expr);
9696
TVM_DLL CallEffectKind SideEffect(const PrimExpr& expr);
9797

9898
/*!
99-
* \brief Whether e expression used any var in variable set.
100-
* \param expr The expression to be checked.
101-
* \param vset_contains The check function to see if var is in the vset.
102-
* \return Whether e uses vset.
99+
* \brief Whether the given Stmt uses any var in the given variable set.
100+
* \param stmt The Stmt to be checked.
101+
* \param vset_contains The check function to see if a var is in the variable set.
102+
* \return Whether `stmt` uses any var in the given variable set.
103103
*/
104-
TVM_DLL bool ExprUseVar(const PrimExpr& expr, std::function<bool(const VarNode*)> vset_contains);
104+
TVM_DLL bool UsesVar(const Stmt& stmt, std::function<bool(const VarNode*)> vset_contains);
105105

106106
/*!
107-
* \brief Whether e expression used var.
108-
* \param expr The expression to be checked.
109-
* \param var The variable.
110-
* \return Whether e uses v.
107+
* \brief Whether the given PrimExpr uses any var in the given variable set.
108+
* \param expr The PrimExpr to be checked.
109+
* \param vset_contains The check function to see if var is in the variable set.
110+
* \return Whether `expr` uses any var in the given variable set.
111111
*/
112-
inline bool ExprUseVar(const PrimExpr& expr, const Var& var) {
113-
return ExprUseVar(expr, [&](const VarNode* node) { return var.get() == node; });
114-
}
112+
TVM_DLL bool UsesVar(const PrimExpr& expr, std::function<bool(const VarNode*)> vset_contains);
115113

116114
/*!
117115
* \brief Verifies whether the IR stmt or Expr is in SSA form.

include/tvm/tir/schedule/block_scope.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ class BlockScope : public ObjectRef {
262262
* \param child_block_srefs The srefs to the leaf blocks
263263
* \note We assume the leaf blocks are given in pre-DFS order
264264
*/
265-
TVM_DLL BlockScope(const Array<StmtSRef>& child_block_srefs);
265+
TVM_DLL explicit BlockScope(const Array<StmtSRef>& child_block_srefs);
266266

267267
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockScope, ObjectRef, BlockScopeNode);
268268
};

include/tvm/tir/schedule/schedule.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,24 @@ class ScheduleNode : public runtime::Object {
242242
/******** Schedule: loop binding/annotation ********/
243243
/******** Schedule: cache read/write ********/
244244
/******** Schedule: reduction ********/
245+
/*!
246+
* \brief Factorize an associative reduction block by the specified loop.
247+
* \details An associative reduction cannot be parallelized directly,
248+
* because it leads to potential race condition during accumulation.
249+
* Alternatively, the reduction could be factorized on a loop with the following steps:
250+
* - Step 1: evenly slice the reduction into `n` separate chunks, where `n` is the loop extent
251+
* - Step 2: compute the chunks separately and write the result into `n` intermediate buffers;
252+
* - Step 3: accumulate the `n` separate buffer into the result buffer.
253+
* Note that the Step 2 above introduces opportunities for parallelization.
254+
* RFactor is a schedule primitive that implements the transformation described above.
255+
* \param loop_rv The loop outside block we want to do rfactor
256+
* \param factor_axis The position where the new dimension is placed in the new introduced rfactor
257+
* buffer. Suppose the original reduction block writes to buffer `B` with
258+
* ndim(B) dimensions, then `factor_axis` should be in range `[-ndim(B) - 1,
259+
* ndim(B)]`, and the negative index will be normalized to a non-negative one
260+
* \return The rfactor block
261+
*/
262+
virtual BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) = 0;
245263
/******** Schedule: blockize & tensorize ********/
246264
};
247265

include/tvm/tir/stmt.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -865,6 +865,7 @@ class For : public Stmt {
865865
Map<String, ObjectRef> annotations = Map<String, ObjectRef>(), Span span = Span());
866866

867867
TVM_DEFINE_OBJECT_REF_METHODS(For, Stmt, ForNode);
868+
TVM_DEFINE_OBJECT_REF_COW_METHOD(ForNode);
868869
};
869870

870871
/*!
@@ -1359,6 +1360,24 @@ TVM_DLL PrimExpr TypeAnnotation(DataType dtype, Span span = Span());
13591360
// overload printing of for type.
13601361
TVM_DLL std::ostream& operator<<(std::ostream& os, ForKind kind);
13611362

1363+
// inline implementations
1364+
inline const char* ForKind2String(ForKind t) {
1365+
switch (t) {
1366+
case ForKind::kSerial:
1367+
return "serial";
1368+
case ForKind::kParallel:
1369+
return "parallel";
1370+
case ForKind::kVectorized:
1371+
return "vectorized";
1372+
case ForKind::kUnrolled:
1373+
return "unroll";
1374+
case ForKind::kThreadBinding:
1375+
return "thread_binding";
1376+
}
1377+
LOG(FATAL) << "Unknown ForKind" << t;
1378+
return "Unknown";
1379+
}
1380+
13621381
} // namespace tir
13631382
} // namespace tvm
13641383
#endif // TVM_TIR_STMT_H_

python/tvm/script/special_stmt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def alloc_buffer(
225225
data=None,
226226
strides=None,
227227
elem_offset=None,
228-
scope="",
228+
scope="global",
229229
align=-1,
230230
offset_factor=0,
231231
buffer_type="default",

python/tvm/tir/schedule/schedule.py

Lines changed: 145 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ def before_inline(a: ty.handle, c: ty.handle) -> None:
431431
432432
.. code-block:: python
433433
434-
sch = tir.Schedule(before_inline, debug_mode=True)
434+
sch = tir.Schedule(before_inline)
435435
sch.compute_inline(sch.get_block("B"))
436436
print(tvm.script.asscript(sch.mod["main"]))
437437
@@ -491,7 +491,7 @@ def before_inline(a: ty.handle, c: ty.handle) -> None:
491491
492492
.. code-block:: python
493493
494-
sch = tir.Schedule(before_inline, debug_mode=True)
494+
sch = tir.Schedule(before_inline)
495495
sch.reverse_compute_inline(sch.get_block("C"))
496496
print(tvm.script.asscript(sch.mod["main"]))
497497
@@ -512,6 +512,149 @@ def after_inline(a: ty.handle, c: ty.handle) -> None:
512512
########## Schedule: loop binding/annotation ##########
513513
########## Schedule: cache read/write ##########
514514
########## Schedule: reduction ##########
515+
def rfactor(self, loop: LoopRV, factor_axis: int) -> LoopRV:
516+
"""Factorize an associative reduction block by the specified loop.
517+
518+
An associative reduction cannot be parallelized directly,
519+
because it leads to potential race condition during accumulation.
520+
Alternatively, the reduction could be factorized on a loop with the following steps:
521+
- Step 1: evenly slice the reduction into `n` separate chunks, where `n` is the loop extent
522+
- Step 2: compute the chunks separately and write the result into `n` intermediate buffers;
523+
- Step 3: accumulate the `n` separate buffer into the result buffer.
524+
Note that the Step 2 above introduces opportunities for parallelization.
525+
526+
RFactor is a schedule primitive that implements the transformation described above:
527+
Given a block that writes to buffer `B`, it factorizes a loop of extent `n`.
528+
529+
For example, the pesudocode below accumulates `B[i] = sum(A[i, : , : ])`:
530+
531+
.. code-block:: python
532+
533+
for i in range(128): # loop i is a data parallel loop
534+
for j in range(128): # loop j is a reduction loop
535+
for k in range(128): # loop k is a reduction loop
536+
B[i] = B[i] + A[i, j, k]
537+
538+
Suppose RFactor is applied on the innermost loop `k` and `factor_axis = 1`.
539+
RFactor then creates an intermediate buffer and two blocks.
540+
541+
1. The intermediate buffer, or "rf-buffer" is a buffer of rank `ndim(B) + 1` and
542+
size `size(B) * n`, whose shape expands from `shape(B)` by adding an axis of `n`
543+
at the position specified by `factor_axis`. For example,
544+
545+
* shape(B) = [1, 2, 3], factor_axis = 0 => shape(B_rf) = [n, 1, 2, 3]
546+
* shape(B) = [1, 2, 3], factor_axis = 1 => shape(B_rf) = [1, n, 2, 3]
547+
* shape(B) = [1, 2, 3], factor_axis = 2 => shape(B_rf) = [1, 2, n, 3]
548+
* shape(B) = [1, 2, 3], factor_axis = 3 => shape(B_rf) = [1, 2, 3, n]
549+
550+
2. The rfactor block, or "rf-block", is a block that writes to the `rf-buffer` without
551+
accumulating over the loop `k`, i.e. the loop `k` is converted from a reduction loop
552+
to a data parallel loop. In our example, the rf-block is:
553+
554+
.. code-block:: python
555+
556+
B_rf = np.zeros((128, 128)) # the rf-buffer
557+
for k in range(128): # loop k is converted to a data parallel loop
558+
for i in range(128): # loop i is a data parallel loop (unchanged)
559+
for j in range(128): # loop j is a reduction loop (unchanged)
560+
B_rf[i, k] = B_rf[i, k] + A[i, j, k]
561+
562+
563+
3. The write-back block, or `wb-block`, is a block that accumulates the rf-buffer into
564+
the result buffer. All the reduction loops are removed except the loop `k` for accumulation.
565+
In our example, the wb-block is:
566+
567+
.. code-block:: python
568+
569+
for i in range(128): # loop i is a data parallel loop (unchanged)
570+
# loop j is removed because it is a reduction loop
571+
for k in range(128): # loop k is a reduction loop (unchanged)
572+
B[i] = B[i] + B_rf[i, k]
573+
574+
575+
Parameters
576+
----------
577+
loop : LoopRV
578+
The loop outside block for which we want to do rfactor
579+
factor_axis : int
580+
The position where the new dimension is placed in the new introduced rfactor buffer
581+
582+
Returns
583+
-------
584+
rf_block : BlockRV
585+
The block which computes partial results over each slices (i.e., the first block
586+
as described in the above illustration)
587+
588+
Examples
589+
--------
590+
591+
Before rfactor, in TensorIR, the IR is:
592+
593+
.. code-block:: python
594+
595+
@tvm.script.tir
596+
def before_rfactor(a: ty.handle, b: ty.handle) -> None:
597+
A = tir.match_buffer(a, (128, 128, 128))
598+
B = tir.match_buffer(b, (128,))
599+
with tir.block([128, tir.reduce_axis(0, 128),
600+
tir.reduce_axis(0, 128)], "B") as [vii, vi, vj]:
601+
with tir.init():
602+
B[vii] = 0.0
603+
B[vii] = B[vii] + A[vii, vi, vj]
604+
605+
Create the schedule and do rfactor:
606+
607+
.. code-block:: python
608+
609+
sch = tir.Schedule(before_rfactor)
610+
_, _, k = sch.get_loops(sch.get_block("B"))
611+
sch.rfactor(k, 0)
612+
print(tvm.script.asscript(sch.mod["main"]))
613+
614+
After applying rfactor, the IR becomes:
615+
616+
.. code-block:: python
617+
618+
@tvm.script.tir
619+
def after_rfactor(a: ty.handle, b: ty.handle) -> None:
620+
A = tir.match_buffer(a, [128, 128, 128])
621+
B = tir.match_buffer(b, [128])
622+
B_rf = tir.alloc_buffer([128, 128])
623+
with tir.block([128, 128, tir.reduce_axis(0, 128)], "B_rf") as [vi2, vii, vi]:
624+
with tir.init():
625+
B_rf[vi2, vii] = 0.0
626+
B_rf[vi2, vii] = (B_rf[vi2, vii] + A[vii, vi, vi2])
627+
with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vii_1, vi2_1]:
628+
with tir.init():
629+
B[vii_1] = 0.0
630+
B[vii_1] = (B[vii_1] + B_rf[vi2_1, vii_1])
631+
632+
633+
Note
634+
----
635+
636+
Rfactor requires:
637+
1) `loop` has only one child block, and it is a reduction block;
638+
2) `loop` is a reduction loop, i.e. the loop variable is bound to only reduction variables
639+
in the block binding;
640+
3) `loop` is not parallelized, vectorized, unrolled or bound to any thread axis;
641+
4) The block scope that `loop` is in is a staged-pipeline;
642+
5) The outermost loop outside the reduction block should has the reduction block as its
643+
first child block;
644+
6) The outermost reduction loop should have only one child block;
645+
7) An unary extent loop that is not bound to any reduction or data parallel variables in
646+
the block binding should not appear under some reduction loop;
647+
8) The reduction block should write to only one buffer, and its init and body are both
648+
simple `BufferStore`s, and the pattern is registered as an associative reducer.
649+
The pre-defined patterns include: plus, multiplication, min and max;
650+
9) Each of the loops on top of the block cannot be bound to a data parallel and a
651+
reduction block binding at the same time;
652+
10) `factor_axis` should be in range `[-ndim(B) - 1, ndim(B)]`,
653+
where `B` is the buffer that the reduction block writes to.
654+
Negative indexing is normalized according to numpy convention.
655+
"""
656+
return _ffi_api_schedule.ScheduleRFactor(self, loop, factor_axis) # type: ignore # pylint: disable=no-member
657+
515658
########## Schedule: blockize & tensorize ##########
516659

517660

src/arith/canonical_simplify.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1137,8 +1137,10 @@ PrimExpr CanonicalSimplifier::Impl::SimplifyReduceCombiner(const ReduceNode* op)
11371137
// and recursively mark the corresponding components
11381138
for (size_t i = 0; i < simplified_result.size(); ++i)
11391139
if (!used[i]) {
1140-
if (ExprUseVar(simplified_result[idx], op->combiner->lhs[i]) ||
1141-
ExprUseVar(simplified_result[idx], op->combiner->rhs[i]))
1140+
if (UsesVar(simplified_result[idx],
1141+
[v = op->combiner->lhs[i].get()](const VarNode* var) { return var == v; }) ||
1142+
UsesVar(simplified_result[idx],
1143+
[v = op->combiner->rhs[i].get()](const VarNode* var) { return var == v; }))
11421144
mark_used(i);
11431145
}
11441146
};

src/arith/detect_linear_equation.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ class LinearEqDetector : public ExprFunctor<LinearEqEntry(const PrimExpr&, const
108108
}
109109
LinearEqEntry VisitExprDefault_(const Object* op, const PrimExpr& e) final {
110110
if (fail_) return LinearEqEntry();
111-
if (ExprUseVar(e, var_)) {
111+
if (UsesVar(e, [this](const VarNode* var) { return var == var_.get(); })) {
112112
fail_ = true;
113113
return LinearEqEntry();
114114
} else {
@@ -159,7 +159,7 @@ Array<PrimExpr> DetectLinearEquation(const PrimExpr& e, const Array<Var>& vars)
159159
for (size_t i = vars.size(); i > 1; --i) {
160160
vset.insert(vars[i - 1].get());
161161
// The previous coeff contains the variable
162-
if (ExprUseVar(coeff[i - 2], vset_contains)) {
162+
if (UsesVar(coeff[i - 2], vset_contains)) {
163163
return Array<PrimExpr>();
164164
}
165165
}

src/printer/tir_text_printer.cc

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -487,23 +487,6 @@ Doc TIRTextPrinter::VisitStmt_(const EvaluateNode* op) {
487487
return doc;
488488
}
489489

490-
inline const char* ForKind2String(ForKind t) {
491-
switch (t) {
492-
case ForKind::kSerial:
493-
return "serial";
494-
case ForKind::kParallel:
495-
return "parallel";
496-
case ForKind::kVectorized:
497-
return "vectorized";
498-
case ForKind::kUnrolled:
499-
return "unroll";
500-
case ForKind::kThreadBinding:
501-
return "thread_binding";
502-
}
503-
LOG(FATAL) << "Unknown ForKind";
504-
return "Unknown";
505-
}
506-
507490
Doc TIRTextPrinter::VisitStmt_(const ForNode* op) {
508491
Doc doc;
509492
doc << "for (" << Print(op->loop_var) << ", " << Print(op->min) << ", "

src/printer/tvmscript_printer.cc

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -704,23 +704,6 @@ Doc TVMScriptPrinter::VisitStmt_(const EvaluateNode* op) {
704704
return doc;
705705
}
706706

707-
inline const char* ForKind2String(ForKind t) {
708-
switch (t) {
709-
case ForKind::kSerial:
710-
return "serial";
711-
case ForKind::kParallel:
712-
return "parallel";
713-
case ForKind::kVectorized:
714-
return "vectorized";
715-
case ForKind::kUnrolled:
716-
return "unroll";
717-
case ForKind::kThreadBinding:
718-
return "thread_binding";
719-
}
720-
LOG(FATAL) << "Unknown ForKind";
721-
return "Unknown";
722-
}
723-
724707
Doc TVMScriptPrinter::VisitStmt_(const ForNode* op) {
725708
Doc doc;
726709
var_not_in_headers.insert(op->loop_var.get());

0 commit comments

Comments
 (0)