@@ -71,6 +71,16 @@ class ROCMModuleNode : public runtime::ModuleNode {
7171 const std::shared_ptr<ModuleNode>& sptr_to_self) final ;
7272
7373
74+ void SaveToFile (const std::string& file_name,
75+ const std::string& format) final {
76+ std::string fmt = GetFileFormat (file_name, format);
77+ std::string meta_file = GetMetaFilePath (file_name);
78+ // note: llvm and asm formats are not laodable, so we don't save them
79+ CHECK_EQ (fmt, fmt_) << " Can only save to format=" << fmt_;
80+ SaveMetaDataToFile (meta_file, fmap_);
81+ SaveBinaryToFile (file_name, data_);
82+ }
83+
7484 void SaveToBinary (dmlc::Stream* stream) final {
7585 stream->Write (fmt_);
7686 stream->Write (fmap_);
@@ -230,6 +240,17 @@ Module ROCMModuleCreate(
230240 return Module (n);
231241}
232242
243+ Module ROCMModuleLoadFile (const std::string& file_name,
244+ const std::string& format) {
245+ std::string data;
246+ std::unordered_map<std::string, FunctionInfo> fmap;
247+ std::string fmt = GetFileFormat (file_name, format);
248+ std::string meta_file = GetMetaFilePath (file_name);
249+ LoadBinaryFromFile (file_name, &data);
250+ LoadMetaDataFromFile (meta_file, &fmap);
251+ return ROCMModuleCreate (data, fmt, fmap, std::string (), std::string ());
252+ }
253+
233254Module ROCMModuleLoadBinary (void * strm) {
234255 dmlc::Stream* stream = static_cast <dmlc::Stream*>(strm);
235256 std::string data;
@@ -248,5 +269,12 @@ TVM_REGISTER_GLOBAL("module.loadbinary_hsaco")
248269
249270TVM_REGISTER_GLOBAL (" module.loadbinary_hip" )
250271.set_body_typed(ROCMModuleLoadBinary);
272+
273+
274+ TVM_REGISTER_GLOBAL (" module.loadfile_hsaco" )
275+ .set_body_typed(ROCMModuleLoadFile);
276+
277+ TVM_REGISTER_GLOBAL (" module.loadfile_hip" )
278+ .set_body_typed(ROCMModuleLoadFile);
251279} // namespace runtime
252280} // namespace tvm
0 commit comments