Skip to content

Commit c2d0744

Browse files
committed
16x8x16 trans working
1 parent c2e314c commit c2d0744

File tree

3 files changed

+405
-24
lines changed

3 files changed

+405
-24
lines changed

tests/python/unittest/test_mma_16x8x16_4k_tune_simple.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -411,20 +411,20 @@ def tile_wmma_fragment(block_read, height):
411411

412412
f = tvm.build(sch.mod["main"], target="cuda", name="dense")
413413

414-
# dev = tvm.device("cuda", 0)
415-
# a_np = np.random.uniform(size=(N, K)).astype("float16")
416-
# b_np = np.random.uniform(size=(K, M)).astype("float16")
417-
# c_np = np.dot(a_np.astype("float32"), b_np.astype("float32"))
418-
# a = tvm.nd.array(a_np, dev)
419-
# b = tvm.nd.array(b_np, dev)
420-
# c = tvm.nd.array(np.zeros((M, N), dtype="float32"), dev)
421-
422-
# print(f.imported_modules[0].get_source())
423-
# f(a, b, c)
414+
dev = tvm.device("cuda", 0)
415+
a_np = np.random.uniform(size=(N, K)).astype("float16")
416+
b_np = np.random.uniform(size=(K, M)).astype("float16")
417+
c_np = np.dot(a_np.astype("float32"), b_np.astype("float32"))
418+
a = tvm.nd.array(a_np, dev)
419+
b = tvm.nd.array(b_np, dev)
420+
c = tvm.nd.array(np.zeros((M, N), dtype="float32"), dev)
421+
422+
print(f.imported_modules[0].get_source())
423+
f(a, b, c)
424424
# tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)
425425
# print("ok")
426426

427-
# evaluator = f.time_evaluator(f.entry_name, dev, number=1000)
428-
# gflops = (N * M * K) * 2 / 1e9
429-
# time_ms = evaluator(a, b, c).mean * 1e3
430-
# print("matmul with tensor core: %f ms, %f GFLOPS" % (time_ms, gflops / (time_ms / 1e3)))
427+
evaluator = f.time_evaluator(f.entry_name, dev, number=10)
428+
gflops = (N * M * K) * 2 / 1e9
429+
time_ms = evaluator(a, b, c).mean * 1e3
430+
print("matmul with tensor core: %f ms, %f GFLOPS" % (time_ms, gflops / (time_ms / 1e3)))

tests/python/unittest/test_mma_16x8x16_simple.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -317,16 +317,17 @@ def fetch_to_shared(block, idx):
317317
target = "cuda"
318318

319319
f = tvm.build(sch.mod["main"], target=target, name="dense")
320-
# dev = tvm.device(target, 0)
320+
dev = tvm.device(target, 0)
321321

322-
# a_np = np.random.uniform(size=(16, K)).astype("float16")
323-
# b_np = np.random.uniform(size=(K, K)).astype("float16")
324-
# c_np = np.dot(a_np.astype("float32"), b_np.astype("float32"))
322+
a_np = np.random.uniform(size=(16, K)).astype("float16")
323+
b_np = np.random.uniform(size=(K, K)).astype("float16")
324+
c_np = np.dot(a_np.astype("float32"), b_np.astype("float32"))
325325

326-
# a = tvm.nd.array(a_np, dev)
327-
# b = tvm.nd.array(b_np, dev)
328-
# c = tvm.nd.array(np.zeros((16, K), dtype="float32"), dev)
326+
a = tvm.nd.array(a_np, dev)
327+
b = tvm.nd.array(b_np, dev)
328+
c = tvm.nd.array(np.zeros((16, K), dtype="float32"), dev)
329329

330-
# # print(f.imported_modules[0].get_source())
331-
# f(a, b, c)
332-
# tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)
330+
# print(f.imported_modules[0].get_source())
331+
f(a, b, c)
332+
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)
333+
print("ok")

0 commit comments

Comments
 (0)