Skip to content

Commit f95f642

Browse files
committed
[TIR][USMP] adding the pass to convert to pool offsets
* Adding a toggle to produce TIR that is TVMScript printable for unit testing * Fixing the unit tests * Ensure deterministic pool variable ordering. Change-Id: I317675df03327b0ebbf4ca074255384e63f07cd6
1 parent c2accec commit f95f642

File tree

6 files changed

+303
-152
lines changed

6 files changed

+303
-152
lines changed

include/tvm/tir/usmp/utils.h

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,44 @@ class PoolAllocation : public ObjectRef {
186186
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PoolAllocation, ObjectRef, PoolAllocationNode);
187187
};
188188

189+
/*!
190+
* \brief This object contains information post-allocation for PoolInfo objects
191+
*/
192+
struct AllocatedPoolInfoNode : public Object {
193+
/*! \brief The assigned PoolInfo object */
194+
PoolInfo pool_info;
195+
/*! \brief The allocated size into this pool */
196+
Integer allocated_size;
197+
/*! \brief An optional associated pool Var*/
198+
Optional<Var> pool_var;
199+
200+
void VisitAttrs(tvm::AttrVisitor* v) {
201+
v->Visit("pool_info", &pool_info);
202+
v->Visit("allocated_size", &allocated_size);
203+
v->Visit("pool_var", &pool_var);
204+
}
205+
206+
bool SEqualReduce(const AllocatedPoolInfoNode* other, SEqualReducer equal) const {
207+
return equal(pool_info, other->pool_info) && equal(allocated_size, other->allocated_size) &&
208+
equal(pool_var, other->pool_var);
209+
}
210+
211+
void SHashReduce(SHashReducer hash_reduce) const {
212+
hash_reduce(pool_info);
213+
hash_reduce(allocated_size);
214+
hash_reduce(pool_var);
215+
}
216+
217+
static constexpr const char* _type_key = "tir.usmp.AllocatedPoolInfo";
218+
TVM_DECLARE_FINAL_OBJECT_INFO(AllocatedPoolInfoNode, Object);
219+
};
220+
221+
class AllocatedPoolInfo : public ObjectRef {
222+
public:
223+
TVM_DLL AllocatedPoolInfo(PoolInfo pool_info, Integer allocated_size, Var pool_var = Var());
224+
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(AllocatedPoolInfo, ObjectRef, AllocatedPoolInfoNode);
225+
};
226+
189227
/*!
190228
* \brief Convert the IR-bound BufferInfo map to an array of BufferInfo
191229
*
@@ -209,6 +247,20 @@ Integer CalculateExtentsSize(const AllocateNode* op);
209247

210248
} // namespace usmp
211249
} // namespace tir
250+
251+
namespace attr {
252+
/*!
253+
* \brief This is a BaseFunc attribute to indicate which input var represent
254+
* a PoolInfo Object in the form of a Map<Var, PoolInfo>.
255+
*/
256+
static constexpr const char* kPoolArgs = "pool_args";
257+
/*!
258+
* \brief This is a BaseFunc attribute to indicate which input var represent
259+
* a PoolInfo Object in the form of a Map<Var, PoolInfo>.
260+
*/
261+
static constexpr const char* kPoolInfoIRModuleAttr = "pool_infos";
262+
} // namespace attr
263+
212264
} // namespace tvm
213265

214266
#endif // TVM_TIR_USMP_UTILS_H_

python/tvm/tir/usmp/transform/transform.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,22 @@
2424
from ..utils import PoolAllocation
2525

2626

27-
def convert_pool_allocations_to_offsets(pool_allocations: Dict[Stmt, PoolAllocation]):
27+
def convert_pool_allocations_to_offsets(
28+
pool_allocations: Dict[Stmt, PoolAllocation], emit_tvmscript_printable: bool = False
29+
):
2830
"""Convert pool allocations to Load nodes with offsets from pools.
2931
3032
Parameters
3133
----------
3234
pool_allocations : Dict[Stmt, PoolAllocation]
3335
Allocate or AllocateConst node to pool allocation mapping
36+
emit_tvmscript_printable : bool
37+
A toggle to emit TVMScript printable IRModule for unit tests
38+
removing all attributes that should be attached for integration
3439
3540
Returns
3641
-------
3742
ret: tvm.transform.Pass
3843
The registered pass that converts the allocations to offsets.
3944
"""
40-
return _ffi_api.ConvertPoolAllocationsToOffsets(pool_allocations)
45+
return _ffi_api.ConvertPoolAllocationsToOffsets(pool_allocations, emit_tvmscript_printable)

