From 43923b6861ffb199ff394f8e2112299b570ff3a3 Mon Sep 17 00:00:00 2001 From: tqchen Date: Fri, 5 Sep 2025 21:57:33 -0400 Subject: [PATCH] [FFI][ABI] Append symbol prefix for ffi exported functions Previously we simply take the raw symbol for DSO libraries. This can cause symbol conflict of functions that take the ffi calling convention and those that are not. This PR updates the convention to ask for LLVM and libary module to always append a prefix __tvm_ffi_ to function symbols, this way we will no longer have conflict in TVM_FFI_EXPORT_DLL_TYPED macro --- ffi/include/tvm/ffi/extra/module.h | 15 +++++++---- ffi/include/tvm/ffi/function.h | 26 +++++++++---------- ffi/python/tvm_ffi/module.py | 4 +-- ffi/src/ffi/extra/library_module.cc | 4 +-- .../ffi/extra/library_module_dynamic_lib.cc | 2 +- .../ffi/extra/library_module_system_lib.cc | 17 ++++++------ ffi/src/ffi/extra/module_internal.h | 12 ++++++++- .../src/main/java/org/apache/tvm/Module.java | 2 +- src/target/llvm/codegen_cpu.cc | 8 +++++- src/target/llvm/llvm_module.cc | 6 +++-- src/target/source/codegen_c.cc | 4 ++- src/target/source/codegen_c.h | 2 ++ src/target/source/codegen_c_host.cc | 8 +++--- src/target/source/codegen_c_host.h | 3 +++ src/tir/transforms/make_packed_api.cc | 14 ++++++---- .../codegen/test_target_codegen_c_host.py | 10 +------ .../codegen/test_target_codegen_llvm.py | 5 +++- .../test_hexagon/test_async_dma_pipeline.py | 10 +++---- .../contrib/test_hexagon/test_parallel_hvx.py | 4 +-- .../test_parallel_hvx_load_vtcm.py | 10 +++---- .../test_hexagon/test_parallel_scalar.py | 6 ++--- .../test_hexagon/test_vtcm_bandwidth.py | 8 ++---- .../test_tir_transform_make_packed_api.py | 3 +++ 23 files changed, 101 insertions(+), 82 deletions(-) diff --git a/ffi/include/tvm/ffi/extra/module.h b/ffi/include/tvm/ffi/extra/module.h index bc7dff159cda..1af2c2b6b2c0 100644 --- a/ffi/include/tvm/ffi/extra/module.h +++ b/ffi/include/tvm/ffi/extra/module.h @@ -223,14 +223,19 @@ class Module : public ObjectRef { * \brief Symbols for library module. */ namespace symbol { +/*!\ brief symbol prefix for tvm ffi related function symbols */ +constexpr const char* tvm_ffi_symbol_prefix = "__tvm_ffi_"; +// Special symbols have one extra _ prefix to avoid conflict with user symbols +/*! + * \brief Default entry function of a library module is tvm_ffi_symbol_prefix + "main" + */ +constexpr const char* tvm_ffi_main = "__tvm_ffi_main"; /*! \brief Global variable to store context pointer for a library module. */ -constexpr const char* tvm_ffi_library_ctx = "__tvm_ffi_library_ctx"; +constexpr const char* tvm_ffi_library_ctx = "__tvm_ffi__library_ctx"; /*! \brief Global variable to store binary data alongside a library module. */ -constexpr const char* tvm_ffi_library_bin = "__tvm_ffi_library_bin"; +constexpr const char* tvm_ffi_library_bin = "__tvm_ffi__library_bin"; /*! \brief Optional metadata prefix of a symbol. */ -constexpr const char* tvm_ffi_metadata_prefix = "__tvm_ffi_metadata_"; -/*! \brief Default entry function of a library module. */ -constexpr const char* tvm_ffi_main = "__tvm_ffi_main__"; +constexpr const char* tvm_ffi_metadata_prefix = "__tvm_ffi__metadata_"; } // namespace symbol } // namespace ffi } // namespace tvm diff --git a/ffi/include/tvm/ffi/function.h b/ffi/include/tvm/ffi/function.h index 5a30f25a7b5b..f84978800e36 100644 --- a/ffi/include/tvm/ffi/function.h +++ b/ffi/include/tvm/ffi/function.h @@ -800,19 +800,19 @@ inline int32_t TypeKeyToIndex(std::string_view type_key) { * * \endcode */ -#define TVM_FFI_DLL_EXPORT_TYPED_FUNC(ExportName, Function) \ - extern "C" { \ - TVM_FFI_DLL_EXPORT int ExportName(void* self, TVMFFIAny* args, int32_t num_args, \ - TVMFFIAny* result) { \ - TVM_FFI_SAFE_CALL_BEGIN(); \ - using FuncInfo = ::tvm::ffi::details::FunctionInfo; \ - static std::string name = #ExportName; \ - ::tvm::ffi::details::unpack_call( \ - std::make_index_sequence{}, &name, Function, \ - reinterpret_cast(args), num_args, \ - reinterpret_cast<::tvm::ffi::Any*>(result)); \ - TVM_FFI_SAFE_CALL_END(); \ - } \ +#define TVM_FFI_DLL_EXPORT_TYPED_FUNC(ExportName, Function) \ + extern "C" { \ + TVM_FFI_DLL_EXPORT int __tvm_ffi_##ExportName(void* self, TVMFFIAny* args, int32_t num_args, \ + TVMFFIAny* result) { \ + TVM_FFI_SAFE_CALL_BEGIN(); \ + using FuncInfo = ::tvm::ffi::details::FunctionInfo; \ + static std::string name = #ExportName; \ + ::tvm::ffi::details::unpack_call( \ + std::make_index_sequence{}, &name, Function, \ + reinterpret_cast(args), num_args, \ + reinterpret_cast<::tvm::ffi::Any*>(result)); \ + TVM_FFI_SAFE_CALL_END(); \ + } \ } } // namespace ffi } // namespace tvm diff --git a/ffi/python/tvm_ffi/module.py b/ffi/python/tvm_ffi/module.py index 56aa15348e8c..c3c1d089c612 100644 --- a/ffi/python/tvm_ffi/module.py +++ b/ffi/python/tvm_ffi/module.py @@ -40,7 +40,7 @@ class Module(core.Object): def __new__(cls): instance = super(Module, cls).__new__(cls) # pylint: disable=no-value-for-parameter - instance.entry_name = "__tvm_ffi_main__" + instance.entry_name = "main" instance._entry = None return instance @@ -55,7 +55,7 @@ def entry_func(self): """ if self._entry: return self._entry - self._entry = self.get_function("__tvm_ffi_main__") + self._entry = self.get_function("main") return self._entry @property diff --git a/ffi/src/ffi/extra/library_module.cc b/ffi/src/ffi/extra/library_module.cc index 71c6da6f7cc4..2864cdb5904a 100644 --- a/ffi/src/ffi/extra/library_module.cc +++ b/ffi/src/ffi/extra/library_module.cc @@ -42,7 +42,7 @@ class LibraryModuleObj final : public ModuleObj { Optional GetFunction(const String& name) final { TVMFFISafeCallType faddr; - faddr = reinterpret_cast(lib_->GetSymbol(name.c_str())); + faddr = reinterpret_cast(lib_->GetSymbolWithSymbolPrefix(name)); // ensure the function keeps the Library Module alive Module self_strong_ref = GetRef(this); if (faddr != nullptr) { @@ -140,7 +140,7 @@ class ContextSymbolRegistry { public: void InitContextSymbols(ObjectPtr lib) { for (const auto& [name, symbol] : context_symbols_) { - if (void** symbol_addr = reinterpret_cast(lib->GetSymbol(name.c_str()))) { + if (void** symbol_addr = reinterpret_cast(lib->GetSymbol(name))) { *symbol_addr = symbol; } } diff --git a/ffi/src/ffi/extra/library_module_dynamic_lib.cc b/ffi/src/ffi/extra/library_module_dynamic_lib.cc index 25463a7e5f92..e85b05180baf 100644 --- a/ffi/src/ffi/extra/library_module_dynamic_lib.cc +++ b/ffi/src/ffi/extra/library_module_dynamic_lib.cc @@ -49,7 +49,7 @@ class DSOLibrary final : public Library { if (lib_handle_) Unload(); } - void* GetSymbol(const char* name) final { return GetSymbol_(name); } + void* GetSymbol(const String& name) final { return GetSymbol_(name.c_str()); } private: // private system dependent implementation diff --git a/ffi/src/ffi/extra/library_module_system_lib.cc b/ffi/src/ffi/extra/library_module_system_lib.cc index cdc932cba292..e93c6602c267 100644 --- a/ffi/src/ffi/extra/library_module_system_lib.cc +++ b/ffi/src/ffi/extra/library_module_system_lib.cc @@ -45,7 +45,7 @@ class SystemLibSymbolRegistry { symbol_table_.Set(name, ptr); } - void* GetSymbol(const char* name) { + void* GetSymbol(const String& name) { auto it = symbol_table_.find(name); if (it != symbol_table_.end()) { return (*it).second; @@ -68,13 +68,14 @@ class SystemLibrary final : public Library { public: explicit SystemLibrary(const String& symbol_prefix) : symbol_prefix_(symbol_prefix) {} - void* GetSymbol(const char* name) { - if (symbol_prefix_.length() != 0) { - String name_with_prefix = symbol_prefix_ + name; - void* symbol = reg_->GetSymbol(name_with_prefix.c_str()); - if (symbol != nullptr) return symbol; - } - return reg_->GetSymbol(name); + void* GetSymbol(const String& name) final { + String name_with_prefix = symbol_prefix_ + name; + return reg_->GetSymbol(name_with_prefix); + } + + void* GetSymbolWithSymbolPrefix(const String& name) final { + String name_with_prefix = symbol::tvm_ffi_symbol_prefix + symbol_prefix_ + name; + return reg_->GetSymbol(name_with_prefix); } private: diff --git a/ffi/src/ffi/extra/module_internal.h b/ffi/src/ffi/extra/module_internal.h index 472d531f4b51..86cb6b66c1f6 100644 --- a/ffi/src/ffi/extra/module_internal.h +++ b/ffi/src/ffi/extra/module_internal.h @@ -48,7 +48,17 @@ class Library : public Object { * \param name The name of the symbol. * \return The symbol. */ - virtual void* GetSymbol(const char* name) = 0; + virtual void* GetSymbol(const String& name) = 0; + /*! + * \brief Get the symbol address for a given name with the tvm ffi symbol prefix. + * \param name The name of the symbol. + * \return The symbol. + * \note This function will be overloaded by systemlib implementation. + */ + virtual void* GetSymbolWithSymbolPrefix(const String& name) { + String name_with_prefix = symbol::tvm_ffi_symbol_prefix + name; + return GetSymbol(name_with_prefix); + } // NOTE: we do not explicitly create an type index and type_key here for libary. // This is because we do not need dynamic type downcasting and only need to use the refcounting }; diff --git a/jvm/core/src/main/java/org/apache/tvm/Module.java b/jvm/core/src/main/java/org/apache/tvm/Module.java index 46a74346760e..174457131f05 100644 --- a/jvm/core/src/main/java/org/apache/tvm/Module.java +++ b/jvm/core/src/main/java/org/apache/tvm/Module.java @@ -46,7 +46,7 @@ private static Function getApi(String name) { } private Function entry = null; - private final String entryName = "__tvm_ffi_main__"; + private final String entryName = "main"; /** diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 5ce8b1ec6584..34e9e8381898 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -229,6 +229,11 @@ void CodeGenCPU::AddFunction(const GlobalVar& gvar, const PrimFunc& func) { } void CodeGenCPU::AddMainFunction(const std::string& entry_func_name) { + if (module_->getFunction(ffi::symbol::tvm_ffi_main) != nullptr) { + // main already exists, no need to create a wrapper function + // main takes precedence over other entry functions + return; + } // create a wrapper function with tvm_ffi_main name and redirects to the entry function llvm::Function* target_func = module_->getFunction(entry_func_name); ICHECK(target_func) << "Function " << entry_func_name << " does not exist in module"; @@ -857,8 +862,9 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array& call_args.push_back(GetPackedFuncHandle(func_name)); call_args.insert(call_args.end(), {packed_args, ConstInt32(nargs), result}); } else { + // directly call into symbol, needs to prefix with tvm_ffi_symbol_prefix callee_ftype = ftype_tvm_ffi_c_func_; - callee_value = module_->getFunction(func_name); + callee_value = module_->getFunction(ffi::symbol::tvm_ffi_symbol_prefix + func_name); if (callee_value == nullptr) { callee_value = llvm::Function::Create(ftype_tvm_ffi_c_func_, llvm::Function::ExternalLinkage, func_name, module_.get()); diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 8ea438626532..6c88d6943423 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -189,7 +189,8 @@ Optional LLVMModuleNode::GetFunction(const String& name) { TVMFFISafeCallType faddr; With llvm_target(*llvm_instance_, LLVMTarget::GetTargetMetadata(*module_)); - faddr = reinterpret_cast(GetFunctionAddr(name, *llvm_target)); + String name_with_prefix = ffi::symbol::tvm_ffi_symbol_prefix + name; + faddr = reinterpret_cast(GetFunctionAddr(name_with_prefix, *llvm_target)); if (faddr == nullptr) return std::nullopt; ffi::Module self_strong_ref = GetRef(this); return ffi::Function::FromPacked([faddr, self_strong_ref](ffi::PackedArgs args, ffi::Any* rv) { @@ -386,7 +387,8 @@ void LLVMModuleNode::LoadIR(const std::string& file_name) { } bool LLVMModuleNode::ImplementsFunction(const String& name) { - return std::find(function_names_.begin(), function_names_.end(), name) != function_names_.end(); + return std::find(function_names_.begin(), function_names_.end(), + ffi::symbol::tvm_ffi_symbol_prefix + name) != function_names_.end(); } void LLVMModuleNode::InitMCJIT() { diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index acc05cf96c08..65c57cf882b4 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -149,7 +149,9 @@ void CodeGenC::DeclareFunction(const GlobalVar& gvar, const PrimFunc& func) { return gvar->name_hint; } }(); - + if (function_name == ffi::symbol::tvm_ffi_main) { + has_tvm_ffi_main_func_ = true; + } internal_functions_.insert({gvar, function_name}); InitFuncState(func); diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index 8c5e1ffd897b..02cb4cd9a779 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -319,6 +319,8 @@ class CodeGenC : public ExprFunctor, Integer constants_byte_alignment_ = 16; /*! \brief whether to print in SSA form */ bool print_ssa_form_{false}; + /*! \brief whether the module has a main function declared */ + bool has_tvm_ffi_main_func_{false}; private: /*! \brief set of volatile buf access */ diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index e18ba0128d6b..a4cbc46f0cca 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -35,7 +35,9 @@ namespace tvm { namespace codegen { -CodeGenCHost::CodeGenCHost() { module_name_ = name_supply_->FreshName("__tvm_ffi_library_ctx"); } +CodeGenCHost::CodeGenCHost() { + module_name_ = name_supply_->FreshName(ffi::symbol::tvm_ffi_library_ctx); +} void CodeGenCHost::Init(bool output_ssa, bool emit_asserts, bool emit_fwd_func_decl, std::string target_str, const std::unordered_set& devices) { @@ -72,7 +74,7 @@ void CodeGenCHost::AddFunction(const GlobalVar& gvar, const PrimFunc& func, emit_fwd_func_decl_ = emit_fwd_func_decl; CodeGenC::AddFunction(gvar, func); - if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { + if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc) && !has_tvm_ffi_main_func_) { ICHECK(global_symbol.has_value()) << "CodeGenCHost: The entry func must have the global_symbol attribute, " << "but function " << gvar << " only has attributes " << func->attrs; @@ -235,7 +237,7 @@ void CodeGenCHost::PrintCallPacked(const CallNode* op) { } else { // directly use the original symbol ICHECK(op->op.same_as(builtin::tvm_call_cpacked_lowered())); - packed_func_name = func_name->value; + packed_func_name = ffi::symbol::tvm_ffi_symbol_prefix + func_name->value; } std::string args_stack = PrintExpr(op->args[1]); diff --git a/src/target/source/codegen_c_host.h b/src/target/source/codegen_c_host.h index 4a2f530e2f98..1c7e65b3b2cb 100644 --- a/src/target/source/codegen_c_host.h +++ b/src/target/source/codegen_c_host.h @@ -44,6 +44,7 @@ class CodeGenCHost : public CodeGenC { const std::unordered_set& devices); void InitGlobalContext(); + void AddFunction(const GlobalVar& gvar, const PrimFunc& f) override; void AddFunction(const GlobalVar& gvar, const PrimFunc& f, bool emit_fwd_func_decl); /*! @@ -83,6 +84,8 @@ class CodeGenCHost : public CodeGenC { bool emit_asserts_; /*! \brief whether to emit forwared function declarations in the resulting C code */ bool emit_fwd_func_decl_; + /*! \brief whether to generate the entry function if encountered */ + bool has_main_func_ = false; std::string GetPackedName(const CallNode* op); void PrintGetFuncFromBackend(const std::string& func_name, const std::string& packed_func_name); diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index e6c6e9aa0275..f557cab91ad8 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -20,6 +20,7 @@ /*! * \file make_packed_api.cc Lower PrimFunc to use the packed function API. */ +#include #include #include #include @@ -196,7 +197,7 @@ Optional RequiresPackedAPI(const PrimFunc& func) { return std::nullopt; } - return global_symbol; + return global_symbol.value(); } PrimFunc MakePackedAPI(PrimFunc func) { @@ -223,6 +224,7 @@ PrimFunc MakePackedAPI(PrimFunc func) { } auto* func_ptr = func.CopyOnWrite(); + // set the global symbol to the packed function name const Stmt nop = Evaluate(0); int num_args = static_cast(func_ptr->params.size()); @@ -362,10 +364,12 @@ PrimFunc MakePackedAPI(PrimFunc func) { binder.BindDLTensor(buffer, device_type, device_id, var, name_hint + "." + var->name_hint); arg_buffer_declarations.push_back(DeclBuffer(buffer, nop)); } - - func = WithAttrs(std::move(func), - {{tvm::attr::kCallingConv, static_cast(CallingConv::kCPackedFunc)}, - {tvm::attr::kTarget, target_host}}); + // reset global symbol to attach prefix + func = WithAttrs( + std::move(func), + {{tvm::attr::kCallingConv, static_cast(CallingConv::kCPackedFunc)}, + {tvm::attr::kTarget, target_host}, + {tvm::attr::kGlobalSymbol, ffi::symbol::tvm_ffi_symbol_prefix + global_symbol.value()}}); Stmt body = ReturnRewriter(v_result)(func_ptr->body); body = AttrStmt(make_zero(DataType::Int(32)), attr::compute_scope, diff --git a/tests/python/codegen/test_target_codegen_c_host.py b/tests/python/codegen/test_target_codegen_c_host.py index 3c80cfbeb0b4..8f3798861f46 100644 --- a/tests/python/codegen/test_target_codegen_c_host.py +++ b/tests/python/codegen/test_target_codegen_c_host.py @@ -184,17 +184,9 @@ def subroutine(A_data: T.handle("float32")): built = tvm.tir.build(mod, target="c") - func_names = list(built["get_func_names"]()) - assert ( - "main" in func_names - ), "Externally exposed functions should be listed in available functions." - assert ( - "subroutine" not in func_names - ), "Internal function should not be listed in available functions." - source = built.inspect_source() assert ( - source.count("main(void*") == 2 + source.count("__tvm_ffi_main(void*") == 2 ), "Expected two occurrences, for forward-declaration and definition" assert ( source.count("subroutine(float*") == 2 diff --git a/tests/python/codegen/test_target_codegen_llvm.py b/tests/python/codegen/test_target_codegen_llvm.py index 953adf78b342..b303cf289eca 100644 --- a/tests/python/codegen/test_target_codegen_llvm.py +++ b/tests/python/codegen/test_target_codegen_llvm.py @@ -953,7 +953,10 @@ def test_llvm_target_attributes(): assert re.match('.*"target-cpu"="skylake".*', attribute_definitions[k]) assert re.match('.*"target-features"=".*[+]avx512f.*".*', attribute_definitions[k]) - expected_functions = ["test_func", "test_func_compute_", "__tvm_parallel_lambda"] + expected_functions = [ + "__tvm_ffi_test_func", + "__tvm_parallel_lambda", + ] for n in expected_functions: assert n in functions_with_target diff --git a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py index 965795d29e02..ab1cce52eac8 100644 --- a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py +++ b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -""" Test different strategies for loading data into vtcm before running HVX workloads. """ +"""Test different strategies for loading data into vtcm before running HVX workloads.""" import numpy as np import pytest @@ -287,13 +287,9 @@ def evaluate( if tvm.testing.utils.IS_IN_CI: # Run with reduced number and repeat for CI - timer = module.time_evaluator( - "__tvm_ffi_main__", hexagon_session.device, number=1, repeat=1 - ) + timer = module.time_evaluator("main", hexagon_session.device, number=1, repeat=1) else: - timer = module.time_evaluator( - "__tvm_ffi_main__", hexagon_session.device, number=10, repeat=10 - ) + timer = module.time_evaluator("main", hexagon_session.device, number=10, repeat=10) time = timer(a_hexagon, b_hexagon, c_hexagon) if expected_output is not None: diff --git a/tests/python/contrib/test_hexagon/test_parallel_hvx.py b/tests/python/contrib/test_hexagon/test_parallel_hvx.py index 6e1b7db4d5c5..cab3f7d64f9b 100644 --- a/tests/python/contrib/test_hexagon/test_parallel_hvx.py +++ b/tests/python/contrib/test_hexagon/test_parallel_hvx.py @@ -156,9 +156,7 @@ def evaluate(hexagon_session, shape_dtypes, expected_output_producer, sch): number = 1 repeat = 1 - timer = module.time_evaluator( - "__tvm_ffi_main__", hexagon_session.device, number=number, repeat=repeat - ) + timer = module.time_evaluator("main", hexagon_session.device, number=number, repeat=repeat) runtime = timer(a_hexagon, b_hexagon, c_hexagon) tvm.testing.assert_allclose(c_hexagon.numpy(), expected_output_producer(c_shape, a, b)) diff --git a/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py b/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py index a0b94d89cfa6..89385b2aeb8f 100644 --- a/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py +++ b/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -""" Test different strategies for loading data into vtcm before running HVX workloads. """ +"""Test different strategies for loading data into vtcm before running HVX workloads.""" import numpy as np import tvm @@ -326,9 +326,7 @@ def setup_and_run(hexagon_session, sch, a, b, c, operations, mem_scope="global") number = 1 repeat = 1 - timer = module.time_evaluator( - "__tvm_ffi_main__", hexagon_session.device, number=number, repeat=repeat - ) + timer = module.time_evaluator("main", hexagon_session.device, number=number, repeat=repeat) time = timer(a_hexagon, b_hexagon, c_hexagon) gops = round(operations * 128 * 3 / time.mean / 1e9, 4) return gops, c_hexagon.numpy() @@ -360,9 +358,7 @@ def setup_and_run_preallocated(hexagon_session, sch, a, b, c, operations): number = 1 repeat = 1 - timer = module.time_evaluator( - "__tvm_ffi_main__", hexagon_session.device, number=number, repeat=repeat - ) + timer = module.time_evaluator("main", hexagon_session.device, number=number, repeat=repeat) time = timer(a_hexagon, b_hexagon, c_hexagon, a_vtcm_hexagon, b_vtcm_hexagon, c_vtcm_hexagon) gops = round(operations * 128 * 3 / time.mean / 1e9, 4) return gops, c_hexagon.numpy() diff --git a/tests/python/contrib/test_hexagon/test_parallel_scalar.py b/tests/python/contrib/test_hexagon/test_parallel_scalar.py index dd765178dc32..d9b9a2480312 100644 --- a/tests/python/contrib/test_hexagon/test_parallel_scalar.py +++ b/tests/python/contrib/test_hexagon/test_parallel_scalar.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -""" Test parallelism for multiple different scalar workloads. """ +"""Test parallelism for multiple different scalar workloads.""" import numpy as np @@ -104,9 +104,7 @@ def evaluate(hexagon_session, operations, expected, sch): number = 1 repeat = 1 - timer = module.time_evaluator( - "__tvm_ffi_main__", hexagon_session.device, number=number, repeat=repeat - ) + timer = module.time_evaluator("main", hexagon_session.device, number=number, repeat=repeat) runtime = timer(a_hexagon, b_hexagon, c_hexagon) tvm.testing.assert_allclose(c_hexagon.numpy(), expected(a, b)) diff --git a/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py b/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py index 265f2bf5fd2d..015a9f0656ed 100644 --- a/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py +++ b/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py @@ -108,13 +108,9 @@ def evaluate(hexagon_session, sch, size): if tvm.testing.utils.IS_IN_CI: # Run with reduced number and repeat for CI - timer = module.time_evaluator( - "__tvm_ffi_main__", hexagon_session.device, number=1, repeat=1 - ) + timer = module.time_evaluator("main", hexagon_session.device, number=1, repeat=1) else: - timer = module.time_evaluator( - "__tvm_ffi_main__", hexagon_session.device, number=10, repeat=10 - ) + timer = module.time_evaluator("main", hexagon_session.device, number=10, repeat=10) runtime = timer(a_hexagon, a_vtcm_hexagon) diff --git a/tests/python/tir-transform/test_tir_transform_make_packed_api.py b/tests/python/tir-transform/test_tir_transform_make_packed_api.py index dd7bd3bf54a2..4fecafef1d15 100644 --- a/tests/python/tir-transform/test_tir_transform_make_packed_api.py +++ b/tests/python/tir-transform/test_tir_transform_make_packed_api.py @@ -261,6 +261,7 @@ def func_without_arg( { "calling_conv": 1, "target": T.target("llvm"), + "global_symbol": "__tvm_ffi_func_without_arg", } ) assert num_args == 0, "func_without_arg: num_args should be 0" @@ -315,6 +316,7 @@ def main( { "calling_conv": 1, "target": T.target("llvm"), + "global_symbol": "__tvm_ffi_main", } ) assert num_args == 1, "main: num_args should be 1" @@ -372,6 +374,7 @@ def main( { "calling_conv": 1, "target": T.target("llvm"), + "global_symbol": "__tvm_ffi_main", } ) assert num_args == 1, "main: num_args should be 1"