Skip to content

Commit 40b6c14

Browse files
authored
[Disco] Add NVSHMEM support (#17317)
This PR adds the supports of NVSHMEM.
1 parent 98de9ba commit 40b6c14

File tree

6 files changed

+261
-0
lines changed

6 files changed

+261
-0
lines changed

CMakeLists.txt

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ include(cmake/utils/FindLLVM.cmake)
1313
include(cmake/utils/FindROCM.cmake)
1414
include(cmake/utils/FindRCCL.cmake)
1515
include(cmake/utils/FindEthosN.cmake)
16+
include(cmake/utils/FindNVSHMEM.cmake)
1617

1718
if(EXISTS ${CMAKE_BINARY_DIR}/config.cmake)
1819
include(${CMAKE_BINARY_DIR}/config.cmake)
@@ -133,6 +134,7 @@ tvm_option(USE_UMA "Build with UMA support" OFF)
133134
tvm_option(USE_VERILATOR "Build with Verilator support" OFF)
134135
tvm_option(USE_MSC "Enable Multi-System Compiler" OFF)
135136
tvm_option(USE_MRVL "Build with MRVL TVM support" OFF)
137+
tvm_option(USE_NVSHMEM "Build with NVSHMEM support" OFF)
136138

137139
# include directories
138140
include_directories(${CMAKE_INCLUDE_PATH})
@@ -472,6 +474,16 @@ if(USE_CUDA AND USE_NCCL)
472474
list(APPEND RUNTIME_SRCS ${RUNTIME_NCCL_SRC})
473475
endif()
474476

477+
if (USE_CUDA AND USE_NVSHMEM)
478+
message(STATUS "Build with NVSHMEM...")
479+
find_nvshmem(${USE_NVSHMEM})
480+
if (NOT NVSHMEM_FOUND)
481+
message(FATAL_ERROR "Cannot find NVSHMEM, USE_NVSHMEM=" ${USE_NVSHMEM})
482+
endif()
483+
tvm_file_glob(GLOB RUNTIME_NVSHMEM_SRCS src/runtime/contrib/nvshmem/*.cc)
484+
list(APPEND RUNTIME_SRCS ${RUNTIME_NVSHMEM_SRCS})
485+
endif()
486+
475487
if(USE_ROCM AND USE_RCCL)
476488
message(STATUS "Build with RCCL...")
477489
find_rccl(${USE_RCCL})
@@ -957,6 +969,17 @@ if(USE_CUDA AND USE_NCCL)
957969
target_link_libraries(tvm_runtime PRIVATE nccl ${LIBRT})
958970
endif()
959971

972+
973+
if (USE_CUDA AND USE_NVSHMEM)
974+
include_directories(SYSTEM ${USE_NVSHMEM}/include)
975+
find_library(NVSHMEM_HOST nvshmem_host ${NVSHMEM_LIB_DIR})
976+
find_library(NVSHMEM_DEVICE nvshmem_device ${NVSHMEM_LIB_DIR})
977+
target_link_libraries(tvm PRIVATE ${NVSHMEM_HOST} ${NVSHMEM_DEVICE})
978+
target_link_libraries(tvm_runtime PRIVATE ${NVSHMEM_HOST} ${NVSHMEM_DEVICE})
979+
set_target_properties(tvm PROPERTIES CUDA_SEPARABLE_COMPILATION ON)
980+
set_target_properties(tvm_runtime PROPERTIES CUDA_SEPARABLE_COMPILATION ON)
981+
endif()
982+
960983
if(USE_ROCM AND USE_RCCL)
961984
target_link_libraries(tvm PRIVATE rccl)
962985
target_link_libraries(tvm_runtime PRIVATE rccl)

cmake/modules/LibInfo.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ function(add_lib_info src_file)
143143
TVM_INFO_USE_VERILATOR="${USE_VERILATOR}"
144144
TVM_INFO_USE_MSC="${USE_MSC}"
145145
TVM_INFO_USE_CCACHE="${USE_CCACHE}"
146+
TVM_INFO_USE_NVSHMEM="${USE_NVSHMEM}"
146147
TVM_INFO_BACKTRACE_ON_SEGFAULT="${BACKTRACE_ON_SEGFAULT}"
147148
)
148149

cmake/utils/FindNVSHMEM.cmake

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
#######################################################
19+
# Enhanced version of find NVSHMEM.
20+
#
21+
# Usage:
22+
# find_nvshmem(${USE_NVSHMEM})
23+
#
24+
# - When USE_NVSHMEM=ON, use auto search
25+
# - When USE_NVSHMEM=/path/to/installed/nvshmem, use the installed nvshmem path.
26+
# Can be useful when nvshmem is installed at specified location.
27+
#
28+
# Provide variables:
29+
#
30+
# - NVSHMEM_FOUND
31+
# - NVSHMEM_INCLUDE_DIR
32+
# - NVSHMEM_LIB_DIR
33+
#
34+
35+
macro(find_nvshmem use_nvshmem)
36+
set(__use_nvshmem ${use_nvshmem})
37+
if(IS_DIRECTORY ${__use_nvshmem})
38+
set(__nvshmem_path ${__use_nvshmem})
39+
message(STATUS "Custom NVSHMEM PATH=" ${__use_nvshmem})
40+
elseif(IS_DIRECTORY $ENV{NVSHMEM_HOME})
41+
set(__nvshmem_path $ENV{NVSHMEM_HOME})
42+
else()
43+
set(__nvshmem_path "")
44+
endif()
45+
46+
find_package(NVSHMEM HINTS ${__nvshmem_path}/lib/cmake/nvshmem/)
47+
48+
if(NVSHMEM_FOUND)
49+
message(STATUS "NVSHMEM_INCLUDE_DIR=" ${NVSHMEM_INCLUDE_DIR})
50+
message(STATUS "NVSHMEM_LIB_DIR=" ${NVSHMEM_LIB_DIR})
51+
endif(NVSHMEM_FOUND)
52+
endmacro(find_nvshmem)
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
#include <nvshmem.h>
20+
#include <nvshmemx.h>
21+
#include <tvm/runtime/disco/disco_worker.h>
22+
#include <tvm/runtime/packed_func.h>
23+
#include <tvm/runtime/registry.h>
24+
25+
#include "../../cuda/cuda_common.h"
26+
27+
namespace tvm {
28+
namespace runtime {
29+
30+
ShapeTuple InitNVSHMEMUID() {
31+
nvshmemx_uniqueid_t uid;
32+
nvshmemx_get_uniqueid(&uid);
33+
std::vector<int64_t> uid_64;
34+
uid_64.push_back(static_cast<int64_t>(uid.version));
35+
for (int i = 0; i < UNIQUEID_PADDING; ++i) {
36+
uid_64.push_back(static_cast<int64_t>(uid.internal[i]));
37+
}
38+
return ShapeTuple(uid_64);
39+
}
40+
41+
void InitNVSHMEM(ShapeTuple uid_64, int num_workers) {
42+
DiscoWorker* worker = DiscoWorker::ThreadLocal();
43+
ICHECK(worker != nullptr);
44+
CHECK_EQ(uid_64.size(), UNIQUEID_PADDING + 1)
45+
<< "ValueError: The length of unique_id must be " << UNIQUEID_PADDING << ", but got "
46+
<< uid_64.size() << ".";
47+
48+
nvshmemx_init_attr_t attr = NVSHMEMX_INIT_ATTR_INITIALIZER;
49+
50+
nvshmemx_uniqueid_t uid;
51+
uid.version = static_cast<int>(uid_64[0]);
52+
for (int i = 0; i < UNIQUEID_PADDING; ++i) {
53+
uid.internal[i] = static_cast<char>(uid_64[i + 1]);
54+
}
55+
nvshmemx_set_attr_uniqueid_args(worker->worker_id, num_workers, &uid, &attr);
56+
nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr);
57+
LOG_INFO << "NVSHMEM init finished: mype=" << nvshmem_my_pe() << " "
58+
<< ", npes=" << nvshmem_n_pes();
59+
}
60+
61+
TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem_uid").set_body_typed(InitNVSHMEMUID);
62+
63+
TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem").set_body_typed(InitNVSHMEM);
64+
65+
} // namespace runtime
66+
} // namespace tvm

src/support/libinfo.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,10 @@
275275
#define TVM_INFO_USE_CCACHE "NOT-FOUND"
276276
#endif
277277

278+
#ifndef TVM_INFO_USE_NVSHMEM
279+
#define TVM_INFO_USE_NVSHMEM "NOT-FOUND"
280+
#endif
281+
278282
namespace tvm {
279283

280284
/*!
@@ -387,6 +391,7 @@ TVM_DLL Map<String, String> GetLibInfo() {
387391
{"USE_VERILATOR", TVM_INFO_USE_VERILATOR},
388392
{"USE_MSC", TVM_INFO_USE_MSC},
389393
{"USE_CCACHE", TVM_INFO_USE_CCACHE},
394+
{"USE_NVSHMEM", TVM_INFO_USE_NVSHMEM},
390395
{"BACKTRACE_ON_SEGFAULT", TVM_INFO_BACKTRACE_ON_SEGFAULT},
391396
};
392397
return result;

tests/python/disco/test_nvshmem.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Basic tests for a Disco nvshmem support"""
18+
# pylint: disable=missing-docstring
19+
import tempfile
20+
21+
import numpy as np
22+
import pytest
23+
import subprocess
24+
import threading
25+
import sys
26+
27+
import tvm
28+
import tvm.testing
29+
from tvm.runtime import ShapeTuple
30+
from tvm.runtime import disco as di
31+
from tvm.exec import disco_worker as _ # pylint: disable=unused-import
32+
33+
_SOCKET_SESSION_TESTER = None
34+
35+
36+
def get_free_port():
37+
import socket
38+
39+
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
40+
s.bind(("", 0))
41+
port = s.getsockname()[1]
42+
s.close()
43+
return port
44+
45+
46+
class SocketSessionTester:
47+
def __init__(self, num_workers):
48+
num_nodes = 2
49+
num_groups = 1
50+
assert num_workers % num_nodes == 0
51+
num_workers_per_node = num_workers // num_nodes
52+
server_host = "localhost"
53+
server_port = get_free_port()
54+
self.sess = None
55+
56+
def start_server():
57+
self.sess = di.SocketSession(
58+
num_nodes, num_workers_per_node, num_groups, server_host, server_port
59+
)
60+
61+
thread = threading.Thread(target=start_server)
62+
thread.start()
63+
64+
cmd = "tvm.exec.disco_remote_socket_session"
65+
self.remote_nodes = []
66+
for _ in range(num_nodes - 1):
67+
self.remote_nodes.append(
68+
subprocess.Popen(
69+
[
70+
"python3",
71+
"-m",
72+
cmd,
73+
server_host,
74+
str(server_port),
75+
str(num_workers_per_node),
76+
],
77+
stdout=sys.stdout,
78+
stderr=sys.stderr,
79+
)
80+
)
81+
82+
thread.join()
83+
84+
def __del__(self):
85+
for node in self.remote_nodes:
86+
node.kill()
87+
if self.sess is not None:
88+
self.sess.shutdown()
89+
del self.sess
90+
91+
92+
def create_socket_session(num_workers):
93+
global _SOCKET_SESSION_TESTER
94+
if _SOCKET_SESSION_TESTER is not None:
95+
del _SOCKET_SESSION_TESTER
96+
_SOCKET_SESSION_TESTER = SocketSessionTester(num_workers)
97+
assert _SOCKET_SESSION_TESTER.sess is not None
98+
return _SOCKET_SESSION_TESTER.sess
99+
100+
101+
@pytest.mark.parametrize("num_workers", [2, 4])
102+
def test_nvshmem_init(num_workers):
103+
if tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid", True) is None:
104+
return
105+
sess = create_socket_session(num_workers=num_workers)
106+
f_init_nvshmem_uid = tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid")
107+
uid = f_init_nvshmem_uid()
108+
init_dfunc = sess.get_global_func("runtime.disco.nvshmem.init_nvshmem")
109+
init_dfunc(uid, num_workers)
110+
sess.sync_worker_0()
111+
112+
113+
if __name__ == "__main__":
114+
tvm.testing.main()

0 commit comments

Comments
 (0)