Skip to content

Commit bcf212d

Browse files
committed
seems to work
1 parent dd8ccf9 commit bcf212d

File tree

2 files changed

+5
-27
lines changed

2 files changed

+5
-27
lines changed

python/tvm/script/parser.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -554,14 +554,12 @@ def transform_Assign(self, node):
554554
4.1 var = T.allocate()
555555
"""
556556

557-
print("parsing ", node.rhs.func_name)
558557
if isinstance(node.rhs, ast.Call):
559558
# Pattern 1 & Pattern 4
560559
if isinstance(node.rhs.func_name, ast.Op):
561560
func = None
562561
else:
563562
func = self.transform(node.rhs.func_name)
564-
print(func)
565563

566564
if isinstance(func, WithScopeHandler):
567565
if not func.concise_scope or not func.def_symbol:
@@ -582,27 +580,12 @@ def transform_Assign(self, node):
582580
elif callable(func):
583581
args = [self.transform(arg) for arg in node.rhs.params]
584582
out = func(*args)
585-
print(out)
586-
print(node.lhs)
587583
assert len(out) == len(node.lhs)
588584

589-
lhs_vars = []
590585
for ast_var, value in zip(node.lhs, out):
591-
var = tvm.te.var(
592-
ast_var.id.name,
593-
"int32",
594-
span=tvm_span_from_synr(ast_var.span),
595-
)
596-
self.context.update_symbol(var.name, var, node)
597-
lhs_vars.append(var)
598-
599-
body = self.parse_body(node)
586+
self.context.update_symbol(ast_var.id.name, value, node)
600587

601-
for var, value in reversed(list(zip(lhs_vars, out))):
602-
self.context.remove_symbol(var.name)
603-
body = tvm.tir.LetStmt(var, value, body, span=tvm_span_from_synr(node.span))
604-
605-
return body
588+
return self.parse_body(node)
606589

607590
if isinstance(node.rhs, (ast.Call, ast.Constant)):
608591
# Pattern 4 of let binding

tests/python/unittest/test_mma_16x8x16_4k_tune.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,6 @@ def ldmatrix_a_desc(a: T.handle, c: T.handle) -> None:
3131
T.writes(A_warp[thread_id, y])
3232
A_warp[thread_id, y] = A_shared[v0, v1]
3333

34-
# T.writes(A_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2])
35-
# A_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2] = A_shared[
36-
# v0, v1
37-
# ]
3834

3935
@T.prim_func
4036
def ldmatrix_a_impl(a: T.handle, c: T.handle) -> None:
@@ -83,10 +79,9 @@ def ldmatrix_b_desc(a: T.handle, c: T.handle) -> None:
8379
with T.block("B_shared_warp"):
8480
v0, v1 = T.axis.remap("SS", [ax0, ax1])
8581
T.reads(B_shared[v0, v1])
86-
T.writes(B_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2])
87-
B_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2] = B_shared[
88-
v0, v1
89-
]
82+
thread_id, y = shared_16x16_to_ldmatrix_32x8_layout(v0, v1)
83+
T.writes(B_warp[thread_id, y])
84+
B_warp[thread_id, y] = B_shared[v0, v1]
9085

9186

9287
@T.prim_func

0 commit comments

Comments
 (0)