Skip to content

Commit ec29b89

Browse files
committed
[TIR][USMP] greedy_by_size usmp algo
* Implementation of greedy by size memory planning algorithm * Added a test case of linear sequence of operators with two pools * Added a test case with residual structures Change-Id: I03b41292eab85ddb43710356c23dd123beb24462
1 parent de8e4f1 commit ec29b89

File tree

2 files changed

+498
-0
lines changed

2 files changed

+498
-0
lines changed
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file tir/analysis/usmp/algo/greedy_by_size.cc
22+
* \brief Implement greedy by size memory planning algorithm
23+
*/
24+
25+
#include <tvm/arith/analyzer.h>
26+
#include <tvm/runtime/device_api.h>
27+
#include <tvm/tir/builtin.h>
28+
#include <tvm/tir/function.h>
29+
#include <tvm/tir/stmt_functor.h>
30+
#include <tvm/tir/usmp/utils.h>
31+
32+
namespace tvm {
33+
namespace tir {
34+
namespace usmp {
35+
namespace algo {
36+
37+
size_t round_up_to_byte_alignment(const size_t& non_aligned_byte_offset,
38+
const int& byte_alignment) {
39+
return ((non_aligned_byte_offset + byte_alignment - 1) / byte_alignment) * byte_alignment;
40+
}
41+
42+
bool IsValidPlacement(const PoolInfo& candidate_pool, const size_t& next_offset,
43+
const size_t& size_bytes) {
44+
if (candidate_pool->size_hint_bytes == -1) {
45+
// this means pool is not bounded
46+
return true;
47+
}
48+
auto pool_size = static_cast<size_t>(candidate_pool->size_hint_bytes->value);
49+
auto max_address = next_offset + size_bytes;
50+
if (max_address <= pool_size) {
51+
return true;
52+
}
53+
return false;
54+
}
55+
56+
PoolInfo SelectPlacementPool(
57+
const Array<PoolInfo>& pool_candidates,
58+
const std::unordered_map<PoolInfo, size_t, ObjectPtrHash, ObjectPtrEqual>& pool_offsets) {
59+
for (const auto& pool_info : pool_candidates) {
60+
if (pool_offsets.count(pool_info)) {
61+
return pool_info;
62+
}
63+
}
64+
ICHECK(false) << "TVM USMP Internal Error: no candidate have been selected!";
65+
return PoolInfo();
66+
}
67+
68+
Map<BufferInfo, PoolAllocation> GreedyBySize(const Array<BufferInfo>& buffer_info_arr) {
69+
std::vector<BufferInfo> buffer_info_vec;
70+
Map<BufferInfo, PoolAllocation> pool_allocations;
71+
for (const auto& buffer_info : buffer_info_arr) {
72+
buffer_info_vec.push_back(std::move(buffer_info));
73+
}
74+
std::sort(buffer_info_vec.begin(), buffer_info_vec.end(),
75+
[](const BufferInfo& a, const BufferInfo& b) {
76+
if (a->size_bytes->value == b->size_bytes->value) {
77+
if (a->conflicts.size() == b->conflicts.size()) {
78+
auto a_name_hash = std::hash<std::string>{}(a->name_hint->data);
79+
auto b_name_hash = std::hash<std::string>{}(b->name_hint->data);
80+
return a_name_hash > b_name_hash;
81+
} else {
82+
return a->conflicts.size() > b->conflicts.size();
83+
}
84+
}
85+
return a->size_bytes > b->size_bytes;
86+
});
87+
88+
for (const auto& buf_info : buffer_info_vec) {
89+
std::unordered_map<PoolInfo, size_t, ObjectPtrHash, ObjectPtrEqual> pool_offset_candidates;
90+
for (const auto& pool_info : buf_info->pool_candidates) {
91+
if (IsValidPlacement(pool_info, 0, buf_info->size_bytes->value)) {
92+
pool_offset_candidates[pool_info] = 0;
93+
}
94+
}
95+
96+
for (const auto& conflict_buf_info_obj : buf_info->conflicts) {
97+
auto conflict_buf_info = Downcast<BufferInfo>(conflict_buf_info_obj);
98+
size_t next_offset = 0;
99+
if (pool_allocations.count(conflict_buf_info)) {
100+
auto pool_allocation = pool_allocations[conflict_buf_info];
101+
next_offset = pool_allocation->byte_offset + conflict_buf_info->size_bytes;
102+
next_offset = round_up_to_byte_alignment(next_offset, conflict_buf_info->alignment->value);
103+
if (IsValidPlacement(pool_allocation->pool_info, next_offset,
104+
buf_info->size_bytes->value)) {
105+
if (next_offset > pool_offset_candidates[pool_allocation->pool_info]) {
106+
pool_offset_candidates[pool_allocation->pool_info] = next_offset;
107+
}
108+
} else {
109+
pool_offset_candidates.erase(pool_allocation->pool_info);
110+
}
111+
}
112+
}
113+
auto selected_pool = SelectPlacementPool(buf_info->pool_candidates, pool_offset_candidates);
114+
pool_allocations.Set(
115+
buf_info, PoolAllocation(selected_pool, Integer(pool_offset_candidates[selected_pool])));
116+
}
117+
return pool_allocations;
118+
}
119+
120+
TVM_REGISTER_GLOBAL("tir.usmp.algo.greedy_by_size")
121+
.set_body_typed([](Array<BufferInfo> buffer_info_arr) {
122+
return GreedyBySize(buffer_info_arr);
123+
});
124+
125+
} // namespace algo
126+
} // namespace usmp
127+
} // namespace tir
128+
} // namespace tvm

0 commit comments

Comments
 (0)