Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions include/tvm/runtime/disco/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,28 +62,28 @@ inline std::string ReduceKind2String(ReduceKind kind) {
* \param device The default device used to initialize the RelaxVM
* \return The RelaxVM as a runtime Module
*/
Module LoadVMModule(std::string path, Device device);
TVM_DLL Module LoadVMModule(std::string path, Device device);
/*!
* \brief Create an uninitialized empty NDArray
* \param shape The shape of the NDArray
* \param dtype The dtype of the NDArray
* \param device The device the NDArray is created on. If None, use the thread local default device
* \return The NDArray created
*/
NDArray DiscoEmptyNDArray(ShapeTuple shape, DataType dtype, Device device);
TVM_DLL NDArray DiscoEmptyNDArray(ShapeTuple shape, DataType dtype, Device device);
/*!
* \brief Perform an allreduce operation using the underlying communication library
* \param send The array send to perform allreduce on
* \param reduce_kind The kind of reduction operation (e.g. sum, avg, min, max)
* \param recv The array receives the outcome of allreduce
*/
void AllReduce(NDArray send, ReduceKind reduce_kind, NDArray recv);
TVM_DLL void AllReduce(NDArray send, ReduceKind reduce_kind, NDArray recv);
/*!
* \brief Perform an allgather operation using the underlying communication library
* \param send The array send to perform allgather on
* \param recv The array receives the outcome of allgather
*/
void AllGather(NDArray send, NDArray recv);
TVM_DLL void AllGather(NDArray send, NDArray recv);
/*!
* \brief Perform a broadcast operation from worker-0
* \param send The buffer to be broadcasted
Expand All @@ -103,20 +103,20 @@ TVM_DLL void ScatterFromWorker0(Optional<NDArray> send, NDArray recv);
* \param recv For worker-0, it must be provided, and otherwise, the buffer must be None. The
* receiving buffer will be divided into equal parts and receive from each worker accordingly.
*/
void GatherToWorker0(NDArray send, Optional<NDArray> recv);
TVM_DLL void GatherToWorker0(NDArray send, Optional<NDArray> recv);
/*!
* \brief Receive a buffer from worker-0. No-op if the current worker is worker-0.
* \param buffer The buffer to be received
*/
void RecvFromWorker0(NDArray buffer);
TVM_DLL void RecvFromWorker0(NDArray buffer);
/*! \brief Get the local worker id */
int WorkerId();
TVM_DLL int WorkerId();
/*!
* \brief Called by the worker thread. Waiting until the worker completes all its tasks.
* As a specific example, on a CUDA worker, it blocks until all kernels are launched and
* cudaStreamSynchronize is complete.
*/
void SyncWorker();
TVM_DLL void SyncWorker();

} // namespace runtime
} // namespace tvm
Expand Down
18 changes: 9 additions & 9 deletions include/tvm/runtime/disco/session.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,51 +196,51 @@ class SessionObj : public Object {
* The second element must be 0, which will later be updated by the session to return reg_id
* The thirtd element is the function to be called.
*/
virtual DRef CallWithPacked(const TVMArgs& args) = 0;
TVM_DLL virtual DRef CallWithPacked(const TVMArgs& args) = 0;
/*! \brief Get a global functions on workers. */
virtual DRef GetGlobalFunc(const std::string& name) = 0;
TVM_DLL virtual DRef GetGlobalFunc(const std::string& name) = 0;
/*!
* \brief Copy an NDArray from worker-0 to the controler-side NDArray
* \param host_array The array to be copied to worker-0
* \param remote_array The NDArray on worker-0
*/
virtual void CopyFromWorker0(const NDArray& host_array, const DRef& remote_array) = 0;
TVM_DLL virtual void CopyFromWorker0(const NDArray& host_array, const DRef& remote_array) = 0;
/*!
* \brief Copy the controler-side NDArray to worker-0
* \param host_array The array to be copied to worker-0
* \param remote_array The NDArray on worker-0
*/
virtual void CopyToWorker0(const NDArray& host_array, const DRef& remote_array) = 0;
TVM_DLL virtual void CopyToWorker0(const NDArray& host_array, const DRef& remote_array) = 0;
/*!
* \brief Synchrnoize the controler with a worker, and it will wait until worker finishes
* executing this instruction.
* \param worker_id The id of the worker to be synced with.
* \note This function is usually used for worker-0, because it is the only worker that is
* assumed to collocate with the controler. Syncing with other workers may not be supported.
*/
virtual void SyncWorker(int worker_id) = 0;
TVM_DLL virtual void SyncWorker(int worker_id) = 0;
/*! \brief Signal all the workers to shutdown */
virtual void Shutdown() = 0;
TVM_DLL virtual void Shutdown() = 0;
/*!
* \brief Initialize the data plane between workers.
* \param ccl The name of the communication backend, e.g., nccl, rccl, mpi.
* \param device_ids The device ids of the workers.
*/
virtual void InitCCL(String ccl, IntTuple device_ids) = 0;
TVM_DLL virtual void InitCCL(String ccl, IntTuple device_ids) = 0;
/*!
* \brief Get the value of a register from a remote worker.
* \param reg_id The id of the register to be fetched.
* \param worker_id The id of the worker to be fetched from.
* \return The value of the register.
*/
virtual TVMRetValue DebugGetFromRemote(int64_t reg_id, int worker_id) = 0;
TVM_DLL virtual TVMRetValue DebugGetFromRemote(int64_t reg_id, int worker_id) = 0;
/*!
* \brief Set the value of a register on a remote worker.
* \param reg_id The id of the register to be set.
* \param value The value to be set.
* \param worker_id The id of the worker to be set.
*/
virtual void DebugSetRegister(int64_t reg_id, TVMArgValue value, int worker_id) = 0;
TVM_DLL virtual void DebugSetRegister(int64_t reg_id, TVMArgValue value, int worker_id) = 0;

struct FFI;
friend struct SessionObj::FFI;
Expand Down