Skip to content

Commit 8e6c309

Browse files
tqchen
authored andcommitted
[LLVM/RUNTIME] Support Parallel for on CPU
1 parent 2f462cc commit 8e6c309

File tree

15 files changed

+298
-86
lines changed

15 files changed

+298
-86
lines changed

include/tvm/ir_pass.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,11 +173,12 @@ LoweredFunc MakeAPI(Stmt body,
173173
int num_unpacked_args);
174174

175175
/*!
176-
* \brief Count number of undefined vars in f.
177-
* \param f The function to be checked.
178-
* \return Number of undefined vars.
176+
* \brief Find undefined vars in the statment.
177+
* \param stmt The function to be checked.
178+
* \param defs The vars that is defined.
179+
* \return Array of undefined vars.
179180
*/
180-
Array<Var> UndefinedVars(const LoweredFunc& f);
181+
Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& defs);
181182

182183
/*!
183184
* \brief Split the function into a host function and device functions.

include/tvm/runtime/c_runtime_api.h

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,18 @@ TVM_DLL int TVMModPreCompile(TVMModuleHandle mod,
225225
const char* func_name,
226226
TVMContext ctx);
227227

228+
/*!
229+
* \brief Free the Module
230+
* \param mod The module to be freed.
231+
*
232+
* \note This may not free up the module's resources.
233+
* If there is active TVMFunctionHandle uses the module
234+
* Or if this module is imported by another active module.
235+
*
236+
* The all functions remains valid until TVMFuncFree is called.
237+
*/
238+
TVM_DLL int TVMModFree(TVMModuleHandle mod);
239+
228240
/*!
229241
* \brief Backend function for modules to get function
230242
* from its environment mod_node (its imports and global function).
@@ -242,17 +254,25 @@ TVM_DLL int TVMModPreCompile(TVMModuleHandle mod,
242254
TVM_DLL int TVMBackendGetFuncFromEnv(void* mod_node,
243255
const char* func_name,
244256
TVMFunctionHandle *out);
257+
245258
/*!
246-
* \brief Free the Module
247-
* \param mod The module to be freed.
259+
* \brief Backend function for running parallel for loop.
248260
*
249-
* \note This may not free up the module's resources.
250-
* If there is active TVMFunctionHandle uses the module
251-
* Or if this module is imported by another active module.
261+
* \note This API is supposed to be used by backend,
262+
* it is not supposed to be used by user.
252263
*
253-
* The all functions remains valid until TVMFuncFree is called.
264+
* \param begin The start of iteration.
265+
* \param end The end of iteration.
266+
* \param lambda The lambda function to be executed.
267+
* \param env The environment of lambda function.
268+
*
269+
* \return 0 when no error is thrown, -1 when failure happens
254270
*/
255-
TVM_DLL int TVMModFree(TVMModuleHandle mod);
271+
TVM_DLL int TVMBackendParallelFor(
272+
int64_t begin,
273+
int64_t end,
274+
int (*lambda)(int64_t begin, int64_t end, void* env),
275+
void* env);
256276

257277
/*!
258278
* \brief Free the function when it is no longer needed.

include/tvm/schedule.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ enum AttachType : int {
3434
/*! \brief IterVar type */
3535
enum IterVarType : int {
3636
kUnrolled = 1,
37-
kVectorized = 2
37+
kVectorized = 2,
38+
kParallel = 3
3839
};
3940

