@@ -143,6 +143,7 @@ def test_device_api_hooks_unpacked_api(device_api_main_func):
143143 + " device_context_ethos_u))\n "
144144 )
145145 # Open Device
146+ print ("main func" , repr (main_func .body ))
146147 assert (
147148 str (main_func .body [1 ][0 ][0 ][0 ])
148149 == "tir.tvm_check_return(0, -1, tir.call_extern("
@@ -239,23 +240,11 @@ def test_without_device_api_packed_api(non_device_api_main_func):
239240
240241 main_func = non_device_api_main_func (interface_api = "packed" , use_unpacked_api = False )
241242 assert str (main_func .body ) == (
242- 'let tvm_value_3 = tir.tvm_stack_alloca("array", 1)\n '
243- 'let tvm_value_2 = tir.tvm_stack_alloca("array", 1)\n '
244- 'let tvm_value_1 = tir.tvm_stack_alloca("array", 1)\n '
245- 'let tvm_value_0 = tir.tvm_stack_alloca("array", 1)\n '
246- "tir.tvm_struct_set(tvm_value_0, 0, 1, x_buffer_var)\n "
247- "tir.tvm_struct_set(tvm_value_0, 0, 10, 1)\n "
248- "tir.tvm_struct_set(tvm_value_0, 0, 9, 0)\n "
249- "tir.tvm_struct_set(tvm_value_1, 0, 1, y_buffer_var)\n "
250- "tir.tvm_struct_set(tvm_value_1, 0, 10, 1)\n "
251- "tir.tvm_struct_set(tvm_value_1, 0, 9, 0)\n "
252- "tir.tvm_struct_set(tvm_value_2, 0, 1, output_buffer_var)\n "
253- "tir.tvm_struct_set(tvm_value_2, 0, 10, 1)\n "
254- "tir.tvm_struct_set(tvm_value_2, 0, 9, 0)\n "
255- "tir.tvm_struct_set(tvm_value_3, 0, 1, tir.reinterpret((uint64)0))\n "
256- "tir.tvm_struct_set(tvm_value_3, 0, 10, 1)\n "
257- "tir.tvm_struct_set(tvm_value_3, 0, 9, 0)\n "
258- 'tir.tvm_call_cpacked("tvmgen_default_fused_multiply", tvm_value_0, tvm_value_1, tvm_value_2, tvm_value_3)\n '
243+ 'tir.tvm_call_cpacked("tvmgen_default_fused_multiply", '
244+ "tir.tvm_stack_make_array(x_buffer_var, tir.tvm_stack_make_shape(10, 10), tir.reinterpret((uint64)0), (uint32)2, float32(0), 0), "
245+ "tir.tvm_stack_make_array(y_buffer_var, tir.tvm_stack_make_shape(1, 10), tir.reinterpret((uint64)0), (uint32)2, float32(0), 0), "
246+ "tir.tvm_stack_make_array(output_buffer_var, tir.tvm_stack_make_shape(10, 10), tir.reinterpret((uint64)0), (uint32)2, float32(0), 0), "
247+ "tir.reinterpret((uint64)0))\n "
259248 )
260249
261250
0 commit comments