Skip to content

Commit 43923b6

Browse files
committed
[FFI][ABI] Append symbol prefix for ffi exported functions
Previously we simply take the raw symbol for DSO libraries. This can cause symbol conflict of functions that take the ffi calling convention and those that are not. This PR updates the convention to ask for LLVM and libary module to always append a prefix __tvm_ffi_ to function symbols, this way we will no longer have conflict in TVM_FFI_EXPORT_DLL_TYPED macro
1 parent 86b391a commit 43923b6

23 files changed

+101
-82
lines changed

ffi/include/tvm/ffi/extra/module.h

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -223,14 +223,19 @@ class Module : public ObjectRef {
223223
* \brief Symbols for library module.
224224
*/
225225
namespace symbol {
226+
/*!\ brief symbol prefix for tvm ffi related function symbols */
227+
constexpr const char* tvm_ffi_symbol_prefix = "__tvm_ffi_";
228+
// Special symbols have one extra _ prefix to avoid conflict with user symbols
229+
/*!
230+
* \brief Default entry function of a library module is tvm_ffi_symbol_prefix + "main"
231+
*/
232+
constexpr const char* tvm_ffi_main = "__tvm_ffi_main";
226233
/*! \brief Global variable to store context pointer for a library module. */
227-
constexpr const char* tvm_ffi_library_ctx = "__tvm_ffi_library_ctx";
234+
constexpr const char* tvm_ffi_library_ctx = "__tvm_ffi__library_ctx";
228235
/*! \brief Global variable to store binary data alongside a library module. */
229-
constexpr const char* tvm_ffi_library_bin = "__tvm_ffi_library_bin";
236+
constexpr const char* tvm_ffi_library_bin = "__tvm_ffi__library_bin";
230237
/*! \brief Optional metadata prefix of a symbol. */
231-
constexpr const char* tvm_ffi_metadata_prefix = "__tvm_ffi_metadata_";
232-
/*! \brief Default entry function of a library module. */
233-
constexpr const char* tvm_ffi_main = "__tvm_ffi_main__";
238+
constexpr const char* tvm_ffi_metadata_prefix = "__tvm_ffi__metadata_";
234239
} // namespace symbol
235240
} // namespace ffi
236241
} // namespace tvm

ffi/include/tvm/ffi/function.h

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -800,19 +800,19 @@ inline int32_t TypeKeyToIndex(std::string_view type_key) {
800800
*
801801
* \endcode
802802
*/
803-
#define TVM_FFI_DLL_EXPORT_TYPED_FUNC(ExportName, Function) \
804-
extern "C" { \
805-
TVM_FFI_DLL_EXPORT int ExportName(void* self, TVMFFIAny* args, int32_t num_args, \
806-
TVMFFIAny* result) { \
807-
TVM_FFI_SAFE_CALL_BEGIN(); \
808-
using FuncInfo = ::tvm::ffi::details::FunctionInfo<decltype(Function)>; \
809-
static std::string name = #ExportName; \
810-
::tvm::ffi::details::unpack_call<typename FuncInfo::RetType>( \
811-
std::make_index_sequence<FuncInfo::num_args>{}, &name, Function, \
812-
reinterpret_cast<const ::tvm::ffi::AnyView*>(args), num_args, \
813-
reinterpret_cast<::tvm::ffi::Any*>(result)); \
814-
TVM_FFI_SAFE_CALL_END(); \
815-
} \
803+
#define TVM_FFI_DLL_EXPORT_TYPED_FUNC(ExportName, Function) \
804+
extern "C" { \
805+
TVM_FFI_DLL_EXPORT int __tvm_ffi_##ExportName(void* self, TVMFFIAny* args, int32_t num_args, \
806+
TVMFFIAny* result) { \
807+
TVM_FFI_SAFE_CALL_BEGIN(); \
808+
using FuncInfo = ::tvm::ffi::details::FunctionInfo<decltype(Function)>; \
809+
static std::string name = #ExportName; \
810+
::tvm::ffi::details::unpack_call<typename FuncInfo::RetType>( \
811+
std::make_index_sequence<FuncInfo::num_args>{}, &name, Function, \
812+
reinterpret_cast<const ::tvm::ffi::AnyView*>(args), num_args, \
813+
reinterpret_cast<::tvm::ffi::Any*>(result)); \
814+
TVM_FFI_SAFE_CALL_END(); \
815+
} \
816816
}
817817
} // namespace ffi
818818
} // namespace tvm

