Skip to content

Commit 443e7a8

Browse files
author
Giuseppe Rossini
committed
Decoupling AOT from graph memory planner
In this PR we are decoupling AOT from the Graph Memory Planner. Since AOT has the runner expressed in TIR we can get rid of the GMP in relay and use the Storage Rewrite Pass to do memory planning on the runner function. This also sorts out the issue mentioned in #8062 Change-Id: I6e33fadbf0462edf0366ee37e84ffde26123d3cb
1 parent dbd076a commit 443e7a8

File tree

3 files changed

+257
-56
lines changed

3 files changed

+257
-56
lines changed

src/relay/backend/aot_executor_codegen.cc

Lines changed: 196 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include <tvm/tir/expr.h>
3232
#include <tvm/tir/function.h>
3333
#include <tvm/tir/stmt.h>
34+
#include <tvm/tir/transform.h>
3435

3536
#include <algorithm>
3637
#include <list>
@@ -44,52 +45,179 @@ namespace tvm {
4445
namespace relay {
4546
namespace backend {
4647

48+
/**
49+
* Struct to contain information about intermediate variables in the
50+
* runner function
51+
*/
52+
struct StorageInfo {
53+
/*! \brief unique integer identifier of the particular intermediate variable */
54+
std::vector<int> ids;
55+
/*! \brief exact size of the temporary */
56+
std::vector<int> sizes_bytes;
57+
/*! \brief device type of the temporary variable */
58+
std::vector<int> dev_types;
59+
};
60+
4761
using IntegerArray = Array<Integer>;
4862
using TargetsMap = std::unordered_map<int, Target>;
63+
using StorageMap =
64+
std::unordered_map<Expr, StorageInfo, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>;
4965

50-
class AotReturnSidVisitor : public ExprVisitor {
66+
/**
67+
* This is an on demand allocator for AOT. A new temporary
68+
* (storage allocator identifier) is allocated for each operation.
69+
*/
70+
class AOTOnDemandAllocator : public ExprVisitor {
5171
public:
52-
explicit AotReturnSidVisitor(Map<Expr, Array<IntegerArray>> storage_device_map)
53-
: storage_device_map_{storage_device_map}, return_sid_{-1} {}
72+
// run the visitor on a function.
73+
void Run(const Function& func) {
74+
node_device_map_ = CollectDeviceInfo(func);
5475

55-
IntegerArray FindReturnSid(Function func) {
56-
VisitExpr(func->body);
57-
return return_sid_;
76+
for (Expr param : func->params) {
77+
CreateSid(param.operator->());
78+
}
79+
80+
GetSid(func->body);
5881
}
5982

60-
protected:
61-
void AssignReturnSid(Expr e) {
62-
auto iter = storage_device_map_.find(e);
63-
if (iter != storage_device_map_.end()) {
64-
return_sid_ = (*iter).second[0];
83+
std::vector<int> GetReturnIds() const { return return_ids_; }
84+
85+
StorageMap GetStorageMap() const { return storage_device_map_; }
86+
87+
void VisitExpr_(const ConstantNode* op) final {
88+
CreateSid(op);
89+
AssignReturnSid(GetRef<Expr>(op));
90+
}
91+
92+
void VisitExpr_(const CallNode* op) final {
93+
// create token for the call node.
94+
CreateSid(op);
95+
for (Expr arg : op->args) {
96+
GetSid(arg);
6597
}
98+
AssignReturnSid(GetRef<Expr>(op));
6699
}
67100

68-
void VisitExpr_(const ConstantNode* cn) override {
69-
ExprVisitor::VisitExpr_(cn);
70-
AssignReturnSid(GetRef<Expr>(cn));
101+
void VisitExpr_(const VarNode* op) final {
102+
ExprVisitor::VisitExpr_(op);
103+
AssignReturnSid(GetRef<Expr>(op));
71104
}
72105

73-
void VisitExpr_(const VarNode* vn) override {
74-
ExprVisitor::VisitExpr_(vn);
75-
AssignReturnSid(GetRef<Expr>(vn));
106+
void VisitExpr_(const FunctionNode* op) final {
107+
// do not recurse into sub function.
76108
}
77109

78-
void VisitExpr_(const CallNode* cn) override {
79-
ExprVisitor::VisitExpr_(cn);
80-
AssignReturnSid(GetRef<Expr>(cn));
110+
void VisitExpr_(const GlobalVarNode* op) final {
111+
// Do nothing.
81112
}
82113

83-
void VisitExpr_(const LetNode* op) override { VisitExpr(op->body); }
114+
void VisitExpr_(const OpNode* op) final {
115+
// Do nothing.
116+
}
117+
118+
void VisitExpr_(const TupleNode* op) final {
119+
StorageInfo field_sid;
120+
Expr expr = GetRef<Expr>(op);
121+
for (Expr field : op->fields) {
122+
auto sid = GetSid(field);
123+
field_sid.ids.insert(field_sid.ids.end(), sid.ids.begin(), sid.ids.end());
124+
field_sid.dev_types.insert(field_sid.dev_types.end(), sid.dev_types.begin(),
125+
sid.dev_types.end());
126+
field_sid.sizes_bytes.insert(field_sid.sizes_bytes.end(), sid.sizes_bytes.begin(),
127+
sid.sizes_bytes.end());
128+
}
129+
130+
storage_device_map_[expr] = field_sid;
131+
AssignReturnSid(expr);
132+
}
84133

85-
void VisitExpr_(const TupleNode* tn) override {
86-
ExprVisitor::VisitExpr_(tn);
87-
AssignReturnSid(GetRef<Expr>(tn));
134+
void VisitExpr_(const TupleGetItemNode* op) final {
135+
Expr expr = GetRef<Expr>(op);
136+
const auto& sid = GetSid(op->tuple);
137+
ICHECK_LT(static_cast<size_t>(op->index), sid.ids.size());
138+
storage_device_map_[expr].ids = {sid.ids[op->index]};
139+
storage_device_map_[expr].sizes_bytes = {sid.sizes_bytes[op->index]};
140+
storage_device_map_[expr].dev_types = {sid.dev_types[op->index]};
141+
AssignReturnSid(expr);
88142
}
89143

144+
void VisitExpr_(const IfNode* op) final { LOG(FATAL) << "if is not supported."; }
145+
146+
void VisitExpr_(const LetNode* op) final { LOG(FATAL) << "if is not supported."; }
147+
90148
private:
91-
Map<Expr, Array<IntegerArray>> storage_device_map_;
92-
IntegerArray return_sid_;
149+
void AssignReturnSid(Expr e) {
150+
auto iter = storage_device_map_.find(e);
151+
if (iter != storage_device_map_.end()) {
152+
return_ids_ = (*iter).second.ids;
153+
}
154+
}
155+
/*!
156+
* \brief ceil(size/word_size) to get number of words.
157+
* \param size The original size.
158+
* \param word_size The element size.
159+
*/
160+
static size_t DivRoundUp(size_t size, size_t word_size) {
161+
return (size + word_size - 1) / word_size;
162+
}
163+
/*!
164+
* \brief Get the memory requirement.
165+
* \param prototype The prototype token.
166+
* \return The required memory size.
167+
*/
168+
size_t GetMemorySize(const TensorTypeNode* ttype) {
169+
ICHECK(ttype != nullptr);
170+
size_t size = 1;
171+
for (IndexExpr dim : ttype->shape) {
172+
const int64_t* pval = tir::as_const_int(dim);
173+
ICHECK(pval != nullptr) << "Cannot allocate memory symbolic tensor shape " << ttype->shape;
174+
ICHECK_GE(*pval, 0) << "Cannot allocate memory for tensor with negative shape" << *pval;
175+
size *= static_cast<size_t>(pval[0]);
176+
}
177+
size *= DivRoundUp(ttype->dtype.bits() * ttype->dtype.lanes(), 8);
178+
return size;
179+
}
180+
/*!
181+
* \brief Get the necessary token.
182+
* \param expr The expression.
183+
* \return The corresponding token.
184+
*/
185+
StorageInfo GetSid(const Expr& expr) {
186+
this->VisitExpr(expr);
187+
auto it = storage_device_map_.find(expr);
188+
ICHECK(it != storage_device_map_.end());
189+
return it->second;
190+
}
191+
192+
void CreateSid(const ExprNode* op) {
193+
StorageInfo sid;
194+
Expr expr = GetRef<Expr>(op);
195+
int device_type = node_device_map_.count(GetRef<Expr>(op)) ? node_device_map_[expr]->value : 0;
196+
if (const auto* tuple_type = op->checked_type().as<TupleTypeNode>()) {
197+
for (Type t : tuple_type->fields) {
198+
const auto* ttype = t.as<TensorTypeNode>();
199+
ICHECK(ttype);
200+
sid.ids.push_back(sid_++);
201+
sid.dev_types.push_back(device_type);
202+
sid.sizes_bytes.push_back(GetMemorySize(ttype));
203+
}
204+
} else {
205+
const auto* ttype = op->checked_type().as<TensorTypeNode>();
206+
ICHECK(ttype);
207+
sid.ids.push_back(sid_++);
208+
sid.dev_types.push_back(device_type);
209+
sid.sizes_bytes.push_back(GetMemorySize(ttype));
210+
}
211+
storage_device_map_[expr] = sid;
212+
}
213+
/*! \brief mapping of expression -> storageInfo*/
214+
StorageMap storage_device_map_;
215+
/*! \brief mapping of expression -> device type*/
216+
Map<Expr, Integer> node_device_map_;
217+
/*! \brief current id of the temporary allocated*/
218+
int sid_{0};
219+
/*! \brief the set of identifiers that are return variables */
220+
std::vector<int> return_ids_;
93221
};
94222

95223
/*! \brief Code generator for AOT executor */
@@ -120,14 +248,14 @@ class AOTExecutorCodegen : public ExprVisitor {
120248
* \brief Return a vector of variables that represents the sids for the given Relay Expr
121249
*/
122250
std::vector<tir::Var> PackSid(Expr expr) {
123-
Array<IntegerArray> sids = storage_device_map_[expr];
251+
auto sids = storage_device_map_[expr];
124252
std::vector<tir::Var> sid_vars;
125253

126254
// Note that an expression can have multiple sids associated with it
127255
// e.g., returning multiple values from a function
128-
for (const auto& sid : sids[0]) {
256+
for (const auto& sid : sids.ids) {
129257
// Determine if an sid is an output buffer
130-
int sid_int = static_cast<int>((sid.as<IntImmNode>())->value);
258+
int sid_int = sid;
131259
auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sid_int);
132260
if (output_iter != return_sid_.end()) {
133261
int output_index = std::distance(return_sid_.begin(), output_iter);
@@ -390,8 +518,8 @@ class AOTExecutorCodegen : public ExprVisitor {
390518
}
391519

392520
ICHECK_GE(storage_device_map_.count(expr), 0);
393-
auto& device_type = storage_device_map_[expr][1];
394-
auto call_dev_type = device_type[0]->value;
521+
auto& device_type = storage_device_map_[expr].dev_types;
522+
auto call_dev_type = device_type[0];
395523
// Normal Relay Function
396524
if (targets_.size() == 1) {
397525
// homogeneous execution.
@@ -428,14 +556,14 @@ class AOTExecutorCodegen : public ExprVisitor {
428556

429557
// If the Var node is an output node we need to copy the content of the variable to the output
430558
// It's safe to check the SID here because Var StorageToken are never reallocated
431-
Array<IntegerArray> sids = storage_device_map_[expr];
559+
auto sids = storage_device_map_[expr];
432560

433-
auto output_iter = std::find(return_sid_.begin(), return_sid_.end(),
434-
static_cast<int>((sids[0][0].as<IntImmNode>())->value));
561+
auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sids.ids[0]);
435562
if (output_iter != return_sid_.end()) {
436563
int output_index = std::distance(return_sid_.begin(), output_iter);
437564
auto var_expr = FindExpr(expr);
438-
CopyToOutput(main_signature_[input_vars_.size() + output_index], var_expr[0], sids[2][0]);
565+
CopyToOutput(main_signature_[input_vars_.size() + output_index], var_expr[0],
566+
sids.sizes_bytes[0]);
439567
}
440568
}
441569

@@ -444,18 +572,18 @@ class AOTExecutorCodegen : public ExprVisitor {
444572
size_t index = params_.size();
445573
std::string name = "p" + std::to_string(index);
446574

447-
param_storage_ids_[name] = storage_device_map_[expr][0][0]->value;
575+
param_storage_ids_[name] = storage_device_map_[expr].ids[0];
448576
params_[name] = op->data;
449577
params_by_expr_.Set(expr, name);
450578

451579
// If the Constant node is an output node we need to copy the content of the parameter to the
452580
// output A Var node can only produce a single output
453-
Array<IntegerArray> sids = storage_device_map_[expr];
454-
auto output_iter = std::find(return_sid_.begin(), return_sid_.end(),
455-
static_cast<int>((sids[0][0].as<IntImmNode>())->value));
581+
auto sids = storage_device_map_[expr];
582+
auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sids.ids[0]);
456583
if (output_iter != return_sid_.end()) {
457584
int output_index = std::distance(return_sid_.begin(), output_iter);
458-
CopyToOutput(main_signature_[input_vars_.size() + output_index], PackParam(expr), sids[2][0]);
585+
CopyToOutput(main_signature_[input_vars_.size() + output_index], PackParam(expr),
586+
sids.sizes_bytes[0]);
459587
}
460588
}
461589

@@ -511,9 +639,9 @@ class AOTExecutorCodegen : public ExprVisitor {
511639
continue;
512640
}
513641

514-
for (unsigned int i = 0; i < kv.second[0].size(); i++) {
515-
int size = kv.second[2][i];
516-
int sid = static_cast<int>((kv.second[0][i].as<IntImmNode>())->value);
642+
for (unsigned int i = 0; i < kv.second.ids.size(); i++) {
643+
int size = kv.second.sizes_bytes[i];
644+
int sid = kv.second.ids[i];
517645

518646
if (std::find(return_sid_.begin(), return_sid_.end(), sid) != return_sid_.end()) {
519647
continue;
@@ -523,6 +651,8 @@ class AOTExecutorCodegen : public ExprVisitor {
523651
// so we don't pay the price of allocation for every inference
524652
if (!allocated[sid]) {
525653
body = tir::Allocate(sids_table_[sid], DataType::Int(8), {size}, tir::const_true(), body);
654+
body = tir::AttrStmt(sids_table_[sid], tir::attr::storage_scope, tir::StringImm("global"),
655+
body);
526656
}
527657
allocated[sid] = true;
528658
}
@@ -566,7 +696,8 @@ class AOTExecutorCodegen : public ExprVisitor {
566696
std::unordered_map<std::string, int64_t> param_storage_ids_;
567697

568698
/*! \brief plan memory of device result */
569-
Map<Expr, Array<IntegerArray>> storage_device_map_;
699+
StorageMap storage_device_map_;
700+
/*! \brief mapping sid -> tir::Var */
570701
std::unordered_map<int, te::Var> sids_table_;
571702
/*! \brief lowered funcs */
572703
std::unordered_map<std::string, IRModule> lowered_funcs_;
@@ -577,7 +708,7 @@ class AOTExecutorCodegen : public ExprVisitor {
577708
/*! \brief the set of statements that make the program */
578709
std::vector<tir::Stmt> stmts_;
579710
/*! \brief the list of return sids (note that the function might return more then one output */
580-
IntegerArray return_sid_;
711+
std::vector<int> return_sid_;
581712

582713
public:
583714
AOTExecutorCodegen(runtime::Module* mod, const TargetsMap& targets, Target target_host)
@@ -588,9 +719,11 @@ class AOTExecutorCodegen : public ExprVisitor {
588719
}
589720

590721
LoweredOutput Codegen(relay::Function func) {
591-
// Get the module, storage map and token sizes
592-
auto pf = GetPackedFunc("relay.backend.GraphPlanMemory");
593-
storage_device_map_ = (*pf)(func);
722+
auto aot_allocator = AOTOnDemandAllocator();
723+
aot_allocator.Run(func);
724+
725+
// Retrieve the storage map
726+
storage_device_map_ = aot_allocator.GetStorageMap();
594727

595728
int input_index = 0;
596729
for (auto input : func->params) {
@@ -600,14 +733,14 @@ class AOTExecutorCodegen : public ExprVisitor {
600733

601734
// Define the storage allocator ids
602735
for (auto kv : storage_device_map_) {
603-
for (const auto& sid : kv.second[0]) {
736+
for (const auto& sid : kv.second.ids) {
604737
te::Var sid_var(MakeString("sid_", sid), PointerType(PrimType(DataType::Int(8))));
605738
sids_table_[sid] = sid_var;
606739
}
607740
}
608741

609-
// Find the return sid
610-
return_sid_ = AotReturnSidVisitor(storage_device_map_).FindReturnSid(func);
742+
// Retrieve the return sids
743+
return_sid_ = aot_allocator.GetReturnIds();
611744
for (unsigned int output_index = 0; output_index < return_sid_.size(); output_index++) {
612745
main_signature_.push_back(tir::Var(MakeString("output_", output_index), DataType::Handle()));
613746
}
@@ -635,14 +768,21 @@ class AOTExecutorCodegen : public ExprVisitor {
635768
}
636769
ret.external_mods = compile_engine_->LowerExternalFunctions();
637770

771+
// Build the TIR IRModule
772+
Map<GlobalVar, BaseFunc> symbol_map;
773+
symbol_map.Set(GlobalVar(::tvm::runtime::symbol::tvm_run_func_prefix), prim_func);
774+
IRModule mod_run(symbol_map);
775+
776+
// Apply storage rewrite pass to the runner function to do memory planning
777+
auto storage_rewrite = tir::transform::StorageRewrite();
778+
mod_run = storage_rewrite(mod_run);
779+
780+
// Update the lowered functions
638781
auto target_host_str = target_host_->str();
639782
if (ret.lowered_funcs.find(target_host_str) != ret.lowered_funcs.end()) {
640-
ret.lowered_funcs[target_host_str]->Add(
641-
GlobalVar(::tvm::runtime::symbol::tvm_run_func_prefix), prim_func);
783+
ret.lowered_funcs[target_host_str]->Update(mod_run);
642784
} else {
643-
Map<GlobalVar, BaseFunc> symbol_map;
644-
symbol_map.Set(GlobalVar(::tvm::runtime::symbol::tvm_run_func_prefix), prim_func);
645-
ret.lowered_funcs.Set(target_host_str, IRModule(symbol_map));
785+
ret.lowered_funcs.Set(target_host_str, mod_run);
646786
}
647787
ret.function_metadata = std::move(function_metadata_);
648788
ret.metadata =

0 commit comments

Comments
 (0)