Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
dfb3a2b
Add FP8 K/V Scale and dtype conversion for prefix/prefill
jon-chuang Aug 6, 2024
93bb6e0
remove outdated
jon-chuang Aug 6, 2024
f758ed7
undo incorrect
jon-chuang Aug 6, 2024
c15c038
Empty commit
jon-chuang Aug 6, 2024
cc1be05
Empty commit
jon-chuang Aug 6, 2024
47bf473
fix dtype conversion only when mixed-precision dot not supported
jon-chuang Aug 6, 2024
058df26
improve
jon-chuang Aug 6, 2024
ff5cf50
do torch tensor type conversion for triton consumption
jon-chuang Aug 6, 2024
ea13b24
minor
jon-chuang Aug 6, 2024
5c2a54f
add comments
jon-chuang Aug 6, 2024
2cb96c6
add missing args
jon-chuang Aug 6, 2024
44237b9
Empty commit
jon-chuang Aug 6, 2024
f1a855d
use correct types
jon-chuang Aug 6, 2024
c58450b
fmt
jon-chuang Aug 6, 2024
d9742e1
fix
jon-chuang Aug 7, 2024
87d5588
Merge branch 'main' of https://github.com/vllm-project/vllm into jon-…
jon-chuang Aug 7, 2024
a824564
triton doesn't like ternary if else
jon-chuang Aug 7, 2024
8d32e38
try
jon-chuang Aug 7, 2024
59572c9
convert dtype
jon-chuang Aug 7, 2024
39a1ffa
debug
jon-chuang Aug 7, 2024
452e74f
debug
jon-chuang Aug 7, 2024
105f3d7
improve tests
jon-chuang Aug 7, 2024
48c6f02
skip correctly
jon-chuang Aug 7, 2024
83a0206
skip hf if is fp8 model
jon-chuang Aug 7, 2024
a38da14
moar
jon-chuang Aug 7, 2024
02e0190
done
jon-chuang Aug 7, 2024
fcb4a82
use logprobs for test
jon-chuang Aug 7, 2024
351dc55
Num
jon-chuang Aug 7, 2024
7ddc42b
format
jon-chuang Aug 7, 2024
04870e5
improve tests
jon-chuang Aug 7, 2024
40b2c6c
Merge branch 'main' of https://github.com/vllm-project/vllm into jon-…
jon-chuang Aug 7, 2024
28557a2
use cache dtype
jon-chuang Aug 8, 2024
b88157f
minor
jon-chuang Aug 8, 2024
59d194d
use less strict for fp8 kv cache
jon-chuang Aug 8, 2024
16e144a
minor fix
jon-chuang Aug 8, 2024
047098f
slightly stricter
jon-chuang Aug 8, 2024
584579a
slightly less strict
jon-chuang Aug 8, 2024
371e56b
use 1e-3 safe
jon-chuang Aug 8, 2024
b2833e5
Simplify tests/kernels/test_prefix_prefill.py
jon-chuang Aug 9, 2024
fd18506
try
jon-chuang Aug 9, 2024
e9c5736
Merge branch 'jon-chuang/fix-fp8-triton-kernel' of https://github.com…
jon-chuang Aug 9, 2024
4874589
commit
jon-chuang Aug 9, 2024
5a30c42
Merge branch 'main' of https://github.com/vllm-project/vllm into jon-…
jon-chuang Aug 9, 2024
f80fc34
skip
jon-chuang Aug 9, 2024
ea45417
fix
jon-chuang Aug 9, 2024
f188624
improve
jon-chuang Aug 9, 2024
e3b7881
improve
jon-chuang Aug 9, 2024
de6c8ca
fmt
jon-chuang Aug 9, 2024
ac39c71
minor
jon-chuang Aug 12, 2024
dcda807
Merge branch 'main' of https://github.com/vllm-project/vllm into jon-…
jon-chuang Aug 12, 2024
cb77f1d
improve
jon-chuang Aug 12, 2024
e2998a9
minor
jon-chuang Aug 12, 2024
c09aaf7
split on plus
jon-chuang Aug 12, 2024
eb92719
limit len
jon-chuang Aug 12, 2024
8f106c3
try
jon-chuang Aug 12, 2024
df80fe5
improve
jon-chuang Aug 12, 2024
462a73c
fix
jon-chuang Aug 12, 2024
ab8a3c3
try again
jon-chuang Aug 12, 2024
35373aa
try again
jon-chuang Aug 12, 2024
9ff0075
use
jon-chuang Aug 12, 2024
ca133a4
apply code review
jon-chuang Aug 12, 2024
39e54cb
parse
jon-chuang Aug 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions docs/source/quantization/fp8_e4m3_kvcache.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,3 @@ Here is an example of how to enable this feature:
# output w/ scaling factors: England, the United Kingdom, and one of the world's leading financial,
# output w/o scaling factors: England, located in the southeastern part of the country. It is known

