diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 0e76e3d86d6e..c1d3d56a683f 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1595,6 +1595,10 @@ def wrapped(*args, **kwargs): atan = _op_wrapper(_tir_op.atan) atan2 = _op_wrapper(_tir_op.atan2) atanh = _op_wrapper(_tir_op.atanh) +bitwise_and = _op_wrapper(_tir_op.bitwise_and) +bitwise_not = _op_wrapper(_tir_op.bitwise_not) +bitwise_or = _op_wrapper(_tir_op.bitwise_or) +bitwise_xor = _op_wrapper(_tir_op.bitwise_xor) ceil = _op_wrapper(_tir_op.ceil) clz = _op_wrapper(_tir_op.clz) copysign = _op_wrapper(_tir_op.copysign) @@ -1846,6 +1850,10 @@ def wrapped(*args, **kwargs): "atan", "atan2", "atanh", + "bitwise_and", + "bitwise_not", + "bitwise_or", + "bitwise_xor", "ceil", "clz", "copysign", diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 9522181432f2..a77f11862def 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -68,6 +68,7 @@ from .op import sin, sinh, asin, asinh from .op import cos, cosh, acos, acosh from .op import tan, tanh, atan, atan2, atanh +from .op import bitwise_and, bitwise_not, bitwise_or, bitwise_xor from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil, hypot from .op import trunc, abs, round, nextafter, nearbyint, power, pow, popcount, fmod, if_then_else from .op import likely, isnan, isnullptr, isfinite, isinf, copysign diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 131e91de876e..0a9c4fdfaa52 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -1997,6 +1997,91 @@ def abs(x, span=None): return _ffi_api.abs(x, span) # type: ignore +def bitwise_and(x, y, span=None): + """Take bitwise and of two values + + Parameters + ---------- + x : PrimExpr + Left operand + + y : PrimExpr + Right operand + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + res : PrimExpr + The result. + """ + return _ffi_api.bitwise_and(x, y, span) + + +def bitwise_not(x, span=None): + """Take bitwise not of input value + + Parameters + ---------- + x : PrimExpr + Input operand + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + res : PrimExpr + The result. + """ + return _ffi_api.bitwise_not(x, span) + + +def bitwise_or(x, y, span=None): + """Take bitwise or of two values + + Parameters + ---------- + x : PrimExpr + Left operand + + y : PrimExpr + Right operand + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + res : PrimExpr + The result. + """ + return _ffi_api.bitwise_or(x, y, span) + + +def bitwise_xor(x, y, span=None): + """Take bitwise xor of two values + + Parameters + ---------- + x : PrimExpr + Left operand + + y : PrimExpr + Right operand + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + res : PrimExpr + The result. + """ + return _ffi_api.bitwise_xor(x, y, span) + + def round(x, span=None): """Round elements of the array to the nearest integer. diff --git a/tests/python/unittest/test_tir_op_types.py b/tests/python/unittest/test_tir_op_types.py index 23a264bef75a..58954e745948 100644 --- a/tests/python/unittest/test_tir_op_types.py +++ b/tests/python/unittest/test_tir_op_types.py @@ -279,6 +279,19 @@ def test_tir_op_shift_right(): assert expr.op.name == "tir.shift_right" +def test_tir_op_bitwise(): + x = tir.Var("x", dtype="int32") + y = tir.Var("y", dtype="int32") + expr = tir.bitwise_and(x, y) + assert expr.op.name == "tir.bitwise_and" + expr = tir.bitwise_or(x, y) + assert expr.op.name == "tir.bitwise_or" + expr = tir.bitwise_not(x) + assert expr.op.name == "tir.bitwise_not" + expr = tir.bitwise_xor(x, y) + assert expr.op.name == "tir.bitwise_xor" + + def test_tir_op_TVMBackendAllocWorkspace(): expr = tir.TVMBackendAllocWorkspace(0, 1, 2, 3, 4) assert expr.op.name == "tir.TVMBackendAllocWorkspace"