Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions include/tvm/runtime/disco/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,30 @@ TVM_DLL void GatherToWorker0(NDArray send, bool in_group, Optional<NDArray> recv
* \param buffer The buffer to be received
*/
TVM_DLL void RecvFromWorker0(NDArray buffer);
/*!
* \brief Send a buffer to the corresponding worker in the next group.
* An error is thrown if the worker is already in the last group.
* \param buffer The sending buffer.
*/
TVM_DLL void SendToNextGroup(NDArray buffer);
/*!
* \brief Receive a buffer from the corresponding worker in the previous group.
* An error is thrown if the worker is already in the first group.
* \param buffer The receiving buffer.
*/
TVM_DLL void RecvFromPrevGroup(NDArray buffer);
/*!
* \brief Send a buffer to the target receiver worker (globally across all groups).
* \param buffer The sending buffer.
* \param receiver_id The global receiver worker id.
*/
TVM_DLL void SendToWorker(NDArray buffer, int receiver_id);
/*!
* \brief Receive a buffer from the target sender worker (globally across all groups).
* \param buffer The receiving buffer.
* \param sender_id The global sender worker id.
*/
TVM_DLL void RecvFromWorker(NDArray buffer, int sender_id);
/*! \brief Get the local worker id */
TVM_DLL int WorkerId();
/*!
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/relax/frontend/nn/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,16 +549,16 @@ def __init__(self, modules: List[Module]):
def __iter__(self):
return iter(self.modules)

def __getitem__(self, idx):
def __getitem__(self, idx: int) -> Module:
return self.modules[idx]

def __setitem__(self, idx, module):
def __setitem__(self, idx: int, module: Module) -> None:
self.modules[idx] = module

def __len__(self):
return len(self.modules)

def append(self, module):
def append(self, module: Module):
"""Add a module to the end of the ModuleList"""
self.modules.append(module)

Expand Down
16 changes: 16 additions & 0 deletions src/runtime/disco/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,18 @@ void GatherToWorker0(NDArray send, bool in_group, Optional<NDArray> recv) {

void RecvFromWorker0(NDArray buffer) { GetCCLFunc("recv_from_worker0")(buffer); }

void SendToNextGroup(NDArray buffer) { GetCCLFunc("send_to_next_group")(buffer); }

void RecvFromPrevGroup(NDArray buffer) { GetCCLFunc("recv_from_prev_group")(buffer); }

void SendToWorker(NDArray buffer, int receiver_id) {
GetCCLFunc("send_to_worker")(buffer, receiver_id);
}

void RecvFromWorker(NDArray buffer, int sender_id) {
GetCCLFunc("recv_from_worker")(buffer, sender_id);
}

int WorkerId() { return DiscoWorker::ThreadLocal()->worker_id; }

void SyncWorker() {
Expand Down Expand Up @@ -136,6 +148,10 @@ TVM_REGISTER_GLOBAL("runtime.disco.broadcast_from_worker0").set_body_typed(Broad
TVM_REGISTER_GLOBAL("runtime.disco.scatter_from_worker0").set_body_typed(ScatterFromWorker0);
TVM_REGISTER_GLOBAL("runtime.disco.gather_to_worker0").set_body_typed(GatherToWorker0);
TVM_REGISTER_GLOBAL("runtime.disco.recv_from_worker0").set_body_typed(RecvFromWorker0);
TVM_REGISTER_GLOBAL("runtime.disco.send_to_next_group").set_body_typed(SendToNextGroup);
TVM_REGISTER_GLOBAL("runtime.disco.recv_from_prev_group").set_body_typed(RecvFromPrevGroup);
TVM_REGISTER_GLOBAL("runtime.disco.send_to_worker").set_body_typed(SendToWorker);
TVM_REGISTER_GLOBAL("runtime.disco.recv_from_worker").set_body_typed(RecvFromWorker);
TVM_REGISTER_GLOBAL("runtime.disco.worker_id").set_body_typed([]() -> ShapeTuple {
return ShapeTuple({WorkerId()});
});
Expand Down
86 changes: 86 additions & 0 deletions src/runtime/disco/nccl/nccl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,57 @@ void RecvFromWorker0(NDArray buffer) {
NCCL_CALL(ncclGroupEnd());
}

void SendToNextGroup(NDArray buffer) {
CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
deviceStream_t stream = ctx->GetDefaultStream();
int worker_id = ctx->worker->worker_id;
int group_size = ctx->worker->num_workers / ctx->worker->num_groups;
int receiver_id = worker_id + group_size;
CHECK_LT(receiver_id, ctx->worker->num_workers)
<< "The current group is already the last group and there is no such a next group.";
NCCL_CALL(ncclGroupStart());
NCCL_CALL(ncclSend(buffer->data, buffer.Shape()->Product(), AsNCCLDataType(buffer.DataType()),
receiver_id, ctx->global_comm, stream));
NCCL_CALL(ncclGroupEnd());
}

void RecvFromPrevGroup(NDArray buffer) {
CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
deviceStream_t stream = ctx->GetDefaultStream();
int worker_id = ctx->worker->worker_id;
int group_size = ctx->worker->num_workers / ctx->worker->num_groups;
int sender_id = worker_id - group_size;
CHECK_GE(sender_id, 0)
<< "The current group is already the first group and there is no such a previous group.";
NCCL_CALL(ncclGroupStart());
NCCL_CALL(ncclRecv(buffer->data, buffer.Shape()->Product(), AsNCCLDataType(buffer.DataType()),
sender_id, ctx->global_comm, stream));
NCCL_CALL(ncclGroupEnd());
}

void SendToWorker(NDArray buffer, int receiver_id) {
CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
deviceStream_t stream = ctx->GetDefaultStream();
int worker_id = ctx->worker->worker_id;
CHECK(receiver_id >= 0 && receiver_id < ctx->worker->num_workers)
<< "Invalid receiver id " << receiver_id << ". The world size is "
<< ctx->worker->num_workers;
CHECK_NE(worker_id, receiver_id) << "Cannot send to worker itself.";
NCCL_CALL(ncclSend(buffer->data, buffer.Shape()->Product(), AsNCCLDataType(buffer.DataType()),
receiver_id, ctx->global_comm, stream));
}

void RecvFromWorker(NDArray buffer, int sender_id) {
CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
deviceStream_t stream = ctx->GetDefaultStream();
int worker_id = ctx->worker->worker_id;
CHECK(sender_id >= 0 && sender_id < ctx->worker->num_workers)
<< "Invalid sender id " << sender_id << ". The world size is " << ctx->worker->num_workers;
CHECK_NE(worker_id, sender_id) << "Cannot receive from the worker itself.";
NCCL_CALL(ncclRecv(buffer->data, buffer.Shape()->Product(), AsNCCLDataType(buffer.DataType()),
sender_id, ctx->global_comm, stream));
}

void SyncWorker() {
CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
ICHECK(ctx->worker != nullptr);
Expand Down Expand Up @@ -284,8 +335,43 @@ TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".gather_to_worker0")
.set_body_typed(GatherToWorker0);
TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".recv_from_worker0")
.set_body_typed(RecvFromWorker0);
TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".send_to_next_group")
.set_body_typed(SendToNextGroup);
TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".recv_from_prev_group")
.set_body_typed(RecvFromPrevGroup);
TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".send_to_worker")
.set_body_typed(SendToWorker);
TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".recv_from_worker")
.set_body_typed(RecvFromWorker);
TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".sync_worker").set_body_typed(SyncWorker);

TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME
".test_send_to_next_group_recv_from_prev_group")
.set_body_typed([](NDArray buffer) {
CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
CHECK_EQ(ctx->worker->num_workers, 4) << "The test requires the world size to be 4.";
CHECK_EQ(ctx->worker->num_groups, 2) << "The test requires the group size to be 2.";
int group_size = ctx->worker->num_workers / ctx->worker->num_groups;
int group_id = ctx->worker->worker_id / group_size;
if (group_id == 0) {
tvm::runtime::nccl::SendToNextGroup(buffer);
} else {
tvm::runtime::nccl::RecvFromPrevGroup(buffer);
}
});

TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".test_worker2_sends_to_worker0")
.set_body_typed([](NDArray buffer) {
CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
CHECK_EQ(ctx->worker->num_workers, 4) << "The test requires the world size to be 4.";
CHECK_EQ(ctx->worker->num_groups, 2) << "The test requires the group size to be 2.";
if (ctx->worker->worker_id == 2) {
tvm::runtime::nccl::SendToWorker(buffer, 0);
} else if (ctx->worker->worker_id == 0) {
tvm::runtime::nccl::RecvFromWorker(buffer, 2);
}
});

} // namespace nccl
} // namespace runtime
} // namespace tvm
40 changes: 39 additions & 1 deletion tests/python/disco/test_ccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
import tvm
import tvm.testing
from tvm import dlight as dl
from tvm import get_global_func
from tvm import relax as rx
from tvm.runtime import disco as di
from tvm.runtime.relax_vm import VirtualMachine
from tvm.script import relax as R
from tvm import get_global_func

_all_session_kinds = [di.ThreadedSession, di.ProcessSession]
_ccl = [get_global_func("runtime.disco.compiled_ccl")()]
Expand Down Expand Up @@ -391,6 +391,44 @@ def test_group_gather(session_kind, ccl, capfd):
), "No warning messages should be generated from disco.Session.gather_to_worker0"


@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
def test_send_to_next_group_receive_from_prev_group(session_kind, ccl):
devices = [0, 1, 2, 3]
sess = session_kind(num_workers=len(devices), num_groups=2)
sess.init_ccl(ccl, *devices)

array_1 = np.arange(12, dtype="float32").reshape(3, 4)
array_2 = np.arange(start=1, stop=-11, step=-1, dtype="float32").reshape(3, 4)
d_array = sess.empty((3, 4), "float32")
d_array.debug_copy_from(0, array_1)
d_array.debug_copy_from(1, array_2)
sess.get_global_func("runtime.disco." + ccl + ".test_send_to_next_group_recv_from_prev_group")(
d_array
)

result_1 = d_array.debug_get_from_remote(2).numpy()
result_2 = d_array.debug_get_from_remote(3).numpy()
np.testing.assert_equal(result_1, array_1)
np.testing.assert_equal(result_2, array_2)


@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
def test_worker2_send_to_worker0(session_kind, ccl):
devices = [0, 1, 2, 3]
sess = session_kind(num_workers=len(devices), num_groups=2)
sess.init_ccl(ccl, *devices)

array = np.arange(start=1, stop=-11, step=-1, dtype="float32").reshape(3, 4)
d_array = sess.empty((3, 4), "float32")
d_array.debug_copy_from(2, array)
sess.get_global_func("runtime.disco." + ccl + ".test_worker2_sends_to_worker0")(d_array)

result = d_array.debug_get_from_remote(0).numpy()
np.testing.assert_equal(result, array)


@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
def test_mlp(session_kind, ccl): # pylint: disable=too-many-locals
Expand Down