Skip to content

Commit d026d06

Browse files
[CodeGen] avoid crash if an exception is raised during llvm cpu codegen (#9786)
* avoid crash if an exception is raised during llvm cpu codegen * use pytest.raises
1 parent 4e0bf23 commit d026d06

File tree

2 files changed

+44
-11
lines changed

2 files changed

+44
-11
lines changed

src/target/llvm/codegen_cpu.cc

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,28 @@ llvm::BasicBlock* CodeGenCPU::CheckCallSuccess(llvm::Value* retcode) {
492492
}
493493

494494
void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) {
495+
/*! \brief maintain states that should be guarded when step into compute scope */
496+
struct ComputeScopeStates {
497+
explicit ComputeScopeStates(CodeGenCPU* parent) : parent_(parent) {}
498+
499+
void EnterWithScope() {
500+
std::swap(function_, parent_->function_);
501+
std::swap(analyzer_, parent_->analyzer_);
502+
std::swap(var_map_, parent_->var_map_);
503+
}
504+
505+
void ExitWithScope() {
506+
std::swap(function_, parent_->function_);
507+
std::swap(analyzer_, parent_->analyzer_);
508+
std::swap(var_map_, parent_->var_map_);
509+
}
510+
511+
llvm::Function* function_{nullptr};
512+
std::unordered_map<const VarNode*, llvm::Value*> var_map_;
513+
std::unique_ptr<arith::Analyzer> analyzer_{std::make_unique<arith::Analyzer>()};
514+
CodeGenCPU* parent_;
515+
};
516+
495517
// There are two reasons why we create another function for compute_scope
496518
// - Make sure the generated compute function is clearly separately(though it can get inlined)
497519
// - Set noalias on all the pointer arguments, some of them are loaded from TVMArgs.
@@ -515,13 +537,13 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) {
515537
llvm::Function::Create(ftype, llvm::Function::InternalLinkage,
516538
value->value.operator llvm::StringRef(), module_.get());
517539
BasicBlock* compute_call_end = CheckCallSuccess(builder_->CreateCall(fcompute, arg_values));
518-
// setup compute function.
519-
std::unordered_map<const VarNode*, llvm::Value*> new_vmap;
540+
// enter compute scope and setup compute function.
541+
With<ComputeScopeStates> scope_states_guard(this);
520542
size_t idx = 0;
521543
for (auto it = fcompute->arg_begin(); it != fcompute->arg_end(); ++it, ++idx) {
522544
llvm::Argument* v = &(*it);
523545
const Var& var = vargs[idx];
524-
new_vmap[var.get()] = v;
546+
var_map_[var.get()] = v;
525547
if (var.dtype().is_handle() && !alias_var_set_.count(var.get())) {
526548
// set non alias.
527549
#if TVM_LLVM_VERSION >= 50
@@ -544,18 +566,11 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) {
544566
}
545567
#endif
546568
}
547-
auto new_analyzer = std::make_unique<arith::Analyzer>();
548-
std::swap(function_, fcompute);
549-
std::swap(analyzer_, new_analyzer);
550-
std::swap(var_map_, new_vmap);
569+
function_ = fcompute;
551570
BasicBlock* compute_entry = BasicBlock::Create(*ctx_, "entry", function_);
552571
builder_->SetInsertPoint(compute_entry);
553572
this->VisitStmt(op->body);
554573
builder_->CreateRet(ConstInt32(0));
555-
// swap the var map back, now we are back on track.
556-
std::swap(var_map_, new_vmap);
557-
std::swap(analyzer_, new_analyzer);
558-
std::swap(function_, fcompute);
559574
builder_->SetInsertPoint(compute_call_end);
560575
}
561576

tests/python/unittest/test_target_codegen_llvm.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from tvm import te
2525
from tvm.relay.backend import Runtime
2626
from tvm.contrib import utils, clang
27+
import tvm.script.tir as T
2728
import numpy as np
2829

2930
import math
@@ -906,5 +907,22 @@ def test_llvm_scalar_concat():
906907
m = tvm.build(mod, [x, y, z], target="llvm")
907908

908909

910+
@tvm.testing.requires_llvm
911+
def test_raise_exception_during_codegen():
912+
@T.prim_func
913+
def threadpool_nested_parallel_loop(
914+
A: T.Buffer[(4, 4), "float32"], B: T.Buffer[(4, 4), "float32"]
915+
) -> None:
916+
T.func_attr({"global_symbol": "main", "tir.noalias": True})
917+
for i in T.parallel(4):
918+
for j in T.parallel(4):
919+
T.store(B.data, i * 4 + j, T.load("float32", A.data, i * 4 + j) * 2.0)
920+
921+
with pytest.raises(tvm.TVMError) as e:
922+
tvm.build({"llvm": tvm.IRModule.from_expr(threadpool_nested_parallel_loop)})
923+
msg = str(e)
924+
assert msg.find("Nested parallel loop is not supported") != -1
925+
926+
909927
if __name__ == "__main__":
910928
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 commit comments

Comments
 (0)