Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions ffi/include/tvm/ffi/extra/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,14 +223,19 @@ class Module : public ObjectRef {
* \brief Symbols for library module.
*/
namespace symbol {
/*!\ brief symbol prefix for tvm ffi related function symbols */
constexpr const char* tvm_ffi_symbol_prefix = "__tvm_ffi_";
// Special symbols have one extra _ prefix to avoid conflict with user symbols
/*!
* \brief Default entry function of a library module is tvm_ffi_symbol_prefix + "main"
*/
constexpr const char* tvm_ffi_main = "__tvm_ffi_main";
/*! \brief Global variable to store context pointer for a library module. */
constexpr const char* tvm_ffi_library_ctx = "__tvm_ffi_library_ctx";
constexpr const char* tvm_ffi_library_ctx = "__tvm_ffi__library_ctx";
/*! \brief Global variable to store binary data alongside a library module. */
constexpr const char* tvm_ffi_library_bin = "__tvm_ffi_library_bin";
constexpr const char* tvm_ffi_library_bin = "__tvm_ffi__library_bin";
/*! \brief Optional metadata prefix of a symbol. */
constexpr const char* tvm_ffi_metadata_prefix = "__tvm_ffi_metadata_";
/*! \brief Default entry function of a library module. */
constexpr const char* tvm_ffi_main = "__tvm_ffi_main__";
constexpr const char* tvm_ffi_metadata_prefix = "__tvm_ffi__metadata_";
} // namespace symbol
} // namespace ffi
} // namespace tvm
Expand Down
26 changes: 13 additions & 13 deletions ffi/include/tvm/ffi/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -800,19 +800,19 @@ inline int32_t TypeKeyToIndex(std::string_view type_key) {
*
* \endcode
*/
#define TVM_FFI_DLL_EXPORT_TYPED_FUNC(ExportName, Function) \
extern "C" { \
TVM_FFI_DLL_EXPORT int ExportName(void* self, TVMFFIAny* args, int32_t num_args, \
TVMFFIAny* result) { \
TVM_FFI_SAFE_CALL_BEGIN(); \
using FuncInfo = ::tvm::ffi::details::FunctionInfo<decltype(Function)>; \
static std::string name = #ExportName; \
::tvm::ffi::details::unpack_call<typename FuncInfo::RetType>( \
std::make_index_sequence<FuncInfo::num_args>{}, &name, Function, \
reinterpret_cast<const ::tvm::ffi::AnyView*>(args), num_args, \
reinterpret_cast<::tvm::ffi::Any*>(result)); \
TVM_FFI_SAFE_CALL_END(); \
} \
#define TVM_FFI_DLL_EXPORT_TYPED_FUNC(ExportName, Function) \
extern "C" { \
TVM_FFI_DLL_EXPORT int __tvm_ffi_##ExportName(void* self, TVMFFIAny* args, int32_t num_args, \
TVMFFIAny* result) { \
TVM_FFI_SAFE_CALL_BEGIN(); \
using FuncInfo = ::tvm::ffi::details::FunctionInfo<decltype(Function)>; \
static std::string name = #ExportName; \
::tvm::ffi::details::unpack_call<typename FuncInfo::RetType>( \
std::make_index_sequence<FuncInfo::num_args>{}, &name, Function, \
reinterpret_cast<const ::tvm::ffi::AnyView*>(args), num_args, \
reinterpret_cast<::tvm::ffi::Any*>(result)); \
TVM_FFI_SAFE_CALL_END(); \
} \
}
} // namespace ffi
} // namespace tvm
Expand Down
4 changes: 2 additions & 2 deletions ffi/python/tvm_ffi/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class Module(core.Object):

def __new__(cls):
instance = super(Module, cls).__new__(cls) # pylint: disable=no-value-for-parameter
instance.entry_name = "__tvm_ffi_main__"
instance.entry_name = "main"
instance._entry = None
return instance

