Skip to content

Commit 91b9d8b

Browse files
jroeschwweic
authored andcommitted
[Relay][Module] Refactor the way we interface between different modules of Relay. (apache#3906)
* Module refactor * Add load module * Add support for idempotent import * Tweak load paths * Move path around * Expose C++ import functions in Python * Fix import * Add doc string * Fix * Fix lint * Fix lint * Fix test failure * Add type solver * Fix lint
1 parent d91e865 commit 91b9d8b

File tree

12 files changed

+142
-45
lines changed

12 files changed

+142
-45
lines changed

include/tvm/relay/expr.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,7 @@ std::string PrettyPrint(const NodeRef& node);
575575
std::string AsText(const NodeRef& node,
576576
bool show_meta_data = true,
577577
runtime::TypedPackedFunc<std::string(Expr)> annotate = nullptr);
578+
578579
} // namespace relay
579580
} // namespace tvm
580581
#endif // TVM_RELAY_EXPR_H_

include/tvm/relay/module.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include <string>
3434
#include <vector>
3535
#include <unordered_map>
36+
#include <unordered_set>
3637

3738
namespace tvm {
3839
namespace relay {
@@ -185,6 +186,23 @@ class ModuleNode : public RelayNode {
185186
*/
186187
TVM_DLL void Update(const Module& other);
187188

189+
/*!
190+
* \brief Import Relay code from the file at path.
191+
* \param path The path of the Relay code to import.
192+
*
193+
* \note The path resolution behavior is standard,
194+
* if abosolute will be the absolute file, if
195+
* relative it will be resovled against the current
196+
* working directory.
197+
*/
198+
TVM_DLL void Import(const std::string& path);
199+
200+
/*!
201+
* \brief Import Relay code from the file at path, relative to the standard library.
202+
* \param path The path of the Relay code to import.
203+
*/
204+
TVM_DLL void ImportFromStd(const std::string& path);
205+
188206
/*! \brief Construct a module from a standalone expression.
189207
*
190208
* Allows one to optionally pass a global function map and
@@ -222,6 +240,11 @@ class ModuleNode : public RelayNode {
222240
* for convenient access
223241
*/
224242
std::unordered_map<int32_t, Constructor> constructor_tag_map_;
243+
244+
/*! \brief The files previously imported, required to ensure
245+
importing is idempotent for each module.
246+
*/
247+
std::unordered_set<std::string> import_set_;
225248
};
226249

227250
struct Module : public NodeRef {
@@ -235,6 +258,12 @@ struct Module : public NodeRef {
235258
using ContainerType = ModuleNode;
236259
};
237260

261+
/*! \brief Parse Relay source into a module.
262+
* \param source A string of Relay source code.
263+
* \param source_name The name of the source file.
264+
* \return A Relay module.
265+
*/
266+
Module FromText(const std::string& source, const std::string& source_name);
238267

239268
} // namespace relay
240269
} // namespace tvm

include/tvm/relay/type.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,12 @@ class TypeReporterNode : public Node {
410410
*/
411411
TVM_DLL virtual void SetLocation(const NodeRef& ref) = 0;
412412

413+
/*!
414+
* \brief Retrieve the current global module.
415+
* \return The global module.
416+
*/
417+
TVM_DLL virtual Module GetModule() = 0;
418+
413419
// solver is not serializable.
414420
void VisitAttrs(tvm::AttrVisitor* v) final {}
415421

python/tvm/relay/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# pylint: disable=wildcard-import, redefined-builtin, invalid-name
1818
"""The Relay IR namespace containing the IR definition and compiler."""
1919
from __future__ import absolute_import
20+
import os
2021
from sys import setrecursionlimit
2122
from ..api import register_func
2223
from . import base

python/tvm/relay/module.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,22 @@
1616
# under the License.
1717
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import
1818
"""A global module storing everything needed to interpret or compile a Relay program."""
19+
import os
1920
from .base import register_relay_node, RelayNode
21+
from .. import register_func
2022
from .._ffi import base as _base
2123
from . import _make
2224
from . import _module
2325
from . import expr as _expr
2426
from . import ty as _ty
2527

28+
__STD_PATH__ = os.path.join(os.path.dirname(os.path.realpath(__file__)), "std")
29+
30+
@register_func("tvm.relay.std_path")
31+
def _std_path():
32+
global __STD_PATH__
33+
return __STD_PATH__
34+
2635
@register_relay_node
2736
class Module(RelayNode):
2837
"""The global Relay module containing collection of functions.
@@ -202,3 +211,9 @@ def from_expr(expr, functions=None, type_defs=None):
202211
funcs = functions if functions is not None else {}
203212
defs = type_defs if type_defs is not None else {}
204213
return _module.Module_FromExpr(expr, funcs, defs)
214+
215+
def _import(self, file_to_import):
216+
return _module.Module_Import(self, file_to_import)
217+
218+
def import_from_std(self, file_to_import):
219+
return _module.Module_ImportFromStd(self, file_to_import)

python/tvm/relay/prelude.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,11 @@
1616
# under the License.
1717
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
1818
"""A prelude containing useful global functions and ADT definitions."""
19-
import os
2019
from .ty import GlobalTypeVar, TypeVar, FuncType, TupleType, scalar_type
2120
from .expr import Var, Function, GlobalVar, Let, If, Tuple, TupleGetItem, const
2221
from .op.tensor import add, subtract, equal
2322
from .adt import Constructor, TypeData, Clause, Match
2423
from .adt import PatternConstructor, PatternVar, PatternWildcard, PatternTuple
25-
from .parser import fromtext
26-
__PRELUDE_PATH__ = os.path.dirname(os.path.realpath(__file__))
2724
from .module import Module
2825

2926
class Prelude:
@@ -479,12 +476,10 @@ def load_prelude(self):
479476
Parses the portions of the Prelude written in Relay's text format and adds
480477
them to the module.
481478
"""
482-
prelude_file = os.path.join(__PRELUDE_PATH__, "prelude.rly")
483-
with open(prelude_file) as prelude:
484-
prelude = fromtext(prelude.read())
485-
self.mod.update(prelude)
486-
self.id = self.mod.get_global_var("id")
487-
self.compose = self.mod.get_global_var("compose")
479+
# TODO(@jroesch): we should remove this helper when we port over prelude
480+
self.mod.import_from_std("prelude.rly")
481+
self.id = self.mod.get_global_var("id")
482+
self.compose = self.mod.get_global_var("compose")
488483

489484

490485
def __init__(self, mod=None):
File renamed without changes.

src/relay/ir/expr_functor.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,6 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
444444
}
445445
}
446446

