From b1176ff56ebc1305bd97324dab06028f19ee6fb7 Mon Sep 17 00:00:00 2001 From: xiangyuT Date: Tue, 18 Mar 2025 17:13:42 +0800 Subject: [PATCH 1/2] init --- csrc/xpu/attention_xpu.cpp | 185 +++++++++++++++++++++++++------------ 1 file changed, 125 insertions(+), 60 deletions(-) diff --git a/csrc/xpu/attention_xpu.cpp b/csrc/xpu/attention_xpu.cpp index 807fa368fefe..8f96171d3895 100644 --- a/csrc/xpu/attention_xpu.cpp +++ b/csrc/xpu/attention_xpu.cpp @@ -866,7 +866,7 @@ void context_attention_kernel_v1( queue.submit(cgf); } -template +template void context_attention_kernel_v2( void* query, void* key, void* value, const void* block_tables, const float scale, const void* query_start_loc, const void* seq_lens, @@ -884,7 +884,6 @@ void context_attention_kernel_v2( const int num_queries_per_kv, const int max_input_length, const int batch_size, const int num_heads, const int num_tokens, const int max_context_len, const int max_q_len) { - constexpr int BLOCK_SIZE = 8; constexpr int NUM_THREADS = 128; // Each wrap handles one context block, therefore, each thread_group_size is // this. @@ -908,7 +907,7 @@ void context_attention_kernel_v2( constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; int padded_max_context_len = - DIVIDE_ROUND_UP(max_context_len + 1 + max_q_len, BLOCK_SIZE) * BLOCK_SIZE; + DIVIDE_ROUND_UP(max_context_len + max_q_len, BLOCK_SIZE) * BLOCK_SIZE; int logits_size = padded_max_context_len * sizeof(float); int outputs_size = (NUM_WARPS / 2) * HD * sizeof(float); // Python-side check in @@ -916,7 +915,7 @@ void context_attention_kernel_v2( // sync with the logic here! int shared_mem_size = std::max(logits_size, outputs_size); // WARN: we have changed this... - sycl::range<3> grid(batch_size, num_heads, max_q_len); + sycl::range<3> grid(batch_size, max_q_len, num_heads); // One work-group that is executing on the device sycl::range<3> block(1, 1, NUM_THREADS); sycl::queue& queue = vllm::xpu::vllmGetQueue(); @@ -933,7 +932,7 @@ void context_attention_kernel_v2( sycl::nd_range<3>(grid * block, block), [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { const int bsz_idx = item_ct1.get_group(0); - const int seq_idx = item_ct1.get_group(2); + const int seq_idx = item_ct1.get_group(1); constexpr bool USE_PARTITIONING = false; int context_len = context_lens_ptr[bsz_idx] + seq_idx; const int seq_len = seq_lens_ptr[bsz_idx]; @@ -977,8 +976,8 @@ void context_attention_kernel_v2( const int thread_idx = item_ct1.get_local_id(2); const int warp_idx = thread_idx / WARP_SIZE; const int lane = thread_idx % WARP_SIZE; - const int head_idx = item_ct1.get_group(1); - const int num_heads = item_ct1.get_group_range(1); + const int head_idx = item_ct1.get_group(2); + const int num_heads = item_ct1.get_group_range(2); const int kv_head_idx = head_idx / num_queries_per_kv; // TODO: consider alibi_slope later constexpr int NUM_ELEMS_PER_THREAD = HD / THREAD_GROUP_SIZE; @@ -2535,7 +2534,7 @@ torch::Tensor context_attention_forward_v2( // value: [num_blocks, num_kv_heads, head_size, block_dim] int block_size = value.size(3); // Currently, only block_size 16 is supported... - assert(block_size == 16); + // assert(block_size == 16); int x = key.size(4); int block_table_stride_bsz = block_tables.stride(0); int block_table_stride_seq = block_tables.stride(1); @@ -2551,60 +2550,126 @@ torch::Tensor context_attention_forward_v2( int v_cache_stride_block = value.stride(3); switch(head_dim) { case 128: - vllm::context_attention_kernel_v2( - query.data_ptr(), key.data_ptr(), value.data_ptr(), - block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), - seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x, - output.data_ptr(), block_table_stride_bsz, block_table_stride_seq, - query_stride_token, query_stride_head, query_stride_dim, - k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim, - k_cache_stride_block, k_cache_stride_x, v_cache_stride_token, - v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block, - output.stride(0), output.stride(1), num_queries_per_kv, - max_input_length, batch_size, num_heads, query.size(0), - max_context_length, max_q_length); + switch(block_size) { + case 8: + vllm::context_attention_kernel_v2( + query.data_ptr(), key.data_ptr(), value.data_ptr(), + block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), + seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x, + output.data_ptr(), block_table_stride_bsz, block_table_stride_seq, + query_stride_token, query_stride_head, query_stride_dim, + k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim, + k_cache_stride_block, k_cache_stride_x, v_cache_stride_token, + v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block, + output.stride(0), output.stride(1), num_queries_per_kv, + max_input_length, batch_size, num_heads, query.size(0), + max_context_length, max_q_length); + break; + case 16: + vllm::context_attention_kernel_v2( + query.data_ptr(), key.data_ptr(), value.data_ptr(), + block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), + seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x, + output.data_ptr(), block_table_stride_bsz, block_table_stride_seq, + query_stride_token, query_stride_head, query_stride_dim, + k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim, + k_cache_stride_block, k_cache_stride_x, v_cache_stride_token, + v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block, + output.stride(0), output.stride(1), num_queries_per_kv, + max_input_length, batch_size, num_heads, query.size(0), + max_context_length, max_q_length); + break; + case 32: + vllm::context_attention_kernel_v2( + query.data_ptr(), key.data_ptr(), value.data_ptr(), + block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), + seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x, + output.data_ptr(), block_table_stride_bsz, block_table_stride_seq, + query_stride_token, query_stride_head, query_stride_dim, + k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim, + k_cache_stride_block, k_cache_stride_x, v_cache_stride_token, + v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block, + output.stride(0), output.stride(1), num_queries_per_kv, + max_input_length, batch_size, num_heads, query.size(0), + max_context_length, max_q_length); + break; + case 64: + vllm::context_attention_kernel_v2( + query.data_ptr(), key.data_ptr(), value.data_ptr(), + block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), + seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x, + output.data_ptr(), block_table_stride_bsz, block_table_stride_seq, + query_stride_token, query_stride_head, query_stride_dim, + k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim, + k_cache_stride_block, k_cache_stride_x, v_cache_stride_token, + v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block, + output.stride(0), output.stride(1), num_queries_per_kv, + max_input_length, batch_size, num_heads, query.size(0), + max_context_length, max_q_length); + break; + default: throw std::runtime_error("unsupported block_size"); + } break; case 64: - vllm::context_attention_kernel_v2( - query.data_ptr(), key.data_ptr(), value.data_ptr(), - block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), - seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x, - output.data_ptr(), block_table_stride_bsz, block_table_stride_seq, - query_stride_token, query_stride_head, query_stride_dim, - k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim, - k_cache_stride_block, k_cache_stride_x, v_cache_stride_token, - v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block, - output.stride(0), output.stride(1), num_queries_per_kv, - max_input_length, batch_size, num_heads, query.size(0), - max_context_length, max_q_length); - break; - case 80: - vllm::context_attention_kernel_v2( - query.data_ptr(), key.data_ptr(), value.data_ptr(), - block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), - seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x, - output.data_ptr(), block_table_stride_bsz, block_table_stride_seq, - query_stride_token, query_stride_head, query_stride_dim, - k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim, - k_cache_stride_block, k_cache_stride_x, v_cache_stride_token, - v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block, - output.stride(0), output.stride(1), num_queries_per_kv, - max_input_length, batch_size, num_heads, query.size(0), - max_context_length, max_q_length); - break; - case 96: - vllm::context_attention_kernel_v2( - query.data_ptr(), key.data_ptr(), value.data_ptr(), - block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), - seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x, - output.data_ptr(), block_table_stride_bsz, block_table_stride_seq, - query_stride_token, query_stride_head, query_stride_dim, - k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim, - k_cache_stride_block, k_cache_stride_x, v_cache_stride_token, - v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block, - output.stride(0), output.stride(1), num_queries_per_kv, - max_input_length, batch_size, num_heads, query.size(0), - max_context_length, max_q_length); + switch(block_size) { + case 8: + vllm::context_attention_kernel_v2( + query.data_ptr(), key.data_ptr(), value.data_ptr(), + block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), + seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x, + output.data_ptr(), block_table_stride_bsz, block_table_stride_seq, + query_stride_token, query_stride_head, query_stride_dim, + k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim, + k_cache_stride_block, k_cache_stride_x, v_cache_stride_token, + v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block, + output.stride(0), output.stride(1), num_queries_per_kv, + max_input_length, batch_size, num_heads, query.size(0), + max_context_length, max_q_length); + break; + case 16: + vllm::context_attention_kernel_v2( + query.data_ptr(), key.data_ptr(), value.data_ptr(), + block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), + seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x, + output.data_ptr(), block_table_stride_bsz, block_table_stride_seq, + query_stride_token, query_stride_head, query_stride_dim, + k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim, + k_cache_stride_block, k_cache_stride_x, v_cache_stride_token, + v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block, + output.stride(0), output.stride(1), num_queries_per_kv, + max_input_length, batch_size, num_heads, query.size(0), + max_context_length, max_q_length); + break; + case 32: + vllm::context_attention_kernel_v2( + query.data_ptr(), key.data_ptr(), value.data_ptr(), + block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), + seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x, + output.data_ptr(), block_table_stride_bsz, block_table_stride_seq, + query_stride_token, query_stride_head, query_stride_dim, + k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim, + k_cache_stride_block, k_cache_stride_x, v_cache_stride_token, + v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block, + output.stride(0), output.stride(1), num_queries_per_kv, + max_input_length, batch_size, num_heads, query.size(0), + max_context_length, max_q_length); + break; + case 64: + vllm::context_attention_kernel_v2( + query.data_ptr(), key.data_ptr(), value.data_ptr(), + block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), + seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x, + output.data_ptr(), block_table_stride_bsz, block_table_stride_seq, + query_stride_token, query_stride_head, query_stride_dim, + k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim, + k_cache_stride_block, k_cache_stride_x, v_cache_stride_token, + v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block, + output.stride(0), output.stride(1), num_queries_per_kv, + max_input_length, batch_size, num_heads, query.size(0), + max_context_length, max_q_length); + break; + default: throw std::runtime_error("unsupported block_size"); + } break; default: throw std::runtime_error("unsupported head_dim"); } From 1b9f239a6444b2f66cde311760cede864998dbbb Mon Sep 17 00:00:00 2001 From: xiangyuT Date: Thu, 20 Mar 2025 10:19:34 +0800 Subject: [PATCH 2/2] refine --- csrc/xpu/attention_xpu.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/csrc/xpu/attention_xpu.cpp b/csrc/xpu/attention_xpu.cpp index 8f96171d3895..af8da0bf0ee2 100644 --- a/csrc/xpu/attention_xpu.cpp +++ b/csrc/xpu/attention_xpu.cpp @@ -866,7 +866,7 @@ void context_attention_kernel_v1( queue.submit(cgf); } -template +template void context_attention_kernel_v2( void* query, void* key, void* value, const void* block_tables, const float scale, const void* query_start_loc, const void* seq_lens, @@ -2552,7 +2552,7 @@ torch::Tensor context_attention_forward_v2( case 128: switch(block_size) { case 8: - vllm::context_attention_kernel_v2( + vllm::context_attention_kernel_v2( query.data_ptr(), key.data_ptr(), value.data_ptr(), block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x, @@ -2566,7 +2566,7 @@ torch::Tensor context_attention_forward_v2( max_context_length, max_q_length); break; case 16: - vllm::context_attention_kernel_v2( + vllm::context_attention_kernel_v2( query.data_ptr(), key.data_ptr(), value.data_ptr(), block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x, @@ -2580,7 +2580,7 @@ torch::Tensor context_attention_forward_v2( max_context_length, max_q_length); break; case 32: - vllm::context_attention_kernel_v2( + vllm::context_attention_kernel_v2( query.data_ptr(), key.data_ptr(), value.data_ptr(), block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x, @@ -2594,7 +2594,7 @@ torch::Tensor context_attention_forward_v2( max_context_length, max_q_length); break; case 64: - vllm::context_attention_kernel_v2( + vllm::context_attention_kernel_v2( query.data_ptr(), key.data_ptr(), value.data_ptr(), block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x, @@ -2613,7 +2613,7 @@ torch::Tensor context_attention_forward_v2( case 64: switch(block_size) { case 8: - vllm::context_attention_kernel_v2( + vllm::context_attention_kernel_v2( query.data_ptr(), key.data_ptr(), value.data_ptr(), block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x, @@ -2627,7 +2627,7 @@ torch::Tensor context_attention_forward_v2( max_context_length, max_q_length); break; case 16: - vllm::context_attention_kernel_v2( + vllm::context_attention_kernel_v2( query.data_ptr(), key.data_ptr(), value.data_ptr(), block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x, @@ -2641,7 +2641,7 @@ torch::Tensor context_attention_forward_v2( max_context_length, max_q_length); break; case 32: - vllm::context_attention_kernel_v2( + vllm::context_attention_kernel_v2( query.data_ptr(), key.data_ptr(), value.data_ptr(), block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x, @@ -2655,7 +2655,7 @@ torch::Tensor context_attention_forward_v2( max_context_length, max_q_length); break; case 64: - vllm::context_attention_kernel_v2( + vllm::context_attention_kernel_v2( query.data_ptr(), key.data_ptr(), value.data_ptr(), block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x,