Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 125 additions & 60 deletions csrc/xpu/attention_xpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -866,7 +866,7 @@ void context_attention_kernel_v1(
queue.submit(cgf);
}

template <typename T, int GS, int HD>
template <typename T, int BLOCK_SIZE, int HD>
void context_attention_kernel_v2(
void* query, void* key, void* value, const void* block_tables,
const float scale, const void* query_start_loc, const void* seq_lens,
Expand All @@ -884,7 +884,6 @@ void context_attention_kernel_v2(
const int num_queries_per_kv, const int max_input_length,
const int batch_size, const int num_heads, const int num_tokens,
const int max_context_len, const int max_q_len) {
constexpr int BLOCK_SIZE = 8;
constexpr int NUM_THREADS = 128;
// Each wrap handles one context block, therefore, each thread_group_size is
// this.
Expand All @@ -908,15 +907,15 @@ void context_attention_kernel_v2(

constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int padded_max_context_len =
DIVIDE_ROUND_UP(max_context_len + 1 + max_q_len, BLOCK_SIZE) * BLOCK_SIZE;
DIVIDE_ROUND_UP(max_context_len + max_q_len, BLOCK_SIZE) * BLOCK_SIZE;
int logits_size = padded_max_context_len * sizeof(float);
int outputs_size = (NUM_WARPS / 2) * HD * sizeof(float);
// Python-side check in
// vllm.worker.worker._check_if_can_support_max_seq_len Keep that in
// sync with the logic here!
int shared_mem_size = std::max(logits_size, outputs_size);
// WARN: we have changed this...
sycl::range<3> grid(batch_size, num_heads, max_q_len);
sycl::range<3> grid(batch_size, max_q_len, num_heads);
// One work-group that is executing on the device
sycl::range<3> block(1, 1, NUM_THREADS);
sycl::queue& queue = vllm::xpu::vllmGetQueue();
Expand All @@ -933,7 +932,7 @@ void context_attention_kernel_v2(
sycl::nd_range<3>(grid * block, block),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
const int bsz_idx = item_ct1.get_group(0);
const int seq_idx = item_ct1.get_group(2);
const int seq_idx = item_ct1.get_group(1);
constexpr bool USE_PARTITIONING = false;
int context_len = context_lens_ptr[bsz_idx] + seq_idx;
const int seq_len = seq_lens_ptr[bsz_idx];
Expand Down Expand Up @@ -977,8 +976,8 @@ void context_attention_kernel_v2(
const int thread_idx = item_ct1.get_local_id(2);
const int warp_idx = thread_idx / WARP_SIZE;
const int lane = thread_idx % WARP_SIZE;
const int head_idx = item_ct1.get_group(1);
const int num_heads = item_ct1.get_group_range(1);
const int head_idx = item_ct1.get_group(2);
const int num_heads = item_ct1.get_group_range(2);
const int kv_head_idx = head_idx / num_queries_per_kv;
// TODO: consider alibi_slope later
constexpr int NUM_ELEMS_PER_THREAD = HD / THREAD_GROUP_SIZE;
Expand Down Expand Up @@ -2535,7 +2534,7 @@ torch::Tensor context_attention_forward_v2(
// value: [num_blocks, num_kv_heads, head_size, block_dim]
int block_size = value.size(3);
// Currently, only block_size 16 is supported...
assert(block_size == 16);
// assert(block_size == 16);
int x = key.size(4);
int block_table_stride_bsz = block_tables.stride(0);
int block_table_stride_seq = block_tables.stride(1);
Expand All @@ -2551,60 +2550,126 @@ torch::Tensor context_attention_forward_v2(
int v_cache_stride_block = value.stride(3);
switch(head_dim) {
case 128:
vllm::context_attention_kernel_v2<sycl::half, 32, 128>(
query.data_ptr(), key.data_ptr(), value.data_ptr(),
block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(),
seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x,
output.data_ptr(), block_table_stride_bsz, block_table_stride_seq,
query_stride_token, query_stride_head, query_stride_dim,
k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim,
k_cache_stride_block, k_cache_stride_x, v_cache_stride_token,
v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block,
output.stride(0), output.stride(1), num_queries_per_kv,
max_input_length, batch_size, num_heads, query.size(0),
max_context_length, max_q_length);
switch(block_size) {
case 8:
vllm::context_attention_kernel_v2<sycl::half, 8, 128>(
query.data_ptr(), key.data_ptr(), value.data_ptr(),
block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(),
seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x,
output.data_ptr(), block_table_stride_bsz, block_table_stride_seq,
query_stride_token, query_stride_head, query_stride_dim,
k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim,
k_cache_stride_block, k_cache_stride_x, v_cache_stride_token,
v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block,
output.stride(0), output.stride(1), num_queries_per_kv,
max_input_length, batch_size, num_heads, query.size(0),
max_context_length, max_q_length);
break;
case 16:
vllm::context_attention_kernel_v2<sycl::half, 16, 128>(
query.data_ptr(), key.data_ptr(), value.data_ptr(),
block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(),
seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x,
output.data_ptr(), block_table_stride_bsz, block_table_stride_seq,
query_stride_token, query_stride_head, query_stride_dim,
k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim,
k_cache_stride_block, k_cache_stride_x, v_cache_stride_token,
v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block,
output.stride(0), output.stride(1), num_queries_per_kv,
max_input_length, batch_size, num_heads, query.size(0),
max_context_length, max_q_length);
break;
case 32:
vllm::context_attention_kernel_v2<sycl::half, 32, 128>(
query.data_ptr(), key.data_ptr(), value.data_ptr(),
block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(),
seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x,
output.data_ptr(), block_table_stride_bsz, block_table_stride_seq,
query_stride_token, query_stride_head, query_stride_dim,
k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim,
k_cache_stride_block, k_cache_stride_x, v_cache_stride_token,
v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block,
output.stride(0), output.stride(1), num_queries_per_kv,
max_input_length, batch_size, num_heads, query.size(0),
max_context_length, max_q_length);
break;
case 64:
vllm::context_attention_kernel_v2<sycl::half, 64, 128>(
query.data_ptr(), key.data_ptr(), value.data_ptr(),
block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(),
seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x,
output.data_ptr(), block_table_stride_bsz, block_table_stride_seq,
query_stride_token, query_stride_head, query_stride_dim,
k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim,
k_cache_stride_block, k_cache_stride_x, v_cache_stride_token,
v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block,
output.stride(0), output.stride(1), num_queries_per_kv,
max_input_length, batch_size, num_heads, query.size(0),
max_context_length, max_q_length);
break;
default: throw std::runtime_error("unsupported block_size");
}
break;
case 64:
vllm::context_attention_kernel_v2<sycl::half, 32, 64>(
query.data_ptr(), key.data_ptr(), value.data_ptr(),
block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(),
seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x,
output.data_ptr(), block_table_stride_bsz, block_table_stride_seq,
query_stride_token, query_stride_head, query_stride_dim,
k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim,
k_cache_stride_block, k_cache_stride_x, v_cache_stride_token,
v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block,
output.stride(0), output.stride(1), num_queries_per_kv,
max_input_length, batch_size, num_heads, query.size(0),
max_context_length, max_q_length);
break;
case 80:
vllm::context_attention_kernel_v2<sycl::half, 32, 80>(
query.data_ptr(), key.data_ptr(), value.data_ptr(),
block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(),
seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x,
output.data_ptr(), block_table_stride_bsz, block_table_stride_seq,
query_stride_token, query_stride_head, query_stride_dim,
k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim,
k_cache_stride_block, k_cache_stride_x, v_cache_stride_token,
v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block,
output.stride(0), output.stride(1), num_queries_per_kv,
max_input_length, batch_size, num_heads, query.size(0),
max_context_length, max_q_length);
break;
case 96:
vllm::context_attention_kernel_v2<sycl::half, 32, 96>(
query.data_ptr(), key.data_ptr(), value.data_ptr(),
block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(),
seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x,
output.data_ptr(), block_table_stride_bsz, block_table_stride_seq,
query_stride_token, query_stride_head, query_stride_dim,
k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim,
k_cache_stride_block, k_cache_stride_x, v_cache_stride_token,
v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block,
output.stride(0), output.stride(1), num_queries_per_kv,
max_input_length, batch_size, num_heads, query.size(0),
max_context_length, max_q_length);
switch(block_size) {
case 8:
vllm::context_attention_kernel_v2<sycl::half, 8, 64>(
query.data_ptr(), key.data_ptr(), value.data_ptr(),
block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(),
seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x,
output.data_ptr(), block_table_stride_bsz, block_table_stride_seq,
query_stride_token, query_stride_head, query_stride_dim,
k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim,
k_cache_stride_block, k_cache_stride_x, v_cache_stride_token,
v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block,
output.stride(0), output.stride(1), num_queries_per_kv,
max_input_length, batch_size, num_heads, query.size(0),
max_context_length, max_q_length);
break;
case 16:
vllm::context_attention_kernel_v2<sycl::half, 16, 64>(
query.data_ptr(), key.data_ptr(), value.data_ptr(),
block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(),
seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x,
output.data_ptr(), block_table_stride_bsz, block_table_stride_seq,
query_stride_token, query_stride_head, query_stride_dim,
k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim,
k_cache_stride_block, k_cache_stride_x, v_cache_stride_token,
v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block,
output.stride(0), output.stride(1), num_queries_per_kv,
max_input_length, batch_size, num_heads, query.size(0),
max_context_length, max_q_length);
break;
case 32:
vllm::context_attention_kernel_v2<sycl::half, 32, 64>(
query.data_ptr(), key.data_ptr(), value.data_ptr(),
block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(),
seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x,
output.data_ptr(), block_table_stride_bsz, block_table_stride_seq,
query_stride_token, query_stride_head, query_stride_dim,
k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim,
k_cache_stride_block, k_cache_stride_x, v_cache_stride_token,
v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block,
output.stride(0), output.stride(1), num_queries_per_kv,
max_input_length, batch_size, num_heads, query.size(0),
max_context_length, max_q_length);
break;
case 64:
vllm::context_attention_kernel_v2<sycl::half, 64, 64>(
query.data_ptr(), key.data_ptr(), value.data_ptr(),
block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(),
seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x,
output.data_ptr(), block_table_stride_bsz, block_table_stride_seq,
query_stride_token, query_stride_head, query_stride_dim,
k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim,
k_cache_stride_block, k_cache_stride_x, v_cache_stride_token,
v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block,
output.stride(0), output.stride(1), num_queries_per_kv,
max_input_length, batch_size, num_heads, query.size(0),
max_context_length, max_q_length);
break;
default: throw std::runtime_error("unsupported block_size");
}
break;
default: throw std::runtime_error("unsupported head_dim");
}
Expand Down