3030#include < llvm/ADT/StringRef.h>
3131#include < llvm/Bitcode/BitcodeWriter.h>
3232#include < llvm/ExecutionEngine/ExecutionEngine.h>
33- #include < llvm/ExecutionEngine/MCJIT.h> // Force linking of MCJIT
33+ #include < llvm/ExecutionEngine/MCJIT.h>
34+ #include < llvm/ExecutionEngine/Orc/LLJIT.h>
35+ #include < llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h>
36+ #include < llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h>
3437#include < llvm/IR/DataLayout.h>
3538#include < llvm/IR/Function.h>
3639#include < llvm/IR/Intrinsics.h>
@@ -109,8 +112,11 @@ class LLVMModuleNode final : public runtime::ModuleNode {
109112
110113 bool ImplementsFunction (const String& name, bool query_imports) final ;
111114
115+ void SetJITEngine (const std::string& jit_engine) { jit_engine_ = jit_engine; }
116+
112117 private:
113- void LazyInitJIT ();
118+ void InitMCJIT ();
119+ void InitORCJIT ();
114120 bool IsCompatibleWithHost (const llvm::TargetMachine* tm) const ;
115121 void * GetGlobalAddr (const std::string& name, const LLVMTarget& llvm_target) const ;
116122 void * GetFunctionAddr (const std::string& name, const LLVMTarget& llvm_target) const ;
@@ -119,21 +125,31 @@ class LLVMModuleNode final : public runtime::ModuleNode {
119125 std::unique_ptr<LLVMInstance> llvm_instance_;
120126 // JIT lock
121127 std::mutex mutex_;
122- // execution engine
123- llvm::ExecutionEngine* ee_{nullptr };
128+ // jit execution engines
129+ llvm::ExecutionEngine* mcjit_ee_{nullptr };
130+ std::unique_ptr<llvm::orc::LLJIT> orcjit_ee_{nullptr };
124131 // The raw pointer to the module.
125132 llvm::Module* module_{nullptr };
126133 // The unique_ptr owning the module. This becomes empty once JIT has been initialized
127134 // (EngineBuilder takes ownership of the module).
128135 std::unique_ptr<llvm::Module> module_owning_ptr_;
129136 /* \brief names of the external functions declared in this module */
130137 Array<String> function_names_;
138+ std::string jit_engine_;
131139};
132140
133141LLVMModuleNode::~LLVMModuleNode () {
134- if (ee_ != nullptr ) {
135- ee_->runStaticConstructorsDestructors (true );
136- delete ee_;
142+ if (mcjit_ee_ != nullptr ) {
143+ mcjit_ee_->runStaticConstructorsDestructors (true );
144+ delete mcjit_ee_;
145+ }
146+ if (orcjit_ee_ != nullptr ) {
147+ auto dtors = llvm::orc::getDestructors (*module_);
148+ auto dtorRunner = std::make_unique<llvm::orc::CtorDtorRunner>(orcjit_ee_->getMainJITDylib ());
149+ dtorRunner->add (dtors);
150+ auto err = dtorRunner->run ();
151+ ICHECK (!err) << llvm::toString (std::move (err));
152+ orcjit_ee_.reset ();
137153 }
138154 module_owning_ptr_.reset ();
139155}
@@ -162,7 +178,9 @@ PackedFunc LLVMModuleNode::GetFunction(const String& name, const ObjectPtr<Objec
162178 std::string target_string = LLVMTarget::GetTargetMetadata (*module_);
163179 return PackedFunc ([target_string](TVMArgs args, TVMRetValue* rv) { *rv = target_string; });
164180 }
165- if (ee_ == nullptr ) LazyInitJIT ();
181+ ICHECK (jit_engine_.size ()) << " JIT engine type is missing" ;
182+ if ((jit_engine_ == " mcjit" ) && (mcjit_ee_ == nullptr )) InitMCJIT ();
183+ if ((jit_engine_ == " orcjit" ) && (orcjit_ee_ == nullptr )) InitORCJIT ();
166184
167185 std::lock_guard<std::mutex> lock (mutex_);
168186
@@ -349,6 +367,7 @@ void LLVMModuleNode::Init(const IRModule& mod, const Target& target) {
349367
350368 module_owning_ptr_ = cg->Finish ();
351369 module_ = module_owning_ptr_.get ();
370+ jit_engine_ = llvm_target->GetJITEngine ();
352371 llvm_target->SetTargetMetadata (module_);
353372 module_->addModuleFlag (llvm::Module::Override, " Debug Info Version" ,
354373 llvm::DEBUG_METADATA_VERSION);
@@ -381,13 +400,16 @@ bool LLVMModuleNode::ImplementsFunction(const String& name, bool query_imports)
381400 return std::find (function_names_.begin (), function_names_.end (), name) != function_names_.end ();
382401}
383402
384- void LLVMModuleNode::LazyInitJIT () {
403+ void LLVMModuleNode::InitMCJIT () {
385404 std::lock_guard<std::mutex> lock (mutex_);
386- if (ee_ ) {
405+ if (mcjit_ee_ ) {
387406 return ;
388407 }
408+ // MCJIT builder
389409 With<LLVMTarget> llvm_target (*llvm_instance_, LLVMTarget::GetTargetMetadata (*module_));
390410 llvm::EngineBuilder builder (std::move (module_owning_ptr_));
411+
412+ // set options
391413 builder.setEngineKind (llvm::EngineKind::JIT);
392414#if TVM_LLVM_VERSION <= 170
393415 builder.setOptLevel (llvm::CodeGenOpt::Aggressive);
@@ -397,18 +419,32 @@ void LLVMModuleNode::LazyInitJIT() {
397419 builder.setMCPU (llvm_target->GetCPU ());
398420 builder.setMAttrs (llvm_target->GetTargetFeatures ());
399421 builder.setTargetOptions (llvm_target->GetTargetOptions ());
422+
423+ // create the taget machine
400424 auto tm = std::unique_ptr<llvm::TargetMachine>(builder.selectTarget ());
401425 if (!IsCompatibleWithHost (tm.get ())) {
402426 LOG (FATAL) << " Cannot run module, architecture mismatch" ;
403427 }
428+
429+ // data layout
404430 llvm::DataLayout layout (tm->createDataLayout ());
405431 ICHECK (layout == module_->getDataLayout ())
406432 << " Data layout mismatch between module("
407433 << module_->getDataLayout ().getStringRepresentation () << " )"
408434 << " and ExecutionEngine (" << layout.getStringRepresentation () << " )" ;
409- ee_ = builder.create (tm.release ());
410- ICHECK (ee_ != nullptr ) << " Failed to initialize jit engine for " << module_->getTargetTriple ();
411- ee_->runStaticConstructorsDestructors (false );
435+
436+ // create MCJIT
437+ mcjit_ee_ = builder.create (tm.release ());
438+ ICHECK (mcjit_ee_ != nullptr ) << " Failed to initialize LLVM MCJIT engine for "
439+ << module_->getTargetTriple ();
440+
441+ VLOG (2 ) << " LLVM MCJIT execute " << module_->getModuleIdentifier () << " for triple `"
442+ << llvm_target->GetTargetTriple () << " `"
443+ << " on cpu `" << llvm_target->GetCPU () << " `" ;
444+
445+ // run ctors
446+ module_->getTargetTriple ();
447+ mcjit_ee_->runStaticConstructorsDestructors (false );
412448
413449 if (void ** ctx_addr =
414450 reinterpret_cast <void **>(GetGlobalAddr (runtime::symbol::tvm_module_ctx, *llvm_target))) {
@@ -421,7 +457,104 @@ void LLVMModuleNode::LazyInitJIT() {
421457 // lead to a runtime crash.
422458 // Do name lookup on a symbol that doesn't exist. This will force MCJIT to finalize
423459 // all loaded objects, which will resolve symbols in JITed code.
424- ee_->getFunctionAddress (" __some_name_that_hopefully_doesnt_exist__b49f8aaade5877eaba7583b91" );
460+ mcjit_ee_->getFunctionAddress (
461+ " __some_name_that_hopefully_doesnt_exist__b49f8aaade5877eaba7583b91" );
462+ }
463+
464+ void LLVMModuleNode::InitORCJIT () {
465+ std::lock_guard<std::mutex> lock (mutex_);
466+ if (orcjit_ee_) {
467+ return ;
468+ }
469+ // ORCJIT builder
470+ With<LLVMTarget> llvm_target (*llvm_instance_, LLVMTarget::GetTargetMetadata (*module_));
471+ llvm::orc::JITTargetMachineBuilder tm_builder (llvm::Triple (llvm_target->GetTargetTriple ()));
472+
473+ // set options
474+ tm_builder.setCPU (llvm_target->GetCPU ());
475+ tm_builder.setFeatures (llvm_target->GetTargetFeatureString ());
476+ tm_builder.setOptions (llvm_target->GetTargetOptions ());
477+ #if TVM_LLVM_VERSION <= 170
478+ tm_builder.setCodeGenOptLevel (llvm::CodeGenOpt::Aggressive);
479+ #else
480+ tm_builder.setCodeGenOptLevel (llvm::CodeGenOptLevel::Aggressive);
481+ #endif
482+
483+ // create the taget machine
484+ auto tm = tm_builder.createTargetMachine ();
485+ if (!IsCompatibleWithHost (tm->get ())) {
486+ LOG (FATAL) << " Cannot run module, architecture mismatch" ;
487+ }
488+
489+ // data layout
490+ String module_name = module_->getModuleIdentifier ();
491+ llvm::DataLayout layout (tm->get ()->createDataLayout ());
492+ ICHECK (layout == module_->getDataLayout ())
493+ << " Data layout mismatch between module("
494+ << module_->getDataLayout ().getStringRepresentation () << " )"
495+ << " and ExecutionEngine (" << layout.getStringRepresentation () << " )" ;
496+
497+ // compiler
498+ const auto compilerBuilder = [&](const llvm::orc::JITTargetMachineBuilder&)
499+ -> llvm::Expected<std::unique_ptr<llvm::orc::IRCompileLayer::IRCompiler>> {
500+ return std::make_unique<llvm::orc::TMOwningSimpleCompiler>(std::move (*tm));
501+ };
502+
503+ #if TVM_LLVM_VERSION >= 130
504+ // linker
505+ const auto linkerBuilder = [&](llvm::orc::ExecutionSession& session, const llvm::Triple&) {
506+ return std::make_unique<llvm::orc::ObjectLinkingLayer>(session);
507+ };
508+ #endif
509+
510+ // create LLJIT
511+ orcjit_ee_ = llvm::cantFail (llvm::orc::LLJITBuilder ()
512+ #if TVM_LLVM_VERSION >= 110
513+ .setDataLayout (layout)
514+ #endif
515+ .setCompileFunctionCreator (compilerBuilder)
516+ #if TVM_LLVM_VERSION >= 130
517+ .setObjectLinkingLayerCreator (linkerBuilder)
518+ #endif
519+ .create ());
520+
521+ ICHECK (orcjit_ee_ != nullptr ) << " Failed to initialize LLVM ORCJIT engine for "
522+ << module_->getTargetTriple ();
523+
524+ // store ctors
525+ auto ctors = llvm::orc::getConstructors (*module_);
526+ llvm::orc::CtorDtorRunner ctorRunner (orcjit_ee_->getMainJITDylib ());
527+ ctorRunner.add (ctors);
528+
529+ // resolve system symbols (like pthread, dl, m, etc.)
530+ auto gen =
531+ llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess (layout.getGlobalPrefix ());
532+ ICHECK (gen) << llvm::toString (gen.takeError ()) << " \n " ;
533+ orcjit_ee_->getMainJITDylib ().addGenerator (std::move (gen.get ()));
534+
535+ // transfer module to a clone
536+ auto uctx = std::make_unique<llvm::LLVMContext>();
537+ auto umod = llvm::CloneModule (*(std::move (module_owning_ptr_)));
538+
539+ // add the llvm module to run
540+ llvm::orc::ThreadSafeModule tsm (std::move (umod), std::move (uctx));
541+ auto err = orcjit_ee_->addIRModule (std::move (tsm));
542+ ICHECK (!err) << llvm::toString (std::move (err));
543+
544+ VLOG (2 ) << " LLVM ORCJIT execute " << module_->getModuleIdentifier () << " for triple `"
545+ << llvm_target->GetTargetTriple () << " `"
546+ << " on cpu `" << llvm_target->GetCPU () << " `" ;
547+
548+ // run ctors
549+ err = ctorRunner.run ();
550+ ICHECK (!err) << llvm::toString (std::move (err));
551+
552+ if (void ** ctx_addr =
553+ reinterpret_cast <void **>(GetGlobalAddr (runtime::symbol::tvm_module_ctx, *llvm_target))) {
554+ *ctx_addr = this ;
555+ }
556+ runtime::InitContextFunctions (
557+ [this , &llvm_target](const char * name) { return GetGlobalAddr (name, *llvm_target); });
425558}
426559
427560bool LLVMModuleNode::IsCompatibleWithHost (const llvm::TargetMachine* tm) const {
@@ -439,20 +572,40 @@ bool LLVMModuleNode::IsCompatibleWithHost(const llvm::TargetMachine* tm) const {
439572void * LLVMModuleNode::GetGlobalAddr (const std::string& name, const LLVMTarget& llvm_target) const {
440573 // first verifies if GV exists.
441574 if (module_->getGlobalVariable (name) != nullptr ) {
442- return reinterpret_cast <void *>(ee_->getGlobalValueAddress (name));
443- } else {
444- return nullptr ;
575+ if (jit_engine_ == " mcjit" ) {
576+ return reinterpret_cast <void *>(mcjit_ee_->getGlobalValueAddress (name));
577+ } else if (jit_engine_ == " orcjit" ) {
578+ #if TVM_LLVM_VERSION >= 150
579+ auto addr = llvm::cantFail (orcjit_ee_->lookup (name)).getValue ();
580+ #else
581+ auto addr = llvm::cantFail (orcjit_ee_->lookup (name)).getAddress ();
582+ #endif
583+ return reinterpret_cast <void *>(addr);
584+ } else {
585+ LOG (FATAL) << " Either `mcjit` or `orcjit` are not initialized." ;
586+ }
445587 }
588+ return nullptr ;
446589}
447590
448591void * LLVMModuleNode::GetFunctionAddr (const std::string& name,
449592 const LLVMTarget& llvm_target) const {
450593 // first verifies if GV exists.
451594 if (module_->getFunction (name) != nullptr ) {
452- return reinterpret_cast <void *>(ee_->getFunctionAddress (name));
453- } else {
454- return nullptr ;
595+ if (jit_engine_ == " mcjit" ) {
596+ return reinterpret_cast <void *>(mcjit_ee_->getFunctionAddress (name));
597+ } else if (jit_engine_ == " orcjit" ) {
598+ #if TVM_LLVM_VERSION >= 150
599+ auto addr = llvm::cantFail (orcjit_ee_->lookup (name)).getValue ();
600+ #else
601+ auto addr = llvm::cantFail (orcjit_ee_->lookup (name)).getAddress ();
602+ #endif
603+ return reinterpret_cast <void *>(addr);
604+ } else {
605+ LOG (FATAL) << " Either `mcjit` or `orcjit` are not initialized." ;
606+ }
455607 }
608+ return nullptr ;
456609}
457610
458611TVM_REGISTER_GLOBAL (" target.build.llvm" )
@@ -473,6 +626,7 @@ TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate")
473626 module ->setTargetTriple (llvm_target->GetTargetTriple ());
474627 module ->setDataLayout (llvm_target->GetOrCreateTargetMachine ()->createDataLayout ());
475628 n->Init (std::move (module ), std::move (llvm_instance));
629+ n->SetJITEngine (llvm_target->GetJITEngine ());
476630 return runtime::Module (n);
477631 });
478632
@@ -592,6 +746,7 @@ TVM_REGISTER_GLOBAL("target.llvm_version_major").set_body_typed([]() -> int {
592746TVM_REGISTER_GLOBAL (" runtime.module.loadfile_ll" )
593747 .set_body_typed([](std::string filename, std::string fmt) -> runtime::Module {
594748 auto n = make_object<LLVMModuleNode>();
749+ n->SetJITEngine (" mcjit" );
595750 n->LoadIR (filename);
596751 return runtime::Module (n);
597752 });
@@ -613,6 +768,7 @@ TVM_REGISTER_GLOBAL("codegen.codegen_blob")
613768 std::unique_ptr<llvm::Module> blob =
614769 CodeGenBlob (data, system_lib, llvm_target.get (), c_symbol_prefix);
615770 n->Init (std::move (blob), std::move (llvm_instance));
771+ n->SetJITEngine (llvm_target->GetJITEngine ());
616772 return runtime::Module (n);
617773 });
618774
@@ -642,6 +798,7 @@ runtime::Module CreateLLVMCppMetadataModule(runtime::metadata::Metadata metadata
642798
643799 auto n = make_object<LLVMModuleNode>();
644800 n->Init (std::move (mod), std::move (llvm_instance));
801+ n->SetJITEngine (llvm_target->GetJITEngine ());
645802
646803 auto meta_mod = MetadataModuleCreate (metadata);
647804 meta_mod->Import (runtime::Module (n));
@@ -688,6 +845,7 @@ runtime::Module CreateLLVMCrtMetadataModule(const Array<runtime::Module>& module
688845
689846 auto n = make_object<LLVMModuleNode>();
690847 n->Init (std::move (mod), std::move (llvm_instance));
848+ n->SetJITEngine (llvm_target->GetJITEngine ());
691849 for (auto m : modules) {
692850 n->Import (m);
693851 }
0 commit comments