Skip to content

Commit c44798d

Browse files
committed
[Disco] Implement SocketSession
Implements SocketSession that connects multiple local worker processes/threads over multiple distributed nodes via TCP socket.
1 parent 73078f1 commit c44798d

File tree

14 files changed

+645
-111
lines changed

14 files changed

+645
-111
lines changed

include/tvm/runtime/disco/disco_worker.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ class DiscoWorker {
5252
DiscoChannel* channel)
5353
: worker_id(worker_id),
5454
num_workers(num_workers),
55+
local_worker_id(worker_id),
5556
default_device(Device{DLDeviceType::kDLCPU, 0}),
5657
worker_zero_data(worker_zero_data),
5758
channel(channel),
@@ -68,6 +69,10 @@ class DiscoWorker {
6869
int worker_id;
6970
/*! \brief Total number of workers */
7071
int num_workers;
72+
/*! \brief The local worker id. This can be different from `worker_id` if the session is
73+
* consisted of multiple distritributed sub-sessions.
74+
*/
75+
int local_worker_id;
7176
/*! \brief The default device to allocate data if not specified */
7277
Device default_device;
7378
/*! \brief The name of the underlying collective communication library. */

include/tvm/runtime/disco/session.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,17 @@ class Session : public ObjectRef {
279279
*/
280280
TVM_DLL static Session ProcessSession(int num_workers, String process_pool_creator,
281281
String entrypoint);
282+
283+
/*!
284+
* \brief Create a session backed by TCP sockets
285+
* \param num_nodes The number of nodes in the cluster connected by TCP sockets.
286+
* \param num_workers_per_node The number of workers on each node.
287+
* \param host The host name of the controler.
288+
* \param port The port number of the controler.
289+
*/
290+
TVM_DLL static Session SocketSession(int num_nodes, int num_workers_per_node, const String& host,
291+
int port);
292+
282293
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Session, ObjectRef, SessionObj);
283294
};
284295

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=invalid-name
18+
import sys
19+
import tvm
20+
from . import disco_worker as _ # pylint: disable=unused-import
21+
22+
23+
if __name__ == "__main__":
24+
if len(sys.argv) != 3:
25+
print("Usage: <server_host> <server_port>")
26+
sys.exit(1)
27+
28+
server_host = sys.argv[1]
29+
server_port = int(sys.argv[2])
30+
func = tvm.get_global_func("runtime.disco.RemoteSocketSession")
31+
func(server_host, server_port)

python/tvm/runtime/disco/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,5 @@
2222
ProcessSession,
2323
Session,
2424
ThreadedSession,
25+
SocketSession,
2526
)

python/tvm/runtime/disco/session.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,26 @@ def _configure_structlog(self) -> None:
527527
func(config, os.getpid())
528528

529529