Note, current prefix caching doesn't work with FP8 KV cache enabled, forward_prefix kernel should handle different KV and cache type.

2 changes: 0 additions & 2 deletions docs/source/quantization/fp8_e5m2_kvcache.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,3 @@ Here is an example of how to enable this feature:
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


Note, current prefix caching doesn't work with FP8 KV cache enabled, forward_prefix kernel should handle different KV and cache type.

104 changes: 96 additions & 8 deletions tests/basic_correctness/test_chunked_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,27 @@

Run `pytest tests/models/test_chunked_prefill.py`.
"""

import pytest

from ..models.utils import check_outputs_equal
from ..models.utils import check_logprobs_close, check_outputs_equal

MODELS = [
"facebook/opt-125m",
"meta-llama/Llama-2-7b-hf",
]
E5M2_KV_MODELS = [
"facebook/opt-125m",
"meta-llama/Llama-2-7b-chat-hf",
]
E4M3_KV_MODELS = [
"meta-llama/Llama-2-7b-chat-hf", "nm-testing/Qwen2-1.5B-Instruct-FP8-K-V",
"nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme"
]
KV_CACHE_QUANTIZATION_PATHS = {
"meta-llama/Llama-2-7b-chat-hf":
"./tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json"
}


@pytest.mark.parametrize("model", MODELS)
Expand All @@ -35,12 +48,12 @@ def test_models(
enforce_eager: bool,
tensor_parallel_size: int,
) -> None:
max_num_seqs = min(chunked_prefill_token_size, 256)
enable_chunked_prefill = False
max_num_batched_tokens = None
if chunked_prefill_token_size != -1:
enable_chunked_prefill = True
max_num_batched_tokens = chunked_prefill_token_size
"""
Checks exact match decode between huggingface model and vllm runner with
chunked prefill.
"""
max_num_seqs = chunked_prefill_token_size
max_num_batched_tokens = chunked_prefill_token_size

with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
Expand All @@ -49,7 +62,7 @@ def test_models(
model,
dtype=dtype,
max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=enable_chunked_prefill,
enable_chunked_prefill=True,
tensor_parallel_size=tensor_parallel_size,
enforce_eager=enforce_eager,
max_num_seqs=max_num_seqs,
Expand All @@ -62,3 +75,78 @@ def test_models(
name_0="hf",
name_1="vllm",
)


@pytest.mark.parametrize("kv_cache_dtype,model",
[("fp8_e5m2", m)
for m in E5M2_KV_MODELS] + [("fp8_e4m3", m)
for m in E4M3_KV_MODELS])
# Due to low-precision numerical divergence, we only test logprob of 4 tokens
@pytest.mark.parametrize("max_tokens", [4])
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16])
@pytest.mark.parametrize("enforce_eager", [False, True])
# NOTE: Increasing this in this suite will fail CI because we currently cannot
# reset distributed env properly. Use a value > 1 just when you test.
@pytest.mark.parametrize("tensor_parallel_size", [1])
def test_models_with_fp8_kv_cache(
vllm_runner,
example_prompts,
kv_cache_dtype: str,
model: str,
max_tokens: int,
chunked_prefill_token_size: int,
enforce_eager: bool,
tensor_parallel_size: int,
) -> None:
"""
Only checks log probs match between chunked-prefill and
non-chunked-prefill version of vLLM model runner.

