Skip to content

Commit bf3c8ef

Browse files
vinx13Siyuan FengspectrometerHBHjinhongyiiMasterJH5574
committed
[TIR][Schedule] Add Annotate/Unannotate primitive
Co-authored-by: Siyuan Feng <[email protected]> Co-authored-by: Bohan Hou <[email protected]> Co-authored-by: Hongyi Jin <[email protected]> Co-authored-by: Ruihang Lai <[email protected]> Co-authored-by: Junru Shao <[email protected]> Co-authored-by: Wuwei Lin <[email protected]> Co-authored-by: Xiyou Zhou <[email protected]>
1 parent 7279c9d commit bf3c8ef

File tree

12 files changed

+529
-6
lines changed

12 files changed

+529
-6
lines changed

include/tvm/tir/schedule/schedule.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,34 @@ class ScheduleNode : public runtime::Object {
450450
int offset) = 0;
451451
/******** Schedule: Blockize & Tensorize ********/
452452
/******** Schedule: Annotation ********/
453+
/*!
454+
* \brief Annotate a loop with a key value pair
455+
* \param loop The loop to be annotated
456+
* \param ann_key The annotation key
457+
* \param ann_val The annotation value, a string or a ExprRV
458+
*/
459+
virtual void Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) = 0;
460+
/*!
461+
* \brief Annotate a block with a key value pair
462+
* \param loop The block to be annotated
463+
* \param ann_key The annotation key
464+
* \param ann_val The annotation value, a string or a ExprRV
465+
*/
466+
virtual void Annotate(const BlockRV& block_rv, const String& ann_key,
467+
const ObjectRef& ann_val) = 0;
468+
/*!
469+
* \brief Unannotate a loop's annotation with key ann_key
470+
* \param loop The loop to be unannotated
471+
* \param ann_key The annotation key
472+
*/
473+
virtual void Unannotate(const LoopRV& loop_rv, const String& ann_key) = 0;
474+
/*!
475+
* \brief Unannotate a block's annotation with key ann_key
476+
* \param loop The block to be unannotated
477+
* \param ann_key The annotation key
478+
*/
479+
virtual void Unannotate(const BlockRV& block_rv, const String& ann_key) = 0;
480+
453481
/******** Schedule: Misc ********/
454482
/*! \brief A no-op that marks the start of postprocessing phase of scheduling */
455483
virtual void EnterPostproc() = 0;

python/tvm/script/tir/scope_handler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import synr
2222
import tvm.tir
23-
from tvm.runtime import Object
23+
from tvm.runtime import Object, String
2424
from tvm.ir import Span, Range
2525
from tvm.tir import Stmt, PrimExpr, IterVar, Var, Buffer, BufferRegion, ForKind
2626

@@ -486,7 +486,7 @@ def create_loop_info(
486486
self.annotations: Mapping[str, Object] = {}
487487
if annotations is not None:
488488
self.annotations = {
489-
key: tvm.tir.StringImm(val) if isinstance(val, str) else val
489+
key: String(val) if isinstance(val, str) else val
490490
for key, val in annotations.items()
491491
}
492492

python/tvm/script/tir/special_stmt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from tvm.ir.expr import PrimExpr, Range
2525

2626
import tvm.tir
27-
from tvm.runtime import Object
27+
from tvm.runtime import Object, String
2828
from tvm import te
2929
from tvm.target import Target
3030
from tvm.ir import Span
@@ -430,7 +430,7 @@ def block_attr(attrs: Mapping[str, Object], span: Span = None):
430430
span,
431431
)
432432
attrs = {
433-
key: tvm.tir.StringImm(val) if isinstance(val, str) else val
433+
key: String(val) if isinstance(val, str) else val
434434
for key, val in attrs.items()
435435
}
436436
block_scope.annotations = attrs

python/tvm/tir/schedule/schedule.py

Lines changed: 119 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
from tvm._ffi import register_object as _register_object
2121
from tvm.error import TVMError, register_error
2222
from tvm.ir import IRModule, PrimExpr
23-
from tvm.runtime import Object
24-
from tvm.tir import Block, For, IntImm, PrimFunc
23+
from tvm.runtime import Object, String
24+
from tvm.tir import Block, FloatImm, For, IntImm, PrimFunc
2525

2626
from . import _ffi_api
2727
from .state import ScheduleState, StmtSRef, _parse_debug_mask, _parse_mod
@@ -1664,6 +1664,123 @@ def after_storage_align(a: T.handle, c: T.handle) -> None:
16641664

16651665
########## Schedule: Annotation ##########
16661666

