Skip to content

Commit dd8ccf9

Browse files
committed
poking with the parser
1 parent 596582c commit dd8ccf9

File tree

2 files changed

+112
-58
lines changed

2 files changed

+112
-58
lines changed

python/tvm/script/parser.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,12 +554,14 @@ def transform_Assign(self, node):
554554
4.1 var = T.allocate()
555555
"""
556556

557+
print("parsing ", node.rhs.func_name)
557558
if isinstance(node.rhs, ast.Call):
558559
# Pattern 1 & Pattern 4
559560
if isinstance(node.rhs.func_name, ast.Op):
560561
func = None
561562
else:
562563
func = self.transform(node.rhs.func_name)
564+
print(func)
563565

564566
if isinstance(func, WithScopeHandler):
565567
if not func.concise_scope or not func.def_symbol:
@@ -577,6 +579,31 @@ def transform_Assign(self, node):
577579
arg_list = self.parse_arg_list(func, node.rhs)
578580
func.handle(node, self.context, arg_list, node.rhs.func_name.span)
579581
return self.parse_body(node)
582+
elif callable(func):
583+
args = [self.transform(arg) for arg in node.rhs.params]
584+
out = func(*args)
585+
print(out)
586+
print(node.lhs)
587+
assert len(out) == len(node.lhs)
588+
589+
lhs_vars = []
590+
for ast_var, value in zip(node.lhs, out):
591+
var = tvm.te.var(
592+
ast_var.id.name,
593+
"int32",
594+
span=tvm_span_from_synr(ast_var.span),
595+
)
596+
self.context.update_symbol(var.name, var, node)
597+
lhs_vars.append(var)
598+
599+
body = self.parse_body(node)
600+
601+
for var, value in reversed(list(zip(lhs_vars, out))):
602+
self.context.remove_symbol(var.name)
603+
body = tvm.tir.LetStmt(var, value, body, span=tvm_span_from_synr(node.span))
604+
605+
return body
606+
580607
if isinstance(node.rhs, (ast.Call, ast.Constant)):
581608
# Pattern 4 of let binding
582609
value = self.transform(node.rhs)
@@ -593,6 +620,7 @@ def transform_Assign(self, node):
593620
if node.ty is None and hasattr(value, "dtype"):
594621
var_ty = value.dtype
595622
else:
623+
print(node.ty, ast_var)
596624
var_ty = self.parse_type(node.ty, ast_var)
597625

598626
var = tvm.te.var(

tests/python/unittest/test_mma_16x8x16_4k_tune.py

Lines changed: 84 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@
88
import numpy as np
99

1010

11+
def shared_16x16_to_ldmatrix_32x8_layout(i, j):
12+
thread_id = 4 * (i % 8) + (j % 8) // 2
13+
return thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 2)
14+
15+
1116
@T.prim_func
1217
def ldmatrix_a_desc(a: T.handle, c: T.handle) -> None:
1318
A_shared = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared")
@@ -21,11 +26,15 @@ def ldmatrix_a_desc(a: T.handle, c: T.handle) -> None:
2126
with T.block("A_shared_warp"):
2227
v0, v1 = T.axis.remap("SS", [ax0, ax1])
2328
T.reads(A_shared[v0, v1])
24-
T.writes(A_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2])
25-
A_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2] = A_shared[
26-
v0, v1
27-
]
2829

30+
thread_id, y = shared_16x16_to_ldmatrix_32x8_layout(v0, v1)
31+
T.writes(A_warp[thread_id, y])
32+
A_warp[thread_id, y] = A_shared[v0, v1]
33+
34+
# T.writes(A_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2])
35+
# A_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2] = A_shared[
36+
# v0, v1
37+
# ]
2938

3039
@T.prim_func
3140
def ldmatrix_a_impl(a: T.handle, c: T.handle) -> None:
@@ -390,22 +399,39 @@ def tile_wmma_fragment(block_read, height):
390399
sch.reorder(i0, j0, i1, j1)
391400
return i1
392401

393-
def shared_16x16_to_ldmatrix_32x8_layout(i, j):
394-
i_0 = i // 16
395-
j_0 = j // 16
396-
397-
i = i % 16
398-
j = j % 16
399-
400-
thread_id = 4 * (i % 8) + (j % 8) // 2
401-
return i_0, j_0, thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 8) % 2
402-
403402
loop_a = tile_wmma_fragment(A_warp, 16)
404403
loop_b = tile_wmma_fragment(B_warp, 16)
405404

406-
sch.transform_layout(A_warp, 0, "write", index_map=shared_16x16_to_ldmatrix_32x8_layout)
407-
sch.transform_layout(B_warp, 0, "write", index_map=shared_16x16_to_ldmatrix_32x8_layout)
408-
sch.transform_layout(C_warp, 0, "read", index_map=shared_16x16_to_ldmatrix_32x8_layout)
405+
sch.transform_layout(
406+
A_warp,
407+
0,
408+
"write",
409+
index_map=lambda i, j: (
410+
i // 16,
411+
j // 16,
412+
*shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16),
413+
),
414+
)
415+
sch.transform_layout(
416+
B_warp,
417+
0,
418+
"write",
419+
index_map=lambda i, j: (
420+
i // 16,
421+
j // 16,
422+
*shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16),
423+
),
424+
)
425+
sch.transform_layout(
426+
C_warp,
427+
0,
428+
"read",
429+
index_map=lambda i, j: (
430+
i // 16,
431+
j // 16,
432+
*shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16),
433+
),
434+
)
409435

410436
sch.tensorize(loop_a, "mma.ldmatrix_a")
411437
sch.tensorize(loop_b, "mma.ldmatrix_b")
@@ -438,44 +464,44 @@ def shared_16x16_to_ldmatrix_32x8_layout(i, j):
438464
schedule(sch)
439465
print(sch.mod.script())
440466

441-
if tune:
442-
with tempfile.TemporaryDirectory() as work_dir:
443-
sch = ms.tune_tir(
444-
mod=workload,
445-
target=tvm.target.Target("nvidia/geforce-rtx-3070"),
446-
config=ms.TuneConfig(
447-
strategy="evolutionary",
448-
num_trials_per_iter=32,
449-
max_trials_per_task=128,
450-
max_trials_global=128,
451-
),
452-
work_dir=work_dir,
453-
space=ms.space_generator.ScheduleFn(schedule),
454-
)
455-
if sch is None:
456-
print("No valid schedule found!")
457-
else:
458-
print(sch.mod.script())
459-
print(sch.trace)
460-
else:
461-
target = "cuda"
462-
f = tvm.build(sch.mod["main"], target=target, name="dense")
463-
464-
dev = tvm.device("cuda", 0)
465-
a_np = np.random.uniform(size=(N, K)).astype("float16")
466-
b_np = np.random.uniform(size=(K, M)).astype("float16")
467-
c_np = np.dot(a_np.astype("float32"), b_np.astype("float32"))
468-
a = tvm.nd.array(a_np, dev)
469-
b = tvm.nd.array(b_np, dev)
470-
c = tvm.nd.array(np.zeros((M, N), dtype="float32"), dev)
471-
f = tvm.build(sch.mod["main"], target="cuda", name="dense")
472-
473-
print(f.imported_modules[0].get_source())
474-
f(a, b, c)
475-
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)
476-
print("ok")
477-
478-
evaluator = f.time_evaluator(f.entry_name, dev, number=1000)
479-
gflops = (N * M * K) * 2 / 1e9
480-
time_ms = evaluator(a, b, c).mean * 1e3
481-
print("matmul with tensor core: %f ms, %f GFLOPS" % (time_ms, gflops / (time_ms / 1e3)))
467+
# if tune:
468+
# with tempfile.TemporaryDirectory() as work_dir:
469+
# sch = ms.tune_tir(
470+
# mod=workload,
471+
# target=tvm.target.Target("nvidia/geforce-rtx-3070"),
472+
# config=ms.TuneConfig(
473+
# strategy="evolutionary",
474+
# num_trials_per_iter=32,
475+
# max_trials_per_task=128,
476+
# max_trials_global=128,
477+
# ),
478+
# work_dir=work_dir,
479+
# space=ms.space_generator.ScheduleFn(schedule),
480+
# )
481+
# if sch is None:
482+
# print("No valid schedule found!")
483+
# else:
484+
# print(sch.mod.script())
485+
# print(sch.trace)
486+
# else:
487+
# target = "cuda"
488+
# f = tvm.build(sch.mod["main"], target=target, name="dense")
489+
490+
# dev = tvm.device("cuda", 0)
491+
# a_np = np.random.uniform(size=(N, K)).astype("float16")
492+
# b_np = np.random.uniform(size=(K, M)).astype("float16")
493+
# c_np = np.dot(a_np.astype("float32"), b_np.astype("float32"))
494+
# a = tvm.nd.array(a_np, dev)
495+
# b = tvm.nd.array(b_np, dev)
496+
# c = tvm.nd.array(np.zeros((M, N), dtype="float32"), dev)
497+
# f = tvm.build(sch.mod["main"], target="cuda", name="dense")
498+
499+
# print(f.imported_modules[0].get_source())
500+
# f(a, b, c)
501+
# tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)
502+
# print("ok")
503+
504+
# evaluator = f.time_evaluator(f.entry_name, dev, number=1000)
505+
# gflops = (N * M * K) * 2 / 1e9
506+
# time_ms = evaluator(a, b, c).mean * 1e3
507+
# print("matmul with tensor core: %f ms, %f GFLOPS" % (time_ms, gflops / (time_ms / 1e3)))

0 commit comments

Comments
 (0)