|
20 | 20 | from tvm._ffi import register_object as _register_object |
21 | 21 | from tvm.error import TVMError, register_error |
22 | 22 | from tvm.ir import IRModule, PrimExpr |
23 | | -from tvm.runtime import Object |
24 | | -from tvm.tir import Block, For, IntImm, PrimFunc |
| 23 | +from tvm.runtime import Object, String |
| 24 | +from tvm.tir import Block, FloatImm, For, IntImm, PrimFunc |
25 | 25 |
|
26 | 26 | from . import _ffi_api |
27 | 27 | from .state import ScheduleState, StmtSRef, _parse_debug_mask, _parse_mod |
@@ -1664,6 +1664,123 @@ def after_storage_align(a: T.handle, c: T.handle) -> None: |
1664 | 1664 |
|
1665 | 1665 | ########## Schedule: Annotation ########## |
1666 | 1666 |
|
| 1667 | + def annotate( |
| 1668 | + self, |
| 1669 | + block_or_loop: Union[BlockRV, LoopRV], |
| 1670 | + ann_key: str, |
| 1671 | + ann_val: Union[str, int, float, ExprRV], |
| 1672 | + ) -> None: |
| 1673 | + """Annotate a block/loop with a key value pair |
| 1674 | +
|
| 1675 | + Parameters |
| 1676 | + ---------- |
| 1677 | + block_or_loop: Union[BlockRV, LoopRV] |
| 1678 | + The block/loop to be annotated |
| 1679 | + ann_key : str |
| 1680 | + The annotation key |
| 1681 | + ann_val : Union[str, int, float, ExprRV] |
| 1682 | + The annotation value |
| 1683 | +
|
| 1684 | + Examples |
| 1685 | + -------- |
| 1686 | +
|
| 1687 | + Before annotate, in TensorIR, the IR is: |
| 1688 | +
|
| 1689 | + .. code-block:: python |
| 1690 | +
|
| 1691 | + @T.prim_func |
| 1692 | + def before_annotate(a: T.handle, b: T.handle) -> None: |
| 1693 | + A = T.match_buffer(a, (128, 128)) |
| 1694 | + B = T.match_buffer(b, (128, 128)) |
| 1695 | + for i, j in T.grid(128, 128): |
| 1696 | + with T.block("B"): |
| 1697 | + vi, vj = T.axis.remap("SS", [i, j]) |
| 1698 | + B[vi, vj] = A[vi, vj] * 2.0 |
| 1699 | +
|
| 1700 | + Create the schedule and do annotate: |
| 1701 | +
|
| 1702 | + .. code-block:: python |
| 1703 | +
|
| 1704 | + sch = tir.Schedule(before_annotate) |
| 1705 | + sch.annotate(sch.get_block("B"), "ann_key", "ann_value") |
| 1706 | + print(sch.mod["main"].script()) |
| 1707 | +
|
| 1708 | + After applying annotate, the IR becomes: |
| 1709 | +
|
| 1710 | + .. code-block:: python |
| 1711 | +
|
| 1712 | + @T.prim_func |
| 1713 | + def after_annotate(a: T.handle, b: T.handle) -> None: |
| 1714 | + A = T.match_buffer(a, (128, 128)) |
| 1715 | + B = T.match_buffer(b, (128, 128)) |
| 1716 | + for i, j in T.grid(128, 128): |
| 1717 | + with T.block("B"): |
| 1718 | + vi, vj = T.axis.remap("SS", [i, j]) |
| 1719 | + T.block_attr({"ann_key", "ann_value"}) |
| 1720 | + B[vi, vj] = A[vi, vj] * 2.0 |
| 1721 | +
|
| 1722 | + """ |
| 1723 | + if isinstance(ann_val, str): |
| 1724 | + ann_val = String(ann_val) |
| 1725 | + elif isinstance(ann_val, int): |
| 1726 | + ann_val = IntImm("int32", ann_val) |
| 1727 | + elif isinstance(ann_val, float): |
| 1728 | + ann_val = FloatImm("float32", ann_val) |
| 1729 | + _ffi_api.ScheduleAnnotate( # pylint: disable=no-member |
| 1730 | + self, block_or_loop, ann_key, ann_val |
| 1731 | + ) |
| 1732 | + |
| 1733 | + def unannotate(self, block_or_loop: Union[BlockRV, LoopRV], ann_key: str) -> None: |
| 1734 | + """Unannotate a block/loop's annotation with key ann_key |
| 1735 | +
|
| 1736 | + Parameters |
| 1737 | + ---------- |
| 1738 | + block_or_loop: Union[BlockRV, LoopRV] |
| 1739 | + The block/loop to be unannotated |
| 1740 | + ann_key : str |
| 1741 | + The annotation key |
| 1742 | +
|
| 1743 | + Examples |
| 1744 | + -------- |
| 1745 | +
|
| 1746 | + Before unannotate, in TensorIR, the IR is: |
| 1747 | +
|
| 1748 | + .. code-block:: python |
| 1749 | +
|
| 1750 | + @T.prim_func |
| 1751 | + def before_unannotate(a: T.handle, b: T.handle) -> None: |
| 1752 | + A = T.match_buffer(a, (128, 128)) |
| 1753 | + B = T.match_buffer(b, (128, 128)) |
| 1754 | + for i, j in T.grid(128, 128): |
| 1755 | + with T.block("B"): |
| 1756 | + vi, vj = T.axis.remap("SS", [i, j]) |
| 1757 | + T.block_attr({"ann_key", "ann_value"}) |
| 1758 | + B[vi, vj] = A[vi, vj] * 2.0 |
| 1759 | +
|
| 1760 | + Create the schedule and do annotate: |
| 1761 | +
|
| 1762 | + .. code-block:: python |
| 1763 | +
|
| 1764 | + sch = tir.Schedule(before_unannotate) |
| 1765 | + sch.unannotate(sch.get_block("B"), "ann_key") |
| 1766 | + print(sch.mod["main"].script()) |
| 1767 | +
|
| 1768 | + After applying unannotate, the IR becomes: |
| 1769 | +
|
| 1770 | + .. code-block:: python |
| 1771 | +
|
| 1772 | + @T.prim_func |
| 1773 | + def after_unannotate(a: T.handle, b: T.handle) -> None: |
| 1774 | + A = T.match_buffer(a, (128, 128)) |
| 1775 | + B = T.match_buffer(b, (128, 128)) |
| 1776 | + for i, j in T.grid(128, 128): |
| 1777 | + with T.block("B"): |
| 1778 | + vi, vj = T.axis.remap("SS", [i, j]) |
| 1779 | + B[vi, vj] = A[vi, vj] * 2.0 |
| 1780 | +
|
| 1781 | + """ |
| 1782 | + _ffi_api.ScheduleUnannotate(self, block_or_loop, ann_key) # pylint: disable=no-member |
| 1783 | + |
1667 | 1784 | ########## Schedule: Misc ########## |
1668 | 1785 |
|
1669 | 1786 | @type_checked |
|
0 commit comments