Skip to content

Commit 57fdeb3

Browse files
committed
[ARITH] Enhance CanProve to handle symbolic bound
This PR enhances CanProve to handle symbolic bound. Such analysis is essential to eliminate predicates in dynamic shape workloads. We also the int set analysis singlepoint check to avoid recursion and improve the overall analysis speed. Added CanProveSinglePoint to serve previous stronger checks. The new CanProve comes with additinal strength argument that can only be used in top-level setting with stronger analysis. Added comment for future implementation efficiency. Testcases are added to cover the cases.
1 parent af39b34 commit 57fdeb3

File tree

13 files changed

+178
-19
lines changed

13 files changed

+178
-19
lines changed

include/tvm/arith/analyzer.h

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,22 @@ enum DivMode {
5959
kFloorDiv
6060
};
6161

62+
/*!
63+
* \brief The strength used in top-level condition proves
64+
* \note The higher, the more time consuming it can be.
65+
*
66+
* Do not use level beyond kDefault in internal recursive rewriting in arith
67+
* analysis and only use it at top-level simplification to avoid speed issues.
68+
*/
69+
enum class ProofStrength : int {
70+
/*! \brief default strength, can be used in. */
71+
kDefault = 0,
72+
/*!
73+
* \brief Prove using symbolic bound analysis
74+
*/
75+
kSymbolicBound = 1
76+
};
77+
6278
/*!
6379
* \brief Constant integer up and lower bound(inclusive).
6480
* Useful for value bound analysis.
@@ -656,11 +672,16 @@ class TVM_DLL Analyzer {
656672
* \brief Whether can we prove condition.
657673
*
658674
* \param cond The expression to be proved.
675+
* \param strength the strength of the prove.
676+
*
659677
* \return The result.
660678
*
661679
* \note Analyzer will call into sub-analyzers to get the result.
680+
* Do not use strength beyond default in sub-analyzers and
681+
* only use it in top-level predicate analysis.
662682
*/
663-
bool CanProve(const PrimExpr& cond);
683+
bool CanProve(const PrimExpr& cond, ProofStrength strength = ProofStrength::kDefault);
684+
664685
/*!
665686
* \brief Simplify expr.
666687
*

include/tvm/arith/int_set.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,22 @@ class IntSet : public ObjectRef {
8585
bool IsEverything() const;
8686
/*! \return Whether the set is a single point */
8787
bool IsSinglePoint() const;
88+
/*!
89+
* \brief Check if we can prove it is a single point.
90+
*
91+
* Unlike IsSinglePoint, which only checks ptr equality
92+
* this function will invoke analyzer to do stonger proofs
93+
* but also takes longer time.
94+
*
95+
* Use this function in some of the primitives but do not
96+
* use it in the inner loop of simplification.
97+
*
98+
* \param ana Analyzer used in the proof.
99+
* \return Whether we can prove it is a single point
100+
*/
101+
bool CanProveSinglePoint(Analyzer* ana) const;
102+
// TODO(tvm-team): update all CanProve to explicitly take
103+
// analyzer to encourage more analyzer reuse
88104
/*! \return Whether the set is proved to be bigger than 0 */
89105
bool CanProvePositive() const;
90106
/*! \return Whether the set is proved to be smaller than 0 */

python/tvm/arith/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
estimate_region_strict_bound,
2424
estimate_region_upper_bound,
2525
)
26-
from .analyzer import ModularSet, ConstIntBound, Analyzer
26+
from .analyzer import ModularSet, ConstIntBound, Analyzer, ProofStrength
2727
from .bound import deduce_bound
2828
from .pattern import detect_linear_equation, detect_clip_bound, detect_common_subexpr
2929
from .int_solver import solve_linear_equations, solve_linear_inequalities

python/tvm/arith/analyzer.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,19 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
"""Arithmetic data structure and utility"""
18+
from enum import IntEnum
1819
import tvm._ffi
1920
from tvm.runtime import Object
2021
from . import _ffi_api
2122

2223

24+
class ProofStrength(IntEnum):
25+
"""Proof strength of the analysis"""
26+
27+
DEFAULT = 0
28+
SYMBOLIC_BOUND = 1
29+
30+
2331
@tvm._ffi.register_object("arith.ModularSet")
2432
class ModularSet(Object):
2533
"""Represent range of (coeff * x + base) for x in Z"""
@@ -91,6 +99,7 @@ def __init__(self):
9199
self._int_set = _mod("int_set")
92100
self._enter_constraint_context = _mod("enter_constraint_context")
93101
self._can_prove_equal = _mod("can_prove_equal")
102+
self._can_prove = _mod("can_prove")
94103

