Skip to content

Commit 9642548

Browse files
committed
Fix invoking C device API.
1 parent 4705a18 commit 9642548

File tree

2 files changed

+10
-21
lines changed

2 files changed

+10
-21
lines changed

src/relay/backend/aot_executor_codegen.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,6 @@ class AOTExecutorCodegen : public MixedModeVisitor {
438438

439439
GlobalVar global_var = call_lowered_props.lowered_func;
440440
bool has_c_device_api_context = device_contexts_.count(global_var) != 0;
441-
bool call_device_hooks = false;
442441
tir::Var device_context;
443442
tir::Stmt func_call;
444443

@@ -480,7 +479,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
480479

481480
ICHECK(func_call.defined()) << "Must define func_call";
482481

483-
if (call_device_hooks) {
482+
if (has_c_device_api_context) {
484483
func_call = tir::SeqStmt(Array<tir::Stmt>({
485484
GenerateDeviceHook(device_context, "Open"),
486485
func_call,
@@ -582,8 +581,9 @@ class AOTExecutorCodegen : public MixedModeVisitor {
582581
Array<String> sections = {"Device", device_name, hook};
583582
String device_hook = ToCFunctionStyle(PrefixName(sections));
584583

585-
return tir::Evaluate(tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(),
586-
{tvm::tir::StringImm(device_hook), context}));
584+
return tir::Evaluate(
585+
AddCheckReturn(tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(),
586+
{tvm::tir::StringImm(device_hook), context})));
587587
}
588588

589589
/*!

tests/python/relay/aot/test_c_device_api.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def test_device_api_hooks_unpacked_api(device_api_main_func):
143143
+ " device_context_ethos_u))\n"
144144
)
145145
# Open Device
146+
print("main func", repr(main_func.body))
146147
assert (
147148
str(main_func.body[1][0][0][0])
148149
== "tir.tvm_check_return(0, -1, tir.call_extern("
@@ -239,23 +240,11 @@ def test_without_device_api_packed_api(non_device_api_main_func):
239240

240241
main_func = non_device_api_main_func(interface_api="packed", use_unpacked_api=False)
241242
assert str(main_func.body) == (
242-
'let tvm_value_3 = tir.tvm_stack_alloca("array", 1)\n'
243-
'let tvm_value_2 = tir.tvm_stack_alloca("array", 1)\n'
244-
'let tvm_value_1 = tir.tvm_stack_alloca("array", 1)\n'
245-
'let tvm_value_0 = tir.tvm_stack_alloca("array", 1)\n'
246-
"tir.tvm_struct_set(tvm_value_0, 0, 1, x_buffer_var)\n"
247-
"tir.tvm_struct_set(tvm_value_0, 0, 10, 1)\n"
248-
"tir.tvm_struct_set(tvm_value_0, 0, 9, 0)\n"
249-
"tir.tvm_struct_set(tvm_value_1, 0, 1, y_buffer_var)\n"
250-
"tir.tvm_struct_set(tvm_value_1, 0, 10, 1)\n"
251-
"tir.tvm_struct_set(tvm_value_1, 0, 9, 0)\n"
252-
"tir.tvm_struct_set(tvm_value_2, 0, 1, output_buffer_var)\n"
253-
"tir.tvm_struct_set(tvm_value_2, 0, 10, 1)\n"
254-
"tir.tvm_struct_set(tvm_value_2, 0, 9, 0)\n"
255-
"tir.tvm_struct_set(tvm_value_3, 0, 1, tir.reinterpret((uint64)0))\n"
256-
"tir.tvm_struct_set(tvm_value_3, 0, 10, 1)\n"
257-
"tir.tvm_struct_set(tvm_value_3, 0, 9, 0)\n"
258-
'tir.tvm_call_cpacked("tvmgen_default_fused_multiply", tvm_value_0, tvm_value_1, tvm_value_2, tvm_value_3)\n'
243+
'tir.tvm_call_cpacked("tvmgen_default_fused_multiply", '
244+
"tir.tvm_stack_make_array(x_buffer_var, tir.tvm_stack_make_shape(10, 10), tir.reinterpret((uint64)0), (uint32)2, float32(0), 0), "
245+
"tir.tvm_stack_make_array(y_buffer_var, tir.tvm_stack_make_shape(1, 10), tir.reinterpret((uint64)0), (uint32)2, float32(0), 0), "
246+
"tir.tvm_stack_make_array(output_buffer_var, tir.tvm_stack_make_shape(10, 10), tir.reinterpret((uint64)0), (uint32)2, float32(0), 0), "
247+
"tir.reinterpret((uint64)0))\n"
259248
)
260249

261250

0 commit comments

Comments
 (0)