-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[TensorIR][Primitive] New schedule primitive reindex_cache_read/write
#14161
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
Generated by tvm-bot |
|
Some additional context (credit to discussion with @andy-yang-1 ): In some cases, For example, suppose we want to cache read buffer @T.prim_func
def main(A: T.Buffer((129, 129), "float32"), B: T.Buffer((128, 128), "float32")):
# with T.block("root"):
B_shared = T.alloc_buffer((32, 32, 4, 4), scope="shared")
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A[vi + 1, vj + 1])
T.writes(B_shared[vi // 4, vj // 4, vi % 4, vj % 4])
B_shared[vi // 4, vj // 4, vi % 4, vj % 4] = A[vi + 1, vj + 1] * T.float32(2)
for i, j in T.grid(128, 128):
with T.block("B_shared"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(B_shared[vi // 4, vj // 4, vi % 4, vj % 4])
T.writes(B[vi, vj])
B[vi, vj] = B_shared[vi // 4, vj // 4, vi % 4, vj % 4]while applying the >>> sch.cache_read("B", 0, "shared")
>>> sch.transform_layout("B", ("read", 0), lambda i, j: (i // 4, j //4, i % 4, j % 4))
ScheduleError: An error occurred in the schedule primitive 'transform_layout'.
The IR with diagnostic is:
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(a: T.handle, c: T.handle):
A = T.match_buffer(a, (129, 129))
B = T.match_buffer(c, (128, 128))
with T.block("root"):
T.reads()
T.writes()
A_shared = T.alloc_buffer((129, 129), scope="shared")
for ax0 in range(129):
for ax1 in range(129):
with T.block("A_shared"):
v0 = T.axis.spatial(129, ax0)
v1 = T.axis.spatial(129, ax1)
T.reads(A[v0, v1])
T.writes(A_shared[v0, v1])
A_shared[v0, v1] = A[v0, v1]
for i in range(128):
for j in range(128):
with T.block("B"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j)
T.reads(A_shared[vi + 1, vj + 1])
T.writes(B[vi, vj])
B[vi, vj] = A_shared[vi + 1, vj + 1] * T.float32(2)
Error message: The transformation T.index_map(lambda i, j: (i // 4, j // 4, i % 4, j % 4)) applied on buffer A_shared of shape [129, 129] would result in shape [33, 33, 4, 4]. However, this would introduce padding wherever axis0 == 32 and 1 <= axis2 or axis1 == 32 and 1 <= axis3 is true.The reason is because by default Another frequent use case is sparse access, suppose we want to cache read the following function: @T.prim_func
def func(a: T.handle, b: T.handle, F: T.handle) -> None:
A = T.match_buffer(a, (128, 128))
F = T.match_buffer(f, (128,), "int32")
B = T.match_buffer(b, (128, 128))
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[F[vi], vj] * 2.0while @T.prim_func
def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), F: T.Buffer((128,), "int32")):
# with T.block("root"):
B_shared = T.alloc_buffer((32, 32, 4, 4), scope="shared")
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A[F[vi], vj], F[vi])
T.writes(B_shared[vi // 4, vj // 4, vi % 4, vj % 4])
B_shared[vi // 4, vj // 4, vi % 4, vj % 4] = A[F[vi], vj] * T.float32(2)
for i, j in T.grid(128, 128):
with T.block("B_shared"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(B_shared[vi // 4, vj // 4, vi % 4, vj % 4])
T.writes(B[vi, vj])
B[vi, vj] = B_shared[vi // 4, vj // 4, vi % 4, vj % 4]
>>> sch.cache_read("B", 0, "shared")
>>> sch.transform_layout("B", ("read", 0), lambda i, j: (i // 4, j //4, i % 4, j % 4))
>>> print(sch.mod["main"].script())
@T.prim_func
def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), F: T.Buffer((128,), "int32")):
# with T.block("root"):
A_shared = T.alloc_buffer((32, 32, 4, 4), scope="shared")
for ax0, ax1 in T.grid(128, 128):
with T.block("A_shared"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v0, v1])
T.writes(A_shared[v0 // 4, v1 // 4, v0 % 4, v1 % 4])
A_shared[v0 // 4, v1 // 4, v0 % 4, v1 % 4] = A[v0, v1]
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A_shared[F[vi] // 4, vj // 4, F[vi] % 4, vj % 4], F[vi])
T.writes(B[vi, vj])
B[vi, vj] = A_shared[F[vi] // 4, vj // 4, F[vi] % 4, vj % 4] * T.float32(2) |
reindex_cache_read/writereindex_cache_read/write
|
This seems like a useful schedule primitive, thanks for this. I have a doubt. In the first example above that uses I understand that in this case the other 2 elements are not needed, but is the primitive automatically verifying whether the other elements are not accessed before copying a smaller chunk to |
|
Hi @quic-sanirudh , this schedule allocates a buffer with the same volume as the Cartesian product of block itervar domains, regardless of the original buffer size, and will not perform the check because those data are not used in current block. |
Hzfengsy
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some nits
Thanks for the reply @yzh119. If it does not perform correctness checks and generates the loop based on the surrounding itervar domains, then couldn't that lead to potentially incorrect code? For example: Applying @T.prim_func
def func(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (129, 129))
B = T.match_buffer(b, (128, 128))
C = T.match_buffer(c, (129, 129))
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi + 1, vj + 1] * 2.0
for i, j in T.grid(129, 129):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = A[vi, vj] * 3.0When we apply # from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((129, 129), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((129, 129), "float32")):
# with T.block("root"):
A_shared = T.alloc_buffer((128, 128), scope="shared")
for i, j in T.grid(128, 128):
with T.block("A_shared"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A[vi + 1, vj + 1])
T.writes(A_shared[vj, vi])
A_shared[vj, vi] = A[vi + 1, vj + 1]
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A_shared[vj, vi])
T.writes(B[vi, vj])
B[vi, vj] = A_shared[vj, vi] * T.float32(2)
vj_1 = T.int32()
vi_1 = T.int32()
for i, j in T.grid(129, 129):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A_shared[vj_1, vi_1])
T.writes(C[vi, vj])
C[vi, vj] = A_shared[vj_1, vi_1] * T.float32(3)Note that the iteration domain of block |
|
Hi @quic-sanirudh yes you are correct, such cases would occur when there are multiple consumer blocks, and we should suggest users using Considering that most of the use cases of
Do you think this acceptable? |
|
Hi @yzh119, thanks for the detailed explanation. Yes I understand that My only concern was that, applying a schedule primitive that lead to incorrect output would confuse most users. As long as we avoid that, in this case by limiting the applicability of Besides most other primitives also have such restrictions on cases in which they work and cases where they're not expected to work. BTW, thanks again for adding this, looks really cool and useful. |
|
Thank you @quic-sanirudh , currently we have these checks:
There are some other checks that I can add:
|
vinx13
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
overall LGTM
IIUC, this check ensures that there's a single consumer block right? Thanks for the change |
|
Hi @quic-sanirudh single point means the block access a single point region like Regarding single consumer issue, we have removed the |
Thanks, that sounds good to me. So now when we apply |
|
@quic-sanirudh yes that's exactly right. |
|
cc @rainy-memory @andy-yang-1 if you are interested. |
Motivation
Currently, we have schedule primitives
cache_read/cache_write, which allocate cache buffers and create cache stages copying data from the buffer being accessed to the cache buffer. However,cache_read/cache_writedo only support customized indices. For the following block:after
cache_read("B", 0, "share"), we get:where we access
A_sharedusing the same indices(vi + 1, vj + 1) as original block, which is not flexible especially we want to do some layout transformation while copying data from original buffer to cache buffer (in MMA tensorization, and in flashattention)This PR propose a new interface that enables us to customize the indices to access the cache buffer, which is expressive enough to describe transposing and blocking.
Proposed API
Below is the proposed interface of
reindex_cache_read(reindex_cache_writehas similar interface):Where
block,read_buffer_indexandstorage_scopehave the same meaning as incache_read, there is another argumentindex_mapspecifies what indices to use to access the cache buffer, in the form of a index map that maps current block itervars to target indices. Suppose the block has itervarsvi, vjand the user wants to access the cache buffer with customized indices[vi // 16, vj // 16, vi % 16, vj % 16], user should set the argumentindex_maptolambda vi, vj: (vi // 16, vj // 16, vi % 16, vj % 16).Example
By applying
reindex_cache_read("B", 0, lambda i, j: (j, i))tofunc, we get:Notes
Unlike
cache_read/cache_writewhich allowscache_reada rectangle region, we only allowsreindex_cache_reada single point, but it's enough to cover most use cases.The cache stage block follows the original order of loops and block itervars in the block. If a block itervar does not appear in the buffer access region, it and its corresponding loop variables will be omitted. User can then use
transform_block_layoutprimitive to reorder the block itervars and surrounding loops of the cache read/write block.Relations to Existing Schedule Primitives
reindexreindexonly supports the special case ofreindex_cache_read/reindex_cache_write, whereindex_mapis the identity map,reindexdoes not havestorage_scopefield.transform_layouttransform_layoutis not designed to transform the layout of input buffers instead of intermediate buffers, and does not have astorage_scopefield.cache_read/witecache_read/cache_writedo not support customized indices.