Skip to content

Commit cbcd41d

Browse files
sgrechanik-hwweic
authored andcommitted
[ARITH] Simplify casts of constants 0 and 1 (apache#3758)
* [ARITH] Simplify casts of constants 0 and 1 * [EXPR] is_const_value to check whether non-ints are consts * Revert "[EXPR] is_const_value to check whether non-ints are consts" This reverts commit 7e1b346. * Use tvm::cast
1 parent 189b59f commit cbcd41d

File tree

4 files changed

+26
-0
lines changed

4 files changed

+26
-0
lines changed

src/arithmetic/rewrite_simplify.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1757,6 +1757,13 @@ Mutate_(const Variable* op, const Expr& self) {
17571757
return self;
17581758
}
17591759

1760+
Expr RewriteSimplifier::Impl::
1761+
Mutate_(const Cast* op, const Expr& self) {
1762+
Expr ret = IRMutator::Mutate_(op, self);
1763+
op = ret.as<Cast>();
1764+
return cast(op->type, op->value);
1765+
}
1766+
17601767
Expr RewriteSimplifier::operator()(const Expr& expr) {
17611768
// Run simplification in post order
17621769
Expr res = expr;

src/arithmetic/rewrite_simplify.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ class RewriteSimplifier::Impl : public IRMutator {
7070
Expr Mutate_(const Call* op, const Expr& self) override;
7171
Expr Mutate_(const Let* op, const Expr& self) override;
7272
Expr Mutate_(const Variable* op, const Expr& self) override;
73+
Expr Mutate_(const Cast* op, const Expr& self) override;
7374

7475
protected:
7576
/*! \brief internal structure for comparison. */

src/lang/expr_operator.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,12 +105,15 @@ bool is_const_power_of_two_integer(const Expr& x, int* shift) {
105105

106106
Expr cast(const Type& t, Expr value) {
107107
using ir::IntImm;
108+
using ir::UIntImm;
108109
using ir::FloatImm;
109110
if (value.type() == t) return value;
110111
// const fold IntImm as they are used in index computations
111112
if (t.lanes() == 1) {
112113
if (const IntImm* op = value.as<IntImm>()) {
113114
return make_const(t, op->value);
115+
} else if (const UIntImm* op = value.as<UIntImm>()) {
116+
return make_const(t, op->value);
114117
} else if (const FloatImm* op = value.as<FloatImm>()) {
115118
return make_const(t, op->value);
116119
}
@@ -122,6 +125,8 @@ Expr cast(const Type& t, Expr value) {
122125
if (value.type() != vtype) {
123126
if (const IntImm* op = value.as<IntImm>()) {
124127
value = make_const(vtype, op->value);
128+
} else if (const UIntImm* op = value.as<UIntImm>()) {
129+
return make_const(t, op->value);
125130
} else if (const FloatImm* op = value.as<FloatImm>()) {
126131
value = make_const(vtype, op->value);
127132
} else {

tests/python/unittest/test_arith_rewrite_simplify.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -804,6 +804,18 @@ def test_let_simplify():
804804
z = tvm.expr.Let(x, 1, x + 1)
805805
ck.verify(z + z, 4)
806806

807+
def test_cast_simplify():
808+
ck = RewriteChecker()
809+
x = tvm.var("x")
810+
811+
dtypes = ["float32", "float16", "int32", "int8", "bool"]
812+
for dtype1 in dtypes:
813+
ck.verify(tvm.expr.Cast(dtype1, x - x), tvm.const(0, dtype1))
814+
ck.verify(tvm.expr.Cast(dtype1, x == x), tvm.const(1, dtype1))
815+
for dtype2 in dtypes:
816+
for i in [0, 1, 2, 3]:
817+
ck.verify(tvm.expr.Cast(dtype1, tvm.const(i, dtype2)), tvm.const(i, dtype1))
818+
807819
if __name__ == "__main__":
808820
test_floordiv_index_simplify()
809821
test_floormod_index_simplify()
@@ -819,3 +831,4 @@ def test_let_simplify():
819831
test_select_simplify()
820832
test_logical_simplify()
821833
test_let_simplify()
834+
test_cast_simplify()

0 commit comments

Comments
 (0)