Skip to content

Commit 9c81e7d

Browse files
author
qsqqsqqsq-intellif
committed
[TIR][Schedule] Add annotate_buffer_access primitive
1 parent d9ee637 commit 9c81e7d

File tree

12 files changed

+620
-6
lines changed

12 files changed

+620
-6
lines changed

include/tvm/tir/schedule/schedule.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -834,6 +834,17 @@ class ScheduleNode : public runtime::Object {
834834
*/
835835
virtual void RollingBuffer(const BlockRV& block_rv, int write_buffer_index) = 0;
836836

837+
/*!
838+
* \brief Annotate the buffer access of a block
839+
* \param block_rv The block to be annotated
840+
* \param buffer_index The index of the buffer in block's read or write region
841+
* \param buffer_index_type The type of the buffer index, kRead or kWrite.
842+
* \param index_map The index map that defines the new read or write region
843+
*/
844+
virtual void AnnotateBufferAccess(const BlockRV& block_rv, int buffer_index,
845+
BufferIndexType buffer_index_type,
846+
const IndexMap& index_map) = 0;
847+
837848
/******** Schedule: Misc ********/
838849
/*! \brief A no-op that marks the start of postprocessing phase of scheduling */
839850
virtual void EnterPostproc() = 0;

include/tvm/tir/stmt.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1664,6 +1664,16 @@ constexpr const char* warp_execution = "warp_execution";
16641664
/*! \brief Mark that a block is disallowed in auto inline. */
16651665
constexpr const char* meta_schedule_inline_rule = "meta_schedule.inline_rule";
16661666

1667+
/*! \brief Mark that a block has an explicitly specified read region.
1668+
* This is used to override the default read region inference in TIR.
1669+
*/
1670+
constexpr const char* explicit_read_region = "explicit_read_region";
1671+
1672+
/*! \brief Mark that a block has an explicitly specified write region.
1673+
* This is used to override the default write region inference in TIR.
1674+
*/
1675+
constexpr const char* explicit_write_region = "explicit_write_region";
1676+
16671677
/*!
16681678
* \brief Check if attr_key is a pragma key extension
16691679
* \param attr_key The attr key to be compared

python/tvm/tir/schedule/schedule.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3904,3 +3904,134 @@ def unsafe_hide_buffer_access(
39043904
buf_type,
39053905
buf_index_array,
39063906
)
3907+
3908+
@type_checked
3909+
def annotate_buffer_access(
3910+
self, block: BlockRV, buffer_index: int, buf_type: str, gen_new_ranges: Callable
3911+
) -> None:
3912+
"""Annotate the read or write region of a block
3913+
3914+
Parameters
3915+
----------
3916+
block : BlockRV
3917+
The block to be annotated
3918+
buffer_index : int
3919+
The index of the buffer in block's read or write region
3920+
buf_type : str
3921+
The buffer type: "read" or "write"
3922+
gen_new_ranges : Callable
3923+
A function that takes the block's iter_vars and returns a Tuple[Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], ...]
3924+
which defines the new read or write region for the buffer.
3925+
Each element in the tuple can be:
3926+
- A single PrimExpr representing the iter_var itself
3927+
- A tuple of two PrimExprs representing the range (begin, end)
3928+
3929+
Examples
3930+
--------
3931+
Annotate a 2D read region for a buffer.
3932+
Before annotate_buffer_access, in TensorIR, the IR is:
3933+
3934+
.. code-block:: python
3935+
3936+
@T.prim_func
3937+
def before_annotate_buffer_access(
3938+
A: T.Buffer((128, 128), "float32"),
3939+
C: T.Buffer((128, 128), "float32")
3940+
) -> None:
3941+
B = T.alloc_buffer((128, 128), "float32")
3942+
for i, j in T.grid(128, 128):
3943+
with T.block("B"):
3944+
vi, vj = T.axis.remap("SS", [i, j])
3945+
B[vi, vj] = A[vi, vj] * 2.0
3946+
for i, j in T.grid(128, 128):
3947+
with T.block("C"):
3948+
vi, vj = T.axis.remap("SS", [i, j])
3949+
C[vi, vj] = B[vi, vj] + 1.0
3950+
3951+
Create the schedule and do annotate_buffer_access:
3952+
3953+
.. code-block:: python
3954+
3955+
sch = tir.Schedule(before_annotate_buffer_access)
3956+
block = sch.get_block("B")
3957+
sch.annotate_buffer_access(block, 0, "read", lambda vi, vj: ((vi - 1, vi + 1), (vj - 1, vj + 1)))
3958+
print(sch.mod["main"].script())
3959+
3960+
After applying annotate_buffer_access, the IR becomes:
3961+
3962+
.. code-block:: python
3963+
3964+
@T.prim_func
3965+
def after_annotate_buffer_access(
3966+
A: T.Buffer((128, 128), "float32"),
3967+
C: T.Buffer((128, 128), "float32")
3968+
) -> None:
3969+
B = T.alloc_buffer((128, 128), "float32")
3970+
for i, j in T.grid(128, 128):
3971+
with T.block("B"):
3972+
vi, vj = T.axis.remap("SS", [i, j])
3973+
T.reads(A[vi - 1:vi + 1, vj - 1:vj + 1])
3974+
T.writes(B[vi, vj])
3975+
T.block_attr({"explicit_read_region": 0})
3976+
B[vi, vj] = A[vi, vj] * 2.0
3977+
for i, j in T.grid(128, 128):
3978+
with T.block("C"):
3979+
vi, vj = T.axis.remap("SS", [i, j])
3980+
C[vi, vj] = B[vi, vj] + 1.0
3981+
3982+
This annotates the read region for buffer A (index 0) in block "B" to be
3983+
[vi-1:vi+1, vj-1:vj+1] for each (vi, vj) in the block's iteration domain.
3984+
3985+
Note
3986+
----
3987+
This function allows manual specification of read or write regions, which can be useful in cases where
3988+
the compiler cannot accurately infer the access pattern, such as complex data-dependent accesses.
3989+
It overrides the automatically inferred region for the specified buffer.
3990+
3991+
The function adds an annotation to the block, indicating that an explicit region has been
3992+
provided for the buffer at the given index. This annotation is used in the CompactBufferAllocation pass
3993+
to respect the manually specified region instead of relying on automatic inference.
3994+
3995+
Caution should be exercised when using this function, as incorrect annotations may lead to
3996+
incorrect code generation or runtime errors. It's crucial to ensure that the specified
3997+
region covers all actual reads or writes performed by the block for the given buffer.
3998+
3999+
"""
4000+
block_obj = self.get(block)
4001+
iter_vars = [x.var for x in block_obj.iter_vars]
4002+
new_ranges_spec = gen_new_ranges(*iter_vars)
4003+
if len(iter_vars) != len(new_ranges_spec):
4004+
raise ValueError(
4005+
f"Number of iter_vars ({len(iter_vars)}) must match number of new_ranges_spec ({len(new_ranges_spec)})"
4006+
)
4007+
4008+
result = []
4009+
for rng in new_ranges_spec:
4010+
if isinstance(rng, (tuple, list)):
4011+
if len(rng) != 2:
4012+
raise ValueError(
4013+
f"Tuple must have exactly 2 elements to represent (begin, end)."
4014+
)
4015+
result.extend(rng)
4016+
elif isinstance(rng, PrimExpr):
4017+
result.extend([rng, rng]) # Single point represented as (rng, rng)
4018+
else:
4019+
raise TypeError(f"Expected PrimExpr or tuple of PrimExpr, got {type(rng)}")
4020+
4021+
# Create index_map using IndexMap constructor
4022+
index_map = IndexMap(
4023+
initial_indices=iter_vars,
4024+
final_indices=result,
4025+
inverse_index_map=None,
4026+
)
4027+
4028+
if buf_type == "read":
4029+
buffer_index_type = 0
4030+
elif buf_type == "write":
4031+
buffer_index_type = 1
4032+
else:
4033+
raise ValueError(f"Invalid buf_type: {buf_type}. Expected 'read' or 'write'.")
4034+
4035+
return _ffi_api.ScheduleAnnotateBufferAccess(
4036+
self, block, buffer_index, buffer_index_type, index_map
4037+
)

src/tir/schedule/concrete_schedule.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,5 +1059,15 @@ void ConcreteScheduleNode::UnsafeHideBufferAccess(const BlockRV& block_rv, const
10591059
this->state_->DebugVerify();
10601060
}
10611061

1062+
void ConcreteScheduleNode::AnnotateBufferAccess(const BlockRV& block_rv, int buffer_index,
1063+
BufferIndexType buffer_index_type,
1064+
const IndexMap& index_map) {
1065+
TVM_TIR_SCHEDULE_BEGIN();
1066+
tir::AnnotateBufferAccess(state_, this->GetSRef(block_rv), buffer_index, buffer_index_type,
1067+
index_map);
1068+
TVM_TIR_SCHEDULE_END("annotate-buffer-access", this->error_render_level_);
1069+
this->state_->DebugVerify();
1070+
}
1071+
10621072
} // namespace tir
10631073
} // namespace tvm

src/tir/schedule/concrete_schedule.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,8 @@ class ConcreteScheduleNode : public ScheduleNode {
183183
void EnterPostproc() override {}
184184
void UnsafeHideBufferAccess(const BlockRV& block_rv, const String& buf_type,
185185
const Array<IntImm>& buf_index_array) override;
186+
void AnnotateBufferAccess(const BlockRV& block_rv, int buffer_index,
187+
BufferIndexType buffer_index_type, const IndexMap& index_map) override;
186188

187189
protected:
188190
/******** Utility functions ********/

src/tir/schedule/primitive.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,16 @@ TVM_DLL void RollingBuffer(ScheduleState self, const StmtSRef& block_sref, int w
718718
TVM_DLL void UnsafeHideBufferAccess(ScheduleState self, const StmtSRef& block_sref,
719719
const String& buf_type, const Array<IntImm>& buf_index_array);
720720

721+
/*!
722+
* \brief Annotate the read or write region of a specific buffer in a block
723+
* \param self The state of the schedule
724+
* \param block_sref The sref of the block to be annotated
725+
* \param buffer_index The index of the buffer in block's read or write region
726+
* \param buffer_index_type The type of the buffer index, kRead or kWrite
727+
* \param index_map The IndexMap that defines the new read or write region for the buffer
728+
*/
729+
TVM_DLL void AnnotateBufferAccess(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
730+
BufferIndexType buffer_index_type, const IndexMap& index_map);
721731
} // namespace tir
722732
} // namespace tvm
723733

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
#include "../utils.h"
2+
3+
namespace tvm {
4+
namespace tir {
5+
6+
class AnnotateRegionRewriter : public StmtExprMutator {
7+
public:
8+
AnnotateRegionRewriter(Buffer buffer, int buffer_index, BufferRegion new_region,
9+
BufferIndexType buffer_index_type)
10+
: buffer_(buffer),
11+
buffer_index_(buffer_index),
12+
new_region_(new_region),
13+
buffer_index_type_(buffer_index_type) {}
14+
15+
Stmt VisitStmt_(const BlockNode* op) final {
16+
Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
17+
18+
Array<BufferRegion> regions =
19+
buffer_index_type_ == BufferIndexType::kWrite ? block->writes : block->reads;
20+
ICHECK_GE(buffer_index_, 0) << "Buffer index must be non-negative";
21+
ICHECK_LT(buffer_index_, static_cast<int>(regions.size())) << "Buffer index out of range";
22+
regions.Set(buffer_index_, new_region_);
23+
24+
ObjectPtr<BlockNode> n = CopyOnWrite(block.get());
25+
if (buffer_index_type_ == BufferIndexType::kWrite) {
26+
n->writes = std::move(regions);
27+
} else {
28+
n->reads = std::move(regions);
29+
}
30+
31+
// Annotate the block with explicit_read_region or explicit_write_region
32+
Map<String, ObjectRef> new_annotations = n->annotations;
33+
String annotation_key = buffer_index_type_ == BufferIndexType::kWrite
34+
? attr::explicit_write_region
35+
: attr::explicit_read_region;
36+
if (new_annotations.count(annotation_key)) {
37+
Array<Integer> buffer_indices = Downcast<Array<Integer>>(new_annotations[annotation_key]);
38+
bool found = false;
39+
for (const Integer& index : buffer_indices) {
40+
if (index->value == buffer_index_) {
41+
found = true;
42+
break;
43+
}
44+
}
45+
if (!found) {
46+
buffer_indices.push_back(Integer(buffer_index_));
47+
new_annotations.Set(annotation_key, buffer_indices);
48+
}
49+
} else {
50+
new_annotations.Set(annotation_key, Array<Integer>{Integer(buffer_index_)});
51+
}
52+
n->annotations = std::move(new_annotations);
53+
54+
return Block(n);
55+
}
56+
57+
private:
58+
Buffer buffer_;
59+
int buffer_index_;
60+
BufferRegion new_region_;
61+
BufferIndexType buffer_index_type_;
62+
};
63+
64+
void AnnotateBufferAccess(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
65+
BufferIndexType buffer_index_type, const IndexMap& index_map) {
66+
const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
67+
Buffer buffer = GetNthAccessBuffer(self, GetRef<Block>(block), buffer_index, buffer_index_type);
68+
69+
arith::Analyzer analyzer;
70+
Array<PrimExpr> block_iter_vars;
71+
for (const IterVar& iter_var : block->iter_vars) {
72+
block_iter_vars.push_back(iter_var->var);
73+
}
74+
Array<PrimExpr> new_indices = index_map->MapIndices(block_iter_vars, &analyzer);
75+
ICHECK_EQ(new_indices.size() % 2, 0) << "The size of new_indices should be even.";
76+
Array<Range> new_ranges;
77+
for (size_t i = 0; i < new_indices.size(); i += 2) {
78+
// (iter_var, iter_var) represents a single point
79+
if (analyzer.CanProveEqual(new_indices[i], new_indices[i + 1])) {
80+
new_ranges.push_back(Range::FromMinExtent(new_indices[i], 1));
81+
}
82+
// (begin, end) represents a region
83+
else {
84+
new_ranges.push_back(Range::FromMinExtent(
85+
new_indices[i], analyzer.Simplify(new_indices[i + 1] - new_indices[i])));
86+
}
87+
}
88+
89+
BufferRegion new_region(buffer, new_ranges);
90+
91+
AnnotateRegionRewriter mutator(buffer, buffer_index, new_region, buffer_index_type);
92+
Stmt new_stmt = mutator(GetRef<Stmt>(block_sref->stmt));
93+
94+
self->Replace(block_sref, new_stmt, {{GetRef<Block>(block), Downcast<Block>(new_stmt)}});
95+
}
96+
97+
struct AnnotateBufferAccessTraits : public UnpackedInstTraits<AnnotateBufferAccessTraits> {
98+
static constexpr const char* kName = "AnnotateBufferAccess";
99+
static constexpr bool kIsPure = false;
100+
101+
private:
102+
static constexpr size_t kNumInputs = 4;
103+
static constexpr size_t kNumAttrs = 0;
104+
static constexpr size_t kNumDecisions = 0;
105+
106+
static void UnpackedApplyToSchedule(Schedule sch, BlockRV block, Integer buffer_index,
107+
Integer buffer_index_type, IndexMap index_map) {
108+
return sch->AnnotateBufferAccess(block, buffer_index->value,
109+
static_cast<BufferIndexType>(buffer_index_type->value),
110+
index_map);
111+
}
112+
113+
static String IndexMap2GenNewRangesLambda(const IndexMap& index_map) {
114+
std::ostringstream oss;
115+
oss << "lambda ";
116+
for (size_t i = 0; i < index_map->initial_indices.size(); ++i) {
117+
if (i != 0) oss << ", ";
118+
oss << index_map->initial_indices[i];
119+
}
120+
oss << ": [";
121+
for (size_t i = 0; i < index_map->final_indices.size(); i += 2) {
122+
if (i != 0) oss << ", ";
123+
if (index_map->final_indices[i].same_as(index_map->final_indices[i + 1])) {
124+
oss << index_map->final_indices[i];
125+
} else {
126+
oss << "(" << index_map->final_indices[i] << ", " << index_map->final_indices[i + 1] << ")";
127+
}
128+
}
129+
oss << "]";
130+
return String(oss.str());
131+
}
132+
133+
static String UnpackedAsPython(Array<String> outputs, String block, Integer buffer_index,
134+
Integer buffer_index_type, IndexMap index_map) {
135+
PythonAPICall py("annotate_buffer_access");
136+
py.Input("block", block);
137+
py.Input("buffer_index", buffer_index->value);
138+
139+
std::ostringstream os;
140+
os << "\"" << BufferIndexType2Str(static_cast<BufferIndexType>(buffer_index_type->value))
141+
<< "\"";
142+
py.Input("buf_type", os.str());
143+
144+
py.Input("gen_new_ranges", IndexMap2GenNewRangesLambda(index_map));
145+
return py.Str();
146+
}
147+
148+
template <typename>
149+
friend struct ::tvm::tir::UnpackedInstTraits;
150+
};
151+
152+
TVM_REGISTER_INST_KIND_TRAITS(AnnotateBufferAccessTraits);
153+
154+
} // namespace tir
155+
} // namespace tvm

