Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 41 additions & 5 deletions include/tvm/arith/analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ class ConstIntBoundAnalyzer {
*
* \param var The variable of interest.
* \param info The bound information.
* \param allow_override Whether do we allow override of existing information.
* \param allow_override whether we allow override of existing information.
*/
TVM_DLL void Update(const Var& var, const ConstIntBound& info, bool allow_override = false);
/*!
Expand Down Expand Up @@ -224,7 +224,7 @@ class ModularSetAnalyzer {
*
* \param var The variable of interest.
* \param info The bound information.
* \param allow_override Whether do we allow override of existing information.
* \param allow_override whether we allow override of existing information.
*/
TVM_DLL void Update(const Var& var, const ModularSet& info, bool allow_override = false);

Expand Down Expand Up @@ -263,10 +263,16 @@ class RewriteSimplifier {
*
* \param var The variable of interest.
* \param new_expr
* \param allow_override Whether do we allow override of existing information.
* \param allow_override Whether we allow override of existing information.
*/
TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool allow_override = false);

/*!
* \brief Update the internal state to enter constraint.
* \param constraint A constraint expression.
*
* \return an exit function that must be called to cleanup the constraint can be nullptr.
*/
std::function<void()> EnterConstraint(const PrimExpr& constraint);

private:
Expand Down Expand Up @@ -297,7 +303,7 @@ class CanonicalSimplifier {
*
* \param var The variable of interest.
* \param new_expr
* \param allow_override Whether do we allow override of existing information.
* \param allow_override whether we allow override of existing information.
*/
TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool allow_override = false);

Expand Down Expand Up @@ -347,7 +353,7 @@ class ConstraintContext {
/*! \brief The constraint */
PrimExpr constraint_;
/*! \brief function to be called in recovery */
std::function<void()> exit_;
std::vector<std::function<void()>> recovery_functions_;
};

/*!
Expand All @@ -365,6 +371,36 @@ class IntSetAnalyzer {
*/
TVM_DLL IntSet operator()(const PrimExpr& expr, const Map<Var, IntSet>& dom_map);

/*!
* \brief Find a symbolic integer set that contains all possible
* values of expr given the domain of each variables, using
* the domain map defined by bound variables.
*
* \param expr The expression of interest.
* \return the result of the analysis.
*/
TVM_DLL IntSet operator()(const PrimExpr& expr);

/*!
* \brief Update binding of var to a new expression.
*
* \param var The variable of interest.
* \param new_interval_set The set of allowed values for this var.
* \param allow_override whether we allow override of existing information.
*/
TVM_DLL void Update(const Var& var, const IntSet& new_interval_set, bool allow_override = false);

/*!
* \brief Update binding of var to a new expression.
*
* \param var The variable of interest.
* \param new_range The range of allowed values for this var.
* \param allow_override whether we allow override of existing information.
*/
TVM_DLL void Bind(const Var& var, const Range& new_range, bool allow_override = false);

std::function<void()> EnterConstraint(const PrimExpr& constraint);

private:
friend class Analyzer;
explicit IntSetAnalyzer(Analyzer* parent);
Expand Down
26 changes: 14 additions & 12 deletions src/arith/analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ void Analyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) {
this->modular_set.Update(var, this->modular_set(new_expr), allow_override);
this->rewrite_simplify.Update(var, new_expr, allow_override);
this->canonical_simplify.Update(var, new_expr, allow_override);
this->int_set.Update(var, this->int_set(new_expr), allow_override);
}

