Skip to content

Commit dbc803a

Browse files
bnellnmjimpang
authored andcommitted
[Kernel][Misc] Use TORCH_LIBRARY instead of PYBIND11_MODULE for custom ops (vllm-project#5047)
1 parent 29c5f3a commit dbc803a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+833
-451
lines changed

CMakeLists.txt

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -66,19 +66,6 @@ endif()
6666
#
6767
find_package(Torch REQUIRED)
6868

69-
#
70-
# Normally `torch.utils.cpp_extension.CUDAExtension` would add
71-
# `libtorch_python.so` for linking against an extension. Torch's cmake
72-
# configuration does not include this library (presumably since the cmake
73-
# config is used for standalone C++ binaries that link against torch).
74-
# The `libtorch_python.so` library defines some of the glue code between
75-
# torch/python via pybind and is required by VLLM extensions for this
76-
# reason. So, add it by manually with `find_library` using torch's
77-
# installed library path.
78-
#
79-
find_library(torch_python_LIBRARY torch_python PATHS
80-
"${TORCH_INSTALL_PREFIX}/lib")
81-
8269
#
8370
# Forward the non-CUDA device extensions to external CMake scripts.
8471
#
@@ -171,7 +158,7 @@ set(VLLM_EXT_SRC
171158
"csrc/quantization/fp8/common.cu"
172159
"csrc/cuda_utils_kernels.cu"
173160
"csrc/moe_align_block_size_kernels.cu"
174-
"csrc/pybind.cpp")
161+
"csrc/torch_bindings.cpp")
175162

176163
if(VLLM_GPU_LANG STREQUAL "CUDA")
177164
include(FetchContent)
@@ -218,14 +205,15 @@ define_gpu_extension_target(
218205
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
219206
ARCHITECTURES ${VLLM_GPU_ARCHES}
220207
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
208+
USE_SABI 3
221209
WITH_SOABI)
222210

223211
#
224212
# _moe_C extension
225213
#
226214

227215
set(VLLM_MOE_EXT_SRC
228-
"csrc/moe/moe_ops.cpp"
216+
"csrc/moe/torch_bindings.cpp"
229217
"csrc/moe/topk_softmax_kernels.cu")
230218

231219
define_gpu_extension_target(
@@ -235,6 +223,7 @@ define_gpu_extension_target(
235223
SOURCES ${VLLM_MOE_EXT_SRC}
236224
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
237225
ARCHITECTURES ${VLLM_GPU_ARCHES}
226+
USE_SABI 3
238227
WITH_SOABI)
239228

240229
#
@@ -249,7 +238,7 @@ set(VLLM_PUNICA_EXT_SRC
249238
"csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu"
250239
"csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu"
251240
"csrc/punica/punica_ops.cu"
252-
"csrc/punica/punica_pybind.cpp")
241+
"csrc/punica/torch_bindings.cpp")
253242

