Skip to content

Commit 0f9a6e3

Browse files
authored
[Bugfix][Kernel] allow non-power-of-2 for prefix prefill with alibi (#4573)
1 parent f6a5930 commit 0f9a6e3

File tree

2 files changed

+267
-17
lines changed

2 files changed

+267
-17
lines changed

tests/kernels/test_prefix_prefill.py

Lines changed: 242 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import math
12
import random
23
import time
34

@@ -6,11 +7,12 @@
67
from xformers import ops as xops
78
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
89

10+
from vllm.attention.backends.xformers import _make_alibi_bias
911
from vllm.attention.ops.prefix_prefill import context_attention_fwd
1012

1113
NUM_HEADS = [64]
1214
NUM_QUERIES_PER_KV = [1, 8, 64]
13-
HEAD_SIZES = [128, 96]
15+
HEAD_SIZES = [128, 96, 24]
1416
DTYPES = [torch.float16]
1517
CUDA_DEVICES = [
1618
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
@@ -207,3 +209,242 @@ def test_contexted_kv_attention(
207209
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
208210
output_ref = output_ref.reshape(output.shape)
209211
assert torch.allclose(output_ref, output, atol=1e-6, rtol=0)
212+
213+
214+
@pytest.mark.parametrize("num_heads", NUM_HEADS)
215+
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
216+
@pytest.mark.parametrize("head_size", HEAD_SIZES)
217+
@pytest.mark.parametrize("dtype", DTYPES)
218+
@pytest.mark.parametrize("device", CUDA_DEVICES)
219+
@torch.inference_mode()
220+
def test_contexted_kv_attention_alibi(
221+
num_heads: int,
222+
num_queries_per_kv: int,
223+
head_size: int,
224+
dtype: torch.dtype,
225+
device: str,
226+
) -> None:
227+
random.seed(0)
228+
torch.manual_seed(0)
229+
if torch.cuda.is_available():
230+
torch.cuda.manual_seed(0)
231+
torch.set_default_device(device)
232+
233+
# Need this, otherwise when we capture the graph the process
234+
# for GPU 1 would run on both GPU0 and GPU1 and things would hang
235+
#
236+
# see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523
237+
torch.cuda.set_device(device)
238+
239+
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
240+
# Fork from: vllm/vllm/model_executor/models/bloom.py#L44
241+
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
242+
base = torch.tensor(
243+
2**(-(2**-(math.log2(closest_power_of_2) - 3))),
244+
dtype=torch.float32,
245+
)
246+
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
247+
slopes = torch.pow(base, powers)
248+
249+
if closest_power_of_2 != total_num_heads:
250+
extra_base = torch.tensor(
251+
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
252+
dtype=torch.float32,
253+
)
254+
num_remaining_heads = min(closest_power_of_2,
255+
total_num_heads - closest_power_of_2)
256+
extra_powers = torch.arange(start=1,
257+
end=1 + 2 * num_remaining_heads,
258+
step=2,
259+
dtype=torch.int32)
260+
slopes = torch.cat(
261+
[slopes, torch.pow(extra_base, extra_powers)], dim=0)
262+
return slopes
263+
264+
alibi_slopes = _get_alibi_slopes(num_heads).to(device)
265+
266+
MAX_SEQ_LEN = 1024
267+
MAX_CTX_LEN = 1024
268+
BS = 10
269+
cache_size = 640
270+
block_size = 32
271+
max_block_per_request = 64
272+
query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
273+
ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
274+
seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
275+
num_kv_heads = num_heads // num_queries_per_kv
276+
277+
num_tokens = sum(query_lens)
278+
query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
279+
query.uniform_(-1e-3, 1e-3)
280+
output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
281+
282+
kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype)
283+
kv.uniform_(-1e-3, 1e-3)
284+
key, value = kv.unbind(dim=1)
285+
286+
k_cache = torch.zeros(cache_size,
287+
block_size,
288+
num_kv_heads,
289+
head_size,
290+
dtype=dtype)
291+
v_cache = torch.zeros(cache_size,
292+
block_size,
293+
num_kv_heads,
294+
head_size,
295+
dtype=dtype)
296+
k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
297+
v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
298+
values = torch.arange(0, cache_size, dtype=torch.long)
299+
values = values[torch.randperm(cache_size)]
300+
block_table = values[:BS * max_block_per_request].view(
301+
BS, max_block_per_request)
302+
b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
303+
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
304+
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1],
305+
dtype=torch.long),
306+
dim=0)
307+
max_input_len = MAX_SEQ_LEN
308+
# copy kv to cache
309+
b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1],
310+
dtype=torch.long),
311+
dim=0)
312+
for i in range(BS):
313+
for j in range(query_lens[i]):
314+
k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] +
315+
j])
316+
v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] +
317+
b_ctx_len[i] + j])
318+
cur_ctx = 0
319+
block_id = 0
320+
while cur_ctx < b_ctx_len[i]:
321+
start_loc = b_seq_start_loc[i] + cur_ctx
322+
if cur_ctx + block_size > b_ctx_len[i]:
323+
end_loc = b_seq_start_loc[i] + b_ctx_len[i]
324+
else:
325+
end_loc = start_loc + block_size
326+
start_slot = block_table[i, block_id] * block_size
327+
end_slot = start_slot + end_loc - start_loc
328+
k_cache.view(-1, num_kv_heads,
329+
head_size)[start_slot:end_slot].copy_(
330+
key[start_loc:end_loc])
331+
v_cache.view(-1, num_kv_heads,
332+
head_size)[start_slot:end_slot].copy_(
333+
value[start_loc:end_loc])
334+
cur_ctx += block_size
335+
block_id += 1
336+
# transpose K_cache[num_blocks, block_size, num_kv_heads, head_size]
337+
# to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8]
338+
k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8,
339+
8).permute(0, 2, 3, 1, 4).contiguous()
340+
# transpose V_cache[num_blocks, block_size, num_kv_heads, head_size]
341+
# to V_cache[num_blocks, num_kv_heads, head_size, block_size]
342+
v_cache = v_cache.view(-1, block_size, num_kv_heads,
343+
head_size).permute(0, 2, 3, 1).contiguous()
344+
345+
# Warm up the Triton kernel by calling it once before actually measuring
346+
# generation time
347+
context_attention_fwd(query,
348+
k,
349+
v,
350+
output,
351+
k_cache,
352+
v_cache,
353+
block_table,
354+
b_start_loc,
355+
b_seq_len,
356+
b_ctx_len,
357+
max_input_len,
358+
alibi_slopes=alibi_slopes)
359+
torch.cuda.synchronize()
360+
start_time = time.time()
361+
context_attention_fwd(query,
362+
k,
363+
v,
364+
output,
365+
k_cache,
366+
v_cache,
367+
block_table,
368+
b_start_loc,
369+
b_seq_len,
370+
b_ctx_len,
371+
max_input_len,
372+
alibi_slopes=alibi_slopes)
373+
torch.cuda.synchronize()
374+
end_time = time.time()
375+
print(f"triton Time: {(end_time - start_time)*1000:.2f} ms")
376+
scale = float(1.0 / (head_size**0.5))
377+
378+
# NOTE(DefTruth): In order to reuse _make_alibi_bias function,
379+
# we have to pad query tensor before MQA/GQA expanding.
380+
if query.shape[0] != key.shape[0]:
381+
query_pad = torch.empty(sum(seq_lens),
382+
num_heads,
383+
head_size,
384+
dtype=dtype)
385+
query_pad.uniform_(-1e-3, 1e-3)
386+
seq_start = 0
387+
query_start = 0
388+
for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
389+
seq_end = seq_start + seq_len
390+
query_end = query_start + query_len
391+
query_pad[seq_start:seq_end, ...] = torch.cat([
392+
torch.zeros(
393+
seq_len - query_len, num_heads, head_size, dtype=dtype),
394+
query[query_start:query_end, ...]
395+
],
396+
dim=0)
397+
seq_start += seq_len
398+
query_start += query_len
399+
query = query_pad
400+
401+
if num_kv_heads != num_heads:
402+
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
403+
# project the key and value tensors to the desired number of
404+
# heads.
405+
#
406+
# see also: vllm/model_executor/layers/attention.py
407+
query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv,
408+
query.shape[-1])
409+
key = key[:, :, None, :].expand(key.shape[0], num_kv_heads,
410+
num_queries_per_kv, key.shape[-1])
411+
value = value[:, :,
412+
None, :].expand(value.shape[0], num_kv_heads,
413+
num_queries_per_kv, value.shape[-1])
414+
415+
query = query.unsqueeze(0)
416+
key = key.unsqueeze(0)
417+
value = value.unsqueeze(0)
418+
419+
attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens)
420+
output_ref = torch.empty_like(output)
421+
seq_start = 0
422+
query_start = 0
423+
start_time = time.time()
424+
# Attention with alibi slopes.
425+
# FIXME(DefTruth): Because xformers does not support dynamic sequence
426+
# lengths with custom attention bias, we process each prompt one by
427+
# one. This is inefficient, especially when we have many short prompts.
428+
# modified from: vllm/attention/backends/xformers.py#L343
429+
for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
430+
seq_end = seq_start + seq_len
431+
query_end = query_start + query_len
432+
out = xops.memory_efficient_attention_forward(query[:,
433+
seq_start:seq_end],
434+
key[:,
435+
seq_start:seq_end],
436+
value[:,
437+
seq_start:seq_end],
438+
attn_bias=attn_bias[i],
439+
p=0.0,
440+
scale=scale)
441+
out = out.view_as(query[:, seq_start:seq_end]).view(
442+
seq_len, num_heads, head_size)
443+
output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len:,
444+
...])
445+
seq_start += seq_len
446+
query_start += query_len
447+
torch.cuda.synchronize()
448+
end_time = time.time()
449+
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
450+
assert torch.allclose(output_ref, output, atol=1e-6, rtol=0)

