|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one |
| 3 | + * or more contributor license agreements. See the NOTICE file |
| 4 | + * distributed with this work for additional information |
| 5 | + * regarding copyright ownership. The ASF licenses this file |
| 6 | + * to you under the Apache License, Version 2.0 (the |
| 7 | + * "License"); you may not use this file except in compliance |
| 8 | + * with the License. You may obtain a copy of the License at |
| 9 | + * |
| 10 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | + * |
| 12 | + * Unless required by applicable law or agreed to in writing, |
| 13 | + * software distributed under the License is distributed on an |
| 14 | + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 15 | + * KIND, either express or implied. See the License for the |
| 16 | + * specific language governing permissions and limitations |
| 17 | + * under the License. |
| 18 | + */ |
| 19 | + |
| 20 | +/*! |
| 21 | + * \file tir/analysis/usmp/algo/greedy_by_size.cc |
| 22 | + * \brief Implement greedy by size memory planning algorithm |
| 23 | + */ |
| 24 | + |
| 25 | +#include <tvm/arith/analyzer.h> |
| 26 | +#include <tvm/runtime/device_api.h> |
| 27 | +#include <tvm/tir/builtin.h> |
| 28 | +#include <tvm/tir/function.h> |
| 29 | +#include <tvm/tir/stmt_functor.h> |
| 30 | +#include <tvm/tir/usmp/utils.h> |
| 31 | + |
| 32 | +namespace tvm { |
| 33 | +namespace tir { |
| 34 | +namespace usmp { |
| 35 | +namespace algo { |
| 36 | + |
| 37 | +size_t round_up_to_byte_alignment(const size_t& non_aligned_byte_offset, |
| 38 | + const int& byte_alignment) { |
| 39 | + return ((non_aligned_byte_offset + byte_alignment - 1) / byte_alignment) * byte_alignment; |
| 40 | +} |
| 41 | + |
| 42 | +bool IsValidPlacement(const PoolInfo& candidate_pool, const size_t& next_offset, |
| 43 | + const size_t& size_bytes) { |
| 44 | + if (candidate_pool->size_hint_bytes == -1) { |
| 45 | + // this means pool is not bounded |
| 46 | + return true; |
| 47 | + } |
| 48 | + auto pool_size = static_cast<size_t>(candidate_pool->size_hint_bytes->value); |
| 49 | + auto max_address = next_offset + size_bytes; |
| 50 | + if (max_address <= pool_size) { |
| 51 | + return true; |
| 52 | + } |
| 53 | + return false; |
| 54 | +} |
| 55 | + |
| 56 | +PoolInfo SelectPlacementPool( |
| 57 | + const Array<PoolInfo>& pool_candidates, |
| 58 | + const std::unordered_map<PoolInfo, size_t, ObjectPtrHash, ObjectPtrEqual>& pool_offsets) { |
| 59 | + for (const auto& pool_info : pool_candidates) { |
| 60 | + if (pool_offsets.count(pool_info)) { |
| 61 | + return pool_info; |
| 62 | + } |
| 63 | + } |
| 64 | + ICHECK(false) << "TVM USMP Internal Error: no candidate have been selected!"; |
| 65 | + return PoolInfo(); |
| 66 | +} |
| 67 | + |
| 68 | +Map<BufferInfo, PoolAllocation> GreedyBySize(const Array<BufferInfo>& buffer_info_arr) { |
| 69 | + std::vector<BufferInfo> buffer_info_vec; |
| 70 | + Map<BufferInfo, PoolAllocation> pool_allocations; |
| 71 | + for (const auto& buffer_info : buffer_info_arr) { |
| 72 | + buffer_info_vec.push_back(std::move(buffer_info)); |
| 73 | + } |
| 74 | + std::sort(buffer_info_vec.begin(), buffer_info_vec.end(), |
| 75 | + [](const BufferInfo& a, const BufferInfo& b) { |
| 76 | + if (a->size_bytes->value == b->size_bytes->value) { |
| 77 | + if (a->conflicts.size() == b->conflicts.size()) { |
| 78 | + auto a_name_hash = std::hash<std::string>{}(a->name_hint->data); |
| 79 | + auto b_name_hash = std::hash<std::string>{}(b->name_hint->data); |
| 80 | + return a_name_hash > b_name_hash; |
| 81 | + } else { |
| 82 | + return a->conflicts.size() > b->conflicts.size(); |
| 83 | + } |
| 84 | + } |
| 85 | + return a->size_bytes > b->size_bytes; |
| 86 | + }); |
| 87 | + |
| 88 | + for (const auto& buf_info : buffer_info_vec) { |
| 89 | + std::unordered_map<PoolInfo, size_t, ObjectPtrHash, ObjectPtrEqual> pool_offset_candidates; |
| 90 | + for (const auto& pool_info : buf_info->pool_candidates) { |
| 91 | + if (IsValidPlacement(pool_info, 0, buf_info->size_bytes->value)) { |
| 92 | + pool_offset_candidates[pool_info] = 0; |
| 93 | + } |
| 94 | + } |
| 95 | + |
| 96 | + for (const auto& conflict_buf_info_obj : buf_info->conflicts) { |
| 97 | + auto conflict_buf_info = Downcast<BufferInfo>(conflict_buf_info_obj); |
| 98 | + size_t next_offset = 0; |
| 99 | + if (pool_allocations.count(conflict_buf_info)) { |
| 100 | + auto pool_allocation = pool_allocations[conflict_buf_info]; |
| 101 | + next_offset = pool_allocation->byte_offset + conflict_buf_info->size_bytes; |
| 102 | + next_offset = round_up_to_byte_alignment(next_offset, conflict_buf_info->alignment->value); |
| 103 | + if (IsValidPlacement(pool_allocation->pool_info, next_offset, |
| 104 | + buf_info->size_bytes->value)) { |
| 105 | + if (next_offset > pool_offset_candidates[pool_allocation->pool_info]) { |
| 106 | + pool_offset_candidates[pool_allocation->pool_info] = next_offset; |
| 107 | + } |
| 108 | + } else { |
| 109 | + pool_offset_candidates.erase(pool_allocation->pool_info); |
| 110 | + } |
| 111 | + } |
| 112 | + } |
| 113 | + auto selected_pool = SelectPlacementPool(buf_info->pool_candidates, pool_offset_candidates); |
| 114 | + pool_allocations.Set( |
| 115 | + buf_info, PoolAllocation(selected_pool, Integer(pool_offset_candidates[selected_pool]))); |
| 116 | + } |
| 117 | + return pool_allocations; |
| 118 | +} |
| 119 | + |
| 120 | +TVM_REGISTER_GLOBAL("tir.usmp.algo.greedy_by_size") |
| 121 | + .set_body_typed([](Array<BufferInfo> buffer_info_arr) { |
| 122 | + return GreedyBySize(buffer_info_arr); |
| 123 | + }); |
| 124 | + |
| 125 | +} // namespace algo |
| 126 | +} // namespace usmp |
| 127 | +} // namespace tir |
| 128 | +} // namespace tvm |
0 commit comments