Skip to content

Commit 937a14f

Browse files
vinx13Siyuan FengspectrometerHBHjinhongyiiMasterJH5574
authored
[TIR][Analysis] Add SuggestIndexMap for layout rewriting (#10732)
This PR added an analysis function `SuggestIndexMap` to analyze buffer access pattern and suggest index map for layout transformations. Co-authored-by: Siyuan Feng <[email protected]> Co-authored-by: Bohan Hou <[email protected]> Co-authored-by: Hongyi Jin <[email protected]> Co-authored-by: Ruihang Lai <[email protected]> Co-authored-by: Junru Shao <[email protected]> Co-authored-by: Xiyou Zhou <[email protected]>
1 parent 67da111 commit 937a14f

File tree

8 files changed

+411
-1
lines changed

8 files changed

+411
-1
lines changed

python/tvm/tir/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list
4343
from .stmt import BufferRegion, MatchBufferRegion, Block, BlockRealize
4444

45-
from .function import PrimFunc, TensorIntrin
45+
from .function import PrimFunc, TensorIntrin, IndexMap
4646

4747
from .op import call_packed, call_intrin, call_pure_extern, call_extern
4848
from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace

python/tvm/tir/function.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,3 +295,18 @@ def from_func(mapping_function: Callable, ndim: Optional[int] = None):
295295

296296
final_indices = mapping_function(*args)
297297
return IndexMap(args, final_indices)
298+
299+
def map_indices(self, indices: List[PrimExpr]) -> List[PrimExpr]:
300+
"""Apply the index map to a set of indices
301+
302+
Parameters
303+
----------
304+
indices : List[PriExpr]
305+
The indices to be mapped
306+
307+
Returns
308+
-------
309+
result : List[PrimExpr]
310+
The mapped indices
311+
"""
312+
return _ffi_api.IndexMapMapIndices(self, indices)

python/tvm/tir/schedule/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,5 @@
2222
from .schedule import BlockRV, ExprRV, LoopRV, Schedule, ScheduleError
2323
from .state import ScheduleDebugMask, ScheduleState
2424
from .trace import Trace
25+
26+
from . import analysis
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Analysis used in TensorIR scheduling"""
18+
from typing import List, Optional
19+
20+
from ..buffer import Buffer
21+
from ..stmt import For
22+
from ..expr import PrimExpr
23+
from ..function import IndexMap
24+
25+
from . import _ffi_api
26+
27+
28+
def suggest_index_map(
29+
buffer: Buffer,
30+
indices: List[PrimExpr],
31+
loops: List[For],
32+
predicate: PrimExpr,
33+
) -> Optional[IndexMap]:
34+
"""Provided the access pattern to a buffer, suggest one of the possible layout
35+
transformation to maximize the locality of the access pattern.
36+
37+
Parameters
38+
----------
39+
buffer : Buffer
40+
The buffer to be transformed.
41+
indices : List[PrimExpr]
42+
The access pattern to the buffer.
43+
loops : List[For]
44+
The loops above the buffer.
45+
predicate : PrimExpr
46+
The predicate of the access.
47+
48+
Returns
49+
-------
50+
index_map : Optional[IndexMap]
51+
The suggested index map. None if no transformation is suggested.
52+
"""
53+
return _ffi_api.SuggestIndexMap( # type: ignore # pylint: disable=no-member
54+
buffer,
55+
indices,
56+
loops,
57+
predicate,
58+
)

src/tir/ir/index_map.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,5 +201,7 @@ TVM_REGISTER_GLOBAL("tir.IndexMap")
201201
return IndexMap(initial_indices, final_indices);
202202
});
203203

204+
TVM_REGISTER_GLOBAL("tir.IndexMapMapIndices").set_body_method<IndexMap>(&IndexMapNode::MapIndices);
205+
204206
} // namespace tir
205207
} // namespace tvm

src/tir/schedule/analysis.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
#include <tvm/arith/analyzer.h>
2323
#include <tvm/ir/op.h>
24+
#include <tvm/tir/index_map.h>
2425
#include <tvm/tir/schedule/state.h>
2526

2627
#include <tuple>
@@ -520,6 +521,19 @@ bool CanComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const S
520521
bool CanReverseComputeAt(const ScheduleState& self, const StmtSRef& block_sref,
521522
const StmtSRef& loop_sref, bool preserve_unit_loops);
522523

524+
/*!
525+
* \brief Provided the access pattern to a buffer, suggest one of the possible layout
526+
* transformation to minimize the locality of the access pattern.
527+
* \param buffer The buffer to be transformed
528+
* \param indices The access pattern to the buffer
529+
* \param loops The loops above the buffer
530+
* \param predicate The predicate of the access
531+
* \param analyzer Arithmetic analyzer
532+
*/
533+
Optional<IndexMap> SuggestIndexMap(const Buffer& buffer, const Array<PrimExpr>& indices,
534+
const Array<For>& loops, const PrimExpr& predicate,
535+
arith::Analyzer* analyzer);
536+
523537
/*!
524538
* \brief Checks if the given AST contains the specific operators
525539
* \param stmt The AST statement to be checked
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
#include "../utils.h"
20+
21+
namespace tvm {
22+
namespace tir {
23+
24+
/*!
25+
* \brief Calculate the strides of the buffer
26+
* \param buffer The buffer
27+
* \return The strides
28+
*/
29+
Array<PrimExpr> GetStrides(const Buffer& buffer) {
30+
if (!buffer->strides.empty()) {
31+
ICHECK_EQ(buffer->strides.size(), buffer->shape.size());
32+
return buffer->strides;
33+
}
34+
int ndim = buffer->shape.size();
35+
if (ndim == 0) {
36+
return {};
37+
}
38+
Array<PrimExpr> strides(ndim, PrimExpr{nullptr});
39+
PrimExpr stride = make_const(buffer->DefaultIndexType(), 1);
40+
for (int i = ndim - 1; i >= 0; --i) {
41+
strides.Set(i, stride);
42+
stride = stride * buffer->shape[i];
43+
}
44+
return strides;
45+
}
46+
47+
/*!
48+
* \brief Auxiliary class that collects the IterSplitExpr in the indexing pattern
49+
* to help decision making in layout transformation
50+
*/
51+
class SplitExprCollector {
52+
public:
53+
/*!
54+
* \brief The corresponding IterSplitExpr, simplified for our case
55+
* The pattern is `source // lower_factor % extent * scale`
56+
*/
57+
struct SplitExpr {
58+
/*! \brief The source variable */
59+
Var source;
60+
/*! \brief The lower factor of the split expression */
61+
int64_t lower_factor;
62+
/*! \brief The extent of the split expression */
63+
int64_t extent;
64+
};
65+
66+
/*!
67+
* \brief Collect the split expressions in the indexing pattern
68+
* \param index The indexing pattern
69+
* \param input_iters The input iterators' domain
70+
* \param predicate The predicate of the affine map
71+
* \param require_bijective Whether the affine map is required to be bijective
72+
* \param analyzer The analyzer
73+
* \return The collected split expressions
74+
*/
75+
static std::vector<SplitExpr> Collect(const PrimExpr& index,
76+
const Map<Var, Range>& input_iters, //
77+
const PrimExpr& predicate, //
78+
bool require_bijective, //
79+
arith::Analyzer* analyzer) {
80+
DiagnosticContext diag_ctx(DiagnosticContext::Default(IRModule()));
81+
Array<arith::IterSumExpr> iter_sum_exprs = arith::DetectIterMap(
82+
{analyzer->Simplify(index)}, input_iters, predicate, require_bijective, analyzer, diag_ctx);
83+
if (iter_sum_exprs.empty()) {
84+
return {};
85+
}
86+
ICHECK_EQ(iter_sum_exprs.size(), 1);
87+
if (iter_sum_exprs[0]->args.size() == 0) {
88+
return {};
89+
}
90+
SplitExprCollector collector;
91+
collector.Visit(iter_sum_exprs[0]);
92+
if (collector.failed_) {
93+
return {};
94+
}
95+
return std::move(collector.exprs_);
96+
}
97+
98+
private:
99+
void Visit(const arith::IterSplitExpr& expr) {
100+
if (const auto* var = expr->source->source.as<tir::VarNode>()) {
101+
const int64_t* lower_factor = as_const_int(expr->lower_factor);
102+
const int64_t* extent = as_const_int(expr->extent);
103+
if (lower_factor == nullptr || extent == nullptr) {
104+
failed_ = true;
105+
return;
106+
}
107+
exprs_.push_back(SplitExpr{GetRef<Var>(var), *lower_factor, *extent});
108+
} else if (const auto* iter_sum_expr = expr->source->source.as<arith::IterSumExprNode>()) {
109+
Visit(GetRef<arith::IterSumExpr>(iter_sum_expr));
110+
} else {
111+
ICHECK(false) << "Unexpected type: " << expr->source->source->GetTypeKey();
112+
}
113+
}
114+
115+
void Visit(const arith::IterSumExpr& expr) {
116+
for (const arith::IterSplitExpr& arg : expr->args) {
117+
Visit(arg);
118+
}
119+
}
120+
121+
/*! \brief Whether the analysis failed */
122+
bool failed_ = false;
123+
/*! \brief The collected split expressions */
124+
std::vector<SplitExpr> exprs_;
125+
};
126+
127+
Optional<IndexMap> SuggestIndexMap(const Buffer& buffer, const Array<PrimExpr>& indices,
128+
const Array<For>& loops, const PrimExpr& predicate,
129+
arith::Analyzer* analyzer) {
130+
int ndim = buffer->shape.size();
131+
int n_loops = loops.size();
132+
// Step 1. Collect the domains and indices of loop variables
133+
Map<Var, Range> input_iters;
134+
std::unordered_map<const VarNode*, int> var2id;
135+
var2id.reserve(n_loops);
136+
for (int i = 0; i < n_loops; ++i) {
137+
const For& loop = loops[i];
138+
input_iters.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent));
139+
var2id.emplace(loop->loop_var.get(), i);
140+
}
141+
// Step 2. Calculate a functor that flattens a multi-dimensional index
142+
auto f_flatten_index = [ndim, strides = GetStrides(buffer), dtype = buffer->DefaultIndexType()](
143+
const Array<PrimExpr>& indices) -> PrimExpr {
144+
PrimExpr flatten_index = make_const(dtype, 0);
145+
for (int i = 0; i < ndim; ++i) {
146+
flatten_index = flatten_index + strides[i] * indices[i];
147+
}
148+
return flatten_index;
149+
};
150+
// Step 3. Detect the IterSplitExpr of the indexing pattern
151+
std::vector<SplitExprCollector::SplitExpr> split_exprs = SplitExprCollector::Collect(
152+
/*index=*/f_flatten_index(indices), input_iters, predicate,
153+
/*require_bijective=*/false, analyzer);
154+
if (split_exprs.empty()) {
155+
return NullOpt;
156+
}
157+
// Step 4. Sort the order of the split expressions
158+
std::vector<int> order(split_exprs.size(), 0);
159+
std::generate(order.begin(), order.end(), [n = 0]() mutable { return n++; });
160+
std::sort(order.begin(), order.end(), [&split_exprs, &var2id](int _a, int _b) -> bool {
161+
const SplitExprCollector::SplitExpr& a = split_exprs[_a];
162+
const SplitExprCollector::SplitExpr& b = split_exprs[_b];
163+
int a_var_id = var2id.at(a.source.get());
164+
int b_var_id = var2id.at(b.source.get());
165+
if (a_var_id != b_var_id) {
166+
return a_var_id < b_var_id;
167+
}
168+
return a.lower_factor > b.lower_factor;
169+
});
170+
// Step 5. Create the indexing mapping
171+
auto f_alter_layout = [f_flatten_index = std::move(f_flatten_index), //
172+
split_exprs = std::move(split_exprs), //
173+
order = std::move(order), //
174+
shape = buffer->shape, //
175+
analyzer //
176+
](Array<Var> indices) -> Array<PrimExpr> {
177+
ICHECK_EQ(indices.size(), shape.size());
178+
for (int i = 0, n = indices.size(); i < n; ++i) {
179+
analyzer->Bind(indices[i], Range::FromMinExtent(0, shape[i]));
180+
}
181+
PrimExpr index = f_flatten_index({indices.begin(), indices.end()});
182+
int ndim = split_exprs.size();
183+
// Step 5.1. Split the flattened index according to `split_exprs`
184+
std::vector<PrimExpr> split;
185+
split.reserve(ndim);
186+
for (int i = ndim - 1; i >= 0; --i) {
187+
index = analyzer->Simplify(index);
188+
int64_t extent = split_exprs[i].extent;
189+
split.push_back(analyzer->Simplify(floormod(index, extent)));
190+
index = floordiv(index, extent);
191+
}
192+
std::reverse(split.begin(), split.end());
193+
// Step 5.2. Reorder the indexing pattern according to `order`
194+
Array<PrimExpr> results;
195+
results.reserve(ndim);
196+
for (int i = 0; i < ndim; ++i) {
197+
results.push_back(split[order[i]]);
198+
}
199+
return results;
200+
};
201+
return IndexMap::FromFunc(ndim, f_alter_layout);
202+
}
203+
204+
TVM_REGISTER_GLOBAL("tir.schedule.SuggestIndexMap")
205+
.set_body_typed([](Buffer buffer, Array<PrimExpr> indices, Array<For> loops,
206+
PrimExpr predicate) {
207+
arith::Analyzer analyzer;
208+
return SuggestIndexMap(buffer, indices, loops, predicate, &analyzer);
209+
});
210+
211+
} // namespace tir
212+
} // namespace tvm

0 commit comments

Comments
 (0)