Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion src/target/llvm/llvm_instance.cc
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,23 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const Target& target) {
}
}

// Target options
// LLVM JIT engine options
if (const Optional<String>& v = target->GetAttr<String>("jit")) {
String value = v.value();
if ((value == "mcjit") || (value == "orcjit")) {
jit_engine_ = value;
} else {
LOG(FATAL) << "invalid jit option " << value << " (can be `mcjit` or `orcjit`).";
}
}

// RISCV code model
auto arch = llvm::Triple(triple_).getArch();
if (arch == llvm::Triple::riscv32 || arch == llvm::Triple::riscv64) {
code_model_ = llvm::CodeModel::Medium;
}

// Target options
#if TVM_LLVM_VERSION < 50
target_options_.LessPreciseFPMADOption = true;
#endif
Expand Down Expand Up @@ -525,6 +540,10 @@ std::string LLVMTargetInfo::str() const {
os << quote << Join(",", opts) << quote;
}

if (jit_engine_ != "mcjit") {
os << " -jit=" << jit_engine_;
}

return os.str();
}

Expand Down
6 changes: 6 additions & 0 deletions src/target/llvm/llvm_instance.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,11 @@ class LLVMTargetInfo {
* \return `llvm::FastMathFlags` for this target
*/
llvm::FastMathFlags GetFastMathFlags() const { return fast_math_flags_; }
/*!
* \brief Get the LLVM JIT engine type
* \return the type name of the JIT engine (default "mcjit" or "orcjit")
*/
const std::string GetJITEngine() const { return jit_engine_; }
/*!
* \brief Get the LLVM optimization level
* \return optimization level for this target
Expand Down Expand Up @@ -324,6 +329,7 @@ class LLVMTargetInfo {
llvm::Reloc::Model reloc_model_ = llvm::Reloc::PIC_;
llvm::CodeModel::Model code_model_ = llvm::CodeModel::Small;
std::shared_ptr<llvm::TargetMachine> target_machine_;
std::string jit_engine_ = "mcjit";
};

/*!
Expand Down
197 changes: 177 additions & 20 deletions src/target/llvm/llvm_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@
#include <llvm/ADT/StringRef.h>
#include <llvm/Bitcode/BitcodeWriter.h>
#include <llvm/ExecutionEngine/ExecutionEngine.h>
#include <llvm/ExecutionEngine/MCJIT.h> // Force linking of MCJIT
#include <llvm/ExecutionEngine/MCJIT.h>
#include <llvm/ExecutionEngine/Orc/LLJIT.h>
#include <llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h>
#include <llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h>
#include <llvm/IR/DataLayout.h>
#include <llvm/IR/Function.h>
#include <llvm/IR/Intrinsics.h>
Expand Down Expand Up @@ -113,8 +116,11 @@ class LLVMModuleNode final : public runtime::ModuleNode {

bool ImplementsFunction(const String& name, bool query_imports) final;

void SetJITEngine(const std::string& jit_engine) { jit_engine_ = jit_engine; }

private:
void LazyInitJIT();
void InitMCJIT();
void InitORCJIT();
bool IsCompatibleWithHost(const llvm::TargetMachine* tm) const;
void* GetGlobalAddr(const std::string& name, const LLVMTarget& llvm_target) const;
void* GetFunctionAddr(const std::string& name, const LLVMTarget& llvm_target) const;
Expand All @@ -123,21 +129,31 @@ class LLVMModuleNode final : public runtime::ModuleNode {
std::unique_ptr<LLVMInstance> llvm_instance_;
// JIT lock
std::mutex mutex_;
// execution engine
llvm::ExecutionEngine* ee_{nullptr};
// jit execution engines
llvm::ExecutionEngine* mcjit_ee_{nullptr};
std::unique_ptr<llvm::orc::LLJIT> orcjit_ee_{nullptr};
// The raw pointer to the module.
llvm::Module* module_{nullptr};
// The unique_ptr owning the module. This becomes empty once JIT has been initialized
// (EngineBuilder takes ownership of the module).
std::unique_ptr<llvm::Module> module_owning_ptr_;
/* \brief names of the external functions declared in this module */
Array<String> function_names_;
std::string jit_engine_;
};

