Skip to content

Commit 00df308

Browse files
committed
fixed missing reverse_compute_at
1 parent 93f9fe7 commit 00df308

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

tests/python/unittest/test_mma_16x8x8_4k.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ def lambda_b(i, j):
289289

290290
# fetch to C_warp 16 * 8 -> 32 * 4
291291
C_warp = sch.cache_write(block, 0, "warp")
292-
# sch.reverse_compute_at(C_warp, sch.get_loops(block)[0])
292+
sch.reverse_compute_at(C_warp, sch.get_loops(block)[0])
293293
# need to do a reverse_compute_at to place it under blockidx.x
294294

295295
sch.transform_layout(
@@ -307,6 +307,7 @@ def lambda_b(i, j):
307307
fused_2 = sch.fuse(f_0, f_3)
308308
sch.bind(fused_1, "threadIdx.x")
309309

310+
310311
block_init_c = sch.decompose_reduction(block, sch.get_loops(block)[1])
311312

312313
block_init_c = sch.get_block("C_init")
@@ -345,6 +346,13 @@ def lambda_b(i, j):
345346
print(f.imported_modules[0].get_source())
346347
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)
347348

349+
print("ok")
350+
351+
evaluator = f.time_evaluator(f.entry_name, dev, number=100)
352+
gflops = (N * M * K) * 2 / 1e9
353+
time_ms = evaluator(a, b, c).mean * 1e3
354+
print("matmul with tensor core: %f ms, %f GFLOPS" % (time_ms, gflops / (time_ms / 1e3)))
355+
348356

349357
if __name__ == "__main__":
350358
test_integration_matmul()

0 commit comments

Comments
 (0)