@@ -492,6 +492,28 @@ llvm::BasicBlock* CodeGenCPU::CheckCallSuccess(llvm::Value* retcode) {
492492}
493493
494494void CodeGenCPU::CreateComputeScope (const AttrStmtNode* op) {
495+ /* ! \brief maintain states that should be guarded when step into compute scope */
496+ struct ComputeScopeStates {
497+ explicit ComputeScopeStates (CodeGenCPU* parent) : parent_(parent) {}
498+
499+ void EnterWithScope () {
500+ std::swap (function_, parent_->function_ );
501+ std::swap (analyzer_, parent_->analyzer_ );
502+ std::swap (var_map_, parent_->var_map_ );
503+ }
504+
505+ void ExitWithScope () {
506+ std::swap (function_, parent_->function_ );
507+ std::swap (analyzer_, parent_->analyzer_ );
508+ std::swap (var_map_, parent_->var_map_ );
509+ }
510+
511+ llvm::Function* function_{nullptr };
512+ std::unordered_map<const VarNode*, llvm::Value*> var_map_;
513+ std::unique_ptr<arith::Analyzer> analyzer_{std::make_unique<arith::Analyzer>()};
514+ CodeGenCPU* parent_;
515+ };
516+
495517 // There are two reasons why we create another function for compute_scope
496518 // - Make sure the generated compute function is clearly separately(though it can get inlined)
497519 // - Set noalias on all the pointer arguments, some of them are loaded from TVMArgs.
@@ -515,13 +537,13 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) {
515537 llvm::Function::Create (ftype, llvm::Function::InternalLinkage,
516538 value->value .operator llvm::StringRef (), module_.get ());
517539 BasicBlock* compute_call_end = CheckCallSuccess (builder_->CreateCall (fcompute, arg_values));
518- // setup compute function.
519- std::unordered_map< const VarNode*, llvm::Value*> new_vmap ;
540+ // enter compute scope and setup compute function.
541+ With<ComputeScopeStates> scope_states_guard ( this ) ;
520542 size_t idx = 0 ;
521543 for (auto it = fcompute->arg_begin (); it != fcompute->arg_end (); ++it, ++idx) {
522544 llvm::Argument* v = &(*it);
523545 const Var& var = vargs[idx];
524- new_vmap [var.get ()] = v;
546+ var_map_ [var.get ()] = v;
525547 if (var.dtype ().is_handle () && !alias_var_set_.count (var.get ())) {
526548 // set non alias.
527549#if TVM_LLVM_VERSION >= 50
@@ -544,18 +566,11 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) {
544566 }
545567#endif
546568 }
547- auto new_analyzer = std::make_unique<arith::Analyzer>();
548- std::swap (function_, fcompute);
549- std::swap (analyzer_, new_analyzer);
550- std::swap (var_map_, new_vmap);
569+ function_ = fcompute;
551570 BasicBlock* compute_entry = BasicBlock::Create (*ctx_, " entry" , function_);
552571 builder_->SetInsertPoint (compute_entry);
553572 this ->VisitStmt (op->body );
554573 builder_->CreateRet (ConstInt32 (0 ));
555- // swap the var map back, now we are back on track.
556- std::swap (var_map_, new_vmap);
557- std::swap (analyzer_, new_analyzer);
558- std::swap (function_, fcompute);
559574 builder_->SetInsertPoint (compute_call_end);
560575}
561576
0 commit comments