void Analyzer::Bind(const Var& var, const Range& range, bool allow_override) {
Expand All @@ -52,6 +53,7 @@ void Analyzer::Bind(const Var& var, const Range& range, bool allow_override) {
this->Bind(var, range->min, allow_override);
} else {
this->const_int_bound.Bind(var, range, allow_override);
this->int_set.Bind(var, range, allow_override);
}
// skip modular_set
// skip rewrite simplify
Expand All @@ -64,22 +66,22 @@ void Analyzer::Bind(const Map<Var, Range>& variables, bool allow_override) {
}

void ConstraintContext::EnterWithScope() {
ICHECK(exit_ == nullptr);
ICHECK(recovery_functions_.size() == 0);
// entering the scope.
auto f0 = analyzer_->const_int_bound.EnterConstraint(constraint_);
auto f1 = analyzer_->modular_set.EnterConstraint(constraint_);
auto f2 = analyzer_->rewrite_simplify.EnterConstraint(constraint_);
// recovery function.
exit_ = [f0, f1, f2]() {
if (f2 != nullptr) f2();
if (f1 != nullptr) f1();
if (f0 != nullptr) f0();
};
recovery_functions_.push_back(analyzer_->const_int_bound.EnterConstraint(constraint_));
recovery_functions_.push_back(analyzer_->modular_set.EnterConstraint(constraint_));
recovery_functions_.push_back(analyzer_->rewrite_simplify.EnterConstraint(constraint_));
recovery_functions_.push_back(analyzer_->int_set.EnterConstraint(constraint_));
}

void ConstraintContext::ExitWithScope() {
ICHECK(exit_ != nullptr);
exit_();
while (recovery_functions_.size()) {
auto& func = recovery_functions_.back();
if (func) {
func();
}
recovery_functions_.pop_back();
}
}

bool Analyzer::CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound) {
Expand Down
43 changes: 11 additions & 32 deletions src/arith/domain_touched.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
#include <unordered_map>
#include <unordered_set>

#include "ir_visitor_with_analyzer.h"

namespace tvm {
namespace arith {

Expand All @@ -56,7 +58,7 @@ using BufferDomainAccess = std::tuple<LoadAccess, StoreAccess, CombinedAccess>;
} // namespace

// Find Read region of the tensor in the stmt.
class BufferTouchedDomain final : public StmtExprVisitor {
class BufferTouchedDomain final : public IRVisitorWithAnalyzer {
public:
BufferTouchedDomain(const Stmt& stmt) { operator()(stmt); }

Expand Down Expand Up @@ -90,65 +92,42 @@ class BufferTouchedDomain final : public StmtExprVisitor {
return ret;
}

void VisitStmt_(const ForNode* op) final {
const VarNode* var = op->loop_var.get();
dom_map_[var] = IntSet::FromRange(Range::FromMinExtent(op->min, op->extent));
StmtExprVisitor::VisitStmt_(op);
dom_map_.erase(var);
}

void VisitStmt_(const LetStmtNode* op) final {
dom_map_[op->var.get()] = arith::EvalSet(op->value, dom_map_);
StmtExprVisitor::VisitStmt_(op);
dom_map_.erase(op->var.get());
}

/* TODO: Thread extent unitest not generated.*/
void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == tir::attr::thread_extent) {
const IterVarNode* thread_axis = op->node.as<IterVarNode>();
ICHECK(thread_axis);
const VarNode* var = thread_axis->var.get();
dom_map_[var] = IntSet::FromRange(Range(make_zero(op->value.dtype()), op->value));
StmtExprVisitor::VisitStmt_(op);
dom_map_.erase(var);
} else {
StmtExprVisitor::VisitStmt_(op);
}
}
private:
using Parent = IRVisitorWithAnalyzer;
using Parent::VisitExpr_;
using Parent::VisitStmt_;

void VisitExpr_(const BufferLoadNode* op) final {
// Record load-exclusive buffer access
Touch(&std::get<LoadAccess>(buffer_access_map_[op->buffer.get()]).set, op->indices);
// Record load-store inclusive buffer access
Touch(&std::get<CombinedAccess>(buffer_access_map_[op->buffer.get()]).set, op->indices);
StmtExprVisitor::VisitExpr_(op);
Parent::VisitExpr_(op);
}

void VisitStmt_(const BufferStoreNode* op) final {
// Record store-exclusive buffer access
Touch(&std::get<StoreAccess>(buffer_access_map_[op->buffer.get()]).set, op->indices);
// Record load-store inclusive buffer access
Touch(&std::get<CombinedAccess>(buffer_access_map_[op->buffer.get()]).set, op->indices);
StmtExprVisitor::VisitStmt_(op);
Parent::VisitStmt_(op);
}

private:
void Touch(BufferTouches* bounds, const Array<PrimExpr>& args) const {
void Touch(BufferTouches* bounds, const Array<PrimExpr>& args) {
if (args.size() > bounds->size()) {
bounds->resize(args.size());
}
for (size_t i = 0; i < args.size(); ++i) {
if (args[i].as<RampNode>()) {
(*bounds)[i].emplace_back(IntSet::Vector(args[i]));
} else {
(*bounds)[i].emplace_back(EvalSet(args[i], dom_map_));
(*bounds)[i].emplace_back(analyzer_.int_set(args[i]));
}
}
}

std::unordered_map<const BufferNode*, BufferDomainAccess> buffer_access_map_;
std::unordered_map<const VarNode*, IntSet> dom_map_;
};

Region DomainTouched(const Stmt& stmt, const Buffer& buffer, bool consider_loads,
Expand Down
Loading