ffi/python/tvm_ffi/module.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class Module(core.Object):
4040

4141
def __new__(cls):
4242
instance = super(Module, cls).__new__(cls) # pylint: disable=no-value-for-parameter
43-
instance.entry_name = "__tvm_ffi_main__"
43+
instance.entry_name = "main"
4444
instance._entry = None
4545
return instance
4646

@@ -55,7 +55,7 @@ def entry_func(self):
5555
"""
5656
if self._entry:
5757
return self._entry
58-
self._entry = self.get_function("__tvm_ffi_main__")
58+
self._entry = self.get_function("main")
5959
return self._entry
6060

6161
@property

ffi/src/ffi/extra/library_module.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class LibraryModuleObj final : public ModuleObj {
4242

4343
Optional<ffi::Function> GetFunction(const String& name) final {
4444
TVMFFISafeCallType faddr;
45-
faddr = reinterpret_cast<TVMFFISafeCallType>(lib_->GetSymbol(name.c_str()));
45+
faddr = reinterpret_cast<TVMFFISafeCallType>(lib_->GetSymbolWithSymbolPrefix(name));
4646
// ensure the function keeps the Library Module alive
4747
Module self_strong_ref = GetRef<Module>(this);
4848
if (faddr != nullptr) {
@@ -140,7 +140,7 @@ class ContextSymbolRegistry {
140140
public:
141141
void InitContextSymbols(ObjectPtr<Library> lib) {
142142
for (const auto& [name, symbol] : context_symbols_) {
143-
if (void** symbol_addr = reinterpret_cast<void**>(lib->GetSymbol(name.c_str()))) {
143+
if (void** symbol_addr = reinterpret_cast<void**>(lib->GetSymbol(name))) {
144144
*symbol_addr = symbol;
145145
}
146146
}

ffi/src/ffi/extra/library_module_dynamic_lib.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class DSOLibrary final : public Library {
4949
if (lib_handle_) Unload();
5050
}
5151

52-
void* GetSymbol(const char* name) final { return GetSymbol_(name); }
52+
void* GetSymbol(const String& name) final { return GetSymbol_(name.c_str()); }
5353

5454
private:
5555
// private system dependent implementation

ffi/src/ffi/extra/library_module_system_lib.cc

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class SystemLibSymbolRegistry {
4545
symbol_table_.Set(name, ptr);
4646
}
4747

48-
void* GetSymbol(const char* name) {
48+
void* GetSymbol(const String& name) {
4949
auto it = symbol_table_.find(name);
5050
if (it != symbol_table_.end()) {
5151
return (*it).second;
@@ -68,13 +68,14 @@ class SystemLibrary final : public Library {
6868
public:
6969
explicit SystemLibrary(const String& symbol_prefix) : symbol_prefix_(symbol_prefix) {}
7070

71-
void* GetSymbol(const char* name) {
72-
if (symbol_prefix_.length() != 0) {
73-
String name_with_prefix = symbol_prefix_ + name;
74-
void* symbol = reg_->GetSymbol(name_with_prefix.c_str());
75-
if (symbol != nullptr) return symbol;
76-
}
77-
return reg_->GetSymbol(name);
71+
void* GetSymbol(const String& name) final {
72+
String name_with_prefix = symbol_prefix_ + name;
73+
return reg_->GetSymbol(name_with_prefix);
74+
}
75+
76+
void* GetSymbolWithSymbolPrefix(const String& name) final {
77+
String name_with_prefix = symbol::tvm_ffi_symbol_prefix + symbol_prefix_ + name;
78+
return reg_->GetSymbol(name_with_prefix);
7879
}
7980

8081
private:

ffi/src/ffi/extra/module_internal.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,17 @@ class Library : public Object {
4848
* \param name The name of the symbol.
4949
* \return The symbol.
5050
*/
51-
virtual void* GetSymbol(const char* name) = 0;
51+
virtual void* GetSymbol(const String& name) = 0;
52+
/*!
53+
* \brief Get the symbol address for a given name with the tvm ffi symbol prefix.
54+
* \param name The name of the symbol.
55+
* \return The symbol.
56+
* \note This function will be overloaded by systemlib implementation.
57+
*/
58+
virtual void* GetSymbolWithSymbolPrefix(const String& name) {
59+
String name_with_prefix = symbol::tvm_ffi_symbol_prefix + name;
60+
return GetSymbol(name_with_prefix);
61+
}
5262
// NOTE: we do not explicitly create an type index and type_key here for libary.
5363
// This is because we do not need dynamic type downcasting and only need to use the refcounting
5464
};

jvm/core/src/main/java/org/apache/tvm/Module.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ private static Function getApi(String name) {
4646
}
4747

4848
private Function entry = null;
49-
private final String entryName = "__tvm_ffi_main__";
49+
private final String entryName = "main";
5050

5151

5252
/**

src/target/llvm/codegen_cpu.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,11 @@ void CodeGenCPU::AddFunction(const GlobalVar& gvar, const PrimFunc& func) {
229229
}
230230

231231
void CodeGenCPU::AddMainFunction(const std::string& entry_func_name) {
232+
if (module_->getFunction(ffi::symbol::tvm_ffi_main) != nullptr) {
233+
// main already exists, no need to create a wrapper function
234+
// main takes precedence over other entry functions
235+
return;
236+
}
232237
// create a wrapper function with tvm_ffi_main name and redirects to the entry function
233238
llvm::Function* target_func = module_->getFunction(entry_func_name);
234239
ICHECK(target_func) << "Function " << entry_func_name << " does not exist in module";
@@ -857,8 +862,9 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array<PrimExpr>&
857862
call_args.push_back(GetPackedFuncHandle(func_name));
858863
call_args.insert(call_args.end(), {packed_args, ConstInt32(nargs), result});
859864
} else {
865+
// directly call into symbol, needs to prefix with tvm_ffi_symbol_prefix
860866
callee_ftype = ftype_tvm_ffi_c_func_;
861-
callee_value = module_->getFunction(func_name);
867+
callee_value = module_->getFunction(ffi::symbol::tvm_ffi_symbol_prefix + func_name);
862868
if (callee_value == nullptr) {
863869
callee_value = llvm::Function::Create(ftype_tvm_ffi_c_func_, llvm::Function::ExternalLinkage,
864870
func_name, module_.get());

src/target/llvm/llvm_module.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,8 @@ Optional<ffi::Function> LLVMModuleNode::GetFunction(const String& name) {
189189

190190
TVMFFISafeCallType faddr;
191191
With<LLVMTarget> llvm_target(*llvm_instance_, LLVMTarget::GetTargetMetadata(*module_));
192-
faddr = reinterpret_cast<TVMFFISafeCallType>(GetFunctionAddr(name, *llvm_target));
192+
String name_with_prefix = ffi::symbol::tvm_ffi_symbol_prefix + name;
193+
faddr = reinterpret_cast<TVMFFISafeCallType>(GetFunctionAddr(name_with_prefix, *llvm_target));
193194
if (faddr == nullptr) return std::nullopt;
194195
ffi::Module self_strong_ref = GetRef<ffi::Module>(this);
195196
return ffi::Function::FromPacked([faddr, self_strong_ref](ffi::PackedArgs args, ffi::Any* rv) {
@@ -386,7 +387,8 @@ void LLVMModuleNode::LoadIR(const std::string& file_name) {
386387
}
387388

388389
bool LLVMModuleNode::ImplementsFunction(const String& name) {
389-
return std::find(function_names_.begin(), function_names_.end(), name) != function_names_.end();
390+
return std::find(function_names_.begin(), function_names_.end(),
391+
ffi::symbol::tvm_ffi_symbol_prefix + name) != function_names_.end();
390392
}
391393

392394
void LLVMModuleNode::InitMCJIT() {

0 commit comments

Comments
 (0)