Skip to content

Conversation

@yzh119
Copy link
Member

@yzh119 yzh119 commented Mar 12, 2023

Motivation

The LowerWarpMemory pass cannot emit shfl_sync instructions because of an internal check introduced in #9727 . Actually if we load value from another lane in the warp, the local_index would inevitably carry the warp index, and this case would be disabled by the check.

This PR fix the issue by disabling the check and add an unit test for warp shuffling.

The PR depends on #14279 , I'll rebase to upstream/main after that PR is merged.

@Lunderberg @masahi @tqchen

@tvm-bot
Copy link
Collaborator

tvm-bot commented Mar 12, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

  • No users to tag found in teams: tensorir, transform See #10317 for details

Generated by tvm-bot

@yzh119 yzh119 force-pushed the enable-warp-shuffle branch from 3aea3ff to 7d8eff7 Compare March 12, 2023 15:40
@yzh119 yzh119 closed this Mar 12, 2023
@yzh119
Copy link
Member Author

yzh119 commented Mar 12, 2023

It seems the unit test still works if I add the ICHECK back, I'll close the PR first.

@junrushao
Copy link
Member

Its fine to keep it open as a draft PR :-)

Copy link
Contributor

@Lunderberg Lunderberg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made a couple of comments as I was reading through, though your comment about the tests passing even with the ICHECK present is interesting. It looks like local_index in the case of warp shuffle is 0, and is used to build the A_warp[0] argument to tvm_warp_shuffle.

B[vi] = A[(vi % 4) * 8 + vi // 4] + T.float32(1)


def test_warp_shuffle_transform():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test looks reasonable as-is, though there's also a tvm.testing.CompareBeforeAfter that you could use to further reduce the boilerplate.

class TestWarpShuffleTransform(tvm.testing.CompareBeforeAfter):
    transform = tvm.tir.transform.LowerWarpMemory()

    def before(A: T.handle("float32", "global"), B: T.handle("float32", "global")):
        ...

    def expected(A: T.handle("float32", "global"), B: T.handle("float32", "global")):
        ...

def main(A: T.handle("float32", "global"), B: T.handle("float32", "global")):
blockIdx_x = T.env_thread("blockIdx.x")
threadIdx_x = T.env_thread("threadIdx.x")
T.func_attr(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the test case only requires the "target" attribute, and only requires "kind" and "thread_warp_size" within that. Can we remove the extra attributes from the unit test?

B_warp = T.allocate([32], "float32", "warp")
T.launch_thread(threadIdx_x, 32)
A_warp_1 = T.Buffer((32,), data=A_warp, scope="warp")
A_1 = T.Buffer((32,), data=A)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of having a separate A: T.handle and A_1: T.Buffer, the buffer could be declared as a parameter A_1: T.Buffer(32). It does result in slightly different TIR, as it follows the style from before MakePackedAPI is applied, but for a unit test would help to emphasize the change being tested.

A_warp_1[threadIdx_x] = A_1[threadIdx_x]
B_warp_1 = T.Buffer((32,), data=B_warp, scope="warp")
T.tvm_storage_sync("warp")
B_warp_1[threadIdx_x] = A_warp_1[threadIdx_x % 4 * 8 + threadIdx_x // 4] + T.float32(1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add a comment here, indicating that this line is the one that should be updated correctly?

@Lunderberg
Copy link
Contributor

(Also, it looks like the initial check dates back to PR#1050, just with different refactorings that touched that line along the way.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants