@@ -335,6 +335,9 @@ def inject_dma_intrin(stmt_in):
335335 Transformed statement
336336 """
337337 env = get_env ()
338+ idxd = tvm .indexdiv
339+ idxm = tvm .indexmod
340+
338341 def _check_compact (buf ):
339342 ndim = len (buf .shape )
340343 size = tvm .const (1 , buf .shape [0 ].dtype )
@@ -369,7 +372,7 @@ def _fold_buffer_dim(buf, scope, elem_block):
369372 x_size = 1
370373 x_stride = buf .strides [ndim - base ]
371374 next_base = base
372- if not util .equal_const_int (x_stride % elem_block , 0 ):
375+ if not util .equal_const_int (idxm ( x_stride , elem_block ) , 0 ):
373376 raise RuntimeError (
374377 "scope %s need to have block=%d, shape=%s, strides=%s" % (
375378 scope , elem_block , buf .shape , buf .strides ))
@@ -394,7 +397,7 @@ def _get_2d_pattern(buf, elem_width, elem_bytes, dtype, scope, allow_fold):
394397 raise RuntimeError ("Expect buffer type to be %s instead of %s" %
395398 (dtype , buf .dtype ))
396399 shape , strides = buf .shape , buf .strides
397- if not util .equal_const_int (buf .elem_offset % elem_block , 0 ):
400+ if not util .equal_const_int (idxm ( buf .elem_offset , elem_block ) , 0 ):
398401 raise RuntimeError ("scope %s need to have block=%d" % (scope , elem_block ))
399402 if allow_fold :
400403 shape , strides = _fold_buffer_dim (buf , scope , elem_block )
@@ -421,23 +424,23 @@ def raise_error():
421424 x_size = 1
422425 x_stride = 1
423426 y_size = 1
424- return x_size , y_size , x_stride , buf .elem_offset / elem_block
427+ return x_size , y_size , x_stride , idxd ( buf .elem_offset , elem_block )
425428 if not util .equal_const_int (strides [- 2 ] - elem_block , 0 ):
426429 raise_error ()
427430
428431 if ndim == 2 :
429432 x_size = shape [- 2 ]
430433 x_stride = shape [- 2 ]
431434 y_size = 1
432- return x_size , y_size , x_stride , buf .elem_offset / elem_block
433- if not util .equal_const_int (strides [- 3 ] % elem_block , 0 ):
435+ return x_size , y_size , x_stride , idxd ( buf .elem_offset , elem_block )
436+ if not util .equal_const_int (idxm ( strides [- 3 ], elem_block ) , 0 ):
434437 raise_error ()
435438
436439 if ndim == 3 :
437440 x_size = shape [- 2 ]
438- x_stride = strides [- 3 ] / elem_block
441+ x_stride = idxd ( strides [- 3 ], elem_block )
439442 y_size = shape [- 3 ]
440- return x_size , y_size , x_stride , buf .elem_offset / elem_block
443+ return x_size , y_size , x_stride , idxd ( buf .elem_offset , elem_block )
441444
442445 else :
443446 if not util .equal_const_int (strides [- 1 ], 1 ):
@@ -451,23 +454,23 @@ def raise_error():
451454 x_size = 1
452455 x_stride = 1
453456 y_size = 1
454- return x_size , y_size , x_stride , buf .elem_offset / elem_block
457+ return x_size , y_size , x_stride , idxd ( buf .elem_offset , elem_block )
455458 if not util .equal_const_int (strides [- 3 ], elem_block ):
456459 raise_error ()
457460
458461 if ndim == 3 :
459462 x_size = shape [- 3 ]
460463 x_stride = shape [- 3 ]
461464 y_size = 1
462- return x_size , y_size , x_stride , buf .elem_offset / elem_block
463- if not util .equal_const_int (strides [- 4 ] % elem_block , 0 ):
465+ return x_size , y_size , x_stride , idxd ( buf .elem_offset , elem_block )
466+ if not util .equal_const_int (idxm ( strides [- 4 ], elem_block ) , 0 ):
464467 raise_error ()
465468
466469 if ndim == 4 :
467470 x_size = shape [- 3 ]
468- x_stride = strides [- 4 ] / elem_block
471+ x_stride = idxd ( strides [- 4 ], elem_block )
469472 y_size = shape [- 4 ]
470- return x_size , y_size , x_stride , buf .elem_offset / elem_block
473+ return x_size , y_size , x_stride , idxd ( buf .elem_offset , elem_block )
471474
472475 raise_error ()
473476
@@ -765,6 +768,8 @@ def inject_alu_intrin(stmt_in):
765768 Transformed statement
766769 """
767770 env = get_env ()
771+ idxm = tvm .indexmod
772+
768773 def _do_fold (stmt ):
769774 def _equal (x , y ):
770775 return tvm .ir_pass .Equal (tvm .ir_pass .Simplify (x - y ), 0 )
@@ -910,10 +915,10 @@ def _flatten_loop(src_coeff, dst_coeff, extents):
910915 assert len (extents ) != 0
911916 assert tvm .ir_pass .Equal (
912917 tvm .ir_pass .Simplify (
913- src_coeff [- 1 ] % ( env .BATCH * env .BLOCK_OUT )), 0 )
918+ idxm ( src_coeff [- 1 ], env .BATCH * env .BLOCK_OUT )), 0 )
914919 assert tvm .ir_pass .Equal (
915920 tvm .ir_pass .Simplify (
916- dst_coeff [- 1 ] % ( env .BATCH * env .BLOCK_OUT )), 0 )
921+ idxm ( dst_coeff [- 1 ], env .BATCH * env .BLOCK_OUT )), 0 )
917922 assert tvm .ir_pass .Equal (src_coeff [- 2 ], 1 )
918923 assert tvm .ir_pass .Equal (dst_coeff [- 2 ], 1 )
919924 if env .BATCH > 1 :
0 commit comments