4041
/*! \brief Stage, contains scheduling for a stage of computation. */
@@ -152,6 +153,12 @@ class Stage : public NodeRef {
152153
* \return reference to self.
153154
*/
154155
Stage& unroll(IterVar var); // NOLINT(*)
156+
/*!
157+
* \brief Parallelize iteration.
158+
* \param var The axis to be parallelized.
159+
* \return reference to self.
160+
*/
161+
Stage& parallel(IterVar var); // NOLINT(*)
155162
/*!
156163
* \brief whether the stage has been scheduled.
157164
* \return whether the stage has been scheduled.

python/tvm/schedule.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,3 +257,13 @@ def unroll(self, var):
257257
The iteration to be unrolled.
258258
"""
259259
_api_internal._StageUnroll(self, var)
260+
261+
def parallel(self, var):
262+
"""Parallelize the iteration.
263+
264+
Parameters
265+
----------
266+
var : IterVar
267+
The iteration to be parallelized.
268+
"""
269+
_api_internal._StageParallel(self, var)

src/api/api_lang.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,12 @@ TVM_REGISTER_API(_StageVectorize)
280280
.vectorize(args[1]);
281281
});
282282

283+
TVM_REGISTER_API(_StageParallel)
284+
.set_body([](TVMArgs args, TVMRetValue* ret) {
285+
args[0].operator Stage()
286+
.parallel(args[1]);
287+
});
288+
283289
TVM_REGISTER_API(_ScheduleNormalize)
284290
.set_body([](TVMArgs args, TVMRetValue* ret) {
285291
args[0].operator Schedule()

src/codegen/llvm/codegen_llvm.cc

Lines changed: 111 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#ifdef TVM_LLVM_VERSION
66

77
#include <tvm/runtime/c_runtime_api.h>
8+
#include <tvm/ir_pass.h>
89
#include "./codegen_llvm.h"
910
#include "../../arithmetic/compute_expr.h"
1011

@@ -30,6 +31,7 @@ void CodeGenLLVM::Init(const std::string& module_name,
3031
t_int8_ = llvm::Type::getInt8Ty(*ctx);
3132
t_int16_ = llvm::Type::getInt16Ty(*ctx);
3233
t_int32_ = llvm::Type::getInt32Ty(*ctx);
34+
t_int64_ = llvm::Type::getInt64Ty(*ctx);
3335
t_float64_ = llvm::Type::getDoubleTy(*ctx);
3436
t_tvm_index_ = llvm::Type::getIntNTy(*ctx, sizeof(tvm_index_t) * 8);
3537
t_tvm_context_ = llvm::StructType::create({t_int_, t_int_});
@@ -43,6 +45,8 @@ void CodeGenLLVM::Init(const std::string& module_name,
4345
t_tvm_type_,
4446
t_tvm_context_});
4547
t_tvm_value_ = llvm::StructType::create({t_float64_});
48+
t_f_tvm_par_for_lambda_ = llvm::FunctionType::get(
49+
t_int_, {t_int64_, t_int64_, t_void_p_}, false);
4650
md_builder_.reset(new llvm::MDBuilder(*ctx));
4751
md_very_likely_branch_ =
4852
md_builder_->createBranchWeights(1 << 30, 0);
@@ -70,7 +74,11 @@ void CodeGenLLVM::Init(const std::string& module_name,
7074
f_tvm_api_set_last_error_ = llvm::Function::Create(
7175
llvm::FunctionType::get(t_void_, {t_char_->getPointerTo()}, false),
7276
llvm::Function::ExternalLinkage, "TVMAPISetLastError", module_.get());
73-
77+
f_tvm_parallel_for_ = llvm::Function::Create(
78+
llvm::FunctionType::get(t_int_, {
79+
t_int64_, t_int64_, t_f_tvm_par_for_lambda_->getPointerTo(), t_void_p_}
80+
, false),
81+
llvm::Function::ExternalLinkage, "TVMBackendParallelFor", module_.get());
7482
this->InitTarget(target_triple);
7583
// initialize builder
7684
builder_.reset(new IRBuilder(*ctx));
@@ -141,7 +149,9 @@ void CodeGenLLVM::AddMainFunction(const std::string& entry_func_name) {
141149
}
142150
llvm::BasicBlock* block = llvm::BasicBlock::Create(*ctx_, "entry", function_);
143151
builder_->SetInsertPoint(block);
144-
builder_->CreateRet(builder_->CreateCall(f, args));
152+
llvm::CallInst* call = builder_->CreateCall(f, args);
153+
call->setTailCall(true);
154+
builder_->CreateRet(call);
145155
}
146156

147157
class FPassManager : public llvm::legacy::FunctionPassManager {
@@ -545,7 +555,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) {
545555
return nullptr;
546556
}
547557

548-
llvm::BasicBlock* CodeGenLLVM::CheckPackedCallSuccess(llvm::Value* retcode) {
558+
llvm::BasicBlock* CodeGenLLVM::CheckCallSuccess(llvm::Value* retcode) {
549559
// create emit codes that checks and load the function.
550560
using llvm::BasicBlock;
551561
BasicBlock* fail_block = BasicBlock::Create(
@@ -563,34 +573,15 @@ llvm::BasicBlock* CodeGenLLVM::CheckPackedCallSuccess(llvm::Value* retcode) {
563573
return end_block;
564574
}
565575
void CodeGenLLVM::Visit_(const For* op) {
566-
using llvm::BasicBlock;
567-
BasicBlock* for_head = BasicBlock::Create(
568-
*ctx_, "for_head", function_);
569-
BasicBlock* for_body = BasicBlock::Create(
570-
*ctx_, "for_body", function_);
571-
BasicBlock* for_end = BasicBlock::Create(
572-
*ctx_, "for_end", function_);
573-
BasicBlock* pre_block = builder_->GetInsertBlock();
574576
CHECK(is_zero(op->min));
575-
Type t = op->min.type();
576-
llvm::Value* init = ConstInt32(0);
577-
llvm::Value* extent = MakeValue(op->extent);
578-
builder_->CreateBr(for_head);
579-
580-
builder_->SetInsertPoint(for_head);
581-
llvm::PHINode* index = builder_->CreatePHI(LLVMType(t), 2);
582-
index->addIncoming(init, pre_block);
583-
llvm::Value* cond = CreateLT(t, index, extent);
584-
builder_->CreateCondBr(cond, for_body, for_end, md_very_likely_branch_);
585-
// body of for
586-
builder_->SetInsertPoint(for_body);
587-
var_map_[op->loop_var.get()] = index;
588-
this->Visit(op->body);
589-
llvm::Value* next_index = CreateAdd(t, index, ConstInt32(1));
590-
index->addIncoming(next_index, builder_->GetInsertBlock());
591-
builder_->CreateBr(for_head);
592-
// end of for
593-
builder_->SetInsertPoint(for_end);
577+
if (op->for_type == ForType::Serial) {
578+
CreateSerialFor(ConstInt32(0), MakeValue(op->extent),
579+
op->loop_var, op->body);
580+
} else if (op->for_type == ForType::Parallel) {
581+
CreateParallelFor(op);
582+
} else {
583+
LOG(FATAL) << "cannot handle for type " << op->for_type;
584+
}
594585
}
595586

596587
void CodeGenLLVM::Visit_(const IfThenElse* op) {
@@ -807,7 +798,7 @@ llvm::Value* CodeGenLLVM::GetPackedFuncHandle(const std::string& fname) {
807798
llvm::Value* ctx = builder_->CreateLoad(gv_mod_ctx_);
808799
llvm::Value* retcode = builder_->CreateCall(
809800
f_tvm_get_func_from_env_, {ctx, GetConstString(fname), out});
810-
init_block = CheckPackedCallSuccess(retcode);
801+
init_block = CheckCallSuccess(retcode);
811802
llvm::Value* loaded_handle = builder_->CreateAlignedLoad(out, align);
812803
builder_->CreateBr(end_block);
813804
// end block
@@ -846,7 +837,7 @@ llvm::Value* CodeGenLLVM::CreateCallPacked(const Call* op) {
846837
}
847838
llvm::Value* ret_value = builder_->CreateAlloca(t_tvm_value_);
848839
llvm::Value* ret_tcode = builder_->CreateAlloca(t_int_);
849-
CheckPackedCallSuccess(
840+
CheckCallSuccess(
850841
builder_->CreateCall(
851842
f_tvm_func_call_,
852843
{handle, targs, tcodes, ConstInt32(nargs), ret_value, ret_tcode}));
@@ -934,6 +925,94 @@ llvm::Value* CodeGenLLVM::GetConstString(const std::string& str) {
934925
}
935926
}
936927

928+
void CodeGenLLVM::CreateParallelFor(const For* op) {
929+
using llvm::BasicBlock;
930+
llvm::Value* min = MakeValue(op->min);
931+
llvm::Value* extent = MakeValue(op->extent);
932+
min = builder_->CreateIntCast(min, t_int64_, op->min.type().is_int());
933+
extent = builder_->CreateIntCast(extent, t_int64_, op->min.type().is_int());
934+
// fields to be packed into closure.
935+
Var loop_var(op->loop_var.node_);
936+
Array<Var> vfields = ir::UndefinedVars(op->body, {loop_var});
937+
std::vector<llvm::Type*> fields;
938+
for (Var v : vfields) {
939+
auto it = var_map_.find(v.get());
940+
CHECK(it != var_map_.end());
941+
fields.push_back(it->second->getType());
942+
}
943+
// closure data
944+
llvm::StructType* tcdata = llvm::StructType::create(fields);
945+
llvm::Function* f = llvm::Function::Create(
946+
t_f_tvm_par_for_lambda_,
947+
llvm::Function::PrivateLinkage,
948+
"__tvm_par_for_lambda", module_.get());
949+
// allocate and setup the closure, call the closure.
950+
llvm::Value* cdata = builder_->CreateAlloca(tcdata, ConstInt32(1));
951+
llvm::Value* zero = ConstInt32(0);
952+
953+
for (size_t i = 0; i < vfields.size(); ++i) {
954+
builder_->CreateStore(
955+
var_map_.at(vfields[i].get()),
956+
builder_->CreateInBoundsGEP(cdata, {zero, ConstInt32(i)}));
957+
}
958+
BasicBlock* par_for_end = CheckCallSuccess(
959+
builder_->CreateCall(
960+
f_tvm_parallel_for_,
961+
{min, extent, f, builder_->CreatePointerCast(cdata, t_void_p_)}));
962+
// Setup the closure function.
963+
BasicBlock *lambda_entry = BasicBlock::Create(*ctx_, "entry", f);
964+
builder_->SetInsertPoint(lambda_entry);
965+
auto it = f->arg_begin();
966+
llvm::Value* begin = &(*it++);
967+
llvm::Value* end = &(*it++);
968+
cdata = &(*it++);
969+
begin = CreateCast(Int(64), op->loop_var.type(), begin);
970+
end = CreateCast(Int(64), op->loop_var.type(), end);
971+
cdata = builder_->CreatePointerCast(cdata, tcdata->getPointerTo());
972+
// setup new variable map, swap it with current var context.
973+
std::unordered_map<const Variable*, llvm::Value*> new_vmap;
974+
for (size_t i = 0; i < vfields.size(); ++i) {
975+
new_vmap[vfields[i].get()] =
976+
builder_->CreateLoad(builder_->CreateInBoundsGEP(
977+
cdata, {zero, ConstInt32(i)}));
978+
}
979+
std::swap(function_, f);
980+
std::swap(new_vmap, var_map_);
981+
CreateSerialFor(begin, end, op->loop_var, op->body);
982+
builder_->CreateRet(ConstInt32(0));
983+
// swap the var map back, now we are back on track.
984+
std::swap(new_vmap, var_map_);
985+
std::swap(function_, f);
986+
builder_->SetInsertPoint(par_for_end);
987+
}
988+
989+
void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end,
990+
const VarExpr& loop_var, const Stmt& body) {
991+
using llvm::BasicBlock;
992+
Type t = loop_var.type();
993+
BasicBlock* for_head = BasicBlock::Create(
994+
*ctx_, "for_head", function_);
995+
BasicBlock* for_body = BasicBlock::Create(
996+
*ctx_, "for_body", function_);
997+
BasicBlock* for_end = BasicBlock::Create(
998+
*ctx_, "for_end", function_);
999+
BasicBlock* pre_block = builder_->GetInsertBlock();
1000+
builder_->CreateBr(for_head);
1001+
builder_->SetInsertPoint(for_head);
1002+
llvm::PHINode* index = builder_->CreatePHI(begin->getType(), 2);
1003+
index->addIncoming(begin, pre_block);
1004+
llvm::Value* cond = CreateLT(t, index, end);
1005+
builder_->CreateCondBr(cond, for_body, for_end, md_very_likely_branch_);
1006+
// body of for
1007+
builder_->SetInsertPoint(for_body);
1008+
var_map_[loop_var.get()] = index;
1009+
this->Visit(body);
1010+
llvm::Value* next_index = CreateAdd(t, index, ConstInt32(1));
1011+
index->addIncoming(next_index, builder_->GetInsertBlock());
1012+
builder_->CreateBr(for_head);
1013+
// end of for
1014+
builder_->SetInsertPoint(for_end);
1015+
}
9371016
} // namespace codegen
9381017
} // namespace tvm
9391018
#endif // TVM_LLVM_VERSION

src/codegen/llvm/codegen_llvm.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,12 @@ class CodeGenLLVM : public IRVisitor {
152152
llvm::StructType* t_tvm_type_{nullptr};
153153
llvm::StructType* t_tvm_array_{nullptr};
154154
llvm::StructType* t_tvm_value_{nullptr};
155+
llvm::FunctionType* t_f_tvm_par_for_lambda_{nullptr};
155156
// tvm api functions
156157
llvm::Function* f_tvm_func_call_{nullptr};
157158
llvm::Function* f_tvm_get_func_from_env_{nullptr};
158159
llvm::Function* f_tvm_api_set_last_error_{nullptr};
160+
llvm::Function* f_tvm_parallel_for_{nullptr};
159161
// The acting body
160162
llvm::BasicBlock* block_{nullptr};
161163
// Last value returned codegen call.
@@ -176,10 +178,15 @@ class CodeGenLLVM : public IRVisitor {
176178
llvm::Value* CreateBufferPtr(Type t, llvm::Value* buffer, llvm::Value* index);
177179
llvm::Value* CreateCast(Type from, Type to, llvm::Value* value);
178180
llvm::Value* GetPackedFuncHandle(const std::string& str);
181+
// Create parallel for.
182+
void CreateParallelFor(const For* op);
183+
// Create serial for
184+
void CreateSerialFor(llvm::Value* begin, llvm::Value* end,
185+
const VarExpr& loop_var, const Stmt& body);
179186
// Check if the call to packed function is successful
180187
// if not directly finalize function and pass on return code.
181188
// return the end block after the check
182-
llvm::BasicBlock* CheckPackedCallSuccess(llvm::Value* retcode);
189+
llvm::BasicBlock* CheckCallSuccess(llvm::Value* retcode);
183190
// Initialize target
184191
void InitTarget(const std::string& target);
185192
// Add a function to set global module context

src/lang/lowered_func.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
/*!
2+
* Copyright (c) 2017 by Contributors
3+
* \file lowered_func.cc
4+
*/
5+
#include <tvm/lowered_func.h>
6+
7+
namespace tvm {
8+
9+
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
10+
.set_dispatch<LoweredFuncNode>([](const LoweredFuncNode *op, IRPrinter *p) {
11+
p->stream << "LoweredFunc(" << op->name << ", " << op << ")";
12+
});
13+
14+
TVM_REGISTER_NODE_TYPE(LoweredFuncNode);
15+
16+
} // namespace tvm

src/pass/make_api.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ LoweredFunc MakeAPI(Stmt body,
188188
n->is_packed_func = num_unpacked_args == 0;
189189
n->body = MergeNest({seq_init, seq_check}, body);
190190
LoweredFunc f(n);
191-
Array<Var> undefined = UndefinedVars(f);
191+
Array<Var> undefined = UndefinedVars(f->body, f->args);
192192
if (undefined.size() != 0) {
193193
std::ostringstream os;
194194
for (Var v : undefined) {

0 commit comments

Comments
 (0)