95104
def const_int_bound(self, expr):
96105
"""Find constant integer bound for expr.
@@ -190,6 +199,24 @@ def int_set(self, expr, dom_map):
190199
"""
191200
return self._int_set(expr, dom_map)
192201

202+
def can_prove(self, expr, strength=ProofStrength.DEFAULT):
203+
"""Check whether we can prove expr to be true.
204+
205+
Parameters
206+
----------
207+
expr : PrimExpr
208+
The expression.
209+
210+
strength: ProofStrength
211+
The proof strength
212+
213+
Returns
214+
-------
215+
result : Expr
216+
The result.
217+
"""
218+
return self._can_prove(expr, strength)
219+
193220
def bind(self, var, expr):
194221
"""Bind a variable to the expression.
195222

src/arith/analyzer.cc

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,15 +115,47 @@ bool Analyzer::CanProveEqual(const PrimExpr& lhs, const PrimExpr& rhs) {
115115
return CanProve(lhs - rhs == 0);
116116
}
117117

118-
bool Analyzer::CanProve(const PrimExpr& expr) {
118+
bool Analyzer::CanProve(const PrimExpr& expr, ProofStrength strength) {
119119
// Avoid potentially expensive simplification unless required.
120120
if (const auto* ptr = expr.as<IntImmNode>()) {
121121
return ptr->value != 0;
122122
}
123-
124123
PrimExpr simplified = Simplify(expr);
125124
const int64_t* as_int = tir::as_const_int(simplified);
126-
return as_int && *as_int;
125+
if (as_int && *as_int) return true;
126+
if (strength >= ProofStrength::kSymbolicBound) {
127+
// NOTE: we intentionally only pattern match common bound predicate i < bound
128+
// and put this implementation at the top-level.
129+
// This is to avoid repeatitive calling of this function
130+
// that causes speed issues.
131+
// This strategy can only be called from top-level and not from sub-analyzers.
132+
Optional<PrimExpr> pos_diff;
133+
int lower_bound = 0;
134+
if (const auto* ptr_lt = expr.as<tir::LTNode>()) {
135+
pos_diff = ptr_lt->b - ptr_lt->a;
136+
lower_bound = 1;
137+
}
138+
if (const auto* ptr_le = expr.as<tir::LENode>()) {
139+
pos_diff = ptr_le->b - ptr_le->a;
140+
lower_bound = 0;
141+
}
142+
if (const auto* ptr_gt = expr.as<tir::GTNode>()) {
143+
pos_diff = ptr_gt->a - ptr_gt->b;
144+
lower_bound = 1;
145+
}
146+
if (const auto* ptr_ge = expr.as<tir::GENode>()) {
147+
pos_diff = ptr_ge->a - ptr_ge->b;
148+
lower_bound = 0;
149+
}
150+
if (pos_diff) {
151+
IntSet iset = this->int_set(this->Simplify(pos_diff.value()));
152+
if (iset.HasLowerBound()) {
153+
ConstIntBound relaxed_lower_bound = this->const_int_bound(this->Simplify(iset.min()));
154+
if (relaxed_lower_bound->min_value >= lower_bound) return true;
155+
}
156+
}
157+
}
158+
return false;
127159
}
128160

129161
PrimExpr Analyzer::Simplify(const PrimExpr& expr, int steps) {
@@ -189,6 +221,11 @@ TVM_REGISTER_GLOBAL("arith.CreateAnalyzer").set_body([](TVMArgs args, TVMRetValu
189221
self->Bind(args[0], args[1].operator PrimExpr());
190222
}
191223
});
224+
} else if (name == "can_prove") {
225+
return PackedFunc([self](TVMArgs args, TVMRetValue* ret) {
226+
int strength = args[1];
227+
*ret = self->CanProve(args[0], static_cast<ProofStrength>(strength));
228+
});
192229
} else if (name == "enter_constraint_context") {
193230
return PackedFunc([self](TVMArgs args, TVMRetValue* ret) {
194231
// can't use make_shared due to noexcept(false) decl in destructor,

src/arith/int_set.cc

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,11 @@ class IntervalSetEvaluator : public ExprFunctor<IntervalSet(const PrimExpr&)> {
492492

493493
IntervalSet VisitExpr_(const CastNode* op) final {
494494
IntervalSet value_set = this->Eval(op->value);
495+
// short cut for the int set.
496+
if (value_set->min_value.same_as(value_set->max_value)) {
497+
if (value_set->IsEmpty()) return value_set;
498+
return IntervalSet::SinglePoint(cast(op->dtype, value_set->min_value));
499+
}
495500
PrimExpr min_value =
496501
value_set->HasLowerBound() ? cast(op->dtype, value_set->min_value) : neg_inf();
497502
PrimExpr max_value =
@@ -723,6 +728,13 @@ bool IntSet::IsSinglePoint() const {
723728
return (s_int && s_int->IsSinglePoint());
724729
}
725730

731+
bool IntSet::CanProveSinglePoint(Analyzer* ana) const {
732+
const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
733+
if (!s_int) return false;
734+
if (s_int->IsSinglePoint()) return true;
735+
return ana->CanProveEqual(s_int->min_value, s_int->max_value);
736+
}
737+
726738
bool IntSet::CanProvePositive() const {
727739
Analyzer analyzer;
728740
const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
@@ -943,9 +955,15 @@ IntSet EvalSet(PrimExpr e, const Map<Var, IntSet>& dom_map) {
943955
}
944956

945957
IntSet IntSet::Vector(PrimExpr x) {
946-
Analyzer ana;
947-
Map<Var, IntSet> dmap;
948-
return IntervalSetEvaluator(&ana, dmap, {}, true).Eval(x);
958+
// short cut: simply get single point
959+
if (x.dtype().lanes() == 1) {
960+
return IntSet::SinglePoint(x);
961+
} else {
962+
// vector case.
963+
Analyzer ana;
964+
Map<Var, IntSet> dmap;
965+
return IntervalSetEvaluator(&ana, dmap, {}, true).Eval(x);
966+
}
949967
}
950968

951969
IntSet EvalSet(PrimExpr e, const Map<IterVar, IntSet>& dom_map) {

src/arith/interval_set.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,11 @@ class IntervalSetNode : public IntSetNode {
6060
bool HasLowerBound() const { return !is_neg_inf(min_value) && !IsEmpty(); }
6161
/*! \return Whether the interval is a single point. */
6262
bool IsSinglePoint() const {
63-
if (min_value.same_as(max_value)) {
64-
return true;
65-
}
66-
Analyzer analyzer;
67-
return analyzer.CanProveEqual(min_value, max_value);
63+
// NOTE: we are only doing cheap check as this is a frequently called routine,
64+
// do manual prove of min and max for stronger single point check.
65+
return min_value.same_as(max_value);
6866
}
67+
6968
/*! \return whether interval represent nothing */
7069
bool IsEmpty() const {
7170
// during computations, either extreme could occur.

src/arith/rewrite_simplify.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,12 @@ CompareResult RewriteSimplifier::Impl::TryCompareUsingKnownInequalities(const Pr
150150

151151
// try to prove x equals val
152152
CompareResult RewriteSimplifier::Impl::TryCompare(const PrimExpr& x, int64_t val) {
153+
// NOTE on implementation: this function can be called many times and can be a bottleneck,
154+
// As a result, we keep comparison here lightweight.
155+
// We only do constant int bound analysis here.
156+
//
157+
// For stronger comparison proof that is out of the recursive simplifcation
158+
// consider look at analyzer::CanProveStrong
153159
PrimExpr diff = this->VisitExpr(x);
154160
if (const auto* ptr = diff.as<IntImmNode>()) {
155161
if (ptr->value == val) {
@@ -176,6 +182,8 @@ CompareResult RewriteSimplifier::Impl::TryCompare(const PrimExpr& x, int64_t val
176182
if (dbound->max_value <= val) {
177183
return CompareResult::kLE;
178184
}
185+
186+
// modular analysis
179187
if (val == 0) {
180188
ModularSet dmod = analyzer_->modular_set(diff);
181189
if (dmod->base != 0) {

src/arith/rewrite_simplify.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,6 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
104104

105105
// maximum number of recursion allowed during a single pass.
106106
static const constexpr int kMaxRecurDepth = 5;
107-
108107
/*!
109108
* \brief try to compare x against val.
110109
* \param x The expression to be evaluated.

src/tir/analysis/block_access_region_detector.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ class BlockReadWriteDetector : public StmtExprVisitor {
7676
Map<Var, Buffer> buffer_var_map_;
7777
/*! \brief The target buffer var mapping to its matching */
7878
std::unordered_map<const VarNode*, MatchBufferRegion> match_buffers_;
79+
/*!\ brief Internal analyzer. */
80+
arith::Analyzer ana_;
7981

8082
/*!
8183
* \brief Update read/write buffers and regions with provided buffer and region
@@ -318,7 +320,7 @@ Array<BufferRegion> BlockReadWriteDetector::CollectRegions(
318320
ICHECK_EQ(buffers[i]->shape.size(), regions[i].size());
319321
for (size_t j = 0; j < regions[i].size(); j++) {
320322
const tvm::arith::IntSet& range = regions[i][j];
321-
if (range.IsSinglePoint()) {
323+
if (range.CanProveSinglePoint(&ana_)) {
322324
PrimExpr min = range.min();
323325
region.push_back(Range::FromMinExtent(min, make_const(min.dtype(), 1)));
324326
} else {

0 commit comments

Comments
 (0)