LLVMModuleNode::~LLVMModuleNode() {
if (ee_ != nullptr) {
ee_->runStaticConstructorsDestructors(true);
delete ee_;
if (mcjit_ee_ != nullptr) {
mcjit_ee_->runStaticConstructorsDestructors(true);
delete mcjit_ee_;
}
if (orcjit_ee_ != nullptr) {
auto dtors = llvm::orc::getDestructors(*module_);
auto dtorRunner = std::make_unique<llvm::orc::CtorDtorRunner>(orcjit_ee_->getMainJITDylib());
dtorRunner->add(dtors);
auto err = dtorRunner->run();
ICHECK(!err) << llvm::toString(std::move(err));
orcjit_ee_.reset();
}
module_owning_ptr_.reset();
}
Expand Down Expand Up @@ -166,7 +182,9 @@ PackedFunc LLVMModuleNode::GetFunction(const String& name, const ObjectPtr<Objec
std::string target_string = LLVMTarget::GetTargetMetadata(*module_);
return PackedFunc([target_string](TVMArgs args, TVMRetValue* rv) { *rv = target_string; });
}
if (ee_ == nullptr) LazyInitJIT();
ICHECK(jit_engine_.size()) << "JIT engine type is missing";
if ((jit_engine_ == "mcjit") && (mcjit_ee_ == nullptr)) InitMCJIT();
if ((jit_engine_ == "orcjit") && (orcjit_ee_ == nullptr)) InitORCJIT();

std::lock_guard<std::mutex> lock(mutex_);

Expand Down Expand Up @@ -353,6 +371,7 @@ void LLVMModuleNode::Init(const IRModule& mod, const Target& target) {

module_owning_ptr_ = cg->Finish();
module_ = module_owning_ptr_.get();
jit_engine_ = llvm_target->GetJITEngine();
llvm_target->SetTargetMetadata(module_);
module_->addModuleFlag(llvm::Module::Override, "Debug Info Version",
llvm::DEBUG_METADATA_VERSION);
Expand Down Expand Up @@ -384,13 +403,16 @@ bool LLVMModuleNode::ImplementsFunction(const String& name, bool query_imports)
return std::find(function_names_.begin(), function_names_.end(), name) != function_names_.end();
}

void LLVMModuleNode::LazyInitJIT() {
void LLVMModuleNode::InitMCJIT() {
std::lock_guard<std::mutex> lock(mutex_);
if (ee_) {
if (mcjit_ee_) {
return;
}
// MCJIT builder
With<LLVMTarget> llvm_target(*llvm_instance_, LLVMTarget::GetTargetMetadata(*module_));
llvm::EngineBuilder builder(std::move(module_owning_ptr_));

// set options
builder.setEngineKind(llvm::EngineKind::JIT);
#if TVM_LLVM_VERSION <= 170
builder.setOptLevel(llvm::CodeGenOpt::Aggressive);
Expand All @@ -400,18 +422,31 @@ void LLVMModuleNode::LazyInitJIT() {
builder.setMCPU(llvm_target->GetCPU());
builder.setMAttrs(llvm_target->GetTargetFeatures());
builder.setTargetOptions(llvm_target->GetTargetOptions());

// create the taget machine
auto tm = std::unique_ptr<llvm::TargetMachine>(builder.selectTarget());
if (!IsCompatibleWithHost(tm.get())) {
LOG(FATAL) << "Cannot run module, architecture mismatch";
}

// data layout
llvm::DataLayout layout(tm->createDataLayout());
ICHECK(layout == module_->getDataLayout())
<< "Data layout mismatch between module("
<< module_->getDataLayout().getStringRepresentation() << ")"
<< " and ExecutionEngine (" << layout.getStringRepresentation() << ")";
ee_ = builder.create(tm.release());
ICHECK(ee_ != nullptr) << "Failed to initialize jit engine for " << module_->getTargetTriple();
ee_->runStaticConstructorsDestructors(false);

// create MCJIT
mcjit_ee_ = builder.create(tm.release());
ICHECK(mcjit_ee_ != nullptr) << "Failed to initialize LLVM MCJIT engine for "
<< module_->getTargetTriple();

VLOG(2) << "LLVM MCJIT execute " << module_->getModuleIdentifier() << " for triple `"
<< llvm_target->GetTargetTriple() << "`"
<< " on cpu `" << llvm_target->GetCPU() << "`";

// run ctors
mcjit_ee_->runStaticConstructorsDestructors(false);

if (void** ctx_addr =
reinterpret_cast<void**>(GetGlobalAddr(runtime::symbol::tvm_module_ctx, *llvm_target))) {
Expand All @@ -424,7 +459,104 @@ void LLVMModuleNode::LazyInitJIT() {
// lead to a runtime crash.
// Do name lookup on a symbol that doesn't exist. This will force MCJIT to finalize
// all loaded objects, which will resolve symbols in JITed code.
ee_->getFunctionAddress("__some_name_that_hopefully_doesnt_exist__b49f8aaade5877eaba7583b91");
mcjit_ee_->getFunctionAddress(
"__some_name_that_hopefully_doesnt_exist__b49f8aaade5877eaba7583b91");
}

void LLVMModuleNode::InitORCJIT() {
std::lock_guard<std::mutex> lock(mutex_);
if (orcjit_ee_) {
return;
}
// ORCJIT builder
With<LLVMTarget> llvm_target(*llvm_instance_, LLVMTarget::GetTargetMetadata(*module_));
llvm::orc::JITTargetMachineBuilder tm_builder(llvm::Triple(llvm_target->GetTargetTriple()));

// set options
tm_builder.setCPU(llvm_target->GetCPU());
tm_builder.setFeatures(llvm_target->GetTargetFeatureString());
tm_builder.setOptions(llvm_target->GetTargetOptions());
#if TVM_LLVM_VERSION <= 170
tm_builder.setCodeGenOptLevel(llvm::CodeGenOpt::Aggressive);
#else
tm_builder.setCodeGenOptLevel(llvm::CodeGenOptLevel::Aggressive);
#endif

// create the taget machine
std::unique_ptr<llvm::TargetMachine> tm = llvm::cantFail(tm_builder.createTargetMachine());
if (!IsCompatibleWithHost(tm.get())) {
LOG(FATAL) << "Cannot run module, architecture mismatch";
}

// data layout
String module_name = module_->getModuleIdentifier();
llvm::DataLayout layout(tm->createDataLayout());
ICHECK(layout == module_->getDataLayout())
<< "Data layout mismatch between module("
<< module_->getDataLayout().getStringRepresentation() << ")"
<< " and ExecutionEngine (" << layout.getStringRepresentation() << ")";

// compiler
const auto compilerBuilder = [&](const llvm::orc::JITTargetMachineBuilder&)
-> llvm::Expected<std::unique_ptr<llvm::orc::IRCompileLayer::IRCompiler>> {
return std::make_unique<llvm::orc::TMOwningSimpleCompiler>(std::move(tm));
};

#if TVM_LLVM_VERSION >= 130
// linker
const auto linkerBuilder = [&](llvm::orc::ExecutionSession& session, const llvm::Triple&) {
return std::make_unique<llvm::orc::ObjectLinkingLayer>(session);
};
#endif

// create LLJIT
orcjit_ee_ = llvm::cantFail(llvm::orc::LLJITBuilder()
#if TVM_LLVM_VERSION >= 110
.setDataLayout(layout)
#endif
.setCompileFunctionCreator(compilerBuilder)
#if TVM_LLVM_VERSION >= 130
.setObjectLinkingLayerCreator(linkerBuilder)
#endif
.create());

ICHECK(orcjit_ee_ != nullptr) << "Failed to initialize LLVM ORCJIT engine for "
<< module_->getTargetTriple();

// store ctors
auto ctors = llvm::orc::getConstructors(*module_);
llvm::orc::CtorDtorRunner ctorRunner(orcjit_ee_->getMainJITDylib());
ctorRunner.add(ctors);

// resolve system symbols (like pthread, dl, m, etc.)
auto gen =
llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(layout.getGlobalPrefix());
ICHECK(gen) << llvm::toString(gen.takeError()) << "\n";
orcjit_ee_->getMainJITDylib().addGenerator(std::move(gen.get()));

// transfer module to a clone
auto uctx = std::make_unique<llvm::LLVMContext>();
auto umod = llvm::CloneModule(*(std::move(module_owning_ptr_)));

// add the llvm module to run
llvm::orc::ThreadSafeModule tsm(std::move(umod), std::move(uctx));
auto err = orcjit_ee_->addIRModule(std::move(tsm));
ICHECK(!err) << llvm::toString(std::move(err));

VLOG(2) << "LLVM ORCJIT execute " << module_->getModuleIdentifier() << " for triple `"
<< llvm_target->GetTargetTriple() << "`"
<< " on cpu `" << llvm_target->GetCPU() << "`";

// run ctors
err = ctorRunner.run();
ICHECK(!err) << llvm::toString(std::move(err));

if (void** ctx_addr =
reinterpret_cast<void**>(GetGlobalAddr(runtime::symbol::tvm_module_ctx, *llvm_target))) {
*ctx_addr = this;
}
runtime::InitContextFunctions(
[this, &llvm_target](const char* name) { return GetGlobalAddr(name, *llvm_target); });
}

bool LLVMModuleNode::IsCompatibleWithHost(const llvm::TargetMachine* tm) const {
Expand All @@ -442,20 +574,40 @@ bool LLVMModuleNode::IsCompatibleWithHost(const llvm::TargetMachine* tm) const {
void* LLVMModuleNode::GetGlobalAddr(const std::string& name, const LLVMTarget& llvm_target) const {
// first verifies if GV exists.
if (module_->getGlobalVariable(name) != nullptr) {
return reinterpret_cast<void*>(ee_->getGlobalValueAddress(name));
} else {
return nullptr;
if (jit_engine_ == "mcjit") {
return reinterpret_cast<void*>(mcjit_ee_->getGlobalValueAddress(name));
} else if (jit_engine_ == "orcjit") {
#if TVM_LLVM_VERSION >= 150
auto addr = llvm::cantFail(orcjit_ee_->lookup(name)).getValue();
#else
auto addr = llvm::cantFail(orcjit_ee_->lookup(name)).getAddress();
#endif
return reinterpret_cast<void*>(addr);
} else {
LOG(FATAL) << "Either `mcjit` or `orcjit` are not initialized.";
}
}
return nullptr;
}

void* LLVMModuleNode::GetFunctionAddr(const std::string& name,
const LLVMTarget& llvm_target) const {
// first verifies if GV exists.
if (module_->getFunction(name) != nullptr) {
return reinterpret_cast<void*>(ee_->getFunctionAddress(name));
} else {
return nullptr;
if (jit_engine_ == "mcjit") {
return reinterpret_cast<void*>(mcjit_ee_->getFunctionAddress(name));
} else if (jit_engine_ == "orcjit") {
#if TVM_LLVM_VERSION >= 150
auto addr = llvm::cantFail(orcjit_ee_->lookup(name)).getValue();
#else
auto addr = llvm::cantFail(orcjit_ee_->lookup(name)).getAddress();
#endif
return reinterpret_cast<void*>(addr);
} else {
LOG(FATAL) << "Either `mcjit` or `orcjit` are not initialized.";
}
}
return nullptr;
}

TVM_REGISTER_GLOBAL("target.build.llvm")
Expand All @@ -476,6 +628,7 @@ TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate")
module->setTargetTriple(llvm_target->GetTargetTriple());
module->setDataLayout(llvm_target->GetOrCreateTargetMachine()->createDataLayout());
n->Init(std::move(module), std::move(llvm_instance));
n->SetJITEngine(llvm_target->GetJITEngine());
return runtime::Module(n);
});

Expand Down Expand Up @@ -595,6 +748,7 @@ TVM_REGISTER_GLOBAL("target.llvm_version_major").set_body_typed([]() -> int {
TVM_REGISTER_GLOBAL("runtime.module.loadfile_ll")
.set_body_typed([](std::string filename, std::string fmt) -> runtime::Module {
auto n = make_object<LLVMModuleNode>();
n->SetJITEngine("mcjit");
n->LoadIR(filename);
return runtime::Module(n);
});
Expand All @@ -616,6 +770,7 @@ TVM_REGISTER_GLOBAL("codegen.codegen_blob")
std::unique_ptr<llvm::Module> blob =
CodeGenBlob(data, system_lib, llvm_target.get(), c_symbol_prefix);
n->Init(std::move(blob), std::move(llvm_instance));
n->SetJITEngine(llvm_target->GetJITEngine());
return runtime::Module(n);
});

Expand Down Expand Up @@ -645,6 +800,7 @@ runtime::Module CreateLLVMCppMetadataModule(runtime::metadata::Metadata metadata

auto n = make_object<LLVMModuleNode>();
n->Init(std::move(mod), std::move(llvm_instance));
n->SetJITEngine(llvm_target->GetJITEngine());

auto meta_mod = MetadataModuleCreate(metadata);
meta_mod->Import(runtime::Module(n));
Expand Down Expand Up @@ -691,6 +847,7 @@ runtime::Module CreateLLVMCrtMetadataModule(const Array<runtime::Module>& module

auto n = make_object<LLVMModuleNode>();
n->Init(std::move(mod), std::move(llvm_instance));
n->SetJITEngine(llvm_target->GetJITEngine());
for (auto m : modules) {
n->Import(m);
}
Expand Down
2 changes: 2 additions & 0 deletions src/target/target_kind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,8 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU)
.add_attr_option<Integer>("opt-level")
// LLVM command line flags, see below
.add_attr_option<Array<String>>("cl-opt")
// LLVM JIT engine mcjit/orcjit
.add_attr_option<String>("jit")
.set_default_keys({"cpu"})
// Force the external codegen kind attribute to be registered, even if no external
// codegen targets are enabled by the TVM build.
Expand Down
Loading