|
16 | 16 | * specific language governing permissions and limitations |
17 | 17 | * under the License. |
18 | 18 | */ |
19 | | -#include <unordered_map> |
20 | 19 |
|
21 | 20 | #include "../utils.h" |
22 | 21 | #include "multi_level_tiling.h" |
| 22 | +#include "../../tir/schedule/analysis.h" |
23 | 23 |
|
24 | 24 | namespace tvm { |
25 | 25 | namespace meta_schedule { |
26 | 26 |
|
27 | 27 | using tir::LoopRV; |
28 | 28 |
|
29 | | -/*! \brief Necessary information used for tensorization */ |
30 | | -class TensorizeInfoNode : public Object { |
31 | | - public: |
32 | | - /*! \brief Maps block loops to desc loops */ |
33 | | - Map<tir::StmtSRef, tir::For> loop_map; |
34 | | - /*! \brief Maps loops in desc to its index, outer to inner */ |
35 | | - Map<tir::For, Integer> desc_loop_indexer; |
36 | | - |
37 | | - void VisitAttrs(AttrVisitor* v) { |
38 | | - v->Visit("loop_map", &loop_map); |
39 | | - v->Visit("desc_loop_indexer", &desc_loop_indexer); |
40 | | - } |
41 | | - |
42 | | - static constexpr const char* _type_key = "tir.analysis.TensorizeInfo"; |
43 | | - TVM_DECLARE_FINAL_OBJECT_INFO(TensorizeInfoNode, Object); |
44 | | -}; |
45 | | - |
46 | | -class TensorizeInfo : public ObjectRef { |
47 | | - public: |
48 | | - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorizeInfo, ObjectRef, TensorizeInfoNode); |
49 | | -}; |
50 | | - |
51 | | -TVM_REGISTER_NODE_TYPE(TensorizeInfoNode); |
52 | | - |
53 | | -Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self, |
54 | | - const tir::StmtSRef& block_sref, |
55 | | - const tir::PrimFunc& desc_func) { |
56 | | - // Try to do tiling automatically if possible |
57 | | - // Now the heuristic is that if block's block var binding is constant + loop var, |
58 | | - // in other words, with tir.block(..., vi=Ci+i, vj=Cj+j, vk=Ck+k), then we split and reorder |
59 | | - // i, j, k according to the loops outside desc_block |
60 | | - // Collect the loops outside block |
61 | | - arith::Analyzer analyzer; |
62 | | - const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref); |
63 | | - // Step 1. Analyze desc_func, extract its block, loops and loop vars |
64 | | - const tir::BlockRealizeNode* desc_block = nullptr; |
65 | | - std::vector<const tir::ForNode*> desc_loops; |
66 | | - std::unordered_set<const tir::VarNode*> desc_loop_vars; |
67 | | - const auto* desc_scope_realize = desc_func->body.as<tir::BlockRealizeNode>(); |
68 | | - ICHECK(desc_scope_realize); |
69 | | - { |
70 | | - auto f_visit = [&desc_block, &desc_loops, &desc_loop_vars, |
71 | | - &analyzer](const ObjectRef& obj) -> bool { |
72 | | - // Extract the block |
73 | | - if (const auto* block = obj.as<tir::BlockRealizeNode>()) { |
74 | | - desc_block = block; |
75 | | - return false; |
76 | | - } |
77 | | - // Extract the loops |
78 | | - if (const auto* loop = obj.as<tir::ForNode>()) { |
79 | | - desc_loops.push_back(loop); |
80 | | - desc_loop_vars.insert(loop->loop_var.get()); |
81 | | - if (!analyzer.CanProve(loop->min == 0)) { |
82 | | - return false; |
83 | | - } |
84 | | - } |
85 | | - return true; |
86 | | - }; |
87 | | - tir::PostOrderVisit(desc_scope_realize->block->body, f_visit); |
88 | | - std::reverse(desc_loops.begin(), desc_loops.end()); |
89 | | - ICHECK(desc_block); |
90 | | - } |
91 | | - // Step 2. Check if `desc_block` matches `block` |
92 | | - // Ignore the scope of buffers when comparing, since we can do cache_read/write |
93 | | - const tir::StmtSRef& scope_sref = GetScopeRoot(self, block_sref, false); |
94 | | - const tir::BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref); |
95 | | - std::vector<const tir::ForNode*> block_loops; |
96 | | - std::unordered_set<const tir::VarNode*> block_loop_vars; |
97 | | - { |
98 | | - for (const tir::StmtSRefNode* loop_sref = block_sref->parent;; loop_sref = loop_sref->parent) { |
99 | | - const auto* loop = loop_sref->StmtAs<tir::ForNode>(); |
100 | | - if (loop == nullptr || loop->body->IsInstance<tir::SeqStmtNode>()) { |
101 | | - break; |
102 | | - } |
103 | | - block_loops.push_back(loop); |
104 | | - block_loop_vars.insert(loop->loop_var.get()); |
105 | | - if (!analyzer.CanProve(loop->min == 0)) { |
106 | | - return NullOpt; |
107 | | - } |
108 | | - } |
109 | | - std::reverse(block_loops.begin(), block_loops.end()); |
110 | | - } |
111 | | - // Step 4. Map from block loops to desc block loops |
112 | | - ObjectPtr<TensorizeInfoNode> ret = make_object<TensorizeInfoNode>(); |
113 | | - int n_block_vars = block->iter_values.size(); |
114 | | - int n_desc_vars = desc_block->iter_values.size(); |
115 | | - int offset = n_block_vars - n_desc_vars; |
116 | | - if (offset < 0) { |
117 | | - return NullOpt; |
118 | | - } |
119 | | - // We align the block and desc block's bindings from the right side |
120 | | - // block (v0=..., v1=..., v2=...) |
121 | | - // ^ i_block |
122 | | - // desc_block( v1=..., v2=...) |
123 | | - // ^ i_desc |
124 | | - for (int i_desc = 0, i_block = offset; i_desc < n_desc_vars; ++i_desc, ++i_block) { |
125 | | - // For each block var binding, we find |
126 | | - const PrimExpr& block_bind = block->iter_values[i_block]; |
127 | | - const PrimExpr& desc_bind = desc_block->iter_values[i_desc]; |
128 | | - // Step 4.1. Find the corresponding loop of the i-th block var of block |
129 | | - const tir::ForNode* block_loop = nullptr; |
130 | | - for (int i = 0, n = block_loops.size(); i < n; ++i) { |
131 | | - // Check if block_bind = block_loops[i]->loop_var + stuff-irrelevant-of-loop-vars |
132 | | - PrimExpr r = analyzer.Simplify(block_bind - block_loops[i]->loop_var); |
133 | | - if (!tir::UsesVar(r, [&block_loop_vars](const tir::VarNode* var) { |
134 | | - return block_loop_vars.count(var); |
135 | | - })) { |
136 | | - block_loop = block_loops[i]; |
137 | | - break; |
138 | | - } |
139 | | - } |
140 | | - if (block_loop == nullptr) { |
141 | | - return NullOpt; |
142 | | - } |
143 | | - // Step 4.2. Find the corresponding loop of the i-th block var of desc |
144 | | - const tir::ForNode* desc_loop = nullptr; |
145 | | - for (int i = 0, n = desc_loops.size(); i < n; ++i) { |
146 | | - // Check if desc_bind = loops[i]->loop_var + stuff-irrelevant-of-loop-vars |
147 | | - PrimExpr r = analyzer.Simplify(desc_bind - desc_loops[i]->loop_var); |
148 | | - if (!tir::UsesVar(r, [&desc_loop_vars](const tir::VarNode* var) { |
149 | | - return desc_loop_vars.count(var); |
150 | | - })) { |
151 | | - desc_loop = desc_loops[i]; |
152 | | - break; |
153 | | - } |
154 | | - } |
155 | | - if (block_loop == nullptr) { |
156 | | - return NullOpt; |
157 | | - } |
158 | | - // Step 4.3. Check divisibility of loop extents |
159 | | - PrimExpr block_extent = analyzer.Simplify(block_loop->extent); |
160 | | - PrimExpr desc_extent = analyzer.Simplify(desc_loop->extent); |
161 | | - if (const auto* int_block_extent = block_extent.as<IntImmNode>()) { |
162 | | - if (const auto* int_desc_extent = desc_extent.as<IntImmNode>()) { |
163 | | - if (int_block_extent->value % int_desc_extent->value != 0) { |
164 | | - return NullOpt; |
165 | | - } |
166 | | - } else { |
167 | | - return NullOpt; |
168 | | - } |
169 | | - } else { |
170 | | - return NullOpt; |
171 | | - } |
172 | | - // Step 4.4. Maps the result of Step 4.1 to Step 4.2 |
173 | | - const tir::StmtSRef& block_loop_sref = self->stmt2ref[block_loop]; |
174 | | - auto it = ret->loop_map.find(block_loop_sref); |
175 | | - if (it == ret->loop_map.end()) { |
176 | | - ret->loop_map.Set(block_loop_sref, GetRef<tir::For>(desc_loop)); |
177 | | - } else if ((*it).second.get() != desc_loop) { |
178 | | - return NullOpt; |
179 | | - } |
180 | | - } |
181 | | - for (int i = 0, n = desc_loops.size(); i < n; ++i) { |
182 | | - ret->desc_loop_indexer.Set(GetRef<tir::For>(desc_loops[i]), Integer(i)); |
183 | | - } |
184 | | - return TensorizeInfo(ret); |
185 | | -} |
186 | | - |
187 | 29 | Optional<LoopRV> TilingwithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv, |
188 | 30 | const String& intrin_name) { |
189 | | - Optional<TensorizeInfo> opt_tensorize_info = GetTensorizeLoopMapping( |
| 31 | + Optional<tir::TensorizeInfo> opt_tensorize_info = GetTensorizeLoopMapping( |
190 | 32 | sch->state(), sch->GetSRef(block_rv), tir::TensorIntrin::Get(intrin_name)->desc); |
191 | 33 | if (!opt_tensorize_info) return NullOpt; |
192 | | - const TensorizeInfoNode* info = opt_tensorize_info.value().get(); |
| 34 | + const tir::TensorizeInfoNode* info = opt_tensorize_info.value().get(); |
193 | 35 | // Construct a mapping from tir loops back to LoopRVs |
194 | 36 | Map<tir::StmtSRef, LoopRV> loop2rv; |
195 | 37 | { |
|
0 commit comments