From 72ccd9aca017e51f0d76669fa79c0b8ecf53f5c1 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sat, 31 Aug 2019 19:44:15 -0700 Subject: [PATCH 01/14] Module refactor --- include/tvm/relay/type.h | 6 ++++++ src/relay/ir/expr_functor.cc | 1 - src/relay/ir/module.cc | 17 +++++++++++++-- src/relay/pass/type_infer.cc | 39 ++++++++++++----------------------- src/relay/pass/type_solver.cc | 13 +++++++++--- src/relay/pass/type_solver.h | 4 +++- 6 files changed, 47 insertions(+), 33 deletions(-) diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index d509fde2a875..16e36785c533 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -410,6 +410,12 @@ class TypeReporterNode : public Node { */ TVM_DLL virtual void SetLocation(const NodeRef& ref) = 0; + /*! + * \brief Retrieve the current global module. + * \return The global module. + */ + TVM_DLL virtual Module GetModule() = 0; + // solver is not serializable. void VisitAttrs(tvm::AttrVisitor* v) final {} diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index da9f7b8d19b9..6a2db6b46d64 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -444,7 +444,6 @@ Expr Bind(const Expr& expr, const tvm::Map& args_map) { } } - TVM_REGISTER_API("relay._expr.Bind") .set_body([](TVMArgs args, TVMRetValue* ret) { NodeRef input = args[0]; diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index dbaea7f02fc7..19ce51bd0981 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -33,11 +33,22 @@ namespace relay { using tvm::IRPrinter; using namespace runtime; +void ModuleNode::RegisterBuiltins() { + std::cout << "IN BUILTINS" << std::endl; + // Add storage type. + auto storage = GlobalTypeVarNode::make("Storage", Kind::kAdtHandle); + auto type_data = TypeDataNode::make(storage, {}, {}); + this->AddDef(storage, type_data); +} + Module ModuleNode::make(tvm::Map global_funcs, tvm::Map global_type_defs) { auto n = make_node(); n->functions = std::move(global_funcs); n->type_definitions = std::move(global_type_defs); + n->global_type_var_map_ = {}; + n->global_var_map_ = {}; + n->constructor_tag_map_ = {}; for (const auto& kv : n->functions) { // set global var map @@ -85,7 +96,9 @@ void ModuleNode::AddUnchecked(const GlobalVar& var, } GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) const { + CHECK(global_type_var_map_.defined()); auto it = global_type_var_map_.find(name); + std::cout << "pass here"; CHECK(it != global_type_var_map_.end()) << "Cannot find global type var " << name << " in the Module"; return (*it).second; @@ -160,8 +173,8 @@ void ModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData& void ModuleNode::AddDef(const GlobalTypeVar& var, const TypeData& type) { this->type_definitions.Set(var, type); // set global type var map - CHECK(!global_type_var_map_.count(var->var->name_hint)) - << "Duplicate global type definition name " << var->var->name_hint; + // CHECK(!global_type_var_map_.count(var->var->name_hint)) + // << "Duplicate global type definition name " << var->var->name_hint; global_type_var_map_.Set(var->var->name_hint, var); RegisterConstructors(var, type); diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index f7de2a927c66..c61d41d4717d 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -108,7 +108,8 @@ class TypeInferencer : private ExprFunctor, explicit TypeInferencer(Module mod, GlobalVar current_func) : mod_(mod), current_func_(current_func), - err_reporter(), solver_(current_func, &this->err_reporter) { + err_reporter(), solver_(current_func, mod, &this->err_reporter) { + CHECK(mod.defined()) << "internal error: Module must be set in the type inferencer"; } // inference the type of expr. @@ -801,36 +802,22 @@ void EnsureCheckedType(const Expr& e) { AllCheckTypePopulated().VisitExpr(e); } -Expr InferType(const Expr& expr, const Module& mod_ref) { - if (!mod_ref.defined()) { - Module mod = ModuleNode::FromExpr(expr); - // NB(@jroesch): By adding the expression to the module we will - // type check it anyway; afterwards we can just recover type - // from the type-checked function to avoid doing unnecessary work. - - Function func = mod->Lookup("main"); - - // FromExpr wraps a naked expression as a function, we will unbox - // it here. - if (expr.as()) { - return std::move(func); - } else { - return func->body; - } - } else { - auto e = TypeInferencer(mod_ref, mod_ref->GetGlobalVar("main")).Infer(expr); - CHECK(WellFormed(e)); - auto free_tvars = FreeTypeVars(e, mod_ref); - CHECK(free_tvars.size() == 0) - << "Found unbound type variables in " << e << ": " << free_tvars; - EnsureCheckedType(e); - return e; - } +Expr InferType(const Expr& expr, const Module& mod) { + auto main = mod->GetGlobalVar("main"); + auto inferencer = TypeInferencer(mod, main); + auto e = inferencer.Infer(expr); + CHECK(WellFormed(e)); + auto free_tvars = FreeTypeVars(e, mod); + CHECK(free_tvars.size() == 0) + << "Found unbound type variables in " << e << ": " << free_tvars; + EnsureCheckedType(e); + return e; } Function InferType(const Function& func, const Module& mod, const GlobalVar& var) { + CHECK(mod.defined()) << "internal error: module must be set for type inference"; Function func_copy = Function(make_node(*func.operator->())); func_copy->checked_type_ = func_copy->func_type_annotation(); mod->AddUnchecked(var, func_copy); diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index 38870762d840..dd38f97555ec 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -61,6 +61,10 @@ class TypeSolver::Reporter : public TypeReporterNode { location = ref; } + TVM_DLL Module GetModule() final { + return this->solver_->module_; + } + private: /*! \brief The location to report unification errors at. */ mutable NodeRef location; @@ -512,10 +516,13 @@ class TypeSolver::Merger : public TypeFunctor { }; // constructor -TypeSolver::TypeSolver(const GlobalVar ¤t_func, ErrorReporter* err_reporter) +TypeSolver::TypeSolver(const GlobalVar& current_func, const Module& module, ErrorReporter* err_reporter) : reporter_(make_node(this)), current_func(current_func), - err_reporter_(err_reporter) { + err_reporter_(err_reporter), + module_(module) { + CHECK(module_.defined()) << + "internal error: module must be defined"; } // destructor @@ -639,7 +646,7 @@ TVM_REGISTER_API("relay._analysis._test_type_solver") using runtime::PackedFunc; using runtime::TypedPackedFunc; ErrorReporter *err_reporter = new ErrorReporter(); - auto solver = std::make_shared(GlobalVarNode::make("test"), err_reporter); + auto solver = std::make_shared(GlobalVarNode::make("test"), Module(), err_reporter); auto mod = [solver, err_reporter](std::string name) -> PackedFunc { if (name == "Solve") { diff --git a/src/relay/pass/type_solver.h b/src/relay/pass/type_solver.h index 28579633c1c6..4a6d2cfa7756 100644 --- a/src/relay/pass/type_solver.h +++ b/src/relay/pass/type_solver.h @@ -63,7 +63,7 @@ using common::LinkedList; */ class TypeSolver { public: - TypeSolver(const GlobalVar& current_func, ErrorReporter* err_reporter); + TypeSolver(const GlobalVar& current_func, const Module& _mod, ErrorReporter* err_reporter); ~TypeSolver(); /*! * \brief Add a type constraint to the solver. @@ -179,6 +179,8 @@ class TypeSolver { GlobalVar current_func; /*! \brief Error reporting. */ ErrorReporter* err_reporter_; + /*! \brief The module. */ + Module module_; /*! * \brief GetTypeNode that is corresponds to t. From b9a5f4cc981f2956db4ba1fd270371297f256911 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 5 Sep 2019 16:47:56 -0700 Subject: [PATCH 02/14] Add load module --- include/tvm/relay/module.h | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/include/tvm/relay/module.h b/include/tvm/relay/module.h index 3496c8815467..9516ffed434b 100644 --- a/include/tvm/relay/module.h +++ b/include/tvm/relay/module.h @@ -185,6 +185,18 @@ class ModuleNode : public RelayNode { */ TVM_DLL void Update(const Module& other); + /*! + * \brief Import Relay code from the file at path. + * \param path The path of the Relay code to import. + */ + TVM_DLL void Import(const std::string& path); + + /*! + * \brief Import Relay code from the file at path, relative to the standard library. + * \param path The path of the Relay code to import. + */ + TVM_DLL void ImportStd(const std::string& path); + /*! \brief Construct a module from a standalone expression. * * Allows one to optionally pass a global function map and @@ -222,6 +234,11 @@ class ModuleNode : public RelayNode { * for convenient access */ std::unordered_map constructor_tag_map_; + + /*! \brief The files previously imported, required to ensure + importing is idempotent for each module. + */ + std::unordered_set import_set_; }; struct Module : public NodeRef { From 175921903b3f7fd34f781f04321650a93805097e Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 5 Sep 2019 18:16:27 -0700 Subject: [PATCH 03/14] Add support for idempotent import --- include/tvm/relay/expr.h | 1 + include/tvm/relay/module.h | 1 + python/tvm/relay/__init__.py | 7 +++- python/tvm/relay/{ => std}/prelude.rly | 0 src/relay/ir/module.cc | 57 +++++++++++++++++++++----- 5 files changed, 55 insertions(+), 11 deletions(-) rename python/tvm/relay/{ => std}/prelude.rly (100%) diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index c5cd6bb9e4ab..b1b8d6a7154e 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -575,6 +575,7 @@ std::string PrettyPrint(const NodeRef& node); std::string AsText(const NodeRef& node, bool show_meta_data = true, runtime::TypedPackedFunc annotate = nullptr); + } // namespace relay } // namespace tvm #endif // TVM_RELAY_EXPR_H_ diff --git a/include/tvm/relay/module.h b/include/tvm/relay/module.h index 9516ffed434b..135c30337f49 100644 --- a/include/tvm/relay/module.h +++ b/include/tvm/relay/module.h @@ -252,6 +252,7 @@ struct Module : public NodeRef { using ContainerType = ModuleNode; }; +Module FromText(std::string source, const std::string& source_name); } // namespace relay } // namespace tvm diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 092cd01d1d4a..73ed593f4031 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -17,6 +17,7 @@ # pylint: disable=wildcard-import, redefined-builtin, invalid-name """The Relay IR namespace containing the IR definition and compiler.""" from __future__ import absolute_import +import os from sys import setrecursionlimit from ..api import register_func from . import base @@ -50,7 +51,7 @@ from . import annotation from . import vision from . import contrib -from . import image +from . import imagers from . import frontend from . import backend from . import quantize @@ -63,6 +64,10 @@ # Required to traverse large programs setrecursionlimit(10000) +@register_func("tvm.relay.std_path") +def _std_path(): + return os.path.dirname(os.path.abspath(__file__)) + # Span Span = base.Span diff --git a/python/tvm/relay/prelude.rly b/python/tvm/relay/std/prelude.rly similarity index 100% rename from python/tvm/relay/prelude.rly rename to python/tvm/relay/std/prelude.rly diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index 19ce51bd0981..936dadd4bd5e 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -26,6 +26,7 @@ #include #include #include +#include namespace tvm { namespace relay { @@ -33,14 +34,6 @@ namespace relay { using tvm::IRPrinter; using namespace runtime; -void ModuleNode::RegisterBuiltins() { - std::cout << "IN BUILTINS" << std::endl; - // Add storage type. - auto storage = GlobalTypeVarNode::make("Storage", Kind::kAdtHandle); - auto type_data = TypeDataNode::make(storage, {}, {}); - this->AddDef(storage, type_data); -} - Module ModuleNode::make(tvm::Map global_funcs, tvm::Map global_type_defs) { auto n = make_node(); @@ -173,8 +166,9 @@ void ModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData& void ModuleNode::AddDef(const GlobalTypeVar& var, const TypeData& type) { this->type_definitions.Set(var, type); // set global type var map - // CHECK(!global_type_var_map_.count(var->var->name_hint)) - // << "Duplicate global type definition name " << var->var->name_hint; + CHECK(!global_type_var_map_.count(var->var->name_hint)) + << "Duplicate global type definition name " << var->var->name_hint; + global_type_var_map_.Set(var->var->name_hint, var); RegisterConstructors(var, type); @@ -254,6 +248,39 @@ Module ModuleNode::FromExpr( return mod; } +void ModuleNode::Import(const std::string& path) { + if (this->import_set_.count(path) == 0) { + this->import_set_.insert(path); + std::fstream src_file(path, std::fstream::in); + std::string file_contents { + std::istreambuf_iterator(src_file), + std::istreambuf_iterator() }; + auto mod_to_import = FromText(file_contents, path); + + for (auto func : mod_to_import->functions) { + this->Add(func.first, func.second, false); + } + + for (auto type : mod_to_import->type_definitions) { + this->AddDef(type.first, type.second); + } + } +} + +void ModuleNode::ImportStd(const std::string& path) { + auto* f = tvm::runtime::Registry::Get("tvm.relay.std_path"); + CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path."; + std::string std_path = (*f)(); + return this->Import(std_path + path); +} + +Module FromText(std::string source, const std::string& source_name) { + auto* f = tvm::runtime::Registry::Get("relay.fromtext"); + CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path."; + Module mod = (*f)(source, source_name); + return mod; +} + TVM_REGISTER_NODE_TYPE(ModuleNode); TVM_REGISTER_API("relay._make.Module") @@ -333,6 +360,16 @@ TVM_REGISTER_API("relay._module.Module_Update") mod->Update(from); }); +TVM_REGISTER_API("relay._module.Module_Import") +.set_body_typed([](Module mod, std::string path) { + mod->Import(path); +}); + +TVM_REGISTER_API("relay._module.Module_ImportStd") +.set_body_typed([](Module mod, std::string path) { + mod->ImportStd(path); +});; + TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch( [](const ModuleNode *node, tvm::IRPrinter *p) { From 67bde249348b7c7392f019749a629d95cbae583f Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 5 Sep 2019 18:20:18 -0700 Subject: [PATCH 04/14] Tweak load paths --- python/tvm/relay/__init__.py | 7 +++++-- python/tvm/relay/prelude.py | 5 +++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 73ed593f4031..bd245dea472c 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -51,7 +51,7 @@ from . import annotation from . import vision from . import contrib -from . import imagers +from . import image from . import frontend from . import backend from . import quantize @@ -64,9 +64,12 @@ # Required to traverse large programs setrecursionlimit(10000) +__STD_PATH__ = os.path.join(os.path.dirname(os.path.realpath(__file__)), "std") + @register_func("tvm.relay.std_path") def _std_path(): - return os.path.dirname(os.path.abspath(__file__)) + global __STD_PATH__ + return __STD_PATH__ # Span Span = base.Span diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index f9a7d3dcaf37..f2cf40174645 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -23,8 +23,9 @@ from .adt import Constructor, TypeData, Clause, Match from .adt import PatternConstructor, PatternVar, PatternWildcard, PatternTuple from .parser import fromtext -__PRELUDE_PATH__ = os.path.dirname(os.path.realpath(__file__)) + from .module import Module +from . import __STD_PATH__ class Prelude: """Contains standard definitions.""" @@ -479,7 +480,7 @@ def load_prelude(self): Parses the portions of the Prelude written in Relay's text format and adds them to the module. """ - prelude_file = os.path.join(__PRELUDE_PATH__, "prelude.rly") + prelude_file = os.path.join(__STD_PATH__, "prelude.rly") with open(prelude_file) as prelude: prelude = fromtext(prelude.read()) self.mod.update(prelude) From 4bb1959d04b6021ce08cc057e33fd278465a587e Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 5 Sep 2019 18:22:41 -0700 Subject: [PATCH 05/14] Move path around --- python/tvm/relay/__init__.py | 7 ------- python/tvm/relay/module.py | 9 +++++++++ python/tvm/relay/prelude.py | 3 +-- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index bd245dea472c..ceb98c4d251e 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -64,13 +64,6 @@ # Required to traverse large programs setrecursionlimit(10000) -__STD_PATH__ = os.path.join(os.path.dirname(os.path.realpath(__file__)), "std") - -@register_func("tvm.relay.std_path") -def _std_path(): - global __STD_PATH__ - return __STD_PATH__ - # Span Span = base.Span diff --git a/python/tvm/relay/module.py b/python/tvm/relay/module.py index e0511a257e6d..266387714a16 100644 --- a/python/tvm/relay/module.py +++ b/python/tvm/relay/module.py @@ -17,11 +17,20 @@ # pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import """A global module storing everything needed to interpret or compile a Relay program.""" from .base import register_relay_node, RelayNode +from .. import register_func from .._ffi import base as _base from . import _make from . import _module from . import expr as _expr from . import ty as _ty +import os + +__STD_PATH__ = os.path.join(os.path.dirname(os.path.realpath(__file__)), "std") + +@register_func("tvm.relay.std_path") +def _std_path(): + global __STD_PATH__ + return __STD_PATH__ @register_relay_node class Module(RelayNode): diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index f2cf40174645..8a21bf9d178e 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -24,8 +24,7 @@ from .adt import PatternConstructor, PatternVar, PatternWildcard, PatternTuple from .parser import fromtext -from .module import Module -from . import __STD_PATH__ +from .module import Module, __STD_PATH__ class Prelude: """Contains standard definitions.""" From 0f7880b261174515cc3c36ca945998a7b1d08f3c Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 5 Sep 2019 18:27:49 -0700 Subject: [PATCH 06/14] Expose C++ import functions in Python --- python/tvm/relay/module.py | 6 ++++++ python/tvm/relay/prelude.py | 10 ++++------ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/module.py b/python/tvm/relay/module.py index 266387714a16..f3553c23ae49 100644 --- a/python/tvm/relay/module.py +++ b/python/tvm/relay/module.py @@ -211,3 +211,9 @@ def from_expr(expr, functions=None, type_defs=None): funcs = functions if functions is not None else {} defs = type_defs if type_defs is not None else {} return _module.Module_FromExpr(expr, funcs, defs) + + def _import(self, file_to_import): + return _module.Module_Import(self, file_to_import) + + def import_std(self, file_to_import): + return _module.Module_ImportStd(self, file_to_import) diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index 8a21bf9d178e..bb44ce76d9cb 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -479,12 +479,10 @@ def load_prelude(self): Parses the portions of the Prelude written in Relay's text format and adds them to the module. """ - prelude_file = os.path.join(__STD_PATH__, "prelude.rly") - with open(prelude_file) as prelude: - prelude = fromtext(prelude.read()) - self.mod.update(prelude) - self.id = self.mod.get_global_var("id") - self.compose = self.mod.get_global_var("compose") + # TODO(@jroesch): we should remove this helper when we port over prelude + self.mod.import_std("prelude.rly") + self.id = self.mod.get_global_var("id") + self.compose = self.mod.get_global_var("compose") def __init__(self, mod=None): From a61d209b699f28305b0cc3d940225f1f6f008870 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 5 Sep 2019 18:40:13 -0700 Subject: [PATCH 07/14] Fix import --- src/relay/ir/module.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index 936dadd4bd5e..45acbd811640 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -249,6 +249,7 @@ Module ModuleNode::FromExpr( } void ModuleNode::Import(const std::string& path) { + LOG(INFO) << "Importing: " << path; if (this->import_set_.count(path) == 0) { this->import_set_.insert(path); std::fstream src_file(path, std::fstream::in); @@ -271,7 +272,7 @@ void ModuleNode::ImportStd(const std::string& path) { auto* f = tvm::runtime::Registry::Get("tvm.relay.std_path"); CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path."; std::string std_path = (*f)(); - return this->Import(std_path + path); + return this->Import(std_path + "/" + path); } Module FromText(std::string source, const std::string& source_name) { From 95d1a0fc83e81a2723fc173a761c921057c597e7 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 5 Sep 2019 19:04:00 -0700 Subject: [PATCH 08/14] Add doc string --- include/tvm/relay/module.h | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/include/tvm/relay/module.h b/include/tvm/relay/module.h index 135c30337f49..3a5fefe13e10 100644 --- a/include/tvm/relay/module.h +++ b/include/tvm/relay/module.h @@ -252,6 +252,11 @@ struct Module : public NodeRef { using ContainerType = ModuleNode; }; +/*! \brief Parse Relay source into a module. + * \param source A string of Relay source code. + * \param source_name The name of the source file. + * \return A Relay module. + */ Module FromText(std::string source, const std::string& source_name); } // namespace relay From 2cf670252bf429b65f3e2d6f86a9f2d55e5fa569 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 10 Sep 2019 15:28:50 -0500 Subject: [PATCH 09/14] Fix --- include/tvm/relay/module.h | 9 +++++++-- python/tvm/relay/module.py | 4 ++-- python/tvm/relay/prelude.py | 2 +- src/relay/ir/module.cc | 9 ++++----- 4 files changed, 14 insertions(+), 10 deletions(-) diff --git a/include/tvm/relay/module.h b/include/tvm/relay/module.h index 3a5fefe13e10..7d3323ce99cd 100644 --- a/include/tvm/relay/module.h +++ b/include/tvm/relay/module.h @@ -188,6 +188,11 @@ class ModuleNode : public RelayNode { /*! * \brief Import Relay code from the file at path. * \param path The path of the Relay code to import. + * + * \note The path resolution behavior is standard, + * if abosolute will be the absolute file, if + * relative it will be resovled against the current + * working directory. */ TVM_DLL void Import(const std::string& path); @@ -195,7 +200,7 @@ class ModuleNode : public RelayNode { * \brief Import Relay code from the file at path, relative to the standard library. * \param path The path of the Relay code to import. */ - TVM_DLL void ImportStd(const std::string& path); + TVM_DLL void ImportFromStd(const std::string& path); /*! \brief Construct a module from a standalone expression. * @@ -257,7 +262,7 @@ struct Module : public NodeRef { * \param source_name The name of the source file. * \return A Relay module. */ -Module FromText(std::string source, const std::string& source_name); +Module FromText(const std::string& source, const std::string& source_name); } // namespace relay } // namespace tvm diff --git a/python/tvm/relay/module.py b/python/tvm/relay/module.py index f3553c23ae49..14dec664c965 100644 --- a/python/tvm/relay/module.py +++ b/python/tvm/relay/module.py @@ -215,5 +215,5 @@ def from_expr(expr, functions=None, type_defs=None): def _import(self, file_to_import): return _module.Module_Import(self, file_to_import) - def import_std(self, file_to_import): - return _module.Module_ImportStd(self, file_to_import) + def import_from_std(self, file_to_import): + return _module.Module_ImportFromStd(self, file_to_import) diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index bb44ce76d9cb..c8935e9282b7 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -480,7 +480,7 @@ def load_prelude(self): them to the module. """ # TODO(@jroesch): we should remove this helper when we port over prelude - self.mod.import_std("prelude.rly") + self.mod.import_from_std("prelude.rly") self.id = self.mod.get_global_var("id") self.compose = self.mod.get_global_var("compose") diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index 45acbd811640..e842326b03ac 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -91,7 +91,6 @@ void ModuleNode::AddUnchecked(const GlobalVar& var, GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) const { CHECK(global_type_var_map_.defined()); auto it = global_type_var_map_.find(name); - std::cout << "pass here"; CHECK(it != global_type_var_map_.end()) << "Cannot find global type var " << name << " in the Module"; return (*it).second; @@ -268,14 +267,14 @@ void ModuleNode::Import(const std::string& path) { } } -void ModuleNode::ImportStd(const std::string& path) { +void ModuleNode::ImportFromStd(const std::string& path) { auto* f = tvm::runtime::Registry::Get("tvm.relay.std_path"); CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path."; std::string std_path = (*f)(); return this->Import(std_path + "/" + path); } -Module FromText(std::string source, const std::string& source_name) { +Module FromText(const std::string& source, const std::string& source_name) { auto* f = tvm::runtime::Registry::Get("relay.fromtext"); CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path."; Module mod = (*f)(source, source_name); @@ -366,9 +365,9 @@ TVM_REGISTER_API("relay._module.Module_Import") mod->Import(path); }); -TVM_REGISTER_API("relay._module.Module_ImportStd") +TVM_REGISTER_API("relay._module.Module_ImportFromStd") .set_body_typed([](Module mod, std::string path) { - mod->ImportStd(path); + mod->ImportFromStd(path); });; TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) From 8351bf608d44a09e8ece3aa1fbfdff423dc60946 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 10 Sep 2019 15:35:16 -0500 Subject: [PATCH 10/14] Fix lint --- include/tvm/relay/module.h | 1 + src/relay/ir/module.cc | 3 ++- src/relay/pass/type_solver.cc | 14 +++++++------- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/include/tvm/relay/module.h b/include/tvm/relay/module.h index 7d3323ce99cd..ee9b4873d28a 100644 --- a/include/tvm/relay/module.h +++ b/include/tvm/relay/module.h @@ -33,6 +33,7 @@ #include #include #include +#include namespace tvm { namespace relay { diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index e842326b03ac..2601f355d03e 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -27,6 +27,7 @@ #include #include #include +#include namespace tvm { namespace relay { @@ -166,7 +167,7 @@ void ModuleNode::AddDef(const GlobalTypeVar& var, const TypeData& type) { this->type_definitions.Set(var, type); // set global type var map CHECK(!global_type_var_map_.count(var->var->name_hint)) - << "Duplicate global type definition name " << var->var->name_hint; + << "Duplicate global type definition name " << var->var->name_hint; global_type_var_map_.Set(var->var->name_hint, var); RegisterConstructors(var, type); diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index dd38f97555ec..004aa5f19dff 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -516,13 +516,13 @@ class TypeSolver::Merger : public TypeFunctor { }; // constructor -TypeSolver::TypeSolver(const GlobalVar& current_func, const Module& module, ErrorReporter* err_reporter) - : reporter_(make_node(this)), - current_func(current_func), - err_reporter_(err_reporter), - module_(module) { - CHECK(module_.defined()) << - "internal error: module must be defined"; +TypeSolver::TypeSolver(const GlobalVar& current_func, const Module& module, + ErrorReporter* err_reporter) + : reporter_(make_node(this)), + current_func(current_func), + err_reporter_(err_reporter), + module_(module) { + CHECK(module_.defined()) << "internal error: module must be defined"; } // destructor From 3a01c43d9322b41aba7e33f9aec29f58e48c4030 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 10 Sep 2019 15:57:45 -0500 Subject: [PATCH 11/14] Fix lint --- python/tvm/relay/module.py | 2 +- python/tvm/relay/prelude.py | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/module.py b/python/tvm/relay/module.py index 14dec664c965..57980dd09cf2 100644 --- a/python/tvm/relay/module.py +++ b/python/tvm/relay/module.py @@ -16,6 +16,7 @@ # under the License. # pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import """A global module storing everything needed to interpret or compile a Relay program.""" +import os from .base import register_relay_node, RelayNode from .. import register_func from .._ffi import base as _base @@ -23,7 +24,6 @@ from . import _module from . import expr as _expr from . import ty as _ty -import os __STD_PATH__ = os.path.join(os.path.dirname(os.path.realpath(__file__)), "std") diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index c8935e9282b7..d05b669ee7f1 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -16,15 +16,12 @@ # under the License. # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name """A prelude containing useful global functions and ADT definitions.""" -import os from .ty import GlobalTypeVar, TypeVar, FuncType, TupleType, scalar_type from .expr import Var, Function, GlobalVar, Let, If, Tuple, TupleGetItem, const from .op.tensor import add, subtract, equal from .adt import Constructor, TypeData, Clause, Match from .adt import PatternConstructor, PatternVar, PatternWildcard, PatternTuple -from .parser import fromtext - -from .module import Module, __STD_PATH__ +from .module import Module class Prelude: """Contains standard definitions.""" From 65030b74475c1bc24eb678e4ddb7bbce7da758c7 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 11 Sep 2019 13:56:10 -0500 Subject: [PATCH 12/14] Fix test failure --- src/relay/pass/type_solver.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index 004aa5f19dff..82844901913e 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -646,7 +646,8 @@ TVM_REGISTER_API("relay._analysis._test_type_solver") using runtime::PackedFunc; using runtime::TypedPackedFunc; ErrorReporter *err_reporter = new ErrorReporter(); - auto solver = std::make_shared(GlobalVarNode::make("test"), Module(), err_reporter); + auto module = ModuleNode::make({}, {}); + auto solver = std::make_shared(GlobalVarNode::make("test"), module, err_reporter); auto mod = [solver, err_reporter](std::string name) -> PackedFunc { if (name == "Solve") { From f008410e7207a84fd9559562c516275982563024 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 11 Sep 2019 14:11:49 -0500 Subject: [PATCH 13/14] Add type solver --- src/relay/pass/type_solver.cc | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index 82844901913e..5f38e0915567 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -647,18 +647,20 @@ TVM_REGISTER_API("relay._analysis._test_type_solver") using runtime::TypedPackedFunc; ErrorReporter *err_reporter = new ErrorReporter(); auto module = ModuleNode::make({}, {}); - auto solver = std::make_shared(GlobalVarNode::make("test"), module, err_reporter); + auto dummy_fn_name = GlobalVarNode::make("test"); + module->Add(dummy_fn_name, FunctionNode::make({}, TupleNode::make({}), Type(), {}, {})); + auto solver = std::make_shared(dummy_fn_name, module, err_reporter); - auto mod = [solver, err_reporter](std::string name) -> PackedFunc { + auto mod = [module, solver, err_reporter](std::string name) -> PackedFunc { if (name == "Solve") { return TypedPackedFunc([solver]() { return solver->Solve(); }); } else if (name == "Unify") { - return TypedPackedFunc([solver, err_reporter](Type lhs, Type rhs) { + return TypedPackedFunc([module, solver, err_reporter](Type lhs, Type rhs) { auto res = solver->Unify(lhs, rhs, lhs); if (err_reporter->AnyErrors()) { - err_reporter->RenderErrors(ModuleNode::make({}, {}), true); + err_reporter->RenderErrors(module, true); } return res; }); From a5233176a0edee966d02a39f2747db8e4a9e4cb1 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 11 Sep 2019 15:05:09 -0500 Subject: [PATCH 14/14] Fix lint --- src/relay/pass/type_solver.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index 5f38e0915567..87aeae64745e 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -657,7 +657,8 @@ TVM_REGISTER_API("relay._analysis._test_type_solver") return solver->Solve(); }); } else if (name == "Unify") { - return TypedPackedFunc([module, solver, err_reporter](Type lhs, Type rhs) { + return TypedPackedFunc( + [module, solver, err_reporter](Type lhs, Type rhs) { auto res = solver->Unify(lhs, rhs, lhs); if (err_reporter->AnyErrors()) { err_reporter->RenderErrors(module, true);