@@ -206,6 +206,7 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
206206 Doc PrintBlockVarRemaps ();
207207 Doc PrintBlockVars (const BlockRealizeNode* op);
208208 Doc PrintBlockAttr (const BlockRealizeNode* op);
209+ Doc PrintExpandedArray (const ArrayNode* op);
209210 Doc PrintBlockBody (const BlockNode* op);
210211 virtual Doc PrintBlockName (const BlockNode* block_op);
211212 Doc PrintBufferRegion (const BufferRegionNode* op);
@@ -220,6 +221,13 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
220221 Doc AllocBuf (const Buffer& buffer);
221222 void TryDeallocVar (const Var& var);
222223 bool ContainsOptionalInfo (const Stmt& stmt);
224+ /* !
225+ * \brief check if a buffer declaration has only 'shape' and 'dtype' arguments specified
226+ * \param buffer The match buffer to be checked
227+ */
228+ bool IsSimpleBuffer (const Buffer& buffer);
229+ Doc PrintInlineBufferBind (const Buffer& buffer);
230+ Doc PrintTuple (const ArrayNode* op);
223231
224232 /* ! Helper functions for loop printing. */
225233 /* !
@@ -404,7 +412,7 @@ Doc TVMScriptPrinter::AllocBufferDeclaration(const Buffer& buf) {
404412 if (buf->offset_factor != 1 || print_factor_explicitly) {
405413 doc << " , offset_factor=" << buf->offset_factor ;
406414 }
407- if (buf->buffer_type != 1 ) {
415+ if (buf->buffer_type != BufferType:: kDefault ) {
408416 doc << " , type=" << Doc::StrLiteral (" auto" );
409417 }
410418 return doc;
@@ -471,6 +479,60 @@ Doc TVMScriptPrinter::PrintMatchBufferRegion(const MatchBufferRegionNode* op) {
471479 return doc;
472480}
473481
482+ // check if all arguments, except the first two, are specified for T.match_buffer
483+ // if not, then this match buffer is printed out as T.buffer in prim_func arguments
484+ bool TVMScriptPrinter::IsSimpleBuffer (const Buffer& buf) {
485+ if (memo_var_.find (buf->data ) != memo_var_.end ()) {
486+ return false ;
487+ }
488+ if (!buf->strides .empty ()) {
489+ return false ;
490+ }
491+ if (buf->elem_offset ->IsInstance <VarNode>()) {
492+ return false ;
493+ } else if (buf->elem_offset ->IsInstance <IntImmNode>()) {
494+ IntImm elem_offset = Downcast<IntImm>(buf->elem_offset );
495+ if (elem_offset->value != 0 ) {
496+ return false ;
497+ }
498+ }
499+ if (buf.scope () != " global" ) {
500+ return false ;
501+ }
502+ if (buf->data_alignment != runtime::kAllocAlignment ) {
503+ return false ;
504+ }
505+ if (buf->offset_factor != 1 ) {
506+ return false ;
507+ }
508+ if (buf->buffer_type != BufferType::kDefault ) {
509+ return false ;
510+ }
511+ return true ;
512+ }
513+
514+ Doc TVMScriptPrinter::PrintInlineBufferBind (const Buffer& buffer) {
515+ Doc doc;
516+ doc << tir_prefix_ << " .Buffer[" << PrintTuple (buffer->shape .as <ArrayNode>());
517+ doc << " , " << PrintDType (buffer->dtype ) << " ]" ;
518+ return doc;
519+ }
520+
521+ // print array out as tuple with parentheses
522+ Doc TVMScriptPrinter::PrintTuple (const ArrayNode* op) {
523+ Doc doc;
524+ doc << ' (' ;
525+ for (size_t i = 0 ; i < op->size (); ++i) {
526+ if (i != 0 ) {
527+ doc << " , " ;
528+ }
529+ doc << Print (op->at (i));
530+ }
531+ if (op->size () == 1 ) doc << " ," ;
532+ doc << ' )' ;
533+ return doc;
534+ }
535+
474536Doc TVMScriptPrinter::PrintCommReducer (const CommReducerNode* op) {
475537 Doc doc;
476538 int n_var = static_cast <int >(op->rhs .size ());
@@ -1095,8 +1157,10 @@ Doc TVMScriptPrinter::PrintBlockAttr(const BlockRealizeNode* op) {
10951157 if (!is_one (op->predicate )) {
10961158 block_attr_doc << Doc::NewLine () << tir_prefix_ << " .where(" << Print (op->predicate ) << " )" ;
10971159 }
1098- block_attr_doc << Doc::NewLine () << tir_prefix_ << " .reads(" << Print (block_op->reads ) << " )" ;
1099- block_attr_doc << Doc::NewLine () << tir_prefix_ << " .writes(" << Print (block_op->writes ) << " )" ;
1160+ block_attr_doc << Doc::NewLine () << tir_prefix_ << " .reads("
1161+ << PrintExpandedArray (block_op->reads .as <ArrayNode>()) << " )" ;
1162+ block_attr_doc << Doc::NewLine () << tir_prefix_ << " .writes("
1163+ << PrintExpandedArray (block_op->writes .as <ArrayNode>()) << " )" ;
11001164 if (!block_op->annotations .empty ()) {
11011165 block_attr_doc << Doc::NewLine () << tir_prefix_ << " .block_attr({" ;
11021166 block_attr_doc << PrintAnnotations (block_op->annotations );
@@ -1105,6 +1169,19 @@ Doc TVMScriptPrinter::PrintBlockAttr(const BlockRealizeNode* op) {
11051169 return block_attr_doc;
11061170}
11071171
1172+ // This function is to make sure arguments of T.reads() and T.writes() is not parsed by printer as a
1173+ // List. Therefore the brackets are removed before and after printing arguments out
1174+ Doc TVMScriptPrinter::PrintExpandedArray (const ArrayNode* op) {
1175+ Doc doc;
1176+ for (size_t i = 0 ; i < op->size (); ++i) {
1177+ if (i != 0 ) {
1178+ doc << " , " ;
1179+ }
1180+ doc << Print (op->at (i));
1181+ }
1182+ return doc;
1183+ }
1184+
11081185Doc TVMScriptPrinter::PrintBlockBody (const BlockNode* op) {
11091186 Doc body;
11101187 for (const auto & alloc_buf : op->alloc_buffers ) {
@@ -1218,8 +1295,21 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) {
12181295 doc << " def " << (func2var_.find (op) == func2var_.end () ? " func" : func2var_[op]->name_hint )
12191296 << " (" ;
12201297 std::vector<Doc> params;
1298+ std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> simple_buf;
12211299 for (const auto & param : op->params ) {
12221300 var_not_in_headers_.insert (param.get ());
1301+ auto it = op->buffer_map .find (param);
1302+ // check if this param is a T.handle
1303+ if (it != op->buffer_map .end ()) {
1304+ // check if this match_buffer has only the first two arguments specified
1305+ const Buffer& buf = (*it).second ;
1306+ if (IsSimpleBuffer (buf)) {
1307+ simple_buf.insert (buf);
1308+ buf_not_in_headers_.insert (buf.get ());
1309+ params.push_back (Print (buf) << " : " << PrintInlineBufferBind (buf));
1310+ continue ;
1311+ }
1312+ }
12231313 params.push_back (Print (param) << " : " << Print (GetType (param)));
12241314 }
12251315 doc << PrintSep (params, Doc::Text (" , " )) << " ) -> " << Print (primFunc->ret_type ) << " :" ;
@@ -1229,9 +1319,11 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) {
12291319 for (const auto & param : op->params ) {
12301320 auto it = op->buffer_map .find (param);
12311321 if (it == op->buffer_map .end ()) continue ;
1232- buf_not_in_headers_.insert ((*it).second .get ());
1233- body << Print ((*it).second ) << " = " << tir_prefix_ << " .match_buffer(" ;
1234- body << Print ((*it).first ) << " , " << memo_buf_decl_[(*it).second ];
1322+ const Buffer& buf = (*it).second ;
1323+ if (simple_buf.count (buf)) continue ;
1324+ buf_not_in_headers_.insert (buf.get ());
1325+ body << Print (buf) << " = " << tir_prefix_ << " .match_buffer(" ;
1326+ body << Print ((*it).first ) << " , " << memo_buf_decl_[buf];
12351327 body << " )" << Doc::NewLine ();
12361328 }
12371329 // print body
@@ -1392,8 +1484,12 @@ Doc TVMScriptPrinter::PrintAnnotations(const Map<String, ObjectRef>& annotations
13921484Doc TVMScriptPrinter::PrintLoop (const For& loop) {
13931485 Doc res;
13941486 res << " for " << Print (loop->loop_var ) << " in " << tir_prefix_
1395- << " ." + std::string (ForKind2String (loop->kind )) + " (" << Print (loop->min ) << " , "
1396- << Print (loop->min + loop->extent );
1487+ << " ." + std::string (ForKind2String (loop->kind )) + " (" ;
1488+ if (is_zero (loop->min )) {
1489+ res << Print (loop->extent );
1490+ } else {
1491+ res << Print (loop->min ) << " , " << Print (loop->min + loop->extent );
1492+ }
13971493 if (loop->thread_binding .defined ()) {
13981494 res << " , thread=" ;
13991495 res << Print (loop->thread_binding .value ()->thread_tag );
0 commit comments