@@ -632,6 +632,33 @@ TVM_DLL const Op& ptx_mma_sp();
632632 */
633633TVM_DLL const Op& ptx_ldmatrix ();
634634
635+ /* !
636+ * \brief tvm intrinsic for storing the result of PTX MMA into a destination pointer.
637+ * For example, if each thread in a warp of size 32 has 4 elements from the result of
638+ * m16xn8xk16 MMA in its registers, this intrinsic can be used to store the result in a
639+ * 16x8 region in shared or global memory.
640+ *
641+ * There is no real PTX instruction that does that, but we want to hide details of
642+ * complex index manipulation behind this intrinsic to simplify TIR lowering passes (e.g.
643+ * LowerWarpMemory).
644+ *
645+ * void mma_store(IntImm m, IntImm n, Var dst_ptr, Var src_ptr, Expr src_offset, Var dst_stride);
646+ */
647+ TVM_DLL const Op& mma_store ();
648+
649+ /* !
650+ * \brief tvm intrinsic for zero-initalizing an MMA accumulation registor.
651+ * For example, if each thread in a warp of size 32 has 8 elements from the A matrix in
652+ * m16xn8xk16 MMA in its registers, this intrinsic can be used to zero-initialize its
653+ * 4 accumulation registers.
654+ *
655+ * There is no real PTX instruction that does that, but we introduce this intrinsic for the
656+ * same reason as mma_store above.
657+ *
658+ * void mma_fill(IntImm local_size, Var local_ptr, Expr offset);
659+ */
660+ TVM_DLL const Op& mma_fill ();
661+
635662// TODO(tvm-team) replace the usage of the vector operations by Shuffle.
636663/* !
637664 * \brief Get the high level half of the vector
0 commit comments