Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions include/tvm/te/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,21 +131,29 @@ class Stage : public ObjectRef {
* \param factor The split factor of the loop.
* \param p_outer The result outer domain
* \param p_inner The result inner domain.
* \param disable_predication If enabled, don't create a predicate for guarding the
* loop. This can be useful when splitting with scalable factors that the schedule writer
* knows are divisible by the loop bound.
* Warning: enabling this feature may result in incorrect code generation if not used carefully.
* \return reference to self.
*/
TVM_DLL Stage& split(IterVar parent, PrimExpr factor, IterVar* p_outer,
IterVar* p_inner); // NOLINT(*)
TVM_DLL Stage& split(IterVar parent, PrimExpr factor, IterVar* p_outer, IterVar* p_inner,
bool disable_predication = false); // NOLINT(*)
/*!
* \brief Split the iteration with given number of parts.
*
* \param parent The parent domain.
* \param nparts The number of parts in the outer domain.
* \param p_outer The result outer domain.
* \param p_inner The result inner domain.
* \param disable_predication If enabled, don't create a predicate for guarding the
* loop. This can be useful when splitting with scalable factors that the schedule writer
* knows are divisible by the loop bound.
* Warning: enabling this feature may result in incorrect code generation if not used carefully.
* \return reference to self.
*/
TVM_DLL Stage& split_by_nparts(IterVar parent, PrimExpr nparts, IterVar* p_outer,
IterVar* p_inner); // NOLINT(*)
IterVar* p_inner, bool disable_predication = false); // NOLINT(*)
/*!
* \brief Fuse the inner outer domain to the target
* \param outer The outer domain to be fused.
Expand Down Expand Up @@ -761,13 +769,16 @@ class SplitNode : public IterVarRelationNode {
PrimExpr factor;
/*! \brief Number of parts, only factor or nparts can be given */
PrimExpr nparts;
/*! \brief Whether to disable generation of predication. */
bool disable_predication;

void VisitAttrs(AttrVisitor* v) {
v->Visit("parent", &parent);
v->Visit("outer", &outer);
v->Visit("inner", &inner);
v->Visit("factor", &factor);
v->Visit("nparts", &nparts);
v->Visit("disable_predication", &disable_predication);
}

