Skip to content

Commit c9d40b6

Browse files
committed
clean up
1 parent 5b2d486 commit c9d40b6

File tree

1 file changed

+2
-20
lines changed

1 file changed

+2
-20
lines changed

tests/python/unittest/test_mma_16x8x16_4k_tune.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,6 @@ def mma_fill_impl(a: T.handle) -> None:
285285
workload = te.create_prim_func(te_workload.matmul_fp16(n=N, m=M, k=K))
286286

287287
tune = False
288-
use_ldmatrix = True
289288

290289

291290
def schedule(sch: tir.Schedule):
@@ -408,25 +407,8 @@ def shared_16x16_to_ldmatrix_32x8_layout(i, j):
408407
sch.transform_layout(B_warp, 0, "write", index_map=shared_16x16_to_ldmatrix_32x8_layout)
409408
sch.transform_layout(C_warp, 0, "read", index_map=shared_16x16_to_ldmatrix_32x8_layout)
410409

411-
# return
412-
413-
if use_ldmatrix:
414-
sch.tensorize(loop_a, "mma.ldmatrix_a")
415-
sch.tensorize(loop_b, "mma.ldmatrix_b")
416-
else:
417-
warp_loop1, warp_loop2 = sch.get_loops(A_warp)[-2:]
418-
f_0, f_1 = sch.split(warp_loop1, factors=[None, 8])
419-
f_2, f_3 = sch.split(warp_loop2, factors=[None, 2])
420-
sch.reorder(f_1, f_2, f_0, f_3)
421-
fused_1 = sch.fuse(f_1, f_2)
422-
fused_2 = sch.fuse(f_0, f_3)
423-
sch.bind(fused_1, "threadIdx.x")
424-
425-
warp_loop1, warp_loop2 = sch.get_loops(B_warp)[-2:]
426-
f_0, f_1 = sch.split(warp_loop1, factors=[4, 2])
427-
sch.reorder(warp_loop2, f_0, f_1)
428-
fused_1 = sch.fuse(warp_loop2, f_0)
429-
sch.bind(fused_1, "threadIdx.x")
410+
sch.tensorize(loop_a, "mma.ldmatrix_a")
411+
sch.tensorize(loop_b, "mma.ldmatrix_b")
430412

431413
mma_loop = sch.get_loops(block_inner)[-3]
432414
sch.tensorize(mma_loop, "mma_sync")

0 commit comments

Comments
 (0)