254243
#
255244
# Copy GPU compilation flags+update for punica
@@ -286,6 +275,7 @@ if (VLLM_PUNICA_GPU_ARCHES)
286275
SOURCES ${VLLM_PUNICA_EXT_SRC}
287276
COMPILE_FLAGS ${VLLM_PUNICA_GPU_FLAGS}
288277
ARCHITECTURES ${VLLM_PUNICA_GPU_ARCHES}
278+
USE_SABI 3
289279
WITH_SOABI)
290280
else()
291281
message(WARNING "Unable to create _punica_C target because none of the "

Dockerfile.rocm

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,9 @@ RUN --mount=type=cache,target=/root/.cache/pip \
106106
pip install -U -r requirements-rocm.txt \
107107
&& patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch \
108108
&& python3 setup.py install \
109-
&& cp build/lib.linux-x86_64-cpython-39/vllm/_C.cpython-39-x86_64-linux-gnu.so vllm/ \
110-
&& cp build/lib.linux-x86_64-cpython-39/vllm/_punica_C.cpython-39-x86_64-linux-gnu.so vllm/ \
111-
&& cp build/lib.linux-x86_64-cpython-39/vllm/_moe_C.cpython-39-x86_64-linux-gnu.so vllm/ \
109+
&& cp build/lib.linux-x86_64-cpython-39/vllm/_C.abi3.so vllm/ \
110+
&& cp build/lib.linux-x86_64-cpython-39/vllm/_punica_C.abi3.so vllm/ \
111+
&& cp build/lib.linux-x86_64-cpython-39/vllm/_moe_C.abi3.so vllm/ \
112112
&& cd ..
113113

114114

cmake/cpu_extension.cmake

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ include_directories("${CMAKE_SOURCE_DIR}/csrc")
1212
#
1313
# Check the compile flags
1414
#
15-
list(APPEND CXX_COMPILE_FLAGS
15+
list(APPEND CXX_COMPILE_FLAGS
1616
"-fopenmp"
1717
"-DVLLM_CPU_EXTENSION")
1818

@@ -44,8 +44,8 @@ if (AVX512_FOUND)
4444

4545
find_isa(${CPUINFO} "avx512_bf16" AVX512BF16_FOUND)
4646
if (AVX512BF16_FOUND OR ENABLE_AVX512BF16)
47-
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
48-
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
47+
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
48+
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
4949
list(APPEND CXX_COMPILE_FLAGS "-mavx512bf16")
5050
else()
5151
message(WARNING "Disable AVX512-BF16 ISA support, requires gcc/g++ >= 12.3")
@@ -73,18 +73,18 @@ set(VLLM_EXT_SRC
7373
"csrc/cpu/cache.cpp"
7474
"csrc/cpu/layernorm.cpp"
7575
"csrc/cpu/pos_encoding.cpp"
76-
"csrc/cpu/pybind.cpp")
76+
"csrc/cpu/torch_bindings.cpp")
7777

7878
define_gpu_extension_target(
7979
_C
8080
DESTINATION vllm
8181
LANGUAGE CXX
8282
SOURCES ${VLLM_EXT_SRC}
8383
COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
84-
WITH_SOABI
84+
USE_SABI 3
85+
WITH_SOABI
8586
)
8687

8788
add_custom_target(default)
8889
message(STATUS "Enabling C extension.")
8990
add_dependencies(default _C)
90-

cmake/utils.cmake

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
macro (find_python_from_executable EXECUTABLE SUPPORTED_VERSIONS)
66
file(REAL_PATH ${EXECUTABLE} EXECUTABLE)
77
set(Python_EXECUTABLE ${EXECUTABLE})
8-
find_package(Python COMPONENTS Interpreter Development.Module)
8+
find_package(Python COMPONENTS Interpreter Development.Module Development.SABIModule)
99
if (NOT Python_FOUND)
1010
message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.")
1111
endif()
@@ -294,14 +294,15 @@ endmacro()
294294
# INCLUDE_DIRECTORIES <dirs> - Extra include directories.
295295
# LIBRARIES <libraries> - Extra link libraries.
296296
# WITH_SOABI - Generate library with python SOABI suffix name.
297+
# USE_SABI <version> - Use python stable api <version>
297298
#
298299
# Note: optimization level/debug info is set via cmake build type.
299300
#
300301
function (define_gpu_extension_target GPU_MOD_NAME)
301302
cmake_parse_arguments(PARSE_ARGV 1
302303
GPU
303304
"WITH_SOABI"
304-
"DESTINATION;LANGUAGE"
305+
"DESTINATION;LANGUAGE;USE_SABI"
305306
"SOURCES;ARCHITECTURES;COMPILE_FLAGS;INCLUDE_DIRECTORIES;LIBRARIES")
306307

307308
# Add hipify preprocessing step when building with HIP/ROCm.
@@ -315,7 +316,11 @@ function (define_gpu_extension_target GPU_MOD_NAME)
315316
set(GPU_WITH_SOABI)
316317
endif()
317318

318-
Python_add_library(${GPU_MOD_NAME} MODULE "${GPU_SOURCES}" ${GPU_WITH_SOABI})
319+
if (GPU_USE_SABI)
320+
Python_add_library(${GPU_MOD_NAME} MODULE USE_SABI ${GPU_USE_SABI} ${GPU_WITH_SOABI} "${GPU_SOURCES}")
321+
else()
322+
Python_add_library(${GPU_MOD_NAME} MODULE ${GPU_WITH_SOABI} "${GPU_SOURCES}")
323+
endif()
319324

320325
if (GPU_LANGUAGE STREQUAL "HIP")
321326
# Make this target dependent on the hipify preprocessor step.

csrc/activation_kernels.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#include <ATen/cuda/CUDAContext.h>
2-
#include <torch/extension.h>
2+
#include <torch/all.h>
33
#include <c10/cuda/CUDAGuard.h>
44

55
#include <cmath>

csrc/attention/attention_kernels.cu

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
* limitations under the License.
1818
*/
1919

20-
#include <torch/extension.h>
20+
#include <torch/all.h>
2121
#include <ATen/cuda/CUDAContext.h>
2222
#include <c10/cuda/CUDAGuard.h>
2323
#include <algorithm>
@@ -808,16 +808,17 @@ void paged_attention_v1(
808808
torch::Tensor&
809809
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
810810
torch::Tensor&
811-
value_cache, // [num_blocks, num_heads, head_size, block_size]
812-
int num_kv_heads, // [num_heads]
813-
float scale,
811+
value_cache, // [num_blocks, num_heads, head_size, block_size]
812+
int64_t num_kv_heads, // [num_heads]
813+
double scale,
814814
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
815815
torch::Tensor& seq_lens, // [num_seqs]
816-
int block_size, int max_seq_len,
816+
int64_t block_size, int64_t max_seq_len,
817817
const c10::optional<torch::Tensor>& alibi_slopes,
818-
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
819-
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
820-
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
818+
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
819+
const int64_t blocksparse_local_blocks,
820+
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
821+
const int64_t blocksparse_head_sliding_step) {
821822
const bool is_block_sparse = (blocksparse_vert_stride > 1);
822823

823824
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
@@ -972,16 +973,17 @@ void paged_attention_v2(
972973
torch::Tensor&
973974
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
974975
torch::Tensor&
975-
value_cache, // [num_blocks, num_heads, head_size, block_size]
976-
int num_kv_heads, // [num_heads]
977-
float scale,
976+
value_cache, // [num_blocks, num_heads, head_size, block_size]
977+
int64_t num_kv_heads, // [num_heads]
978+
double scale,
978979
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
979980
torch::Tensor& seq_lens, // [num_seqs]
980-
int block_size, int max_seq_len,
981+
int64_t block_size, int64_t max_seq_len,
981982
const c10::optional<torch::Tensor>& alibi_slopes,
982-
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
983-
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
984-
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
983+
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
984+
const int64_t blocksparse_local_blocks,
985+
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
986+
const int64_t blocksparse_head_sliding_step) {
985987
const bool is_block_sparse = (blocksparse_vert_stride > 1);
986988
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
987989
CALL_V2_LAUNCHER_BLOCK_SIZE)
@@ -990,4 +992,4 @@ void paged_attention_v2(
990992
#undef WARP_SIZE
991993
#undef MAX
992994
#undef MIN
993-
#undef DIVIDE_ROUND_UP
995+
#undef DIVIDE_ROUND_UP

csrc/cache.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,25 @@
11
#pragma once
22

3-
#include <torch/extension.h>
3+
#include <torch/all.h>
44

55
#include <map>
66
#include <vector>
77

88
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
99
const torch::Tensor& block_mapping);
1010

11-
void copy_blocks(std::vector<torch::Tensor>& key_caches,
12-
std::vector<torch::Tensor>& value_caches,
11+
// Note: the key_caches and value_caches vectors are constant but
12+
// not the Tensors they contain. The vectors need to be const refs
13+
// in order to satisfy pytorch's C++ operator registration code.
14+
void copy_blocks(std::vector<torch::Tensor> const& key_caches,
15+
std::vector<torch::Tensor> const& value_caches,
1316
const torch::Tensor& block_mapping);
1417

1518
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
1619
torch::Tensor& key_cache, torch::Tensor& value_cache,
1720
torch::Tensor& slot_mapping,
18-
const std::string& kv_cache_dtype, const float kv_scale);
21+
const std::string& kv_cache_dtype,
22+
const double kv_scale);
1923

2024
void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
2125
torch::Tensor& key_cache,
@@ -25,4 +29,4 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
2529

2630
// Just for unittest
2731
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
28-
const float scale, const std::string& kv_cache_dtype);
32+
const double scale, const std::string& kv_cache_dtype);

