Skip to content

Commit f5bb85b

Browse files
authored
[Core][Distributed] improve p2p cache generation (#5528)
1 parent 28c145e commit f5bb85b

File tree

2 files changed

+265
-96
lines changed

2 files changed

+265
-96
lines changed
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
"""This file is a pure Python wrapper for the cudart library.
2+
It avoids the need to compile a separate shared library, and is
3+
convenient for use when we just need to call a few functions.
4+
"""
5+
6+
import ctypes
7+
from dataclasses import dataclass
8+
from typing import Any, Dict, List, Optional
9+
10+
# this line makes it possible to directly load `libcudart.so` using `ctypes`
11+
import torch # noqa
12+
13+
from vllm.logger import init_logger
14+
15+
logger = init_logger(__name__)
16+
17+
# === export types and functions from cudart to Python ===
18+
# for the original cudart definition, please check
19+
# https://docs.nvidia.com/cuda/cuda-runtime-api/index.html
20+
21+
cudaError_t = ctypes.c_int
22+
cudaMemcpyKind = ctypes.c_int
23+
24+
25+
class cudaIpcMemHandle_t(ctypes.Structure):
26+
_fields_ = [("internal", ctypes.c_byte * 128)]
27+
28+
29+
@dataclass
30+
class Function:
31+
name: str
32+
restype: Any
33+
argtypes: List[Any]
34+
35+
36+
class CudaRTLibrary:
37+
exported_functions = [
38+
# ​cudaError_t cudaSetDevice ( int device )
39+
Function("cudaSetDevice", cudaError_t, [ctypes.c_int]),
40+
# cudaError_t cudaDeviceSynchronize ( void )
41+
Function("cudaDeviceSynchronize", cudaError_t, []),
42+
# ​cudaError_t cudaDeviceReset ( void )
43+
Function("cudaDeviceReset", cudaError_t, []),
44+
45+
# const char* cudaGetErrorString ( cudaError_t error )
46+
Function("cudaGetErrorString", ctypes.c_char_p, [cudaError_t]),
47+
48+
# ​cudaError_t cudaMalloc ( void** devPtr, size_t size )
49+
Function("cudaMalloc", cudaError_t,
50+
[ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]),
51+
# ​cudaError_t cudaFree ( void* devPtr )
52+
Function("cudaFree", cudaError_t, [ctypes.c_void_p]),
53+
# ​cudaError_t cudaMemset ( void* devPtr, int value, size_t count )
54+
Function("cudaMemset", cudaError_t,
55+
[ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]),
56+
# ​cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa
57+
Function("cudaMemcpy", cudaError_t, [
58+
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind
59+
]),
60+
61+
# cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa
62+
Function("cudaIpcGetMemHandle", cudaError_t,
63+
[ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p]),
64+
# ​cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa
65+
Function("cudaIpcOpenMemHandle", cudaError_t, [
66+
ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint
67+
]),
68+
]
69+
70+
# class attribute to store the mapping from the path to the library
71+
# to avoid loading the same library multiple times
72+
path_to_library_cache: Dict[str, Any] = {}
73+
74+
# class attribute to store the mapping from library path
75+
# to the corresponding dictionary
76+
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
77+
78+
def __init__(self, so_file: Optional[str] = None):
79+
if so_file is None:
80+
assert torch.version.cuda is not None
81+
major_version = torch.version.cuda.split(".")[0]
82+
so_file = f"libcudart.so.{major_version}"
83+
if so_file not in CudaRTLibrary.path_to_library_cache:
84+
lib = ctypes.CDLL(so_file)
85+
CudaRTLibrary.path_to_library_cache[so_file] = lib
86+
self.lib = CudaRTLibrary.path_to_library_cache[so_file]
87+
88+
if so_file not in CudaRTLibrary.path_to_dict_mapping:
89+
_funcs = {}
90+
for func in CudaRTLibrary.exported_functions:
91+
f = getattr(self.lib, func.name)
92+
f.restype = func.restype
93+
f.argtypes = func.argtypes
94+
_funcs[func.name] = f
95+
CudaRTLibrary.path_to_dict_mapping[so_file] = _funcs
96+
self.funcs = CudaRTLibrary.path_to_dict_mapping[so_file]
97+
98+
def CUDART_CHECK(self, result: cudaError_t) -> None:
99+
if result != 0:
100+
error_str = self.cudaGetErrorString(result)
101+
raise RuntimeError(f"CUDART error: {error_str}")
102+
103+
def cudaGetErrorString(self, error: cudaError_t) -> str:
104+
return self.funcs["cudaGetErrorString"](error).decode("utf-8")
105+
106+
def cudaSetDevice(self, device: int) -> None:
107+
self.CUDART_CHECK(self.funcs["cudaSetDevice"](device))
108+
109+
def cudaDeviceSynchronize(self) -> None:
110+
self.CUDART_CHECK(self.funcs["cudaDeviceSynchronize"]())
111+
112+
def cudaDeviceReset(self) -> None:
113+
self.CUDART_CHECK(self.funcs["cudaDeviceReset"]())
114+
115+
def cudaMalloc(self, size: int) -> ctypes.c_void_p:
116+
devPtr = ctypes.c_void_p()
117+
self.CUDART_CHECK(self.funcs["cudaMalloc"](ctypes.byref(devPtr), size))
118+
return devPtr
119+
120+
def cudaFree(self, devPtr: ctypes.c_void_p) -> None:
121+
self.CUDART_CHECK(self.funcs["cudaFree"](devPtr))
122+
123+
def cudaMemset(self, devPtr: ctypes.c_void_p, value: int,
124+
count: int) -> None:
125+
self.CUDART_CHECK(self.funcs["cudaMemset"](devPtr, value, count))
126+
127+
def cudaMemcpy(self, dst: ctypes.c_void_p, src: ctypes.c_void_p,
128+
count: int) -> None:
129+
cudaMemcpyDefault = 4
130+
kind = cudaMemcpyDefault
131+
self.CUDART_CHECK(self.funcs["cudaMemcpy"](dst, src, count, kind))
132+
133+
def cudaIpcGetMemHandle(self,
134+
devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t:
135+
handle = cudaIpcMemHandle_t()
136+
self.CUDART_CHECK(self.funcs["cudaIpcGetMemHandle"](
137+
ctypes.byref(handle), devPtr))
138+
return handle
139+
140+
def cudaIpcOpenMemHandle(self,
141+
handle: cudaIpcMemHandle_t) -> ctypes.c_void_p:
142+
cudaIpcMemLazyEnablePeerAccess = 1
143+
devPtr = ctypes.c_void_p()
144+
self.CUDART_CHECK(self.funcs["cudaIpcOpenMemHandle"](
145+
ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess))
146+
return devPtr

0 commit comments

Comments
 (0)