Skip to content

Commit 7c35267

Browse files
authored
[Fix] add TVM_DLL to disco functions (#16258)
1 parent 09acbc8 commit 7c35267

File tree

6 files changed

+17
-16
lines changed

6 files changed

+17
-16
lines changed

include/tvm/runtime/disco/builtin.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,14 +89,14 @@ void AllGather(NDArray send, NDArray recv);
8989
* \param send The buffer to be broadcasted
9090
* \param recv The buffer receives the broadcasted array
9191
*/
92-
void BroadcastFromWorker0(NDArray send, NDArray recv);
92+
TVM_DLL void BroadcastFromWorker0(NDArray send, NDArray recv);
9393
/*!
9494
* \brief Perform a scatter operation from worker-0, chunking the given buffer into equal parts.
9595
* \param send For worker-0, it must be provided, and otherwise, the buffer must be None.
9696
* The buffer will be divided into equal parts and sent to each worker accordingly.
9797
* \param recv The receiving buffer, which must not be None.
9898
*/
99-
void ScatterFromWorker0(Optional<NDArray> send, NDArray recv);
99+
TVM_DLL void ScatterFromWorker0(Optional<NDArray> send, NDArray recv);
100100
/*!
101101
* \brief Perform a gather operation to worker-0.
102102
* \param send The sending buffer, which must not be None.

include/tvm/runtime/disco/disco_worker.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class DiscoWorker {
6060
/*! \brief Main loop of the worker */
6161
void MainLoop();
6262
/*! \brief Get the worker instance on the current thread */
63-
static DiscoWorker* ThreadLocal();
63+
TVM_DLL static DiscoWorker* ThreadLocal();
6464
/*! \brief Set the specific register to a specific value */
6565
void SetRegister(int reg_id, TVMArgValue value);
6666

include/tvm/runtime/relax_vm/ndarray_cache_support.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,10 @@ struct NDArrayCacheMetadata {
6363
};
6464

6565
/*! \brief Load a FileRecord into memory */
66-
Array<NDArray> Load(Device device, //
67-
const std::string& path_prefix, //
68-
std::string* raw_data_buffer, //
69-
Optional<NDArray>* staging_buffer = nullptr) const;
66+
TVM_DLL Array<NDArray> Load(Device device, //
67+
const std::string& path_prefix, //
68+
std::string* raw_data_buffer, //
69+
Optional<NDArray>* staging_buffer = nullptr) const;
7070

7171
/*! \brief Relative path to the bin file */
7272
std::string data_path;
@@ -83,7 +83,7 @@ struct NDArrayCacheMetadata {
8383
std::string path;
8484

8585
/*! \brief Load the metadata from a specific directory */
86-
static NDArrayCacheMetadata Load(const std::string& path);
86+
TVM_DLL static NDArrayCacheMetadata Load(const std::string& path);
8787
/*! \brief Load the metadata from a given JSON string */
8888
static NDArrayCacheMetadata LoadFromStr(const std::string& json_str, const std::string& path);
8989
};

src/runtime/disco/builtin.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,11 @@ void AllReduce(NDArray send, ReduceKind reduce_kind, NDArray recv) {
8585

8686
void AllGather(NDArray send, NDArray recv) { GetCCLFunc("allgather")(send, recv); }
8787

88-
void BroadcastFromWorker0(NDArray send, NDArray recv) {
88+
TVM_DLL void BroadcastFromWorker0(NDArray send, NDArray recv) {
8989
GetCCLFunc("broadcast_from_worker0")(send, recv);
9090
}
9191

92-
void ScatterFromWorker0(Optional<NDArray> send, NDArray recv) {
92+
TVM_DLL void ScatterFromWorker0(Optional<NDArray> send, NDArray recv) {
9393
GetCCLFunc("scatter_from_worker0")(send, recv);
9494
}
9595

src/runtime/disco/disco_worker.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ struct ThreadLocalDiscoWorker {
3737
}
3838
};
3939

40-
DiscoWorker* DiscoWorker::ThreadLocal() {
40+
TVM_DLL DiscoWorker* DiscoWorker::ThreadLocal() {
4141
DiscoWorker* ret = ThreadLocalDiscoWorker::Get()->worker;
4242
CHECK(ret) << "ValueError: The current thread is not a DiscoWorker thread";
4343
return ret;

src/runtime/relax_vm/ndarray_cache_support.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ NDArrayCacheMetadata NDArrayCacheMetadata::LoadFromStr(const std::string& json_s
123123
return result;
124124
}
125125

126-
NDArrayCacheMetadata NDArrayCacheMetadata::Load(const std::string& path) {
126+
TVM_DLL NDArrayCacheMetadata NDArrayCacheMetadata::Load(const std::string& path) {
127127
picojson::value json_info;
128128
{
129129
std::string json_str;
@@ -183,10 +183,11 @@ NDArray NDArrayCacheMetadata::FileRecord::ParamRecord::Load(
183183
return arr;
184184
}
185185

186-
Array<NDArray> NDArrayCacheMetadata::FileRecord::Load(Device device,
187-
const std::string& path_prefix, //
188-
std::string* raw_data_buffer, //
189-
Optional<NDArray>* staging_buffer) const {
186+
TVM_DLL Array<NDArray> NDArrayCacheMetadata::FileRecord::Load(
187+
Device device,
188+
const std::string& path_prefix, //
189+
std::string* raw_data_buffer, //
190+
Optional<NDArray>* staging_buffer) const {
190191
LoadBinaryFromFile(path_prefix + "/" + this->data_path, raw_data_buffer);
191192
CHECK_EQ(this->format, "raw-shard") << "ValueError: Only `raw-shard` format is supported";
192193
CHECK_EQ(this->nbytes, raw_data_buffer->length())

0 commit comments

Comments
 (0)