@@ -209,7 +209,7 @@ def check_tir_const_fold(
209209 if isinstance (x_range , (int , float )):
210210 x = x_range
211211 elif dtype .startswith ("int" ) or dtype .startswith ("uint" ):
212- x = np .random .randint (x_range [0 ], x_range [1 ] + 1 )
212+ x = np .random .randint (x_range [0 ], x_range [1 ] + 1 , dtype = dtype )
213213 else :
214214 x = np .random .uniform (x_range [0 ], x_range [1 ])
215215
@@ -218,7 +218,7 @@ def check_tir_const_fold(
218218 if isinstance (y_range , (int , float )):
219219 y = y_range
220220 elif dtype .startswith ("int" ) or dtype .startswith ("uint" ):
221- y = np .random .randint (y_range [0 ], y_range [1 ] + 1 )
221+ y = np .random .randint (y_range [0 ], y_range [1 ] + 1 , dtype = dtype )
222222 else :
223223 y = np .random .uniform (y_range [0 ], y_range [1 ])
224224
@@ -249,25 +249,35 @@ def test_tir_floatimm_const_fold():
249249 """Behavior check: folding fp32 match platform f32 arithmetic"""
250250
251251 @T .prim_func
252- def float_imm_multiply (x : T .float32 , y : T .float32 ) -> T .float32 :
253- T . evaluate ( T . ret ( x * y , dtype = "float32" ))
252+ def float_imm_multiply (x : T .float32 , y : T .float32 , z : T .Buffer [(), " float32" ]) :
253+ z [()] = x * y
254254
255255 @T .prim_func
256- def float_imm_add (x : T .float32 , y : T .float32 ) -> T .float32 :
257- T . evaluate ( T . ret ( x + y , dtype = "float32" ))
256+ def float_imm_add (x : T .float32 , y : T .float32 , z : T .Buffer [(), " float32" ]) :
257+ z [()] = x + y
258258
259259 @T .prim_func
260- def float_imm_sub (x : T .float32 , y : T .float32 ) -> T .float32 :
261- T . evaluate ( T . ret ( x - y , dtype = "float32" ))
260+ def float_imm_sub (x : T .float32 , y : T .float32 , z : T .Buffer [(), " float32" ]) :
261+ z [()] = x - y
262262
263263 @T .prim_func
264- def float_imm_div (x : T .float32 , y : T .float32 ) -> T .float32 :
265- T . evaluate ( T . ret ( x / y , dtype = "float32" ))
264+ def float_imm_div (x : T .float32 , y : T .float32 , z : T .Buffer [(), " float32" ]) :
265+ z [()] = x / y
266266
267- fmul = tvm .build (float_imm_multiply , target = "llvm" )
268- fadd = tvm .build (float_imm_add , target = "llvm" )
269- fsub = tvm .build (float_imm_sub , target = "llvm" )
270- fdiv = tvm .build (float_imm_div , target = "llvm" )
267+ def __wrap_build (f ):
268+ lib = tvm .build (f , target = "llvm" )
269+ z = tvm .nd .array (np .zeros ([]).astype ("float32" ))
270+
271+ def _func (x , y ):
272+ lib (x , y , z )
273+ return z .numpy ()
274+
275+ return _func
276+
277+ fmul = __wrap_build (float_imm_multiply )
278+ fadd = __wrap_build (float_imm_add )
279+ fsub = __wrap_build (float_imm_sub )
280+ fdiv = __wrap_build (float_imm_div )
271281
272282 # overflow
273283 check_tir_const_fold ("float32" , lambda x , y : x * y , fmul , 3.0e30 , 3.0e30 , np .inf )
0 commit comments