Skip to content

Commit 5cbeb04

Browse files
committed
Introduce flash-attn (>= 2.5.0) and add test cases.
Signed-off-by: Tao He <[email protected]>
1 parent 93348d9 commit 5cbeb04

File tree

14 files changed

+612
-95
lines changed

14 files changed

+612
-95
lines changed

benchmarks/benchmark_latency.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def main(args: argparse.Namespace):
2727
kv_cache_dtype=args.kv_cache_dtype,
2828
device=args.device,
2929
ray_workers_use_nsight=args.ray_workers_use_nsight,
30+
use_flash_attn=args.use_flash_attn,
3031
)
3132

3233
sampling_params = SamplingParams(
@@ -151,5 +152,9 @@ def run_to_completion(profile_dir: Optional[str] = None):
151152
action='store_true',
152153
help="If specified, use nsight to profile ray workers",
153154
)
155+
parser.add_argument(
156+
"--use-flash-attn",
157+
action="store_true",
158+
help="Use flash attention (requires flash-attn >= 2.5.0).")
154159
args = parser.parse_args()
155160
main(args)

benchmarks/benchmark_throughput.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def run_vllm(
7575
device: str,
7676
enable_prefix_caching: bool,
7777
gpu_memory_utilization: float = 0.9,
78+
use_flash_attn: Optional[bool] = False,
7879
) -> float:
7980
from vllm import LLM, SamplingParams
8081
llm = LLM(model=model,
@@ -89,7 +90,8 @@ def run_vllm(
8990
enforce_eager=enforce_eager,
9091
kv_cache_dtype=kv_cache_dtype,
9192
device=device,
92-
enable_prefix_caching=enable_prefix_caching)
93+
enable_prefix_caching=enable_prefix_caching,
94+
use_flash_attn=use_flash_attn)
9395