This test is used when there is discrepancy in kernels
/ numerics (e.g. when using lower-precision types like FP8).
"""
NUM_LOG_PROBS = 8

if model == "facebook/opt-125m":
pytest.skip(
"#7378: CUDA illegal memory access (undiagnosed) facebook/opt-125m"
)

max_num_seqs = chunked_prefill_token_size
max_num_batched_tokens = chunked_prefill_token_size

extra_kwargs = {}
if model in KV_CACHE_QUANTIZATION_PATHS:
extra_kwargs["quantization_param_path"] = KV_CACHE_QUANTIZATION_PATHS[
model]

with vllm_runner(
model,
tensor_parallel_size=tensor_parallel_size,
enforce_eager=enforce_eager,
max_num_seqs=max_num_seqs,
kv_cache_dtype=kv_cache_dtype,
**extra_kwargs,
) as vllm_model:
no_chunked_prefill_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, NUM_LOG_PROBS)

with vllm_runner(
model,
max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=True,
tensor_parallel_size=tensor_parallel_size,
enforce_eager=enforce_eager,
max_num_seqs=max_num_seqs,
kv_cache_dtype=kv_cache_dtype,
**extra_kwargs,
) as vllm_model:
chunked_prefill_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, NUM_LOG_PROBS)

check_logprobs_close(
outputs_0_lst=no_chunked_prefill_outputs,
outputs_1_lst=chunked_prefill_outputs,
name_0="no_chunked_prefill",
name_1="chunked_prefill",
)
33 changes: 26 additions & 7 deletions tests/kernels/test_prefix_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from vllm.attention.backends.xformers import _make_alibi_bias
from vllm.attention.ops.prefix_prefill import context_attention_fwd
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE

NUM_HEADS = [64]
NUM_QUERIES_PER_KV = [1, 8, 64]
Expand All @@ -18,12 +19,14 @@
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
SLIDING_WINDOW = [0, 16, 64, 128, 256, 512, 2048]
KV_CACHE_DTYPES = ["auto", "fp8", "fp8_e5m2"]


@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOW)
@torch.inference_mode()
Expand All @@ -33,6 +36,7 @@ def test_contexted_kv_attention(
head_size: int,
sliding_window: int,
dtype: torch.dtype,
kv_cache_dtype: str,
device: str,
) -> None:
random.seed(0)
Expand Down Expand Up @@ -67,16 +71,20 @@ def test_contexted_kv_attention(
kv.uniform_(-1e-3, 1e-3)
key, value = kv.unbind(dim=1)

if kv_cache_dtype == "auto":
cache_dtype = dtype
else:
cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype]
k_cache = torch.zeros(cache_size,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
dtype=cache_dtype)
v_cache = torch.zeros(cache_size,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
dtype=cache_dtype)
k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
values = torch.arange(0, cache_size, dtype=torch.long)
Expand Down Expand Up @@ -132,6 +140,7 @@ def test_contexted_kv_attention(
k,
v,
output,
kv_cache_dtype,
k_cache,
v_cache,
block_table,
Expand All @@ -146,6 +155,7 @@ def test_contexted_kv_attention(
k,
v,
output,
kv_cache_dtype,
k_cache,
v_cache,
block_table,
Expand Down Expand Up @@ -208,20 +218,23 @@ def test_contexted_kv_attention(
end_time = time.time()
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
output_ref = output_ref.reshape(output.shape)
assert torch.allclose(output_ref, output, atol=1e-6, rtol=0)
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6
torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)


@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_contexted_kv_attention_alibi(
num_heads: int,
num_queries_per_kv: int,
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: str,
device: str,
) -> None:
random.seed(0)
Expand Down Expand Up @@ -282,17 +295,20 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype)
kv.uniform_(-1e-3, 1e-3)
key, value = kv.unbind(dim=1)

if kv_cache_dtype == "auto":
cache_dtype = dtype
else:
cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype]
k_cache = torch.zeros(cache_size,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
dtype=cache_dtype)
v_cache = torch.zeros(cache_size,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
dtype=cache_dtype)
k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
values = torch.arange(0, cache_size, dtype=torch.long)
Expand Down Expand Up @@ -348,6 +364,7 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
k,
v,
output,
kv_cache_dtype,
k_cache,
v_cache,
block_table,
Expand All @@ -362,6 +379,7 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
k,
v,
output,
kv_cache_dtype,
k_cache,
v_cache,
block_table,
Expand Down Expand Up @@ -447,4 +465,5 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
torch.cuda.synchronize()
end_time = time.time()
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
assert torch.allclose(output_ref, output, atol=1e-6, rtol=0)
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6
torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
3 changes: 3 additions & 0 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,7 @@ def forward(
query,
key,
value,
self.kv_cache_dtype,
key_cache,
value_cache,
prefill_meta.block_tables,
Expand All @@ -468,6 +469,8 @@ def forward(
prefill_meta.max_query_len,
self.alibi_slopes,
self.sliding_window[0],
k_scale,
v_scale,
)

if decode_meta := attn_metadata.decode_metadata:
Expand Down
3 changes: 3 additions & 0 deletions vllm/attention/backends/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,7 @@ def forward(
query,
key,
value,
self.kv_cache_dtype,
key_cache,
value_cache,
prefill_meta.block_tables,
Expand All @@ -613,6 +614,8 @@ def forward(
prefill_meta.max_query_len,
self.alibi_slopes,
self.sliding_window,
k_scale,
v_scale,
)
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
Expand Down
1 change: 1 addition & 0 deletions vllm/attention/ops/ipex_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def forward_prefix(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache_dtype: str,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
Expand Down
6 changes: 6 additions & 0 deletions vllm/attention/ops/paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def forward_prefix(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache_dtype: str,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
Expand All @@ -203,13 +204,16 @@ def forward_prefix(
max_query_len: int,
alibi_slopes: Optional[torch.Tensor],
sliding_window: Optional[int],
k_scale: float,
v_scale: float,
) -> torch.Tensor:
output = torch.empty_like(query)
context_attention_fwd(
query,
key,
value,
output,
kv_cache_dtype,
key_cache,
value_cache,
block_tables,
Expand All @@ -218,6 +222,8 @@ def forward_prefix(
seq_lens_tensor,
context_lens,
max_query_len,
k_scale,
v_scale,
alibi_slopes,
sliding_window,
)
Expand Down
Loading