2626#include < tvm/relay/analysis.h>
2727#include < tvm/relay/transform.h>
2828#include < sstream>
29+ #include < fstream>
30+ #include < unordered_set>
2931
3032namespace tvm {
3133namespace relay {
@@ -38,6 +40,9 @@ Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs,
3840 auto n = make_node<ModuleNode>();
3941 n->functions = std::move (global_funcs);
4042 n->type_definitions = std::move (global_type_defs);
43+ n->global_type_var_map_ = {};
44+ n->global_var_map_ = {};
45+ n->constructor_tag_map_ = {};
4146
4247 for (const auto & kv : n->functions ) {
4348 // set global var map
@@ -85,6 +90,7 @@ void ModuleNode::AddUnchecked(const GlobalVar& var,
8590}
8691
8792GlobalTypeVar ModuleNode::GetGlobalTypeVar (const std::string& name) const {
93+ CHECK (global_type_var_map_.defined ());
8894 auto it = global_type_var_map_.find (name);
8995 CHECK (it != global_type_var_map_.end ())
9096 << " Cannot find global type var " << name << " in the Module" ;
@@ -162,6 +168,7 @@ void ModuleNode::AddDef(const GlobalTypeVar& var, const TypeData& type) {
162168 // set global type var map
163169 CHECK (!global_type_var_map_.count (var->var ->name_hint ))
164170 << " Duplicate global type definition name " << var->var ->name_hint ;
171+
165172 global_type_var_map_.Set (var->var ->name_hint , var);
166173 RegisterConstructors (var, type);
167174
@@ -241,6 +248,40 @@ Module ModuleNode::FromExpr(
241248 return mod;
242249}
243250
251+ void ModuleNode::Import (const std::string& path) {
252+ LOG (INFO) << " Importing: " << path;
253+ if (this ->import_set_ .count (path) == 0 ) {
254+ this ->import_set_ .insert (path);
255+ std::fstream src_file (path, std::fstream::in);
256+ std::string file_contents {
257+ std::istreambuf_iterator<char >(src_file),
258+ std::istreambuf_iterator<char >() };
259+ auto mod_to_import = FromText (file_contents, path);
260+
261+ for (auto func : mod_to_import->functions ) {
262+ this ->Add (func.first , func.second , false );
263+ }
264+
265+ for (auto type : mod_to_import->type_definitions ) {
266+ this ->AddDef (type.first , type.second );
267+ }
268+ }
269+ }
270+
271+ void ModuleNode::ImportFromStd (const std::string& path) {
272+ auto * f = tvm::runtime::Registry::Get (" tvm.relay.std_path" );
273+ CHECK (f != nullptr ) << " The Relay std_path is not set, please register tvm.relay.std_path." ;
274+ std::string std_path = (*f)();
275+ return this ->Import (std_path + " /" + path);
276+ }
277+
278+ Module FromText (const std::string& source, const std::string& source_name) {
279+ auto * f = tvm::runtime::Registry::Get (" relay.fromtext" );
280+ CHECK (f != nullptr ) << " The Relay std_path is not set, please register tvm.relay.std_path." ;
281+ Module mod = (*f)(source, source_name);
282+ return mod;
283+ }
284+
244285TVM_REGISTER_NODE_TYPE (ModuleNode);
245286
246287TVM_REGISTER_API (" relay._make.Module" )
@@ -320,6 +361,16 @@ TVM_REGISTER_API("relay._module.Module_Update")
320361 mod->Update (from);
321362});
322363
364+ TVM_REGISTER_API (" relay._module.Module_Import" )
365+ .set_body_typed<void (Module, std::string)>([](Module mod, std::string path) {
366+ mod->Import (path);
367+ });
368+
369+ TVM_REGISTER_API (" relay._module.Module_ImportFromStd" )
370+ .set_body_typed<void (Module, std::string)>([](Module mod, std::string path) {
371+ mod->ImportFromStd (path);
372+ });;
373+
323374TVM_STATIC_IR_FUNCTOR_REGISTER (IRPrinter, vtable)
324375.set_dispatch<ModuleNode>(
325376 [](const ModuleNode *node, tvm::IRPrinter *p) {
0 commit comments