Skip to content

Commit 30bf013

Browse files
cblmemoSiyuan Feng
andauthored
[TIR][Schedule] Add unittest for read_write_at (#14395)
This PR adds unittest for schedule primitive read_at and write_at. Co-authored-by: Siyuan Feng <[email protected]>
1 parent 6e70e79 commit 30bf013

File tree

1 file changed

+221
-0
lines changed

1 file changed

+221
-0
lines changed
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
# or more contributor license agreements. See the NOTICE file
2+
# distributed with this work for additional information
3+
# regarding copyright ownership. The ASF licenses this file
4+
# to you under the Apache License, Version 2.0 (the
5+
# "License"); you may not use this file except in compliance
6+
# with the License. You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing,
11+
# software distributed under the License is distributed on an
12+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13+
# KIND, either express or implied. See the License for the
14+
# specific language governing permissions and limitations
15+
# under the License.
16+
# pylint: disable=missing-function-docstring,missing-module-docstring
17+
import sys
18+
19+
import pytest
20+
21+
import tvm
22+
from tvm import tir
23+
from tvm.script import tir as T
24+
from tvm.tir.schedule.testing import verify_trace_roundtrip
25+
26+
27+
# fmt: off
28+
# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks,not-callable
29+
30+
@T.prim_func
31+
def cuda_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=undefined-loop-variable
32+
A = T.match_buffer(a, [2048, 2048], "float32")
33+
B = T.match_buffer(b, [2048, 2048], "float32")
34+
C = T.match_buffer(c, [2048, 2048], "float32")
35+
for by in T.thread_binding(0, 32, thread = "blockIdx.y"):
36+
for bx in T.thread_binding(0, 32, thread = "blockIdx.x"):
37+
for vy in T.thread_binding(0, 2, thread = "vthread.y"):
38+
for vx in T.thread_binding(0, 2, thread = "vthread.x"):
39+
for ty in T.thread_binding(0, 8, thread = "threadIdx.y"):
40+
for tx in T.thread_binding(0, 8, thread = "threadIdx.x"):
41+
for k0 in T.serial(0, 256):
42+
for k1 in T.unroll(0, 8):
43+
for _, i, j in T.grid(1, 4, 4):
44+
with T.block("C"):
45+
vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i)
46+
vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j)
47+
vk = T.axis.R(2048, k0 * 8 + k1)
48+
T.reads([C[vi, vj], A[vi, vk], B[vk, vj]])
49+
T.writes([C[vi, vj]])
50+
with T.init():
51+
C[vi, vj] = 0.0
52+
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
53+
54+
55+
@T.prim_func
56+
def cuda_matmul_read_at_a(a: T.handle, b: T.handle, c: T.handle) -> None:
57+
A = T.match_buffer(a, [2048, 2048], dtype="float32")
58+
B = T.match_buffer(b, [2048, 2048], dtype="float32")
59+
C = T.match_buffer(c, [2048, 2048], dtype="float32")
60+
A_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared")
61+
for by in T.thread_binding(0, 32, thread="blockIdx.y"):
62+
for bx in T.thread_binding(0, 32, thread="blockIdx.x"):
63+
for vy in T.thread_binding(0, 2, thread="vthread.y"):
64+
for vx in T.thread_binding(0, 2, thread="vthread.x"):
65+
for ty in T.thread_binding(0, 8, thread="threadIdx.y"):
66+
for tx in T.thread_binding(0, 8, thread="threadIdx.x"):
67+
for k0 in T.serial(0, 256):
68+
with T.block("A_shared"):
69+
v0 = T.axis.S(32, by)
70+
v1 = T.axis.S(256, k0)
71+
T.reads([A[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]])
72+
T.writes([A_shared[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]])
73+
T.block_attr({"auto_copy":1})
74+
for ax0, ax1 in T.grid(64, 8):
75+
A_shared[v0 * 64 + ax0, v1 * 8 + ax1] = A[v0 * 64 + ax0, v1 * 8 + ax1]
76+
for k1 in T.unroll(0, 8):
77+
for v_, i, j in T.grid(1, 4, 4):
78+
with T.block("C"):
79+
vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i)
80+
vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j)
81+
vk = T.axis.R(2048, k0 * 8 + k1)
82+
T.reads([C[vi, vj], A_shared[vi, vk], B[vk, vj]])
83+
T.writes([C[vi, vj]])
84+
with T.init():
85+
C[vi, vj] = T.float32(0)
86+
C[vi, vj] = C[vi, vj] + A_shared[vi, vk] * B[vk, vj]
87+
88+
89+
@T.prim_func
90+
def cuda_matmul_read_at_ab(a: T.handle, b: T.handle, c: T.handle) -> None:
91+
A = T.match_buffer(a, [2048, 2048], dtype="float32")
92+
B = T.match_buffer(b, [2048, 2048], dtype="float32")
93+
C = T.match_buffer(c, [2048, 2048], dtype="float32")
94+
A_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared")
95+
B_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared")
96+
for by in T.thread_binding(0, 32, thread="blockIdx.y"):
97+
for bx in T.thread_binding(0, 32, thread="blockIdx.x"):
98+
for vy in T.thread_binding(0, 2, thread="vthread.y"):
99+
for vx in T.thread_binding(0, 2, thread="vthread.x"):
100+
for ty in T.thread_binding(0, 8, thread="threadIdx.y"):
101+
for tx in T.thread_binding(0, 8, thread="threadIdx.x"):
102+
for k0 in T.serial(0, 256):
103+
with T.block("A_shared"):
104+
v0 = T.axis.S(32, by)
105+
v1 = T.axis.S(256, k0)
106+
T.reads([A[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]])
107+
T.writes([A_shared[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]])
108+
T.block_attr({"auto_copy":1})
109+
for ax0, ax1 in T.grid(64, 8):
110+
A_shared[v0 * 64 + ax0, v1 * 8 + ax1] = A[v0 * 64 + ax0, v1 * 8 + ax1]
111+
with T.block("B_shared"):
112+
v0 = T.axis.S(256, k0)
113+
v1 = T.axis.S(32, bx)
114+
T.reads([B[v0 * 8 : v0 * 8 + 8, v1 * 64 : v1 * 64 + 64]])
115+
T.writes([B_shared[v0 * 8 : v0 * 8 + 8, v1 * 64 : v1 * 64 + 64]])
116+
T.block_attr({"auto_copy":1})
117+
for ax0, ax1 in T.grid(8, 64):
118+
B_shared[v0 * 8 + ax0, v1 * 64 + ax1] = B[v0 * 8 + ax0, v1 * 64 + ax1]
119+
for k1 in T.unroll(0, 8):
120+
for v_, i, j in T.grid(1, 4, 4):
121+
with T.block("C"):
122+
vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i)
123+
vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j)
124+
vk = T.axis.R(2048, k0 * 8 + k1)
125+
T.reads([C[vi, vj], A_shared[vi, vk], B_shared[vk, vj]])
126+
T.writes([C[vi, vj]])
127+
with T.init():
128+
C[vi, vj] = T.float32(0)
129+
C[vi, vj] = C[vi, vj] + A_shared[vi, vk] * B_shared[vk, vj]
130+
131+
@T.prim_func
132+
def cuda_matmul_write_at_c(a: T.handle, b: T.handle, c: T.handle) -> None:
133+
A = T.match_buffer(a, [2048, 2048], dtype="float32")
134+
B = T.match_buffer(b, [2048, 2048], dtype="float32")
135+
C = T.match_buffer(c, [2048, 2048], dtype="float32")
136+
A_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared")
137+
B_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared")
138+
C_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared")
139+
for by in T.thread_binding(0, 32, thread="blockIdx.y"):
140+
for bx in T.thread_binding(0, 32, thread="blockIdx.x"):
141+
for vy in T.thread_binding(0, 2, thread="vthread.y"):
142+
for vx in T.thread_binding(0, 2, thread="vthread.x"):
143+
for ty in T.thread_binding(0, 8, thread="threadIdx.y"):
144+
for tx in T.thread_binding(0, 8, thread="threadIdx.x"):
145+
for k0 in T.serial(0, 256):
146+
with T.block("A_shared"):
147+
v0 = T.axis.S(32, by)
148+
v1 = T.axis.S(256, k0)
149+
T.reads([A[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]])
150+
T.writes([A_shared[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]])
151+
T.block_attr({"auto_copy":1})
152+
for ax0, ax1 in T.grid(64, 8):
153+
A_shared[v0 * 64 + ax0, v1 * 8 + ax1] = A[v0 * 64 + ax0, v1 * 8 + ax1]
154+
with T.block("B_shared"):
155+
v0 = T.axis.S(256, k0)
156+
v1 = T.axis.S(32, bx)
157+
T.reads([B[v0 * 8 : v0 * 8 + 8, v1 * 64 : v1 * 64 + 64]])
158+
T.writes([B_shared[v0 * 8 : v0 * 8 + 8, v1 * 64 : v1 * 64 + 64]])
159+
T.block_attr({"auto_copy":1})
160+
for ax0, ax1 in T.grid(8, 64):
161+
B_shared[v0 * 8 + ax0, v1 * 64 + ax1] = B[v0 * 8 + ax0, v1 * 64 + ax1]
162+
for k1 in T.unroll(0, 8):
163+
for v_, i, j in T.grid(1, 4, 4):
164+
with T.block("C"):
165+
vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i)
166+
vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j)
167+
vk = T.axis.R(2048, k0 * 8 + k1)
168+
T.reads([C_shared[vi, vj], A_shared[vi, vk], B_shared[vk, vj]])
169+
T.writes([C_shared[vi, vj]])
170+
with T.init():
171+
C_shared[vi, vj] = T.float32(0)
172+
C_shared[vi, vj] = C_shared[vi, vj] + A_shared[vi, vk] * B_shared[vk, vj]
173+
with T.block("C_shared"):
174+
v0 = T.axis.S(32, by)
175+
v1 = T.axis.S(32, bx)
176+
T.reads([C_shared[v0 * 64 : v0 * 64 + 64, v1 * 64 : v1 * 64 + 64]])
177+
T.writes([C[v0 * 64 : v0 * 64 + 64, v1 * 64 : v1 * 64 + 64]])
178+
T.block_attr({"auto_copy":1})
179+
for ax0, ax1 in T.grid(64, 64):
180+
C[v0 * 64 + ax0, v1 * 64 + ax1] = C_shared[v0 * 64 + ax0, v1 * 64 + ax1]
181+
182+
183+
# pylint: enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks,not-callable
184+
# fmt: on
185+
186+
187+
def test_read_at_global_to_shared_a():
188+
sch = tir.Schedule(cuda_matmul, debug_mask="all")
189+
block = sch.get_block("C")
190+
# pylint: disable=invalid-name
191+
_by, _bx, _vy, _vx, _ty, _tx, k0, _k1, _, _i, _j = sch.get_loops(block)
192+
# pylint: enable=invalid-name
193+
sch.read_at(k0, block, 1, "shared")
194+
tvm.ir.assert_structural_equal(sch.mod["main"], cuda_matmul_read_at_a)
195+
verify_trace_roundtrip(sch, cuda_matmul)
196+
197+
198+
def test_read_at_global_to_shared_ab():
199+
sch = tir.Schedule(cuda_matmul_read_at_a, debug_mask="all")
200+
block = sch.get_block("C")
201+
# pylint: disable=invalid-name
202+
_by, _bx, _vy, _vx, _ty, _tx, k0, _k1, _, _i, _j = sch.get_loops(block)
203+
# pylint: enable=invalid-name
204+
sch.read_at(k0, block, 2, "shared")
205+
tvm.ir.assert_structural_equal(sch.mod["main"], cuda_matmul_read_at_ab)
206+
verify_trace_roundtrip(sch, cuda_matmul_read_at_a)
207+
208+
209+
def test_read_at_local_to_shared_c():
210+
sch = tir.Schedule(cuda_matmul_read_at_ab, debug_mask="all")
211+
block = sch.get_block("C")
212+
# pylint: disable=invalid-name
213+
_by, _bx, _vy, _vx, _ty, tx, _k0, _k1, _, _i, _j = sch.get_loops(block)
214+
# pylint: enable=invalid-name
215+
sch.write_at(tx, block, 0, "shared")
216+
tvm.ir.assert_structural_equal(sch.mod["main"], cuda_matmul_write_at_c)
217+
verify_trace_roundtrip(sch, cuda_matmul_read_at_ab)
218+
219+
220+
if __name__ == "__main__":
221+
tvm.testing.main()

0 commit comments

Comments
 (0)