88import numpy as np
99
1010
11+ def shared_16x16_to_ldmatrix_32x8_layout (i , j ):
12+ thread_id = 4 * (i % 8 ) + (j % 8 ) // 2
13+ return thread_id , 4 * (j // 8 ) + (i // 8 ) * 2 + (j % 2 )
14+
15+
1116@T .prim_func
1217def ldmatrix_a_desc (a : T .handle , c : T .handle ) -> None :
1318 A_shared = T .match_buffer (a , (16 , 16 ), "float16" , align = 128 , offset_factor = 16 , scope = "shared" )
@@ -21,11 +26,15 @@ def ldmatrix_a_desc(a: T.handle, c: T.handle) -> None:
2126 with T .block ("A_shared_warp" ):
2227 v0 , v1 = T .axis .remap ("SS" , [ax0 , ax1 ])
2328 T .reads (A_shared [v0 , v1 ])
24- T .writes (A_warp [v0 % 8 * 4 + v1 % 8 // 2 , v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2 ])
25- A_warp [v0 % 8 * 4 + v1 % 8 // 2 , v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2 ] = A_shared [
26- v0 , v1
27- ]
2829
30+ thread_id , y = shared_16x16_to_ldmatrix_32x8_layout (v0 , v1 )
31+ T .writes (A_warp [thread_id , y ])
32+ A_warp [thread_id , y ] = A_shared [v0 , v1 ]
33+
34+ # T.writes(A_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2])
35+ # A_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2] = A_shared[
36+ # v0, v1
37+ # ]
2938
3039@T .prim_func
3140def ldmatrix_a_impl (a : T .handle , c : T .handle ) -> None :
@@ -390,22 +399,39 @@ def tile_wmma_fragment(block_read, height):
390399 sch .reorder (i0 , j0 , i1 , j1 )
391400 return i1
392401
393- def shared_16x16_to_ldmatrix_32x8_layout (i , j ):
394- i_0 = i // 16
395- j_0 = j // 16
396-
397- i = i % 16
398- j = j % 16
399-
400- thread_id = 4 * (i % 8 ) + (j % 8 ) // 2
401- return i_0 , j_0 , thread_id , 4 * (j // 8 ) + (i // 8 ) * 2 + (j % 8 ) % 2
402-
403402 loop_a = tile_wmma_fragment (A_warp , 16 )
404403 loop_b = tile_wmma_fragment (B_warp , 16 )
405404
406- sch .transform_layout (A_warp , 0 , "write" , index_map = shared_16x16_to_ldmatrix_32x8_layout )
407- sch .transform_layout (B_warp , 0 , "write" , index_map = shared_16x16_to_ldmatrix_32x8_layout )
408- sch .transform_layout (C_warp , 0 , "read" , index_map = shared_16x16_to_ldmatrix_32x8_layout )
405+ sch .transform_layout (
406+ A_warp ,
407+ 0 ,
408+ "write" ,
409+ index_map = lambda i , j : (
410+ i // 16 ,
411+ j // 16 ,
412+ * shared_16x16_to_ldmatrix_32x8_layout (i % 16 , j % 16 ),
413+ ),
414+ )
415+ sch .transform_layout (
416+ B_warp ,
417+ 0 ,
418+ "write" ,
419+ index_map = lambda i , j : (
420+ i // 16 ,
421+ j // 16 ,
422+ * shared_16x16_to_ldmatrix_32x8_layout (i % 16 , j % 16 ),
423+ ),
424+ )
425+ sch .transform_layout (
426+ C_warp ,
427+ 0 ,
428+ "read" ,
429+ index_map = lambda i , j : (
430+ i // 16 ,
431+ j // 16 ,
432+ * shared_16x16_to_ldmatrix_32x8_layout (i % 16 , j % 16 ),
433+ ),
434+ )
409435
410436 sch .tensorize (loop_a , "mma.ldmatrix_a" )
411437 sch .tensorize (loop_b , "mma.ldmatrix_b" )
@@ -438,44 +464,44 @@ def shared_16x16_to_ldmatrix_32x8_layout(i, j):
438464schedule (sch )
439465print (sch .mod .script ())
440466
441- if tune :
442- with tempfile .TemporaryDirectory () as work_dir :
443- sch = ms .tune_tir (
444- mod = workload ,
445- target = tvm .target .Target ("nvidia/geforce-rtx-3070" ),
446- config = ms .TuneConfig (
447- strategy = "evolutionary" ,
448- num_trials_per_iter = 32 ,
449- max_trials_per_task = 128 ,
450- max_trials_global = 128 ,
451- ),
452- work_dir = work_dir ,
453- space = ms .space_generator .ScheduleFn (schedule ),
454- )
455- if sch is None :
456- print ("No valid schedule found!" )
457- else :
458- print (sch .mod .script ())
459- print (sch .trace )
460- else :
461- target = "cuda"
462- f = tvm .build (sch .mod ["main" ], target = target , name = "dense" )
463-
464- dev = tvm .device ("cuda" , 0 )
465- a_np = np .random .uniform (size = (N , K )).astype ("float16" )
466- b_np = np .random .uniform (size = (K , M )).astype ("float16" )
467- c_np = np .dot (a_np .astype ("float32" ), b_np .astype ("float32" ))
468- a = tvm .nd .array (a_np , dev )
469- b = tvm .nd .array (b_np , dev )
470- c = tvm .nd .array (np .zeros ((M , N ), dtype = "float32" ), dev )
471- f = tvm .build (sch .mod ["main" ], target = "cuda" , name = "dense" )
472-
473- print (f .imported_modules [0 ].get_source ())
474- f (a , b , c )
475- tvm .testing .assert_allclose (c .numpy (), c_np , rtol = 1e-3 )
476- print ("ok" )
477-
478- evaluator = f .time_evaluator (f .entry_name , dev , number = 1000 )
479- gflops = (N * M * K ) * 2 / 1e9
480- time_ms = evaluator (a , b , c ).mean * 1e3
481- print ("matmul with tensor core: %f ms, %f GFLOPS" % (time_ms , gflops / (time_ms / 1e3 )))
467+ # if tune:
468+ # with tempfile.TemporaryDirectory() as work_dir:
469+ # sch = ms.tune_tir(
470+ # mod=workload,
471+ # target=tvm.target.Target("nvidia/geforce-rtx-3070"),
472+ # config=ms.TuneConfig(
473+ # strategy="evolutionary",
474+ # num_trials_per_iter=32,
475+ # max_trials_per_task=128,
476+ # max_trials_global=128,
477+ # ),
478+ # work_dir=work_dir,
479+ # space=ms.space_generator.ScheduleFn(schedule),
480+ # )
481+ # if sch is None:
482+ # print("No valid schedule found!")
483+ # else:
484+ # print(sch.mod.script())
485+ # print(sch.trace)
486+ # else:
487+ # target = "cuda"
488+ # f = tvm.build(sch.mod["main"], target=target, name="dense")
489+
490+ # dev = tvm.device("cuda", 0)
491+ # a_np = np.random.uniform(size=(N, K)).astype("float16")
492+ # b_np = np.random.uniform(size=(K, M)).astype("float16")
493+ # c_np = np.dot(a_np.astype("float32"), b_np.astype("float32"))
494+ # a = tvm.nd.array(a_np, dev)
495+ # b = tvm.nd.array(b_np, dev)
496+ # c = tvm.nd.array(np.zeros((M, N), dtype="float32"), dev)
497+ # f = tvm.build(sch.mod["main"], target="cuda", name="dense")
498+
499+ # print(f.imported_modules[0].get_source())
500+ # f(a, b, c)
501+ # tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)
502+ # print("ok")
503+
504+ # evaluator = f.time_evaluator(f.entry_name, dev, number=1000)
505+ # gflops = (N * M * K) * 2 / 1e9
506+ # time_ms = evaluator(a, b, c).mean * 1e3
507+ # print("matmul with tensor core: %f ms, %f GFLOPS" % (time_ms, gflops / (time_ms / 1e3)))
0 commit comments