Skip to content

Commit d4a5175

Browse files
t-vimasahi
authored andcommitted
ROCm: Add SaveToFile and LoadFile (#3665)
...and add rocm module_save to the tests.
1 parent 0cecd03 commit d4a5175

File tree

2 files changed

+36
-2
lines changed

2 files changed

+36
-2
lines changed

src/runtime/rocm/rocm_module.cc

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,16 @@ class ROCMModuleNode : public runtime::ModuleNode {
7171
const std::shared_ptr<ModuleNode>& sptr_to_self) final;
7272

7373

74+
void SaveToFile(const std::string& file_name,
75+
const std::string& format) final {
76+
std::string fmt = GetFileFormat(file_name, format);
77+
std::string meta_file = GetMetaFilePath(file_name);
78+
// note: llvm and asm formats are not laodable, so we don't save them
79+
CHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_;
80+
SaveMetaDataToFile(meta_file, fmap_);
81+
SaveBinaryToFile(file_name, data_);
82+
}
83+
7484
void SaveToBinary(dmlc::Stream* stream) final {
7585
stream->Write(fmt_);
7686
stream->Write(fmap_);
@@ -230,6 +240,17 @@ Module ROCMModuleCreate(
230240
return Module(n);
231241
}
232242

243+
Module ROCMModuleLoadFile(const std::string& file_name,
244+
const std::string& format) {
245+
std::string data;
246+
std::unordered_map<std::string, FunctionInfo> fmap;
247+
std::string fmt = GetFileFormat(file_name, format);
248+
std::string meta_file = GetMetaFilePath(file_name);
249+
LoadBinaryFromFile(file_name, &data);
250+
LoadMetaDataFromFile(meta_file, &fmap);
251+
return ROCMModuleCreate(data, fmt, fmap, std::string(), std::string());
252+
}
253+
233254
Module ROCMModuleLoadBinary(void* strm) {
234255
dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
235256
std::string data;
@@ -248,5 +269,12 @@ TVM_REGISTER_GLOBAL("module.loadbinary_hsaco")
248269

249270
TVM_REGISTER_GLOBAL("module.loadbinary_hip")
250271
.set_body_typed(ROCMModuleLoadBinary);
272+
273+
274+
TVM_REGISTER_GLOBAL("module.loadfile_hsaco")
275+
.set_body_typed(ROCMModuleLoadFile);
276+
277+
TVM_REGISTER_GLOBAL("module.loadfile_hip")
278+
.set_body_typed(ROCMModuleLoadFile);
251279
} // namespace runtime
252280
} // namespace tvm

tests/python/unittest/test_codegen_device.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,12 @@ def check_module_save(device, host="stackvm"):
7676
return
7777
if not tvm.module.enabled(host):
7878
return
79-
fmt = "ptx" if device == "cuda" else device
79+
if device == "cuda":
80+
fmt = "ptx"
81+
elif device == "rocm":
82+
fmt = "hsaco"
83+
else:
84+
fmt = device
8085
mhost = tvm.codegen.build_module(fsplits[0], host)
8186
mdev = tvm.codegen.build_module(fsplits[1:], device)
8287
temp = util.tempdir()
@@ -99,8 +104,9 @@ def check_module_save(device, host="stackvm"):
99104
check_module_save("cuda", host="stackvm")
100105
check_target("nvptx", host="llvm")
101106
check_target("vulkan", host="llvm")
102-
check_target("rocm", host="llvm")
103107
check_module_save("vulkan", host="stackvm")
108+
check_target("rocm", host="llvm")
109+
check_module_save("rocm", host="llvm")
104110

105111

106112
if __name__ == "__main__":

0 commit comments

Comments
 (0)