src/tir/schedule/schedule.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,13 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleEnterPostproc")
310310
.set_body_method<Schedule>(&ScheduleNode::EnterPostproc);
311311
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnsafeHideBufferAccess")
312312
.set_body_method<Schedule>(&ScheduleNode::UnsafeHideBufferAccess);
313+
/******** (FFI) Annotate buffer access ********/
314+
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleAnnotateBufferAccess")
315+
.set_body_typed([](Schedule self, const BlockRV& block_rv, int buffer_index,
316+
int buffer_index_type, const IndexMap& index_map) {
317+
return self->AnnotateBufferAccess(block_rv, buffer_index,
318+
static_cast<BufferIndexType>(buffer_index_type), index_map);
319+
});
313320

314321
} // namespace tir
315322
} // namespace tvm

src/tir/schedule/traced_schedule.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -769,5 +769,17 @@ void TracedScheduleNode::UnsafeHideBufferAccess(const BlockRV& block_rv, const S
769769
/*outputs=*/{}));
770770
}
771771

772+
void TracedScheduleNode::AnnotateBufferAccess(const BlockRV& block_rv, int buffer_index,
773+
BufferIndexType buffer_index_type,
774+
const IndexMap& index_map) {
775+
ConcreteScheduleNode::AnnotateBufferAccess(block_rv, buffer_index, buffer_index_type, index_map);
776+
static const InstructionKind& kind = InstructionKind::Get("AnnotateBufferAccess");
777+
trace_->Append(/*inst=*/Instruction(
778+
/*kind=*/kind,
779+
/*inputs=*/{block_rv, Integer(buffer_index), Integer(buffer_index_type), index_map},
780+
/*attrs=*/{},
781+
/*outputs=*/{}));
782+
}
783+
772784
} // namespace tir
773785
} // namespace tvm

0 commit comments

Comments
 (0)