Skip to content
7 changes: 7 additions & 0 deletions include/tvm/arith/iter_affine_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,13 @@ Array<Array<IterMark>> SubspaceDivide(const Array<PrimExpr>& bindings,
bool require_bijective, arith::Analyzer* analyzer,
DiagnosticContext diag_ctx);

/*!
* \brief Given an IterMapExpr, transform it to normal PrimExpr.
* \param expr The input IterMapExpr.
* \return The corresponding normal PrimExpr.
*/
PrimExpr NormalizeIterMapToExpr(const IterMapExpr& expr);

} // namespace arith
} // namespace tvm
#endif // TVM_ARITH_ITER_AFFINE_MAP_H_
52 changes: 52 additions & 0 deletions include/tvm/tir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,58 @@ class LinkedParam : public ObjectRef {
TVM_DEFINE_OBJECT_REF_COW_METHOD(LinkedParamNode);
};

/*!
* \brief Tensor intrinsics for tensorization
*/
class TensorIntrinNode : public Object {
public:
/*! \brief The function to describe the computation. */
PrimFunc desc;
/*! \brief The function of the implementation for the execution. */
PrimFunc impl;

void VisitAttrs(AttrVisitor* v) {
v->Visit("desc", &desc);
v->Visit("impl", &impl);
}

static constexpr const char* _type_key = "tir.TensorIntrin";
TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinNode, Object);
};

/*!
* \brief Managed reference to TensorIntrinNode.
*/
class TensorIntrin : public ObjectRef {
public:
/*!
* \brief Constructor
* \param desc The function to describe the computation.
* \param impl The function of the implementation for the execution.
*/
TVM_DLL explicit TensorIntrin(PrimFunc desc, PrimFunc impl);

/*!
* \brief Create and register a TensorIntrin. After registration, the TensorIntrin can be looked
* up with its name.
* \param name The name of the TensorIntrin to register
* \param intrin The TensorIntrin to register.
* \throws This method throws an exception if the TensorIntrin with the specified name already
* exists.
*/
TVM_DLL static void Register(String name, TensorIntrin intrin);

/*!
* \brief Look up TensorIntrin by name. Raises an exception if not found.
* \param name The name of the TensorIntrin.
* \return The TensorIntrin with the specified name.
* \throws This method throws an exception if the TensorIntrin does not exist.
*/
TVM_DLL static TensorIntrin Get(String name);

TVM_DEFINE_OBJECT_REF_METHODS(TensorIntrin, ObjectRef, TensorIntrinNode)
};

/*!
* \brief Specialize parameters of PrimFunc.
* \param func The PrimFunc to be specialized.
Expand Down
19 changes: 19 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,25 @@ class ScheduleNode : public runtime::Object {
*/
virtual void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) = 0;
/******** Schedule: Blockize & Tensorize ********/
/*!
* \brief Convert the subtree rooted at a specific loop into a block.
* \param loop_rv the root of the subtree
* \return the new block
*/
virtual BlockRV Blockize(const LoopRV& loop_rv) = 0;
/*!
* \brief Tensorize the computation enclosed by loop with the tensor intrin.
* \param loop_rv The loop to be tensorized
* \param intrin Name of the tensor intrinsic
*/
virtual void Tensorize(const LoopRV& loop_rv, const String& intrin) = 0;
/*!
* \brief Tensorize the computation enclosed by loop with the tensor intrin.
* \param block_rv The block to be tensorized
* \param intrin Name of the tensor intrinsic
*/
virtual void Tensorize(const BlockRV& block_rv, const String& intrin) = 0;

