Skip to content

Commit 2fc118b

Browse files
committed
clean up headers
1 parent d8b2aa3 commit 2fc118b

File tree

5 files changed

+172
-165
lines changed

5 files changed

+172
-165
lines changed

src/meta_schedule/schedule_rule/multi_level_tiling.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
*/
1919
#include "multi_level_tiling.h"
2020

21-
#include <unordered_map>
22-
2321
#include "../utils.h"
2422

2523
namespace tvm {

src/meta_schedule/schedule_rule/multi_level_tiling.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,12 @@
1616
* specific language governing permissions and limitations
1717
* under the License.
1818
*/
19-
#include <unordered_map>
19+
#ifndef TVM_META_SCHEDULE_SCHEDULE_RULE_MULTI_LEVEL_TILING_H_
20+
#define TVM_META_SCHEDULE_SCHEDULE_RULE_MULTI_LEVEL_TILING_H_
2021

21-
#include "../utils.h"
22+
#include <tvm/meta_schedule/schedule_rule.h>
23+
#include <tvm/tir/schedule/schedule.h>
24+
#include "../../support/array.h"
2225

2326
namespace tvm {
2427
namespace meta_schedule {
@@ -206,3 +209,5 @@ ScheduleRule MultiLevelTilingInitCommon(String structure, Optional<Array<String>
206209

207210
} // namespace meta_schedule
208211
} // namespace tvm
212+
213+
#endif // TVM_META_SCHEDULE_SCHEDULE_RULE_MULTI_LEVEL_TILING_H_

src/meta_schedule/schedule_rule/multi_level_tiling_vnni.cc

Lines changed: 3 additions & 161 deletions
Original file line numberDiff line numberDiff line change
@@ -16,180 +16,22 @@
1616
* specific language governing permissions and limitations
1717
* under the License.
1818
*/
19-
#include <unordered_map>
2019

2120
#include "../utils.h"
2221
#include "multi_level_tiling.h"
22+
#include "../../tir/schedule/analysis.h"
2323

2424
namespace tvm {
2525
namespace meta_schedule {
2626

2727
using tir::LoopRV;
2828

29-
/*! \brief Necessary information used for tensorization */
30-
class TensorizeInfoNode : public Object {
31-
public:
32-
/*! \brief Maps block loops to desc loops */
33-
Map<tir::StmtSRef, tir::For> loop_map;
34-
/*! \brief Maps loops in desc to its index, outer to inner */
35-
Map<tir::For, Integer> desc_loop_indexer;
36-
37-
void VisitAttrs(AttrVisitor* v) {
38-
v->Visit("loop_map", &loop_map);
39-
v->Visit("desc_loop_indexer", &desc_loop_indexer);
40-
}
41-
42-
static constexpr const char* _type_key = "tir.analysis.TensorizeInfo";
43-
TVM_DECLARE_FINAL_OBJECT_INFO(TensorizeInfoNode, Object);
44-
};
45-
46-
class TensorizeInfo : public ObjectRef {
47-
public:
48-
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorizeInfo, ObjectRef, TensorizeInfoNode);
49-
};
50-
51-
TVM_REGISTER_NODE_TYPE(TensorizeInfoNode);
52-
53-
Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
54-
const tir::StmtSRef& block_sref,
55-
const tir::PrimFunc& desc_func) {
56-
// Try to do tiling automatically if possible
57-
// Now the heuristic is that if block's block var binding is constant + loop var,
58-
// in other words, with tir.block(..., vi=Ci+i, vj=Cj+j, vk=Ck+k), then we split and reorder
59-
// i, j, k according to the loops outside desc_block
60-
// Collect the loops outside block
61-
arith::Analyzer analyzer;
62-
const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref);
63-
// Step 1. Analyze desc_func, extract its block, loops and loop vars
64-
const tir::BlockRealizeNode* desc_block = nullptr;
65-
std::vector<const tir::ForNode*> desc_loops;
66-
std::unordered_set<const tir::VarNode*> desc_loop_vars;
67-
const auto* desc_scope_realize = desc_func->body.as<tir::BlockRealizeNode>();
68-
ICHECK(desc_scope_realize);
69-
{
70-
auto f_visit = [&desc_block, &desc_loops, &desc_loop_vars,
71-
&analyzer](const ObjectRef& obj) -> bool {
72-
// Extract the block
73-
if (const auto* block = obj.as<tir::BlockRealizeNode>()) {
74-
desc_block = block;
75-
return false;
76-
}
77-
// Extract the loops
78-
if (const auto* loop = obj.as<tir::ForNode>()) {
79-
desc_loops.push_back(loop);
80-
desc_loop_vars.insert(loop->loop_var.get());
81-
if (!analyzer.CanProve(loop->min == 0)) {
82-
return false;
83-
}
84-
}
85-
return true;
86-
};
87-
tir::PostOrderVisit(desc_scope_realize->block->body, f_visit);
88-
std::reverse(desc_loops.begin(), desc_loops.end());
89-
ICHECK(desc_block);
90-
}
91-
// Step 2. Check if `desc_block` matches `block`
92-
// Ignore the scope of buffers when comparing, since we can do cache_read/write
93-
const tir::StmtSRef& scope_sref = GetScopeRoot(self, block_sref, false);
94-
const tir::BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref);
95-
std::vector<const tir::ForNode*> block_loops;
96-
std::unordered_set<const tir::VarNode*> block_loop_vars;
97-
{
98-
for (const tir::StmtSRefNode* loop_sref = block_sref->parent;; loop_sref = loop_sref->parent) {
99-
const auto* loop = loop_sref->StmtAs<tir::ForNode>();
100-
if (loop == nullptr || loop->body->IsInstance<tir::SeqStmtNode>()) {
101-
break;
102-
}
103-
block_loops.push_back(loop);
104-
block_loop_vars.insert(loop->loop_var.get());
105-
if (!analyzer.CanProve(loop->min == 0)) {
106-
return NullOpt;
107-
}
108-
}
109-
std::reverse(block_loops.begin(), block_loops.end());
110-
}
111-
// Step 4. Map from block loops to desc block loops
112-
ObjectPtr<TensorizeInfoNode> ret = make_object<TensorizeInfoNode>();
113-
int n_block_vars = block->iter_values.size();
114-
int n_desc_vars = desc_block->iter_values.size();
115-
int offset = n_block_vars - n_desc_vars;
116-
if (offset < 0) {
117-
return NullOpt;
118-
}
119-
// We align the block and desc block's bindings from the right side
120-
// block (v0=..., v1=..., v2=...)
121-
// ^ i_block
122-
// desc_block( v1=..., v2=...)
123-
// ^ i_desc
124-
for (int i_desc = 0, i_block = offset; i_desc < n_desc_vars; ++i_desc, ++i_block) {
125-
// For each block var binding, we find
126-
const PrimExpr& block_bind = block->iter_values[i_block];
127-
const PrimExpr& desc_bind = desc_block->iter_values[i_desc];
128-
// Step 4.1. Find the corresponding loop of the i-th block var of block
129-
const tir::ForNode* block_loop = nullptr;
130-
for (int i = 0, n = block_loops.size(); i < n; ++i) {
131-
// Check if block_bind = block_loops[i]->loop_var + stuff-irrelevant-of-loop-vars
132-
PrimExpr r = analyzer.Simplify(block_bind - block_loops[i]->loop_var);
133-
if (!tir::UsesVar(r, [&block_loop_vars](const tir::VarNode* var) {
134-
return block_loop_vars.count(var);
135-
})) {
136-
block_loop = block_loops[i];
137-
break;
138-
}
139-
}
140-
if (block_loop == nullptr) {
141-
return NullOpt;
142-
}
143-
// Step 4.2. Find the corresponding loop of the i-th block var of desc
144-
const tir::ForNode* desc_loop = nullptr;
145-
for (int i = 0, n = desc_loops.size(); i < n; ++i) {
146-
// Check if desc_bind = loops[i]->loop_var + stuff-irrelevant-of-loop-vars
147-
PrimExpr r = analyzer.Simplify(desc_bind - desc_loops[i]->loop_var);
148-
if (!tir::UsesVar(r, [&desc_loop_vars](const tir::VarNode* var) {
149-
return desc_loop_vars.count(var);
150-
})) {
151-
desc_loop = desc_loops[i];
152-
break;
153-
}
154-
}
155-
if (block_loop == nullptr) {
156-
return NullOpt;
157-
}
158-
// Step 4.3. Check divisibility of loop extents
159-
PrimExpr block_extent = analyzer.Simplify(block_loop->extent);
160-
PrimExpr desc_extent = analyzer.Simplify(desc_loop->extent);
161-
if (const auto* int_block_extent = block_extent.as<IntImmNode>()) {
162-
if (const auto* int_desc_extent = desc_extent.as<IntImmNode>()) {
163-
if (int_block_extent->value % int_desc_extent->value != 0) {
164-
return NullOpt;
165-
}
166-
} else {
167-
return NullOpt;
168-
}
169-
} else {
170-
return NullOpt;
171-
}
172-
// Step 4.4. Maps the result of Step 4.1 to Step 4.2
173-
const tir::StmtSRef& block_loop_sref = self->stmt2ref[block_loop];
174-
auto it = ret->loop_map.find(block_loop_sref);
175-
if (it == ret->loop_map.end()) {
176-
ret->loop_map.Set(block_loop_sref, GetRef<tir::For>(desc_loop));
177-
} else if ((*it).second.get() != desc_loop) {
178-
return NullOpt;
179-
}
180-
}
181-
for (int i = 0, n = desc_loops.size(); i < n; ++i) {
182-
ret->desc_loop_indexer.Set(GetRef<tir::For>(desc_loops[i]), Integer(i));
183-
}
184-
return TensorizeInfo(ret);
185-
}
186-
18729
Optional<LoopRV> TilingwithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv,
18830
const String& intrin_name) {
189-
Optional<TensorizeInfo> opt_tensorize_info = GetTensorizeLoopMapping(
31+
Optional<tir::TensorizeInfo> opt_tensorize_info = GetTensorizeLoopMapping(
19032
sch->state(), sch->GetSRef(block_rv), tir::TensorIntrin::Get(intrin_name)->desc);
19133
if (!opt_tensorize_info) return NullOpt;
192-
const TensorizeInfoNode* info = opt_tensorize_info.value().get();
34+
const tir::TensorizeInfoNode* info = opt_tensorize_info.value().get();
19335
// Construct a mapping from tir loops back to LoopRVs
19436
Map<tir::StmtSRef, LoopRV> loop2rv;
19537
{

src/tir/schedule/analysis.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,6 +645,32 @@ Array<arith::IntSet> AnalyzeRegionLowerBound(const BufferRegion& region, const P
645645
const StmtSRef& dom_high_exclusive,
646646
arith::Analyzer* analyzer);
647647

648+
/*! \brief Necessary information used for tensorization */
649+
class TensorizeInfoNode : public Object {
650+
public:
651+
/*! \brief Maps block loops to desc loops */
652+
Map<tir::StmtSRef, tir::For> loop_map;
653+
/*! \brief Maps loops in desc to its index, outer to inner */
654+
Map<tir::For, Integer> desc_loop_indexer;
655+
656+
void VisitAttrs(AttrVisitor* v) {
657+
v->Visit("loop_map", &loop_map);
658+
v->Visit("desc_loop_indexer", &desc_loop_indexer);
659+
}
660+
661+
static constexpr const char* _type_key = "tir.analysis.TensorizeInfo";
662+
TVM_DECLARE_FINAL_OBJECT_INFO(TensorizeInfoNode, Object);
663+
};
664+
665+
class TensorizeInfo : public ObjectRef {
666+
public:
667+
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorizeInfo, ObjectRef, TensorizeInfoNode);
668+
};
669+
670+
Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
671+
const tir::StmtSRef& block_sref,
672+
const tir::PrimFunc& desc_func);
673+
648674
} // namespace tir
649675
} // namespace tvm
650676