static constexpr const char* _type_key = "Split";
Expand All @@ -780,7 +791,8 @@ class SplitNode : public IterVarRelationNode {
*/
class Split : public IterVarRelation {
public:
TVM_DLL Split(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor, PrimExpr nparts);
TVM_DLL Split(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor, PrimExpr nparts,
bool disable_predication);

TVM_DEFINE_OBJECT_REF_METHODS(Split, IterVarRelation, SplitNode);
};
Expand Down
11 changes: 8 additions & 3 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -349,11 +349,16 @@ class ScheduleNode : public runtime::Object {
* \param loop_rv The loop to be split
* \param factors The positive tiling factors, and at most one of which is `NullOpt`, which means
* that factor is inferred.
* \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings
* \return The new loops after split
* \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings.
* \param disable_predication If enabled, don't create a predicate for guarding the
* loop. This can be useful when splitting with scalable factors that the schedule writer
* knows are divisible by the loop bound.
* Warning: enabling this feature may result in incorrect code generation if not used carefully.
* \return The new loops after split.
*/
virtual Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors,
bool preserve_unit_iters = true) = 0;
bool preserve_unit_iters = true,
bool disable_predication = false) = 0;
/*!
* \brief Partition the loops into sequence of multiple loops
* 1) The loop can't have annotation or thread binding.
Expand Down
14 changes: 11 additions & 3 deletions python/tvm/te/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def rfactor(self, tensor, axis, factor_axis=0):
class Stage(Object):
"""A Stage represents schedule for one operation."""

def split(self, parent, factor=None, nparts=None):
def split(self, parent, factor=None, nparts=None, disable_predication=False):
"""Split the stage either by factor providing outer scope, or both

Parameters
Expand All @@ -215,6 +215,14 @@ def split(self, parent, factor=None, nparts=None):
nparts : Expr, optional
The number of outer parts.

disable_predication : bool, optional
If enabled, don't create a predicate for guarding the loop. This can
be useful when splitting with scalable factors that the schedule writer
knows are divisible by the loop bound.

Warning: enabling this feature may result in incorrect code generation
if not used carefully.

Returns
-------
outer : IterVar
Expand All @@ -226,11 +234,11 @@ def split(self, parent, factor=None, nparts=None):
if nparts is not None:
if factor is not None:
raise ValueError("Do not need to provide both outer and nparts")
outer, inner = _ffi_api.StageSplitByNParts(self, parent, nparts)
outer, inner = _ffi_api.StageSplitByNParts(self, parent, nparts, disable_predication)
else:
if factor is None:
raise ValueError("Either nparts or factor need to be provided")
outer, inner = _ffi_api.StageSplitByFactor(self, parent, factor)
outer, inner = _ffi_api.StageSplitByFactor(self, parent, factor, disable_predication)
return outer, inner

def fuse(self, *args):
Expand Down
15 changes: 14 additions & 1 deletion python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,7 @@ def split(
loop: LoopRV,
factors: List[Union[int, ExprRV, None]],
preserve_unit_iters: bool = True,
disable_predication: bool = False,
) -> List[LoopRV]:
"""Split a loop into a list of consecutive loops. It requires:
1) The loop can't have annotation or thread binding.
Expand All @@ -759,6 +760,14 @@ def split(
preserve_unit_iters : bool
Whether or not to preserve unit iterators in block bindings

disable_predication : bool
If enabled, don't create a predicate for guarding the loop. This can
be useful when splitting with scalable factors that the schedule writer
knows are divisible by the loop bound.

Warning: enabling this feature may result in incorrect code generation
if not used carefully.

Returns
-------
split_loops : List[LoopRV]
Expand Down Expand Up @@ -809,7 +818,11 @@ def after_split(a: T.handle, b: T.handle) -> None:
# that there is at most one None in `factors`
return list(
_ffi_api.ScheduleSplit( # type: ignore # pylint: disable=no-member
self, loop, factors, preserve_unit_iters
self,
loop,
factors,
preserve_unit_iters,
disable_predication,
)
)

Expand Down
20 changes: 20 additions & 0 deletions src/arith/analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>

#include "../tir/analysis/check_contains.h"
#include "./scalable_expression.h"
#include "const_fold.h"
#include "product_normal_form.h"

Expand Down Expand Up @@ -227,6 +229,24 @@ bool Analyzer::CanProve(const PrimExpr& expr, ProofStrength strength) {
}
}

// Current analysis may not be powerful enough to prove expressions containing
// the same symbolic value multiple times. However, when the symbolic values are
// "T.vscale" and the compile target uses a scalable architecture extension like
// SVE, we can make some assumptions about the value of vscale and iterate over a
// space of pre-defined values to attempt to prove the expression.
if (tir::CheckContains::ExprContains(expr, IsVScaleCall)) {
Target curr_target = tvm::Target::Current();
if (curr_target.defined() && curr_target->features.defined() &&
(curr_target->features.find("has_sve") != curr_target->features.end()) &&
curr_target->GetFeature<Bool>("has_sve").value_or(Bool(false)).operator bool()) {
return CanProveVscaleExpressionFromKnownValues(this, simplified, kAArch64VScaleValues);
}
LOG(WARNING)
<< "The expression contains scalable values. An attempt to prove by substituting "
"with known values of vscale was not performed. This proof currently only supports "
"AArch64 SVE targets, but the target was "
<< curr_target;
}
return false;
}

Expand Down
38 changes: 38 additions & 0 deletions src/arith/scalable_expression.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>

#include <vector>

#include "../tir/transforms/replace_selected_expr.h"
#include "./pattern_match.h"

namespace tvm {
Expand All @@ -39,6 +42,19 @@ bool IsVScaleCall(const PrimExpr& expr) {
return false;
}

PrimExpr SubstituteVScaleWithKnownValue(const PrimExpr& expr, unsigned int vscale_value) {
std::function<bool(const PrimExpr&)> predicate_selector = [](const PrimExpr& current_expr) {
return IsVScaleCall(current_expr);
};
std::function<bool(const PrimExpr&)> can_replace_inside = [](const PrimExpr& current_expr) {
return true;
};

return tir::ReplaceSelectedExpr::ReplaceSelectedExprInExpr(
expr, predicate_selector, tir::MakeConstScalar(DataType::Int(32), vscale_value),
can_replace_inside);
}

std::optional<int> ExtractVscaleFactor(const PrimExpr& lanes) {
PVar<IntImm> multiplier;
PCallExpr<PVscaleOp> vscale;
Expand All @@ -50,5 +66,27 @@ std::optional<int> ExtractVscaleFactor(const PrimExpr& lanes) {
}
}

bool IsComparison(const PrimExpr& expr) {
return expr->IsInstance<tir::LENode>() || expr->IsInstance<tir::LTNode>() ||
expr->IsInstance<tir::GENode>() || expr->IsInstance<tir::GTNode>() ||
expr->IsInstance<tir::EQNode>() || expr->IsInstance<tir::NENode>();
}

bool CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const PrimExpr& expr,
const std::vector<unsigned int>& vscale_values) {
ICHECK(IsComparison(expr)) << "Expected comparison but got: " << expr;
bool can_prove_expr = true;
for (const unsigned int vscale_value : vscale_values) {
PrimExpr result = SubstituteVScaleWithKnownValue(expr, vscale_value);
result = analyzer->Simplify(result);
const int64_t* as_int = tir::as_const_int(result);
if (!as_int || *as_int == 0) {
can_prove_expr = false;
break;
}
}
return can_prove_expr;
}

} // namespace arith
} // namespace tvm
25 changes: 25 additions & 0 deletions src/arith/scalable_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,27 +25,52 @@
#ifndef TVM_ARITH_SCALABLE_EXPRESSION_H_
#define TVM_ARITH_SCALABLE_EXPRESSION_H_

#include <tvm/arith/analyzer.h>
#include <tvm/ir/expr.h>

#include <optional>
#include <vector>

namespace tvm {
namespace arith {

/*! \brief A list of known vscale values to try for an AArch64 SVE target. */
static const std::vector<unsigned int> kAArch64VScaleValues = {1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16};

