Skip to content

Commit fa3b34c

Browse files
authored
[TVMScript] Fix printing ForNode annotations (#8891)
1 parent 2f1c845 commit fa3b34c

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

src/printer/tvmscript_printer.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1069,7 +1069,7 @@ Doc TVMScriptPrinter::PrintLoop(const For& loop) {
10691069
res << Print(loop->thread_binding.value()->thread_tag);
10701070
}
10711071
if (!loop->annotations.empty()) {
1072-
res << ", annotation = {";
1072+
res << ", annotations = {";
10731073
res << PrintAnnotations(loop->annotations);
10741074
res << "}";
10751075
}

tests/python/unittest/test_tvmscript_roundtrip.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2803,7 +2803,9 @@ def for_thread_binding(a: ty.handle, b: ty.handle) -> None:
28032803
B = tir.match_buffer(b, (16, 16), "float32")
28042804

28052805
for i in tir.thread_binding(0, 16, thread="threadIdx.x"):
2806-
for j in tir.thread_binding(0, 16, thread="threadIdx.y"):
2806+
for j in tir.thread_binding(
2807+
0, 16, thread="threadIdx.y", annotations={"attr_key": "attr_value"}
2808+
):
28072809
A[i, j] = B[i, j] + tir.float32(1)
28082810

28092811

@@ -2818,6 +2820,7 @@ def test_for_thread_binding():
28182820
assert isinstance(rt_func.body.body, tir.stmt.For)
28192821
assert rt_func.body.body.kind == 4
28202822
assert rt_func.body.body.thread_binding.thread_tag == "threadIdx.y"
2823+
assert rt_func.body.body.annotations["attr_key"] == "attr_value"
28212824

28222825

28232826
@tvm.script.tir

0 commit comments

Comments
 (0)