530+
@register_func("runtime.disco.create_socket_session_local_workers")
531+
def _create_socket_session_local_workers(num_workers) -> Session:
532+
"""Create the local session for each distributed node over socket session."""
533+
return ProcessSession(num_workers)
534+
535+
536+
@register_object("runtime.disco.SocketSession")
537+
class SocketSession(Session):
538+
"""A Disco session backed by socket-based multi-node communication."""
539+
540+
def __init__(self, num_nodes: int, num_workers_per_node: int, host: str, port: int) -> None:
541+
self.__init_handle_by_constructor__(
542+
_ffi_api.SocketSession, # type: ignore # pylint: disable=no-member
543+
num_nodes,
544+
num_workers_per_node,
545+
host,
546+
port,
547+
)
548+
549+
530550
@register_func("runtime.disco._configure_structlog")
531551
def _configure_structlog(pickled_config: bytes, parent_pid: int) -> None:
532552
"""Configure structlog for all disco workers

src/runtime/disco/bcast_session.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,16 @@ class BcastSessionObj : public SessionObj {
6565
* \param TVMArgs The input arguments in TVM's PackedFunc calling convention
6666
*/
6767
virtual void BroadcastPacked(const TVMArgs& args) = 0;
68+
69+
/*!
70+
* \brief Send a packed sequence to a worker. This function is usually called by the controler to
71+
* communicate with worker-0, because the worker-0 is assumed to be always collocated with the
72+
* controler. Sending to other workers may not be supported.
73+
* \param worker_id The worker id to send the packed sequence to.
74+
* \param args The packed sequence to send.
75+
*/
76+
virtual void SendPacked(int worker_id, const TVMArgs& args) = 0;
77+
6878
/*!
6979
* \brief Receive a packed sequence from a worker. This function is usually called by the
7080
* controler to communicate with worker-0, because the worker-0 is assumed to be always
@@ -83,6 +93,16 @@ class BcastSessionObj : public SessionObj {
8393

8494
struct Internal;
8595
friend struct Internal;
96+
friend class SocketSessionObj;
97+
friend class RemoteSocketSession;
98+
};
99+
100+
/*!
101+
* \brief Managed reference to BcastSessionObj.
102+
*/
103+
class BcastSession : public Session {
104+
public:
105+
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(BcastSession, Session, BcastSessionObj);
86106
};
87107

88108
} // namespace runtime

src/runtime/disco/disco_worker.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,15 +129,15 @@ struct DiscoWorker::Impl {
129129
}
130130

131131
static void CopyFromWorker0(DiscoWorker* self, int reg_id) {
132-
if (self->worker_zero_data != nullptr) {
132+
if (self->worker_id == 0) {
133133
NDArray tgt = GetNDArrayFromHost(self);
134134
NDArray src = GetReg(self, reg_id);
135135
tgt.CopyFrom(src);
136136
}
137137
}
138138

139139
static void CopyToWorker0(DiscoWorker* self, int reg_id) {
140-
if (self->worker_zero_data != nullptr) {
140+
if (self->worker_id == 0) {
141141
NDArray src = GetNDArrayFromHost(self);
142142
NDArray tgt = GetReg(self, reg_id);
143143
tgt.CopyFrom(src);
@@ -206,5 +206,11 @@ struct DiscoWorker::Impl {
206206

207207
void DiscoWorker::MainLoop() { DiscoWorker::Impl::MainLoop(this); }
208208

209+
TVM_REGISTER_GLOBAL("runtime.disco.set_worker_id").set_body_typed([](IntTuple worker_ids) {
210+
DiscoWorker* worker = DiscoWorker::ThreadLocal();
211+
ICHECK_EQ(worker->num_workers, worker_ids.size());
212+
worker->worker_id = worker_ids[worker->local_worker_id];
213+
});
214+
209215
} // namespace runtime
210216
} // namespace tvm

src/runtime/disco/message_queue.h

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
#ifndef TVM_RUNTIME_DISCO_MESSAGE_QUEUE_H_
20+
#define TVM_RUNTIME_DISCO_MESSAGE_QUEUE_H_
21+
22+
#include <dmlc/io.h>
23+
24+
#include "./protocol.h"
25+
26+
namespace tvm {
27+
namespace runtime {
28+
29+
class DiscoStreamMessageQueue : private dmlc::Stream,
30+
private DiscoProtocol<DiscoStreamMessageQueue> {
31+
public:
32+
explicit DiscoStreamMessageQueue(Stream* stream) : stream_(stream) {}
33+
34+
~DiscoStreamMessageQueue() = default;
35+
36+
void Send(const TVMArgs& args) {
37+
RPCReference::ReturnPackedSeq(args.values, args.type_codes, args.num_args, this);
38+
CommitSendAndNotifyEnqueue();
39+
}
40+
41+
TVMArgs Recv() {
42+
bool is_implicit_shutdown = DequeueNextPacket();
43+
TVMValue* values = nullptr;
44+
int* type_codes = nullptr;
45+
int num_args = 0;
46+
47+
if (is_implicit_shutdown) {
48+
num_args = 2;
49+
values = ArenaAlloc<TVMValue>(num_args);
50+
type_codes = ArenaAlloc<int>(num_args);
51+
TVMArgsSetter setter(values, type_codes);
52+
setter(0, static_cast<int>(DiscoAction::kShutDown));
53+
setter(1, 0);
54+
} else {
55+
RPCReference::RecvPackedSeq(&values, &type_codes, &num_args, this);
56+
}
57+
return TVMArgs(values, type_codes, num_args);
58+
}
59+
60+
protected:
61+
void CommitSendAndNotifyEnqueue() {
62+
stream_->Write(write_buffer_.data(), write_buffer_.size());
63+
write_buffer_.clear();
64+
}
65+
66+
/* \brief Read next packet and reset unpacker
67+
*
68+
* Read the next packet into `read_buffer_`, releasing all arena
69+
* allocations performed by the unpacker and resetting the unpacker
70+
* to its initial state.
71+
*
72+
* \return A boolean value. If true, this packet should be treated
73+
* equivalently to a `DiscoAction::kShutdown` event. If false,
74+
* this packet should be unpacked.
75+
*/
76+
bool DequeueNextPacket() {
77+
uint64_t packet_nbytes = 0;
78+
int read_size = stream_->Read(&packet_nbytes, sizeof(packet_nbytes));
79+
if (read_size == 0) {
80+
// Special case, connection dropped between packets. Treat as a
81+
// request to shutdown.
82+
return true;
83+
}
84+
85+
ICHECK_EQ(read_size, sizeof(packet_nbytes))
86+
<< "Stream closed without proper shutdown. Please make sure to explicitly call "
87+
"`Session::Shutdown`";
88+
read_buffer_.resize(packet_nbytes);
89+
read_size = stream_->Read(read_buffer_.data(), packet_nbytes);
90+
ICHECK_EQ(read_size, packet_nbytes)
91+
<< "Stream closed without proper shutdown. Please make sure to explicitly call "
92+
"`Session::Shutdown`";
93+
read_offset_ = 0;
94+
this->RecycleAll();
95+
RPCCode code = RPCCode::kReturn;
96+
this->Read(&code);
97+
return false;
98+
}
99+
100+
size_t Read(void* data, size_t size) final {
101+
std::memcpy(data, read_buffer_.data() + read_offset_, size);
102+
read_offset_ += size;
103+
ICHECK_LE(read_offset_, read_buffer_.size());
104+
return size;
105+
}
106+
107+
size_t Write(const void* data, size_t size) final {
108+
size_t cur_size = write_buffer_.size();
109+
write_buffer_.resize(cur_size + size);
110+
std::memcpy(write_buffer_.data() + cur_size, data, size);
111+
return size;
112+
}
113+
114+
using dmlc::Stream::Read;
115+
using dmlc::Stream::ReadArray;
116+
using dmlc::Stream::Write;
117+
using dmlc::Stream::WriteArray;
118+
friend struct RPCReference;
119+
friend struct DiscoProtocol<DiscoStreamMessageQueue>;
120+
121+
// The read/write buffer will only be accessed by the producer thread.
122+
std::string write_buffer_;
123+
std::string read_buffer_;
124+
size_t read_offset_ = 0;
125+
dmlc::Stream* stream_;
126+
};
127+
128+
} // namespace runtime
129+
} // namespace tvm
130+
131+
#endif // TVM_RUNTIME_DISCO_MESSAGE_QUEUE_H_

src/runtime/disco/nccl/nccl.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ void InitCCLPerWorker(IntTuple device_ids, std::string unique_id_bytes) {
8383
<< "and has not been destructed";
8484

8585
// Step up local context of NCCL
86-
int device_id = device_ids[worker->worker_id];
86+
int device_id = device_ids[worker->local_worker_id];
8787
SetDevice(device_id);
8888
#if TVM_NCCL_RCCL_SWITCH == 0
8989
StreamCreate(&ctx->default_stream);

0 commit comments

Comments
 (0)