Skip to content

Commit c170843

Browse files
LunderbergMikael Sevenier
authored andcommitted
[TIR] HoistExpression, generalization of HoistIfThenElse (apache#11592)
* [TIR][Arith] Use non-inlined bindings when proving conditional * [TIR][Arith] Recognize Var when used as a literal constraint * [TIR][Arith] Added simplification of constrained if_then_else op This feels like it should definitely be part of RewriteSimplify, but that will require making CanInlineLet be a virtual function. * [TIR] Implemented HoistExpression transformation This is a generalized form of HoistIfThenElse, which can also hoist Let bindings, or portions of conditional expressions. This will be used in upcoming changes to separate compute loops into a slow loop that handles edge cases and a fast branchless loop. * [TIR] Expressed HoistIfThenElse as special case of HoistExpression * Lint fixes * Fixed breakage in tvmc unit test that relied on pass type * More accurate handling of kUsingBlockVar Didn't correctly reproduce previous behavior. In addition to preventing hoisting of expressions that use a block variable (e.g. threadIdx.x), should also prevent hoisting of expressions across a "thread_extent" AttrStmt. * Updated comment for HoistExpression pass * Fix linting error
1 parent 453689d commit c170843

File tree

10 files changed

+1335
-463
lines changed

10 files changed

+1335
-463
lines changed

include/tvm/tir/transform.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,19 @@ TVM_DLL Pass PointerValueTypeRewrite();
364364
*/
365365
TVM_DLL Pass HoistIfThenElse();
366366

367+
/*!
368+
* \brief Hoist loop-invariant expressions nodes to
369+
* outside the elligible loops.
370+
*
371+
* Can hoist conditionals used in IfThenElse statements and
372+
* expressions, bindings of variables in Let statements and
373+
* expressions, or boolean expressions, configurable to enable/disable
374+
* each hoistable type.
375+
*
376+
* \return The pass.
377+
*/
378+
TVM_DLL Pass HoistExpression();
379+
367380
/*!
368381
* \brief Lower cross-thread reduction from thread
369382
* bindings to intrinsic function calls.

python/tvm/tir/transform/transform.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
# under the License.
1717
"""Wrapping existing transformations."""
1818
# pylint: disable=invalid-name
19+
20+
21+
import enum
1922
from typing import Callable, Optional
2023

2124
from . import _ffi_api
@@ -593,6 +596,74 @@ def HoistIfThenElse(variant: Optional[str] = None):
593596
return _ffi_api.HoistIfThenElse() # type: ignore
594597

595598

599+
class HoistedConditionals(enum.Flag):
600+
"""Flags for use in HoistExpressionConfig.conditional_types
601+
602+
Each bitflag represents a type of expression that should be
603+
hoisted to the outermost loop possible.
604+
"""
605+
606+
Never = 0
607+
""" No hoisting of conditionals """
608+
609+
IfElseStmt = 1
610+
""" If set, look for hoist candidates in IfElseStmt """
611+
612+
IfElseExpr = 2
613+
""" If set, look for hoist candidates in tir.if_then_else """
614+
615+
BooleanExpression = 4
616+
""" If set, look for hoist candidates in all boolean expressions """
617+
618+
UsingBlockVar = 8
619+
""" If set, allow hoisting of conditionals that use a block variable (e.g. threadIdx.x) """
620+
621+
All = IfElseStmt | IfElseExpr | BooleanExpression | UsingBlockVar
622+
""" Enable all hoisting of conditionals"""
623+
624+
625+
class HoistedLetBindings(enum.Flag):
626+
"""Flags for use in HoistExpressionConfig.let_binding_types
627+
628+
Each bitflag represents a type of let binding expression that should be
629+
hoisted to the outermost loop possible.
630+
"""
631+
632+
Never = 0
633+
""" No hoisting of let bindings """
634+
635+
RequiredByConditional = 1
636+
""" Bindings that are used by a hoisted conditional """
637+
638+
LetStmt = 2
639+
""" Bindings occuring in LetStmt """
640+
641+
LetExpr = 4
642+
""" Bindings occuring in Let expressions """
643+
644+
All = RequiredByConditional | LetStmt | LetExpr
645+
""" Enable all hoisting of let bindings """
646+
647+
648+
def HoistExpression():
649+
"""Generalized verison of HoistIfThenElse.
650+
651+
Hoist loop-invariant expressions to outside the eligible loops.
652+
Searches for expressions in:
653+
654+
* LetStmt bindings
655+
* IfThenElse conditions
656+
* Boolean operators
657+
658+
Returns
659+
-------
660+
fpass : tvm.transform.Pass
661+
The result pass
662+
663+
"""
664+
return _ffi_api.HoistExpression() # type: ignore
665+
666+
596667
def LowerCrossThreadReduction():
597668
"""Lower cross-thread reduction from thread bindings to
598669
intrinsic function calls.

src/arith/rewrite_simplify.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) {
228228
std::function<void()> RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& constraint) {
229229
size_t old_literal_size = literal_constraints_.size();
230230
// we will compare the already simplified result with the constraint,
231-
// so simplify the constarint as well
231+
// so simplify the constraint as well
232232
PrimExpr new_constraint = operator()(constraint);
233233
for (const PrimExpr& subconstraint : ExtractConstraints(new_constraint)) {
234234
if (SideEffect(subconstraint) <= CallEffectKind::kPure) {
@@ -1673,6 +1673,12 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) {
16731673

16741674
PrimExpr RewriteSimplifier::Impl::VisitExpr_(const VarNode* op) {
16751675
Var var = GetRef<Var>(op);
1676+
if (op->dtype == DataType::Bool()) {
1677+
if (auto match = TryMatchLiteralConstraint(var)) {
1678+
return match.value();
1679+
}
1680+
}
1681+
16761682
auto it = var_map_.find(var);
16771683
if (it != var_map_.end()) {
16781684
return it->second;

0 commit comments

Comments
 (0)