From 5cd3412e1578de4e3e1eb24db08596ab0056e047 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 25 May 2024 05:45:28 +0000 Subject: [PATCH 01/41] Use TORCH_LIBRARY instead of PYBIND11_MODULE --- csrc/attention/attention_kernels.cu | 22 ++-- csrc/cpu/attention.cpp | 24 ++-- csrc/cpu/layernorm.cpp | 4 +- csrc/cpu/pos_encoding.cpp | 2 +- csrc/cpu/pybind.cpp | 47 ++++--- csrc/layernorm_kernels.cu | 4 +- csrc/moe/moe_ops.cpp | 8 +- csrc/moe_align_block_size_kernels.cu | 4 +- csrc/ops.h | 54 ++++---- csrc/pos_encoding_kernels.cu | 10 +- csrc/punica/punica_ops.cu | 4 +- csrc/punica/punica_ops.h | 4 +- csrc/punica/punica_pybind.cpp | 11 +- csrc/pybind.cpp | 174 ++++++++++++++++++-------- csrc/quantization/awq/gemm_kernels.cu | 6 +- csrc/quantization/gptq/q_gemm.cu | 4 +- tests/kernels/test_int8_quant.py | 2 +- vllm/_custom_ops.py | 57 +++++---- 18 files changed, 260 insertions(+), 181 deletions(-) diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 8f89f89786c3..3c8907dc2680 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -809,15 +809,15 @@ void paged_attention_v1( key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] - int num_kv_heads, // [num_heads] - float scale, + int64_t num_kv_heads, // [num_heads] + double scale, torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& seq_lens, // [num_seqs] - int block_size, int max_seq_len, + int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, - const int blocksparse_local_blocks, const int blocksparse_vert_stride, - const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, + const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { const bool is_block_sparse = (blocksparse_vert_stride > 1); DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, @@ -973,15 +973,15 @@ void paged_attention_v2( key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] - int num_kv_heads, // [num_heads] - float scale, + int64_t num_kv_heads, // [num_heads] + double scale, torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& seq_lens, // [num_seqs] - int block_size, int max_seq_len, + int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, - const int blocksparse_local_blocks, const int blocksparse_vert_stride, - const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, + const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { const bool is_block_sparse = (blocksparse_vert_stride > 1); DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, CALL_V2_LAUNCHER_BLOCK_SIZE) @@ -990,4 +990,4 @@ void paged_attention_v2( #undef WARP_SIZE #undef MAX #undef MIN -#undef DIVIDE_ROUND_UP \ No newline at end of file +#undef DIVIDE_ROUND_UP diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index ed8cfbd421f0..47b674cfaa7b 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -420,12 +420,12 @@ void paged_attention_v1_impl_launcher( void paged_attention_v1( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int num_kv_heads, float scale, - torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, - int max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, - const int blocksparse_local_blocks, const int blocksparse_vert_stride, - const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + torch::Tensor& value_cache, int64_t num_kv_heads, double scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, + int64_t max_seq_len, const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, + const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { TORCH_CHECK(kv_scale == 1.0f); TORCH_CHECK(blocksparse_vert_stride <= 1, "CPU backend does not support blocksparse attention yet."); @@ -738,12 +738,12 @@ void paged_attention_v2_impl_launcher( void paged_attention_v2( torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int num_kv_heads, float scale, - torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, - int max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, - const int blocksparse_local_blocks, const int blocksparse_vert_stride, - const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + torch::Tensor& value_cache, int64_t num_kv_heads, double scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, + int64_t max_seq_len, const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, + const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { TORCH_CHECK(kv_scale == 1.0f); TORCH_CHECK(blocksparse_vert_stride <= 1, "CPU backend does not support blocksparse attention yet."); diff --git a/csrc/cpu/layernorm.cpp b/csrc/cpu/layernorm.cpp index 65d3ddcec570..a76ad08928a2 100644 --- a/csrc/cpu/layernorm.cpp +++ b/csrc/cpu/layernorm.cpp @@ -88,7 +88,7 @@ void fused_add_rms_norm_impl(scalar_t* __restrict__ input, } // namespace void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, - float epsilon) { + double epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; @@ -102,7 +102,7 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, } void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, - torch::Tensor& weight, float epsilon) { + torch::Tensor& weight, double epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; diff --git a/csrc/cpu/pos_encoding.cpp b/csrc/cpu/pos_encoding.cpp index e8aead17ae5a..96bce7dda013 100644 --- a/csrc/cpu/pos_encoding.cpp +++ b/csrc/cpu/pos_encoding.cpp @@ -168,7 +168,7 @@ void rotary_embedding_gptj_impl( }; // namespace void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, - torch::Tensor& key, int head_size, + torch::Tensor& key, int64_t head_size, torch::Tensor& cos_sin_cache, bool is_neox) { int num_tokens = query.numel() / query.size(-1); int rot_dim = cos_sin_cache.size(1); diff --git a/csrc/cpu/pybind.cpp b/csrc/cpu/pybind.cpp index e5b2ce4f3011..1c864d2922b1 100644 --- a/csrc/cpu/pybind.cpp +++ b/csrc/cpu/pybind.cpp @@ -2,38 +2,45 @@ #include "ops.h" #include -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // vLLM custom ops - pybind11::module ops = m.def_submodule("ops", "vLLM custom operators"); // Attention ops - ops.def("paged_attention_v1", &paged_attention_v1, - "Compute the attention between an input query and the cached " - "keys/values using PagedAttention."); - ops.def("paged_attention_v2", &paged_attention_v2, "PagedAttention V2."); + // Compute the attention between an input query and the cached keys/values + //using PagedAttention. + ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1); + // PagedAttention V2. + ops.impl("paged_attention_v2", torch::kCPU, &paged_attention_v2); // Activation ops - ops.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU."); - ops.def("gelu_and_mul", &gelu_and_mul, - "Activation function used in GeGLU with `none` approximation."); - ops.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, - "Activation function used in GeGLU with `tanh` approximation."); - ops.def("gelu_new", &gelu_new, "GELU implementation used in GPT-2."); - ops.def("gelu_fast", &gelu_fast, "Approximate GELU implementation."); + // Activation function used in SwiGLU. + ops.impl("silu_and_mul", torch::kCPU, &silu_and_mul); + // Activation function used in GeGLU with `none` approximation. + ops.impl("gelu_and_mul", torch::kCPU, &gelu_and_mul); + // Activation function used in GeGLU with `tanh` approximation. + ops.impl("gelu_tanh_and_mul", torch::kCPU, &gelu_tanh_and_mul); + // GELU implementation used in GPT-2. + ops.impl("gelu_new", torch::kCPU, &gelu_new); + // Approximate GELU implementation. + ops.impl("gelu_fast", torch::kCPU, &gelu_fast); // Layernorm - ops.def("rms_norm", &rms_norm, - "Apply Root Mean Square (RMS) Normalization to the input tensor."); + // Apply Root Mean Square (RMS) Normalization to the input tensor. + ops.impl("rms_norm", torch::kCPU, &rms_norm); - ops.def("fused_add_rms_norm", &fused_add_rms_norm, - "In-place fused Add and RMS Normalization"); + // In-place fused Add and RMS Normalization. + ops.impl("fused_add_rms_norm", torch::kCPU, &fused_add_rms_norm); // Rotary embedding - ops.def("rotary_embedding", &rotary_embedding, - "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); + // Apply GPT-NeoX or GPT-J style rotary embedding to query and key. + ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ // Cache ops - pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); + pybind11::module cache_ops = m.impl_submodule("cache_ops", "vLLM cache ops"); cache_ops.def("swap_blocks", &swap_blocks, "Swap in (out) the cache blocks from src to dst"); cache_ops.def("copy_blocks", ©_blocks, diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index 70a2b3b0a07b..73d3dfa9e81a 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -291,7 +291,7 @@ fused_add_rms_norm_kernel( void rms_norm(torch::Tensor& out, // [..., hidden_size] torch::Tensor& input, // [..., hidden_size] torch::Tensor& weight, // [hidden_size] - float epsilon) { + double epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; @@ -319,7 +319,7 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size] void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size] torch::Tensor& residual, // [..., hidden_size] torch::Tensor& weight, // [hidden_size] - float epsilon) { + double epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; diff --git a/csrc/moe/moe_ops.cpp b/csrc/moe/moe_ops.cpp index 4122f7630d7c..3b24c2ed4864 100644 --- a/csrc/moe/moe_ops.cpp +++ b/csrc/moe/moe_ops.cpp @@ -2,7 +2,9 @@ #include -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("topk_softmax", &topk_softmax, - "Apply topk softmax to the gating outputs."); +#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE) + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { + // Apply topk softmax to the gating outputs. + m.impl("topk_softmax", torch::kCUDA, &topk_softmax); } diff --git a/csrc/moe_align_block_size_kernels.cu b/csrc/moe_align_block_size_kernels.cu index edc441d12102..63d7c73b1631 100644 --- a/csrc/moe_align_block_size_kernels.cu +++ b/csrc/moe_align_block_size_kernels.cu @@ -108,8 +108,8 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, } } // namespace vllm -void moe_align_block_size(torch::Tensor topk_ids, int num_experts, - int block_size, torch::Tensor sorted_token_ids, +void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, + int64_t block_size, torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); diff --git a/csrc/ops.h b/csrc/ops.h index 06b60e748886..1bd480c1466a 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -4,37 +4,37 @@ void paged_attention_v1( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int num_kv_heads, float scale, - torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, - int max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, - const int blocksparse_local_blocks, const int blocksparse_vert_stride, - const int blocksparse_block_size, const int blocksparse_head_sliding_step); + torch::Tensor& value_cache, int64_t num_kv_heads, double scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, + int64_t max_seq_len, const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, + const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step); void paged_attention_v2( torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int num_kv_heads, float scale, - torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, - int max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, - const int blocksparse_local_blocks, const int blocksparse_vert_stride, - const int blocksparse_block_size, const int blocksparse_head_sliding_step); + torch::Tensor& value_cache, int64_t num_kv_heads, double scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, + int64_t max_seq_len, const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, + const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step); void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, - float epsilon); + double epsilon); void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, - torch::Tensor& weight, float epsilon); + torch::Tensor& weight, double epsilon); void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, - torch::Tensor& key, int head_size, + torch::Tensor& key, int64_t head_size, torch::Tensor& cos_sin_cache, bool is_neox); void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query, - torch::Tensor& key, int head_size, + torch::Tensor& key, int64_t head_size, torch::Tensor& cos_sin_cache, bool is_neox, - int rot_dim, + int64_t rot_dim, torch::Tensor& cos_sin_cache_offsets); void silu_and_mul(torch::Tensor& out, torch::Tensor& input); @@ -60,12 +60,12 @@ torch::Tensor aqlm_dequant(const torch::Tensor& codes, torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scaling_factors, torch::Tensor _zeros, - int split_k_iters); + int64_t split_k_iters); torch::Tensor awq_dequantize(torch::Tensor _kernel, torch::Tensor _scaling_factors, - torch::Tensor _zeros, int split_k_iters, int thx, - int thy); + torch::Tensor _zeros, int64_t split_k_iters, int64_t thx, + int64_t thy); torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_scales, torch::Tensor& workspace, @@ -88,9 +88,9 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits); -int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, torch::Tensor const& a_scales, - torch::Tensor const& b_scales); +void cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, torch::Tensor const& a_scales, + torch::Tensor const& b_scales); #endif @@ -106,9 +106,9 @@ void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, torch::Tensor b_gptq_qzeros, torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, - bool use_exllama, int bit); + bool use_exllama, int64_t bit); -void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int bit); +void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit); void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input, torch::Tensor& scale); @@ -116,8 +116,8 @@ void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input, void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input, torch::Tensor& scale); -void moe_align_block_size(torch::Tensor topk_ids, int num_experts, - int block_size, torch::Tensor sorted_token_ids, +void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, + int64_t block_size, torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad); diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index 69d6dae1c26b..caca03284735 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -127,7 +127,7 @@ void rotary_embedding( // [num_tokens, num_heads * head_size] torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or // [num_tokens, num_kv_heads * head_size] - int head_size, + int64_t head_size, torch::Tensor& cos_sin_cache, // [max_position, rot_dim] bool is_neox) { int64_t num_tokens = query.numel() / query.size(-1); @@ -138,7 +138,7 @@ void rotary_embedding( int64_t key_stride = key.stride(-2); dim3 grid(num_tokens); - dim3 block(std::min(num_heads * rot_dim / 2, 512)); + dim3 block(std::min(num_heads * rot_dim / 2, 512)); const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] { @@ -168,9 +168,9 @@ void batched_rotary_embedding( // [num_tokens, num_heads * head_size] torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or // [num_tokens, num_kv_heads * head_size] - int head_size, + int64_t head_size, torch::Tensor& cos_sin_cache, // [max_position, rot_dim] - bool is_neox, int rot_dim, + bool is_neox, int64_t rot_dim, torch::Tensor& cos_sin_cache_offsets // [num_tokens] ) { int64_t num_tokens = cos_sin_cache_offsets.size(0); @@ -180,7 +180,7 @@ void batched_rotary_embedding( int64_t key_stride = key.stride(-2); dim3 grid(num_tokens); - dim3 block(std::min(num_heads * rot_dim / 2, 512)); + dim3 block(std::min(num_heads * rot_dim / 2, 512)); const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] { diff --git a/csrc/punica/punica_ops.cu b/csrc/punica/punica_ops.cu index 61de3b37937c..e345d8a24d45 100644 --- a/csrc/punica/punica_ops.cu +++ b/csrc/punica/punica_ops.cu @@ -88,7 +88,7 @@ inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W, } void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, - torch::Tensor indicies, int64_t layer_idx, float scale) { + torch::Tensor indicies, int64_t layer_idx, double scale) { CHECK_INPUT(y); CHECK_INPUT(x); CHECK_INPUT(w); @@ -320,7 +320,7 @@ void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, torch::Tensor indicies, int64_t layer_idx, - float scale, int64_t h_in, int64_t h_out, + double scale, int64_t h_in, int64_t h_out, int64_t y_offset) { CHECK_INPUT(y); CHECK_INPUT(x); diff --git a/csrc/punica/punica_ops.h b/csrc/punica/punica_ops.h index 937e2d1d25d4..e94d26f9701c 100644 --- a/csrc/punica/punica_ops.h +++ b/csrc/punica/punica_ops.h @@ -3,9 +3,9 @@ #include void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, - torch::Tensor indicies, int64_t layer_idx, float scale); + torch::Tensor indicies, int64_t layer_idx, double scale); void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, torch::Tensor indicies, int64_t layer_idx, - float scale, int64_t h_in, int64_t h_out, + double scale, int64_t h_in, int64_t h_out, int64_t y_offset); diff --git a/csrc/punica/punica_pybind.cpp b/csrc/punica/punica_pybind.cpp index 9490ad59cdd5..ea054b475148 100644 --- a/csrc/punica/punica_pybind.cpp +++ b/csrc/punica/punica_pybind.cpp @@ -2,12 +2,9 @@ #include "punica_ops.h" -//====== pybind ====== +#define TORCH_LIBRARY_EXPAND(NAME, M) TORCH_LIBRARY(NAME, M) -#define DEFINE_pybind(name) m.def(#name, &name, #name); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv"); - m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level, - "dispatch_bgmv_low_level"); +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { + m.impl("dispatch_bgmv", torch::kCUDA, &dispatch_bgmv); + m.impl("dispatch_bgmv_low_level", torch::kCUDA, &dispatch_bgmv_low_level); } diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 547823aa1b04..59e24b467d04 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -3,76 +3,150 @@ #include "ops.h" #include -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // vLLM custom ops - pybind11::module ops = m.def_submodule("ops", "vLLM custom operators"); // Attention ops - ops.def("paged_attention_v1", &paged_attention_v1, - "Compute the attention between an input query and the cached " - "keys/values using PagedAttention."); - ops.def("paged_attention_v2", &paged_attention_v2, "PagedAttention V2."); + // Compute the attention between an input query and the cached + // keys/values using PagedAttention. + ops.def("paged_attention_v1(Tensor out, Tensor query, Tensor key_cache, " + "Tensor value_cache, int num_kv_heads, float scale, Tensor " + "block_tables, Tensor seq_lens, int block_size, int max_seq_len, " + "Tensor? alibi_slopes, str kv_cache_dtype, float kv_scale) -> ()"); + ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1); + + // PagedAttention V2. + ops.def("paged_attention_v2(Tensor out, Tensor exp_sums, Tensor max_logits," + "Tensor tmp_out, Tensor query, Tensor key_cache, Tensor value_cache," + "int num_kv_heads, float scale, Tensor block_tables, Tensor seq_lens," + "int block_size, int max_seq_len, Tensor? alibi_slopes, " + "str kv_cache_dtype, float kv_scale) -> ()"); + ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2); // Activation ops - ops.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU."); - ops.def("gelu_and_mul", &gelu_and_mul, - "Activation function used in GeGLU with `none` approximation."); - ops.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, - "Activation function used in GeGLU with `tanh` approximation."); - ops.def("gelu_new", &gelu_new, "GELU implementation used in GPT-2."); - ops.def("gelu_fast", &gelu_fast, "Approximate GELU implementation."); + // Activation function used in SwiGLU. + ops.def("silu_and_mul(Tensor out, Tensor input) -> ()"); + ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul); + + // Activation function used in GeGLU with `none` approximation. + ops.def("gelu_and_mul(Tensor out, Tensor input) -> ()"); + ops.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul); + + // Activation function used in GeGLU with `tanh` approximation. + ops.def("gelu_tanh_and_mul(Tensor out, Tensor input) -> ()"); + ops.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul); + + // GELU implementation used in GPT-2. + ops.def("gelu_new(Tensor out, Tensor input) -> ()"); + ops.impl("gelu_new", torch::kCUDA, &gelu_new); + + // Approximate GELU implementation. + ops.def("gelu_fast(Tensor out, Tensor input) -> ()"); + ops.impl("gelu_fast", torch::kCUDA, &gelu_fast); // Layernorm - ops.def("rms_norm", &rms_norm, - "Apply Root Mean Square (RMS) Normalization to the input tensor."); + // Apply Root Mean Square (RMS) Normalization to the input tensor. + ops.def("rms_norm(Tensor out, Tensor input, Tensor weight, float epsilon) -> ()"); + //ops.def(torch::schema("rms_norm(Tensor out, Tensor input, Tensor weight, float epsilon) -> ()"), c10::AliasAnalysisKind::CONSERVATIVE); + ops.impl("rms_norm", torch::kCUDA, &rms_norm); - ops.def("fused_add_rms_norm", &fused_add_rms_norm, - "In-place fused Add and RMS Normalization"); + // In-place fused Add and RMS Normalization. + ops.def("fused_add_rms_norm(Tensor input, Tensor residual, Tensor weight, " + "float epsilon) -> ()"); + ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm); // Rotary embedding - ops.def("rotary_embedding", &rotary_embedding, - "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); + // Apply GPT-NeoX or GPT-J style rotary embedding to query and key. + ops.def("rotary_embedding(Tensor positions, Tensor query, Tensor key, int " + "head_size, Tensor cos_sin_cache, bool is_neox) -> ()"); + ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding); - ops.def("batched_rotary_embedding", &batched_rotary_embedding, - "Apply GPT-NeoX or GPT-J style rotary embedding to query and key " - "(supports multiple loras)"); + // Apply GPT-NeoX or GPT-J style rotary embedding to query and key + // (supports multiple loras). + ops.def("batched_rotary_embedding(Tensor positions, Tensor query, Tensor " + "key, int head_size, Tensor cos_sin_cache, bool is_neox, int " + "rot_dim, Tensor cos_sin_cache_offsets) -> ()"); + ops.impl("batched_rotary_embedding", torch::kCUDA, &batched_rotary_embedding); -// Quantization ops + // Quantization ops #ifndef USE_ROCM - ops.def("aqlm_gemm", &aqlm_gemm, "Quantized GEMM for AQLM"); - ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM"); - ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); - ops.def("marlin_gemm", &marlin_gemm, - "Marlin (Dense) Optimized Quantized GEMM for GPTQ"); - ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm, - "Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ"); - ops.def("gptq_marlin_gemm", &gptq_marlin_gemm, - "gptq_marlin Optimized Quantized GEMM for GPTQ"); - ops.def("gptq_marlin_repack", &gptq_marlin_repack, - "gptq_marlin repack from GPTQ"); - ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ"); - ops.def("cutlass_scaled_mm_dq", &cutlass_scaled_mm_dq, - "CUTLASS w8a8 GEMM, supporting symmetric per-tensor or " - "per-row/column quantization."); + // Quantized GEMM for AQLM. + ops.def("aqlm_gemm(Tensor input, Tensor codes, Tensor codebooks, Tensor scales, Tensor codebook_partition_sizes, Tensor? bias) -> Tensor"); + ops.impl("aqlm_gemm", torch::kCUDA, &aqlm_gemm); + + // Decompression method for AQLM. + ops.def("aqlm_dequant(Tensor codes, Tensor codebooks, Tensor codebook_partition_sizes) -> Tensor"); + ops.impl("aqlm_dequant", torch::kCUDA, &aqlm_dequant); + + // Quantized GEMM for AWQ. + ops.def("awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, Tensor _zeros, int split_k_iters) -> Tensor"); + ops.impl("awq_gemm", torch::kCUDA, &awq_gemm); + + // Marlin (Dense) Optimized Quantized GEMM for GPTQ. + ops.def("marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, Tensor workspace, int size_m, int size_n, int size_k) -> Tensor"); + ops.impl("marlin_gemm", torch::kCUDA, &marlin_gemm); + + // Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ. + ops.def("gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, Tensor b_scales, Tensor workspace, int num_bits, int size_m, int size_n, int size_k) -> Tensor"); + ops.impl("gptq_marlin_24_gemm", torch::kCUDA, &gptq_marlin_24_gemm); + + // gptq_marlin Optimized Quantized GEMM for GPTQ. + ops.def("gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, Tensor g_idx, Tensor perm, Tensor workspace, int num_bits, int size_m, int size_n, int size_k, bool is_k_full) -> Tensor"); + ops.impl("gptq_marlin_gemm", torch::kCUDA, &gptq_marlin_gemm); + + // gptq_marlin repack from GPTQ. + ops.def("gptq_marlin_repack(Tensor b_q_weight, Tensor perm, int size_k, int size_n, int num_bits) -> Tensor"); + ops.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack); + + // Dequantization for AWQ. + ops.def("awq_dequantize(Tensor _kernel, Tensor _scaling_factors, Tensor _zeros, int split_k_iters, int thx, int thy) -> Tensor"); + ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize); + + // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column + // quantization. + ops.def("cutlass_scaled_mm_dq(Tensor out, Tensor a, Tensor b, Tensor a_scales, Tensor b_scales) -> ()"); + ops.impl("cutlass_scaled_mm_dq", torch::kCUDA, &cutlass_scaled_mm_dq); #endif - ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); - ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); - ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); - ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant, - "Compute FP8 quantized tensor for given scaling factor"); - ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant, - "Compute FP8 quantized tensor and scaling factor"); - ops.def("moe_align_block_size", &moe_align_block_size, - "Aligning the number of tokens to be processed by each expert such " - "that it is divisible by the block size."); + // Quantized GEMM for GPTQ. + ops.def("gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, int bit) -> Tensor"); + ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm); + + // Post processing for GPTQ. + ops.def("gptq_shuffle(Tensor q_weight, Tensor q_perm, int bit) -> ()"); + ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle); + + // Quantized GEMM for SqueezeLLM. + ops.def("squeezellm_gemm(Tensor vec, Tensor mat, Tensor mul, Tensor lookup_table) -> ()"); + ops.impl("squeezellm_gemm", torch::kCUDA, &squeezellm_gemm); - ops.def("static_scaled_int8_quant", &static_scaled_int8_quant, - "Compute int8 quantized tensor for given scaling factor"); + // Compute FP8 quantized tensor for given scaling factor. + ops.def("static_scaled_fp8_quant(Tensor out, Tensor input, Tensor scale) -> ()"); + ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant); + // Compute FP8 quantized tensor and scaling factor. + ops.def("dynamic_scaled_fp8_quant(Tensor out, Tensor input, Tensor scale) -> ()"); + ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant); + + // Aligning the number of tokens to be processed by each expert such + // that it is divisible by the block size. + ops.def("moe_align_block_size(Tensor topk_ids, int num_experts, int block_size," + "Tensor sorted_token_ids, Tensor experts_ids, Tensor num_tokens_post_pad) -> ()"); + ops.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); + + // Compute int8 quantized tensor for given scaling factor. + ops.def("static_scaled_int8_quant(Tensor out, Tensor input, float scale) -> ()"); + ops.impl("static_scaled_int8_quant", torch::kCUDA, &static_scaled_int8_quant); + + // Compute int8 quantized tensor and scaling factor ops.def("dynamic_scaled_int8_quant", &dynamic_scaled_int8_quant, "Compute int8 quantized tensor and scaling factor"); + ops.impl("dynamic_scaled_int8_quant", torch::kCUDA, &dynamic_scaled_int8_quant); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ // Cache ops pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); cache_ops.def("swap_blocks", &swap_blocks, diff --git a/csrc/quantization/awq/gemm_kernels.cu b/csrc/quantization/awq/gemm_kernels.cu index bb8e5bbb23d7..4ca69e956969 100644 --- a/csrc/quantization/awq/gemm_kernels.cu +++ b/csrc/quantization/awq/gemm_kernels.cu @@ -435,8 +435,8 @@ __global__ void __launch_bounds__(64) torch::Tensor awq_dequantize(torch::Tensor _kernel, torch::Tensor _scaling_factors, - torch::Tensor _zeros, int split_k_iters, int thx, - int thy) { + torch::Tensor _zeros, int64_t split_k_iters, int64_t thx, + int64_t thy) { int in_c = _kernel.size(0); int qout_c = _kernel.size(1); int out_c = qout_c * 8; @@ -491,7 +491,7 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel, torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scaling_factors, torch::Tensor _zeros, - int split_k_iters) { + int64_t split_k_iters) { int num_in_feats = _in_feats.size(0); int num_in_channels = _in_feats.size(1); const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); diff --git a/csrc/quantization/gptq/q_gemm.cu b/csrc/quantization/gptq/q_gemm.cu index 480c4986c382..91813f306713 100644 --- a/csrc/quantization/gptq/q_gemm.cu +++ b/csrc/quantization/gptq/q_gemm.cu @@ -1823,7 +1823,7 @@ void shuffle_exllama_weight(uint32_t* q_weight, int* q_perm, int height, torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, torch::Tensor b_gptq_qzeros, torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, - bool use_exllama, int bit) { + bool use_exllama, int64_t bit) { const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options); @@ -1845,7 +1845,7 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, return c; } -void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int bit) { +void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit) { const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight)); vllm::gptq::shuffle_exllama_weight( (uint32_t*)q_weight.data_ptr(), diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index aab7af9d2cbf..d37f7d2e6ef4 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -1,7 +1,7 @@ import pytest import torch -from vllm._C import ops +from vllm import _custom_ops as ops DTYPES = [torch.half, torch.bfloat16, torch.float] HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192, diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 7e12f1ba14cd..2c4f71e14e1d 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1,9 +1,8 @@ -from typing import Optional, Tuple, Type +from typing import Optional, Tuple, Type, List import torch try: - from vllm._C import cache_ops as vllm_cache_ops from vllm._C import ops as vllm_ops except ImportError as e: from vllm.logger import init_logger @@ -13,23 +12,23 @@ # activation ops def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: - vllm_ops.silu_and_mul(out, x) + torch.ops._C.silu_and_mul(out, x) def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: - vllm_ops.gelu_and_mul(out, x) + torch.ops._C.gelu_and_mul(out, x) def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: - vllm_ops.gelu_tanh_and_mul(out, x) + torch.ops._C.gelu_tanh_and_mul(out, x) def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None: - vllm_ops.gelu_fast(out, x) + torch.ops._C.gelu_fast(out, x) def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None: - vllm_ops.gelu_new(out, x) + torch.ops._C.gelu_new(out, x) # page attention ops @@ -100,8 +99,8 @@ def rotary_embedding( cos_sin_cache: torch.Tensor, is_neox: bool, ) -> None: - vllm_ops.rotary_embedding(positions, query, key, head_size, cos_sin_cache, - is_neox) + torch.ops._C.rotary_embedding(positions, query, key, head_size, cos_sin_cache, + is_neox) def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, @@ -109,7 +108,7 @@ def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, cos_sin_cache: torch.Tensor, is_neox: bool, rot_dim: int, cos_sin_cache_offsets: torch.Tensor) -> None: - vllm_ops.batched_rotary_embedding(positions, query, key, head_size, + torch.ops._C.batched_rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox, rot_dim, cos_sin_cache_offsets) @@ -117,12 +116,12 @@ def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, # layer norm ops def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, epsilon: float) -> None: - vllm_ops.rms_norm(out, input, weight, epsilon) + torch.ops._C.rms_norm(out, input, weight, epsilon) def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, epsilon: float) -> None: - vllm_ops.fused_add_rms_norm(input, residual, weight, epsilon) + torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) # quantization ops @@ -130,13 +129,13 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, split_k_iters: int, thx: int, thy: int) -> torch.Tensor: - return vllm_ops.awq_dequantize(qweight, scales, zeros, split_k_iters, thx, + return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters, thx, thy) def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, scales: torch.Tensor, split_k_iters: int) -> torch.Tensor: - return vllm_ops.awq_gemm(input, qweight, qzeros, scales, split_k_iters) + return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters) # gptq @@ -144,26 +143,26 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor, use_exllama: bool, bit: int) -> torch.Tensor: - return vllm_ops.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, + return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_exllama, bit) def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, bit: int) -> None: - vllm_ops.gptq_shuffle(q_weight, q_perm, bit) + torch.ops._C.gptq_shuffle(q_weight, q_perm, bit) # squeezellm def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor, lookup_table: torch.Tensor) -> None: - vllm_ops.squeezellm_gemm(vec, mat, mul, lookup_table) + torch.ops._C.squeezellm_gemm(vec, mat, mul, lookup_table) # marlin def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int, size_n: int, size_k: int) -> torch.Tensor: - return vllm_ops.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m, + return torch.ops._C.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m, size_n, size_k) @@ -172,7 +171,7 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, b_meta: torch.Tensor, b_scales: torch.Tensor, workspace: torch.Tensor, num_bits: int, size_m: int, size_n: int, size_k: int) -> torch.Tensor: - return vllm_ops.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales, + return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales, workspace, num_bits, size_m, size_n, size_k) @@ -188,7 +187,7 @@ def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor, n = b.shape[1] out = torch.empty((m, n), dtype=out_dtype, device=a.device) - vllm_ops.cutlass_scaled_mm_dq(out, a, b, scale_a, scale_b) + torch.ops._C.cutlass_scaled_mm_dq(out, a, b, scale_a, scale_b) return out @@ -198,20 +197,20 @@ def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor, codebooks: torch.Tensor, scales: torch.Tensor, codebook_partition_sizes: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: - return vllm_ops.aqlm_gemm(input, codes, codebooks, scales, + return torch.ops._C.aqlm_gemm(input, codes, codebooks, scales, codebook_partition_sizes, bias) def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor, codebook_partition_sizes: torch.Tensor) -> torch.Tensor: - return vllm_ops.aqlm_dequant(codes, codebooks, codebook_partition_sizes) + return torch.ops._C.aqlm_dequant(codes, codebooks, codebook_partition_sizes) # gptq_marlin def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, size_k: int, size_n: int, num_bits: int) -> torch.Tensor: - return vllm_ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, + return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits) @@ -220,7 +219,7 @@ def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, perm: torch.Tensor, workspace: torch.Tensor, num_bits: int, size_m: int, size_n: int, size_k: int, is_k_full: bool) -> torch.Tensor: - return vllm_ops.gptq_marlin_gemm(a, b_q_weight, b_scales, g_idx, perm, + return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, g_idx, perm, workspace, num_bits, size_m, size_n, size_k, is_k_full) @@ -259,9 +258,9 @@ def scaled_fp8_quant( output = torch.empty_like(input, dtype=torch.float8_e4m3fn) if scale is None: scale = torch.zeros(1, device=input.device, dtype=torch.float32) - vllm_ops.dynamic_scaled_fp8_quant(output, input, scale) + torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) else: - vllm_ops.static_scaled_fp8_quant(output, input, scale) + torch.ops._C.static_scaled_fp8_quant(output, input, scale) return output, scale @@ -284,14 +283,14 @@ def scaled_int8_quant( output = torch.empty_like(input, dtype=torch.int8) if scale is not None: # static-per-tensor quantization. - vllm_ops.static_scaled_int8_quant(output, input, scale) + torch.ops._C.static_scaled_int8_quant(output, input, scale) return output, scale # dynamic-per-token quantization. input_scales = torch.empty((input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32) - vllm_ops.dynamic_scaled_int8_quant(output, input, input_scales) + torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales) return output, input_scales @@ -300,7 +299,7 @@ def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, block_size: int, sorted_token_ids: torch.Tensor, experts_ids: torch.Tensor, num_tokens_post_pad: torch.Tensor) -> None: - vllm_ops.moe_align_block_size(topk_ids, num_experts, block_size, + torch.ops._C.moe_align_block_size(topk_ids, num_experts, block_size, sorted_token_ids, experts_ids, num_tokens_post_pad) From 77c3e93e988f01c2dd9cb36fad885bc6abb42823 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 25 May 2024 05:47:54 +0000 Subject: [PATCH 02/41] fix cpu defs --- csrc/cpu/pybind.cpp | 25 ++++++++++++++++++++++++- csrc/pybind.cpp | 6 ++---- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/csrc/cpu/pybind.cpp b/csrc/cpu/pybind.cpp index 1c864d2922b1..9fbb8e764289 100644 --- a/csrc/cpu/pybind.cpp +++ b/csrc/cpu/pybind.cpp @@ -7,32 +7,55 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Attention ops // Compute the attention between an input query and the cached keys/values - //using PagedAttention. + // using PagedAttention. + ops.def("paged_attention_v1(Tensor out, Tensor query, Tensor key_cache, " + "Tensor value_cache, int num_kv_heads, float scale, Tensor " + "block_tables, Tensor seq_lens, int block_size, int max_seq_len, " + "Tensor? alibi_slopes, str kv_cache_dtype, float kv_scale) -> ()"); ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1); + // PagedAttention V2. + ops.def("paged_attention_v2(Tensor out, Tensor exp_sums, Tensor max_logits," + "Tensor tmp_out, Tensor query, Tensor key_cache, Tensor value_cache," + "int num_kv_heads, float scale, Tensor block_tables, Tensor seq_lens," + "int block_size, int max_seq_len, Tensor? alibi_slopes, " + "str kv_cache_dtype, float kv_scale) -> ()"); ops.impl("paged_attention_v2", torch::kCPU, &paged_attention_v2); // Activation ops + // Activation function used in SwiGLU. + ops.def("silu_and_mul(Tensor out, Tensor input) -> ()"); ops.impl("silu_and_mul", torch::kCPU, &silu_and_mul); + // Activation function used in GeGLU with `none` approximation. + ops.def("gelu_and_mul(Tensor out, Tensor input) -> ()"); ops.impl("gelu_and_mul", torch::kCPU, &gelu_and_mul); + // Activation function used in GeGLU with `tanh` approximation. + ops.def("gelu_tanh_and_mul(Tensor out, Tensor input) -> ()"); ops.impl("gelu_tanh_and_mul", torch::kCPU, &gelu_tanh_and_mul); + // GELU implementation used in GPT-2. + ops.def("gelu_new(Tensor out, Tensor input) -> ()"); ops.impl("gelu_new", torch::kCPU, &gelu_new); + // Approximate GELU implementation. + ops.def("gelu_fast(Tensor out, Tensor input) -> ()"); ops.impl("gelu_fast", torch::kCPU, &gelu_fast); // Layernorm // Apply Root Mean Square (RMS) Normalization to the input tensor. + ops.def("rms_norm(Tensor out, Tensor input, Tensor weight, float epsilon) -> ()"); ops.impl("rms_norm", torch::kCPU, &rms_norm); // In-place fused Add and RMS Normalization. + ops.def("fused_add_rms_norm(Tensor input, Tensor residual, Tensor weight, float epsilon) -> ()"); ops.impl("fused_add_rms_norm", torch::kCPU, &fused_add_rms_norm); // Rotary embedding // Apply GPT-NeoX or GPT-J style rotary embedding to query and key. + ops.def("rotary_embedding(Tensor positions, Tensor query, Tensor key, int head_size, Tensor cos_sin_cache, bool is_neox) -> ()"); ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding); } diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 59e24b467d04..a2f3f7366069 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -51,14 +51,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("rms_norm", torch::kCUDA, &rms_norm); // In-place fused Add and RMS Normalization. - ops.def("fused_add_rms_norm(Tensor input, Tensor residual, Tensor weight, " - "float epsilon) -> ()"); + ops.def("fused_add_rms_norm(Tensor input, Tensor residual, Tensor weight, float epsilon) -> ()"); ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm); // Rotary embedding // Apply GPT-NeoX or GPT-J style rotary embedding to query and key. - ops.def("rotary_embedding(Tensor positions, Tensor query, Tensor key, int " - "head_size, Tensor cos_sin_cache, bool is_neox) -> ()"); + ops.def("rotary_embedding(Tensor positions, Tensor query, Tensor key, int head_size, Tensor cos_sin_cache, bool is_neox) -> ()"); ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding); // Apply GPT-NeoX or GPT-J style rotary embedding to query and key From d3172994fbd186ddc1eb4238621a8d03c7046daf Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 25 May 2024 06:41:49 +0000 Subject: [PATCH 03/41] fix typo in cpu pybind --- csrc/cpu/pybind.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/cpu/pybind.cpp b/csrc/cpu/pybind.cpp index 9fbb8e764289..4024c29d2dbf 100644 --- a/csrc/cpu/pybind.cpp +++ b/csrc/cpu/pybind.cpp @@ -63,7 +63,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // Cache ops - pybind11::module cache_ops = m.impl_submodule("cache_ops", "vLLM cache ops"); + pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); cache_ops.def("swap_blocks", &swap_blocks, "Swap in (out) the cache blocks from src to dst"); cache_ops.def("copy_blocks", ©_blocks, From 730ad1556c12e2c00dbd8c74ee8f9025cc799d8d Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 25 May 2024 06:55:10 +0000 Subject: [PATCH 04/41] fix moe/punica --- vllm/lora/punica.py | 20 +++++----- .../layers/fused_moe/fused_moe.py | 40 +++++++++++-------- 2 files changed, 33 insertions(+), 27 deletions(-) diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py index c87bed54726f..98379b93b194 100644 --- a/vllm/lora/punica.py +++ b/vllm/lora/punica.py @@ -42,11 +42,11 @@ def bgmv( scale: Scaling factor. """ try: - import vllm._punica_C as punica_kernels + torch.ops._punica_C.dispatch_bgmv except ImportError as e: _raise_import_error(e) - punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale) + torch.ops._punica_C.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale) def dispatch_bgmv_low_level(y: torch.Tensor, x: torch.Tensor, @@ -76,10 +76,10 @@ def dispatch_bgmv_low_level(y: torch.Tensor, x: torch.Tensor, y_slice_size: Size of the y column slice. """ try: - import vllm._punica_C as punica_kernels + torch.ops._punica_C.dispatch_bgmv except ImportError as e: _raise_import_error(e) - punica_kernels.dispatch_bgmv_low_level( + torch.ops._punica_C.dispatch_bgmv_low_level( y, x, w_t_all, @@ -123,7 +123,7 @@ def add_lora(y: torch.Tensor, buffer: Optional. Shape: `[B, R]`. Temporary buffer. """ try: - import vllm._punica_C as punica_kernels + torch.ops._punica_C.dispatch_bgmv except ImportError as e: _raise_import_error(e) @@ -135,8 +135,8 @@ def add_lora(y: torch.Tensor, buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) - punica_kernels.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, 1.0) - punica_kernels.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx, + torch.ops._punica_C.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, 1.0) + torch.ops._punica_C.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx, scale) @@ -177,7 +177,7 @@ def add_lora_slice(y: torch.Tensor, y_slice_size: Size of the y column slice. """ try: - import vllm._punica_C as punica_kernels + torch.ops._punica_C.dispatch_bgmv except ImportError as e: _raise_import_error(e) @@ -189,7 +189,7 @@ def add_lora_slice(y: torch.Tensor, buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) - punica_kernels.dispatch_bgmv_low_level( + torch.ops._punica_C.dispatch_bgmv_low_level( buffer, x, wa_t_all, @@ -200,7 +200,7 @@ def add_lora_slice(y: torch.Tensor, buffer.size(1), 0, ) - punica_kernels.dispatch_bgmv_low_level( + torch.ops._punica_C.dispatch_bgmv_low_level( y, buffer, wb_t_all, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 1c6947137a1c..757b873dc270 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -343,26 +343,32 @@ def fused_topk( M, _ = hidden_states.shape - topk_weights = torch.empty(M, + if is_hip(): + # The MoE kernels are not yet supported on ROCm. + routing_weights = torch.softmax(gating_output, + dim=-1, + dtype=torch.float32) + topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1) + else: + topk_weights = torch.empty(M, + topk, + dtype=torch.float32, + device=hidden_states.device) + topk_ids = torch.empty(M, topk, dtype=torch.float32, device=hidden_states.device) - topk_ids = torch.empty(M, - topk, - dtype=torch.int32, - device=hidden_states.device) - token_expert_indicies = torch.empty(M, - topk, - dtype=torch.int32, - device=hidden_states.device) - moe_kernels.topk_softmax( - topk_weights, - topk_ids, - token_expert_indicies, - gating_output.float(), # TODO(woosuk): Optimize this. - ) - del token_expert_indicies # Not used. Will be used in the future. - + token_expert_indicies = torch.empty(M, + topk, + dtype=torch.int32, + device=hidden_states.device) + torch.ops._moe_C.topk_softmax( + topk_weights, + topk_ids, + token_expert_indicies, + gating_output.float(), # TODO(woosuk): Optimize this. + ) + del token_expert_indicies # Not used. Will be used in the future. if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) return topk_weights, topk_ids From d4357329583efb20c0aca42f77a649c1f5de3000 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 25 May 2024 20:31:43 +0000 Subject: [PATCH 05/41] fixes --- csrc/moe/moe_ops.cpp | 6 ++++++ csrc/punica/punica_pybind.cpp | 8 ++++++++ 2 files changed, 14 insertions(+) diff --git a/csrc/moe/moe_ops.cpp b/csrc/moe/moe_ops.cpp index 3b24c2ed4864..ad298a5078e9 100644 --- a/csrc/moe/moe_ops.cpp +++ b/csrc/moe/moe_ops.cpp @@ -6,5 +6,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { // Apply topk softmax to the gating outputs. + m.def("topk_softmax(Tensor topk_weights, Tensor topk_indices, Tensor token_expert_indices, Tensor gating_output) -> ()"); m.impl("topk_softmax", torch::kCUDA, &topk_softmax); } + +// TODO: get rid of this +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ +} diff --git a/csrc/punica/punica_pybind.cpp b/csrc/punica/punica_pybind.cpp index ea054b475148..13a4811ec231 100644 --- a/csrc/punica/punica_pybind.cpp +++ b/csrc/punica/punica_pybind.cpp @@ -5,6 +5,14 @@ #define TORCH_LIBRARY_EXPAND(NAME, M) TORCH_LIBRARY(NAME, M) TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { + m.def("dispatch_bgmv(Tensor y, Tensor x, Tensor w, Tensor indicies, int layer_idx, float scale) -> ()"); m.impl("dispatch_bgmv", torch::kCUDA, &dispatch_bgmv); + + m.def("dispatch_bgmv_low_level(Tensor y, Tensor x, Tensor w, Tensor indicies, int64_t layer_idx, float scale, int h_in, int h_out, int y_offset) -> ()"); m.impl("dispatch_bgmv_low_level", torch::kCUDA, &dispatch_bgmv_low_level); } + +// TODO: get rid of this +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ +} From f474a785421d6a1f2d6292e8b07aaf3d1bc24f63 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sun, 26 May 2024 02:58:25 +0000 Subject: [PATCH 06/41] fixes --- vllm/lora/punica.py | 8 ++++---- vllm/model_executor/layers/fused_moe/fused_moe.py | 2 ++ 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py index 98379b93b194..4ad37261828b 100644 --- a/vllm/lora/punica.py +++ b/vllm/lora/punica.py @@ -42,7 +42,7 @@ def bgmv( scale: Scaling factor. """ try: - torch.ops._punica_C.dispatch_bgmv + import vllm._punica_C as punica_kernels except ImportError as e: _raise_import_error(e) @@ -76,7 +76,7 @@ def dispatch_bgmv_low_level(y: torch.Tensor, x: torch.Tensor, y_slice_size: Size of the y column slice. """ try: - torch.ops._punica_C.dispatch_bgmv + import vllm._punica_C as punica_kernels except ImportError as e: _raise_import_error(e) torch.ops._punica_C.dispatch_bgmv_low_level( @@ -123,7 +123,7 @@ def add_lora(y: torch.Tensor, buffer: Optional. Shape: `[B, R]`. Temporary buffer. """ try: - torch.ops._punica_C.dispatch_bgmv + import vllm._punica_C as punica_kernels except ImportError as e: _raise_import_error(e) @@ -177,7 +177,7 @@ def add_lora_slice(y: torch.Tensor, y_slice_size: Size of the y column slice. """ try: - torch.ops._punica_C.dispatch_bgmv + import vllm._punica_C as punica_kernels except ImportError as e: _raise_import_error(e) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 757b873dc270..60d922aa6cc6 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -350,6 +350,8 @@ def fused_topk( dtype=torch.float32) topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1) else: + import vllm._moe_C as moe_kernels + topk_weights = torch.empty(M, topk, dtype=torch.float32, From 5ddf7c061695e9d0d930b0a08ba441f6835af074 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sun, 26 May 2024 03:33:15 +0000 Subject: [PATCH 07/41] fix punica_pybind signature --- csrc/punica/punica_pybind.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/punica/punica_pybind.cpp b/csrc/punica/punica_pybind.cpp index 13a4811ec231..7027416d7ee8 100644 --- a/csrc/punica/punica_pybind.cpp +++ b/csrc/punica/punica_pybind.cpp @@ -8,7 +8,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { m.def("dispatch_bgmv(Tensor y, Tensor x, Tensor w, Tensor indicies, int layer_idx, float scale) -> ()"); m.impl("dispatch_bgmv", torch::kCUDA, &dispatch_bgmv); - m.def("dispatch_bgmv_low_level(Tensor y, Tensor x, Tensor w, Tensor indicies, int64_t layer_idx, float scale, int h_in, int h_out, int y_offset) -> ()"); + m.def("dispatch_bgmv_low_level(Tensor y, Tensor x, Tensor w, Tensor indicies, int layer_idx, float scale, int h_in, int h_out, int y_offset) -> ()"); m.impl("dispatch_bgmv_low_level", torch::kCUDA, &dispatch_bgmv_low_level); } From 699a373faf1cdaefd878c967a0fb148c775441d3 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sun, 26 May 2024 06:02:06 +0000 Subject: [PATCH 08/41] rebase --- csrc/attention/attention_kernels.cu | 4 ++-- csrc/ops.h | 2 ++ csrc/pybind.cpp | 8 ++++++-- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 3c8907dc2680..a985d1d19bdb 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -815,7 +815,7 @@ void paged_attention_v1( torch::Tensor& seq_lens, // [num_seqs] int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, + const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { const bool is_block_sparse = (blocksparse_vert_stride > 1); @@ -979,7 +979,7 @@ void paged_attention_v2( torch::Tensor& seq_lens, // [num_seqs] int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, + const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { const bool is_block_sparse = (blocksparse_vert_stride > 1); diff --git a/csrc/ops.h b/csrc/ops.h index 1bd480c1466a..77c95a10f5ad 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -2,6 +2,8 @@ #include +#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE) + void paged_attention_v1( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int64_t num_kv_heads, double scale, diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index a2f3f7366069..cd7b627288f7 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -12,7 +12,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("paged_attention_v1(Tensor out, Tensor query, Tensor key_cache, " "Tensor value_cache, int num_kv_heads, float scale, Tensor " "block_tables, Tensor seq_lens, int block_size, int max_seq_len, " - "Tensor? alibi_slopes, str kv_cache_dtype, float kv_scale) -> ()"); + "Tensor? alibi_slopes, str kv_cache_dtype, float kv_scale, int tp_rank," + "int blocksparse_local_blocks, int blocksparse_vert_stride, " + "int blocksparse_block_size, int blocksparse_head_sliding_step) -> ()"); ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1); // PagedAttention V2. @@ -20,7 +22,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "Tensor tmp_out, Tensor query, Tensor key_cache, Tensor value_cache," "int num_kv_heads, float scale, Tensor block_tables, Tensor seq_lens," "int block_size, int max_seq_len, Tensor? alibi_slopes, " - "str kv_cache_dtype, float kv_scale) -> ()"); + "str kv_cache_dtype, float kv_scale, int tp_rank, " + "int blocksparse_local_blocks, int blocksparse_vert_stride," + "int blocksparse_block_size, int blocksparse_head_sliding_step) -> ()"); ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2); // Activation ops From f2646f4d6ad5334d4d9fb65ba1885ed913612520 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sun, 26 May 2024 06:11:55 +0000 Subject: [PATCH 09/41] fix format --- csrc/cpu/pybind.cpp | 35 +++++++------ csrc/moe/moe_ops.cpp | 8 +-- csrc/pybind.cpp | 114 ++++++++++++++++++++++++++++--------------- vllm/_custom_ops.py | 39 +++++++-------- vllm/lora/punica.py | 8 +-- 5 files changed, 124 insertions(+), 80 deletions(-) diff --git a/csrc/cpu/pybind.cpp b/csrc/cpu/pybind.cpp index 4024c29d2dbf..91ffc59d336b 100644 --- a/csrc/cpu/pybind.cpp +++ b/csrc/cpu/pybind.cpp @@ -8,18 +8,20 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Attention ops // Compute the attention between an input query and the cached keys/values // using PagedAttention. - ops.def("paged_attention_v1(Tensor out, Tensor query, Tensor key_cache, " - "Tensor value_cache, int num_kv_heads, float scale, Tensor " - "block_tables, Tensor seq_lens, int block_size, int max_seq_len, " - "Tensor? alibi_slopes, str kv_cache_dtype, float kv_scale) -> ()"); + ops.def( + "paged_attention_v1(Tensor out, Tensor query, Tensor key_cache, " + "Tensor value_cache, int num_kv_heads, float scale, Tensor " + "block_tables, Tensor seq_lens, int block_size, int max_seq_len, " + "Tensor? alibi_slopes, str kv_cache_dtype, float kv_scale) -> ()"); ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1); // PagedAttention V2. - ops.def("paged_attention_v2(Tensor out, Tensor exp_sums, Tensor max_logits," - "Tensor tmp_out, Tensor query, Tensor key_cache, Tensor value_cache," - "int num_kv_heads, float scale, Tensor block_tables, Tensor seq_lens," - "int block_size, int max_seq_len, Tensor? alibi_slopes, " - "str kv_cache_dtype, float kv_scale) -> ()"); + ops.def( + "paged_attention_v2(Tensor out, Tensor exp_sums, Tensor max_logits," + "Tensor tmp_out, Tensor query, Tensor key_cache, Tensor value_cache," + "int num_kv_heads, float scale, Tensor block_tables, Tensor seq_lens," + "int block_size, int max_seq_len, Tensor? alibi_slopes, " + "str kv_cache_dtype, float kv_scale) -> ()"); ops.impl("paged_attention_v2", torch::kCPU, &paged_attention_v2); // Activation ops @@ -46,22 +48,25 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Layernorm // Apply Root Mean Square (RMS) Normalization to the input tensor. - ops.def("rms_norm(Tensor out, Tensor input, Tensor weight, float epsilon) -> ()"); + ops.def( + "rms_norm(Tensor out, Tensor input, Tensor weight, float epsilon) -> ()"); ops.impl("rms_norm", torch::kCPU, &rms_norm); // In-place fused Add and RMS Normalization. - ops.def("fused_add_rms_norm(Tensor input, Tensor residual, Tensor weight, float epsilon) -> ()"); + ops.def( + "fused_add_rms_norm(Tensor input, Tensor residual, Tensor weight, float " + "epsilon) -> ()"); ops.impl("fused_add_rms_norm", torch::kCPU, &fused_add_rms_norm); // Rotary embedding // Apply GPT-NeoX or GPT-J style rotary embedding to query and key. - ops.def("rotary_embedding(Tensor positions, Tensor query, Tensor key, int head_size, Tensor cos_sin_cache, bool is_neox) -> ()"); + ops.def( + "rotary_embedding(Tensor positions, Tensor query, Tensor key, int " + "head_size, Tensor cos_sin_cache, bool is_neox) -> ()"); ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding); } - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // Cache ops pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); cache_ops.def("swap_blocks", &swap_blocks, diff --git a/csrc/moe/moe_ops.cpp b/csrc/moe/moe_ops.cpp index ad298a5078e9..b3402269b377 100644 --- a/csrc/moe/moe_ops.cpp +++ b/csrc/moe/moe_ops.cpp @@ -6,11 +6,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { // Apply topk softmax to the gating outputs. - m.def("topk_softmax(Tensor topk_weights, Tensor topk_indices, Tensor token_expert_indices, Tensor gating_output) -> ()"); + m.def( + "topk_softmax(Tensor topk_weights, Tensor topk_indices, Tensor " + "token_expert_indices, Tensor gating_output) -> ()"); m.impl("topk_softmax", torch::kCUDA, &topk_softmax); } // TODO: get rid of this -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ -} +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index cd7b627288f7..636ee4b39ffe 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -9,22 +9,24 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Attention ops // Compute the attention between an input query and the cached // keys/values using PagedAttention. - ops.def("paged_attention_v1(Tensor out, Tensor query, Tensor key_cache, " - "Tensor value_cache, int num_kv_heads, float scale, Tensor " - "block_tables, Tensor seq_lens, int block_size, int max_seq_len, " - "Tensor? alibi_slopes, str kv_cache_dtype, float kv_scale, int tp_rank," - "int blocksparse_local_blocks, int blocksparse_vert_stride, " - "int blocksparse_block_size, int blocksparse_head_sliding_step) -> ()"); + ops.def( + "paged_attention_v1(Tensor out, Tensor query, Tensor key_cache, " + "Tensor value_cache, int num_kv_heads, float scale, Tensor " + "block_tables, Tensor seq_lens, int block_size, int max_seq_len, " + "Tensor? alibi_slopes, str kv_cache_dtype, float kv_scale, int tp_rank," + "int blocksparse_local_blocks, int blocksparse_vert_stride, " + "int blocksparse_block_size, int blocksparse_head_sliding_step) -> ()"); ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1); // PagedAttention V2. - ops.def("paged_attention_v2(Tensor out, Tensor exp_sums, Tensor max_logits," - "Tensor tmp_out, Tensor query, Tensor key_cache, Tensor value_cache," - "int num_kv_heads, float scale, Tensor block_tables, Tensor seq_lens," - "int block_size, int max_seq_len, Tensor? alibi_slopes, " - "str kv_cache_dtype, float kv_scale, int tp_rank, " - "int blocksparse_local_blocks, int blocksparse_vert_stride," - "int blocksparse_block_size, int blocksparse_head_sliding_step) -> ()"); + ops.def( + "paged_attention_v2(Tensor out, Tensor exp_sums, Tensor max_logits," + "Tensor tmp_out, Tensor query, Tensor key_cache, Tensor value_cache," + "int num_kv_heads, float scale, Tensor block_tables, Tensor seq_lens," + "int block_size, int max_seq_len, Tensor? alibi_slopes, " + "str kv_cache_dtype, float kv_scale, int tp_rank, " + "int blocksparse_local_blocks, int blocksparse_vert_stride," + "int blocksparse_block_size, int blocksparse_head_sliding_step) -> ()"); ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2); // Activation ops @@ -50,68 +52,97 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Layernorm // Apply Root Mean Square (RMS) Normalization to the input tensor. - ops.def("rms_norm(Tensor out, Tensor input, Tensor weight, float epsilon) -> ()"); - //ops.def(torch::schema("rms_norm(Tensor out, Tensor input, Tensor weight, float epsilon) -> ()"), c10::AliasAnalysisKind::CONSERVATIVE); + ops.def( + "rms_norm(Tensor out, Tensor input, Tensor weight, float epsilon) -> ()"); + // ops.def(torch::schema("rms_norm(Tensor out, Tensor input, Tensor weight, + // float epsilon) -> ()"), c10::AliasAnalysisKind::CONSERVATIVE); ops.impl("rms_norm", torch::kCUDA, &rms_norm); // In-place fused Add and RMS Normalization. - ops.def("fused_add_rms_norm(Tensor input, Tensor residual, Tensor weight, float epsilon) -> ()"); + ops.def( + "fused_add_rms_norm(Tensor input, Tensor residual, Tensor weight, float " + "epsilon) -> ()"); ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm); // Rotary embedding // Apply GPT-NeoX or GPT-J style rotary embedding to query and key. - ops.def("rotary_embedding(Tensor positions, Tensor query, Tensor key, int head_size, Tensor cos_sin_cache, bool is_neox) -> ()"); + ops.def( + "rotary_embedding(Tensor positions, Tensor query, Tensor key, int " + "head_size, Tensor cos_sin_cache, bool is_neox) -> ()"); ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding); // Apply GPT-NeoX or GPT-J style rotary embedding to query and key // (supports multiple loras). - ops.def("batched_rotary_embedding(Tensor positions, Tensor query, Tensor " - "key, int head_size, Tensor cos_sin_cache, bool is_neox, int " - "rot_dim, Tensor cos_sin_cache_offsets) -> ()"); + ops.def( + "batched_rotary_embedding(Tensor positions, Tensor query, Tensor " + "key, int head_size, Tensor cos_sin_cache, bool is_neox, int " + "rot_dim, Tensor cos_sin_cache_offsets) -> ()"); ops.impl("batched_rotary_embedding", torch::kCUDA, &batched_rotary_embedding); // Quantization ops #ifndef USE_ROCM // Quantized GEMM for AQLM. - ops.def("aqlm_gemm(Tensor input, Tensor codes, Tensor codebooks, Tensor scales, Tensor codebook_partition_sizes, Tensor? bias) -> Tensor"); + ops.def( + "aqlm_gemm(Tensor input, Tensor codes, Tensor codebooks, Tensor scales, " + "Tensor codebook_partition_sizes, Tensor? bias) -> Tensor"); ops.impl("aqlm_gemm", torch::kCUDA, &aqlm_gemm); // Decompression method for AQLM. - ops.def("aqlm_dequant(Tensor codes, Tensor codebooks, Tensor codebook_partition_sizes) -> Tensor"); + ops.def( + "aqlm_dequant(Tensor codes, Tensor codebooks, Tensor " + "codebook_partition_sizes) -> Tensor"); ops.impl("aqlm_dequant", torch::kCUDA, &aqlm_dequant); // Quantized GEMM for AWQ. - ops.def("awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, Tensor _zeros, int split_k_iters) -> Tensor"); + ops.def( + "awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, " + "Tensor _zeros, int split_k_iters) -> Tensor"); ops.impl("awq_gemm", torch::kCUDA, &awq_gemm); // Marlin (Dense) Optimized Quantized GEMM for GPTQ. - ops.def("marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, Tensor workspace, int size_m, int size_n, int size_k) -> Tensor"); + ops.def( + "marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, Tensor " + "workspace, int size_m, int size_n, int size_k) -> Tensor"); ops.impl("marlin_gemm", torch::kCUDA, &marlin_gemm); // Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ. - ops.def("gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, Tensor b_scales, Tensor workspace, int num_bits, int size_m, int size_n, int size_k) -> Tensor"); + ops.def( + "gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, Tensor " + "b_scales, Tensor workspace, int num_bits, int size_m, int size_n, int " + "size_k) -> Tensor"); ops.impl("gptq_marlin_24_gemm", torch::kCUDA, &gptq_marlin_24_gemm); // gptq_marlin Optimized Quantized GEMM for GPTQ. - ops.def("gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, Tensor g_idx, Tensor perm, Tensor workspace, int num_bits, int size_m, int size_n, int size_k, bool is_k_full) -> Tensor"); + ops.def( + "gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, Tensor " + "g_idx, Tensor perm, Tensor workspace, int num_bits, int size_m, int " + "size_n, int size_k, bool is_k_full) -> Tensor"); ops.impl("gptq_marlin_gemm", torch::kCUDA, &gptq_marlin_gemm); // gptq_marlin repack from GPTQ. - ops.def("gptq_marlin_repack(Tensor b_q_weight, Tensor perm, int size_k, int size_n, int num_bits) -> Tensor"); + ops.def( + "gptq_marlin_repack(Tensor b_q_weight, Tensor perm, int size_k, int " + "size_n, int num_bits) -> Tensor"); ops.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack); // Dequantization for AWQ. - ops.def("awq_dequantize(Tensor _kernel, Tensor _scaling_factors, Tensor _zeros, int split_k_iters, int thx, int thy) -> Tensor"); + ops.def( + "awq_dequantize(Tensor _kernel, Tensor _scaling_factors, Tensor _zeros, " + "int split_k_iters, int thx, int thy) -> Tensor"); ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize); // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column // quantization. - ops.def("cutlass_scaled_mm_dq(Tensor out, Tensor a, Tensor b, Tensor a_scales, Tensor b_scales) -> ()"); + ops.def( + "cutlass_scaled_mm_dq(Tensor out, Tensor a, Tensor b, Tensor a_scales, " + "Tensor b_scales) -> ()"); ops.impl("cutlass_scaled_mm_dq", torch::kCUDA, &cutlass_scaled_mm_dq); #endif // Quantized GEMM for GPTQ. - ops.def("gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, int bit) -> Tensor"); + ops.def( + "gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, Tensor " + "b_gptq_scales, Tensor b_g_idx, bool use_exllama, int bit) -> Tensor"); ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm); // Post processing for GPTQ. @@ -119,25 +150,32 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle); // Quantized GEMM for SqueezeLLM. - ops.def("squeezellm_gemm(Tensor vec, Tensor mat, Tensor mul, Tensor lookup_table) -> ()"); + ops.def( + "squeezellm_gemm(Tensor vec, Tensor mat, Tensor mul, Tensor " + "lookup_table) -> ()"); ops.impl("squeezellm_gemm", torch::kCUDA, &squeezellm_gemm); // Compute FP8 quantized tensor for given scaling factor. - ops.def("static_scaled_fp8_quant(Tensor out, Tensor input, Tensor scale) -> ()"); + ops.def( + "static_scaled_fp8_quant(Tensor out, Tensor input, Tensor scale) -> ()"); ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant); // Compute FP8 quantized tensor and scaling factor. - ops.def("dynamic_scaled_fp8_quant(Tensor out, Tensor input, Tensor scale) -> ()"); + ops.def( + "dynamic_scaled_fp8_quant(Tensor out, Tensor input, Tensor scale) -> ()"); ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant); // Aligning the number of tokens to be processed by each expert such // that it is divisible by the block size. - ops.def("moe_align_block_size(Tensor topk_ids, int num_experts, int block_size," - "Tensor sorted_token_ids, Tensor experts_ids, Tensor num_tokens_post_pad) -> ()"); + ops.def( + "moe_align_block_size(Tensor topk_ids, int num_experts, int block_size," + "Tensor sorted_token_ids, Tensor experts_ids, Tensor " + "num_tokens_post_pad) -> ()"); ops.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); // Compute int8 quantized tensor for given scaling factor. - ops.def("static_scaled_int8_quant(Tensor out, Tensor input, float scale) -> ()"); + ops.def( + "static_scaled_int8_quant(Tensor out, Tensor input, float scale) -> ()"); ops.impl("static_scaled_int8_quant", torch::kCUDA, &static_scaled_int8_quant); // Compute int8 quantized tensor and scaling factor @@ -146,9 +184,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("dynamic_scaled_int8_quant", torch::kCUDA, &dynamic_scaled_int8_quant); } - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // Cache ops pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); cache_ops.def("swap_blocks", &swap_blocks, diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 2c4f71e14e1d..8adfb631febc 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -52,7 +52,7 @@ def paged_attention_v1( blocksparse_block_size: int = 64, blocksparse_head_sliding_step: int = 0, ) -> None: - vllm_ops.paged_attention_v1( + torch.ops._C.paged_attention_v1( out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, @@ -82,7 +82,7 @@ def paged_attention_v2( blocksparse_block_size: int = 64, blocksparse_head_sliding_step: int = 0, ) -> None: - vllm_ops.paged_attention_v2( + torch.ops._C.paged_attention_v2( out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, kv_scale, tp_rank, @@ -99,8 +99,8 @@ def rotary_embedding( cos_sin_cache: torch.Tensor, is_neox: bool, ) -> None: - torch.ops._C.rotary_embedding(positions, query, key, head_size, cos_sin_cache, - is_neox) + torch.ops._C.rotary_embedding(positions, query, key, head_size, + cos_sin_cache, is_neox) def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, @@ -109,8 +109,8 @@ def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, rot_dim: int, cos_sin_cache_offsets: torch.Tensor) -> None: torch.ops._C.batched_rotary_embedding(positions, query, key, head_size, - cos_sin_cache, is_neox, rot_dim, - cos_sin_cache_offsets) + cos_sin_cache, is_neox, rot_dim, + cos_sin_cache_offsets) # layer norm ops @@ -129,8 +129,8 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, split_k_iters: int, thx: int, thy: int) -> torch.Tensor: - return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters, thx, - thy) + return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters, + thx, thy) def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, @@ -144,7 +144,7 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, b_g_idx: torch.Tensor, use_exllama: bool, bit: int) -> torch.Tensor: return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, - b_g_idx, use_exllama, bit) + b_g_idx, use_exllama, bit) def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, @@ -163,7 +163,7 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int, size_n: int, size_k: int) -> torch.Tensor: return torch.ops._C.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m, - size_n, size_k) + size_n, size_k) # marlin_24 @@ -172,8 +172,8 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, workspace: torch.Tensor, num_bits: int, size_m: int, size_n: int, size_k: int) -> torch.Tensor: return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales, - workspace, num_bits, size_m, size_n, - size_k) + workspace, num_bits, size_m, + size_n, size_k) # cutlass @@ -198,12 +198,13 @@ def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor, codebook_partition_sizes: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: return torch.ops._C.aqlm_gemm(input, codes, codebooks, scales, - codebook_partition_sizes, bias) + codebook_partition_sizes, bias) def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor, codebook_partition_sizes: torch.Tensor) -> torch.Tensor: - return torch.ops._C.aqlm_dequant(codes, codebooks, codebook_partition_sizes) + return torch.ops._C.aqlm_dequant(codes, codebooks, + codebook_partition_sizes) # gptq_marlin @@ -211,7 +212,7 @@ def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, size_k: int, size_n: int, num_bits: int) -> torch.Tensor: return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, - num_bits) + num_bits) def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, @@ -220,8 +221,8 @@ def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, num_bits: int, size_m: int, size_n: int, size_k: int, is_k_full: bool) -> torch.Tensor: return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, g_idx, perm, - workspace, num_bits, size_m, size_n, - size_k, is_k_full) + workspace, num_bits, size_m, size_n, + size_k, is_k_full) # fp8 @@ -300,8 +301,8 @@ def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, experts_ids: torch.Tensor, num_tokens_post_pad: torch.Tensor) -> None: torch.ops._C.moe_align_block_size(topk_ids, num_experts, block_size, - sorted_token_ids, experts_ids, - num_tokens_post_pad) + sorted_token_ids, experts_ids, + num_tokens_post_pad) def reshape_and_cache( diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py index 4ad37261828b..c006789a95e3 100644 --- a/vllm/lora/punica.py +++ b/vllm/lora/punica.py @@ -46,7 +46,8 @@ def bgmv( except ImportError as e: _raise_import_error(e) - torch.ops._punica_C.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale) + torch.ops._punica_C.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, + scale) def dispatch_bgmv_low_level(y: torch.Tensor, x: torch.Tensor, @@ -135,9 +136,10 @@ def add_lora(y: torch.Tensor, buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) - torch.ops._punica_C.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, 1.0) + torch.ops._punica_C.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, + 1.0) torch.ops._punica_C.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx, - scale) + scale) def add_lora_slice(y: torch.Tensor, From 497c3f4a103d5ed7df4be59ac3cdd478722a6ae0 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sun, 26 May 2024 06:15:18 +0000 Subject: [PATCH 10/41] more clang-format --- csrc/attention/attention_kernels.cu | 14 ++++++++------ csrc/cpu/attention.cpp | 10 ++++++---- csrc/ops.h | 14 ++++++++------ csrc/punica/punica_pybind.cpp | 12 +++++++----- csrc/quantization/awq/gemm_kernels.cu | 4 ++-- 5 files changed, 31 insertions(+), 23 deletions(-) diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index a985d1d19bdb..fc2a53432abf 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -808,7 +808,7 @@ void paged_attention_v1( torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor& - value_cache, // [num_blocks, num_heads, head_size, block_size] + value_cache, // [num_blocks, num_heads, head_size, block_size] int64_t num_kv_heads, // [num_heads] double scale, torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] @@ -816,8 +816,9 @@ void paged_attention_v1( int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, - const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, - const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { + const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step) { const bool is_block_sparse = (blocksparse_vert_stride > 1); DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, @@ -972,7 +973,7 @@ void paged_attention_v2( torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor& - value_cache, // [num_blocks, num_heads, head_size, block_size] + value_cache, // [num_blocks, num_heads, head_size, block_size] int64_t num_kv_heads, // [num_heads] double scale, torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] @@ -980,8 +981,9 @@ void paged_attention_v2( int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, - const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, - const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { + const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step) { const bool is_block_sparse = (blocksparse_vert_stride > 1); DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, CALL_V2_LAUNCHER_BLOCK_SIZE) diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index 47b674cfaa7b..836709332531 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -424,8 +424,9 @@ void paged_attention_v1( torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, - const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, - const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { + const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step) { TORCH_CHECK(kv_scale == 1.0f); TORCH_CHECK(blocksparse_vert_stride <= 1, "CPU backend does not support blocksparse attention yet."); @@ -742,8 +743,9 @@ void paged_attention_v2( torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, - const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, - const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { + const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step) { TORCH_CHECK(kv_scale == 1.0f); TORCH_CHECK(blocksparse_vert_stride <= 1, "CPU backend does not support blocksparse attention yet."); diff --git a/csrc/ops.h b/csrc/ops.h index 77c95a10f5ad..0c5b233eca1e 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -10,8 +10,9 @@ void paged_attention_v1( torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, - const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, - const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step); + const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step); void paged_attention_v2( torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, @@ -20,8 +21,9 @@ void paged_attention_v2( torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, - const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, - const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step); + const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step); void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, double epsilon); @@ -66,8 +68,8 @@ torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor awq_dequantize(torch::Tensor _kernel, torch::Tensor _scaling_factors, - torch::Tensor _zeros, int64_t split_k_iters, int64_t thx, - int64_t thy); + torch::Tensor _zeros, int64_t split_k_iters, + int64_t thx, int64_t thy); torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_scales, torch::Tensor& workspace, diff --git a/csrc/punica/punica_pybind.cpp b/csrc/punica/punica_pybind.cpp index 7027416d7ee8..b18a019b6057 100644 --- a/csrc/punica/punica_pybind.cpp +++ b/csrc/punica/punica_pybind.cpp @@ -5,14 +5,16 @@ #define TORCH_LIBRARY_EXPAND(NAME, M) TORCH_LIBRARY(NAME, M) TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { - m.def("dispatch_bgmv(Tensor y, Tensor x, Tensor w, Tensor indicies, int layer_idx, float scale) -> ()"); + m.def( + "dispatch_bgmv(Tensor y, Tensor x, Tensor w, Tensor indicies, int " + "layer_idx, float scale) -> ()"); m.impl("dispatch_bgmv", torch::kCUDA, &dispatch_bgmv); - m.def("dispatch_bgmv_low_level(Tensor y, Tensor x, Tensor w, Tensor indicies, int layer_idx, float scale, int h_in, int h_out, int y_offset) -> ()"); + m.def( + "dispatch_bgmv_low_level(Tensor y, Tensor x, Tensor w, Tensor indicies, " + "int layer_idx, float scale, int h_in, int h_out, int y_offset) -> ()"); m.impl("dispatch_bgmv_low_level", torch::kCUDA, &dispatch_bgmv_low_level); } // TODO: get rid of this -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ -} +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} diff --git a/csrc/quantization/awq/gemm_kernels.cu b/csrc/quantization/awq/gemm_kernels.cu index 4ca69e956969..079694dcd5be 100644 --- a/csrc/quantization/awq/gemm_kernels.cu +++ b/csrc/quantization/awq/gemm_kernels.cu @@ -435,8 +435,8 @@ __global__ void __launch_bounds__(64) torch::Tensor awq_dequantize(torch::Tensor _kernel, torch::Tensor _scaling_factors, - torch::Tensor _zeros, int64_t split_k_iters, int64_t thx, - int64_t thy) { + torch::Tensor _zeros, int64_t split_k_iters, + int64_t thx, int64_t thy) { int in_c = _kernel.size(0); int qout_c = _kernel.size(1); int out_c = qout_c * 8; From ed5cb0b6c3da63430f1526ff6122699481e59441 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sun, 26 May 2024 18:39:27 +0000 Subject: [PATCH 11/41] cpu fixes --- csrc/cpu/pybind.cpp | 8 +++-- vllm/_custom_ops.py | 2 +- vllm/lora/punica.py | 29 +++++++++---------- .../layers/fused_moe/fused_moe.py | 1 + 4 files changed, 21 insertions(+), 19 deletions(-) diff --git a/csrc/cpu/pybind.cpp b/csrc/cpu/pybind.cpp index 91ffc59d336b..7b77983522b2 100644 --- a/csrc/cpu/pybind.cpp +++ b/csrc/cpu/pybind.cpp @@ -12,7 +12,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "paged_attention_v1(Tensor out, Tensor query, Tensor key_cache, " "Tensor value_cache, int num_kv_heads, float scale, Tensor " "block_tables, Tensor seq_lens, int block_size, int max_seq_len, " - "Tensor? alibi_slopes, str kv_cache_dtype, float kv_scale) -> ()"); + "Tensor? alibi_slopes, str kv_cache_dtype, float kv_scale, int tp_rank," + "int blocksparse_local_blocks, int blocksparse_vert_stride, " + "int blocksparse_block_size, int blocksparse_head_sliding_step) -> ()"); ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1); // PagedAttention V2. @@ -21,7 +23,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "Tensor tmp_out, Tensor query, Tensor key_cache, Tensor value_cache," "int num_kv_heads, float scale, Tensor block_tables, Tensor seq_lens," "int block_size, int max_seq_len, Tensor? alibi_slopes, " - "str kv_cache_dtype, float kv_scale) -> ()"); + "str kv_cache_dtype, float kv_scale, int tp_rank, " + "int blocksparse_local_blocks, int blocksparse_vert_stride," + "int blocksparse_block_size, int blocksparse_head_sliding_step) -> ()"); ops.impl("paged_attention_v2", torch::kCPU, &paged_attention_v2); // Activation ops diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 8adfb631febc..4990ee28ad4e 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Type, List +from typing import Optional, Tuple, Type import torch diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py index c006789a95e3..b0483bc97b97 100644 --- a/vllm/lora/punica.py +++ b/vllm/lora/punica.py @@ -16,6 +16,14 @@ def _raise_import_error(e): "was set.") from e +def _check_punica_support(): + try: + # ruff: noqa: F401 + import vllm._punica_C as punica_kernels + except ImportError as e: + _raise_import_error(e) + + def bgmv( y: torch.Tensor, x: torch.Tensor, @@ -41,10 +49,7 @@ def bgmv( layer_idx: Layer index of the weight matrices. scale: Scaling factor. """ - try: - import vllm._punica_C as punica_kernels - except ImportError as e: - _raise_import_error(e) + _check_punica_support() torch.ops._punica_C.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale) @@ -76,10 +81,8 @@ def dispatch_bgmv_low_level(y: torch.Tensor, x: torch.Tensor, y_offset: Offset to apply to the starting column of y. y_slice_size: Size of the y column slice. """ - try: - import vllm._punica_C as punica_kernels - except ImportError as e: - _raise_import_error(e) + _check_punica_support() + torch.ops._punica_C.dispatch_bgmv_low_level( y, x, @@ -123,10 +126,7 @@ def add_lora(y: torch.Tensor, scale: Scaling factor. buffer: Optional. Shape: `[B, R]`. Temporary buffer. """ - try: - import vllm._punica_C as punica_kernels - except ImportError as e: - _raise_import_error(e) + _check_punica_support() r = wb_t_all.size(-1) if buffer is None: @@ -178,10 +178,7 @@ def add_lora_slice(y: torch.Tensor, y_offset: Offset to apply to the starting column of y. y_slice_size: Size of the y column slice. """ - try: - import vllm._punica_C as punica_kernels - except ImportError as e: - _raise_import_error(e) + _check_punica_support() r = wb_t_all.size(-1) if buffer is None: diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 60d922aa6cc6..d1c9049ce5b0 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -350,6 +350,7 @@ def fused_topk( dtype=torch.float32) topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1) else: + # ruff: noqa: F401 import vllm._moe_C as moe_kernels topk_weights = torch.empty(M, From 81c2783a02ec0ab96cc78cbe582a5992d87c668a Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 27 May 2024 03:33:31 +0000 Subject: [PATCH 12/41] convert cache_ops + cuda_utils --- csrc/cache.h | 8 ++--- csrc/cache_kernels.cu | 8 ++--- csrc/cuda_utils.h | 4 +-- csrc/cuda_utils_kernels.cu | 6 ++-- csrc/ops.h | 2 ++ csrc/pybind.cpp | 51 +++++++++++++++++---------- vllm/_custom_ops.py | 27 +++++++++----- vllm/attention/backends/flash_attn.py | 10 +++--- vllm/utils.py | 7 ++-- 9 files changed, 72 insertions(+), 51 deletions(-) diff --git a/csrc/cache.h b/csrc/cache.h index 435ae3e57f55..fbf6064f0791 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -8,14 +8,14 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst, const torch::Tensor& block_mapping); -void copy_blocks(std::vector& key_caches, - std::vector& value_caches, +void copy_blocks(std::vector const& key_caches, + std::vector const& value_caches, const torch::Tensor& block_mapping); void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype, const float kv_scale); + const std::string& kv_cache_dtype, const double kv_scale); void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, @@ -25,4 +25,4 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, // Just for unittest void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, - const float scale, const std::string& kv_cache_dtype); + const double scale, const std::string& kv_cache_dtype); diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index d924ac39b89c..80538ac18b98 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -95,8 +95,8 @@ __global__ void copy_blocks_kernel(int64_t* key_cache_ptrs, } // namespace vllm -void copy_blocks(std::vector& key_caches, - std::vector& value_caches, +void copy_blocks(std::vector const& key_caches, + std::vector const& value_caches, const torch::Tensor& block_mapping) { int num_layers = key_caches.size(); TORCH_CHECK(num_layers == value_caches.size()); @@ -255,7 +255,7 @@ void reshape_and_cache( torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] torch::Tensor& slot_mapping, // [num_tokens] - const std::string& kv_cache_dtype, const float kv_scale) { + const std::string& kv_cache_dtype, const double kv_scale) { int num_tokens = key.size(0); int num_heads = key.size(1); int head_size = key.size(2); @@ -334,7 +334,7 @@ __global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache, // Only for testing. void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, - const float kv_scale, const std::string& kv_cache_dtype) { + const double kv_scale, const std::string& kv_cache_dtype) { torch::Device src_device = src_cache.device(); torch::Device dst_device = dst_cache.device(); TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU") diff --git a/csrc/cuda_utils.h b/csrc/cuda_utils.h index 2ba49b339e14..09f8a2f12dd0 100644 --- a/csrc/cuda_utils.h +++ b/csrc/cuda_utils.h @@ -2,6 +2,6 @@ #include -int get_device_attribute(int attribute, int device_id); +int64_t get_device_attribute(int64_t attribute, int64_t device_id); -int get_max_shared_memory_per_block_device_attribute(int device_id); +int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id); diff --git a/csrc/cuda_utils_kernels.cu b/csrc/cuda_utils_kernels.cu index 7d8e2e19720f..d6f9eb646fad 100644 --- a/csrc/cuda_utils_kernels.cu +++ b/csrc/cuda_utils_kernels.cu @@ -2,7 +2,7 @@ #include #include #endif -int get_device_attribute(int attribute, int device_id) { +int64_t get_device_attribute(int64_t attribute, int64_t device_id) { int device, value; if (device_id < 0) { cudaGetDevice(&device); @@ -14,8 +14,8 @@ int get_device_attribute(int attribute, int device_id) { return value; } -int get_max_shared_memory_per_block_device_attribute(int device_id) { - int attribute; +int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id) { + int64_t attribute; // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html // cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74 diff --git a/csrc/ops.h b/csrc/ops.h index 0c5b233eca1e..d968ff2e1e7b 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -3,6 +3,8 @@ #include #define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE) +#define _CONCAT(A, B) A##B +#define CONCAT(A, B) _CONCAT(A, B) void paged_attention_v1( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 636ee4b39ffe..e88559279be3 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -184,30 +184,43 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("dynamic_scaled_int8_quant", torch::kCUDA, &dynamic_scaled_int8_quant); } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { +TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { // Cache ops - pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); - cache_ops.def("swap_blocks", &swap_blocks, - "Swap in (out) the cache blocks from src to dst"); - cache_ops.def("copy_blocks", ©_blocks, - "Copy the cache blocks from src to dst"); - cache_ops.def("reshape_and_cache", &reshape_and_cache, - "Reshape the key and value tensors and cache them"); - cache_ops.def("reshape_and_cache_flash", &reshape_and_cache_flash, - "Reshape the key and value tensors and cache them"); - cache_ops.def("convert_fp8", &convert_fp8, - "Convert the key and value cache to fp8 data type"); + // Swap in (out) the cache blocks from src to dst. + cache_ops.def("swap_blocks(Tensor src, Tensor dst, Tensor block_mapping) -> ()"); + cache_ops.impl("swap_blocks", torch::kCUDA, &swap_blocks); + + // Copy the cache blocks from src to dst. + cache_ops.def("copy_blocks", ©_blocks); + cache_ops.impl("copy_blocks", torch::kCUDA, ©_blocks); + + // Reshape the key and value tensors and cache them. + cache_ops.def("reshape_and_cache", &reshape_and_cache); + cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache); + // Reshape the key and value tensors and cache them. + cache_ops.def("reshape_and_cache_flash", &reshape_and_cache_flash); + cache_ops.impl("reshape_and_cache_flash", torch::kCUDA, &reshape_and_cache_flash); + + // Convert the key and value cache to fp8 data type. + cache_ops.def("convert_fp8", &convert_fp8); + cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8); +} + +TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) { // Cuda utils - pybind11::module cuda_utils = - m.def_submodule("cuda_utils", "vLLM cuda utils"); - cuda_utils.def("get_device_attribute", &get_device_attribute, - "Gets the specified device attribute."); - cuda_utils.def("get_max_shared_memory_per_block_device_attribute", - &get_max_shared_memory_per_block_device_attribute, - "Gets the maximum shared memory per block device attribute."); + // Gets the specified device attribute. + cuda_utils.def("get_device_attribute(int attribute, int device_id) -> int"); + cuda_utils.impl("get_device_attribute", torch::kCUDA, &get_device_attribute); + + // Gets the maximum shared memory per block device attribute. + cuda_utils.def("get_max_shared_memory_per_block_device_attribute(int device_id) -> int"); + cuda_utils.impl("get_max_shared_memory_per_block_device_attribute", torch::kCUDA, + &get_max_shared_memory_per_block_device_attribute); +} +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { #ifndef USE_ROCM // Custom all-reduce kernels pybind11::module custom_ar = m.def_submodule("custom_ar", "custom allreduce"); diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 4990ee28ad4e..de240b264116 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -3,7 +3,8 @@ import torch try: - from vllm._C import ops as vllm_ops + # ruff: noqa: SIM105 + import vllm._C except ImportError as e: from vllm.logger import init_logger logger = init_logger(__name__) @@ -314,8 +315,8 @@ def reshape_and_cache( kv_cache_dtype: str, kv_scale: float, ) -> None: - vllm_cache_ops.reshape_and_cache(key, value, key_cache, value_cache, - slot_mapping, kv_cache_dtype, kv_scale) + torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache, value_cache, + slot_mapping, kv_cache_dtype, kv_scale) def reshape_and_cache_flash( @@ -326,25 +327,33 @@ def reshape_and_cache_flash( slot_mapping: torch.Tensor, kv_cache_dtype: str, ) -> None: - vllm_cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache, - slot_mapping, kv_cache_dtype) + torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache, + slot_mapping, kv_cache_dtype) def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor, block_mapping: torch.Tensor) -> None: - vllm_cache_ops.copy_blocks(key_caches, value_caches, block_mapping) + torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping) def swap_blocks(src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor) -> None: - vllm_cache_ops.swap_blocks(src, dst, block_mapping) + torch.ops._C_cache_ops.swap_blocks(src, dst, block_mapping) def convert_fp8(output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8") -> None: - vllm_cache_ops.convert_fp8(output, input, scale, kv_dtype) + torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype) -#TODO: cuda_utils, custom_ar +def get_device_attribute(attribute: int, device: int) -> int: + return torch.ops._C_cuda_utils.get_device_attribute(attribute, device) + + +def get_max_shared_memory_per_block_device_attribute(device: int) -> int: + return torch.ops._C_cuda_utils.get_max_shared_memory_per_block_device_attribute(device) + + +#TODO: custom_ar diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 070c074e511b..8c64c2bfdeb8 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -5,7 +5,7 @@ import torch from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache -from vllm._C import cache_ops +from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata) @@ -47,11 +47,11 @@ def swap_blocks( ) -> None: src_key_cache = src_kv_cache[0] dst_key_cache = dst_kv_cache[0] - cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) + ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) src_value_cache = src_kv_cache[1] dst_value_cache = dst_kv_cache[1] - cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) + ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) @staticmethod def copy_blocks( @@ -60,7 +60,7 @@ def copy_blocks( ) -> None: key_caches = [kv_cache[0] for kv_cache in kv_caches] value_caches = [kv_cache[1] for kv_cache in kv_caches] - cache_ops.copy_blocks(key_caches, value_caches, src_to_dists) + ops.copy_blocks(key_caches, value_caches, src_to_dists) @dataclass @@ -285,7 +285,7 @@ def forward( # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory profiling run. - cache_ops.reshape_and_cache_flash( + ops.reshape_and_cache_flash( key, value, key_cache, diff --git a/vllm/utils.py b/vllm/utils.py index 2bd24d086f69..54d446b23350 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -22,6 +22,7 @@ import torch import vllm.envs as envs +from vllm import _custom_ops as ops from vllm.logger import enable_trace_function_call, init_logger T = TypeVar("T") @@ -148,12 +149,8 @@ def is_neuron() -> bool: @lru_cache(maxsize=None) def get_max_shared_memory_bytes(gpu: int = 0) -> int: """Returns the maximum shared memory per thread block in bytes.""" - # NOTE: This import statement should be executed lazily since - # the Neuron-X backend does not have the `cuda_utils` module. - from vllm._C import cuda_utils - max_shared_mem = ( - cuda_utils.get_max_shared_memory_per_block_device_attribute(gpu)) + ops.get_max_shared_memory_per_block_device_attribute(gpu)) # value 0 will cause MAX_SEQ_LEN become negative and test_attention.py # will fail assert max_shared_mem > 0, "max_shared_mem can not be zero" From f0c5e87da940c28c3b996e5cc7bcecc324557f5e Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 27 May 2024 05:27:04 +0000 Subject: [PATCH 13/41] add mutable indices to schema registration --- csrc/cpu/pybind.cpp | 59 +++++++--------- csrc/moe/moe_ops.cpp | 10 +-- csrc/moe/moe_ops.h | 2 +- csrc/ops.h | 6 +- csrc/punica/punica_ops.h | 2 +- csrc/punica/punica_pybind.cpp | 12 +--- csrc/pybind.cpp | 124 +++++++++++----------------------- csrc/register.h | 51 ++++++++++++++ vllm/_custom_ops.py | 18 +++-- 9 files changed, 132 insertions(+), 152 deletions(-) create mode 100644 csrc/register.h diff --git a/csrc/cpu/pybind.cpp b/csrc/cpu/pybind.cpp index 7b77983522b2..a11f30c0d1cb 100644 --- a/csrc/cpu/pybind.cpp +++ b/csrc/cpu/pybind.cpp @@ -8,75 +8,62 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Attention ops // Compute the attention between an input query and the cached keys/values // using PagedAttention. - ops.def( - "paged_attention_v1(Tensor out, Tensor query, Tensor key_cache, " - "Tensor value_cache, int num_kv_heads, float scale, Tensor " - "block_tables, Tensor seq_lens, int block_size, int max_seq_len, " - "Tensor? alibi_slopes, str kv_cache_dtype, float kv_scale, int tp_rank," - "int blocksparse_local_blocks, int blocksparse_vert_stride, " - "int blocksparse_block_size, int blocksparse_head_sliding_step) -> ()"); + ops.def("paged_attention_v1", &paged_attention_v1); ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1); // PagedAttention V2. - ops.def( - "paged_attention_v2(Tensor out, Tensor exp_sums, Tensor max_logits," - "Tensor tmp_out, Tensor query, Tensor key_cache, Tensor value_cache," - "int num_kv_heads, float scale, Tensor block_tables, Tensor seq_lens," - "int block_size, int max_seq_len, Tensor? alibi_slopes, " - "str kv_cache_dtype, float kv_scale, int tp_rank, " - "int blocksparse_local_blocks, int blocksparse_vert_stride," - "int blocksparse_block_size, int blocksparse_head_sliding_step) -> ()"); + ops.def("paged_attention_v2", &paged_attention_v2); ops.impl("paged_attention_v2", torch::kCPU, &paged_attention_v2); // Activation ops // Activation function used in SwiGLU. - ops.def("silu_and_mul(Tensor out, Tensor input) -> ()"); + ops.def("silu_and_mul", &silu_and_mul); ops.impl("silu_and_mul", torch::kCPU, &silu_and_mul); // Activation function used in GeGLU with `none` approximation. - ops.def("gelu_and_mul(Tensor out, Tensor input) -> ()"); + ops.def("gelu_and_mul", &gelu_and_mul); ops.impl("gelu_and_mul", torch::kCPU, &gelu_and_mul); // Activation function used in GeGLU with `tanh` approximation. - ops.def("gelu_tanh_and_mul(Tensor out, Tensor input) -> ()"); + ops.def("gelu_tanh_and_mul", &gelu_tanh_and_mul); ops.impl("gelu_tanh_and_mul", torch::kCPU, &gelu_tanh_and_mul); // GELU implementation used in GPT-2. - ops.def("gelu_new(Tensor out, Tensor input) -> ()"); + ops.def("gelu_new", &gelu_new); ops.impl("gelu_new", torch::kCPU, &gelu_new); // Approximate GELU implementation. - ops.def("gelu_fast(Tensor out, Tensor input) -> ()"); + ops.def("gelu_fast", &gelu_fast); ops.impl("gelu_fast", torch::kCPU, &gelu_fast); // Layernorm // Apply Root Mean Square (RMS) Normalization to the input tensor. - ops.def( - "rms_norm(Tensor out, Tensor input, Tensor weight, float epsilon) -> ()"); + ops.def("rms_norm", &rms_norm); ops.impl("rms_norm", torch::kCPU, &rms_norm); // In-place fused Add and RMS Normalization. - ops.def( - "fused_add_rms_norm(Tensor input, Tensor residual, Tensor weight, float " - "epsilon) -> ()"); + ops.def("fused_add_rms_norm", &fused_add_rms_norm); ops.impl("fused_add_rms_norm", torch::kCPU, &fused_add_rms_norm); // Rotary embedding // Apply GPT-NeoX or GPT-J style rotary embedding to query and key. - ops.def( - "rotary_embedding(Tensor positions, Tensor query, Tensor key, int " - "head_size, Tensor cos_sin_cache, bool is_neox) -> ()"); + ops.def("rotary_embedding", &rotary_embedding); ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding); } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + +TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { // Cache ops - pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); - cache_ops.def("swap_blocks", &swap_blocks, - "Swap in (out) the cache blocks from src to dst"); - cache_ops.def("copy_blocks", ©_blocks, - "Copy the cache blocks from src to dst"); - cache_ops.def("reshape_and_cache", &reshape_and_cache, - "Reshape the key and value tensors and cache them"); + // Swap in (out) the cache blocks from src to dst. + cache_ops.def("swap_blocks", &swap_blocks); + cache_ops.impl("swap_blocks", torch::kCPU, &swap_blocks); + + // Copy the cache blocks from src to dst. + cache_ops.def("copy_blocks", ©_blocks); + cache_ops.impl("copy_blocks", torch::kCPU, ©_blocks); + + // Reshape the key and value tensors and cache them. + cache_ops.def("reshape_and_cache", &reshape_and_cache); + cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache); } diff --git a/csrc/moe/moe_ops.cpp b/csrc/moe/moe_ops.cpp index b3402269b377..c7e11268ef97 100644 --- a/csrc/moe/moe_ops.cpp +++ b/csrc/moe/moe_ops.cpp @@ -1,16 +1,10 @@ #include "moe_ops.h" -#include - -#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE) - TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { // Apply topk softmax to the gating outputs. - m.def( - "topk_softmax(Tensor topk_weights, Tensor topk_indices, Tensor " - "token_expert_indices, Tensor gating_output) -> ()"); + vllm::def(m, "topk_softmax", &topk_softmax, {0, 1}); m.impl("topk_softmax", torch::kCUDA, &topk_softmax); } -// TODO: get rid of this +// TODO: get rid of this? PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index 93e7844ac199..66b6e44a3b7c 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -1,6 +1,6 @@ #pragma once -#include +#include "register.h" void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices, torch::Tensor& token_expert_indices, diff --git a/csrc/ops.h b/csrc/ops.h index d968ff2e1e7b..156865ae3094 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -1,10 +1,6 @@ #pragma once -#include - -#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE) -#define _CONCAT(A, B) A##B -#define CONCAT(A, B) _CONCAT(A, B) +#include "register.h" void paged_attention_v1( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, diff --git a/csrc/punica/punica_ops.h b/csrc/punica/punica_ops.h index e94d26f9701c..d2038a828e9d 100644 --- a/csrc/punica/punica_ops.h +++ b/csrc/punica/punica_ops.h @@ -1,6 +1,6 @@ #pragma once -#include +#include "register.h" void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, torch::Tensor indicies, int64_t layer_idx, double scale); diff --git a/csrc/punica/punica_pybind.cpp b/csrc/punica/punica_pybind.cpp index b18a019b6057..cefe449dd7d2 100644 --- a/csrc/punica/punica_pybind.cpp +++ b/csrc/punica/punica_pybind.cpp @@ -1,18 +1,10 @@ -#include - #include "punica_ops.h" -#define TORCH_LIBRARY_EXPAND(NAME, M) TORCH_LIBRARY(NAME, M) - TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { - m.def( - "dispatch_bgmv(Tensor y, Tensor x, Tensor w, Tensor indicies, int " - "layer_idx, float scale) -> ()"); + vllm::def(m, "dispatch_bgmv", &dispatch_bgmv, {0}); m.impl("dispatch_bgmv", torch::kCUDA, &dispatch_bgmv); - m.def( - "dispatch_bgmv_low_level(Tensor y, Tensor x, Tensor w, Tensor indicies, " - "int layer_idx, float scale, int h_in, int h_out, int y_offset) -> ()"); + vllm::def(m, "dispatch_bgmv_low_level", &dispatch_bgmv_low_level, {0}); m.impl("dispatch_bgmv_low_level", torch::kCUDA, &dispatch_bgmv_low_level); } diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index e88559279be3..1047401416f8 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -3,179 +3,132 @@ #include "ops.h" #include +using vllm::def; + TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // vLLM custom ops // Attention ops // Compute the attention between an input query and the cached // keys/values using PagedAttention. - ops.def( - "paged_attention_v1(Tensor out, Tensor query, Tensor key_cache, " - "Tensor value_cache, int num_kv_heads, float scale, Tensor " - "block_tables, Tensor seq_lens, int block_size, int max_seq_len, " - "Tensor? alibi_slopes, str kv_cache_dtype, float kv_scale, int tp_rank," - "int blocksparse_local_blocks, int blocksparse_vert_stride, " - "int blocksparse_block_size, int blocksparse_head_sliding_step) -> ()"); + //ops.def("paged_attention_v1", &paged_attention_v1); + def(ops, "paged_attention_v1", &paged_attention_v1, {0}); ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1); // PagedAttention V2. - ops.def( - "paged_attention_v2(Tensor out, Tensor exp_sums, Tensor max_logits," - "Tensor tmp_out, Tensor query, Tensor key_cache, Tensor value_cache," - "int num_kv_heads, float scale, Tensor block_tables, Tensor seq_lens," - "int block_size, int max_seq_len, Tensor? alibi_slopes, " - "str kv_cache_dtype, float kv_scale, int tp_rank, " - "int blocksparse_local_blocks, int blocksparse_vert_stride," - "int blocksparse_block_size, int blocksparse_head_sliding_step) -> ()"); + def(ops, "paged_attention_v2", &paged_attention_v2, {0}); ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2); // Activation ops // Activation function used in SwiGLU. - ops.def("silu_and_mul(Tensor out, Tensor input) -> ()"); + def(ops, "silu_and_mul", &silu_and_mul, {0}); ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul); // Activation function used in GeGLU with `none` approximation. - ops.def("gelu_and_mul(Tensor out, Tensor input) -> ()"); + def(ops, "gelu_and_mul", &gelu_and_mul, {0}); ops.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul); // Activation function used in GeGLU with `tanh` approximation. - ops.def("gelu_tanh_and_mul(Tensor out, Tensor input) -> ()"); + def(ops, "gelu_tanh_and_mul", &gelu_tanh_and_mul, {0}); ops.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul); // GELU implementation used in GPT-2. - ops.def("gelu_new(Tensor out, Tensor input) -> ()"); + def(ops, "gelu_new", &gelu_new, {0}); ops.impl("gelu_new", torch::kCUDA, &gelu_new); // Approximate GELU implementation. - ops.def("gelu_fast(Tensor out, Tensor input) -> ()"); + def(ops, "gelu_fast", &gelu_fast, {0}); ops.impl("gelu_fast", torch::kCUDA, &gelu_fast); // Layernorm // Apply Root Mean Square (RMS) Normalization to the input tensor. - ops.def( - "rms_norm(Tensor out, Tensor input, Tensor weight, float epsilon) -> ()"); + def(ops, "rms_norm", &rms_norm, {0}); + //ops.def("rms_norm", &rms_norm); // ops.def(torch::schema("rms_norm(Tensor out, Tensor input, Tensor weight, // float epsilon) -> ()"), c10::AliasAnalysisKind::CONSERVATIVE); ops.impl("rms_norm", torch::kCUDA, &rms_norm); // In-place fused Add and RMS Normalization. - ops.def( - "fused_add_rms_norm(Tensor input, Tensor residual, Tensor weight, float " - "epsilon) -> ()"); + def(ops, "fused_add_rms_norm", &fused_add_rms_norm, {0, 1}); ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm); // Rotary embedding // Apply GPT-NeoX or GPT-J style rotary embedding to query and key. - ops.def( - "rotary_embedding(Tensor positions, Tensor query, Tensor key, int " - "head_size, Tensor cos_sin_cache, bool is_neox) -> ()"); + def(ops, "rotary_embedding", &rotary_embedding, {1, 2}); ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding); // Apply GPT-NeoX or GPT-J style rotary embedding to query and key // (supports multiple loras). - ops.def( - "batched_rotary_embedding(Tensor positions, Tensor query, Tensor " - "key, int head_size, Tensor cos_sin_cache, bool is_neox, int " - "rot_dim, Tensor cos_sin_cache_offsets) -> ()"); + def(ops, "batched_rotary_embedding", &batched_rotary_embedding, {1, 2}); // ? ops.impl("batched_rotary_embedding", torch::kCUDA, &batched_rotary_embedding); // Quantization ops #ifndef USE_ROCM // Quantized GEMM for AQLM. - ops.def( - "aqlm_gemm(Tensor input, Tensor codes, Tensor codebooks, Tensor scales, " - "Tensor codebook_partition_sizes, Tensor? bias) -> Tensor"); + ops.def("aqlm_gemm", &aqlm_gemm); ops.impl("aqlm_gemm", torch::kCUDA, &aqlm_gemm); // Decompression method for AQLM. - ops.def( - "aqlm_dequant(Tensor codes, Tensor codebooks, Tensor " - "codebook_partition_sizes) -> Tensor"); + ops.def("aqlm_dequant", &aqlm_dequant); ops.impl("aqlm_dequant", torch::kCUDA, &aqlm_dequant); // Quantized GEMM for AWQ. - ops.def( - "awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, " - "Tensor _zeros, int split_k_iters) -> Tensor"); + ops.def("awq_gemm", &awq_gemm); ops.impl("awq_gemm", torch::kCUDA, &awq_gemm); // Marlin (Dense) Optimized Quantized GEMM for GPTQ. - ops.def( - "marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, Tensor " - "workspace, int size_m, int size_n, int size_k) -> Tensor"); + ops.def("marlin_gemm", &marlin_gemm); ops.impl("marlin_gemm", torch::kCUDA, &marlin_gemm); // Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ. - ops.def( - "gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, Tensor " - "b_scales, Tensor workspace, int num_bits, int size_m, int size_n, int " - "size_k) -> Tensor"); + ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm); ops.impl("gptq_marlin_24_gemm", torch::kCUDA, &gptq_marlin_24_gemm); // gptq_marlin Optimized Quantized GEMM for GPTQ. - ops.def( - "gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, Tensor " - "g_idx, Tensor perm, Tensor workspace, int num_bits, int size_m, int " - "size_n, int size_k, bool is_k_full) -> Tensor"); + ops.def("gptq_marlin_gemm", &gptq_marlin_gemm); ops.impl("gptq_marlin_gemm", torch::kCUDA, &gptq_marlin_gemm); // gptq_marlin repack from GPTQ. - ops.def( - "gptq_marlin_repack(Tensor b_q_weight, Tensor perm, int size_k, int " - "size_n, int num_bits) -> Tensor"); + ops.def("gptq_marlin_repack", &gptq_marlin_repack); ops.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack); // Dequantization for AWQ. - ops.def( - "awq_dequantize(Tensor _kernel, Tensor _scaling_factors, Tensor _zeros, " - "int split_k_iters, int thx, int thy) -> Tensor"); + ops.def("awq_dequantize", &awq_dequantize); ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize); // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column // quantization. - ops.def( - "cutlass_scaled_mm_dq(Tensor out, Tensor a, Tensor b, Tensor a_scales, " - "Tensor b_scales) -> ()"); + def(ops, "cutlass_scaled_mm_dq", &cutlass_scaled_mm_dq, {0}); ops.impl("cutlass_scaled_mm_dq", torch::kCUDA, &cutlass_scaled_mm_dq); #endif // Quantized GEMM for GPTQ. - ops.def( - "gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, Tensor " - "b_gptq_scales, Tensor b_g_idx, bool use_exllama, int bit) -> Tensor"); + ops.def("gptq_gemm", &gptq_gemm); ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm); // Post processing for GPTQ. - ops.def("gptq_shuffle(Tensor q_weight, Tensor q_perm, int bit) -> ()"); + def(ops, "gptq_shuffle", &gptq_shuffle, {0}); ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle); // Quantized GEMM for SqueezeLLM. - ops.def( - "squeezellm_gemm(Tensor vec, Tensor mat, Tensor mul, Tensor " - "lookup_table) -> ()"); + def(ops, "squeezellm_gemm", &squeezellm_gemm, {2}); ops.impl("squeezellm_gemm", torch::kCUDA, &squeezellm_gemm); // Compute FP8 quantized tensor for given scaling factor. - ops.def( - "static_scaled_fp8_quant(Tensor out, Tensor input, Tensor scale) -> ()"); + def(ops, "static_scaled_fp8_quant", &static_scaled_fp8_quant, {0}); ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant); // Compute FP8 quantized tensor and scaling factor. - ops.def( - "dynamic_scaled_fp8_quant(Tensor out, Tensor input, Tensor scale) -> ()"); + def(ops, "dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant, {0}); ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant); // Aligning the number of tokens to be processed by each expert such // that it is divisible by the block size. - ops.def( - "moe_align_block_size(Tensor topk_ids, int num_experts, int block_size," - "Tensor sorted_token_ids, Tensor experts_ids, Tensor " - "num_tokens_post_pad) -> ()"); + def(ops, "moe_align_block_size", &moe_align_block_size, {3, 4, 5}); ops.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); // Compute int8 quantized tensor for given scaling factor. - ops.def( - "static_scaled_int8_quant(Tensor out, Tensor input, float scale) -> ()"); + def(ops, "static_scaled_int8_quant", &static_scaled_int8_quant, {0}); ops.impl("static_scaled_int8_quant", torch::kCUDA, &static_scaled_int8_quant); // Compute int8 quantized tensor and scaling factor @@ -187,23 +140,23 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { // Cache ops // Swap in (out) the cache blocks from src to dst. - cache_ops.def("swap_blocks(Tensor src, Tensor dst, Tensor block_mapping) -> ()"); + def(cache_ops, "swap_blocks", &swap_blocks, {0,1}); cache_ops.impl("swap_blocks", torch::kCUDA, &swap_blocks); // Copy the cache blocks from src to dst. - cache_ops.def("copy_blocks", ©_blocks); + def(cache_ops, "copy_blocks", ©_blocks, {0, 1}); cache_ops.impl("copy_blocks", torch::kCUDA, ©_blocks); // Reshape the key and value tensors and cache them. - cache_ops.def("reshape_and_cache", &reshape_and_cache); + def(cache_ops, "reshape_and_cache", &reshape_and_cache, {2, 3}); // 4? cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache); // Reshape the key and value tensors and cache them. - cache_ops.def("reshape_and_cache_flash", &reshape_and_cache_flash); + def(cache_ops, "reshape_and_cache_flash", &reshape_and_cache_flash, {2, 3}); // 4? cache_ops.impl("reshape_and_cache_flash", torch::kCUDA, &reshape_and_cache_flash); // Convert the key and value cache to fp8 data type. - cache_ops.def("convert_fp8", &convert_fp8); + def(cache_ops, "convert_fp8", &convert_fp8, {0}); cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8); } @@ -211,11 +164,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) { // Cuda utils // Gets the specified device attribute. - cuda_utils.def("get_device_attribute(int attribute, int device_id) -> int"); + cuda_utils.def("get_device_attribute", &get_device_attribute); cuda_utils.impl("get_device_attribute", torch::kCUDA, &get_device_attribute); // Gets the maximum shared memory per block device attribute. - cuda_utils.def("get_max_shared_memory_per_block_device_attribute(int device_id) -> int"); + cuda_utils.def("get_max_shared_memory_per_block_device_attribute", + &get_max_shared_memory_per_block_device_attribute); cuda_utils.impl("get_max_shared_memory_per_block_device_attribute", torch::kCUDA, &get_max_shared_memory_per_block_device_attribute); } diff --git a/csrc/register.h b/csrc/register.h new file mode 100644 index 000000000000..35f61d6b0c4a --- /dev/null +++ b/csrc/register.h @@ -0,0 +1,51 @@ +#pragma once + +#include + +#include + +#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE) +#define _CONCAT(A, B) A##B +#define CONCAT(A, B) _CONCAT(A, B) + +namespace vllm { + +template +void def(torch::Library& lib, std::string const& name, FnType* fn, + std::initializer_list mutating_arg_indices = {}) { +#if 1 + auto raw_schema = + c10::detail::inferFunctionSchemaFromFunctor>(); + auto named_schema = raw_schema->cloneWithName(name, ""); + + if (mutating_arg_indices.size() != 0) { + std::vector const& args = named_schema.arguments(); + std::vector new_args; + for (size_t i = 0; i < args.size(); ++i) { + auto const& arg = args[i]; + if (std::find(mutating_arg_indices.begin(), mutating_arg_indices.end(), + i) == mutating_arg_indices.end()) { + new_args.push_back(arg); + } else { + c10::AliasInfo new_alias_info; + if (arg.alias_info()) { + new_alias_info = *arg.alias_info(); + } + new_alias_info.setIsWrite(true); + + new_args.emplace_back( + arg.name(), arg.type(), arg.real_type(), arg.N(), + arg.default_value(), arg.kwarg_only(), new_alias_info); + } + } + + named_schema = named_schema.cloneWithArguments(std::move(new_args)); + } + + lib.def(std::move(named_schema)); +#else + lib.def(name.c_str(), fn); +#endif +} + +} // namespace vllm diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index de240b264116..973c33659781 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -3,7 +3,7 @@ import torch try: - # ruff: noqa: SIM105 + # ruff: noqa: F401 SIM105 import vllm._C except ImportError as e: from vllm.logger import init_logger @@ -315,8 +315,10 @@ def reshape_and_cache( kv_cache_dtype: str, kv_scale: float, ) -> None: - torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache, value_cache, - slot_mapping, kv_cache_dtype, kv_scale) + torch.ops._C_cache_ops.reshape_and_cache( + key, value, key_cache, + value_cache, slot_mapping, + kv_cache_dtype, kv_scale) def reshape_and_cache_flash( @@ -327,8 +329,10 @@ def reshape_and_cache_flash( slot_mapping: torch.Tensor, kv_cache_dtype: str, ) -> None: - torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache, - slot_mapping, kv_cache_dtype) + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, value, key_cache, + value_cache, slot_mapping, + kv_cache_dtype) def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor, @@ -353,7 +357,9 @@ def get_device_attribute(attribute: int, device: int) -> int: def get_max_shared_memory_per_block_device_attribute(device: int) -> int: - return torch.ops._C_cuda_utils.get_max_shared_memory_per_block_device_attribute(device) + # ruff: noqa: E501 + return torch.ops._C_cuda_utils.get_max_shared_memory_per_block_device_attribute( + device) #TODO: custom_ar From 70bf9a9a9032ff5c2803bcb03d8dfd1955d8acb4 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 27 May 2024 05:29:06 +0000 Subject: [PATCH 14/41] clang format --- csrc/cpu/pybind.cpp | 1 - csrc/pybind.cpp | 23 +++++++++++++---------- csrc/register.h | 6 +++--- 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/csrc/cpu/pybind.cpp b/csrc/cpu/pybind.cpp index a11f30c0d1cb..7e5fe9d4bea5 100644 --- a/csrc/cpu/pybind.cpp +++ b/csrc/cpu/pybind.cpp @@ -52,7 +52,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding); } - TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { // Cache ops // Swap in (out) the cache blocks from src to dst. diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 1047401416f8..707724d9fd7f 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -11,7 +11,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Attention ops // Compute the attention between an input query and the cached // keys/values using PagedAttention. - //ops.def("paged_attention_v1", &paged_attention_v1); + // ops.def("paged_attention_v1", &paged_attention_v1); def(ops, "paged_attention_v1", &paged_attention_v1, {0}); ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1); @@ -42,10 +42,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Layernorm // Apply Root Mean Square (RMS) Normalization to the input tensor. - def(ops, "rms_norm", &rms_norm, {0}); - //ops.def("rms_norm", &rms_norm); - // ops.def(torch::schema("rms_norm(Tensor out, Tensor input, Tensor weight, - // float epsilon) -> ()"), c10::AliasAnalysisKind::CONSERVATIVE); + def(ops, "rms_norm", &rms_norm, {0}); + // ops.def("rms_norm", &rms_norm); + // ops.def(torch::schema("rms_norm(Tensor out, Tensor input, Tensor weight, + // float epsilon) -> ()"), c10::AliasAnalysisKind::CONSERVATIVE); ops.impl("rms_norm", torch::kCUDA, &rms_norm); // In-place fused Add and RMS Normalization. @@ -140,7 +140,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { // Cache ops // Swap in (out) the cache blocks from src to dst. - def(cache_ops, "swap_blocks", &swap_blocks, {0,1}); + def(cache_ops, "swap_blocks", &swap_blocks, {0, 1}); cache_ops.impl("swap_blocks", torch::kCUDA, &swap_blocks); // Copy the cache blocks from src to dst. @@ -148,12 +148,14 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { cache_ops.impl("copy_blocks", torch::kCUDA, ©_blocks); // Reshape the key and value tensors and cache them. - def(cache_ops, "reshape_and_cache", &reshape_and_cache, {2, 3}); // 4? + def(cache_ops, "reshape_and_cache", &reshape_and_cache, {2, 3}); // 4? cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache); // Reshape the key and value tensors and cache them. - def(cache_ops, "reshape_and_cache_flash", &reshape_and_cache_flash, {2, 3}); // 4? - cache_ops.impl("reshape_and_cache_flash", torch::kCUDA, &reshape_and_cache_flash); + def(cache_ops, "reshape_and_cache_flash", &reshape_and_cache_flash, + {2, 3}); // 4? + cache_ops.impl("reshape_and_cache_flash", torch::kCUDA, + &reshape_and_cache_flash); // Convert the key and value cache to fp8 data type. def(cache_ops, "convert_fp8", &convert_fp8, {0}); @@ -170,7 +172,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) { // Gets the maximum shared memory per block device attribute. cuda_utils.def("get_max_shared_memory_per_block_device_attribute", &get_max_shared_memory_per_block_device_attribute); - cuda_utils.impl("get_max_shared_memory_per_block_device_attribute", torch::kCUDA, + cuda_utils.impl("get_max_shared_memory_per_block_device_attribute", + torch::kCUDA, &get_max_shared_memory_per_block_device_attribute); } diff --git a/csrc/register.h b/csrc/register.h index 35f61d6b0c4a..adcb7e8178a6 100644 --- a/csrc/register.h +++ b/csrc/register.h @@ -33,9 +33,9 @@ void def(torch::Library& lib, std::string const& name, FnType* fn, } new_alias_info.setIsWrite(true); - new_args.emplace_back( - arg.name(), arg.type(), arg.real_type(), arg.N(), - arg.default_value(), arg.kwarg_only(), new_alias_info); + new_args.emplace_back(arg.name(), arg.type(), arg.real_type(), arg.N(), + arg.default_value(), arg.kwarg_only(), + new_alias_info); } } From 216d3d1b288eb2a7b3c591dd2dad4a6ba7182eac Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 27 May 2024 05:30:23 +0000 Subject: [PATCH 15/41] fix format --- csrc/cache.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/csrc/cache.h b/csrc/cache.h index fbf6064f0791..bf8887a74602 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -15,7 +15,8 @@ void copy_blocks(std::vector const& key_caches, void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype, const double kv_scale); + const std::string& kv_cache_dtype, + const double kv_scale); void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, From 4bc89c1901a23570f0b45011af5e155e8884ad03 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 27 May 2024 05:36:22 +0000 Subject: [PATCH 16/41] fix cpu binding --- csrc/cpu/pybind.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/csrc/cpu/pybind.cpp b/csrc/cpu/pybind.cpp index 7e5fe9d4bea5..e1a690cc862c 100644 --- a/csrc/cpu/pybind.cpp +++ b/csrc/cpu/pybind.cpp @@ -66,3 +66,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { cache_ops.def("reshape_and_cache", &reshape_and_cache); cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache); } + +// TODO: get rid of this? +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} From fb2a1950d6901484133391f62fd6cb9f3bf36282 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 27 May 2024 05:53:21 +0000 Subject: [PATCH 17/41] fix intel --- csrc/cpu/cache.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/csrc/cpu/cache.cpp b/csrc/cpu/cache.cpp index 2890ba6e2bb3..36e0523662b1 100644 --- a/csrc/cpu/cache.cpp +++ b/csrc/cpu/cache.cpp @@ -5,8 +5,8 @@ namespace { template -void copy_blocks_cpu_impl(std::vector& key_caches, - std::vector& value_caches, +void copy_blocks_cpu_impl(std::vector const& key_caches, + std::vector const& value_caches, const torch::Tensor& mapping_pairs, const int element_num_per_block, const int layer_num) { @@ -82,8 +82,8 @@ void reshape_and_cache_cpu_impl( } }; // namespace -void copy_blocks(std::vector& key_caches, - std::vector& value_caches, +void copy_blocks(std::vector const& key_caches, + std::vector const& value_caches, const torch::Tensor& block_mapping) { unsigned num_layers = key_caches.size(); TORCH_CHECK(num_layers == value_caches.size()); @@ -104,7 +104,7 @@ void copy_blocks(std::vector& key_caches, void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype, float kv_scale) { + const std::string& kv_cache_dtype, double kv_scale) { TORCH_CHECK(kv_scale == 1.0f); int num_tokens = key.size(0); From 40521982bf9a8c53483d5b55d9df8ec0a1fecec9 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 27 May 2024 14:09:32 +0000 Subject: [PATCH 18/41] format --- vllm/_custom_ops.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 973c33659781..7881234a0f83 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -315,10 +315,9 @@ def reshape_and_cache( kv_cache_dtype: str, kv_scale: float, ) -> None: - torch.ops._C_cache_ops.reshape_and_cache( - key, value, key_cache, - value_cache, slot_mapping, - kv_cache_dtype, kv_scale) + torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache, + value_cache, slot_mapping, + kv_cache_dtype, kv_scale) def reshape_and_cache_flash( @@ -329,10 +328,9 @@ def reshape_and_cache_flash( slot_mapping: torch.Tensor, kv_cache_dtype: str, ) -> None: - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, value, key_cache, - value_cache, slot_mapping, - kv_cache_dtype) + torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache, + value_cache, slot_mapping, + kv_cache_dtype) def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor, From e8e9af2825cb6d0bdb695668dce94ce9224566f0 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 31 May 2024 02:39:09 +0000 Subject: [PATCH 19/41] add meta functions and signatures --- csrc/custom_all_reduce.cu | 2 +- csrc/moe/moe_ops.cpp | 4 +- csrc/moe/moe_ops.h | 6 +- csrc/ops.h | 6 +- csrc/punica/punica_ops.h | 6 +- csrc/punica/punica_pybind.cpp | 10 +- csrc/pybind.cpp | 237 +++++++++++++++--- .../gptq_marlin/gptq_marlin_repack.cu | 12 + csrc/register.h | 51 ---- vllm/model_executor/models/llama.py | 2 + 10 files changed, 246 insertions(+), 90 deletions(-) delete mode 100644 csrc/register.h diff --git a/csrc/custom_all_reduce.cu b/csrc/custom_all_reduce.cu index 0b1d95848525..13b8d49f33ce 100644 --- a/csrc/custom_all_reduce.cu +++ b/csrc/custom_all_reduce.cu @@ -125,7 +125,7 @@ void dispose(fptr_t _fa) { delete fa; } -int meta_size() { return sizeof(vllm::Signal); } +int64_t meta_size() { return sizeof(vllm::Signal); } void register_buffer(fptr_t _fa, torch::Tensor& t, const std::vector& handles, diff --git a/csrc/moe/moe_ops.cpp b/csrc/moe/moe_ops.cpp index c7e11268ef97..cc1ee9e8ebb2 100644 --- a/csrc/moe/moe_ops.cpp +++ b/csrc/moe/moe_ops.cpp @@ -2,7 +2,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { // Apply topk softmax to the gating outputs. - vllm::def(m, "topk_softmax", &topk_softmax, {0, 1}); + m.def( + "topk_softmax(Tensor! topk_weights, Tensor! topk_indices,Tensor " + "token_expert_indices,Tensor gating_output) -> ()"); m.impl("topk_softmax", torch::kCUDA, &topk_softmax); } diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index 66b6e44a3b7c..733b8560658c 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -1,6 +1,10 @@ #pragma once -#include "register.h" +#include + +#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE) +#define _CONCAT(A, B) A##B +#define CONCAT(A, B) _CONCAT(A, B) void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices, torch::Tensor& token_expert_indices, diff --git a/csrc/ops.h b/csrc/ops.h index 156865ae3094..24f31928c9fd 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -1,6 +1,8 @@ #pragma once -#include "register.h" +#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE) +#define _CONCAT(A, B) A##B +#define CONCAT(A, B) _CONCAT(A, B) void paged_attention_v1( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, @@ -135,7 +137,7 @@ void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, torch::Tensor& out); void dispose(fptr_t _fa); -int meta_size(); +int64_t meta_size(); void register_buffer(fptr_t _fa, torch::Tensor& t, const std::vector& handles, const std::vector& offsets); diff --git a/csrc/punica/punica_ops.h b/csrc/punica/punica_ops.h index d2038a828e9d..110c94d402df 100644 --- a/csrc/punica/punica_ops.h +++ b/csrc/punica/punica_ops.h @@ -1,6 +1,10 @@ #pragma once -#include "register.h" +#include + +#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE) +#define _CONCAT(A, B) A##B +#define CONCAT(A, B) _CONCAT(A, B) void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, torch::Tensor indicies, int64_t layer_idx, double scale); diff --git a/csrc/punica/punica_pybind.cpp b/csrc/punica/punica_pybind.cpp index cefe449dd7d2..cd464c5f663c 100644 --- a/csrc/punica/punica_pybind.cpp +++ b/csrc/punica/punica_pybind.cpp @@ -1,10 +1,16 @@ #include "punica_ops.h" TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { - vllm::def(m, "dispatch_bgmv", &dispatch_bgmv, {0}); + m.def( + "dispatch_bgmv(Tensor! y, Tensor x, Tensor w, Tensor indicies, int " + "layer_idx, float scale) -> ()"); m.impl("dispatch_bgmv", torch::kCUDA, &dispatch_bgmv); - vllm::def(m, "dispatch_bgmv_low_level", &dispatch_bgmv_low_level, {0}); + m.def( + "dispatch_bgmv_low_level(Tensor! y, Tensor x, Tensor w," + "Tensor indicies, int layer_idx," + "float scale, int h_in, int h_out," + "int y_offset) -> ()"); m.impl("dispatch_bgmv_low_level", torch::kCUDA, &dispatch_bgmv_low_level); } diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 707724d9fd7f..a7731094e721 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -1,9 +1,116 @@ #include "cache.h" #include "cuda_utils.h" #include "ops.h" +// #include "quantization/gptq_marlin/gptq_marlin.cuh" //?? #include -using vllm::def; +torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight, + torch::Tensor& perm, int64_t size_k, + int64_t size_n, int64_t num_bits); + +// See +// https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ptttacy8y1u9 +// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations + +// where should these live? near the implementations of the kernels? +namespace vllm::meta { + +torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& scales, + const torch::Tensor& codebook_partition_sizes, + const std::optional& bias) { + auto input_sizes = input.sizes(); + + auto out_features = codes.size(0) * codebooks.size(2); + auto flat_input = input.reshape({-1, input.size(-1)}); + auto flat_output = torch::empty( + {flat_input.size(0), out_features}, + torch::TensorOptions().dtype(input.dtype()).device(input.device())); + + auto output_sizes = input_sizes.vec(); + output_sizes.pop_back(); + output_sizes.push_back(-1); + return flat_output.reshape(output_sizes); +} + +torch::Tensor aqlm_dequant(const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& codebook_partition_sizes) { + auto in_features = codes.size(1) * 8; + auto out_features = codes.size(0); + return torch::empty({out_features, in_features}, + torch::TensorOptions() + .dtype(codebooks.dtype()) + .device(codebooks.device())); +} + +torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, + torch::Tensor _scaling_factors, torch::Tensor _zeros, + int64_t split_k_iters) { + int num_in_feats = _in_feats.size(0); + auto options = torch::TensorOptions() + .dtype(_in_feats.dtype()) + .device(_in_feats.device()); +#if 0 + at::Tensor _out_feats = + torch::empty({num_in_feats, _kernel.size(1) * 8}, options); + return _out_feats.sum(0); +#else + return torch::empty({_kernel.size(1) * 8}, options); +#endif +} + +torch::Tensor awq_dequantize(torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, int64_t split_k_iters, + int64_t thx, int64_t thy) { + int in_c = _kernel.size(0); + int qout_c = _kernel.size(1); + int out_c = qout_c * 8; + + auto options = torch::TensorOptions() + .dtype(_scaling_factors.dtype()) + .device(_scaling_factors.device()); + + return torch::empty({in_c, out_c}, options); +} + +torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, torch::Tensor& workspace, + int64_t size_m, int64_t size_n, int64_t size_k) { + auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); + return torch::empty({size_m, size_n}, options); +} + +torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_meta, + torch::Tensor& b_scales, + torch::Tensor& workspace, int64_t num_bits, + int64_t size_m, int64_t size_n, + int64_t size_k) { + auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); + return torch::empty({size_m, size_n}, options); +} + +torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, torch::Tensor& g_idx, + torch::Tensor& perm, torch::Tensor& workspace, + int64_t num_bits, int64_t size_m, int64_t size_n, + int64_t size_k, bool is_k_full) { + auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); + return torch::empty({size_m, size_n}, options); +} + +torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, + torch::Tensor b_gptq_qzeros, + torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, + bool use_exllama, int64_t bit) { + auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); + return torch::empty({a.size(0), b_q_weight.size(1)}, options); +} + +} // namespace vllm::meta TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // vLLM custom ops @@ -11,55 +118,86 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Attention ops // Compute the attention between an input query and the cached // keys/values using PagedAttention. - // ops.def("paged_attention_v1", &paged_attention_v1); - def(ops, "paged_attention_v1", &paged_attention_v1, {0}); + ops.def( + "paged_attention_v1(" + " Tensor! out, Tensor query, Tensor key_cache," + " Tensor value_cache, int num_kv_heads, float scale," + " Tensor block_tables, Tensor seq_lens, int block_size," + " int max_seq_len, Tensor? alibi_slopes," + " str kv_cache_dtype, float kv_scale, int tp_rank," + " int blocksparse_local_blocks," + " int blocksparse_vert_stride, int blocksparse_block_size," + " int blocksparse_head_sliding_step) -> ()"); ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1); // PagedAttention V2. - def(ops, "paged_attention_v2", &paged_attention_v2, {0}); + // def(ops, "paged_attention_v2", &paged_attention_v2, {0}); + ops.def( + "paged_attention_v2(" + " Tensor! out, Tensor exp_sums, Tensor max_logits," + " Tensor tmp_out, Tensor query, Tensor key_cache," + " Tensor value_cache, int num_kv_heads, float scale," + " Tensor block_tables, Tensor seq_lens, int block_size," + " int max_seq_len, Tensor? alibi_slopes," + " str kv_cache_dtype, float kv_scale, int tp_rank," + " int blocksparse_local_blocks," + " int blocksparse_vert_stride, int blocksparse_block_size," + " int blocksparse_head_sliding_step) -> ()"); + ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2); // Activation ops // Activation function used in SwiGLU. - def(ops, "silu_and_mul", &silu_and_mul, {0}); + ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()"); ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul); // Activation function used in GeGLU with `none` approximation. - def(ops, "gelu_and_mul", &gelu_and_mul, {0}); + ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()"); ops.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul); // Activation function used in GeGLU with `tanh` approximation. - def(ops, "gelu_tanh_and_mul", &gelu_tanh_and_mul, {0}); + ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()"); ops.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul); // GELU implementation used in GPT-2. - def(ops, "gelu_new", &gelu_new, {0}); + ops.def("gelu_new(Tensor! out, Tensor input) -> ()"); ops.impl("gelu_new", torch::kCUDA, &gelu_new); // Approximate GELU implementation. - def(ops, "gelu_fast", &gelu_fast, {0}); + ops.def("gelu_fast(Tensor! out, Tensor input) -> ()"); ops.impl("gelu_fast", torch::kCUDA, &gelu_fast); // Layernorm // Apply Root Mean Square (RMS) Normalization to the input tensor. - def(ops, "rms_norm", &rms_norm, {0}); - // ops.def("rms_norm", &rms_norm); + ops.def( + "rms_norm(Tensor! out, Tensor input, Tensor weight, float epsilon) -> " + "()"); // ops.def(torch::schema("rms_norm(Tensor out, Tensor input, Tensor weight, // float epsilon) -> ()"), c10::AliasAnalysisKind::CONSERVATIVE); ops.impl("rms_norm", torch::kCUDA, &rms_norm); // In-place fused Add and RMS Normalization. - def(ops, "fused_add_rms_norm", &fused_add_rms_norm, {0, 1}); + ops.def( + "fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, " + "float epsilon) -> ()"); ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm); // Rotary embedding // Apply GPT-NeoX or GPT-J style rotary embedding to query and key. - def(ops, "rotary_embedding", &rotary_embedding, {1, 2}); + ops.def( + "rotary_embedding(Tensor positions, Tensor! query," + " Tensor! key, int head_size," + " Tensor cos_sin_cache, bool is_neox) -> ()"); ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding); // Apply GPT-NeoX or GPT-J style rotary embedding to query and key // (supports multiple loras). - def(ops, "batched_rotary_embedding", &batched_rotary_embedding, {1, 2}); // ? + ops.def( + "batched_rotary_embedding(Tensor positions, Tensor! query," + " Tensor! key, int head_size," + " Tensor cos_sin_cache, bool is_neox," + " int rot_dim," + " Tensor cos_sin_cache_offsets) -> ()"); ops.impl("batched_rotary_embedding", torch::kCUDA, &batched_rotary_embedding); // Quantization ops @@ -67,68 +205,91 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Quantized GEMM for AQLM. ops.def("aqlm_gemm", &aqlm_gemm); ops.impl("aqlm_gemm", torch::kCUDA, &aqlm_gemm); + ops.impl("aqlm_gemm", torch::kMeta, &vllm::meta::aqlm_gemm); // Decompression method for AQLM. ops.def("aqlm_dequant", &aqlm_dequant); ops.impl("aqlm_dequant", torch::kCUDA, &aqlm_dequant); + ops.impl("aqlm_dequant", torch::kMeta, &vllm::meta::aqlm_dequant); // Quantized GEMM for AWQ. ops.def("awq_gemm", &awq_gemm); ops.impl("awq_gemm", torch::kCUDA, &awq_gemm); + ops.impl("awq_gemm", torch::kMeta, &vllm::meta::awq_gemm); + + // Dequantization for AWQ. + ops.def("awq_dequantize", &awq_dequantize); + ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize); + ops.impl("awq_dequantize", torch::kMeta, &vllm::meta::awq_dequantize); // Marlin (Dense) Optimized Quantized GEMM for GPTQ. ops.def("marlin_gemm", &marlin_gemm); ops.impl("marlin_gemm", torch::kCUDA, &marlin_gemm); + ops.impl("marlin_gemm", torch::kMeta, &vllm::meta::marlin_gemm); // Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ. ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm); ops.impl("gptq_marlin_24_gemm", torch::kCUDA, &gptq_marlin_24_gemm); + ops.impl("gptq_marlin_24_gemm", torch::kMeta, + &vllm::meta::gptq_marlin_24_gemm); // gptq_marlin Optimized Quantized GEMM for GPTQ. ops.def("gptq_marlin_gemm", &gptq_marlin_gemm); ops.impl("gptq_marlin_gemm", torch::kCUDA, &gptq_marlin_gemm); + ops.impl("gptq_marlin_gemm", torch::kMeta, &vllm::meta::gptq_marlin_gemm); // gptq_marlin repack from GPTQ. ops.def("gptq_marlin_repack", &gptq_marlin_repack); ops.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack); - - // Dequantization for AWQ. - ops.def("awq_dequantize", &awq_dequantize); - ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize); + ops.impl("gptq_marlin_repack", torch::kMeta, &gptq_marlin_repack_meta); // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column // quantization. - def(ops, "cutlass_scaled_mm_dq", &cutlass_scaled_mm_dq, {0}); + ops.def( + "cutlass_scaled_mm_dq(Tensor! out, Tensor a," + " Tensor b, Tensor a_scales," + " Tensor b_scales) -> ()"); ops.impl("cutlass_scaled_mm_dq", torch::kCUDA, &cutlass_scaled_mm_dq); #endif // Quantized GEMM for GPTQ. ops.def("gptq_gemm", &gptq_gemm); ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm); + ops.impl("gptq_gemm", torch::kMeta, &vllm::meta::gptq_gemm); // Post processing for GPTQ. - def(ops, "gptq_shuffle", &gptq_shuffle, {0}); + ops.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()"); ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle); // Quantized GEMM for SqueezeLLM. - def(ops, "squeezellm_gemm", &squeezellm_gemm, {2}); + ops.def( + "squeezellm_gemm(Tensor vec, Tensor mat, Tensor! mul, Tensor " + "lookup_table) -> ()"); ops.impl("squeezellm_gemm", torch::kCUDA, &squeezellm_gemm); // Compute FP8 quantized tensor for given scaling factor. - def(ops, "static_scaled_fp8_quant", &static_scaled_fp8_quant, {0}); + ops.def( + "static_scaled_fp8_quant(Tensor! out, Tensor input, Tensor scale) -> ()"); ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant); // Compute FP8 quantized tensor and scaling factor. - def(ops, "dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant, {0}); + ops.def( + "dynamic_scaled_fp8_quant(Tensor! out, Tensor input, Tensor scale) -> " + "()"); ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant); // Aligning the number of tokens to be processed by each expert such // that it is divisible by the block size. - def(ops, "moe_align_block_size", &moe_align_block_size, {3, 4, 5}); + ops.def( + "moe_align_block_size(Tensor topk_ids, int num_experts," + " int block_size, Tensor! sorted_token_ids," + " Tensor! experts_ids," + " Tensor! num_tokens_post_pad) -> ()"); ops.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); // Compute int8 quantized tensor for given scaling factor. - def(ops, "static_scaled_int8_quant", &static_scaled_int8_quant, {0}); + ops.def( + "static_scaled_int8_quant(Tensor! out, Tensor input, float scale) -> ()"); ops.impl("static_scaled_int8_quant", torch::kCUDA, &static_scaled_int8_quant); // Compute int8 quantized tensor and scaling factor @@ -140,25 +301,39 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { // Cache ops // Swap in (out) the cache blocks from src to dst. - def(cache_ops, "swap_blocks", &swap_blocks, {0, 1}); + cache_ops.def( + "swap_blocks(Tensor! src, Tensor! dst, Tensor block_mapping) -> ()"); cache_ops.impl("swap_blocks", torch::kCUDA, &swap_blocks); // Copy the cache blocks from src to dst. - def(cache_ops, "copy_blocks", ©_blocks, {0, 1}); + cache_ops.def( + "copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor " + "block_mapping) -> ()"); cache_ops.impl("copy_blocks", torch::kCUDA, ©_blocks); // Reshape the key and value tensors and cache them. - def(cache_ops, "reshape_and_cache", &reshape_and_cache, {2, 3}); // 4? + cache_ops.def( + "reshape_and_cache(Tensor key, Tensor value," + " Tensor! key_cache, Tensor! value_cache," + " Tensor slot_mapping," + " str kv_cache_dtype," + " float kv_scale) -> ()"); cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache); // Reshape the key and value tensors and cache them. - def(cache_ops, "reshape_and_cache_flash", &reshape_and_cache_flash, - {2, 3}); // 4? + cache_ops.def( + "reshape_and_cache_flash(Tensor key, Tensor value," + " Tensor! key_cache," + " Tensor! value_cache," + " Tensor slot_mapping," + " str kv_cache_dtype) -> ()"); cache_ops.impl("reshape_and_cache_flash", torch::kCUDA, &reshape_and_cache_flash); // Convert the key and value cache to fp8 data type. - def(cache_ops, "convert_fp8", &convert_fp8, {0}); + cache_ops.def( + "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, str " + "kv_cache_dtype) -> ()"); cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8); } diff --git a/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu b/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu index 4adc158eb14e..6e0ad9cf3a61 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu @@ -348,3 +348,15 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, } #endif + +torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight, + torch::Tensor& perm, int64_t size_k, + int64_t size_n, int64_t num_bits) { + int const pack_factor = 32 / num_bits; + auto options = torch::TensorOptions() + .dtype(b_q_weight.dtype()) + .device(b_q_weight.device()); + return torch::empty({size_k / gptq_marlin::tile_size, + size_n * gptq_marlin::tile_size / pack_factor}, + options); +} diff --git a/csrc/register.h b/csrc/register.h deleted file mode 100644 index adcb7e8178a6..000000000000 --- a/csrc/register.h +++ /dev/null @@ -1,51 +0,0 @@ -#pragma once - -#include - -#include - -#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE) -#define _CONCAT(A, B) A##B -#define CONCAT(A, B) _CONCAT(A, B) - -namespace vllm { - -template -void def(torch::Library& lib, std::string const& name, FnType* fn, - std::initializer_list mutating_arg_indices = {}) { -#if 1 - auto raw_schema = - c10::detail::inferFunctionSchemaFromFunctor>(); - auto named_schema = raw_schema->cloneWithName(name, ""); - - if (mutating_arg_indices.size() != 0) { - std::vector const& args = named_schema.arguments(); - std::vector new_args; - for (size_t i = 0; i < args.size(); ++i) { - auto const& arg = args[i]; - if (std::find(mutating_arg_indices.begin(), mutating_arg_indices.end(), - i) == mutating_arg_indices.end()) { - new_args.push_back(arg); - } else { - c10::AliasInfo new_alias_info; - if (arg.alias_info()) { - new_alias_info = *arg.alias_info(); - } - new_alias_info.setIsWrite(true); - - new_args.emplace_back(arg.name(), arg.type(), arg.real_type(), arg.N(), - arg.default_value(), arg.kwarg_only(), - new_alias_info); - } - } - - named_schema = named_schema.cloneWithArguments(std::move(new_args)); - } - - lib.def(std::move(named_schema)); -#else - lib.def(name.c_str(), fn); -#endif -} - -} // namespace vllm diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index d83ee9a201c0..4fd3a9bbd0be 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -270,6 +270,7 @@ def __init__( def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) + #@torch.compile def forward( self, input_ids: Optional[torch.Tensor], @@ -361,6 +362,7 @@ def __init__( config.vocab_size, logit_scale) self.sampler = Sampler() + #@torch.compile def forward( self, input_ids: torch.Tensor, From c72bf4c518a43ddb95f95a7d38f599ef0ec4b044 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 31 May 2024 02:53:31 +0000 Subject: [PATCH 20/41] update cpu bindings --- csrc/cpu/pybind.cpp | 60 +++++++++++++++++++++++++++++++++++---------- csrc/pybind.cpp | 1 - 2 files changed, 47 insertions(+), 14 deletions(-) diff --git a/csrc/cpu/pybind.cpp b/csrc/cpu/pybind.cpp index e1a690cc862c..6fe5df36c0db 100644 --- a/csrc/cpu/pybind.cpp +++ b/csrc/cpu/pybind.cpp @@ -8,62 +8,96 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Attention ops // Compute the attention between an input query and the cached keys/values // using PagedAttention. - ops.def("paged_attention_v1", &paged_attention_v1); + ops.def( + "paged_attention_v1(" + " Tensor! out, Tensor query, Tensor key_cache," + " Tensor value_cache, int num_kv_heads, float scale," + " Tensor block_tables, Tensor seq_lens, int block_size," + " int max_seq_len, Tensor? alibi_slopes," + " str kv_cache_dtype, float kv_scale, int tp_rank," + " int blocksparse_local_blocks," + " int blocksparse_vert_stride, int blocksparse_block_size," + " int blocksparse_head_sliding_step) -> ()"); ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1); // PagedAttention V2. - ops.def("paged_attention_v2", &paged_attention_v2); + ops.def( + "paged_attention_v2(" + " Tensor! out, Tensor exp_sums, Tensor max_logits," + " Tensor tmp_out, Tensor query, Tensor key_cache," + " Tensor value_cache, int num_kv_heads, float scale," + " Tensor block_tables, Tensor seq_lens, int block_size," + " int max_seq_len, Tensor? alibi_slopes," + " str kv_cache_dtype, float kv_scale, int tp_rank," + " int blocksparse_local_blocks," + " int blocksparse_vert_stride, int blocksparse_block_size," + " int blocksparse_head_sliding_step) -> ()"); ops.impl("paged_attention_v2", torch::kCPU, &paged_attention_v2); // Activation ops // Activation function used in SwiGLU. - ops.def("silu_and_mul", &silu_and_mul); + ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()"); ops.impl("silu_and_mul", torch::kCPU, &silu_and_mul); // Activation function used in GeGLU with `none` approximation. - ops.def("gelu_and_mul", &gelu_and_mul); + ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()"); ops.impl("gelu_and_mul", torch::kCPU, &gelu_and_mul); // Activation function used in GeGLU with `tanh` approximation. - ops.def("gelu_tanh_and_mul", &gelu_tanh_and_mul); + ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()"); ops.impl("gelu_tanh_and_mul", torch::kCPU, &gelu_tanh_and_mul); // GELU implementation used in GPT-2. - ops.def("gelu_new", &gelu_new); + ops.def("gelu_new(Tensor! out, Tensor input) -> ()"); ops.impl("gelu_new", torch::kCPU, &gelu_new); // Approximate GELU implementation. - ops.def("gelu_fast", &gelu_fast); + ops.def("gelu_fast(Tensor! out, Tensor input) -> ()"); ops.impl("gelu_fast", torch::kCPU, &gelu_fast); // Layernorm // Apply Root Mean Square (RMS) Normalization to the input tensor. - ops.def("rms_norm", &rms_norm); + ops.def( + "rms_norm(Tensor! out, Tensor input, Tensor weight, float epsilon) -> " + "()"); ops.impl("rms_norm", torch::kCPU, &rms_norm); // In-place fused Add and RMS Normalization. - ops.def("fused_add_rms_norm", &fused_add_rms_norm); + ops.def( + "fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, " + "float epsilon) -> ()"); ops.impl("fused_add_rms_norm", torch::kCPU, &fused_add_rms_norm); // Rotary embedding // Apply GPT-NeoX or GPT-J style rotary embedding to query and key. - ops.def("rotary_embedding", &rotary_embedding); + ops.def( + "rotary_embedding(Tensor positions, Tensor! query," + " Tensor! key, int head_size," + " Tensor cos_sin_cache, bool is_neox) -> ()"); ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding); } TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { // Cache ops // Swap in (out) the cache blocks from src to dst. - cache_ops.def("swap_blocks", &swap_blocks); + cache_ops.def( + "swap_blocks(Tensor! src, Tensor! dst, Tensor block_mapping) -> ()"); cache_ops.impl("swap_blocks", torch::kCPU, &swap_blocks); // Copy the cache blocks from src to dst. - cache_ops.def("copy_blocks", ©_blocks); + cache_ops.def( + "copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor " + "block_mapping) -> ()"); cache_ops.impl("copy_blocks", torch::kCPU, ©_blocks); // Reshape the key and value tensors and cache them. - cache_ops.def("reshape_and_cache", &reshape_and_cache); + cache_ops.def( + "reshape_and_cache(Tensor key, Tensor value," + " Tensor! key_cache, Tensor! value_cache," + " Tensor slot_mapping," + " str kv_cache_dtype," + " float kv_scale) -> ()"); cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache); } diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index a7731094e721..44c0aa2a5c25 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -143,7 +143,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_head_sliding_step) -> ()"); - ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2); // Activation ops From c5562c8dbbf2363b38dcb6f078fb379176961c38 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 31 May 2024 18:08:40 +0000 Subject: [PATCH 21/41] convert custom_ar ops --- csrc/custom_all_reduce.cu | 10 ++-- csrc/custom_all_reduce.cuh | 9 ++-- csrc/ops.h | 8 +-- csrc/pybind.cpp | 39 +++++++++------ vllm/_custom_ops.py | 50 ++++++++++++++++++- .../device_communicators/custom_all_reduce.py | 36 +++++++------ 6 files changed, 105 insertions(+), 47 deletions(-) diff --git a/csrc/custom_all_reduce.cu b/csrc/custom_all_reduce.cu index 13b8d49f33ce..2420bd745127 100644 --- a/csrc/custom_all_reduce.cu +++ b/csrc/custom_all_reduce.cu @@ -5,13 +5,13 @@ #include "custom_all_reduce.cuh" -// fake pointer type -using fptr_t = uint64_t; +// fake pointer type, must match fptr_t type in ops.h +using fptr_t = int64_t; static_assert(sizeof(void*) == sizeof(fptr_t)); fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, const std::vector& handles, - const std::vector& offsets, int rank, + const std::vector& offsets, int64_t rank, bool full_nvlink) { int world_size = offsets.size(); if (world_size > 8) @@ -55,7 +55,7 @@ bool _is_weak_contiguous(torch::Tensor& t) { t.numel() * t.element_size()); } -bool should_custom_ar(torch::Tensor& inp, int max_size, int world_size, +bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size, bool full_nvlink) { auto inp_size = inp.numel() * inp.element_size(); // custom allreduce requires input byte size to be multiples of 16 @@ -134,7 +134,7 @@ void register_buffer(fptr_t _fa, torch::Tensor& t, fa->register_buffer(handles, offsets, t.data_ptr()); } -std::pair, std::vector> get_graph_buffer_ipc_meta( +std::tuple, std::vector> get_graph_buffer_ipc_meta( fptr_t _fa) { auto fa = reinterpret_cast(_fa); return fa->get_graph_buffer_ipc_meta(); diff --git a/csrc/custom_all_reduce.cuh b/csrc/custom_all_reduce.cuh index 1ed49b8aa9ca..8bc89834bed4 100644 --- a/csrc/custom_all_reduce.cuh +++ b/csrc/custom_all_reduce.cuh @@ -312,11 +312,12 @@ class CustomAllreduce { return it->second; } - std::pair, std::vector> + std::pair, std::vector> get_graph_buffer_ipc_meta() { auto num_buffers = graph_unreg_buffers_.size(); auto handle_sz = sizeof(cudaIpcMemHandle_t); - std::vector handles(handle_sz * num_buffers, 0); + std::string empty_handle_str(handle_sz, 0); + std::vector handles(num_buffers, empty_handle_str); std::vector offsets(num_buffers); for (int i = 0; i < num_buffers; i++) { auto ptr = graph_unreg_buffers_[i]; @@ -328,10 +329,10 @@ class CustomAllreduce { (CUdeviceptr)ptr) != CUDA_SUCCESS) throw std::runtime_error("failed to get pointer attr"); CUDACHECK(cudaIpcGetMemHandle( - (cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr)); + (cudaIpcMemHandle_t*)handles[i].data(), base_ptr)); offsets[i] = ((char*)ptr) - ((char*)base_ptr); } - return std::make_pair(handles, offsets); + return {handles, offsets}; } void check_rank_data_capacity(size_t num = 1) { diff --git a/csrc/ops.h b/csrc/ops.h index 24f31928c9fd..a12bda8562bf 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -126,12 +126,12 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, torch::Tensor num_tokens_post_pad); #ifndef USE_ROCM -using fptr_t = uint64_t; +using fptr_t = int64_t; fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, const std::vector& handles, - const std::vector& offsets, int rank, + const std::vector& offsets, int64_t rank, bool full_nvlink); -bool should_custom_ar(torch::Tensor& inp, int max_size, int world_size, +bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size, bool full_nvlink); void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, @@ -141,7 +141,7 @@ int64_t meta_size(); void register_buffer(fptr_t _fa, torch::Tensor& t, const std::vector& handles, const std::vector& offsets); -std::pair, std::vector> get_graph_buffer_ipc_meta( +std::tuple, std::vector> get_graph_buffer_ipc_meta( fptr_t _fa); void register_graph_buffers(fptr_t _fa, const std::vector& handles, const std::vector>& offsets); diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 44c0aa2a5c25..4865dd5fc658 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -131,7 +131,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1); // PagedAttention V2. - // def(ops, "paged_attention_v2", &paged_attention_v2, {0}); ops.def( "paged_attention_v2(" " Tensor! out, Tensor exp_sums, Tensor max_logits," @@ -171,8 +170,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def( "rms_norm(Tensor! out, Tensor input, Tensor weight, float epsilon) -> " "()"); - // ops.def(torch::schema("rms_norm(Tensor out, Tensor input, Tensor weight, - // float epsilon) -> ()"), c10::AliasAnalysisKind::CONSERVATIVE); ops.impl("rms_norm", torch::kCUDA, &rms_norm); // In-place fused Add and RMS Normalization. @@ -351,20 +348,30 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) { &get_max_shared_memory_per_block_device_attribute); } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { #ifndef USE_ROCM +TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { // Custom all-reduce kernels - pybind11::module custom_ar = m.def_submodule("custom_ar", "custom allreduce"); - custom_ar.def("init_custom_ar", &init_custom_ar, "init_custom_ar"); - custom_ar.def("should_custom_ar", &should_custom_ar, "should_custom_ar"); - custom_ar.def("all_reduce_reg", &all_reduce_reg, "all_reduce_reg"); - custom_ar.def("all_reduce_unreg", &all_reduce_unreg, "all_reduce_unreg"); - custom_ar.def("dispose", &dispose, "dispose"); - custom_ar.def("meta_size", &meta_size, "meta_size"); - custom_ar.def("register_buffer", ®ister_buffer, "register_buffer"); - custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta, - "get_graph_buffer_ipc_meta"); - custom_ar.def("register_graph_buffers", ®ister_graph_buffers, - "register_graph_buffers"); + custom_ar.def("init_custom_ar", &init_custom_ar); // modify inputs? + custom_ar.def("should_custom_ar", &should_custom_ar); + custom_ar.def("all_reduce_reg", &all_reduce_reg); // has out + custom_ar.def("all_reduce_unreg", &all_reduce_unreg); // has out + custom_ar.def("dispose", &dispose); + custom_ar.def("meta_size", &meta_size); + custom_ar.def("register_buffer", ®ister_buffer); + custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta); + custom_ar.def("register_graph_buffers", ®ister_graph_buffers); + + custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar); + custom_ar.impl("should_custom_ar", torch::kCUDA, &should_custom_ar); + custom_ar.impl("all_reduce_reg", torch::kCUDA, &all_reduce_reg); + custom_ar.impl("all_reduce_unreg", torch::kCUDA, &all_reduce_unreg); + custom_ar.impl("dispose", torch::kCPU, &dispose); + custom_ar.impl("meta_size", torch::kCPU, &meta_size); + custom_ar.impl("register_buffer", torch::kCUDA, ®ister_buffer); + custom_ar.impl("get_graph_buffer_ipc_meta", torch::kCPU, &get_graph_buffer_ipc_meta); + custom_ar.impl("register_graph_buffers", torch::kCPU, ®ister_graph_buffers); +} #endif + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { } diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 7881234a0f83..e8073096e0b5 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Type +from typing import Optional, Tuple, Type, List import torch @@ -360,4 +360,50 @@ def get_max_shared_memory_per_block_device_attribute(device: int) -> int: device) -#TODO: custom_ar +def init_custom_ar(meta: torch.Tensor, + rank_data: torch.Tensor, + handles: List[str], + offsets: List[int], + rank: int, + full_nvlink: bool) -> int: + return torch.ops._C_custom_ar.init_custom_ar(meta, rank_data, handles, offsets, rank, full_nvlink) + + +def should_custom_ar(inp: torch.Tensor, + max_size: int, + world_size: int, + full_nvlink: bool) -> bool: + return torch.ops._C_custom_ar.should_custom_ar(inp, max_size, world_size, full_nvlink) + + +def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None: + torch.ops._C_custom_ar.all_reduce_reg(fa, inp, out) + + +def all_reduce_unreg(fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor) -> None: + torch.ops._C_custom_ar.all_reduce_unreg(fa, inp, reg_buffer, out) + + +def dispose(fa: int) -> None: + torch.ops._C_custom_ar.dispose(fa) + + +def meta_size() -> int: + return torch.ops._C_custom_ar.meta_size() + + +def register_buffer(fa: int, + t: torch.Tensor, + handles: List[str], + offsets: List[int]) -> None: + return torch.ops._C_custom_ar.register_buffer(fa, t, handles, offsets) + + +def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[str], List[int]]: + return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa) + + +def register_graph_buffers(fa: int, + handles: List[str], + offsets: List[List[int]]) -> None: + torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets) diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index a3902aecb379..63de489cb682 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -15,7 +15,11 @@ try: import pynvml - from vllm._C import custom_ar + from vllm import _custom_ops as ops + + # this should fail if not installed. TODO: replace with some kind of ops query + ops.meta_size() + custom_ar = True @contextmanager def _nvml(): @@ -25,9 +29,9 @@ def _nvml(): finally: pynvml.nvmlShutdown() -except ImportError: +except (ImportError, AttributeError): # For AMD GPUs - custom_ar = None + custom_ar = False pynvml = None @contextmanager @@ -97,7 +101,7 @@ def __init__(self, self._IS_CAPTURING = False self.disabled = True - if custom_ar is None: + if not custom_ar: # disable because of missing custom allreduce library # e.g. in a non-cuda environment return @@ -175,7 +179,7 @@ def __init__(self, # meta data composes of two parts: meta data for synchronization # (256 bytes) and a temporary buffer for storing intermediate # allreduce results. - self.meta = torch.zeros(custom_ar.meta_size() + max_size, + self.meta = torch.zeros(ops.meta_size() + max_size, dtype=torch.uint8, device=self.device) # This is a pre-registered IPC buffer. In eager mode, input tensors @@ -196,9 +200,9 @@ def __init__(self, self.world_size = world_size handles, offsets = self._get_ipc_meta(self.meta) self.full_nvlink = full_nvlink - self._ptr = custom_ar.init_custom_ar(self.meta, self.rank_data, - handles, offsets, rank, - self.full_nvlink) + self._ptr = ops.init_custom_ar(self.meta, self.rank_data, + handles, offsets, rank, + self.full_nvlink) self.register_buffer(self.buffer) @contextmanager @@ -252,31 +256,31 @@ def _gather_ipc_meta(self, shard_data): def register_buffer(self, inp: torch.Tensor): handles, offsets = self._get_ipc_meta(inp) - custom_ar.register_buffer(self._ptr, inp, handles, offsets) + ops.register_buffer(self._ptr, inp, handles, offsets) def register_graph_buffers(self): - handle, offset = custom_ar.get_graph_buffer_ipc_meta(self._ptr) + handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr) handles, offsets = self._gather_ipc_meta((bytes(handle), offset)) logger.info("Registering %d cuda graph addresses", len(offset)) - custom_ar.register_graph_buffers(self._ptr, handles, offsets) + ops.register_graph_buffers(self._ptr, handles, offsets) def should_custom_ar(self, inp: torch.Tensor): - return custom_ar.should_custom_ar(inp, self.max_size, self.world_size, - self.full_nvlink) + return ops.should_custom_ar(inp, self.max_size, self.world_size, + self.full_nvlink) # all reduce, assuming inp tensor is IPC registered with register_buffer, # or, in the context of cuda graphs, register_graph_buffers def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None): if out is None: out = torch.empty_like(inp) - custom_ar.all_reduce_reg(self._ptr, inp, out) + ops.all_reduce_reg(self._ptr, inp, out) return out # all reduce, assuming inp tensor is NOT IPC registered def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None): if out is None: out = torch.empty_like(inp) - custom_ar.all_reduce_unreg(self._ptr, inp, self.buffer, out) + ops.all_reduce_unreg(self._ptr, inp, self.buffer, out) return out def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: @@ -304,7 +308,7 @@ def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: def close(self): if not self.disabled and self._ptr: - custom_ar.dispose(self._ptr) + ops.dispose(self._ptr) self._ptr = 0 def __del__(self): From a0988ac55b7eb4d4e6811dae60bba53b45e9e25d Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 31 May 2024 18:17:50 +0000 Subject: [PATCH 22/41] fix formatting --- csrc/custom_all_reduce.cu | 4 ++-- csrc/custom_all_reduce.cuh | 4 ++-- csrc/ops.h | 4 ++-- csrc/pybind.cpp | 13 +++++++------ .../device_communicators/custom_all_reduce.py | 3 ++- 5 files changed, 15 insertions(+), 13 deletions(-) diff --git a/csrc/custom_all_reduce.cu b/csrc/custom_all_reduce.cu index 2420bd745127..cfec3b572c0d 100644 --- a/csrc/custom_all_reduce.cu +++ b/csrc/custom_all_reduce.cu @@ -134,8 +134,8 @@ void register_buffer(fptr_t _fa, torch::Tensor& t, fa->register_buffer(handles, offsets, t.data_ptr()); } -std::tuple, std::vector> get_graph_buffer_ipc_meta( - fptr_t _fa) { +std::tuple, std::vector> +get_graph_buffer_ipc_meta(fptr_t _fa) { auto fa = reinterpret_cast(_fa); return fa->get_graph_buffer_ipc_meta(); } diff --git a/csrc/custom_all_reduce.cuh b/csrc/custom_all_reduce.cuh index 8bc89834bed4..1f4f0ff95514 100644 --- a/csrc/custom_all_reduce.cuh +++ b/csrc/custom_all_reduce.cuh @@ -328,8 +328,8 @@ class CustomAllreduce { CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, (CUdeviceptr)ptr) != CUDA_SUCCESS) throw std::runtime_error("failed to get pointer attr"); - CUDACHECK(cudaIpcGetMemHandle( - (cudaIpcMemHandle_t*)handles[i].data(), base_ptr)); + CUDACHECK(cudaIpcGetMemHandle((cudaIpcMemHandle_t*)handles[i].data(), + base_ptr)); offsets[i] = ((char*)ptr) - ((char*)base_ptr); } return {handles, offsets}; diff --git a/csrc/ops.h b/csrc/ops.h index a12bda8562bf..d53f1d9a61d3 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -141,8 +141,8 @@ int64_t meta_size(); void register_buffer(fptr_t _fa, torch::Tensor& t, const std::vector& handles, const std::vector& offsets); -std::tuple, std::vector> get_graph_buffer_ipc_meta( - fptr_t _fa); +std::tuple, std::vector> +get_graph_buffer_ipc_meta(fptr_t _fa); void register_graph_buffers(fptr_t _fa, const std::vector& handles, const std::vector>& offsets); #endif diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 4865dd5fc658..03aa78862dec 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -353,8 +353,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { // Custom all-reduce kernels custom_ar.def("init_custom_ar", &init_custom_ar); // modify inputs? custom_ar.def("should_custom_ar", &should_custom_ar); - custom_ar.def("all_reduce_reg", &all_reduce_reg); // has out - custom_ar.def("all_reduce_unreg", &all_reduce_unreg); // has out + custom_ar.def("all_reduce_reg", &all_reduce_reg); // has out + custom_ar.def("all_reduce_unreg", &all_reduce_unreg); // has out custom_ar.def("dispose", &dispose); custom_ar.def("meta_size", &meta_size); custom_ar.def("register_buffer", ®ister_buffer); @@ -368,10 +368,11 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { custom_ar.impl("dispose", torch::kCPU, &dispose); custom_ar.impl("meta_size", torch::kCPU, &meta_size); custom_ar.impl("register_buffer", torch::kCUDA, ®ister_buffer); - custom_ar.impl("get_graph_buffer_ipc_meta", torch::kCPU, &get_graph_buffer_ipc_meta); - custom_ar.impl("register_graph_buffers", torch::kCPU, ®ister_graph_buffers); + custom_ar.impl("get_graph_buffer_ipc_meta", torch::kCPU, + &get_graph_buffer_ipc_meta); + custom_ar.impl("register_graph_buffers", torch::kCPU, + ®ister_graph_buffers); } #endif -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { -} +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 63de489cb682..32a6bb509b8d 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -17,7 +17,8 @@ from vllm import _custom_ops as ops - # this should fail if not installed. TODO: replace with some kind of ops query + # this should fail if not installed. + # TODO: replace with some kind of ops query ops.meta_size() custom_ar = True From a0a2a00e160b516f830c5d9d4d8371e121fb50a8 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 31 May 2024 18:56:30 +0000 Subject: [PATCH 23/41] move punica and moe ops into _custom_ops --- vllm/_custom_ops.py | 70 ++++++++++++++++++- .../device_communicators/custom_all_reduce.py | 10 +-- vllm/lora/punica.py | 30 ++++---- .../layers/fused_moe/fused_moe.py | 5 +- 4 files changed, 88 insertions(+), 27 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index e8073096e0b5..a21161c0ff96 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -3,13 +3,30 @@ import torch try: - # ruff: noqa: F401 SIM105 + # ruff: noqa: SIM105 import vllm._C except ImportError as e: from vllm.logger import init_logger logger = init_logger(__name__) logger.warning("Failed to import from vllm._C with %r", e) +try: + # ruff: noqa: SIM105 + import vllm._C_moe +except ImportError: + pass + +try: + # ruff: noqa: SIM105, F401 + import vllm._C_punica +except ImportError: + pass + + +def is_custom_op_supported(op_name: str) -> bool: + op, overloads = torch._C._jit_get_operation(f'_C_{op_name}') + return op is not None + # activation ops def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: @@ -306,6 +323,20 @@ def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, num_tokens_post_pad) +def topk_softmax( + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + token_expert_indicies: torch.Tensor, + gating_output: float +) -> None: + torch.ops._moe_C.topk_softmax( + topk_weights, + topk_ids, + token_expert_indicies, + gating_output + ) + + def reshape_and_cache( key: torch.Tensor, value: torch.Tensor, @@ -360,6 +391,7 @@ def get_max_shared_memory_per_block_device_attribute(device: int) -> int: device) +# custom ar def init_custom_ar(meta: torch.Tensor, rank_data: torch.Tensor, handles: List[str], @@ -407,3 +439,39 @@ def register_graph_buffers(fa: int, handles: List[str], offsets: List[List[int]]) -> None: torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets) + + +# punica +def dispatch_bgmv( + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + indicies: torch.Tensor, + layer_idx: int, + scale: float, +) -> None: + torch.ops._punica_C.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale) + + +def dispatch_bgmv_low_level( + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + indicies: torch.Tensor, + layer_idx: int, + scale: float, + h_in: int, + h_out: int, + y_offset: int, +) -> None: + torch.ops._punica_C.dispatch_bgmv_low_level( + y, + x, + w_t_all, + indicies, + layer_idx, + scale, + h_in, + h_out, + y_offset, + ) diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 32a6bb509b8d..91aaf66c7f25 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -11,15 +11,15 @@ from vllm.distributed.parallel_state import ( get_local_rank, get_tensor_model_parallel_cpu_group) from vllm.logger import init_logger +from vllm import _custom_ops as ops try: import pynvml - from vllm import _custom_ops as ops + # Simulate ImportError if custom_ar ops are not supported. + if not ops.is_custom_op_supported("custom_ar::meta_size"): + raise ImportError("custom_ar", __file__) - # this should fail if not installed. - # TODO: replace with some kind of ops query - ops.meta_size() custom_ar = True @contextmanager @@ -30,7 +30,7 @@ def _nvml(): finally: pynvml.nvmlShutdown() -except (ImportError, AttributeError): +except ImportError: # For AMD GPUs custom_ar = False pynvml = None diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py index b0483bc97b97..465dac87e0a9 100644 --- a/vllm/lora/punica.py +++ b/vllm/lora/punica.py @@ -3,25 +3,22 @@ from typing import Optional import torch +from vllm import _custom_ops as ops - -def _raise_import_error(e): +def _raise_import_error(): if torch.cuda.get_device_capability() < (8, 0): raise ImportError( - "punica LoRA kernels require compute capability >= 8.0") from e + "punica LoRA kernels require compute capability >= 8.0") else: raise ImportError( "punica LoRA kernels could not be imported. If you built vLLM " "from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var " - "was set.") from e + "was set.") def _check_punica_support(): - try: - # ruff: noqa: F401 - import vllm._punica_C as punica_kernels - except ImportError as e: - _raise_import_error(e) + if not ops.is_custom_op_supported("punica::dispatch_bgmv"): + _raise_import_error() def bgmv( @@ -51,8 +48,7 @@ def bgmv( """ _check_punica_support() - torch.ops._punica_C.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, - scale) + ops.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale) def dispatch_bgmv_low_level(y: torch.Tensor, x: torch.Tensor, @@ -83,7 +79,7 @@ def dispatch_bgmv_low_level(y: torch.Tensor, x: torch.Tensor, """ _check_punica_support() - torch.ops._punica_C.dispatch_bgmv_low_level( + ops.dispatch_bgmv_low_level( y, x, w_t_all, @@ -136,10 +132,8 @@ def add_lora(y: torch.Tensor, buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) - torch.ops._punica_C.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, - 1.0) - torch.ops._punica_C.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx, - scale) + ops.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, 1.0) + ops.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx, scale) def add_lora_slice(y: torch.Tensor, @@ -188,7 +182,7 @@ def add_lora_slice(y: torch.Tensor, buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) - torch.ops._punica_C.dispatch_bgmv_low_level( + ops.dispatch_bgmv_low_level( buffer, x, wa_t_all, @@ -199,7 +193,7 @@ def add_lora_slice(y: torch.Tensor, buffer.size(1), 0, ) - torch.ops._punica_C.dispatch_bgmv_low_level( + ops.dispatch_bgmv_low_level( y, buffer, wb_t_all, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index d1c9049ce5b0..41faa44226e7 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -350,8 +350,7 @@ def fused_topk( dtype=torch.float32) topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1) else: - # ruff: noqa: F401 - import vllm._moe_C as moe_kernels + assert ops.is_custom_op_supported("moe::topk_softmax") topk_weights = torch.empty(M, topk, @@ -365,7 +364,7 @@ def fused_topk( topk, dtype=torch.int32, device=hidden_states.device) - torch.ops._moe_C.topk_softmax( + ops.topk_softmax( topk_weights, topk_ids, token_expert_indicies, From d9567247a545f5aeea999c3968d7557b3bb3d9b4 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 31 May 2024 20:28:38 +0000 Subject: [PATCH 24/41] maybe fix all_reduce --- vllm/distributed/device_communicators/custom_all_reduce.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 91aaf66c7f25..3ea83d2df821 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -261,7 +261,7 @@ def register_buffer(self, inp: torch.Tensor): def register_graph_buffers(self): handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr) - handles, offsets = self._gather_ipc_meta((bytes(handle), offset)) + handles, offsets = self._gather_ipc_meta(handle, offset) logger.info("Registering %d cuda graph addresses", len(offset)) ops.register_graph_buffers(self._ptr, handles, offsets) From a24f0231e62d85d9dc637b5291e1a45adeb65de8 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 31 May 2024 20:52:32 +0000 Subject: [PATCH 25/41] fix more formatting --- vllm/_custom_ops.py | 2 +- vllm/lora/punica.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index a21161c0ff96..7880e057b62d 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Type, List +from typing import List, Optional, Tuple, Type import torch diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py index 465dac87e0a9..436f5184c8e7 100644 --- a/vllm/lora/punica.py +++ b/vllm/lora/punica.py @@ -3,8 +3,10 @@ from typing import Optional import torch + from vllm import _custom_ops as ops + def _raise_import_error(): if torch.cuda.get_device_capability() < (8, 0): raise ImportError( From b1f61f40af380284aa1e883d5e208ae2618dff3c Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 1 Jun 2024 01:53:25 +0000 Subject: [PATCH 26/41] fix some stuff --- vllm/_custom_ops.py | 83 ++++++++----------- .../device_communicators/custom_all_reduce.py | 8 +- vllm/lora/punica.py | 2 +- .../layers/fused_moe/fused_moe.py | 2 +- 4 files changed, 42 insertions(+), 53 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 7880e057b62d..ce1f2e9af2b0 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -12,19 +12,19 @@ try: # ruff: noqa: SIM105 - import vllm._C_moe + import vllm._moe_C except ImportError: pass try: # ruff: noqa: SIM105, F401 - import vllm._C_punica + import vllm._punica_C except ImportError: pass def is_custom_op_supported(op_name: str) -> bool: - op, overloads = torch._C._jit_get_operation(f'_C_{op_name}') + op, overloads = torch._C._jit_get_operation(op_name) return op is not None @@ -323,18 +323,11 @@ def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, num_tokens_post_pad) -def topk_softmax( - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - token_expert_indicies: torch.Tensor, - gating_output: float -) -> None: - torch.ops._moe_C.topk_softmax( - topk_weights, - topk_ids, - token_expert_indicies, - gating_output - ) +def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, + token_expert_indicies: torch.Tensor, + gating_output: float) -> None: + torch.ops._moe_C.topk_softmax(topk_weights, topk_ids, + token_expert_indicies, gating_output) def reshape_and_cache( @@ -392,27 +385,25 @@ def get_max_shared_memory_per_block_device_attribute(device: int) -> int: # custom ar -def init_custom_ar(meta: torch.Tensor, - rank_data: torch.Tensor, - handles: List[str], - offsets: List[int], - rank: int, +def init_custom_ar(meta: torch.Tensor, rank_data: torch.Tensor, + handles: List[str], offsets: List[int], rank: int, full_nvlink: bool) -> int: - return torch.ops._C_custom_ar.init_custom_ar(meta, rank_data, handles, offsets, rank, full_nvlink) + return torch.ops._C_custom_ar.init_custom_ar(meta, rank_data, handles, + offsets, rank, full_nvlink) -def should_custom_ar(inp: torch.Tensor, - max_size: int, - world_size: int, +def should_custom_ar(inp: torch.Tensor, max_size: int, world_size: int, full_nvlink: bool) -> bool: - return torch.ops._C_custom_ar.should_custom_ar(inp, max_size, world_size, full_nvlink) + return torch.ops._C_custom_ar.should_custom_ar(inp, max_size, world_size, + full_nvlink) def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None: torch.ops._C_custom_ar.all_reduce_reg(fa, inp, out) -def all_reduce_unreg(fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor) -> None: +def all_reduce_unreg(fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, + out: torch.Tensor) -> None: torch.ops._C_custom_ar.all_reduce_unreg(fa, inp, reg_buffer, out) @@ -424,9 +415,7 @@ def meta_size() -> int: return torch.ops._C_custom_ar.meta_size() -def register_buffer(fa: int, - t: torch.Tensor, - handles: List[str], +def register_buffer(fa: int, t: torch.Tensor, handles: List[str], offsets: List[int]) -> None: return torch.ops._C_custom_ar.register_buffer(fa, t, handles, offsets) @@ -435,34 +424,34 @@ def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[str], List[int]]: return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa) -def register_graph_buffers(fa: int, - handles: List[str], +def register_graph_buffers(fa: int, handles: List[str], offsets: List[List[int]]) -> None: torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets) # punica def dispatch_bgmv( - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - indicies: torch.Tensor, - layer_idx: int, - scale: float, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + indicies: torch.Tensor, + layer_idx: int, + scale: float, ) -> None: - torch.ops._punica_C.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale) + torch.ops._punica_C.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, + scale) def dispatch_bgmv_low_level( - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - indicies: torch.Tensor, - layer_idx: int, - scale: float, - h_in: int, - h_out: int, - y_offset: int, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + indicies: torch.Tensor, + layer_idx: int, + scale: float, + h_in: int, + h_out: int, + y_offset: int, ) -> None: torch.ops._punica_C.dispatch_bgmv_low_level( y, diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 3ea83d2df821..961a9011bd48 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -201,9 +201,8 @@ def __init__(self, self.world_size = world_size handles, offsets = self._get_ipc_meta(self.meta) self.full_nvlink = full_nvlink - self._ptr = ops.init_custom_ar(self.meta, self.rank_data, - handles, offsets, rank, - self.full_nvlink) + self._ptr = ops.init_custom_ar(self.meta, self.rank_data, handles, + offsets, rank, self.full_nvlink) self.register_buffer(self.buffer) @contextmanager @@ -261,7 +260,8 @@ def register_buffer(self, inp: torch.Tensor): def register_graph_buffers(self): handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr) - handles, offsets = self._gather_ipc_meta(handle, offset) + #handles, offsets = self._gather_ipc_meta((bytes(handle), offset)) + handles, offsets = self._gather_ipc_meta((handle, offset)) logger.info("Registering %d cuda graph addresses", len(offset)) ops.register_graph_buffers(self._ptr, handles, offsets) diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py index 436f5184c8e7..c9d08f6346a3 100644 --- a/vllm/lora/punica.py +++ b/vllm/lora/punica.py @@ -19,7 +19,7 @@ def _raise_import_error(): def _check_punica_support(): - if not ops.is_custom_op_supported("punica::dispatch_bgmv"): + if not ops.is_custom_op_supported("_punica_C::dispatch_bgmv"): _raise_import_error() diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 41faa44226e7..6bed659defc5 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -350,7 +350,7 @@ def fused_topk( dtype=torch.float32) topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1) else: - assert ops.is_custom_op_supported("moe::topk_softmax") + assert ops.is_custom_op_supported("_moe_C::topk_softmax") topk_weights = torch.empty(M, topk, From 6d35d5c248a3b79475b2d4e2298513c1a31581b7 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 1 Jun 2024 20:03:31 +0000 Subject: [PATCH 27/41] use python stable api --- CMakeLists.txt | 3 + cmake/cpu_extension.cmake | 10 +- cmake/utils.cmake | 11 +- csrc/activation_kernels.cu | 2 +- csrc/attention/attention_kernels.cu | 2 +- csrc/cache.h | 2 +- csrc/cache_kernels.cu | 2 +- csrc/cpu/cpu_types.hpp | 2 +- csrc/cpu/pybind.cpp | 7 +- csrc/cuda_utils.h | 2 - csrc/custom_all_reduce.cu | 2 +- csrc/dispatch_utils.h | 2 +- csrc/layernorm_kernels.cu | 2 +- csrc/moe/moe_ops.cpp | 4 +- csrc/moe/moe_ops.h | 6 +- csrc/moe/topk_softmax_kernels.cu | 2 +- csrc/moe_align_block_size_kernels.cu | 2 +- csrc/ops.h | 58 +++++++- csrc/pos_encoding_kernels.cu | 2 +- csrc/punica/punica_ops.cu | 2 +- csrc/punica/punica_ops.h | 6 +- csrc/punica/punica_pybind.cpp | 4 +- csrc/pybind.cpp | 130 ++---------------- csrc/quantization/aqlm/gemm_kernels.cu | 33 ++++- csrc/quantization/awq/gemm_kernels.cu | 27 +++- .../compressed_tensors/int8_quant_kernels.cu | 2 +- .../cutlass_w8a8/scaled_mm_dq_c2x.cu | 2 +- .../cutlass_w8a8/scaled_mm_dq_c3x.cu | 2 +- .../cutlass_w8a8/scaled_mm_dq_entry.cu | 2 +- csrc/quantization/fp8/common.cu | 2 +- csrc/quantization/gptq/q_gemm.cu | 10 +- csrc/quantization/gptq_marlin/gptq_marlin.cu | 12 +- csrc/quantization/gptq_marlin/gptq_marlin.cuh | 2 +- .../marlin/dense/marlin_cuda_kernel.cu | 10 +- .../marlin/sparse/marlin_24_cuda_kernel.cu | 10 +- .../squeezellm/quant_cuda_kernel.cu | 1 - csrc/registration.h | 26 ++++ setup.py | 2 +- 38 files changed, 233 insertions(+), 175 deletions(-) create mode 100644 csrc/registration.h diff --git a/CMakeLists.txt b/CMakeLists.txt index a197063f3360..d49353ff2279 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -218,6 +218,7 @@ define_gpu_extension_target( COMPILE_FLAGS ${VLLM_GPU_FLAGS} ARCHITECTURES ${VLLM_GPU_ARCHES} INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} + USE_SABI 3 WITH_SOABI) # @@ -235,6 +236,7 @@ define_gpu_extension_target( SOURCES ${VLLM_MOE_EXT_SRC} COMPILE_FLAGS ${VLLM_GPU_FLAGS} ARCHITECTURES ${VLLM_GPU_ARCHES} + USE_SABI 3 WITH_SOABI) # @@ -286,6 +288,7 @@ if (VLLM_PUNICA_GPU_ARCHES) SOURCES ${VLLM_PUNICA_EXT_SRC} COMPILE_FLAGS ${VLLM_PUNICA_GPU_FLAGS} ARCHITECTURES ${VLLM_PUNICA_GPU_ARCHES} + USE_SABI 3 WITH_SOABI) else() message(WARNING "Unable to create _punica_C target because none of the " diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 0cf37769a696..4c6c6beb91b8 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -12,7 +12,7 @@ include_directories("${CMAKE_SOURCE_DIR}/csrc") # # Check the compile flags # -list(APPEND CXX_COMPILE_FLAGS +list(APPEND CXX_COMPILE_FLAGS "-fopenmp" "-DVLLM_CPU_EXTENSION") @@ -44,8 +44,8 @@ if (AVX512_FOUND) find_isa(${CPUINFO} "avx512_bf16" AVX512BF16_FOUND) if (AVX512BF16_FOUND OR ENABLE_AVX512BF16) - if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND - CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3) + if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND + CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3) list(APPEND CXX_COMPILE_FLAGS "-mavx512bf16") else() message(WARNING "Disable AVX512-BF16 ISA support, requires gcc/g++ >= 12.3") @@ -81,10 +81,10 @@ define_gpu_extension_target( LANGUAGE CXX SOURCES ${VLLM_EXT_SRC} COMPILE_FLAGS ${CXX_COMPILE_FLAGS} - WITH_SOABI + USE_SABI 3 + WITH_SOABI ) add_custom_target(default) message(STATUS "Enabling C extension.") add_dependencies(default _C) - diff --git a/cmake/utils.cmake b/cmake/utils.cmake index 00c81e4d00ad..f3c1286dd849 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -5,7 +5,7 @@ macro (find_python_from_executable EXECUTABLE SUPPORTED_VERSIONS) file(REAL_PATH ${EXECUTABLE} EXECUTABLE) set(Python_EXECUTABLE ${EXECUTABLE}) - find_package(Python COMPONENTS Interpreter Development.Module) + find_package(Python COMPONENTS Interpreter Development.Module Development.SABIModule) if (NOT Python_FOUND) message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.") endif() @@ -294,6 +294,7 @@ endmacro() # INCLUDE_DIRECTORIES - Extra include directories. # LIBRARIES - Extra link libraries. # WITH_SOABI - Generate library with python SOABI suffix name. +# USE_SABI - Use python stable api # # Note: optimization level/debug info is set via cmake build type. # @@ -301,7 +302,7 @@ function (define_gpu_extension_target GPU_MOD_NAME) cmake_parse_arguments(PARSE_ARGV 1 GPU "WITH_SOABI" - "DESTINATION;LANGUAGE" + "DESTINATION;LANGUAGE;USE_SABI" "SOURCES;ARCHITECTURES;COMPILE_FLAGS;INCLUDE_DIRECTORIES;LIBRARIES") # Add hipify preprocessing step when building with HIP/ROCm. @@ -315,7 +316,11 @@ function (define_gpu_extension_target GPU_MOD_NAME) set(GPU_WITH_SOABI) endif() - Python_add_library(${GPU_MOD_NAME} MODULE "${GPU_SOURCES}" ${GPU_WITH_SOABI}) + if (GPU_USE_SABI) + Python_add_library(${GPU_MOD_NAME} MODULE USE_SABI ${GPU_USE_SABI} ${GPU_WITH_SOABI} "${GPU_SOURCES}") + else() + Python_add_library(${GPU_MOD_NAME} MODULE ${GPU_WITH_SOABI} "${GPU_SOURCES}") + endif() if (GPU_LANGUAGE STREQUAL "HIP") # Make this target dependent on the hipify preprocessor step. diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 867f63f12de4..86ac2e75e78e 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -1,5 +1,5 @@ #include -#include +#include #include #include diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index fc2a53432abf..91083481705c 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -17,7 +17,7 @@ * limitations under the License. */ -#include +#include #include #include #include diff --git a/csrc/cache.h b/csrc/cache.h index bf8887a74602..ba2cbdceaaf8 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 80538ac18b98..e26bcb5d5cf7 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -1,4 +1,4 @@ -#include +#include #include #include diff --git a/csrc/cpu/cpu_types.hpp b/csrc/cpu/cpu_types.hpp index c1d3ec058b99..034c406a532d 100644 --- a/csrc/cpu/cpu_types.hpp +++ b/csrc/cpu/cpu_types.hpp @@ -3,7 +3,7 @@ #define CPU_TYPES_HPP #include -#include +#include namespace vec_op { diff --git a/csrc/cpu/pybind.cpp b/csrc/cpu/pybind.cpp index 6fe5df36c0db..f3b1936e73d6 100644 --- a/csrc/cpu/pybind.cpp +++ b/csrc/cpu/pybind.cpp @@ -1,6 +1,8 @@ #include "cache.h" #include "ops.h" -#include + +#include +#include TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // vLLM custom ops @@ -101,5 +103,4 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache); } -// TODO: get rid of this? -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/cuda_utils.h b/csrc/cuda_utils.h index 09f8a2f12dd0..73944f4c1489 100644 --- a/csrc/cuda_utils.h +++ b/csrc/cuda_utils.h @@ -1,7 +1,5 @@ #pragma once -#include - int64_t get_device_attribute(int64_t attribute, int64_t device_id); int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id); diff --git a/csrc/custom_all_reduce.cu b/csrc/custom_all_reduce.cu index cfec3b572c0d..fc76002cd4df 100644 --- a/csrc/custom_all_reduce.cu +++ b/csrc/custom_all_reduce.cu @@ -1,7 +1,7 @@ #include #include #include -#include +#include #include "custom_all_reduce.cuh" diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index 3ecea03242f0..a634e1c3d488 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -4,7 +4,7 @@ */ #pragma once -#include +#include #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index 73d3dfa9e81a..ca1c04bd880d 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -1,4 +1,4 @@ -#include +#include #include #include diff --git a/csrc/moe/moe_ops.cpp b/csrc/moe/moe_ops.cpp index cc1ee9e8ebb2..ed65ec63f913 100644 --- a/csrc/moe/moe_ops.cpp +++ b/csrc/moe/moe_ops.cpp @@ -1,3 +1,4 @@ +#include "registration.h" #include "moe_ops.h" TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { @@ -8,5 +9,4 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { m.impl("topk_softmax", torch::kCUDA, &topk_softmax); } -// TODO: get rid of this? -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index 733b8560658c..a251730aa765 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -1,10 +1,6 @@ #pragma once -#include - -#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE) -#define _CONCAT(A, B) A##B -#define CONCAT(A, B) _CONCAT(A, B) +#include void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices, torch::Tensor& token_expert_indices, diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu index 6ba4fcdb3a3f..de9747b60252 100644 --- a/csrc/moe/topk_softmax_kernels.cu +++ b/csrc/moe/topk_softmax_kernels.cu @@ -16,7 +16,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include +#include #include #include #include "../cuda_compat.h" diff --git a/csrc/moe_align_block_size_kernels.cu b/csrc/moe_align_block_size_kernels.cu index 63d7c73b1631..1f8d75da83bb 100644 --- a/csrc/moe_align_block_size_kernels.cu +++ b/csrc/moe_align_block_size_kernels.cu @@ -1,4 +1,4 @@ -#include +#include #include #include diff --git a/csrc/ops.h b/csrc/ops.h index d53f1d9a61d3..d6b43a6b1a38 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -1,8 +1,15 @@ #pragma once -#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE) -#define _CONCAT(A, B) A##B -#define CONCAT(A, B) _CONCAT(A, B) +#include + +// Note on op signatures (TODO) +// The X_meta signatures are for the meta functions corresponding to op X. +// They must be kept in sync with the signature for X. Generally, only +// functions that return Tensors require a meta function. +// +// See the following links for detailed docs on op registration and function schemas. +// https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ptttacy8y1u9 +// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations void paged_attention_v1( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, @@ -58,23 +65,46 @@ torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, const torch::Tensor& codebook_partition_sizes, const std::optional& bias); +torch::Tensor aqlm_gemm_meta(const torch::Tensor& input, const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& scales, + const torch::Tensor& codebook_partition_sizes, + const std::optional& bias); + torch::Tensor aqlm_dequant(const torch::Tensor& codes, const torch::Tensor& codebooks, const torch::Tensor& codebook_partition_sizes); +torch::Tensor aqlm_dequant_meta(const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& codebook_partition_sizes); + torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scaling_factors, torch::Tensor _zeros, int64_t split_k_iters); +torch::Tensor awq_gemm_meta(torch::Tensor _in_feats, torch::Tensor _kernel, + torch::Tensor _scaling_factors, torch::Tensor _zeros, + int64_t split_k_iters); + torch::Tensor awq_dequantize(torch::Tensor _kernel, torch::Tensor _scaling_factors, torch::Tensor _zeros, int64_t split_k_iters, int64_t thx, int64_t thy); +torch::Tensor awq_dequantize_meta(torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, int64_t split_k_iters, + int64_t thx, int64_t thy); + torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_scales, torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k); +torch::Tensor marlin_gemm_meta(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, torch::Tensor& workspace, + int64_t size_m, int64_t size_n, int64_t size_k); + torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_meta, torch::Tensor& b_scales, @@ -82,16 +112,33 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, int64_t size_m, int64_t size_n, int64_t size_k); +torch::Tensor gptq_marlin_24_gemm_meta(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_meta, + torch::Tensor& b_scales, + torch::Tensor& workspace, int64_t num_bits, + int64_t size_m, int64_t size_n, + int64_t size_k); + torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_scales, torch::Tensor& g_idx, torch::Tensor& perm, torch::Tensor& workspace, int64_t num_bits, int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full); +torch::Tensor gptq_marlin_gemm_meta(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, torch::Tensor& g_idx, + torch::Tensor& perm, torch::Tensor& workspace, + int64_t num_bits, int64_t size_m, int64_t size_n, + int64_t size_k, bool is_k_full); + torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits); +torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight, + torch::Tensor& perm, int64_t size_k, + int64_t size_n, int64_t num_bits); + void cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales); @@ -112,6 +159,11 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, bool use_exllama, int64_t bit); +torch::Tensor gptq_gemm_meta(torch::Tensor a, torch::Tensor b_q_weight, + torch::Tensor b_gptq_qzeros, + torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, + bool use_exllama, int64_t bit); + void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit); void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input, diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index caca03284735..97184a873559 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -1,4 +1,4 @@ -#include +#include #include #include diff --git a/csrc/punica/punica_ops.cu b/csrc/punica/punica_ops.cu index e345d8a24d45..dd29820144b3 100644 --- a/csrc/punica/punica_ops.cu +++ b/csrc/punica/punica_ops.cu @@ -1,4 +1,4 @@ -#include +#include #include #include diff --git a/csrc/punica/punica_ops.h b/csrc/punica/punica_ops.h index 110c94d402df..1242f6228633 100644 --- a/csrc/punica/punica_ops.h +++ b/csrc/punica/punica_ops.h @@ -1,10 +1,6 @@ #pragma once -#include - -#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE) -#define _CONCAT(A, B) A##B -#define CONCAT(A, B) _CONCAT(A, B) +#include void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, torch::Tensor indicies, int64_t layer_idx, double scale); diff --git a/csrc/punica/punica_pybind.cpp b/csrc/punica/punica_pybind.cpp index cd464c5f663c..894e229b6d9d 100644 --- a/csrc/punica/punica_pybind.cpp +++ b/csrc/punica/punica_pybind.cpp @@ -1,3 +1,4 @@ +#include "registration.h" #include "punica_ops.h" TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { @@ -14,5 +15,4 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { m.impl("dispatch_bgmv_low_level", torch::kCUDA, &dispatch_bgmv_low_level); } -// TODO: get rid of this -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 03aa78862dec..c8bae203635c 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -1,116 +1,9 @@ #include "cache.h" #include "cuda_utils.h" #include "ops.h" -// #include "quantization/gptq_marlin/gptq_marlin.cuh" //?? -#include - -torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight, - torch::Tensor& perm, int64_t size_k, - int64_t size_n, int64_t num_bits); - -// See -// https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ptttacy8y1u9 -// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations - -// where should these live? near the implementations of the kernels? -namespace vllm::meta { - -torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& scales, - const torch::Tensor& codebook_partition_sizes, - const std::optional& bias) { - auto input_sizes = input.sizes(); - - auto out_features = codes.size(0) * codebooks.size(2); - auto flat_input = input.reshape({-1, input.size(-1)}); - auto flat_output = torch::empty( - {flat_input.size(0), out_features}, - torch::TensorOptions().dtype(input.dtype()).device(input.device())); - - auto output_sizes = input_sizes.vec(); - output_sizes.pop_back(); - output_sizes.push_back(-1); - return flat_output.reshape(output_sizes); -} - -torch::Tensor aqlm_dequant(const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& codebook_partition_sizes) { - auto in_features = codes.size(1) * 8; - auto out_features = codes.size(0); - return torch::empty({out_features, in_features}, - torch::TensorOptions() - .dtype(codebooks.dtype()) - .device(codebooks.device())); -} - -torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, - torch::Tensor _scaling_factors, torch::Tensor _zeros, - int64_t split_k_iters) { - int num_in_feats = _in_feats.size(0); - auto options = torch::TensorOptions() - .dtype(_in_feats.dtype()) - .device(_in_feats.device()); -#if 0 - at::Tensor _out_feats = - torch::empty({num_in_feats, _kernel.size(1) * 8}, options); - return _out_feats.sum(0); -#else - return torch::empty({_kernel.size(1) * 8}, options); -#endif -} - -torch::Tensor awq_dequantize(torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, int64_t split_k_iters, - int64_t thx, int64_t thy) { - int in_c = _kernel.size(0); - int qout_c = _kernel.size(1); - int out_c = qout_c * 8; - - auto options = torch::TensorOptions() - .dtype(_scaling_factors.dtype()) - .device(_scaling_factors.device()); - - return torch::empty({in_c, out_c}, options); -} - -torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& workspace, - int64_t size_m, int64_t size_n, int64_t size_k) { - auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); - return torch::empty({size_m, size_n}, options); -} - -torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_meta, - torch::Tensor& b_scales, - torch::Tensor& workspace, int64_t num_bits, - int64_t size_m, int64_t size_n, - int64_t size_k) { - auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); - return torch::empty({size_m, size_n}, options); -} - -torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& g_idx, - torch::Tensor& perm, torch::Tensor& workspace, - int64_t num_bits, int64_t size_m, int64_t size_n, - int64_t size_k, bool is_k_full) { - auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); - return torch::empty({size_m, size_n}, options); -} - -torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, - torch::Tensor b_gptq_qzeros, - torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, - bool use_exllama, int64_t bit) { - auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); - return torch::empty({a.size(0), b_q_weight.size(1)}, options); -} +#include "registration.h" -} // namespace vllm::meta +#include TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // vLLM custom ops @@ -201,38 +94,37 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Quantized GEMM for AQLM. ops.def("aqlm_gemm", &aqlm_gemm); ops.impl("aqlm_gemm", torch::kCUDA, &aqlm_gemm); - ops.impl("aqlm_gemm", torch::kMeta, &vllm::meta::aqlm_gemm); + ops.impl("aqlm_gemm", torch::kMeta, &aqlm_gemm_meta); // Decompression method for AQLM. ops.def("aqlm_dequant", &aqlm_dequant); ops.impl("aqlm_dequant", torch::kCUDA, &aqlm_dequant); - ops.impl("aqlm_dequant", torch::kMeta, &vllm::meta::aqlm_dequant); + ops.impl("aqlm_dequant", torch::kMeta, &aqlm_dequant_meta); // Quantized GEMM for AWQ. ops.def("awq_gemm", &awq_gemm); ops.impl("awq_gemm", torch::kCUDA, &awq_gemm); - ops.impl("awq_gemm", torch::kMeta, &vllm::meta::awq_gemm); + ops.impl("awq_gemm", torch::kMeta, &awq_gemm_meta); // Dequantization for AWQ. ops.def("awq_dequantize", &awq_dequantize); ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize); - ops.impl("awq_dequantize", torch::kMeta, &vllm::meta::awq_dequantize); + ops.impl("awq_dequantize", torch::kMeta, &awq_dequantize_meta); // Marlin (Dense) Optimized Quantized GEMM for GPTQ. ops.def("marlin_gemm", &marlin_gemm); ops.impl("marlin_gemm", torch::kCUDA, &marlin_gemm); - ops.impl("marlin_gemm", torch::kMeta, &vllm::meta::marlin_gemm); + ops.impl("marlin_gemm", torch::kMeta, &marlin_gemm_meta); // Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ. ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm); ops.impl("gptq_marlin_24_gemm", torch::kCUDA, &gptq_marlin_24_gemm); - ops.impl("gptq_marlin_24_gemm", torch::kMeta, - &vllm::meta::gptq_marlin_24_gemm); + ops.impl("gptq_marlin_24_gemm", torch::kMeta, &gptq_marlin_24_gemm_meta); // gptq_marlin Optimized Quantized GEMM for GPTQ. ops.def("gptq_marlin_gemm", &gptq_marlin_gemm); ops.impl("gptq_marlin_gemm", torch::kCUDA, &gptq_marlin_gemm); - ops.impl("gptq_marlin_gemm", torch::kMeta, &vllm::meta::gptq_marlin_gemm); + ops.impl("gptq_marlin_gemm", torch::kMeta, &gptq_marlin_gemm_meta); // gptq_marlin repack from GPTQ. ops.def("gptq_marlin_repack", &gptq_marlin_repack); @@ -251,7 +143,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Quantized GEMM for GPTQ. ops.def("gptq_gemm", &gptq_gemm); ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm); - ops.impl("gptq_gemm", torch::kMeta, &vllm::meta::gptq_gemm); + ops.impl("gptq_gemm", torch::kMeta, &gptq_gemm_meta); // Post processing for GPTQ. ops.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()"); @@ -375,4 +267,4 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { } #endif -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/quantization/aqlm/gemm_kernels.cu b/csrc/quantization/aqlm/gemm_kernels.cu index 255844eec56d..2ed6325c56ce 100644 --- a/csrc/quantization/aqlm/gemm_kernels.cu +++ b/csrc/quantization/aqlm/gemm_kernels.cu @@ -18,7 +18,7 @@ #include #include #include -#include +#include #include #include @@ -543,6 +543,26 @@ torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, return {}; } +torch::Tensor aqlm_gemm_meta(const torch::Tensor& input, + const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& scales, + const torch::Tensor& codebook_partition_sizes, + const std::optional& bias) { + auto input_sizes = input.sizes(); + + auto out_features = codes.size(0) * codebooks.size(2); + auto flat_input = input.reshape({-1, input.size(-1)}); + auto flat_output = torch::empty( + {flat_input.size(0), out_features}, + torch::TensorOptions().dtype(input.dtype()).device(input.device())); + + auto output_sizes = input_sizes.vec(); + output_sizes.pop_back(); + output_sizes.push_back(-1); + return flat_output.reshape(output_sizes); +} + torch::Tensor aqlm_dequant(const torch::Tensor& codes, const torch::Tensor& codebooks, const torch::Tensor& codebook_partition_sizes) { @@ -596,3 +616,14 @@ torch::Tensor aqlm_dequant(const torch::Tensor& codes, " entries is not currently supported.") return {}; } + +torch::Tensor aqlm_dequant_meta(const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& codebook_partition_sizes) { + auto in_features = codes.size(1) * 8; + auto out_features = codes.size(0); + return torch::empty({out_features, in_features}, + torch::TensorOptions() + .dtype(codebooks.dtype()) + .device(codebooks.device())); +} diff --git a/csrc/quantization/awq/gemm_kernels.cu b/csrc/quantization/awq/gemm_kernels.cu index 079694dcd5be..dd71d286c740 100644 --- a/csrc/quantization/awq/gemm_kernels.cu +++ b/csrc/quantization/awq/gemm_kernels.cu @@ -7,7 +7,7 @@ Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023} } */ -#include +#include #include #include "dequantize.cuh" @@ -483,6 +483,21 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel, return _de_kernel; } +torch::Tensor awq_dequantize_meta(torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, int64_t split_k_iters, + int64_t thx, int64_t thy) { + int in_c = _kernel.size(0); + int qout_c = _kernel.size(1); + int out_c = qout_c * 8; + + auto options = torch::TensorOptions() + .dtype(_scaling_factors.dtype()) + .device(_scaling_factors.device()); + + return torch::empty({in_c, out_c}, options); +} + // in_feats: M, IC [float16] // kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b] // scaling_factors: IC // G, OC [float16] @@ -547,3 +562,13 @@ torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, } return _out_feats.sum(0); } + +torch::Tensor awq_gemm_meta(torch::Tensor _in_feats, torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, int64_t split_k_iters) { + int num_in_feats = _in_feats.size(0); + auto options = torch::TensorOptions() + .dtype(_in_feats.dtype()) + .device(_in_feats.device()); + return torch::empty({_kernel.size(1) * 8}, options); +} diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index 280b0327111d..aa9511daa277 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -1,5 +1,5 @@ #include -#include +#include #include #include "../../dispatch_utils.h" diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu index 088fee4783fa..23a8b4070b70 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu @@ -1,5 +1,5 @@ #include -#include +#include #include diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu index 8fc4ba662ecd..a99802153643 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu @@ -4,7 +4,7 @@ #if defined CUDA_VERSION && CUDA_VERSION >= 12000 -#include +#include #include diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu index eb532f2ac7a9..423e64a4932e 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu @@ -1,7 +1,7 @@ #include #include -#include +#include void cutlass_scaled_mm_dq_sm75(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index 55be3305a9b8..8c5b693bf6ed 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -1,5 +1,5 @@ #include -#include +#include #include #include diff --git a/csrc/quantization/gptq/q_gemm.cu b/csrc/quantization/gptq/q_gemm.cu index 91813f306713..52dc4ab0837b 100644 --- a/csrc/quantization/gptq/q_gemm.cu +++ b/csrc/quantization/gptq/q_gemm.cu @@ -6,7 +6,7 @@ https://github.com/qwopqwop200/GPTQ-for-LLaMa #include #include -#include +#include #include #include #include @@ -1854,3 +1854,11 @@ void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit) { : (int*)q_perm.data_ptr(), q_weight.size(0) * 32 / bit, q_weight.size(1), bit); } + +torch::Tensor gptq_gemm_meta(torch::Tensor a, torch::Tensor b_q_weight, + torch::Tensor b_gptq_qzeros, + torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, + bool use_exllama, int64_t bit) { + auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); + return torch::empty({a.size(0), b_q_weight.size(1)}, options); +} diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index c573b9041065..12192fcaef34 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -1867,4 +1867,14 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, return c; } -#endif \ No newline at end of file +#endif + +torch::Tensor gptq_marlin_gemm_meta(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, + torch::Tensor& g_idx, torch::Tensor& perm, + torch::Tensor& workspace, int64_t num_bits, + int64_t size_m, int64_t size_n, + int64_t size_k, bool is_k_full) { + auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); + return torch::empty({size_m, size_n}, options); +} diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cuh b/csrc/quantization/gptq_marlin/gptq_marlin.cuh index ba5368ea8835..42af44951efd 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cuh +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cuh @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include diff --git a/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu b/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu index 03d66cecedf1..7a45d1327a20 100644 --- a/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu +++ b/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu @@ -15,7 +15,7 @@ * limitations under the License. */ -#include +#include #include #include @@ -1134,3 +1134,11 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, return c; } + +torch::Tensor marlin_gemm_meta(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, + torch::Tensor& workspace, int64_t size_m, + int64_t size_n, int64_t size_k) { + auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); + return torch::empty({size_m, size_n}, options); +} diff --git a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu index 686dd7851e6a..b32c54bd5986 100644 --- a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu +++ b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu @@ -16,7 +16,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include +#include #include #include @@ -1123,3 +1123,11 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, return c; } + +torch::Tensor gptq_marlin_24_gemm_meta( + torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_meta, + torch::Tensor& b_scales, torch::Tensor& workspace, int64_t num_bits, + int64_t size_m, int64_t size_n, int64_t size_k) { + auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); + return torch::empty({size_m, size_n}, options); +} diff --git a/csrc/quantization/squeezellm/quant_cuda_kernel.cu b/csrc/quantization/squeezellm/quant_cuda_kernel.cu index 1b339fa4b392..40baac610869 100644 --- a/csrc/quantization/squeezellm/quant_cuda_kernel.cu +++ b/csrc/quantization/squeezellm/quant_cuda_kernel.cu @@ -1,5 +1,4 @@ #include -#include #include #include #include diff --git a/csrc/registration.h b/csrc/registration.h new file mode 100644 index 000000000000..be93cef37e8b --- /dev/null +++ b/csrc/registration.h @@ -0,0 +1,26 @@ +#pragma once + +#include + +#define _CONCAT(A, B) A##B +#define CONCAT(A, B) _CONCAT(A, B) + +#define _STRINGIFY(A) #A +#define STRINGIFY(A) _STRINGIFY(A) + +#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE) + +// REGISTER_EXTENSION allows the shared library to be loaded and initialized +// via python's import statement. +#define REGISTER_EXTENSION(NAME) \ +PyMODINIT_FUNC CONCAT(PyInit_,NAME)() \ +{ \ + static struct PyModuleDef module = { \ + PyModuleDef_HEAD_INIT, \ + STRINGIFY(NAME), \ + nullptr, \ + 0, \ + nullptr \ + }; \ + return PyModule_Create(&module); \ +} diff --git a/setup.py b/setup.py index f7d465b60c15..339b0ad6de2d 100644 --- a/setup.py +++ b/setup.py @@ -60,7 +60,7 @@ def remove_prefix(text, prefix): class CMakeExtension(Extension): def __init__(self, name: str, cmake_lists_dir: str = '.', **kwa) -> None: - super().__init__(name, sources=[], **kwa) + super().__init__(name, sources=[], py_limited_api=True, **kwa) self.cmake_lists_dir = os.path.abspath(cmake_lists_dir) From 1fcde34cbc03d7e0f1bc897f2053a7d44e5096d1 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 1 Jun 2024 20:07:30 +0000 Subject: [PATCH 28/41] fix cpu --- csrc/cpu/pybind.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/cpu/pybind.cpp b/csrc/cpu/pybind.cpp index f3b1936e73d6..3d63a96b1405 100644 --- a/csrc/cpu/pybind.cpp +++ b/csrc/cpu/pybind.cpp @@ -1,8 +1,8 @@ #include "cache.h" #include "ops.h" +#include "registration.h" #include -#include TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // vLLM custom ops From 88dc7244480051e7e1962a62de1d82f18aef14f4 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 1 Jun 2024 20:18:05 +0000 Subject: [PATCH 29/41] add comment --- vllm/_custom_ops.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index ce1f2e9af2b0..59193316b4d6 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -2,6 +2,8 @@ import torch +# TODO: try torch.ops.load_library here + try: # ruff: noqa: SIM105 import vllm._C From 9d42c29fe6ebe516120127aed8229d9f921d216f Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 1 Jun 2024 20:37:07 +0000 Subject: [PATCH 30/41] fix punica --- csrc/punica/punica_ops.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/punica/punica_ops.h b/csrc/punica/punica_ops.h index 1242f6228633..5d625d0564f7 100644 --- a/csrc/punica/punica_ops.h +++ b/csrc/punica/punica_ops.h @@ -1,6 +1,6 @@ #pragma once -#include +#include void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, torch::Tensor indicies, int64_t layer_idx, double scale); From f178aed7b29b8ef88361265db06af7a538a15691 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 1 Jun 2024 20:43:01 +0000 Subject: [PATCH 31/41] try to fix ROCM dockerfile --- Dockerfile.rocm | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Dockerfile.rocm b/Dockerfile.rocm index e30a2aaf3020..954958df88fc 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -106,9 +106,9 @@ RUN --mount=type=cache,target=/root/.cache/pip \ pip install -U -r requirements-rocm.txt \ && patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch \ && python3 setup.py install \ - && cp build/lib.linux-x86_64-cpython-39/vllm/_C.cpython-39-x86_64-linux-gnu.so vllm/ \ - && cp build/lib.linux-x86_64-cpython-39/vllm/_punica_C.cpython-39-x86_64-linux-gnu.so vllm/ \ - && cp build/lib.linux-x86_64-cpython-39/vllm/_moe_C.cpython-39-x86_64-linux-gnu.so vllm/ \ + && cp build/lib.linux-x86_64-cpython-39/vllm/_C.abi3.so vllm/ \ + && cp build/lib.linux-x86_64-cpython-39/vllm/_punica_C.abi3.so vllm/ \ + && cp build/lib.linux-x86_64-cpython-39/vllm/_moe_C.abi3.so vllm/ \ && cd .. From 13e36f0e77bf0cd9aac42b61000f835a795a78d4 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 3 Jun 2024 18:04:10 +0000 Subject: [PATCH 32/41] cleanups --- csrc/ops.h | 9 ----- csrc/pybind.cpp | 39 ++++++++++++++----- csrc/registration.h | 2 + vllm/_custom_ops.py | 15 ++----- .../device_communicators/custom_all_reduce.py | 3 +- vllm/lora/punica.py | 10 ++--- vllm/model_executor/models/llama.py | 2 - 7 files changed, 40 insertions(+), 40 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index d6b43a6b1a38..f28ae5be17b5 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -2,15 +2,6 @@ #include -// Note on op signatures (TODO) -// The X_meta signatures are for the meta functions corresponding to op X. -// They must be kept in sync with the signature for X. Generally, only -// functions that return Tensors require a meta function. -// -// See the following links for detailed docs on op registration and function schemas. -// https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ptttacy8y1u9 -// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations - void paged_attention_v1( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int64_t num_kv_heads, double scale, diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index c8bae203635c..8955a49a97bb 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -5,6 +5,16 @@ #include +// Note on op signatures: +// The X_meta signatures are for the meta functions corresponding to op X. +// They must be kept in sync with the signature for X. Generally, only +// functions that return Tensors require a meta function. +// +// See the following links for detailed docs on op registration and function +// schemas. +// https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ptttacy8y1u9 +// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations + TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // vLLM custom ops @@ -243,25 +253,34 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) { #ifndef USE_ROCM TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { // Custom all-reduce kernels - custom_ar.def("init_custom_ar", &init_custom_ar); // modify inputs? - custom_ar.def("should_custom_ar", &should_custom_ar); - custom_ar.def("all_reduce_reg", &all_reduce_reg); // has out - custom_ar.def("all_reduce_unreg", &all_reduce_unreg); // has out - custom_ar.def("dispose", &dispose); - custom_ar.def("meta_size", &meta_size); - custom_ar.def("register_buffer", ®ister_buffer); - custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta); - custom_ar.def("register_graph_buffers", ®ister_graph_buffers); - + custom_ar.def("init_custom_ar", &init_custom_ar); custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar); + + custom_ar.def("should_custom_ar", &should_custom_ar); custom_ar.impl("should_custom_ar", torch::kCUDA, &should_custom_ar); + + custom_ar.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()"); custom_ar.impl("all_reduce_reg", torch::kCUDA, &all_reduce_reg); + + custom_ar.def( + "all_reduce_unreg(int fa, Tensor inp, Tensor reg_buffer, Tensor! out) -> " + "()"); custom_ar.impl("all_reduce_unreg", torch::kCUDA, &all_reduce_unreg); + + custom_ar.def("dispose", &dispose); custom_ar.impl("dispose", torch::kCPU, &dispose); + + custom_ar.def("meta_size", &meta_size); custom_ar.impl("meta_size", torch::kCPU, &meta_size); + + custom_ar.def("register_buffer", ®ister_buffer); custom_ar.impl("register_buffer", torch::kCUDA, ®ister_buffer); + + custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta); custom_ar.impl("get_graph_buffer_ipc_meta", torch::kCPU, &get_graph_buffer_ipc_meta); + + custom_ar.def("register_graph_buffers", ®ister_graph_buffers); custom_ar.impl("register_graph_buffers", torch::kCPU, ®ister_graph_buffers); } diff --git a/csrc/registration.h b/csrc/registration.h index be93cef37e8b..7002e24d74c3 100644 --- a/csrc/registration.h +++ b/csrc/registration.h @@ -8,6 +8,8 @@ #define _STRINGIFY(A) #A #define STRINGIFY(A) _STRINGIFY(A) +// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME +// could be a macro instead of a literal token. #define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE) // REGISTER_EXTENSION allows the shared library to be loaded and initialized diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 59193316b4d6..e8f387f9ba6c 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1,28 +1,21 @@ from typing import List, Optional, Tuple, Type +import contextlib import torch -# TODO: try torch.ops.load_library here - try: - # ruff: noqa: SIM105 import vllm._C except ImportError as e: from vllm.logger import init_logger logger = init_logger(__name__) logger.warning("Failed to import from vllm._C with %r", e) -try: - # ruff: noqa: SIM105 +with contextlib.suppress(ImportError): import vllm._moe_C -except ImportError: - pass -try: - # ruff: noqa: SIM105, F401 +with contextlib.suppress(ImportError): + # ruff: noqa: F401 import vllm._punica_C -except ImportError: - pass def is_custom_op_supported(op_name: str) -> bool: diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 961a9011bd48..5d0b1db42838 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -17,7 +17,7 @@ import pynvml # Simulate ImportError if custom_ar ops are not supported. - if not ops.is_custom_op_supported("custom_ar::meta_size"): + if not ops.is_custom_op_supported("_C_custom_ar::meta_size"): raise ImportError("custom_ar", __file__) custom_ar = True @@ -260,7 +260,6 @@ def register_buffer(self, inp: torch.Tensor): def register_graph_buffers(self): handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr) - #handles, offsets = self._gather_ipc_meta((bytes(handle), offset)) handles, offsets = self._gather_ipc_meta((handle, offset)) logger.info("Registering %d cuda graph addresses", len(offset)) ops.register_graph_buffers(self._ptr, handles, offsets) diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py index c9d08f6346a3..7ecaa450f175 100644 --- a/vllm/lora/punica.py +++ b/vllm/lora/punica.py @@ -7,7 +7,10 @@ from vllm import _custom_ops as ops -def _raise_import_error(): +def _check_punica_support(): + if ops.is_custom_op_supported("_punica_C::dispatch_bgmv"): + return + if torch.cuda.get_device_capability() < (8, 0): raise ImportError( "punica LoRA kernels require compute capability >= 8.0") @@ -18,11 +21,6 @@ def _raise_import_error(): "was set.") -def _check_punica_support(): - if not ops.is_custom_op_supported("_punica_C::dispatch_bgmv"): - _raise_import_error() - - def bgmv( y: torch.Tensor, x: torch.Tensor, diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 4fd3a9bbd0be..d83ee9a201c0 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -270,7 +270,6 @@ def __init__( def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) - #@torch.compile def forward( self, input_ids: Optional[torch.Tensor], @@ -362,7 +361,6 @@ def __init__( config.vocab_size, logit_scale) self.sampler = Sampler() - #@torch.compile def forward( self, input_ids: torch.Tensor, From 415dc4275c007a16d416cfbae28638b51feea447 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 3 Jun 2024 20:42:57 +0000 Subject: [PATCH 33/41] rebase + use Tensor instead of std::vector in custom ar api --- csrc/custom_all_reduce.cu | 10 +++-- csrc/custom_all_reduce.cuh | 11 +++-- csrc/ops.h | 2 +- csrc/pybind.cpp | 2 +- .../device_communicators/custom_all_reduce.py | 2 +- .../layers/fused_moe/fused_moe.py | 43 ++++++++----------- 6 files changed, 32 insertions(+), 38 deletions(-) diff --git a/csrc/custom_all_reduce.cu b/csrc/custom_all_reduce.cu index fc76002cd4df..9f063a71ebb1 100644 --- a/csrc/custom_all_reduce.cu +++ b/csrc/custom_all_reduce.cu @@ -134,10 +134,14 @@ void register_buffer(fptr_t _fa, torch::Tensor& t, fa->register_buffer(handles, offsets, t.data_ptr()); } -std::tuple, std::vector> -get_graph_buffer_ipc_meta(fptr_t _fa) { +std::tuple> get_graph_buffer_ipc_meta( + fptr_t _fa) { auto fa = reinterpret_cast(_fa); - return fa->get_graph_buffer_ipc_meta(); + auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta(); + auto options = torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); + auto handles = torch::empty({static_cast(handle_bytes.size())}, options); + std::memcpy(handles.data_ptr(), handle_bytes.data(), handle_bytes.size()); + return {handles, std::move(offsets)}; } void register_graph_buffers(fptr_t _fa, const std::vector& handles, diff --git a/csrc/custom_all_reduce.cuh b/csrc/custom_all_reduce.cuh index 1f4f0ff95514..1ed49b8aa9ca 100644 --- a/csrc/custom_all_reduce.cuh +++ b/csrc/custom_all_reduce.cuh @@ -312,12 +312,11 @@ class CustomAllreduce { return it->second; } - std::pair, std::vector> + std::pair, std::vector> get_graph_buffer_ipc_meta() { auto num_buffers = graph_unreg_buffers_.size(); auto handle_sz = sizeof(cudaIpcMemHandle_t); - std::string empty_handle_str(handle_sz, 0); - std::vector handles(num_buffers, empty_handle_str); + std::vector handles(handle_sz * num_buffers, 0); std::vector offsets(num_buffers); for (int i = 0; i < num_buffers; i++) { auto ptr = graph_unreg_buffers_[i]; @@ -328,11 +327,11 @@ class CustomAllreduce { CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, (CUdeviceptr)ptr) != CUDA_SUCCESS) throw std::runtime_error("failed to get pointer attr"); - CUDACHECK(cudaIpcGetMemHandle((cudaIpcMemHandle_t*)handles[i].data(), - base_ptr)); + CUDACHECK(cudaIpcGetMemHandle( + (cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr)); offsets[i] = ((char*)ptr) - ((char*)base_ptr); } - return {handles, offsets}; + return std::make_pair(handles, offsets); } void check_rank_data_capacity(size_t num = 1) { diff --git a/csrc/ops.h b/csrc/ops.h index f28ae5be17b5..61efab180c67 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -184,7 +184,7 @@ int64_t meta_size(); void register_buffer(fptr_t _fa, torch::Tensor& t, const std::vector& handles, const std::vector& offsets); -std::tuple, std::vector> +std::tuple> get_graph_buffer_ipc_meta(fptr_t _fa); void register_graph_buffers(fptr_t _fa, const std::vector& handles, const std::vector>& offsets); diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 8955a49a97bb..52a99d0675c7 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -187,7 +187,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Compute int8 quantized tensor for given scaling factor. ops.def( - "static_scaled_int8_quant(Tensor! out, Tensor input, float scale) -> ()"); + "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale) -> ()"); ops.impl("static_scaled_int8_quant", torch::kCUDA, &static_scaled_int8_quant); // Compute int8 quantized tensor and scaling factor diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 5d0b1db42838..5a6496373b78 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -260,7 +260,7 @@ def register_buffer(self, inp: torch.Tensor): def register_graph_buffers(self): handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr) - handles, offsets = self._gather_ipc_meta((handle, offset)) + handles, offsets = self._gather_ipc_meta((bytes(handle), offset)) logger.info("Registering %d cuda graph addresses", len(offset)) ops.register_graph_buffers(self._ptr, handles, offsets) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 6bed659defc5..4d0160ff296a 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -8,7 +8,6 @@ import triton import triton.language as tl -import vllm._moe_C as moe_kernels from vllm import _custom_ops as ops from vllm.logger import init_logger @@ -343,34 +342,26 @@ def fused_topk( M, _ = hidden_states.shape - if is_hip(): - # The MoE kernels are not yet supported on ROCm. - routing_weights = torch.softmax(gating_output, - dim=-1, - dtype=torch.float32) - topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1) - else: - assert ops.is_custom_op_supported("_moe_C::topk_softmax") - - topk_weights = torch.empty(M, - topk, - dtype=torch.float32, - device=hidden_states.device) - topk_ids = torch.empty(M, + topk_weights = torch.empty(M, topk, dtype=torch.float32, device=hidden_states.device) - token_expert_indicies = torch.empty(M, - topk, - dtype=torch.int32, - device=hidden_states.device) - ops.topk_softmax( - topk_weights, - topk_ids, - token_expert_indicies, - gating_output.float(), # TODO(woosuk): Optimize this. - ) - del token_expert_indicies # Not used. Will be used in the future. + topk_ids = torch.empty(M, + topk, + dtype=torch.int32, + device=hidden_states.device) + token_expert_indicies = torch.empty(M, + topk, + dtype=torch.int32, + device=hidden_states.device) + ops.topk_softmax( + topk_weights, + topk_ids, + token_expert_indicies, + gating_output.float(), # TODO(woosuk): Optimize this. + ) + del token_expert_indicies # Not used. Will be used in the future. + if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) return topk_weights, topk_ids From 0f35b0815eb5be2e94816da5180ef25e45a04339 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 3 Jun 2024 20:45:12 +0000 Subject: [PATCH 34/41] fix formatting --- csrc/custom_all_reduce.cu | 6 ++-- csrc/ops.h | 33 ++++++++++--------- csrc/pybind.cpp | 3 +- csrc/registration.h | 18 ++++------ vllm/_custom_ops.py | 2 +- .../device_communicators/custom_all_reduce.py | 2 +- 6 files changed, 31 insertions(+), 33 deletions(-) diff --git a/csrc/custom_all_reduce.cu b/csrc/custom_all_reduce.cu index 9f063a71ebb1..82a3563979f1 100644 --- a/csrc/custom_all_reduce.cu +++ b/csrc/custom_all_reduce.cu @@ -138,8 +138,10 @@ std::tuple> get_graph_buffer_ipc_meta( fptr_t _fa) { auto fa = reinterpret_cast(_fa); auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta(); - auto options = torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); - auto handles = torch::empty({static_cast(handle_bytes.size())}, options); + auto options = + torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); + auto handles = + torch::empty({static_cast(handle_bytes.size())}, options); std::memcpy(handles.data_ptr(), handle_bytes.data(), handle_bytes.size()); return {handles, std::move(offsets)}; } diff --git a/csrc/ops.h b/csrc/ops.h index 61efab180c67..b71df3a9b6aa 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -56,7 +56,8 @@ torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, const torch::Tensor& codebook_partition_sizes, const std::optional& bias); -torch::Tensor aqlm_gemm_meta(const torch::Tensor& input, const torch::Tensor& codes, +torch::Tensor aqlm_gemm_meta(const torch::Tensor& input, + const torch::Tensor& codes, const torch::Tensor& codebooks, const torch::Tensor& scales, const torch::Tensor& codebook_partition_sizes, @@ -75,8 +76,8 @@ torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, int64_t split_k_iters); torch::Tensor awq_gemm_meta(torch::Tensor _in_feats, torch::Tensor _kernel, - torch::Tensor _scaling_factors, torch::Tensor _zeros, - int64_t split_k_iters); + torch::Tensor _scaling_factors, + torch::Tensor _zeros, int64_t split_k_iters); torch::Tensor awq_dequantize(torch::Tensor _kernel, torch::Tensor _scaling_factors, @@ -93,8 +94,9 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, int64_t size_m, int64_t size_n, int64_t size_k); torch::Tensor marlin_gemm_meta(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& workspace, - int64_t size_m, int64_t size_n, int64_t size_k); + torch::Tensor& b_scales, + torch::Tensor& workspace, int64_t size_m, + int64_t size_n, int64_t size_k); torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_meta, @@ -103,12 +105,10 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, int64_t size_m, int64_t size_n, int64_t size_k); -torch::Tensor gptq_marlin_24_gemm_meta(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_meta, - torch::Tensor& b_scales, - torch::Tensor& workspace, int64_t num_bits, - int64_t size_m, int64_t size_n, - int64_t size_k); +torch::Tensor gptq_marlin_24_gemm_meta( + torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_meta, + torch::Tensor& b_scales, torch::Tensor& workspace, int64_t num_bits, + int64_t size_m, int64_t size_n, int64_t size_k); torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_scales, torch::Tensor& g_idx, @@ -117,9 +117,10 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, int64_t size_k, bool is_k_full); torch::Tensor gptq_marlin_gemm_meta(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& g_idx, - torch::Tensor& perm, torch::Tensor& workspace, - int64_t num_bits, int64_t size_m, int64_t size_n, + torch::Tensor& b_scales, + torch::Tensor& g_idx, torch::Tensor& perm, + torch::Tensor& workspace, int64_t num_bits, + int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full); torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, @@ -184,8 +185,8 @@ int64_t meta_size(); void register_buffer(fptr_t _fa, torch::Tensor& t, const std::vector& handles, const std::vector& offsets); -std::tuple> -get_graph_buffer_ipc_meta(fptr_t _fa); +std::tuple> get_graph_buffer_ipc_meta( + fptr_t _fa); void register_graph_buffers(fptr_t _fa, const std::vector& handles, const std::vector>& offsets); #endif diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 52a99d0675c7..62dabf8e542d 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -187,7 +187,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Compute int8 quantized tensor for given scaling factor. ops.def( - "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale) -> ()"); + "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale) -> " + "()"); ops.impl("static_scaled_int8_quant", torch::kCUDA, &static_scaled_int8_quant); // Compute int8 quantized tensor and scaling factor diff --git a/csrc/registration.h b/csrc/registration.h index 7002e24d74c3..e5396e9a8b13 100644 --- a/csrc/registration.h +++ b/csrc/registration.h @@ -14,15 +14,9 @@ // REGISTER_EXTENSION allows the shared library to be loaded and initialized // via python's import statement. -#define REGISTER_EXTENSION(NAME) \ -PyMODINIT_FUNC CONCAT(PyInit_,NAME)() \ -{ \ - static struct PyModuleDef module = { \ - PyModuleDef_HEAD_INIT, \ - STRINGIFY(NAME), \ - nullptr, \ - 0, \ - nullptr \ - }; \ - return PyModule_Create(&module); \ -} +#define REGISTER_EXTENSION(NAME) \ + PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \ + static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \ + STRINGIFY(NAME), nullptr, 0, nullptr}; \ + return PyModule_Create(&module); \ + } diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index e8f387f9ba6c..440b0e8afa99 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1,6 +1,6 @@ +import contextlib from typing import List, Optional, Tuple, Type -import contextlib import torch try: diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 5a6496373b78..4a0e19bc0c15 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -6,12 +6,12 @@ from torch.distributed import ProcessGroup import vllm.envs as envs +from vllm import _custom_ops as ops from vllm.distributed.device_communicators.custom_all_reduce_utils import ( gpu_p2p_access_check) from vllm.distributed.parallel_state import ( get_local_rank, get_tensor_model_parallel_cpu_group) from vllm.logger import init_logger -from vllm import _custom_ops as ops try: import pynvml From 3b722e3a84f1e68c8697b86bf08528fe7c690c6e Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 3 Jun 2024 23:18:53 +0000 Subject: [PATCH 35/41] fix test_int8_quant.py test --- tests/kernels/test_int8_quant.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index d37f7d2e6ef4..b8a5a8100312 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -55,11 +55,11 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int, x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 - out1 = (x / scale).round().clamp(int8_traits.min, - int8_traits.max).to(torch.int8) - out2 = torch.empty_like(x, dtype=torch.int8) + out1 = (x / scale).round().clamp( + torch.iinfo(torch.int8).min, + torch.iinfo(torch.int8).max).to(torch.int8) scale_argument = torch.tensor([scale], dtype=torch.float32, device="cuda") - ops.static_scaled_int8_quant(out2, x, scale_argument) + out2 = ops.static_scaled_int8_quant(x, scale_argument) assert torch.allclose(out1, out2, atol=1) # big atol to account for rounding errors From 3e5cec2220a4c18499379210d089c3dbcd08eb79 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 4 Jun 2024 19:30:06 +0000 Subject: [PATCH 36/41] remove meta ops for now --- csrc/ops.h | 46 ------------------- csrc/pybind.cpp | 9 ---- csrc/quantization/aqlm/gemm_kernels.cu | 31 ------------- csrc/quantization/awq/gemm_kernels.cu | 25 ---------- csrc/quantization/gptq/q_gemm.cu | 8 ---- csrc/quantization/gptq_marlin/gptq_marlin.cu | 10 ---- .../gptq_marlin/gptq_marlin_repack.cu | 12 ----- .../marlin/dense/marlin_cuda_kernel.cu | 8 ---- .../marlin/sparse/marlin_24_cuda_kernel.cu | 8 ---- 9 files changed, 157 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index b71df3a9b6aa..0c270a78c331 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -56,48 +56,23 @@ torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, const torch::Tensor& codebook_partition_sizes, const std::optional& bias); -torch::Tensor aqlm_gemm_meta(const torch::Tensor& input, - const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& scales, - const torch::Tensor& codebook_partition_sizes, - const std::optional& bias); - torch::Tensor aqlm_dequant(const torch::Tensor& codes, const torch::Tensor& codebooks, const torch::Tensor& codebook_partition_sizes); -torch::Tensor aqlm_dequant_meta(const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& codebook_partition_sizes); - torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scaling_factors, torch::Tensor _zeros, int64_t split_k_iters); -torch::Tensor awq_gemm_meta(torch::Tensor _in_feats, torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, int64_t split_k_iters); - torch::Tensor awq_dequantize(torch::Tensor _kernel, torch::Tensor _scaling_factors, torch::Tensor _zeros, int64_t split_k_iters, int64_t thx, int64_t thy); -torch::Tensor awq_dequantize_meta(torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, int64_t split_k_iters, - int64_t thx, int64_t thy); - torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_scales, torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k); -torch::Tensor marlin_gemm_meta(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, - torch::Tensor& workspace, int64_t size_m, - int64_t size_n, int64_t size_k); - torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_meta, torch::Tensor& b_scales, @@ -105,32 +80,16 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, int64_t size_m, int64_t size_n, int64_t size_k); -torch::Tensor gptq_marlin_24_gemm_meta( - torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_meta, - torch::Tensor& b_scales, torch::Tensor& workspace, int64_t num_bits, - int64_t size_m, int64_t size_n, int64_t size_k); - torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_scales, torch::Tensor& g_idx, torch::Tensor& perm, torch::Tensor& workspace, int64_t num_bits, int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full); -torch::Tensor gptq_marlin_gemm_meta(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, - torch::Tensor& g_idx, torch::Tensor& perm, - torch::Tensor& workspace, int64_t num_bits, - int64_t size_m, int64_t size_n, - int64_t size_k, bool is_k_full); - torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits); -torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight, - torch::Tensor& perm, int64_t size_k, - int64_t size_n, int64_t num_bits); - void cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales); @@ -151,11 +110,6 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, bool use_exllama, int64_t bit); -torch::Tensor gptq_gemm_meta(torch::Tensor a, torch::Tensor b_q_weight, - torch::Tensor b_gptq_qzeros, - torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, - bool use_exllama, int64_t bit); - void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit); void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input, diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 62dabf8e542d..2d27ecfa591c 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -104,42 +104,34 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Quantized GEMM for AQLM. ops.def("aqlm_gemm", &aqlm_gemm); ops.impl("aqlm_gemm", torch::kCUDA, &aqlm_gemm); - ops.impl("aqlm_gemm", torch::kMeta, &aqlm_gemm_meta); // Decompression method for AQLM. ops.def("aqlm_dequant", &aqlm_dequant); ops.impl("aqlm_dequant", torch::kCUDA, &aqlm_dequant); - ops.impl("aqlm_dequant", torch::kMeta, &aqlm_dequant_meta); // Quantized GEMM for AWQ. ops.def("awq_gemm", &awq_gemm); ops.impl("awq_gemm", torch::kCUDA, &awq_gemm); - ops.impl("awq_gemm", torch::kMeta, &awq_gemm_meta); // Dequantization for AWQ. ops.def("awq_dequantize", &awq_dequantize); ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize); - ops.impl("awq_dequantize", torch::kMeta, &awq_dequantize_meta); // Marlin (Dense) Optimized Quantized GEMM for GPTQ. ops.def("marlin_gemm", &marlin_gemm); ops.impl("marlin_gemm", torch::kCUDA, &marlin_gemm); - ops.impl("marlin_gemm", torch::kMeta, &marlin_gemm_meta); // Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ. ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm); ops.impl("gptq_marlin_24_gemm", torch::kCUDA, &gptq_marlin_24_gemm); - ops.impl("gptq_marlin_24_gemm", torch::kMeta, &gptq_marlin_24_gemm_meta); // gptq_marlin Optimized Quantized GEMM for GPTQ. ops.def("gptq_marlin_gemm", &gptq_marlin_gemm); ops.impl("gptq_marlin_gemm", torch::kCUDA, &gptq_marlin_gemm); - ops.impl("gptq_marlin_gemm", torch::kMeta, &gptq_marlin_gemm_meta); // gptq_marlin repack from GPTQ. ops.def("gptq_marlin_repack", &gptq_marlin_repack); ops.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack); - ops.impl("gptq_marlin_repack", torch::kMeta, &gptq_marlin_repack_meta); // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column // quantization. @@ -153,7 +145,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Quantized GEMM for GPTQ. ops.def("gptq_gemm", &gptq_gemm); ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm); - ops.impl("gptq_gemm", torch::kMeta, &gptq_gemm_meta); // Post processing for GPTQ. ops.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()"); diff --git a/csrc/quantization/aqlm/gemm_kernels.cu b/csrc/quantization/aqlm/gemm_kernels.cu index 2ed6325c56ce..8fb985680086 100644 --- a/csrc/quantization/aqlm/gemm_kernels.cu +++ b/csrc/quantization/aqlm/gemm_kernels.cu @@ -543,26 +543,6 @@ torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, return {}; } -torch::Tensor aqlm_gemm_meta(const torch::Tensor& input, - const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& scales, - const torch::Tensor& codebook_partition_sizes, - const std::optional& bias) { - auto input_sizes = input.sizes(); - - auto out_features = codes.size(0) * codebooks.size(2); - auto flat_input = input.reshape({-1, input.size(-1)}); - auto flat_output = torch::empty( - {flat_input.size(0), out_features}, - torch::TensorOptions().dtype(input.dtype()).device(input.device())); - - auto output_sizes = input_sizes.vec(); - output_sizes.pop_back(); - output_sizes.push_back(-1); - return flat_output.reshape(output_sizes); -} - torch::Tensor aqlm_dequant(const torch::Tensor& codes, const torch::Tensor& codebooks, const torch::Tensor& codebook_partition_sizes) { @@ -616,14 +596,3 @@ torch::Tensor aqlm_dequant(const torch::Tensor& codes, " entries is not currently supported.") return {}; } - -torch::Tensor aqlm_dequant_meta(const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& codebook_partition_sizes) { - auto in_features = codes.size(1) * 8; - auto out_features = codes.size(0); - return torch::empty({out_features, in_features}, - torch::TensorOptions() - .dtype(codebooks.dtype()) - .device(codebooks.device())); -} diff --git a/csrc/quantization/awq/gemm_kernels.cu b/csrc/quantization/awq/gemm_kernels.cu index dd71d286c740..6d6da5f3d874 100644 --- a/csrc/quantization/awq/gemm_kernels.cu +++ b/csrc/quantization/awq/gemm_kernels.cu @@ -483,21 +483,6 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel, return _de_kernel; } -torch::Tensor awq_dequantize_meta(torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, int64_t split_k_iters, - int64_t thx, int64_t thy) { - int in_c = _kernel.size(0); - int qout_c = _kernel.size(1); - int out_c = qout_c * 8; - - auto options = torch::TensorOptions() - .dtype(_scaling_factors.dtype()) - .device(_scaling_factors.device()); - - return torch::empty({in_c, out_c}, options); -} - // in_feats: M, IC [float16] // kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b] // scaling_factors: IC // G, OC [float16] @@ -562,13 +547,3 @@ torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, } return _out_feats.sum(0); } - -torch::Tensor awq_gemm_meta(torch::Tensor _in_feats, torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, int64_t split_k_iters) { - int num_in_feats = _in_feats.size(0); - auto options = torch::TensorOptions() - .dtype(_in_feats.dtype()) - .device(_in_feats.device()); - return torch::empty({_kernel.size(1) * 8}, options); -} diff --git a/csrc/quantization/gptq/q_gemm.cu b/csrc/quantization/gptq/q_gemm.cu index 52dc4ab0837b..785f1a09c190 100644 --- a/csrc/quantization/gptq/q_gemm.cu +++ b/csrc/quantization/gptq/q_gemm.cu @@ -1854,11 +1854,3 @@ void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit) { : (int*)q_perm.data_ptr(), q_weight.size(0) * 32 / bit, q_weight.size(1), bit); } - -torch::Tensor gptq_gemm_meta(torch::Tensor a, torch::Tensor b_q_weight, - torch::Tensor b_gptq_qzeros, - torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, - bool use_exllama, int64_t bit) { - auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); - return torch::empty({a.size(0), b_q_weight.size(1)}, options); -} diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 12192fcaef34..0beb9de14c68 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -1868,13 +1868,3 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, } #endif - -torch::Tensor gptq_marlin_gemm_meta(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, - torch::Tensor& g_idx, torch::Tensor& perm, - torch::Tensor& workspace, int64_t num_bits, - int64_t size_m, int64_t size_n, - int64_t size_k, bool is_k_full) { - auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); - return torch::empty({size_m, size_n}, options); -} diff --git a/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu b/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu index 6e0ad9cf3a61..4adc158eb14e 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu @@ -348,15 +348,3 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, } #endif - -torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight, - torch::Tensor& perm, int64_t size_k, - int64_t size_n, int64_t num_bits) { - int const pack_factor = 32 / num_bits; - auto options = torch::TensorOptions() - .dtype(b_q_weight.dtype()) - .device(b_q_weight.device()); - return torch::empty({size_k / gptq_marlin::tile_size, - size_n * gptq_marlin::tile_size / pack_factor}, - options); -} diff --git a/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu b/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu index 7a45d1327a20..d124c0149912 100644 --- a/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu +++ b/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu @@ -1134,11 +1134,3 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, return c; } - -torch::Tensor marlin_gemm_meta(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, - torch::Tensor& workspace, int64_t size_m, - int64_t size_n, int64_t size_k) { - auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); - return torch::empty({size_m, size_n}, options); -} diff --git a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu index b32c54bd5986..b5effc305544 100644 --- a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu +++ b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu @@ -1123,11 +1123,3 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, return c; } - -torch::Tensor gptq_marlin_24_gemm_meta( - torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_meta, - torch::Tensor& b_scales, torch::Tensor& workspace, int64_t num_bits, - int64_t size_m, int64_t size_n, int64_t size_k) { - auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); - return torch::empty({size_m, size_n}, options); -} From dc8732098ced42d95962d8b1de9010925d96459a Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 7 Jun 2024 20:44:39 +0000 Subject: [PATCH 37/41] rebase + some review comments --- csrc/cpu/pybind.cpp | 2 +- csrc/moe/moe_ops.cpp | 4 ++-- csrc/pybind.cpp | 12 +++++++----- tests/kernels/test_int8_quant.py | 12 ++++++------ 4 files changed, 16 insertions(+), 14 deletions(-) diff --git a/csrc/cpu/pybind.cpp b/csrc/cpu/pybind.cpp index 3d63a96b1405..a2bf0d49adba 100644 --- a/csrc/cpu/pybind.cpp +++ b/csrc/cpu/pybind.cpp @@ -84,7 +84,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { // Cache ops // Swap in (out) the cache blocks from src to dst. cache_ops.def( - "swap_blocks(Tensor! src, Tensor! dst, Tensor block_mapping) -> ()"); + "swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()"); cache_ops.impl("swap_blocks", torch::kCPU, &swap_blocks); // Copy the cache blocks from src to dst. diff --git a/csrc/moe/moe_ops.cpp b/csrc/moe/moe_ops.cpp index ed65ec63f913..243752b9a9e8 100644 --- a/csrc/moe/moe_ops.cpp +++ b/csrc/moe/moe_ops.cpp @@ -4,8 +4,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { // Apply topk softmax to the gating outputs. m.def( - "topk_softmax(Tensor! topk_weights, Tensor! topk_indices,Tensor " - "token_expert_indices,Tensor gating_output) -> ()"); + "topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! " + "token_expert_indices, Tensor gating_output) -> ()"); m.impl("topk_softmax", torch::kCUDA, &topk_softmax); } diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 2d27ecfa591c..df2603544c85 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -163,7 +163,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Compute FP8 quantized tensor and scaling factor. ops.def( - "dynamic_scaled_fp8_quant(Tensor! out, Tensor input, Tensor scale) -> " + "dynamic_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! scale) -> " "()"); ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant); @@ -183,16 +183,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("static_scaled_int8_quant", torch::kCUDA, &static_scaled_int8_quant); // Compute int8 quantized tensor and scaling factor - ops.def("dynamic_scaled_int8_quant", &dynamic_scaled_int8_quant, - "Compute int8 quantized tensor and scaling factor"); - ops.impl("dynamic_scaled_int8_quant", torch::kCUDA, &dynamic_scaled_int8_quant); + ops.def( + "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale) -> " + "()"); + ops.impl("dynamic_scaled_int8_quant", torch::kCUDA, + &dynamic_scaled_int8_quant); } TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { // Cache ops // Swap in (out) the cache blocks from src to dst. cache_ops.def( - "swap_blocks(Tensor! src, Tensor! dst, Tensor block_mapping) -> ()"); + "swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()"); cache_ops.impl("swap_blocks", torch::kCUDA, &swap_blocks); // Copy the cache blocks from src to dst. diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index b8a5a8100312..3fe98df6698d 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -1,7 +1,7 @@ import pytest import torch -from vllm import _custom_ops as ops +import vllm._C DTYPES = [torch.half, torch.bfloat16, torch.float] HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192, @@ -33,7 +33,7 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int, ops_out = torch.empty_like(x, dtype=torch.int8, device="cuda") scales_out = torch.empty_like(scales, dtype=torch.float32, device="cuda") - ops.dynamic_scaled_int8_quant(ops_out, x, scales_out) + torch.ops._C.dynamic_scaled_int8_quant(ops_out, x, scales_out) assert torch.allclose(scales_out, scales) assert torch.allclose(torch_out, ops_out, @@ -55,11 +55,11 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int, x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 - out1 = (x / scale).round().clamp( - torch.iinfo(torch.int8).min, - torch.iinfo(torch.int8).max).to(torch.int8) + out1 = (x / scale).round().clamp(int8_traits.min, + int8_traits.max).to(torch.int8) + out2 = torch.empty_like(x, dtype=torch.int8) scale_argument = torch.tensor([scale], dtype=torch.float32, device="cuda") - out2 = ops.static_scaled_int8_quant(x, scale_argument) + torch.ops._C.static_scaled_int8_quant(out2, x, scale_argument) assert torch.allclose(out1, out2, atol=1) # big atol to account for rounding errors From bb2446ed4cd914ea268003064df62ee830a9d599 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 7 Jun 2024 20:48:14 +0000 Subject: [PATCH 38/41] libtorch_python.so no longer needed? --- CMakeLists.txt | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index d49353ff2279..a0d665a47144 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -66,19 +66,6 @@ endif() # find_package(Torch REQUIRED) -# -# Normally `torch.utils.cpp_extension.CUDAExtension` would add -# `libtorch_python.so` for linking against an extension. Torch's cmake -# configuration does not include this library (presumably since the cmake -# config is used for standalone C++ binaries that link against torch). -# The `libtorch_python.so` library defines some of the glue code between -# torch/python via pybind and is required by VLLM extensions for this -# reason. So, add it by manually with `find_library` using torch's -# installed library path. -# -find_library(torch_python_LIBRARY torch_python PATHS - "${TORCH_INSTALL_PREFIX}/lib") - # # Forward the non-CUDA device extensions to external CMake scripts. # From 48868e12b4fe57e96dc983097365deec4868acda Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 7 Jun 2024 21:03:41 +0000 Subject: [PATCH 39/41] rename pybind files to torch_bindings.cpp --- CMakeLists.txt | 6 +++--- cmake/cpu_extension.cmake | 2 +- csrc/cpu/{pybind.cpp => torch_bindings.cpp} | 0 csrc/moe/{moe_ops.cpp => torch_bindings.cpp} | 0 csrc/punica/{punica_pybind.cpp => torch_bindings.cpp} | 0 csrc/{pybind.cpp => torch_bindings.cpp} | 0 6 files changed, 4 insertions(+), 4 deletions(-) rename csrc/cpu/{pybind.cpp => torch_bindings.cpp} (100%) rename csrc/moe/{moe_ops.cpp => torch_bindings.cpp} (100%) rename csrc/punica/{punica_pybind.cpp => torch_bindings.cpp} (100%) rename csrc/{pybind.cpp => torch_bindings.cpp} (100%) diff --git a/CMakeLists.txt b/CMakeLists.txt index a0d665a47144..ad6736c47f45 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -158,7 +158,7 @@ set(VLLM_EXT_SRC "csrc/quantization/fp8/common.cu" "csrc/cuda_utils_kernels.cu" "csrc/moe_align_block_size_kernels.cu" - "csrc/pybind.cpp") + "csrc/torch_bindings.cpp") if(VLLM_GPU_LANG STREQUAL "CUDA") include(FetchContent) @@ -213,7 +213,7 @@ define_gpu_extension_target( # set(VLLM_MOE_EXT_SRC - "csrc/moe/moe_ops.cpp" + "csrc/moe/torch_bindings.cpp" "csrc/moe/topk_softmax_kernels.cu") define_gpu_extension_target( @@ -238,7 +238,7 @@ set(VLLM_PUNICA_EXT_SRC "csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu" "csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu" "csrc/punica/punica_ops.cu" - "csrc/punica/punica_pybind.cpp") + "csrc/punica/torch_bindings.cpp") # # Copy GPU compilation flags+update for punica diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 4c6c6beb91b8..61d4843838ba 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -73,7 +73,7 @@ set(VLLM_EXT_SRC "csrc/cpu/cache.cpp" "csrc/cpu/layernorm.cpp" "csrc/cpu/pos_encoding.cpp" - "csrc/cpu/pybind.cpp") + "csrc/cpu/torch_bindings.cpp") define_gpu_extension_target( _C diff --git a/csrc/cpu/pybind.cpp b/csrc/cpu/torch_bindings.cpp similarity index 100% rename from csrc/cpu/pybind.cpp rename to csrc/cpu/torch_bindings.cpp diff --git a/csrc/moe/moe_ops.cpp b/csrc/moe/torch_bindings.cpp similarity index 100% rename from csrc/moe/moe_ops.cpp rename to csrc/moe/torch_bindings.cpp diff --git a/csrc/punica/punica_pybind.cpp b/csrc/punica/torch_bindings.cpp similarity index 100% rename from csrc/punica/punica_pybind.cpp rename to csrc/punica/torch_bindings.cpp diff --git a/csrc/pybind.cpp b/csrc/torch_bindings.cpp similarity index 100% rename from csrc/pybind.cpp rename to csrc/torch_bindings.cpp From 4ed8bf246c7ada6eb78aedd6c438b41d889dff2e Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 7 Jun 2024 21:42:25 +0000 Subject: [PATCH 40/41] add comments about const vectors --- csrc/cache.h | 3 +++ csrc/cache_kernels.cu | 3 +++ csrc/cpu/cache.cpp | 3 +++ 3 files changed, 9 insertions(+) diff --git a/csrc/cache.h b/csrc/cache.h index ba2cbdceaaf8..86caa9345361 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -8,6 +8,9 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst, const torch::Tensor& block_mapping); +// Note: the key_caches and value_caches vectors are constant but +// not the Tensors they contain. The vectors need to be const refs +// in order to satisfy pytorch's C++ operator registration code. void copy_blocks(std::vector const& key_caches, std::vector const& value_caches, const torch::Tensor& block_mapping); diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index e26bcb5d5cf7..72041076ae00 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -95,6 +95,9 @@ __global__ void copy_blocks_kernel(int64_t* key_cache_ptrs, } // namespace vllm +// Note: the key_caches and value_caches vectors are constant but +// not the Tensors they contain. The vectors need to be const refs +// in order to satisfy pytorch's C++ operator registration code. void copy_blocks(std::vector const& key_caches, std::vector const& value_caches, const torch::Tensor& block_mapping) { diff --git a/csrc/cpu/cache.cpp b/csrc/cpu/cache.cpp index 36e0523662b1..2b5c3bd6ee70 100644 --- a/csrc/cpu/cache.cpp +++ b/csrc/cpu/cache.cpp @@ -82,6 +82,9 @@ void reshape_and_cache_cpu_impl( } }; // namespace +// Note: the key_caches and value_caches vectors are constant but +// not the Tensors they contain. The vectors need to be const refs +// in order to satisfy pytorch's C++ operator registration code. void copy_blocks(std::vector const& key_caches, std::vector const& value_caches, const torch::Tensor& block_mapping) { From 57088e422ce3139a8fa10a4c96def83eab1f06c5 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sun, 9 Jun 2024 00:20:33 +0000 Subject: [PATCH 41/41] rebase + run format.sh --- tests/kernels/test_int8_quant.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index 3fe98df6698d..0daf7439468a 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -1,6 +1,7 @@ import pytest import torch +# ruff: noqa: F401 import vllm._C DTYPES = [torch.half, torch.bfloat16, torch.float]