Skip to content

Commit 10557aa

Browse files
committed
[TIR][USMP] Added buffer info extraction pass
Adding more documentation for data structures and the approach Change-Id: Ide2bfffaeff9add86853b6992017264e5d796299
1 parent 9935d7c commit 10557aa

File tree

6 files changed

+95
-20
lines changed

6 files changed

+95
-20
lines changed

include/tvm/tir/usmp/utils.h

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,17 @@ namespace tvm {
3333
namespace tir {
3434
namespace usmp {
3535

36+
/*!
37+
* \brief The string parameter to indicate read and write access to a pool
38+
* This needs to be kept in sync with PoolInfo.READ_WRITE_ACCESS in
39+
* python/tvm/tir/usmp/utils.py
40+
*/
3641
static constexpr const char* kTargetPoolReadWriteAccess = "rw";
42+
/*!
43+
* \brief The string parameter to indicate read only access to a pool
44+
* This needs to be kept in sync with PoolInfo.READ_ONLY_ACCESS in
45+
* python/tvm/tir/usmp/utils.py
46+
*/
3747
static constexpr const char* kTargetPoolReadOnlyAccess = "ro";
3848

3949
/*!
@@ -43,8 +53,8 @@ struct PoolInfoNode : public Object {
4353
/*! \brief The name of the memory pool */
4454
String pool_name;
4555
/*! \brief The expected size hint to be used by the allocator.
46-
* The size_hint is defaulted to -1 to indicate the pool is not
47-
* size restricted.
56+
* The size_hint_bytes is defaulted to kUnrestrictedPoolSizeHint
57+
* to indicate the pool is not size restricted.
4858
*/
4959
Integer size_hint_bytes;
5060
/*! \brief The accessibility from each Target*/
@@ -71,10 +81,15 @@ struct PoolInfoNode : public Object {
7181
TVM_DECLARE_FINAL_OBJECT_INFO(PoolInfoNode, Object);
7282
};
7383

84+
/*!
85+
* \brief The PoolSize is unrestricted for the memory planner
86+
*/
87+
static const int kUnrestrictedPoolSizeHint = -1;
88+
7489
class PoolInfo : public ObjectRef {
7590
public:
7691
TVM_DLL PoolInfo(String pool_name, Map<Target, String> target_access,
77-
Integer size_hint_bytes = -1);
92+
Integer size_hint_bytes = kUnrestrictedPoolSizeHint);
7893
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PoolInfo, ObjectRef, PoolInfoNode);
7994
};
8095

@@ -172,7 +187,12 @@ class PoolAllocation : public ObjectRef {
172187
*/
173188
Array<BufferInfo> CreateArrayBufferInfo(const Map<Stmt, BufferInfo>& buffer_info_map);
174189

175-
static constexpr const char* kPoolCandidatesIRModAttr = "candidate_memory_pools";
190+
/*!
191+
* \brief The allocate node attribute to indicate candidate memory pools.
192+
* This needs to be kept in sync with CANDIDATE_MEMORY_POOL_ATTR in
193+
* python/tvm/tir/usmp/utils.py.
194+
*/
195+
static constexpr const char* kPoolCandidatesAllocateAttr = "candidate_memory_pools";
176196

177197
} // namespace usmp
178198
} // namespace tir

python/tvm/tir/usmp/utils.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@
2424
from tvm.target import Target
2525
from . import _ffi_api
2626

27+
28+
# The allocate node attribute to indicate candidate memory pools.
29+
# This needs to be kept in sync with CANDIDATE_MEMORY_POOL_ATTR in
30+
# include/tvm/tir/usmp/utils.h
2731
CANDIDATE_MEMORY_POOL_ATTR = "candidate_memory_pools"
2832

2933

