|
| 1 | +"""This file is a pure Python wrapper for the cudart library. |
| 2 | +It avoids the need to compile a separate shared library, and is |
| 3 | +convenient for use when we just need to call a few functions. |
| 4 | +""" |
| 5 | + |
| 6 | +import ctypes |
| 7 | +from dataclasses import dataclass |
| 8 | +from typing import Any, Dict, List, Optional |
| 9 | + |
| 10 | +# this line makes it possible to directly load `libcudart.so` using `ctypes` |
| 11 | +import torch # noqa |
| 12 | + |
| 13 | +from vllm.logger import init_logger |
| 14 | + |
| 15 | +logger = init_logger(__name__) |
| 16 | + |
| 17 | +# === export types and functions from cudart to Python === |
| 18 | +# for the original cudart definition, please check |
| 19 | +# https://docs.nvidia.com/cuda/cuda-runtime-api/index.html |
| 20 | + |
| 21 | +cudaError_t = ctypes.c_int |
| 22 | +cudaMemcpyKind = ctypes.c_int |
| 23 | + |
| 24 | + |
| 25 | +class cudaIpcMemHandle_t(ctypes.Structure): |
| 26 | + _fields_ = [("internal", ctypes.c_byte * 128)] |
| 27 | + |
| 28 | + |
| 29 | +@dataclass |
| 30 | +class Function: |
| 31 | + name: str |
| 32 | + restype: Any |
| 33 | + argtypes: List[Any] |
| 34 | + |
| 35 | + |
| 36 | +class CudaRTLibrary: |
| 37 | + exported_functions = [ |
| 38 | + # cudaError_t cudaSetDevice ( int device ) |
| 39 | + Function("cudaSetDevice", cudaError_t, [ctypes.c_int]), |
| 40 | + # cudaError_t cudaDeviceSynchronize ( void ) |
| 41 | + Function("cudaDeviceSynchronize", cudaError_t, []), |
| 42 | + # cudaError_t cudaDeviceReset ( void ) |
| 43 | + Function("cudaDeviceReset", cudaError_t, []), |
| 44 | + |
| 45 | + # const char* cudaGetErrorString ( cudaError_t error ) |
| 46 | + Function("cudaGetErrorString", ctypes.c_char_p, [cudaError_t]), |
| 47 | + |
| 48 | + # cudaError_t cudaMalloc ( void** devPtr, size_t size ) |
| 49 | + Function("cudaMalloc", cudaError_t, |
| 50 | + [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]), |
| 51 | + # cudaError_t cudaFree ( void* devPtr ) |
| 52 | + Function("cudaFree", cudaError_t, [ctypes.c_void_p]), |
| 53 | + # cudaError_t cudaMemset ( void* devPtr, int value, size_t count ) |
| 54 | + Function("cudaMemset", cudaError_t, |
| 55 | + [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]), |
| 56 | + # cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa |
| 57 | + Function("cudaMemcpy", cudaError_t, [ |
| 58 | + ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind |
| 59 | + ]), |
| 60 | + |
| 61 | + # cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa |
| 62 | + Function("cudaIpcGetMemHandle", cudaError_t, |
| 63 | + [ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p]), |
| 64 | + # cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa |
| 65 | + Function("cudaIpcOpenMemHandle", cudaError_t, [ |
| 66 | + ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint |
| 67 | + ]), |
| 68 | + ] |
| 69 | + |
| 70 | + # class attribute to store the mapping from the path to the library |
| 71 | + # to avoid loading the same library multiple times |
| 72 | + path_to_library_cache: Dict[str, Any] = {} |
| 73 | + |
| 74 | + # class attribute to store the mapping from library path |
| 75 | + # to the corresponding dictionary |
| 76 | + path_to_dict_mapping: Dict[str, Dict[str, Any]] = {} |
| 77 | + |
| 78 | + def __init__(self, so_file: Optional[str] = None): |
| 79 | + if so_file is None: |
| 80 | + assert torch.version.cuda is not None |
| 81 | + major_version = torch.version.cuda.split(".")[0] |
| 82 | + so_file = f"libcudart.so.{major_version}" |
| 83 | + if so_file not in CudaRTLibrary.path_to_library_cache: |
| 84 | + lib = ctypes.CDLL(so_file) |
| 85 | + CudaRTLibrary.path_to_library_cache[so_file] = lib |
| 86 | + self.lib = CudaRTLibrary.path_to_library_cache[so_file] |
| 87 | + |
| 88 | + if so_file not in CudaRTLibrary.path_to_dict_mapping: |
| 89 | + _funcs = {} |
| 90 | + for func in CudaRTLibrary.exported_functions: |
| 91 | + f = getattr(self.lib, func.name) |
| 92 | + f.restype = func.restype |
| 93 | + f.argtypes = func.argtypes |
| 94 | + _funcs[func.name] = f |
| 95 | + CudaRTLibrary.path_to_dict_mapping[so_file] = _funcs |
| 96 | + self.funcs = CudaRTLibrary.path_to_dict_mapping[so_file] |
| 97 | + |
| 98 | + def CUDART_CHECK(self, result: cudaError_t) -> None: |
| 99 | + if result != 0: |
| 100 | + error_str = self.cudaGetErrorString(result) |
| 101 | + raise RuntimeError(f"CUDART error: {error_str}") |
| 102 | + |
| 103 | + def cudaGetErrorString(self, error: cudaError_t) -> str: |
| 104 | + return self.funcs["cudaGetErrorString"](error).decode("utf-8") |
| 105 | + |
| 106 | + def cudaSetDevice(self, device: int) -> None: |
| 107 | + self.CUDART_CHECK(self.funcs["cudaSetDevice"](device)) |
| 108 | + |
| 109 | + def cudaDeviceSynchronize(self) -> None: |
| 110 | + self.CUDART_CHECK(self.funcs["cudaDeviceSynchronize"]()) |
| 111 | + |
| 112 | + def cudaDeviceReset(self) -> None: |
| 113 | + self.CUDART_CHECK(self.funcs["cudaDeviceReset"]()) |
| 114 | + |
| 115 | + def cudaMalloc(self, size: int) -> ctypes.c_void_p: |
| 116 | + devPtr = ctypes.c_void_p() |
| 117 | + self.CUDART_CHECK(self.funcs["cudaMalloc"](ctypes.byref(devPtr), size)) |
| 118 | + return devPtr |
| 119 | + |
| 120 | + def cudaFree(self, devPtr: ctypes.c_void_p) -> None: |
| 121 | + self.CUDART_CHECK(self.funcs["cudaFree"](devPtr)) |
| 122 | + |
| 123 | + def cudaMemset(self, devPtr: ctypes.c_void_p, value: int, |
| 124 | + count: int) -> None: |
| 125 | + self.CUDART_CHECK(self.funcs["cudaMemset"](devPtr, value, count)) |
| 126 | + |
| 127 | + def cudaMemcpy(self, dst: ctypes.c_void_p, src: ctypes.c_void_p, |
| 128 | + count: int) -> None: |
| 129 | + cudaMemcpyDefault = 4 |
| 130 | + kind = cudaMemcpyDefault |
| 131 | + self.CUDART_CHECK(self.funcs["cudaMemcpy"](dst, src, count, kind)) |
| 132 | + |
| 133 | + def cudaIpcGetMemHandle(self, |
| 134 | + devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t: |
| 135 | + handle = cudaIpcMemHandle_t() |
| 136 | + self.CUDART_CHECK(self.funcs["cudaIpcGetMemHandle"]( |
| 137 | + ctypes.byref(handle), devPtr)) |
| 138 | + return handle |
| 139 | + |
| 140 | + def cudaIpcOpenMemHandle(self, |
| 141 | + handle: cudaIpcMemHandle_t) -> ctypes.c_void_p: |
| 142 | + cudaIpcMemLazyEnablePeerAccess = 1 |
| 143 | + devPtr = ctypes.c_void_p() |
| 144 | + self.CUDART_CHECK(self.funcs["cudaIpcOpenMemHandle"]( |
| 145 | + ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess)) |
| 146 | + return devPtr |
0 commit comments