1667+
def annotate(
1668+
self,
1669+
block_or_loop: Union[BlockRV, LoopRV],
1670+
ann_key: str,
1671+
ann_val: Union[str, int, float, ExprRV],
1672+
) -> None:
1673+
"""Annotate a block/loop with a key value pair
1674+
1675+
Parameters
1676+
----------
1677+
block_or_loop: Union[BlockRV, LoopRV]
1678+
The block/loop to be annotated
1679+
ann_key : str
1680+
The annotation key
1681+
ann_val : Union[str, int, float, ExprRV]
1682+
The annotation value
1683+
1684+
Examples
1685+
--------
1686+
1687+
Before annotate, in TensorIR, the IR is:
1688+
1689+
.. code-block:: python
1690+
1691+
@T.prim_func
1692+
def before_annotate(a: T.handle, b: T.handle) -> None:
1693+
A = T.match_buffer(a, (128, 128))
1694+
B = T.match_buffer(b, (128, 128))
1695+
for i, j in T.grid(128, 128):
1696+
with T.block("B"):
1697+
vi, vj = T.axis.remap("SS", [i, j])
1698+
B[vi, vj] = A[vi, vj] * 2.0
1699+
1700+
Create the schedule and do annotate:
1701+
1702+
.. code-block:: python
1703+
1704+
sch = tir.Schedule(before_annotate)
1705+
sch.annotate(sch.get_block("B"), "ann_key", "ann_value")
1706+
print(sch.mod["main"].script())
1707+
1708+
After applying annotate, the IR becomes:
1709+
1710+
.. code-block:: python
1711+
1712+
@T.prim_func
1713+
def after_annotate(a: T.handle, b: T.handle) -> None:
1714+
A = T.match_buffer(a, (128, 128))
1715+
B = T.match_buffer(b, (128, 128))
1716+
for i, j in T.grid(128, 128):
1717+
with T.block("B"):
1718+
vi, vj = T.axis.remap("SS", [i, j])
1719+
T.block_attr({"ann_key", "ann_value"})
1720+
B[vi, vj] = A[vi, vj] * 2.0
1721+
1722+
"""
1723+
if isinstance(ann_val, str):
1724+
ann_val = String(ann_val)
1725+
elif isinstance(ann_val, int):
1726+
ann_val = IntImm("int32", ann_val)
1727+
elif isinstance(ann_val, float):
1728+
ann_val = FloatImm("float32", ann_val)
1729+
_ffi_api.ScheduleAnnotate( # pylint: disable=no-member
1730+
self, block_or_loop, ann_key, ann_val
1731+
)
1732+
1733+
def unannotate(self, block_or_loop: Union[BlockRV, LoopRV], ann_key: str) -> None:
1734+
"""Unannotate a block/loop's annotation with key ann_key
1735+
1736+
Parameters
1737+
----------
1738+
block_or_loop: Union[BlockRV, LoopRV]
1739+
The block/loop to be unannotated
1740+
ann_key : str
1741+
The annotation key
1742+
1743+
Examples
1744+
--------
1745+
1746+
Before unannotate, in TensorIR, the IR is:
1747+
1748+
.. code-block:: python
1749+
1750+
@T.prim_func
1751+
def before_unannotate(a: T.handle, b: T.handle) -> None:
1752+
A = T.match_buffer(a, (128, 128))
1753+
B = T.match_buffer(b, (128, 128))
1754+
for i, j in T.grid(128, 128):
1755+
with T.block("B"):
1756+
vi, vj = T.axis.remap("SS", [i, j])
1757+
T.block_attr({"ann_key", "ann_value"})
1758+
B[vi, vj] = A[vi, vj] * 2.0
1759+
1760+
Create the schedule and do annotate:
1761+
1762+
.. code-block:: python
1763+
1764+
sch = tir.Schedule(before_unannotate)
1765+
sch.unannotate(sch.get_block("B"), "ann_key")
1766+
print(sch.mod["main"].script())
1767+
1768+
After applying unannotate, the IR becomes:
1769+
1770+
.. code-block:: python
1771+
1772+
@T.prim_func
1773+
def after_unannotate(a: T.handle, b: T.handle) -> None:
1774+
A = T.match_buffer(a, (128, 128))
1775+
B = T.match_buffer(b, (128, 128))
1776+
for i, j in T.grid(128, 128):
1777+
with T.block("B"):
1778+
vi, vj = T.axis.remap("SS", [i, j])
1779+
B[vi, vj] = A[vi, vj] * 2.0
1780+
1781+
"""
1782+
_ffi_api.ScheduleUnannotate(self, block_or_loop, ann_key) # pylint: disable=no-member
1783+
16671784
########## Schedule: Misc ##########
16681785

16691786
@type_checked