@@ -50,11 +54,20 @@ class PoolInfo(Object):
5054
5155
"""
5256

57+
# The string parameter to indicate read and write access to a pool
58+
# This needs to be kept in sync with kTargetPoolReadWriteAccess in
59+
# include/tvm/tir/usmp/utils.h
5360
READ_WRITE_ACCESS = "rw"
61+
# The string parameter to indicate read only access to a pool
62+
# This needs to be kept in sync with kTargetPoolReadOnlyAccess in
63+
# include/tvm/tir/usmp/utils.h
5464
READ_ONLY_ACCESS = "ro"
5565

5666
def __init__(
57-
self, pool_name: str, target_access: Dict[Target, str], size_hint_bytes: Optional[int] = -1
67+
self,
68+
pool_name: str,
69+
target_access: Dict[Target, str],
70+
size_hint_bytes: Optional[int] = None,
5871
):
5972
self.__init_handle_by_constructor__(
6073
_ffi_api.PoolInfo, # type: ignore # pylint: disable=no-member

src/tir/usmp/analysis/extract_buffer_info.cc

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,12 @@
1818
*/
1919

2020
/*!
21-
* \file tir/analysis/usmp/convert_for_loops_serial.cc
22-
* \brief Convert all for loops to serial for lesser memory consumption
21+
* \file tir/analysis/usmp/extract_buffer_info.cc
22+
*
23+
* \brief This analysis pass consumes a TIR IRModule with a main function
24+
* that defines a ordering in the calles to operators and produces BufferInfo
25+
* objects that contains information about tir.allocate nodes and liveness
26+
* conflicts between other tir.allocate nodes.
2327
*/
2428
#include <tvm/arith/analyzer.h>
2529
#include <tvm/runtime/device_api.h>
@@ -34,6 +38,12 @@ namespace tvm {
3438
namespace tir {
3539
namespace usmp {
3640

41+
/*! \brief This class takes a TIR IRModule and a main PrimFunc that contains
42+
* that defines a ordering in the calles to operators and produces BufferInfo
43+
* objects that contains information about tir.allocate nodes and liveness
44+
* conflicts between other tir.allocate nodes.
45+
*/
46+
3747
class BufferInfoExtractor : public StmtExprVisitor {
3848
public:
3949
explicit BufferInfoExtractor(const IRModule& module) : module_(module) {
@@ -63,6 +73,11 @@ class BufferInfoExtractor : public StmtExprVisitor {
6373

6474
std::unordered_set<Stmt, ObjectPtrHash, ObjectPtrEqual> currently_live_allocates;
6575
int current_stmt_idx_ = 0;
76+
// This structure is supposed to contain information
77+
// around the scope the visitor is currently in.
78+
// We only check whether the current scope belong to
79+
// a Serial ForKind. We are not planning for Parallel
80+
// ForKind just yet.
6681
struct ScopeInfo {
6782
For for_loop;
6883
};
@@ -77,39 +92,44 @@ void BufferInfoExtractor::VisitStmt(const Stmt& n) {
7792
StmtExprVisitor::VisitStmt(n);
7893
}
7994

80-
static size_t CalculateExtentsSize(const AllocateNode* op) {
95+
static Integer CalculateExtentsSize(const AllocateNode* op) {
8196
size_t element_size_bytes = op->dtype.bytes();
8297
size_t num_elements = 1;
8398
for (const auto& ext : op->extents) {
8499
if (ext->IsInstance<IntImmNode>()) {
85100
num_elements *= Downcast<IntImm>(ext)->value;
86101
} else {
87102
// We can't statically calculate workspace for dynamic shapes
88-
num_elements = 0;
103+
return Integer();
89104
}
90105
}
91-
return (num_elements * element_size_bytes);
106+
return Integer(num_elements * element_size_bytes);
92107
}
93108

94109
void BufferInfoExtractor::VisitStmt_(const AllocateNode* op) {
95110
const auto& currect_scope_info = scope_stack_.top();
96111
const auto& type = Downcast<PointerType>(op->buffer_var->type_annotation);
97112
const auto& storage_scope = type->storage_scope;
98113

99-
// If the allocate is in a for loop,
100-
// USMP currently only looks at serial for loops.
114+
// If the allocate is in a for loop, USMP currently only looks at serial for loops.
115+
// If its not a serial for loop, then memory planner will omit them in the current memory planning
116+
// process leaving them to as tir.allocate nodes for codegen. Additionally, the USMP can only work
117+
// with buffers that have global storage_scope
101118
if ((!currect_scope_info.for_loop.defined()) ||
102119
(currect_scope_info.for_loop.defined() &&
103120
currect_scope_info.for_loop->kind == ForKind::kSerial && storage_scope == "global")) {
104-
// USMP can only work with buffers that have global storage_scope
105121
auto size_bytes = CalculateExtentsSize(op);
106122
// We only statically memory plan only allocates with known
107123
// compile time sizes.
108-
if (size_bytes) {
124+
if (size_bytes.defined()) {
109125
// By default, the core compiler is assumed to attach the a default pool to each allocate.
110-
ICHECK(op->annotations.count(kPoolCandidatesIRModAttr))
126+
ICHECK(op->annotations.count(kPoolCandidatesAllocateAttr))
111127
<< "Every statically sized allocate node needs an pool candidate attribute";
112-
auto pool_candidates = Downcast<Array<PoolInfo>>(op->annotations[kPoolCandidatesIRModAttr]);
128+
auto pool_candidates =
129+
Downcast<Array<PoolInfo>>(op->annotations[kPoolCandidatesAllocateAttr]);
130+
131+
// TODO(@manupa-arm): improve the error when the responsible component for attaching a single
132+
// pool is added
113133
ICHECK(pool_candidates.size() > 0)
114134
<< "The core compiler should at least attach a single PoolInfo. If there were no "
115135
"user-given arguments for memory pools, the default behaviour is a single size "
@@ -203,6 +223,13 @@ void BufferInfoExtractor::VisitExpr_(const CallNode* op) {
203223
Map<BufferInfo, tir::Stmt> BufferInfoExtractor::operator()(const PrimFunc& main_func) {
204224
this->VisitStmt(main_func->body);
205225

226+
// A liveness event is an event that when
227+
// traversing the tir.Stmts where tir.allocate node
228+
// begins or ceases to be Live. This particular struct
229+
// is used to solve interval overlap problem using
230+
// a sweep-line algorithm. For that, we need to record
231+
// where the liveness event occurred in a chronological
232+
// order.
206233
enum LivenessEventType { START = 0, END = 1 };
207234
struct LivenessEvent {
208235
size_t tick;
@@ -216,6 +243,8 @@ Map<BufferInfo, tir::Stmt> BufferInfoExtractor::operator()(const PrimFunc& main_
216243
}
217244
};
218245

246+
// Create a vector of liveness events
247+
// associated with each BufferNodes.
219248
std::vector<LivenessEvent> le_events;
220249
for (const auto& kv : buffer_info_map_) {
221250
if (!kv.second->IsInstance<AllocateNode>()) {
@@ -240,6 +269,9 @@ Map<BufferInfo, tir::Stmt> BufferInfoExtractor::operator()(const PrimFunc& main_
240269
le_events.push_back(le_event_end);
241270
}
242271

272+
// Sort the liveness events based on the chronological
273+
// ordering. For events that are simultaneous, START event
274+
// takes precedence.
243275
std::sort(le_events.begin(), le_events.end(),
244276
[](const LivenessEvent& lhs, const LivenessEvent& rhs) {
245277
if (lhs.tick < rhs.tick) {
@@ -249,6 +281,9 @@ Map<BufferInfo, tir::Stmt> BufferInfoExtractor::operator()(const PrimFunc& main_
249281
}
250282
return false;
251283
});
284+
285+
// Traverse the liveness events using a open set to track what
286+
// is live while updating the conflicts through out the linear traversal
252287
std::unordered_set<BufferInfo, ObjectPtrHash, ObjectPtrEqual> open_set;
253288
for (const auto& le_event : le_events) {
254289
if (le_event.le_type == START) {
@@ -258,7 +293,6 @@ Map<BufferInfo, tir::Stmt> BufferInfoExtractor::operator()(const PrimFunc& main_
258293
}
259294
open_set.insert(le_event.buffer_info);
260295
} else {
261-
ICHECK(le_event.le_type == END);
262296
open_set.erase(le_event.buffer_info);
263297
}
264298
}

src/tir/usmp/utils.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,10 @@ TVM_REGISTER_NODE_TYPE(PoolInfoNode);
7878
TVM_REGISTER_GLOBAL("tir.usmp.PoolInfo")
7979
.set_body_typed([](String pool_name, Map<Target, String> target_access,
8080
Integer size_hint_bytes) {
81-
return PoolInfo(pool_name, target_access, size_hint_bytes);
81+
if (size_hint_bytes.defined()) {
82+
return PoolInfo(pool_name, target_access, size_hint_bytes);
83+
}
84+
return PoolInfo(pool_name, target_access);
8285
});
8386

8487
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)

tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
import pytest
18+
import sys
1819

1920
import tvm
2021
from tvm import tir, script
@@ -79,6 +80,9 @@ def _assign_poolinfos_to_allocates_in_irmodule(mod, pool_infos):
7980
return ret
8081

8182

83+
# These are test IRModules that contains varied topologies of operator graphs
84+
# that includes a main TIR function that includes call to such operators.
85+
8286
# fmt: off
8387
@tvm.script.ir_module
8488
class LinearStructure:
@@ -846,4 +850,4 @@ def test_inception_structure():
846850

847851

848852
if __name__ == "__main__":
849-
pytest.main([__file__])
853+
pytest.main([__file__] + sys.argv[1:])

tests/python/unittest/test_tir_usmp_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
import pytest
18+
import sys
1819

1920
import tvm
2021
from tvm.script import tir as T
@@ -188,4 +189,4 @@ def test_create_array_buffer_info():
188189

189190

190191
if __name__ == "__main__":
191-
pytest.main([__file__])
192+
pytest.main([__file__] + sys.argv[1:])

0 commit comments

Comments
 (0)