|
| 1 | +# or more contributor license agreements. See the NOTICE file |
| 2 | +# distributed with this work for additional information |
| 3 | +# regarding copyright ownership. The ASF licenses this file |
| 4 | +# to you under the Apache License, Version 2.0 (the |
| 5 | +# "License"); you may not use this file except in compliance |
| 6 | +# with the License. You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, |
| 11 | +# software distributed under the License is distributed on an |
| 12 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 13 | +# KIND, either express or implied. See the License for the |
| 14 | +# specific language governing permissions and limitations |
| 15 | +# under the License. |
| 16 | +# pylint: disable=missing-function-docstring,missing-module-docstring |
| 17 | +import sys |
| 18 | + |
| 19 | +import pytest |
| 20 | + |
| 21 | +import tvm |
| 22 | +from tvm import tir |
| 23 | +from tvm.script import tir as T |
| 24 | +from tvm.tir.schedule.testing import verify_trace_roundtrip |
| 25 | + |
| 26 | + |
| 27 | +# fmt: off |
| 28 | +# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks,not-callable |
| 29 | + |
| 30 | +@T.prim_func |
| 31 | +def cuda_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=undefined-loop-variable |
| 32 | + A = T.match_buffer(a, [2048, 2048], "float32") |
| 33 | + B = T.match_buffer(b, [2048, 2048], "float32") |
| 34 | + C = T.match_buffer(c, [2048, 2048], "float32") |
| 35 | + for by in T.thread_binding(0, 32, thread = "blockIdx.y"): |
| 36 | + for bx in T.thread_binding(0, 32, thread = "blockIdx.x"): |
| 37 | + for vy in T.thread_binding(0, 2, thread = "vthread.y"): |
| 38 | + for vx in T.thread_binding(0, 2, thread = "vthread.x"): |
| 39 | + for ty in T.thread_binding(0, 8, thread = "threadIdx.y"): |
| 40 | + for tx in T.thread_binding(0, 8, thread = "threadIdx.x"): |
| 41 | + for k0 in T.serial(0, 256): |
| 42 | + for k1 in T.unroll(0, 8): |
| 43 | + for _, i, j in T.grid(1, 4, 4): |
| 44 | + with T.block("C"): |
| 45 | + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) |
| 46 | + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) |
| 47 | + vk = T.axis.R(2048, k0 * 8 + k1) |
| 48 | + T.reads([C[vi, vj], A[vi, vk], B[vk, vj]]) |
| 49 | + T.writes([C[vi, vj]]) |
| 50 | + with T.init(): |
| 51 | + C[vi, vj] = 0.0 |
| 52 | + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] |
| 53 | + |
| 54 | + |
| 55 | +@T.prim_func |
| 56 | +def cuda_matmul_read_at_a(a: T.handle, b: T.handle, c: T.handle) -> None: |
| 57 | + A = T.match_buffer(a, [2048, 2048], dtype="float32") |
| 58 | + B = T.match_buffer(b, [2048, 2048], dtype="float32") |
| 59 | + C = T.match_buffer(c, [2048, 2048], dtype="float32") |
| 60 | + A_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared") |
| 61 | + for by in T.thread_binding(0, 32, thread="blockIdx.y"): |
| 62 | + for bx in T.thread_binding(0, 32, thread="blockIdx.x"): |
| 63 | + for vy in T.thread_binding(0, 2, thread="vthread.y"): |
| 64 | + for vx in T.thread_binding(0, 2, thread="vthread.x"): |
| 65 | + for ty in T.thread_binding(0, 8, thread="threadIdx.y"): |
| 66 | + for tx in T.thread_binding(0, 8, thread="threadIdx.x"): |
| 67 | + for k0 in T.serial(0, 256): |
| 68 | + with T.block("A_shared"): |
| 69 | + v0 = T.axis.S(32, by) |
| 70 | + v1 = T.axis.S(256, k0) |
| 71 | + T.reads([A[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) |
| 72 | + T.writes([A_shared[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) |
| 73 | + T.block_attr({"auto_copy":1}) |
| 74 | + for ax0, ax1 in T.grid(64, 8): |
| 75 | + A_shared[v0 * 64 + ax0, v1 * 8 + ax1] = A[v0 * 64 + ax0, v1 * 8 + ax1] |
| 76 | + for k1 in T.unroll(0, 8): |
| 77 | + for v_, i, j in T.grid(1, 4, 4): |
| 78 | + with T.block("C"): |
| 79 | + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) |
| 80 | + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) |
| 81 | + vk = T.axis.R(2048, k0 * 8 + k1) |
| 82 | + T.reads([C[vi, vj], A_shared[vi, vk], B[vk, vj]]) |
| 83 | + T.writes([C[vi, vj]]) |
| 84 | + with T.init(): |
| 85 | + C[vi, vj] = T.float32(0) |
| 86 | + C[vi, vj] = C[vi, vj] + A_shared[vi, vk] * B[vk, vj] |
| 87 | + |
| 88 | + |
| 89 | +@T.prim_func |
| 90 | +def cuda_matmul_read_at_ab(a: T.handle, b: T.handle, c: T.handle) -> None: |
| 91 | + A = T.match_buffer(a, [2048, 2048], dtype="float32") |
| 92 | + B = T.match_buffer(b, [2048, 2048], dtype="float32") |
| 93 | + C = T.match_buffer(c, [2048, 2048], dtype="float32") |
| 94 | + A_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared") |
| 95 | + B_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared") |
| 96 | + for by in T.thread_binding(0, 32, thread="blockIdx.y"): |
| 97 | + for bx in T.thread_binding(0, 32, thread="blockIdx.x"): |
| 98 | + for vy in T.thread_binding(0, 2, thread="vthread.y"): |
| 99 | + for vx in T.thread_binding(0, 2, thread="vthread.x"): |
| 100 | + for ty in T.thread_binding(0, 8, thread="threadIdx.y"): |
| 101 | + for tx in T.thread_binding(0, 8, thread="threadIdx.x"): |
| 102 | + for k0 in T.serial(0, 256): |
| 103 | + with T.block("A_shared"): |
| 104 | + v0 = T.axis.S(32, by) |
| 105 | + v1 = T.axis.S(256, k0) |
| 106 | + T.reads([A[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) |
| 107 | + T.writes([A_shared[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) |
| 108 | + T.block_attr({"auto_copy":1}) |
| 109 | + for ax0, ax1 in T.grid(64, 8): |
| 110 | + A_shared[v0 * 64 + ax0, v1 * 8 + ax1] = A[v0 * 64 + ax0, v1 * 8 + ax1] |
| 111 | + with T.block("B_shared"): |
| 112 | + v0 = T.axis.S(256, k0) |
| 113 | + v1 = T.axis.S(32, bx) |
| 114 | + T.reads([B[v0 * 8 : v0 * 8 + 8, v1 * 64 : v1 * 64 + 64]]) |
| 115 | + T.writes([B_shared[v0 * 8 : v0 * 8 + 8, v1 * 64 : v1 * 64 + 64]]) |
| 116 | + T.block_attr({"auto_copy":1}) |
| 117 | + for ax0, ax1 in T.grid(8, 64): |
| 118 | + B_shared[v0 * 8 + ax0, v1 * 64 + ax1] = B[v0 * 8 + ax0, v1 * 64 + ax1] |
| 119 | + for k1 in T.unroll(0, 8): |
| 120 | + for v_, i, j in T.grid(1, 4, 4): |
| 121 | + with T.block("C"): |
| 122 | + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) |
| 123 | + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) |
| 124 | + vk = T.axis.R(2048, k0 * 8 + k1) |
| 125 | + T.reads([C[vi, vj], A_shared[vi, vk], B_shared[vk, vj]]) |
| 126 | + T.writes([C[vi, vj]]) |
| 127 | + with T.init(): |
| 128 | + C[vi, vj] = T.float32(0) |
| 129 | + C[vi, vj] = C[vi, vj] + A_shared[vi, vk] * B_shared[vk, vj] |
| 130 | + |
| 131 | +@T.prim_func |
| 132 | +def cuda_matmul_write_at_c(a: T.handle, b: T.handle, c: T.handle) -> None: |
| 133 | + A = T.match_buffer(a, [2048, 2048], dtype="float32") |
| 134 | + B = T.match_buffer(b, [2048, 2048], dtype="float32") |
| 135 | + C = T.match_buffer(c, [2048, 2048], dtype="float32") |
| 136 | + A_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared") |
| 137 | + B_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared") |
| 138 | + C_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared") |
| 139 | + for by in T.thread_binding(0, 32, thread="blockIdx.y"): |
| 140 | + for bx in T.thread_binding(0, 32, thread="blockIdx.x"): |
| 141 | + for vy in T.thread_binding(0, 2, thread="vthread.y"): |
| 142 | + for vx in T.thread_binding(0, 2, thread="vthread.x"): |
| 143 | + for ty in T.thread_binding(0, 8, thread="threadIdx.y"): |
| 144 | + for tx in T.thread_binding(0, 8, thread="threadIdx.x"): |
| 145 | + for k0 in T.serial(0, 256): |
| 146 | + with T.block("A_shared"): |
| 147 | + v0 = T.axis.S(32, by) |
| 148 | + v1 = T.axis.S(256, k0) |
| 149 | + T.reads([A[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) |
| 150 | + T.writes([A_shared[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) |
| 151 | + T.block_attr({"auto_copy":1}) |
| 152 | + for ax0, ax1 in T.grid(64, 8): |
| 153 | + A_shared[v0 * 64 + ax0, v1 * 8 + ax1] = A[v0 * 64 + ax0, v1 * 8 + ax1] |
| 154 | + with T.block("B_shared"): |
| 155 | + v0 = T.axis.S(256, k0) |
| 156 | + v1 = T.axis.S(32, bx) |
| 157 | + T.reads([B[v0 * 8 : v0 * 8 + 8, v1 * 64 : v1 * 64 + 64]]) |
| 158 | + T.writes([B_shared[v0 * 8 : v0 * 8 + 8, v1 * 64 : v1 * 64 + 64]]) |
| 159 | + T.block_attr({"auto_copy":1}) |
| 160 | + for ax0, ax1 in T.grid(8, 64): |
| 161 | + B_shared[v0 * 8 + ax0, v1 * 64 + ax1] = B[v0 * 8 + ax0, v1 * 64 + ax1] |
| 162 | + for k1 in T.unroll(0, 8): |
| 163 | + for v_, i, j in T.grid(1, 4, 4): |
| 164 | + with T.block("C"): |
| 165 | + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) |
| 166 | + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) |
| 167 | + vk = T.axis.R(2048, k0 * 8 + k1) |
| 168 | + T.reads([C_shared[vi, vj], A_shared[vi, vk], B_shared[vk, vj]]) |
| 169 | + T.writes([C_shared[vi, vj]]) |
| 170 | + with T.init(): |
| 171 | + C_shared[vi, vj] = T.float32(0) |
| 172 | + C_shared[vi, vj] = C_shared[vi, vj] + A_shared[vi, vk] * B_shared[vk, vj] |
| 173 | + with T.block("C_shared"): |
| 174 | + v0 = T.axis.S(32, by) |
| 175 | + v1 = T.axis.S(32, bx) |
| 176 | + T.reads([C_shared[v0 * 64 : v0 * 64 + 64, v1 * 64 : v1 * 64 + 64]]) |
| 177 | + T.writes([C[v0 * 64 : v0 * 64 + 64, v1 * 64 : v1 * 64 + 64]]) |
| 178 | + T.block_attr({"auto_copy":1}) |
| 179 | + for ax0, ax1 in T.grid(64, 64): |
| 180 | + C[v0 * 64 + ax0, v1 * 64 + ax1] = C_shared[v0 * 64 + ax0, v1 * 64 + ax1] |
| 181 | + |
| 182 | + |
| 183 | +# pylint: enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks,not-callable |
| 184 | +# fmt: on |
| 185 | + |
| 186 | + |
| 187 | +def test_read_at_global_to_shared_a(): |
| 188 | + sch = tir.Schedule(cuda_matmul, debug_mask="all") |
| 189 | + block = sch.get_block("C") |
| 190 | + # pylint: disable=invalid-name |
| 191 | + _by, _bx, _vy, _vx, _ty, _tx, k0, _k1, _, _i, _j = sch.get_loops(block) |
| 192 | + # pylint: enable=invalid-name |
| 193 | + sch.read_at(k0, block, 1, "shared") |
| 194 | + tvm.ir.assert_structural_equal(sch.mod["main"], cuda_matmul_read_at_a) |
| 195 | + verify_trace_roundtrip(sch, cuda_matmul) |
| 196 | + |
| 197 | + |
| 198 | +def test_read_at_global_to_shared_ab(): |
| 199 | + sch = tir.Schedule(cuda_matmul_read_at_a, debug_mask="all") |
| 200 | + block = sch.get_block("C") |
| 201 | + # pylint: disable=invalid-name |
| 202 | + _by, _bx, _vy, _vx, _ty, _tx, k0, _k1, _, _i, _j = sch.get_loops(block) |
| 203 | + # pylint: enable=invalid-name |
| 204 | + sch.read_at(k0, block, 2, "shared") |
| 205 | + tvm.ir.assert_structural_equal(sch.mod["main"], cuda_matmul_read_at_ab) |
| 206 | + verify_trace_roundtrip(sch, cuda_matmul_read_at_a) |
| 207 | + |
| 208 | + |
| 209 | +def test_read_at_local_to_shared_c(): |
| 210 | + sch = tir.Schedule(cuda_matmul_read_at_ab, debug_mask="all") |
| 211 | + block = sch.get_block("C") |
| 212 | + # pylint: disable=invalid-name |
| 213 | + _by, _bx, _vy, _vx, _ty, tx, _k0, _k1, _, _i, _j = sch.get_loops(block) |
| 214 | + # pylint: enable=invalid-name |
| 215 | + sch.write_at(tx, block, 0, "shared") |
| 216 | + tvm.ir.assert_structural_equal(sch.mod["main"], cuda_matmul_write_at_c) |
| 217 | + verify_trace_roundtrip(sch, cuda_matmul_read_at_ab) |
| 218 | + |
| 219 | + |
| 220 | +if __name__ == "__main__": |
| 221 | + tvm.testing.main() |
0 commit comments