Expand All @@ -55,7 +55,7 @@ def entry_func(self):
"""
if self._entry:
return self._entry
self._entry = self.get_function("__tvm_ffi_main__")
self._entry = self.get_function("main")
return self._entry

@property
Expand Down
4 changes: 2 additions & 2 deletions ffi/src/ffi/extra/library_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class LibraryModuleObj final : public ModuleObj {

Optional<ffi::Function> GetFunction(const String& name) final {
TVMFFISafeCallType faddr;
faddr = reinterpret_cast<TVMFFISafeCallType>(lib_->GetSymbol(name.c_str()));
faddr = reinterpret_cast<TVMFFISafeCallType>(lib_->GetSymbolWithSymbolPrefix(name));
// ensure the function keeps the Library Module alive
Module self_strong_ref = GetRef<Module>(this);
if (faddr != nullptr) {
Expand Down Expand Up @@ -140,7 +140,7 @@ class ContextSymbolRegistry {
public:
void InitContextSymbols(ObjectPtr<Library> lib) {
for (const auto& [name, symbol] : context_symbols_) {
if (void** symbol_addr = reinterpret_cast<void**>(lib->GetSymbol(name.c_str()))) {
if (void** symbol_addr = reinterpret_cast<void**>(lib->GetSymbol(name))) {
*symbol_addr = symbol;
}
}
Expand Down
2 changes: 1 addition & 1 deletion ffi/src/ffi/extra/library_module_dynamic_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class DSOLibrary final : public Library {
if (lib_handle_) Unload();
}

void* GetSymbol(const char* name) final { return GetSymbol_(name); }
void* GetSymbol(const String& name) final { return GetSymbol_(name.c_str()); }

private:
// private system dependent implementation
Expand Down
17 changes: 9 additions & 8 deletions ffi/src/ffi/extra/library_module_system_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class SystemLibSymbolRegistry {
symbol_table_.Set(name, ptr);
}

void* GetSymbol(const char* name) {
void* GetSymbol(const String& name) {
auto it = symbol_table_.find(name);
if (it != symbol_table_.end()) {
return (*it).second;
Expand All @@ -68,13 +68,14 @@ class SystemLibrary final : public Library {
public:
explicit SystemLibrary(const String& symbol_prefix) : symbol_prefix_(symbol_prefix) {}

void* GetSymbol(const char* name) {
if (symbol_prefix_.length() != 0) {
String name_with_prefix = symbol_prefix_ + name;
void* symbol = reg_->GetSymbol(name_with_prefix.c_str());
if (symbol != nullptr) return symbol;
}
return reg_->GetSymbol(name);
void* GetSymbol(const String& name) final {
String name_with_prefix = symbol_prefix_ + name;
return reg_->GetSymbol(name_with_prefix);
}

void* GetSymbolWithSymbolPrefix(const String& name) final {
String name_with_prefix = symbol::tvm_ffi_symbol_prefix + symbol_prefix_ + name;
return reg_->GetSymbol(name_with_prefix);
}

private:
Expand Down
12 changes: 11 additions & 1 deletion ffi/src/ffi/extra/module_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,17 @@ class Library : public Object {
* \param name The name of the symbol.
* \return The symbol.
*/
virtual void* GetSymbol(const char* name) = 0;
virtual void* GetSymbol(const String& name) = 0;
/*!
* \brief Get the symbol address for a given name with the tvm ffi symbol prefix.
* \param name The name of the symbol.
* \return The symbol.
* \note This function will be overloaded by systemlib implementation.
*/
virtual void* GetSymbolWithSymbolPrefix(const String& name) {
String name_with_prefix = symbol::tvm_ffi_symbol_prefix + name;
return GetSymbol(name_with_prefix);
}
// NOTE: we do not explicitly create an type index and type_key here for libary.
// This is because we do not need dynamic type downcasting and only need to use the refcounting
};
Expand Down
2 changes: 1 addition & 1 deletion jvm/core/src/main/java/org/apache/tvm/Module.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ private static Function getApi(String name) {
}

private Function entry = null;
private final String entryName = "__tvm_ffi_main__";
private final String entryName = "main";


/**
Expand Down
8 changes: 7 additions & 1 deletion src/target/llvm/codegen_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,11 @@ void CodeGenCPU::AddFunction(const GlobalVar& gvar, const PrimFunc& func) {
}

void CodeGenCPU::AddMainFunction(const std::string& entry_func_name) {
if (module_->getFunction(ffi::symbol::tvm_ffi_main) != nullptr) {
// main already exists, no need to create a wrapper function
// main takes precedence over other entry functions
return;
}
// create a wrapper function with tvm_ffi_main name and redirects to the entry function
llvm::Function* target_func = module_->getFunction(entry_func_name);
ICHECK(target_func) << "Function " << entry_func_name << " does not exist in module";
Expand Down Expand Up @@ -857,8 +862,9 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array<PrimExpr>&
call_args.push_back(GetPackedFuncHandle(func_name));
call_args.insert(call_args.end(), {packed_args, ConstInt32(nargs), result});
} else {
// directly call into symbol, needs to prefix with tvm_ffi_symbol_prefix
callee_ftype = ftype_tvm_ffi_c_func_;
callee_value = module_->getFunction(func_name);
callee_value = module_->getFunction(ffi::symbol::tvm_ffi_symbol_prefix + func_name);
if (callee_value == nullptr) {
callee_value = llvm::Function::Create(ftype_tvm_ffi_c_func_, llvm::Function::ExternalLinkage,
func_name, module_.get());
Expand Down
6 changes: 4 additions & 2 deletions src/target/llvm/llvm_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,8 @@ Optional<ffi::Function> LLVMModuleNode::GetFunction(const String& name) {

TVMFFISafeCallType faddr;
With<LLVMTarget> llvm_target(*llvm_instance_, LLVMTarget::GetTargetMetadata(*module_));
faddr = reinterpret_cast<TVMFFISafeCallType>(GetFunctionAddr(name, *llvm_target));
String name_with_prefix = ffi::symbol::tvm_ffi_symbol_prefix + name;
faddr = reinterpret_cast<TVMFFISafeCallType>(GetFunctionAddr(name_with_prefix, *llvm_target));
if (faddr == nullptr) return std::nullopt;
ffi::Module self_strong_ref = GetRef<ffi::Module>(this);
return ffi::Function::FromPacked([faddr, self_strong_ref](ffi::PackedArgs args, ffi::Any* rv) {
Expand Down Expand Up @@ -386,7 +387,8 @@ void LLVMModuleNode::LoadIR(const std::string& file_name) {
}

bool LLVMModuleNode::ImplementsFunction(const String& name) {
return std::find(function_names_.begin(), function_names_.end(), name) != function_names_.end();
return std::find(function_names_.begin(), function_names_.end(),
ffi::symbol::tvm_ffi_symbol_prefix + name) != function_names_.end();
}

void LLVMModuleNode::InitMCJIT() {
Expand Down
4 changes: 3 additions & 1 deletion src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,9 @@ void CodeGenC::DeclareFunction(const GlobalVar& gvar, const PrimFunc& func) {
return gvar->name_hint;
}
}();

if (function_name == ffi::symbol::tvm_ffi_main) {
has_tvm_ffi_main_func_ = true;
}
internal_functions_.insert({gvar, function_name});

InitFuncState(func);
Expand Down
2 changes: 2 additions & 0 deletions src/target/source/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,8 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
Integer constants_byte_alignment_ = 16;
/*! \brief whether to print in SSA form */
bool print_ssa_form_{false};
/*! \brief whether the module has a main function declared */
bool has_tvm_ffi_main_func_{false};

private:
/*! \brief set of volatile buf access */
Expand Down
8 changes: 5 additions & 3 deletions src/target/source/codegen_c_host.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
namespace tvm {
namespace codegen {

CodeGenCHost::CodeGenCHost() { module_name_ = name_supply_->FreshName("__tvm_ffi_library_ctx"); }
CodeGenCHost::CodeGenCHost() {
module_name_ = name_supply_->FreshName(ffi::symbol::tvm_ffi_library_ctx);
}

void CodeGenCHost::Init(bool output_ssa, bool emit_asserts, bool emit_fwd_func_decl,
std::string target_str, const std::unordered_set<std::string>& devices) {
Expand Down Expand Up @@ -72,7 +74,7 @@ void CodeGenCHost::AddFunction(const GlobalVar& gvar, const PrimFunc& func,

emit_fwd_func_decl_ = emit_fwd_func_decl;
CodeGenC::AddFunction(gvar, func);
if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc) && !has_tvm_ffi_main_func_) {
ICHECK(global_symbol.has_value())
<< "CodeGenCHost: The entry func must have the global_symbol attribute, "
<< "but function " << gvar << " only has attributes " << func->attrs;
Expand Down Expand Up @@ -235,7 +237,7 @@ void CodeGenCHost::PrintCallPacked(const CallNode* op) {
} else {
// directly use the original symbol
ICHECK(op->op.same_as(builtin::tvm_call_cpacked_lowered()));
packed_func_name = func_name->value;
packed_func_name = ffi::symbol::tvm_ffi_symbol_prefix + func_name->value;
}

std::string args_stack = PrintExpr(op->args[1]);
Expand Down
3 changes: 3 additions & 0 deletions src/target/source/codegen_c_host.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class CodeGenCHost : public CodeGenC {
const std::unordered_set<std::string>& devices);

void InitGlobalContext();

void AddFunction(const GlobalVar& gvar, const PrimFunc& f) override;
void AddFunction(const GlobalVar& gvar, const PrimFunc& f, bool emit_fwd_func_decl);
/*!
Expand Down Expand Up @@ -83,6 +84,8 @@ class CodeGenCHost : public CodeGenC {
bool emit_asserts_;
/*! \brief whether to emit forwared function declarations in the resulting C code */
bool emit_fwd_func_decl_;
/*! \brief whether to generate the entry function if encountered */
bool has_main_func_ = false;

std::string GetPackedName(const CallNode* op);
void PrintGetFuncFromBackend(const std::string& func_name, const std::string& packed_func_name);
Expand Down
14 changes: 9 additions & 5 deletions src/tir/transforms/make_packed_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
/*!
* \file make_packed_api.cc Lower PrimFunc to use the packed function API.
*/
#include <tvm/ffi/extra/module.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/device_api.h>
Expand Down Expand Up @@ -196,7 +197,7 @@ Optional<String> RequiresPackedAPI(const PrimFunc& func) {
return std::nullopt;
}

return global_symbol;
return global_symbol.value();
}

PrimFunc MakePackedAPI(PrimFunc func) {
Expand All @@ -223,6 +224,7 @@ PrimFunc MakePackedAPI(PrimFunc func) {
}

auto* func_ptr = func.CopyOnWrite();
// set the global symbol to the packed function name
const Stmt nop = Evaluate(0);
int num_args = static_cast<int>(func_ptr->params.size());

Expand Down Expand Up @@ -362,10 +364,12 @@ PrimFunc MakePackedAPI(PrimFunc func) {
binder.BindDLTensor(buffer, device_type, device_id, var, name_hint + "." + var->name_hint);
arg_buffer_declarations.push_back(DeclBuffer(buffer, nop));
}

func = WithAttrs(std::move(func),
{{tvm::attr::kCallingConv, static_cast<int>(CallingConv::kCPackedFunc)},
{tvm::attr::kTarget, target_host}});
// reset global symbol to attach prefix
func = WithAttrs(
std::move(func),
{{tvm::attr::kCallingConv, static_cast<int>(CallingConv::kCPackedFunc)},
{tvm::attr::kTarget, target_host},
{tvm::attr::kGlobalSymbol, ffi::symbol::tvm_ffi_symbol_prefix + global_symbol.value()}});

Stmt body = ReturnRewriter(v_result)(func_ptr->body);
body = AttrStmt(make_zero(DataType::Int(32)), attr::compute_scope,
Expand Down
10 changes: 1 addition & 9 deletions tests/python/codegen/test_target_codegen_c_host.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,17 +184,9 @@ def subroutine(A_data: T.handle("float32")):

built = tvm.tir.build(mod, target="c")

func_names = list(built["get_func_names"]())
assert (
"main" in func_names
), "Externally exposed functions should be listed in available functions."
assert (
"subroutine" not in func_names
), "Internal function should not be listed in available functions."

source = built.inspect_source()
assert (
source.count("main(void*") == 2
source.count("__tvm_ffi_main(void*") == 2
), "Expected two occurrences, for forward-declaration and definition"
assert (
source.count("subroutine(float*") == 2
Expand Down
5 changes: 4 additions & 1 deletion tests/python/codegen/test_target_codegen_llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,7 +953,10 @@ def test_llvm_target_attributes():
assert re.match('.*"target-cpu"="skylake".*', attribute_definitions[k])
assert re.match('.*"target-features"=".*[+]avx512f.*".*', attribute_definitions[k])

expected_functions = ["test_func", "test_func_compute_", "__tvm_parallel_lambda"]
expected_functions = [
"__tvm_ffi_test_func",
"__tvm_parallel_lambda",
]
for n in expected_functions:
assert n in functions_with_target

Expand Down
10 changes: 3 additions & 7 deletions tests/python/contrib/test_hexagon/test_async_dma_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.

""" Test different strategies for loading data into vtcm before running HVX workloads. """
"""Test different strategies for loading data into vtcm before running HVX workloads."""

import numpy as np
import pytest
Expand Down Expand Up @@ -287,13 +287,9 @@ def evaluate(

if tvm.testing.utils.IS_IN_CI:
# Run with reduced number and repeat for CI
timer = module.time_evaluator(
"__tvm_ffi_main__", hexagon_session.device, number=1, repeat=1
)
timer = module.time_evaluator("main", hexagon_session.device, number=1, repeat=1)
else:
timer = module.time_evaluator(
"__tvm_ffi_main__", hexagon_session.device, number=10, repeat=10
)
timer = module.time_evaluator("main", hexagon_session.device, number=10, repeat=10)

time = timer(a_hexagon, b_hexagon, c_hexagon)
if expected_output is not None:
Expand Down
4 changes: 1 addition & 3 deletions tests/python/contrib/test_hexagon/test_parallel_hvx.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,7 @@ def evaluate(hexagon_session, shape_dtypes, expected_output_producer, sch):
number = 1
repeat = 1

timer = module.time_evaluator(
"__tvm_ffi_main__", hexagon_session.device, number=number, repeat=repeat
)
timer = module.time_evaluator("main", hexagon_session.device, number=number, repeat=repeat)
runtime = timer(a_hexagon, b_hexagon, c_hexagon)
tvm.testing.assert_allclose(c_hexagon.numpy(), expected_output_producer(c_shape, a, b))

Expand Down
Loading
Loading