@@ -196,13 +196,43 @@ def mma_store_impl(a: T.handle, c: T.handle) -> None:
196196 tx = T .env_thread ("threadIdx.x" )
197197 T .launch_thread (tx , 32 )
198198
199- T .evaluate (T .mma_store ("m16n8" , C .access_ptr ("w" ), C_warp .data , C_warp .elem_offset , s1 , dtype = "float32" ))
199+ T .evaluate (T .mma_store (16 , 8 , C .access_ptr ("w" ), C_warp .data , C_warp .elem_offset , s1 , dtype = "float32" ))
200+
201+
202+ @T .prim_func
203+ def mma_fill_desc (a : T .handle ) -> None :
204+ C_warp = T .match_buffer (a , [32 , 4 ], dtype = "float32" , scope = "warp" )
205+
206+ with T .block ("root" ):
207+ T .reads ()
208+ T .writes (C_warp [0 :32 , 0 :4 ])
209+ for i0 , i1 in T .grid (32 , 4 ):
210+ with T .block ("C_warp" ):
211+ i_init = T .axis .spatial (16 , i1 // 2 * 8 + i0 // 4 )
212+ j_init = T .axis .spatial (8 , (i0 % 4 ) * 2 + i1 % 2 )
213+ T .reads ()
214+ T .writes (C_warp [i_init % 8 * 4 + j_init % 8 // 2 , i_init % 16 // 8 * 2 + j_init % 2 ])
215+ C_warp [i_init % 8 * 4 + j_init % 8 // 2 , i_init % 16 // 8 * 2 + j_init % 2 ] = T .float32 (0 )
216+
217+
218+ @T .prim_func
219+ def mma_fill_impl (a : T .handle ) -> None :
220+ C_warp = T .match_buffer (a , [32 , 4 ], dtype = "float32" , scope = "warp" , offset_factor = 1 )
221+
222+ with T .block ("root" ):
223+ T .reads ()
224+ T .writes (C_warp [0 :32 , 0 :4 ])
225+ tx = T .env_thread ("threadIdx.x" )
226+ T .launch_thread (tx , 32 )
227+
228+ T .evaluate (T .mma_fill (4 , C_warp .data , C_warp .elem_offset , dtype = "float32" ))
200229
201230
202231tir .TensorIntrin .register ("mma.ldmatrix_a" , ldmatrix_a_desc , ldmatrix_a_impl )
203232tir .TensorIntrin .register ("mma.ldmatrix_b" , ldmatrix_b_desc , ldmatrix_b_impl )
204233tir .TensorIntrin .register ("mma_sync" , mma_sync_desc , mma_sync_impl )
205234tir .TensorIntrin .register ("mma_store" , mma_store_desc , mma_store_impl )
235+ tir .TensorIntrin .register ("mma_fill" , mma_fill_desc , mma_fill_impl )
206236
207237N = 4096
208238M = 4096
@@ -381,7 +411,8 @@ def lambda_b(i, j):
381411 sch .reorder (f_1 , f_2 , f_0 , f_3 )
382412 fused_1 = sch .fuse (f_1 , f_2 )
383413 fused_2 = sch .fuse (f_0 , f_3 )
384- sch .bind (fused_1 , "threadIdx.x" )
414+ # sch.bind(fused_1, "threadIdx.x")
415+ sch .tensorize (fused_1 , "mma_fill" )
385416
386417 warp_loop1 , warp_loop2 = sch .get_loops (C_warp )[- 2 :]
387418 f_0 , f_1 = sch .split (warp_loop1 , factors = [None , 8 ])
@@ -394,7 +425,6 @@ def lambda_b(i, j):
394425 # return
395426
396427 sch .tensorize (fused_1 , "mma_store" )
397- # sch.bind(fused_1, "threadIdx.x")
398428
399429
400430ir_module = tvm .IRModule ({"main" : workload })
@@ -440,7 +470,7 @@ def lambda_b(i, j):
440470tvm .testing .assert_allclose (c .numpy (), c_np , rtol = 1e-3 )
441471print ("ok" )
442472
443- # evaluator = f.time_evaluator(f.entry_name, dev, number=1000)
444- # gflops = (N * M * K) * 2 / 1e9
445- # time_ms = evaluator(a, b, c).mean * 1e3
446- # print("matmul with tensor core: %f ms, %f GFLOPS" % (time_ms, gflops / (time_ms / 1e3)))
473+ evaluator = f .time_evaluator (f .entry_name , dev , number = 1000 )
474+ gflops = (N * M * K ) * 2 / 1e9
475+ time_ms = evaluator (a , b , c ).mean * 1e3
476+ print ("matmul with tensor core: %f ms, %f GFLOPS" % (time_ms , gflops / (time_ms / 1e3 )))
0 commit comments