Skip to content

Conversation

@yzh119
Copy link
Member

@yzh119 yzh119 commented Mar 1, 2023

Motivation

Currently, we have schedule primitives cache_read/cache_write, which allocate cache buffers and create cache stages copying data from the buffer being accessed to the cache buffer. However, cache_read/cache_write do only support customized indices. For the following block:

@T.prim_func
def func(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (129, 129))
    B = T.match_buffer(b, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi + 1, vj + 1] * 2.0

after cache_read("B", 0, "share"), we get:

# from tvm.script import tir as T
@T.prim_func
def main(A: T.Buffer((129, 129), "float32"), B: T.Buffer((128, 128), "float32")):
    # with T.block("root"):
    A_shared = T.alloc_buffer((129, 129), scope="shared")
    for ax0, ax1 in T.grid(129, 129):
        with T.block("A_shared"):
            v0, v1 = T.axis.remap("SS", [ax0, ax1])
            T.reads(A[v0, v1])
            T.writes(A_shared[v0, v1])
            A_shared[v0, v1] = A[v0, v1]
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            T.reads(A_shared[vi + 1, vj + 1])
            T.writes(B[vi, vj])
            B[vi, vj] = A_shared[vi + 1, vj + 1] * T.float32(2)

where we access A_shared using the same indices(vi + 1, vj + 1) as original block, which is not flexible especially we want to do some layout transformation while copying data from original buffer to cache buffer (in MMA tensorization, and in flashattention)

This PR propose a new interface that enables us to customize the indices to access the cache buffer, which is expressive enough to describe transposing and blocking.

Proposed API

Below is the proposed interface of reindex_cache_read (reindex_cache_write has similar interface):

def reindex_cache_read(
    self,
    block: Union[BlockRV, str],
    read_buffer_index: int,
    storage_scope: str,
    index_map: Union[IndexMap, Callable],
) -> BlockRV:
    ...

