diff --git a/docs/source/quantization/fp8_e4m3_kvcache.rst b/docs/source/quantization/fp8_e4m3_kvcache.rst index fd71c00b7bf8..cc52d8f40af8 100644 --- a/docs/source/quantization/fp8_e4m3_kvcache.rst +++ b/docs/source/quantization/fp8_e4m3_kvcache.rst @@ -45,5 +45,3 @@ Here is an example of how to enable this feature: # output w/ scaling factors: England, the United Kingdom, and one of the world's leading financial, # output w/o scaling factors: England, located in the southeastern part of the country. It is known -Note, current prefix caching doesn't work with FP8 KV cache enabled, forward_prefix kernel should handle different KV and cache type. - diff --git a/docs/source/quantization/fp8_e5m2_kvcache.rst b/docs/source/quantization/fp8_e5m2_kvcache.rst index 337252a00aef..9ae07bcd3b99 100644 --- a/docs/source/quantization/fp8_e5m2_kvcache.rst +++ b/docs/source/quantization/fp8_e5m2_kvcache.rst @@ -32,5 +32,3 @@ Here is an example of how to enable this feature: print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") -Note, current prefix caching doesn't work with FP8 KV cache enabled, forward_prefix kernel should handle different KV and cache type. - diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py index 767e0628765b..fcc444842213 100644 --- a/tests/basic_correctness/test_chunked_prefill.py +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -6,14 +6,27 @@ Run `pytest tests/models/test_chunked_prefill.py`. """ + import pytest -from ..models.utils import check_outputs_equal +from ..models.utils import check_logprobs_close, check_outputs_equal MODELS = [ "facebook/opt-125m", "meta-llama/Llama-2-7b-hf", ] +E5M2_KV_MODELS = [ + "facebook/opt-125m", + "meta-llama/Llama-2-7b-chat-hf", +] +E4M3_KV_MODELS = [ + "meta-llama/Llama-2-7b-chat-hf", "nm-testing/Qwen2-1.5B-Instruct-FP8-K-V", + "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme" +] +KV_CACHE_QUANTIZATION_PATHS = { + "meta-llama/Llama-2-7b-chat-hf": + "./tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json" +} @pytest.mark.parametrize("model", MODELS) @@ -35,12 +48,12 @@ def test_models( enforce_eager: bool, tensor_parallel_size: int, ) -> None: - max_num_seqs = min(chunked_prefill_token_size, 256) - enable_chunked_prefill = False - max_num_batched_tokens = None - if chunked_prefill_token_size != -1: - enable_chunked_prefill = True - max_num_batched_tokens = chunked_prefill_token_size + """ + Checks exact match decode between huggingface model and vllm runner with + chunked prefill. + """ + max_num_seqs = chunked_prefill_token_size + max_num_batched_tokens = chunked_prefill_token_size with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) @@ -49,7 +62,7 @@ def test_models( model, dtype=dtype, max_num_batched_tokens=max_num_batched_tokens, - enable_chunked_prefill=enable_chunked_prefill, + enable_chunked_prefill=True, tensor_parallel_size=tensor_parallel_size, enforce_eager=enforce_eager, max_num_seqs=max_num_seqs, @@ -62,3 +75,78 @@ def test_models( name_0="hf", name_1="vllm", ) + + +@pytest.mark.parametrize("kv_cache_dtype,model", + [("fp8_e5m2", m) + for m in E5M2_KV_MODELS] + [("fp8_e4m3", m) + for m in E4M3_KV_MODELS]) +# Due to low-precision numerical divergence, we only test logprob of 4 tokens +@pytest.mark.parametrize("max_tokens", [4]) +@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16]) +@pytest.mark.parametrize("enforce_eager", [False, True]) +# NOTE: Increasing this in this suite will fail CI because we currently cannot +# reset distributed env properly. Use a value > 1 just when you test. +@pytest.mark.parametrize("tensor_parallel_size", [1]) +def test_models_with_fp8_kv_cache( + vllm_runner, + example_prompts, + kv_cache_dtype: str, + model: str, + max_tokens: int, + chunked_prefill_token_size: int, + enforce_eager: bool, + tensor_parallel_size: int, +) -> None: + """ + Only checks log probs match between chunked-prefill and + non-chunked-prefill version of vLLM model runner. + + This test is used when there is discrepancy in kernels + / numerics (e.g. when using lower-precision types like FP8). + """ + NUM_LOG_PROBS = 8 + + if model == "facebook/opt-125m": + pytest.skip( + "#7378: CUDA illegal memory access (undiagnosed) facebook/opt-125m" + ) + + max_num_seqs = chunked_prefill_token_size + max_num_batched_tokens = chunked_prefill_token_size + + extra_kwargs = {} + if model in KV_CACHE_QUANTIZATION_PATHS: + extra_kwargs["quantization_param_path"] = KV_CACHE_QUANTIZATION_PATHS[ + model] + + with vllm_runner( + model, + tensor_parallel_size=tensor_parallel_size, + enforce_eager=enforce_eager, + max_num_seqs=max_num_seqs, + kv_cache_dtype=kv_cache_dtype, + **extra_kwargs, + ) as vllm_model: + no_chunked_prefill_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, NUM_LOG_PROBS) + + with vllm_runner( + model, + max_num_batched_tokens=max_num_batched_tokens, + enable_chunked_prefill=True, + tensor_parallel_size=tensor_parallel_size, + enforce_eager=enforce_eager, + max_num_seqs=max_num_seqs, + kv_cache_dtype=kv_cache_dtype, + **extra_kwargs, + ) as vllm_model: + chunked_prefill_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, NUM_LOG_PROBS) + + check_logprobs_close( + outputs_0_lst=no_chunked_prefill_outputs, + outputs_1_lst=chunked_prefill_outputs, + name_0="no_chunked_prefill", + name_1="chunked_prefill", + ) diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index 99fda8364dc0..60f9a4dc9f90 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -9,6 +9,7 @@ from vllm.attention.backends.xformers import _make_alibi_bias from vllm.attention.ops.prefix_prefill import context_attention_fwd +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE NUM_HEADS = [64] NUM_QUERIES_PER_KV = [1, 8, 64] @@ -18,12 +19,14 @@ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] SLIDING_WINDOW = [0, 16, 64, 128, 256, 512, 2048] +KV_CACHE_DTYPES = ["auto", "fp8", "fp8_e5m2"] @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES) @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("sliding_window", SLIDING_WINDOW) @torch.inference_mode() @@ -33,6 +36,7 @@ def test_contexted_kv_attention( head_size: int, sliding_window: int, dtype: torch.dtype, + kv_cache_dtype: str, device: str, ) -> None: random.seed(0) @@ -67,16 +71,20 @@ def test_contexted_kv_attention( kv.uniform_(-1e-3, 1e-3) key, value = kv.unbind(dim=1) + if kv_cache_dtype == "auto": + cache_dtype = dtype + else: + cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype] k_cache = torch.zeros(cache_size, block_size, num_kv_heads, head_size, - dtype=dtype) + dtype=cache_dtype) v_cache = torch.zeros(cache_size, block_size, num_kv_heads, head_size, - dtype=dtype) + dtype=cache_dtype) k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) values = torch.arange(0, cache_size, dtype=torch.long) @@ -132,6 +140,7 @@ def test_contexted_kv_attention( k, v, output, + kv_cache_dtype, k_cache, v_cache, block_table, @@ -146,6 +155,7 @@ def test_contexted_kv_attention( k, v, output, + kv_cache_dtype, k_cache, v_cache, block_table, @@ -208,13 +218,15 @@ def test_contexted_kv_attention( end_time = time.time() print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") output_ref = output_ref.reshape(output.shape) - assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) + atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6 + torch.testing.assert_close(output, output_ref, atol=atol, rtol=0) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES) @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() def test_contexted_kv_attention_alibi( @@ -222,6 +234,7 @@ def test_contexted_kv_attention_alibi( num_queries_per_kv: int, head_size: int, dtype: torch.dtype, + kv_cache_dtype: str, device: str, ) -> None: random.seed(0) @@ -282,17 +295,20 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype) kv.uniform_(-1e-3, 1e-3) key, value = kv.unbind(dim=1) - + if kv_cache_dtype == "auto": + cache_dtype = dtype + else: + cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype] k_cache = torch.zeros(cache_size, block_size, num_kv_heads, head_size, - dtype=dtype) + dtype=cache_dtype) v_cache = torch.zeros(cache_size, block_size, num_kv_heads, head_size, - dtype=dtype) + dtype=cache_dtype) k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) values = torch.arange(0, cache_size, dtype=torch.long) @@ -348,6 +364,7 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: k, v, output, + kv_cache_dtype, k_cache, v_cache, block_table, @@ -362,6 +379,7 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: k, v, output, + kv_cache_dtype, k_cache, v_cache, block_table, @@ -447,4 +465,5 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: torch.cuda.synchronize() end_time = time.time() print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") - assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) + atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6 + torch.testing.assert_close(output, output_ref, atol=atol, rtol=0) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 26e9b8a93fb9..e305679231d0 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -459,6 +459,7 @@ def forward( query, key, value, + self.kv_cache_dtype, key_cache, value_cache, prefill_meta.block_tables, @@ -468,6 +469,8 @@ def forward( prefill_meta.max_query_len, self.alibi_slopes, self.sliding_window[0], + k_scale, + v_scale, ) if decode_meta := attn_metadata.decode_metadata: diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 24ba5fc72540..7e36509bff86 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -604,6 +604,7 @@ def forward( query, key, value, + self.kv_cache_dtype, key_cache, value_cache, prefill_meta.block_tables, @@ -613,6 +614,8 @@ def forward( prefill_meta.max_query_len, self.alibi_slopes, self.sliding_window, + k_scale, + v_scale, ) assert output[:num_prefill_tokens].shape == out.shape output[:num_prefill_tokens] = out diff --git a/vllm/attention/ops/ipex_attn.py b/vllm/attention/ops/ipex_attn.py index 81d308c4d4e2..6b270ffd5bc0 100644 --- a/vllm/attention/ops/ipex_attn.py +++ b/vllm/attention/ops/ipex_attn.py @@ -90,6 +90,7 @@ def forward_prefix( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + kv_cache_dtype: str, key_cache: torch.Tensor, value_cache: torch.Tensor, block_tables: torch.Tensor, diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index e88963ade16c..92023d5b75f5 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -194,6 +194,7 @@ def forward_prefix( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + kv_cache_dtype: str, key_cache: torch.Tensor, value_cache: torch.Tensor, block_tables: torch.Tensor, @@ -203,6 +204,8 @@ def forward_prefix( max_query_len: int, alibi_slopes: Optional[torch.Tensor], sliding_window: Optional[int], + k_scale: float, + v_scale: float, ) -> torch.Tensor: output = torch.empty_like(query) context_attention_fwd( @@ -210,6 +213,7 @@ def forward_prefix( key, value, output, + kv_cache_dtype, key_cache, value_cache, block_tables, @@ -218,6 +222,8 @@ def forward_prefix( seq_lens_tensor, context_lens, max_query_len, + k_scale, + v_scale, alibi_slopes, sliding_window, ) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 4577d84db18a..558b2f3eeac7 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -18,6 +18,8 @@ def _fwd_kernel( V_cache, B_Loc, sm_scale, + k_scale, + v_scale, B_Start_Loc, B_Seqlen, B_Ctxlen, @@ -117,10 +119,15 @@ def _fwd_kernel( cur_kv_head * stride_v_cache_h + offs_d[None, :] * stride_v_cache_d + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) - k = tl.load(K_cache + off_k, - mask=dim_mask[:, None] & - ((start_n + offs_n[None, :]) < cur_batch_ctx_len), - other=0.0) # [D,N] + k_load = tl.load(K_cache + off_k, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < cur_batch_ctx_len), + other=0.0) # [D,N] + + if k_load.dtype.is_fp8(): + k = (k_load.to(tl.float32) * k_scale).to(q.dtype) + else: + k = k_load qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # [M,N] qk += tl.dot(q, k) @@ -161,12 +168,16 @@ def _fwd_kernel( acc_scale = l_i / l_i_new * alpha acc = acc * acc_scale[:, None] # update acc - v = tl.load(V_cache + off_v, - mask=dim_mask[None, :] & - ((start_n + offs_n[:, None]) < cur_batch_ctx_len), - other=0.0) # [N,D] - + v_load = tl.load(V_cache + off_v, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < cur_batch_ctx_len), + other=0.0) # [N,D] + if v_load.dtype.is_fp8(): + v = (v_load.to(tl.float32) * v_scale).to(q.dtype) + else: + v = v_load p = p.to(v.dtype) + acc += tl.dot(p, v) # # update m_i and l_i l_i = l_i_new @@ -225,8 +236,8 @@ def _fwd_kernel( mask=dim_mask[None, :] & ((start_n + offs_n[:, None]) < cur_batch_query_len), other=0.0) - p = p.to(v.dtype) + acc += tl.dot(p, v) # update m_i and l_i l_i = l_i_new @@ -336,7 +347,6 @@ def _fwd_kernel_flash_attn_v2( k = tl.load(K_cache + off_k, mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, other=0.0) - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, @@ -442,6 +452,8 @@ def _fwd_kernel_alibi( V_cache, B_Loc, sm_scale, + k_scale, + v_scale, B_Start_Loc, B_Seqlen, B_Ctxlen, @@ -537,10 +549,15 @@ def _fwd_kernel_alibi( cur_kv_head * stride_v_cache_h + offs_d[None, :] * stride_v_cache_d + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) - k = tl.load(K_cache + off_k, - mask=dim_mask[:, None] & - ((start_n + offs_n[None, :]) < cur_batch_ctx_len), - other=0.0) # [D,N] + k_load = tl.load(K_cache + off_k, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < cur_batch_ctx_len), + other=0.0) # [D,N] + + if k_load.dtype.is_fp8(): + k = (k_load.to(tl.float32) * k_scale).to(q.dtype) + else: + k = k_load qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) @@ -573,12 +590,16 @@ def _fwd_kernel_alibi( # acc_scale = l_i / l_i_new * alpha acc = acc * acc_scale[:, None] # update acc - v = tl.load(V_cache + off_v, - mask=dim_mask[None, :] & - ((start_n + offs_n[:, None]) < cur_batch_ctx_len), - other=0.0) - + v_load = tl.load(V_cache + off_v, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < cur_batch_ctx_len), + other=0.0) + if v_load.dtype.is_fp8(): + v = (v_load.to(tl.float32) * v_scale).to(q.dtype) + else: + v = v_load p = p.to(v.dtype) + acc += tl.dot(p, v, allow_tf32=False) # update m_i and l_i l_i = l_i_new @@ -650,8 +671,8 @@ def _fwd_kernel_alibi( ((start_n + offs_n[:, None]) < cur_batch_seq_len - cur_batch_ctx_len), other=0.0) - p = p.to(v.dtype) + acc += tl.dot(p, v, allow_tf32=False) # update m_i and l_i l_i = l_i_new @@ -675,6 +696,7 @@ def context_attention_fwd(q, k, v, o, + kv_cache_dtype: str, k_cache, v_cache, b_loc, @@ -682,17 +704,41 @@ def context_attention_fwd(q, b_seq_len, b_ctx_len, max_input_len, + k_scale: float = 1.0, + v_scale: float = 1.0, alibi_slopes=None, sliding_window=None): cap = current_platform.get_device_capability() BLOCK = 128 if cap[0] >= 8 else 64 + NUM_WARPS = 8 # need to reduce num. blocks when using fp32 # due to increased use of GPU shared memory if q.dtype is torch.float32: BLOCK = BLOCK // 2 + # Conversion of FP8 Tensor from uint8 storage to + # appropriate torch.dtype for interpretation by Triton + if "fp8" in kv_cache_dtype: + assert (k_cache.dtype == torch.uint8) + assert (v_cache.dtype == torch.uint8) + + if kv_cache_dtype in ("fp8", "fp8_e4m3"): + target_dtype = torch.float8_e4m3fn + elif kv_cache_dtype == "fp8_e5m2": + target_dtype = torch.float8_e5m2 + else: + raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype) + + k_cache = k_cache.view(target_dtype) + v_cache = v_cache.view(target_dtype) + + if (k_cache.dtype == torch.uint8 + or v_cache.dtype == torch.uint8 and kv_cache_dtype == "auto"): + raise ValueError("kv_cache_dtype='auto' unsupported for\ + FP8 KV Cache prefill kernel") + # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv @@ -709,7 +755,6 @@ def context_attention_fwd(q, if sliding_window is None or sliding_window <= 0: sliding_window = 0 - num_warps = 8 if Lk <= 64 else 8 if alibi_slopes is not None: _fwd_kernel_alibi[grid]( q, @@ -719,6 +764,8 @@ def context_attention_fwd(q, v_cache, b_loc, sm_scale, + k_scale, + v_scale, b_start_loc, b_seq_len, b_ctx_len, @@ -757,7 +804,7 @@ def context_attention_fwd(q, BLOCK_DMODEL=Lk, BLOCK_DMODEL_PADDED=Lk_padded, BLOCK_N=BLOCK, - num_warps=num_warps, + num_warps=NUM_WARPS, num_stages=1, ) return @@ -770,6 +817,8 @@ def context_attention_fwd(q, v_cache, b_loc, sm_scale, + k_scale, + v_scale, b_start_loc, b_seq_len, b_ctx_len, @@ -807,7 +856,7 @@ def context_attention_fwd(q, BLOCK_DMODEL_PADDED=Lk_padded, BLOCK_N=BLOCK, SLIDING_WINDOW=sliding_window, - num_warps=num_warps, + num_warps=NUM_WARPS, num_stages=1, ) return diff --git a/vllm/config.py b/vllm/config.py index 4207466cfc5c..0ebe8110be55 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -546,10 +546,6 @@ def _verify_prefix_caching(self) -> None: raise NotImplementedError( "Prefix caching is not supported with sliding window. " "Run with --disable-sliding-window to use prefix caching.") - if self.cache_dtype == "fp8": - raise NotImplementedError( - "Prefix caching is not supported for fp8 cache_dtype. " - "Run with --kv-cache-dtype auto to use prefix caching.") def verify_with_parallel_config( self,