src/tir/schedule/analysis/analysis.cc

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1992,5 +1992,141 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, //
19921992
}
19931993
}
19941994

1995+
TVM_REGISTER_NODE_TYPE(TensorizeInfoNode);
1996+
1997+
Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
1998+
const tir::StmtSRef& block_sref,
1999+
const tir::PrimFunc& desc_func) {
2000+
// Try to do tiling automatically if possible
2001+
// Now the heuristic is that if block's block var binding is constant + loop var,
2002+
// in other words, with tir.block(..., vi=Ci+i, vj=Cj+j, vk=Ck+k), then we split and reorder
2003+
// i, j, k according to the loops outside desc_block
2004+
// Collect the loops outside block
2005+
arith::Analyzer analyzer;
2006+
const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref);
2007+
// Step 1. Analyze desc_func, extract its block, loops and loop vars
2008+
const tir::BlockRealizeNode* desc_block = nullptr;
2009+
std::vector<const tir::ForNode*> desc_loops;
2010+
std::unordered_set<const tir::VarNode*> desc_loop_vars;
2011+
const auto* desc_scope_realize = desc_func->body.as<tir::BlockRealizeNode>();
2012+
ICHECK(desc_scope_realize);
2013+
{
2014+
auto f_visit = [&desc_block, &desc_loops, &desc_loop_vars,
2015+
&analyzer](const ObjectRef& obj) -> bool {
2016+
// Extract the block
2017+
if (const auto* block = obj.as<tir::BlockRealizeNode>()) {
2018+
desc_block = block;
2019+
return false;
2020+
}
2021+
// Extract the loops
2022+
if (const auto* loop = obj.as<tir::ForNode>()) {
2023+
desc_loops.push_back(loop);
2024+
desc_loop_vars.insert(loop->loop_var.get());
2025+
if (!analyzer.CanProve(loop->min == 0)) {
2026+
return false;
2027+
}
2028+
}
2029+
return true;
2030+
};
2031+
tir::PostOrderVisit(desc_scope_realize->block->body, f_visit);
2032+
std::reverse(desc_loops.begin(), desc_loops.end());
2033+
ICHECK(desc_block);
2034+
}
2035+
// Step 2. Check if `desc_block` matches `block`
2036+
// Ignore the scope of buffers when comparing, since we can do cache_read/write
2037+
const tir::StmtSRef& scope_sref = GetScopeRoot(self, block_sref, false);
2038+
const tir::BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref);
2039+
std::vector<const tir::ForNode*> block_loops;
2040+
std::unordered_set<const tir::VarNode*> block_loop_vars;
2041+
{
2042+
for (const tir::StmtSRefNode* loop_sref = block_sref->parent;; loop_sref = loop_sref->parent) {
2043+
const auto* loop = loop_sref->StmtAs<tir::ForNode>();
2044+
if (loop == nullptr || loop->body->IsInstance<tir::SeqStmtNode>()) {
2045+
break;
2046+
}
2047+
block_loops.push_back(loop);
2048+
block_loop_vars.insert(loop->loop_var.get());
2049+
if (!analyzer.CanProve(loop->min == 0)) {
2050+
return NullOpt;
2051+
}
2052+
}
2053+
std::reverse(block_loops.begin(), block_loops.end());
2054+
}
2055+
// Step 4. Map from block loops to desc block loops
2056+
ObjectPtr<TensorizeInfoNode> ret = make_object<TensorizeInfoNode>();
2057+
int n_block_vars = block->iter_values.size();
2058+
int n_desc_vars = desc_block->iter_values.size();
2059+
int offset = n_block_vars - n_desc_vars;
2060+
if (offset < 0) {
2061+
return NullOpt;
2062+
}
2063+
// We align the block and desc block's bindings from the right side
2064+
// block (v0=..., v1=..., v2=...)
2065+
// ^ i_block
2066+
// desc_block( v1=..., v2=...)
2067+
// ^ i_desc
2068+
for (int i_desc = 0, i_block = offset; i_desc < n_desc_vars; ++i_desc, ++i_block) {
2069+
// For each block var binding, we find
2070+
const PrimExpr& block_bind = block->iter_values[i_block];
2071+
const PrimExpr& desc_bind = desc_block->iter_values[i_desc];
2072+
// Step 4.1. Find the corresponding loop of the i-th block var of block
2073+
const tir::ForNode* block_loop = nullptr;
2074+
for (int i = 0, n = block_loops.size(); i < n; ++i) {
2075+
// Check if block_bind = block_loops[i]->loop_var + stuff-irrelevant-of-loop-vars
2076+
PrimExpr r = analyzer.Simplify(block_bind - block_loops[i]->loop_var);
2077+
if (!tir::UsesVar(r, [&block_loop_vars](const tir::VarNode* var) {
2078+
return block_loop_vars.count(var);
2079+
})) {
2080+
block_loop = block_loops[i];
2081+
break;
2082+
}
2083+
}
2084+
if (block_loop == nullptr) {
2085+
return NullOpt;
2086+
}
2087+
// Step 4.2. Find the corresponding loop of the i-th block var of desc
2088+
const tir::ForNode* desc_loop = nullptr;
2089+
for (int i = 0, n = desc_loops.size(); i < n; ++i) {
2090+
// Check if desc_bind = loops[i]->loop_var + stuff-irrelevant-of-loop-vars
2091+
PrimExpr r = analyzer.Simplify(desc_bind - desc_loops[i]->loop_var);
2092+
if (!tir::UsesVar(r, [&desc_loop_vars](const tir::VarNode* var) {
2093+
return desc_loop_vars.count(var);
2094+
})) {
2095+
desc_loop = desc_loops[i];
2096+
break;
2097+
}
2098+
}
2099+
if (block_loop == nullptr) {
2100+
return NullOpt;
2101+
}
2102+
// Step 4.3. Check divisibility of loop extents
2103+
PrimExpr block_extent = analyzer.Simplify(block_loop->extent);
2104+
PrimExpr desc_extent = analyzer.Simplify(desc_loop->extent);
2105+
if (const auto* int_block_extent = block_extent.as<IntImmNode>()) {
2106+
if (const auto* int_desc_extent = desc_extent.as<IntImmNode>()) {
2107+
if (int_block_extent->value % int_desc_extent->value != 0) {
2108+
return NullOpt;
2109+
}
2110+
} else {
2111+
return NullOpt;
2112+
}
2113+
} else {
2114+
return NullOpt;
2115+
}
2116+
// Step 4.4. Maps the result of Step 4.1 to Step 4.2
2117+
const tir::StmtSRef& block_loop_sref = self->stmt2ref[block_loop];
2118+
auto it = ret->loop_map.find(block_loop_sref);
2119+
if (it == ret->loop_map.end()) {
2120+
ret->loop_map.Set(block_loop_sref, GetRef<tir::For>(desc_loop));
2121+
} else if ((*it).second.get() != desc_loop) {
2122+
return NullOpt;
2123+
}
2124+
}
2125+
for (int i = 0, n = desc_loops.size(); i < n; ++i) {
2126+
ret->desc_loop_indexer.Set(GetRef<tir::For>(desc_loops[i]), Integer(i));
2127+
}
2128+
return TensorizeInfo(ret);
2129+
}
2130+
19952131
} // namespace tir
19962132
} // namespace tvm

0 commit comments

Comments
 (0)