src/tir/schedule/concrete_schedule.cc

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,53 @@ BlockRV ConcreteScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) {
563563

564564
/******** Schedule: Blockize & Tensorize ********/
565565
/******** Schedule: Annotation ********/
566+
567+
ObjectRef ConcreteScheduleNode::CheckAndGetAnnotationValue(const ObjectRef& ann_val) {
568+
if (ann_val.as<StringObj>()) {
569+
return ann_val;
570+
}
571+
if (const auto* expr = ann_val.as<PrimExprNode>()) {
572+
ICHECK(!ann_val->IsInstance<StringImmNode>())
573+
<< "TypeError: runtime::String is expected, but gets StringImm";
574+
return this->Get(GetRef<PrimExpr>(expr));
575+
}
576+
LOG(FATAL)
577+
<< "TypeError: Only strings, integers, floats, ExprRVs and Arrays are supported for now, but "
578+
<< "gets: " << ann_val->GetTypeKey();
579+
throw;
580+
}
581+
582+
void ConcreteScheduleNode::Annotate(const LoopRV& loop_rv, const String& ann_key,
583+
const ObjectRef& ann_val) {
584+
TVM_TIR_SCHEDULE_BEGIN();
585+
tir::Annotate(state_, this->GetSRef(loop_rv), ann_key, this->CheckAndGetAnnotationValue(ann_val));
586+
this->state_->DebugVerify();
587+
TVM_TIR_SCHEDULE_END("annotate", this->error_render_level_);
588+
}
589+
590+
void ConcreteScheduleNode::Unannotate(const LoopRV& loop_rv, const String& ann_key) {
591+
TVM_TIR_SCHEDULE_BEGIN();
592+
tir::Unannotate(state_, this->GetSRef(loop_rv), ann_key);
593+
this->state_->DebugVerify();
594+
TVM_TIR_SCHEDULE_END("unannotate", this->error_render_level_);
595+
}
596+
597+
void ConcreteScheduleNode::Annotate(const BlockRV& block_rv, const String& ann_key,
598+
const ObjectRef& ann_val) {
599+
TVM_TIR_SCHEDULE_BEGIN();
600+
tir::Annotate(state_, this->GetSRef(block_rv), ann_key,
601+
this->CheckAndGetAnnotationValue(ann_val));
602+
this->state_->DebugVerify();
603+
TVM_TIR_SCHEDULE_END("annotate", this->error_render_level_);
604+
}
605+
606+
void ConcreteScheduleNode::Unannotate(const BlockRV& loop_rv, const String& ann_key) {
607+
TVM_TIR_SCHEDULE_BEGIN();
608+
tir::Unannotate(state_, this->GetSRef(loop_rv), ann_key);
609+
this->state_->DebugVerify();
610+
TVM_TIR_SCHEDULE_END("unannotate", this->error_render_level_);
611+
}
612+
566613
/******** Schedule: Misc ********/
567614

568615
} // namespace tir

src/tir/schedule/concrete_schedule.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,11 @@ class ConcreteScheduleNode : public ScheduleNode {
120120
int offset) override;
121121
/******** Schedule: Blockize & Tensorize ********/
122122
/******** Schedule: Annotation ********/
123+
void Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) override;
124+
void Unannotate(const LoopRV& loop_rv, const String& ann_key) override;
125+
void Annotate(const BlockRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) override;
126+
void Unannotate(const BlockRV& loop_rv, const String& ann_key) override;
127+
123128
/******** Schedule: Misc ********/
124129
void EnterPostproc() override {}
125130

@@ -161,6 +166,13 @@ class ConcreteScheduleNode : public ScheduleNode {
161166
inline Array<ExprRV> CreateRV(const std::vector<int64_t>& value);
162167
/*! \brief Remove a random variable from the symbol table */
163168
inline void RemoveFromSymbolTable(const ObjectRef& rv);
169+
/*!
170+
* \brief Check the annotation value is valid and look up the random variable. Raises an exception
171+
* if the type of the annotation value is not allowed.
172+
* \param The annotation value.
173+
* \return The annotation value with random variables substituted with their values.
174+
*/
175+
ObjectRef CheckAndGetAnnotationValue(const ObjectRef& ann_val);
164176
};
165177

166178
// implementations

src/tir/schedule/primitive.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,23 @@ TVM_DLL void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int bu
340340

341341
/******** Schedule: Blockize & Tensorize ********/
342342
/******** Schedule: Annotation ********/
343+
/*!
344+
* \brief Annotate a block/loop with a key value pair
345+
* \param self The state of the schedule
346+
* \param sref The block/loop sref to be annotated
347+
* \param ann_key The annotation key
348+
* \param ann_val The annotation value
349+
*/
350+
TVM_DLL void Annotate(ScheduleState self, const StmtSRef& sref, const String& ann_key,
351+
const ObjectRef& ann_val);
352+
/*!
353+
* \brief Unannotate a block/loop's annotation with key ann_key
354+
* \param self The state of the schedule
355+
* \param sref The block/loop to be unannotated
356+
* \param ann_key The annotation key
357+
*/
358+
TVM_DLL void Unannotate(ScheduleState self, const StmtSRef& sref, const String& ann_key);
359+
343360
/******** Schedule: Misc ********/
344361

345362
} // namespace tir

0 commit comments

Comments
 (0)