csrc/cache_kernels.cu

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#include <torch/extension.h>
1+
#include <torch/all.h>
22
#include <ATen/cuda/CUDAContext.h>
33
#include <c10/cuda/CUDAGuard.h>
44

@@ -95,8 +95,11 @@ __global__ void copy_blocks_kernel(int64_t* key_cache_ptrs,
9595

9696
} // namespace vllm
9797

98-
void copy_blocks(std::vector<torch::Tensor>& key_caches,
99-
std::vector<torch::Tensor>& value_caches,
98+
// Note: the key_caches and value_caches vectors are constant but
99+
// not the Tensors they contain. The vectors need to be const refs
100+
// in order to satisfy pytorch's C++ operator registration code.
101+
void copy_blocks(std::vector<torch::Tensor> const& key_caches,
102+
std::vector<torch::Tensor> const& value_caches,
100103
const torch::Tensor& block_mapping) {
101104
int num_layers = key_caches.size();
102105
TORCH_CHECK(num_layers == value_caches.size());
@@ -255,7 +258,7 @@ void reshape_and_cache(
255258
torch::Tensor&
256259
value_cache, // [num_blocks, num_heads, head_size, block_size]
257260
torch::Tensor& slot_mapping, // [num_tokens]
258-
const std::string& kv_cache_dtype, const float kv_scale) {
261+
const std::string& kv_cache_dtype, const double kv_scale) {
259262
int num_tokens = key.size(0);
260263
int num_heads = key.size(1);
261264
int head_size = key.size(2);
@@ -334,7 +337,7 @@ __global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,
334337

335338
// Only for testing.
336339
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
337-
const float kv_scale, const std::string& kv_cache_dtype) {
340+
const double kv_scale, const std::string& kv_cache_dtype) {
338341
torch::Device src_device = src_cache.device();
339342
torch::Device dst_device = dst_cache.device();
340343
TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")

csrc/cpu/attention.cpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -420,12 +420,13 @@ void paged_attention_v1_impl_launcher(
420420

421421
void paged_attention_v1(
422422
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
423-
torch::Tensor& value_cache, int num_kv_heads, float scale,
424-
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
425-
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
426-
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
427-
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
428-
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
423+
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
424+
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
425+
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
426+
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
427+
const int64_t blocksparse_local_blocks,
428+
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
429+
const int64_t blocksparse_head_sliding_step) {
429430
TORCH_CHECK(kv_scale == 1.0f);
430431
TORCH_CHECK(blocksparse_vert_stride <= 1,
431432
"CPU backend does not support blocksparse attention yet.");
@@ -738,12 +739,13 @@ void paged_attention_v2_impl_launcher(
738739
void paged_attention_v2(
739740
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
740741
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
741-
torch::Tensor& value_cache, int num_kv_heads, float scale,
742-
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
743-
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
744-
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
745-
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
746-
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
742+
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
743+
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
744+
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
745+
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
746+
const int64_t blocksparse_local_blocks,
747+
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
748+
const int64_t blocksparse_head_sliding_step) {
747749
TORCH_CHECK(kv_scale == 1.0f);
748750
TORCH_CHECK(blocksparse_vert_stride <= 1,
749751
"CPU backend does not support blocksparse attention yet.");

csrc/cpu/cache.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66
namespace {
77
template <typename scalar_t>
8-
void copy_blocks_cpu_impl(std::vector<torch::Tensor>& key_caches,
9-
std::vector<torch::Tensor>& value_caches,
8+
void copy_blocks_cpu_impl(std::vector<torch::Tensor> const& key_caches,
9+
std::vector<torch::Tensor> const& value_caches,
1010
const torch::Tensor& mapping_pairs,
1111
const int element_num_per_block,
1212
const int layer_num) {
@@ -82,8 +82,11 @@ void reshape_and_cache_cpu_impl(
8282
}
8383
}; // namespace
8484

85-
void copy_blocks(std::vector<torch::Tensor>& key_caches,
86-
std::vector<torch::Tensor>& value_caches,
85+
// Note: the key_caches and value_caches vectors are constant but
86+
// not the Tensors they contain. The vectors need to be const refs
87+
// in order to satisfy pytorch's C++ operator registration code.
88+
void copy_blocks(std::vector<torch::Tensor> const& key_caches,
89+
std::vector<torch::Tensor> const& value_caches,
8790
const torch::Tensor& block_mapping) {
8891
unsigned num_layers = key_caches.size();
8992
TORCH_CHECK(num_layers == value_caches.size());
@@ -104,7 +107,7 @@ void copy_blocks(std::vector<torch::Tensor>& key_caches,
104107
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
105108
torch::Tensor& key_cache, torch::Tensor& value_cache,
106109
torch::Tensor& slot_mapping,
107-
const std::string& kv_cache_dtype, float kv_scale) {
110+
const std::string& kv_cache_dtype, double kv_scale) {
108111
TORCH_CHECK(kv_scale == 1.0f);
109112

110113
int num_tokens = key.size(0);

0 commit comments

Comments
 (0)