vllm/attention/ops/prefix_prefill.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,8 @@ def _fwd_kernel_alibi(
472472
stride_v_cache_bl,
473473
num_queries_per_kv: int,
474474
BLOCK_M: tl.constexpr,
475-
BLOCK_DMODEL: tl.constexpr,
475+
BLOCK_DMODEL: tl.constexpr, # head size
476+
BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
476477
BLOCK_N: tl.constexpr,
477478
):
478479
# attn_bias[]
@@ -493,21 +494,24 @@ def _fwd_kernel_alibi(
493494

494495
# initialize offsets
495496
offs_n = tl.arange(0, BLOCK_N)
496-
offs_d = tl.arange(0, BLOCK_DMODEL)
497+
offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
497498
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
498499
off_q = (
499500
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
500501
cur_head * stride_qh + offs_d[None, :] * stride_qd)
501502

502-
q = tl.load(
503-
Q + off_q,
504-
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
505-
other=0.0)
503+
dim_mask = tl.where(
504+
tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1)
505+
506+
q = tl.load(Q + off_q,
507+
mask=dim_mask[None, :] &
508+
(offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len),
509+
other=0.0)
506510

507511
# # initialize pointer to m and l
508512
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
509513
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
510-
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
514+
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32)
511515

512516
alibi_slope = tl.load(Alibi_slopes + cur_head)
513517
alibi_start_q = tl.arange(
@@ -532,8 +536,9 @@ def _fwd_kernel_alibi(
532536
offs_d[None, :] * stride_v_cache_d +
533537
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
534538
k = tl.load(K_cache + off_k,
535-
mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
536-
other=0.0)
539+
mask=dim_mask[:, None] &
540+
((start_n + offs_n[None, :]) < cur_batch_ctx_len),
541+
other=0.0) # [D,N]
537542

538543
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
539544
qk += tl.dot(q, k)
@@ -567,7 +572,8 @@ def _fwd_kernel_alibi(
567572
acc = acc * acc_scale[:, None]
568573
# update acc
569574
v = tl.load(V_cache + off_v,
570-
mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
575+
mask=dim_mask[None, :] &
576+
((start_n + offs_n[:, None]) < cur_batch_ctx_len),
571577
other=0.0)
572578

573579
p = p.to(v.dtype)
@@ -600,8 +606,9 @@ def _fwd_kernel_alibi(
600606
# -- compute qk ----
601607
k = tl.load(k_ptrs +
602608
(cur_batch_in_all_start_index + start_n) * stride_kbs,
603-
mask=(start_n + offs_n[None, :]) <
604-
cur_batch_seq_len - cur_batch_ctx_len,
609+
mask=dim_mask[:, None] &
610+
((start_n + offs_n[None, :]) <
611+
cur_batch_seq_len - cur_batch_ctx_len),
605612
other=0.0)
606613

607614
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
@@ -637,8 +644,9 @@ def _fwd_kernel_alibi(
637644
# update acc
638645
v = tl.load(v_ptrs +
639646
(cur_batch_in_all_start_index + start_n) * stride_vbs,
640-
mask=(start_n + offs_n[:, None]) <
641-
cur_batch_seq_len - cur_batch_ctx_len,
647+
mask=dim_mask[None, :] &
648+
((start_n + offs_n[:, None]) <
649+
cur_batch_seq_len - cur_batch_ctx_len),
642650
other=0.0)
643651

644652
p = p.to(v.dtype)
@@ -656,7 +664,8 @@ def _fwd_kernel_alibi(
656664
out_ptrs = Out + off_o
657665
tl.store(out_ptrs,
658666
acc,
659-
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
667+
mask=dim_mask[None, :] &
668+
(offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len))
660669
return
661670

662671
@torch.inference_mode()
@@ -690,7 +699,6 @@ def context_attention_fwd(q,
690699

691700
num_warps = 8 if Lk <= 64 else 8
692701
if alibi_slopes is not None:
693-
assert Lk == Lk_padded
694702
_fwd_kernel_alibi[grid](
695703
q,
696704
k,
@@ -735,6 +743,7 @@ def context_attention_fwd(q,
735743
num_queries_per_kv=num_queries_per_kv,
736744
BLOCK_M=BLOCK,
737745
BLOCK_DMODEL=Lk,
746+
BLOCK_DMODEL_PADDED=Lk_padded,
738747
BLOCK_N=BLOCK,
739748
num_warps=num_warps,
740749
num_stages=1,

0 commit comments

Comments
 (0)