Skip to content

Commit 7fd9cd3

Browse files
committed
[TVMScript] Default to T.Buffer than T.buffer_decl
TVMScript parser supports both `T.Buffer` and `T.buffer_decl` interchangeably, which share the same semantics in TIR AST. However, `T.buffer_decl` is usually confused with `T.decl_buffer`. To clarify the semantics, we decide to print `T.Buffer` instead. Note that this PR is backward compatible with the previous behavior, i.e. the parser still parses TVMScript with `T.decl_buffer`, and the only difference is the print now produces `T.Buffer` instead by default.
1 parent fd3f803 commit 7fd9cd3

32 files changed

+607
-631
lines changed

python/tvm/script/ir_builder/tir/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@
1717
"""Package tvm.script.ir_builder.tir"""
1818
from .ir import * # pylint: disable=wildcard-import,redefined-builtin
1919
from .ir import boolean as bool # pylint: disable=redefined-builtin
20+
from .ir import buffer_decl as Buffer

python/tvm/script/parser/tir/entry.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class BufferProxy:
5555
def __call__(
5656
self,
5757
shape,
58-
dtype=None,
58+
dtype="float32",
5959
data=None,
6060
strides=None,
6161
elem_offset=None,
@@ -65,8 +65,6 @@ def __call__(
6565
buffer_type="",
6666
axis_separators=None,
6767
) -> Buffer:
68-
if dtype is None:
69-
raise ValueError("Data type must be specified when constructing buffer")
7068
return buffer_decl(
7169
shape,
7270
dtype=dtype,

src/script/printer/tir/buffer.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) //
209209
if (!d->IsVarDefined(buffer)) {
210210
if (Optional<Frame> opt_f = FindLowestVarDef(buffer, d)) {
211211
ExprDoc lhs = DefineBuffer(buffer, opt_f.value(), d);
212-
ExprDoc rhs = BufferDecl(buffer, "buffer_decl", // TODO(@junrushao): name confusing
213-
{}, p, opt_f.value(), d);
212+
ExprDoc rhs = BufferDecl(buffer, "Buffer", {}, p, opt_f.value(), d);
214213
opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, NullOpt));
215214
}
216215
}

src/script/printer/tir/ir.cc

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
3434
} else if (dtype == DataType::Bool()) {
3535
return LiteralDoc::Boolean(imm->value, imm_p->Attr("value"));
3636
} else {
37-
return TIR(d, runtime::DLDataType2String(dtype)) //
38-
->Call({LiteralDoc::Int(imm->value, imm_p->Attr("value"))});
37+
return TIR(d, DType2Str(dtype))->Call({LiteralDoc::Int(imm->value, imm_p->Attr("value"))});
3938
}
4039
});
4140

@@ -45,7 +44,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
4544
if (dtype == d->cfg->float_dtype) {
4645
return LiteralDoc::Float(imm->value, imm_p->Attr("value"));
4746
} else {
48-
return TIR(d, runtime::DLDataType2String(dtype)) //
47+
return TIR(d, DType2Str(dtype))
4948
->Call({LiteralDoc::Float(imm->value, imm_p->Attr("value"))});
5049
}
5150
});
@@ -61,8 +60,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
6160

6261
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
6362
.set_dispatch<PrimType>("", [](PrimType ty, ObjectPath p, IRDocsifier d) -> Doc {
64-
std::string dtype = ty->dtype.is_void() ? "void" : runtime::DLDataType2String(ty->dtype);
65-
return TIR(d, dtype);
63+
return TIR(d, DType2Str(ty->dtype));
6664
});
6765

6866
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)

src/script/printer/utils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ inline std::string Docsify(const ObjectRef& obj, const IRDocsifier& d, const Fra
6565
return DocToPythonScript(StmtBlockDoc(f->stmts), cfg);
6666
}
6767

68+
inline std::string DType2Str(const runtime::DataType& dtype) {
69+
return dtype.is_void() ? "void" : runtime::DLDataType2String(dtype);
70+
}
71+
6872
} // namespace printer
6973
} // namespace script
7074
} // namespace tvm

tests/python/contrib/test_ethosu/test_copy_compute_reordering.py

