1818 */
1919
2020/* !
21- * \file tir/analysis/usmp/convert_for_loops_serial.cc
22- * \brief Convert all for loops to serial for lesser memory consumption
21+ * \file tir/analysis/usmp/extract_buffer_info.cc
22+ *
23+ * \brief This analysis pass consumes a TIR IRModule with a main function
24+ * that defines a ordering in the calles to operators and produces BufferInfo
25+ * objects that contains information about tir.allocate nodes and liveness
26+ * conflicts between other tir.allocate nodes.
2327 */
2428#include < tvm/arith/analyzer.h>
2529#include < tvm/runtime/device_api.h>
@@ -34,6 +38,12 @@ namespace tvm {
3438namespace tir {
3539namespace usmp {
3640
41+ /* ! \brief This class takes a TIR IRModule and a main PrimFunc that contains
42+ * that defines a ordering in the calles to operators and produces BufferInfo
43+ * objects that contains information about tir.allocate nodes and liveness
44+ * conflicts between other tir.allocate nodes.
45+ */
46+
3747class BufferInfoExtractor : public StmtExprVisitor {
3848 public:
3949 explicit BufferInfoExtractor (const IRModule& module ) : module_(module ) {
@@ -63,6 +73,11 @@ class BufferInfoExtractor : public StmtExprVisitor {
6373
6474 std::unordered_set<Stmt, ObjectPtrHash, ObjectPtrEqual> currently_live_allocates;
6575 int current_stmt_idx_ = 0 ;
76+ // This structure is supposed to contain information
77+ // around the scope the visitor is currently in.
78+ // We only check whether the current scope belong to
79+ // a Serial ForKind. We are not planning for Parallel
80+ // ForKind just yet.
6681 struct ScopeInfo {
6782 For for_loop;
6883 };
@@ -77,39 +92,44 @@ void BufferInfoExtractor::VisitStmt(const Stmt& n) {
7792 StmtExprVisitor::VisitStmt (n);
7893}
7994
80- static size_t CalculateExtentsSize (const AllocateNode* op) {
95+ static Integer CalculateExtentsSize (const AllocateNode* op) {
8196 size_t element_size_bytes = op->dtype .bytes ();
8297 size_t num_elements = 1 ;
8398 for (const auto & ext : op->extents ) {
8499 if (ext->IsInstance <IntImmNode>()) {
85100 num_elements *= Downcast<IntImm>(ext)->value ;
86101 } else {
87102 // We can't statically calculate workspace for dynamic shapes
88- num_elements = 0 ;
103+ return Integer () ;
89104 }
90105 }
91- return (num_elements * element_size_bytes);
106+ return Integer (num_elements * element_size_bytes);
92107}
93108
94109void BufferInfoExtractor::VisitStmt_ (const AllocateNode* op) {
95110 const auto & currect_scope_info = scope_stack_.top ();
96111 const auto & type = Downcast<PointerType>(op->buffer_var ->type_annotation );
97112 const auto & storage_scope = type->storage_scope ;
98113
99- // If the allocate is in a for loop,
100- // USMP currently only looks at serial for loops.
114+ // If the allocate is in a for loop, USMP currently only looks at serial for loops.
115+ // If its not a serial for loop, then memory planner will omit them in the current memory planning
116+ // process leaving them to as tir.allocate nodes for codegen. Additionally, the USMP can only work
117+ // with buffers that have global storage_scope
101118 if ((!currect_scope_info.for_loop .defined ()) ||
102119 (currect_scope_info.for_loop .defined () &&
103120 currect_scope_info.for_loop ->kind == ForKind::kSerial && storage_scope == " global" )) {
104- // USMP can only work with buffers that have global storage_scope
105121 auto size_bytes = CalculateExtentsSize (op);
106122 // We only statically memory plan only allocates with known
107123 // compile time sizes.
108- if (size_bytes) {
124+ if (size_bytes. defined () ) {
109125 // By default, the core compiler is assumed to attach the a default pool to each allocate.
110- ICHECK (op->annotations .count (kPoolCandidatesIRModAttr ))
126+ ICHECK (op->annotations .count (kPoolCandidatesAllocateAttr ))
111127 << " Every statically sized allocate node needs an pool candidate attribute" ;
112- auto pool_candidates = Downcast<Array<PoolInfo>>(op->annotations [kPoolCandidatesIRModAttr ]);
128+ auto pool_candidates =
129+ Downcast<Array<PoolInfo>>(op->annotations [kPoolCandidatesAllocateAttr ]);
130+
131+ // TODO(@manupa-arm): improve the error when the responsible component for attaching a single
132+ // pool is added
113133 ICHECK (pool_candidates.size () > 0 )
114134 << " The core compiler should at least attach a single PoolInfo. If there were no "
115135 " user-given arguments for memory pools, the default behaviour is a single size "
@@ -203,6 +223,13 @@ void BufferInfoExtractor::VisitExpr_(const CallNode* op) {
203223Map<BufferInfo, tir::Stmt> BufferInfoExtractor::operator ()(const PrimFunc& main_func) {
204224 this ->VisitStmt (main_func->body );
205225
226+ // A liveness event is an event that when
227+ // traversing the tir.Stmts where tir.allocate node
228+ // begins or ceases to be Live. This particular struct
229+ // is used to solve interval overlap problem using
230+ // a sweep-line algorithm. For that, we need to record
231+ // where the liveness event occurred in a chronological
232+ // order.
206233 enum LivenessEventType { START = 0 , END = 1 };
207234 struct LivenessEvent {
208235 size_t tick;
@@ -216,6 +243,8 @@ Map<BufferInfo, tir::Stmt> BufferInfoExtractor::operator()(const PrimFunc& main_
216243 }
217244 };
218245
246+ // Create a vector of liveness events
247+ // associated with each BufferNodes.
219248 std::vector<LivenessEvent> le_events;
220249 for (const auto & kv : buffer_info_map_) {
221250 if (!kv.second ->IsInstance <AllocateNode>()) {
@@ -240,6 +269,9 @@ Map<BufferInfo, tir::Stmt> BufferInfoExtractor::operator()(const PrimFunc& main_
240269 le_events.push_back (le_event_end);
241270 }
242271
272+ // Sort the liveness events based on the chronological
273+ // ordering. For events that are simultaneous, START event
274+ // takes precedence.
243275 std::sort (le_events.begin (), le_events.end (),
244276 [](const LivenessEvent& lhs, const LivenessEvent& rhs) {
245277 if (lhs.tick < rhs.tick ) {
@@ -249,6 +281,9 @@ Map<BufferInfo, tir::Stmt> BufferInfoExtractor::operator()(const PrimFunc& main_
249281 }
250282 return false ;
251283 });
284+
285+ // Traverse the liveness events using a open set to track what
286+ // is live while updating the conflicts through out the linear traversal
252287 std::unordered_set<BufferInfo, ObjectPtrHash, ObjectPtrEqual> open_set;
253288 for (const auto & le_event : le_events) {
254289 if (le_event.le_type == START) {
@@ -258,7 +293,6 @@ Map<BufferInfo, tir::Stmt> BufferInfoExtractor::operator()(const PrimFunc& main_
258293 }
259294 open_set.insert (le_event.buffer_info );
260295 } else {
261- ICHECK (le_event.le_type == END);
262296 open_set.erase (le_event.buffer_info );
263297 }
264298 }
0 commit comments