Skip to content

Commit 6b42ee9

Browse files
committed
[LLVM][RUNTIME] Add optional LLVM ORCJIT runtime executor
1 parent 78a6146 commit 6b42ee9

File tree

7 files changed

+232
-36
lines changed

7 files changed

+232
-36
lines changed

src/target/llvm/llvm_instance.cc

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,8 +252,23 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const Target& target) {
252252
}
253253
}
254254

255-
// Target options
255+
// LLVM JIT engine options
256+
if (const Optional<String>& v = target->GetAttr<String>("jit")) {
257+
String value = v.value();
258+
if ((value == "mcjit") || (value == "orcjit")) {
259+
jit_engine_ = value;
260+
} else {
261+
LOG(FATAL) << "invalid jit option " << value << " (can be `mcjit` or `orcjit`).";
262+
}
263+
}
256264

265+
// RISCV code model
266+
auto arch = llvm::Triple(triple_).getArch();
267+
if (arch == llvm::Triple::riscv32 || arch == llvm::Triple::riscv64) {
268+
code_model_ = llvm::CodeModel::Medium;
269+
}
270+
271+
// Target options
257272
#if TVM_LLVM_VERSION < 50
258273
target_options_.LessPreciseFPMADOption = true;
259274
#endif
@@ -521,6 +536,10 @@ std::string LLVMTargetInfo::str() const {
521536
os << quote << Join(",", opts) << quote;
522537
}
523538

539+
if (jit_engine_ != "mcjit") {
540+
os << " -jit=" << jit_engine_;
541+
}
542+
524543
return os.str();
525544
}
526545

src/target/llvm/llvm_instance.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,11 @@ class LLVMTargetInfo {
212212
* \return `llvm::FastMathFlags` for this target
213213
*/
214214
llvm::FastMathFlags GetFastMathFlags() const { return fast_math_flags_; }
215+
/*!
216+
* \brief Get the LLVM JIT engine type
217+
* \return the type name of the JIT engine (default "mcjit" or "orcjit")
218+
*/
219+
const std::string GetJITEngine() const { return jit_engine_; }
215220
/*!
216221
* \brief Get the LLVM optimization level
217222
* \return optimization level for this target
@@ -324,6 +329,7 @@ class LLVMTargetInfo {
324329
llvm::Reloc::Model reloc_model_ = llvm::Reloc::PIC_;
325330
llvm::CodeModel::Model code_model_ = llvm::CodeModel::Small;
326331
std::shared_ptr<llvm::TargetMachine> target_machine_;
332+
std::string jit_engine_ = "mcjit";
327333
};
328334

329335
/*!

src/target/llvm/llvm_module.cc

Lines changed: 178 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@
3030
#include <llvm/ADT/StringRef.h>
3131
#include <llvm/Bitcode/BitcodeWriter.h>
3232
#include <llvm/ExecutionEngine/ExecutionEngine.h>
33-
#include <llvm/ExecutionEngine/MCJIT.h> // Force linking of MCJIT
33+
#include <llvm/ExecutionEngine/MCJIT.h>
34+
#include <llvm/ExecutionEngine/Orc/LLJIT.h>
35+
#include <llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h>
36+
#include <llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h>
3437
#include <llvm/IR/DataLayout.h>
3538
#include <llvm/IR/Function.h>
3639
#include <llvm/IR/Intrinsics.h>
@@ -109,8 +112,11 @@ class LLVMModuleNode final : public runtime::ModuleNode {
109112

110113
bool ImplementsFunction(const String& name, bool query_imports) final;
111114

115+
void SetJITEngine(const std::string& jit_engine) { jit_engine_ = jit_engine; }
116+
112117
private:
113-
void LazyInitJIT();
118+
void InitMCJIT();
119+
void InitORCJIT();
114120
bool IsCompatibleWithHost(const llvm::TargetMachine* tm) const;
115121
void* GetGlobalAddr(const std::string& name, const LLVMTarget& llvm_target) const;
116122
void* GetFunctionAddr(const std::string& name, const LLVMTarget& llvm_target) const;
@@ -119,21 +125,31 @@ class LLVMModuleNode final : public runtime::ModuleNode {
119125
std::unique_ptr<LLVMInstance> llvm_instance_;
120126
// JIT lock
121127
std::mutex mutex_;
122-
// execution engine
123-
llvm::ExecutionEngine* ee_{nullptr};
128+
// jit execution engines
129+
llvm::ExecutionEngine* mcjit_ee_{nullptr};
130+
std::unique_ptr<llvm::orc::LLJIT> orcjit_ee_{nullptr};
124131
// The raw pointer to the module.
125132
llvm::Module* module_{nullptr};
126133
// The unique_ptr owning the module. This becomes empty once JIT has been initialized
127134
// (EngineBuilder takes ownership of the module).
128135
std::unique_ptr<llvm::Module> module_owning_ptr_;
129136
/* \brief names of the external functions declared in this module */
130137
Array<String> function_names_;
138+
std::string jit_engine_;
131139
};
132140

