Skip to content

Commit 89be933

Browse files
authored
[ARITH] cleanup the indexmod/div on python side (apache#4028)
1 parent 9b58279 commit 89be933

File tree

1 file changed

+19
-14
lines changed

1 file changed

+19
-14
lines changed

python/vta/ir_pass.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,9 @@ def inject_dma_intrin(stmt_in):
335335
Transformed statement
336336
"""
337337
env = get_env()
338+
idxd = tvm.indexdiv
339+
idxm = tvm.indexmod
340+
338341
def _check_compact(buf):
339342
ndim = len(buf.shape)
340343
size = tvm.const(1, buf.shape[0].dtype)
@@ -369,7 +372,7 @@ def _fold_buffer_dim(buf, scope, elem_block):
369372
x_size = 1
370373
x_stride = buf.strides[ndim - base]
371374
next_base = base
372-
if not util.equal_const_int(x_stride % elem_block, 0):
375+
if not util.equal_const_int(idxm(x_stride, elem_block), 0):
373376
raise RuntimeError(
374377
"scope %s need to have block=%d, shape=%s, strides=%s" % (
375378
scope, elem_block, buf.shape, buf.strides))
@@ -394,7 +397,7 @@ def _get_2d_pattern(buf, elem_width, elem_bytes, dtype, scope, allow_fold):
394397
raise RuntimeError("Expect buffer type to be %s instead of %s" %
395398
(dtype, buf.dtype))
396399
shape, strides = buf.shape, buf.strides
397-
if not util.equal_const_int(buf.elem_offset % elem_block, 0):
400+
if not util.equal_const_int(idxm(buf.elem_offset, elem_block), 0):
398401
raise RuntimeError("scope %s need to have block=%d" % (scope, elem_block))
399402
if allow_fold:
400403
shape, strides = _fold_buffer_dim(buf, scope, elem_block)
@@ -421,23 +424,23 @@ def raise_error():
421424
x_size = 1
422425
x_stride = 1
423426
y_size = 1
424-
return x_size, y_size, x_stride, buf.elem_offset / elem_block
427+
return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
425428
if not util.equal_const_int(strides[-2] - elem_block, 0):
426429
raise_error()
427430

428431
if ndim == 2:
429432
x_size = shape[-2]
430433
x_stride = shape[-2]
431434
y_size = 1
432-
return x_size, y_size, x_stride, buf.elem_offset / elem_block
433-
if not util.equal_const_int(strides[-3] % elem_block, 0):
435+
return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
436+
if not util.equal_const_int(idxm(strides[-3], elem_block), 0):
434437
raise_error()
435438

436439
if ndim == 3:
437440
x_size = shape[-2]
438-
x_stride = strides[-3] / elem_block
441+
x_stride = idxd(strides[-3], elem_block)
439442
y_size = shape[-3]
440-
return x_size, y_size, x_stride, buf.elem_offset / elem_block
443+
return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
441444

442445
else:
443446
if not util.equal_const_int(strides[-1], 1):
@@ -451,23 +454,23 @@ def raise_error():
451454
x_size = 1
452455
x_stride = 1
453456
y_size = 1
454-
return x_size, y_size, x_stride, buf.elem_offset / elem_block
457+
return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
455458
if not util.equal_const_int(strides[-3], elem_block):
456459
raise_error()
457460

458461
if ndim == 3:
459462
x_size = shape[-3]
460463
x_stride = shape[-3]
461464
y_size = 1
462-
return x_size, y_size, x_stride, buf.elem_offset / elem_block
463-
if not util.equal_const_int(strides[-4] % elem_block, 0):
465+
return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
466+
if not util.equal_const_int(idxm(strides[-4], elem_block), 0):
464467
raise_error()
465468

466469
if ndim == 4:
467470
x_size = shape[-3]
468-
x_stride = strides[-4] / elem_block
471+
x_stride = idxd(strides[-4], elem_block)
469472
y_size = shape[-4]
470-
return x_size, y_size, x_stride, buf.elem_offset / elem_block
473+
return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
471474

472475
raise_error()
473476

@@ -765,6 +768,8 @@ def inject_alu_intrin(stmt_in):
765768
Transformed statement
766769
"""
767770
env = get_env()
771+
idxm = tvm.indexmod
772+
768773
def _do_fold(stmt):
769774
def _equal(x, y):
770775
return tvm.ir_pass.Equal(tvm.ir_pass.Simplify(x - y), 0)
@@ -910,10 +915,10 @@ def _flatten_loop(src_coeff, dst_coeff, extents):
910915
assert len(extents) != 0
911916
assert tvm.ir_pass.Equal(
912917
tvm.ir_pass.Simplify(
913-
src_coeff[-1] % (env.BATCH * env.BLOCK_OUT)), 0)
918+
idxm(src_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0)
914919
assert tvm.ir_pass.Equal(
915920
tvm.ir_pass.Simplify(
916-
dst_coeff[-1] % (env.BATCH * env.BLOCK_OUT)), 0)
921+
idxm(dst_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0)
917922
assert tvm.ir_pass.Equal(src_coeff[-2], 1)
918923
assert tvm.ir_pass.Equal(dst_coeff[-2], 1)
919924
if env.BATCH > 1:

0 commit comments

Comments
 (0)