-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[TensorIR][Transform] Enable warp shuffling for LowerWarpMemory
#14280
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
Generated by tvm-bot |
3aea3ff to
7d8eff7
Compare
|
It seems the unit test still works if I add the |
|
Its fine to keep it open as a draft PR :-) |
Lunderberg
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Made a couple of comments as I was reading through, though your comment about the tests passing even with the ICHECK present is interesting. It looks like local_index in the case of warp shuffle is 0, and is used to build the A_warp[0] argument to tvm_warp_shuffle.
| B[vi] = A[(vi % 4) * 8 + vi // 4] + T.float32(1) | ||
|
|
||
|
|
||
| def test_warp_shuffle_transform(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test looks reasonable as-is, though there's also a tvm.testing.CompareBeforeAfter that you could use to further reduce the boilerplate.
class TestWarpShuffleTransform(tvm.testing.CompareBeforeAfter):
transform = tvm.tir.transform.LowerWarpMemory()
def before(A: T.handle("float32", "global"), B: T.handle("float32", "global")):
...
def expected(A: T.handle("float32", "global"), B: T.handle("float32", "global")):
...| def main(A: T.handle("float32", "global"), B: T.handle("float32", "global")): | ||
| blockIdx_x = T.env_thread("blockIdx.x") | ||
| threadIdx_x = T.env_thread("threadIdx.x") | ||
| T.func_attr( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like the test case only requires the "target" attribute, and only requires "kind" and "thread_warp_size" within that. Can we remove the extra attributes from the unit test?
| B_warp = T.allocate([32], "float32", "warp") | ||
| T.launch_thread(threadIdx_x, 32) | ||
| A_warp_1 = T.Buffer((32,), data=A_warp, scope="warp") | ||
| A_1 = T.Buffer((32,), data=A) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of having a separate A: T.handle and A_1: T.Buffer, the buffer could be declared as a parameter A_1: T.Buffer(32). It does result in slightly different TIR, as it follows the style from before MakePackedAPI is applied, but for a unit test would help to emphasize the change being tested.
| A_warp_1[threadIdx_x] = A_1[threadIdx_x] | ||
| B_warp_1 = T.Buffer((32,), data=B_warp, scope="warp") | ||
| T.tvm_storage_sync("warp") | ||
| B_warp_1[threadIdx_x] = A_warp_1[threadIdx_x % 4 * 8 + threadIdx_x // 4] + T.float32(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we add a comment here, indicating that this line is the one that should be updated correctly?
|
(Also, it looks like the initial check dates back to PR#1050, just with different refactorings that touched that line along the way.) |
Motivation
The
LowerWarpMemorypass cannot emitshfl_syncinstructions because of an internal check introduced in #9727 . Actually if we load value from another lane in the warp, thelocal_indexwould inevitably carry the warp index, and this case would be disabled by the check.This PR fix the issue by disabling the check and add an unit test for warp shuffling.
The PR depends on #14279 , I'll rebase to upstream/main after that PR is merged.
@Lunderberg @masahi @tqchen