From c13b2aeedf7d2ab8e53882dfed0d7e842f4890f7 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Thu, 14 Feb 2019 21:08:28 -0800 Subject: [PATCH 01/73] minor bug fix and beginnings of hypothesis testing --- python/tvm/relay/_parser.py | 3 +++ .../relay/test_ir_parser_printer_roundtrip.py | 26 +++++++++++++++++++ 2 files changed, 29 insertions(+) create mode 100644 tests/python/relay/test_ir_parser_printer_roundtrip.py diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py index 9fdffab4e62e..de7e2ae24959 100644 --- a/python/tvm/relay/_parser.py +++ b/python/tvm/relay/_parser.py @@ -512,6 +512,9 @@ def make_parser(data): def fromtext(data, source_name=None): # type: (str, str) -> Union[expr.Expr, module.Module] """Parse a Relay program.""" + if data == "": + raise ParseError("Cannot parse the empty string.") + global __source_name_counter__ if source_name is None: diff --git a/tests/python/relay/test_ir_parser_printer_roundtrip.py b/tests/python/relay/test_ir_parser_printer_roundtrip.py new file mode 100644 index 000000000000..bb1f18cb692a --- /dev/null +++ b/tests/python/relay/test_ir_parser_printer_roundtrip.py @@ -0,0 +1,26 @@ +import tvm +from tvm import relay +from hypothesis import given, reject +from hypothesis.strategies import text, lists, integers, composite, recursive + +@composite +def constants(draw): + # python_tensor = draw(recursive(integers(), lists)) + python_tensor = draw(lists(integers())) + # TODO: generate higher dimensional and 0D tensors. must be box shaped + return relay.Constant(tvm.nd.array(python_tensor)) + +@given(constants()) +def test_roundtrip(e): + relay.fromtext(e.astext()) + +# @given(text()) +# def test_fuzz(s): +# try: +# relay.fromtext(s) +# except tvm._ffi.base.TVMError: +# reject() + +if __name__ == "__main__": + for _ in range(10): + print(constants().example()) From aec4724112e8156e8a0485b29cdb3cf9de025277 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Fri, 15 Feb 2019 14:05:21 -0800 Subject: [PATCH 02/73] bump semver. add inline_meta_data flag. implement tuples --- include/tvm/relay/expr.h | 1 + python/tvm/relay/base.py | 6 +- python/tvm/relay/grammar/Relay.g4 | 2 +- src/relay/ir/error.cc | 2 +- src/relay/ir/text_printer.cc | 76 +++++++++++-------- src/relay/pass/fuse_ops.cc | 2 +- tests/python/relay/test_ir_parser.py | 2 +- .../relay/test_ir_parser_printer_roundtrip.py | 29 ++++--- 8 files changed, 71 insertions(+), 49 deletions(-) diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 06a1aa1ac9ef..5be11dcfb77c 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -553,6 +553,7 @@ inline const TTypeNode* ExprNode::type_as() const { */ std::string RelayPrint( const NodeRef& node, + bool inline_meta_data = false, bool show_meta_data = true, runtime::TypedPackedFunc annotate = nullptr); } // namespace relay diff --git a/python/tvm/relay/base.py b/python/tvm/relay/base.py index e0491d62f552..ca9e36c88f73 100644 --- a/python/tvm/relay/base.py +++ b/python/tvm/relay/base.py @@ -38,7 +38,7 @@ def register_relay_attr_node(type_key=None): class RelayNode(NodeBase): """Base class of all Relay nodes.""" - def astext(self, show_meta_data=True, annotate=None): + def astext(self, inline_meta_data=False, show_meta_data=True, annotate=None): """Get the text format of the expression. Parameters @@ -62,13 +62,13 @@ def astext(self, show_meta_data=True, annotate=None): text : str The text format of the expression. """ - return _expr.RelayPrint(self, show_meta_data, annotate) + return _expr.RelayPrint(self, inline_meta_data, show_meta_data, annotate) def set_span(self, span): _base.set_span(self, span) def __str__(self): - return self.astext(show_meta_data=False) + return self.astext(inline_meta_data=False, show_meta_data=False) @register_relay_node diff --git a/python/tvm/relay/grammar/Relay.g4 b/python/tvm/relay/grammar/Relay.g4 index 0a2206265502..18b05f2761e8 100644 --- a/python/tvm/relay/grammar/Relay.g4 +++ b/python/tvm/relay/grammar/Relay.g4 @@ -1,6 +1,6 @@ grammar Relay; -SEMVER: 'v0.0.1' ; +SEMVER: 'v0.0.2' ; // Lexing // comments diff --git a/src/relay/ir/error.cc b/src/relay/ir/error.cc index 24f8d1c49b6b..be0d11064209 100644 --- a/src/relay/ir/error.cc +++ b/src/relay/ir/error.cc @@ -91,7 +91,7 @@ void ErrorReporter::RenderErrors(const Module& module, bool use_color) { // // The annotation callback will annotate the error messages // contained in the map. - annotated_prog << RelayPrint(func, false, [&err_map](tvm::relay::Expr expr) { + annotated_prog << RelayPrint(func, false, false, [&err_map](tvm::relay::Expr expr) { auto it = err_map.find(expr); if (it != err_map.end()) { return it->second; diff --git a/src/relay/ir/text_printer.cc b/src/relay/ir/text_printer.cc index 932856a2055d..7ea4b69d70a0 100644 --- a/src/relay/ir/text_printer.cc +++ b/src/relay/ir/text_printer.cc @@ -73,9 +73,9 @@ inline std::ostream& operator<<(std::ostream& os, const TextValue& val) { // NO * fn (%x: Tensor[(meta[Variable][0],), float32]) { * %x * } - * # Meta data section is a json-serialized string - * # of the following array. - * # [tvm.var("n")] + * // Meta data section is a json-serialized string + * // of the following array. + * // [tvm.var("n")] * * \endcode * @@ -139,15 +139,17 @@ class TextPrinter : public TypeFunctor, // NOLINT(*) public AttrFunctor { // NOLINT(*) public: - explicit TextPrinter(bool show_meta_data, + explicit TextPrinter(bool inline_meta_data, + bool show_meta_data, runtime::TypedPackedFunc annotate) - : show_meta_data_(show_meta_data), annotate_(annotate) {} + : inline_meta_data_(inline_meta_data), show_meta_data_(show_meta_data), annotate_(annotate) {} /*! * \brief Print a node to string. * \param node. * \return The string representation. */ std::string Print(const NodeRef& node) { + stream_ << "v0.0.2\n"; if (node.as()) { this->PrintFunc(Downcast(node)); } else if (node.as()) { @@ -163,12 +165,12 @@ class TextPrinter : if (show_meta_data_) { std::string meta_json = meta_.GetMetaSection(); // append meta data in the end. - stream_ << "# meta data\n" + stream_ << "// meta data\n" << "r\"\"\"\n" << meta_json << "\n" << "\"\"\""; } else { - stream_ << "# meta data omitted. you can use show_meta_data=True to include meta-data\n"; + stream_ << "// meta data omitted. you can use show_meta_data=True to include meta-data\n"; } } return stream_.str(); @@ -256,27 +258,38 @@ class TextPrinter : } TextValue VisitExpr_(const TupleNode* op) final { - std::vector fields; - for (Expr field : op->fields) { - fields.push_back(GetValue(field)); - } - // NOTE: always recursively visit to get ids, - // before print out the current line - TextValue id = this->AllocTempVar(); - this->PrintIndent(); - stream_ << id << " = ("; - for (size_t i = 0; i < fields.size(); ++i) { - stream_ << fields[i]; - if (i + 1 != fields.size()) { - stream_ << ", "; + if (inline_meta_data_) { + stream_ << "("; + for (size_t i = 0; i < op->fields.size(); i++) { + stream_ << GetValue(op->fields[i]); + if (i + 1 != op->fields.size()) + stream_ << ", "; } + stream_ << ")"; + return TextValue(""); + } else { + std::vector fields; + for (Expr field : op->fields) { + fields.push_back(GetValue(field)); + } + // NOTE: always recursively visit to get ids, + // before print out the current line + TextValue id = this->AllocTempVar(); + this->PrintIndent(); + stream_ << id << " = ("; + for (size_t i = 0; i < fields.size(); ++i) { + stream_ << fields[i]; + if (i + 1 != fields.size()) { + stream_ << ", "; + } + } + if (fields.size() == 1) { + stream_ << ','; + } + stream_ << ')'; + this->PrintEndInst("\n"); + return id; } - if (fields.size() == 1) { - stream_ << ','; - } - stream_ << ')'; - this->PrintEndInst("\n"); - return id; } TextValue VisitExpr_(const VarNode* op) final { @@ -637,9 +650,9 @@ class TextPrinter : void PrintOptionalInfo(const Expr& expr) { // additional information in comment. if (annotate_ != nullptr) { - stream_ << " # " << annotate_(expr); + stream_ << " // " << annotate_(expr); } else if (expr->checked_type_.defined()) { - stream_ << " # ty="; + stream_ << " // ty="; this->PrintType(expr->checked_type(), stream_); } } @@ -795,6 +808,8 @@ class TextPrinter : private: class AttrPrinter; friend class AttrPrinter; + /*! \brief Whether to inline meta data. If enabled, ignores show_meta_data_ flag. */ + bool inline_meta_data_; /*! \brief Whether to print meta data. */ bool show_meta_data_; /*! \brief additional comment function */ @@ -890,14 +905,15 @@ void TextPrinter::PrintCallAttrs(const Expr& op, } std::string RelayPrint(const NodeRef& node, + bool inline_meta_data, bool show_meta_data, runtime::TypedPackedFunc annotate) { - return TextPrinter(show_meta_data, annotate).Print(node); + return TextPrinter(inline_meta_data, show_meta_data, annotate).Print(node); } TVM_REGISTER_API("relay._expr.RelayPrint") .set_body_typed)>(RelayPrint); } // namespace relay diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 66ff9caf4ae4..6efcd4464e75 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -835,7 +835,7 @@ class FuseMutator : private ExprMutator { // Debug function, dump the group assignment in text. void DebugDumpGroup(const Expr& body) { - std::string text = RelayPrint(body, false, [this](const Expr& expr) -> std::string { + std::string text = RelayPrint(body, false, false, [this](const Expr& expr) -> std::string { auto it = gmap_.find(expr.get()); if (it == gmap_.end()) return ""; std::ostringstream os; diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 08d4c430101b..9b725234a33c 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -12,7 +12,7 @@ else: raises_parse_error = lambda x: x -SEMVER = "v0.0.1" +SEMVER = "v0.0.2" BINARY_OPS = { "*": relay.multiply, diff --git a/tests/python/relay/test_ir_parser_printer_roundtrip.py b/tests/python/relay/test_ir_parser_printer_roundtrip.py index bb1f18cb692a..1bcd89698416 100644 --- a/tests/python/relay/test_ir_parser_printer_roundtrip.py +++ b/tests/python/relay/test_ir_parser_printer_roundtrip.py @@ -1,26 +1,31 @@ import tvm from tvm import relay +from tvm.relay.ir_pass import alpha_equal +import numpy as np + from hypothesis import given, reject from hypothesis.strategies import text, lists, integers, composite, recursive @composite def constants(draw): # python_tensor = draw(recursive(integers(), lists)) - python_tensor = draw(lists(integers())) + # python_tensor = draw(lists(integers(min_value=-1000, max_value=1000))) + python_tensor = draw(integers(min_value=-1000, max_value=1000)) # TODO: generate higher dimensional and 0D tensors. must be box shaped - return relay.Constant(tvm.nd.array(python_tensor)) + return relay.Constant(tvm.nd.array(np.array(python_tensor).astype("int32"))) -@given(constants()) -def test_roundtrip(e): - relay.fromtext(e.astext()) +@composite +def tuples(draw): + # TODO: replace constants with exprs + return relay.Tuple(draw(lists(constants()))) -# @given(text()) -# def test_fuzz(s): -# try: -# relay.fromtext(s) -# except tvm._ffi.base.TVMError: -# reject() +@given(tuples()) +def test_roundtrip(e): + print(e.astext(inline_meta_data=True)) + alpha_equal(relay.fromtext(e.astext(inline_meta_data=True)), e) + # e.astext() if __name__ == "__main__": for _ in range(10): - print(constants().example()) + # print(constants().example().astext()) + print(tuples().example().astext(inline_meta_data=True)) From 1595a46b6a80926f0c3885a15592b0a235b4b256 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Sat, 16 Feb 2019 23:58:11 -0800 Subject: [PATCH 03/73] add simple wadler-style printer infrastructure --- src/relay/ir/pretty_printer.cc | 95 ++++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 src/relay/ir/pretty_printer.cc diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc new file mode 100644 index 000000000000..6352a53a3f25 --- /dev/null +++ b/src/relay/ir/pretty_printer.cc @@ -0,0 +1,95 @@ +/*! + * Copyright (c) 2019 by Contributors + * \file pretty_printer.cc + * \brief Pretty printer for Relay programs + * Supports ANF and GNF formats and metadata. + */ +#include + +namespace tvm { +namespace relay { + +namespace doc { + +// Doc model based on Section 1 of https://homepages.inf.ed.ac.uk/wadler/papers/prettier/prettier.pdf. + +enum DocType { NIL, TEXT, LINE }; + +struct Doc { + virtual DocType getType(); +}; + +struct Nil : Doc { + DocType getType() { return NIL; } +}; + +struct Text : Doc { + std::string str; + Doc doc; + + Text(std::string str) : str(str), doc(Nil()) { } + Text(std::string str, Doc doc) : str(str), doc(doc) { } + + DocType getType() { return TEXT; } +}; + +struct Line : Doc { + size_t indent; + Doc doc; + + Line() : indent(0), doc(Nil()) { } + Line(size_t indent, Doc doc) : indent(indent), doc(doc) { } + + DocType getType() { return LINE; } +}; + +// concatenate two documents +Doc Concat(Doc &left, Doc &right) { + if (left.getType() == TEXT) { + Text &text = static_cast(left); + return Text(text.str, Concat(text.doc, right)); + } else if (left.getType() == LINE) { + Line &line = static_cast(left); + return Line(line.indent, Concat(line.doc, right)); + } else if (left.getType() == NIL) { + return right; + } else { assert(false); } +} + +// overload + to concatenate documents +Doc operator+(Doc& left, Doc& right) { + Concat(left, right); +} + +// add indentation to a document +Doc Nest(size_t indent, Doc doc) { + if (doc.getType() == TEXT) { + Text &text = static_cast(doc); + return Text(text.str, Nest(indent, text.doc)); + } else if (doc.getType() == LINE) { + Line &line = static_cast(doc); + return Line(indent + line.indent, Nest(indent, line.doc)); + } else if (doc.getType() == NIL) { + return Nil(); + } else { assert(false); } +} + +// print a document to the given ostream +void Layout(Doc doc, std::ostream& os) { + if (doc.getType() == TEXT) { + Text &text = static_cast(doc); + os << text.str; + Layout(text.doc, os); + } else if (doc.getType() == LINE) { + Line &line = static_cast(doc); + os << std::endl << std::string(line.indent, ' '); + Layout(line.doc, os); + } else if (doc.getType() == NIL) { + // do nothing! + } else { assert(false); } +} + +} // doc + +} // relay +} // tvm \ No newline at end of file From 7b6ce8b457764f89860299502ad7cd88d939c83c Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Sun, 17 Feb 2019 00:02:31 -0800 Subject: [PATCH 04/73] make things build --- src/relay/ir/pretty_printer.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 6352a53a3f25..79431c6a1881 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -16,7 +16,7 @@ namespace doc { enum DocType { NIL, TEXT, LINE }; struct Doc { - virtual DocType getType(); + virtual DocType getType() { assert(false); }; }; struct Nil : Doc { @@ -58,7 +58,7 @@ Doc Concat(Doc &left, Doc &right) { // overload + to concatenate documents Doc operator+(Doc& left, Doc& right) { - Concat(left, right); + return Concat(left, right); } // add indentation to a document From c11260a1f241a2754e1b0db506185d5f5e801ec7 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Sun, 17 Feb 2019 13:35:14 -0800 Subject: [PATCH 05/73] remove stringstreams. add pretty printer class --- src/relay/ir/pretty_printer.cc | 43 ++++++++++++++++++++++++++++------ 1 file changed, 36 insertions(+), 7 deletions(-) diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 79431c6a1881..fdb6c417d6d9 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -74,22 +74,51 @@ Doc Nest(size_t indent, Doc doc) { } else { assert(false); } } -// print a document to the given ostream -void Layout(Doc doc, std::ostream& os) { +// convert a document to a string +std::string Layout(Doc doc) { if (doc.getType() == TEXT) { Text &text = static_cast(doc); - os << text.str; - Layout(text.doc, os); + return text.str + Layout(text.doc); } else if (doc.getType() == LINE) { Line &line = static_cast(doc); - os << std::endl << std::string(line.indent, ' '); - Layout(line.doc, os); + return "\n" + std::string(line.indent, ' ') + Layout(line.doc); } else if (doc.getType() == NIL) { - // do nothing! + return ""; } else { assert(false); } } } // doc +class PrettyPrinter : + public ExprFunctor { + public: + explicit PrettyPrinter() {} + + std::string Print(const NodeRef& node) { + if (node.as_derived()) { + return doc::Layout(this->PrintExpr(Downcast(node))); + } else { assert(false); } + } + + doc::Doc PrintExpr(const Expr& expr) { + auto it = memo_.find(expr); + if (it != memo_.end()) return it->second; + doc::Doc val = this->VisitExpr(expr); + memo_[expr] = val; + return val; + } + + private: + /*! \brief Map from Expr to Doc */ + std::unordered_map memo_; +}; + +std::string RelayPrettyPrint(const NodeRef& node) { + return PrettyPrinter().Print(node); +} + +TVM_REGISTER_API("relay._expr.RelayPrettyPrint") +.set_body_typed(RelayPrettyPrint); + } // relay } // tvm \ No newline at end of file From 0a3210b71f24695d24b2bd4c93dd24b0d1c55173 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Sun, 17 Feb 2019 14:59:21 -0800 Subject: [PATCH 06/73] commit failing refactoring things. can't return abstract class. may need to rethink design --- src/relay/ir/pretty_printer.cc | 157 +++++++++++++----- .../relay/test_ir_parser_printer_roundtrip.py | 6 + 2 files changed, 118 insertions(+), 45 deletions(-) diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index fdb6c417d6d9..4844ffad3e01 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -16,72 +16,72 @@ namespace doc { enum DocType { NIL, TEXT, LINE }; struct Doc { - virtual DocType getType() { assert(false); }; + virtual DocType getType() const = 0; }; -struct Nil : Doc { - DocType getType() { return NIL; } +struct Nil : public Doc { + DocType getType() const { return NIL; } }; -struct Text : Doc { - std::string str; - Doc doc; +struct Text : public Doc { + const std::string str; + const Doc* doc; - Text(std::string str) : str(str), doc(Nil()) { } - Text(std::string str, Doc doc) : str(str), doc(doc) { } + Text(const std::string str) : str(str), doc(new Nil()) { } + Text(const std::string str, const Doc* doc) : str(str), doc(doc) { } - DocType getType() { return TEXT; } + DocType getType() const { return TEXT; } }; -struct Line : Doc { - size_t indent; - Doc doc; +struct Line : public Doc { + const size_t indent; + const Doc* doc; - Line() : indent(0), doc(Nil()) { } - Line(size_t indent, Doc doc) : indent(indent), doc(doc) { } + Line() : indent(0), doc(new Nil()) { } + Line(const size_t indent, const Doc* doc) : indent(indent), doc(doc) { } - DocType getType() { return LINE; } + DocType getType() const { return LINE; } }; // concatenate two documents -Doc Concat(Doc &left, Doc &right) { - if (left.getType() == TEXT) { - Text &text = static_cast(left); - return Text(text.str, Concat(text.doc, right)); - } else if (left.getType() == LINE) { - Line &line = static_cast(left); - return Line(line.indent, Concat(line.doc, right)); - } else if (left.getType() == NIL) { +const Doc* Concat(const Doc* left, const Doc* right) { + if (left->getType() == TEXT) { + const Text* text = static_cast(left); + return new Text(text->str, Concat(text->doc, right)); + } else if (left->getType() == LINE) { + const Line* line = static_cast(left); + return new Line(line->indent, Concat(line->doc, right)); + } else if (left->getType() == NIL) { return right; } else { assert(false); } } // overload + to concatenate documents -Doc operator+(Doc& left, Doc& right) { - return Concat(left, right); +const Doc& operator+(const Doc& left, const Doc& right) { + return *Concat(&left, &right); } // add indentation to a document -Doc Nest(size_t indent, Doc doc) { - if (doc.getType() == TEXT) { - Text &text = static_cast(doc); - return Text(text.str, Nest(indent, text.doc)); - } else if (doc.getType() == LINE) { - Line &line = static_cast(doc); - return Line(indent + line.indent, Nest(indent, line.doc)); - } else if (doc.getType() == NIL) { - return Nil(); +const Doc* Nest(size_t indent, const Doc* doc) { + if (doc->getType() == TEXT) { + const Text* text = static_cast(doc); + return new Text(text->str, Nest(indent, text->doc)); + } else if (doc->getType() == LINE) { + const Line* line = static_cast(doc); + return new Line(indent + line->indent, Nest(indent, line->doc)); + } else if (doc->getType() == NIL) { + return new Nil(); } else { assert(false); } } // convert a document to a string -std::string Layout(Doc doc) { +std::string Layout(const Doc& doc) { if (doc.getType() == TEXT) { - Text &text = static_cast(doc); - return text.str + Layout(text.doc); + const Text& text = static_cast(doc); + return text.str + Layout(*text.doc); } else if (doc.getType() == LINE) { - Line &line = static_cast(doc); - return "\n" + std::string(line.indent, ' ') + Layout(line.doc); + const Line& line = static_cast(doc); + return "\n" + std::string(line.indent, ' ') + Layout(*line.doc); } else if (doc.getType() == NIL) { return ""; } else { assert(false); } @@ -89,35 +89,102 @@ std::string Layout(Doc doc) { } // doc +using namespace doc; + class PrettyPrinter : - public ExprFunctor { + public ExprFunctor { public: explicit PrettyPrinter() {} std::string Print(const NodeRef& node) { if (node.as_derived()) { - return doc::Layout(this->PrintExpr(Downcast(node))); + return Layout(this->PrintExpr(Downcast(node))); } else { assert(false); } } - doc::Doc PrintExpr(const Expr& expr) { + const Doc& PrintExpr(const Expr& expr) { auto it = memo_.find(expr); if (it != memo_.end()) return it->second; - doc::Doc val = this->VisitExpr(expr); + const Doc& val = this->VisitExpr(expr); memo_[expr] = val; return val; } + // render a tvm array with open and closing brackets and a separator + // we use docs instead of strings for input to allow the caller to use + // newlines where desired + template + const Doc& PrintArray(const Doc& open, const tvm::Array& arr, const Doc& sep, const Doc& close) { + Doc seq; + if (arr.size() == 0) { + seq = Nil(); + } else { + seq = Text(this->Print(arr[0])); + for (size_t i = 1; i < arr.size(); i++) { + seq = seq + sep + Text(this->Print(arr[i])); + } + } + + return open + seq + close; + } + + /*! + * \brief special method to print out const scalar + * \param dtype The data type + * \param data The pointer to hold the data. + */ + template + const Doc& PrintConstScalar(DataType dtype, const T* data) { // NOLINT(*) + std::stringstream ss; + if (dtype == Int(32)) { + ss << data[0]; + } else if (dtype == Float(32)) { + ss << data[0] << 'f'; + } else if (dtype == Bool()) { + // ss << PrintBool(data[0] != 0); + assert(false); + } else { + ss << dtype << "(" << data[0] << ")"; + } + return new Text(ss.str()); + } + + const Doc& VisitExpr_(const ConstantNode* op) final { + // Print out simple scalar directly. + if (op->is_scalar()) { + std::ostringstream os; + DataType dtype = TVMType2Type(op->data->dtype); + CHECK_EQ(op->data->ctx.device_type, kDLCPU); + if (dtype == Int(32)) { + return PrintConstScalar(dtype, static_cast(op->data->data)); + } else if (dtype == Int(64)) { + return PrintConstScalar(dtype, static_cast(op->data->data)); + } else if (dtype == Float(32)) { + return PrintConstScalar(dtype, static_cast(op->data->data)); + } else if (dtype == Float(64)) { + return PrintConstScalar(dtype, static_cast(op->data->data)); + } else if (dtype == Bool()) { + return PrintConstScalar(dtype, static_cast(op->data->data)); + } + } + // TODO: handle complicated scalars + assert(false); + } + + const Doc& VisitExpr_(const TupleNode* op) final { + return PrintArray(Text("("), op->fields, Text(", "), Text(")")); + } + private: /*! \brief Map from Expr to Doc */ - std::unordered_map memo_; + std::unordered_map memo_; }; std::string RelayPrettyPrint(const NodeRef& node) { return PrettyPrinter().Print(node); } -TVM_REGISTER_API("relay._expr.RelayPrettyPrint") +TVM_REGISTER_API("relay._expr.pretty_print") .set_body_typed(RelayPrettyPrint); } // relay diff --git a/tests/python/relay/test_ir_parser_printer_roundtrip.py b/tests/python/relay/test_ir_parser_printer_roundtrip.py index 1bcd89698416..929b29aeadfd 100644 --- a/tests/python/relay/test_ir_parser_printer_roundtrip.py +++ b/tests/python/relay/test_ir_parser_printer_roundtrip.py @@ -1,6 +1,7 @@ import tvm from tvm import relay from tvm.relay.ir_pass import alpha_equal +from tvm.relay._expr import pretty_print import numpy as np from hypothesis import given, reject @@ -25,6 +26,11 @@ def test_roundtrip(e): alpha_equal(relay.fromtext(e.astext(inline_meta_data=True)), e) # e.astext() +@given(tuples()) +def test_roundtrip_pp(e): + print(pretty_print(e)) + alpha_equal(relay.fromtext(pretty_print(e)), e) + if __name__ == "__main__": for _ in range(10): # print(constants().example().astext()) From 9f5667272f7790b0df404c231390f88dbc38c3bc Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Sun, 17 Feb 2019 18:12:25 -0800 Subject: [PATCH 07/73] switch to tvm-style adts --- include/tvm/relay/doc.h | 112 +++++++++++++++ src/relay/ir/doc.cc | 127 ++++++++++++++++++ src/relay/ir/pretty_printer.cc | 103 ++------------ .../relay/test_ir_parser_printer_roundtrip.py | 8 +- 4 files changed, 255 insertions(+), 95 deletions(-) create mode 100644 include/tvm/relay/doc.h create mode 100644 src/relay/ir/doc.cc diff --git a/include/tvm/relay/doc.h b/include/tvm/relay/doc.h new file mode 100644 index 000000000000..1b7303ddfd37 --- /dev/null +++ b/include/tvm/relay/doc.h @@ -0,0 +1,112 @@ +/*! + * Copyright (c) 2019 by Contributors + * \file tvm/relay/doc.h + * \brief Doc ADT used for pretty printing. + * Based on Section 1 of https://homepages.inf.ed.ac.uk/wadler/papers/prettier/prettier.pdf. + */ +#ifndef TVM_RELAY_DOC_H_ +#define TVM_RELAY_DOC_H_ + +#include + +namespace tvm { +namespace relay { + +// ADT +class Doc; + +class DocNode : public Node { + public: + static constexpr const char* _type_key = "Doc"; + TVM_DECLARE_BASE_NODE_INFO(DocNode, Node); +}; + +class Doc : public NodeRef { + public: + Doc() {} + explicit Doc(NodePtr n) : NodeRef(n) {} + const DocNode* operator->() const { + return static_cast(node_.get()); + } + + using ContainerType = DocNode; +}; + +class Nil_; + +class Nil_Node : public DocNode { + public: + Nil_Node() {} + + void VisitAttrs(tvm::AttrVisitor* v) final {} + + TVM_DLL static Nil_ make(); + + static constexpr const char* _type_key = "Nil_"; + TVM_DECLARE_NODE_TYPE_INFO(Nil_Node, DocNode); +}; + +RELAY_DEFINE_NODE_REF(Nil_, Nil_Node, Doc); + +class Text_; + +class Text_Node : public DocNode { + public: + std::string str; + Doc doc; + + Text_Node() {} + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("str", &str); + v->Visit("doc", &doc); + } + + TVM_DLL static Text_ make(std::string str, Doc doc); + + static constexpr const char* _type_key = "Text_"; + TVM_DECLARE_NODE_TYPE_INFO(Text_Node, DocNode); +}; + +RELAY_DEFINE_NODE_REF(Text_, Text_Node, Doc); + +class Line_; + +class Line_Node : public DocNode { + public: + int indent; + Doc doc; + + Line_Node() {} + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("indent", &indent); + v->Visit("doc", &doc); + } + + TVM_DLL static Line_ make(int indent, Doc doc); + + static constexpr const char* _type_key = "Line_"; + TVM_DECLARE_NODE_TYPE_INFO(Line_Node, DocNode); +}; + +RELAY_DEFINE_NODE_REF(Line_, Line_Node, Doc); + +// empty doc +Doc Nil(); +// lift string to text +Doc Text(const std::string str); +// new line +Doc Line(); +// concat two docs +Doc Concat(const Doc& left, const Doc& right); +// sugar for Concat +Doc operator+(const Doc& left, const Doc& right); +// indent a doc +Doc Nest(int indent, const Doc& doc); +// convert doc to a string +std::string Layout(const Doc& doc); + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_DOC_H_ diff --git a/src/relay/ir/doc.cc b/src/relay/ir/doc.cc new file mode 100644 index 000000000000..d94eab8c5aaf --- /dev/null +++ b/src/relay/ir/doc.cc @@ -0,0 +1,127 @@ +/*! + * Copyright (c) 2019 by Contributors + * \file src/tvm/relay/doc.cc + * \brief Doc ADT used for pretty printing. + * Based on Section 1 of https://homepages.inf.ed.ac.uk/wadler/papers/prettier/prettier.pdf. + */ +#include + +namespace tvm { +namespace relay { + +// Doc ADT implementation +Nil_ Nil_Node::make() { + NodePtr n = make_node(); + return Nil_(n); +} + +TVM_REGISTER_API("relay._make.Nil_") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = Nil_Node::make(); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const Nil_Node* node, tvm::IRPrinter* p) { + p->stream << "Nil_Node()"; + }); + +Text_ Text_Node::make(std::string str, Doc doc) { + NodePtr n = make_node(); + n->str = std::move(str); + n->doc = std::move(doc); + return Text_(n); +} + +TVM_REGISTER_API("relay._make.Text_") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = Text_Node::make(args[0], args[1]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const Text_Node* node, tvm::IRPrinter* p) { + p->stream << "Text_Node(" << node->str << ", " << node->doc << ")"; + }); + +Line_ Line_Node::make(int indent, Doc doc) { + NodePtr n = make_node(); + n->indent = indent; + n->doc = std::move(doc); + return Line_(n); +} + +TVM_REGISTER_API("relay._make.Line_") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = Line_Node::make(args[0], args[1]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const Line_Node* node, tvm::IRPrinter* p) { + p->stream << "Line_Node(" << node->indent << ", " << node->doc << ")"; + }); + +// DSL functions + +// empty doc +Doc Nil() { + return Nil_Node::make(); +} + +// lift string to text +Doc Text(const std::string str) { + return Text_Node::make(str, Nil()); +} + +// new line +Doc Line() { + return Line_Node::make(0, Nil()); +} + +// concat two docs +Doc Concat(const Doc& left, const Doc& right) { + if (const Text_Node* text = left.as()) { + // push right into text continuation + return Text_Node::make(text->str, Concat(text->doc, right)); + } else if (const Line_Node* line = left.as()) { + // push right into line continuation + return Line_Node::make(line->indent, Concat(line->doc, right)); + } else if (const Nil_Node* nil = left.as()) { + // throwaway nils on the left + return right; + } else {assert(false);} +} + +// sugar for Concat +Doc operator+(const Doc& left, const Doc& right) { + return Concat(left, right); +} + +// indent a doc +Doc Nest(int indent, const Doc& doc) { + if (const Text_Node* text = doc.as()) { + // push nest through + return Text_Node::make(text->str, Nest(indent, text->doc)); + } else if (const Line_Node* line = doc.as()) { + // add indent to lines and continue + return Line_Node::make(indent + line->indent, Nest(indent, line->doc)); + } else if (const Nil_Node* nil = doc.as()) { + // absorb it + return Nil_(); + } else {assert(false);} +} + +// convert a doc to a string +std::string Layout(const Doc& doc) { + if (const Text_Node* text = doc.as()) { + // add text and continue + return text->str + Layout(text->doc); + } else if (const Line_Node* line = doc.as()) { + // add a newline and indents, then continue + return "\n" + std::string(line->indent, ' ') + Layout(line->doc); + } else if (const Nil_Node* nil = doc.as()) { + // empty string + return ""; + } else {assert(false);} +} + +} // relay +} // tvm \ No newline at end of file diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 4844ffad3e01..280859a535ae 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -4,95 +4,14 @@ * \brief Pretty printer for Relay programs * Supports ANF and GNF formats and metadata. */ +#include #include namespace tvm { namespace relay { -namespace doc { - -// Doc model based on Section 1 of https://homepages.inf.ed.ac.uk/wadler/papers/prettier/prettier.pdf. - -enum DocType { NIL, TEXT, LINE }; - -struct Doc { - virtual DocType getType() const = 0; -}; - -struct Nil : public Doc { - DocType getType() const { return NIL; } -}; - -struct Text : public Doc { - const std::string str; - const Doc* doc; - - Text(const std::string str) : str(str), doc(new Nil()) { } - Text(const std::string str, const Doc* doc) : str(str), doc(doc) { } - - DocType getType() const { return TEXT; } -}; - -struct Line : public Doc { - const size_t indent; - const Doc* doc; - - Line() : indent(0), doc(new Nil()) { } - Line(const size_t indent, const Doc* doc) : indent(indent), doc(doc) { } - - DocType getType() const { return LINE; } -}; - -// concatenate two documents -const Doc* Concat(const Doc* left, const Doc* right) { - if (left->getType() == TEXT) { - const Text* text = static_cast(left); - return new Text(text->str, Concat(text->doc, right)); - } else if (left->getType() == LINE) { - const Line* line = static_cast(left); - return new Line(line->indent, Concat(line->doc, right)); - } else if (left->getType() == NIL) { - return right; - } else { assert(false); } -} - -// overload + to concatenate documents -const Doc& operator+(const Doc& left, const Doc& right) { - return *Concat(&left, &right); -} - -// add indentation to a document -const Doc* Nest(size_t indent, const Doc* doc) { - if (doc->getType() == TEXT) { - const Text* text = static_cast(doc); - return new Text(text->str, Nest(indent, text->doc)); - } else if (doc->getType() == LINE) { - const Line* line = static_cast(doc); - return new Line(indent + line->indent, Nest(indent, line->doc)); - } else if (doc->getType() == NIL) { - return new Nil(); - } else { assert(false); } -} - -// convert a document to a string -std::string Layout(const Doc& doc) { - if (doc.getType() == TEXT) { - const Text& text = static_cast(doc); - return text.str + Layout(*text.doc); - } else if (doc.getType() == LINE) { - const Line& line = static_cast(doc); - return "\n" + std::string(line.indent, ' ') + Layout(*line.doc); - } else if (doc.getType() == NIL) { - return ""; - } else { assert(false); } -} - -} // doc - -using namespace doc; - class PrettyPrinter : - public ExprFunctor { + public ExprFunctor { public: explicit PrettyPrinter() {} @@ -102,10 +21,10 @@ class PrettyPrinter : } else { assert(false); } } - const Doc& PrintExpr(const Expr& expr) { + const Doc PrintExpr(const Expr& expr) { auto it = memo_.find(expr); if (it != memo_.end()) return it->second; - const Doc& val = this->VisitExpr(expr); + Doc val = this->VisitExpr(expr); memo_[expr] = val; return val; } @@ -114,7 +33,7 @@ class PrettyPrinter : // we use docs instead of strings for input to allow the caller to use // newlines where desired template - const Doc& PrintArray(const Doc& open, const tvm::Array& arr, const Doc& sep, const Doc& close) { + const Doc PrintArray(const Doc& open, const tvm::Array& arr, const Doc& sep, const Doc& close) { Doc seq; if (arr.size() == 0) { seq = Nil(); @@ -134,7 +53,7 @@ class PrettyPrinter : * \param data The pointer to hold the data. */ template - const Doc& PrintConstScalar(DataType dtype, const T* data) { // NOLINT(*) + Doc PrintConstScalar(DataType dtype, const T* data) { // NOLINT(*) std::stringstream ss; if (dtype == Int(32)) { ss << data[0]; @@ -146,10 +65,10 @@ class PrettyPrinter : } else { ss << dtype << "(" << data[0] << ")"; } - return new Text(ss.str()); + return Text(ss.str()); } - const Doc& VisitExpr_(const ConstantNode* op) final { + Doc VisitExpr_(const ConstantNode* op) final { // Print out simple scalar directly. if (op->is_scalar()) { std::ostringstream os; @@ -167,11 +86,11 @@ class PrettyPrinter : return PrintConstScalar(dtype, static_cast(op->data->data)); } } - // TODO: handle complicated scalars + // TODO: handle tensors assert(false); } - const Doc& VisitExpr_(const TupleNode* op) final { + Doc VisitExpr_(const TupleNode* op) final { return PrintArray(Text("("), op->fields, Text(", "), Text(")")); } @@ -181,7 +100,7 @@ class PrettyPrinter : }; std::string RelayPrettyPrint(const NodeRef& node) { - return PrettyPrinter().Print(node); + return "v0.0.2\n" + PrettyPrinter().Print(node); } TVM_REGISTER_API("relay._expr.pretty_print") diff --git a/tests/python/relay/test_ir_parser_printer_roundtrip.py b/tests/python/relay/test_ir_parser_printer_roundtrip.py index 929b29aeadfd..a4d72c14652a 100644 --- a/tests/python/relay/test_ir_parser_printer_roundtrip.py +++ b/tests/python/relay/test_ir_parser_printer_roundtrip.py @@ -20,17 +20,19 @@ def tuples(draw): # TODO: replace constants with exprs return relay.Tuple(draw(lists(constants()))) -@given(tuples()) +""" @given(tuples()) def test_roundtrip(e): print(e.astext(inline_meta_data=True)) alpha_equal(relay.fromtext(e.astext(inline_meta_data=True)), e) - # e.astext() + # e.astext() """ @given(tuples()) def test_roundtrip_pp(e): - print(pretty_print(e)) alpha_equal(relay.fromtext(pretty_print(e)), e) +# def test_roundtrip_pp_simple(): +# print(pretty_print(relay.const(1))) + if __name__ == "__main__": for _ in range(10): # print(constants().example().astext()) From 1f8ff51a12e96af43baa75542367a374ab01ad7c Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Sun, 17 Feb 2019 19:02:15 -0800 Subject: [PATCH 08/73] tuple projection and more testing --- src/relay/ir/doc.cc | 2 +- src/relay/ir/pretty_printer.cc | 8 +++++-- ...undtrip.py => test_ir_parser_roundtrip.py} | 23 +++++++++++++------ 3 files changed, 23 insertions(+), 10 deletions(-) rename tests/python/relay/{test_ir_parser_printer_roundtrip.py => test_ir_parser_roundtrip.py} (64%) diff --git a/src/relay/ir/doc.cc b/src/relay/ir/doc.cc index d94eab8c5aaf..e16a64b9c1ab 100644 --- a/src/relay/ir/doc.cc +++ b/src/relay/ir/doc.cc @@ -124,4 +124,4 @@ std::string Layout(const Doc& doc) { } } // relay -} // tvm \ No newline at end of file +} // tvm diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 280859a535ae..465edd902ead 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -91,7 +91,11 @@ class PrettyPrinter : } Doc VisitExpr_(const TupleNode* op) final { - return PrintArray(Text("("), op->fields, Text(", "), Text(")")); + return PrintArray(Text("("), op->fields, Text(", "), Text(")")); + } + + Doc VisitExpr_(const TupleGetItemNode* op) final { + return this->VisitExpr(op->tuple) + Text(".") + Text(std::to_string(op->index)); } private: @@ -107,4 +111,4 @@ TVM_REGISTER_API("relay._expr.pretty_print") .set_body_typed(RelayPrettyPrint); } // relay -} // tvm \ No newline at end of file +} // tvm diff --git a/tests/python/relay/test_ir_parser_printer_roundtrip.py b/tests/python/relay/test_ir_parser_roundtrip.py similarity index 64% rename from tests/python/relay/test_ir_parser_printer_roundtrip.py rename to tests/python/relay/test_ir_parser_roundtrip.py index a4d72c14652a..eab287f43fec 100644 --- a/tests/python/relay/test_ir_parser_printer_roundtrip.py +++ b/tests/python/relay/test_ir_parser_roundtrip.py @@ -4,8 +4,12 @@ from tvm.relay._expr import pretty_print import numpy as np -from hypothesis import given, reject -from hypothesis.strategies import text, lists, integers, composite, recursive +from hypothesis import given, reject, settings +from hypothesis.strategies import text, lists, integers, composite, recursive, deferred + +exprs = deferred(lambda: constants() + # | projections(exprs) + | tuples(exprs)) @composite def constants(draw): @@ -16,9 +20,12 @@ def constants(draw): return relay.Constant(tvm.nd.array(np.array(python_tensor).astype("int32"))) @composite -def tuples(draw): - # TODO: replace constants with exprs - return relay.Tuple(draw(lists(constants()))) +def tuples(draw, field_type): + return relay.Tuple(draw(lists(field_type, max_size=5))) + +@composite +def projections(draw, field_type): + return relay.TupleGetItem(draw(field_type), draw(integers(min_value=-1000, max_value=1000))) """ @given(tuples()) def test_roundtrip(e): @@ -26,7 +33,8 @@ def test_roundtrip(e): alpha_equal(relay.fromtext(e.astext(inline_meta_data=True)), e) # e.astext() """ -@given(tuples()) +@settings(deadline=500) +@given(exprs) def test_roundtrip_pp(e): alpha_equal(relay.fromtext(pretty_print(e)), e) @@ -36,4 +44,5 @@ def test_roundtrip_pp(e): if __name__ == "__main__": for _ in range(10): # print(constants().example().astext()) - print(tuples().example().astext(inline_meta_data=True)) + # print(tuples().example().astext(inline_meta_data=True)) + print(pretty_print(exprs.example())) From 414fab40ff27ae401f2c730cbcef4d50c04fc1fe Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Sun, 17 Feb 2019 19:19:53 -0800 Subject: [PATCH 09/73] revert text_printer changes --- include/tvm/relay/expr.h | 1 - python/tvm/relay/base.py | 6 +- src/relay/ir/error.cc | 2 +- src/relay/ir/text_printer.cc | 76 ++++++++----------- src/relay/pass/fuse_ops.cc | 2 +- .../python/relay/test_ir_parser_roundtrip.py | 11 --- 6 files changed, 35 insertions(+), 63 deletions(-) diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 5be11dcfb77c..06a1aa1ac9ef 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -553,7 +553,6 @@ inline const TTypeNode* ExprNode::type_as() const { */ std::string RelayPrint( const NodeRef& node, - bool inline_meta_data = false, bool show_meta_data = true, runtime::TypedPackedFunc annotate = nullptr); } // namespace relay diff --git a/python/tvm/relay/base.py b/python/tvm/relay/base.py index ca9e36c88f73..e0491d62f552 100644 --- a/python/tvm/relay/base.py +++ b/python/tvm/relay/base.py @@ -38,7 +38,7 @@ def register_relay_attr_node(type_key=None): class RelayNode(NodeBase): """Base class of all Relay nodes.""" - def astext(self, inline_meta_data=False, show_meta_data=True, annotate=None): + def astext(self, show_meta_data=True, annotate=None): """Get the text format of the expression. Parameters @@ -62,13 +62,13 @@ def astext(self, inline_meta_data=False, show_meta_data=True, annotate=None): text : str The text format of the expression. """ - return _expr.RelayPrint(self, inline_meta_data, show_meta_data, annotate) + return _expr.RelayPrint(self, show_meta_data, annotate) def set_span(self, span): _base.set_span(self, span) def __str__(self): - return self.astext(inline_meta_data=False, show_meta_data=False) + return self.astext(show_meta_data=False) @register_relay_node diff --git a/src/relay/ir/error.cc b/src/relay/ir/error.cc index be0d11064209..24f8d1c49b6b 100644 --- a/src/relay/ir/error.cc +++ b/src/relay/ir/error.cc @@ -91,7 +91,7 @@ void ErrorReporter::RenderErrors(const Module& module, bool use_color) { // // The annotation callback will annotate the error messages // contained in the map. - annotated_prog << RelayPrint(func, false, false, [&err_map](tvm::relay::Expr expr) { + annotated_prog << RelayPrint(func, false, [&err_map](tvm::relay::Expr expr) { auto it = err_map.find(expr); if (it != err_map.end()) { return it->second; diff --git a/src/relay/ir/text_printer.cc b/src/relay/ir/text_printer.cc index 7ea4b69d70a0..932856a2055d 100644 --- a/src/relay/ir/text_printer.cc +++ b/src/relay/ir/text_printer.cc @@ -73,9 +73,9 @@ inline std::ostream& operator<<(std::ostream& os, const TextValue& val) { // NO * fn (%x: Tensor[(meta[Variable][0],), float32]) { * %x * } - * // Meta data section is a json-serialized string - * // of the following array. - * // [tvm.var("n")] + * # Meta data section is a json-serialized string + * # of the following array. + * # [tvm.var("n")] * * \endcode * @@ -139,17 +139,15 @@ class TextPrinter : public TypeFunctor, // NOLINT(*) public AttrFunctor { // NOLINT(*) public: - explicit TextPrinter(bool inline_meta_data, - bool show_meta_data, + explicit TextPrinter(bool show_meta_data, runtime::TypedPackedFunc annotate) - : inline_meta_data_(inline_meta_data), show_meta_data_(show_meta_data), annotate_(annotate) {} + : show_meta_data_(show_meta_data), annotate_(annotate) {} /*! * \brief Print a node to string. * \param node. * \return The string representation. */ std::string Print(const NodeRef& node) { - stream_ << "v0.0.2\n"; if (node.as()) { this->PrintFunc(Downcast(node)); } else if (node.as()) { @@ -165,12 +163,12 @@ class TextPrinter : if (show_meta_data_) { std::string meta_json = meta_.GetMetaSection(); // append meta data in the end. - stream_ << "// meta data\n" + stream_ << "# meta data\n" << "r\"\"\"\n" << meta_json << "\n" << "\"\"\""; } else { - stream_ << "// meta data omitted. you can use show_meta_data=True to include meta-data\n"; + stream_ << "# meta data omitted. you can use show_meta_data=True to include meta-data\n"; } } return stream_.str(); @@ -258,38 +256,27 @@ class TextPrinter : } TextValue VisitExpr_(const TupleNode* op) final { - if (inline_meta_data_) { - stream_ << "("; - for (size_t i = 0; i < op->fields.size(); i++) { - stream_ << GetValue(op->fields[i]); - if (i + 1 != op->fields.size()) - stream_ << ", "; - } - stream_ << ")"; - return TextValue(""); - } else { - std::vector fields; - for (Expr field : op->fields) { - fields.push_back(GetValue(field)); - } - // NOTE: always recursively visit to get ids, - // before print out the current line - TextValue id = this->AllocTempVar(); - this->PrintIndent(); - stream_ << id << " = ("; - for (size_t i = 0; i < fields.size(); ++i) { - stream_ << fields[i]; - if (i + 1 != fields.size()) { - stream_ << ", "; - } - } - if (fields.size() == 1) { - stream_ << ','; + std::vector fields; + for (Expr field : op->fields) { + fields.push_back(GetValue(field)); + } + // NOTE: always recursively visit to get ids, + // before print out the current line + TextValue id = this->AllocTempVar(); + this->PrintIndent(); + stream_ << id << " = ("; + for (size_t i = 0; i < fields.size(); ++i) { + stream_ << fields[i]; + if (i + 1 != fields.size()) { + stream_ << ", "; } - stream_ << ')'; - this->PrintEndInst("\n"); - return id; } + if (fields.size() == 1) { + stream_ << ','; + } + stream_ << ')'; + this->PrintEndInst("\n"); + return id; } TextValue VisitExpr_(const VarNode* op) final { @@ -650,9 +637,9 @@ class TextPrinter : void PrintOptionalInfo(const Expr& expr) { // additional information in comment. if (annotate_ != nullptr) { - stream_ << " // " << annotate_(expr); + stream_ << " # " << annotate_(expr); } else if (expr->checked_type_.defined()) { - stream_ << " // ty="; + stream_ << " # ty="; this->PrintType(expr->checked_type(), stream_); } } @@ -808,8 +795,6 @@ class TextPrinter : private: class AttrPrinter; friend class AttrPrinter; - /*! \brief Whether to inline meta data. If enabled, ignores show_meta_data_ flag. */ - bool inline_meta_data_; /*! \brief Whether to print meta data. */ bool show_meta_data_; /*! \brief additional comment function */ @@ -905,15 +890,14 @@ void TextPrinter::PrintCallAttrs(const Expr& op, } std::string RelayPrint(const NodeRef& node, - bool inline_meta_data, bool show_meta_data, runtime::TypedPackedFunc annotate) { - return TextPrinter(inline_meta_data, show_meta_data, annotate).Print(node); + return TextPrinter(show_meta_data, annotate).Print(node); } TVM_REGISTER_API("relay._expr.RelayPrint") .set_body_typed)>(RelayPrint); } // namespace relay diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 6efcd4464e75..66ff9caf4ae4 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -835,7 +835,7 @@ class FuseMutator : private ExprMutator { // Debug function, dump the group assignment in text. void DebugDumpGroup(const Expr& body) { - std::string text = RelayPrint(body, false, false, [this](const Expr& expr) -> std::string { + std::string text = RelayPrint(body, false, [this](const Expr& expr) -> std::string { auto it = gmap_.find(expr.get()); if (it == gmap_.end()) return ""; std::ostringstream os; diff --git a/tests/python/relay/test_ir_parser_roundtrip.py b/tests/python/relay/test_ir_parser_roundtrip.py index eab287f43fec..15fcf61e5533 100644 --- a/tests/python/relay/test_ir_parser_roundtrip.py +++ b/tests/python/relay/test_ir_parser_roundtrip.py @@ -27,22 +27,11 @@ def tuples(draw, field_type): def projections(draw, field_type): return relay.TupleGetItem(draw(field_type), draw(integers(min_value=-1000, max_value=1000))) -""" @given(tuples()) -def test_roundtrip(e): - print(e.astext(inline_meta_data=True)) - alpha_equal(relay.fromtext(e.astext(inline_meta_data=True)), e) - # e.astext() """ - @settings(deadline=500) @given(exprs) def test_roundtrip_pp(e): alpha_equal(relay.fromtext(pretty_print(e)), e) -# def test_roundtrip_pp_simple(): -# print(pretty_print(relay.const(1))) - if __name__ == "__main__": for _ in range(10): - # print(constants().example().astext()) - # print(tuples().example().astext(inline_meta_data=True)) print(pretty_print(exprs.example())) From 497f4eb81178d614bf189cbf70b6a45989202478 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Sun, 17 Feb 2019 19:32:18 -0800 Subject: [PATCH 10/73] derandomize --- tests/python/relay/test_ir_parser_roundtrip.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/python/relay/test_ir_parser_roundtrip.py b/tests/python/relay/test_ir_parser_roundtrip.py index 15fcf61e5533..a4c1b322bbaf 100644 --- a/tests/python/relay/test_ir_parser_roundtrip.py +++ b/tests/python/relay/test_ir_parser_roundtrip.py @@ -27,7 +27,8 @@ def tuples(draw, field_type): def projections(draw, field_type): return relay.TupleGetItem(draw(field_type), draw(integers(min_value=-1000, max_value=1000))) -@settings(deadline=500) +# TODO: figure out a way to not have to derandomize all the time +@settings(deadline=500, derandomize=True) @given(exprs) def test_roundtrip_pp(e): alpha_equal(relay.fromtext(pretty_print(e)), e) From c921cd0768f6fe169b540c4b2b46ff58cf8be745 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Sun, 17 Feb 2019 19:36:39 -0800 Subject: [PATCH 11/73] revert version number bump --- include/tvm/relay/doc.h | 2 ++ python/tvm/relay/grammar/Relay.g4 | 2 +- src/relay/ir/pretty_printer.cc | 2 +- tests/python/relay/test_ir_parser.py | 2 +- 4 files changed, 5 insertions(+), 3 deletions(-) diff --git a/include/tvm/relay/doc.h b/include/tvm/relay/doc.h index 1b7303ddfd37..449fcd17aa58 100644 --- a/include/tvm/relay/doc.h +++ b/include/tvm/relay/doc.h @@ -92,6 +92,8 @@ class Line_Node : public DocNode { RELAY_DEFINE_NODE_REF(Line_, Line_Node, Doc); +// DSL functions + // empty doc Doc Nil(); // lift string to text diff --git a/python/tvm/relay/grammar/Relay.g4 b/python/tvm/relay/grammar/Relay.g4 index 18b05f2761e8..0a2206265502 100644 --- a/python/tvm/relay/grammar/Relay.g4 +++ b/python/tvm/relay/grammar/Relay.g4 @@ -1,6 +1,6 @@ grammar Relay; -SEMVER: 'v0.0.2' ; +SEMVER: 'v0.0.1' ; // Lexing // comments diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 465edd902ead..0ced4f73d259 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -104,7 +104,7 @@ class PrettyPrinter : }; std::string RelayPrettyPrint(const NodeRef& node) { - return "v0.0.2\n" + PrettyPrinter().Print(node); + return "v0.0.1\n" + PrettyPrinter().Print(node); } TVM_REGISTER_API("relay._expr.pretty_print") diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 9b725234a33c..08d4c430101b 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -12,7 +12,7 @@ else: raises_parse_error = lambda x: x -SEMVER = "v0.0.2" +SEMVER = "v0.0.1" BINARY_OPS = { "*": relay.multiply, From b7cdf43f5cadb07d535e27c309bc06f00642fef6 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Thu, 21 Feb 2019 12:40:34 -0800 Subject: [PATCH 12/73] adds separate gnf printer file and factors out some common functions. refuses to build tho --- include/tvm/relay/doc.h | 7 ++++ src/relay/ir/doc.cc | 47 +++++++++++++++++++++++ src/relay/ir/gnf_printer.cc | 68 ++++++++++++++++++++++++++++++++++ src/relay/ir/pretty_printer.cc | 55 +++++---------------------- 4 files changed, 132 insertions(+), 45 deletions(-) create mode 100644 src/relay/ir/gnf_printer.cc diff --git a/include/tvm/relay/doc.h b/include/tvm/relay/doc.h index 449fcd17aa58..5826edcbfc56 100644 --- a/include/tvm/relay/doc.h +++ b/include/tvm/relay/doc.h @@ -108,6 +108,13 @@ Doc operator+(const Doc& left, const Doc& right); Doc Nest(int indent, const Doc& doc); // convert doc to a string std::string Layout(const Doc& doc); +// render array-like things: e.g. (1, 2, 3) +Doc PrintArray(const Doc& open, const tvm::Array& arr, const Doc& sep, const Doc& close); +// Print constant bool value. +Doc PrintBool(bool value); +// special method to print out const scalar +template +Doc PrintConstScalar(DataType dtype, const T* data); } // namespace relay } // namespace tvm diff --git a/src/relay/ir/doc.cc b/src/relay/ir/doc.cc index e16a64b9c1ab..5930609f945e 100644 --- a/src/relay/ir/doc.cc +++ b/src/relay/ir/doc.cc @@ -123,5 +123,52 @@ std::string Layout(const Doc& doc) { } else {assert(false);} } +// render array-like things: e.g. (1, 2, 3) +Doc PrintArray(const Doc& open, const tvm::Array& arr, const Doc& sep, const Doc& close) { + Doc seq; + if (arr.size() == 0) { + seq = Nil(); + } else { + seq = arr[0]; + for (size_t i = 1; i < arr.size(); i++) { + seq = seq + sep + arr[i]; + } + } + + return open + seq + close; +} + +/*! + * \brief Print constant bool value. + * \param value The value to be printed. + */ +Doc PrintBool(bool value) { + if (value) { + return Text("True"); + } else { + return Text("False"); + } +} + +/*! + * \brief special method to print out const scalar + * \param dtype The data type + * \param data The pointer to hold the data. + */ +template +Doc PrintConstScalar(DataType dtype, const T* data) { + std::ostringstream os; + if (dtype == Int(32)) { + os << data[0]; + } else if (dtype == Float(32)) { + os << data[0] << 'f'; + } else if (dtype == Bool()) { + return PrintBool(data[0] != 0); + } else { + os << dtype << "(" << data[0] << ")"; + } + return Text(os.str()); +} + } // relay } // tvm diff --git a/src/relay/ir/gnf_printer.cc b/src/relay/ir/gnf_printer.cc new file mode 100644 index 000000000000..3b9277b6c554 --- /dev/null +++ b/src/relay/ir/gnf_printer.cc @@ -0,0 +1,68 @@ +/*! + * Copyright (c) 2019 by Contributors + * \file gnf_printer.cc + * \brief GNF printer for Relay programs + * Supports GNF and metadata. + */ +#include +#include + +namespace tvm { +namespace relay { + +class GNFPrinter : + public ExprFunctor { + public: + explicit GNFPrinter() {} + + std::string Print(const NodeRef& node) { + if (node.as_derived()) { + return Layout(this->PrintExpr(Downcast(node))); + } else { assert(false); } + } + + const Doc PrintExpr(const Expr& expr) { + auto it = memo_.find(expr); + if (it != memo_.end()) return it->second; + Doc val = this->VisitExpr(expr); + memo_[expr] = val; + return val; + } + + Doc VisitExpr_(const ConstantNode* op) final { + // Print out simple scalars directly. + if (op->is_scalar()) { + std::ostringstream os; + DataType dtype = TVMType2Type(op->data->dtype); + CHECK_EQ(op->data->ctx.device_type, kDLCPU); + if (dtype == Int(32)) { + return PrintConstScalar(dtype, static_cast(op->data->data)); + } else if (dtype == Int(64)) { + return PrintConstScalar(dtype, static_cast(op->data->data)); + } else if (dtype == Float(32)) { + return PrintConstScalar(dtype, static_cast(op->data->data)); + } else if (dtype == Float(64)) { + return PrintConstScalar(dtype, static_cast(op->data->data)); + } else if (dtype == Bool()) { + return PrintConstScalar(dtype, static_cast(op->data->data)); + } + } + // TODO: handle tensors + assert(false); + } + + private: + /*! \brief Map from Expr to Doc */ + std::unordered_map memo_; + size_t temp_var_counter_{0}; +}; + +std::string RelayPrettyPrint(const NodeRef& node) { + return "v0.0.1\n" + GNFPrinter().Print(node); +} + +TVM_REGISTER_API("relay._expr.pretty_print") +.set_body_typed(RelayPrettyPrint); + +} // relay +} // tvm diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 0ced4f73d259..f3b3a0c17cd4 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -2,7 +2,7 @@ * Copyright (c) 2019 by Contributors * \file pretty_printer.cc * \brief Pretty printer for Relay programs - * Supports ANF and GNF formats and metadata. + * Supports functional style and metadata. */ #include #include @@ -15,9 +15,9 @@ class PrettyPrinter : public: explicit PrettyPrinter() {} - std::string Print(const NodeRef& node) { + const Doc Print(const NodeRef& node) { if (node.as_derived()) { - return Layout(this->PrintExpr(Downcast(node))); + return this->PrintExpr(Downcast(node)); } else { assert(false); } } @@ -29,45 +29,6 @@ class PrettyPrinter : return val; } - // render a tvm array with open and closing brackets and a separator - // we use docs instead of strings for input to allow the caller to use - // newlines where desired - template - const Doc PrintArray(const Doc& open, const tvm::Array& arr, const Doc& sep, const Doc& close) { - Doc seq; - if (arr.size() == 0) { - seq = Nil(); - } else { - seq = Text(this->Print(arr[0])); - for (size_t i = 1; i < arr.size(); i++) { - seq = seq + sep + Text(this->Print(arr[i])); - } - } - - return open + seq + close; - } - - /*! - * \brief special method to print out const scalar - * \param dtype The data type - * \param data The pointer to hold the data. - */ - template - Doc PrintConstScalar(DataType dtype, const T* data) { // NOLINT(*) - std::stringstream ss; - if (dtype == Int(32)) { - ss << data[0]; - } else if (dtype == Float(32)) { - ss << data[0] << 'f'; - } else if (dtype == Bool()) { - // ss << PrintBool(data[0] != 0); - assert(false); - } else { - ss << dtype << "(" << data[0] << ")"; - } - return Text(ss.str()); - } - Doc VisitExpr_(const ConstantNode* op) final { // Print out simple scalar directly. if (op->is_scalar()) { @@ -91,11 +52,15 @@ class PrettyPrinter : } Doc VisitExpr_(const TupleNode* op) final { - return PrintArray(Text("("), op->fields, Text(", "), Text(")")); + std::vector fields; + for (Expr field : op->fields) { + fields.push_back(this->Print(field)); + } + return PrintArray(Text("("), fields, Text(", "), Text(")")); } Doc VisitExpr_(const TupleGetItemNode* op) final { - return this->VisitExpr(op->tuple) + Text(".") + Text(std::to_string(op->index)); + return this->Print(op->tuple) + Text(".") + Text(std::to_string(op->index)); } private: @@ -104,7 +69,7 @@ class PrettyPrinter : }; std::string RelayPrettyPrint(const NodeRef& node) { - return "v0.0.1\n" + PrettyPrinter().Print(node); + return "v0.0.1\n" + Layout(PrettyPrinter().Print(node)); } TVM_REGISTER_API("relay._expr.pretty_print") From eabd52a9613c0c52dc3c89952b1b93edc2cc2679 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Thu, 21 Feb 2019 15:54:29 -0800 Subject: [PATCH 13/73] fix some stuff and get compiling again --- include/tvm/relay/doc.h | 20 ++++++++++++++++++-- src/relay/ir/doc.cc | 20 -------------------- src/relay/ir/gnf_printer.cc | 12 ++++++------ 3 files changed, 24 insertions(+), 28 deletions(-) diff --git a/include/tvm/relay/doc.h b/include/tvm/relay/doc.h index 5826edcbfc56..eda2c26b5096 100644 --- a/include/tvm/relay/doc.h +++ b/include/tvm/relay/doc.h @@ -112,9 +112,25 @@ std::string Layout(const Doc& doc); Doc PrintArray(const Doc& open, const tvm::Array& arr, const Doc& sep, const Doc& close); // Print constant bool value. Doc PrintBool(bool value); -// special method to print out const scalar +/*! + * \brief special method to print out const scalar + * \param dtype The data type + * \param data The pointer to hold the data. + */ template -Doc PrintConstScalar(DataType dtype, const T* data); +Doc PrintConstScalar(DataType dtype, const T* data) { + std::ostringstream os; + if (dtype == Int(32)) { + os << data[0]; + } else if (dtype == Float(32)) { + os << data[0] << 'f'; + } else if (dtype == Bool()) { + return PrintBool(data[0] != 0); + } else { + os << dtype << "(" << data[0] << ")"; + } + return Text(os.str()); +} } // namespace relay } // namespace tvm diff --git a/src/relay/ir/doc.cc b/src/relay/ir/doc.cc index 5930609f945e..5f8a51a14a4b 100644 --- a/src/relay/ir/doc.cc +++ b/src/relay/ir/doc.cc @@ -150,25 +150,5 @@ Doc PrintBool(bool value) { } } -/*! - * \brief special method to print out const scalar - * \param dtype The data type - * \param data The pointer to hold the data. - */ -template -Doc PrintConstScalar(DataType dtype, const T* data) { - std::ostringstream os; - if (dtype == Int(32)) { - os << data[0]; - } else if (dtype == Float(32)) { - os << data[0] << 'f'; - } else if (dtype == Bool()) { - return PrintBool(data[0] != 0); - } else { - os << dtype << "(" << data[0] << ")"; - } - return Text(os.str()); -} - } // relay } // tvm diff --git a/src/relay/ir/gnf_printer.cc b/src/relay/ir/gnf_printer.cc index 3b9277b6c554..2d6a19d3d34d 100644 --- a/src/relay/ir/gnf_printer.cc +++ b/src/relay/ir/gnf_printer.cc @@ -15,9 +15,9 @@ class GNFPrinter : public: explicit GNFPrinter() {} - std::string Print(const NodeRef& node) { + const Doc Print(const NodeRef& node) { if (node.as_derived()) { - return Layout(this->PrintExpr(Downcast(node))); + return this->PrintExpr(Downcast(node)); } else { assert(false); } } @@ -57,12 +57,12 @@ class GNFPrinter : size_t temp_var_counter_{0}; }; -std::string RelayPrettyPrint(const NodeRef& node) { - return "v0.0.1\n" + GNFPrinter().Print(node); +std::string RelayGNFPrint(const NodeRef& node) { + return "v0.0.1\n" + Layout(GNFPrinter().Print(node)); } -TVM_REGISTER_API("relay._expr.pretty_print") -.set_body_typed(RelayPrettyPrint); +TVM_REGISTER_API("relay._expr.gnf_print") +.set_body_typed(RelayGNFPrint); } // relay } // tvm From fd254560a5bd57bac1eaec88445e59539e133da7 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Thu, 21 Feb 2019 21:12:38 -0800 Subject: [PATCH 14/73] simplify Doc infrastructure --- include/tvm/relay/doc.h | 105 ++++++++++------------------- src/relay/ir/doc.cc | 119 +++++++++++---------------------- src/relay/ir/pretty_printer.cc | 2 +- 3 files changed, 74 insertions(+), 152 deletions(-) diff --git a/include/tvm/relay/doc.h b/include/tvm/relay/doc.h index eda2c26b5096..b2843164a331 100644 --- a/include/tvm/relay/doc.h +++ b/include/tvm/relay/doc.h @@ -13,91 +13,54 @@ namespace tvm { namespace relay { // ADT -class Doc; - -class DocNode : public Node { - public: - static constexpr const char* _type_key = "Doc"; - TVM_DECLARE_BASE_NODE_INFO(DocNode, Node); -}; - -class Doc : public NodeRef { - public: - Doc() {} - explicit Doc(NodePtr n) : NodeRef(n) {} - const DocNode* operator->() const { - return static_cast(node_.get()); - } - - using ContainerType = DocNode; -}; - -class Nil_; - -class Nil_Node : public DocNode { - public: - Nil_Node() {} - - void VisitAttrs(tvm::AttrVisitor* v) final {} - - TVM_DLL static Nil_ make(); - - static constexpr const char* _type_key = "Nil_"; - TVM_DECLARE_NODE_TYPE_INFO(Nil_Node, DocNode); +struct DocNode { + virtual ~DocNode() = default; }; -RELAY_DEFINE_NODE_REF(Nil_, Nil_Node, Doc); - -class Text_; - -class Text_Node : public DocNode { - public: - std::string str; - Doc doc; +using Doc = std::shared_ptr; - Text_Node() {} +struct NilNode : DocNode { }; - void VisitAttrs(tvm::AttrVisitor* v) final { - v->Visit("str", &str); - v->Visit("doc", &doc); - } +struct TextNode : DocNode { + std::string str; + Doc doc; - TVM_DLL static Text_ make(std::string str, Doc doc); - - static constexpr const char* _type_key = "Text_"; - TVM_DECLARE_NODE_TYPE_INFO(Text_Node, DocNode); + TextNode(const std::string& str, const Doc& doc) : str(str), doc(doc) {} }; -RELAY_DEFINE_NODE_REF(Text_, Text_Node, Doc); - -class Line_; - -class Line_Node : public DocNode { - public: - int indent; - Doc doc; - - Line_Node() {} - - void VisitAttrs(tvm::AttrVisitor* v) final { - v->Visit("indent", &indent); - v->Visit("doc", &doc); - } - - TVM_DLL static Line_ make(int indent, Doc doc); +struct LineNode : DocNode { + int indent; + Doc doc; - static constexpr const char* _type_key = "Line_"; - TVM_DECLARE_NODE_TYPE_INFO(Line_Node, DocNode); + LineNode(int indent, const Doc& doc) : indent(indent), doc(doc) {} }; -RELAY_DEFINE_NODE_REF(Line_, Line_Node, Doc); +/* template +T Match(const Doc& doc, + const T& case_nil, + const std::function& case_text, + const std::function& case_line) { + if (auto nil = std::dynamic_pointer_cast(doc)) { + return case_nil; + } else if (auto text = std::dynamic_pointer_cast(doc)) { + return case_text(text->str, text->doc); + } else if (auto line = std::dynamic_pointer_cast(doc)) { + return case_line(line->indent, line->doc); + } else {assert(false);} +} */ + +// text constructor +Doc Text(const std::string& str, const Doc& doc); + +// line constructor +Doc Line(int indent, const Doc& doc); // DSL functions -// empty doc +// empty doc/nil constructor Doc Nil(); // lift string to text -Doc Text(const std::string str); +Doc Text(const std::string& str); // new line Doc Line(); // concat two docs @@ -109,7 +72,7 @@ Doc Nest(int indent, const Doc& doc); // convert doc to a string std::string Layout(const Doc& doc); // render array-like things: e.g. (1, 2, 3) -Doc PrintArray(const Doc& open, const tvm::Array& arr, const Doc& sep, const Doc& close); +Doc PrintVec(const Doc& open, const std::vector& arr, const Doc& sep, const Doc& close); // Print constant bool value. Doc PrintBool(bool value); /*! diff --git a/src/relay/ir/doc.cc b/src/relay/ir/doc.cc index 5f8a51a14a4b..937dec38921b 100644 --- a/src/relay/ir/doc.cc +++ b/src/relay/ir/doc.cc @@ -9,84 +9,44 @@ namespace tvm { namespace relay { -// Doc ADT implementation -Nil_ Nil_Node::make() { - NodePtr n = make_node(); - return Nil_(n); -} - -TVM_REGISTER_API("relay._make.Nil_") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = Nil_Node::make(); - }); +// DSL function implementations -TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) -.set_dispatch([](const Nil_Node* node, tvm::IRPrinter* p) { - p->stream << "Nil_Node()"; - }); - -Text_ Text_Node::make(std::string str, Doc doc) { - NodePtr n = make_node(); - n->str = std::move(str); - n->doc = std::move(doc); - return Text_(n); +// empty doc/nil constructor +Doc Nil() { + return std::make_shared(); } -TVM_REGISTER_API("relay._make.Text_") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = Text_Node::make(args[0], args[1]); - }); - -TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) -.set_dispatch([](const Text_Node* node, tvm::IRPrinter* p) { - p->stream << "Text_Node(" << node->str << ", " << node->doc << ")"; - }); - -Line_ Line_Node::make(int indent, Doc doc) { - NodePtr n = make_node(); - n->indent = indent; - n->doc = std::move(doc); - return Line_(n); +// text constructor +Doc Text(const std::string& str, const Doc& doc) { + return std::make_shared(str, doc); } -TVM_REGISTER_API("relay._make.Line_") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = Line_Node::make(args[0], args[1]); - }); - -TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) -.set_dispatch([](const Line_Node* node, tvm::IRPrinter* p) { - p->stream << "Line_Node(" << node->indent << ", " << node->doc << ")"; - }); - -// DSL functions - -// empty doc -Doc Nil() { - return Nil_Node::make(); +// lift string to text +Doc Text(const std::string& str) { + return Text(str, Nil()); } -// lift string to text -Doc Text(const std::string str) { - return Text_Node::make(str, Nil()); +// line constructor +Doc Line(int indent, const Doc& doc) { + return std::make_shared(indent, doc); } // new line -Doc Line() { - return Line_Node::make(0, Nil()); +Doc Line(const Doc& doc) { + return Line(0, doc); } // concat two docs Doc Concat(const Doc& left, const Doc& right) { - if (const Text_Node* text = left.as()) { + if (auto nil = std::dynamic_pointer_cast(left)) { + // throw away nil + return right; + } else if (auto text = std::dynamic_pointer_cast(left)) { // push right into text continuation - return Text_Node::make(text->str, Concat(text->doc, right)); - } else if (const Line_Node* line = left.as()) { + return Text(text->str, Concat(text->doc, right)); + } else if (auto line = std::dynamic_pointer_cast(left)) { // push right into line continuation - return Line_Node::make(line->indent, Concat(line->doc, right)); - } else if (const Nil_Node* nil = left.as()) { - // throwaway nils on the left - return right; + return Line(line->indent, Concat(line->doc, right)); } else {assert(false);} } @@ -97,41 +57,40 @@ Doc operator+(const Doc& left, const Doc& right) { // indent a doc Doc Nest(int indent, const Doc& doc) { - if (const Text_Node* text = doc.as()) { + if (auto nil = std::dynamic_pointer_cast(doc)) { + // absorb nest + return nil; + } else if (auto text = std::dynamic_pointer_cast(doc)) { // push nest through - return Text_Node::make(text->str, Nest(indent, text->doc)); - } else if (const Line_Node* line = doc.as()) { - // add indent to lines and continue - return Line_Node::make(indent + line->indent, Nest(indent, line->doc)); - } else if (const Nil_Node* nil = doc.as()) { - // absorb it - return Nil_(); + return Text(text->str, Nest(indent, text->doc)); + } else if (auto line = std::dynamic_pointer_cast(doc)) { + // add indent to line and continue + return Line(indent + line->indent, Nest(indent, line->doc)); } else {assert(false);} } // convert a doc to a string std::string Layout(const Doc& doc) { - if (const Text_Node* text = doc.as()) { + if (auto nil = std::dynamic_pointer_cast(doc)) { + return ""; + } else if (auto text = std::dynamic_pointer_cast(doc)) { // add text and continue return text->str + Layout(text->doc); - } else if (const Line_Node* line = doc.as()) { + } else if (auto line = std::dynamic_pointer_cast(doc)) { // add a newline and indents, then continue return "\n" + std::string(line->indent, ' ') + Layout(line->doc); - } else if (const Nil_Node* nil = doc.as()) { - // empty string - return ""; } else {assert(false);} } // render array-like things: e.g. (1, 2, 3) -Doc PrintArray(const Doc& open, const tvm::Array& arr, const Doc& sep, const Doc& close) { +Doc PrintVec(const Doc& open, const std::vector& vec, const Doc& sep, const Doc& close) { Doc seq; - if (arr.size() == 0) { + if (vec.size() == 0) { seq = Nil(); } else { - seq = arr[0]; - for (size_t i = 1; i < arr.size(); i++) { - seq = seq + sep + arr[i]; + seq = vec[0]; + for (size_t i = 1; i < vec.size(); i++) { + seq = seq + sep + vec[i]; } } diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index f3b3a0c17cd4..3c33db635ae6 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -56,7 +56,7 @@ class PrettyPrinter : for (Expr field : op->fields) { fields.push_back(this->Print(field)); } - return PrintArray(Text("("), fields, Text(", "), Text(")")); + return PrintVec(Text("("), fields, Text(", "), Text(")")); } Doc VisitExpr_(const TupleGetItemNode* op) final { From bd21270209422bcc4e096e32a1b1987108c323a1 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Thu, 21 Feb 2019 22:19:33 -0800 Subject: [PATCH 15/73] fix bugs. improve gnf printer --- src/relay/ir/doc.cc | 4 +- src/relay/ir/gnf_printer.cc | 42 ++++++++++++++++--- src/relay/ir/pretty_printer.cc | 4 +- .../python/relay/test_ir_parser_roundtrip.py | 6 +++ 4 files changed, 46 insertions(+), 10 deletions(-) diff --git a/src/relay/ir/doc.cc b/src/relay/ir/doc.cc index 937dec38921b..9426153e2e37 100644 --- a/src/relay/ir/doc.cc +++ b/src/relay/ir/doc.cc @@ -32,8 +32,8 @@ Doc Line(int indent, const Doc& doc) { } // new line -Doc Line(const Doc& doc) { - return Line(0, doc); +Doc Line() { + return Line(0, Nil()); } // concat two docs diff --git a/src/relay/ir/gnf_printer.cc b/src/relay/ir/gnf_printer.cc index 2d6a19d3d34d..41ce353b23a4 100644 --- a/src/relay/ir/gnf_printer.cc +++ b/src/relay/ir/gnf_printer.cc @@ -15,18 +15,39 @@ class GNFPrinter : public: explicit GNFPrinter() {} - const Doc Print(const NodeRef& node) { + Doc PrintFinal(const NodeRef& node) { + Print(node); + return doc + TempVar(temp_var_counter_ - 1); + } + + Doc Print(const NodeRef& node) { if (node.as_derived()) { return this->PrintExpr(Downcast(node)); } else { assert(false); } } - const Doc PrintExpr(const Expr& expr) { + Doc TempVar(int n) { + std::ostringstream os; + os << n; + return Text("\%") + Text(os.str()); + } + + Doc AllocTemp() { + return TempVar(temp_var_counter_++); + } + + Doc PrintExpr(const Expr& expr) { + // Exploit memoization to print GNF. + // The first time we visit an expression, we need to allocate a temp var + // for it. Every subsequent time we can just use its assigned variable. + // This works since hashing uses pointer equality. auto it = memo_.find(expr); if (it != memo_.end()) return it->second; - Doc val = this->VisitExpr(expr); - memo_[expr] = val; - return val; + Doc printed_expr = this->VisitExpr(expr); + Doc temp_var = AllocTemp(); + memo_[expr] = temp_var; + doc = doc + temp_var + Text(" = ") + printed_expr + Line(); + return temp_var; } Doc VisitExpr_(const ConstantNode* op) final { @@ -51,14 +72,23 @@ class GNFPrinter : assert(false); } + Doc VisitExpr_(const TupleNode* op) final { + std::vector fields; + for (Expr field : op->fields) { + fields.push_back(this->Print(field)); + } + return PrintVec(Text("("), fields, Text(", "), Text(")")); + } + private: /*! \brief Map from Expr to Doc */ + Doc doc = Nil(); std::unordered_map memo_; size_t temp_var_counter_{0}; }; std::string RelayGNFPrint(const NodeRef& node) { - return "v0.0.1\n" + Layout(GNFPrinter().Print(node)); + return "v0.0.1\n" + Layout(GNFPrinter().PrintFinal(node)); } TVM_REGISTER_API("relay._expr.gnf_print") diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 3c33db635ae6..66ba6f20a5be 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -15,13 +15,13 @@ class PrettyPrinter : public: explicit PrettyPrinter() {} - const Doc Print(const NodeRef& node) { + Doc Print(const NodeRef& node) { if (node.as_derived()) { return this->PrintExpr(Downcast(node)); } else { assert(false); } } - const Doc PrintExpr(const Expr& expr) { + Doc PrintExpr(const Expr& expr) { auto it = memo_.find(expr); if (it != memo_.end()) return it->second; Doc val = this->VisitExpr(expr); diff --git a/tests/python/relay/test_ir_parser_roundtrip.py b/tests/python/relay/test_ir_parser_roundtrip.py index a4c1b322bbaf..3284452d42e2 100644 --- a/tests/python/relay/test_ir_parser_roundtrip.py +++ b/tests/python/relay/test_ir_parser_roundtrip.py @@ -33,6 +33,12 @@ def projections(draw, field_type): def test_roundtrip_pp(e): alpha_equal(relay.fromtext(pretty_print(e)), e) +def test_gnf(): + assert relay._expr.gnf_print(relay.const(1)) == "v0.0.1\n%0 = 1\n%0" + assert relay._expr.gnf_print(relay.Tuple([relay.const(1), relay.const(1)])) == "v0.0.1\n%0 = 1\n%1 = 1\n%2 = (%0, %1)\n%2" + one = relay.const(1) + assert relay._expr.gnf_print(relay.Tuple([one, one])) == "v0.0.1\n%0 = 1\n%1 = (%0, %0)\n%1" + if __name__ == "__main__": for _ in range(10): print(pretty_print(exprs.example())) From 1a58f64324324b502e25cac4f0a6825744b282d6 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Fri, 22 Feb 2019 17:51:06 -0800 Subject: [PATCH 16/73] add stream-like interface for docs --- include/tvm/relay/doc.h | 15 ++++++++++-- src/relay/ir/doc.cc | 23 +++++++++++++++---- src/relay/ir/gnf_printer.cc | 7 ++++-- src/relay/ir/pretty_printer.cc | 6 +++-- .../python/relay/test_ir_parser_roundtrip.py | 1 + 5 files changed, 42 insertions(+), 10 deletions(-) diff --git a/include/tvm/relay/doc.h b/include/tvm/relay/doc.h index b2843164a331..ed41b6883126 100644 --- a/include/tvm/relay/doc.h +++ b/include/tvm/relay/doc.h @@ -67,12 +67,23 @@ Doc Line(); Doc Concat(const Doc& left, const Doc& right); // sugar for Concat Doc operator+(const Doc& left, const Doc& right); +// sugar for Concat with result stored in left +Doc& operator<<(Doc& left, const Doc& right); +// like above, but automatically lifts string to a doc +Doc& operator<<(Doc& left, const std::string& right); +// like above, but converts right to a string +template +Doc& operator<<(Doc& left, const T& right) { + std::ostringstream os; + os << right; + return left << os.str(); +} // indent a doc Doc Nest(int indent, const Doc& doc); // convert doc to a string std::string Layout(const Doc& doc); -// render array-like things: e.g. (1, 2, 3) -Doc PrintVec(const Doc& open, const std::vector& arr, const Doc& sep, const Doc& close); +// render vectors of docs with a separator. e.g. [1, 2, 3], f -> 1f2f3 +Doc PrintVec(const std::vector& arr, const Doc& sep); // Print constant bool value. Doc PrintBool(bool value); /*! diff --git a/src/relay/ir/doc.cc b/src/relay/ir/doc.cc index 9426153e2e37..060cd04c1851 100644 --- a/src/relay/ir/doc.cc +++ b/src/relay/ir/doc.cc @@ -55,6 +55,21 @@ Doc operator+(const Doc& left, const Doc& right) { return Concat(left, right); } +// sugar for Concat with result stored in left +Doc& operator<<(Doc& left, const Doc& right) { + left = left + right; + return left; +} + +// like above, but automatically lifts string to a doc +Doc& operator<<(Doc& left, const std::string& right) { + if (right == "\n") { + return left << Line(); + } else { + return left << Text(right); + } +} + // indent a doc Doc Nest(int indent, const Doc& doc) { if (auto nil = std::dynamic_pointer_cast(doc)) { @@ -82,19 +97,19 @@ std::string Layout(const Doc& doc) { } else {assert(false);} } -// render array-like things: e.g. (1, 2, 3) -Doc PrintVec(const Doc& open, const std::vector& vec, const Doc& sep, const Doc& close) { +// render vectors of docs with a separator. e.g. [1, 2, 3], f -> 1f2f3 +Doc PrintVec(const std::vector& vec, const Doc& sep) { Doc seq; if (vec.size() == 0) { seq = Nil(); } else { seq = vec[0]; for (size_t i = 1; i < vec.size(); i++) { - seq = seq + sep + vec[i]; + seq << sep << vec[i]; } } - return open + seq + close; + return seq; } /*! diff --git a/src/relay/ir/gnf_printer.cc b/src/relay/ir/gnf_printer.cc index 41ce353b23a4..e6178cdedf4e 100644 --- a/src/relay/ir/gnf_printer.cc +++ b/src/relay/ir/gnf_printer.cc @@ -46,7 +46,7 @@ class GNFPrinter : Doc printed_expr = this->VisitExpr(expr); Doc temp_var = AllocTemp(); memo_[expr] = temp_var; - doc = doc + temp_var + Text(" = ") + printed_expr + Line(); + doc << temp_var << " = " << printed_expr << "\n"; return temp_var; } @@ -77,9 +77,12 @@ class GNFPrinter : for (Expr field : op->fields) { fields.push_back(this->Print(field)); } - return PrintVec(Text("("), fields, Text(", "), Text(")")); + Doc doc = Nil(); + return doc << "(" << PrintVec(fields, Text(", ")) << ")"; } + + private: /*! \brief Map from Expr to Doc */ Doc doc = Nil(); diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 66ba6f20a5be..bb97b6cfbf21 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -56,11 +56,13 @@ class PrettyPrinter : for (Expr field : op->fields) { fields.push_back(this->Print(field)); } - return PrintVec(Text("("), fields, Text(", "), Text(")")); + Doc doc = Nil(); + return doc << "(" << PrintVec(fields, Text(", ")) << ")"; } Doc VisitExpr_(const TupleGetItemNode* op) final { - return this->Print(op->tuple) + Text(".") + Text(std::to_string(op->index)); + Doc doc = Nil(); + return doc << this->Print(op->tuple) << "." << op->index; } private: diff --git a/tests/python/relay/test_ir_parser_roundtrip.py b/tests/python/relay/test_ir_parser_roundtrip.py index 3284452d42e2..35c9b0791b53 100644 --- a/tests/python/relay/test_ir_parser_roundtrip.py +++ b/tests/python/relay/test_ir_parser_roundtrip.py @@ -42,3 +42,4 @@ def test_gnf(): if __name__ == "__main__": for _ in range(10): print(pretty_print(exprs.example())) + relay._expr.gnf_print(relay.const(1)) \ No newline at end of file From 073029d725ef8872e456012e22a6b42341418c80 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Fri, 22 Feb 2019 18:25:50 -0800 Subject: [PATCH 17/73] if nodes in gnf --- src/relay/ir/gnf_printer.cc | 24 ++++++++++++++++--- .../python/relay/test_ir_parser_roundtrip.py | 10 +++++--- 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/src/relay/ir/gnf_printer.cc b/src/relay/ir/gnf_printer.cc index e6178cdedf4e..f1fec90c7f2f 100644 --- a/src/relay/ir/gnf_printer.cc +++ b/src/relay/ir/gnf_printer.cc @@ -13,7 +13,9 @@ namespace relay { class GNFPrinter : public ExprFunctor { public: - explicit GNFPrinter() {} + explicit GNFPrinter(const std::unordered_map& memo_, size_t temp_var_counter_) : memo_(memo_), temp_var_counter_(temp_var_counter_) {} + + explicit GNFPrinter() : temp_var_counter_(0) {} Doc PrintFinal(const NodeRef& node) { Print(node); @@ -81,13 +83,29 @@ class GNFPrinter : return doc << "(" << PrintVec(fields, Text(", ")) << ")"; } + Doc VisitExpr_(const TupleGetItemNode* op) final { + Doc doc = Nil(); + return doc << this->Print(op->tuple) << "." << op->index; + } - + Doc VisitExpr_(const IfNode* op) final { + Doc doc = Nil(); + Doc true_b = Nil(); + Doc false_b = Nil(); + doc << "if (" << this->Print(op->cond) << ") {"; + // create a new scope by creating a new printer object. + doc << Nest(2, true_b << "\n" << GNFPrinter(memo_, temp_var_counter_).PrintFinal(op->true_branch)) << "\n"; + doc << "} else {"; + doc << Nest(2, false_b << "\n" << GNFPrinter(memo_, temp_var_counter_).PrintFinal(op->false_branch)) << "\n"; + doc << "}"; + return doc; + } +g private: /*! \brief Map from Expr to Doc */ Doc doc = Nil(); std::unordered_map memo_; - size_t temp_var_counter_{0}; + size_t temp_var_counter_; }; std::string RelayGNFPrint(const NodeRef& node) { diff --git a/tests/python/relay/test_ir_parser_roundtrip.py b/tests/python/relay/test_ir_parser_roundtrip.py index 35c9b0791b53..d17cb9ebeb43 100644 --- a/tests/python/relay/test_ir_parser_roundtrip.py +++ b/tests/python/relay/test_ir_parser_roundtrip.py @@ -38,8 +38,12 @@ def test_gnf(): assert relay._expr.gnf_print(relay.Tuple([relay.const(1), relay.const(1)])) == "v0.0.1\n%0 = 1\n%1 = 1\n%2 = (%0, %1)\n%2" one = relay.const(1) assert relay._expr.gnf_print(relay.Tuple([one, one])) == "v0.0.1\n%0 = 1\n%1 = (%0, %0)\n%1" + + assert relay._expr.gnf_print(relay.If(relay.const(True), relay.TupleGetItem(relay.Tuple([one, one]), 0), relay.TupleGetItem(relay.Tuple([one, one, relay.const(1)]), 0))) == "v0.0.1\n%0 = True\n%1 = if (%0) {\n %1 = 1\n %2 = (%1, %1)\n %3 = %2.0\n %3\n} else {\n %1 = 1\n %2 = 1\n %3 = (%1, %1, %2)\n %4 = %3.0\n %4\n}\n%1" if __name__ == "__main__": - for _ in range(10): - print(pretty_print(exprs.example())) - relay._expr.gnf_print(relay.const(1)) \ No newline at end of file + # for _ in range(10): + # print(pretty_print(exprs.example())) + one = relay.const(1) + print(relay._expr.gnf_print(relay.TupleGetItem(relay.Tuple([one, one]), 0))) + print(relay._expr.gnf_print(relay.If(relay.const(True), relay.TupleGetItem(relay.Tuple([one, one]), 0), relay.TupleGetItem(relay.Tuple([one, one, relay.const(1)]), 0)))) From 14e083f1098d40ad1b0ac6de1cf3c02e53ddd735 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Fri, 22 Feb 2019 18:51:11 -0800 Subject: [PATCH 18/73] semiworking let. kinda --- src/relay/ir/gnf_printer.cc | 32 +++++++++++++++---- .../python/relay/test_ir_parser_roundtrip.py | 2 ++ 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/src/relay/ir/gnf_printer.cc b/src/relay/ir/gnf_printer.cc index f1fec90c7f2f..eb3da2d8c4ea 100644 --- a/src/relay/ir/gnf_printer.cc +++ b/src/relay/ir/gnf_printer.cc @@ -38,7 +38,7 @@ class GNFPrinter : return TempVar(temp_var_counter_++); } - Doc PrintExpr(const Expr& expr) { + Doc PrintExpr(const Expr& expr, bool gnf) { // Exploit memoization to print GNF. // The first time we visit an expression, we need to allocate a temp var // for it. Every subsequent time we can just use its assigned variable. @@ -46,10 +46,20 @@ class GNFPrinter : auto it = memo_.find(expr); if (it != memo_.end()) return it->second; Doc printed_expr = this->VisitExpr(expr); - Doc temp_var = AllocTemp(); - memo_[expr] = temp_var; - doc << temp_var << " = " << printed_expr << "\n"; - return temp_var; + if (gnf && + !expr.as()) { + Doc temp_var = AllocTemp(); + memo_[expr] = temp_var; + doc << temp_var << " = " << printed_expr << "\n"; + return temp_var; + } else { + memo_[expr] = printed_expr; + return printed_expr; + } + } + + Doc PrintExpr(const Expr& expr) { + return this->PrintExpr(expr, true); } Doc VisitExpr_(const ConstantNode* op) final { @@ -100,7 +110,17 @@ class GNFPrinter : doc << "}"; return doc; } -g + + Doc VisitExpr_(const LetNode* op) final { + Doc ret = Nil(); + // TODO: this should call a var printer + // TODO: this should have a type annotation + ret << "let \%" << op->var->name_hint() << " = " << this->PrintExpr(op->value, false) << ";" << "\n"; + ret << this->PrintExpr(op->body); + doc << ret; + return ret; + } + private: /*! \brief Map from Expr to Doc */ Doc doc = Nil(); diff --git a/tests/python/relay/test_ir_parser_roundtrip.py b/tests/python/relay/test_ir_parser_roundtrip.py index d17cb9ebeb43..c25953396fdc 100644 --- a/tests/python/relay/test_ir_parser_roundtrip.py +++ b/tests/python/relay/test_ir_parser_roundtrip.py @@ -47,3 +47,5 @@ def test_gnf(): one = relay.const(1) print(relay._expr.gnf_print(relay.TupleGetItem(relay.Tuple([one, one]), 0))) print(relay._expr.gnf_print(relay.If(relay.const(True), relay.TupleGetItem(relay.Tuple([one, one]), 0), relay.TupleGetItem(relay.Tuple([one, one, relay.const(1)]), 0)))) + SEMVER = "v0.0.1" + print(relay._expr.gnf_print(relay.fromtext(SEMVER+"let %x = 1; 5"))) From 1ba8a82b7a0353e704c897f6cdb685a05b7ea9e2 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Fri, 22 Feb 2019 19:17:32 -0800 Subject: [PATCH 19/73] fix let and refactor --- src/relay/ir/doc.cc | 2 +- {include/tvm/relay => src/relay/ir}/doc.h | 0 src/relay/ir/gnf_printer.cc | 36 +++++++++++-------- src/relay/ir/pretty_printer.cc | 2 +- .../python/relay/test_ir_parser_roundtrip.py | 3 ++ 5 files changed, 26 insertions(+), 17 deletions(-) rename {include/tvm/relay => src/relay/ir}/doc.h (100%) diff --git a/src/relay/ir/doc.cc b/src/relay/ir/doc.cc index 060cd04c1851..87f2abbfc2f5 100644 --- a/src/relay/ir/doc.cc +++ b/src/relay/ir/doc.cc @@ -4,7 +4,7 @@ * \brief Doc ADT used for pretty printing. * Based on Section 1 of https://homepages.inf.ed.ac.uk/wadler/papers/prettier/prettier.pdf. */ -#include +#include "doc.h" namespace tvm { namespace relay { diff --git a/include/tvm/relay/doc.h b/src/relay/ir/doc.h similarity index 100% rename from include/tvm/relay/doc.h rename to src/relay/ir/doc.h diff --git a/src/relay/ir/gnf_printer.cc b/src/relay/ir/gnf_printer.cc index eb3da2d8c4ea..5c0f9a284d48 100644 --- a/src/relay/ir/gnf_printer.cc +++ b/src/relay/ir/gnf_printer.cc @@ -4,7 +4,7 @@ * \brief GNF printer for Relay programs * Supports GNF and metadata. */ -#include +#include "doc.h" #include namespace tvm { @@ -17,17 +17,26 @@ class GNFPrinter : explicit GNFPrinter() : temp_var_counter_(0) {} + Doc PrintNestedScope(const NodeRef& node) { + return GNFPrinter(memo_, temp_var_counter_).PrintFinal(node); + } + Doc PrintFinal(const NodeRef& node) { - Print(node); - return doc + TempVar(temp_var_counter_ - 1); + Print(node, false); + return doc; } - Doc Print(const NodeRef& node) { + // note: gnf flag is only one level deep + Doc Print(const NodeRef& node, bool gnf) { if (node.as_derived()) { - return this->PrintExpr(Downcast(node)); + return this->PrintExpr(Downcast(node), gnf); } else { assert(false); } } + Doc Print(const NodeRef& node) { + return this->Print(node, true); + } + Doc TempVar(int n) { std::ostringstream os; os << n; @@ -54,14 +63,11 @@ class GNFPrinter : return temp_var; } else { memo_[expr] = printed_expr; + doc << printed_expr; return printed_expr; } } - Doc PrintExpr(const Expr& expr) { - return this->PrintExpr(expr, true); - } - Doc VisitExpr_(const ConstantNode* op) final { // Print out simple scalars directly. if (op->is_scalar()) { @@ -104,20 +110,20 @@ class GNFPrinter : Doc false_b = Nil(); doc << "if (" << this->Print(op->cond) << ") {"; // create a new scope by creating a new printer object. - doc << Nest(2, true_b << "\n" << GNFPrinter(memo_, temp_var_counter_).PrintFinal(op->true_branch)) << "\n"; + doc << Nest(2, true_b << "\n" << PrintNestedScope(op->true_branch)) << "\n"; doc << "} else {"; - doc << Nest(2, false_b << "\n" << GNFPrinter(memo_, temp_var_counter_).PrintFinal(op->false_branch)) << "\n"; + doc << Nest(2, false_b << "\n" << PrintNestedScope(op->false_branch)) << "\n"; doc << "}"; return doc; } Doc VisitExpr_(const LetNode* op) final { Doc ret = Nil(); - // TODO: this should call a var printer + // TODO: this should call a var printer, which needs to differentiate + // between free and bound vars // TODO: this should have a type annotation - ret << "let \%" << op->var->name_hint() << " = " << this->PrintExpr(op->value, false) << ";" << "\n"; - ret << this->PrintExpr(op->body); - doc << ret; + ret << "let \%" << op->var->name_hint() << " = " << PrintNestedScope(op->value) << ";" << "\n"; + ret << PrintNestedScope(op->body); return ret; } diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index bb97b6cfbf21..ab577fe82550 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -4,7 +4,7 @@ * \brief Pretty printer for Relay programs * Supports functional style and metadata. */ -#include +#include "doc.h" #include namespace tvm { diff --git a/tests/python/relay/test_ir_parser_roundtrip.py b/tests/python/relay/test_ir_parser_roundtrip.py index c25953396fdc..ff7365e945aa 100644 --- a/tests/python/relay/test_ir_parser_roundtrip.py +++ b/tests/python/relay/test_ir_parser_roundtrip.py @@ -46,6 +46,9 @@ def test_gnf(): # print(pretty_print(exprs.example())) one = relay.const(1) print(relay._expr.gnf_print(relay.TupleGetItem(relay.Tuple([one, one]), 0))) + print() print(relay._expr.gnf_print(relay.If(relay.const(True), relay.TupleGetItem(relay.Tuple([one, one]), 0), relay.TupleGetItem(relay.Tuple([one, one, relay.const(1)]), 0)))) + print() SEMVER = "v0.0.1" print(relay._expr.gnf_print(relay.fromtext(SEMVER+"let %x = 1; 5"))) + print() From 3ad0ad1cb2893a0df62689dc0817ae60d7e2ff15 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Fri, 22 Feb 2019 19:19:59 -0800 Subject: [PATCH 20/73] tempvar refactor --- src/relay/ir/gnf_printer.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/relay/ir/gnf_printer.cc b/src/relay/ir/gnf_printer.cc index 5c0f9a284d48..5dd2933e718f 100644 --- a/src/relay/ir/gnf_printer.cc +++ b/src/relay/ir/gnf_printer.cc @@ -38,9 +38,8 @@ class GNFPrinter : } Doc TempVar(int n) { - std::ostringstream os; - os << n; - return Text("\%") + Text(os.str()); + Doc doc = Nil(); + return doc << "\%" << n; } Doc AllocTemp() { From fb3d910acb10521354dd7c522ba6b6ca605357f5 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Fri, 22 Feb 2019 19:22:27 -0800 Subject: [PATCH 21/73] Nest -> Indent to avoid confusion with nested scopes --- src/relay/ir/doc.cc | 10 +++++----- src/relay/ir/doc.h | 2 +- src/relay/ir/gnf_printer.cc | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/relay/ir/doc.cc b/src/relay/ir/doc.cc index 87f2abbfc2f5..3515d3d5653f 100644 --- a/src/relay/ir/doc.cc +++ b/src/relay/ir/doc.cc @@ -71,16 +71,16 @@ Doc& operator<<(Doc& left, const std::string& right) { } // indent a doc -Doc Nest(int indent, const Doc& doc) { +Doc Indent(int indent, const Doc& doc) { if (auto nil = std::dynamic_pointer_cast(doc)) { - // absorb nest + // absorb indent return nil; } else if (auto text = std::dynamic_pointer_cast(doc)) { - // push nest through - return Text(text->str, Nest(indent, text->doc)); + // push indent through + return Text(text->str, Indent(indent, text->doc)); } else if (auto line = std::dynamic_pointer_cast(doc)) { // add indent to line and continue - return Line(indent + line->indent, Nest(indent, line->doc)); + return Line(indent + line->indent, Indent(indent, line->doc)); } else {assert(false);} } diff --git a/src/relay/ir/doc.h b/src/relay/ir/doc.h index ed41b6883126..4e88d48afea0 100644 --- a/src/relay/ir/doc.h +++ b/src/relay/ir/doc.h @@ -79,7 +79,7 @@ Doc& operator<<(Doc& left, const T& right) { return left << os.str(); } // indent a doc -Doc Nest(int indent, const Doc& doc); +Doc Indent(int indent, const Doc& doc); // convert doc to a string std::string Layout(const Doc& doc); // render vectors of docs with a separator. e.g. [1, 2, 3], f -> 1f2f3 diff --git a/src/relay/ir/gnf_printer.cc b/src/relay/ir/gnf_printer.cc index 5dd2933e718f..524cca4a8307 100644 --- a/src/relay/ir/gnf_printer.cc +++ b/src/relay/ir/gnf_printer.cc @@ -17,6 +17,7 @@ class GNFPrinter : explicit GNFPrinter() : temp_var_counter_(0) {} + // create a new scope by creating a new printer object. Doc PrintNestedScope(const NodeRef& node) { return GNFPrinter(memo_, temp_var_counter_).PrintFinal(node); } @@ -108,10 +109,9 @@ class GNFPrinter : Doc true_b = Nil(); Doc false_b = Nil(); doc << "if (" << this->Print(op->cond) << ") {"; - // create a new scope by creating a new printer object. - doc << Nest(2, true_b << "\n" << PrintNestedScope(op->true_branch)) << "\n"; + doc << Indent(2, true_b << "\n" << PrintNestedScope(op->true_branch)) << "\n"; doc << "} else {"; - doc << Nest(2, false_b << "\n" << PrintNestedScope(op->false_branch)) << "\n"; + doc << Indent(2, false_b << "\n" << PrintNestedScope(op->false_branch)) << "\n"; doc << "}"; return doc; } From bba6b97f31e8aa731ef69c29eb0a6733db634dd6 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Fri, 22 Feb 2019 19:56:40 -0800 Subject: [PATCH 22/73] minor changes --- src/relay/ir/gnf_printer.cc | 8 ++++---- tests/python/relay/test_ir_parser_roundtrip.py | 4 ++++ 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/relay/ir/gnf_printer.cc b/src/relay/ir/gnf_printer.cc index 524cca4a8307..d01708c7d5fa 100644 --- a/src/relay/ir/gnf_printer.cc +++ b/src/relay/ir/gnf_printer.cc @@ -117,13 +117,13 @@ class GNFPrinter : } Doc VisitExpr_(const LetNode* op) final { - Doc ret = Nil(); + Doc doc = Nil(); // TODO: this should call a var printer, which needs to differentiate // between free and bound vars // TODO: this should have a type annotation - ret << "let \%" << op->var->name_hint() << " = " << PrintNestedScope(op->value) << ";" << "\n"; - ret << PrintNestedScope(op->body); - return ret; + doc << "let \%" << op->var->name_hint() << " = " << PrintNestedScope(op->value) << ";" << "\n"; + doc << PrintNestedScope(op->body); + return doc; } private: diff --git a/tests/python/relay/test_ir_parser_roundtrip.py b/tests/python/relay/test_ir_parser_roundtrip.py index ff7365e945aa..8b2b96ab17ce 100644 --- a/tests/python/relay/test_ir_parser_roundtrip.py +++ b/tests/python/relay/test_ir_parser_roundtrip.py @@ -52,3 +52,7 @@ def test_gnf(): SEMVER = "v0.0.1" print(relay._expr.gnf_print(relay.fromtext(SEMVER+"let %x = 1; 5"))) print() + print(relay.fromtext(SEMVER+"let %x = 1; %x").astext()) + print(relay.fromtext(SEMVER+"let %x = (1, 1); %x").astext()) + print(relay.TupleGetItem(relay.Tuple([one, one]), 0).astext()) + print() From 1d78b979ff5e81f3238072a3dae2f2f17c83f5d7 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Sat, 23 Feb 2019 14:44:38 -0800 Subject: [PATCH 23/73] fix lets for real (mostly) and combine anf and gnf printing --- src/relay/ir/gnf_printer.cc | 144 ------------------ src/relay/ir/pretty_printer.cc | 136 +++++++++++++++-- .../python/relay/test_ir_parser_roundtrip.py | 31 ++-- 3 files changed, 140 insertions(+), 171 deletions(-) delete mode 100644 src/relay/ir/gnf_printer.cc diff --git a/src/relay/ir/gnf_printer.cc b/src/relay/ir/gnf_printer.cc deleted file mode 100644 index d01708c7d5fa..000000000000 --- a/src/relay/ir/gnf_printer.cc +++ /dev/null @@ -1,144 +0,0 @@ -/*! - * Copyright (c) 2019 by Contributors - * \file gnf_printer.cc - * \brief GNF printer for Relay programs - * Supports GNF and metadata. - */ -#include "doc.h" -#include - -namespace tvm { -namespace relay { - -class GNFPrinter : - public ExprFunctor { - public: - explicit GNFPrinter(const std::unordered_map& memo_, size_t temp_var_counter_) : memo_(memo_), temp_var_counter_(temp_var_counter_) {} - - explicit GNFPrinter() : temp_var_counter_(0) {} - - // create a new scope by creating a new printer object. - Doc PrintNestedScope(const NodeRef& node) { - return GNFPrinter(memo_, temp_var_counter_).PrintFinal(node); - } - - Doc PrintFinal(const NodeRef& node) { - Print(node, false); - return doc; - } - - // note: gnf flag is only one level deep - Doc Print(const NodeRef& node, bool gnf) { - if (node.as_derived()) { - return this->PrintExpr(Downcast(node), gnf); - } else { assert(false); } - } - - Doc Print(const NodeRef& node) { - return this->Print(node, true); - } - - Doc TempVar(int n) { - Doc doc = Nil(); - return doc << "\%" << n; - } - - Doc AllocTemp() { - return TempVar(temp_var_counter_++); - } - - Doc PrintExpr(const Expr& expr, bool gnf) { - // Exploit memoization to print GNF. - // The first time we visit an expression, we need to allocate a temp var - // for it. Every subsequent time we can just use its assigned variable. - // This works since hashing uses pointer equality. - auto it = memo_.find(expr); - if (it != memo_.end()) return it->second; - Doc printed_expr = this->VisitExpr(expr); - if (gnf && - !expr.as()) { - Doc temp_var = AllocTemp(); - memo_[expr] = temp_var; - doc << temp_var << " = " << printed_expr << "\n"; - return temp_var; - } else { - memo_[expr] = printed_expr; - doc << printed_expr; - return printed_expr; - } - } - - Doc VisitExpr_(const ConstantNode* op) final { - // Print out simple scalars directly. - if (op->is_scalar()) { - std::ostringstream os; - DataType dtype = TVMType2Type(op->data->dtype); - CHECK_EQ(op->data->ctx.device_type, kDLCPU); - if (dtype == Int(32)) { - return PrintConstScalar(dtype, static_cast(op->data->data)); - } else if (dtype == Int(64)) { - return PrintConstScalar(dtype, static_cast(op->data->data)); - } else if (dtype == Float(32)) { - return PrintConstScalar(dtype, static_cast(op->data->data)); - } else if (dtype == Float(64)) { - return PrintConstScalar(dtype, static_cast(op->data->data)); - } else if (dtype == Bool()) { - return PrintConstScalar(dtype, static_cast(op->data->data)); - } - } - // TODO: handle tensors - assert(false); - } - - Doc VisitExpr_(const TupleNode* op) final { - std::vector fields; - for (Expr field : op->fields) { - fields.push_back(this->Print(field)); - } - Doc doc = Nil(); - return doc << "(" << PrintVec(fields, Text(", ")) << ")"; - } - - Doc VisitExpr_(const TupleGetItemNode* op) final { - Doc doc = Nil(); - return doc << this->Print(op->tuple) << "." << op->index; - } - - Doc VisitExpr_(const IfNode* op) final { - Doc doc = Nil(); - Doc true_b = Nil(); - Doc false_b = Nil(); - doc << "if (" << this->Print(op->cond) << ") {"; - doc << Indent(2, true_b << "\n" << PrintNestedScope(op->true_branch)) << "\n"; - doc << "} else {"; - doc << Indent(2, false_b << "\n" << PrintNestedScope(op->false_branch)) << "\n"; - doc << "}"; - return doc; - } - - Doc VisitExpr_(const LetNode* op) final { - Doc doc = Nil(); - // TODO: this should call a var printer, which needs to differentiate - // between free and bound vars - // TODO: this should have a type annotation - doc << "let \%" << op->var->name_hint() << " = " << PrintNestedScope(op->value) << ";" << "\n"; - doc << PrintNestedScope(op->body); - return doc; - } - - private: - /*! \brief Map from Expr to Doc */ - Doc doc = Nil(); - std::unordered_map memo_; - size_t temp_var_counter_; -}; - -std::string RelayGNFPrint(const NodeRef& node) { - return "v0.0.1\n" + Layout(GNFPrinter().PrintFinal(node)); -} - -TVM_REGISTER_API("relay._expr.gnf_print") -.set_body_typed(RelayGNFPrint); - -} // relay -} // tvm diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index ab577fe82550..28995584a991 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -2,7 +2,7 @@ * Copyright (c) 2019 by Contributors * \file pretty_printer.cc * \brief Pretty printer for Relay programs - * Supports functional style and metadata. + * Supports ANF, GNF, and metadata. */ #include "doc.h" #include @@ -13,24 +13,79 @@ namespace relay { class PrettyPrinter : public ExprFunctor { public: - explicit PrettyPrinter() {} + explicit PrettyPrinter(const std::unordered_map& memo_, size_t temp_var_counter_, bool GNF_) : memo_(memo_), temp_var_counter_(temp_var_counter_), GNF_(GNF_) {} - Doc Print(const NodeRef& node) { + explicit PrettyPrinter() : temp_var_counter_(0), GNF_(true) {} + + explicit PrettyPrinter(bool GNF_) : temp_var_counter_(0), GNF_(GNF_) {} + + // indent a new body + Doc PrintBody(const NodeRef& node, int indent = 2) { + Doc doc = Nil(); + Doc body = Nil(); + doc << "{"; + doc << Indent(indent, body << "\n" << PrintNestedScope(node)) << "\n"; + doc << "}"; + return doc; + } + + // create a new scope by creating a new printer object. + Doc PrintNestedScope(const NodeRef& node) { + if (GNF_) { + return PrettyPrinter(memo_, temp_var_counter_, GNF_).PrintFinal(node); + } else { + return Print(node); + } + } + + Doc PrintFinal(const NodeRef& node) { + Print(node, true, false); + return doc; + } + + // note: gnf flag is only one level deep + Doc Print(const NodeRef& node, bool gnf = true, bool hoist = true) { if (node.as_derived()) { - return this->PrintExpr(Downcast(node)); + return PrintExpr(Downcast(node), gnf, hoist); } else { assert(false); } } - Doc PrintExpr(const Expr& expr) { + Doc TempVar(int n) { + Doc doc = Nil(); + return doc << "\%" << n; + } + + Doc AllocTemp() { + return TempVar(temp_var_counter_++); + } + + Doc PrintExpr(const Expr& expr, bool gnf = true, bool hoist = true) { + // Exploit memoization to print GNF. + // The first time we visit an expression, we need to allocate a temp var + // for it. Every subsequent time we can just use its assigned variable. + // This works since hashing uses pointer equality. auto it = memo_.find(expr); if (it != memo_.end()) return it->second; - Doc val = this->VisitExpr(expr); - memo_[expr] = val; - return val; + Doc printed_expr = VisitExpr(expr); + if (gnf && GNF_) { + if (hoist) { + Doc temp_var = AllocTemp(); + memo_[expr] = temp_var; + doc << temp_var << " = " << printed_expr << "\n"; + return temp_var; + } else { + memo_[expr] = printed_expr; + doc << printed_expr; + return printed_expr; + } + } else { + memo_[expr] = printed_expr; + return printed_expr; + } } Doc VisitExpr_(const ConstantNode* op) final { - // Print out simple scalar directly. + // Print out simple scalars directly. if (op->is_scalar()) { std::ostringstream os; DataType dtype = TVMType2Type(op->data->dtype); @@ -54,7 +109,7 @@ class PrettyPrinter : Doc VisitExpr_(const TupleNode* op) final { std::vector fields; for (Expr field : op->fields) { - fields.push_back(this->Print(field)); + fields.push_back(Print(field)); } Doc doc = Nil(); return doc << "(" << PrintVec(fields, Text(", ")) << ")"; @@ -62,20 +117,71 @@ class PrettyPrinter : Doc VisitExpr_(const TupleGetItemNode* op) final { Doc doc = Nil(); - return doc << this->Print(op->tuple) << "." << op->index; + return doc << Print(op->tuple) << "." << op->index; } + Doc VisitExpr_(const IfNode* op) final { + Doc doc = Nil(); + doc << "if (" << Print(op->cond) << ") "; + doc << PrintBody(op->true_branch); + doc << " else "; + doc << PrintBody(op->false_branch); + return doc; + } + + Doc VisitExpr_(const LetNode* op) final { + Doc doc = Nil(); + // TODO: this should call a var printer, which needs to differentiate + // between free and bound vars + // TODO: this should have a type annotation + // TODO: lets in value position need to be scoped + // + // we use ANF mode for the first level of the value position so the final + // expression isn't hoisted or added to the doc stream + doc << "let \%" << op->var->name_hint() << " = " << Print(op->value, false) << ";" << "\n"; + doc << PrintNestedScope(op->body); + return doc; + } + + // Doc PrintFunc(const Doc& prefix, const FunctionNode* fn) { + // // TODO(tqchen, M.K.) support generic function + // // Possibly through meta-data + // CHECK_EQ(fn->type_params.size(), 0U) + // << "generic fn not yet supported"; + // Doc doc = Nil(); + // doc << prefix << "("; + // AllocVarName(fn->params[i]); + // this->PrintVarDecl(fn->params[i], stream_); + // doc << ')'; + // /* if (fn->ret_type.defined()) { + // doc << " -> "; + // this->PrintType(fn->ret_type, stream_); + // } */ + // doc << PrintBody(fn->body); + // return doc; + // } + private: /*! \brief Map from Expr to Doc */ + Doc doc = Nil(); std::unordered_map memo_; + size_t temp_var_counter_; + bool GNF_; }; -std::string RelayPrettyPrint(const NodeRef& node) { - return "v0.0.1\n" + Layout(PrettyPrinter().Print(node)); +std::string RelayGNFPrint(const NodeRef& node) { + return "v0.0.1\n" + Layout(PrettyPrinter().PrintFinal(node)) + "\n"; } -TVM_REGISTER_API("relay._expr.pretty_print") -.set_body_typed(RelayPrettyPrint); +std::string RelayANFPrint(const NodeRef& node) { + return "v0.0.1\n" + Layout(PrettyPrinter(false).Print(node)) + "\n"; +} + +TVM_REGISTER_API("relay._expr.gnf_print") +.set_body_typed(RelayGNFPrint); + +TVM_REGISTER_API("relay._expr.anf_print") +.set_body_typed(RelayANFPrint); } // relay } // tvm diff --git a/tests/python/relay/test_ir_parser_roundtrip.py b/tests/python/relay/test_ir_parser_roundtrip.py index 8b2b96ab17ce..f20a93c9070b 100644 --- a/tests/python/relay/test_ir_parser_roundtrip.py +++ b/tests/python/relay/test_ir_parser_roundtrip.py @@ -1,7 +1,7 @@ import tvm from tvm import relay from tvm.relay.ir_pass import alpha_equal -from tvm.relay._expr import pretty_print +from tvm.relay._expr import anf_print, gnf_print import numpy as np from hypothesis import given, reject, settings @@ -31,28 +31,35 @@ def projections(draw, field_type): @settings(deadline=500, derandomize=True) @given(exprs) def test_roundtrip_pp(e): - alpha_equal(relay.fromtext(pretty_print(e)), e) + alpha_equal(relay.fromtext(anf_print(e)), e) def test_gnf(): - assert relay._expr.gnf_print(relay.const(1)) == "v0.0.1\n%0 = 1\n%0" - assert relay._expr.gnf_print(relay.Tuple([relay.const(1), relay.const(1)])) == "v0.0.1\n%0 = 1\n%1 = 1\n%2 = (%0, %1)\n%2" + assert gnf_print(relay.const(1)) == "v0.0.1\n1\n" + assert gnf_print(relay.Tuple([relay.const(1), relay.const(1)])) == "v0.0.1\n%0 = 1\n%1 = 1\n(%0, %1)\n" one = relay.const(1) - assert relay._expr.gnf_print(relay.Tuple([one, one])) == "v0.0.1\n%0 = 1\n%1 = (%0, %0)\n%1" + assert gnf_print(relay.Tuple([one, one])) == "v0.0.1\n%0 = 1\n(%0, %0)\n" - assert relay._expr.gnf_print(relay.If(relay.const(True), relay.TupleGetItem(relay.Tuple([one, one]), 0), relay.TupleGetItem(relay.Tuple([one, one, relay.const(1)]), 0))) == "v0.0.1\n%0 = True\n%1 = if (%0) {\n %1 = 1\n %2 = (%1, %1)\n %3 = %2.0\n %3\n} else {\n %1 = 1\n %2 = 1\n %3 = (%1, %1, %2)\n %4 = %3.0\n %4\n}\n%1" + assert gnf_print(relay.If(relay.const(True), relay.TupleGetItem(relay.Tuple([one, one]), 0), relay.TupleGetItem(relay.Tuple([one, one, relay.const(1)]), 0))) == "v0.0.1\n%0 = True\nif (%0) {\n %1 = 1\n %2 = (%1, %1)\n %2.0\n} else {\n %1 = 1\n %2 = 1\n %3 = (%1, %1, %2)\n %3.0\n}\n" if __name__ == "__main__": # for _ in range(10): - # print(pretty_print(exprs.example())) + # print(anf_print(exprs.example())) one = relay.const(1) - print(relay._expr.gnf_print(relay.TupleGetItem(relay.Tuple([one, one]), 0))) + tup = relay.Tuple([relay.const(1), relay.const(1)]) + print(gnf_print(relay.TupleGetItem(relay.Tuple([one, one]), 0))) print() - print(relay._expr.gnf_print(relay.If(relay.const(True), relay.TupleGetItem(relay.Tuple([one, one]), 0), relay.TupleGetItem(relay.Tuple([one, one, relay.const(1)]), 0)))) + print(gnf_print(relay.If(relay.const(True), relay.TupleGetItem(relay.Tuple([one, one]), 0), relay.TupleGetItem(relay.Tuple([one, one, relay.const(1)]), 0)))) print() - SEMVER = "v0.0.1" - print(relay._expr.gnf_print(relay.fromtext(SEMVER+"let %x = 1; 5"))) + print(anf_print(relay.If(relay.const(True), relay.TupleGetItem(relay.Tuple([one, one]), 0), relay.TupleGetItem(relay.Tuple([one, one, relay.const(1)]), 0)))) print() + SEMVER = "v0.0.1" + print(gnf_print(relay.fromtext(SEMVER+"let %x = 1; 5"))) print(relay.fromtext(SEMVER+"let %x = 1; %x").astext()) print(relay.fromtext(SEMVER+"let %x = (1, 1); %x").astext()) print(relay.TupleGetItem(relay.Tuple([one, one]), 0).astext()) - print() + print(relay.fromtext(SEMVER+"let %x = 1; let %x = 2; %x").astext()) + print(relay.Let(relay.var("x"), relay.Tuple([tup, tup]), relay.const(5)).astext()) + print(gnf_print(relay.Let(relay.var("x"), relay.Tuple([tup, tup]), relay.const(5)))) + print(anf_print(relay.Let(relay.var("x"), relay.Tuple([tup, tup]), relay.const(5)))) + print(anf_print(relay.fromtext(SEMVER+"let %x = 1; let %x = 2; 3"))) + print(gnf_print(relay.fromtext(SEMVER+"let %x = 1; let %x = 2; 3"))) From 24f255fadfaa0631f923f419c02799215c92e321 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Sat, 23 Feb 2019 18:57:51 -0800 Subject: [PATCH 24/73] support variables, functions, and fix a corner-case bug when only a top-level node has been seen before. remove hoist path --- src/relay/ir/doc.h | 14 -- src/relay/ir/pretty_printer.cc | 123 ++++++++++++------ .../python/relay/test_ir_parser_roundtrip.py | 8 +- 3 files changed, 91 insertions(+), 54 deletions(-) diff --git a/src/relay/ir/doc.h b/src/relay/ir/doc.h index 4e88d48afea0..65830fe91397 100644 --- a/src/relay/ir/doc.h +++ b/src/relay/ir/doc.h @@ -35,20 +35,6 @@ struct LineNode : DocNode { LineNode(int indent, const Doc& doc) : indent(indent), doc(doc) {} }; -/* template -T Match(const Doc& doc, - const T& case_nil, - const std::function& case_text, - const std::function& case_line) { - if (auto nil = std::dynamic_pointer_cast(doc)) { - return case_nil; - } else if (auto text = std::dynamic_pointer_cast(doc)) { - return case_text(text->str, text->doc); - } else if (auto line = std::dynamic_pointer_cast(doc)) { - return case_line(line->indent, line->doc); - } else {assert(false);} -} */ - // text constructor Doc Text(const std::string& str, const Doc& doc); diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 28995584a991..34f3af2d5123 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -13,7 +13,7 @@ namespace relay { class PrettyPrinter : public ExprFunctor { public: - explicit PrettyPrinter(const std::unordered_map& memo_, size_t temp_var_counter_, bool GNF_) : memo_(memo_), temp_var_counter_(temp_var_counter_), GNF_(GNF_) {} + explicit PrettyPrinter(const std::unordered_map& memo_, const std::unordered_map& name_alloc_map_, size_t temp_var_counter_, bool GNF_) : memo_(memo_), name_alloc_map_(name_alloc_map_), temp_var_counter_(temp_var_counter_), GNF_(GNF_) {} explicit PrettyPrinter() : temp_var_counter_(0), GNF_(true) {} @@ -29,37 +29,81 @@ class PrettyPrinter : return doc; } - // create a new scope by creating a new printer object. + // create a new scope by creating a new printer object. This allows temp var + // numbers to be reused and prevents hoisted vars from escaping too far Doc PrintNestedScope(const NodeRef& node) { if (GNF_) { - return PrettyPrinter(memo_, temp_var_counter_, GNF_).PrintFinal(node); + return PrettyPrinter(memo_, name_alloc_map_, temp_var_counter_, GNF_).PrintFinal(node); } else { return Print(node); } } Doc PrintFinal(const NodeRef& node) { - Print(node, true, false); + doc << Print(node, false); return doc; } // note: gnf flag is only one level deep - Doc Print(const NodeRef& node, bool gnf = true, bool hoist = true) { + Doc Print(const NodeRef& node, bool gnf = true) { if (node.as_derived()) { - return PrintExpr(Downcast(node), gnf, hoist); + return PrintExpr(Downcast(node), gnf); } else { assert(false); } } Doc TempVar(int n) { Doc doc = Nil(); - return doc << "\%" << n; + return doc << "%" << n; } Doc AllocTemp() { return TempVar(temp_var_counter_++); } - Doc PrintExpr(const Expr& expr, bool gnf = true, bool hoist = true) { + /*! + * \brief get a unique name with the corresponding prefix + * \param prefix The prefix of the name + * \return The returned name. + */ + Doc GetUniqueName(std::string prefix) { + auto it = name_alloc_map_.find(prefix); + if (it != name_alloc_map_.end()) { + while (true) { + std::ostringstream os; + os << prefix << (++it->second); + std::string name = os.str(); + if (name_alloc_map_.count(name) == 0) { + prefix = name; + break; + } + } + } + name_alloc_map_[prefix] = 0; + return Text(prefix); + } + + /*! + * \brief Allocate name to a variable. + * \param var The input variable. + * \return The corresponding name. + */ + Doc AllocVar(const Var& var) { + std::string name = var->name_hint(); + // always make sure first name is alpha + if (name.length() != 0 && !std::isalpha(name[0])) { + name = "v" + name; + } + Doc val = GetUniqueName("%" + name); + // still print if ir is malformed, but show the error. + if (memo_.count(var)) { + memo_[var] = val + Text("-malformed-ir"); + } + memo_[var] = val; + // TODO: should also return type annotation + return val; + } + + Doc PrintExpr(const Expr& expr, bool gnf = true) { // Exploit memoization to print GNF. // The first time we visit an expression, we need to allocate a temp var // for it. Every subsequent time we can just use its assigned variable. @@ -68,16 +112,10 @@ class PrettyPrinter : if (it != memo_.end()) return it->second; Doc printed_expr = VisitExpr(expr); if (gnf && GNF_) { - if (hoist) { - Doc temp_var = AllocTemp(); - memo_[expr] = temp_var; - doc << temp_var << " = " << printed_expr << "\n"; - return temp_var; - } else { - memo_[expr] = printed_expr; - doc << printed_expr; - return printed_expr; - } + Doc temp_var = AllocTemp(); + memo_[expr] = temp_var; + doc << temp_var << " = " << printed_expr << "\n"; + return temp_var; } else { memo_[expr] = printed_expr; return printed_expr; @@ -135,36 +173,47 @@ class PrettyPrinter : // between free and bound vars // TODO: this should have a type annotation // TODO: lets in value position need to be scoped - // + // we use ANF mode for the first level of the value position so the final // expression isn't hoisted or added to the doc stream - doc << "let \%" << op->var->name_hint() << " = " << Print(op->value, false) << ";" << "\n"; + doc << "let %" << op->var->name_hint() << " = " << Print(op->value, false) << ";" << "\n"; + // we use a nested scope here so GNF hoisting doesn't escape too far + // and so consecutive lets don't get hoisted doc << PrintNestedScope(op->body); return doc; } - // Doc PrintFunc(const Doc& prefix, const FunctionNode* fn) { - // // TODO(tqchen, M.K.) support generic function - // // Possibly through meta-data - // CHECK_EQ(fn->type_params.size(), 0U) - // << "generic fn not yet supported"; - // Doc doc = Nil(); - // doc << prefix << "("; - // AllocVarName(fn->params[i]); - // this->PrintVarDecl(fn->params[i], stream_); - // doc << ')'; - // /* if (fn->ret_type.defined()) { - // doc << " -> "; - // this->PrintType(fn->ret_type, stream_); - // } */ - // doc << PrintBody(fn->body); - // return doc; - // } + Doc PrintFunc(const Doc& prefix, const FunctionNode* fn) { + // TODO(tqchen, M.K.) support generic function + // Possibly through meta-data + CHECK_EQ(fn->type_params.size(), 0U) + << "generic fn not yet supported"; + Doc doc = Nil(); + doc << prefix << "("; + // TODO: need nested var scopes for this!! + std::vector params; + for (Var param : fn->params) { + params.push_back(AllocVar(param)); + } + doc << PrintVec(params, Text(", ")); + doc << ") "; + /* if (fn->ret_type.defined()) { + doc << " -> "; + this->PrintType(fn->ret_type, stream_); + } */ + doc << PrintBody(fn->body); + return doc; + } + + Doc VisitExpr_(const FunctionNode* op) final { + return PrintFunc(Text("fn "), op); + } private: /*! \brief Map from Expr to Doc */ Doc doc = Nil(); std::unordered_map memo_; + std::unordered_map name_alloc_map_; size_t temp_var_counter_; bool GNF_; }; diff --git a/tests/python/relay/test_ir_parser_roundtrip.py b/tests/python/relay/test_ir_parser_roundtrip.py index f20a93c9070b..5e1195177a4f 100644 --- a/tests/python/relay/test_ir_parser_roundtrip.py +++ b/tests/python/relay/test_ir_parser_roundtrip.py @@ -47,11 +47,9 @@ def test_gnf(): one = relay.const(1) tup = relay.Tuple([relay.const(1), relay.const(1)]) print(gnf_print(relay.TupleGetItem(relay.Tuple([one, one]), 0))) - print() + print(relay.If(relay.const(True), relay.TupleGetItem(relay.Tuple([one, one]), 0), relay.TupleGetItem(relay.Tuple([one, one, relay.const(1)]), 0)).astext()) print(gnf_print(relay.If(relay.const(True), relay.TupleGetItem(relay.Tuple([one, one]), 0), relay.TupleGetItem(relay.Tuple([one, one, relay.const(1)]), 0)))) - print() print(anf_print(relay.If(relay.const(True), relay.TupleGetItem(relay.Tuple([one, one]), 0), relay.TupleGetItem(relay.Tuple([one, one, relay.const(1)]), 0)))) - print() SEMVER = "v0.0.1" print(gnf_print(relay.fromtext(SEMVER+"let %x = 1; 5"))) print(relay.fromtext(SEMVER+"let %x = 1; %x").astext()) @@ -63,3 +61,7 @@ def test_gnf(): print(anf_print(relay.Let(relay.var("x"), relay.Tuple([tup, tup]), relay.const(5)))) print(anf_print(relay.fromtext(SEMVER+"let %x = 1; let %x = 2; 3"))) print(gnf_print(relay.fromtext(SEMVER+"let %x = 1; let %x = 2; 3"))) + print(anf_print(relay.fromtext(SEMVER+"fn(%x) { %x }"))) + print(gnf_print(relay.fromtext(SEMVER+"fn(%x) { %x }"))) + print(gnf_print(relay.fromtext(SEMVER+"fn(%x) { (%x, %x) }"))) + print(gnf_print(relay.If(one, relay.TupleGetItem(relay.Tuple([one, one]), 0), one))) \ No newline at end of file From 61033d06c9a717e44edd009b812db46c1253d837 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Thu, 28 Feb 2019 15:44:53 -0800 Subject: [PATCH 25/73] fix if scope in GNF --- src/relay/ir/pretty_printer.cc | 27 ++++++++++++------- .../python/relay/test_ir_parser_roundtrip.py | 4 ++- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 34f3af2d5123..92abfb031ab6 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -33,15 +33,21 @@ class PrettyPrinter : // numbers to be reused and prevents hoisted vars from escaping too far Doc PrintNestedScope(const NodeRef& node) { if (GNF_) { - return PrettyPrinter(memo_, name_alloc_map_, temp_var_counter_, GNF_).PrintFinal(node); + // print in a new scope + doc_stack_.push_back(Nil()); + Doc doc = PrintFinal(node); + doc_stack_.pop_back(); + return doc; } else { return Print(node); } } Doc PrintFinal(const NodeRef& node) { - doc << Print(node, false); - return doc; + // TODO(@jmp): If these lines are combined it segfaults?? + Doc doc = Print(node, false); + doc_stack_.back() << doc; + return doc_stack_.back(); } // note: gnf flag is only one level deep @@ -82,11 +88,11 @@ class PrettyPrinter : return Text(prefix); } - /*! - * \brief Allocate name to a variable. - * \param var The input variable. - * \return The corresponding name. - */ + /*! + * \brief Allocate name to a variable. + * \param var The input variable. + * \return The corresponding name. + */ Doc AllocVar(const Var& var) { std::string name = var->name_hint(); // always make sure first name is alpha @@ -114,7 +120,7 @@ class PrettyPrinter : if (gnf && GNF_) { Doc temp_var = AllocTemp(); memo_[expr] = temp_var; - doc << temp_var << " = " << printed_expr << "\n"; + doc_stack_.back() << temp_var << " = " << printed_expr << "\n"; return temp_var; } else { memo_[expr] = printed_expr; @@ -210,8 +216,9 @@ class PrettyPrinter : } private: + /*! \brief Stack of docs to implement scoped GNFing. */ + std::vector doc_stack_{Nil()}; /*! \brief Map from Expr to Doc */ - Doc doc = Nil(); std::unordered_map memo_; std::unordered_map name_alloc_map_; size_t temp_var_counter_; diff --git a/tests/python/relay/test_ir_parser_roundtrip.py b/tests/python/relay/test_ir_parser_roundtrip.py index 5e1195177a4f..a1a98b62e2ed 100644 --- a/tests/python/relay/test_ir_parser_roundtrip.py +++ b/tests/python/relay/test_ir_parser_roundtrip.py @@ -48,6 +48,7 @@ def test_gnf(): tup = relay.Tuple([relay.const(1), relay.const(1)]) print(gnf_print(relay.TupleGetItem(relay.Tuple([one, one]), 0))) print(relay.If(relay.const(True), relay.TupleGetItem(relay.Tuple([one, one]), 0), relay.TupleGetItem(relay.Tuple([one, one, relay.const(1)]), 0)).astext()) + print(gnf_print(relay.If(relay.const(True), relay.const(1), relay.const(1)))) print(gnf_print(relay.If(relay.const(True), relay.TupleGetItem(relay.Tuple([one, one]), 0), relay.TupleGetItem(relay.Tuple([one, one, relay.const(1)]), 0)))) print(anf_print(relay.If(relay.const(True), relay.TupleGetItem(relay.Tuple([one, one]), 0), relay.TupleGetItem(relay.Tuple([one, one, relay.const(1)]), 0)))) SEMVER = "v0.0.1" @@ -64,4 +65,5 @@ def test_gnf(): print(anf_print(relay.fromtext(SEMVER+"fn(%x) { %x }"))) print(gnf_print(relay.fromtext(SEMVER+"fn(%x) { %x }"))) print(gnf_print(relay.fromtext(SEMVER+"fn(%x) { (%x, %x) }"))) - print(gnf_print(relay.If(one, relay.TupleGetItem(relay.Tuple([one, one]), 0), one))) \ No newline at end of file + print(gnf_print(relay.If(one, relay.TupleGetItem(relay.Tuple([one, one]), 0), one))) + print(relay.If(relay.const(True), tup, tup).astext()) \ No newline at end of file From 73ee892cbd5cda37f5fc68ae445d6cdda41dcefe Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Thu, 28 Feb 2019 17:52:32 -0800 Subject: [PATCH 26/73] add global var --- src/relay/ir/pretty_printer.cc | 11 +++++++---- tests/python/relay/test_ir_parser_roundtrip.py | 3 ++- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 92abfb031ab6..7eb2246c485c 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -46,8 +46,7 @@ class PrettyPrinter : Doc PrintFinal(const NodeRef& node) { // TODO(@jmp): If these lines are combined it segfaults?? Doc doc = Print(node, false); - doc_stack_.back() << doc; - return doc_stack_.back(); + return doc_stack_.back() << doc; } // note: gnf flag is only one level deep @@ -117,7 +116,8 @@ class PrettyPrinter : auto it = memo_.find(expr); if (it != memo_.end()) return it->second; Doc printed_expr = VisitExpr(expr); - if (gnf && GNF_) { + // we choose to inline some nodes + if (GNF_ && gnf && !expr.as() && !expr.as()) { Doc temp_var = AllocTemp(); memo_[expr] = temp_var; doc_stack_.back() << temp_var << " = " << printed_expr << "\n"; @@ -196,7 +196,6 @@ class PrettyPrinter : << "generic fn not yet supported"; Doc doc = Nil(); doc << prefix << "("; - // TODO: need nested var scopes for this!! std::vector params; for (Var param : fn->params) { params.push_back(AllocVar(param)); @@ -215,6 +214,10 @@ class PrettyPrinter : return PrintFunc(Text("fn "), op); } + Doc VisitExpr_(const GlobalVarNode* op) final { + return Text('@' + op->name_hint); + } + private: /*! \brief Stack of docs to implement scoped GNFing. */ std::vector doc_stack_{Nil()}; diff --git a/tests/python/relay/test_ir_parser_roundtrip.py b/tests/python/relay/test_ir_parser_roundtrip.py index a1a98b62e2ed..14033b93c753 100644 --- a/tests/python/relay/test_ir_parser_roundtrip.py +++ b/tests/python/relay/test_ir_parser_roundtrip.py @@ -66,4 +66,5 @@ def test_gnf(): print(gnf_print(relay.fromtext(SEMVER+"fn(%x) { %x }"))) print(gnf_print(relay.fromtext(SEMVER+"fn(%x) { (%x, %x) }"))) print(gnf_print(relay.If(one, relay.TupleGetItem(relay.Tuple([one, one]), 0), one))) - print(relay.If(relay.const(True), tup, tup).astext()) \ No newline at end of file + print(relay.If(relay.const(True), tup, tup).astext()) + print(gnf_print(relay.If(relay.GlobalVar("foo"), relay.TupleGetItem(relay.Tuple([one, one]), 0), one))) \ No newline at end of file From 00cbd9f77806b862617fca2e50e98d3a00963958 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Thu, 28 Feb 2019 21:31:40 -0800 Subject: [PATCH 27/73] add simple call and op --- src/relay/ir/pretty_printer.cc | 18 ++++++++++++++++-- tests/python/relay/test_ir_parser_roundtrip.py | 8 +++++++- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 7eb2246c485c..434e929b6f53 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -44,7 +44,7 @@ class PrettyPrinter : } Doc PrintFinal(const NodeRef& node) { - // TODO(@jmp): If these lines are combined it segfaults?? + // must print first so doc_stack_.back() reference doesn't become stale Doc doc = Print(node, false); return doc_stack_.back() << doc; } @@ -117,7 +117,7 @@ class PrettyPrinter : if (it != memo_.end()) return it->second; Doc printed_expr = VisitExpr(expr); // we choose to inline some nodes - if (GNF_ && gnf && !expr.as() && !expr.as()) { + if (GNF_ && gnf && !expr.as() && !expr.as() && !expr.as()) { Doc temp_var = AllocTemp(); memo_[expr] = temp_var; doc_stack_.back() << temp_var << " = " << printed_expr << "\n"; @@ -218,6 +218,20 @@ class PrettyPrinter : return Text('@' + op->name_hint); } + Doc VisitExpr_(const OpNode* op) final { + return Text(op->name); + } + + Doc VisitExpr_(const CallNode* op) final { + Doc doc = Nil(); + doc << Print(op->op); + std::vector args; + for (Expr arg : op->args) { + args.push_back(Print(arg)); + } + return doc << "(" << PrintVec(args, Text(", ")) << ")"; + } + private: /*! \brief Stack of docs to implement scoped GNFing. */ std::vector doc_stack_{Nil()}; diff --git a/tests/python/relay/test_ir_parser_roundtrip.py b/tests/python/relay/test_ir_parser_roundtrip.py index 14033b93c753..8a4db4f3fc8c 100644 --- a/tests/python/relay/test_ir_parser_roundtrip.py +++ b/tests/python/relay/test_ir_parser_roundtrip.py @@ -67,4 +67,10 @@ def test_gnf(): print(gnf_print(relay.fromtext(SEMVER+"fn(%x) { (%x, %x) }"))) print(gnf_print(relay.If(one, relay.TupleGetItem(relay.Tuple([one, one]), 0), one))) print(relay.If(relay.const(True), tup, tup).astext()) - print(gnf_print(relay.If(relay.GlobalVar("foo"), relay.TupleGetItem(relay.Tuple([one, one]), 0), one))) \ No newline at end of file + print(gnf_print(relay.If(relay.GlobalVar("foo"), relay.TupleGetItem(relay.Tuple([one, one]), 0), one))) + print(anf_print(relay.fromtext(SEMVER+"(fn(%x, %y) { %x })(1, 2)"))) + print(gnf_print(relay.fromtext(SEMVER+"(fn(%x, %y) { %x })(1, 2)"))) + print(relay.fromtext(SEMVER+"(fn(%x, %y) { %x })(1, 2)").astext()) + print(relay.fromtext(SEMVER+"fn(%x, %y) { %x + %y }").astext()) + print(anf_print(relay.fromtext(SEMVER+"fn(%x, %y) { %x + %y }"))) + print(gnf_print(relay.fromtext(SEMVER+"fn(%x, %y) { %x + %y }"))) From 9795c10f4a36b4ead9ac195dd85a61054ca3d88d Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Fri, 1 Mar 2019 18:30:45 -0800 Subject: [PATCH 28/73] add attrs and small refactoring --- src/relay/ir/doc.h | 2 +- src/relay/ir/pretty_printer.cc | 145 +++++++++++++++++- .../python/relay/test_ir_parser_roundtrip.py | 4 + 3 files changed, 145 insertions(+), 6 deletions(-) diff --git a/src/relay/ir/doc.h b/src/relay/ir/doc.h index 65830fe91397..aee45b5befdc 100644 --- a/src/relay/ir/doc.h +++ b/src/relay/ir/doc.h @@ -69,7 +69,7 @@ Doc Indent(int indent, const Doc& doc); // convert doc to a string std::string Layout(const Doc& doc); // render vectors of docs with a separator. e.g. [1, 2, 3], f -> 1f2f3 -Doc PrintVec(const std::vector& arr, const Doc& sep); +Doc PrintVec(const std::vector& arr, const Doc& sep = Text(", ")); // Print constant bool value. Doc PrintBool(bool value); /*! diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 434e929b6f53..88d3f68d7c96 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -6,14 +6,17 @@ */ #include "doc.h" #include +#include "type_functor.h" +#include "../../lang/attr_functor.h" namespace tvm { namespace relay { class PrettyPrinter : - public ExprFunctor { + public ExprFunctor, + public TypeFunctor { public: - explicit PrettyPrinter(const std::unordered_map& memo_, const std::unordered_map& name_alloc_map_, size_t temp_var_counter_, bool GNF_) : memo_(memo_), name_alloc_map_(name_alloc_map_), temp_var_counter_(temp_var_counter_), GNF_(GNF_) {} + explicit PrettyPrinter(const std::unordered_map& memo_, const std::unordered_map& memo_type_, const std::unordered_map& name_alloc_map_, size_t temp_var_counter_, bool GNF_) : memo_(memo_), memo_type_(memo_type_), name_alloc_map_(name_alloc_map_), temp_var_counter_(temp_var_counter_), GNF_(GNF_) {} explicit PrettyPrinter() : temp_var_counter_(0), GNF_(true) {} @@ -49,10 +52,14 @@ class PrettyPrinter : return doc_stack_.back() << doc; } + Doc PrintAttrs(const Attrs& attrs); + // note: gnf flag is only one level deep Doc Print(const NodeRef& node, bool gnf = true) { if (node.as_derived()) { return PrintExpr(Downcast(node), gnf); + } else if (node.as_derived()) { + return PrintType(Downcast(node)); } else { assert(false); } } @@ -108,6 +115,9 @@ class PrettyPrinter : return val; } + //------------------------------------ + // Overload of Expr printing functions + //------------------------------------ Doc PrintExpr(const Expr& expr, bool gnf = true) { // Exploit memoization to print GNF. // The first time we visit an expression, we need to allocate a temp var @@ -156,7 +166,7 @@ class PrettyPrinter : fields.push_back(Print(field)); } Doc doc = Nil(); - return doc << "(" << PrintVec(fields, Text(", ")) << ")"; + return doc << "(" << PrintVec(fields) << ")"; } Doc VisitExpr_(const TupleGetItemNode* op) final { @@ -200,7 +210,7 @@ class PrettyPrinter : for (Var param : fn->params) { params.push_back(AllocVar(param)); } - doc << PrintVec(params, Text(", ")); + doc << PrintVec(params) << PrintAttrs(fn->attrs); doc << ") "; /* if (fn->ret_type.defined()) { doc << " -> "; @@ -229,7 +239,18 @@ class PrettyPrinter : for (Expr arg : op->args) { args.push_back(Print(arg)); } - return doc << "(" << PrintVec(args, Text(", ")) << ")"; + return doc << "(" << PrintVec(args) << PrintAttrs(op->attrs) << ")"; + } + + //------------------------------------ + // Overload of Expr printing functions + //------------------------------------ + Doc PrintType(const Type& type) { + auto it = memo_type_.find(type); + if (it != memo_type_.end()) return it->second; + Doc printed_type = VisitType(type); + memo_type_[type] = printed_type; + return printed_type; } private: @@ -237,11 +258,125 @@ class PrettyPrinter : std::vector doc_stack_{Nil()}; /*! \brief Map from Expr to Doc */ std::unordered_map memo_; + /*! \brief Map from Type to Doc */ + std::unordered_map memo_type_; std::unordered_map name_alloc_map_; size_t temp_var_counter_; bool GNF_; + class AttrPrinter; + friend class AttrPrinter; +}; + +/*! + * \brief Attribute printer which prints the attributes in the call. + */ +class PrettyPrinter::AttrPrinter : + public AttrVisitor, + public AttrFunctor { + public: + AttrPrinter(Doc& doc_) : doc_(doc_) {} + + template + Doc PrintKV(const char* key, const T& value) { + Doc doc = Nil(); + return doc << ", " << key << "=" << value; + } + + void Visit(const char* key, double* value) final { + doc_ << PrintKV(key, value[0]); + } + void Visit(const char* key, int64_t* value) final { + doc_ << PrintKV(key, value[0]); + } + void Visit(const char* key, uint64_t* value) final { + doc_ << PrintKV(key, value[0]); + } + void Visit(const char* key, int* value) final { + doc_ << PrintKV(key, value[0]); + } + void Visit(const char* key, bool* value) final { + doc_ << PrintKV(key, PrintBool(value[0])); + } + void Visit(const char* key, std::string* value) final { + doc_ << PrintKV(key, PrintString(value[0])); + } + void Visit(const char* key, void** value) final { + LOG(FATAL) << "do not allow void as argument"; + } + void Visit(const char* key, DataType* value) final { + doc_ << PrintKV(key, PrintString(runtime::TVMType2String(Type2TVMType(value[0])))); + } + void Visit(const char* key, NodeRef* value) final { + doc_ << PrintKV(key, PrintAttr(value[0])); + } + void Visit(const char* key, runtime::NDArray* value) final { + LOG(FATAL) << "do not allow NDarray as argument"; + } + + //------------------------------------ + // Overload of Attr printing functions + //------------------------------------ + + Doc PrintAttr(const NodeRef& value) { // NOLINT(*) + if (value.defined()) { + return VisitAttr(value); + } else { + return Text("None"); + } + } + + Doc VisitAttr_(const ArrayNode* op) final { // NOLINT(*) + Doc doc = Nil(); + doc << "["; + std::vector arr_vals; + for (NodePtr val : op->data) { + arr_vals.push_back(PrintAttr(NodeRef(val))); + } + doc << PrintVec(arr_vals); + doc << "]"; + return doc; + } + + Doc VisitAttrDefault_(const Node* op) final { // NOLINT(*) + // os << meta_.GetMetaNode(GetRef(op)); + assert(false); + } + + Doc VisitAttr_(const ir::IntImm* op) final { // NOLINT(*) + return PrintConstScalar(op->type, &(op->value)); + } + + Doc VisitAttr_(const ir::UIntImm* op) final { // NOLINT(*) + return PrintConstScalar(op->type, &(op->value)); + } + + Doc VisitAttr_(const ir::FloatImm* op) final { // NOLINT(*) + return PrintConstScalar(op->type, &(op->value)); + } + + Doc PrintString(const std::string& value) { // NOLINT(*) + // TODO(M.K.): add escape. + Doc doc = Nil(); + return doc << "\"" << value << "\""; + } + + Doc VisitAttr_(const ir::StringImm* op) final { // NOLINT(*) + return PrintString(op->value); + } + + private: + Doc& doc_; }; +Doc PrettyPrinter::PrintAttrs(const Attrs& attrs) { // NOLINT(*) + // TODO: meta + if (!attrs.defined()) return Nil(); + Doc doc = Nil(); + AttrPrinter printer(doc); + const_cast(attrs.operator->())->VisitNonDefaultAttrs(&printer); + return doc; +} + std::string RelayGNFPrint(const NodeRef& node) { return "v0.0.1\n" + Layout(PrettyPrinter().PrintFinal(node)) + "\n"; } diff --git a/tests/python/relay/test_ir_parser_roundtrip.py b/tests/python/relay/test_ir_parser_roundtrip.py index 8a4db4f3fc8c..ea167ff43e43 100644 --- a/tests/python/relay/test_ir_parser_roundtrip.py +++ b/tests/python/relay/test_ir_parser_roundtrip.py @@ -74,3 +74,7 @@ def test_gnf(): print(relay.fromtext(SEMVER+"fn(%x, %y) { %x + %y }").astext()) print(anf_print(relay.fromtext(SEMVER+"fn(%x, %y) { %x + %y }"))) print(gnf_print(relay.fromtext(SEMVER+"fn(%x, %y) { %x + %y }"))) + print(relay.Call(relay.fromtext(SEMVER+"fn(%x) { %x }"), [relay.const(1)], attrs=tvm.make.node("DictAttrs", n="foo")).astext()) + print(anf_print(relay.Call(relay.fromtext(SEMVER+"fn(%x) { %x }"), [relay.const(1)], attrs=tvm.make.node("DictAttrs", n="foo")))) + # print(relay.fromtext(SEMVER+"add(n=5)").astext()) + # print(anf_print(relay.fromtext(SEMVER+"fn (n=5) { () }"))) \ No newline at end of file From 2a7e94d442d25294e1cc8aca263d1d5a898ab716 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Sun, 3 Mar 2019 21:41:26 -0800 Subject: [PATCH 29/73] add tensor type printing and refactor attr printing to get it to work --- src/relay/ir/doc.cc | 10 ++ src/relay/ir/doc.h | 2 + src/relay/ir/pretty_printer.cc | 131 ++++++++++-------- .../python/relay/test_ir_parser_roundtrip.py | 5 +- 4 files changed, 88 insertions(+), 60 deletions(-) diff --git a/src/relay/ir/doc.cc b/src/relay/ir/doc.cc index 3515d3d5653f..f72a0e95b193 100644 --- a/src/relay/ir/doc.cc +++ b/src/relay/ir/doc.cc @@ -124,5 +124,15 @@ Doc PrintBool(bool value) { } } +Doc PrintDType(DataType dtype) { + return Text(runtime::TVMType2String(Type2TVMType(dtype))); +} + +Doc PrintString(const std::string& value) { // NOLINT(*) + // TODO(M.K.): add escape. + Doc doc = Nil(); + return doc << "\"" << value << "\""; +} + } // relay } // tvm diff --git a/src/relay/ir/doc.h b/src/relay/ir/doc.h index aee45b5befdc..9140459a31c8 100644 --- a/src/relay/ir/doc.h +++ b/src/relay/ir/doc.h @@ -72,6 +72,8 @@ std::string Layout(const Doc& doc); Doc PrintVec(const std::vector& arr, const Doc& sep = Text(", ")); // Print constant bool value. Doc PrintBool(bool value); +Doc PrintDType(DataType dtype); +Doc PrintString(const std::string& value); /*! * \brief special method to print out const scalar * \param dtype The data type diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 88d3f68d7c96..e4cc70db8b24 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -14,7 +14,8 @@ namespace relay { class PrettyPrinter : public ExprFunctor, - public TypeFunctor { + public TypeFunctor, + public AttrFunctor { public: explicit PrettyPrinter(const std::unordered_map& memo_, const std::unordered_map& memo_type_, const std::unordered_map& name_alloc_map_, size_t temp_var_counter_, bool GNF_) : memo_(memo_), memo_type_(memo_type_), name_alloc_map_(name_alloc_map_), temp_var_counter_(temp_var_counter_), GNF_(GNF_) {} @@ -243,7 +244,7 @@ class PrettyPrinter : } //------------------------------------ - // Overload of Expr printing functions + // Overload of Type printing functions //------------------------------------ Doc PrintType(const Type& type) { auto it = memo_type_.find(type); @@ -253,6 +254,70 @@ class PrettyPrinter : return printed_type; } + Doc VisitType_(const TensorTypeNode* node) final { // NOLINT(*) + // scalar type + if (node->shape.size() == 0) { + return PrintDType(node->dtype); + } + Doc doc = Nil(); + doc << "Tensor[("; + std::vector shapes; + for (NodeRef shape : node->shape) { + shapes.push_back(PrintAttr(shape)); + } + doc << PrintVec(shapes); + // conform to python tuple format (1,) + if (node->shape.size() == 1) { + doc << ","; + } + return doc << "), " << PrintDType(node->dtype) << "]"; + } + + //------------------------------------ + // Overload of Attr printing functions + //------------------------------------ + + Doc PrintAttr(const NodeRef& value) { // NOLINT(*) + if (value.defined()) { + return VisitAttr(value); + } else { + return Text("None"); + } + } + + Doc VisitAttr_(const ArrayNode* op) final { // NOLINT(*) + Doc doc = Nil(); + doc << "["; + std::vector arr_vals; + for (NodePtr val : op->data) { + arr_vals.push_back(PrintAttr(NodeRef(val))); + } + doc << PrintVec(arr_vals); + doc << "]"; + return doc; + } + + Doc VisitAttrDefault_(const Node* op) final { // NOLINT(*) + // os << meta_.GetMetaNode(GetRef(op)); + assert(false); + } + + Doc VisitAttr_(const ir::IntImm* op) final { // NOLINT(*) + return PrintConstScalar(op->type, &(op->value)); + } + + Doc VisitAttr_(const ir::UIntImm* op) final { // NOLINT(*) + return PrintConstScalar(op->type, &(op->value)); + } + + Doc VisitAttr_(const ir::FloatImm* op) final { // NOLINT(*) + return PrintConstScalar(op->type, &(op->value)); + } + + Doc VisitAttr_(const ir::StringImm* op) final { // NOLINT(*) + return PrintString(op->value); + } + private: /*! \brief Stack of docs to implement scoped GNFing. */ std::vector doc_stack_{Nil()}; @@ -270,11 +335,9 @@ class PrettyPrinter : /*! * \brief Attribute printer which prints the attributes in the call. */ -class PrettyPrinter::AttrPrinter : - public AttrVisitor, - public AttrFunctor { +class PrettyPrinter::AttrPrinter : public AttrVisitor { public: - AttrPrinter(Doc& doc_) : doc_(doc_) {} + AttrPrinter(Doc& doc_, PrettyPrinter* parent_) : doc_(doc_), parent_(parent_) {} template Doc PrintKV(const char* key, const T& value) { @@ -307,72 +370,22 @@ class PrettyPrinter::AttrPrinter : doc_ << PrintKV(key, PrintString(runtime::TVMType2String(Type2TVMType(value[0])))); } void Visit(const char* key, NodeRef* value) final { - doc_ << PrintKV(key, PrintAttr(value[0])); + doc_ << PrintKV(key, parent_->PrintAttr(value[0])); } void Visit(const char* key, runtime::NDArray* value) final { LOG(FATAL) << "do not allow NDarray as argument"; } - //------------------------------------ - // Overload of Attr printing functions - //------------------------------------ - - Doc PrintAttr(const NodeRef& value) { // NOLINT(*) - if (value.defined()) { - return VisitAttr(value); - } else { - return Text("None"); - } - } - - Doc VisitAttr_(const ArrayNode* op) final { // NOLINT(*) - Doc doc = Nil(); - doc << "["; - std::vector arr_vals; - for (NodePtr val : op->data) { - arr_vals.push_back(PrintAttr(NodeRef(val))); - } - doc << PrintVec(arr_vals); - doc << "]"; - return doc; - } - - Doc VisitAttrDefault_(const Node* op) final { // NOLINT(*) - // os << meta_.GetMetaNode(GetRef(op)); - assert(false); - } - - Doc VisitAttr_(const ir::IntImm* op) final { // NOLINT(*) - return PrintConstScalar(op->type, &(op->value)); - } - - Doc VisitAttr_(const ir::UIntImm* op) final { // NOLINT(*) - return PrintConstScalar(op->type, &(op->value)); - } - - Doc VisitAttr_(const ir::FloatImm* op) final { // NOLINT(*) - return PrintConstScalar(op->type, &(op->value)); - } - - Doc PrintString(const std::string& value) { // NOLINT(*) - // TODO(M.K.): add escape. - Doc doc = Nil(); - return doc << "\"" << value << "\""; - } - - Doc VisitAttr_(const ir::StringImm* op) final { // NOLINT(*) - return PrintString(op->value); - } - private: Doc& doc_; + PrettyPrinter* parent_; }; Doc PrettyPrinter::PrintAttrs(const Attrs& attrs) { // NOLINT(*) // TODO: meta if (!attrs.defined()) return Nil(); Doc doc = Nil(); - AttrPrinter printer(doc); + AttrPrinter printer(doc, this); const_cast(attrs.operator->())->VisitNonDefaultAttrs(&printer); return doc; } diff --git a/tests/python/relay/test_ir_parser_roundtrip.py b/tests/python/relay/test_ir_parser_roundtrip.py index ea167ff43e43..d5b9c8ff137b 100644 --- a/tests/python/relay/test_ir_parser_roundtrip.py +++ b/tests/python/relay/test_ir_parser_roundtrip.py @@ -75,6 +75,9 @@ def test_gnf(): print(anf_print(relay.fromtext(SEMVER+"fn(%x, %y) { %x + %y }"))) print(gnf_print(relay.fromtext(SEMVER+"fn(%x, %y) { %x + %y }"))) print(relay.Call(relay.fromtext(SEMVER+"fn(%x) { %x }"), [relay.const(1)], attrs=tvm.make.node("DictAttrs", n="foo")).astext()) - print(anf_print(relay.Call(relay.fromtext(SEMVER+"fn(%x) { %x }"), [relay.const(1)], attrs=tvm.make.node("DictAttrs", n="foo")))) + # print(anf_print(relay.Call(relay.fromtext(SEMVER+"fn(%x) { %x }"), [relay.const(1)], attrs=tvm.make.node("DictAttrs", n="foo")))) + print(relay.TensorType([5, 5]).astext()) + print(anf_print(relay.TensorType([5, 5]))) + print(gnf_print(relay.TensorType([5, 5]))) # print(relay.fromtext(SEMVER+"add(n=5)").astext()) # print(anf_print(relay.fromtext(SEMVER+"fn (n=5) { () }"))) \ No newline at end of file From 1d4524174bf600327451da49bb244c48f0dba6c5 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Sun, 3 Mar 2019 21:54:54 -0800 Subject: [PATCH 30/73] indenting and add tests --- src/relay/ir/pretty_printer.cc | 114 ++++++++++-------- .../python/relay/test_ir_parser_roundtrip.py | 17 ++- 2 files changed, 75 insertions(+), 56 deletions(-) diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index e4cc70db8b24..a3aea69fd1b0 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -255,68 +255,82 @@ class PrettyPrinter : } Doc VisitType_(const TensorTypeNode* node) final { // NOLINT(*) - // scalar type - if (node->shape.size() == 0) { - return PrintDType(node->dtype); - } - Doc doc = Nil(); - doc << "Tensor[("; - std::vector shapes; - for (NodeRef shape : node->shape) { - shapes.push_back(PrintAttr(shape)); + // scalar type + if (node->shape.size() == 0) { + return PrintDType(node->dtype); + } + Doc doc = Nil(); + doc << "Tensor[("; + std::vector shapes; + for (NodeRef shape : node->shape) { + shapes.push_back(PrintAttr(shape)); + } + doc << PrintVec(shapes); + // conform to python tuple format (1,) + if (node->shape.size() == 1) { + doc << ","; + } + return doc << "), " << PrintDType(node->dtype) << "]"; } - doc << PrintVec(shapes); - // conform to python tuple format (1,) - if (node->shape.size() == 1) { - doc << ","; + + Doc VisitType_(const TupleTypeNode* node) final { + std::vector fields; + for (NodeRef field : node->fields) { + fields.push_back(Print(field)); + } + Doc doc = Nil(); + doc << "(" << PrintVec(fields); + // conform to python tuple format (1,) + if (node->fields.size() == 1) { + doc << ","; + } + return doc << ")"; } - return doc << "), " << PrintDType(node->dtype) << "]"; - } - //------------------------------------ - // Overload of Attr printing functions - //------------------------------------ + //------------------------------------ + // Overload of Attr printing functions + //------------------------------------ - Doc PrintAttr(const NodeRef& value) { // NOLINT(*) - if (value.defined()) { - return VisitAttr(value); - } else { - return Text("None"); + Doc PrintAttr(const NodeRef& value) { // NOLINT(*) + if (value.defined()) { + return VisitAttr(value); + } else { + return Text("None"); + } } - } - Doc VisitAttr_(const ArrayNode* op) final { // NOLINT(*) - Doc doc = Nil(); - doc << "["; - std::vector arr_vals; - for (NodePtr val : op->data) { - arr_vals.push_back(PrintAttr(NodeRef(val))); + Doc VisitAttr_(const ArrayNode* op) final { // NOLINT(*) + Doc doc = Nil(); + doc << "["; + std::vector arr_vals; + for (NodePtr val : op->data) { + arr_vals.push_back(PrintAttr(NodeRef(val))); + } + doc << PrintVec(arr_vals); + doc << "]"; + return doc; } - doc << PrintVec(arr_vals); - doc << "]"; - return doc; - } - Doc VisitAttrDefault_(const Node* op) final { // NOLINT(*) - // os << meta_.GetMetaNode(GetRef(op)); - assert(false); - } + Doc VisitAttrDefault_(const Node* op) final { // NOLINT(*) + // os << meta_.GetMetaNode(GetRef(op)); + assert(false); + } - Doc VisitAttr_(const ir::IntImm* op) final { // NOLINT(*) - return PrintConstScalar(op->type, &(op->value)); - } + Doc VisitAttr_(const ir::IntImm* op) final { // NOLINT(*) + return PrintConstScalar(op->type, &(op->value)); + } - Doc VisitAttr_(const ir::UIntImm* op) final { // NOLINT(*) - return PrintConstScalar(op->type, &(op->value)); - } + Doc VisitAttr_(const ir::UIntImm* op) final { // NOLINT(*) + return PrintConstScalar(op->type, &(op->value)); + } - Doc VisitAttr_(const ir::FloatImm* op) final { // NOLINT(*) - return PrintConstScalar(op->type, &(op->value)); - } + Doc VisitAttr_(const ir::FloatImm* op) final { // NOLINT(*) + return PrintConstScalar(op->type, &(op->value)); + } - Doc VisitAttr_(const ir::StringImm* op) final { // NOLINT(*) - return PrintString(op->value); - } + Doc VisitAttr_(const ir::StringImm* op) final { // NOLINT(*) + return PrintString(op->value); + } private: /*! \brief Stack of docs to implement scoped GNFing. */ diff --git a/tests/python/relay/test_ir_parser_roundtrip.py b/tests/python/relay/test_ir_parser_roundtrip.py index d5b9c8ff137b..bbea46d1d81f 100644 --- a/tests/python/relay/test_ir_parser_roundtrip.py +++ b/tests/python/relay/test_ir_parser_roundtrip.py @@ -35,11 +35,19 @@ def test_roundtrip_pp(e): def test_gnf(): assert gnf_print(relay.const(1)) == "v0.0.1\n1\n" - assert gnf_print(relay.Tuple([relay.const(1), relay.const(1)])) == "v0.0.1\n%0 = 1\n%1 = 1\n(%0, %1)\n" + assert gnf_print(relay.Tuple([relay.const(1), relay.const(1)])) == "v0.0.1\n(1, 1)\n" one = relay.const(1) - assert gnf_print(relay.Tuple([one, one])) == "v0.0.1\n%0 = 1\n(%0, %0)\n" + assert gnf_print(relay.Tuple([one, one])) == "v0.0.1\n(1, 1)\n" - assert gnf_print(relay.If(relay.const(True), relay.TupleGetItem(relay.Tuple([one, one]), 0), relay.TupleGetItem(relay.Tuple([one, one, relay.const(1)]), 0))) == "v0.0.1\n%0 = True\nif (%0) {\n %1 = 1\n %2 = (%1, %1)\n %2.0\n} else {\n %1 = 1\n %2 = 1\n %3 = (%1, %1, %2)\n %3.0\n}\n" + # assert gnf_print(relay.If(relay.const(True), relay.TupleGetItem(relay.Tuple([one, one]), 0), relay.TupleGetItem(relay.Tuple([one, one, relay.const(1)]), 0))) == "v0.0.1\n%0 = True\nif (%0) {\n %1 = 1\n %2 = (%1, %1)\n %2.0\n} else {\n %1 = 1\n %2 = 1\n %3 = (%1, %1, %2)\n %3.0\n}\n" + +def test_tensor_type(): + assert gnf_print(relay.TensorType([5, 5])) == "v0.0.1\nTensor[(5, 5), float32]\n" + +def test_tuple_type(): + assert gnf_print(relay.TupleType([])) == "v0.0.1\n()\n" + assert gnf_print(relay.TupleType([relay.scalar_type("int32")])) == "v0.0.1\n(int32,)\n" + assert gnf_print(relay.TupleType([relay.scalar_type("int32"),relay.scalar_type("int32")])) == "v0.0.1\n(int32, int32)\n" if __name__ == "__main__": # for _ in range(10): @@ -76,8 +84,5 @@ def test_gnf(): print(gnf_print(relay.fromtext(SEMVER+"fn(%x, %y) { %x + %y }"))) print(relay.Call(relay.fromtext(SEMVER+"fn(%x) { %x }"), [relay.const(1)], attrs=tvm.make.node("DictAttrs", n="foo")).astext()) # print(anf_print(relay.Call(relay.fromtext(SEMVER+"fn(%x) { %x }"), [relay.const(1)], attrs=tvm.make.node("DictAttrs", n="foo")))) - print(relay.TensorType([5, 5]).astext()) - print(anf_print(relay.TensorType([5, 5]))) - print(gnf_print(relay.TensorType([5, 5]))) # print(relay.fromtext(SEMVER+"add(n=5)").astext()) # print(anf_print(relay.fromtext(SEMVER+"fn (n=5) { () }"))) \ No newline at end of file From f6c26f5e7b88a385479940e7b1a5e9941e257904 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Sun, 3 Mar 2019 22:06:15 -0800 Subject: [PATCH 31/73] func type --- src/relay/ir/pretty_printer.cc | 11 ++++++++++- tests/python/relay/test_ir_parser_roundtrip.py | 3 +++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index a3aea69fd1b0..fff894b3f786 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -275,7 +275,7 @@ class PrettyPrinter : Doc VisitType_(const TupleTypeNode* node) final { std::vector fields; - for (NodeRef field : node->fields) { + for (Type field : node->fields) { fields.push_back(Print(field)); } Doc doc = Nil(); @@ -287,6 +287,15 @@ class PrettyPrinter : return doc << ")"; } + Doc VisitType_(const FuncTypeNode* node) final { + Doc doc = Nil(); + std::vector arg_types; + for (Type arg_type : node->arg_types) { + arg_types.push_back(Print(arg_type)); + } + return doc << "fn (" << PrintVec(arg_types) << ") -> " << Print(node->ret_type); + } + //------------------------------------ // Overload of Attr printing functions //------------------------------------ diff --git a/tests/python/relay/test_ir_parser_roundtrip.py b/tests/python/relay/test_ir_parser_roundtrip.py index bbea46d1d81f..ce220451ccf3 100644 --- a/tests/python/relay/test_ir_parser_roundtrip.py +++ b/tests/python/relay/test_ir_parser_roundtrip.py @@ -49,6 +49,9 @@ def test_tuple_type(): assert gnf_print(relay.TupleType([relay.scalar_type("int32")])) == "v0.0.1\n(int32,)\n" assert gnf_print(relay.TupleType([relay.scalar_type("int32"),relay.scalar_type("int32")])) == "v0.0.1\n(int32, int32)\n" +def test_func_type(): + assert gnf_print(relay.FuncType([relay.scalar_type("int32"), relay.scalar_type("int32")], relay.scalar_type("int32"))) == "v0.0.1\nfn (int32, int32) -> int32\n" + if __name__ == "__main__": # for _ in range(10): # print(anf_print(exprs.example())) From 10fe7bf4d8cb14b759b8269ccea229aee8973cff Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Sun, 3 Mar 2019 22:12:04 -0800 Subject: [PATCH 32/73] set up PrintMod --- src/relay/ir/pretty_printer.cc | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index fff894b3f786..cd096ef7e06e 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -61,6 +61,8 @@ class PrettyPrinter : return PrintExpr(Downcast(node), gnf); } else if (node.as_derived()) { return PrintType(Downcast(node)); + // } else if (node.as_derived()) { + // return PrintMod(Downcast(node)); } else { assert(false); } } @@ -221,6 +223,20 @@ class PrettyPrinter : return doc; } + /* Doc PrintMod(const Module& mod) { + Doc doc = Nil(); + int counter = 0; + for (const auto& kv : mod->functions) { + std::ostringstream os; + if (counter++ != 0) { + doc << "\n"; + } + os << "def @" << kv.first->name_hint; + doc << PrintFunc(os.str(), kv.second); + return doc << "\n"; + } + } */ + Doc VisitExpr_(const FunctionNode* op) final { return PrintFunc(Text("fn "), op); } From 7aa71cfe37ba40c45b667035770a7063b336a92d Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Mon, 4 Mar 2019 14:19:41 -0800 Subject: [PATCH 33/73] free vars and modules --- src/relay/ir/pretty_printer.cc | 30 +++++++++++-------- .../python/relay/test_ir_parser_roundtrip.py | 26 ++++++++++++++-- 2 files changed, 42 insertions(+), 14 deletions(-) diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index cd096ef7e06e..5a6f3af1c6d8 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -61,8 +61,8 @@ class PrettyPrinter : return PrintExpr(Downcast(node), gnf); } else if (node.as_derived()) { return PrintType(Downcast(node)); - // } else if (node.as_derived()) { - // return PrintMod(Downcast(node)); + } else if (node.as_derived()) { + return PrintMod(Downcast(node)); } else { assert(false); } } @@ -130,17 +130,24 @@ class PrettyPrinter : if (it != memo_.end()) return it->second; Doc printed_expr = VisitExpr(expr); // we choose to inline some nodes - if (GNF_ && gnf && !expr.as() && !expr.as() && !expr.as()) { + if (GNF_ && gnf && !expr.as() && !expr.as() && !expr.as() && !expr.as()) { Doc temp_var = AllocTemp(); memo_[expr] = temp_var; doc_stack_.back() << temp_var << " = " << printed_expr << "\n"; return temp_var; } else { memo_[expr] = printed_expr; + if (expr.as()) { + doc_stack_.back() << "free_var " << printed_expr << "\n"; + } return printed_expr; } } + Doc VisitExpr_(const VarNode* op) final { + return AllocVar(GetRef(op)); + } + Doc VisitExpr_(const ConstantNode* op) final { // Print out simple scalars directly. if (op->is_scalar()) { @@ -188,21 +195,19 @@ class PrettyPrinter : Doc VisitExpr_(const LetNode* op) final { Doc doc = Nil(); - // TODO: this should call a var printer, which needs to differentiate - // between free and bound vars // TODO: this should have a type annotation // TODO: lets in value position need to be scoped // we use ANF mode for the first level of the value position so the final // expression isn't hoisted or added to the doc stream - doc << "let %" << op->var->name_hint() << " = " << Print(op->value, false) << ";" << "\n"; + doc << "let " << AllocVar(op->var) << " = " << Print(op->value, false) << ";" << "\n"; // we use a nested scope here so GNF hoisting doesn't escape too far // and so consecutive lets don't get hoisted doc << PrintNestedScope(op->body); return doc; } - Doc PrintFunc(const Doc& prefix, const FunctionNode* fn) { + Doc PrintFunc(const Doc& prefix, const Function& fn) { // TODO(tqchen, M.K.) support generic function // Possibly through meta-data CHECK_EQ(fn->type_params.size(), 0U) @@ -223,7 +228,7 @@ class PrettyPrinter : return doc; } - /* Doc PrintMod(const Module& mod) { + Doc PrintMod(const Module& mod) { Doc doc = Nil(); int counter = 0; for (const auto& kv : mod->functions) { @@ -232,13 +237,14 @@ class PrettyPrinter : doc << "\n"; } os << "def @" << kv.first->name_hint; - doc << PrintFunc(os.str(), kv.second); - return doc << "\n"; + doc << PrintFunc(Text(os.str()), kv.second); + doc << "\n"; } - } */ + return doc; + } Doc VisitExpr_(const FunctionNode* op) final { - return PrintFunc(Text("fn "), op); + return PrintFunc(Text("fn "), GetRef(op)); } Doc VisitExpr_(const GlobalVarNode* op) final { diff --git a/tests/python/relay/test_ir_parser_roundtrip.py b/tests/python/relay/test_ir_parser_roundtrip.py index ce220451ccf3..f2f055fbac3c 100644 --- a/tests/python/relay/test_ir_parser_roundtrip.py +++ b/tests/python/relay/test_ir_parser_roundtrip.py @@ -71,7 +71,7 @@ def test_func_type(): print(relay.Let(relay.var("x"), relay.Tuple([tup, tup]), relay.const(5)).astext()) print(gnf_print(relay.Let(relay.var("x"), relay.Tuple([tup, tup]), relay.const(5)))) print(anf_print(relay.Let(relay.var("x"), relay.Tuple([tup, tup]), relay.const(5)))) - print(anf_print(relay.fromtext(SEMVER+"let %x = 1; let %x = 2; 3"))) + print(anf_print(relay.fromtext(SEMVER+"let %x = 1; let %x = 2; %x"))) print(gnf_print(relay.fromtext(SEMVER+"let %x = 1; let %x = 2; 3"))) print(anf_print(relay.fromtext(SEMVER+"fn(%x) { %x }"))) print(gnf_print(relay.fromtext(SEMVER+"fn(%x) { %x }"))) @@ -88,4 +88,26 @@ def test_func_type(): print(relay.Call(relay.fromtext(SEMVER+"fn(%x) { %x }"), [relay.const(1)], attrs=tvm.make.node("DictAttrs", n="foo")).astext()) # print(anf_print(relay.Call(relay.fromtext(SEMVER+"fn(%x) { %x }"), [relay.const(1)], attrs=tvm.make.node("DictAttrs", n="foo")))) # print(relay.fromtext(SEMVER+"add(n=5)").astext()) - # print(anf_print(relay.fromtext(SEMVER+"fn (n=5) { () }"))) \ No newline at end of file + # print(anf_print(relay.fromtext(SEMVER+"fn (n=5) { () }"))) + x = relay.var("x", shape=(3, 2)) + y = relay.var("y") + one = relay.const(10e10, dtype="float32") + z = relay.add(x, one) + z = relay.add(z, z) + f = relay.Function([x, y], z) + print(z.astext()) + print(f.astext()) + print(gnf_print(z)) + print(gnf_print(f)) + print(anf_print(z)) + print(anf_print(f)) + x = relay.var("x", "float32") + y = relay.var("y", "float32") + z = relay.add(x, y) + z = relay.add(z, z) + f = relay.Function([x, y], z) + env = relay.Module() + env["myf"] = f + print(env.astext()) + print(gnf_print(env)) + print(anf_print(env)) \ No newline at end of file From cf91b0288028b3a06bb20dfcf6d3d3d78cd838d5 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Mon, 4 Mar 2019 16:12:23 -0800 Subject: [PATCH 34/73] type annotations --- src/relay/ir/pretty_printer.cc | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 5a6f3af1c6d8..96526a05e942 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -114,7 +114,9 @@ class PrettyPrinter : memo_[var] = val + Text("-malformed-ir"); } memo_[var] = val; - // TODO: should also return type annotation + if (var->type_annotation.defined()) { + val << ": " << Print(var->type_annotation); + } return val; } @@ -220,10 +222,9 @@ class PrettyPrinter : } doc << PrintVec(params) << PrintAttrs(fn->attrs); doc << ") "; - /* if (fn->ret_type.defined()) { - doc << " -> "; - this->PrintType(fn->ret_type, stream_); - } */ + if (fn->ret_type.defined()) { + doc << "-> " << Print(fn->ret_type) << " "; + } doc << PrintBody(fn->body); return doc; } From 46835e04b1fc6ee70f8e16af35ef9797441f5c6d Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Mon, 4 Mar 2019 16:20:30 -0800 Subject: [PATCH 35/73] scope in let value position --- src/relay/ir/pretty_printer.cc | 15 +++++++++------ tests/python/relay/test_ir_parser_roundtrip.py | 5 ++++- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 96526a05e942..49a28e84b4c2 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -197,12 +197,15 @@ class PrettyPrinter : Doc VisitExpr_(const LetNode* op) final { Doc doc = Nil(); - // TODO: this should have a type annotation - // TODO: lets in value position need to be scoped - - // we use ANF mode for the first level of the value position so the final - // expression isn't hoisted or added to the doc stream - doc << "let " << AllocVar(op->var) << " = " << Print(op->value, false) << ";" << "\n"; + doc << "let " << AllocVar(op->var) << " = "; + if (op->value.as()) { + doc << PrintBody(op->value); + } else { + // we use ANF mode for the first level of the value position so the + // final expression isn't hoisted or added to the doc stream + doc << Print(op->value, false); + } + doc << ";" << "\n"; // we use a nested scope here so GNF hoisting doesn't escape too far // and so consecutive lets don't get hoisted doc << PrintNestedScope(op->body); diff --git a/tests/python/relay/test_ir_parser_roundtrip.py b/tests/python/relay/test_ir_parser_roundtrip.py index f2f055fbac3c..d4ad74b61054 100644 --- a/tests/python/relay/test_ir_parser_roundtrip.py +++ b/tests/python/relay/test_ir_parser_roundtrip.py @@ -110,4 +110,7 @@ def test_func_type(): env["myf"] = f print(env.astext()) print(gnf_print(env)) - print(anf_print(env)) \ No newline at end of file + print(anf_print(env)) + print(gnf_print(relay.fromtext(SEMVER+"let %x = { let %y = 2; %y }; %x"))) + print(gnf_print(relay.fromtext(SEMVER+"let %x = { let %y = 2; ((%y + %y, %y * %y), 1) }; %x"))) + print(anf_print(relay.fromtext(SEMVER+"let %x = { let %y = 2; ((%y + %y, %y * %y), 1) }; %x"))) \ No newline at end of file From d2cfe74441cd2e5249eb9d9f3b438975ca7c3d47 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Mon, 4 Mar 2019 22:24:51 -0800 Subject: [PATCH 36/73] meta except it doesn't work and it broke old stuff --- src/relay/ir/pretty_printer.cc | 188 +++++++++++++++--- .../python/relay/test_ir_parser_roundtrip.py | 3 +- 2 files changed, 166 insertions(+), 25 deletions(-) diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 49a28e84b4c2..03733334f3fb 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -12,16 +12,109 @@ namespace tvm { namespace relay { +/*! + * \brief Meta data context for TextPrinter. + * + * This is an important part to enable bi-directional serializability. + * We use tvm's Node system to build the current IR. + * It can be hard to design a text format for all the possible nodes + * as the set of nodes can grow when we do more extensions. + * + * Instead of trying to design readable text format for every node, + * we support a meta data section in the text format. + * We allow the text format to refer to a node in the meta data section. + * + * The meta data section is a json serialized string of an Map>. + * Each element in the meta data section can be referenced by the text format. + * Each meta data node is printed in the following format. + * + * meta[type-key-of-node>][] + * + * Specifically, consider the following IR(constructed by python). + * + * \code + * + * n = tvm.var("n") + * x = tvm.relay.var("x", shape=(n, 1)) + * f = tvm.relay.Function([x], x) + * print(f.astext()) + * + * \endcode + * + * The corresponding text format is shown in the following code block. + * + * \code + * + * fn (%x: Tensor[(meta[Variable][0],), float32]) { + * %x + * } + * # Meta data section is a json-serialized string + * # of the following array. + * # [tvm.var("n")] + * + * \endcode + * + * Note that we store tvm.var("n") in the meta data section. + * Since it is stored in the index-0 in the meta data section, + * we print it as meta[Variable][0]. + * + * The text parser can recover this object by loading from the corresponding + * location in the meta data section. + * + * This is is a design trade-off. + * It allows us to embedded any meta data in the text format, + * while still being able to tweak the text part of the printed IR easily. + */ +class TextMetaDataContext { + public: + /*! + * \brief Get text representation of meta node. + * \param node The node to be converted to meta node. + * \return A string representation of the meta node. + */ + Doc GetMetaNode(const NodeRef& node) { + auto it = meta_repr_.find(node); + if (it != meta_repr_.end()) { + return it->second; + } + Array& mvector = + meta_data_[node->type_key()]; + int64_t index = static_cast(mvector.size()); + mvector.push_back(node); + Doc doc = Nil(); + doc << "meta[" << node->type_key() << "][" << index << "]"; + meta_repr_[node] = doc; + return meta_repr_[node]; + } + /*! + * \brief Get the metadata section in json format. + * \return the meta datastring. + */ + std::string GetMetaSection() const { + if (meta_data_.size() == 0) return std::string(); + return SaveJSON(Map( + meta_data_.begin(), meta_data_.end())); + } + + /*! \return whether the meta data context is empty. */ + bool empty() const { + return meta_data_.empty(); + } + + private: + /*! \brief additional metadata stored in TVM json format */ + std::unordered_map > meta_data_; + /*! \brief map from meta data into its string representation */ + std::unordered_map meta_repr_; +}; + class PrettyPrinter : public ExprFunctor, public TypeFunctor, public AttrFunctor { public: - explicit PrettyPrinter(const std::unordered_map& memo_, const std::unordered_map& memo_type_, const std::unordered_map& name_alloc_map_, size_t temp_var_counter_, bool GNF_) : memo_(memo_), memo_type_(memo_type_), name_alloc_map_(name_alloc_map_), temp_var_counter_(temp_var_counter_), GNF_(GNF_) {} - explicit PrettyPrinter() : temp_var_counter_(0), GNF_(true) {} - - explicit PrettyPrinter(bool GNF_) : temp_var_counter_(0), GNF_(GNF_) {} + explicit PrettyPrinter(bool GNF, bool show_meta_data) : show_meta_data_(show_meta_data), temp_var_counter_(0), GNF_(GNF) {} // indent a new body Doc PrintBody(const NodeRef& node, int indent = 2) { @@ -39,7 +132,8 @@ class PrettyPrinter : if (GNF_) { // print in a new scope doc_stack_.push_back(Nil()); - Doc doc = PrintFinal(node); + Doc doc = Print(node, false); + doc_stack_.back() << doc; doc_stack_.pop_back(); return doc; } else { @@ -50,17 +144,26 @@ class PrettyPrinter : Doc PrintFinal(const NodeRef& node) { // must print first so doc_stack_.back() reference doesn't become stale Doc doc = Print(node, false); + if (!meta_.empty()) { + if (show_meta_data_) { + std::string meta_json = meta_.GetMetaSection(); + // append meta data in the end. + doc << "/* meta data */" << "\n" << meta_json << "\n"; + } else { + doc << "// meta data omitted. you can use show_meta_data=True to include meta data\n"; + } + }; return doc_stack_.back() << doc; } Doc PrintAttrs(const Attrs& attrs); // note: gnf flag is only one level deep - Doc Print(const NodeRef& node, bool gnf = true) { + Doc Print(const NodeRef& node, bool gnf = true, bool meta = false) { if (node.as_derived()) { - return PrintExpr(Downcast(node), gnf); + return PrintExpr(Downcast(node), gnf, meta); } else if (node.as_derived()) { - return PrintType(Downcast(node)); + return PrintType(Downcast(node), meta); } else if (node.as_derived()) { return PrintMod(Downcast(node)); } else { assert(false); } @@ -123,14 +226,21 @@ class PrettyPrinter : //------------------------------------ // Overload of Expr printing functions //------------------------------------ - Doc PrintExpr(const Expr& expr, bool gnf = true) { + Doc PrintExpr(const Expr& expr, bool gnf, bool meta) { // Exploit memoization to print GNF. // The first time we visit an expression, we need to allocate a temp var // for it. Every subsequent time we can just use its assigned variable. // This works since hashing uses pointer equality. auto it = memo_.find(expr); if (it != memo_.end()) return it->second; - Doc printed_expr = VisitExpr(expr); + Doc printed_expr; + if (meta) { + printed_expr = meta_.GetMetaNode(GetRef(expr.get())); + std::cerr << Layout(printed_expr) << "\n"; + assert(false); + } else { + printed_expr = VisitExpr(expr); + } // we choose to inline some nodes if (GNF_ && gnf && !expr.as() && !expr.as() && !expr.as() && !expr.as()) { Doc temp_var = AllocTemp(); @@ -168,8 +278,8 @@ class PrettyPrinter : return PrintConstScalar(dtype, static_cast(op->data->data)); } } - // TODO: handle tensors - assert(false); + // default fall-back, record it as meta node. + return Print(GetRef(op), true, true); } Doc VisitExpr_(const TupleNode* op) final { @@ -214,7 +324,7 @@ class PrettyPrinter : Doc PrintFunc(const Doc& prefix, const Function& fn) { // TODO(tqchen, M.K.) support generic function - // Possibly through meta-data + // Possibly through meta data CHECK_EQ(fn->type_params.size(), 0U) << "generic fn not yet supported"; Doc doc = Nil(); @@ -272,14 +382,24 @@ class PrettyPrinter : //------------------------------------ // Overload of Type printing functions //------------------------------------ - Doc PrintType(const Type& type) { + Doc PrintType(const Type& type, bool meta) { auto it = memo_type_.find(type); if (it != memo_type_.end()) return it->second; - Doc printed_type = VisitType(type); + Doc printed_type; + if (meta) { + printed_type = meta_.GetMetaNode(GetRef(type.get())); + } else { + printed_type = VisitType(type); + } memo_type_[type] = printed_type; return printed_type; } + Doc VisitTypeDefault_(const Node* node) final { // NOLINT(*) + // by default always print as meta data + return Print(GetRef(node), true, true); + } + Doc VisitType_(const TensorTypeNode* node) final { // NOLINT(*) // scalar type if (node->shape.size() == 0) { @@ -326,14 +446,24 @@ class PrettyPrinter : // Overload of Attr printing functions //------------------------------------ - Doc PrintAttr(const NodeRef& value) { // NOLINT(*) + Doc PrintAttr(const NodeRef& value, bool meta = false) { // NOLINT(*) if (value.defined()) { - return VisitAttr(value); + Doc printed_attr; + if (meta) { + printed_attr = meta_.GetMetaNode(value); + } else { + printed_attr = VisitAttr(value); + } + return printed_attr; } else { return Text("None"); } } + Doc VisitAttrDefault_(const Node* op) final { // NOLINT(*) + return PrintAttr(GetRef(op), true); + } + Doc VisitAttr_(const ArrayNode* op) final { // NOLINT(*) Doc doc = Nil(); doc << "["; @@ -346,11 +476,6 @@ class PrettyPrinter : return doc; } - Doc VisitAttrDefault_(const Node* op) final { // NOLINT(*) - // os << meta_.GetMetaNode(GetRef(op)); - assert(false); - } - Doc VisitAttr_(const ir::IntImm* op) final { // NOLINT(*) return PrintConstScalar(op->type, &(op->value)); } @@ -368,6 +493,8 @@ class PrettyPrinter : } private: + /*! \brief Whether to print meta data. */ + bool show_meta_data_; /*! \brief Stack of docs to implement scoped GNFing. */ std::vector doc_stack_{Nil()}; /*! \brief Map from Expr to Doc */ @@ -375,6 +502,8 @@ class PrettyPrinter : /*! \brief Map from Type to Doc */ std::unordered_map memo_type_; std::unordered_map name_alloc_map_; + /*! \brief meta data context */ + TextMetaDataContext meta_; size_t temp_var_counter_; bool GNF_; class AttrPrinter; @@ -440,11 +569,19 @@ Doc PrettyPrinter::PrintAttrs(const Attrs& attrs) { // NOLINT(*) } std::string RelayGNFPrint(const NodeRef& node) { - return "v0.0.1\n" + Layout(PrettyPrinter().PrintFinal(node)) + "\n"; + Doc doc = Nil(); + doc << "v0.0.1" << "\n" << PrettyPrinter(true, false).PrintFinal(node) << "\n"; + return Layout(doc); } std::string RelayANFPrint(const NodeRef& node) { - return "v0.0.1\n" + Layout(PrettyPrinter(false).Print(node)) + "\n"; + return "v0.0.1\n" + Layout(PrettyPrinter(false, false).Print(node)) + "\n"; +} + +std::string RelayPrettyPrint(const NodeRef& node, bool gnf, bool show_meta_data) { + Doc doc = Nil(); + doc << "v0.0.1" << "\n" << PrettyPrinter(gnf, show_meta_data).PrintFinal(node) << "\n"; + return Layout(doc); } TVM_REGISTER_API("relay._expr.gnf_print") @@ -453,5 +590,8 @@ TVM_REGISTER_API("relay._expr.gnf_print") TVM_REGISTER_API("relay._expr.anf_print") .set_body_typed(RelayANFPrint); +TVM_REGISTER_API("relay._expr.pretty_print") +.set_body_typed(RelayPrettyPrint); + } // relay } // tvm diff --git a/tests/python/relay/test_ir_parser_roundtrip.py b/tests/python/relay/test_ir_parser_roundtrip.py index d4ad74b61054..4465901cf9c4 100644 --- a/tests/python/relay/test_ir_parser_roundtrip.py +++ b/tests/python/relay/test_ir_parser_roundtrip.py @@ -113,4 +113,5 @@ def test_func_type(): print(anf_print(env)) print(gnf_print(relay.fromtext(SEMVER+"let %x = { let %y = 2; %y }; %x"))) print(gnf_print(relay.fromtext(SEMVER+"let %x = { let %y = 2; ((%y + %y, %y * %y), 1) }; %x"))) - print(anf_print(relay.fromtext(SEMVER+"let %x = { let %y = 2; ((%y + %y, %y * %y), 1) }; %x"))) \ No newline at end of file + print(anf_print(relay.fromtext(SEMVER+"let %x = { let %y = 2; ((%y + %y, %y * %y), 1) }; %x"))) + print(relay.const([1,2,3]).astext()) \ No newline at end of file From 9b4bad70561b5b50e68dd2aa93ff15d5a2d44c5b Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Tue, 5 Mar 2019 11:47:35 -0800 Subject: [PATCH 37/73] fix bugs except for meta bugs --- src/relay/ir/pretty_printer.cc | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 03733334f3fb..e42bfcaf0db7 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -121,19 +121,20 @@ class PrettyPrinter : Doc doc = Nil(); Doc body = Nil(); doc << "{"; - doc << Indent(indent, body << "\n" << PrintNestedScope(node)) << "\n"; + doc << Indent(indent, body << "\n" << PrintScope(node)) << "\n"; doc << "}"; return doc; } // create a new scope by creating a new printer object. This allows temp var // numbers to be reused and prevents hoisted vars from escaping too far - Doc PrintNestedScope(const NodeRef& node) { + Doc PrintScope(const NodeRef& node) { if (GNF_) { // print in a new scope doc_stack_.push_back(Nil()); + // must print first so doc_stack_.back() reference doesn't become stale Doc doc = Print(node, false); - doc_stack_.back() << doc; + doc = doc_stack_.back() << doc; doc_stack_.pop_back(); return doc; } else { @@ -142,8 +143,8 @@ class PrettyPrinter : } Doc PrintFinal(const NodeRef& node) { - // must print first so doc_stack_.back() reference doesn't become stale - Doc doc = Print(node, false); + Doc doc = Nil(); + doc << PrintScope(node); if (!meta_.empty()) { if (show_meta_data_) { std::string meta_json = meta_.GetMetaSection(); @@ -153,7 +154,7 @@ class PrettyPrinter : doc << "// meta data omitted. you can use show_meta_data=True to include meta data\n"; } }; - return doc_stack_.back() << doc; + return doc; } Doc PrintAttrs(const Attrs& attrs); @@ -318,7 +319,7 @@ class PrettyPrinter : doc << ";" << "\n"; // we use a nested scope here so GNF hoisting doesn't escape too far // and so consecutive lets don't get hoisted - doc << PrintNestedScope(op->body); + doc << PrintScope(op->body); return doc; } @@ -496,7 +497,7 @@ class PrettyPrinter : /*! \brief Whether to print meta data. */ bool show_meta_data_; /*! \brief Stack of docs to implement scoped GNFing. */ - std::vector doc_stack_{Nil()}; + std::vector doc_stack_{}; /*! \brief Map from Expr to Doc */ std::unordered_map memo_; /*! \brief Map from Type to Doc */ @@ -575,7 +576,7 @@ std::string RelayGNFPrint(const NodeRef& node) { } std::string RelayANFPrint(const NodeRef& node) { - return "v0.0.1\n" + Layout(PrettyPrinter(false, false).Print(node)) + "\n"; + return "v0.0.1\n" + Layout(PrettyPrinter(false, false).PrintFinal(node)) + "\n"; } std::string RelayPrettyPrint(const NodeRef& node, bool gnf, bool show_meta_data) { From 0549d26e23376058ac03b4f97cb75811c8f15120 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Tue, 5 Mar 2019 15:09:57 -0800 Subject: [PATCH 38/73] fix printing of free variables. meta still broken --- src/relay/ir/pretty_printer.cc | 33 +++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index e42bfcaf0db7..ea0dcccd303d 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -65,7 +65,7 @@ namespace relay { * It allows us to embedded any meta data in the text format, * while still being able to tweak the text part of the printed IR easily. */ -class TextMetaDataContext { +class TextMetaDataContextFoo { public: /*! * \brief Get text representation of meta node. @@ -201,12 +201,7 @@ class PrettyPrinter : return Text(prefix); } - /*! - * \brief Allocate name to a variable. - * \param var The input variable. - * \return The corresponding name. - */ - Doc AllocVar(const Var& var) { + Doc PrintVar(const Var& var) { std::string name = var->name_hint(); // always make sure first name is alpha if (name.length() != 0 && !std::isalpha(name[0])) { @@ -215,13 +210,23 @@ class PrettyPrinter : Doc val = GetUniqueName("%" + name); // still print if ir is malformed, but show the error. if (memo_.count(var)) { - memo_[var] = val + Text("-malformed-ir"); + val << Text("-malformed-ir"); } memo_[var] = val; + return val; + } + + /*! + * \brief Allocate name to a variable. + * \param var The input variable. + * \return The corresponding name. + */ + Doc AllocVar(const Var& var) { + Doc doc = PrintVar(var); if (var->type_annotation.defined()) { - val << ": " << Print(var->type_annotation); + doc << ": " << Print(var->type_annotation); } - return val; + return doc; } //------------------------------------ @@ -237,8 +242,6 @@ class PrettyPrinter : Doc printed_expr; if (meta) { printed_expr = meta_.GetMetaNode(GetRef(expr.get())); - std::cerr << Layout(printed_expr) << "\n"; - assert(false); } else { printed_expr = VisitExpr(expr); } @@ -257,8 +260,10 @@ class PrettyPrinter : } } + // Should only be triggered when op is a free variable being visited for the + // first time. Doc VisitExpr_(const VarNode* op) final { - return AllocVar(GetRef(op)); + return PrintVar(GetRef(op)); } Doc VisitExpr_(const ConstantNode* op) final { @@ -504,7 +509,7 @@ class PrettyPrinter : std::unordered_map memo_type_; std::unordered_map name_alloc_map_; /*! \brief meta data context */ - TextMetaDataContext meta_; + TextMetaDataContextFoo meta_; size_t temp_var_counter_; bool GNF_; class AttrPrinter; From d117835743250ede973630af384e725f6fb04946 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Tue, 5 Mar 2019 15:17:43 -0800 Subject: [PATCH 39/73] fix free vars for real --- src/relay/ir/pretty_printer.cc | 29 ++++++++++++----------------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index ea0dcccd303d..92c55ba56e8c 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -201,7 +201,12 @@ class PrettyPrinter : return Text(prefix); } - Doc PrintVar(const Var& var) { + /*! + * \brief Allocate name to a variable. + * \param var The input variable. + * \return The corresponding name. + */ + Doc AllocVar(const Var& var) { std::string name = var->name_hint(); // always make sure first name is alpha if (name.length() != 0 && !std::isalpha(name[0])) { @@ -213,20 +218,10 @@ class PrettyPrinter : val << Text("-malformed-ir"); } memo_[var] = val; - return val; - } - - /*! - * \brief Allocate name to a variable. - * \param var The input variable. - * \return The corresponding name. - */ - Doc AllocVar(const Var& var) { - Doc doc = PrintVar(var); if (var->type_annotation.defined()) { - doc << ": " << Print(var->type_annotation); + val << ": " << Print(var->type_annotation); } - return doc; + return val; } //------------------------------------ @@ -251,11 +246,11 @@ class PrettyPrinter : memo_[expr] = temp_var; doc_stack_.back() << temp_var << " = " << printed_expr << "\n"; return temp_var; + } else if (expr.as()) { + doc_stack_.back() << "free_var " << printed_expr << "\n"; + return memo_[expr]; } else { memo_[expr] = printed_expr; - if (expr.as()) { - doc_stack_.back() << "free_var " << printed_expr << "\n"; - } return printed_expr; } } @@ -263,7 +258,7 @@ class PrettyPrinter : // Should only be triggered when op is a free variable being visited for the // first time. Doc VisitExpr_(const VarNode* op) final { - return PrintVar(GetRef(op)); + return AllocVar(GetRef(op)); } Doc VisitExpr_(const ConstantNode* op) final { From d6d4065f1263b365cef259fb85bf5ef41a289361 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Tue, 5 Mar 2019 15:45:13 -0800 Subject: [PATCH 40/73] fix segfault --- src/relay/ir/pretty_printer.cc | 25 +++++++++---------- .../python/relay/test_ir_parser_roundtrip.py | 3 ++- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 92c55ba56e8c..3f2a301a6099 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -129,17 +129,13 @@ class PrettyPrinter : // create a new scope by creating a new printer object. This allows temp var // numbers to be reused and prevents hoisted vars from escaping too far Doc PrintScope(const NodeRef& node) { - if (GNF_) { - // print in a new scope - doc_stack_.push_back(Nil()); - // must print first so doc_stack_.back() reference doesn't become stale - Doc doc = Print(node, false); - doc = doc_stack_.back() << doc; - doc_stack_.pop_back(); - return doc; - } else { - return Print(node); - } + // print in a new scope + doc_stack_.push_back(Nil()); + // must print first so doc_stack_.back() reference doesn't become stale + Doc doc = Print(node, false); + doc = doc_stack_.back() << doc; + doc_stack_.pop_back(); + return doc; } Doc PrintFinal(const NodeRef& node) { @@ -149,9 +145,9 @@ class PrettyPrinter : if (show_meta_data_) { std::string meta_json = meta_.GetMetaSection(); // append meta data in the end. - doc << "/* meta data */" << "\n" << meta_json << "\n"; + doc << "\n" << "/* meta data */" << "\n" << meta_json; } else { - doc << "// meta data omitted. you can use show_meta_data=True to include meta data\n"; + doc << "\n" << "// meta data omitted. you can use show_meta_data=True to include meta data"; } }; return doc; @@ -247,7 +243,10 @@ class PrettyPrinter : doc_stack_.back() << temp_var << " = " << printed_expr << "\n"; return temp_var; } else if (expr.as()) { + // This is our first time visiting the var and we hit the VarNode case + // in the visitor. Thus the variable is free. doc_stack_.back() << "free_var " << printed_expr << "\n"; + // Memoization is done in AllocVar. return memo_[expr]; } else { memo_[expr] = printed_expr; diff --git a/tests/python/relay/test_ir_parser_roundtrip.py b/tests/python/relay/test_ir_parser_roundtrip.py index 4465901cf9c4..5432cfe7dbb0 100644 --- a/tests/python/relay/test_ir_parser_roundtrip.py +++ b/tests/python/relay/test_ir_parser_roundtrip.py @@ -114,4 +114,5 @@ def test_func_type(): print(gnf_print(relay.fromtext(SEMVER+"let %x = { let %y = 2; %y }; %x"))) print(gnf_print(relay.fromtext(SEMVER+"let %x = { let %y = 2; ((%y + %y, %y * %y), 1) }; %x"))) print(anf_print(relay.fromtext(SEMVER+"let %x = { let %y = 2; ((%y + %y, %y * %y), 1) }; %x"))) - print(relay.const([1,2,3]).astext()) \ No newline at end of file + print(relay.const([1,2,3]).astext()) + print(gnf_print(relay.const([1,2,3]))) \ No newline at end of file From 772aad951741444916c0d739da12d54033a7d967 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Tue, 5 Mar 2019 17:45:11 -0800 Subject: [PATCH 41/73] complete meta --- src/relay/ir/pretty_printer.cc | 22 +++++++++---------- src/relay/ir/text_printer.cc | 4 ++-- .../python/relay/test_ir_parser_roundtrip.py | 3 ++- 3 files changed, 14 insertions(+), 15 deletions(-) diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 3f2a301a6099..fc49a59c8082 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -13,7 +13,7 @@ namespace tvm { namespace relay { /*! - * \brief Meta data context for TextPrinter. + * \brief Meta data context for PrettyPrinter. * * This is an important part to enable bi-directional serializability. * We use tvm's Node system to build the current IR. @@ -65,7 +65,7 @@ namespace relay { * It allows us to embedded any meta data in the text format, * while still being able to tweak the text part of the printed IR easily. */ -class TextMetaDataContextFoo { +class TextMetaDataContext { public: /*! * \brief Get text representation of meta node. @@ -503,7 +503,7 @@ class PrettyPrinter : std::unordered_map memo_type_; std::unordered_map name_alloc_map_; /*! \brief meta data context */ - TextMetaDataContextFoo meta_; + TextMetaDataContext meta_; size_t temp_var_counter_; bool GNF_; class AttrPrinter; @@ -560,7 +560,7 @@ class PrettyPrinter::AttrPrinter : public AttrVisitor { }; Doc PrettyPrinter::PrintAttrs(const Attrs& attrs) { // NOLINT(*) - // TODO: meta + // TODO: meta? if (!attrs.defined()) return Nil(); Doc doc = Nil(); AttrPrinter printer(doc, this); @@ -568,20 +568,18 @@ Doc PrettyPrinter::PrintAttrs(const Attrs& attrs) { // NOLINT(*) return doc; } -std::string RelayGNFPrint(const NodeRef& node) { +std::string RelayPrettyPrint(const NodeRef& node, bool gnf, bool show_meta_data) { Doc doc = Nil(); - doc << "v0.0.1" << "\n" << PrettyPrinter(true, false).PrintFinal(node) << "\n"; + doc << "v0.0.1" << "\n" << PrettyPrinter(gnf, show_meta_data).PrintFinal(node) << "\n"; return Layout(doc); } -std::string RelayANFPrint(const NodeRef& node) { - return "v0.0.1\n" + Layout(PrettyPrinter(false, false).PrintFinal(node)) + "\n"; +std::string RelayGNFPrint(const NodeRef& node) { + return RelayPrettyPrint(node, true, true); } -std::string RelayPrettyPrint(const NodeRef& node, bool gnf, bool show_meta_data) { - Doc doc = Nil(); - doc << "v0.0.1" << "\n" << PrettyPrinter(gnf, show_meta_data).PrintFinal(node) << "\n"; - return Layout(doc); +std::string RelayANFPrint(const NodeRef& node) { + return RelayPrettyPrint(node, false, true); } TVM_REGISTER_API("relay._expr.gnf_print") diff --git a/src/relay/ir/text_printer.cc b/src/relay/ir/text_printer.cc index 932856a2055d..3eaad13d2f9a 100644 --- a/src/relay/ir/text_printer.cc +++ b/src/relay/ir/text_printer.cc @@ -90,7 +90,7 @@ inline std::ostream& operator<<(std::ostream& os, const TextValue& val) { // NO * It allows us to embedded any meta-data in the text format, * while still being able to tweak the text part of the printed IR easily. */ -class TextMetaDataContext { +class TextMetaDataContextFoo { public: /*! * \brief Get text representation of meta node. @@ -800,7 +800,7 @@ class TextPrinter : /*! \brief additional comment function */ runtime::TypedPackedFunc annotate_; /*! \brief meta data context */ - TextMetaDataContext meta_; + TextMetaDataContextFoo meta_; /*! \brief Check whether scope is still valid */ std::vector scope_valid_; /*! \brief The current indentation value */ diff --git a/tests/python/relay/test_ir_parser_roundtrip.py b/tests/python/relay/test_ir_parser_roundtrip.py index 5432cfe7dbb0..90f7b3fdf47c 100644 --- a/tests/python/relay/test_ir_parser_roundtrip.py +++ b/tests/python/relay/test_ir_parser_roundtrip.py @@ -115,4 +115,5 @@ def test_func_type(): print(gnf_print(relay.fromtext(SEMVER+"let %x = { let %y = 2; ((%y + %y, %y * %y), 1) }; %x"))) print(anf_print(relay.fromtext(SEMVER+"let %x = { let %y = 2; ((%y + %y, %y * %y), 1) }; %x"))) print(relay.const([1,2,3]).astext()) - print(gnf_print(relay.const([1,2,3]))) \ No newline at end of file + print(gnf_print(relay.const([1,2,3]))) + print(anf_print(relay.const([1,2,3]))) \ No newline at end of file From 6d5767ca1bbf530767f21d540634eb47602946bf Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Tue, 5 Mar 2019 18:34:03 -0800 Subject: [PATCH 42/73] bye bye text printer. hello pretty printer --- include/tvm/relay/expr.h | 4 +- python/tvm/relay/base.py | 9 +- src/relay/ir/pretty_printer.cc | 63 +- src/relay/ir/text_printer.cc | 904 ------------------ .../python/relay/test_ir_parser_roundtrip.py | 9 +- 5 files changed, 54 insertions(+), 935 deletions(-) delete mode 100644 src/relay/ir/text_printer.cc diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 06a1aa1ac9ef..ad6c328b9d91 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -549,12 +549,14 @@ inline const TTypeNode* ExprNode::type_as() const { * \param show_meta_data Whether to print meta data section. * \param annotate An optional callback function for attaching * additional comment block to an expr. + * \param gnf Whether to print in GNF. * \return The text representation. */ std::string RelayPrint( const NodeRef& node, bool show_meta_data = true, - runtime::TypedPackedFunc annotate = nullptr); + runtime::TypedPackedFunc annotate = nullptr, + bool gnf = true); } // namespace relay } // namespace tvm #endif // TVM_RELAY_EXPR_H_ diff --git a/python/tvm/relay/base.py b/python/tvm/relay/base.py index e0491d62f552..bde2be95ae8f 100644 --- a/python/tvm/relay/base.py +++ b/python/tvm/relay/base.py @@ -38,7 +38,7 @@ def register_relay_attr_node(type_key=None): class RelayNode(NodeBase): """Base class of all Relay nodes.""" - def astext(self, show_meta_data=True, annotate=None): + def astext(self, show_meta_data=True, annotate=None, gnf=True): """Get the text format of the expression. Parameters @@ -51,9 +51,12 @@ def astext(self, show_meta_data=True, annotate=None): Optional annotate function to provide additional information in the comment block. + gnf : bool + Whether to print in GNF. + Note ---- - The metadata section is necessary to fully parse the text format. + The meta data section is necessary to fully parse the text format. However, it can contain dumps that are big (e.g constant weights), so it can be helpful to skip printing the meta data section. @@ -62,7 +65,7 @@ def astext(self, show_meta_data=True, annotate=None): text : str The text format of the expression. """ - return _expr.RelayPrint(self, show_meta_data, annotate) + return _expr.RelayPrint(self, show_meta_data, annotate, gnf) def set_span(self, span): _base.set_span(self, span) diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index fc49a59c8082..e3dc71e54852 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -113,8 +113,24 @@ class PrettyPrinter : public TypeFunctor, public AttrFunctor { public: - explicit PrettyPrinter() : temp_var_counter_(0), GNF_(true) {} - explicit PrettyPrinter(bool GNF, bool show_meta_data) : show_meta_data_(show_meta_data), temp_var_counter_(0), GNF_(GNF) {} + explicit PrettyPrinter(bool GNF, bool show_meta_data, runtime::TypedPackedFunc annotate) : GNF_(GNF), show_meta_data_(show_meta_data), annotate_(annotate) {} + + /*! + * \brief Print additional info about expr in comment. + * \param expr The expression. + */ + Doc PrintOptionalInfo(const Expr& expr) { + Doc doc = Nil(); + // additional information in comment. + if (annotate_ != nullptr) { + return doc << " # " << annotate_(expr); + } else if (expr->checked_type_.defined()) { + doc << " # ty="; + return doc << Print(expr->checked_type()); + } else { + return Nil(); + } + } // indent a new body Doc PrintBody(const NodeRef& node, int indent = 2) { @@ -240,7 +256,11 @@ class PrettyPrinter : if (GNF_ && gnf && !expr.as() && !expr.as() && !expr.as() && !expr.as()) { Doc temp_var = AllocTemp(); memo_[expr] = temp_var; - doc_stack_.back() << temp_var << " = " << printed_expr << "\n"; + doc_stack_.back() << temp_var << " = " << printed_expr; + if (expr.as()) { + doc_stack_.back() << PrintOptionalInfo(expr); + } + doc_stack_.back() << "\n"; return temp_var; } else if (expr.as()) { // This is our first time visiting the var and we hit the VarNode case @@ -279,7 +299,9 @@ class PrettyPrinter : } } // default fall-back, record it as meta node. - return Print(GetRef(op), true, true); + Doc doc = Nil(); + return doc << Print(GetRef(op), true, true) + << PrintOptionalInfo(GetRef(op)); } Doc VisitExpr_(const TupleNode* op) final { @@ -493,19 +515,24 @@ class PrettyPrinter : } private: + /*! \brief Whether to use GNF. */ + bool GNF_; /*! \brief Whether to print meta data. */ bool show_meta_data_; + /*! \brief additional comment function */ + runtime::TypedPackedFunc annotate_; /*! \brief Stack of docs to implement scoped GNFing. */ std::vector doc_stack_{}; /*! \brief Map from Expr to Doc */ std::unordered_map memo_; /*! \brief Map from Type to Doc */ std::unordered_map memo_type_; + /*! \brief name allocation map */ std::unordered_map name_alloc_map_; /*! \brief meta data context */ TextMetaDataContext meta_; - size_t temp_var_counter_; - bool GNF_; + /*! \brief counter of temporary variable */ + size_t temp_var_counter_{0}; class AttrPrinter; friend class AttrPrinter; }; @@ -560,7 +587,7 @@ class PrettyPrinter::AttrPrinter : public AttrVisitor { }; Doc PrettyPrinter::PrintAttrs(const Attrs& attrs) { // NOLINT(*) - // TODO: meta? + // TODO: fallback meta? if (!attrs.defined()) return Nil(); Doc doc = Nil(); AttrPrinter printer(doc, this); @@ -568,28 +595,14 @@ Doc PrettyPrinter::PrintAttrs(const Attrs& attrs) { // NOLINT(*) return doc; } -std::string RelayPrettyPrint(const NodeRef& node, bool gnf, bool show_meta_data) { +std::string RelayPrint(const NodeRef& node, bool show_meta_data, runtime::TypedPackedFunc annotate, bool gnf) { Doc doc = Nil(); - doc << "v0.0.1" << "\n" << PrettyPrinter(gnf, show_meta_data).PrintFinal(node) << "\n"; + doc << "v0.0.1" << "\n" << PrettyPrinter(gnf, show_meta_data, annotate).PrintFinal(node) << "\n"; return Layout(doc); } -std::string RelayGNFPrint(const NodeRef& node) { - return RelayPrettyPrint(node, true, true); -} - -std::string RelayANFPrint(const NodeRef& node) { - return RelayPrettyPrint(node, false, true); -} - -TVM_REGISTER_API("relay._expr.gnf_print") -.set_body_typed(RelayGNFPrint); - -TVM_REGISTER_API("relay._expr.anf_print") -.set_body_typed(RelayANFPrint); - -TVM_REGISTER_API("relay._expr.pretty_print") -.set_body_typed(RelayPrettyPrint); +TVM_REGISTER_API("relay._expr.RelayPrint") +.set_body_typed, bool)>(RelayPrint); } // relay } // tvm diff --git a/src/relay/ir/text_printer.cc b/src/relay/ir/text_printer.cc deleted file mode 100644 index 3eaad13d2f9a..000000000000 --- a/src/relay/ir/text_printer.cc +++ /dev/null @@ -1,904 +0,0 @@ -/*! - * Copyright (c) 2018 by Contributors - * \file text_printer.cc - * \brief Text printer to print relay in text form. - */ -#include -#include -#include -#include -#include "type_functor.h" -#include "../../lang/attr_functor.h" - -namespace tvm { -namespace relay { - -/*! - * \brief the text value used in text printer. - * Defined as a struct for future compatibility reason - */ -struct TextValue { - /*! \brief The str representation */ - std::string name; - // constructor - TextValue() {} - // constructor - explicit TextValue(std::string name) : name(name) {} - TextValue operator+(const TextValue& rhs) const { - return TextValue(name + rhs.name); - } - TextValue operator+(const std::string& str) const { - return TextValue(name + str); - } -}; - -// operator overloading -inline std::ostream& operator<<(std::ostream& os, const TextValue& val) { // NOLINT(*) - return os << val.name; -} - -/*! - * \brief Meta data context for TextPrinter. - * - * This is an important part to enable bi-directional serializability. - * We use tvm's Node system to build the current IR. - * It can be hard to design a text format for all the possible nodes - * as the set of nodes can grow when we do more extensions. - * - * Instead of trying to design readable text format for every node, - * we support a meta-data section in the text format. - * We allow the text format to refer to a node in the meta-data section. - * - * The meta-data section is a json serialized string of an Map>. - * Each element in the meta-data section can be referenced by the text format. - * Each meta data node is printed in the following format. - * - * meta[type-key-of-node>][] - * - * Specifically, consider the following IR(constructed by python). - * - * \code - * - * n = tvm.var("n") - * x = tvm.relay.var("x", shape=(n, 1)) - * f = tvm.relay.Function([x], x) - * print(f.astext()) - * - * \endcode - * - * The corresponding text format is shown in the following code block. - * - * \code - * - * fn (%x: Tensor[(meta[Variable][0],), float32]) { - * %x - * } - * # Meta data section is a json-serialized string - * # of the following array. - * # [tvm.var("n")] - * - * \endcode - * - * Note that we store tvm.var("n") in the meta data section. - * Since it is stored in the index-0 in the meta-data section, - * we print it as meta[Variable][0]. - * - * The text parser can recover this object by loading from the corresponding - * location in the meta data section. - * - * This is is a design trade-off. - * It allows us to embedded any meta-data in the text format, - * while still being able to tweak the text part of the printed IR easily. - */ -class TextMetaDataContextFoo { - public: - /*! - * \brief Get text representation of meta node. - * \param node The node to be converted to meta node. - * \return A string representation of the meta node. - */ - std::string GetMetaNode(const NodeRef& node) { - auto it = meta_repr_.find(node); - if (it != meta_repr_.end()) { - return it->second; - } - Array& mvector = - meta_data_[node->type_key()]; - int64_t index = static_cast(mvector.size()); - mvector.push_back(node); - std::ostringstream os; - os << "meta[" << node->type_key() << "][" << index << "]"; - meta_repr_[node] = os.str(); - return meta_repr_[node]; - } - /*! - * \brief Get the metadata section in json format. - * \return the meta datastring. - */ - std::string GetMetaSection() const { - if (meta_data_.size() == 0) return std::string(); - return SaveJSON(Map( - meta_data_.begin(), meta_data_.end())); - } - - /*! \return whether the meta data context is empty. */ - bool empty() const { - return meta_data_.empty(); - } - - private: - /*! \brief additional metadata stored in TVM json format */ - std::unordered_map > meta_data_; - /*! \brief map from meta data into its string representation */ - std::unordered_map meta_repr_; -}; - -class TextPrinter : - public ExprFunctor, - public PatternFunctor, - public TypeFunctor, // NOLINT(*) - public AttrFunctor { // NOLINT(*) - public: - explicit TextPrinter(bool show_meta_data, - runtime::TypedPackedFunc annotate) - : show_meta_data_(show_meta_data), annotate_(annotate) {} - /*! - * \brief Print a node to string. - * \param node. - * \return The string representation. - */ - std::string Print(const NodeRef& node) { - if (node.as()) { - this->PrintFunc(Downcast(node)); - } else if (node.as()) { - this->PrintEnv(Downcast(node)); - } else if (node.as_derived()) { - this->PrintType(Downcast(node), stream_); - } else if (node.as_derived()) { - this->PrintExpr(Downcast(node)); - } else { - stream_ << node; - } - if (!meta_.empty()) { - if (show_meta_data_) { - std::string meta_json = meta_.GetMetaSection(); - // append meta data in the end. - stream_ << "# meta data\n" - << "r\"\"\"\n" - << meta_json << "\n" - << "\"\"\""; - } else { - stream_ << "# meta data omitted. you can use show_meta_data=True to include meta-data\n"; - } - } - return stream_.str(); - } - - void PrintFunc(const Function& func) { - this->PrintFuncInternal("fn ", func); - stream_ << "\n"; - } - - void PrintEnv(const Module& mod) { - int counter = 0; - for (const auto& kv : mod->functions) { - std::ostringstream os; - if (counter++ != 0) { - stream_ << "\n"; - } - os << "def @" << kv.first->name_hint; - this->PrintFuncInternal(os.str(), kv.second); - stream_ << "\n"; - } - } - - void PrintExpr(const Expr& expr) { - TextValue val = GetValue(expr); - stream_ << val << "\n"; - } - - /*! - * \brief Get text representation of expr. - * - * This function may generate additional instructions - * in order to compute the final result id of expr. - * - * When trying to recursively print out an Expr. - * The caller should always call GetValue of its children first. - * Then the caller can print out to stream_ using the obtained value. - * - * This is to avoid the call of subsequent GetValue print out - * additional instructions which get mixed with the partial instruction - * printed by the caller. - * - * \param expr The input expression. - * \return The text value of Expr. - */ - TextValue GetValue(const Expr& expr) { - auto it = memo_.find(expr); - if (it != memo_.end()) return it->second; - TextValue val = this->VisitExpr(expr); - memo_[expr] = val; - return val; - } - TextValue GetValue(const Pattern& p) { - return this->VisitPattern(p); - } - //------------------------------------ - // Overload of Expr printing functions - //------------------------------------ - TextValue VisitExpr_(const ConstantNode* op) final { - // Print out simple scalar directly. - if (op->is_scalar()) { - std::ostringstream os; - DataType dtype = TVMType2Type(op->data->dtype); - CHECK_EQ(op->data->ctx.device_type, kDLCPU); - if (dtype == Int(32)) { - return ConstScalar(dtype, static_cast(op->data->data)); - } else if (dtype == Int(64)) { - return ConstScalar(dtype, static_cast(op->data->data)); - } else if (dtype == Float(32)) { - return ConstScalar(dtype, static_cast(op->data->data)); - } else if (dtype == Float(64)) { - return ConstScalar(dtype, static_cast(op->data->data)); - } else if (dtype == Bool()) { - return ConstScalar(dtype, static_cast(op->data->data)); - } - } - // default fall-back, record it as meta node. - TextValue id = this->AllocTempVar(); - this->PrintIndent(); - stream_ << id << " = " << meta_.GetMetaNode(GetRef(op)); - this->PrintEndInst(""); - this->PrintOptionalInfo(GetRef(op)); - stream_ << '\n'; - return id; - } - - TextValue VisitExpr_(const TupleNode* op) final { - std::vector fields; - for (Expr field : op->fields) { - fields.push_back(GetValue(field)); - } - // NOTE: always recursively visit to get ids, - // before print out the current line - TextValue id = this->AllocTempVar(); - this->PrintIndent(); - stream_ << id << " = ("; - for (size_t i = 0; i < fields.size(); ++i) { - stream_ << fields[i]; - if (i + 1 != fields.size()) { - stream_ << ", "; - } - } - if (fields.size() == 1) { - stream_ << ','; - } - stream_ << ')'; - this->PrintEndInst("\n"); - return id; - } - - TextValue VisitExpr_(const VarNode* op) final { - Var var = GetRef(op); - // This is an unbounded var. - TextValue val = AllocVarName(var); - this->PrintIndent(); - stream_ << "free_var "; - this->PrintVarDecl(var, stream_); - this->PrintEndInst("\n"); - return val; - } - - TextValue VisitExpr_(const GlobalVarNode* op) final { - return TextValue('@' + op->name_hint); - } - - TextValue VisitExpr_(const FunctionNode* op) final { - TextValue id = AllocTempVar(); - std::ostringstream os; - os << id << " = fn"; - this->PrintFuncInternal(os.str(), GetRef(op)); - this->PrintEndInst("\n"); - return id; - } - - TextValue VisitExpr_(const CallNode* op) final { - // possibly through meta-data - std::vector args; - for (Expr arg : op->args) { - args.emplace_back(GetValue(arg)); - } - TextValue call_op = GetValue(op->op); - TextValue id = this->AllocTempVar(); - this->PrintIndent(); - - stream_ << id << " = " << call_op; - - auto type_args = op->type_args; - - if (!IsPrimitiveOp(op->op) && type_args.size() > 0U) { - stream_ << "<"; - for (size_t i = 0; i < op->type_args.size(); ++i) { - this->PrintType(type_args[i], stream_); - if (i + 1 != type_args.size()) { - stream_ << ", "; - } - } - stream_ << ">"; - } - - stream_ << "("; - for (size_t i = 0; i < args.size(); ++i) { - stream_ << args[i]; - if (i + 1 != args.size()) { - stream_ << ", "; - } - } - this->PrintCallAttrs(op->op, op->attrs, stream_); - stream_ << ")"; - this->PrintEndInst(""); - this->PrintOptionalInfo(GetRef(op)); - stream_ << '\n'; - return id; - } - - TextValue VisitExpr_(const LetNode* op) final { - TextValue id = this->AllocTempVar(); - this->PrintIndent(); - stream_ << id << " = "; - this->PrintScope(GetRef(op)); - this->PrintEndInst("\n"); - return id; - } - - TextValue VisitExpr_(const IfNode* op) final { - TextValue id = this->AllocTempVar(); - this->PrintIndent(); - stream_ << id << " = "; - this->PrintScope(GetRef(op)); - this->PrintEndInst("\n"); - return id; - } - - TextValue VisitExpr_(const OpNode* op) final { - return TextValue(op->name); - } - - TextValue VisitExpr_(const TupleGetItemNode* op) final { - TextValue tuple = GetValue(op->tuple); - TextValue id = this->AllocTempVar(); - this->PrintIndent(); - stream_ << id << " = " << tuple << "." << op->index << ""; - this->PrintEndInst("\n"); - return id; - } - - TextValue VisitExpr_(const RefCreateNode* op) final { - TextValue value = GetValue(op->value); - TextValue id = this->AllocTempVar(); - this->PrintIndent(); - stream_ << id << " = " << "RefCreate(" << op->value << ")"; - this->PrintEndInst("\n"); - return id; - } - - TextValue VisitExpr_(const RefReadNode* op) final { - TextValue ref = GetValue(op->ref); - TextValue id = this->AllocTempVar(); - this->PrintIndent(); - stream_ << id << " = " << "RefRead(" << ref << ")"; - this->PrintEndInst("\n"); - return id; - } - - TextValue VisitExpr_(const RefWriteNode* op) final { - TextValue ref = GetValue(op->ref); - TextValue value = GetValue(op->value); - TextValue id = this->AllocTempVar(); - this->PrintIndent(); - stream_ << id << " = " << "RefWrite(" << ref << ", " << value << ")"; - this->PrintEndInst("\n"); - return id; - } - - TextValue VisitExpr_(const MatchNode* op) final { - TextValue data = GetValue(op->data); - this->PrintIndent(); - TextValue id = this->AllocTempVar(); - stream_ << id << " = " << "Match " << data << " with"; - this->PrintEndInst("\n"); - for (const auto& c : op->clauses) { - this->PrintIndent(); - stream_ << GetValue(c->lhs) << " to " << GetValue(c->rhs); - this->PrintEndInst("\n"); - } - return id; - } - - TextValue VisitPattern_(const PatternConstructorNode* p) final { - TextValue ret(p->constructor->name_hint + "("); - for (const Pattern& pat : p->patterns) { - ret = ret + " " + GetValue(pat); - } - return ret + ")"; - } - - TextValue VisitPattern_(const PatternVarNode* pv) final { - return GetValue(pv->var); - } - - TextValue VisitExpr_(const ConstructorNode* n) final { - return TextValue(n->name_hint); - } - - /*! - * \brief Print the type to os - * \param type The type to be printed. - * \param os The output type. - */ - void PrintType(const Type& type, std::ostream& os) { // NOLINT(*) - this->VisitType(type, os); - } - //------------------------------------ - // Overload of Expr printing functions - //------------------------------------ - void VisitType_(const TensorTypeNode* node, std::ostream& os) final { // NOLINT(*) - // scalar type - if (node->shape.size() == 0) { - os << runtime::TVMType2String(Type2TVMType(node->dtype)); - return; - } - os << "Tensor[("; - for (size_t i = 0; i < node->shape.size(); ++i) { - this->PrintAttr(node->shape[i], os); - if (i + 1 != node->shape.size()) { - os << ", "; - } - } - // conform to python tuple format (1,) - if (node->shape.size() == 1) { - os << ","; - } - os << "), " << runtime::TVMType2String(Type2TVMType(node->dtype)) << "]"; - } - - void VisitType_(const TupleTypeNode* node, std::ostream& os) final { // NOLINT(*) - os << "Tuple["; - for (size_t i = 0; i < node->fields.size(); ++i) { - this->PrintType(node->fields[i], os); - if (i + 1 != node->fields.size()) { - os << ", "; - } - } - os << "]"; - } - - void VisitType_(const RefTypeNode* node, std::ostream& os) final { - VisitTypeDefault_(node, os); - } - - void VisitType_(const TypeCallNode* node, std::ostream& os) final { - os << node->func << "(" << node->args << ")"; - } - - void VisitType_(const GlobalTypeVarNode* node, std::ostream& os) final { - VisitTypeDefault_(node, os); - } - - void VisitType_(const TypeDataNode* node, std::ostream& os) final { - VisitTypeDefault_(node, os); - } - - void VisitTypeDefault_(const Node* node, std::ostream& os) final { // NOLINT(*) - // by default always print as meta-data - os << meta_.GetMetaNode(GetRef(node)); - } - - /*! - * \brief Print an attribute value to os. - * \param value The value to be printed. - * \param os The output type. - */ - void PrintAttr(const NodeRef& value, std::ostream& os) { // NOLINT(*) - if (value.defined()) { - this->VisitAttr(value, os); - } else { - os << "None"; - } - } - //------------------------------------ - // Overload of Attr printing functions - //------------------------------------ - void VisitAttr_(const ArrayNode* op, std::ostream& os) final { // NOLINT(*) - os << "["; - for (size_t i = 0; i < op->data.size(); ++i) { - this->PrintAttr(NodeRef(op->data[i]), os); - if (i + 1 != op->data.size()) { - os << ", "; - } - } - os << "]"; - } - void VisitAttrDefault_(const Node* op, std::ostream& os) final { // NOLINT(*) - os << meta_.GetMetaNode(GetRef(op)); - } - - void VisitAttr_(const ir::IntImm* op, std::ostream& os) final { // NOLINT(*) - this->PrintConstScalar(op->type, &(op->value), os); - } - - void VisitAttr_(const ir::UIntImm* op, std::ostream& os) final { // NOLINT(*) - this->PrintConstScalar(op->type, &(op->value), os); - } - - void VisitAttr_(const ir::FloatImm* op, std::ostream& os) final { // NOLINT(*) - this->PrintConstScalar(op->type, &(op->value), os); - } - - void VisitAttr_(const ir::StringImm* op, std::ostream& os) final { // NOLINT(*) - this->PrintString(op->value, os); - } - - protected: - /*! - * \brief Print attributes after call. - * \param op The operator to be called. - * \param attrs The attributes. - * \param os The output stream. - */ - void PrintCallAttrs(const Expr& op, const Attrs& attrs, std::ostream& os); // NOLINT(*) - - /*! - * \brief Print the a new scopr. - * \param body The body. - */ - void PrintScope(Expr body) { - stream_ << "{\n"; - int sid = this->BeginScope(); - this->PrintScopeBody(body); - this->EndScope(sid); - this->PrintIndent(); - stream_ << "}"; - } - /*! - * \brief Print the body of a new scope without {} - * - * This function will keep printing continuous sequence - * of let/if scope without introducing a new scope in the text. - * - * \param body The body. - */ - void PrintScopeBody(Expr body) { - if (const LetNode* let = body.as()) { - TextValue value = GetValue(let->value); - AllocVarName(let->var); - // let var = value; - this->PrintIndent(); - stream_ << "let "; - this->PrintVarDecl(let->var, stream_); - stream_ << " = " << value; - this->PrintEndInst("\n"); - this->PrintScopeBody(let->body); - } else if (const IfNode* ifnode = body.as()) { - TextValue cond = GetValue(ifnode->cond); - this->PrintIndent(); - stream_ << "if (" << cond << ") "; - this->PrintScope(ifnode->true_branch); - this->PrintIndent(); - stream_ << "else "; - this->PrintScope(ifnode->false_branch); - this->PrintEndInst("\n"); - } else { - TextValue value = GetValue(body); - this->PrintIndent(); - stream_ << value; - this->PrintEndInst("\n"); - } - } - - /*! - * \brief Internal function to print a function argument list and its body. - * \param prefix The prefix before argument list. - * \param fn The function to be printed. - */ - void PrintFuncInternal(std::string prefix, const Function& fn) { - // TODO(tqchen, M.K.) support generic function - // Possibly through meta-data - CHECK_EQ(fn->type_params.size(), 0U) - << "generic fn not yet supported"; - this->PrintIndent(); - stream_ << prefix << "("; - size_t decl_indent = prefix.length() + 1; - for (size_t i = 0; i < fn->params.size(); ++i) { - if (i != 0) { - this->PrintIndent(decl_indent); - } - AllocVarName(fn->params[i]); - this->PrintVarDecl(fn->params[i], stream_); - if (i + 1 != fn->params.size()) { - stream_ << ",\n"; - } - } - stream_ << ')'; - if (fn->ret_type.defined()) { - stream_ << '\n'; - this->PrintIndent(decl_indent); - stream_ << "-> "; - this->PrintType(fn->ret_type, stream_); - } - stream_ << ' '; - this->PrintScope(fn->body); - } - /*! - * \brief Print additional info about expr in comment. - * \param expr The expression. - */ - void PrintOptionalInfo(const Expr& expr) { - // additional information in comment. - if (annotate_ != nullptr) { - stream_ << " # " << annotate_(expr); - } else if (expr->checked_type_.defined()) { - stream_ << " # ty="; - this->PrintType(expr->checked_type(), stream_); - } - } - /*! - * \brief print var_name[:type] - * \param var The variable to be printed - * \param os The output stream - */ - void PrintVarDecl(const Var& var, std::ostream& os) { // NOLINT(*) - TextValue v = GetValue(var); - os << v; - if (var->type_annotation.defined()) { - os << ": "; - this->PrintType(var->type_annotation, os); - } - } - /*! - * \brief Get a constant scalar value. - * \param dtype The data type. - * \param data The pointer to the data. - * \tparam T the content data type holding the data. - */ - template - TextValue ConstScalar(DataType dtype, const T* data) { - std::ostringstream os; - PrintConstScalar(dtype, data, os); - return TextValue(os.str()); - } - /*! - * \brief special method to print out const scalar - * \param dtype The data type - * \param data The pointer to hold the data. - * \param os The output stream. - */ - template - void PrintConstScalar(DataType dtype, const T* data, std::ostream& os) { // NOLINT(*) - if (dtype == Int(32)) { - os << data[0]; - } else if (dtype == Float(32)) { - os << data[0] << 'f'; - } else if (dtype == Bool()) { - PrintBool(data[0] != 0, os); - } else { - os << dtype << "(" << data[0] << ")"; - } - } - /*! - * \brief Print constant bool value. - * \param value The value to be printed. - * \param os The output stream - */ - void PrintBool(bool value, std::ostream& os) { // NOLINT(*) - if (value) { - os << "True"; - } else { - os << "False"; - } - } - /*! - * \brief Print constant string. - * \param value The value to be printed. - * \param os The output stream - */ - void PrintString(const std::string& value, std::ostream& os) { // NOLINT(*) - // TODO(M.K.): add escape. - os << "\"" << value << "\""; - } - /*! - * \brief get a unique name with the corresponding prefix - * \param prefix The prefix of the name - * \return The returned name. - */ - std::string GetUniqueName(std::string prefix) { - auto it = name_alloc_map_.find(prefix); - if (it != name_alloc_map_.end()) { - while (true) { - std::ostringstream os; - os << prefix << (++it->second); - std::string name = os.str(); - if (name_alloc_map_.count(name) == 0) { - prefix = name; - break; - } - } - } - name_alloc_map_[prefix] = 0; - return prefix; - } - /*! - * \brief mark the beginning of a new scope - * \return The scope id. - */ - int BeginScope() { - int sid = static_cast(scope_valid_.size()); - scope_valid_.push_back(true); - indent_ += 2; - return sid; - } - /*! - * \brief mark the end of an old scope. - * \param scope_id The scope id to be ended. - */ - void EndScope(int scope_id) { - scope_valid_[scope_id] = false; - indent_ -= 2; - } - /*! - * \brief Print the indent to the stream. - * \param more_indent More indentation besides the current one. - */ - void PrintIndent(int64_t more_indent = 0) { - for (int i = 0; i < indent_ + more_indent; ++i) { - stream_ << ' '; - } - } - /*! - * \brief print end of the line. - */ - void PrintEndInst(const char* suffix) { - stream_ << suffix; - } - /*! - * \brief Allocate temporary value - * \return A new text value. - */ - TextValue AllocTempVar() { - std::ostringstream os; - os << '%' << temp_var_counter_++; - return TextValue(os.str()); - } - /*! - * \brief Allocate name to a variable. - * \param var The input variable. - * \return The corresponding name. - */ - TextValue AllocVarName(const Var& var) { - std::string name = var->name_hint(); - // always make sure first name is alpha - if (name.length() != 0 && !std::isalpha(name[0])) { - name = "%v" + name; - } else { - name = "%" + name; - } - TextValue val(GetUniqueName(name)); - // still print if ir is malformed, but show the error. - if (memo_.count(var)) { - memo_[var] = TextValue(val.name + "-malformed-ir"); - } - memo_[var] = val; - return val; - } - - private: - class AttrPrinter; - friend class AttrPrinter; - /*! \brief Whether to print meta data. */ - bool show_meta_data_; - /*! \brief additional comment function */ - runtime::TypedPackedFunc annotate_; - /*! \brief meta data context */ - TextMetaDataContextFoo meta_; - /*! \brief Check whether scope is still valid */ - std::vector scope_valid_; - /*! \brief The current indentation value */ - int indent_{0}; - /*! \brief name allocation map */ - std::unordered_map name_alloc_map_; - /*! \brief Map from expression to its text value */ - std::unordered_map memo_; - /*! \brief counter of temporary variable */ - int64_t temp_var_counter_{0}; - /*! \brief Output stream */ - std::ostringstream stream_; -}; - -/*! - * \brief Attribute printer which prints the attributes in the call. - */ -class TextPrinter::AttrPrinter: public AttrVisitor { - public: - AttrPrinter(std::ostream& stream, TextPrinter* parent) // NOLINT(*) - : stream_(stream), parent_(parent) {} - - void Visit(const char* key, double* value) final { - PrintSep(); - stream_ << key << "=" << value[0]; - } - void Visit(const char* key, int64_t* value) final { - PrintSep(); - stream_ << key << "=" << value[0]; - } - void Visit(const char* key, uint64_t* value) final { - PrintSep(); - stream_ << key << "=" << value[0]; - } - void Visit(const char* key, int* value) final { - PrintSep(); - stream_ << key << "=" << value[0]; - } - void Visit(const char* key, bool* value) final { - PrintSep(); - stream_ << key << "="; - parent_->PrintBool(value[0], stream_); - } - void Visit(const char* key, std::string* value) final { - PrintSep(); - stream_ << key << "="; - parent_->PrintString(value[0], stream_); - } - void Visit(const char* key, void** value) final { - LOG(FATAL) << "do not allow void as argument"; - } - void Visit(const char* key, DataType* value) final { - PrintSep(); - stream_ << key << "="; - parent_->PrintString(runtime::TVMType2String(Type2TVMType(value[0])), stream_); - } - void Visit(const char* key, NodeRef* value) final { - PrintSep(); - stream_ << key << "="; - parent_->PrintAttr(value[0], stream_); - } - void Visit(const char* key, runtime::NDArray* value) final { - LOG(FATAL) << "do not allow NDarray as argument"; - } - - private: - void PrintSep() { - stream_ << ", "; - } - std::ostream& stream_; // NOLINT(*) - TextPrinter* parent_; -}; - -void TextPrinter::PrintCallAttrs(const Expr& op, - const Attrs& attrs, - std::ostream& os) { // NOLINT(*) - if (!attrs.defined()) return; - if (const auto* op_node = op.as()) { - if (attrs->type_index() == op_node->attrs_type_index) { - AttrPrinter printer(os, this); - const_cast(attrs.operator->()) - ->VisitNonDefaultAttrs(&printer); - return; - } - } - os << ", " << meta_.GetMetaNode(attrs); -} - -std::string RelayPrint(const NodeRef& node, - bool show_meta_data, - runtime::TypedPackedFunc annotate) { - return TextPrinter(show_meta_data, annotate).Print(node); -} - -TVM_REGISTER_API("relay._expr.RelayPrint") -.set_body_typed)>(RelayPrint); - -} // namespace relay -} // namespace tvm diff --git a/tests/python/relay/test_ir_parser_roundtrip.py b/tests/python/relay/test_ir_parser_roundtrip.py index 90f7b3fdf47c..f53a6ba0cb4d 100644 --- a/tests/python/relay/test_ir_parser_roundtrip.py +++ b/tests/python/relay/test_ir_parser_roundtrip.py @@ -1,12 +1,17 @@ import tvm from tvm import relay from tvm.relay.ir_pass import alpha_equal -from tvm.relay._expr import anf_print, gnf_print import numpy as np from hypothesis import given, reject, settings from hypothesis.strategies import text, lists, integers, composite, recursive, deferred +def gnf_print(expr): + return expr.astext(gnf=True) + +def anf_print(expr): + return expr.astext(gnf=False) + exprs = deferred(lambda: constants() # | projections(exprs) | tuples(exprs)) @@ -116,4 +121,4 @@ def test_func_type(): print(anf_print(relay.fromtext(SEMVER+"let %x = { let %y = 2; ((%y + %y, %y * %y), 1) }; %x"))) print(relay.const([1,2,3]).astext()) print(gnf_print(relay.const([1,2,3]))) - print(anf_print(relay.const([1,2,3]))) \ No newline at end of file + print(anf_print(relay.const([1,2,3]))) From 1bb55bb03462acd3f1e5d337afc840e8b3d461c9 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Tue, 5 Mar 2019 20:59:35 -0800 Subject: [PATCH 43/73] pass text printer tests --- src/relay/ir/pretty_printer.cc | 10 ++++++++++ tests/python/relay/test_ir_text_printer.py | 4 ++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index e3dc71e54852..8c159f7f6b19 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -249,6 +249,13 @@ class PrettyPrinter : Doc printed_expr; if (meta) { printed_expr = meta_.GetMetaNode(GetRef(expr.get())); + } else if (GNF_ && gnf && expr.as()) { + // wrap GNFed let in brackets + printed_expr = Nil(); + Doc body = Nil(); + printed_expr << "{"; + printed_expr << Indent(2, body << "\n" << VisitExpr(expr)) << "\n"; + printed_expr << "}"; } else { printed_expr = VisitExpr(expr); } @@ -270,6 +277,9 @@ class PrettyPrinter : return memo_[expr]; } else { memo_[expr] = printed_expr; + if (GNF_ && expr.as()) { + printed_expr << PrintOptionalInfo(expr); + } return printed_expr; } } diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index 21bd85a3eb37..a32b6c2b608a 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -33,8 +33,8 @@ def test_env(): text = env.astext() assert "def @myf" in text assert "def @myf" in str(env) - assert "%1 = add(%0, %0) # ty=float32" in text - assert "%1 = add(%0, %0) # ty=float32" in str(env) + assert "add(%0, %0) # ty=float32" in text + assert "add(%0, %0) # ty=float32" in str(env) show(env.astext(annotate=lambda x: str(x.checked_type.dtype))) show(text) From 69370d2ab04ddf92c1cf7a8742fe7c76afe1d870 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Wed, 6 Mar 2019 14:02:41 -0800 Subject: [PATCH 44/73] pretty printer lint --- src/relay/ir/pretty_printer.cc | 785 +++++++++++++++++---------------- 1 file changed, 399 insertions(+), 386 deletions(-) diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 8c159f7f6b19..bd8332a1a538 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -4,8 +4,8 @@ * \brief Pretty printer for Relay programs * Supports ANF, GNF, and metadata. */ -#include "doc.h" #include +#include "doc.h" #include "type_functor.h" #include "../../lang/attr_functor.h" @@ -112,439 +112,446 @@ class PrettyPrinter : public ExprFunctor, public TypeFunctor, public AttrFunctor { - public: - explicit PrettyPrinter(bool GNF, bool show_meta_data, runtime::TypedPackedFunc annotate) : GNF_(GNF), show_meta_data_(show_meta_data), annotate_(annotate) {} - - /*! - * \brief Print additional info about expr in comment. - * \param expr The expression. - */ - Doc PrintOptionalInfo(const Expr& expr) { - Doc doc = Nil(); - // additional information in comment. - if (annotate_ != nullptr) { - return doc << " # " << annotate_(expr); - } else if (expr->checked_type_.defined()) { - doc << " # ty="; - return doc << Print(expr->checked_type()); - } else { - return Nil(); - } - } + public: + explicit PrettyPrinter(bool GNF, + bool show_meta_data, + runtime::TypedPackedFunc annotate) : GNF_(GNF), + show_meta_data_(show_meta_data), + annotate_(annotate) {} - // indent a new body - Doc PrintBody(const NodeRef& node, int indent = 2) { - Doc doc = Nil(); - Doc body = Nil(); - doc << "{"; - doc << Indent(indent, body << "\n" << PrintScope(node)) << "\n"; - doc << "}"; - return doc; + /*! + * \brief Print additional info about expr in comment. + * \param expr The expression. + */ + Doc PrintOptionalInfo(const Expr& expr) { + Doc doc = Nil(); + // additional information in comment. + if (annotate_ != nullptr) { + return doc << " # " << annotate_(expr); + } else if (expr->checked_type_.defined()) { + doc << " # ty="; + return doc << Print(expr->checked_type()); + } else { + return Nil(); } + } - // create a new scope by creating a new printer object. This allows temp var - // numbers to be reused and prevents hoisted vars from escaping too far - Doc PrintScope(const NodeRef& node) { - // print in a new scope - doc_stack_.push_back(Nil()); - // must print first so doc_stack_.back() reference doesn't become stale - Doc doc = Print(node, false); - doc = doc_stack_.back() << doc; - doc_stack_.pop_back(); - return doc; - } + // indent a new body + Doc PrintBody(const NodeRef& node, int indent = 2) { + Doc doc = Nil(); + Doc body = Nil(); + doc << "{"; + doc << Indent(indent, body << "\n" << PrintScope(node)) << "\n"; + doc << "}"; + return doc; + } - Doc PrintFinal(const NodeRef& node) { - Doc doc = Nil(); - doc << PrintScope(node); - if (!meta_.empty()) { - if (show_meta_data_) { - std::string meta_json = meta_.GetMetaSection(); - // append meta data in the end. - doc << "\n" << "/* meta data */" << "\n" << meta_json; - } else { - doc << "\n" << "// meta data omitted. you can use show_meta_data=True to include meta data"; - } - }; - return doc; - } + // create a new scope by creating a new printer object. This allows temp var + // numbers to be reused and prevents hoisted vars from escaping too far + Doc PrintScope(const NodeRef& node) { + // print in a new scope + doc_stack_.push_back(Nil()); + // must print first so doc_stack_.back() reference doesn't become stale + Doc doc = Print(node, false); + doc = doc_stack_.back() << doc; + doc_stack_.pop_back(); + return doc; + } - Doc PrintAttrs(const Attrs& attrs); - - // note: gnf flag is only one level deep - Doc Print(const NodeRef& node, bool gnf = true, bool meta = false) { - if (node.as_derived()) { - return PrintExpr(Downcast(node), gnf, meta); - } else if (node.as_derived()) { - return PrintType(Downcast(node), meta); - } else if (node.as_derived()) { - return PrintMod(Downcast(node)); - } else { assert(false); } + Doc PrintFinal(const NodeRef& node) { + Doc doc = Nil(); + doc << PrintScope(node); + if (!meta_.empty()) { + if (show_meta_data_) { + std::string meta_json = meta_.GetMetaSection(); + // append meta data in the end. + doc << "\n" << "/* meta data */" << "\n" << meta_json; + } else { + doc << "\n" + << "// meta data omitted. you can use show_meta_data=True to include meta data"; + } } + return doc; + } - Doc TempVar(int n) { - Doc doc = Nil(); - return doc << "%" << n; - } + Doc PrintAttrs(const Attrs& attrs); + + // note: gnf flag is only one level deep + Doc Print(const NodeRef& node, bool gnf = true, bool meta = false) { + if (node.as_derived()) { + return PrintExpr(Downcast(node), gnf, meta); + } else if (node.as_derived()) { + return PrintType(Downcast(node), meta); + } else if (node.as_derived()) { + return PrintMod(Downcast(node)); + } else { assert(false); } + } - Doc AllocTemp() { - return TempVar(temp_var_counter_++); - } + Doc TempVar(int n) { + Doc doc = Nil(); + return doc << "%" << n; + } - /*! - * \brief get a unique name with the corresponding prefix - * \param prefix The prefix of the name - * \return The returned name. - */ - Doc GetUniqueName(std::string prefix) { - auto it = name_alloc_map_.find(prefix); - if (it != name_alloc_map_.end()) { - while (true) { - std::ostringstream os; - os << prefix << (++it->second); - std::string name = os.str(); - if (name_alloc_map_.count(name) == 0) { - prefix = name; - break; - } + Doc AllocTemp() { + return TempVar(temp_var_counter_++); + } + + /*! + * \brief get a unique name with the corresponding prefix + * \param prefix The prefix of the name + * \return The returned name. + */ + Doc GetUniqueName(std::string prefix) { + auto it = name_alloc_map_.find(prefix); + if (it != name_alloc_map_.end()) { + while (true) { + std::ostringstream os; + os << prefix << (++it->second); + std::string name = os.str(); + if (name_alloc_map_.count(name) == 0) { + prefix = name; + break; } } - name_alloc_map_[prefix] = 0; - return Text(prefix); } + name_alloc_map_[prefix] = 0; + return Text(prefix); + } - /*! - * \brief Allocate name to a variable. - * \param var The input variable. - * \return The corresponding name. - */ - Doc AllocVar(const Var& var) { - std::string name = var->name_hint(); - // always make sure first name is alpha - if (name.length() != 0 && !std::isalpha(name[0])) { - name = "v" + name; - } - Doc val = GetUniqueName("%" + name); - // still print if ir is malformed, but show the error. - if (memo_.count(var)) { - val << Text("-malformed-ir"); - } - memo_[var] = val; - if (var->type_annotation.defined()) { - val << ": " << Print(var->type_annotation); - } - return val; - } + /*! + * \brief Allocate name to a variable. + * \param var The input variable. + * \return The corresponding name. + */ + Doc AllocVar(const Var& var) { + std::string name = var->name_hint(); + // always make sure first name is alpha + if (name.length() != 0 && !std::isalpha(name[0])) { + name = "v" + name; + } + Doc val = GetUniqueName("%" + name); + // still print if ir is malformed, but show the error. + if (memo_.count(var)) { + val << Text("-malformed-ir"); + } + memo_[var] = val; + if (var->type_annotation.defined()) { + val << ": " << Print(var->type_annotation); + } + return val; + } - //------------------------------------ - // Overload of Expr printing functions - //------------------------------------ - Doc PrintExpr(const Expr& expr, bool gnf, bool meta) { - // Exploit memoization to print GNF. - // The first time we visit an expression, we need to allocate a temp var - // for it. Every subsequent time we can just use its assigned variable. - // This works since hashing uses pointer equality. - auto it = memo_.find(expr); - if (it != memo_.end()) return it->second; - Doc printed_expr; - if (meta) { - printed_expr = meta_.GetMetaNode(GetRef(expr.get())); - } else if (GNF_ && gnf && expr.as()) { - // wrap GNFed let in brackets - printed_expr = Nil(); - Doc body = Nil(); - printed_expr << "{"; - printed_expr << Indent(2, body << "\n" << VisitExpr(expr)) << "\n"; - printed_expr << "}"; - } else { - printed_expr = VisitExpr(expr); + //------------------------------------ + // Overload of Expr printing functions + //------------------------------------ + Doc PrintExpr(const Expr& expr, bool gnf, bool meta) { + // Exploit memoization to print GNF. + // The first time we visit an expression, we need to allocate a temp var + // for it. Every subsequent time we can just use its assigned variable. + // This works since hashing uses pointer equality. + auto it = memo_.find(expr); + if (it != memo_.end()) return it->second; + Doc printed_expr; + if (meta) { + printed_expr = meta_.GetMetaNode(GetRef(expr.get())); + } else if (GNF_ && gnf && expr.as()) { + // wrap GNFed let in brackets + printed_expr = Nil(); + Doc body = Nil(); + printed_expr << "{"; + printed_expr << Indent(2, body << "\n" << VisitExpr(expr)) << "\n"; + printed_expr << "}"; + } else { + printed_expr = VisitExpr(expr); + } + // we choose to inline some nodes + if (GNF_ && gnf && + !expr.as() && !expr.as() && + !expr.as() && !expr.as()) { + Doc temp_var = AllocTemp(); + memo_[expr] = temp_var; + doc_stack_.back() << temp_var << " = " << printed_expr; + if (expr.as()) { + doc_stack_.back() << PrintOptionalInfo(expr); } - // we choose to inline some nodes - if (GNF_ && gnf && !expr.as() && !expr.as() && !expr.as() && !expr.as()) { - Doc temp_var = AllocTemp(); - memo_[expr] = temp_var; - doc_stack_.back() << temp_var << " = " << printed_expr; - if (expr.as()) { - doc_stack_.back() << PrintOptionalInfo(expr); - } - doc_stack_.back() << "\n"; - return temp_var; - } else if (expr.as()) { - // This is our first time visiting the var and we hit the VarNode case - // in the visitor. Thus the variable is free. - doc_stack_.back() << "free_var " << printed_expr << "\n"; - // Memoization is done in AllocVar. - return memo_[expr]; - } else { - memo_[expr] = printed_expr; - if (GNF_ && expr.as()) { - printed_expr << PrintOptionalInfo(expr); - } - return printed_expr; + doc_stack_.back() << "\n"; + return temp_var; + } else if (expr.as()) { + // This is our first time visiting the var and we hit the VarNode case + // in the visitor. Thus the variable is free. + doc_stack_.back() << "free_var " << printed_expr << "\n"; + // Memoization is done in AllocVar. + return memo_[expr]; + } else { + memo_[expr] = printed_expr; + if (GNF_ && expr.as()) { + printed_expr << PrintOptionalInfo(expr); } + return printed_expr; } + } - // Should only be triggered when op is a free variable being visited for the - // first time. - Doc VisitExpr_(const VarNode* op) final { - return AllocVar(GetRef(op)); - } + // Should only be triggered when op is a free variable being visited for the + // first time. + Doc VisitExpr_(const VarNode* op) final { + return AllocVar(GetRef(op)); + } - Doc VisitExpr_(const ConstantNode* op) final { - // Print out simple scalars directly. - if (op->is_scalar()) { - std::ostringstream os; - DataType dtype = TVMType2Type(op->data->dtype); - CHECK_EQ(op->data->ctx.device_type, kDLCPU); - if (dtype == Int(32)) { - return PrintConstScalar(dtype, static_cast(op->data->data)); - } else if (dtype == Int(64)) { - return PrintConstScalar(dtype, static_cast(op->data->data)); - } else if (dtype == Float(32)) { - return PrintConstScalar(dtype, static_cast(op->data->data)); - } else if (dtype == Float(64)) { - return PrintConstScalar(dtype, static_cast(op->data->data)); - } else if (dtype == Bool()) { - return PrintConstScalar(dtype, static_cast(op->data->data)); - } + Doc VisitExpr_(const ConstantNode* op) final { + // Print out simple scalars directly. + if (op->is_scalar()) { + std::ostringstream os; + DataType dtype = TVMType2Type(op->data->dtype); + CHECK_EQ(op->data->ctx.device_type, kDLCPU); + if (dtype == Int(32)) { + return PrintConstScalar(dtype, static_cast(op->data->data)); + } else if (dtype == Int(64)) { + return PrintConstScalar(dtype, static_cast(op->data->data)); + } else if (dtype == Float(32)) { + return PrintConstScalar(dtype, static_cast(op->data->data)); + } else if (dtype == Float(64)) { + return PrintConstScalar(dtype, static_cast(op->data->data)); + } else if (dtype == Bool()) { + return PrintConstScalar(dtype, static_cast(op->data->data)); } - // default fall-back, record it as meta node. - Doc doc = Nil(); - return doc << Print(GetRef(op), true, true) - << PrintOptionalInfo(GetRef(op)); } + // default fall-back, record it as meta node. + Doc doc = Nil(); + return doc << Print(GetRef(op), true, true) + << PrintOptionalInfo(GetRef(op)); + } - Doc VisitExpr_(const TupleNode* op) final { - std::vector fields; - for (Expr field : op->fields) { - fields.push_back(Print(field)); - } - Doc doc = Nil(); - return doc << "(" << PrintVec(fields) << ")"; + Doc VisitExpr_(const TupleNode* op) final { + std::vector fields; + for (Expr field : op->fields) { + fields.push_back(Print(field)); } + Doc doc = Nil(); + return doc << "(" << PrintVec(fields) << ")"; + } - Doc VisitExpr_(const TupleGetItemNode* op) final { - Doc doc = Nil(); - return doc << Print(op->tuple) << "." << op->index; - } + Doc VisitExpr_(const TupleGetItemNode* op) final { + Doc doc = Nil(); + return doc << Print(op->tuple) << "." << op->index; + } - Doc VisitExpr_(const IfNode* op) final { - Doc doc = Nil(); - doc << "if (" << Print(op->cond) << ") "; - doc << PrintBody(op->true_branch); - doc << " else "; - doc << PrintBody(op->false_branch); - return doc; - } + Doc VisitExpr_(const IfNode* op) final { + Doc doc = Nil(); + doc << "if (" << Print(op->cond) << ") "; + doc << PrintBody(op->true_branch); + doc << " else "; + doc << PrintBody(op->false_branch); + return doc; + } + + Doc VisitExpr_(const LetNode* op) final { + Doc doc = Nil(); + doc << "let " << AllocVar(op->var) << " = "; + if (op->value.as()) { + doc << PrintBody(op->value); + } else { + // we use ANF mode for the first level of the value position so the + // final expression isn't hoisted or added to the doc stream + doc << Print(op->value, false); + } + doc << ";" << "\n"; + // we use a nested scope here so GNF hoisting doesn't escape too far + // and so consecutive lets don't get hoisted + doc << PrintScope(op->body); + return doc; + } - Doc VisitExpr_(const LetNode* op) final { + Doc PrintFunc(const Doc& prefix, const Function& fn) { + // TODO(tqchen, M.K.) support generic function + // Possibly through meta data + CHECK_EQ(fn->type_params.size(), 0U) + << "generic fn not yet supported"; Doc doc = Nil(); - doc << "let " << AllocVar(op->var) << " = "; - if (op->value.as()) { - doc << PrintBody(op->value); - } else { - // we use ANF mode for the first level of the value position so the - // final expression isn't hoisted or added to the doc stream - doc << Print(op->value, false); + doc << prefix << "("; + std::vector params; + for (Var param : fn->params) { + params.push_back(AllocVar(param)); + } + doc << PrintVec(params) << PrintAttrs(fn->attrs); + doc << ") "; + if (fn->ret_type.defined()) { + doc << "-> " << Print(fn->ret_type) << " "; } - doc << ";" << "\n"; - // we use a nested scope here so GNF hoisting doesn't escape too far - // and so consecutive lets don't get hoisted - doc << PrintScope(op->body); + doc << PrintBody(fn->body); return doc; - } - - Doc PrintFunc(const Doc& prefix, const Function& fn) { - // TODO(tqchen, M.K.) support generic function - // Possibly through meta data - CHECK_EQ(fn->type_params.size(), 0U) - << "generic fn not yet supported"; - Doc doc = Nil(); - doc << prefix << "("; - std::vector params; - for (Var param : fn->params) { - params.push_back(AllocVar(param)); - } - doc << PrintVec(params) << PrintAttrs(fn->attrs); - doc << ") "; - if (fn->ret_type.defined()) { - doc << "-> " << Print(fn->ret_type) << " "; - } - doc << PrintBody(fn->body); - return doc; - } + } - Doc PrintMod(const Module& mod) { - Doc doc = Nil(); - int counter = 0; - for (const auto& kv : mod->functions) { - std::ostringstream os; - if (counter++ != 0) { - doc << "\n"; - } - os << "def @" << kv.first->name_hint; - doc << PrintFunc(Text(os.str()), kv.second); + Doc PrintMod(const Module& mod) { + Doc doc = Nil(); + int counter = 0; + for (const auto& kv : mod->functions) { + std::ostringstream os; + if (counter++ != 0) { doc << "\n"; } - return doc; + os << "def @" << kv.first->name_hint; + doc << PrintFunc(Text(os.str()), kv.second); + doc << "\n"; } + return doc; + } - Doc VisitExpr_(const FunctionNode* op) final { - return PrintFunc(Text("fn "), GetRef(op)); - } + Doc VisitExpr_(const FunctionNode* op) final { + return PrintFunc(Text("fn "), GetRef(op)); + } - Doc VisitExpr_(const GlobalVarNode* op) final { - return Text('@' + op->name_hint); - } + Doc VisitExpr_(const GlobalVarNode* op) final { + return Text('@' + op->name_hint); + } - Doc VisitExpr_(const OpNode* op) final { - return Text(op->name); - } + Doc VisitExpr_(const OpNode* op) final { + return Text(op->name); + } - Doc VisitExpr_(const CallNode* op) final { - Doc doc = Nil(); - doc << Print(op->op); - std::vector args; - for (Expr arg : op->args) { - args.push_back(Print(arg)); - } - return doc << "(" << PrintVec(args) << PrintAttrs(op->attrs) << ")"; + Doc VisitExpr_(const CallNode* op) final { + Doc doc = Nil(); + doc << Print(op->op); + std::vector args; + for (Expr arg : op->args) { + args.push_back(Print(arg)); } + return doc << "(" << PrintVec(args) << PrintAttrs(op->attrs) << ")"; + } - //------------------------------------ - // Overload of Type printing functions - //------------------------------------ - Doc PrintType(const Type& type, bool meta) { - auto it = memo_type_.find(type); - if (it != memo_type_.end()) return it->second; - Doc printed_type; - if (meta) { - printed_type = meta_.GetMetaNode(GetRef(type.get())); - } else { - printed_type = VisitType(type); - } - memo_type_[type] = printed_type; - return printed_type; - } + //------------------------------------ + // Overload of Type printing functions + //------------------------------------ + Doc PrintType(const Type& type, bool meta) { + auto it = memo_type_.find(type); + if (it != memo_type_.end()) return it->second; + Doc printed_type; + if (meta) { + printed_type = meta_.GetMetaNode(GetRef(type.get())); + } else { + printed_type = VisitType(type); + } + memo_type_[type] = printed_type; + return printed_type; + } - Doc VisitTypeDefault_(const Node* node) final { // NOLINT(*) - // by default always print as meta data - return Print(GetRef(node), true, true); - } + Doc VisitTypeDefault_(const Node* node) final { // NOLINT(*) + // by default always print as meta data + return Print(GetRef(node), true, true); + } - Doc VisitType_(const TensorTypeNode* node) final { // NOLINT(*) - // scalar type - if (node->shape.size() == 0) { - return PrintDType(node->dtype); - } - Doc doc = Nil(); - doc << "Tensor[("; - std::vector shapes; - for (NodeRef shape : node->shape) { - shapes.push_back(PrintAttr(shape)); - } - doc << PrintVec(shapes); - // conform to python tuple format (1,) - if (node->shape.size() == 1) { - doc << ","; - } - return doc << "), " << PrintDType(node->dtype) << "]"; + Doc VisitType_(const TensorTypeNode* node) final { // NOLINT(*) + // scalar type + if (node->shape.size() == 0) { + return PrintDType(node->dtype); } + Doc doc = Nil(); + doc << "Tensor[("; + std::vector shapes; + for (NodeRef shape : node->shape) { + shapes.push_back(PrintAttr(shape)); + } + doc << PrintVec(shapes); + // conform to python tuple format (1,) + if (node->shape.size() == 1) { + doc << ","; + } + return doc << "), " << PrintDType(node->dtype) << "]"; + } - Doc VisitType_(const TupleTypeNode* node) final { - std::vector fields; - for (Type field : node->fields) { - fields.push_back(Print(field)); - } - Doc doc = Nil(); - doc << "(" << PrintVec(fields); - // conform to python tuple format (1,) - if (node->fields.size() == 1) { - doc << ","; - } - return doc << ")"; + Doc VisitType_(const TupleTypeNode* node) final { + std::vector fields; + for (Type field : node->fields) { + fields.push_back(Print(field)); + } + Doc doc = Nil(); + doc << "(" << PrintVec(fields); + // conform to python tuple format (1,) + if (node->fields.size() == 1) { + doc << ","; } + return doc << ")"; + } - Doc VisitType_(const FuncTypeNode* node) final { - Doc doc = Nil(); - std::vector arg_types; - for (Type arg_type : node->arg_types) { - arg_types.push_back(Print(arg_type)); - } - return doc << "fn (" << PrintVec(arg_types) << ") -> " << Print(node->ret_type); + Doc VisitType_(const FuncTypeNode* node) final { + Doc doc = Nil(); + std::vector arg_types; + for (Type arg_type : node->arg_types) { + arg_types.push_back(Print(arg_type)); } + return doc << "fn (" << PrintVec(arg_types) << ") -> " << Print(node->ret_type); + } - //------------------------------------ - // Overload of Attr printing functions - //------------------------------------ - - Doc PrintAttr(const NodeRef& value, bool meta = false) { // NOLINT(*) - if (value.defined()) { - Doc printed_attr; - if (meta) { - printed_attr = meta_.GetMetaNode(value); - } else { - printed_attr = VisitAttr(value); - } - return printed_attr; + //------------------------------------ + // Overload of Attr printing functions + //------------------------------------ + + Doc PrintAttr(const NodeRef& value, bool meta = false) { // NOLINT(*) + if (value.defined()) { + Doc printed_attr; + if (meta) { + printed_attr = meta_.GetMetaNode(value); } else { - return Text("None"); + printed_attr = VisitAttr(value); } + return printed_attr; + } else { + return Text("None"); } + } - Doc VisitAttrDefault_(const Node* op) final { // NOLINT(*) - return PrintAttr(GetRef(op), true); - } + Doc VisitAttrDefault_(const Node* op) final { // NOLINT(*) + return PrintAttr(GetRef(op), true); + } - Doc VisitAttr_(const ArrayNode* op) final { // NOLINT(*) - Doc doc = Nil(); - doc << "["; - std::vector arr_vals; - for (NodePtr val : op->data) { - arr_vals.push_back(PrintAttr(NodeRef(val))); - } - doc << PrintVec(arr_vals); - doc << "]"; - return doc; - } + Doc VisitAttr_(const ArrayNode* op) final { // NOLINT(*) + Doc doc = Nil(); + doc << "["; + std::vector arr_vals; + for (NodePtr val : op->data) { + arr_vals.push_back(PrintAttr(NodeRef(val))); + } + doc << PrintVec(arr_vals); + doc << "]"; + return doc; + } - Doc VisitAttr_(const ir::IntImm* op) final { // NOLINT(*) - return PrintConstScalar(op->type, &(op->value)); - } + Doc VisitAttr_(const ir::IntImm* op) final { // NOLINT(*) + return PrintConstScalar(op->type, &(op->value)); + } - Doc VisitAttr_(const ir::UIntImm* op) final { // NOLINT(*) - return PrintConstScalar(op->type, &(op->value)); - } + Doc VisitAttr_(const ir::UIntImm* op) final { // NOLINT(*) + return PrintConstScalar(op->type, &(op->value)); + } - Doc VisitAttr_(const ir::FloatImm* op) final { // NOLINT(*) - return PrintConstScalar(op->type, &(op->value)); - } + Doc VisitAttr_(const ir::FloatImm* op) final { // NOLINT(*) + return PrintConstScalar(op->type, &(op->value)); + } - Doc VisitAttr_(const ir::StringImm* op) final { // NOLINT(*) - return PrintString(op->value); - } + Doc VisitAttr_(const ir::StringImm* op) final { // NOLINT(*) + return PrintString(op->value); + } - private: - /*! \brief Whether to use GNF. */ - bool GNF_; - /*! \brief Whether to print meta data. */ - bool show_meta_data_; - /*! \brief additional comment function */ - runtime::TypedPackedFunc annotate_; - /*! \brief Stack of docs to implement scoped GNFing. */ - std::vector doc_stack_{}; - /*! \brief Map from Expr to Doc */ - std::unordered_map memo_; - /*! \brief Map from Type to Doc */ - std::unordered_map memo_type_; - /*! \brief name allocation map */ - std::unordered_map name_alloc_map_; - /*! \brief meta data context */ - TextMetaDataContext meta_; - /*! \brief counter of temporary variable */ - size_t temp_var_counter_{0}; - class AttrPrinter; - friend class AttrPrinter; + private: + /*! \brief Whether to use GNF. */ + bool GNF_; + /*! \brief Whether to print meta data. */ + bool show_meta_data_; + /*! \brief additional comment function */ + runtime::TypedPackedFunc annotate_; + /*! \brief Stack of docs to implement scoped GNFing. */ + std::vector doc_stack_{}; + /*! \brief Map from Expr to Doc */ + std::unordered_map memo_; + /*! \brief Map from Type to Doc */ + std::unordered_map memo_type_; + /*! \brief name allocation map */ + std::unordered_map name_alloc_map_; + /*! \brief meta data context */ + TextMetaDataContext meta_; + /*! \brief counter of temporary variable */ + size_t temp_var_counter_{0}; + class AttrPrinter; + friend class AttrPrinter; }; /*! @@ -552,7 +559,7 @@ class PrettyPrinter : */ class PrettyPrinter::AttrPrinter : public AttrVisitor { public: - AttrPrinter(Doc& doc_, PrettyPrinter* parent_) : doc_(doc_), parent_(parent_) {} + AttrPrinter(Doc& doc, PrettyPrinter* parent) : doc_(doc), parent_(parent) {} template Doc PrintKV(const char* key, const T& value) { @@ -591,13 +598,13 @@ class PrettyPrinter::AttrPrinter : public AttrVisitor { LOG(FATAL) << "do not allow NDarray as argument"; } - private: - Doc& doc_; - PrettyPrinter* parent_; + private: + Doc& doc_; + PrettyPrinter* parent_; }; Doc PrettyPrinter::PrintAttrs(const Attrs& attrs) { // NOLINT(*) - // TODO: fallback meta? + // TODO(jmp): fallback meta? if (!attrs.defined()) return Nil(); Doc doc = Nil(); AttrPrinter printer(doc, this); @@ -605,14 +612,20 @@ Doc PrettyPrinter::PrintAttrs(const Attrs& attrs) { // NOLINT(*) return doc; } -std::string RelayPrint(const NodeRef& node, bool show_meta_data, runtime::TypedPackedFunc annotate, bool gnf) { +std::string RelayPrint(const NodeRef& node, + bool show_meta_data, + runtime::TypedPackedFunc annotate, + bool gnf) { Doc doc = Nil(); doc << "v0.0.1" << "\n" << PrettyPrinter(gnf, show_meta_data, annotate).PrintFinal(node) << "\n"; return Layout(doc); } TVM_REGISTER_API("relay._expr.RelayPrint") -.set_body_typed, bool)>(RelayPrint); +.set_body_typed, + bool)>(RelayPrint); -} // relay -} // tvm +} // namespace relay +} // namespace tvm From 589f778756d6b5c92830f7962febdd4b6da75d22 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Wed, 6 Mar 2019 14:07:11 -0800 Subject: [PATCH 45/73] more linting --- src/relay/ir/doc.cc | 5 +++-- src/relay/ir/doc.h | 8 +++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/relay/ir/doc.cc b/src/relay/ir/doc.cc index f72a0e95b193..5b4d27022f4f 100644 --- a/src/relay/ir/doc.cc +++ b/src/relay/ir/doc.cc @@ -4,6 +4,7 @@ * \brief Doc ADT used for pretty printing. * Based on Section 1 of https://homepages.inf.ed.ac.uk/wadler/papers/prettier/prettier.pdf. */ +#include #include "doc.h" namespace tvm { @@ -134,5 +135,5 @@ Doc PrintString(const std::string& value) { // NOLINT(*) return doc << "\"" << value << "\""; } -} // relay -} // tvm +} // namespace relay +} // namespace tvm diff --git a/src/relay/ir/doc.h b/src/relay/ir/doc.h index 9140459a31c8..9f58fa1098a1 100644 --- a/src/relay/ir/doc.h +++ b/src/relay/ir/doc.h @@ -4,10 +4,12 @@ * \brief Doc ADT used for pretty printing. * Based on Section 1 of https://homepages.inf.ed.ac.uk/wadler/papers/prettier/prettier.pdf. */ -#ifndef TVM_RELAY_DOC_H_ -#define TVM_RELAY_DOC_H_ +#ifndef TVM_RELAY_IR_DOC_H_ +#define TVM_RELAY_IR_DOC_H_ +#include #include +#include namespace tvm { namespace relay { @@ -96,4 +98,4 @@ Doc PrintConstScalar(DataType dtype, const T* data) { } // namespace relay } // namespace tvm -#endif // TVM_RELAY_DOC_H_ +#endif // TVM_RELAY_IR_DOC_H_ From f27c130196714a9c2c602495ab3b0bd581b272db Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Wed, 6 Mar 2019 14:13:03 -0800 Subject: [PATCH 46/73] fix typo --- src/relay/ir/doc.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/ir/doc.h b/src/relay/ir/doc.h index 9f58fa1098a1..e5ea46723d06 100644 --- a/src/relay/ir/doc.h +++ b/src/relay/ir/doc.h @@ -71,7 +71,7 @@ Doc Indent(int indent, const Doc& doc); // convert doc to a string std::string Layout(const Doc& doc); // render vectors of docs with a separator. e.g. [1, 2, 3], f -> 1f2f3 -Doc PrintVec(const std::vector& arr, const Doc& sep = Text(", ")); +Doc PrintVec(const std::vector& vec, const Doc& sep = Text(", ")); // Print constant bool value. Doc PrintBool(bool value); Doc PrintDType(DataType dtype); From 3586edf5ef38d014a06fe9afbb3824ebc843b4a5 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Wed, 6 Mar 2019 14:14:05 -0800 Subject: [PATCH 47/73] lint --- src/relay/ir/doc.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/ir/doc.h b/src/relay/ir/doc.h index e5ea46723d06..68a50e031977 100644 --- a/src/relay/ir/doc.h +++ b/src/relay/ir/doc.h @@ -7,8 +7,8 @@ #ifndef TVM_RELAY_IR_DOC_H_ #define TVM_RELAY_IR_DOC_H_ -#include #include +#include #include namespace tvm { From 84b9c13e1ebde302d6444025d2437f7c2445b064 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Wed, 6 Mar 2019 17:00:39 -0800 Subject: [PATCH 48/73] add default case --- src/relay/ir/pretty_printer.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index bd8332a1a538..9283f1b2d0e8 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -184,7 +184,10 @@ class PrettyPrinter : return PrintType(Downcast(node), meta); } else if (node.as_derived()) { return PrintMod(Downcast(node)); - } else { assert(false); } + } else { + Doc doc = Nil(); + return doc << node; + } } Doc TempVar(int n) { From e6e2d949abccde16bc7c462873fe26bd4e7f614b Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Wed, 6 Mar 2019 17:19:06 -0800 Subject: [PATCH 49/73] remove trailing whitespace --- src/relay/ir/pretty_printer.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 9283f1b2d0e8..9998e449bc3d 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -184,7 +184,7 @@ class PrettyPrinter : return PrintType(Downcast(node), meta); } else if (node.as_derived()) { return PrintMod(Downcast(node)); - } else { + } else { Doc doc = Nil(); return doc << node; } From 7fa7cff5239d841c17eba5809acec5a39a35ea4c Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Wed, 6 Mar 2019 17:59:29 -0800 Subject: [PATCH 50/73] trigger ci --- include/tvm/relay/expr.h | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index ad6c328b9d91..4d1bae3a42c1 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -552,11 +552,10 @@ inline const TTypeNode* ExprNode::type_as() const { * \param gnf Whether to print in GNF. * \return The text representation. */ -std::string RelayPrint( - const NodeRef& node, - bool show_meta_data = true, - runtime::TypedPackedFunc annotate = nullptr, - bool gnf = true); +std::string RelayPrint(const NodeRef& node, + bool show_meta_data = true, + runtime::TypedPackedFunc annotate = nullptr, + bool gnf = true); } // namespace relay } // namespace tvm #endif // TVM_RELAY_EXPR_H_ From cbd7c89c0056e9c2de02e6f7d8ab72bc8f6d76f6 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Wed, 6 Mar 2019 20:06:12 -0800 Subject: [PATCH 51/73] fix some tests, add meta fallback, add hypothesis to python packages --- .../install/ubuntu_install_python_package.sh | 4 ++-- src/relay/ir/pretty_printer.cc | 21 ++++++++++++------- tests/python/relay/test_op_level1.py | 4 ++-- tests/python/relay/test_op_level4.py | 2 +- tests/python/relay/test_type_infer.py | 14 ++++++------- 5 files changed, 25 insertions(+), 20 deletions(-) diff --git a/docker/install/ubuntu_install_python_package.sh b/docker/install/ubuntu_install_python_package.sh index 200fe6e47781..7530a6ba5ee2 100755 --- a/docker/install/ubuntu_install_python_package.sh +++ b/docker/install/ubuntu_install_python_package.sh @@ -5,5 +5,5 @@ set -u set -o pipefail # install libraries for python package on ubuntu -pip2 install nose pylint==1.9.4 six numpy nose-timer cython decorator scipy tornado typing antlr4-python2-runtime attrs -pip3 install nose pylint==1.9.4 six numpy nose-timer cython decorator scipy tornado typed_ast pytest mypy orderedset antlr4-python3-runtime attrs +pip2 install nose pylint==1.9.4 six numpy nose-timer cython decorator scipy tornado typing antlr4-python2-runtime attrs hypothesis +pip3 install nose pylint==1.9.4 six numpy nose-timer cython decorator scipy tornado typed_ast pytest mypy orderedset antlr4-python3-runtime attrs hypothesis diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 9998e449bc3d..b3232bc122ba 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -174,7 +174,7 @@ class PrettyPrinter : return doc; } - Doc PrintAttrs(const Attrs& attrs); + Doc PrintAttrs(const Attrs& attrs, const Expr& op); // note: gnf flag is only one level deep Doc Print(const NodeRef& node, bool gnf = true, bool meta = false) { @@ -375,7 +375,7 @@ class PrettyPrinter : for (Var param : fn->params) { params.push_back(AllocVar(param)); } - doc << PrintVec(params) << PrintAttrs(fn->attrs); + doc << PrintVec(params) << PrintAttrs(fn->attrs, fn); doc << ") "; if (fn->ret_type.defined()) { doc << "-> " << Print(fn->ret_type) << " "; @@ -418,7 +418,7 @@ class PrettyPrinter : for (Expr arg : op->args) { args.push_back(Print(arg)); } - return doc << "(" << PrintVec(args) << PrintAttrs(op->attrs) << ")"; + return doc << "(" << PrintVec(args) << PrintAttrs(op->attrs, GetRef(op)) << ")"; } //------------------------------------ @@ -606,13 +606,18 @@ class PrettyPrinter::AttrPrinter : public AttrVisitor { PrettyPrinter* parent_; }; -Doc PrettyPrinter::PrintAttrs(const Attrs& attrs) { // NOLINT(*) - // TODO(jmp): fallback meta? +Doc PrettyPrinter::PrintAttrs(const Attrs& attrs, const Expr& op) { // NOLINT(*) if (!attrs.defined()) return Nil(); Doc doc = Nil(); - AttrPrinter printer(doc, this); - const_cast(attrs.operator->())->VisitNonDefaultAttrs(&printer); - return doc; + const auto* op_node = op.as(); + if (op_node && (attrs->type_index() != op_node->attrs_type_index)) { + // fallback + return doc << ", " << meta_.GetMetaNode(attrs); + } else { + AttrPrinter printer(doc, this); + const_cast(attrs.operator->())->VisitNonDefaultAttrs(&printer); + return doc; + } } std::string RelayPrint(const NodeRef& node, diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index b954e42bf1ab..ad981e6b904b 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -22,7 +22,7 @@ def check_single_op(opfunc, ref): x = relay.var("x", tp) y = opfunc(x) # test printer - assert ("%0 = {}(%x)".format(y.op.name)) in y.astext() + assert ("{}(%x)".format(y.op.name)) in y.astext() # test type inference assert relay.ir_pass.infer_type(y).checked_type == tp @@ -62,7 +62,7 @@ def check_binary_op(opfunc, ref): y = relay.var("y", t2) z = opfunc(x, y) # test printer - assert ("%0 = {}(%x, %y)".format(z.op.name)) in z.astext() + assert ("{}(%x, %y)".format(z.op.name)) in z.astext() assert relay.ir_pass.infer_type(z).checked_type == t1 if ref is not None: diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index ae7fe320940a..c876309b7383 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -13,7 +13,7 @@ def check_binary_op(opfunc, ref): y = relay.var("y", t2) z = opfunc(x, y) # test printer - assert ("%0 = {}(%x, %y)".format(z.op.name)) in z.astext() + assert ("{}(%x, %y)".format(z.op.name)) in z.astext() assert relay.ir_pass.infer_type(z).checked_type == t1 if ref is not None: diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index 8c8e7dfd1fcc..d7705c683c5d 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -28,7 +28,7 @@ def initialize_box_adt(mod): def test_monomorphic_let(): - "Program: let x = 1; return x" + "Program: let x = 1; x" sb = relay.ScopeBuilder() x = sb.let('x', relay.const(1.0, "float64")) sb.ret(x) @@ -48,7 +48,7 @@ def test_add_broadcast_op(): """ Program: fn (x: Tensor[(10, 4), f32], y: Tensor[(5, 10, 1), f32]) -> Tensor[(5, 10, 4), f32] { - return x + y; + x + y } """ x = relay.var('x', shape=(10, 4)) @@ -67,7 +67,7 @@ def test_dual_op(): fn (x : Tensor[f32, (10, 10)]) { let t1 = log(x); let t2 = add(t1, x); - return t1; + t1 } """ tp = relay.TensorType((10, 10), "float32") @@ -84,7 +84,7 @@ def test_dual_op(): def test_decl(): """Program: def f(x : Tensor[(10, 10), f32]) { - return log(x); + log(x) } """ tp = relay.TensorType((10, 10)) @@ -99,9 +99,9 @@ def test_recursion(): Program: def f(n: i32, data: f32) -> f32 { if (n == 0) { - return data; + data } else { - return f(n - 1, log(data)); + f(n - 1, log(data)) } } """ @@ -118,7 +118,7 @@ def f(n: i32, data: f32) -> f32 { sb.ret(f(relay.subtract(n, relay.const(1, ti32)), relay.log(data))) mod = relay.Module() mod[f] = relay.Function([n, data], sb.get()) - assert "%3 = @f(%1, %2)" in mod.astext() + assert "@f(%1, %2)" in mod.astext() assert mod[f].checked_type == relay.FuncType([ti32, tf32], tf32) From da6642f766f046306593b17e9cdd8ef18541870b Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Thu, 7 Mar 2019 13:53:20 -0800 Subject: [PATCH 52/73] references --- src/relay/ir/pretty_printer.cc | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index b3232bc122ba..1922f88d543e 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -421,6 +421,21 @@ class PrettyPrinter : return doc << "(" << PrintVec(args) << PrintAttrs(op->attrs, GetRef(op)) << ")"; } + Doc VisitExpr_(const RefCreateNode* op) final { + Doc doc = Nil(); + return doc << "ref(" << Print(op->value) << ")"; + } + + Doc VisitExpr_(const RefReadNode* op) final { + Doc doc = Nil(); + return doc << Print(op->ref) << "^"; + } + + Doc VisitExpr_(const RefWriteNode* op) final { + Doc doc = Nil(); + return doc << "(" << Print(op->ref) << " := " << Print(op->value) << ")"; + } + //------------------------------------ // Overload of Type printing functions //------------------------------------ @@ -484,6 +499,11 @@ class PrettyPrinter : return doc << "fn (" << PrintVec(arg_types) << ") -> " << Print(node->ret_type); } + Doc VisitType_(const RefTypeNode* node) final { + Doc doc = Nil(); + return doc << "ref(" << Print(node->value) << ")"; + } + //------------------------------------ // Overload of Attr printing functions //------------------------------------ From 89f58f6eef12af190554e7818939841b011f4838 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Thu, 7 Mar 2019 15:10:52 -0800 Subject: [PATCH 53/73] pattern matching --- src/relay/ir/pretty_printer.cc | 38 ++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 1922f88d543e..deb59108c30f 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -5,6 +5,8 @@ * Supports ANF, GNF, and metadata. */ #include +#include +#include #include "doc.h" #include "type_functor.h" #include "../../lang/attr_functor.h" @@ -110,6 +112,7 @@ class TextMetaDataContext { class PrettyPrinter : public ExprFunctor, + public PatternFunctor, public TypeFunctor, public AttrFunctor { public: @@ -137,6 +140,7 @@ class PrettyPrinter : } // indent a new body + // TODO(jmp): indent should be an instance variable of the printer Doc PrintBody(const NodeRef& node, int indent = 2) { Doc doc = Nil(); Doc body = Nil(); @@ -436,6 +440,40 @@ class PrettyPrinter : return doc << "(" << Print(op->ref) << " := " << Print(op->value) << ")"; } + Doc VisitExpr_(const MatchNode* op) final { + // TODO(jmp): Lots of code duplication here because PrintBody and PrintScope don't accept Docs. + Doc doc = Nil(); + Doc body = Nil(); + doc << "match " << Print(op->data) << " "; + doc << "{"; + std::vector clauses; + for (const auto& clause : op->clauses) { + Doc clause_doc = Nil(); + clauses.push_back(clause_doc << Print(clause->lhs, false) << " -> " << Print(clause->rhs, false)); + } + doc << Indent(2, body << "\n" << PrintVec(clauses, Line())) << "\n"; + doc << "}"; + return doc; + } + + Doc VisitPattern_(const PatternConstructorNode* p) final { + Doc doc = Nil(); + doc << p->constructor->name_hint << "("; + std::vector pats; + for (const auto& pat : p->patterns) { + pats.push_back(Print(pat)); + } + return doc << PrintVec(pats) << ")"; + } + + Doc VisitPattern_(const PatternVarNode* pv) final { + return AllocVar(pv->var); + } + + Doc VisitExpr_(const ConstructorNode* n) final { + return Text(n->name_hint); + } + //------------------------------------ // Overload of Type printing functions //------------------------------------ From f3a5728ba9735793d6f4fa348a978ae4f137b463 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Thu, 7 Mar 2019 15:18:44 -0800 Subject: [PATCH 54/73] fix attr printing --- src/relay/ir/pretty_printer.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index deb59108c30f..92354c3134ea 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -673,7 +673,7 @@ Doc PrettyPrinter::PrintAttrs(const Attrs& attrs, const Expr& op) { // NOLINT(* return doc << ", " << meta_.GetMetaNode(attrs); } else { AttrPrinter printer(doc, this); - const_cast(attrs.operator->())->VisitNonDefaultAttrs(&printer); + const_cast(attrs.operator->())->VisitAttrs(&printer); return doc; } } From d090704d734e5ec91a01659964210cb997b2d5ba Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Thu, 7 Mar 2019 15:22:39 -0800 Subject: [PATCH 55/73] linting --- src/relay/ir/pretty_printer.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 92354c3134ea..6275a831a203 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -449,7 +449,8 @@ class PrettyPrinter : std::vector clauses; for (const auto& clause : op->clauses) { Doc clause_doc = Nil(); - clauses.push_back(clause_doc << Print(clause->lhs, false) << " -> " << Print(clause->rhs, false)); + clauses.push_back(clause_doc << Print(clause->lhs, false) << " -> " + << Print(clause->rhs, false)); } doc << Indent(2, body << "\n" << PrintVec(clauses, Line())) << "\n"; doc << "}"; From 9d8a167497f7aff5e99f381aa966fd9b94647e0b Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Fri, 8 Mar 2019 19:09:55 -0800 Subject: [PATCH 56/73] default attr flag, new tests, fix some bugs --- include/tvm/relay/expr.h | 3 +- python/tvm/relay/base.py | 4 +- src/relay/ir/pretty_printer.cc | 41 ++++-- .../python/relay/test_ir_parser_roundtrip.py | 117 +++++++++++------- 4 files changed, 103 insertions(+), 62 deletions(-) diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 4d1bae3a42c1..65713483138b 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -555,7 +555,8 @@ inline const TTypeNode* ExprNode::type_as() const { std::string RelayPrint(const NodeRef& node, bool show_meta_data = true, runtime::TypedPackedFunc annotate = nullptr, - bool gnf = true); + bool gnf = true, + bool visit_default = false); } // namespace relay } // namespace tvm #endif // TVM_RELAY_EXPR_H_ diff --git a/python/tvm/relay/base.py b/python/tvm/relay/base.py index bde2be95ae8f..3e37f96d251b 100644 --- a/python/tvm/relay/base.py +++ b/python/tvm/relay/base.py @@ -38,7 +38,7 @@ def register_relay_attr_node(type_key=None): class RelayNode(NodeBase): """Base class of all Relay nodes.""" - def astext(self, show_meta_data=True, annotate=None, gnf=True): + def astext(self, show_meta_data=True, annotate=None, gnf=True, visit_default=False): """Get the text format of the expression. Parameters @@ -65,7 +65,7 @@ def astext(self, show_meta_data=True, annotate=None, gnf=True): text : str The text format of the expression. """ - return _expr.RelayPrint(self, show_meta_data, annotate, gnf) + return _expr.RelayPrint(self, show_meta_data, annotate, gnf, visit_default) def set_span(self, span): _base.set_span(self, span) diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 6275a831a203..6dae88a720a4 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -118,9 +118,12 @@ class PrettyPrinter : public: explicit PrettyPrinter(bool GNF, bool show_meta_data, - runtime::TypedPackedFunc annotate) : GNF_(GNF), - show_meta_data_(show_meta_data), - annotate_(annotate) {} + runtime::TypedPackedFunc annotate, + bool visit_default) : + GNF_(GNF), + show_meta_data_(show_meta_data), + annotate_(annotate), + visit_default_(visit_default) {} /*! * \brief Print additional info about expr in comment. @@ -208,7 +211,8 @@ class PrettyPrinter : * \param prefix The prefix of the name * \return The returned name. */ - Doc GetUniqueName(std::string prefix) { + Doc GetUniqueName(const std::string& prefix) { + std::string unique_prefix = prefix; auto it = name_alloc_map_.find(prefix); if (it != name_alloc_map_.end()) { while (true) { @@ -216,13 +220,13 @@ class PrettyPrinter : os << prefix << (++it->second); std::string name = os.str(); if (name_alloc_map_.count(name) == 0) { - prefix = name; + unique_prefix = name; break; } } } - name_alloc_map_[prefix] = 0; - return Text(prefix); + name_alloc_map_[unique_prefix] = 0; + return Text(unique_prefix); } /*! @@ -277,7 +281,7 @@ class PrettyPrinter : !expr.as() && !expr.as()) { Doc temp_var = AllocTemp(); memo_[expr] = temp_var; - doc_stack_.back() << temp_var << " = " << printed_expr; + doc_stack_.back() << temp_var << " = " << printed_expr << ";"; if (expr.as()) { doc_stack_.back() << PrintOptionalInfo(expr); } @@ -334,7 +338,12 @@ class PrettyPrinter : fields.push_back(Print(field)); } Doc doc = Nil(); - return doc << "(" << PrintVec(fields) << ")"; + doc << "(" << PrintVec(fields); + // conform to python tuple format (1,) + if (op->fields.size() == 1) { + doc << ","; + } + return doc << ")"; } Doc VisitExpr_(const TupleGetItemNode* op) final { @@ -600,6 +609,8 @@ class PrettyPrinter : bool show_meta_data_; /*! \brief additional comment function */ runtime::TypedPackedFunc annotate_; + /*! \brief Whether to visit default attributes. */ + bool visit_default_; /*! \brief Stack of docs to implement scoped GNFing. */ std::vector doc_stack_{}; /*! \brief Map from Expr to Doc */ @@ -674,7 +685,11 @@ Doc PrettyPrinter::PrintAttrs(const Attrs& attrs, const Expr& op) { // NOLINT(* return doc << ", " << meta_.GetMetaNode(attrs); } else { AttrPrinter printer(doc, this); - const_cast(attrs.operator->())->VisitAttrs(&printer); + if (visit_default_) { + const_cast(attrs.operator->())->VisitAttrs(&printer); + } else { + const_cast(attrs.operator->())->VisitNonDefaultAttrs(&printer); + } return doc; } } @@ -682,9 +697,10 @@ Doc PrettyPrinter::PrintAttrs(const Attrs& attrs, const Expr& op) { // NOLINT(* std::string RelayPrint(const NodeRef& node, bool show_meta_data, runtime::TypedPackedFunc annotate, - bool gnf) { + bool gnf, + bool visit_default) { Doc doc = Nil(); - doc << "v0.0.1" << "\n" << PrettyPrinter(gnf, show_meta_data, annotate).PrintFinal(node) << "\n"; + doc << "v0.0.1" << "\n" << PrettyPrinter(gnf, show_meta_data, annotate, visit_default).PrintFinal(node); return Layout(doc); } @@ -692,6 +708,7 @@ TVM_REGISTER_API("relay._expr.RelayPrint") .set_body_typed, + bool, bool)>(RelayPrint); } // namespace relay diff --git a/tests/python/relay/test_ir_parser_roundtrip.py b/tests/python/relay/test_ir_parser_roundtrip.py index f53a6ba0cb4d..fa953ada8642 100644 --- a/tests/python/relay/test_ir_parser_roundtrip.py +++ b/tests/python/relay/test_ir_parser_roundtrip.py @@ -3,8 +3,9 @@ from tvm.relay.ir_pass import alpha_equal import numpy as np -from hypothesis import given, reject, settings -from hypothesis.strategies import text, lists, integers, composite, recursive, deferred +# TODO(@jmp): Re-enable later when hypothesis is added as a dependency. +# from hypothesis import given, reject, settings +# from hypothesis.strategies import text, lists, integers, composite, recursive, deferred def gnf_print(expr): return expr.astext(gnf=True) @@ -12,54 +13,82 @@ def gnf_print(expr): def anf_print(expr): return expr.astext(gnf=False) -exprs = deferred(lambda: constants() - # | projections(exprs) - | tuples(exprs)) - -@composite -def constants(draw): - # python_tensor = draw(recursive(integers(), lists)) - # python_tensor = draw(lists(integers(min_value=-1000, max_value=1000))) - python_tensor = draw(integers(min_value=-1000, max_value=1000)) - # TODO: generate higher dimensional and 0D tensors. must be box shaped - return relay.Constant(tvm.nd.array(np.array(python_tensor).astype("int32"))) - -@composite -def tuples(draw, field_type): - return relay.Tuple(draw(lists(field_type, max_size=5))) - -@composite -def projections(draw, field_type): - return relay.TupleGetItem(draw(field_type), draw(integers(min_value=-1000, max_value=1000))) - -# TODO: figure out a way to not have to derandomize all the time -@settings(deadline=500, derandomize=True) -@given(exprs) -def test_roundtrip_pp(e): - alpha_equal(relay.fromtext(anf_print(e)), e) - -def test_gnf(): - assert gnf_print(relay.const(1)) == "v0.0.1\n1\n" - assert gnf_print(relay.Tuple([relay.const(1), relay.const(1)])) == "v0.0.1\n(1, 1)\n" +# TODO(@jmp): Re-enable later when hypothesis is added as a dependency. +# exprs = deferred(lambda: constants() +# # | projections(exprs) +# | tuples(exprs)) + +# @composite +# def constants(draw): +# # python_tensor = draw(recursive(integers(), lists)) +# # python_tensor = draw(lists(integers(min_value=-1000, max_value=1000))) +# python_tensor = draw(integers(min_value=-1000, max_value=1000)) +# # TODO: generate higher dimensional and 0D tensors. must be box shaped +# return relay.Constant(tvm.nd.array(np.array(python_tensor).astype("int32"))) + +# @composite +# def tuples(draw, field_type): +# return relay.Tuple(draw(lists(field_type, max_size=5))) + +# @composite +# def projections(draw, field_type): +# return relay.TupleGetItem(draw(field_type), draw(integers(min_value=-1000, max_value=1000))) + +# # TODO(@jmp): figure out a way to not have to derandomize all the time +# @settings(deadline=500, derandomize=True) +# @given(exprs) +# def test_roundtrip_pp(e): +# alpha_equal(relay.fromtext(anf_print(e)), e) + +def print_parse(e, gnf = True): + return alpha_equal(relay.fromtext(e.astext(gnf=gnf)), e) + +def parse_print(s): + s = "v0.0.1\n"+s + return relay.fromtext(s).astext() == s + +def roundtrip(e, s, gnf = True): + return print_parse(e, gnf) and parse_print(s) + +def test_gnf_simple(): + assert roundtrip(relay.const(1), "1") + +def test_tuple(): + assert roundtrip(relay.Tuple([]), "()") + assert roundtrip(relay.Tuple([relay.const(1)]), "(1,)") + assert roundtrip(relay.Tuple([relay.const(1), relay.const(1)]), "(1, 1)") one = relay.const(1) - assert gnf_print(relay.Tuple([one, one])) == "v0.0.1\n(1, 1)\n" - - # assert gnf_print(relay.If(relay.const(True), relay.TupleGetItem(relay.Tuple([one, one]), 0), relay.TupleGetItem(relay.Tuple([one, one, relay.const(1)]), 0))) == "v0.0.1\n%0 = True\nif (%0) {\n %1 = 1\n %2 = (%1, %1)\n %2.0\n} else {\n %1 = 1\n %2 = 1\n %3 = (%1, %1, %2)\n %3.0\n}\n" + assert print_parse(relay.Tuple([one, one])) + tup = relay.Tuple([relay.const(1), relay.const(1)]) + assert roundtrip(relay.Tuple([tup, tup]), "%0 = (1, 1);\n(%0, %0)") + assert print_parse(relay.Tuple([tup, tup]), gnf=False) + assert relay.Tuple([tup, tup]).astext(gnf=False) == "v0.0.1\n((1, 1), (1, 1))" def test_tensor_type(): - assert gnf_print(relay.TensorType([5, 5])) == "v0.0.1\nTensor[(5, 5), float32]\n" + assert relay.TensorType([5, 5]).astext() == "v0.0.1\nTensor[(5, 5), float32]" def test_tuple_type(): - assert gnf_print(relay.TupleType([])) == "v0.0.1\n()\n" - assert gnf_print(relay.TupleType([relay.scalar_type("int32")])) == "v0.0.1\n(int32,)\n" - assert gnf_print(relay.TupleType([relay.scalar_type("int32"),relay.scalar_type("int32")])) == "v0.0.1\n(int32, int32)\n" + assert relay.TupleType([]).astext() == "v0.0.1\n()" + assert relay.TupleType([relay.scalar_type("int32")]).astext() == "v0.0.1\n(int32,)" + assert relay.TupleType([relay.scalar_type("int32"),relay.scalar_type("int32")]).astext() == "v0.0.1\n(int32, int32)" def test_func_type(): - assert gnf_print(relay.FuncType([relay.scalar_type("int32"), relay.scalar_type("int32")], relay.scalar_type("int32"))) == "v0.0.1\nfn (int32, int32) -> int32\n" + assert relay.FuncType([relay.scalar_type("int32"), relay.scalar_type("int32")], relay.scalar_type("int32")).astext() == "v0.0.1\nfn (int32, int32) -> int32" + +def test_let(): + x = relay.var("x") + y = relay.var("y") + assert roundtrip(relay.Let(x, relay.const(1), relay.const(5)), "let %x = 1;\n5") + assert roundtrip(relay.Let(x, relay.const(1), x), "let %x = 1;\n%x") + assert roundtrip(relay.Let(x, relay.Tuple([relay.const(1), relay.const(1)]), x), "let %x = (1, 1);\n%x") + assert roundtrip(relay.Let(x, relay.Let(y, relay.const(2), y), x), "let %x = {\n let %y = 2;\n %y\n};\n%x") + +def test_func(): + x = relay.var("x") + assert roundtrip(relay.Function([x], x), "fn (%x) {\n %x\n}") + assert roundtrip(relay.Function([x], relay.Tuple([x, x])), "fn (%x) {\n (%x, %x)\n}") if __name__ == "__main__": - # for _ in range(10): - # print(anf_print(exprs.example())) one = relay.const(1) tup = relay.Tuple([relay.const(1), relay.const(1)]) print(gnf_print(relay.TupleGetItem(relay.Tuple([one, one]), 0))) @@ -68,9 +97,6 @@ def test_func_type(): print(gnf_print(relay.If(relay.const(True), relay.TupleGetItem(relay.Tuple([one, one]), 0), relay.TupleGetItem(relay.Tuple([one, one, relay.const(1)]), 0)))) print(anf_print(relay.If(relay.const(True), relay.TupleGetItem(relay.Tuple([one, one]), 0), relay.TupleGetItem(relay.Tuple([one, one, relay.const(1)]), 0)))) SEMVER = "v0.0.1" - print(gnf_print(relay.fromtext(SEMVER+"let %x = 1; 5"))) - print(relay.fromtext(SEMVER+"let %x = 1; %x").astext()) - print(relay.fromtext(SEMVER+"let %x = (1, 1); %x").astext()) print(relay.TupleGetItem(relay.Tuple([one, one]), 0).astext()) print(relay.fromtext(SEMVER+"let %x = 1; let %x = 2; %x").astext()) print(relay.Let(relay.var("x"), relay.Tuple([tup, tup]), relay.const(5)).astext()) @@ -78,8 +104,6 @@ def test_func_type(): print(anf_print(relay.Let(relay.var("x"), relay.Tuple([tup, tup]), relay.const(5)))) print(anf_print(relay.fromtext(SEMVER+"let %x = 1; let %x = 2; %x"))) print(gnf_print(relay.fromtext(SEMVER+"let %x = 1; let %x = 2; 3"))) - print(anf_print(relay.fromtext(SEMVER+"fn(%x) { %x }"))) - print(gnf_print(relay.fromtext(SEMVER+"fn(%x) { %x }"))) print(gnf_print(relay.fromtext(SEMVER+"fn(%x) { (%x, %x) }"))) print(gnf_print(relay.If(one, relay.TupleGetItem(relay.Tuple([one, one]), 0), one))) print(relay.If(relay.const(True), tup, tup).astext()) @@ -116,7 +140,6 @@ def test_func_type(): print(env.astext()) print(gnf_print(env)) print(anf_print(env)) - print(gnf_print(relay.fromtext(SEMVER+"let %x = { let %y = 2; %y }; %x"))) print(gnf_print(relay.fromtext(SEMVER+"let %x = { let %y = 2; ((%y + %y, %y * %y), 1) }; %x"))) print(anf_print(relay.fromtext(SEMVER+"let %x = { let %y = 2; ((%y + %y, %y * %y), 1) }; %x"))) print(relay.const([1,2,3]).astext()) From 1a26c84cf661b7b8b83e00a28b09d126fb95597d Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Fri, 8 Mar 2019 19:11:31 -0800 Subject: [PATCH 57/73] revert hypothesis addition --- docker/install/ubuntu_install_python_package.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/install/ubuntu_install_python_package.sh b/docker/install/ubuntu_install_python_package.sh index 7530a6ba5ee2..200fe6e47781 100755 --- a/docker/install/ubuntu_install_python_package.sh +++ b/docker/install/ubuntu_install_python_package.sh @@ -5,5 +5,5 @@ set -u set -o pipefail # install libraries for python package on ubuntu -pip2 install nose pylint==1.9.4 six numpy nose-timer cython decorator scipy tornado typing antlr4-python2-runtime attrs hypothesis -pip3 install nose pylint==1.9.4 six numpy nose-timer cython decorator scipy tornado typed_ast pytest mypy orderedset antlr4-python3-runtime attrs hypothesis +pip2 install nose pylint==1.9.4 six numpy nose-timer cython decorator scipy tornado typing antlr4-python2-runtime attrs +pip3 install nose pylint==1.9.4 six numpy nose-timer cython decorator scipy tornado typed_ast pytest mypy orderedset antlr4-python3-runtime attrs From a587a36bd3081fe741e6ecf00ae288ac29a49162 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Fri, 8 Mar 2019 19:13:49 -0800 Subject: [PATCH 58/73] linting --- src/relay/ir/pretty_printer.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 6dae88a720a4..1c8fc6e6c9a0 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -700,7 +700,8 @@ std::string RelayPrint(const NodeRef& node, bool gnf, bool visit_default) { Doc doc = Nil(); - doc << "v0.0.1" << "\n" << PrettyPrinter(gnf, show_meta_data, annotate, visit_default).PrintFinal(node); + doc << "v0.0.1" << "\n" + << PrettyPrinter(gnf, show_meta_data, annotate, visit_default).PrintFinal(node); return Layout(doc); } From fb1345ff7f1ca422d10bebc216b4e2ecf8afb6b5 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Tue, 12 Mar 2019 11:57:51 -0700 Subject: [PATCH 59/73] revert overeager printing optimizations. untrack roundtrip tests --- src/relay/ir/pretty_printer.cc | 36 ++--- .../python/relay/test_ir_parser_roundtrip.py | 147 ------------------ tests/python/relay/test_ir_text_printer.py | 4 +- tests/python/relay/test_op_level1.py | 4 +- tests/python/relay/test_op_level4.py | 2 +- tests/python/relay/test_type_infer.py | 2 +- 6 files changed, 20 insertions(+), 175 deletions(-) delete mode 100644 tests/python/relay/test_ir_parser_roundtrip.py diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 1c8fc6e6c9a0..241d20ea2f57 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -159,7 +159,7 @@ class PrettyPrinter : // print in a new scope doc_stack_.push_back(Nil()); // must print first so doc_stack_.back() reference doesn't become stale - Doc doc = Print(node, false); + Doc doc = Print(node); doc = doc_stack_.back() << doc; doc_stack_.pop_back(); return doc; @@ -184,9 +184,9 @@ class PrettyPrinter : Doc PrintAttrs(const Attrs& attrs, const Expr& op); // note: gnf flag is only one level deep - Doc Print(const NodeRef& node, bool gnf = true, bool meta = false) { + Doc Print(const NodeRef& node, bool meta = false) { if (node.as_derived()) { - return PrintExpr(Downcast(node), gnf, meta); + return PrintExpr(Downcast(node), meta); } else if (node.as_derived()) { return PrintType(Downcast(node), meta); } else if (node.as_derived()) { @@ -255,7 +255,7 @@ class PrettyPrinter : //------------------------------------ // Overload of Expr printing functions //------------------------------------ - Doc PrintExpr(const Expr& expr, bool gnf, bool meta) { + Doc PrintExpr(const Expr& expr, bool meta) { // Exploit memoization to print GNF. // The first time we visit an expression, we need to allocate a temp var // for it. Every subsequent time we can just use its assigned variable. @@ -265,7 +265,7 @@ class PrettyPrinter : Doc printed_expr; if (meta) { printed_expr = meta_.GetMetaNode(GetRef(expr.get())); - } else if (GNF_ && gnf && expr.as()) { + } else if (GNF_ && expr.as()) { // wrap GNFed let in brackets printed_expr = Nil(); Doc body = Nil(); @@ -276,12 +276,12 @@ class PrettyPrinter : printed_expr = VisitExpr(expr); } // we choose to inline some nodes - if (GNF_ && gnf && + if (GNF_ && !expr.as() && !expr.as() && !expr.as() && !expr.as()) { Doc temp_var = AllocTemp(); memo_[expr] = temp_var; - doc_stack_.back() << temp_var << " = " << printed_expr << ";"; + doc_stack_.back() << temp_var << " = " << printed_expr; if (expr.as()) { doc_stack_.back() << PrintOptionalInfo(expr); } @@ -328,8 +328,8 @@ class PrettyPrinter : } // default fall-back, record it as meta node. Doc doc = Nil(); - return doc << Print(GetRef(op), true, true) - << PrintOptionalInfo(GetRef(op)); + return doc << Print(GetRef(op), true) + << PrintOptionalInfo(GetRef(op)); } Doc VisitExpr_(const TupleNode* op) final { @@ -362,16 +362,8 @@ class PrettyPrinter : Doc VisitExpr_(const LetNode* op) final { Doc doc = Nil(); - doc << "let " << AllocVar(op->var) << " = "; - if (op->value.as()) { - doc << PrintBody(op->value); - } else { - // we use ANF mode for the first level of the value position so the - // final expression isn't hoisted or added to the doc stream - doc << Print(op->value, false); - } - doc << ";" << "\n"; - // we use a nested scope here so GNF hoisting doesn't escape too far + doc << "let " << AllocVar(op->var) << " = " << PrintBody(op->value) << ";" << "\n"; + // we use a scope here so GNF hoisting doesn't escape too far // and so consecutive lets don't get hoisted doc << PrintScope(op->body); return doc; @@ -458,8 +450,8 @@ class PrettyPrinter : std::vector clauses; for (const auto& clause : op->clauses) { Doc clause_doc = Nil(); - clauses.push_back(clause_doc << Print(clause->lhs, false) << " -> " - << Print(clause->rhs, false)); + clauses.push_back(clause_doc << Print(clause->lhs) << " -> " + << Print(clause->rhs)); } doc << Indent(2, body << "\n" << PrintVec(clauses, Line())) << "\n"; doc << "}"; @@ -502,7 +494,7 @@ class PrettyPrinter : Doc VisitTypeDefault_(const Node* node) final { // NOLINT(*) // by default always print as meta data - return Print(GetRef(node), true, true); + return Print(GetRef(node), true); } Doc VisitType_(const TensorTypeNode* node) final { // NOLINT(*) diff --git a/tests/python/relay/test_ir_parser_roundtrip.py b/tests/python/relay/test_ir_parser_roundtrip.py deleted file mode 100644 index fa953ada8642..000000000000 --- a/tests/python/relay/test_ir_parser_roundtrip.py +++ /dev/null @@ -1,147 +0,0 @@ -import tvm -from tvm import relay -from tvm.relay.ir_pass import alpha_equal -import numpy as np - -# TODO(@jmp): Re-enable later when hypothesis is added as a dependency. -# from hypothesis import given, reject, settings -# from hypothesis.strategies import text, lists, integers, composite, recursive, deferred - -def gnf_print(expr): - return expr.astext(gnf=True) - -def anf_print(expr): - return expr.astext(gnf=False) - -# TODO(@jmp): Re-enable later when hypothesis is added as a dependency. -# exprs = deferred(lambda: constants() -# # | projections(exprs) -# | tuples(exprs)) - -# @composite -# def constants(draw): -# # python_tensor = draw(recursive(integers(), lists)) -# # python_tensor = draw(lists(integers(min_value=-1000, max_value=1000))) -# python_tensor = draw(integers(min_value=-1000, max_value=1000)) -# # TODO: generate higher dimensional and 0D tensors. must be box shaped -# return relay.Constant(tvm.nd.array(np.array(python_tensor).astype("int32"))) - -# @composite -# def tuples(draw, field_type): -# return relay.Tuple(draw(lists(field_type, max_size=5))) - -# @composite -# def projections(draw, field_type): -# return relay.TupleGetItem(draw(field_type), draw(integers(min_value=-1000, max_value=1000))) - -# # TODO(@jmp): figure out a way to not have to derandomize all the time -# @settings(deadline=500, derandomize=True) -# @given(exprs) -# def test_roundtrip_pp(e): -# alpha_equal(relay.fromtext(anf_print(e)), e) - -def print_parse(e, gnf = True): - return alpha_equal(relay.fromtext(e.astext(gnf=gnf)), e) - -def parse_print(s): - s = "v0.0.1\n"+s - return relay.fromtext(s).astext() == s - -def roundtrip(e, s, gnf = True): - return print_parse(e, gnf) and parse_print(s) - -def test_gnf_simple(): - assert roundtrip(relay.const(1), "1") - -def test_tuple(): - assert roundtrip(relay.Tuple([]), "()") - assert roundtrip(relay.Tuple([relay.const(1)]), "(1,)") - assert roundtrip(relay.Tuple([relay.const(1), relay.const(1)]), "(1, 1)") - one = relay.const(1) - assert print_parse(relay.Tuple([one, one])) - tup = relay.Tuple([relay.const(1), relay.const(1)]) - assert roundtrip(relay.Tuple([tup, tup]), "%0 = (1, 1);\n(%0, %0)") - assert print_parse(relay.Tuple([tup, tup]), gnf=False) - assert relay.Tuple([tup, tup]).astext(gnf=False) == "v0.0.1\n((1, 1), (1, 1))" - -def test_tensor_type(): - assert relay.TensorType([5, 5]).astext() == "v0.0.1\nTensor[(5, 5), float32]" - -def test_tuple_type(): - assert relay.TupleType([]).astext() == "v0.0.1\n()" - assert relay.TupleType([relay.scalar_type("int32")]).astext() == "v0.0.1\n(int32,)" - assert relay.TupleType([relay.scalar_type("int32"),relay.scalar_type("int32")]).astext() == "v0.0.1\n(int32, int32)" - -def test_func_type(): - assert relay.FuncType([relay.scalar_type("int32"), relay.scalar_type("int32")], relay.scalar_type("int32")).astext() == "v0.0.1\nfn (int32, int32) -> int32" - -def test_let(): - x = relay.var("x") - y = relay.var("y") - assert roundtrip(relay.Let(x, relay.const(1), relay.const(5)), "let %x = 1;\n5") - assert roundtrip(relay.Let(x, relay.const(1), x), "let %x = 1;\n%x") - assert roundtrip(relay.Let(x, relay.Tuple([relay.const(1), relay.const(1)]), x), "let %x = (1, 1);\n%x") - assert roundtrip(relay.Let(x, relay.Let(y, relay.const(2), y), x), "let %x = {\n let %y = 2;\n %y\n};\n%x") - -def test_func(): - x = relay.var("x") - assert roundtrip(relay.Function([x], x), "fn (%x) {\n %x\n}") - assert roundtrip(relay.Function([x], relay.Tuple([x, x])), "fn (%x) {\n (%x, %x)\n}") - -if __name__ == "__main__": - one = relay.const(1) - tup = relay.Tuple([relay.const(1), relay.const(1)]) - print(gnf_print(relay.TupleGetItem(relay.Tuple([one, one]), 0))) - print(relay.If(relay.const(True), relay.TupleGetItem(relay.Tuple([one, one]), 0), relay.TupleGetItem(relay.Tuple([one, one, relay.const(1)]), 0)).astext()) - print(gnf_print(relay.If(relay.const(True), relay.const(1), relay.const(1)))) - print(gnf_print(relay.If(relay.const(True), relay.TupleGetItem(relay.Tuple([one, one]), 0), relay.TupleGetItem(relay.Tuple([one, one, relay.const(1)]), 0)))) - print(anf_print(relay.If(relay.const(True), relay.TupleGetItem(relay.Tuple([one, one]), 0), relay.TupleGetItem(relay.Tuple([one, one, relay.const(1)]), 0)))) - SEMVER = "v0.0.1" - print(relay.TupleGetItem(relay.Tuple([one, one]), 0).astext()) - print(relay.fromtext(SEMVER+"let %x = 1; let %x = 2; %x").astext()) - print(relay.Let(relay.var("x"), relay.Tuple([tup, tup]), relay.const(5)).astext()) - print(gnf_print(relay.Let(relay.var("x"), relay.Tuple([tup, tup]), relay.const(5)))) - print(anf_print(relay.Let(relay.var("x"), relay.Tuple([tup, tup]), relay.const(5)))) - print(anf_print(relay.fromtext(SEMVER+"let %x = 1; let %x = 2; %x"))) - print(gnf_print(relay.fromtext(SEMVER+"let %x = 1; let %x = 2; 3"))) - print(gnf_print(relay.fromtext(SEMVER+"fn(%x) { (%x, %x) }"))) - print(gnf_print(relay.If(one, relay.TupleGetItem(relay.Tuple([one, one]), 0), one))) - print(relay.If(relay.const(True), tup, tup).astext()) - print(gnf_print(relay.If(relay.GlobalVar("foo"), relay.TupleGetItem(relay.Tuple([one, one]), 0), one))) - print(anf_print(relay.fromtext(SEMVER+"(fn(%x, %y) { %x })(1, 2)"))) - print(gnf_print(relay.fromtext(SEMVER+"(fn(%x, %y) { %x })(1, 2)"))) - print(relay.fromtext(SEMVER+"(fn(%x, %y) { %x })(1, 2)").astext()) - print(relay.fromtext(SEMVER+"fn(%x, %y) { %x + %y }").astext()) - print(anf_print(relay.fromtext(SEMVER+"fn(%x, %y) { %x + %y }"))) - print(gnf_print(relay.fromtext(SEMVER+"fn(%x, %y) { %x + %y }"))) - print(relay.Call(relay.fromtext(SEMVER+"fn(%x) { %x }"), [relay.const(1)], attrs=tvm.make.node("DictAttrs", n="foo")).astext()) - # print(anf_print(relay.Call(relay.fromtext(SEMVER+"fn(%x) { %x }"), [relay.const(1)], attrs=tvm.make.node("DictAttrs", n="foo")))) - # print(relay.fromtext(SEMVER+"add(n=5)").astext()) - # print(anf_print(relay.fromtext(SEMVER+"fn (n=5) { () }"))) - x = relay.var("x", shape=(3, 2)) - y = relay.var("y") - one = relay.const(10e10, dtype="float32") - z = relay.add(x, one) - z = relay.add(z, z) - f = relay.Function([x, y], z) - print(z.astext()) - print(f.astext()) - print(gnf_print(z)) - print(gnf_print(f)) - print(anf_print(z)) - print(anf_print(f)) - x = relay.var("x", "float32") - y = relay.var("y", "float32") - z = relay.add(x, y) - z = relay.add(z, z) - f = relay.Function([x, y], z) - env = relay.Module() - env["myf"] = f - print(env.astext()) - print(gnf_print(env)) - print(anf_print(env)) - print(gnf_print(relay.fromtext(SEMVER+"let %x = { let %y = 2; ((%y + %y, %y * %y), 1) }; %x"))) - print(anf_print(relay.fromtext(SEMVER+"let %x = { let %y = 2; ((%y + %y, %y * %y), 1) }; %x"))) - print(relay.const([1,2,3]).astext()) - print(gnf_print(relay.const([1,2,3]))) - print(anf_print(relay.const([1,2,3]))) diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index a32b6c2b608a..21bd85a3eb37 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -33,8 +33,8 @@ def test_env(): text = env.astext() assert "def @myf" in text assert "def @myf" in str(env) - assert "add(%0, %0) # ty=float32" in text - assert "add(%0, %0) # ty=float32" in str(env) + assert "%1 = add(%0, %0) # ty=float32" in text + assert "%1 = add(%0, %0) # ty=float32" in str(env) show(env.astext(annotate=lambda x: str(x.checked_type.dtype))) show(text) diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index ad981e6b904b..b954e42bf1ab 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -22,7 +22,7 @@ def check_single_op(opfunc, ref): x = relay.var("x", tp) y = opfunc(x) # test printer - assert ("{}(%x)".format(y.op.name)) in y.astext() + assert ("%0 = {}(%x)".format(y.op.name)) in y.astext() # test type inference assert relay.ir_pass.infer_type(y).checked_type == tp @@ -62,7 +62,7 @@ def check_binary_op(opfunc, ref): y = relay.var("y", t2) z = opfunc(x, y) # test printer - assert ("{}(%x, %y)".format(z.op.name)) in z.astext() + assert ("%0 = {}(%x, %y)".format(z.op.name)) in z.astext() assert relay.ir_pass.infer_type(z).checked_type == t1 if ref is not None: diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index c876309b7383..ae7fe320940a 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -13,7 +13,7 @@ def check_binary_op(opfunc, ref): y = relay.var("y", t2) z = opfunc(x, y) # test printer - assert ("{}(%x, %y)".format(z.op.name)) in z.astext() + assert ("%0 = {}(%x, %y)".format(z.op.name)) in z.astext() assert relay.ir_pass.infer_type(z).checked_type == t1 if ref is not None: diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index d7705c683c5d..8fb83ece0ebd 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -118,7 +118,7 @@ def f(n: i32, data: f32) -> f32 { sb.ret(f(relay.subtract(n, relay.const(1, ti32)), relay.log(data))) mod = relay.Module() mod[f] = relay.Function([n, data], sb.get()) - assert "@f(%1, %2)" in mod.astext() + assert "%3 = @f(%1, %2)" in mod.astext() assert mod[f].checked_type == relay.FuncType([ti32, tf32], tf32) From 3a5d1de1df00d1d5744871491c4b7b0fb252473d Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Tue, 12 Mar 2019 12:34:00 -0700 Subject: [PATCH 60/73] separate interfaces for interchange and debug --- include/tvm/relay/expr.h | 4 +-- python/tvm/relay/base.py | 5 +--- python/tvm/relay/ir_pass.py | 31 +++++++++++++++++++++++ src/relay/ir/pretty_printer.cc | 46 ++++++++++++++++++++-------------- 4 files changed, 60 insertions(+), 26 deletions(-) diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 65713483138b..41fd0b41c9a7 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -554,9 +554,7 @@ inline const TTypeNode* ExprNode::type_as() const { */ std::string RelayPrint(const NodeRef& node, bool show_meta_data = true, - runtime::TypedPackedFunc annotate = nullptr, - bool gnf = true, - bool visit_default = false); + runtime::TypedPackedFunc annotate = nullptr); } // namespace relay } // namespace tvm #endif // TVM_RELAY_EXPR_H_ diff --git a/python/tvm/relay/base.py b/python/tvm/relay/base.py index 3e37f96d251b..7c2b8e24eb12 100644 --- a/python/tvm/relay/base.py +++ b/python/tvm/relay/base.py @@ -38,7 +38,7 @@ def register_relay_attr_node(type_key=None): class RelayNode(NodeBase): """Base class of all Relay nodes.""" - def astext(self, show_meta_data=True, annotate=None, gnf=True, visit_default=False): + def astext(self, show_meta_data=True, annotate=None): """Get the text format of the expression. Parameters @@ -51,9 +51,6 @@ def astext(self, show_meta_data=True, annotate=None, gnf=True, visit_default=Fal Optional annotate function to provide additional information in the comment block. - gnf : bool - Whether to print in GNF. - Note ---- The meta data section is necessary to fully parse the text format. diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 12b6ec8ca8e2..c74b3720e559 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -905,3 +905,34 @@ def eliminate_common_subexpr(expr, fskip=None): The output expression. """ return _ir_pass.eliminate_common_subexpr(expr, fskip) + +def pass_debug_print(ast, show_meta_data=True, annotate=None, gnf=True): + """ + THIS SHOULD BE USED ONLY FOR DEBUGGING, NOT AS AN INTERCHANGE FORMAT! + USE `.astext()` INSTEAD! + + A version of the pretty printer intended for debugging passes. Contains + advanced printing options. + + Parameters + ---------- + ast : Union[relay.Expr, relay.Module, relay.Type] + The relay fragment to be turned into text. + + show_meta_data : bool + Whether to include meta data section in the text + if there is meta data. + + annotate: Optional[relay.Expr->str] + Optional annotate function to provide additional + information in the comment block. + + gnf : bool + Whether to print in GNF. If it is disabled, pointers are left implicit. + + Returns + ------- + text : str + A text representation of `ast`. + """ + return _ir_pass.pass_debug_print(ast, show_meta_data, annotate, gnf) \ No newline at end of file diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 241d20ea2f57..b7fae7a5507e 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -118,12 +118,10 @@ class PrettyPrinter : public: explicit PrettyPrinter(bool GNF, bool show_meta_data, - runtime::TypedPackedFunc annotate, - bool visit_default) : + runtime::TypedPackedFunc annotate) : GNF_(GNF), show_meta_data_(show_meta_data), - annotate_(annotate), - visit_default_(visit_default) {} + annotate_(annotate) {} /*! * \brief Print additional info about expr in comment. @@ -601,8 +599,6 @@ class PrettyPrinter : bool show_meta_data_; /*! \brief additional comment function */ runtime::TypedPackedFunc annotate_; - /*! \brief Whether to visit default attributes. */ - bool visit_default_; /*! \brief Stack of docs to implement scoped GNFing. */ std::vector doc_stack_{}; /*! \brief Map from Expr to Doc */ @@ -677,32 +673,44 @@ Doc PrettyPrinter::PrintAttrs(const Attrs& attrs, const Expr& op) { // NOLINT(* return doc << ", " << meta_.GetMetaNode(attrs); } else { AttrPrinter printer(doc, this); - if (visit_default_) { - const_cast(attrs.operator->())->VisitAttrs(&printer); - } else { - const_cast(attrs.operator->())->VisitNonDefaultAttrs(&printer); - } + const_cast(attrs.operator->())->VisitNonDefaultAttrs(&printer); return doc; } } -std::string RelayPrint(const NodeRef& node, - bool show_meta_data, - runtime::TypedPackedFunc annotate, - bool gnf, - bool visit_default) { +std::string PrettyPrint_(const NodeRef& node, + bool show_meta_data, + runtime::TypedPackedFunc annotate, + bool gnf) { Doc doc = Nil(); doc << "v0.0.1" << "\n" - << PrettyPrinter(gnf, show_meta_data, annotate, visit_default).PrintFinal(node); + << PrettyPrinter(gnf, show_meta_data, annotate).PrintFinal(node); return Layout(doc); } +std::string RelayPrint(const NodeRef& node, + bool show_meta_data, + runtime::TypedPackedFunc annotate) { + return PrettyPrint_(node, show_meta_data, annotate, false); +} + +std::string PassDebugPrint(const NodeRef& node, + bool show_meta_data, + runtime::TypedPackedFunc annotate, + bool gnf) { + return PrettyPrint_(node, show_meta_data, annotate, gnf); +} + TVM_REGISTER_API("relay._expr.RelayPrint") .set_body_typed, + runtime::TypedPackedFunc)>(RelayPrint); + +TVM_REGISTER_API("relay._ir_pass.pass_debug_print") +.set_body_typed(RelayPrint); + runtime::TypedPackedFunc, + bool)>(PassDebugPrint); } // namespace relay } // namespace tvm From 580781599245d05e2924f97c0f557e2e9b6a58a0 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Tue, 12 Mar 2019 12:39:30 -0700 Subject: [PATCH 61/73] bug fix --- python/tvm/relay/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/base.py b/python/tvm/relay/base.py index 7c2b8e24eb12..548c0e35a342 100644 --- a/python/tvm/relay/base.py +++ b/python/tvm/relay/base.py @@ -62,7 +62,7 @@ def astext(self, show_meta_data=True, annotate=None): text : str The text format of the expression. """ - return _expr.RelayPrint(self, show_meta_data, annotate, gnf, visit_default) + return _expr.RelayPrint(self, show_meta_data, annotate) def set_span(self, span): _base.set_span(self, span) From e7201faeb6ba8e88de9c5d34b40890e357e2bc87 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Tue, 12 Mar 2019 14:10:56 -0700 Subject: [PATCH 62/73] remove hacks --- src/relay/ir/pretty_printer.cc | 16 +++++++++------- tests/python/relay/test_ir_text_printer.py | 2 +- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index b7fae7a5507e..5480e2677496 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -250,6 +250,11 @@ class PrettyPrinter : return val; } + inline bool IsAtomicExpr(const Expr& expr) { + return expr.as() || expr.as() || + expr.as() || expr.as(); + } + //------------------------------------ // Overload of Expr printing functions //------------------------------------ @@ -273,10 +278,8 @@ class PrettyPrinter : } else { printed_expr = VisitExpr(expr); } - // we choose to inline some nodes - if (GNF_ && - !expr.as() && !expr.as() && - !expr.as() && !expr.as()) { + // we choose to inline atomic exprs + if (GNF_ && !IsAtomicExpr(expr)) { Doc temp_var = AllocTemp(); memo_[expr] = temp_var; doc_stack_.back() << temp_var << " = " << printed_expr; @@ -360,9 +363,8 @@ class PrettyPrinter : Doc VisitExpr_(const LetNode* op) final { Doc doc = Nil(); - doc << "let " << AllocVar(op->var) << " = " << PrintBody(op->value) << ";" << "\n"; + doc << "let " << AllocVar(op->var) << " = " << Print(op->value) << "\n"; // we use a scope here so GNF hoisting doesn't escape too far - // and so consecutive lets don't get hoisted doc << PrintScope(op->body); return doc; } @@ -691,7 +693,7 @@ std::string PrettyPrint_(const NodeRef& node, std::string RelayPrint(const NodeRef& node, bool show_meta_data, runtime::TypedPackedFunc annotate) { - return PrettyPrint_(node, show_meta_data, annotate, false); + return PrettyPrint_(node, show_meta_data, annotate, true); } std::string PassDebugPrint(const NodeRef& node, diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index 21bd85a3eb37..f252194f8c0c 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -95,7 +95,7 @@ def test_let_if_scope(): f = relay.Function([x, y, cond], result) text = f.astext() - assert text.count("{") == 4 + assert text.count("{") == 6 assert "%cond: bool" in text show(f.astext()) From 3184de0f93f3cf925c6a9659cf2c27db951f5817 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Tue, 12 Mar 2019 17:25:13 -0700 Subject: [PATCH 63/73] fix atr bug --- src/relay/ir/pretty_printer.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 5480e2677496..9e1862cc76eb 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -423,7 +423,7 @@ class PrettyPrinter : for (Expr arg : op->args) { args.push_back(Print(arg)); } - return doc << "(" << PrintVec(args) << PrintAttrs(op->attrs, GetRef(op)) << ")"; + return doc << "(" << PrintVec(args) << PrintAttrs(op->attrs, op->op) << ")"; } Doc VisitExpr_(const RefCreateNode* op) final { From f5081bdf0b4cc0a68ad375d5f7f94b83c7e9b948 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Tue, 12 Mar 2019 17:37:06 -0700 Subject: [PATCH 64/73] lint --- src/relay/ir/doc.cc | 1 + src/relay/ir/doc.h | 1 + 2 files changed, 2 insertions(+) diff --git a/src/relay/ir/doc.cc b/src/relay/ir/doc.cc index 5b4d27022f4f..83229d7187fc 100644 --- a/src/relay/ir/doc.cc +++ b/src/relay/ir/doc.cc @@ -4,6 +4,7 @@ * \brief Doc ADT used for pretty printing. * Based on Section 1 of https://homepages.inf.ed.ac.uk/wadler/papers/prettier/prettier.pdf. */ +#include #include #include "doc.h" diff --git a/src/relay/ir/doc.h b/src/relay/ir/doc.h index 68a50e031977..20c11b35d862 100644 --- a/src/relay/ir/doc.h +++ b/src/relay/ir/doc.h @@ -8,6 +8,7 @@ #define TVM_RELAY_IR_DOC_H_ #include +#include #include #include From c8ab556651c30fc8606c68c3778ec6f17112ba9b Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Tue, 12 Mar 2019 17:44:01 -0700 Subject: [PATCH 65/73] lint --- python/tvm/relay/ir_pass.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index c74b3720e559..f3f8fea97412 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -906,6 +906,7 @@ def eliminate_common_subexpr(expr, fskip=None): """ return _ir_pass.eliminate_common_subexpr(expr, fskip) + def pass_debug_print(ast, show_meta_data=True, annotate=None, gnf=True): """ THIS SHOULD BE USED ONLY FOR DEBUGGING, NOT AS AN INTERCHANGE FORMAT! @@ -935,4 +936,4 @@ def pass_debug_print(ast, show_meta_data=True, annotate=None, gnf=True): text : str A text representation of `ast`. """ - return _ir_pass.pass_debug_print(ast, show_meta_data, annotate, gnf) \ No newline at end of file + return _ir_pass.pass_debug_print(ast, show_meta_data, annotate, gnf) From da68638fa1e1ce1cb51304be99ccf3e31bd68378 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Wed, 13 Mar 2019 19:37:14 -0700 Subject: [PATCH 66/73] improve doc stream interface --- src/relay/ir/doc.cc | 93 ++++++++-------------------------- src/relay/ir/doc.h | 86 ++++++++++++++++--------------- src/relay/ir/pretty_printer.cc | 75 ++++++++++++++------------- 3 files changed, 104 insertions(+), 150 deletions(-) diff --git a/src/relay/ir/doc.cc b/src/relay/ir/doc.cc index 83229d7187fc..da948898687f 100644 --- a/src/relay/ir/doc.cc +++ b/src/relay/ir/doc.cc @@ -13,104 +13,53 @@ namespace relay { // DSL function implementations -// empty doc/nil constructor -Doc Nil() { - return std::make_shared(); -} - // text constructor -Doc Text(const std::string& str, const Doc& doc) { - return std::make_shared(str, doc); -} - -// lift string to text -Doc Text(const std::string& str) { - return Text(str, Nil()); +DocAtom Text(const std::string& str) { + return std::make_shared(str); } // line constructor -Doc Line(int indent, const Doc& doc) { - return std::make_shared(indent, doc); -} - -// new line -Doc Line() { - return Line(0, Nil()); -} - -// concat two docs -Doc Concat(const Doc& left, const Doc& right) { - if (auto nil = std::dynamic_pointer_cast(left)) { - // throw away nil - return right; - } else if (auto text = std::dynamic_pointer_cast(left)) { - // push right into text continuation - return Text(text->str, Concat(text->doc, right)); - } else if (auto line = std::dynamic_pointer_cast(left)) { - // push right into line continuation - return Line(line->indent, Concat(line->doc, right)); - } else {assert(false);} -} - -// sugar for Concat -Doc operator+(const Doc& left, const Doc& right) { - return Concat(left, right); +DocAtom Line(int indent) { + return std::make_shared(indent); } // sugar for Concat with result stored in left -Doc& operator<<(Doc& left, const Doc& right) { - left = left + right; - return left; +Doc& Doc::operator<<(const Doc& right) { + this->stream_.insert(this->stream_.end(), right.stream_.begin(), right.stream_.end()); + return *this; } // like above, but automatically lifts string to a doc -Doc& operator<<(Doc& left, const std::string& right) { +Doc& Doc::operator<<(const std::string& right) { if (right == "\n") { - return left << Line(); + return *this << Line(); } else { - return left << Text(right); + return *this << Text(right); } } // indent a doc Doc Indent(int indent, const Doc& doc) { - if (auto nil = std::dynamic_pointer_cast(doc)) { - // absorb indent - return nil; - } else if (auto text = std::dynamic_pointer_cast(doc)) { - // push indent through - return Text(text->str, Indent(indent, text->doc)); - } else if (auto line = std::dynamic_pointer_cast(doc)) { - // add indent to line and continue - return Line(indent + line->indent, Indent(indent, line->doc)); - } else {assert(false);} -} - -// convert a doc to a string -std::string Layout(const Doc& doc) { - if (auto nil = std::dynamic_pointer_cast(doc)) { - return ""; - } else if (auto text = std::dynamic_pointer_cast(doc)) { - // add text and continue - return text->str + Layout(text->doc); - } else if (auto line = std::dynamic_pointer_cast(doc)) { - // add a newline and indents, then continue - return "\n" + std::string(line->indent, ' ') + Layout(line->doc); - } else {assert(false);} + Doc ret; + for (auto atom : doc.stream_) { + if (auto text = std::dynamic_pointer_cast(atom)) { + ret << atom; + } else if (auto line = std::dynamic_pointer_cast(atom)) { + ret << Line(indent + line->indent); + } else {assert(false);} + } + return ret; } // render vectors of docs with a separator. e.g. [1, 2, 3], f -> 1f2f3 Doc PrintVec(const std::vector& vec, const Doc& sep) { Doc seq; - if (vec.size() == 0) { - seq = Nil(); - } else { + if (vec.size() != 0) { seq = vec[0]; for (size_t i = 1; i < vec.size(); i++) { seq << sep << vec[i]; } } - return seq; } @@ -132,7 +81,7 @@ Doc PrintDType(DataType dtype) { Doc PrintString(const std::string& value) { // NOLINT(*) // TODO(M.K.): add escape. - Doc doc = Nil(); + Doc doc; return doc << "\"" << value << "\""; } diff --git a/src/relay/ir/doc.h b/src/relay/ir/doc.h index 20c11b35d862..a0bb94cc2145 100644 --- a/src/relay/ir/doc.h +++ b/src/relay/ir/doc.h @@ -16,61 +16,67 @@ namespace tvm { namespace relay { // ADT -struct DocNode { - virtual ~DocNode() = default; +struct DocAtomNode { + virtual ~DocAtomNode() = default; }; -using Doc = std::shared_ptr; +using DocAtom = std::shared_ptr; -struct NilNode : DocNode { }; - -struct TextNode : DocNode { +struct TextNode : DocAtomNode { std::string str; - Doc doc; - TextNode(const std::string& str, const Doc& doc) : str(str), doc(doc) {} + TextNode(const std::string& str) : str(str) {} }; -struct LineNode : DocNode { +struct LineNode : DocAtomNode { int indent; - Doc doc; - LineNode(int indent, const Doc& doc) : indent(indent), doc(doc) {} + LineNode(int indent) : indent(indent) {} }; -// text constructor -Doc Text(const std::string& str, const Doc& doc); +// Doc is a stream-like interface +struct Doc { + public: + Doc() {} + Doc(const DocAtom& atom) : stream_({atom}) {} -// line constructor -Doc Line(int indent, const Doc& doc); + // Append right to left. + Doc& operator<<(const Doc& right); + // like above, but automatically lifts string to a doc atom + Doc& operator<<(const std::string& right); + // like above, but converts right to a string first + template + Doc& operator<<(const T& right) { + std::ostringstream os; + os << right; + return *this << os.str(); + } + + // indent a doc stream + friend Doc Indent(int indent, const Doc& doc); + + std::string str() { + std::ostringstream os; + for (auto atom : stream_) { + if (auto text = std::dynamic_pointer_cast(atom)) { + os << text->str; + } else if (auto line = std::dynamic_pointer_cast(atom)) { + os << "\n" << std::string(line->indent, ' '); + } else {assert(false);} + } + return os.str(); + } + private: + std::vector stream_; +}; // DSL functions -// empty doc/nil constructor -Doc Nil(); -// lift string to text -Doc Text(const std::string& str); -// new line -Doc Line(); -// concat two docs -Doc Concat(const Doc& left, const Doc& right); -// sugar for Concat -Doc operator+(const Doc& left, const Doc& right); -// sugar for Concat with result stored in left -Doc& operator<<(Doc& left, const Doc& right); -// like above, but automatically lifts string to a doc -Doc& operator<<(Doc& left, const std::string& right); -// like above, but converts right to a string -template -Doc& operator<<(Doc& left, const T& right) { - std::ostringstream os; - os << right; - return left << os.str(); -} -// indent a doc -Doc Indent(int indent, const Doc& doc); -// convert doc to a string -std::string Layout(const Doc& doc); +// text constructor +DocAtom Text(const std::string& str); +// line constructor +DocAtom Line(int indent = 0); + // render vectors of docs with a separator. e.g. [1, 2, 3], f -> 1f2f3 Doc PrintVec(const std::vector& vec, const Doc& sep = Text(", ")); // Print constant bool value. diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 9e1862cc76eb..c084dc0fd1c6 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -83,7 +83,7 @@ class TextMetaDataContext { meta_data_[node->type_key()]; int64_t index = static_cast(mvector.size()); mvector.push_back(node); - Doc doc = Nil(); + Doc doc; doc << "meta[" << node->type_key() << "][" << index << "]"; meta_repr_[node] = doc; return meta_repr_[node]; @@ -128,23 +128,23 @@ class PrettyPrinter : * \param expr The expression. */ Doc PrintOptionalInfo(const Expr& expr) { - Doc doc = Nil(); + Doc doc; // additional information in comment. if (annotate_ != nullptr) { - return doc << " # " << annotate_(expr); + return doc << " // " << annotate_(expr); } else if (expr->checked_type_.defined()) { - doc << " # ty="; + doc << " // ty="; return doc << Print(expr->checked_type()); } else { - return Nil(); + return doc; } } // indent a new body // TODO(jmp): indent should be an instance variable of the printer Doc PrintBody(const NodeRef& node, int indent = 2) { - Doc doc = Nil(); - Doc body = Nil(); + Doc doc; + Doc body; doc << "{"; doc << Indent(indent, body << "\n" << PrintScope(node)) << "\n"; doc << "}"; @@ -155,7 +155,7 @@ class PrettyPrinter : // numbers to be reused and prevents hoisted vars from escaping too far Doc PrintScope(const NodeRef& node) { // print in a new scope - doc_stack_.push_back(Nil()); + doc_stack_.push_back({}); // must print first so doc_stack_.back() reference doesn't become stale Doc doc = Print(node); doc = doc_stack_.back() << doc; @@ -164,7 +164,7 @@ class PrettyPrinter : } Doc PrintFinal(const NodeRef& node) { - Doc doc = Nil(); + Doc doc; doc << PrintScope(node); if (!meta_.empty()) { if (show_meta_data_) { @@ -190,13 +190,13 @@ class PrettyPrinter : } else if (node.as_derived()) { return PrintMod(Downcast(node)); } else { - Doc doc = Nil(); + Doc doc; return doc << node; } } Doc TempVar(int n) { - Doc doc = Nil(); + Doc doc; return doc << "%" << n; } @@ -270,8 +270,7 @@ class PrettyPrinter : printed_expr = meta_.GetMetaNode(GetRef(expr.get())); } else if (GNF_ && expr.as()) { // wrap GNFed let in brackets - printed_expr = Nil(); - Doc body = Nil(); + Doc body; printed_expr << "{"; printed_expr << Indent(2, body << "\n" << VisitExpr(expr)) << "\n"; printed_expr << "}"; @@ -328,7 +327,7 @@ class PrettyPrinter : } } // default fall-back, record it as meta node. - Doc doc = Nil(); + Doc doc; return doc << Print(GetRef(op), true) << PrintOptionalInfo(GetRef(op)); } @@ -338,7 +337,7 @@ class PrettyPrinter : for (Expr field : op->fields) { fields.push_back(Print(field)); } - Doc doc = Nil(); + Doc doc; doc << "(" << PrintVec(fields); // conform to python tuple format (1,) if (op->fields.size() == 1) { @@ -348,12 +347,12 @@ class PrettyPrinter : } Doc VisitExpr_(const TupleGetItemNode* op) final { - Doc doc = Nil(); + Doc doc; return doc << Print(op->tuple) << "." << op->index; } Doc VisitExpr_(const IfNode* op) final { - Doc doc = Nil(); + Doc doc; doc << "if (" << Print(op->cond) << ") "; doc << PrintBody(op->true_branch); doc << " else "; @@ -362,7 +361,7 @@ class PrettyPrinter : } Doc VisitExpr_(const LetNode* op) final { - Doc doc = Nil(); + Doc doc; doc << "let " << AllocVar(op->var) << " = " << Print(op->value) << "\n"; // we use a scope here so GNF hoisting doesn't escape too far doc << PrintScope(op->body); @@ -374,7 +373,7 @@ class PrettyPrinter : // Possibly through meta data CHECK_EQ(fn->type_params.size(), 0U) << "generic fn not yet supported"; - Doc doc = Nil(); + Doc doc; doc << prefix << "("; std::vector params; for (Var param : fn->params) { @@ -390,7 +389,7 @@ class PrettyPrinter : } Doc PrintMod(const Module& mod) { - Doc doc = Nil(); + Doc doc; int counter = 0; for (const auto& kv : mod->functions) { std::ostringstream os; @@ -417,7 +416,7 @@ class PrettyPrinter : } Doc VisitExpr_(const CallNode* op) final { - Doc doc = Nil(); + Doc doc; doc << Print(op->op); std::vector args; for (Expr arg : op->args) { @@ -427,29 +426,29 @@ class PrettyPrinter : } Doc VisitExpr_(const RefCreateNode* op) final { - Doc doc = Nil(); + Doc doc; return doc << "ref(" << Print(op->value) << ")"; } Doc VisitExpr_(const RefReadNode* op) final { - Doc doc = Nil(); + Doc doc; return doc << Print(op->ref) << "^"; } Doc VisitExpr_(const RefWriteNode* op) final { - Doc doc = Nil(); + Doc doc; return doc << "(" << Print(op->ref) << " := " << Print(op->value) << ")"; } Doc VisitExpr_(const MatchNode* op) final { // TODO(jmp): Lots of code duplication here because PrintBody and PrintScope don't accept Docs. - Doc doc = Nil(); - Doc body = Nil(); + Doc doc; + Doc body; doc << "match " << Print(op->data) << " "; doc << "{"; std::vector clauses; for (const auto& clause : op->clauses) { - Doc clause_doc = Nil(); + Doc clause_doc; clauses.push_back(clause_doc << Print(clause->lhs) << " -> " << Print(clause->rhs)); } @@ -459,7 +458,7 @@ class PrettyPrinter : } Doc VisitPattern_(const PatternConstructorNode* p) final { - Doc doc = Nil(); + Doc doc; doc << p->constructor->name_hint << "("; std::vector pats; for (const auto& pat : p->patterns) { @@ -502,7 +501,7 @@ class PrettyPrinter : if (node->shape.size() == 0) { return PrintDType(node->dtype); } - Doc doc = Nil(); + Doc doc; doc << "Tensor[("; std::vector shapes; for (NodeRef shape : node->shape) { @@ -521,7 +520,7 @@ class PrettyPrinter : for (Type field : node->fields) { fields.push_back(Print(field)); } - Doc doc = Nil(); + Doc doc; doc << "(" << PrintVec(fields); // conform to python tuple format (1,) if (node->fields.size() == 1) { @@ -531,7 +530,7 @@ class PrettyPrinter : } Doc VisitType_(const FuncTypeNode* node) final { - Doc doc = Nil(); + Doc doc; std::vector arg_types; for (Type arg_type : node->arg_types) { arg_types.push_back(Print(arg_type)); @@ -540,7 +539,7 @@ class PrettyPrinter : } Doc VisitType_(const RefTypeNode* node) final { - Doc doc = Nil(); + Doc doc; return doc << "ref(" << Print(node->value) << ")"; } @@ -567,7 +566,7 @@ class PrettyPrinter : } Doc VisitAttr_(const ArrayNode* op) final { // NOLINT(*) - Doc doc = Nil(); + Doc doc; doc << "["; std::vector arr_vals; for (NodePtr val : op->data) { @@ -626,7 +625,7 @@ class PrettyPrinter::AttrPrinter : public AttrVisitor { template Doc PrintKV(const char* key, const T& value) { - Doc doc = Nil(); + Doc doc; return doc << ", " << key << "=" << value; } @@ -667,8 +666,8 @@ class PrettyPrinter::AttrPrinter : public AttrVisitor { }; Doc PrettyPrinter::PrintAttrs(const Attrs& attrs, const Expr& op) { // NOLINT(*) - if (!attrs.defined()) return Nil(); - Doc doc = Nil(); + Doc doc; + if (!attrs.defined()) return doc; const auto* op_node = op.as(); if (op_node && (attrs->type_index() != op_node->attrs_type_index)) { // fallback @@ -684,10 +683,10 @@ std::string PrettyPrint_(const NodeRef& node, bool show_meta_data, runtime::TypedPackedFunc annotate, bool gnf) { - Doc doc = Nil(); + Doc doc; doc << "v0.0.1" << "\n" << PrettyPrinter(gnf, show_meta_data, annotate).PrintFinal(node); - return Layout(doc); + return doc.str(); } std::string RelayPrint(const NodeRef& node, From c7c8a6764842873c967e2211cbad896e5ecd3398 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Wed, 13 Mar 2019 19:59:23 -0700 Subject: [PATCH 67/73] further simplify interface. remove NOLINTs --- src/relay/ir/doc.cc | 22 ++++++++++++--------- src/relay/ir/doc.h | 14 +++++++------ src/relay/ir/pretty_printer.cc | 36 +++++++++++++++++----------------- 3 files changed, 39 insertions(+), 33 deletions(-) diff --git a/src/relay/ir/doc.cc b/src/relay/ir/doc.cc index da948898687f..01a0b9b8a6d7 100644 --- a/src/relay/ir/doc.cc +++ b/src/relay/ir/doc.cc @@ -11,6 +11,14 @@ namespace tvm { namespace relay { +Doc::Doc(const std::string& str) { + if (str == "\n") { + this->stream_ = {Line()}; + } else { + this->stream_ = {Text(str)}; + } +} + // DSL function implementations // text constructor @@ -31,11 +39,7 @@ Doc& Doc::operator<<(const Doc& right) { // like above, but automatically lifts string to a doc Doc& Doc::operator<<(const std::string& right) { - if (right == "\n") { - return *this << Line(); - } else { - return *this << Text(right); - } + return *this << Doc(right); } // indent a doc @@ -69,17 +73,17 @@ Doc PrintVec(const std::vector& vec, const Doc& sep) { */ Doc PrintBool(bool value) { if (value) { - return Text("True"); + return Doc("True"); } else { - return Text("False"); + return Doc("False"); } } Doc PrintDType(DataType dtype) { - return Text(runtime::TVMType2String(Type2TVMType(dtype))); + return Doc(Text(runtime::TVMType2String(Type2TVMType(dtype)))); } -Doc PrintString(const std::string& value) { // NOLINT(*) +Doc PrintString(const std::string& value) { // TODO(M.K.): add escape. Doc doc; return doc << "\"" << value << "\""; diff --git a/src/relay/ir/doc.h b/src/relay/ir/doc.h index a0bb94cc2145..aa153512a59b 100644 --- a/src/relay/ir/doc.h +++ b/src/relay/ir/doc.h @@ -25,20 +25,21 @@ using DocAtom = std::shared_ptr; struct TextNode : DocAtomNode { std::string str; - TextNode(const std::string& str) : str(str) {} + explicit TextNode(const std::string& str) : str(str) {} }; struct LineNode : DocAtomNode { int indent; - LineNode(int indent) : indent(indent) {} + explicit LineNode(int indent) : indent(indent) {} }; // Doc is a stream-like interface -struct Doc { +class Doc { public: Doc() {} - Doc(const DocAtom& atom) : stream_({atom}) {} + explicit Doc(const DocAtom& atom) : stream_({atom}) {} + explicit Doc(const std::string& str); // Append right to left. Doc& operator<<(const Doc& right); @@ -66,6 +67,7 @@ struct Doc { } return os.str(); } + private: std::vector stream_; }; @@ -78,7 +80,7 @@ DocAtom Text(const std::string& str); DocAtom Line(int indent = 0); // render vectors of docs with a separator. e.g. [1, 2, 3], f -> 1f2f3 -Doc PrintVec(const std::vector& vec, const Doc& sep = Text(", ")); +Doc PrintVec(const std::vector& vec, const Doc& sep = Doc(", ")); // Print constant bool value. Doc PrintBool(bool value); Doc PrintDType(DataType dtype); @@ -100,7 +102,7 @@ Doc PrintConstScalar(DataType dtype, const T* data) { } else { os << dtype << "(" << data[0] << ")"; } - return Text(os.str()); + return Doc(os.str()); } } // namespace relay diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index c084dc0fd1c6..37b6e840b964 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -224,7 +224,7 @@ class PrettyPrinter : } } name_alloc_map_[unique_prefix] = 0; - return Text(unique_prefix); + return Doc(unique_prefix); } /*! @@ -397,22 +397,22 @@ class PrettyPrinter : doc << "\n"; } os << "def @" << kv.first->name_hint; - doc << PrintFunc(Text(os.str()), kv.second); + doc << PrintFunc(Doc(os.str()), kv.second); doc << "\n"; } return doc; } Doc VisitExpr_(const FunctionNode* op) final { - return PrintFunc(Text("fn "), GetRef(op)); + return PrintFunc(Doc("fn "), GetRef(op)); } Doc VisitExpr_(const GlobalVarNode* op) final { - return Text('@' + op->name_hint); + return Doc('@' + op->name_hint); } Doc VisitExpr_(const OpNode* op) final { - return Text(op->name); + return Doc(op->name); } Doc VisitExpr_(const CallNode* op) final { @@ -452,7 +452,7 @@ class PrettyPrinter : clauses.push_back(clause_doc << Print(clause->lhs) << " -> " << Print(clause->rhs)); } - doc << Indent(2, body << "\n" << PrintVec(clauses, Line())) << "\n"; + doc << Indent(2, body << "\n" << PrintVec(clauses, Doc("\n"))) << "\n"; doc << "}"; return doc; } @@ -472,7 +472,7 @@ class PrettyPrinter : } Doc VisitExpr_(const ConstructorNode* n) final { - return Text(n->name_hint); + return Doc(n->name_hint); } //------------------------------------ @@ -491,12 +491,12 @@ class PrettyPrinter : return printed_type; } - Doc VisitTypeDefault_(const Node* node) final { // NOLINT(*) + Doc VisitTypeDefault_(const Node* node) final { // by default always print as meta data return Print(GetRef(node), true); } - Doc VisitType_(const TensorTypeNode* node) final { // NOLINT(*) + Doc VisitType_(const TensorTypeNode* node) final { // scalar type if (node->shape.size() == 0) { return PrintDType(node->dtype); @@ -547,7 +547,7 @@ class PrettyPrinter : // Overload of Attr printing functions //------------------------------------ - Doc PrintAttr(const NodeRef& value, bool meta = false) { // NOLINT(*) + Doc PrintAttr(const NodeRef& value, bool meta = false) { if (value.defined()) { Doc printed_attr; if (meta) { @@ -557,15 +557,15 @@ class PrettyPrinter : } return printed_attr; } else { - return Text("None"); + return Doc("None"); } } - Doc VisitAttrDefault_(const Node* op) final { // NOLINT(*) + Doc VisitAttrDefault_(const Node* op) final { return PrintAttr(GetRef(op), true); } - Doc VisitAttr_(const ArrayNode* op) final { // NOLINT(*) + Doc VisitAttr_(const ArrayNode* op) final { Doc doc; doc << "["; std::vector arr_vals; @@ -577,19 +577,19 @@ class PrettyPrinter : return doc; } - Doc VisitAttr_(const ir::IntImm* op) final { // NOLINT(*) + Doc VisitAttr_(const ir::IntImm* op) final { return PrintConstScalar(op->type, &(op->value)); } - Doc VisitAttr_(const ir::UIntImm* op) final { // NOLINT(*) + Doc VisitAttr_(const ir::UIntImm* op) final { return PrintConstScalar(op->type, &(op->value)); } - Doc VisitAttr_(const ir::FloatImm* op) final { // NOLINT(*) + Doc VisitAttr_(const ir::FloatImm* op) final { return PrintConstScalar(op->type, &(op->value)); } - Doc VisitAttr_(const ir::StringImm* op) final { // NOLINT(*) + Doc VisitAttr_(const ir::StringImm* op) final { return PrintString(op->value); } @@ -665,7 +665,7 @@ class PrettyPrinter::AttrPrinter : public AttrVisitor { PrettyPrinter* parent_; }; -Doc PrettyPrinter::PrintAttrs(const Attrs& attrs, const Expr& op) { // NOLINT(*) +Doc PrettyPrinter::PrintAttrs(const Attrs& attrs, const Expr& op) { Doc doc; if (!attrs.defined()) return doc; const auto* op_node = op.as(); From 81ea258efd24bb88b9abc872c37248ed707948d7 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Wed, 13 Mar 2019 20:54:10 -0700 Subject: [PATCH 68/73] fix bugs and docs --- src/relay/ir/doc.cc | 34 +++++++++------------- src/relay/ir/doc.h | 21 +++++++------ src/relay/ir/pretty_printer.cc | 4 +-- tests/python/relay/test_ir_text_printer.py | 4 +-- 4 files changed, 27 insertions(+), 36 deletions(-) diff --git a/src/relay/ir/doc.cc b/src/relay/ir/doc.cc index 01a0b9b8a6d7..0c3f77bc42b2 100644 --- a/src/relay/ir/doc.cc +++ b/src/relay/ir/doc.cc @@ -11,6 +11,16 @@ namespace tvm { namespace relay { +// Text constructor +DocAtom Text(const std::string& str) { + return std::make_shared(str); +} + +// Line constructor +DocAtom Line(int indent = 0) { + return std::make_shared(indent); +} + Doc::Doc(const std::string& str) { if (str == "\n") { this->stream_ = {Line()}; @@ -21,41 +31,27 @@ Doc::Doc(const std::string& str) { // DSL function implementations -// text constructor -DocAtom Text(const std::string& str) { - return std::make_shared(str); -} - -// line constructor -DocAtom Line(int indent) { - return std::make_shared(indent); -} - -// sugar for Concat with result stored in left Doc& Doc::operator<<(const Doc& right) { this->stream_.insert(this->stream_.end(), right.stream_.begin(), right.stream_.end()); return *this; } -// like above, but automatically lifts string to a doc Doc& Doc::operator<<(const std::string& right) { return *this << Doc(right); } -// indent a doc Doc Indent(int indent, const Doc& doc) { Doc ret; for (auto atom : doc.stream_) { if (auto text = std::dynamic_pointer_cast(atom)) { - ret << atom; + ret.stream_.push_back(text); } else if (auto line = std::dynamic_pointer_cast(atom)) { - ret << Line(indent + line->indent); + ret.stream_.push_back(Line(indent + line->indent)); } else {assert(false);} } return ret; } -// render vectors of docs with a separator. e.g. [1, 2, 3], f -> 1f2f3 Doc PrintVec(const std::vector& vec, const Doc& sep) { Doc seq; if (vec.size() != 0) { @@ -67,10 +63,6 @@ Doc PrintVec(const std::vector& vec, const Doc& sep) { return seq; } -/*! - * \brief Print constant bool value. - * \param value The value to be printed. - */ Doc PrintBool(bool value) { if (value) { return Doc("True"); @@ -80,7 +72,7 @@ Doc PrintBool(bool value) { } Doc PrintDType(DataType dtype) { - return Doc(Text(runtime::TVMType2String(Type2TVMType(dtype)))); + return Doc(runtime::TVMType2String(Type2TVMType(dtype))); } Doc PrintString(const std::string& value) { diff --git a/src/relay/ir/doc.h b/src/relay/ir/doc.h index aa153512a59b..04b49fe7cd85 100644 --- a/src/relay/ir/doc.h +++ b/src/relay/ir/doc.h @@ -2,7 +2,9 @@ * Copyright (c) 2019 by Contributors * \file tvm/relay/doc.h * \brief Doc ADT used for pretty printing. - * Based on Section 1 of https://homepages.inf.ed.ac.uk/wadler/papers/prettier/prettier.pdf. + * Based on Section 1 of + * https://homepages.inf.ed.ac.uk/wadler/papers/prettier/prettier.pdf, but with + * a vector instead of an implicitly linked list. */ #ifndef TVM_RELAY_IR_DOC_H_ #define TVM_RELAY_IR_DOC_H_ @@ -15,7 +17,7 @@ namespace tvm { namespace relay { -// ADT +// Doc Atom ADT struct DocAtomNode { virtual ~DocAtomNode() = default; }; @@ -38,12 +40,11 @@ struct LineNode : DocAtomNode { class Doc { public: Doc() {} - explicit Doc(const DocAtom& atom) : stream_({atom}) {} explicit Doc(const std::string& str); - // Append right to left. + // Append right to this. Doc& operator<<(const Doc& right); - // like above, but automatically lifts string to a doc atom + // like above, but automatically lifts string to a Doc Doc& operator<<(const std::string& right); // like above, but converts right to a string first template @@ -56,6 +57,7 @@ class Doc { // indent a doc stream friend Doc Indent(int indent, const Doc& doc); + // Wadler's `layout` std::string str() { std::ostringstream os; for (auto atom : stream_) { @@ -74,16 +76,13 @@ class Doc { // DSL functions -// text constructor -DocAtom Text(const std::string& str); -// line constructor -DocAtom Line(int indent = 0); - // render vectors of docs with a separator. e.g. [1, 2, 3], f -> 1f2f3 Doc PrintVec(const std::vector& vec, const Doc& sep = Doc(", ")); -// Print constant bool value. +// Print a constant bool value. Doc PrintBool(bool value); +// Print a data type. Doc PrintDType(DataType dtype); +// Print a string. Doc PrintString(const std::string& value); /*! * \brief special method to print out const scalar diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 37b6e840b964..f89c83756a05 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -155,7 +155,7 @@ class PrettyPrinter : // numbers to be reused and prevents hoisted vars from escaping too far Doc PrintScope(const NodeRef& node) { // print in a new scope - doc_stack_.push_back({}); + doc_stack_.push_back(Doc()); // must print first so doc_stack_.back() reference doesn't become stale Doc doc = Print(node); doc = doc_stack_.back() << doc; @@ -241,7 +241,7 @@ class PrettyPrinter : Doc val = GetUniqueName("%" + name); // still print if ir is malformed, but show the error. if (memo_.count(var)) { - val << Text("-malformed-ir"); + val << "-malformed-ir"; } memo_[var] = val; if (var->type_annotation.defined()) { diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index f252194f8c0c..626436d9573f 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -33,8 +33,8 @@ def test_env(): text = env.astext() assert "def @myf" in text assert "def @myf" in str(env) - assert "%1 = add(%0, %0) # ty=float32" in text - assert "%1 = add(%0, %0) # ty=float32" in str(env) + assert "%1 = add(%0, %0) // ty=float32" in text + assert "%1 = add(%0, %0) // ty=float32" in str(env) show(env.astext(annotate=lambda x: str(x.checked_type.dtype))) show(text) From fd43b918413fabb42f7df1adc29000c53eae1588 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Wed, 13 Mar 2019 21:13:37 -0700 Subject: [PATCH 69/73] move str to doc.cc --- src/relay/ir/doc.cc | 12 ++++++++++++ src/relay/ir/doc.h | 12 +----------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/src/relay/ir/doc.cc b/src/relay/ir/doc.cc index 0c3f77bc42b2..27b17acf4dca 100644 --- a/src/relay/ir/doc.cc +++ b/src/relay/ir/doc.cc @@ -52,6 +52,18 @@ Doc Indent(int indent, const Doc& doc) { return ret; } +std::string Doc::str() { + std::ostringstream os; + for (auto atom : this->stream_) { + if (auto text = std::dynamic_pointer_cast(atom)) { + os << text->str; + } else if (auto line = std::dynamic_pointer_cast(atom)) { + os << "\n" << std::string(line->indent, ' '); + } else {assert(false);} + } + return os.str(); +} + Doc PrintVec(const std::vector& vec, const Doc& sep) { Doc seq; if (vec.size() != 0) { diff --git a/src/relay/ir/doc.h b/src/relay/ir/doc.h index 04b49fe7cd85..db94e654d4a5 100644 --- a/src/relay/ir/doc.h +++ b/src/relay/ir/doc.h @@ -58,17 +58,7 @@ class Doc { friend Doc Indent(int indent, const Doc& doc); // Wadler's `layout` - std::string str() { - std::ostringstream os; - for (auto atom : stream_) { - if (auto text = std::dynamic_pointer_cast(atom)) { - os << text->str; - } else if (auto line = std::dynamic_pointer_cast(atom)) { - os << "\n" << std::string(line->indent, ' '); - } else {assert(false);} - } - return os.str(); - } + std::string str(); private: std::vector stream_; From 2eb7f6e5d1267e77de77213abc6f8cade3a05d05 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Thu, 14 Mar 2019 15:49:23 -0700 Subject: [PATCH 70/73] remove stale documentation --- include/tvm/relay/expr.h | 1 - 1 file changed, 1 deletion(-) diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 41fd0b41c9a7..4513022687f8 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -549,7 +549,6 @@ inline const TTypeNode* ExprNode::type_as() const { * \param show_meta_data Whether to print meta data section. * \param annotate An optional callback function for attaching * additional comment block to an expr. - * \param gnf Whether to print in GNF. * \return The text representation. */ std::string RelayPrint(const NodeRef& node, From c9e7d8829a9f4928388a8ade4f125cb0f7e96a31 Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Sat, 16 Mar 2019 17:11:20 -0700 Subject: [PATCH 71/73] Update src/relay/ir/pretty_printer.cc Co-Authored-By: joshpoll --- src/relay/ir/pretty_printer.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index f89c83756a05..61dc7c4b7a28 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -90,7 +90,7 @@ class TextMetaDataContext { } /*! * \brief Get the metadata section in json format. - * \return the meta datastring. + * \return the meta data string. */ std::string GetMetaSection() const { if (meta_data_.size() == 0) return std::string(); From a528fe28840d08ba0b0591489beae0761a3e8a12 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Sun, 17 Mar 2019 21:02:20 -0700 Subject: [PATCH 72/73] address feedback --- src/relay/ir/doc.cc | 1 + src/relay/ir/pretty_printer.cc | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/ir/doc.cc b/src/relay/ir/doc.cc index 27b17acf4dca..ef98bd8ed1ed 100644 --- a/src/relay/ir/doc.cc +++ b/src/relay/ir/doc.cc @@ -32,6 +32,7 @@ Doc::Doc(const std::string& str) { // DSL function implementations Doc& Doc::operator<<(const Doc& right) { + assert(this != &right); this->stream_.insert(this->stream_.end(), right.stream_.begin(), right.stream_.end()); return *this; } diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index f89c83756a05..0717b69ca234 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -181,7 +181,6 @@ class PrettyPrinter : Doc PrintAttrs(const Attrs& attrs, const Expr& op); - // note: gnf flag is only one level deep Doc Print(const NodeRef& node, bool meta = false) { if (node.as_derived()) { return PrintExpr(Downcast(node), meta); From 7f9921dbe7981f921ea64b66aa7c58c2e4729ad0 Mon Sep 17 00:00:00 2001 From: Josh Pollock Date: Tue, 19 Mar 2019 19:39:27 -0700 Subject: [PATCH 73/73] minor comment changes and bump ci --- src/relay/ir/doc.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/relay/ir/doc.h b/src/relay/ir/doc.h index db94e654d4a5..b9a82555c479 100644 --- a/src/relay/ir/doc.h +++ b/src/relay/ir/doc.h @@ -44,9 +44,9 @@ class Doc { // Append right to this. Doc& operator<<(const Doc& right); - // like above, but automatically lifts string to a Doc + // Like above, but automatically lifts string to a Doc. Doc& operator<<(const std::string& right); - // like above, but converts right to a string first + // Like above, but converts right to a string first. template Doc& operator<<(const T& right) { std::ostringstream os; @@ -54,7 +54,7 @@ class Doc { return *this << os.str(); } - // indent a doc stream + // Indent a doc stream. friend Doc Indent(int indent, const Doc& doc); // Wadler's `layout` @@ -66,7 +66,7 @@ class Doc { // DSL functions -// render vectors of docs with a separator. e.g. [1, 2, 3], f -> 1f2f3 +// Render vectors of docs with a separator. e.g. PrintVec([1, 2, 3], f) -> 1f2f3 Doc PrintVec(const std::vector& vec, const Doc& sep = Doc(", ")); // Print a constant bool value. Doc PrintBool(bool value);