Skip to content

Commit 700bac9

Browse files
committed
[ARITH] cleanup the indexmod/div on python side
1 parent 368a4ae commit 700bac9

File tree

19 files changed

+144
-89
lines changed

19 files changed

+144
-89
lines changed

python/tvm/autotvm/task/task.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,9 @@ def _count_flop(exp):
350350
return _count_flop(exp.value)
351351
if isinstance(exp, expr.Var):
352352
return 0
353-
if isinstance(exp, (expr.Add, expr.Sub, expr.Mul, expr.Div, expr.Mod,
353+
if isinstance(exp, (expr.Add, expr.Sub, expr.Mul,
354+
expr.Div, expr.Mod,
355+
expr.FloorDiv, expr.FloorMod,
354356
expr.Max, expr.Min,
355357
expr.EQ, expr.NE, expr.LT, expr.LE, expr.GT, expr.GE,
356358
expr.And, expr.Or, expr.Not)):

python/tvm/expr.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -72,23 +72,23 @@ def __rmul__(self, other):
7272
return _generic.multiply(other, self)
7373

7474
def __div__(self, other):
75-
# if _dtype_is_int(self) and _dtype_is_int(other):
76-
# raise div_ambiguity_error()
75+
if _dtype_is_int(self) and _dtype_is_int(other):
76+
raise div_ambiguity_error()
7777
return _generic.divide(self, other)
7878

7979
def __rdiv__(self, other):
80-
# if _dtype_is_int(self) and _dtype_is_int(other):
81-
# raise div_ambiguity_error()
80+
if _dtype_is_int(self) and _dtype_is_int(other):
81+
raise div_ambiguity_error()
8282
return _generic.divide(other, self)
8383

8484
def __truediv__(self, other):
85-
# if _dtype_is_int(self) and _dtype_is_int(other):
86-
# raise div_ambiguity_error()
85+
if _dtype_is_int(self) and _dtype_is_int(other):
86+
raise div_ambiguity_error()
8787
return _generic.divide(self, other)
8888

8989
def __rtruediv__(self, other):
90-
# if _dtype_is_int(self) and _dtype_is_int(other):
91-
# raise div_ambiguity_error()
90+
if _dtype_is_int(self) and _dtype_is_int(other):
91+
raise div_ambiguity_error()
9292
return _generic.divide(other, self)
9393

9494
def __floordiv__(self, other):
@@ -100,8 +100,8 @@ def __rfloordiv__(self, other):
100100
return _generic.divide(other, self)
101101

102102
def __mod__(self, other):
103-
# raise div_ambiguity_error()
104-
return _make._OpMod(self, other)
103+
raise div_ambiguity_error()
104+
# return _make._OpMod(self, other)
105105

106106
def __neg__(self):
107107
neg_one = _api_internal._const(-1, self.dtype)

src/pass/rewrite_unsafe_select.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
* to you under the Apache License, Version 2.0 (the
77
* "License"); you may not use this file except in compliance
88
* with the License. You may obtain a copy of the License at
9-
*
9+
*
1010
* http://www.apache.org/licenses/LICENSE-2.0
11-
*
11+
*
1212
* Unless required by applicable law or agreed to in writing,
1313
* software distributed under the License is distributed on an
1414
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -64,6 +64,8 @@ class UnsafeExprDetector : public ExprFunctor<bool(const Expr& n)> {
6464
bool VisitExpr_(const Mul* op) final { return BinaryOp(op); }
6565
bool VisitExpr_(const Div* op) final { return BinaryOp(op); }
6666
bool VisitExpr_(const Mod* op) final { return BinaryOp(op); }
67+
bool VisitExpr_(const FloorDiv* op) final { return BinaryOp(op); }
68+
bool VisitExpr_(const FloorMod* op) final { return BinaryOp(op); }
6769
bool VisitExpr_(const Min* op) final { return BinaryOp(op); }
6870
bool VisitExpr_(const Max* op) final { return BinaryOp(op); }
6971
bool VisitExpr_(const EQ* op) final { return BinaryOp(op); }

tests/python/relay/test_op_level3.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,8 @@ def verify_split(dshape, indices_or_sections, ret_type, axis=None):
373373
yy = run_infer_type(y.astuple())
374374
assert yy.checked_type == ret_type
375375

376+
idxd = tvm.indexdiv
377+
376378
d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4")
377379
axis = tvm.var("axis")
378380
verify_split((5, 5, 2, 2), 5,
@@ -393,15 +395,15 @@ def verify_split(dshape, indices_or_sections, ret_type, axis=None):
393395
axis=0)
394396
verify_split((d1, d2, d3, d4), 4,
395397
relay.ty.TupleType(tvm.convert([
396-
relay.ty.TensorType((d1, d2, d3/4, d4), "float32"),
397-
relay.ty.TensorType((d1, d2, d3/4, d4), "float32"),
398-
relay.ty.TensorType((d1, d2, d3/4, d4), "float32"),
399-
relay.ty.TensorType((d1, d2, d3/4, d4), "float32")])),
398+
relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32"),
399+
relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32"),
400+
relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32"),
401+
relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32")])),
400402
axis=2)
401403
verify_split((d1, d2, d3, d4), 2,
402404
relay.ty.TupleType(tvm.convert([
403-
relay.ty.TensorType((d1/2, d2, d3, d4), "float32"),
404-
relay.ty.TensorType((d1/2, d2, d3, d4), "float32")])),
405+
relay.ty.TensorType((idxd(d1, 2), d2, d3, d4), "float32"),
406+
relay.ty.TensorType((idxd(d1, 2), d2, d3, d4), "float32")])),
405407
axis=0)
406408
verify_split((d1, d2, d3, d4), (2, 4, 7),
407409
relay.ty.TupleType(tvm.convert([

tests/python/relay/test_op_level5.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -487,8 +487,9 @@ def verify_yolo_reorg(shape, stride, out_shape):
487487
assert zz.checked_type == relay.ty.TensorType(out_shape, "float32")
488488

489489
n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
490+
idxd = tvm.indexdiv
490491
verify_yolo_reorg((n, c, 20, 20), 10, (n, c*10*10, 2, 2))
491-
verify_yolo_reorg((n, c, h, w), 2, (n, c*2*2, h/2, w/2))
492+
verify_yolo_reorg((n, c, h, w), 2, (n, c*2*2, idxd(h, 2), idxd(w, 2)))
492493

493494
def test_yolo_reorg():
494495
def verify_yolo_reorg(shape, stride):

tests/python/unittest/test_autotvm_flop_calculator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,14 @@ def test_pack_gemm():
6060
k = tvm.reduce_axis((0, L))
6161

6262
bn = 4
63-
fld = tvm.floordiv
64-
flm = tvm.floormod
63+
idxd = tvm.indexdiv
64+
idxm = tvm.indexmod
6565

6666
A_pack = tvm.compute((N // bn, L, bn), lambda i, j, k: A[i * bn + k][j])
6767
B_pack = tvm.compute((M // bn, L, bn), lambda i, j, k: B[i * bn + k][j])
6868
C_pack = tvm.compute((N // bn, M // bn, bn, bn), lambda i, j, ii, jj:
6969
tvm.sum(A_pack[i, k, ii].astype(acc_dtype) * B_pack[j, k, jj].astype(acc_dtype), axis=[k]))
70-
C = tvm.compute((N, M), lambda i, j: C_pack[fld(i, bn)][fld(j, bn)][flm(i, bn)][flm(j, bn)])
70+
C = tvm.compute((N, M), lambda i, j: C_pack[idxd(i, bn)][idxd(j, bn)][idxm(i, bn)][idxm(j, bn)])
7171

7272
s = tvm.create_schedule([C.op])
7373
assert compute_flop(s) == 2 * N * L * M

tests/python/unittest/test_ir_builder.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,14 +109,15 @@ def test_gpu():
109109
dtype = "float32"
110110
A = tvm.placeholder((n,), name='A')
111111
B = tvm.placeholder((n,), name='B')
112-
fld = tvm.floordiv
112+
idxd = tvm.indexdiv
113+
113114
def test_device_ir(A, B, C):
114115
n = A.shape[0]
115116
max_threads = 32
116117
ib = tvm.ir_builder.create()
117118
bx = tvm.thread_axis("blockIdx.x")
118119
tx = tvm.thread_axis("threadIdx.x")
119-
ib.scope_attr(bx, "thread_extent", fld(n+max_threads-1, max_threads))
120+
ib.scope_attr(bx, "thread_extent", idxd(n+max_threads-1, max_threads))
120121
ib.scope_attr(tx, "thread_extent", max_threads)
121122
idx = bx.var * max_threads + tx.var
122123
Aptr = ib.buffer_ptr(A)

tests/python/unittest/test_lang_buffer.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -94,31 +94,31 @@ def test_buffer_index_merge_mult_mod():
9494
def assert_simplified_equal(index_simplified, index_direct):
9595
assert tvm.ir_pass.Equal(index_simplified, index_direct),\
9696
"index_simplified=%s, index_direct=%s" %(index_simplified, index_direct)
97-
idxdiv = tvm.indexdiv
98-
idxmod = tvm.indexmod
97+
idxd = tvm.indexdiv
98+
idxm = tvm.indexmod
9999
# Test Case1
100100
index_simplified = A_stride.vload(
101-
(idxdiv(idxmod(k0, k1), s), idxmod(idxmod(k0, k1), s) + idxdiv(k0, k1) * k1))
101+
(idxd(idxm(k0, k1), s), idxm(idxm(k0, k1), s) + idxd(k0, k1) * k1))
102102
index_direct = A_stride.vload((0, k0))
103103
assert_simplified_equal(index_simplified, index_direct)
104104

105105
# Test Case2
106-
index_simplified = A.vload((idxdiv(idxmod(k0, idxdiv(k1, s)), n),
107-
idxmod(idxmod(k0, idxdiv(k1, s)), n) + idxmod(k0, k1)))
108-
index_direct = A.vload((0, idxmod(k0, k1) + idxmod(k0, idxdiv(k1, s))))
106+
index_simplified = A.vload((idxd(idxm(k0, idxd(k1, s)), n),
107+
idxm(idxm(k0, idxd(k1, s)), n) + idxm(k0, k1)))
108+
index_direct = A.vload((0, idxm(k0, k1) + idxm(k0, idxd(k1, s))))
109109
assert_simplified_equal(index_simplified, index_direct)
110110
# Test Case3
111-
index_simplified = A.vload((idxdiv((idxdiv(k0, idxdiv(k1, s)) * idxdiv(k1, s)), n) +
112-
idxdiv(idxmod(k0, idxdiv(k1, s)), n),
113-
idxmod((idxdiv(k0, idxdiv(k1, s)) * idxdiv(k1, s)), n) +
114-
idxmod(idxmod(k0, idxdiv(k1, s)), n)))
111+
index_simplified = A.vload((idxd((idxd(k0, idxd(k1, s)) * idxd(k1, s)), n) +
112+
idxd(idxm(k0, idxd(k1, s)), n),
113+
idxm((idxd(k0, idxd(k1, s)) * idxd(k1, s)), n) +
114+
idxm(idxm(k0, idxd(k1, s)), n)))
115115
index_direct = A.vload((0, k0))
116116
assert_simplified_equal(index_simplified, index_direct)
117117
# Test Case4 (not able to simplify)
118-
index_simplified = A.vload((idxdiv(idxmod(k0, idxdiv(k1, s)), n),
119-
idxmod(idxmod(k0, idxdiv(k1, n)), n) + idxmod(k0, k1)))
120-
index_direct = A.vload((0, idxdiv(idxmod(k0, idxdiv(k1, s)), n) * n +
121-
(idxmod(idxmod(k0, idxdiv(k1, n)), n) + idxmod(k0, k1))))
118+
index_simplified = A.vload((idxd(idxm(k0, idxd(k1, s)), n),
119+
idxm(idxm(k0, idxd(k1, n)), n) + idxm(k0, k1)))
120+
index_direct = A.vload((0, idxd(idxm(k0, idxd(k1, s)), n) * n +
121+
(idxm(idxm(k0, idxd(k1, n)), n) + idxm(k0, k1))))
122122
assert_simplified_equal(index_simplified, index_direct)
123123

124124

tests/python/unittest/test_pass_rewrite_unsafe_select.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def test_rewrite_Select():
2828
tvm.expr.Select(i > 1, A[i-1], 1.0) > 0.0, A[i], 0.1)
2929
zz = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(z)).value
3030

31-
a = tvm.expr.Select(i>10, y, z)
31+
a = tvm.expr.Select(tvm.floordiv(i, 4) > 10, y, z)
3232
aa = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(a)).value
3333
assert yy.name == "tvm_if_then_else"
3434
assert zz.name == "tvm_if_then_else"

tests/python/unittest/test_schedule_tensorize.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -221,14 +221,15 @@ def check_rfactor_no_reset_multi_reduction(factor, rfactor):
221221
# This tests whether algorithm and intrinsics expressions are simplified
222222
# as much as possible first and then checked for equality. See Issue #696
223223
def test_tensorize_op():
224-
tdiv = tvm.truncdiv
225-
tmod = tvm.truncmod
224+
idxd = tvm.indexdiv
225+
idxm = tvm.indexmod
226+
226227
def op_intrin():
227228
bh = 9
228229
bw = 9
229230
x = tvm.placeholder((5, 5), name='A')
230231
y = tvm.compute((bh, bw),
231-
lambda i, j: x[tdiv(j,3) + tmod(i,3), tmod(j,3)+ tdiv(i,3)])
232+
lambda i, j: x[idxd(j,3) + idxm(i,3), idxm(j,3)+ idxd(i,3)])
232233

233234
def intrin_func(ins, outs):
234235
xx, = ins
@@ -239,7 +240,7 @@ def intrin_func(ins, outs):
239240
return tvm.decl_tensor_intrin(y.op, intrin_func)
240241

241242
A = tvm.placeholder((5, 5), name='A')
242-
B = tvm.compute((9,9), lambda i, j: A[tdiv(j,3) + tmod(i,3), tmod(j,3) + tdiv(i,3)])
243+
B = tvm.compute((9,9), lambda i, j: A[idxd(j,3) + idxm(i,3), idxm(j,3) + idxd(i,3)])
243244
bt = op_intrin()
244245
s = tvm.create_schedule(B.op)
245246

0 commit comments

Comments
 (0)