Skip to content

Commit 22db840

Browse files
fix i386 fp16 cases
1 parent 59f7db4 commit 22db840

File tree

1 file changed

+24
-14
lines changed

1 file changed

+24
-14
lines changed

tests/python/unittest/test_tir_imm_values.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def check_tir_const_fold(
209209
if isinstance(x_range, (int, float)):
210210
x = x_range
211211
elif dtype.startswith("int") or dtype.startswith("uint"):
212-
x = np.random.randint(x_range[0], x_range[1] + 1)
212+
x = np.random.randint(x_range[0], x_range[1] + 1, dtype=dtype)
213213
else:
214214
x = np.random.uniform(x_range[0], x_range[1])
215215

@@ -218,7 +218,7 @@ def check_tir_const_fold(
218218
if isinstance(y_range, (int, float)):
219219
y = y_range
220220
elif dtype.startswith("int") or dtype.startswith("uint"):
221-
y = np.random.randint(y_range[0], y_range[1] + 1)
221+
y = np.random.randint(y_range[0], y_range[1] + 1, dtype=dtype)
222222
else:
223223
y = np.random.uniform(y_range[0], y_range[1])
224224

@@ -249,25 +249,35 @@ def test_tir_floatimm_const_fold():
249249
"""Behavior check: folding fp32 match platform f32 arithmetic"""
250250

251251
@T.prim_func
252-
def float_imm_multiply(x: T.float32, y: T.float32) -> T.float32:
253-
T.evaluate(T.ret(x * y, dtype="float32"))
252+
def float_imm_multiply(x: T.float32, y: T.float32, z: T.Buffer[(), "float32"]):
253+
z[()] = x * y
254254

255255
@T.prim_func
256-
def float_imm_add(x: T.float32, y: T.float32) -> T.float32:
257-
T.evaluate(T.ret(x + y, dtype="float32"))
256+
def float_imm_add(x: T.float32, y: T.float32, z: T.Buffer[(), "float32"]):
257+
z[()] = x + y
258258

259259
@T.prim_func
260-
def float_imm_sub(x: T.float32, y: T.float32) -> T.float32:
261-
T.evaluate(T.ret(x - y, dtype="float32"))
260+
def float_imm_sub(x: T.float32, y: T.float32, z: T.Buffer[(), "float32"]):
261+
z[()] = x - y
262262

263263
@T.prim_func
264-
def float_imm_div(x: T.float32, y: T.float32) -> T.float32:
265-
T.evaluate(T.ret(x / y, dtype="float32"))
264+
def float_imm_div(x: T.float32, y: T.float32, z: T.Buffer[(), "float32"]):
265+
z[()] = x / y
266266

267-
fmul = tvm.build(float_imm_multiply, target="llvm")
268-
fadd = tvm.build(float_imm_add, target="llvm")
269-
fsub = tvm.build(float_imm_sub, target="llvm")
270-
fdiv = tvm.build(float_imm_div, target="llvm")
267+
def __wrap_build(f):
268+
lib = tvm.build(f, target="llvm")
269+
z = tvm.nd.array(np.zeros([]).astype("float32"))
270+
271+
def _func(x, y):
272+
lib(x, y, z)
273+
return z.numpy()
274+
275+
return _func
276+
277+
fmul = __wrap_build(float_imm_multiply)
278+
fadd = __wrap_build(float_imm_add)
279+
fsub = __wrap_build(float_imm_sub)
280+
fdiv = __wrap_build(float_imm_div)
271281

272282
# overflow
273283
check_tir_const_fold("float32", lambda x, y: x * y, fmul, 3.0e30, 3.0e30, np.inf)

0 commit comments

Comments
 (0)