Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 55 additions & 17 deletions src/arithmetic/bound_deducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ std::vector<const Node*> GetPath(Expr target, Expr expr) {
return v.path_;
}

enum CompareOp {kGreater, kLess, kEqual};

// a visitor to deduce the bound of a variable from a expression
class BoundDeducer: public IRVisitor {
public:
Expand Down Expand Up @@ -120,7 +122,7 @@ class BoundDeducer: public IRVisitor {
} else {
result_ -= op->a;
result_ = - result_;
is_greater_ = !is_greater_;
comp_op = ReverseOp(comp_op);
}
Visit(left ? op->a : op->b);
}
Expand All @@ -138,7 +140,7 @@ class BoundDeducer: public IRVisitor {
}

if (sign_operand == SignType::kNegative) {
is_greater_ = !is_greater_;
comp_op = ReverseOp(comp_op);
} else if (sign_operand == SignType::kUnknown) {
// unable to get the sign of operand
success_ = false;
Expand All @@ -151,11 +153,15 @@ class BoundDeducer: public IRVisitor {

if (!divided) {
// Handle non-divisible case
// NOTE: this accounts for truc div behavior.
// NOTE: this accounts for trunc div behavior.
bool target_is_non_neg = expr_map_[target_var].can_prove_non_negative();

if (is_greater_) {
if (comp_op == kGreater) {
result_ += 1;
} else if (comp_op == kEqual) {
// condition unsatisfiable as with trunc div, it will change the expression
success_ = false;
return;
} else {
// NOTE: this is a bit sutble hack.
//
Expand Down Expand Up @@ -185,14 +191,14 @@ class BoundDeducer: public IRVisitor {
}

Expr result_;
bool is_greater_{true};
CompareOp comp_op{kGreater};
bool success_{true};

private:
void Init();
void Transform();
void Relax();

CompareOp ReverseOp(CompareOp comp_op);
Expr target_;
Expr expr_;
const std::unordered_map<const Variable*, IntSet>& hint_map_;
Expand Down Expand Up @@ -228,51 +234,72 @@ void BoundDeducer::Init() {
Transform();
}

CompareOp BoundDeducer::ReverseOp(CompareOp comp_op) {
switch (comp_op) {
case kEqual: return kEqual; // IntSet can not represent range for `NE
case kGreater: return kLess;
case kLess: return kGreater;
default:
LOG(FATAL) << "Not a valid compare op";
return kGreater; // return some default value
}
}

void BoundDeducer::Transform() {
// We will ensure to set expr_ such that it contains target_
if (const LT* op = expr_.as<LT>()) {
if (GetPath(target_, op->a).empty()) {
// a < b -> b >= a + 1
is_greater_ = true;
comp_op = kGreater;
expr_ = op->b;
result_ = op->a + 1;
} else {
// a < b -> a <= b - 1
is_greater_ = false;
comp_op = kLess;
expr_ = op->a;
result_ = op->b - 1;
}
} else if (const LE* op = expr_.as<LE>()) {
if (GetPath(target_, op->a).empty()) {
// a <= b -> b >= a
is_greater_ = true;
comp_op = kGreater;
expr_ = op->b;
result_ = op->a;
} else {
is_greater_ = false;
comp_op = kLess;
expr_ = op->a;
result_ = op->b;
}
} else if (const GT* op = expr_.as<GT>()) {
if (GetPath(target_, op->a).empty()) {
// a > b -> b <= a - 1
is_greater_ = false;
comp_op = kLess;
expr_ = op->b;
result_ = op->a - 1;
} else {
// a > b -> a >= b + 1
is_greater_ = true;
comp_op = kGreater;
expr_ = op->a;
result_ = op->b + 1;
}
} else if (const GE* op = expr_.as<GE>()) {
if (GetPath(target_, op->a).empty()) {
// a >= b -> b <= a
is_greater_ = false;
comp_op = kLess;
expr_ = op->b;
result_ = op->a;
} else {
comp_op = kGreater;
expr_ = op->a;
result_ = op->b;
}
} else if (const EQ* op = expr_.as<EQ>()) {
comp_op = kEqual;
if (GetPath(target_, op->a).empty()) {
// if the b == a -> a == b
expr_ = op->b;
result_ = op->a;
} else {
is_greater_ = true;
expr_ = op->a;
result_ = op->b;
}
Expand Down Expand Up @@ -304,8 +331,16 @@ void BoundDeducer::Relax() {
success_ = false;
return;
}
expr_ = is_greater_ ? a.min() : a.max();
result_ = is_greater_ ? b.max() : b.min();
// Both LHS and RHS of the EQ should behave as constants e.g. i == j,
// can not be resolved when either `i` or `j` or both are variables with
// some Range OR `i` and `j` both should be a single point in IntSet
if (comp_op == kEqual && (!analyzer_.CanProve(b.min() == b.max())
|| !analyzer_.CanProve(a.min() == a.max()))) {
success_ = false;
return;
}
expr_ = (comp_op == kGreater) ? a.min() : a.max();
result_ = (comp_op == kGreater) ? b.max() : b.min();
}

IntSet DeduceBound(Expr v, Expr e,
Expand All @@ -315,7 +350,10 @@ IntSet DeduceBound(Expr v, Expr e,
d.Deduce();
if (!d.success_) return IntSet::nothing();
Expr min = neg_inf(), max = pos_inf();
if (d.is_greater_) {
if (d.comp_op == kEqual) {
min = d.result_;
max = d.result_;
} else if (d.comp_op == kGreater) {
min = d.result_;
} else {
max = d.result_;
Expand Down
8 changes: 6 additions & 2 deletions src/pass/loop_partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,6 @@ class PartitionFinder : public IRVisitor {

private:
Expr InverseCond(const Expr& cond) {
// We expect most condition not to be of EQ or NE form.
// Currently we do not handle inversing EQ or NE.
Expr inverse_cond;
if (const LT* op = cond.as<LT>()) {
// a < b -> a >= b
Expand All @@ -241,6 +239,12 @@ class PartitionFinder : public IRVisitor {
} else if (const GE* op = cond.as<GE>()) {
// a >= b -> a < b
inverse_cond = LT::make(op->a, op->b);
} else if (const EQ* op = cond.as<EQ>()) {
// a == b -> a != b
inverse_cond = NE::make(op->a, op->b);
// a != b -> a == b
} else if (const NE* op = cond.as<NE>()) {
inverse_cond = EQ::make(op->a, op->b);
}
return inverse_cond;
}
Expand Down
39 changes: 39 additions & 0 deletions tests/python/unittest/test_arith_deduce_bound.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,44 @@ def test_deduce():
res3 = tvm.arith.DeduceBound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
assert str(tvm.ir_pass.Simplify(res3.min_value)) == str(ans3)

# tests for `EQ` op
res4 = tvm.arith.DeduceBound(a, a == b, {}, {})
assert_expr_equal(res4.max_value, b)
assert_expr_equal(res4.min_value, b)

# Unsatisfiable `EQ`, variable as one of the Operand
res5 = tvm.arith.DeduceBound(a, (a == b), {b: b_s}, {b: b_s})
assert str(res5.max_value) == "neg_inf"
assert str(res5.min_value) == "pos_inf"

# variable `a` on the RHS side
res6 = tvm.arith.DeduceBound(a, 10 == a, {}, {})
assert_expr_equal(res6.max_value, 10)
assert_expr_equal(res6.min_value, 10)

# Add, Sub in `EQ`
e4 = ((a - c) == (b + d))
ans4 = (b + d + c)
res7 = tvm.arith.DeduceBound(a, e4, {b: b_s, c: c_s, d: d_s}, {})
assert_expr_equal(res7.max_value, ans4)
assert_expr_equal(res7.min_value, ans4)

# Satisfiable Mul in `EQ` with negative sign
res8 = tvm.arith.DeduceBound(a, (5 * a == -10), {}, {})
assert_expr_equal(res8.max_value, -2)
assert_expr_equal(res8.min_value, -2)

# Unsatisfiable Mul in `EQ`
e5 = (4 * a == b)
res9 = tvm.arith.DeduceBound(a, e5, {b: b_s}, {})
assert str(res9.max_value) == "neg_inf"
assert str(res9.min_value) == "pos_inf"

# Unsatisfiable Mul in `EQ`
res10 = tvm.arith.DeduceBound(a, (b * a == b), {b: b_s}, {}) # simplifier is not able to prove that (b % b == 0)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tqchen @xqdan I noticed that TVM wasn't able to prove that b % b == 0 where b is a variable. Is it something expected ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this can be improved @tqchen

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be improved by adding a rewrite pattern

assert str(res10.max_value) == "neg_inf"
assert str(res10.min_value) == "pos_inf"


def test_check():
a = tvm.var('a')
Expand Down Expand Up @@ -175,5 +213,6 @@ def test_complex(a1, a2, coff):

if __name__ == "__main__":
test_check()
test_deduce()
test_deduce_basic()
test_deduce_complex()
13 changes: 13 additions & 0 deletions tests/python/unittest/test_pass_loop_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,18 @@ def test_condition():
stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt.first, lambda x: isinstance(x, tvm.expr.Select))))

def test_condition_EQ():
ib = tvm.ir_builder.create()
m = tvm.var('m')
n = tvm.var('n')
with ib.for_range(0, 10, 'i') as i:
ib.emit(tvm.make.Evaluate(
tvm.make.Select(ib.likely(tvm.expr.EQ(i, 5)), m, n)))
stmt = ib.get()
stmt = tvm.ir_pass.LoopPartition(stmt, True)
stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt.first, lambda x: isinstance(x, tvm.expr.Select))))

def test_thread_axis2():
n = tvm.convert(4096)
m = tvm.var('m')
Expand Down Expand Up @@ -420,6 +432,7 @@ def test_simple_rfactor():
test_thread_axis()
test_vectorize()
test_condition()
test_condition_EQ()
test_thread_axis2()
test_everything_during_deduction()
test_single_likely()
Expand Down