/******** Schedule: Annotation ********/
/*!
* \brief Annotate a loop with a key value pair
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list
from .stmt import BufferRegion, MatchBufferRegion, Block, BlockRealize

from .function import PrimFunc
from .function import PrimFunc, TensorIntrin

from .op import call_packed, call_intrin, call_pure_extern, call_extern
from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace
Expand Down
48 changes: 48 additions & 0 deletions python/tvm/tir/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,51 @@ def script(self, tir_prefix: str = "T", show_meta: bool = False) -> str:
return tvm._ffi.get_global_func("script.AsTVMScript")(
self, tir_prefix, show_meta
) # type: ignore


@tvm._ffi.register_object("tir.TensorIntrin")
class TensorIntrin(Object):
"""A tensor intrinsic.

Parameters
----------
desc : PrimFunc
The function to describe the computation.

impl : PrimFunc
The function of the implementation for the execution.
"""

def __init__(self, desc, impl):
self.__init_handle_by_constructor__(_ffi_api.TensorIntrin, desc, impl)

@staticmethod
def register(name: str, desc: PrimFunc, impl: PrimFunc):
"""Register a tensor intrinsic with its name.

Parameters
----------
name : str
The name of the TensorIntrin to register.
desc : PrimFunc
The function to describe the computation.
impl : PrimFunc
The function of the implementation for the execution.
"""
return _ffi_api.TensorIntrinRegister(name, TensorIntrin(desc, impl)) # type: ignore

@staticmethod
def get(name: str):
"""Look up a tensor intrinsic by its name.

Parameters
----------
name : str
The name of the TensorIntrin to look up.

Returns
-------
result : TensorIntrin
The TensorIntrin with the specified name.
"""
return _ffi_api.TensorIntrinGet(name) # pylint: type: ignore
229 changes: 229 additions & 0 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1759,6 +1759,235 @@ def after_set_scope(

########## Schedule: Blockize & Tensorize ##########

@type_checked
def blockize(self, loop: LoopRV) -> BlockRV:
"""Convert the subtree rooted at a specific loop into a block.

Parameters
----------
loop : LoopRV
The root of the subtree.

Returns
-------
result : BlockRV
The new block.

Examples
--------

Before blockize, in TensorIR, the IR is:

.. code-block:: python

@T.prim_func
def before_blockize(
A: T.Buffer[(128, 128), "float32"],
B: T.Buffer[(128, 128), "float32"]
) -> None:
for i_0, j_0, i_1, j_1 in T.grid(8, 8, 16, 16):
with T.block("B"):
vi = T.axis.spatial(128, i_0 * 16 + i_1)
vj = T.axis.spatial(128, j_0 * 16 + j_1)
T.reads(A[vi, vj])
T.writes(B[vi, vj])
B[vi, vj] = A[vi, vj] * T.float32(2)

Create the schedule and do set_scope:

.. code-block:: python

sch = tir.Schedule(before_blockize)
B = sch.get_block("B")
_, _, i1, _ = sch.get_loops(B)
sch.blockize(i1)
print(sch.mod["main"].script())

After applying blockize, the IR becomes:

.. code-block:: python

@T.prim_func
def after_blockize(
A: T.Buffer[(128, 128), "float32"],
B: T.Buffer[(128, 128), "float32"]
)-> None:
for i_0, j_0 in T.grid(8, 8):
with T.block("B_o"):
vio, vjo = T.axis.remap("SS", [i_0, j_0])
T.reads(A[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16])
T.writes(B[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16])
for i_1, j_1 in T.grid(16, 16):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i_1, j_1])
T.reads(A[vio * 16 + vi, vjo * 16 + vj])
T.writes(B[vio * 16 + vi, vjo * 16 + vj])
B[vio * 16 + vi, vjo * 16 + vj] = A[vio * 16 + vi, vjo * 16 + vj] \
* T.float32(2)

Note
----
blockize requires there is exactly one block under the given loop and the bindings of the
block are divisible by the subspace represented by the loops starting at the given loop.
"""

return _ffi_api.ScheduleBlockize(self, loop) # type: ignore # pylint: disable=no-member

@type_checked
def tensorize(self, block_or_loop: Union[BlockRV, LoopRV], tensor_intrin: str) -> None:
"""Tensorize the computation enclosed by loop with the tensor intrinsic.

Parameters
----------
block_or_loop : Union[BlockRV, LoopRV]
The loop to be tensorized.
tensor_intrin : str
The tensor intrin or the name of the tensor intrin.

Examples
--------

Before tensorize, in TensorIR, the IR is:

.. code-block:: python

@T.prim_func
def before_tensorize(
A: T.Buffer[(128, 128), "float32"],
B: T.Buffer[(128, 128), "float32"],
C: T.Buffer[(128, 128), "float32"],
) -> None:
# body
# with T.block("root")
for i_0, j_0, k_0, i_1, j_1, k_1 in T.grid(8, 8, 8, 16, 16, 16):
with T.block("update"):
vi = T.axis.spatial(128, i_0 * 16 + i_1)
vj = T.axis.spatial(128, j_0 * 16 + j_1)
vk = T.axis.reduce(128, k_0 * 16 + k_1)
T.reads(C[vi, vj], A[vi, vk], B[vj, vk])
T.writes(C[vi, vj])
with T.init():
C[vi, vj] = T.float32(0)
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]

Declare and register the tensor intrinsic:

.. code-block:: python

@T.prim_func
def mma_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (16, 16), align=128, offset_factor=1)
B = T.match_buffer(b, (16, 16), align=128, offset_factor=1)
C = T.match_buffer(c, (16, 16), align=128, offset_factor=1)

with T.block("root"):
T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0 : 16, 0 : 16])
T.writes(C[0 : 16, 0 : 16])
for i, j, k in T.grid(16, 16, 16):
with T.block("update"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]


@T.prim_func
def mma_intrin(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (16, 16), align=128, offset_factor=1)
B = T.match_buffer(b, (16, 16), align=128, offset_factor=1)
C = T.match_buffer(c, (16, 16), align=128, offset_factor=1)

with T.block("root"):
T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0 : 16, 0 : 16])
T.writes(C[0 : 16, 0 : 16])
T.evaluate(
T.tvm_mma_sync(
C.data,
C.elem_offset // 256,
A.data,
A.elem_offset // 256,
B.data,
B.elem_offset // 256,
C.data,
C.elem_offset // 256,
dtype="handle",
)
)

tir.TensorIntrin.register("test_mma_intrin", mma_desc, mma_intrin)

Create the schedule and do tensorize:

.. code-block:: python

sch = tir.Schedule(before_tensorize)
update = sch.get_block("update")
_, _, _, i1, _, _ = sch.get_loops(update)
sch.tensorize(i1, "test_mma_intrin")
print(sch.mod["main"].script())

After applying tensorize, the IR becomes:

.. code-block:: python

@T.prim_func
def after_tensorize(
A: T.Buffer[(128, 128), "float32"],
B: T.Buffer[(128, 128), "float32"],
C: T.Buffer[(128, 128), "float32"],
) -> None:
# body
# with T.block("root")
for i_0, j_0, k_0 in T.grid(8, 8, 8):
with T.block("update_o"):
vio, vjo, vko = T.axis.remap("SSR", [i_0, j_0, k_0])
T.reads(
C[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16],
A[vio * 16 : vio * 16 + 16, vko * 16 : vko * 16 + 16],
B[vjo * 16 : vjo * 16 + 16, vko * 16 : vko * 16 + 16],
)
T.writes(C[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16])
A_1 = T.match_buffer(
A[vio * 16 : vio * 16 + 16, vko * 16 : vko * 16 + 16],
[16, 16],
dtype="float32",
offset_factor=1,
)
B_1 = T.match_buffer(
B[vjo * 16 : vjo * 16 + 16, vko * 16 : vko * 16 + 16],
[16, 16],
dtype="float32",
offset_factor=1,
)
C_1 = T.match_buffer(
C[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16],
[16, 16],
dtype="float32",
offset_factor=1,
)
with T.init():
for i_1, j_1 in T.grid(16, 16):
with T.block("update_init"):
vi_init, vj_init = T.axis.remap("SS", [i_1, j_1])
T.reads()
T.writes(C[vio * 16 + vi_init, vjo * 16 + vj_init])
C[vio * 16 + vi_init, vjo * 16 + vj_init] = T.float32(0)
T.evaluate(
T.tvm_mma_sync(
C_1.data,
C_1.elem_offset // 256,
A_1.data,
A_1.elem_offset // 256,
B_1.data,
B_1.elem_offset // 256,
C_1.data,
C_1.elem_offset // 256,
dtype="handle",
)
)
"""
_ffi_api.ScheduleTensorize( # type: ignore # pylint: disable=no-member
self, block_or_loop, tensor_intrin
)

########## Schedule: Annotation ##########

@type_checked
Expand Down
Loading