Skip to content

Problem with equality operator of expr.Mod? #540

@wweic

Description

@wweic

I'm looking at unit test of ir builder, looks like the == operator does not work for Mod as expected:

    ib = tvm.ir_builder.create()
    n = tvm.var("n")
    A = ib.pointer("float32", name="A")
    with ib.for_range(0, n, name="i") as i:
        with ib.if_scope((i % 2) == 0):
            A[i] = A[i] + 1
        with ib.else_scope():
            A[0] = A[i] + 2

The generated IR is:

~ ipython
Python 2.7.10 (default, Jul 14 2015, 19:46:27)
Type "copyright", "credits" or "license" for more information.

IPython 5.1.0 -- An enhanced Interactive Python.
?         -> Introduction and overview of IPython's features.
%quickref -> Quick reference.
help      -> Python's own help system.
object?   -> Details about 'object', use 'object??' for extra details.

In [1]: import tvm
   ...:

In [2]:     ib = tvm.ir_builder.create()
   ...:     n = tvm.var("n")
   ...:     A = ib.pointer("float32", name="A")
   ...:     with ib.for_range(0, n, name="i") as i:
   ...:         with ib.if_scope((i % 2) == 0):
   ...:             A[i] = A[i] + 1
   ...:         with ib.else_scope():
   ...:             A[0] = A[i] + 2
   ...:
   ...:     body = ib.get()
   ...:

In [3]: body
Out[3]:
for (i, 0, n) {
  if (0) {
    A[i] = (A[i] + 1.000000f)
  } else {
    A[0] = (A[i] + 2.000000f)
  }
}

Note that the condition expression is constant 0 instead of an equality check expression. The __eq__ operator for Mod comes from tvm/_ffi/node.py

In [4]: i = ib.for_range(0, n, name="i")

In [5]: a = i._enter_value % 1

In [6]: a.__eq__.im_func
Out[6]: <function tvm._ffi.node.__eq__>

Which I assume what we want is tvm/expr.py?

I modify the unit test to use self.equal directly, and the code now makes sense to me:

In [7]: %paste
    ib = tvm.ir_builder.create()
    n = tvm.var("n")
    A = ib.pointer("float32", name="A")
    with ib.for_range(0, n, name="i") as i:
        with ib.if_scope((i % 2).equal(0)):
            A[i] = A[i] + 1
        with ib.else_scope():
            A[0] = A[i] + 2

    body = ib.get()
## -- End pasted text --

In [8]: body
Out[8]:
for (i, 0, n) {
  if (((i % 2) == 0)) {
    A[i] = (A[i] + 1.000000f)
  } else {
    A[0] = (A[i] + 2.000000f)
  }
}

Let me know if you think this is a problem. I'd be interested to help figure out why.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions