Skip to content

Commit bf23fc5

Browse files
committed
mma intrin generation with meta programming
1 parent 5afb5f0 commit bf23fc5

File tree

8 files changed

+188
-489
lines changed

8 files changed

+188
-489
lines changed

python/tvm/tir/tensor_intrin/cuda.py

Lines changed: 162 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# under the License.
1717
# pylint: disable=invalid-name,missing-function-docstring
1818
"""Intrinsics for tensorization on NVIDIA GPU."""
19+
from .. import Cast
1920
from ..._ffi import register_func
2021
from ...runtime import convert
2122
from .. import TensorIntrin
@@ -46,6 +47,7 @@ def index_map_shared_16x16_to_ldmatrix_32x8_layout(i, j):
4647
lift = convert
4748

4849
M_DIM = 16
50+
N_DIM = 16
4951
WARP_SIZE = 32
5052
HALF_WARP = WARP_SIZE // 2
5153
HALF_WARP_expr = lift(HALF_WARP)
@@ -81,7 +83,6 @@ def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed):
8183
assert dtype == "int8"
8284

8385
if ldmatrix_col_major:
84-
print("foo")
8586
index_map = shared_32x16_to_ldmatrix_32x16_layout
8687
shared_offset = (
8788
lambda _, stride: stride
@@ -172,6 +173,148 @@ def ldmatrix_impl(warp_handle: T.handle, shared_handle: T.handle) -> None:
172173
return ldmatrix_desc, ldmatrix_impl
173174

174175

176+
def get_mma_intrin(k_dim, out_dtype, transposed):
177+
local_size = (M_DIM * k_dim) // WARP_SIZE
178+
local_size_out = (M_DIM * N_DIM) // 32
179+
180+
index_map_C = shared_16x16_to_ldmatrix_32x8_layout
181+
182+
if k_dim == 16:
183+
index_map_A = shared_16x16_to_ldmatrix_32x8_layout
184+
index_map_B = shared_16x16_to_ldmatrix_32x8_layout
185+
mma_prefix = "m16n8k16"
186+
elif k_dim == 32 and transposed:
187+
index_map_A = index_map_B = shared_16x32_to_ldmatrix_32x16_layout
188+
mma_prefix = "m16n8k32"
189+
elif k_dim == 32 and not transposed:
190+
index_map_A = shared_16x32_to_ldmatrix_32x16_layout
191+
index_map_B = shared_32x16_to_ldmatrix_32x16_layout
192+
mma_prefix = "m16n8k32"
193+
else:
194+
assert False
195+
196+
out_dtype_abbrv = {"float16": "fp16", "float32": "fp32", "int32": "int32"}[out_dtype]
197+
198+
if out_dtype in ["float16", "float32"]:
199+
in_dtype = "float16"
200+
in_dtype_abbrv = "fp16"
201+
else:
202+
in_dtype = "int8"
203+
in_dtype_abbrv = "int8"
204+
205+
def maybe_cast(v):
206+
if out_dtype in ["float32", "int32"]:
207+
return Cast(out_dtype, v)
208+
return v
209+
210+
def maybe_swap(i, j):
211+
if transposed:
212+
return j, i
213+
return i, j
214+
215+
@T.prim_func
216+
def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
217+
A = T.match_buffer(
218+
a, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, scope="warp"
219+
)
220+
B = T.match_buffer(
221+
b, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, scope="warp"
222+
)
223+
C = T.match_buffer(
224+
c, (WARP_SIZE, local_size_out), out_dtype, align=128, offset_factor=16, scope="warp"
225+
)
226+
227+
with T.block("root"):
228+
T.reads(
229+
C[0:WARP_SIZE, 0:local_size_out],
230+
A[0:WARP_SIZE, 0:local_size],
231+
B[0:WARP_SIZE, 0:local_size],
232+
)
233+
T.writes(C[0:WARP_SIZE, 0:local_size_out])
234+
235+
for i, j, k in T.grid(M_DIM, N_DIM, k_dim):
236+
with T.block("C"):
237+
i, j, k = T.axis.remap("SSR", [i, j, k])
238+
b_row_ind, b_col_ind = maybe_swap(k, j)
239+
240+
thread_id_C, local_id_C = index_map_C(i, j)
241+
thread_id_A, local_id_A = index_map_A(i, k)
242+
thread_id_B, local_id_B = index_map_B(b_row_ind, b_col_ind)
243+
244+
T.reads(
245+
C[thread_id_C, local_id_C],
246+
A[thread_id_A, local_id_A],
247+
B[thread_id_B, local_id_B],
248+
)
249+
T.writes(C[thread_id_C, local_id_C])
250+
251+
C[thread_id_C, local_id_C] += maybe_cast(
252+
A[thread_id_A, local_id_A]
253+
) * maybe_cast(B[thread_id_B, local_id_B])
254+
255+
@T.prim_func
256+
def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
257+
A = T.match_buffer(
258+
a, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, scope="warp"
259+
)
260+
B = T.match_buffer(
261+
b, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, scope="warp"
262+
)
263+
C = T.match_buffer(
264+
c, (WARP_SIZE, local_size_out), out_dtype, align=128, offset_factor=16, scope="warp"
265+
)
266+
267+
with T.block("root"):
268+
T.reads(
269+
C[0:WARP_SIZE, 0:local_size_out],
270+
A[0:WARP_SIZE, 0:local_size],
271+
B[0:WARP_SIZE, 0:local_size],
272+
)
273+
T.writes(C[0:WARP_SIZE, 0:local_size_out])
274+
tx = T.env_thread("threadIdx.x")
275+
T.launch_thread(tx, WARP_SIZE)
276+
277+
T.evaluate(
278+
T.ptx_mma(
279+
mma_prefix,
280+
"row",
281+
"col",
282+
in_dtype_abbrv,
283+
in_dtype_abbrv,
284+
out_dtype_abbrv,
285+
A.data,
286+
A.elem_offset + tx * lift(local_size),
287+
B.data,
288+
B.elem_offset + tx * lift(local_size),
289+
C.data,
290+
C.elem_offset + tx * lift(local_size_out),
291+
False,
292+
dtype=out_dtype,
293+
)
294+
)
295+
296+
T.evaluate(
297+
T.ptx_mma(
298+
mma_prefix,
299+
"row",
300+
"col",
301+
in_dtype_abbrv,
302+
in_dtype_abbrv,
303+
out_dtype_abbrv,
304+
A.data,
305+
A.elem_offset + tx * lift(local_size),
306+
B.data,
307+
B.elem_offset + tx * lift(local_size) + lift(local_size) // 2,
308+
C.data,
309+
C.elem_offset + tx * lift(local_size_out) + lift(local_size_out) // 2,
310+
False,
311+
dtype=out_dtype,
312+
)
313+
)
314+
315+
return mma_sync_desc, mma_sync_impl
316+
317+
175318
LDMATRIX_16x16_A_INTRIN = "mma.ldmatrix_16x16_a"
176319
TensorIntrin.register(LDMATRIX_16x16_A_INTRIN, *get_ldmatrix_intrin(16, "float16", False, False))
177320

@@ -191,3 +334,21 @@ def ldmatrix_impl(warp_handle: T.handle, shared_handle: T.handle) -> None:
191334

192335
LDMATRIX_16x32_B_TRANS_INTRIN = "mma.ldmatrix_16x32_b_trans"
193336
TensorIntrin.register(LDMATRIX_16x32_B_TRANS_INTRIN, *get_ldmatrix_intrin(32, "int8", True, True))
337+
338+
MMA_f16f16f32_INTRIN = "mma_f16f16f32"
339+
TensorIntrin.register(MMA_f16f16f32_INTRIN, *get_mma_intrin(16, "float32", False))
340+
341+
MMA_f16f16f32_TRANS_INTRIN = "mma_f16f16f32_trans"
342+
TensorIntrin.register(MMA_f16f16f32_TRANS_INTRIN, *get_mma_intrin(16, "float32", True))
343+
344+
MMA_f16f16f16_INTRIN = "mma_f16f16f16"
345+
TensorIntrin.register(MMA_f16f16f16_INTRIN, *get_mma_intrin(16, "float16", False))
346+
347+
MMA_f16f16f16_TRANS_INTRIN = "mma_f16f16f16_trans"
348+
TensorIntrin.register(MMA_f16f16f16_TRANS_INTRIN, *get_mma_intrin(16, "float16", True))
349+
350+
MMA_i8i8i32_INTRIN = "mma_i8i8i32"
351+
TensorIntrin.register(MMA_i8i8i32_INTRIN, *get_mma_intrin(32, "int32", False))
352+
353+
MMA_i8i8i32_TRANS_INTRIN = "mma_i8i8i32_trans"
354+
TensorIntrin.register(MMA_i8i8i32_TRANS_INTRIN, *get_mma_intrin(32, "int32", True))

tests/python/unittest/test_mma_16x8x16_4k_tune.py

Lines changed: 2 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -7,90 +7,13 @@
77
from tvm.tir.tensor_intrin.cuda import (
88
LDMATRIX_16x16_A_INTRIN,
99
LDMATRIX_16x16_B_INTRIN,
10+
MMA_f16f16f32_INTRIN,
1011
shared_16x16_to_ldmatrix_32x8_layout,
1112
)
1213
import tvm.testing
1314
import numpy as np
1415

1516

16-
@T.prim_func
17-
def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
18-
A = T.match_buffer(a, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
19-
B = T.match_buffer(b, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
20-
C = T.match_buffer(c, (32, 8), "float32", align=128, offset_factor=16, scope="warp")
21-
22-
with T.block("root"):
23-
T.reads(C[0:32, 0:8], A[0:32, 0:8], B[0:32, 0:8])
24-
T.writes(C[0:32, 0:8])
25-
for i, j, k in T.grid(16, 16, 16):
26-
with T.block("C"):
27-
i, j, k = T.axis.remap("SSR", [i, j, k])
28-
thread_id_C, local_id_C = shared_16x16_to_ldmatrix_32x8_layout(i, j)
29-
thread_id_A, local_id_A = shared_16x16_to_ldmatrix_32x8_layout(i, k)
30-
thread_id_B, local_id_B = shared_16x16_to_ldmatrix_32x8_layout(k, j)
31-
32-
T.reads(
33-
C[thread_id_C, local_id_C],
34-
A[thread_id_A, local_id_A],
35-
B[thread_id_B, local_id_B],
36-
)
37-
T.writes(C[thread_id_C, local_id_C])
38-
C[thread_id_C, local_id_C] += T.cast(
39-
A[thread_id_A, local_id_A], "float32"
40-
) * T.cast(B[thread_id_B, local_id_B], "float32")
41-
42-
43-
@T.prim_func
44-
def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
45-
A = T.match_buffer(a, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
46-
B = T.match_buffer(b, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
47-
C = T.match_buffer(c, (32, 8), "float32", align=128, offset_factor=16, scope="warp")
48-
49-
with T.block("root"):
50-
T.reads(C[0:32, 0:8], A[0:32, 0:8], B[0:32, 0:8])
51-
T.writes(C[0:32, 0:8])
52-
tx = T.env_thread("threadIdx.x")
53-
T.launch_thread(tx, 32)
54-
55-
T.evaluate(
56-
T.ptx_mma(
57-
"m16n8k16",
58-
"row",
59-
"col",
60-
"fp16",
61-
"fp16",
62-
"fp32",
63-
A.data,
64-
A.elem_offset + tx * 8,
65-
B.data,
66-
B.elem_offset + tx * 8,
67-
C.data,
68-
C.elem_offset + tx * 8,
69-
False,
70-
dtype="float32",
71-
)
72-
)
73-
74-
T.evaluate(
75-
T.ptx_mma(
76-
"m16n8k16",
77-
"row",
78-
"col",
79-
"fp16",
80-
"fp16",
81-
"fp32",
82-
A.data,
83-
A.elem_offset + tx * 8,
84-
B.data,
85-
B.elem_offset + tx * 8 + 4,
86-
C.data,
87-
C.elem_offset + tx * 8 + 4,
88-
False,
89-
dtype="float32",
90-
)
91-
)
92-
93-
9417
@T.prim_func
9518
def mma_store_desc(a: T.handle, c: T.handle) -> None:
9619
C_warp = T.match_buffer(a, [32, 8], dtype="float32", scope="warp")
@@ -160,7 +83,6 @@ def mma_fill_impl(a: T.handle) -> None:
16083
T.evaluate(T.mma_fill(8, C_warp.data, C_warp.elem_offset, dtype="float32"))
16184

16285

163-
tir.TensorIntrin.register("mma_sync", mma_sync_desc, mma_sync_impl)
16486
tir.TensorIntrin.register("mma_store", mma_store_desc, mma_store_impl)
16587
tir.TensorIntrin.register("mma_fill", mma_fill_desc, mma_fill_impl)
16688

@@ -291,7 +213,7 @@ def index_map(i, j):
291213

292214
sch.tensorize(loop_a, LDMATRIX_16x16_A_INTRIN)
293215
sch.tensorize(loop_b, LDMATRIX_16x16_B_INTRIN)
294-
sch.tensorize(sch.get_loops(block_inner)[-3], "mma_sync")
216+
sch.tensorize(sch.get_loops(block_inner)[-3], MMA_f16f16f32_INTRIN)
295217
sch.tensorize(sch.get_loops(block_init_c)[-2], "mma_fill")
296218
sch.tensorize(sch.get_loops(C_warp)[-2], "mma_store")
297219

0 commit comments

Comments
 (0)