Skip to content

Commit 3aea3ff

Browse files
committed
upd
1 parent 1f0a25c commit 3aea3ff

File tree

2 files changed

+3
-34
lines changed

2 files changed

+3
-34
lines changed

src/driver/driver_api.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -659,9 +659,7 @@ transform::Sequential DeviceModulePassManager(IRModule mixed_mod, Target target)
659659

660660
device_pass_list.push_back(tir::transform::BindTarget(target));
661661

662-
device_pass_list.push_back(transform::PrintIR());
663662
device_pass_list.push_back(tir::transform::LowerWarpMemory());
664-
device_pass_list.push_back(transform::PrintIR());
665663
device_pass_list.push_back(tir::transform::Simplify());
666664
device_pass_list.push_back(tir::transform::LowerCustomDatatypes());
667665
device_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo());

tests/python/unittest/test_tir_transform_lower_warp_memory.py

Lines changed: 3 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -363,13 +363,12 @@ def test_warp_shuffle_transform():
363363
class Before:
364364
@T.prim_func
365365
def main(A: T.handle("float32", "global"), B: T.handle("float32", "global")):
366-
# blockIdx_x = T.int32()
367366
blockIdx_x = T.env_thread("blockIdx.x")
368367
threadIdx_x = T.env_thread("threadIdx.x")
369368
T.func_attr(
370369
{
371370
"calling_conv": 2,
372-
"global_symbol": "default_function_kernel0",
371+
"global_symbol": "main",
373372
"target": T.target(
374373
{
375374
"host": {"keys": ["cpu"], "kind": "llvm", "tag": ""},
@@ -410,7 +409,7 @@ def main(A: T.handle("float32", "global"), B: T.handle("float32", "global")):
410409
T.func_attr(
411410
{
412411
"calling_conv": 2,
413-
"global_symbol": "default_function_kernel0",
412+
"global_symbol": "main",
414413
"target": T.target(
415414
{
416415
"host": {"keys": ["cpu"], "kind": "llvm", "tag": ""},
@@ -448,34 +447,6 @@ def main(A: T.handle("float32", "global"), B: T.handle("float32", "global")):
448447

449448
tvm.ir.assert_structural_equal(after, Expected)
450449

451-
@T.prim_func
452-
def warp_shuffle(A: T.Buffer([32], "float32"), B: T.Buffer([32], "float32")) -> None:
453-
for i in range(32):
454-
with T.block("warp_shuffle"):
455-
vi = T.axis.spatial(32, i)
456-
B[vi] = A[vi % 4 * 8 + vi // 4] + T.float32(1)
457-
458-
459-
@tvm.testing.requires_cuda
460-
def test_warp_shuffle():
461-
mod = tvm.IRModule.from_expr(warp_shuffle)
462-
sch = tvm.tir.Schedule(mod["main"])
463-
blk = sch.get_block("warp_shuffle")
464-
i, = sch.get_loops(blk)
465-
io, ii = sch.split(i, [1, 32])
466-
sch.bind(ii, "threadIdx.x")
467-
A_warp = sch.cache_read(blk, 0, "warp")
468-
sch.compute_at(A_warp, io)
469-
sch.bind(sch.get_loops(A_warp)[-1], "threadIdx.x")
470-
B_warp = sch.cache_write(blk, 0, "warp")
471-
sch.reverse_compute_at(B_warp, io)
472-
sch.bind(sch.get_loops(B_warp)[-1], "threadIdx.x")
473-
sch.bind(io, "blockIdx.x")
474-
print(sch.mod["main"].script())
475-
f = tvm.build(sch.mod["main"], target="cuda")
476-
477450

478451
if __name__ == "__main__":
479-
# tvm.testing.main()
480-
test_warp_shuffle_transform()
481-
# test_warp_shuffle()
452+
tvm.testing.main()

0 commit comments

Comments
 (0)