Skip to content

Commit 262f1a8

Browse files
author
qsqqsqqsq-intellif
committed
[TIR][Schedule] Add annotate_buffer_access primitive
1 parent e3faa55 commit 262f1a8

File tree

12 files changed

+736
-6
lines changed

12 files changed

+736
-6
lines changed

include/tvm/tir/schedule/schedule.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -834,6 +834,17 @@ class ScheduleNode : public runtime::Object {
834834
*/
835835
virtual void RollingBuffer(const BlockRV& block_rv, int write_buffer_index) = 0;
836836

837+
/*!
838+
* \brief Annotate the buffer access of a block
839+
* \param block_rv The block to be annotated
840+
* \param buffer_index The index of the buffer in block's read or write region
841+
* \param buffer_index_type The type of the buffer index, kRead or kWrite.
842+
* \param index_map The index map that defines the new read or write region
843+
*/
844+
virtual void AnnotateBufferAccess(const BlockRV& block_rv, int buffer_index,
845+
BufferIndexType buffer_index_type,
846+
const IndexMap& index_map) = 0;
847+
837848
/******** Schedule: Misc ********/
838849
/*! \brief A no-op that marks the start of postprocessing phase of scheduling */
839850
virtual void EnterPostproc() = 0;

include/tvm/tir/stmt.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1664,6 +1664,16 @@ constexpr const char* warp_execution = "warp_execution";
16641664
/*! \brief Mark that a block is disallowed in auto inline. */
16651665
constexpr const char* meta_schedule_inline_rule = "meta_schedule.inline_rule";
16661666

1667+
/*! \brief Mark that a block has an explicitly specified read region.
1668+
* This is used to override the default read region inference in TIR.
1669+
*/
1670+
constexpr const char* explicit_read_region = "explicit_read_region";
1671+
1672+
/*! \brief Mark that a block has an explicitly specified write region.
1673+
* This is used to override the default write region inference in TIR.
1674+
*/
1675+
constexpr const char* explicit_write_region = "explicit_write_region";
1676+
16671677
/*!
16681678
* \brief Check if attr_key is a pragma key extension
16691679
* \param attr_key The attr key to be compared

python/tvm/tir/schedule/schedule.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3907,3 +3907,139 @@ def unsafe_hide_buffer_access(
39073907
buf_type,
39083908
buf_index_array,
39093909
)
3910+
3911+
@type_checked
3912+
def annotate_buffer_access(
3913+
self, block: BlockRV, buffer_index: int, buf_type: str, gen_new_ranges: Callable
3914+
) -> None:
3915+
"""Annotate the read or write region of a block
3916+
3917+
Parameters
3918+
----------
3919+
block : BlockRV
3920+
The block to be annotated
3921+
buffer_index : int
3922+
The index of the buffer in block's read or write region
3923+
buf_type : str
3924+
The buffer type: "read" or "write"
3925+
gen_new_ranges : Callable
3926+
A function that takes the block's iter_vars and returns a
3927+
Tuple[Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], ...]
3928+
which defines the new read or write region for the buffer.
3929+
Each element in the tuple can be:
3930+
- A single PrimExpr representing the iter_var itself
3931+
- A tuple of two PrimExprs representing the range (begin, end)
3932+
3933+
Examples
3934+
--------
3935+
Annotate a 2D read region for a buffer.
3936+
Before annotate_buffer_access, in TensorIR, the IR is:
3937+
3938+
.. code-block:: python
3939+
3940+
@T.prim_func
3941+
def before_annotate_buffer_access(
3942+
A: T.Buffer((128, 128), "float32"),
3943+
C: T.Buffer((128, 128), "float32")
3944+
) -> None:
3945+
B = T.alloc_buffer((128, 128), "float32")
3946+
for i, j in T.grid(128, 128):
3947+
with T.block("B"):
3948+
vi, vj = T.axis.remap("SS", [i, j])
3949+
B[vi, vj] = A[vi, vj] * 2.0
3950+
for i, j in T.grid(128, 128):
3951+
with T.block("C"):
3952+
vi, vj = T.axis.remap("SS", [i, j])
3953+
C[vi, vj] = B[vi, vj] + 1.0
3954+
3955+
Create the schedule and do annotate_buffer_access:
3956+
3957+
.. code-block:: python
3958+
3959+
sch = tir.Schedule(before_annotate_buffer_access)
3960+
block = sch.get_block("B")
3961+
sch.annotate_buffer_access(block, 0, "read",
3962+
lambda vi, vj: ((vi - 1, vi + 1), (vj - 1, vj + 1)))
3963+
print(sch.mod["main"].script())
3964+
3965+
After applying annotate_buffer_access, the IR becomes:
3966+
3967+
.. code-block:: python
3968+
3969+
@T.prim_func
3970+
def after_annotate_buffer_access(
3971+
A: T.Buffer((128, 128), "float32"),
3972+
C: T.Buffer((128, 128), "float32")
3973+
) -> None:
3974+
B = T.alloc_buffer((128, 128), "float32")
3975+
for i, j in T.grid(128, 128):
3976+
with T.block("B"):
3977+
vi, vj = T.axis.remap("SS", [i, j])
3978+
T.reads(A[vi - 1:vi + 1, vj - 1:vj + 1])
3979+
T.writes(B[vi, vj])
3980+
T.block_attr({"explicit_read_region": 0})
3981+
B[vi, vj] = A[vi, vj] * 2.0
3982+
for i, j in T.grid(128, 128):
3983+
with T.block("C"):
3984+
vi, vj = T.axis.remap("SS", [i, j])
3985+
C[vi, vj] = B[vi, vj] + 1.0
3986+
3987+
This annotates the read region for buffer A (index 0) in block "B" to be
3988+
[vi-1:vi+1, vj-1:vj+1] for each (vi, vj) in the block's iteration domain.
3989+
3990+
Note
3991+
----
3992+
This function allows manual specification of read or write regions, which
3993+
can be useful in cases where the compiler cannot accurately infer the
3994+
access pattern, such as complex data-dependent accesses.
3995+
It overrides the automatically inferred region for the specified buffer.
3996+
The function adds an annotation to the block, indicating that an explicit
3997+
region has been provided for the buffer at the given index. This annotation
3998+
is used in the CompactBufferAllocation pass to respect the manually specified
3999+
region instead of relying on automatic inference.
4000+
4001+
Caution should be exercised when using this function, as incorrect annotations
4002+
may lead to incorrect code generation or runtime errors. It's crucial to
4003+
ensure that the specified region covers all actual reads or writes performed
4004+
by the block for the given buffer.
4005+
4006+
"""
4007+
block_obj = self.get(block)
4008+
iter_vars = [x.var for x in block_obj.iter_vars]
4009+
new_ranges_spec = gen_new_ranges(*iter_vars)
4010+
if len(iter_vars) != len(new_ranges_spec):
4011+
raise ValueError(
4012+
f"Number of iter_vars ({len(iter_vars)}) must match "
4013+
f"number of new_ranges_spec ({len(new_ranges_spec)})"
4014+
)
4015+
4016+
result = []
4017+
for rng in new_ranges_spec:
4018+
if isinstance(rng, (tuple, list)):
4019+
if len(rng) != 2:
4020+
raise ValueError(
4021+
"Tuple must have exactly 2 elements to represent (begin, end)."
4022+
)
4023+
result.extend(rng)
4024+
elif isinstance(rng, PrimExpr):
4025+
result.extend([rng, rng + 1]) # Single point represented as (rng, rng + 1)
4026+
else:
4027+
raise TypeError(f"Expected PrimExpr or tuple of PrimExpr, got {type(rng)}")
4028+
4029+
# Create index_map using IndexMap constructor
4030+
index_map = IndexMap(
4031+
initial_indices=iter_vars,
4032+
final_indices=result,
4033+
inverse_index_map=None,
4034+
)
4035+
4036+
if buf_type == "read":
4037+
buffer_index_type = 0
4038+
elif buf_type == "write":
4039+
buffer_index_type = 1
4040+
else:
4041+
raise ValueError(f"Invalid buf_type: {buf_type}. Expected 'read' or 'write'.")
4042+
4043+
return _ffi_api.ScheduleAnnotateBufferAccess(
4044+
self, block, buffer_index, buffer_index_type, index_map
4045+
)

src/tir/schedule/concrete_schedule.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,5 +1059,15 @@ void ConcreteScheduleNode::UnsafeHideBufferAccess(const BlockRV& block_rv, const
10591059
this->state_->DebugVerify();
10601060
}
10611061

1062+
void ConcreteScheduleNode::AnnotateBufferAccess(const BlockRV& block_rv, int buffer_index,
1063+
BufferIndexType buffer_index_type,
1064+
const IndexMap& index_map) {
1065+
TVM_TIR_SCHEDULE_BEGIN();
1066+
tir::AnnotateBufferAccess(state_, this->GetSRef(block_rv), buffer_index, buffer_index_type,
1067+
index_map);
1068+
TVM_TIR_SCHEDULE_END("annotate-buffer-access", this->error_render_level_);
1069+
this->state_->DebugVerify();
1070+
}
1071+
10621072
} // namespace tir
10631073
} // namespace tvm

src/tir/schedule/concrete_schedule.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,8 @@ class ConcreteScheduleNode : public ScheduleNode {
183183
void EnterPostproc() override {}
184184
void UnsafeHideBufferAccess(const BlockRV& block_rv, const String& buf_type,
185185
const Array<IntImm>& buf_index_array) override;
186+
void AnnotateBufferAccess(const BlockRV& block_rv, int buffer_index,
187+
BufferIndexType buffer_index_type, const IndexMap& index_map) override;
186188

187189
protected:
188190
/******** Utility functions ********/

src/tir/schedule/primitive.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,16 @@ TVM_DLL void RollingBuffer(ScheduleState self, const StmtSRef& block_sref, int w
718718
TVM_DLL void UnsafeHideBufferAccess(ScheduleState self, const StmtSRef& block_sref,
719719
const String& buf_type, const Array<IntImm>& buf_index_array);
720720

721+
/*!
722+
* \brief Annotate the read or write region of a specific buffer in a block
723+
* \param self The state of the schedule
724+
* \param block_sref The sref of the block to be annotated
725+
* \param buffer_index The index of the buffer in block's read or write region
726+
* \param buffer_index_type The type of the buffer index, kRead or kWrite
727+
* \param index_map The IndexMap that defines the new read or write region for the buffer
728+
*/
729+
TVM_DLL void AnnotateBufferAccess(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
730+
BufferIndexType buffer_index_type, const IndexMap& index_map);
721731
} // namespace tir
722732
} // namespace tvm
723733

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
#include "../utils.h"
20+
21+
namespace tvm {
22+
namespace tir {
23+
24+
class AnnotateRegionRewriter : public StmtExprMutator {
25+
public:
26+
AnnotateRegionRewriter(Buffer buffer, int buffer_index, BufferRegion new_region,
27+
BufferIndexType buffer_index_type)
28+
: buffer_(buffer),
29+
buffer_index_(buffer_index),
30+
new_region_(new_region),
31+
buffer_index_type_(buffer_index_type) {}
32+
33+
Stmt VisitStmt_(const BlockNode* op) final {
34+
Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
35+
36+
Array<BufferRegion> regions =
37+
buffer_index_type_ == BufferIndexType::kWrite ? block->writes : block->reads;
38+
ICHECK_GE(buffer_index_, 0) << "Buffer index must be non-negative";
39+
ICHECK_LT(buffer_index_, static_cast<int>(regions.size())) << "Buffer index out of range";
40+
regions.Set(buffer_index_, new_region_);
41+
42+
ObjectPtr<BlockNode> n = CopyOnWrite(block.get());
43+
if (buffer_index_type_ == BufferIndexType::kWrite) {
44+
n->writes = std::move(regions);
45+
} else {
46+
n->reads = std::move(regions);
47+
}
48+
49+
// Annotate the block with explicit_read_region or explicit_write_region
50+
Map<String, ObjectRef> new_annotations = n->annotations;
51+
String annotation_key = buffer_index_type_ == BufferIndexType::kWrite
52+
? attr::explicit_write_region
53+
: attr::explicit_read_region;
54+
if (new_annotations.count(annotation_key)) {
55+
Array<Integer> buffer_indices = Downcast<Array<Integer>>(new_annotations[annotation_key]);
56+
bool found = false;
57+
for (const Integer& index : buffer_indices) {
58+
if (index->value == buffer_index_) {
59+
found = true;
60+
break;
61+
}
62+
}
63+
if (!found) {
64+
buffer_indices.push_back(Integer(buffer_index_));
65+
new_annotations.Set(annotation_key, buffer_indices);
66+
}
67+
} else {
68+
new_annotations.Set(annotation_key, Array<Integer>{Integer(buffer_index_)});
69+
}
70+
n->annotations = std::move(new_annotations);
71+
72+
return Block(n);
73+
}
74+
75+
private:
76+
Buffer buffer_;
77+
int buffer_index_;
78+
BufferRegion new_region_;
79+
BufferIndexType buffer_index_type_;
80+
};
81+
82+
void AnnotateBufferAccess(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
83+
BufferIndexType buffer_index_type, const IndexMap& index_map) {
84+
const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
85+
Buffer buffer = GetNthAccessBuffer(self, GetRef<Block>(block), buffer_index, buffer_index_type);
86+
87+
arith::Analyzer analyzer;
88+
Array<PrimExpr> block_iter_vars;
89+
for (const IterVar& iter_var : block->iter_vars) {
90+
block_iter_vars.push_back(iter_var->var);
91+
}
92+
Array<PrimExpr> new_indices = index_map->MapIndices(block_iter_vars, &analyzer);
93+
ICHECK_EQ(new_indices.size() % 2, 0) << "The size of new_indices should be even.";
94+
Array<Range> new_ranges;
95+
for (size_t i = 0; i < new_indices.size(); i += 2) {
96+
// (begin, end) represents a region
97+
new_ranges.push_back(Range::FromMinExtent(
98+
new_indices[i], analyzer.Simplify(new_indices[i + 1] - new_indices[i])));
99+
}
100+
101+
BufferRegion new_region(buffer, new_ranges);
102+
103+
AnnotateRegionRewriter mutator(buffer, buffer_index, new_region, buffer_index_type);
104+
Stmt new_stmt = mutator(GetRef<Stmt>(block_sref->stmt));
105+
106+
self->Replace(block_sref, new_stmt, {{GetRef<Block>(block), Downcast<Block>(new_stmt)}});
107+
}
108+
109+
struct AnnotateBufferAccessTraits : public UnpackedInstTraits<AnnotateBufferAccessTraits> {
110+
static constexpr const char* kName = "AnnotateBufferAccess";
111+
static constexpr bool kIsPure = false;
112+
113+
private:
114+
static constexpr size_t kNumInputs = 4;
115+
static constexpr size_t kNumAttrs = 0;
116+
static constexpr size_t kNumDecisions = 0;
117+
118+
static void UnpackedApplyToSchedule(Schedule sch, BlockRV block, Integer buffer_index,
119+
Integer buffer_index_type, IndexMap index_map) {
120+
return sch->AnnotateBufferAccess(block, buffer_index->value,
121+
static_cast<BufferIndexType>(buffer_index_type->value),
122+
index_map);
123+
}
124+
125+
static String IndexMap2GenNewRangesLambda(const IndexMap& index_map) {
126+
std::ostringstream oss;
127+
oss << "lambda ";
128+
for (size_t i = 0; i < index_map->initial_indices.size(); ++i) {
129+
if (i != 0) oss << ", ";
130+
oss << index_map->initial_indices[i];
131+
}
132+
oss << ": [";
133+
for (size_t i = 0; i < index_map->final_indices.size(); i += 2) {
134+
if (i != 0) oss << ", ";
135+
if (index_map->final_indices[i].same_as(index_map->final_indices[i + 1])) {
136+
oss << index_map->final_indices[i];
137+
} else {
138+
oss << "(" << index_map->final_indices[i] << ", " << index_map->final_indices[i + 1] << ")";
139+
}
140+
}
141+
oss << "]";
142+
return String(oss.str());
143+
}
144+
145+
static String UnpackedAsPython(Array<String> outputs, String block, Integer buffer_index,
146+
Integer buffer_index_type, IndexMap index_map) {
147+
PythonAPICall py("annotate_buffer_access");
148+
py.Input("block", block);
149+
py.Input("buffer_index", buffer_index->value);
150+
151+
std::ostringstream os;
152+
os << "\"" << BufferIndexType2Str(static_cast<BufferIndexType>(buffer_index_type->value))
153+
<< "\"";
154+
py.Input("buf_type", os.str());
155+
156+
py.Input("gen_new_ranges", IndexMap2GenNewRangesLambda(index_map));
157+
return py.Str();
158+
}
159+
160+
template <typename>
161+
friend struct ::tvm::tir::UnpackedInstTraits;
162+
};
163+
164+
TVM_REGISTER_INST_KIND_TRAITS(AnnotateBufferAccessTraits);
165+
166+
} // namespace tir
167+
} // namespace tvm

src/tir/schedule/schedule.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,13 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleEnterPostproc")
310310
.set_body_method<Schedule>(&ScheduleNode::EnterPostproc);
311311
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnsafeHideBufferAccess")
312312
.set_body_method<Schedule>(&ScheduleNode::UnsafeHideBufferAccess);
313+
/******** (FFI) Annotate buffer access ********/
314+
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleAnnotateBufferAccess")
315+
.set_body_typed([](Schedule self, const BlockRV& block_rv, int buffer_index,
316+
int buffer_index_type, const IndexMap& index_map) {
317+
return self->AnnotateBufferAccess(block_rv, buffer_index,
318+
static_cast<BufferIndexType>(buffer_index_type), index_map);
319+
});
313320

314321
} // namespace tir
315322
} // namespace tvm

0 commit comments

Comments
 (0)