Skip to content

Commit 6790af8

Browse files
authored
[LLVM] Fixes up to the latest LLVM21 (#18204)
This PR fix TVM use with the latest LLVM version 21. - At this time LLVM21 is available as a release candidate. - Double checks for backward compatibility down to LLVM10
1 parent 789e0b8 commit 6790af8

File tree

7 files changed

+56
-3
lines changed

7 files changed

+56
-3
lines changed

src/target/llvm/codegen_amdgpu.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,11 @@ runtime::Module BuildAMDGPU(IRModule mod, Target target) {
284284

285285
for (auto& bitcode_path : bitcode_files) {
286286
std::unique_ptr<llvm::Module> mlib = llvm_instance.LoadIR(bitcode_path);
287+
#if TVM_LLVM_VERSION >= 210
288+
mlib->setTargetTriple(llvm::Triple(llvm_target->GetTargetTriple()));
289+
#else
287290
mlib->setTargetTriple(llvm_target->GetTargetTriple());
291+
#endif
288292
mlib->setDataLayout(tm->createDataLayout());
289293

290294
for (llvm::Function& f : mlib->functions()) {

src/target/llvm/codegen_blob.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,11 @@ std::unique_ptr<llvm::Module> CodeGenBlob(const std::string& data, bool system_l
6969
llvm::LLVMContext* ctx = llvm_target->GetContext();
7070
std::string module_name = c_symbol_prefix + "devc";
7171
auto module = std::make_unique<llvm::Module>(module_name, *ctx);
72+
#if TVM_LLVM_VERSION >= 210
73+
module->setTargetTriple(triple);
74+
#else
7275
module->setTargetTriple(triple.str());
76+
#endif
7377
llvm_target->SetTargetMetadata(module.get());
7478
module->setDataLayout(tm->createDataLayout());
7579
auto* blob_value = llvm::ConstantDataArray::getString(*ctx, data, false);

src/target/llvm/codegen_llvm.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,11 @@ void CodeGenLLVM::SetFastMathFlags(llvm::FastMathFlags fmf) { builder_->setFastM
168168

169169
void CodeGenLLVM::InitTarget() {
170170
llvm::TargetMachine* tm = llvm_target_->GetOrCreateTargetMachine();
171+
#if TVM_LLVM_VERSION >= 210
172+
module_->setTargetTriple(tm->getTargetTriple());
173+
#else
171174
module_->setTargetTriple(tm->getTargetTriple().str());
175+
#endif
172176
module_->setDataLayout(tm->createDataLayout());
173177
#if TVM_LLVM_VERSION >= 200
174178
data_layout_.reset(new llvm::DataLayout(module_.get()->getDataLayout()));
@@ -374,7 +378,11 @@ void CodeGenLLVM::HandleImport(const std::string& code) {
374378
mlib = llvm_target_->GetInstance().ParseIR(code);
375379
}
376380

381+
#if TVM_LLVM_VERSION >= 210
382+
mlib->setTargetTriple(llvm::Triple(llvm_target_->GetTargetTriple()));
383+
#else
377384
mlib->setTargetTriple(llvm_target_->GetTargetTriple());
385+
#endif
378386
mlib->setDataLayout(llvm_target_->GetOrCreateTargetMachine()->createDataLayout());
379387
// mark all the functions as force inline
380388
for (llvm::Function& f : mlib->functions()) {

src/target/llvm/codegen_nvptx.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,11 @@ class CodeGenNVPTX : public CodeGenLLVM {
189189
} else if (sync == "shared" || sync == "shared.dyn") {
190190
#if TVM_LLVM_VERSION >= 200
191191
llvm::Function* f = llvm::cast<llvm::Function>(llvm::Intrinsic::getOrInsertDeclaration(
192+
#if TVM_LLVM_VERSION >= 210
193+
module_.get(), llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all, {}));
194+
#else
192195
module_.get(), llvm::Intrinsic::nvvm_barrier0, {}));
196+
#endif
193197
#else
194198
llvm::Function* f =
195199
llvm::Intrinsic::getDeclaration(module_.get(), llvm::Intrinsic::nvvm_barrier0);
@@ -335,7 +339,11 @@ runtime::Module BuildNVPTX(IRModule mod, Target target) {
335339
std::string path = (*flibdevice_path)(compute_ver).cast<std::string>();
336340
if (path.length() != 0) {
337341
std::unique_ptr<llvm::Module> mlib = llvm_instance.LoadIR(path);
342+
#if TVM_LLVM_VERSION >= 210
343+
mlib->setTargetTriple(llvm::Triple(llvm_target->GetTargetTriple()));
344+
#else
338345
mlib->setTargetTriple(llvm_target->GetTargetTriple());
346+
#endif
339347
mlib->setDataLayout(tm->createDataLayout());
340348
cg->AddLinkModule(std::move(mlib));
341349
}

src/target/llvm/llvm_instance.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -981,7 +981,11 @@ std::string LLVMTarget::GetTargetMetadata(const llvm::Module& module) {
981981
return meta.str();
982982
}
983983
}
984+
#if TVM_LLVM_VERSION >= 210
985+
return "llvm -mtriple " + module.getTargetTriple().str();
986+
#else
984987
return "llvm -mtriple " + module.getTargetTriple();
988+
#endif
985989
}
986990

987991
void LLVMTarget::SetTargetMetadata(llvm::Module* module) const {

src/target/llvm/llvm_module.cc

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,11 @@ void LLVMModuleNode::InitMCJIT() {
429429
// create MCJIT
430430
mcjit_ee_ = builder.create(tm.release());
431431
ICHECK(mcjit_ee_ != nullptr) << "Failed to initialize LLVM MCJIT engine for "
432+
#if TVM_LLVM_VERSION >= 210
433+
<< module_->getTargetTriple().str();
434+
#else
432435
<< module_->getTargetTriple();
436+
#endif
433437

434438
VLOG(2) << "LLVM MCJIT execute " << module_->getModuleIdentifier() << " for triple `"
435439
<< llvm_target->GetTargetTriple() << "`"
@@ -503,21 +507,34 @@ void LLVMModuleNode::InitORCJIT() {
503507
#if TVM_LLVM_VERSION >= 130
504508
// linker
505509
const auto linkerBuilder =
510+
#if TVM_LLVM_VERSION >= 210
511+
[&](llvm::orc::ExecutionSession& session)
512+
-> llvm::Expected<std::unique_ptr<llvm::orc::ObjectLayer>> {
513+
#else
506514
[&](llvm::orc::ExecutionSession& session,
507515
const llvm::Triple& triple) -> std::unique_ptr<llvm::orc::ObjectLayer> {
516+
#endif
508517
#if _WIN32
509518
auto GetMemMgr = []() { return std::make_unique<llvm::SectionMemoryManager>(); };
510519
auto ObjLinkingLayer =
511520
std::make_unique<llvm::orc::RTDyldObjectLinkingLayer>(session, std::move(GetMemMgr));
512521
#else
513522
auto ObjLinkingLayer = std::make_unique<llvm::orc::ObjectLinkingLayer>(session);
514523
#endif
524+
#if TVM_LLVM_VERSION >= 210
525+
if (tm_builder.getTargetTriple().isOSBinFormatCOFF()) {
526+
#else
515527
if (triple.isOSBinFormatCOFF()) {
528+
#endif
516529
ObjLinkingLayer->setOverrideObjectFlagsWithResponsibilityFlags(true);
517530
ObjLinkingLayer->setAutoClaimResponsibilityForObjectSymbols(true);
518531
}
532+
#if TVM_LLVM_VERSION >= 210
533+
return llvm::Expected<std::unique_ptr<llvm::orc::ObjectLayer>>(std::move(ObjLinkingLayer));
534+
#else
519535
return ObjLinkingLayer;
520-
};
536+
#endif
537+
}; // NOLINT(readability/braces)
521538
#endif
522539

523540
// create LLJIT
@@ -532,7 +549,11 @@ void LLVMModuleNode::InitORCJIT() {
532549
.create());
533550

534551
ICHECK(orcjit_ee_ != nullptr) << "Failed to initialize LLVM ORCJIT engine for "
552+
#if TVM_LLVM_VERSION >= 210
553+
<< module_->getTargetTriple().str();
554+
#else
535555
<< module_->getTargetTriple();
556+
#endif
536557

537558
// store ctors
538559
auto ctors = llvm::orc::getConstructors(*module_);
@@ -638,7 +659,11 @@ static void LLVMReflectionRegister() {
638659
// Generate a LLVM module from an input target string
639660
auto module = std::make_unique<llvm::Module>(module_name, *llvm_target->GetContext());
640661
llvm_target->SetTargetMetadata(module.get());
662+
#if TVM_LLVM_VERSION >= 210
663+
module->setTargetTriple(llvm::Triple(llvm_target->GetTargetTriple()));
664+
#else
641665
module->setTargetTriple(llvm_target->GetTargetTriple());
666+
#endif
642667
module->setDataLayout(llvm_target->GetOrCreateTargetMachine()->createDataLayout());
643668
n->Init(std::move(module), std::move(llvm_instance));
644669
n->SetJITEngine(llvm_target->GetJITEngine());

tests/cpp/target/parsers/aprofile_test.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ TEST_F(AProfileParser, DefaultSVESupportSVESupport) {
317317
TargetJSON target = ParseTargetWithAttrs("", "aarch64-arm-none-eabi", {arch_attr});
318318
TargetFeatures features = Downcast<TargetFeatures>(target.at("features"));
319319
EXPECT_TRUE(IsArch(target));
320-
#if TVM_LLVM_VERSION >= 190
320+
#if TVM_LLVM_VERSION >= 190 || (TVM_LLVM_VERSION / 10) == 13
321321
// The generic aarch64 should not have SVE enabled
322322
EXPECT_FALSE(Downcast<Bool>(features.at("has_sve")));
323323
#else
@@ -364,7 +364,7 @@ TEST_F(AProfileParser, DefaultFP16Support) {
364364
TargetJSON target = ParseTargetWithAttrs("", "aarch64-arm-none-eabi", {arch_attr});
365365
TargetFeatures features = Downcast<TargetFeatures>(target.at("features"));
366366
EXPECT_TRUE(IsArch(target));
367-
#if TVM_LLVM_VERSION >= 190
367+
#if TVM_LLVM_VERSION >= 190 || (TVM_LLVM_VERSION / 10) == 13
368368
// The generic aarch64 should not have FP16 enabled
369369
EXPECT_FALSE(Downcast<Bool>(features.at("has_fp16_simd")));
370370
#else

0 commit comments

Comments
 (0)