@@ -363,13 +363,12 @@ def test_warp_shuffle_transform():
363363 class Before :
364364 @T .prim_func
365365 def main (A : T .handle ("float32" , "global" ), B : T .handle ("float32" , "global" )):
366- # blockIdx_x = T.int32()
367366 blockIdx_x = T .env_thread ("blockIdx.x" )
368367 threadIdx_x = T .env_thread ("threadIdx.x" )
369368 T .func_attr (
370369 {
371370 "calling_conv" : 2 ,
372- "global_symbol" : "default_function_kernel0 " ,
371+ "global_symbol" : "main " ,
373372 "target" : T .target (
374373 {
375374 "host" : {"keys" : ["cpu" ], "kind" : "llvm" , "tag" : "" },
@@ -410,7 +409,7 @@ def main(A: T.handle("float32", "global"), B: T.handle("float32", "global")):
410409 T .func_attr (
411410 {
412411 "calling_conv" : 2 ,
413- "global_symbol" : "default_function_kernel0 " ,
412+ "global_symbol" : "main " ,
414413 "target" : T .target (
415414 {
416415 "host" : {"keys" : ["cpu" ], "kind" : "llvm" , "tag" : "" },
@@ -448,34 +447,6 @@ def main(A: T.handle("float32", "global"), B: T.handle("float32", "global")):
448447
449448 tvm .ir .assert_structural_equal (after , Expected )
450449
451- @T .prim_func
452- def warp_shuffle (A : T .Buffer ([32 ], "float32" ), B : T .Buffer ([32 ], "float32" )) -> None :
453- for i in range (32 ):
454- with T .block ("warp_shuffle" ):
455- vi = T .axis .spatial (32 , i )
456- B [vi ] = A [vi % 4 * 8 + vi // 4 ] + T .float32 (1 )
457-
458-
459- @tvm .testing .requires_cuda
460- def test_warp_shuffle ():
461- mod = tvm .IRModule .from_expr (warp_shuffle )
462- sch = tvm .tir .Schedule (mod ["main" ])
463- blk = sch .get_block ("warp_shuffle" )
464- i , = sch .get_loops (blk )
465- io , ii = sch .split (i , [1 , 32 ])
466- sch .bind (ii , "threadIdx.x" )
467- A_warp = sch .cache_read (blk , 0 , "warp" )
468- sch .compute_at (A_warp , io )
469- sch .bind (sch .get_loops (A_warp )[- 1 ], "threadIdx.x" )
470- B_warp = sch .cache_write (blk , 0 , "warp" )
471- sch .reverse_compute_at (B_warp , io )
472- sch .bind (sch .get_loops (B_warp )[- 1 ], "threadIdx.x" )
473- sch .bind (io , "blockIdx.x" )
474- print (sch .mod ["main" ].script ())
475- f = tvm .build (sch .mod ["main" ], target = "cuda" )
476-
477450
478451if __name__ == "__main__" :
479- # tvm.testing.main()
480- test_warp_shuffle_transform ()
481- # test_warp_shuffle()
452+ tvm .testing .main ()
0 commit comments