@@ -285,7 +285,6 @@ def mma_fill_impl(a: T.handle) -> None:
285285workload = te .create_prim_func (te_workload .matmul_fp16 (n = N , m = M , k = K ))
286286
287287tune = False
288- use_ldmatrix = True
289288
290289
291290def schedule (sch : tir .Schedule ):
@@ -408,25 +407,8 @@ def shared_16x16_to_ldmatrix_32x8_layout(i, j):
408407 sch .transform_layout (B_warp , 0 , "write" , index_map = shared_16x16_to_ldmatrix_32x8_layout )
409408 sch .transform_layout (C_warp , 0 , "read" , index_map = shared_16x16_to_ldmatrix_32x8_layout )
410409
411- # return
412-
413- if use_ldmatrix :
414- sch .tensorize (loop_a , "mma.ldmatrix_a" )
415- sch .tensorize (loop_b , "mma.ldmatrix_b" )
416- else :
417- warp_loop1 , warp_loop2 = sch .get_loops (A_warp )[- 2 :]
418- f_0 , f_1 = sch .split (warp_loop1 , factors = [None , 8 ])
419- f_2 , f_3 = sch .split (warp_loop2 , factors = [None , 2 ])
420- sch .reorder (f_1 , f_2 , f_0 , f_3 )
421- fused_1 = sch .fuse (f_1 , f_2 )
422- fused_2 = sch .fuse (f_0 , f_3 )
423- sch .bind (fused_1 , "threadIdx.x" )
424-
425- warp_loop1 , warp_loop2 = sch .get_loops (B_warp )[- 2 :]
426- f_0 , f_1 = sch .split (warp_loop1 , factors = [4 , 2 ])
427- sch .reorder (warp_loop2 , f_0 , f_1 )
428- fused_1 = sch .fuse (warp_loop2 , f_0 )
429- sch .bind (fused_1 , "threadIdx.x" )
410+ sch .tensorize (loop_a , "mma.ldmatrix_a" )
411+ sch .tensorize (loop_b , "mma.ldmatrix_b" )
430412
431413 mma_loop = sch .get_loops (block_inner )[- 3 ]
432414 sch .tensorize (mma_loop , "mma_sync" )
0 commit comments