9496
# Add the requests to the engine.
9597
for prompt, _, output_len in requests:
@@ -213,7 +215,8 @@ def main(args: argparse.Namespace):
213215
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
214216
args.trust_remote_code, args.dtype, args.max_model_len,
215217
args.enforce_eager, args.kv_cache_dtype, args.device,
216-
args.enable_prefix_caching, args.gpu_memory_utilization)
218+
args.enable_prefix_caching, args.gpu_memory_utilization,
219+
args.use_flash_attn)
217220
elif args.backend == "hf":
218221
assert args.tensor_parallel_size == 1
219222
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
@@ -314,6 +317,10 @@ def main(args: argparse.Namespace):
314317
"--enable-prefix-caching",
315318
action='store_true',
316319
help="enable automatic prefix caching for vLLM backend.")
320+
parser.add_argument(
321+
"--use-flash-attn",
322+
action="store_true",
323+
help="Use flash attention (requires flash-attn >= 2.5.0).")
317324
args = parser.parse_args()
318325
if args.tokenizer is None:
319326
args.tokenizer = args.model
Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
from typing import Optional
2+
import argparse
3+
import random
4+
import time
5+
6+
import numpy as np
7+
import torch
8+
9+
try:
10+
from flash_attn import flash_attn_func, flash_attn_with_kvcache
11+
except ImportError:
12+
flash_attn_func, flash_attn_with_kvcache = None, None
13+
14+
from xformers import ops as xops
15+
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
16+
17+
from vllm._C import cache_ops
18+
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
19+
20+
NUM_BLOCKS = 1024
21+
22+
23+
@torch.inference_mode()
24+
def main(
25+
version: str,
26+
num_seqs: int,
27+
context_len: int,
28+
num_query_heads: int,
29+
num_kv_heads: int,
30+
head_size: int,
31+
use_alibi: bool,
32+
block_size: int,
33+
dtype: torch.dtype,
34+
seed: int,
35+
do_profile: bool,
36+
device: str = "cuda",
37+
kv_cache_dtype: Optional[str] = None,
38+
) -> None:
39+
random.seed(seed)
40+
torch.random.manual_seed(seed)
41+
if torch.cuda.is_available():
42+
torch.cuda.manual_seed(seed)
43+
44+
use_flash_attn = version in ["flash-attn", "flash-attn-kvcache"]
45+
if use_flash_attn:
46+
if dtype not in [torch.half, torch.bfloat16
47+
] or kv_cache_dtype != "auto":
48+
raise ValueError(
49+
"skip: flash-attn requires dtype and kv_cache_dtype to be half or bfloat16"
50+
)
51+
52+
context_lens = [context_len for _ in range(num_seqs)]
53+
max_context_len = max(context_lens)
54+
context_lens_tensor = torch.tensor(context_lens,
55+
dtype=torch.int,
56+
device=device)
57+
zero_context_lens_tensor = torch.zeros_like(context_lens_tensor)
58+
59+
scale = float(1.0 / (head_size**0.5))
60+
qkv = torch.empty(num_seqs,
61+
max_context_len,
62+
num_query_heads + 2 * num_kv_heads,
63+
head_size,
64+
dtype=dtype,
65+
device=device)
66+
qkv.uniform_(-scale, scale)
67+
query, key, value = qkv.split(
68+
[num_query_heads, num_kv_heads, num_kv_heads], dim=2)
69+
70+
assert num_query_heads % num_kv_heads == 0
71+
num_queries_per_kv = num_query_heads // num_kv_heads
72+
73+
alibi_slopes = None
74+
if use_alibi:
75+
alibi_slopes = torch.randn(num_query_heads,
76+
dtype=torch.float,
77+
device=device)
78+
79+
# Create the block tables.
80+
if use_flash_attn:
81+
block_size = ((block_size + 256 - 1) // 256) * 256
82+
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
83+
block_tables, slot_mapping = [], []
84+
for seq_idx in range(num_seqs):
85+
block_table = [
86+
random.randint(0, NUM_BLOCKS - 1)
87+
for _ in range(max_num_blocks_per_seq)
88+
]
89+
block_tables.append(block_table)
90+
slot_mapping.append([])
91+
for i in range(context_lens[seq_idx]):
92+
block_number = block_table[i // block_size]
93+
block_offset = i % block_size
94+
slot = block_number * block_size + block_offset
95+
slot_mapping[-1].append(slot)
96+
for _ in range(max_context_len - context_lens[seq_idx]):
97+
slot_mapping[-1].append(-1)
98+
block_tables = torch.tensor(block_tables, dtype=torch.int, device=device)
99+
slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=device)
100+
101+
# Create the KV cache.
102+
key_caches, value_caches = create_kv_caches_with_random(
103+
NUM_BLOCKS,
104+
block_size,
105+
1,
106+
num_kv_heads,
107+
head_size,
108+
kv_cache_dtype,
109+
dtype,
110+
device=device,
111+
use_flash_attn=use_flash_attn)
112+
key_cache, value_cache = key_caches[0], value_caches[0]
113+
114+
if version == "xformers":
115+
attn_bias = BlockDiagonalCausalMask.from_seqlens(context_lens)
116+
if num_queries_per_kv > 1:
117+
# Handle MQA and GQA
118+
key_repeated = torch.repeat_interleave(key,
119+
num_queries_per_kv,
120+
dim=2)
121+
value_repeated = torch.repeat_interleave(value,
122+
num_queries_per_kv,
123+
dim=2)
124+
else:
125+
key_repeated = key
126+
value_repeated = value
127+
128+
def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
129+
torch.cuda.synchronize()
130+
if profile:
131+
torch.cuda.cudart().cudaProfilerStart()
132+
start_time = time.perf_counter()
133+
134+
for _ in range(num_iters):
135+
if version == "xformers":
136+
cache_ops.reshape_and_cache(
137+
key.reshape(-1, *key.shape[2:]),
138+
value.reshape(-1, *key.shape[2:]),
139+
key_cache,
140+
value_cache,
141+
slot_mapping.flatten(),
142+
kv_cache_dtype,
143+
)
144+
output = xops.memory_efficient_attention_forward(
145+
query.reshape(1, -1, *query.shape[2:]),
146+
key_repeated.reshape(1, -1, *key_repeated.shape[2:]),
147+
value_repeated.reshape(1, -1, *value_repeated.shape[2:]),
148+
attn_bias=attn_bias,
149+
p=0.0,
150+
scale=scale,
151+
)
152+
output = output.reshape(query.shape)
153+
elif version == "flash-attn":
154+
flat_slot_mapping = slot_mapping.flatten()
155+
slot_block_index = flat_slot_mapping // block_size
156+
slot_block_offset = flat_slot_mapping % block_size
157+
key_cache[slot_block_index,
158+
slot_block_offset, :, :] = key.reshape(
159+
-1, *key.shape[2:])
160+
value_cache[slot_block_index,
161+
slot_block_offset, :, :] = value.reshape(
162+
-1, *key.shape[2:])
163+
output = flash_attn_func(
164+
q=query,
165+
k=key,
166+
v=value,
167+
softmax_scale=scale,
168+
causal=True,
169+
alibi_slopes=alibi_slopes,
170+
)
171+
elif version == "flash-attn-kvcache":
172+
output = flash_attn_with_kvcache(
173+
q=query,
174+
k_cache=key_cache,
175+
v_cache=value_cache,
176+
k=key,
177+
v=value,
178+
cache_seqlens=zero_context_lens_tensor,
179+
block_table=block_tables,
180+
softmax_scale=scale,
181+
causal=True,
182+
alibi_slopes=alibi_slopes,
183+
)
184+
else:
185+
raise ValueError(f"Invalid version: {version}")
186+
torch.cuda.synchronize()
187+
188+
end_time = time.perf_counter()
189+
if profile:
190+
torch.cuda.cudart().cudaProfilerStart()
191+
return (end_time - start_time) / num_iters
192+
193+
# Warmup.
194+
print("Warming up...")
195+
run_benchmark = run_cuda_benchmark
196+
run_benchmark(num_iters=3, profile=False)
197+
198+
# Benchmark.
199+
if do_profile:
200+
latency = run_benchmark(num_iters=1, profile=True)
201+
else:
202+
latency = run_benchmark(num_iters=100, profile=False)
203+
print(
204+
f"Version: {version}, Context Length: {context_len}, Batch size: {num_seqs}, Kernel running time: {latency * 1000000:.3f} us"
205+
)
206+
207+
208+
if __name__ == '__main__':
209+
parser = argparse.ArgumentParser(
210+
description="Benchmark the paged attention kernel.")
211+
parser.add_argument(
212+
"--version",
213+
type=str,
214+
choices=["xformers", "flash-attn", "flash-attn-kvcache"],
215+
default="xformers")
216+
parser.add_argument("--batch-size", type=int, default=8)
217+
parser.add_argument("--context-len", type=int, default=4096)
218+
parser.add_argument("--num-query-heads", type=int, default=64)
219+
parser.add_argument("--num-kv-heads", type=int, default=8)
220+
parser.add_argument("--head-size",
221+
type=int,
222+
choices=[64, 80, 96, 112, 128, 256],
223+
default=128)
224+
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
225+
parser.add_argument("--use-alibi", action="store_true")
226+
parser.add_argument("--dtype",
227+
type=str,
228+
choices=["half", "bfloat16", "float"],
229+
default="half")
230+
parser.add_argument("--seed", type=int, default=0)
231+
parser.add_argument("--profile", action="store_true")
232+
parser.add_argument(
233+
"--kv-cache-dtype",
234+
type=str,
235+
choices=["auto", "fp8_e5m2"],
236+
default="auto",
237+
help=
238+
'Data type for kv cache storage. If "auto", will use model data type.')
239+
parser.add_argument("--device", type=str, choices=["cuda"], default="cuda")
240+
args = parser.parse_args()
241+
print(args)
242+
243+
if args.num_query_heads % args.num_kv_heads != 0:
244+
raise ValueError("num_query_heads must be divisible by num_kv_heads")
245+
main(
246+
version=args.version,
247+
num_seqs=args.batch_size,
248+
context_len=args.context_len,
249+
num_query_heads=args.num_query_heads,
250+
num_kv_heads=args.num_kv_heads,
251+
head_size=args.head_size,
252+
block_size=args.block_size,
253+
use_alibi=args.use_alibi,
254+
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
255+
seed=args.seed,
256+
do_profile=args.profile,
257+
kv_cache_dtype=args.kv_cache_dtype,
258+
)

0 commit comments

Comments
 (0)