-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[TensorIR][Transform] Enable warp shuffling for LowerWarpMemory
#14280
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e70d8eb
b3b98fa
409f7c5
9b92252
ec53460
4cc6681
4a5396a
fdabd56
5c4228a
c75c233
107b21e
8ed3953
7d8eff7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,6 +19,7 @@ | |
| import tvm | ||
| import tvm.testing | ||
| from tvm import te | ||
| from tvm.script import tir as T | ||
| from tvm.contrib.nvcc import have_fp16 | ||
|
|
||
|
|
||
|
|
@@ -347,5 +348,105 @@ def test_lower_warp_memory_divide_by_factor(): | |
| tvm.tir.transform.LowerWarpMemory()(mod)["f_kernel0"] | ||
|
|
||
|
|
||
| @T.prim_func | ||
| def func(a: T.handle, b: T.handle) -> None: | ||
| A = T.match_buffer(a, [32], "float32") | ||
| B = T.match_buffer(b, [32], "float32") | ||
| for i in range(32): | ||
| with T.block("warp_shuffle"): | ||
| vi = T.axis.spatial(32, i) | ||
| B[vi] = A[(vi % 4) * 8 + vi // 4] + T.float32(1) | ||
|
|
||
|
|
||
| def test_warp_shuffle_transform(): | ||
| @tvm.script.ir_module | ||
| class Before: | ||
| @T.prim_func | ||
| def main(A: T.handle("float32", "global"), B: T.handle("float32", "global")): | ||
| blockIdx_x = T.env_thread("blockIdx.x") | ||
| threadIdx_x = T.env_thread("threadIdx.x") | ||
| T.func_attr( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like the test case only requires the |
||
| { | ||
| "calling_conv": 2, | ||
| "global_symbol": "main", | ||
| "target": T.target( | ||
| { | ||
| "host": {"keys": ["cpu"], "kind": "llvm", "tag": ""}, | ||
| "keys": ["cuda", "gpu"], | ||
| "kind": "cuda", | ||
| "max_num_threads": 1024, | ||
| "tag": "", | ||
| "thread_warp_size": 32, | ||
| } | ||
| ), | ||
| "tir.device_thread_axis": [ | ||
| T.iter_var(blockIdx_x, [0, 1], "ThreadIndex", "blockIdx.x"), | ||
| T.iter_var(threadIdx_x, [0, 32], "ThreadIndex", "threadIdx.x"), | ||
| ], | ||
| "tir.is_global_func": 1, | ||
| "tir.noalias": 1, | ||
| } | ||
| ) | ||
| T.launch_thread(blockIdx_x, 1) | ||
| A_warp = T.allocate([32], "float32", "warp") | ||
| B_warp = T.allocate([32], "float32", "warp") | ||
| T.launch_thread(threadIdx_x, 32) | ||
| A_warp_1 = T.Buffer((32,), data=A_warp, scope="warp") | ||
| A_1 = T.Buffer((32,), data=A) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of having a separate |
||
| A_warp_1[threadIdx_x] = A_1[threadIdx_x] | ||
| B_warp_1 = T.Buffer((32,), data=B_warp, scope="warp") | ||
| T.tvm_storage_sync("warp") | ||
| B_warp_1[threadIdx_x] = A_warp_1[threadIdx_x % 4 * 8 + threadIdx_x // 4] + T.float32(1) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we add a comment here, indicating that this line is the one that should be updated correctly? |
||
| B_1 = T.Buffer((32,), data=B) | ||
| B_1[threadIdx_x] = B_warp_1[threadIdx_x] | ||
|
|
||
| @tvm.script.ir_module | ||
| class Expected: | ||
| @T.prim_func | ||
| def main(A: T.handle("float32", "global"), B: T.handle("float32", "global")): | ||
| blockIdx_x = T.env_thread("blockIdx.x") | ||
| threadIdx_x = T.env_thread("threadIdx.x") | ||
| T.func_attr( | ||
| { | ||
| "calling_conv": 2, | ||
| "global_symbol": "main", | ||
| "target": T.target( | ||
| { | ||
| "host": {"keys": ["cpu"], "kind": "llvm", "tag": ""}, | ||
| "keys": ["cuda", "gpu"], | ||
| "kind": "cuda", | ||
| "max_num_threads": 1024, | ||
| "tag": "", | ||
| "thread_warp_size": 32, | ||
| } | ||
| ), | ||
| "tir.device_thread_axis": [ | ||
| T.iter_var(blockIdx_x, [0, 1], "ThreadIndex", "blockIdx.x"), | ||
| T.iter_var(threadIdx_x, [0, 32], "ThreadIndex", "threadIdx.x"), | ||
| ], | ||
| "tir.is_global_func": 1, | ||
| "tir.noalias": 1, | ||
| } | ||
| ) | ||
| T.launch_thread(blockIdx_x, 1) | ||
| A_warp = T.allocate([1], "float32", "local") | ||
| B_warp = T.allocate([1], "float32", "local") | ||
| T.launch_thread(threadIdx_x, 32) | ||
| A_warp_1 = T.Buffer((32,), data=A_warp, scope="local") | ||
| A_1 = T.Buffer((32,), data=A) | ||
| A_warp_1[0] = A_1[threadIdx_x] | ||
| B_warp_1 = T.Buffer((32,), data=B_warp, scope="local") | ||
| T.tvm_storage_sync("warp") | ||
| B_warp_1[0] = T.tvm_warp_shuffle( | ||
| T.tvm_warp_activemask(), A_warp_1[0], threadIdx_x % 4 * 8 + threadIdx_x // 4, 32, 32 | ||
| ) + T.float32(1) | ||
| B_1 = T.Buffer((32,), data=B) | ||
| B_1[threadIdx_x] = B_warp_1[0] | ||
|
|
||
| after = tvm.tir.transform.LowerWarpMemory()(Before) | ||
|
|
||
| tvm.ir.assert_structural_equal(after, Expected) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| tvm.testing.main() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test looks reasonable as-is, though there's also a
tvm.testing.CompareBeforeAfterthat you could use to further reduce the boilerplate.