@@ -289,7 +289,7 @@ def lambda_b(i, j):
289289
290290 # fetch to C_warp 16 * 8 -> 32 * 4
291291 C_warp = sch .cache_write (block , 0 , "warp" )
292- # sch.reverse_compute_at(C_warp, sch.get_loops(block)[0])
292+ sch .reverse_compute_at (C_warp , sch .get_loops (block )[0 ])
293293 # need to do a reverse_compute_at to place it under blockidx.x
294294
295295 sch .transform_layout (
@@ -307,6 +307,7 @@ def lambda_b(i, j):
307307 fused_2 = sch .fuse (f_0 , f_3 )
308308 sch .bind (fused_1 , "threadIdx.x" )
309309
310+
310311 block_init_c = sch .decompose_reduction (block , sch .get_loops (block )[1 ])
311312
312313 block_init_c = sch .get_block ("C_init" )
@@ -345,6 +346,13 @@ def lambda_b(i, j):
345346 print (f .imported_modules [0 ].get_source ())
346347 tvm .testing .assert_allclose (c .numpy (), c_np , rtol = 1e-3 )
347348
349+ print ("ok" )
350+
351+ evaluator = f .time_evaluator (f .entry_name , dev , number = 100 )
352+ gflops = (N * M * K ) * 2 / 1e9
353+ time_ms = evaluator (a , b , c ).mean * 1e3
354+ print ("matmul with tensor core: %f ms, %f GFLOPS" % (time_ms , gflops / (time_ms / 1e3 )))
355+
348356
349357if __name__ == "__main__" :
350358 test_integration_matmul ()
0 commit comments