Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,19 @@ TVM_DLL Pass PointerValueTypeRewrite();
*/
TVM_DLL Pass HoistIfThenElse();

/*!
* \brief Hoist loop-invariant expressions nodes to
* outside the elligible loops.
*
* Can hoist conditionals used in IfThenElse statements and
* expressions, bindings of variables in Let statements and
* expressions, or boolean expressions, configurable to enable/disable
* each hoistable type.
*
* \return The pass.
*/
TVM_DLL Pass HoistExpression();

/*!
* \brief Lower cross-thread reduction from thread
* bindings to intrinsic function calls.
Expand Down
71 changes: 71 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
# under the License.
"""Wrapping existing transformations."""
# pylint: disable=invalid-name


import enum
from typing import Callable, Optional

from . import _ffi_api
Expand Down Expand Up @@ -593,6 +596,74 @@ def HoistIfThenElse(variant: Optional[str] = None):
return _ffi_api.HoistIfThenElse() # type: ignore


class HoistedConditionals(enum.Flag):
"""Flags for use in HoistExpressionConfig.conditional_types

Each bitflag represents a type of expression that should be
hoisted to the outermost loop possible.
"""

Never = 0
""" No hoisting of conditionals """

IfElseStmt = 1
""" If set, look for hoist candidates in IfElseStmt """

IfElseExpr = 2
""" If set, look for hoist candidates in tir.if_then_else """

BooleanExpression = 4
""" If set, look for hoist candidates in all boolean expressions """

UsingBlockVar = 8
""" If set, allow hoisting of conditionals that use a block variable (e.g. threadIdx.x) """

All = IfElseStmt | IfElseExpr | BooleanExpression | UsingBlockVar
""" Enable all hoisting of conditionals"""


class HoistedLetBindings(enum.Flag):
"""Flags for use in HoistExpressionConfig.let_binding_types

Each bitflag represents a type of let binding expression that should be
hoisted to the outermost loop possible.
"""

Never = 0
""" No hoisting of let bindings """

RequiredByConditional = 1
""" Bindings that are used by a hoisted conditional """

LetStmt = 2
""" Bindings occuring in LetStmt """

LetExpr = 4
""" Bindings occuring in Let expressions """

All = RequiredByConditional | LetStmt | LetExpr
""" Enable all hoisting of let bindings """


def HoistExpression():
"""Generalized verison of HoistIfThenElse.

Hoist loop-invariant expressions to outside the eligible loops.
Searches for expressions in:

* LetStmt bindings
* IfThenElse conditions
* Boolean operators

Returns
-------
fpass : tvm.transform.Pass
The result pass

"""
return _ffi_api.HoistExpression() # type: ignore


def LowerCrossThreadReduction():
"""Lower cross-thread reduction from thread bindings to
intrinsic function calls.
Expand Down
8 changes: 7 additions & 1 deletion src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) {
std::function<void()> RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& constraint) {
size_t old_literal_size = literal_constraints_.size();
// we will compare the already simplified result with the constraint,
// so simplify the constarint as well
// so simplify the constraint as well
PrimExpr new_constraint = operator()(constraint);
for (const PrimExpr& subconstraint : ExtractConstraints(new_constraint)) {
if (SideEffect(subconstraint) <= CallEffectKind::kPure) {
Expand Down Expand Up @@ -1673,6 +1673,12 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) {

PrimExpr RewriteSimplifier::Impl::VisitExpr_(const VarNode* op) {
Var var = GetRef<Var>(op);
if (op->dtype == DataType::Bool()) {
if (auto match = TryMatchLiteralConstraint(var)) {
return match.value();
}
}

auto it = var_map_.find(var);
if (it != var_map_.end()) {
return it->second;
Expand Down
Loading