Skip to content

Commit 40ecfec

Browse files
author
Yuanjing Shi
authored
[TVMScript] Improve printer for TIR syntax sugar (#9680)
1 parent 404d9cf commit 40ecfec

File tree

4 files changed

+138
-43
lines changed

4 files changed

+138
-43
lines changed

python/tvm/script/tir/special_stmt.py

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -317,8 +317,7 @@ class BlockReads(SpecialStmt):
317317

318318
def __init__(self):
319319
def reads(
320-
read_regions: Union[BufferSlice, List[BufferSlice]],
321-
*other_regions: BufferSlice,
320+
*read_regions: Union[BufferSlice, List[BufferSlice]],
322321
span: Span = None,
323322
):
324323
assert self.context, "call 'exit_scope' before 'enter_scope'"
@@ -335,16 +334,18 @@ def reads(
335334
+ str(", ".join(str(x) for x in block_scope.reads)),
336335
span,
337336
)
338-
if isinstance(read_regions, BufferSlice):
339-
read_regions = [read_regions]
340-
for region in other_regions:
341-
read_regions.append(region)
342-
if not isinstance(read_regions, list):
343-
self.context.report_error(
344-
"Incorrect input type. "
345-
+ f"Expected BufferSlice or List[BufferSlice], but got {type(read_regions)}",
346-
span,
347-
)
337+
if len(read_regions) > 1:
338+
for read_region in read_regions:
339+
if not isinstance(read_region, BufferSlice):
340+
self.context.report_error(
341+
"Incorrect input type. Expected *BufferSlice or List[BufferSlice],"
342+
+ f" but got {type(read_regions)}",
343+
span,
344+
)
345+
elif len(read_regions) == 1:
346+
if isinstance(read_regions[0], list):
347+
read_regions = read_regions[0]
348+
348349
block_scope.reads = read_regions
349350

350351
super().__init__(reads, def_symbol=False)
@@ -368,8 +369,7 @@ class BlockWrites(SpecialStmt):
368369

369370
def __init__(self):
370371
def writes(
371-
write_region: Union[BufferSlice, List[BufferSlice]],
372-
*other_region: BufferSlice,
372+
*write_regions: Union[BufferSlice, List[BufferSlice]],
373373
span: Span = None,
374374
):
375375
assert self.context, "call 'exit_scope' before 'enter_scope'"
@@ -386,19 +386,18 @@ def writes(
386386
+ str(", ".join(str(x) for x in block_scope.writes)),
387387
span,
388388
)
389-
if isinstance(write_region, list):
390-
pass
391-
elif isinstance(write_region, BufferSlice):
392-
write_region = [write_region]
393-
for region in other_region:
394-
write_region.append(region)
395-
else:
396-
self.context.report_error(
397-
"Incorrect input type. "
398-
+ f"Expected BufferSlice or List[BufferSlice], but got {type(write_region)}",
399-
span,
400-
)
401-
block_scope.writes = write_region
389+
if len(write_regions) > 1:
390+
for write_region in write_regions:
391+
if not isinstance(write_region, BufferSlice):
392+
self.context.report_error(
393+
"Incorrect input type. Expected *BufferSlice or List[BufferSlice],"
394+
+ f" but got {type(write_regions)}",
395+
span,
396+
)
397+
elif len(write_regions) == 1:
398+
if isinstance(write_regions[0], list):
399+
write_regions = write_regions[0]
400+
block_scope.writes = write_regions
402401

403402
super().__init__(writes, def_symbol=False)
404403

src/printer/tvmscript_printer.cc

Lines changed: 104 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
206206
Doc PrintBlockVarRemaps();
207207
Doc PrintBlockVars(const BlockRealizeNode* op);
208208
Doc PrintBlockAttr(const BlockRealizeNode* op);
209+
Doc PrintExpandedArray(const ArrayNode* op);
209210
Doc PrintBlockBody(const BlockNode* op);
210211
virtual Doc PrintBlockName(const BlockNode* block_op);
211212
Doc PrintBufferRegion(const BufferRegionNode* op);
@@ -220,6 +221,13 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
220221
Doc AllocBuf(const Buffer& buffer);
221222
void TryDeallocVar(const Var& var);
222223
bool ContainsOptionalInfo(const Stmt& stmt);
224+
/*!
225+
* \brief check if a buffer declaration has only 'shape' and 'dtype' arguments specified
226+
* \param buffer The match buffer to be checked
227+
*/
228+
bool IsSimpleBuffer(const Buffer& buffer);
229+
Doc PrintInlineBufferBind(const Buffer& buffer);
230+
Doc PrintTuple(const ArrayNode* op);
223231

224232
/*! Helper functions for loop printing. */
225233
/*!
@@ -404,7 +412,7 @@ Doc TVMScriptPrinter::AllocBufferDeclaration(const Buffer& buf) {
404412
if (buf->offset_factor != 1 || print_factor_explicitly) {
405413
doc << ", offset_factor=" << buf->offset_factor;
406414
}
407-
if (buf->buffer_type != 1) {
415+
if (buf->buffer_type != BufferType::kDefault) {
408416
doc << ", type=" << Doc::StrLiteral("auto");
409417
}
410418
return doc;
@@ -471,6 +479,60 @@ Doc TVMScriptPrinter::PrintMatchBufferRegion(const MatchBufferRegionNode* op) {
471479
return doc;
472480
}
473481

482+
// check if all arguments, except the first two, are specified for T.match_buffer
483+
// if not, then this match buffer is printed out as T.buffer in prim_func arguments
484+
bool TVMScriptPrinter::IsSimpleBuffer(const Buffer& buf) {
485+
if (memo_var_.find(buf->data) != memo_var_.end()) {
486+
return false;
487+
}
488+
if (!buf->strides.empty()) {
489+
return false;
490+
}
491+
if (buf->elem_offset->IsInstance<VarNode>()) {
492+
return false;
493+
} else if (buf->elem_offset->IsInstance<IntImmNode>()) {
494+
IntImm elem_offset = Downcast<IntImm>(buf->elem_offset);
495+
if (elem_offset->value != 0) {
496+
return false;
497+
}
498+
}
499+
if (buf.scope() != "global") {
500+
return false;
501+
}
502+
if (buf->data_alignment != runtime::kAllocAlignment) {
503+
return false;
504+
}
505+
if (buf->offset_factor != 1) {
506+
return false;
507+
}
508+
if (buf->buffer_type != BufferType::kDefault) {
509+
return false;
510+
}
511+
return true;
512+
}
513+
514+
Doc TVMScriptPrinter::PrintInlineBufferBind(const Buffer& buffer) {
515+
Doc doc;
516+
doc << tir_prefix_ << ".Buffer[" << PrintTuple(buffer->shape.as<ArrayNode>());
517+
doc << ", " << PrintDType(buffer->dtype) << "]";
518+
return doc;
519+
}
520+
521+
// print array out as tuple with parentheses
522+
Doc TVMScriptPrinter::PrintTuple(const ArrayNode* op) {
523+
Doc doc;
524+
doc << '(';
525+
for (size_t i = 0; i < op->size(); ++i) {
526+
if (i != 0) {
527+
doc << ", ";
528+
}
529+
doc << Print(op->at(i));
530+
}
531+
if (op->size() == 1) doc << ",";
532+
doc << ')';
533+
return doc;
534+
}
535+
474536
Doc TVMScriptPrinter::PrintCommReducer(const CommReducerNode* op) {
475537
Doc doc;
476538
int n_var = static_cast<int>(op->rhs.size());
@@ -1095,8 +1157,10 @@ Doc TVMScriptPrinter::PrintBlockAttr(const BlockRealizeNode* op) {
10951157
if (!is_one(op->predicate)) {
10961158
block_attr_doc << Doc::NewLine() << tir_prefix_ << ".where(" << Print(op->predicate) << ")";
10971159
}
1098-
block_attr_doc << Doc::NewLine() << tir_prefix_ << ".reads(" << Print(block_op->reads) << ")";
1099-
block_attr_doc << Doc::NewLine() << tir_prefix_ << ".writes(" << Print(block_op->writes) << ")";
1160+
block_attr_doc << Doc::NewLine() << tir_prefix_ << ".reads("
1161+
<< PrintExpandedArray(block_op->reads.as<ArrayNode>()) << ")";
1162+
block_attr_doc << Doc::NewLine() << tir_prefix_ << ".writes("
1163+
<< PrintExpandedArray(block_op->writes.as<ArrayNode>()) << ")";
11001164
if (!block_op->annotations.empty()) {
11011165
block_attr_doc << Doc::NewLine() << tir_prefix_ << ".block_attr({";
11021166
block_attr_doc << PrintAnnotations(block_op->annotations);
@@ -1105,6 +1169,19 @@ Doc TVMScriptPrinter::PrintBlockAttr(const BlockRealizeNode* op) {
11051169
return block_attr_doc;
11061170
}
11071171

1172+
// This function is to make sure arguments of T.reads() and T.writes() is not parsed by printer as a
1173+
// List. Therefore the brackets are removed before and after printing arguments out
1174+
Doc TVMScriptPrinter::PrintExpandedArray(const ArrayNode* op) {
1175+
Doc doc;
1176+
for (size_t i = 0; i < op->size(); ++i) {
1177+
if (i != 0) {
1178+
doc << ", ";
1179+
}
1180+
doc << Print(op->at(i));
1181+
}
1182+
return doc;
1183+
}
1184+
11081185
Doc TVMScriptPrinter::PrintBlockBody(const BlockNode* op) {
11091186
Doc body;
11101187
for (const auto& alloc_buf : op->alloc_buffers) {
@@ -1218,8 +1295,21 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) {
12181295
doc << "def " << (func2var_.find(op) == func2var_.end() ? "func" : func2var_[op]->name_hint)
12191296
<< "(";
12201297
std::vector<Doc> params;
1298+
std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> simple_buf;
12211299
for (const auto& param : op->params) {
12221300
var_not_in_headers_.insert(param.get());
1301+
auto it = op->buffer_map.find(param);
1302+
// check if this param is a T.handle
1303+
if (it != op->buffer_map.end()) {
1304+
// check if this match_buffer has only the first two arguments specified
1305+
const Buffer& buf = (*it).second;
1306+
if (IsSimpleBuffer(buf)) {
1307+
simple_buf.insert(buf);
1308+
buf_not_in_headers_.insert(buf.get());
1309+
params.push_back(Print(buf) << ": " << PrintInlineBufferBind(buf));
1310+
continue;
1311+
}
1312+
}
12231313
params.push_back(Print(param) << ": " << Print(GetType(param)));
12241314
}
12251315
doc << PrintSep(params, Doc::Text(", ")) << ") -> " << Print(primFunc->ret_type) << ":";
@@ -1229,9 +1319,11 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) {
12291319
for (const auto& param : op->params) {
12301320
auto it = op->buffer_map.find(param);
12311321
if (it == op->buffer_map.end()) continue;
1232-
buf_not_in_headers_.insert((*it).second.get());
1233-
body << Print((*it).second) << " = " << tir_prefix_ << ".match_buffer(";
1234-
body << Print((*it).first) << ", " << memo_buf_decl_[(*it).second];
1322+
const Buffer& buf = (*it).second;
1323+
if (simple_buf.count(buf)) continue;
1324+
buf_not_in_headers_.insert(buf.get());
1325+
body << Print(buf) << " = " << tir_prefix_ << ".match_buffer(";
1326+
body << Print((*it).first) << ", " << memo_buf_decl_[buf];
12351327
body << ")" << Doc::NewLine();
12361328
}
12371329
// print body
@@ -1392,8 +1484,12 @@ Doc TVMScriptPrinter::PrintAnnotations(const Map<String, ObjectRef>& annotations
13921484
Doc TVMScriptPrinter::PrintLoop(const For& loop) {
13931485
Doc res;
13941486
res << "for " << Print(loop->loop_var) << " in " << tir_prefix_
1395-
<< "." + std::string(ForKind2String(loop->kind)) + "(" << Print(loop->min) << ", "
1396-
<< Print(loop->min + loop->extent);
1487+
<< "." + std::string(ForKind2String(loop->kind)) + "(";
1488+
if (is_zero(loop->min)) {
1489+
res << Print(loop->extent);
1490+
} else {
1491+
res << Print(loop->min) << ", " << Print(loop->min + loop->extent);
1492+
}
13971493
if (loop->thread_binding.defined()) {
13981494
res << ", thread=";
13991495
res << Print(loop->thread_binding.value()->thread_tag);

tests/python/unittest/test_tvmscript_error_report.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -544,10 +544,10 @@ def test_reorder_fail_nested_loop_inner():
544544
with pytest.raises(tvm.tir.ScheduleError) as execinfo:
545545
sch.reorder(k, i)
546546
expected_sub_error_message = (
547-
" for i in T.serial(0, 128):\n"
547+
" for i in T.serial(128):\n"
548548
" # tir.For#0\n"
549-
" for j in T.serial(0, 128):\n"
550-
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n"
549+
" for j in T.serial(128):\n"
550+
" ^^^^^^^^^^^^^^^^^^^^^^^\n"
551551
)
552552
assert expected_sub_error_message in str(execinfo.value)
553553

@@ -560,9 +560,9 @@ def test_fuse_fail_nested_loop_outer():
560560
sch.fuse(k, i)
561561
expected_sub_error_message = (
562562
" # tir.For#1\n"
563-
" for i in T.serial(0, 128):\n"
564-
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n"
565-
" for j in T.serial(0, 128):\n"
563+
" for i in T.serial(128):\n"
564+
" ^^^^^^^^^^^^^^^^^^^^^^^\n"
565+
" for j in T.serial(128):\n"
566566
)
567567
assert expected_sub_error_message in str(execinfo.value)
568568

tests/python/unittest/test_tvmscript_syntax_sugar.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,8 @@ def elementwise_handle(
118118
# match buffer - use buffer with kwargs
119119
@T.prim_func
120120
def elementwise_buffer_kwargs(
121-
a: T.Buffer(shape=(128, 128, 128, 128), dtype="float32", elem_offset=None),
122-
b: T.Buffer(shape=(128, 128, 128, 128), dtype="float32", elem_offset=None),
121+
a: T.Buffer(shape=(128, 128, 128, 128), dtype="float32"),
122+
b: T.Buffer(shape=(128, 128, 128, 128), dtype="float32"),
123123
) -> None:
124124
for i, j, k, l in T.grid(128, 128, 128, 128):
125125
with T.block("B"):

0 commit comments

Comments
 (0)