Skip to content

Commit 6b7d542

Browse files
committed
Support standardize runtime module
1 parent 9384353 commit 6b7d542

File tree

6 files changed

+419
-53
lines changed

6 files changed

+419
-53
lines changed

python/tvm/_ffi/function.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ def __init__(self, handle):
8282
def __del__(self):
8383
check_call(_LIB.TVMModFree(self.handle))
8484

85+
def __hash__(self):
86+
return ctypes.cast(self.handle, ctypes.c_void_p).value
87+
8588
@property
8689
def entry_func(self):
8790
"""Get the entry function

python/tvm/module.py

Lines changed: 48 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -118,31 +118,37 @@ def export_library(self,
118118
self.save(file_name)
119119
return
120120

121-
if not (self.type_key == "llvm" or self.type_key == "c"):
122-
raise ValueError("Module[%s]: Only llvm and c support export shared" % self.type_key)
121+
modules = self._collect_dso_modules()
123122
temp = _util.tempdir()
124-
if fcompile is not None and hasattr(fcompile, "object_format"):
125-
object_format = fcompile.object_format
126-
else:
127-
if self.type_key == "llvm":
128-
object_format = "o"
123+
files = []
124+
is_system_lib = False
125+
updated_kwargs = False
126+
for module in modules:
127+
if fcompile is not None and hasattr(fcompile, "object_format"):
128+
object_format = fcompile.object_format
129129
else:
130-
assert self.type_key == "c"
131-
object_format = "cc"
132-
path_obj = temp.relpath("lib." + object_format)
133-
self.save(path_obj)
134-
files = [path_obj]
135-
is_system_lib = self.type_key == "llvm" and self.get_function("__tvm_is_system_module")()
136-
has_imported_c_file = False
130+
if module.type_key == "llvm":
131+
object_format = "o"
132+
else:
133+
assert module.type_key == "c"
134+
object_format = "cc"
135+
path_obj = temp.relpath("lib" + str(hash(module)) + "." + object_format)
136+
module.save(path_obj)
137+
files.append(path_obj)
138+
is_system_lib = (module.type_key == "llvm" and
139+
module.get_function("__tvm_is_system_module")())
140+
if module.type_key == "c":
141+
if updated_kwargs:
142+
continue
143+
options = []
144+
if "options" in kwargs:
145+
opts = kwargs["options"]
146+
options = opts if isinstance(opts, (list, tuple)) else [opts]
147+
opts = options + ["-I" + path for path in find_include_path()]
148+
kwargs.update({'options': opts})
149+
updated_kwargs = True
150+
137151
if self.imported_modules:
138-
for i, m in enumerate(self.imported_modules):
139-
if m.type_key == "c":
140-
has_imported_c_file = True
141-
c_file_name = "tmp_" + str(i) + ".cc"
142-
path_cc = temp.relpath(c_file_name)
143-
with open(path_cc, "w") as f:
144-
f.write(m.get_source())
145-
files.append(path_cc)
146152
path_cc = temp.relpath("devc.cc")
147153
with open(path_cc, "w") as f:
148154
f.write(_PackImportsToC(self, is_system_lib))
@@ -152,13 +158,7 @@ def export_library(self,
152158
fcompile = _tar.tar
153159
else:
154160
fcompile = _cc.create_shared
155-
if self.type_key == "c" or has_imported_c_file:
156-
options = []
157-
if "options" in kwargs:
158-
opts = kwargs["options"]
159-
options = opts if isinstance(opts, (list, tuple)) else [opts]
160-
opts = options + ["-I" + path for path in find_include_path()]
161-
kwargs.update({'options': opts})
161+
162162
fcompile(file_name, files, **kwargs)
163163

164164
def time_evaluator(self, func_name, ctx, number=10, repeat=1, min_repeat_ms=0):
@@ -219,6 +219,25 @@ def evaluator(*args):
219219
except NameError:
220220
raise NameError("time_evaluate is only supported when RPC is enabled")
221221

222+
def _collect_dso_modules(self):
223+
"""Helper function to collect dso modules, then return it."""
224+
visited, stack, dso_modules = set(), [], []
225+
# append root module
226+
visited.add(self)
227+
stack.append(self)
228+
while stack:
229+
module = stack.pop()
230+
if module._dso_exportable():
231+
dso_modules.append(module)
232+
for m in module.imported_modules:
233+
if m not in visited:
234+
visited.add(m)
235+
stack.append(m)
236+
return dso_modules
237+
238+
def _dso_exportable(self):
239+
return self.type_key == "llvm" or self.type_key == "c"
240+
222241

223242
def system_lib():
224243
"""Get system-wide library module singleton.

src/codegen/codegen.cc

Lines changed: 105 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,10 @@
2828
#include <tvm/build_module.h>
2929
#include <dmlc/memory_io.h>
3030
#include <sstream>
31-
#include <iostream>
31+
#include <vector>
32+
#include <cstdint>
33+
#include <unordered_set>
34+
#include <cstring>
3235