src/tir/ir/stmt.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,12 @@ namespace tir {
3535
LetStmt::LetStmt(Var var, PrimExpr value, Stmt body, Span span) {
3636
ICHECK(value.defined());
3737
ICHECK(body.defined());
38-
ICHECK_EQ(value.dtype(), var.dtype());
38+
auto vdtype = value.dtype();
39+
if (var->type_annotation.as<PointerTypeNode>()) {
40+
ICHECK(vdtype.is_handle());
41+
} else {
42+
ICHECK_EQ(value.dtype(), var.dtype());
43+
}
3944

4045
ObjectPtr<LetStmtNode> node = make_object<LetStmtNode>();
4146
node->var = std::move(var);

src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc

Lines changed: 108 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,9 @@ namespace usmp {
3939
class PoolAllocationToOffsetConverter : public StmtExprMutator {
4040
public:
4141
explicit PoolAllocationToOffsetConverter(const IRModule& module,
42-
const Map<tir::Stmt, PoolAllocation>& pool_allocations)
43-
: pool_allocations_(pool_allocations) {
42+
const Map<tir::Stmt, PoolAllocation>& pool_allocations,
43+
bool emit_tvmscript_printable = false)
44+
: pool_allocations_(pool_allocations), emit_tvmscript_printable_(emit_tvmscript_printable) {
4445
module_ = module->ShallowCopy();
4546
for (const auto& gv_func : module_->functions) {
4647
function_global_vars_.Set(gv_func.first->name_hint, gv_func.first);
@@ -51,7 +52,6 @@ class PoolAllocationToOffsetConverter : public StmtExprMutator {
5152
Allocate allocate_node = Downcast<Allocate>(kv.first);
5253
PoolAllocation pool_allocation = kv.second;
5354
PoolInfo pool_info = pool_allocation->pool_info;
54-
pool_ordering_.insert(pool_info);
5555
int byte_pool_offset = pool_allocation->byte_offset->value;
5656
int required_pool_size_for_allocation =
5757
byte_pool_offset + CalculateExtentsSize(allocate_node.operator->());
@@ -64,12 +64,26 @@ class PoolAllocationToOffsetConverter : public StmtExprMutator {
6464
}
6565
}
6666
}
67+
68+
for (const auto& kv : all_pools_sizes_) {
69+
PoolInfo pi = kv.first;
70+
int allocated_size = kv.second;
71+
allocated_pool_ordering_.push_back(AllocatedPoolInfo(pi, allocated_size));
72+
}
73+
std::sort(allocated_pool_ordering_.begin(), allocated_pool_ordering_.end(),
74+
[](const AllocatedPoolInfo& lhs, const AllocatedPoolInfo& rhs) {
75+
if (lhs->pool_info->pool_name < rhs->pool_info->pool_name) {
76+
return true;
77+
}
78+
return false;
79+
});
6780
}
6881
IRModule operator()();
6982

7083
private:
7184
PrimExpr VisitExpr_(const CallNode* op) override;
7285
Stmt VisitStmt_(const AllocateNode* op) override;
86+
// PrimExpr VisitExpr_(const VarNode* op) override;
7387
PrimExpr VisitExpr_(const LoadNode* op) override;
7488
Stmt VisitStmt_(const StoreNode* op) override;
7589

@@ -79,6 +93,7 @@ class PoolAllocationToOffsetConverter : public StmtExprMutator {
7993
struct ScopeInfo {
8094
Array<tir::Var> params;
8195
Map<PoolInfo, tir::Var> pools_to_params;
96+
Array<AllocatedPoolInfo> allocated_pool_params;
8297
Map<tir::Var, Buffer> buffer_map;
8398
};
8499

@@ -101,7 +116,7 @@ class PoolAllocationToOffsetConverter : public StmtExprMutator {
101116
/*! \brief This is a helper to append the pool args to
102117
* the callsite of the function.
103118
*/
104-
Array<PrimExpr> AppendPoolParamsToArgs(const CallNode* op);
119+
Array<PrimExpr> AppendPoolParamsToArgs(const Array<PrimExpr>& args);
105120
/*! \brief Some arguments that used to be Allocate nodes
106121
* should be replaced by Let nodes in the pass that loads
107122
* the space from a pool variable.
@@ -117,7 +132,7 @@ class PoolAllocationToOffsetConverter : public StmtExprMutator {
117132
/*! \brief The input allocate node to PoolAllocation map */
118133
Map<tir::Stmt, PoolAllocation> pool_allocations_;
119134
/*! \brief The set of ordered pools to ensure an unique order of args for functions */
120-
std::set<PoolInfo> pool_ordering_;
135+
std::vector<AllocatedPoolInfo> allocated_pool_ordering_;
121136
/*! \brief The storage of calculated pool size at init */
122137
std::unordered_map<PoolInfo, int, ObjectPtrHash, ObjectPtrEqual> all_pools_sizes_;
123138
/*! \brief The AoT codegen uses extern_calls due to some functions not being exposed in the TIR
@@ -130,6 +145,10 @@ class PoolAllocationToOffsetConverter : public StmtExprMutator {
130145
Map<tir::Var, tir::Var> allocate_buf_to_let_var_;
131146
/*! \brief A counter to give references to pools a reproducible unique set of names */
132147
int pool_var_count_ = 0;
148+
/*! \brief This toggles to remove non tvmscript printable items for IRModule for unit tests */
149+
bool emit_tvmscript_printable_ = false;
150+
/*! \brief A counter to give references to pools a reproducible unique set of names */
151+
std::unordered_set<PrimFunc, ObjectPtrHash, ObjectPtrEqual> visited_primfuncs;
133152
};
134153

135154
PoolAllocationToOffsetConverter::ScopeInfo PoolAllocationToOffsetConverter::UpdateFunctionScopeInfo(
@@ -138,14 +157,22 @@ PoolAllocationToOffsetConverter::ScopeInfo PoolAllocationToOffsetConverter::Upda
138157
si.params = original_func->params;
139158
si.buffer_map = original_func->buffer_map;
140159
Map<tir::Var, PoolInfo> ret;
141-
for (const PoolInfo& pool_info : pool_ordering_) {
160+
for (const AllocatedPoolInfo& allocated_pool_info : allocated_pool_ordering_) {
161+
PoolInfo pool_info = allocated_pool_info->pool_info;
142162
String pool_ref_name = pool_info->pool_name + "_" + std::to_string(pool_var_count_++);
143163
String var_name = pool_ref_name + "_var";
144164
DataType elem_dtype = DataType::UInt(8);
145165
Var buffer_var(var_name, PointerType(PrimType(elem_dtype), "global"));
146-
Var pool_var(var_name, DataType::Handle());
166+
Var pool_var;
167+
if (!emit_tvmscript_printable_) {
168+
pool_var = Var(var_name, PointerType(PrimType(elem_dtype), "global"));
169+
} else {
170+
pool_var = Var(var_name, DataType::Handle(8));
171+
}
147172
si.params.push_back(pool_var);
148173
si.pools_to_params.Set(pool_info, pool_var);
174+
si.allocated_pool_params.push_back(AllocatedPoolInfo(
175+
allocated_pool_info->pool_info, allocated_pool_info->allocated_size, pool_var));
149176

150177
int pool_size = all_pools_sizes_[pool_info];
151178
String buffer_var_name = pool_ref_name + "_buffer_var";
@@ -157,22 +184,40 @@ PoolAllocationToOffsetConverter::ScopeInfo PoolAllocationToOffsetConverter::Upda
157184

158185
PrimFunc PoolAllocationToOffsetConverter::CreatePrimFuncWithPoolParams(
159186
const PrimFunc& original_primfunc) {
160-
ScopeInfo si = UpdateFunctionScopeInfo(original_primfunc);
161-
this->scope_stack.push(si);
162-
Stmt new_body = this->VisitStmt(original_primfunc->body);
163-
this->scope_stack.pop();
164-
return PrimFunc(si.params, new_body, original_primfunc->ret_type, si.buffer_map,
165-
original_primfunc->attrs);
187+
// Only create the new function if it was not modified with pool params
188+
if (visited_primfuncs.find(original_primfunc) == visited_primfuncs.end()) {
189+
ScopeInfo si = UpdateFunctionScopeInfo(original_primfunc);
190+
this->scope_stack.push(si);
191+
Stmt new_body = this->VisitStmt(original_primfunc->body);
192+
this->scope_stack.pop();
193+
DictAttrs original_attrs = original_primfunc->attrs;
194+
// We dont need attrs of PrimFunc that might include non printable attrs such as target
195+
// for unit tests where emit_tvmscript_printable_ is to be used.
196+
if (emit_tvmscript_printable_) {
197+
original_attrs = DictAttrs();
198+
}
199+
PrimFunc ret =
200+
PrimFunc(si.params, new_body, original_primfunc->ret_type, si.buffer_map, original_attrs);
201+
if (!emit_tvmscript_printable_) {
202+
return WithAttr(ret, tvm::attr::kPoolArgs, si.allocated_pool_params);
203+
}
204+
visited_primfuncs.insert(ret);
205+
return ret;
206+
}
207+
return original_primfunc;
166208
}
167209

168-
Array<PrimExpr> PoolAllocationToOffsetConverter::AppendPoolParamsToArgs(const CallNode* op) {
210+
Array<PrimExpr> PoolAllocationToOffsetConverter::AppendPoolParamsToArgs(
211+
const Array<PrimExpr>& args) {
169212
Array<PrimExpr> new_args;
170-
for (const auto& arg : op->args) {
213+
for (const auto& arg : args) {
171214
new_args.push_back(VisitExpr(arg));
172215
}
173-
for (const auto& pools_vars : this->scope_stack.top().pools_to_params) {
216+
ScopeInfo top_scope = this->scope_stack.top();
217+
for (const auto& pools_vars : top_scope.pools_to_params) {
174218
tir::Var pool_var = pools_vars.second;
175-
new_args.push_back(pool_var);
219+
Buffer buffer_var = top_scope.buffer_map[pool_var];
220+
new_args.push_back(buffer_var->data);
176221
}
177222
return new_args;
178223
}
@@ -192,24 +237,30 @@ Array<PrimExpr> PoolAllocationToOffsetConverter::ReplaceAllocateArgsWithLetArgs(
192237
}
193238

194239
PrimExpr PoolAllocationToOffsetConverter::VisitExpr_(const CallNode* op) {
195-
if (op->op.same_as(builtin::call_extern())) {
240+
if (op->op.same_as(builtin::call_extern()) || op->op.same_as(builtin::tvm_call_cpacked())) {
196241
String func_name = Downcast<StringImm>(op->args[0])->value;
197-
GlobalVar gv = function_global_vars_.at(func_name);
198-
PrimFunc func = Downcast<PrimFunc>(module_->Lookup(gv));
199-
PrimFunc prim_func = CreatePrimFuncWithPoolParams(func);
200-
module_->Update(gv, prim_func);
201-
Array<PrimExpr> new_args = AppendPoolParamsToArgs(op);
202-
new_args = ReplaceAllocateArgsWithLetArgs(new_args);
203-
return Call(op->dtype, builtin::call_extern(), new_args);
204-
} else if (op->op->IsInstance<PrimFuncNode>()) {
242+
Array<PrimExpr> new_args;
243+
if (function_global_vars_.find(func_name) != function_global_vars_.end()) {
244+
GlobalVar gv = function_global_vars_.at(func_name);
245+
PrimFunc func = Downcast<PrimFunc>(module_->Lookup(gv));
246+
PrimFunc prim_func = CreatePrimFuncWithPoolParams(func);
247+
module_->Update(gv, prim_func);
248+
new_args = AppendPoolParamsToArgs(op->args);
249+
new_args = ReplaceAllocateArgsWithLetArgs(new_args);
250+
} else {
251+
new_args = ReplaceAllocateArgsWithLetArgs(op->args);
252+
}
253+
return Call(op->dtype, op->op, new_args);
254+
}
255+
if (op->op->IsInstance<PrimFuncNode>()) {
205256
PrimFunc func = Downcast<PrimFunc>(op->op);
206257
PrimFunc prim_func = CreatePrimFuncWithPoolParams(func);
207-
Array<PrimExpr> new_args = AppendPoolParamsToArgs(op);
258+
Array<PrimExpr> new_args = AppendPoolParamsToArgs(op->args);
259+
new_args = AppendPoolParamsToArgs(new_args);
208260
new_args = ReplaceAllocateArgsWithLetArgs(new_args);
209261
return Call(op->dtype, prim_func, new_args);
210-
} else {
211-
return StmtExprMutator::VisitExpr_(op);
212262
}
263+
return StmtExprMutator::VisitExpr_(op);
213264
}
214265

215266
Stmt PoolAllocationToOffsetConverter::VisitStmt_(const AllocateNode* op) {
@@ -219,12 +270,19 @@ Stmt PoolAllocationToOffsetConverter::VisitStmt_(const AllocateNode* op) {
219270
Var param = scope_info.pools_to_params[pool_allocation->pool_info];
220271
Buffer buffer_var = scope_info.buffer_map[param];
221272
ICHECK(pool_allocation->byte_offset < all_pools_sizes_[pool_allocation->pool_info]);
222-
Load load_node = Load(op->dtype, buffer_var->data, pool_allocation->byte_offset, op->condition);
223-
Var tir_var(op->buffer_var->name_hint + "_let", op->dtype);
273+
Load load_node =
274+
Load(DataType::UInt(8), buffer_var->data, pool_allocation->byte_offset, op->condition);
275+
Call address_of_load = Call(DataType::Handle(8), builtin::address_of(), {load_node});
276+
Var tir_var;
277+
if (!emit_tvmscript_printable_) {
278+
tir_var = Var(op->buffer_var->name_hint + "_let", op->buffer_var->type_annotation);
279+
} else {
280+
tir_var = Var(op->buffer_var->name_hint + "_let", DataType::Handle(8));
281+
}
224282
allocate_buf_to_let_var_.Set(op->buffer_var, tir_var);
225283
Stmt new_body = VisitStmt(op->body);
226284
allocate_buf_to_let_var_.erase(op->buffer_var);
227-
return LetStmt(tir_var, load_node, new_body);
285+
return LetStmt(tir_var, address_of_load, new_body);
228286
}
229287
return StmtExprMutator::VisitStmt_(op);
230288
}
@@ -252,17 +310,31 @@ IRModule PoolAllocationToOffsetConverter::operator()() {
252310
this->scope_stack.push(si);
253311
Stmt main_func_body = this->VisitStmt(main_func->body);
254312
this->scope_stack.pop();
255-
module_->Update(gv, PrimFunc(si.params, main_func_body, main_func->ret_type, si.buffer_map,
256-
main_func->attrs));
313+
// We dont need attrs of PrimFunc that might include non printable attrs such as target
314+
// for unit tests where emit_tvmscript_printable_ is to be used.
315+
if (!emit_tvmscript_printable_) {
316+
main_func =
317+
PrimFunc(si.params, main_func_body, main_func->ret_type, si.buffer_map, main_func->attrs);
318+
main_func = WithAttr(main_func, tvm::attr::kPoolArgs, si.allocated_pool_params);
319+
} else {
320+
main_func =
321+
PrimFunc(si.params, main_func_body, main_func->ret_type, si.buffer_map, DictAttrs());
322+
}
323+
module_->Update(gv, main_func);
324+
if (!emit_tvmscript_printable_) {
325+
return WithAttr(this->module_, tvm::attr::kPoolArgs, si.allocated_pool_params);
326+
}
257327
return this->module_;
258328
}
259329

260330
namespace transform {
261331

262332
tvm::transform::Pass ConvertPoolAllocationsToOffsets(
263-
const Map<tir::Stmt, PoolAllocation>& pool_allocations) {
333+
const Map<tir::Stmt, PoolAllocation>& pool_allocations,
334+
Bool emit_tvmscript_printable = Bool(false)) {
264335
auto pass_func = [=](IRModule m, tvm::transform::PassContext ctx) {
265-
return Downcast<IRModule>(PoolAllocationToOffsetConverter(m, pool_allocations)());
336+
return Downcast<IRModule>(PoolAllocationToOffsetConverter(
337+
m, pool_allocations, emit_tvmscript_printable->value != 0)());
266338
};
267339
return tvm::transform::CreateModulePass(pass_func, 0, "tir.usmp.ConvertPoolAllocationsToOffsets",
268340
{});

src/tir/usmp/utils.cc

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,30 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
113113
<< ")";
114114
});
115115

116+
AllocatedPoolInfo::AllocatedPoolInfo(PoolInfo pool_info, Integer allocated_size, Var pool_var) {
117+
auto allocated_poolinfo_node = make_object<AllocatedPoolInfoNode>();
118+
allocated_poolinfo_node->pool_info = pool_info;
119+
allocated_poolinfo_node->allocated_size = allocated_size;
120+
if (pool_var.defined()) {
121+
allocated_poolinfo_node->pool_var = pool_var;
122+
}
123+
data_ = std::move(allocated_poolinfo_node);
124+
}
125+
126+
TVM_REGISTER_NODE_TYPE(AllocatedPoolInfoNode);
127+
TVM_REGISTER_GLOBAL("tir.usmp.AllocatedPoolInfo")
128+
.set_body_typed([](PoolInfo pool_info, Integer allocated_size) {
129+
return AllocatedPoolInfo(pool_info, allocated_size);
130+
});
131+
132+
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
133+
.set_dispatch<AllocatedPoolInfoNode>([](const ObjectRef& ref, ReprPrinter* p) {
134+
auto* node = static_cast<const AllocatedPoolInfoNode*>(ref.get());
135+
p->stream << "AllocatedPoolInfoNode(\n"
136+
<< "pool_info=" << node->pool_info << ",\n allocated_size=" << node->allocated_size
137+
<< ")";
138+
});
139+
116140
Array<BufferInfo> CreateArrayBufferInfo(const Map<BufferInfo, Stmt>& buffer_info_map) {
117141
Array<BufferInfo> ret;
118142
for (const auto& kv : buffer_info_map) {

0 commit comments

Comments
 (0)