447-
448447
TVM_REGISTER_API("relay._expr.Bind")
449448
.set_body([](TVMArgs args, TVMRetValue* ret) {
450449
NodeRef input = args[0];

src/relay/ir/module.cc

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
#include <tvm/relay/analysis.h>
2727
#include <tvm/relay/transform.h>
2828
#include <sstream>
29+
#include <fstream>
30+
#include <unordered_set>
2931

3032
namespace tvm {
3133
namespace relay {
@@ -38,6 +40,9 @@ Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs,
3840
auto n = make_node<ModuleNode>();
3941
n->functions = std::move(global_funcs);
4042
n->type_definitions = std::move(global_type_defs);
43+
n->global_type_var_map_ = {};
44+
n->global_var_map_ = {};
45+
n->constructor_tag_map_ = {};
4146

4247
for (const auto& kv : n->functions) {
4348
// set global var map
@@ -85,6 +90,7 @@ void ModuleNode::AddUnchecked(const GlobalVar& var,
8590
}
8691

8792
GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) const {
93+
CHECK(global_type_var_map_.defined());
8894
auto it = global_type_var_map_.find(name);
8995
CHECK(it != global_type_var_map_.end())
9096
<< "Cannot find global type var " << name << " in the Module";
@@ -162,6 +168,7 @@ void ModuleNode::AddDef(const GlobalTypeVar& var, const TypeData& type) {
162168
// set global type var map
163169
CHECK(!global_type_var_map_.count(var->var->name_hint))
164170
<< "Duplicate global type definition name " << var->var->name_hint;
171+
165172
global_type_var_map_.Set(var->var->name_hint, var);
166173
RegisterConstructors(var, type);
167174

@@ -241,6 +248,40 @@ Module ModuleNode::FromExpr(
241248
return mod;
242249
}
243250

251+
void ModuleNode::Import(const std::string& path) {
252+
LOG(INFO) << "Importing: " << path;
253+
if (this->import_set_.count(path) == 0) {
254+
this->import_set_.insert(path);
255+
std::fstream src_file(path, std::fstream::in);
256+
std::string file_contents {
257+
std::istreambuf_iterator<char>(src_file),
258+
std::istreambuf_iterator<char>() };
259+
auto mod_to_import = FromText(file_contents, path);
260+
261+
for (auto func : mod_to_import->functions) {
262+
this->Add(func.first, func.second, false);
263+
}
264+
265+
for (auto type : mod_to_import->type_definitions) {
266+
this->AddDef(type.first, type.second);
267+
}
268+
}
269+
}
270+
271+
void ModuleNode::ImportFromStd(const std::string& path) {
272+
auto* f = tvm::runtime::Registry::Get("tvm.relay.std_path");
273+
CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path.";
274+
std::string std_path = (*f)();
275+
return this->Import(std_path + "/" + path);
276+
}
277+
278+
Module FromText(const std::string& source, const std::string& source_name) {
279+
auto* f = tvm::runtime::Registry::Get("relay.fromtext");
280+
CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path.";
281+
Module mod = (*f)(source, source_name);
282+
return mod;
283+
}
284+
244285
TVM_REGISTER_NODE_TYPE(ModuleNode);
245286

246287
TVM_REGISTER_API("relay._make.Module")
@@ -320,6 +361,16 @@ TVM_REGISTER_API("relay._module.Module_Update")
320361
mod->Update(from);
321362
});
322363

364+
TVM_REGISTER_API("relay._module.Module_Import")
365+
.set_body_typed<void(Module, std::string)>([](Module mod, std::string path) {
366+
mod->Import(path);
367+
});
368+
369+
TVM_REGISTER_API("relay._module.Module_ImportFromStd")
370+
.set_body_typed<void(Module, std::string)>([](Module mod, std::string path) {
371+
mod->ImportFromStd(path);
372+
});;
373+
323374
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
324375
.set_dispatch<ModuleNode>(
325376
[](const ModuleNode *node, tvm::IRPrinter *p) {

src/relay/pass/type_infer.cc

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,8 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
108108

109109
explicit TypeInferencer(Module mod, GlobalVar current_func)
110110
: mod_(mod), current_func_(current_func),
111-
err_reporter(), solver_(current_func, &this->err_reporter) {
111+
err_reporter(), solver_(current_func, mod, &this->err_reporter) {
112+
CHECK(mod.defined()) << "internal error: Module must be set in the type inferencer";
112113
}
113114

114115
// inference the type of expr.
@@ -790,36 +791,22 @@ void EnsureCheckedType(const Expr& e) {
790791
AllCheckTypePopulated().VisitExpr(e);
791792
}
792793

793-
Expr InferType(const Expr& expr, const Module& mod_ref) {
794-
if (!mod_ref.defined()) {
795-
Module mod = ModuleNode::FromExpr(expr);
796-
// NB(@jroesch): By adding the expression to the module we will
797-
// type check it anyway; afterwards we can just recover type
798-
// from the type-checked function to avoid doing unnecessary work.
799-
800-
Function func = mod->Lookup("main");
801-
802-
// FromExpr wraps a naked expression as a function, we will unbox
803-
// it here.
804-
if (expr.as<FunctionNode>()) {
805-
return std::move(func);
806-
} else {
807-
return func->body;
808-
}
809-
} else {
810-
auto e = TypeInferencer(mod_ref, mod_ref->GetGlobalVar("main")).Infer(expr);
811-
CHECK(WellFormed(e));
812-
auto free_tvars = FreeTypeVars(e, mod_ref);
813-
CHECK(free_tvars.size() == 0)
814-
<< "Found unbound type variables in " << e << ": " << free_tvars;
815-
EnsureCheckedType(e);
816-
return e;
817-
}
794+
Expr InferType(const Expr& expr, const Module& mod) {
795+
auto main = mod->GetGlobalVar("main");
796+
auto inferencer = TypeInferencer(mod, main);
797+
auto e = inferencer.Infer(expr);
798+
CHECK(WellFormed(e));
799+
auto free_tvars = FreeTypeVars(e, mod);
800+
CHECK(free_tvars.size() == 0)
801+
<< "Found unbound type variables in " << e << ": " << free_tvars;
802+
EnsureCheckedType(e);
803+
return e;
818804
}
819805

820806
Function InferType(const Function& func,
821807
const Module& mod,
822808
const GlobalVar& var) {
809+
CHECK(mod.defined()) << "internal error: module must be set for type inference";
823810
Function func_copy = Function(make_node<FunctionNode>(*func.operator->()));
824811
func_copy->checked_type_ = func_copy->func_type_annotation();
825812
mod->AddUnchecked(var, func_copy);

0 commit comments

Comments
 (0)