Skip to content

Commit 317f062

Browse files
committed
[Disco] Implement num_workers property for disco.Session
Prior to this commit, while the `num_workers` argument was provided to the `disco.Session` object, it could not be determined from an existing `disco.Session` object. As a result, functions that interacted with a multi-GPU setup frequently required separate `num_workers` and `disco_session` argument, which could erroneously be out-of-sync (e.g. passing the incorrect `num_workers`, or omitting the `disco_session` argument when `num_workers>1`). To remove this class of errors, this commit adds a `disco.Session.num_workers` property. The separate `num_workers` argument is no longer necessary, as it can be determined from the `disco.Session` instance.
1 parent 0dfc5f9 commit 317f062

File tree

6 files changed

+20
-0
lines changed

6 files changed

+20
-0
lines changed

include/tvm/runtime/disco/session.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,8 @@ class SessionObj : public Object {
197197
* The thirtd element is the function to be called.
198198
*/
199199
TVM_DLL virtual DRef CallWithPacked(const TVMArgs& args) = 0;
200+
/*! \brief Get the number of workers in the session. */
201+
TVM_DLL virtual int64_t GetNumWorkers() = 0;
200202
/*! \brief Get a global functions on workers. */
201203
TVM_DLL virtual DRef GetGlobalFunc(const std::string& name) = 0;
202204
/*!

python/tvm/runtime/disco/session.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,11 @@ def shutdown(self):
146146
"""Shut down the Disco session"""
147147
_ffi_api.SessionShutdown(self) # type: ignore # pylint: disable=no-member
148148

149+
@property
150+
def num_workers(self) -> int:
151+
"""Return the number of workers in the session"""
152+
return _ffi_api.SessionGetNumWorkers(self) # type: ignore # pylint: disable=no-member
153+
149154
def get_global_func(self, name: str) -> DRef:
150155
"""Get a global function on workers.
151156

src/runtime/disco/process_session.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,8 @@ class ProcessSessionObj final : public BcastSessionObj {
153153

154154
~ProcessSessionObj() { Kill(); }
155155

156+
int64_t GetNumWorkers() { return workers_.size() + 1; }
157+
156158
TVMRetValue DebugGetFromRemote(int64_t reg_id, int worker_id) {
157159
if (worker_id == 0) {
158160
this->SyncWorker(worker_id);

src/runtime/disco/session.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ TVM_REGISTER_GLOBAL("runtime.disco.DRefDebugGetFromRemote")
3737
.set_body_method<DRef>(&DRefObj::DebugGetFromRemote);
3838
TVM_REGISTER_GLOBAL("runtime.disco.DRefDebugCopyFrom")
3939
.set_body_method<DRef>(&DRefObj::DebugCopyFrom);
40+
TVM_REGISTER_GLOBAL("runtime.disco.SessionGetNumWorkers")
41+
.set_body_method<Session>(&SessionObj::GetNumWorkers);
4042
TVM_REGISTER_GLOBAL("runtime.disco.SessionGetGlobalFunc")
4143
.set_body_method<Session>(&SessionObj::GetGlobalFunc);
4244
TVM_REGISTER_GLOBAL("runtime.disco.SessionCopyFromWorker0")

src/runtime/disco/threaded_session.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,8 @@ class ThreadedSessionObj final : public BcastSessionObj {
154154
workers_.clear();
155155
}
156156

157+
int64_t GetNumWorkers() { return workers_.size(); }
158+
157159
TVMRetValue DebugGetFromRemote(int64_t reg_id, int worker_id) {
158160
this->SyncWorker(worker_id);
159161
return this->workers_.at(worker_id).worker->register_file.at(reg_id);

tests/python/disco/test_session.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,13 @@ def transpose_2(
220220
np.testing.assert_equal(z_nd, x_np)
221221

222222

223+
@pytest.mark.parametrize("session_kind", _all_session_kinds)
224+
@pytest.mark.parametrize("num_workers", [1, 2, 4])
225+
def test_num_workers(session_kind, num_workers):
226+
sess = session_kind(num_workers=num_workers)
227+
assert sess.num_workers == num_workers
228+
229+
223230
if __name__ == "__main__":
224231
test_int(di.ProcessSession)
225232
test_float(di.ProcessSession)

0 commit comments

Comments
 (0)