3336
namespace tvm {
3437
namespace codegen {
@@ -58,20 +61,111 @@ runtime::Module Build(const Array<LoweredFunc>& funcs,
5861
return m;
5962
}
6063

64+
/*! \brief Helper class to serialize module */
65+
class ModuleSerializer {
66+
public:
67+
explicit ModuleSerializer(runtime::Module mod) : mod_(mod) {
68+
Init();
69+
}
70+
71+
void SerializeModule(dmlc::Stream* stream) {
72+
// Only have one DSO module and it is in the root, then
73+
// we will not produce import_tree_.
74+
bool has_import_tree = true;
75+
if (DSOExportable(mod_.operator->()) && mod_->imports().empty()) {
76+
has_import_tree = false;
77+
}
78+
uint64_t sz = 0;
79+
if (has_import_tree) {
80+
// we will append one key for _import_tree
81+
// The layout is the same as before: binary_size, key, logic, key, logic...
82+
sz = mod_vec_.size() + 1;
83+
} else {
84+
// Keep the old behaviour
85+
sz = mod_->imports().size();
86+
}
87+
stream->Write(sz);
88+
89+
for (auto m : mod_vec_) {
90+
std::string mod_type_key = m->type_key();
91+
if (!DSOExportable(m)) {
92+
stream->Write(mod_type_key);
93+
m->SaveToBinary(stream);
94+
} else if (has_import_tree) {
95+
mod_type_key = "_lib";
96+
stream->Write(mod_type_key);
97+
}
98+
}
99+
100+
// Write _import_tree key if we have
101+
if (has_import_tree) {
102+
std::string import_key = "_import_tree";
103+
stream->Write(import_key);
104+
stream->Write(import_tree_row_ptr_);
105+
stream->Write(import_tree_child_indices_);
106+
}
107+
}
108+
109+
private:
110+
void Init() {
111+
CreateModuleIndex();
112+
CreateImportTree();
113+
}
114+
115+
// invariance: root module is always at location 0.
116+
// The module order is collected via DFS
117+
void CreateModuleIndex() {
118+
std::unordered_set<const runtime::ModuleNode*> visited {mod_.operator->()};
119+
std::vector<runtime::ModuleNode*> stack {mod_.operator->()};
120+
uint64_t module_index = 0;
121+
122+
while (!stack.empty()) {
123+
runtime::ModuleNode* n = stack.back();
124+
stack.pop_back();
125+
mod2index_[n] = module_index++;
126+
mod_vec_.emplace_back(n);
127+
for (runtime::Module m : n->imports()) {
128+
runtime::ModuleNode* next = m.operator->();
129+
if (visited.count(next) == 0) {
130+
visited.insert(next);
131+
stack.push_back(next);
132+
}
133+
}
134+
}
135+
}
136+
137+
void CreateImportTree() {
138+
for (auto m : mod_vec_) {
139+
for (runtime::Module im : m->imports()) {
140+
uint64_t mod_index = mod2index_[im.operator->()];
141+
import_tree_child_indices_.push_back(mod_index);
142+
}
143+
import_tree_row_ptr_.push_back(import_tree_child_indices_.size());
144+
}
145+
}
146+
147+
bool DSOExportable(const runtime::ModuleNode* mod) {
148+
return !std::strcmp(mod->type_key(), "llvm") ||
149+
!std::strcmp(mod->type_key(), "c");
150+
}
151+
152+
runtime::Module mod_;
153+
// construct module to index
154+
std::unordered_map<runtime::ModuleNode*, size_t> mod2index_;
155+
// index -> module
156+
std::vector<runtime::ModuleNode*> mod_vec_;
157+
std::vector<uint64_t> import_tree_row_ptr_ {0};
158+
std::vector<uint64_t> import_tree_child_indices_;
159+
};
160+
61161
std::string PackImportsToC(const runtime::Module& mod, bool system_lib) {
62162
std::string bin;
63163
dmlc::MemoryStringStream ms(&bin);
64164
dmlc::Stream* stream = &ms;
65-
uint64_t sz = static_cast<uint64_t>(mod->imports().size());
66-
stream->Write(sz);
67-
for (runtime::Module im : mod->imports()) {
68-
CHECK_EQ(im->imports().size(), 0U)
69-
<< "Only support simply one-level hierarchy";
70-
std::string tkey = im->type_key();
71-
stream->Write(tkey);
72-
if (tkey == "c") continue;
73-
im->SaveToBinary(stream);
74-
}
165+
166+
ModuleSerializer module_serializer(mod);
167+
module_serializer.SerializeModule(stream);
168+
75169
// translate to C program
76170
std::ostringstream os;
77171
os << "#ifdef _WIN32\n"

src/runtime/library_module.cc

Lines changed: 54 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include <tvm/runtime/registry.h>
2929
#include <string>
3030
#include <vector>
31+
#include <cstdint>
3132
#include "library_module.h"
3233

3334
namespace tvm {
@@ -108,9 +109,11 @@ void InitContextFunctions(std::function<void*(const char*)> fgetsymbol) {
108109
/*!
109110
* \brief Load and append module blob to module list
110111
* \param mblob The module blob.
111-
* \param module_list The module list to append to
112+
* \param lib The library.
113+
*
114+
* \return Root Module.
112115
*/
113-
void ImportModuleBlob(const char* mblob, std::vector<Module>* mlist) {
116+
runtime::Module ProcessModuleBlob(const char* mblob, ObjectPtr<Library> lib) {
114117
#ifndef _LIBCPP_SGX_CONFIG
115118
CHECK(mblob != nullptr);
116119
uint64_t nbytes = 0;
@@ -123,20 +126,56 @@ void ImportModuleBlob(const char* mblob, std::vector<Module>* mlist) {
123126
dmlc::Stream* stream = &fs;
124127
uint64_t size;
125128
CHECK(stream->Read(&size));
129+
std::vector<Module> modules;
130+
std::vector<uint64_t> import_tree_row_ptr;
131+
std::vector<uint64_t> import_tree_child_indices;
126132
for (uint64_t i = 0; i < size; ++i) {
127133
std::string tkey;
128134
CHECK(stream->Read(&tkey));
129-
if (tkey == "c") continue;
130-
std::string fkey = "module.loadbinary_" + tkey;
131-
const PackedFunc* f = Registry::Get(fkey);
132-
CHECK(f != nullptr)
135+
// Currently, _lib is for DSOModule, but we
136+
// don't have loadbinary function for it currently
137+
if (tkey == "_lib") {
138+
auto dso_module = Module(make_object<LibraryModuleNode>(lib));
139+
modules.emplace_back(dso_module);
140+
} else if (tkey == "_import_tree") {
141+
CHECK(stream->Read(&import_tree_row_ptr));
142+
CHECK(stream->Read(&import_tree_child_indices));
143+
} else {
144+
std::string fkey = "module.loadbinary_" + tkey;
145+
const PackedFunc* f = Registry::Get(fkey);
146+
CHECK(f != nullptr)
133147
<< "Loader of " << tkey << "("
134148
<< fkey << ") is not presented.";
135-
Module m = (*f)(static_cast<void*>(stream));
136-
mlist->push_back(m);
149+
Module m = (*f)(static_cast<void*>(stream));
150+
modules.emplace_back(m);
151+
}
137152
}
153+
// if we are using old dll, we don't have import tree
154+
// so that we can't reconstruct module relationship using import tree
155+
if (import_tree_row_ptr.empty()) {
156+
auto n = make_object<LibraryModuleNode>(lib);
157+
auto module_import_addr = ModuleInternal::GetImportsAddr(n.operator->());
158+
for (const auto& m : modules) {
159+
module_import_addr->emplace_back(m);
160+
}
161+
return Module(n);
162+
} else {
163+
for (size_t i = 0; i < modules.size(); ++i) {
164+
for (size_t j = import_tree_row_ptr[i]; j < import_tree_row_ptr[i + 1]; ++j) {
165+
auto module_import_addr = ModuleInternal::GetImportsAddr(modules[i].operator->());
166+
auto child_index = import_tree_child_indices[j];
167+
CHECK(child_index < modules.size());
168+
module_import_addr->emplace_back(modules[child_index]);
169+
}
170+
}
171+
}
172+
CHECK(!modules.empty());
173+
// invariance: root module is always at location 0.
174+
// The module order is collected via DFS
175+
return modules[0];
138176
#else
139177
LOG(FATAL) << "SGX does not support ImportModuleBlob";
178+
return Module();
140179
#endif
141180
}
142181

@@ -149,17 +188,20 @@ Module CreateModuleFromLibrary(ObjectPtr<Library> lib) {
149188
const char* dev_mblob =
150189
reinterpret_cast<const char*>(
151190
lib->GetSymbol(runtime::symbol::tvm_dev_mblob));
191+
Module root_mod;
152192
if (dev_mblob != nullptr) {
153-
ImportModuleBlob(
154-
dev_mblob, ModuleInternal::GetImportsAddr(n.operator->()));
193+
root_mod = ProcessModuleBlob(dev_mblob, lib);
194+
} else {
195+
// Only have one single DSO Module
196+
root_mod = Module(n);
155197
}
156198

157-
Module root_mod = Module(n);
158-
// allow lookup of symbol from root(so all symbols are visible).
199+
// allow lookup of symbol from root (so all symbols are visible).
159200
if (auto *ctx_addr =
160201
reinterpret_cast<void**>(lib->GetSymbol(runtime::symbol::tvm_module_ctx))) {
161202
*ctx_addr = root_mod.operator->();
162203
}
204+
163205
return root_mod;
164206
}
165207
} // namespace runtime

src/runtime/module.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ const PackedFunc* ModuleNode::GetFuncFromEnv(const std::string& name) {
115115
if (it != import_cache_.end()) return it->second.get();
116116
PackedFunc pf;
117117
for (Module& m : this->imports_) {
118-
pf = m.GetFunction(name, false);
118+
pf = m.GetFunction(name, true);
119119
if (pf != nullptr) break;
120120
}
121121
if (pf == nullptr) {

0 commit comments

Comments
 (0)