Where block, read_buffer_index and storage_scope have the same meaning as in cache_read, there is another argument index_map specifies what indices to use to access the cache buffer, in the form of a index map that maps current block itervars to target indices. Suppose the block has itervars vi, vj and the user wants to access the cache buffer with customized indices [vi // 16, vj // 16, vi % 16, vj % 16], user should set the argument index_map to lambda vi, vj: (vi // 16, vj // 16, vi % 16, vj % 16).

Example

By applying reindex_cache_read("B", 0, lambda i, j: (j, i)) to func, we get:

@T.prim_func
def main(A: T.Buffer((129, 129), "float32"), B: T.Buffer((128, 128), "float32")):
    # with T.block("root"):
    A_shared = T.alloc_buffer((128, 128), scope="shared")
    for i, j in T.grid(128, 128):
        with T.block("A_shared"):
            vi, vj = T.axis.remap("SS", [i, j])
            T.reads(A[vi + 1, vj + 1])
            T.writes(A_shared[vj, vi])
            A_shared[vj, vi] = A[vi + 1, vj + 1]
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            T.reads(A_shared[vj, vi])
            T.writes(B[vi, vj])
            B[vi, vj] = A_shared[vj, vi] * T.float32(2)

Notes

Unlike cache_read/cache_write which allows cache_read a rectangle region, we only allows reindex_cache_read a single point, but it's enough to cover most use cases.

The cache stage block follows the original order of loops and block itervars in the block. If a block itervar does not appear in the buffer access region, it and its corresponding loop variables will be omitted. User can then use transform_block_layout primitive to reorder the block itervars and surrounding loops of the cache read/write block.

Relations to Existing Schedule Primitives

  • Relation with reindex
    • reindex only supports the special case of reindex_cache_read/reindex_cache_write, whereindex_map is the identity map, reindex does not have storage_scope field.
  • Relation with transform_layout
    • transform_layout is not designed to transform the layout of input buffers instead of intermediate buffers, and does not have a storage_scope field.
  • Relation with cache_read/wite
    • cache_read/cache_write do not support customized indices.

@tvm-bot
Copy link
Collaborator

tvm-bot commented Mar 1, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

  • No users to tag found in teams: tensorir, primitive See #10317 for details

Generated by tvm-bot

@yzh119
Copy link
Member Author

yzh119 commented Mar 1, 2023

Some additional context (credit to discussion with @andy-yang-1 ):

In some cases, reindex_cache_read/write has the same effect as first applying cache_read and then applying transform_layout, however, that's not always the case especially when the buffer region to cache read is non-trivial.

For example, suppose we want to cache read buffer A in block B in the example above, with indices vi // 4, vj // 4, vi % 4, vj % 4, we can use reindex_cache_read("B", 0, "shared", lambda i, j: (i // 4, j // 4, i % 4, j % 4)) which would transform the program to:

@T.prim_func
def main(A: T.Buffer((129, 129), "float32"), B: T.Buffer((128, 128), "float32")):
    # with T.block("root"):
    B_shared = T.alloc_buffer((32, 32, 4, 4), scope="shared")
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            T.reads(A[vi + 1, vj + 1])
            T.writes(B_shared[vi // 4, vj // 4, vi % 4, vj % 4])
            B_shared[vi // 4, vj // 4, vi % 4, vj % 4] = A[vi + 1, vj + 1] * T.float32(2)
    for i, j in T.grid(128, 128):
        with T.block("B_shared"):
            vi, vj = T.axis.remap("SS", [i, j])
            T.reads(B_shared[vi // 4, vj // 4, vi % 4, vj % 4])
            T.writes(B[vi, vj])
            B[vi, vj] = B_shared[vi // 4, vj // 4, vi % 4, vj % 4]

while applying the cache_read + transform_layout would trigger the following error:

>>> sch.cache_read("B", 0, "shared")
>>> sch.transform_layout("B", ("read", 0), lambda i, j: (i // 4, j //4, i % 4, j % 4))
ScheduleError: An error occurred in the schedule primitive 'transform_layout'.
The IR with diagnostic is:
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(a: T.handle, c: T.handle):
        A = T.match_buffer(a, (129, 129))
        B = T.match_buffer(c, (128, 128))
        with T.block("root"):
            T.reads()
            T.writes()
            A_shared = T.alloc_buffer((129, 129), scope="shared")
            for ax0 in range(129):
                for ax1 in range(129):
                    with T.block("A_shared"):
                        v0 = T.axis.spatial(129, ax0)
                        v1 = T.axis.spatial(129, ax1)
                        T.reads(A[v0, v1])
                        T.writes(A_shared[v0, v1])
                        A_shared[v0, v1] = A[v0, v1]
            for i in range(128):
                for j in range(128):
                    with T.block("B"):
                        vi = T.axis.spatial(128, i)
                        vj = T.axis.spatial(128, j)
                        T.reads(A_shared[vi + 1, vj + 1])
                        T.writes(B[vi, vj])
                        B[vi, vj] = A_shared[vi + 1, vj + 1] * T.float32(2)
Error message: The transformation T.index_map(lambda i, j: (i // 4, j // 4, i % 4, j % 4)) applied on buffer A_shared of shape [129, 129] would result in shape [33, 33, 4, 4].  However, this would introduce padding wherever axis0 == 32 and 1 <= axis2 or axis1 == 32 and 1 <= axis3 is true.

The reason is because by default cache_read/cache_write would allocate buffers that covers [vi + 1, vj + 1], vi \in [0, 128), vj \in [0, 128) which is (129, 129) and do not meet the necessity of transform_layout.

Another frequent use case is sparse access, suppose we want to cache read the following function:

@T.prim_func
def func(a: T.handle, b: T.handle, F: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    F = T.match_buffer(f, (128,), "int32")
    B = T.match_buffer(b, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[F[vi], vj] * 2.0

while reindex_cache_read("B", 0, "shared", lambda i, j: (i // 4, j // 4, i % 4, j % 4)) can get you a correct transformation:

@T.prim_func
def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), F: T.Buffer((128,), "int32")):
    # with T.block("root"):
    B_shared = T.alloc_buffer((32, 32, 4, 4), scope="shared")
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            T.reads(A[F[vi], vj], F[vi])
            T.writes(B_shared[vi // 4, vj // 4, vi % 4, vj % 4])
            B_shared[vi // 4, vj // 4, vi % 4, vj % 4] = A[F[vi], vj] * T.float32(2)
    for i, j in T.grid(128, 128):
        with T.block("B_shared"):
            vi, vj = T.axis.remap("SS", [i, j])
            T.reads(B_shared[vi // 4, vj // 4, vi % 4, vj % 4])
            T.writes(B[vi, vj])
            B[vi, vj] = B_shared[vi // 4, vj // 4, vi % 4, vj % 4]

cache_read + transform_layout would generate a wrong transformation:

>>> sch.cache_read("B", 0, "shared")
>>> sch.transform_layout("B", ("read", 0), lambda i, j: (i // 4, j //4, i % 4, j % 4))
>>> print(sch.mod["main"].script())
@T.prim_func
def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), F: T.Buffer((128,), "int32")):
    # with T.block("root"):
    A_shared = T.alloc_buffer((32, 32, 4, 4), scope="shared")
    for ax0, ax1 in T.grid(128, 128):
        with T.block("A_shared"):
            v0, v1 = T.axis.remap("SS", [ax0, ax1])
            T.reads(A[v0, v1])
            T.writes(A_shared[v0 // 4, v1 // 4, v0 % 4, v1 % 4])
            A_shared[v0 // 4, v1 // 4, v0 % 4, v1 % 4] = A[v0, v1]
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            T.reads(A_shared[F[vi] // 4, vj // 4, F[vi] % 4, vj % 4], F[vi])
            T.writes(B[vi, vj])
            B[vi, vj] = A_shared[F[vi] // 4, vj // 4, F[vi] % 4, vj % 4] * T.float32(2)

@tqchen tqchen changed the title [Primitive] New schedule primitive reindex_cache_read/write [TensorIR][Primitive] New schedule primitive reindex_cache_read/write Mar 1, 2023
@yzh119
Copy link
Member Author

yzh119 commented Mar 1, 2023

cc @spectrometerHBH @vinx13

@quic-sanirudh
Copy link
Contributor

This seems like a useful schedule primitive, thanks for this. I have a doubt. In the first example above that uses reindex_cache_read, the original buffer A has dimensions (129,129), but A_shared that's created is just (128,128).

I understand that in this case the other 2 elements are not needed, but is the primitive automatically verifying whether the other elements are not accessed before copying a smaller chunk to "shared" memory?

@yzh119
Copy link
Member Author

yzh119 commented Mar 2, 2023

Hi @quic-sanirudh , this schedule allocates a buffer with the same volume as the Cartesian product of block itervar domains, regardless of the original buffer size, and will not perform the check because those data are not used in current block.

Copy link
Member

@Hzfengsy Hzfengsy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some nits

@Hzfengsy
Copy link
Member

Hzfengsy commented Mar 2, 2023

@quic-sanirudh
Copy link
Contributor

Hi @quic-sanirudh , this schedule allocates a buffer with the same volume as the Cartesian product of block itervar domains, regardless of the original buffer size, and will not perform the check because those data are not used in current block.

Thanks for the reply @yzh119. If it does not perform correctness checks and generates the loop based on the surrounding itervar domains, then couldn't that lead to potentially incorrect code?

For example: Applying reindex_cache_read on the below modification of your initial example would probably lead to incorrect code:

@T.prim_func
def func(a: T.handle, b: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (129, 129))
    B = T.match_buffer(b, (128, 128))
    C = T.match_buffer(c, (129, 129))
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi + 1, vj + 1] * 2.0
    for i, j in T.grid(129, 129):
        with T.block("C"):
            vi, vj = T.axis.remap("SS", [i, j])
            C[vi, vj] = A[vi, vj] * 3.0

When we apply sch.reindex_cache_read("B", 0, "shared", lambda i, j: (j, i)), we get the below IR:

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer((129, 129), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((129, 129), "float32")):
        # with T.block("root"):
        A_shared = T.alloc_buffer((128, 128), scope="shared")
        for i, j in T.grid(128, 128):
            with T.block("A_shared"):
                vi, vj = T.axis.remap("SS", [i, j])
                T.reads(A[vi + 1, vj + 1])
                T.writes(A_shared[vj, vi])
                A_shared[vj, vi] = A[vi + 1, vj + 1]
        for i, j in T.grid(128, 128):
            with T.block("B"):
                vi, vj = T.axis.remap("SS", [i, j])
                T.reads(A_shared[vj, vi])
                T.writes(B[vi, vj])
                B[vi, vj] = A_shared[vj, vi] * T.float32(2)
        vj_1 = T.int32()
        vi_1 = T.int32()
        for i, j in T.grid(129, 129):
            with T.block("C"):
                vi, vj = T.axis.remap("SS", [i, j])
                T.reads(A_shared[vj_1, vi_1])
                T.writes(C[vi, vj])
                C[vi, vj] = A_shared[vj_1, vi_1] * T.float32(3)

Note that the iteration domain of block "A_shared" is (128, 128), but when A_shared is used in block "C" the iteration domain is (129,129). This means the values of C from C[0,0] to C[0,128] and C[i,0] for 0 <= i <= 128 would be invalid.

@yzh119
Copy link
Member Author

yzh119 commented Mar 2, 2023

Hi @quic-sanirudh yes you are correct, such cases would occur when there are multiple consumer blocks, and we should suggest users using cache_read/cache_write in these cases, which analyzes the union of access regions of all consumer blocks. Supporting this feature for reindex_cache_read_write would make things complicated, for example, in your program you tried to reindex cache read buffer A in block B, it's inevitable that we need to change the indices of buffer A in block C as well which is not desired.

Considering that most of the use cases of reindex_cache_read are orchestrating data movement from shared memory to local registers and do not involve multiple consumer blocks, I suggest:

  1. only support single consumer block
  2. remove the consumer_blocks argument in API
  3. emphasize this point in docstring and suggest user to use cache_read/cache_write in this case.

Do you think this acceptable?

@quic-sanirudh
Copy link
Contributor

Hi @yzh119, thanks for the detailed explanation. Yes I understand that reindex_cache_read/write might not need the extended checks. In that case, as you suggested, I agree that constraining the schedule_primitive to cases where there is only one consumer block makes sense to me.

My only concern was that, applying a schedule primitive that lead to incorrect output would confuse most users. As long as we avoid that, in this case by limiting the applicability of reindex_cache_read, and throwing an error for illegal uses, we should be good.

Besides most other primitives also have such restrictions on cases in which they work and cases where they're not expected to work.

BTW, thanks again for adding this, looks really cool and useful.

@yzh119
Copy link
Member Author

yzh119 commented Mar 2, 2023

Thank you @quic-sanirudh , currently we have these checks:

  1. whether the access region is a single point
  2. whether the set of block itervars appeared in lhs and rhs of cache stage matches.

There are some other checks that I can add:

  1. Whether we can guarantee the indices to access cache buffer >= 0
  2. Whether the user provided index map is bijective (I notice that [TIR] Allow TransformLayout with non-inversible index map #14095 is relaxing this checks, so I doubt if it is necessary).

Copy link
Member

@vinx13 vinx13 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall LGTM

@quic-sanirudh
Copy link
Contributor

  1. whether the access region is a single point

IIUC, this check ensures that there's a single consumer block right? Thanks for the change

@yzh119
Copy link
Member Author

yzh119 commented Mar 3, 2023

Hi @quic-sanirudh single point means the block access a single point region like A[vi, vj] instead of a rectangle region A[vi: vi + 10, vj: vj + 10], and we will fail such case: https://github.com/apache/tvm/pull/14161/files#diff-c7feaaf793277a9477f822dcdc1b0b44c8b3ad7254800240996dcf8a0553b663R1449-R1452

Regarding single consumer issue, we have removed the consumer_blocks argument from reindex_cache_read/reindex_cache_write API, and in the case where there are multiple possible consumer blocks, we will only consider the input block as a consumer block: https://github.com/apache/tvm/pull/14161/files#diff-c7feaaf793277a9477f822dcdc1b0b44c8b3ad7254800240996dcf8a0553b663R1431-R1434

@quic-sanirudh
Copy link
Contributor

Regarding single consumer issue, we have removed the consumer_blocks argument from reindex_cache_read/reindex_cache_write API, and in the case where there are multiple possible consumer blocks, we will only consider the input block as a consumer block: https://github.com/apache/tvm/pull/14161/files#diff-c7feaaf793277a9477f822dcdc1b0b44c8b3ad7254800240996dcf8a0553b663R1431-R1434

Thanks, that sounds good to me. So now when we apply reindex_cache_read, the new buffer will only be used in that block? That definitely solves our problem, so thanks a lot for the change.

@yzh119
Copy link
Member Author

yzh119 commented Mar 4, 2023

@quic-sanirudh yes that's exactly right.

@Hzfengsy
Copy link
Member

Hzfengsy commented Mar 6, 2023

cc @rainy-memory @andy-yang-1 if you are interested.

@Hzfengsy Hzfengsy merged commit 9d732d0 into apache:main Mar 8, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants