Skip to content

Commit 5e086cf

Browse files
committed
add doc for mma_fill and mma_store intrin
1 parent 4f945c4 commit 5e086cf

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

include/tvm/tir/builtin.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,31 @@ TVM_DLL const Op& ptx_mma_sp();
632632
*/
633633
TVM_DLL const Op& ptx_ldmatrix();
634634

635+
/*!
636+
* \brief tvm intrinsic for storing the result of PTX MMA into a destination pointer.
637+
* For example, if each thread in a warp of size 32 has 4 elements from the result of
638+
* m16xn8xk16 MMA in its registers, this intrinsic can be used to store the result in a
639+
* 16x8 region in shared or global memory.
640+
*
641+
* There is no real PTX instruction that does that, but we want to hide details of
642+
* complex index manipulation behind this intrinsic to simplify TIR lowering passes (e.g.
643+
* LowerWarpMemory).
644+
*
645+
* void mma_store(IntImm m, IntImm n, Var dst_ptr, Var src_ptr, Expr src_offset, Var dst_stride);
646+
*/
635647
TVM_DLL const Op& mma_store();
648+
649+
/*!
650+
* \brief tvm intrinsic for zero-initalizing an MMA accumulation registor.
651+
* For example, if each thread in a warp of size 32 has 8 elements from the A matrix in
652+
* m16xn8xk16 MMA in its registers, this intrinsic can be used to zero-initialize its
653+
* 4 accumulation registers.
654+
*
655+
* There is no real PTX instruction that does that, but we introduce this intrinsic for the
656+
* same reason as mma_store above.
657+
*
658+
* void mma_fill(IntImm local_size, Var local_ptr, Expr offset);
659+
*/
636660
TVM_DLL const Op& mma_fill();
637661

638662
// TODO(tvm-team) replace the usage of the vector operations by Shuffle.

tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def tile_wmma_fragment(block_read, height, width):
180180
sch.tensorize(sch.get_loops(block_init_c)[-2], mma_fill_intrin)
181181
sch.tensorize(sch.get_loops(C_warp)[-2], mma_store_intrin)
182182

183-
print(sch.mod.script())
183+
# print(sch.mod.script())
184184

185185
f = tvm.build(sch.mod["main"], target="cuda", name="dense")
186186
dev = tvm.device("cuda", 0)

0 commit comments

Comments
 (0)