diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 3704eff33ec2..ceb9f574f2c9 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -135,7 +135,7 @@ class ConstIntBoundAnalyzer { * * \param var The variable of interest. * \param info The bound information. - * \param allow_override Whether do we allow override of existing information. + * \param allow_override whether we allow override of existing information. */ TVM_DLL void Update(const Var& var, const ConstIntBound& info, bool allow_override = false); /*! @@ -224,7 +224,7 @@ class ModularSetAnalyzer { * * \param var The variable of interest. * \param info The bound information. - * \param allow_override Whether do we allow override of existing information. + * \param allow_override whether we allow override of existing information. */ TVM_DLL void Update(const Var& var, const ModularSet& info, bool allow_override = false); @@ -263,10 +263,16 @@ class RewriteSimplifier { * * \param var The variable of interest. * \param new_expr - * \param allow_override Whether do we allow override of existing information. + * \param allow_override Whether we allow override of existing information. */ TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool allow_override = false); + /*! + * \brief Update the internal state to enter constraint. + * \param constraint A constraint expression. + * + * \return an exit function that must be called to cleanup the constraint can be nullptr. + */ std::function EnterConstraint(const PrimExpr& constraint); private: @@ -297,7 +303,7 @@ class CanonicalSimplifier { * * \param var The variable of interest. * \param new_expr - * \param allow_override Whether do we allow override of existing information. + * \param allow_override whether we allow override of existing information. */ TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool allow_override = false); @@ -347,7 +353,7 @@ class ConstraintContext { /*! \brief The constraint */ PrimExpr constraint_; /*! \brief function to be called in recovery */ - std::function exit_; + std::vector> recovery_functions_; }; /*! @@ -365,6 +371,36 @@ class IntSetAnalyzer { */ TVM_DLL IntSet operator()(const PrimExpr& expr, const Map& dom_map); + /*! + * \brief Find a symbolic integer set that contains all possible + * values of expr given the domain of each variables, using + * the domain map defined by bound variables. + * + * \param expr The expression of interest. + * \return the result of the analysis. + */ + TVM_DLL IntSet operator()(const PrimExpr& expr); + + /*! + * \brief Update binding of var to a new expression. + * + * \param var The variable of interest. + * \param new_interval_set The set of allowed values for this var. + * \param allow_override whether we allow override of existing information. + */ + TVM_DLL void Update(const Var& var, const IntSet& new_interval_set, bool allow_override = false); + + /*! + * \brief Update binding of var to a new expression. + * + * \param var The variable of interest. + * \param new_range The range of allowed values for this var. + * \param allow_override whether we allow override of existing information. + */ + TVM_DLL void Bind(const Var& var, const Range& new_range, bool allow_override = false); + + std::function EnterConstraint(const PrimExpr& constraint); + private: friend class Analyzer; explicit IntSetAnalyzer(Analyzer* parent); diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index b922138057e9..f32c9b2ff4cf 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -44,6 +44,7 @@ void Analyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) { this->modular_set.Update(var, this->modular_set(new_expr), allow_override); this->rewrite_simplify.Update(var, new_expr, allow_override); this->canonical_simplify.Update(var, new_expr, allow_override); + this->int_set.Update(var, this->int_set(new_expr), allow_override); } void Analyzer::Bind(const Var& var, const Range& range, bool allow_override) { @@ -52,6 +53,7 @@ void Analyzer::Bind(const Var& var, const Range& range, bool allow_override) { this->Bind(var, range->min, allow_override); } else { this->const_int_bound.Bind(var, range, allow_override); + this->int_set.Bind(var, range, allow_override); } // skip modular_set // skip rewrite simplify @@ -64,22 +66,22 @@ void Analyzer::Bind(const Map& variables, bool allow_override) { } void ConstraintContext::EnterWithScope() { - ICHECK(exit_ == nullptr); + ICHECK(recovery_functions_.size() == 0); // entering the scope. - auto f0 = analyzer_->const_int_bound.EnterConstraint(constraint_); - auto f1 = analyzer_->modular_set.EnterConstraint(constraint_); - auto f2 = analyzer_->rewrite_simplify.EnterConstraint(constraint_); - // recovery function. - exit_ = [f0, f1, f2]() { - if (f2 != nullptr) f2(); - if (f1 != nullptr) f1(); - if (f0 != nullptr) f0(); - }; + recovery_functions_.push_back(analyzer_->const_int_bound.EnterConstraint(constraint_)); + recovery_functions_.push_back(analyzer_->modular_set.EnterConstraint(constraint_)); + recovery_functions_.push_back(analyzer_->rewrite_simplify.EnterConstraint(constraint_)); + recovery_functions_.push_back(analyzer_->int_set.EnterConstraint(constraint_)); } void ConstraintContext::ExitWithScope() { - ICHECK(exit_ != nullptr); - exit_(); + while (recovery_functions_.size()) { + auto& func = recovery_functions_.back(); + if (func) { + func(); + } + recovery_functions_.pop_back(); + } } bool Analyzer::CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound) { diff --git a/src/arith/domain_touched.cc b/src/arith/domain_touched.cc index 403ea47f4e61..d2c5d79a0960 100644 --- a/src/arith/domain_touched.cc +++ b/src/arith/domain_touched.cc @@ -30,6 +30,8 @@ #include #include +#include "ir_visitor_with_analyzer.h" + namespace tvm { namespace arith { @@ -56,7 +58,7 @@ using BufferDomainAccess = std::tuple; } // namespace // Find Read region of the tensor in the stmt. -class BufferTouchedDomain final : public StmtExprVisitor { +class BufferTouchedDomain final : public IRVisitorWithAnalyzer { public: BufferTouchedDomain(const Stmt& stmt) { operator()(stmt); } @@ -90,39 +92,17 @@ class BufferTouchedDomain final : public StmtExprVisitor { return ret; } - void VisitStmt_(const ForNode* op) final { - const VarNode* var = op->loop_var.get(); - dom_map_[var] = IntSet::FromRange(Range::FromMinExtent(op->min, op->extent)); - StmtExprVisitor::VisitStmt_(op); - dom_map_.erase(var); - } - - void VisitStmt_(const LetStmtNode* op) final { - dom_map_[op->var.get()] = arith::EvalSet(op->value, dom_map_); - StmtExprVisitor::VisitStmt_(op); - dom_map_.erase(op->var.get()); - } - - /* TODO: Thread extent unitest not generated.*/ - void VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == tir::attr::thread_extent) { - const IterVarNode* thread_axis = op->node.as(); - ICHECK(thread_axis); - const VarNode* var = thread_axis->var.get(); - dom_map_[var] = IntSet::FromRange(Range(make_zero(op->value.dtype()), op->value)); - StmtExprVisitor::VisitStmt_(op); - dom_map_.erase(var); - } else { - StmtExprVisitor::VisitStmt_(op); - } - } + private: + using Parent = IRVisitorWithAnalyzer; + using Parent::VisitExpr_; + using Parent::VisitStmt_; void VisitExpr_(const BufferLoadNode* op) final { // Record load-exclusive buffer access Touch(&std::get(buffer_access_map_[op->buffer.get()]).set, op->indices); // Record load-store inclusive buffer access Touch(&std::get(buffer_access_map_[op->buffer.get()]).set, op->indices); - StmtExprVisitor::VisitExpr_(op); + Parent::VisitExpr_(op); } void VisitStmt_(const BufferStoreNode* op) final { @@ -130,11 +110,11 @@ class BufferTouchedDomain final : public StmtExprVisitor { Touch(&std::get(buffer_access_map_[op->buffer.get()]).set, op->indices); // Record load-store inclusive buffer access Touch(&std::get(buffer_access_map_[op->buffer.get()]).set, op->indices); - StmtExprVisitor::VisitStmt_(op); + Parent::VisitStmt_(op); } private: - void Touch(BufferTouches* bounds, const Array& args) const { + void Touch(BufferTouches* bounds, const Array& args) { if (args.size() > bounds->size()) { bounds->resize(args.size()); } @@ -142,13 +122,12 @@ class BufferTouchedDomain final : public StmtExprVisitor { if (args[i].as()) { (*bounds)[i].emplace_back(IntSet::Vector(args[i])); } else { - (*bounds)[i].emplace_back(EvalSet(args[i], dom_map_)); + (*bounds)[i].emplace_back(analyzer_.int_set(args[i])); } } } std::unordered_map buffer_access_map_; - std::unordered_map dom_map_; }; Region DomainTouched(const Stmt& stmt, const Buffer& buffer, bool consider_loads, diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 48fae479b042..6d48ad1ed151 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -31,6 +31,7 @@ #include #include +#include "constraint_extract.h" #include "interval_set.h" #include "pattern_match.h" @@ -63,7 +64,7 @@ IntervalSet Intersect(Analyzer* analyzer, IntervalSet a, IntervalSet b) { PrimExpr min_value = max(a->min_value, b->min_value); if ((max_value.dtype().is_int() || max_value.dtype().is_uint()) && (min_value.dtype().is_int() || min_value.dtype().is_uint()) && - analyzer->CanProveGreaterEqual(min_value - max_value, 1)) { + analyzer->CanProve(max_value < min_value)) { return IntervalSet::Empty(); } else { return IntervalSet(min_value, max_value); @@ -105,14 +106,14 @@ TVM_DECLARE_LOGICAL_OP(Not); * \note this can possibly relax the set. */ template -inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, DataType dtype) { if (a->IsSinglePoint() && b->IsSinglePoint()) { PrimExpr res = TryConstFold(a->min_value, b->min_value); if (!res.defined()) res = Op(a->min_value, b->min_value); return IntervalSet::SinglePoint(res); } if (is_logical_op::value) { - return IntervalSet(make_const(a->min_value.dtype(), 0), make_const(a->min_value.dtype(), 1)); + return IntervalSet(make_const(dtype, 0), make_const(dtype, 1)); } if (a->IsEmpty()) return a; if (b->IsEmpty()) return b; @@ -122,7 +123,8 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { } template <> -inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, IntervalSet b) { +inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, IntervalSet b, + DataType /* dtype */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value + b->min_value); } @@ -136,7 +138,8 @@ inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, IntervalS } template <> -inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, IntervalSet b) { +inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, IntervalSet b, + DataType /* dtype */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value - b->min_value); } @@ -150,7 +153,8 @@ inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, IntervalS } template <> -inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, + DataType /* dtype */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value * b->min_value); } @@ -183,7 +187,8 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Interval } template <> -inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, + DataType /* dtype */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value / b->min_value); } @@ -216,7 +221,8 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Interval } template <> -inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, + DataType /* dtype */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(truncmod(a->min_value, b->min_value)); } @@ -244,7 +250,8 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Interval } template <> -inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, + DataType /* dtype */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(floordiv(a->min_value, b->min_value)); } @@ -277,7 +284,8 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Int } template <> -inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, + DataType /* dtype */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(floormod(a->min_value, b->min_value)); } @@ -294,7 +302,10 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Int // a mod b = a - (a / b) * b if a_max / b == a_min / b auto qmax = a->HasUpperBound() ? floordiv(a->max_value, divisor) : pos_inf(); auto qmin = a->HasLowerBound() ? floordiv(a->min_value, divisor) : neg_inf(); - if (analyzer->CanProve(qmax == qmin)) { + // We can compare +/- inf against each other, but cannot use + // operator== between the symbolic limits and an integer. + bool compatible_dtypes = !(qmin.dtype().is_handle() ^ qmax.dtype().is_handle()); + if (compatible_dtypes && analyzer->CanProve(qmax == qmin)) { auto tmax = a->max_value - divisor * qmin; auto tmin = a->min_value - divisor * qmin; return IntervalSet(tmin, tmax); @@ -311,7 +322,8 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Int } template <> -inline IntervalSet Combine(Analyzer* analzyer, IntervalSet a, IntervalSet b) { +inline IntervalSet Combine(Analyzer* analzyer, IntervalSet a, IntervalSet b, + DataType /* dtype */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(max(a->min_value, b->min_value)); } @@ -321,7 +333,8 @@ inline IntervalSet Combine(Analyzer* analzyer, IntervalSet a, Interval } template <> -inline IntervalSet Combine(Analyzer* analzyer, IntervalSet a, IntervalSet b) { +inline IntervalSet Combine(Analyzer* analzyer, IntervalSet a, IntervalSet b, + DataType /* dtype */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(min(a->min_value, b->min_value)); } @@ -423,10 +436,12 @@ class IntervalSetEvaluator : public ExprFunctor { int64_t vstride = stride.Eval()->value; if (vstride > 0) { return Combine(analyzer_, base, - IntervalSet(make_zero(t), make_const(t, vstride * op->lanes - 1))); + IntervalSet(make_zero(t), make_const(t, vstride * op->lanes - 1)), + op->dtype); } else { return Combine(analyzer_, base, - IntervalSet(make_const(t, vstride * op->lanes + 1), make_zero(t))); + IntervalSet(make_const(t, vstride * op->lanes + 1), make_zero(t)), + op->dtype); } } DLOG(WARNING) << "cannot evaluate set on expression " << GetRef(op); @@ -490,7 +505,7 @@ class IntervalSetEvaluator : public ExprFunctor { if (MatchPoint(a, op->a) && MatchPoint(b, op->b)) { return IntervalSet::SinglePoint(GetRef(op)); } - return Combine(analyzer_, a, b); + return Combine(analyzer_, a, b, op->dtype); } // recursive depth @@ -509,8 +524,37 @@ class IntSetAnalyzer::Impl { return IntervalSetEvaluator(analyzer_, dom_map).Eval(expr); } + IntSet Eval(const PrimExpr& expr) const { + return IntervalSetEvaluator(analyzer_, GetCurrentBounds(), true).Eval(expr); + } + + void Bind(const Var& var, const Range& range, bool allow_override) { + Update(var, IntSet::FromRange(range), allow_override); + } + + void Update(const Var& var, const IntSet& info, bool override_info); + void Bind(const Var& var, const PrimExpr& expr, bool override_info); + std::function EnterConstraint(const PrimExpr& constraint); + private: + // Get the current variable bounds, including both global bounds and + // scope-dependent bounds. + Map GetCurrentBounds() const; + + // Utility function to split a boolean condition into the domain + // bounds implied by that condition. + static std::vector> DetectBoundInfo(const PrimExpr& cond); + + // The parent arith::Analyzer Analyzer* analyzer_; + + // Map of variables to global variable bounds (e.g. loop iterator + // ranges) + Map dom_map_; + + // Map of variables to implicit scope-dependent bounds (e.g. inside + // the body of an if-statement) + Map constraints_; }; IntSetAnalyzer::IntSetAnalyzer(Analyzer* parent) : impl_(new Impl(parent)) {} @@ -521,6 +565,141 @@ IntSet IntSetAnalyzer::operator()(const PrimExpr& expr, const Map& return impl_->Eval(expr, dom_map); } +IntSet IntSetAnalyzer::operator()(const PrimExpr& expr) { return impl_->Eval(expr); } + +void IntSetAnalyzer::Update(const Var& var, const IntSet& info, bool allow_override) { + impl_->Update(var, info, allow_override); +} + +void IntSetAnalyzer::Bind(const Var& var, const Range& range, bool allow_override) { + impl_->Bind(var, range, allow_override); +} + +void IntSetAnalyzer::Impl::Update(const Var& var, const IntSet& info, bool can_override) { + if (!can_override) { + auto it = dom_map_.find(var); + if (it != dom_map_.end()) { + const IntSet& old_info = (*it).second; + + ICHECK(ExprDeepEqual()(old_info.min(), info.min())) + << "Trying to update var \'" << var << "\'" + << " with a different minimum value: " + << "original=" << old_info.min() << ", new=" << info.min(); + + ICHECK(ExprDeepEqual()(old_info.max(), info.max())) + << "Trying to update var \'" << var << "\'" + << " with a different maximum value: " + << "original=" << old_info.max() << ", new=" << info.max(); + } + } + dom_map_.Set(var, info); +} + +void IntSetAnalyzer::Impl::Bind(const Var& var, const PrimExpr& expr, bool can_override) { + Update(var, Eval(expr), can_override); +} + +Map IntSetAnalyzer::Impl::GetCurrentBounds() const { + // If either constraints_ or dom_map_ is empty, return the other to + // avoid constructing a new map. + if (constraints_.empty()) { + return dom_map_; + } else if (dom_map_.empty()) { + return constraints_; + } + + // If neither is empty, construct a merged domain map with + // information from both sources. + Map merged = dom_map_; + for (const auto& pair : constraints_) { + auto it = merged.find(pair.first); + if (it == merged.end()) { + merged.Set(pair.first, pair.second); + } else { + merged.Set(pair.first, Intersect({pair.second, (*it).second})); + } + } + return merged; +} + +std::vector> IntSetAnalyzer::Impl::DetectBoundInfo( + const PrimExpr& constraint) { + PVar x; + PVar limit; + + std::vector> bounds; + for (const PrimExpr& subconstraint : ExtractConstraints(constraint)) { + if ((x <= limit).Match(subconstraint)) { + bounds.push_back({x.Eval(), IntSet::Interval(SymbolicLimits::neg_inf_, limit.Eval())}); + } else if ((x < limit).Match(subconstraint)) { + bounds.push_back({x.Eval(), IntSet::Interval(SymbolicLimits::neg_inf_, limit.Eval() - 1)}); + } else if ((x >= limit).Match(subconstraint)) { + bounds.push_back({x.Eval(), IntSet::Interval(limit.Eval(), SymbolicLimits::pos_inf_)}); + } else if ((x > limit).Match(subconstraint)) { + bounds.push_back({x.Eval(), IntSet::Interval(limit.Eval() + 1, SymbolicLimits::pos_inf_)}); + } else if ((x == limit).Match(subconstraint)) { + bounds.push_back({x.Eval(), IntSet::SinglePoint(limit.Eval())}); + } + + if ((limit >= x).Match(subconstraint)) { + bounds.push_back({x.Eval(), IntSet::Interval(SymbolicLimits::neg_inf_, limit.Eval())}); + } else if ((limit > x).Match(subconstraint)) { + bounds.push_back({x.Eval(), IntSet::Interval(SymbolicLimits::neg_inf_, limit.Eval() - 1)}); + } else if ((limit <= x).Match(subconstraint)) { + bounds.push_back({x.Eval(), IntSet::Interval(limit.Eval(), SymbolicLimits::pos_inf_)}); + } else if ((limit < x).Match(subconstraint)) { + bounds.push_back({x.Eval(), IntSet::Interval(limit.Eval() + 1, SymbolicLimits::pos_inf_)}); + } else if ((limit == x).Match(subconstraint)) { + bounds.push_back({x.Eval(), IntSet::SinglePoint(limit.Eval())}); + } + } + return bounds; +} + +std::function IntSetAnalyzer::EnterConstraint(const PrimExpr& constraint) { + return impl_->EnterConstraint(constraint); +} + +std::function IntSetAnalyzer::Impl::EnterConstraint(const PrimExpr& constraint) { + Map cached_values; + + auto bounds = DetectBoundInfo(constraint); + + if (bounds.size() == 0) return nullptr; + + // Collect the current values of each var that is changes by this + // constraint. + for (const auto& pair : bounds) { + auto it = constraints_.find(pair.first); + if (it == constraints_.end()) { + cached_values.Set(pair.first, IntSet()); + } else { + cached_values.Set(pair.first, (*it).second); + } + } + + // Update all constraints + for (const auto& pair : bounds) { + auto it = constraints_.find(pair.first); + if (it == constraints_.end()) { + constraints_.Set(pair.first, pair.second); + } else { + constraints_.Set(pair.first, Intersect({pair.second, (*it).second})); + } + } + + auto frecover = [cached_values, this]() { + for (const auto& it : cached_values) { + if (it.second.defined()) { + constraints_.Set(it.first, it.second); + } else { + constraints_.erase(it.first); + } + } + }; + return frecover; +} + // Quickly adapt to IntSet interface // TODO(tqchen): revisit IntSet interface as well. Range IntSet::CoverRange(Range max_range) const { diff --git a/src/arith/ir_visitor_with_analyzer.cc b/src/arith/ir_visitor_with_analyzer.cc new file mode 100644 index 000000000000..75ae22ef9915 --- /dev/null +++ b/src/arith/ir_visitor_with_analyzer.cc @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/arith/ir_visitor_with_analyzer.cc + */ +#include "ir_visitor_with_analyzer.h" + +#include +#include +#include + +namespace tvm { +namespace arith { + +using namespace tir; + +void IRVisitorWithAnalyzer::VisitStmt_(const ForNode* op) { + analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); + StmtExprVisitor::VisitStmt_(op); +} + +void IRVisitorWithAnalyzer::VisitStmt_(const BlockNode* op) { + for (const auto& iter_var : op->iter_vars) { + analyzer_.Bind(iter_var->var, iter_var->dom); + } + StmtExprVisitor::VisitStmt_(op); +} + +void IRVisitorWithAnalyzer::VisitStmt_(const LetStmtNode* op) { + this->VisitExpr(op->value); + analyzer_.Bind(op->var, op->value); + this->VisitStmt(op->body); +} + +void IRVisitorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) { + this->VisitExpr(op->condition); + + PrimExpr real_condition = ExtractRealCondition(op->condition); + + { + With constraint(&analyzer_, real_condition); + this->VisitStmt(op->then_case); + } + if (op->else_case.defined()) { + With constraint(&analyzer_, analyzer_.rewrite_simplify(Not(real_condition))); + this->VisitStmt(op->else_case); + } +} + +void IRVisitorWithAnalyzer::VisitStmt_(const AttrStmtNode* op) { + if (op->attr_key == tir::attr::thread_extent || op->attr_key == tir::attr::virtual_thread) { + IterVar iv = Downcast(op->node); + ICHECK_NE(iv->thread_tag.length(), 0U); + analyzer_.Bind(iv->var, Range::FromMinExtent(0, op->value)); + } + StmtExprVisitor::VisitStmt_(op); +} + +void IRVisitorWithAnalyzer::VisitStmt_(const AssertStmtNode* op) { + this->VisitExpr(op->condition); + this->VisitExpr(op->message); + With constraint(&analyzer_, op->condition); + this->VisitStmt(op->body); +} + +void IRVisitorWithAnalyzer::VisitExpr_(const CallNode* op) { + // add condition context to if_then_else + static auto op_if_then_else = Op::Get("tir.if_then_else"); + if (op->op.same_as(op_if_then_else)) { + PrimExpr cond = op->args[0]; + this->VisitExpr(op->args[0]); + { + With constraint(&analyzer_, cond); + this->VisitExpr(op->args[1]); + } + { + With constraint(&analyzer_, analyzer_.rewrite_simplify(Not(cond))); + this->VisitExpr(op->args[2]); + } + } else { + StmtExprVisitor::VisitExpr_(op); + } +} + +void IRVisitorWithAnalyzer::VisitExpr_(const LetNode* op) { + this->VisitExpr(op->value); + analyzer_.Bind(op->var, op->value); + this->VisitExpr(op->body); +} + +void IRVisitorWithAnalyzer::VisitExpr_(const ReduceNode* op) { + for (const IterVar& iv : op->axis) { + analyzer_.Bind(iv->var, iv->dom); + } + StmtExprVisitor::VisitExpr_(op); +} + +PrimExpr IRVisitorWithAnalyzer::ExtractRealCondition(PrimExpr condition) const { + if (auto call = condition.as()) { + if (call->op.same_as(builtin::likely())) { + return call->args[0]; + } + } + + return condition; +} + +} // namespace arith +} // namespace tvm diff --git a/src/arith/ir_visitor_with_analyzer.h b/src/arith/ir_visitor_with_analyzer.h index 058abc8c7d20..f41a628f3cc6 100644 --- a/src/arith/ir_visitor_with_analyzer.h +++ b/src/arith/ir_visitor_with_analyzer.h @@ -30,42 +30,37 @@ #include namespace tvm { -namespace tir { +namespace arith { -class IRVisitorWithAnalyzer final : public StmtExprVisitor { +class IRVisitorWithAnalyzer : public tir::StmtExprVisitor { public: PrimExpr Simplify(const PrimExpr& expr) { return analyzer_.Simplify(expr); } - void VisitStmt_(const ForNode* op) { - analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); - return StmtExprVisitor::VisitStmt_(op); - } + using StmtExprVisitor::VisitExpr_; + using StmtExprVisitor::VisitStmt_; - void VisitStmt_(const AttrStmtNode* op) { - if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { - IterVar iv = Downcast(op->node); - ICHECK_NE(iv->thread_tag.length(), 0U); - analyzer_.Bind(iv->var, Range::FromMinExtent(0, op->value)); - StmtExprVisitor::VisitStmt_(op); - } else { - StmtExprVisitor::VisitStmt_(op); - } - } + void VisitStmt_(const tir::ForNode* op); + void VisitStmt_(const tir::BlockNode* op); + void VisitStmt_(const tir::LetStmtNode* op); + void VisitStmt_(const tir::IfThenElseNode* op); + void VisitStmt_(const tir::AttrStmtNode* op); + void VisitStmt_(const tir::AssertStmtNode* op); + void VisitExpr_(const tir::CallNode* op); + void VisitExpr_(const tir::LetNode* op); + void VisitExpr_(const tir::ReduceNode* op); - void VisitExpr_(const ReduceNode* op) { - // Setup the domain information before simplification. - for (const IterVar& iv : op->axis) { - analyzer_.Bind(iv->var, iv->dom); - } - // Recursively call simplification when necessary. - StmtExprVisitor::VisitExpr_(op); - } + // IRVisitorWithAnalyzer deliberately does not handle Select nodes, + // because both sides of a Select node are visited regardless of the + // condition. protected: /*! \brief internal analyzer field. */ arith::Analyzer analyzer_; + + private: + PrimExpr ExtractRealCondition(PrimExpr condition) const; }; -} // namespace tir +} // namespace arith } // namespace tvm #endif // TVM_ARITH_IR_VISITOR_WITH_ANALYZER_H_ diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index f2d9aba4fba8..dd236537e9c2 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -47,6 +47,7 @@ namespace tvm { namespace tir { +using arith::IRVisitorWithAnalyzer; using runtime::StorageRank; using runtime::StorageScope; using runtime::ThreadScope; diff --git a/src/tir/transforms/texture_flatten.cc b/src/tir/transforms/texture_flatten.cc index a607e5914b39..3c35b73bc8d7 100644 --- a/src/tir/transforms/texture_flatten.cc +++ b/src/tir/transforms/texture_flatten.cc @@ -38,6 +38,7 @@ namespace tvm { namespace tir { +using arith::IRVisitorWithAnalyzer; using runtime::ApplyTexture2DFlattening; using runtime::DefaultTextureLayoutSeparator; using runtime::IsTextureStorage;