Lines changed: 73 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,16 @@ class AllOperatorsWithWeights:
2929
def main() -> None:
3030
# function attr dict
3131
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
32-
buffer1 = T.buffer_decl([8192], "int8")
33-
buffer2 = T.buffer_decl([128], "uint8")
34-
buffer3 = T.buffer_decl([32], "uint8")
35-
buffer4 = T.buffer_decl([112], "uint8")
36-
buffer5 = T.buffer_decl([32], "uint8")
37-
buffer6 = T.buffer_decl([112], "uint8")
38-
buffer7 = T.buffer_decl([32], "uint8")
39-
buffer8 = T.buffer_decl([112], "uint8")
40-
buffer9 = T.buffer_decl([32], "uint8")
41-
buffer10 = T.buffer_decl([2048], "int8")
32+
buffer1 = T.Buffer([8192], "int8")
33+
buffer2 = T.Buffer([128], "uint8")
34+
buffer3 = T.Buffer([32], "uint8")
35+
buffer4 = T.Buffer([112], "uint8")
36+
buffer5 = T.Buffer([32], "uint8")
37+
buffer6 = T.Buffer([112], "uint8")
38+
buffer7 = T.Buffer([32], "uint8")
39+
buffer8 = T.Buffer([112], "uint8")
40+
buffer9 = T.Buffer([32], "uint8")
41+
buffer10 = T.Buffer([2048], "int8")
4242
# body
4343
p1 = T.decl_buffer([128], "uint8")
4444
p2 = T.decl_buffer([112], "uint8")
@@ -77,16 +77,16 @@ class ReferenceModule:
7777
def main() -> None:
7878
# function attr dict
7979
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
80-
buffer1 = T.buffer_decl([8192], "int8")
81-
buffer2 = T.buffer_decl([128], "uint8")
82-
buffer3 = T.buffer_decl([32], "uint8")
83-
buffer4 = T.buffer_decl([112], "uint8")
84-
buffer5 = T.buffer_decl([32], "uint8")
85-
buffer6 = T.buffer_decl([112], "uint8")
86-
buffer7 = T.buffer_decl([32], "uint8")
87-
buffer8 = T.buffer_decl([112], "uint8")
88-
buffer9 = T.buffer_decl([32], "uint8")
89-
buffer10 = T.buffer_decl([2048], "int8")
80+
buffer1 = T.Buffer([8192], "int8")
81+
buffer2 = T.Buffer([128], "uint8")
82+
buffer3 = T.Buffer([32], "uint8")
83+
buffer4 = T.Buffer([112], "uint8")
84+
buffer5 = T.Buffer([32], "uint8")
85+
buffer6 = T.Buffer([112], "uint8")
86+
buffer7 = T.Buffer([32], "uint8")
87+
buffer8 = T.Buffer([112], "uint8")
88+
buffer9 = T.Buffer([32], "uint8")
89+
buffer10 = T.Buffer([2048], "int8")
9090
# body
9191
p1 = T.decl_buffer([128], "uint8")
9292
p2 = T.decl_buffer([112], "uint8")
@@ -123,16 +123,16 @@ class ReferenceModule:
123123
def main() -> None:
124124
# function attr dict
125125
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
126-
buffer1 = T.buffer_decl([8192], "int8")
127-
buffer2 = T.buffer_decl([128], "uint8")
128-
buffer3 = T.buffer_decl([32], "uint8")
129-
buffer4 = T.buffer_decl([112], "uint8")
130-
buffer5 = T.buffer_decl([32], "uint8")
131-
buffer6 = T.buffer_decl([112], "uint8")
132-
buffer7 = T.buffer_decl([32], "uint8")
133-
buffer8 = T.buffer_decl([112], "uint8")
134-
buffer9 = T.buffer_decl([32], "uint8")
135-
buffer10 = T.buffer_decl([2048], "int8")
126+
buffer1 = T.Buffer([8192], "int8")
127+
buffer2 = T.Buffer([128], "uint8")
128+
buffer3 = T.Buffer([32], "uint8")
129+
buffer4 = T.Buffer([112], "uint8")
130+
buffer5 = T.Buffer([32], "uint8")
131+
buffer6 = T.Buffer([112], "uint8")
132+
buffer7 = T.Buffer([32], "uint8")
133+
buffer8 = T.Buffer([112], "uint8")
134+
buffer9 = T.Buffer([32], "uint8")
135+
buffer10 = T.Buffer([2048], "int8")
136136
# body
137137
p1 = T.decl_buffer([128], "uint8")
138138
p2 = T.decl_buffer([112], "uint8")
@@ -167,8 +167,8 @@ class AllOperatorsWithoutWeights:
167167
@T.prim_func
168168
def main() -> None:
169169
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
170-
buffer1 = T.buffer_decl([36], "int8")
171-
buffer2 = T.buffer_decl([9], "int8")
170+
buffer1 = T.Buffer([36], "int8")
171+
buffer2 = T.Buffer([9], "int8")
172172
# body
173173
p1 = T.decl_buffer([96], "int8")
174174
T.evaluate(T.call_extern("ethosu_pooling", "int8", 3, 4, 3, 3, 0, 4, buffer1[0], 0, 0, 0, T.float32(1), 0, "NHWC", 12, 3, 1, "int8", 3, 2, 3, 3, 0, 2, p1[0], 0, 0, 0, T.float32(1), 0, "NHCWB16", 32, 16, 1, "MAX", 2, 1, 2, 1, 1, 1, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
@@ -189,11 +189,11 @@ class OperatorsWithAndWithoutWeights:
189189
@T.prim_func
190190
def main() -> None:
191191
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
192-
buffer1 = T.buffer_decl([97156], "int8")
193-
buffer2 = T.buffer_decl([80], "uint8")
194-
buffer3 = T.buffer_decl([64], "uint8")
195-
buffer4 = T.buffer_decl([96], "uint8")
196-
buffer5 = T.buffer_decl([32], "uint8")
192+
buffer1 = T.Buffer([97156], "int8")
193+
buffer2 = T.Buffer([80], "uint8")
194+
buffer3 = T.Buffer([64], "uint8")
195+
buffer4 = T.Buffer([96], "uint8")
196+
buffer5 = T.Buffer([32], "uint8")
197197
# body
198198
p1 = T.decl_buffer([390336], "int8")
199199
p2 = T.decl_buffer([80], "uint8")
@@ -224,11 +224,11 @@ class ReferenceModule:
224224
@T.prim_func
225225
def main() -> None:
226226
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
227-
buffer1 = T.buffer_decl([97156], "int8")
228-
buffer2 = T.buffer_decl([80], "uint8")
229-
buffer3 = T.buffer_decl([64], "uint8")
230-
buffer4 = T.buffer_decl([96], "uint8")
231-
buffer5 = T.buffer_decl([32], "uint8")
227+
buffer1 = T.Buffer([97156], "int8")
228+
buffer2 = T.Buffer([80], "uint8")
229+
buffer3 = T.Buffer([64], "uint8")
230+
buffer4 = T.Buffer([96], "uint8")
231+
buffer5 = T.Buffer([32], "uint8")
232232
# body
233233
p1 = T.decl_buffer([390336], "int8")
234234
p2 = T.decl_buffer([80], "uint8")
@@ -257,11 +257,11 @@ class ReferenceModule:
257257
@T.prim_func
258258
def main() -> None:
259259
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
260-
buffer1 = T.buffer_decl([97156], "int8")
261-
buffer2 = T.buffer_decl([80], "uint8")
262-
buffer3 = T.buffer_decl([64], "uint8")
263-
buffer4 = T.buffer_decl([96], "uint8")
264-
buffer5 = T.buffer_decl([32], "uint8")
260+
buffer1 = T.Buffer([97156], "int8")
261+
buffer2 = T.Buffer([80], "uint8")
262+
buffer3 = T.Buffer([64], "uint8")
263+
buffer4 = T.Buffer([96], "uint8")
264+
buffer5 = T.Buffer([32], "uint8")
265265
# body
266266
p1 = T.decl_buffer([390336], "int8")
267267
p2 = T.decl_buffer([80], "uint8")
@@ -289,14 +289,14 @@ class CopyToBufferWithLocalScope:
289289
@T.prim_func
290290
def main() -> None:
291291
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
292-
buffer1 = T.buffer_decl([64], "uint8")
293-
buffer2 = T.buffer_decl([48], "uint8")
294-
buffer3 = T.buffer_decl([48], "uint8")
295-
buffer4 = T.buffer_decl([256], "uint8")
296-
buffer5 = T.buffer_decl([16], "uint8")
297-
buffer6 = T.buffer_decl([48], "uint8")
298-
buffer7 = T.buffer_decl([256], "uint8")
299-
buffer8 = T.buffer_decl([64], "uint8")
292+
buffer1 = T.Buffer([64], "uint8")
293+
buffer2 = T.Buffer([48], "uint8")
294+
buffer3 = T.Buffer([48], "uint8")
295+
buffer4 = T.Buffer([256], "uint8")
296+
buffer5 = T.Buffer([16], "uint8")
297+
buffer6 = T.Buffer([48], "uint8")
298+
buffer7 = T.Buffer([256], "uint8")
299+
buffer8 = T.Buffer([64], "uint8")
300300
# body
301301
p1 = T.decl_buffer([48], "uint8")
302302
p2 = T.decl_buffer([48], "uint8")
@@ -330,14 +330,14 @@ class ReferenceModule:
330330
@T.prim_func
331331
def main() -> None:
332332
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
333-
buffer1 = T.buffer_decl([64], "uint8")
334-
buffer2 = T.buffer_decl([48], "uint8")
335-
buffer3 = T.buffer_decl([48], "uint8")
336-
buffer4 = T.buffer_decl([256], "uint8")
337-
buffer5 = T.buffer_decl([16], "uint8")
338-
buffer6 = T.buffer_decl([48], "uint8")
339-
buffer7 = T.buffer_decl([256], "uint8")
340-
buffer8 = T.buffer_decl([64], "uint8")
333+
buffer1 = T.Buffer([64], "uint8")
334+
buffer2 = T.Buffer([48], "uint8")
335+
buffer3 = T.Buffer([48], "uint8")
336+
buffer4 = T.Buffer([256], "uint8")
337+
buffer5 = T.Buffer([16], "uint8")
338+
buffer6 = T.Buffer([48], "uint8")
339+
buffer7 = T.Buffer([256], "uint8")
340+
buffer8 = T.Buffer([64], "uint8")
341341
# body
342342
p1 = T.decl_buffer([48], "uint8")
343343
p2 = T.decl_buffer([48], "uint8")
@@ -406,11 +406,11 @@ class ReferenceModule:
406406
@T.prim_func
407407
def main() -> None:
408408
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
409-
buffer1 = T.buffer_decl([97156], "int8")
410-
buffer2 = T.buffer_decl([80], "uint8")
411-
buffer3 = T.buffer_decl([64], "uint8")
412-
buffer4 = T.buffer_decl([96], "uint8")
413-
buffer5 = T.buffer_decl([32], "uint8")
409+
buffer1 = T.Buffer([97156], "int8")
410+
buffer2 = T.Buffer([80], "uint8")
411+
buffer3 = T.Buffer([64], "uint8")
412+
buffer4 = T.Buffer([96], "uint8")
413+
buffer5 = T.Buffer([32], "uint8")
414414
# body
415415
p1 = T.decl_buffer([390336], "int8")
416416
p2 = T.decl_buffer([80], "uint8")
@@ -439,11 +439,11 @@ class ReferenceModule:
439439
@T.prim_func
440440
def main() -> None:
441441
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
442-
buffer1 = T.buffer_decl([97156], "int8")
443-
buffer2 = T.buffer_decl([80], "uint8")
444-
buffer3 = T.buffer_decl([64], "uint8")
445-
buffer4 = T.buffer_decl([96], "uint8")
446-
buffer5 = T.buffer_decl([32], "uint8")
442+
buffer1 = T.Buffer([97156], "int8")
443+
buffer2 = T.Buffer([80], "uint8")
444+
buffer3 = T.Buffer([64], "uint8")
445+
buffer4 = T.Buffer([96], "uint8")
446+
buffer5 = T.Buffer([32], "uint8")
447447
# body
448448
p1 = T.decl_buffer([390336], "int8")
449449
p2 = T.decl_buffer([80], "uint8")

0 commit comments

Comments
 (0)