/*!
* \brief Check if an expr is a call to the vscale intrinsic.
* \param expr The expr to check
* \return True if the expr is a call to the vscale intrinsic, false if not.
*/
bool IsVScaleCall(const PrimExpr& expr);

/*!
* \brief Substitute a vscale intrinsic call with a known scalar value.
* \param expr The expr to apply substitutions to.
* \param vscale_value The scalar value to replace vscale with.
* \return A rewritten expression with vscale values replaced with a scalar value.
*/
PrimExpr SubstituteVScaleWithKnownValue(const PrimExpr& expr, unsigned int vscale_value);

/*!
* \brief Returns the vscale multiplier as a nullable type
* \param lanes The scalable lanes as a PrimExpr
* \return vscale multiplier as std::optional<int>
*/
std::optional<int> ExtractVscaleFactor(const PrimExpr& lanes);

/*!
* \brief Check if the expression can be proven when evaluating it on all possible values
of vscale.
* \param analyzer An analyzer instance.
* \param expr The expression to try to prove.
* \param vscale_values A list of values to substitute vscale with.
* \return Whether or not the expression can be proven with this technique.
*/
bool CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const PrimExpr& expr,
const std::vector<unsigned int>& vscale_values);

} // namespace arith
} // namespace tvm

Expand Down
3 changes: 2 additions & 1 deletion src/te/schedule/message_passing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,8 @@ void PassUpBoundCheck(const Stage& s, const Map<IterVar, Range>& dom_map,
if (outer || inner) {
state[s->parent] = true;
} else {
if (analyzer->CanProve(dom_map.at(s->parent)->extent == factor * step)) {
if (analyzer->CanProve(dom_map.at(s->parent)->extent == factor * step) ||
s->disable_predication) {
state[s->parent] = false;
} else {
state[s->parent] = true;
Expand Down
30 changes: 17 additions & 13 deletions src/te/schedule/schedule_lang.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ DataType MatchDataType(std::vector<DataType> dtypes) {
}

void SplitHelper(StageNode* self, IterVar parent, PrimExpr factor, PrimExpr nparts,
IterVar* p_outer, IterVar* p_inner) {
IterVar* p_outer, IterVar* p_inner, bool disable_predication) {
// Check if split is valid.
ICHECK(parent->iter_type == kDataPar || parent->iter_type == kCommReduce ||
parent->iter_type == kOrdered)
Expand All @@ -83,7 +83,7 @@ void SplitHelper(StageNode* self, IterVar parent, PrimExpr factor, PrimExpr npar
Array<IterVar>& all_vars = self->all_iter_vars;
Array<IterVar>& leaf_vars = self->leaf_iter_vars;
size_t pos = FindLeafVar(all_vars.GetArrayNode(), leaf_vars.GetArrayNode(), parent);
self->relations.push_back(Split(parent, outer, inner, factor, nparts));
self->relations.push_back(Split(parent, outer, inner, factor, nparts, disable_predication));
// add vars to all vars
all_vars.push_back(outer);
all_vars.push_back(inner);
Expand Down Expand Up @@ -226,17 +226,17 @@ Stage& Stage::set_store_predicate(PrimExpr predicate) {
return *this;
}

Stage& Stage::split(IterVar parent, PrimExpr factor, IterVar* p_outer,
IterVar* p_inner) { // NOLINT(*)
Stage& Stage::split(IterVar parent, PrimExpr factor, IterVar* p_outer, IterVar* p_inner,
bool disable_predication) { // NOLINT(*)
With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
SplitHelper(operator->(), parent, factor, PrimExpr(), p_outer, p_inner);
SplitHelper(operator->(), parent, factor, PrimExpr(), p_outer, p_inner, disable_predication);
return *this;
}

Stage& Stage::split_by_nparts(IterVar parent, PrimExpr nparts, IterVar* p_outer,
IterVar* p_inner) { // NOLINT(*)
Stage& Stage::split_by_nparts(IterVar parent, PrimExpr nparts, IterVar* p_outer, IterVar* p_inner,
bool disable_predication) { // NOLINT(*)
With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
SplitHelper(operator->(), parent, PrimExpr(), nparts, p_outer, p_inner);
SplitHelper(operator->(), parent, PrimExpr(), nparts, p_outer, p_inner, disable_predication);
return *this;
}

Expand Down Expand Up @@ -805,13 +805,15 @@ void ScheduleContext::ExitWithScope() {
}
}

Split::Split(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor, PrimExpr nparts) {
Split::Split(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor, PrimExpr nparts,
bool disable_predication) {
auto n = make_object<SplitNode>();
n->parent = parent;
n->outer = outer;
n->inner = inner;
n->factor = factor;
n->nparts = nparts;
n->disable_predication = disable_predication;
data_ = std::move(n);
}

Expand Down Expand Up @@ -927,6 +929,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << ", nparts=";
p->Print(op->nparts);
}
p->stream << ", disable_predication=";
p->stream << op->disable_predication;
p->stream << ')';
})
.set_dispatch<FuseNode>([](const ObjectRef& node, ReprPrinter* p) {
Expand Down Expand Up @@ -973,16 +977,16 @@ TVM_REGISTER_GLOBAL("te.StageSetScope").set_body_method(&Stage::set_scope);
TVM_REGISTER_GLOBAL("te.StageBind").set_body_method(&Stage::bind);

TVM_REGISTER_GLOBAL("te.StageSplitByFactor")
.set_body_typed([](Stage stage, IterVar parent, PrimExpr factor) {
.set_body_typed([](Stage stage, IterVar parent, PrimExpr factor, bool disable_predication) {
IterVar outer, inner;
stage.split(parent, factor, &outer, &inner);
stage.split(parent, factor, &outer, &inner, disable_predication);
return Array<IterVar>({outer, inner});
});

TVM_REGISTER_GLOBAL("te.StageSplitByNParts")
.set_body_typed([](Stage stage, IterVar parent, PrimExpr nparts) {
.set_body_typed([](Stage stage, IterVar parent, PrimExpr nparts, bool disable_predication) {
IterVar outer, inner;
stage.split_by_nparts(parent, nparts, &outer, &inner);
stage.split_by_nparts(parent, nparts, &outer, &inner, disable_predication);
return Array<IterVar>({outer, inner});
});

Expand Down
Loading