133141
LLVMModuleNode::~LLVMModuleNode() {
134-
if (ee_ != nullptr) {
135-
ee_->runStaticConstructorsDestructors(true);
136-
delete ee_;
142+
if (mcjit_ee_ != nullptr) {
143+
mcjit_ee_->runStaticConstructorsDestructors(true);
144+
delete mcjit_ee_;
145+
}
146+
if (orcjit_ee_ != nullptr) {
147+
auto dtors = llvm::orc::getDestructors(*module_);
148+
auto dtorRunner = std::make_unique<llvm::orc::CtorDtorRunner>(orcjit_ee_->getMainJITDylib());
149+
dtorRunner->add(dtors);
150+
auto err = dtorRunner->run();
151+
ICHECK(!err) << llvm::toString(std::move(err));
152+
orcjit_ee_.reset();
137153
}
138154
module_owning_ptr_.reset();
139155
}
@@ -162,7 +178,9 @@ PackedFunc LLVMModuleNode::GetFunction(const String& name, const ObjectPtr<Objec
162178
std::string target_string = LLVMTarget::GetTargetMetadata(*module_);
163179
return PackedFunc([target_string](TVMArgs args, TVMRetValue* rv) { *rv = target_string; });
164180
}
165-
if (ee_ == nullptr) LazyInitJIT();
181+
ICHECK(jit_engine_.size()) << "JIT engine type is missing";
182+
if ((jit_engine_ == "mcjit") && (mcjit_ee_ == nullptr)) InitMCJIT();
183+
if ((jit_engine_ == "orcjit") && (orcjit_ee_ == nullptr)) InitORCJIT();
166184

167185
std::lock_guard<std::mutex> lock(mutex_);
168186

@@ -349,6 +367,7 @@ void LLVMModuleNode::Init(const IRModule& mod, const Target& target) {
349367

350368
module_owning_ptr_ = cg->Finish();
351369
module_ = module_owning_ptr_.get();
370+
jit_engine_ = llvm_target->GetJITEngine();
352371
llvm_target->SetTargetMetadata(module_);
353372
module_->addModuleFlag(llvm::Module::Override, "Debug Info Version",
354373
llvm::DEBUG_METADATA_VERSION);
@@ -381,13 +400,16 @@ bool LLVMModuleNode::ImplementsFunction(const String& name, bool query_imports)
381400
return std::find(function_names_.begin(), function_names_.end(), name) != function_names_.end();
382401
}
383402

384-
void LLVMModuleNode::LazyInitJIT() {
403+
void LLVMModuleNode::InitMCJIT() {
385404
std::lock_guard<std::mutex> lock(mutex_);
386-
if (ee_) {
405+
if (mcjit_ee_) {
387406
return;
388407
}
408+
// MCJIT builder
389409
With<LLVMTarget> llvm_target(*llvm_instance_, LLVMTarget::GetTargetMetadata(*module_));
390410
llvm::EngineBuilder builder(std::move(module_owning_ptr_));
411+
412+
// set options
391413
builder.setEngineKind(llvm::EngineKind::JIT);
392414
#if TVM_LLVM_VERSION <= 170
393415
builder.setOptLevel(llvm::CodeGenOpt::Aggressive);
@@ -397,18 +419,32 @@ void LLVMModuleNode::LazyInitJIT() {
397419
builder.setMCPU(llvm_target->GetCPU());
398420
builder.setMAttrs(llvm_target->GetTargetFeatures());
399421
builder.setTargetOptions(llvm_target->GetTargetOptions());
422+
423+
// create the taget machine
400424
auto tm = std::unique_ptr<llvm::TargetMachine>(builder.selectTarget());
401425
if (!IsCompatibleWithHost(tm.get())) {
402426
LOG(FATAL) << "Cannot run module, architecture mismatch";
403427
}
428+
429+
// data layout
404430
llvm::DataLayout layout(tm->createDataLayout());
405431
ICHECK(layout == module_->getDataLayout())
406432
<< "Data layout mismatch between module("
407433
<< module_->getDataLayout().getStringRepresentation() << ")"
408434
<< " and ExecutionEngine (" << layout.getStringRepresentation() << ")";
409-
ee_ = builder.create(tm.release());
410-
ICHECK(ee_ != nullptr) << "Failed to initialize jit engine for " << module_->getTargetTriple();
411-
ee_->runStaticConstructorsDestructors(false);
435+
436+
// create MCJIT
437+
mcjit_ee_ = builder.create(tm.release());
438+
ICHECK(mcjit_ee_ != nullptr) << "Failed to initialize LLVM MCJIT engine for "
439+
<< module_->getTargetTriple();
440+
441+
VLOG(2) << "LLVM MCJIT execute " << module_->getModuleIdentifier() << " for triple `"
442+
<< llvm_target->GetTargetTriple() << "`"
443+
<< " on cpu `" << llvm_target->GetCPU() << "`";
444+
445+
// run ctors
446+
module_->getTargetTriple();
447+
mcjit_ee_->runStaticConstructorsDestructors(false);
412448

413449
if (void** ctx_addr =
414450
reinterpret_cast<void**>(GetGlobalAddr(runtime::symbol::tvm_module_ctx, *llvm_target))) {
@@ -421,7 +457,104 @@ void LLVMModuleNode::LazyInitJIT() {
421457
// lead to a runtime crash.
422458
// Do name lookup on a symbol that doesn't exist. This will force MCJIT to finalize
423459
// all loaded objects, which will resolve symbols in JITed code.
424-
ee_->getFunctionAddress("__some_name_that_hopefully_doesnt_exist__b49f8aaade5877eaba7583b91");
460+
mcjit_ee_->getFunctionAddress(
461+
"__some_name_that_hopefully_doesnt_exist__b49f8aaade5877eaba7583b91");
462+
}
463+
464+
void LLVMModuleNode::InitORCJIT() {
465+
std::lock_guard<std::mutex> lock(mutex_);
466+
if (orcjit_ee_) {
467+
return;
468+
}
469+
// ORCJIT builder
470+
With<LLVMTarget> llvm_target(*llvm_instance_, LLVMTarget::GetTargetMetadata(*module_));
471+
llvm::orc::JITTargetMachineBuilder tm_builder(llvm::Triple(llvm_target->GetTargetTriple()));
472+
473+
// set options
474+
tm_builder.setCPU(llvm_target->GetCPU());
475+
tm_builder.setFeatures(llvm_target->GetTargetFeatureString());
476+
tm_builder.setOptions(llvm_target->GetTargetOptions());
477+
#if TVM_LLVM_VERSION <= 170
478+
tm_builder.setCodeGenOptLevel(llvm::CodeGenOpt::Aggressive);
479+
#else
480+
tm_builder.setCodeGenOptLevel(llvm::CodeGenOptLevel::Aggressive);
481+
#endif
482+
483+
// create the taget machine
484+
auto tm = tm_builder.createTargetMachine();
485+
if (!IsCompatibleWithHost(tm->get())) {
486+
LOG(FATAL) << "Cannot run module, architecture mismatch";
487+
}
488+
489+
// data layout
490+
String module_name = module_->getModuleIdentifier();
491+
llvm::DataLayout layout(tm->get()->createDataLayout());
492+
ICHECK(layout == module_->getDataLayout())
493+
<< "Data layout mismatch between module("
494+
<< module_->getDataLayout().getStringRepresentation() << ")"
495+
<< " and ExecutionEngine (" << layout.getStringRepresentation() << ")";
496+
497+
// compiler
498+
const auto compilerBuilder = [&](const llvm::orc::JITTargetMachineBuilder&)
499+
-> llvm::Expected<std::unique_ptr<llvm::orc::IRCompileLayer::IRCompiler>> {
500+
return std::make_unique<llvm::orc::TMOwningSimpleCompiler>(std::move(*tm));
501+
};
502+
503+
#if TVM_LLVM_VERSION >= 130
504+
// linker
505+
const auto linkerBuilder = [&](llvm::orc::ExecutionSession& session, const llvm::Triple&) {
506+
return std::make_unique<llvm::orc::ObjectLinkingLayer>(session);
507+
};
508+
#endif
509+
510+
// create LLJIT
511+
orcjit_ee_ = llvm::cantFail(llvm::orc::LLJITBuilder()
512+
#if TVM_LLVM_VERSION >= 110
513+
.setDataLayout(layout)
514+
#endif
515+
.setCompileFunctionCreator(compilerBuilder)
516+
#if TVM_LLVM_VERSION >= 130
517+
.setObjectLinkingLayerCreator(linkerBuilder)
518+
#endif
519+
.create());
520+
521+
ICHECK(orcjit_ee_ != nullptr) << "Failed to initialize LLVM ORCJIT engine for "
522+
<< module_->getTargetTriple();
523+
524+
// store ctors
525+
auto ctors = llvm::orc::getConstructors(*module_);
526+
llvm::orc::CtorDtorRunner ctorRunner(orcjit_ee_->getMainJITDylib());
527+
ctorRunner.add(ctors);
528+
529+
// resolve system symbols (like pthread, dl, m, etc.)
530+
auto gen =
531+
llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(layout.getGlobalPrefix());
532+
ICHECK(gen) << llvm::toString(gen.takeError()) << "\n";
533+
orcjit_ee_->getMainJITDylib().addGenerator(std::move(gen.get()));
534+
535+
// transfer module to a clone
536+
auto uctx = std::make_unique<llvm::LLVMContext>();
537+
auto umod = llvm::CloneModule(*(std::move(module_owning_ptr_)));
538+
539+
// add the llvm module to run
540+
llvm::orc::ThreadSafeModule tsm(std::move(umod), std::move(uctx));
541+
auto err = orcjit_ee_->addIRModule(std::move(tsm));
542+
ICHECK(!err) << llvm::toString(std::move(err));
543+
544+
VLOG(2) << "LLVM ORCJIT execute " << module_->getModuleIdentifier() << " for triple `"
545+
<< llvm_target->GetTargetTriple() << "`"
546+
<< " on cpu `" << llvm_target->GetCPU() << "`";
547+
548+
// run ctors
549+
err = ctorRunner.run();
550+
ICHECK(!err) << llvm::toString(std::move(err));
551+
552+
if (void** ctx_addr =
553+
reinterpret_cast<void**>(GetGlobalAddr(runtime::symbol::tvm_module_ctx, *llvm_target))) {
554+
*ctx_addr = this;
555+
}
556+
runtime::InitContextFunctions(
557+
[this, &llvm_target](const char* name) { return GetGlobalAddr(name, *llvm_target); });
425558
}
426559

427560
bool LLVMModuleNode::IsCompatibleWithHost(const llvm::TargetMachine* tm) const {
@@ -439,20 +572,40 @@ bool LLVMModuleNode::IsCompatibleWithHost(const llvm::TargetMachine* tm) const {
439572
void* LLVMModuleNode::GetGlobalAddr(const std::string& name, const LLVMTarget& llvm_target) const {
440573
// first verifies if GV exists.
441574
if (module_->getGlobalVariable(name) != nullptr) {
442-
return reinterpret_cast<void*>(ee_->getGlobalValueAddress(name));
443-
} else {
444-
return nullptr;
575+
if (jit_engine_ == "mcjit") {
576+
return reinterpret_cast<void*>(mcjit_ee_->getGlobalValueAddress(name));
577+
} else if (jit_engine_ == "orcjit") {
578+
#if TVM_LLVM_VERSION >= 150
579+
auto addr = llvm::cantFail(orcjit_ee_->lookup(name)).getValue();
580+
#else
581+
auto addr = llvm::cantFail(orcjit_ee_->lookup(name)).getAddress();
582+
#endif
583+
return reinterpret_cast<void*>(addr);
584+
} else {
585+
LOG(FATAL) << "Either `mcjit` or `orcjit` are not initialized.";
586+
}
445587
}
588+
return nullptr;
446589
}
447590

448591
void* LLVMModuleNode::GetFunctionAddr(const std::string& name,
449592
const LLVMTarget& llvm_target) const {
450593
// first verifies if GV exists.
451594
if (module_->getFunction(name) != nullptr) {
452-
return reinterpret_cast<void*>(ee_->getFunctionAddress(name));
453-
} else {
454-
return nullptr;
595+
if (jit_engine_ == "mcjit") {
596+
return reinterpret_cast<void*>(mcjit_ee_->getFunctionAddress(name));
597+
} else if (jit_engine_ == "orcjit") {
598+
#if TVM_LLVM_VERSION >= 150
599+
auto addr = llvm::cantFail(orcjit_ee_->lookup(name)).getValue();
600+
#else
601+
auto addr = llvm::cantFail(orcjit_ee_->lookup(name)).getAddress();
602+
#endif
603+
return reinterpret_cast<void*>(addr);
604+
} else {
605+
LOG(FATAL) << "Either `mcjit` or `orcjit` are not initialized.";
606+
}
455607
}
608+
return nullptr;
456609
}
457610

458611
TVM_REGISTER_GLOBAL("target.build.llvm")
@@ -473,6 +626,7 @@ TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate")
473626
module->setTargetTriple(llvm_target->GetTargetTriple());
474627
module->setDataLayout(llvm_target->GetOrCreateTargetMachine()->createDataLayout());
475628
n->Init(std::move(module), std::move(llvm_instance));
629+
n->SetJITEngine(llvm_target->GetJITEngine());
476630
return runtime::Module(n);
477631
});
478632

@@ -592,6 +746,7 @@ TVM_REGISTER_GLOBAL("target.llvm_version_major").set_body_typed([]() -> int {
592746
TVM_REGISTER_GLOBAL("runtime.module.loadfile_ll")
593747
.set_body_typed([](std::string filename, std::string fmt) -> runtime::Module {
594748
auto n = make_object<LLVMModuleNode>();
749+
n->SetJITEngine("mcjit");
595750
n->LoadIR(filename);
596751
return runtime::Module(n);
597752
});
@@ -613,6 +768,7 @@ TVM_REGISTER_GLOBAL("codegen.codegen_blob")
613768
std::unique_ptr<llvm::Module> blob =
614769
CodeGenBlob(data, system_lib, llvm_target.get(), c_symbol_prefix);
615770
n->Init(std::move(blob), std::move(llvm_instance));
771+
n->SetJITEngine(llvm_target->GetJITEngine());
616772
return runtime::Module(n);
617773
});
618774

@@ -642,6 +798,7 @@ runtime::Module CreateLLVMCppMetadataModule(runtime::metadata::Metadata metadata
642798

643799
auto n = make_object<LLVMModuleNode>();
644800
n->Init(std::move(mod), std::move(llvm_instance));
801+
n->SetJITEngine(llvm_target->GetJITEngine());
645802

646803
auto meta_mod = MetadataModuleCreate(metadata);
647804
meta_mod->Import(runtime::Module(n));
@@ -688,6 +845,7 @@ runtime::Module CreateLLVMCrtMetadataModule(const Array<runtime::Module>& module
688845

689846
auto n = make_object<LLVMModuleNode>();
690847
n->Init(std::move(mod), std::move(llvm_instance));
848+
n->SetJITEngine(llvm_target->GetJITEngine());
691849
for (auto m : modules) {
692850
n->Import(m);
693851
}

src/target/target_kind.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,8 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU)
291291
.add_attr_option<Integer>("opt-level")
292292
// LLVM command line flags, see below
293293
.add_attr_option<Array<String>>("cl-opt")
294+
// LLVM JIT engine mcjit/orcjit
295+
.add_attr_option<String>("jit")
294296
.set_default_keys({"cpu"})
295297
// Force the external codegen kind attribute to be registered, even if no external
296298
// codegen targets are enabled by the TVM build.

0 commit comments

Comments
 (0)