Skip to content

Commit 072991e

Browse files
committed
[Fix][Dlight] (Low-batched-)GeMV on small spatial loops
This PR fixes an issue in the dlight GeMV rule and the low-batch GeMV rule. The issue happens when the inner spatial loop has small length (e.g., in the MoE gate layer, this length is usually 8). The error is because the GeMV scheduling does not make sure that each TIR block reads/writes the same number of local registers, and this inconsistency leads to wrong generated code. For example, in the schedule (prior to this fix), the first TIR block was scheduled to assign each thread 2 local registers, while the second block was scheduled to assign each thread 1 local register, which is incorrect. Unfortunately, this error only shows up when the spatial loop has small length. One regression test is added.
1 parent 77a7b01 commit 072991e

File tree

3 files changed

+138
-7
lines changed

3 files changed

+138
-7
lines changed

python/tvm/dlight/gpu/gemv.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -342,12 +342,16 @@ def apply(
342342
sch.reverse_compute_at(rf2, loop=bx, preserve_unit_loops=True)
343343
tr, vec_c, *ts_tile_s = sch.get_loops(block=rf2)[1:]
344344
ts_tile_s = sch.fuse(*ts_tile_s)
345-
ts, tile_s = sch.split(ts_tile_s, factors=[TS, None], preserve_unit_iters=True)
345+
ts_o, ts_i, tile_s = sch.split(
346+
ts_tile_s, factors=[None, TS, TILE_S], preserve_unit_iters=True
347+
)
346348
tile_s, vec_s = sch.split(
347349
tile_s,
348350
factors=[None, get_max_factor(TILE_S, [1, 2, 4, 8])],
349351
preserve_unit_iters=True,
350352
)
353+
assert sch.get(ts_o).extent.value == 1
354+
ts = sch.fuse(ts_o, ts_i)
351355
sch.reorder(ts, tr, tile_s, vec_s, vec_c)
352356
sch.bind(ts, TAG_S)
353357
sch.bind(tr, TAG_R)
@@ -357,7 +361,11 @@ def apply(
357361
sch.reverse_compute_at(gemv, loop=bx, preserve_unit_loops=True)
358362
tr, *ts_tile_s = sch.get_loops(block=gemv)[1:]
359363
ts_tile_s = sch.fuse(*ts_tile_s)
360-
ts, tile_s = sch.split(ts_tile_s, factors=[TS, None], preserve_unit_iters=True)
364+
ts_o, ts_i, tile_s = sch.split(
365+
ts_tile_s, factors=[None, TS, TILE_S], preserve_unit_iters=True
366+
)
367+
assert sch.get(ts_o).extent.value == 1
368+
ts = sch.fuse(ts_o, ts_i)
361369
sch.reorder(tile_s, ts, tr)
362370
sch.bind(ts, TAG_S)
363371
sch.bind(tr, TAG_R)
@@ -411,7 +419,11 @@ def apply(
411419
sch.reverse_compute_at(epilogue, bx, preserve_unit_loops=True)
412420
ts_tile_s = sch.fuse(*sch.get_loops(epilogue)[1:])
413421
ts_tile_s = sch.get_loops(epilogue)[-1]
414-
ts, tile_s = sch.split(ts_tile_s, factors=[TS, None], preserve_unit_iters=True)
422+
ts_o, ts_i, tile_s = sch.split(
423+
ts_tile_s, factors=[None, TS, TILE_S], preserve_unit_iters=True
424+
)
425+
assert sch.get(ts_o).extent.value == 1
426+
ts = sch.fuse(ts_o, ts_i)
415427
sch.bind(ts, TAG_S)
416428
sch.set_scope(block, 0, "local")
417429
# pylint: enable=invalid-name

python/tvm/dlight/gpu/low_batch_gemv.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
"""A rule for low-batch GEMM / decode-GEMM using GEMV schedule."""
1818
import re
1919
from functools import reduce
20-
from typing import List, Optional, Union, Set
20+
from typing import List, Optional, Set, Union
2121

2222
from tvm import DataType, arith, ir, tir
2323
from tvm.target import Target
@@ -428,12 +428,16 @@ def apply(
428428
sch.reverse_compute_at(rf2, loop=bx, preserve_unit_loops=True)
429429
tr, vec_c, batch_loop, *ts_tile_s = sch.get_loops(block=rf2)[2:]
430430
ts_tile_s = sch.fuse(*ts_tile_s)
431-
ts, tile_s = sch.split(ts_tile_s, factors=[TS, None], preserve_unit_iters=True)
431+
ts_o, ts_i, tile_s = sch.split(
432+
ts_tile_s, factors=[None, TS, TILE_S], preserve_unit_iters=True
433+
)
432434
tile_s, vec_s = sch.split(
433435
tile_s,
434436
factors=[None, get_max_factor(TILE_S, [1, 2, 4, 8])],
435437
preserve_unit_iters=True,
436438
)
439+
assert sch.get(ts_o).extent.value == 1
440+
ts = sch.fuse(ts_o, ts_i)
437441
sch.reorder(ts, tr, tile_s, batch_loop, vec_s, vec_c)
438442
sch.bind(ts, TAG_S)
439443
sch.bind(tr, TAG_R)
@@ -444,7 +448,11 @@ def apply(
444448

445449
tr, batch_loop, *ts_tile_s = sch.get_loops(block=gemv)[2:]
446450
ts_tile_s = sch.fuse(*ts_tile_s)
447-
ts, tile_s = sch.split(ts_tile_s, factors=[TS, None], preserve_unit_iters=True)
451+
ts_o, ts_i, tile_s = sch.split(
452+
ts_tile_s, factors=[None, TS, TILE_S], preserve_unit_iters=True
453+
)
454+
assert sch.get(ts_o).extent.value == 1
455+
ts = sch.fuse(ts_o, ts_i)
448456
sch.reorder(tile_s, batch_loop, ts, tr)
449457
sch.bind(ts, TAG_S)
450458
sch.bind(tr, TAG_R)
@@ -499,7 +507,11 @@ def apply(
499507
sch.reverse_compute_at(epilogue, bx, preserve_unit_loops=True)
500508
ts_tile_s = sch.fuse(*sch.get_loops(epilogue)[3:])
501509
ts_tile_s = sch.get_loops(epilogue)[-1]
502-
ts, tile_s = sch.split(ts_tile_s, factors=[TS, None], preserve_unit_iters=True)
510+
ts_o, ts_i, tile_s = sch.split(
511+
ts_tile_s, factors=[None, TS, TILE_S], preserve_unit_iters=True
512+
)
513+
assert sch.get(ts_o).extent.value == 1
514+
ts = sch.fuse(ts_o, ts_i)
503515
sch.bind(ts, TAG_S)
504516
sch.set_scope(block, 0, "local")
505517

tests/python/dlight/test_gpu_low_batch_gemv.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,5 +275,112 @@ def before(var_A: T.handle, var_B: T.handle, matmul: T.Buffer((T.int64(1), T.int
275275
tvm.ir.assert_structural_equal(mod["main"], before)
276276

277277

278+
def test_small_spatial_axis():
279+
280+
@T.prim_func(private=True)
281+
def func(var_A: T.handle, B: T.Buffer((T.int64(8), T.int64(4096)), "float16"), var_C: T.handle):
282+
T.func_attr({"tir.noalias": T.bool(True)})
283+
batch_size = T.int64()
284+
A = T.match_buffer(var_A, (batch_size, T.int64(4096)), "float16")
285+
C = T.match_buffer(var_C, (batch_size, T.int64(8)), "float16")
286+
for i0, i1, k in T.grid(batch_size, T.int64(8), T.int64(4096)):
287+
with T.block("NT_matmul"):
288+
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
289+
T.reads(A[v_i0, v_k], B[v_i1, v_k])
290+
T.writes(C[v_i0, v_i1])
291+
with T.init():
292+
C[v_i0, v_i1] = T.float16(0)
293+
C[v_i0, v_i1] = C[v_i0, v_i1] + A[v_i0, v_k] * B[v_i1, v_k]
294+
295+
# fmt: off
296+
@T.prim_func(private=True)
297+
def expected(var_A: T.handle, B: T.Buffer((T.int64(8), T.int64(4096)), "float16"), var_C: T.handle):
298+
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
299+
batch_size = T.int64()
300+
A = T.match_buffer(var_A, (batch_size, T.int64(4096)), "float16")
301+
C = T.match_buffer(var_C, (batch_size, T.int64(8)), "float16")
302+
# with T.block("root"):
303+
C_pad_local = T.alloc_buffer(((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(8)), "float16", scope="local")
304+
C_pad_rf_local = T.alloc_buffer((T.int64(128), (batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(8)), "float16", scope="local")
305+
C_pad_rf_local_1 = T.alloc_buffer((T.int64(32), (batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(8)), "float16", scope="local")
306+
for ax0_0 in T.thread_binding((batch_size + T.int64(3)) // T.int64(4), thread="blockIdx.y"):
307+
for u_fused_ax1_fused_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"):
308+
for u_fused_ax1_fused_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
309+
for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
310+
for ax0_1_init, u_fused_ax1_fused_fused_2_init in T.grid(T.int64(4), T.int64(2)):
311+
for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init in T.vectorized(T.int64(4)):
312+
with T.block("NT_matmul_rf_init"):
313+
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * T.int64(4) + ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init)
314+
v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1_init)
315+
v1 = T.axis.spatial(T.int64(8), u_fused_ax1_fused_fused_0 * T.int64(32) + u_fused_ax1_fused_fused_1 * T.int64(2) + u_fused_ax1_fused_fused_2_init)
316+
T.where((u_fused_ax1_fused_fused_0 * T.int64(16) + u_fused_ax1_fused_fused_1) * T.int64(2) + u_fused_ax1_fused_fused_2_init < T.int64(8))
317+
T.reads()
318+
T.writes(C_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, v1])
319+
C_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, v1] = T.float16(0)
320+
for ax2_fused_u_fused_0 in T.serial(T.int64(16), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
321+
for ax0_1, u_fused_ax1_fused_fused_2, ax2_fused_u_fused_2 in T.grid(T.int64(4), T.int64(2), T.int64(2)):
322+
for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 in T.vectorized(T.int64(4)):
323+
with T.block("NT_matmul_rf_update"):
324+
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T.axis.spatial(T.int64(128), ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * T.int64(4) + ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1)
325+
v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1)
326+
v1 = T.axis.spatial(T.int64(8), u_fused_ax1_fused_fused_0 * T.int64(32) + u_fused_ax1_fused_fused_1 * T.int64(2) + u_fused_ax1_fused_fused_2)
327+
vax2_fused_u_fused_0, vax2_fused_u_fused_2 = T.axis.remap("RR", [ax2_fused_u_fused_0, ax2_fused_u_fused_2])
328+
T.where((u_fused_ax1_fused_fused_0 * T.int64(16) + u_fused_ax1_fused_fused_1) * T.int64(2) + u_fused_ax1_fused_fused_2 < T.int64(8))
329+
T.reads(C_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, v1], A[v0, vax2_fused_u_fused_0 * T.int64(256) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_u_fused_2 * T.int64(4) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused % T.int64(4)], B[v1, vax2_fused_u_fused_0 * T.int64(256) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_u_fused_2 * T.int64(4) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused % T.int64(4)])
330+
T.writes(C_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, v1])
331+
C_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, v1] = C_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0, v1] + T.if_then_else(v0 < batch_size, A[v0, vax2_fused_u_fused_0 * T.int64(256) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_u_fused_2 * T.int64(4) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused % T.int64(4)], T.float16(0)) * B[v1, vax2_fused_u_fused_0 * T.int64(256) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax2_fused_u_fused_2 * T.int64(4) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused % T.int64(4)]
332+
for ax3_fused_0_ax3_fused_1_fused in T.thread_binding(T.int64(16), thread="threadIdx.y"):
333+
for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
334+
for ax3_fused_2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
335+
for ax2 in range(T.int64(4)):
336+
for ax3_fused_2_1 in T.vectorized(T.int64(2)):
337+
with T.block("NT_matmul_rf_init"):
338+
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T.axis.spatial(T.int64(32), ax0)
339+
v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2)
340+
v1 = T.axis.spatial(T.int64(8), ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1)
341+
T.where((T.Mul(T.int64(0), T.int64(16)) + ax3_fused_0_ax3_fused_1_fused % T.int64(16)) * T.int64(2) + (ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1) < T.int64(8))
342+
T.reads()
343+
T.writes(C_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, v1])
344+
C_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, v1] = T.float16(0)
345+
for ax1 in range(T.int64(4)):
346+
with T.block("NT_matmul_rf_update"):
347+
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
348+
v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2)
349+
v1 = T.axis.spatial(T.int64(8), ax3_fused_0_ax3_fused_1_fused * T.int64(2) + ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1)
350+
T.where((T.Mul(T.int64(0), T.int64(16)) + ax3_fused_0_ax3_fused_1_fused % T.int64(16)) * T.int64(2) + (ax3_fused_2_0 * T.int64(2) + ax3_fused_2_1) < T.int64(8))
351+
T.reads(C_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, v1], C_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * T.int64(4) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1, v0, v1])
352+
T.writes(C_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, v1])
353+
C_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, v1] = C_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, v1] + C_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * T.int64(4) + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1, v0, v1]
354+
for ax2_fused_2, ax1 in T.grid(T.int64(2), T.int64(4)):
355+
for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(T.int64(16), thread="threadIdx.y"):
356+
for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
357+
with T.block("NT_matmul"):
358+
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T.axis.reduce(T.int64(32), ax0)
359+
v0 = T.axis.spatial((batch_size + T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax1)
360+
v1 = T.axis.spatial(T.int64(8), ax2_fused_0_ax2_fused_1_fused * T.int64(2) + ax2_fused_2)
361+
T.where((T.Mul(T.int64(0), T.int64(16)) + ax2_fused_0_ax2_fused_1_fused % T.int64(16)) * T.int64(2) + ax2_fused_2 < T.int64(8))
362+
T.reads(C_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, v1])
363+
T.writes(C_pad_local[v0, v1])
364+
with T.init():
365+
C_pad_local[v0, v1] = T.float16(0)
366+
C_pad_local[v0, v1] = C_pad_local[v0, v1] + C_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0, v1]
367+
for ax0 in range(T.int64(4)):
368+
for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(T.int64(16), thread="threadIdx.y"):
369+
for ax1_fused_2 in range(T.int64(2)):
370+
with T.block("C_pad"):
371+
v0 = T.axis.spatial(batch_size, ax0_0 * T.int64(4) + ax0)
372+
v1 = T.axis.spatial(T.int64(8), ax1_fused_0_ax1_fused_1_fused * T.int64(2) + ax1_fused_2)
373+
T.where((ax0_0 - (batch_size + T.int64(3)) // T.int64(4) < T.int64(0) or ax0_0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 < batch_size and (T.Mul(T.int64(0), T.int64(16)) + ax1_fused_0_ax1_fused_1_fused % T.int64(16)) * T.int64(2) + ax1_fused_2 < T.int64(8))
374+
T.reads(C_pad_local[v0, v1])
375+
T.writes(C[v0, v1])
376+
C[v0, v1] = C_pad_local[v0, v1]
377+
# fmt: on
378+
379+
mod = tvm.IRModule({"main": func})
380+
with Target("cuda"):
381+
mod = dl.ApplyDefaultSchedule(dl.gpu.LowBatchGEMV(4))(mod)
382+
tvm.ir.assert_structural_equal(mod["main"], expected)
383+
384+
278385
if __name__ == "__main__":
279386
tvm.testing.main()

0 commit comments

Comments
 (0)