55#ifdef TVM_LLVM_VERSION
66
77#include < tvm/runtime/c_runtime_api.h>
8+ #include < tvm/ir_pass.h>
89#include " ./codegen_llvm.h"
910#include " ../../arithmetic/compute_expr.h"
1011
@@ -30,6 +31,7 @@ void CodeGenLLVM::Init(const std::string& module_name,
3031 t_int8_ = llvm::Type::getInt8Ty (*ctx);
3132 t_int16_ = llvm::Type::getInt16Ty (*ctx);
3233 t_int32_ = llvm::Type::getInt32Ty (*ctx);
34+ t_int64_ = llvm::Type::getInt64Ty (*ctx);
3335 t_float64_ = llvm::Type::getDoubleTy (*ctx);
3436 t_tvm_index_ = llvm::Type::getIntNTy (*ctx, sizeof (tvm_index_t ) * 8 );
3537 t_tvm_context_ = llvm::StructType::create ({t_int_, t_int_});
@@ -43,6 +45,8 @@ void CodeGenLLVM::Init(const std::string& module_name,
4345 t_tvm_type_,
4446 t_tvm_context_});
4547 t_tvm_value_ = llvm::StructType::create ({t_float64_});
48+ t_f_tvm_par_for_lambda_ = llvm::FunctionType::get (
49+ t_int_, {t_int64_, t_int64_, t_void_p_}, false );
4650 md_builder_.reset (new llvm::MDBuilder (*ctx));
4751 md_very_likely_branch_ =
4852 md_builder_->createBranchWeights (1 << 30 , 0 );
@@ -70,7 +74,11 @@ void CodeGenLLVM::Init(const std::string& module_name,
7074 f_tvm_api_set_last_error_ = llvm::Function::Create (
7175 llvm::FunctionType::get (t_void_, {t_char_->getPointerTo ()}, false ),
7276 llvm::Function::ExternalLinkage, " TVMAPISetLastError" , module_.get ());
73-
77+ f_tvm_parallel_for_ = llvm::Function::Create (
78+ llvm::FunctionType::get (t_int_, {
79+ t_int64_, t_int64_, t_f_tvm_par_for_lambda_->getPointerTo (), t_void_p_}
80+ , false ),
81+ llvm::Function::ExternalLinkage, " TVMBackendParallelFor" , module_.get ());
7482 this ->InitTarget (target_triple);
7583 // initialize builder
7684 builder_.reset (new IRBuilder (*ctx));
@@ -141,7 +149,9 @@ void CodeGenLLVM::AddMainFunction(const std::string& entry_func_name) {
141149 }
142150 llvm::BasicBlock* block = llvm::BasicBlock::Create (*ctx_, " entry" , function_);
143151 builder_->SetInsertPoint (block);
144- builder_->CreateRet (builder_->CreateCall (f, args));
152+ llvm::CallInst* call = builder_->CreateCall (f, args);
153+ call->setTailCall (true );
154+ builder_->CreateRet (call);
145155}
146156
147157class FPassManager : public llvm ::legacy::FunctionPassManager {
@@ -545,7 +555,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) {
545555 return nullptr ;
546556}
547557
548- llvm::BasicBlock* CodeGenLLVM::CheckPackedCallSuccess (llvm::Value* retcode) {
558+ llvm::BasicBlock* CodeGenLLVM::CheckCallSuccess (llvm::Value* retcode) {
549559 // create emit codes that checks and load the function.
550560 using llvm::BasicBlock;
551561 BasicBlock* fail_block = BasicBlock::Create (
@@ -563,34 +573,15 @@ llvm::BasicBlock* CodeGenLLVM::CheckPackedCallSuccess(llvm::Value* retcode) {
563573 return end_block;
564574}
565575void CodeGenLLVM::Visit_ (const For* op) {
566- using llvm::BasicBlock;
567- BasicBlock* for_head = BasicBlock::Create (
568- *ctx_, " for_head" , function_);
569- BasicBlock* for_body = BasicBlock::Create (
570- *ctx_, " for_body" , function_);
571- BasicBlock* for_end = BasicBlock::Create (
572- *ctx_, " for_end" , function_);
573- BasicBlock* pre_block = builder_->GetInsertBlock ();
574576 CHECK (is_zero (op->min ));
575- Type t = op->min .type ();
576- llvm::Value* init = ConstInt32 (0 );
577- llvm::Value* extent = MakeValue (op->extent );
578- builder_->CreateBr (for_head);
579-
580- builder_->SetInsertPoint (for_head);
581- llvm::PHINode* index = builder_->CreatePHI (LLVMType (t), 2 );
582- index->addIncoming (init, pre_block);
583- llvm::Value* cond = CreateLT (t, index, extent);
584- builder_->CreateCondBr (cond, for_body, for_end, md_very_likely_branch_);
585- // body of for
586- builder_->SetInsertPoint (for_body);
587- var_map_[op->loop_var .get ()] = index;
588- this ->Visit (op->body );
589- llvm::Value* next_index = CreateAdd (t, index, ConstInt32 (1 ));
590- index->addIncoming (next_index, builder_->GetInsertBlock ());
591- builder_->CreateBr (for_head);
592- // end of for
593- builder_->SetInsertPoint (for_end);
577+ if (op->for_type == ForType::Serial) {
578+ CreateSerialFor (ConstInt32 (0 ), MakeValue (op->extent ),
579+ op->loop_var , op->body );
580+ } else if (op->for_type == ForType::Parallel) {
581+ CreateParallelFor (op);
582+ } else {
583+ LOG (FATAL) << " cannot handle for type " << op->for_type ;
584+ }
594585}
595586
596587void CodeGenLLVM::Visit_ (const IfThenElse* op) {
@@ -807,7 +798,7 @@ llvm::Value* CodeGenLLVM::GetPackedFuncHandle(const std::string& fname) {
807798 llvm::Value* ctx = builder_->CreateLoad (gv_mod_ctx_);
808799 llvm::Value* retcode = builder_->CreateCall (
809800 f_tvm_get_func_from_env_, {ctx, GetConstString (fname), out});
810- init_block = CheckPackedCallSuccess (retcode);
801+ init_block = CheckCallSuccess (retcode);
811802 llvm::Value* loaded_handle = builder_->CreateAlignedLoad (out, align);
812803 builder_->CreateBr (end_block);
813804 // end block
@@ -846,7 +837,7 @@ llvm::Value* CodeGenLLVM::CreateCallPacked(const Call* op) {
846837 }
847838 llvm::Value* ret_value = builder_->CreateAlloca (t_tvm_value_);
848839 llvm::Value* ret_tcode = builder_->CreateAlloca (t_int_);
849- CheckPackedCallSuccess (
840+ CheckCallSuccess (
850841 builder_->CreateCall (
851842 f_tvm_func_call_,
852843 {handle, targs, tcodes, ConstInt32 (nargs), ret_value, ret_tcode}));
@@ -934,6 +925,94 @@ llvm::Value* CodeGenLLVM::GetConstString(const std::string& str) {
934925 }
935926}
936927
928+ void CodeGenLLVM::CreateParallelFor (const For* op) {
929+ using llvm::BasicBlock;
930+ llvm::Value* min = MakeValue (op->min );
931+ llvm::Value* extent = MakeValue (op->extent );
932+ min = builder_->CreateIntCast (min, t_int64_, op->min .type ().is_int ());
933+ extent = builder_->CreateIntCast (extent, t_int64_, op->min .type ().is_int ());
934+ // fields to be packed into closure.
935+ Var loop_var (op->loop_var .node_ );
936+ Array<Var> vfields = ir::UndefinedVars (op->body , {loop_var});
937+ std::vector<llvm::Type*> fields;
938+ for (Var v : vfields) {
939+ auto it = var_map_.find (v.get ());
940+ CHECK (it != var_map_.end ());
941+ fields.push_back (it->second ->getType ());
942+ }
943+ // closure data
944+ llvm::StructType* tcdata = llvm::StructType::create (fields);
945+ llvm::Function* f = llvm::Function::Create (
946+ t_f_tvm_par_for_lambda_,
947+ llvm::Function::PrivateLinkage,
948+ " __tvm_par_for_lambda" , module_.get ());
949+ // allocate and setup the closure, call the closure.
950+ llvm::Value* cdata = builder_->CreateAlloca (tcdata, ConstInt32 (1 ));
951+ llvm::Value* zero = ConstInt32 (0 );
952+
953+ for (size_t i = 0 ; i < vfields.size (); ++i) {
954+ builder_->CreateStore (
955+ var_map_.at (vfields[i].get ()),
956+ builder_->CreateInBoundsGEP (cdata, {zero, ConstInt32 (i)}));
957+ }
958+ BasicBlock* par_for_end = CheckCallSuccess (
959+ builder_->CreateCall (
960+ f_tvm_parallel_for_,
961+ {min, extent, f, builder_->CreatePointerCast (cdata, t_void_p_)}));
962+ // Setup the closure function.
963+ BasicBlock *lambda_entry = BasicBlock::Create (*ctx_, " entry" , f);
964+ builder_->SetInsertPoint (lambda_entry);
965+ auto it = f->arg_begin ();
966+ llvm::Value* begin = &(*it++);
967+ llvm::Value* end = &(*it++);
968+ cdata = &(*it++);
969+ begin = CreateCast (Int (64 ), op->loop_var .type (), begin);
970+ end = CreateCast (Int (64 ), op->loop_var .type (), end);
971+ cdata = builder_->CreatePointerCast (cdata, tcdata->getPointerTo ());
972+ // setup new variable map, swap it with current var context.
973+ std::unordered_map<const Variable*, llvm::Value*> new_vmap;
974+ for (size_t i = 0 ; i < vfields.size (); ++i) {
975+ new_vmap[vfields[i].get ()] =
976+ builder_->CreateLoad (builder_->CreateInBoundsGEP (
977+ cdata, {zero, ConstInt32 (i)}));
978+ }
979+ std::swap (function_, f);
980+ std::swap (new_vmap, var_map_);
981+ CreateSerialFor (begin, end, op->loop_var , op->body );
982+ builder_->CreateRet (ConstInt32 (0 ));
983+ // swap the var map back, now we are back on track.
984+ std::swap (new_vmap, var_map_);
985+ std::swap (function_, f);
986+ builder_->SetInsertPoint (par_for_end);
987+ }
988+
989+ void CodeGenLLVM::CreateSerialFor (llvm::Value* begin, llvm::Value* end,
990+ const VarExpr& loop_var, const Stmt& body) {
991+ using llvm::BasicBlock;
992+ Type t = loop_var.type ();
993+ BasicBlock* for_head = BasicBlock::Create (
994+ *ctx_, " for_head" , function_);
995+ BasicBlock* for_body = BasicBlock::Create (
996+ *ctx_, " for_body" , function_);
997+ BasicBlock* for_end = BasicBlock::Create (
998+ *ctx_, " for_end" , function_);
999+ BasicBlock* pre_block = builder_->GetInsertBlock ();
1000+ builder_->CreateBr (for_head);
1001+ builder_->SetInsertPoint (for_head);
1002+ llvm::PHINode* index = builder_->CreatePHI (begin->getType (), 2 );
1003+ index->addIncoming (begin, pre_block);
1004+ llvm::Value* cond = CreateLT (t, index, end);
1005+ builder_->CreateCondBr (cond, for_body, for_end, md_very_likely_branch_);
1006+ // body of for
1007+ builder_->SetInsertPoint (for_body);
1008+ var_map_[loop_var.get ()] = index;
1009+ this ->Visit (body);
1010+ llvm::Value* next_index = CreateAdd (t, index, ConstInt32 (1 ));
1011+ index->addIncoming (next_index, builder_->GetInsertBlock ());
1012+ builder_->CreateBr (for_head);
1013+ // end of for
1014+ builder_->SetInsertPoint (for_end);
1015+ }
9371016} // namespace codegen
9381017} // namespace tvm
9391018#endif // TVM_LLVM_VERSION
0 commit comments