|
| 1 | +import math |
1 | 2 | import random |
2 | 3 | import time |
3 | 4 |
|
|
6 | 7 | from xformers import ops as xops |
7 | 8 | from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask |
8 | 9 |
|
| 10 | +from vllm.attention.backends.xformers import _make_alibi_bias |
9 | 11 | from vllm.attention.ops.prefix_prefill import context_attention_fwd |
10 | 12 |
|
11 | 13 | NUM_HEADS = [64] |
12 | 14 | NUM_QUERIES_PER_KV = [1, 8, 64] |
13 | | -HEAD_SIZES = [128, 96] |
| 15 | +HEAD_SIZES = [128, 96, 24] |
14 | 16 | DTYPES = [torch.float16] |
15 | 17 | CUDA_DEVICES = [ |
16 | 18 | f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) |
@@ -207,3 +209,242 @@ def test_contexted_kv_attention( |
207 | 209 | print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") |
208 | 210 | output_ref = output_ref.reshape(output.shape) |
209 | 211 | assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) |
| 212 | + |
| 213 | + |
| 214 | +@pytest.mark.parametrize("num_heads", NUM_HEADS) |
| 215 | +@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV) |
| 216 | +@pytest.mark.parametrize("head_size", HEAD_SIZES) |
| 217 | +@pytest.mark.parametrize("dtype", DTYPES) |
| 218 | +@pytest.mark.parametrize("device", CUDA_DEVICES) |
| 219 | +@torch.inference_mode() |
| 220 | +def test_contexted_kv_attention_alibi( |
| 221 | + num_heads: int, |
| 222 | + num_queries_per_kv: int, |
| 223 | + head_size: int, |
| 224 | + dtype: torch.dtype, |
| 225 | + device: str, |
| 226 | +) -> None: |
| 227 | + random.seed(0) |
| 228 | + torch.manual_seed(0) |
| 229 | + if torch.cuda.is_available(): |
| 230 | + torch.cuda.manual_seed(0) |
| 231 | + torch.set_default_device(device) |
| 232 | + |
| 233 | + # Need this, otherwise when we capture the graph the process |
| 234 | + # for GPU 1 would run on both GPU0 and GPU1 and things would hang |
| 235 | + # |
| 236 | + # see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523 |
| 237 | + torch.cuda.set_device(device) |
| 238 | + |
| 239 | + def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: |
| 240 | + # Fork from: vllm/vllm/model_executor/models/bloom.py#L44 |
| 241 | + closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) |
| 242 | + base = torch.tensor( |
| 243 | + 2**(-(2**-(math.log2(closest_power_of_2) - 3))), |
| 244 | + dtype=torch.float32, |
| 245 | + ) |
| 246 | + powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) |
| 247 | + slopes = torch.pow(base, powers) |
| 248 | + |
| 249 | + if closest_power_of_2 != total_num_heads: |
| 250 | + extra_base = torch.tensor( |
| 251 | + 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), |
| 252 | + dtype=torch.float32, |
| 253 | + ) |
| 254 | + num_remaining_heads = min(closest_power_of_2, |
| 255 | + total_num_heads - closest_power_of_2) |
| 256 | + extra_powers = torch.arange(start=1, |
| 257 | + end=1 + 2 * num_remaining_heads, |
| 258 | + step=2, |
| 259 | + dtype=torch.int32) |
| 260 | + slopes = torch.cat( |
| 261 | + [slopes, torch.pow(extra_base, extra_powers)], dim=0) |
| 262 | + return slopes |
| 263 | + |
| 264 | + alibi_slopes = _get_alibi_slopes(num_heads).to(device) |
| 265 | + |
| 266 | + MAX_SEQ_LEN = 1024 |
| 267 | + MAX_CTX_LEN = 1024 |
| 268 | + BS = 10 |
| 269 | + cache_size = 640 |
| 270 | + block_size = 32 |
| 271 | + max_block_per_request = 64 |
| 272 | + query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)] |
| 273 | + ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)] |
| 274 | + seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)] |
| 275 | + num_kv_heads = num_heads // num_queries_per_kv |
| 276 | + |
| 277 | + num_tokens = sum(query_lens) |
| 278 | + query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) |
| 279 | + query.uniform_(-1e-3, 1e-3) |
| 280 | + output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) |
| 281 | + |
| 282 | + kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype) |
| 283 | + kv.uniform_(-1e-3, 1e-3) |
| 284 | + key, value = kv.unbind(dim=1) |
| 285 | + |
| 286 | + k_cache = torch.zeros(cache_size, |
| 287 | + block_size, |
| 288 | + num_kv_heads, |
| 289 | + head_size, |
| 290 | + dtype=dtype) |
| 291 | + v_cache = torch.zeros(cache_size, |
| 292 | + block_size, |
| 293 | + num_kv_heads, |
| 294 | + head_size, |
| 295 | + dtype=dtype) |
| 296 | + k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) |
| 297 | + v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) |
| 298 | + values = torch.arange(0, cache_size, dtype=torch.long) |
| 299 | + values = values[torch.randperm(cache_size)] |
| 300 | + block_table = values[:BS * max_block_per_request].view( |
| 301 | + BS, max_block_per_request) |
| 302 | + b_seq_len = torch.tensor(seq_lens, dtype=torch.long) |
| 303 | + b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long) |
| 304 | + b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1], |
| 305 | + dtype=torch.long), |
| 306 | + dim=0) |
| 307 | + max_input_len = MAX_SEQ_LEN |
| 308 | + # copy kv to cache |
| 309 | + b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1], |
| 310 | + dtype=torch.long), |
| 311 | + dim=0) |
| 312 | + for i in range(BS): |
| 313 | + for j in range(query_lens[i]): |
| 314 | + k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + |
| 315 | + j]) |
| 316 | + v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + |
| 317 | + b_ctx_len[i] + j]) |
| 318 | + cur_ctx = 0 |
| 319 | + block_id = 0 |
| 320 | + while cur_ctx < b_ctx_len[i]: |
| 321 | + start_loc = b_seq_start_loc[i] + cur_ctx |
| 322 | + if cur_ctx + block_size > b_ctx_len[i]: |
| 323 | + end_loc = b_seq_start_loc[i] + b_ctx_len[i] |
| 324 | + else: |
| 325 | + end_loc = start_loc + block_size |
| 326 | + start_slot = block_table[i, block_id] * block_size |
| 327 | + end_slot = start_slot + end_loc - start_loc |
| 328 | + k_cache.view(-1, num_kv_heads, |
| 329 | + head_size)[start_slot:end_slot].copy_( |
| 330 | + key[start_loc:end_loc]) |
| 331 | + v_cache.view(-1, num_kv_heads, |
| 332 | + head_size)[start_slot:end_slot].copy_( |
| 333 | + value[start_loc:end_loc]) |
| 334 | + cur_ctx += block_size |
| 335 | + block_id += 1 |
| 336 | + # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size] |
| 337 | + # to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8] |
| 338 | + k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8, |
| 339 | + 8).permute(0, 2, 3, 1, 4).contiguous() |
| 340 | + # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size] |
| 341 | + # to V_cache[num_blocks, num_kv_heads, head_size, block_size] |
| 342 | + v_cache = v_cache.view(-1, block_size, num_kv_heads, |
| 343 | + head_size).permute(0, 2, 3, 1).contiguous() |
| 344 | + |
| 345 | + # Warm up the Triton kernel by calling it once before actually measuring |
| 346 | + # generation time |
| 347 | + context_attention_fwd(query, |
| 348 | + k, |
| 349 | + v, |
| 350 | + output, |
| 351 | + k_cache, |
| 352 | + v_cache, |
| 353 | + block_table, |
| 354 | + b_start_loc, |
| 355 | + b_seq_len, |
| 356 | + b_ctx_len, |
| 357 | + max_input_len, |
| 358 | + alibi_slopes=alibi_slopes) |
| 359 | + torch.cuda.synchronize() |
| 360 | + start_time = time.time() |
| 361 | + context_attention_fwd(query, |
| 362 | + k, |
| 363 | + v, |
| 364 | + output, |
| 365 | + k_cache, |
| 366 | + v_cache, |
| 367 | + block_table, |
| 368 | + b_start_loc, |
| 369 | + b_seq_len, |
| 370 | + b_ctx_len, |
| 371 | + max_input_len, |
| 372 | + alibi_slopes=alibi_slopes) |
| 373 | + torch.cuda.synchronize() |
| 374 | + end_time = time.time() |
| 375 | + print(f"triton Time: {(end_time - start_time)*1000:.2f} ms") |
| 376 | + scale = float(1.0 / (head_size**0.5)) |
| 377 | + |
| 378 | + # NOTE(DefTruth): In order to reuse _make_alibi_bias function, |
| 379 | + # we have to pad query tensor before MQA/GQA expanding. |
| 380 | + if query.shape[0] != key.shape[0]: |
| 381 | + query_pad = torch.empty(sum(seq_lens), |
| 382 | + num_heads, |
| 383 | + head_size, |
| 384 | + dtype=dtype) |
| 385 | + query_pad.uniform_(-1e-3, 1e-3) |
| 386 | + seq_start = 0 |
| 387 | + query_start = 0 |
| 388 | + for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): |
| 389 | + seq_end = seq_start + seq_len |
| 390 | + query_end = query_start + query_len |
| 391 | + query_pad[seq_start:seq_end, ...] = torch.cat([ |
| 392 | + torch.zeros( |
| 393 | + seq_len - query_len, num_heads, head_size, dtype=dtype), |
| 394 | + query[query_start:query_end, ...] |
| 395 | + ], |
| 396 | + dim=0) |
| 397 | + seq_start += seq_len |
| 398 | + query_start += query_len |
| 399 | + query = query_pad |
| 400 | + |
| 401 | + if num_kv_heads != num_heads: |
| 402 | + # As of Nov 2023, xformers only supports MHA. For MQA/GQA, |
| 403 | + # project the key and value tensors to the desired number of |
| 404 | + # heads. |
| 405 | + # |
| 406 | + # see also: vllm/model_executor/layers/attention.py |
| 407 | + query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv, |
| 408 | + query.shape[-1]) |
| 409 | + key = key[:, :, None, :].expand(key.shape[0], num_kv_heads, |
| 410 | + num_queries_per_kv, key.shape[-1]) |
| 411 | + value = value[:, :, |
| 412 | + None, :].expand(value.shape[0], num_kv_heads, |
| 413 | + num_queries_per_kv, value.shape[-1]) |
| 414 | + |
| 415 | + query = query.unsqueeze(0) |
| 416 | + key = key.unsqueeze(0) |
| 417 | + value = value.unsqueeze(0) |
| 418 | + |
| 419 | + attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens) |
| 420 | + output_ref = torch.empty_like(output) |
| 421 | + seq_start = 0 |
| 422 | + query_start = 0 |
| 423 | + start_time = time.time() |
| 424 | + # Attention with alibi slopes. |
| 425 | + # FIXME(DefTruth): Because xformers does not support dynamic sequence |
| 426 | + # lengths with custom attention bias, we process each prompt one by |
| 427 | + # one. This is inefficient, especially when we have many short prompts. |
| 428 | + # modified from: vllm/attention/backends/xformers.py#L343 |
| 429 | + for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): |
| 430 | + seq_end = seq_start + seq_len |
| 431 | + query_end = query_start + query_len |
| 432 | + out = xops.memory_efficient_attention_forward(query[:, |
| 433 | + seq_start:seq_end], |
| 434 | + key[:, |
| 435 | + seq_start:seq_end], |
| 436 | + value[:, |
| 437 | + seq_start:seq_end], |
| 438 | + attn_bias=attn_bias[i], |
| 439 | + p=0.0, |
| 440 | + scale=scale) |
| 441 | + out = out.view_as(query[:, seq_start:seq_end]).view( |
| 442 | + seq_len, num_heads, head_size) |
| 443 | + output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len:, |
| 444 | + ...]) |
| 445 | + seq_start += seq_len |
| 446 | + query_start += query_len |
| 447 | + torch.cuda.synchronize() |
| 448 | + end_time = time.time() |
| 449 | + print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") |
| 450 | + assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) |
0 commit comments