|
| 1 | +from typing import Optional |
| 2 | +import argparse |
| 3 | +import random |
| 4 | +import time |
| 5 | + |
| 6 | +import numpy as np |
| 7 | +import torch |
| 8 | + |
| 9 | +try: |
| 10 | + from flash_attn import flash_attn_func, flash_attn_with_kvcache |
| 11 | +except ImportError: |
| 12 | + flash_attn_func, flash_attn_with_kvcache = None, None |
| 13 | + |
| 14 | +from xformers import ops as xops |
| 15 | +from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask |
| 16 | + |
| 17 | +from vllm._C import cache_ops |
| 18 | +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random |
| 19 | + |
| 20 | +NUM_BLOCKS = 1024 |
| 21 | + |
| 22 | + |
| 23 | +@torch.inference_mode() |
| 24 | +def main( |
| 25 | + version: str, |
| 26 | + num_seqs: int, |
| 27 | + context_len: int, |
| 28 | + num_query_heads: int, |
| 29 | + num_kv_heads: int, |
| 30 | + head_size: int, |
| 31 | + use_alibi: bool, |
| 32 | + block_size: int, |
| 33 | + dtype: torch.dtype, |
| 34 | + seed: int, |
| 35 | + do_profile: bool, |
| 36 | + device: str = "cuda", |
| 37 | + kv_cache_dtype: Optional[str] = None, |
| 38 | +) -> None: |
| 39 | + random.seed(seed) |
| 40 | + torch.random.manual_seed(seed) |
| 41 | + if torch.cuda.is_available(): |
| 42 | + torch.cuda.manual_seed(seed) |
| 43 | + |
| 44 | + use_flash_attn = version in ["flash-attn", "flash-attn-kvcache"] |
| 45 | + if use_flash_attn: |
| 46 | + if dtype not in [torch.half, torch.bfloat16 |
| 47 | + ] or kv_cache_dtype != "auto": |
| 48 | + raise ValueError( |
| 49 | + "skip: flash-attn requires dtype and kv_cache_dtype to be half or bfloat16" |
| 50 | + ) |
| 51 | + |
| 52 | + context_lens = [context_len for _ in range(num_seqs)] |
| 53 | + max_context_len = max(context_lens) |
| 54 | + context_lens_tensor = torch.tensor(context_lens, |
| 55 | + dtype=torch.int, |
| 56 | + device=device) |
| 57 | + zero_context_lens_tensor = torch.zeros_like(context_lens_tensor) |
| 58 | + |
| 59 | + scale = float(1.0 / (head_size**0.5)) |
| 60 | + qkv = torch.empty(num_seqs, |
| 61 | + max_context_len, |
| 62 | + num_query_heads + 2 * num_kv_heads, |
| 63 | + head_size, |
| 64 | + dtype=dtype, |
| 65 | + device=device) |
| 66 | + qkv.uniform_(-scale, scale) |
| 67 | + query, key, value = qkv.split( |
| 68 | + [num_query_heads, num_kv_heads, num_kv_heads], dim=2) |
| 69 | + |
| 70 | + assert num_query_heads % num_kv_heads == 0 |
| 71 | + num_queries_per_kv = num_query_heads // num_kv_heads |
| 72 | + |
| 73 | + alibi_slopes = None |
| 74 | + if use_alibi: |
| 75 | + alibi_slopes = torch.randn(num_query_heads, |
| 76 | + dtype=torch.float, |
| 77 | + device=device) |
| 78 | + |
| 79 | + # Create the block tables. |
| 80 | + if use_flash_attn: |
| 81 | + block_size = ((block_size + 256 - 1) // 256) * 256 |
| 82 | + max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size |
| 83 | + block_tables, slot_mapping = [], [] |
| 84 | + for seq_idx in range(num_seqs): |
| 85 | + block_table = [ |
| 86 | + random.randint(0, NUM_BLOCKS - 1) |
| 87 | + for _ in range(max_num_blocks_per_seq) |
| 88 | + ] |
| 89 | + block_tables.append(block_table) |
| 90 | + slot_mapping.append([]) |
| 91 | + for i in range(context_lens[seq_idx]): |
| 92 | + block_number = block_table[i // block_size] |
| 93 | + block_offset = i % block_size |
| 94 | + slot = block_number * block_size + block_offset |
| 95 | + slot_mapping[-1].append(slot) |
| 96 | + for _ in range(max_context_len - context_lens[seq_idx]): |
| 97 | + slot_mapping[-1].append(-1) |
| 98 | + block_tables = torch.tensor(block_tables, dtype=torch.int, device=device) |
| 99 | + slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=device) |
| 100 | + |
| 101 | + # Create the KV cache. |
| 102 | + key_caches, value_caches = create_kv_caches_with_random( |
| 103 | + NUM_BLOCKS, |
| 104 | + block_size, |
| 105 | + 1, |
| 106 | + num_kv_heads, |
| 107 | + head_size, |
| 108 | + kv_cache_dtype, |
| 109 | + dtype, |
| 110 | + device=device, |
| 111 | + use_flash_attn=use_flash_attn) |
| 112 | + key_cache, value_cache = key_caches[0], value_caches[0] |
| 113 | + |
| 114 | + if version == "xformers": |
| 115 | + attn_bias = BlockDiagonalCausalMask.from_seqlens(context_lens) |
| 116 | + if num_queries_per_kv > 1: |
| 117 | + # Handle MQA and GQA |
| 118 | + key_repeated = torch.repeat_interleave(key, |
| 119 | + num_queries_per_kv, |
| 120 | + dim=2) |
| 121 | + value_repeated = torch.repeat_interleave(value, |
| 122 | + num_queries_per_kv, |
| 123 | + dim=2) |
| 124 | + else: |
| 125 | + key_repeated = key |
| 126 | + value_repeated = value |
| 127 | + |
| 128 | + def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: |
| 129 | + torch.cuda.synchronize() |
| 130 | + if profile: |
| 131 | + torch.cuda.cudart().cudaProfilerStart() |
| 132 | + start_time = time.perf_counter() |
| 133 | + |
| 134 | + for _ in range(num_iters): |
| 135 | + if version == "xformers": |
| 136 | + cache_ops.reshape_and_cache( |
| 137 | + key.reshape(-1, *key.shape[2:]), |
| 138 | + value.reshape(-1, *key.shape[2:]), |
| 139 | + key_cache, |
| 140 | + value_cache, |
| 141 | + slot_mapping.flatten(), |
| 142 | + kv_cache_dtype, |
| 143 | + ) |
| 144 | + output = xops.memory_efficient_attention_forward( |
| 145 | + query.reshape(1, -1, *query.shape[2:]), |
| 146 | + key_repeated.reshape(1, -1, *key_repeated.shape[2:]), |
| 147 | + value_repeated.reshape(1, -1, *value_repeated.shape[2:]), |
| 148 | + attn_bias=attn_bias, |
| 149 | + p=0.0, |
| 150 | + scale=scale, |
| 151 | + ) |
| 152 | + output = output.reshape(query.shape) |
| 153 | + elif version == "flash-attn": |
| 154 | + flat_slot_mapping = slot_mapping.flatten() |
| 155 | + slot_block_index = flat_slot_mapping // block_size |
| 156 | + slot_block_offset = flat_slot_mapping % block_size |
| 157 | + key_cache[slot_block_index, |
| 158 | + slot_block_offset, :, :] = key.reshape( |
| 159 | + -1, *key.shape[2:]) |
| 160 | + value_cache[slot_block_index, |
| 161 | + slot_block_offset, :, :] = value.reshape( |
| 162 | + -1, *key.shape[2:]) |
| 163 | + output = flash_attn_func( |
| 164 | + q=query, |
| 165 | + k=key, |
| 166 | + v=value, |
| 167 | + softmax_scale=scale, |
| 168 | + causal=True, |
| 169 | + alibi_slopes=alibi_slopes, |
| 170 | + ) |
| 171 | + elif version == "flash-attn-kvcache": |
| 172 | + output = flash_attn_with_kvcache( |
| 173 | + q=query, |
| 174 | + k_cache=key_cache, |
| 175 | + v_cache=value_cache, |
| 176 | + k=key, |
| 177 | + v=value, |
| 178 | + cache_seqlens=zero_context_lens_tensor, |
| 179 | + block_table=block_tables, |
| 180 | + softmax_scale=scale, |
| 181 | + causal=True, |
| 182 | + alibi_slopes=alibi_slopes, |
| 183 | + ) |
| 184 | + else: |
| 185 | + raise ValueError(f"Invalid version: {version}") |
| 186 | + torch.cuda.synchronize() |
| 187 | + |
| 188 | + end_time = time.perf_counter() |
| 189 | + if profile: |
| 190 | + torch.cuda.cudart().cudaProfilerStart() |
| 191 | + return (end_time - start_time) / num_iters |
| 192 | + |
| 193 | + # Warmup. |
| 194 | + print("Warming up...") |
| 195 | + run_benchmark = run_cuda_benchmark |
| 196 | + run_benchmark(num_iters=3, profile=False) |
| 197 | + |
| 198 | + # Benchmark. |
| 199 | + if do_profile: |
| 200 | + latency = run_benchmark(num_iters=1, profile=True) |
| 201 | + else: |
| 202 | + latency = run_benchmark(num_iters=100, profile=False) |
| 203 | + print( |
| 204 | + f"Version: {version}, Context Length: {context_len}, Batch size: {num_seqs}, Kernel running time: {latency * 1000000:.3f} us" |
| 205 | + ) |
| 206 | + |
| 207 | + |
| 208 | +if __name__ == '__main__': |
| 209 | + parser = argparse.ArgumentParser( |
| 210 | + description="Benchmark the paged attention kernel.") |
| 211 | + parser.add_argument( |
| 212 | + "--version", |
| 213 | + type=str, |
| 214 | + choices=["xformers", "flash-attn", "flash-attn-kvcache"], |
| 215 | + default="xformers") |
| 216 | + parser.add_argument("--batch-size", type=int, default=8) |
| 217 | + parser.add_argument("--context-len", type=int, default=4096) |
| 218 | + parser.add_argument("--num-query-heads", type=int, default=64) |
| 219 | + parser.add_argument("--num-kv-heads", type=int, default=8) |
| 220 | + parser.add_argument("--head-size", |
| 221 | + type=int, |
| 222 | + choices=[64, 80, 96, 112, 128, 256], |
| 223 | + default=128) |
| 224 | + parser.add_argument("--block-size", type=int, choices=[16, 32], default=16) |
| 225 | + parser.add_argument("--use-alibi", action="store_true") |
| 226 | + parser.add_argument("--dtype", |
| 227 | + type=str, |
| 228 | + choices=["half", "bfloat16", "float"], |
| 229 | + default="half") |
| 230 | + parser.add_argument("--seed", type=int, default=0) |
| 231 | + parser.add_argument("--profile", action="store_true") |
| 232 | + parser.add_argument( |
| 233 | + "--kv-cache-dtype", |
| 234 | + type=str, |
| 235 | + choices=["auto", "fp8_e5m2"], |
| 236 | + default="auto", |
| 237 | + help= |
| 238 | + 'Data type for kv cache storage. If "auto", will use model data type.') |
| 239 | + parser.add_argument("--device", type=str, choices=["cuda"], default="cuda") |
| 240 | + args = parser.parse_args() |
| 241 | + print(args) |
| 242 | + |
| 243 | + if args.num_query_heads % args.num_kv_heads != 0: |
| 244 | + raise ValueError("num_query_heads must be divisible by num_kv_heads") |
| 245 | + main( |
| 246 | + version=args.version, |
| 247 | + num_seqs=args.batch_size, |
| 248 | + context_len=args.context_len, |
| 249 | + num_query_heads=args.num_query_heads, |
| 250 | + num_kv_heads=args.num_kv_heads, |
| 251 | + head_size=args.head_size, |
| 252 | + block_size=args.block_size, |
| 253 | + use_alibi=args.use_alibi, |
| 254 | + dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype], |
| 255 | + seed=args.seed, |
| 256 | + do_profile=args.profile, |
| 257 | + kv_cache_dtype=args.kv_cache_dtype, |
| 258 | + ) |
0 commit comments