Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
/vllm/_version.py

# vllm-flash-attn built from source
vllm/vllm_flash_attn/
vllm/vllm_flash_attn/*
!vllm/vllm_flash_attn/fa_utils.py

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
16 changes: 12 additions & 4 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@
compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens,
get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
is_all_encoder_attn_metadata_set, is_block_tables_empty)
from vllm.fa_utils import get_flash_attn_version
from vllm.logger import init_logger
from vllm.multimodal import MultiModalPlaceholderMap
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
flash_attn_with_kvcache)
from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8,
get_flash_attn_version)

if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
Expand Down Expand Up @@ -632,10 +633,13 @@ def __init__(
self.kv_cache_dtype = kv_cache_dtype
self.vllm_flash_attn_version = get_flash_attn_version(
requires_alibi=self.alibi_slopes is not None)
if (is_quantized_kv_cache(self.kv_cache_dtype)
and self.vllm_flash_attn_version != 3):
if is_quantized_kv_cache(self.kv_cache_dtype) and (
not self.kv_cache_dtype.startswith("fp8")
or not flash_attn_supports_fp8()):
raise NotImplementedError(
"Only FlashAttention3 supports FP8 KV cache")
f"FlashAttention does not support {self.kv_cache_dtype} "
"kv-cache on this device "
f"(FA supports fp8 = {flash_attn_supports_fp8()}).")
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0
Expand Down Expand Up @@ -704,6 +708,10 @@ def forward(
logits_soft_cap: Optional[float] = self.logits_soft_cap
fp8_attention = kv_cache_dtype.startswith("fp8")

if fp8_attention and not flash_attn_supports_fp8():
raise NotImplementedError(
"FlashAttention does not support FP8 kv-cache on this device.")

if kv_cache.numel() > 0:
key_cache = kv_cache[0]
value_cache = kv_cache[1]
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,6 @@
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
from vllm.fa_utils import get_flash_attn_version
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, RowParallelLinear,
UnquantizedLinearMethod)
Expand All @@ -214,6 +213,7 @@
from vllm.multimodal import MultiModalPlaceholderMap
from vllm.platforms import current_platform
from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down
from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version

try:
from vllm.vllm_flash_attn import flash_attn_varlen_func
Expand Down
4 changes: 0 additions & 4 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1150,10 +1150,6 @@ def _verify_cache_dtype(self) -> None:
if self.cache_dtype == "auto":
pass
elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"):
if envs.VLLM_USE_V1:
raise NotImplementedError(
"V1 does not yet support fp8 KV cache. "
"Set VLLM_USE_V1=0 to enable fp8 kv cache.")
logger.info(
"Using fp8 data type to store kv cache. It reduces the GPU "
"memory footprint and boosts the performance. "
Expand Down
17 changes: 14 additions & 3 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1500,9 +1500,20 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:

# No Fp8 KV cache so far.
if self.kv_cache_dtype != "auto":
_raise_or_fallback(feature_name="--kv-cache-dtype",
recommend_to_remove=False)
return False
fp8_attention = self.kv_cache_dtype.startswith("fp8")
will_use_fa = (
current_platform.is_cuda()
and not envs.is_set("VLLM_ATTENTION_BACKEND")
) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
supported = False
if fp8_attention and will_use_fa:
from vllm.vllm_flash_attn.fa_utils import (
flash_attn_supports_fp8)
supported = flash_attn_supports_fp8()
if not supported:
_raise_or_fallback(feature_name="--kv-cache-dtype",
recommend_to_remove=False)
return False

# No Prompt Adapter so far.
if self.enable_prompt_adapter:
Expand Down
8 changes: 3 additions & 5 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# import custom ops, trigger op registration
import vllm._C # noqa
import vllm.envs as envs
from vllm.fa_utils import get_flash_attn_version
from vllm.logger import init_logger
from vllm.utils import import_pynvml

Expand Down Expand Up @@ -258,7 +257,7 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
try:
import vllm.vllm_flash_attn # noqa: F401
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend)
FlashAttentionBackend, flash_attn_supports_fp8)

supported_sizes = \
FlashAttentionBackend.get_supported_head_sizes()
Expand All @@ -269,10 +268,9 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
target_backend = _Backend.XFORMERS
fp8_kv_cache = (kv_cache_dtype is not None
and kv_cache_dtype.startswith("fp8"))
if (fp8_kv_cache and get_flash_attn_version() != 3):
if (fp8_kv_cache and not flash_attn_supports_fp8()):
logger.info(
"Cannot use FlashAttention-2 backend for FP8 KV cache."
)
"Cannot use FlashAttention backend for FP8 KV cache.")
logger.warning(
"Please use FlashInfer backend with FP8 KV Cache for "
"better performance by setting environment variable "
Expand Down
10 changes: 6 additions & 4 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
AttentionMetadata, AttentionType,
is_quantized_kv_cache)
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
from vllm.fa_utils import get_flash_attn_version
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cdiv
from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8,
get_flash_attn_version)

if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
Expand Down Expand Up @@ -182,9 +183,6 @@ def __init__(
else:
self.sliding_window = (sliding_window - 1, 0)
self.kv_cache_dtype = kv_cache_dtype
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"FlashAttention V1 with FP8 KV cache not yet supported")
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0
Expand All @@ -206,6 +204,10 @@ def __init__(
"are not implemented for "
"FlashAttentionImpl")
self.vllm_flash_attn_version = get_flash_attn_version()
if is_quantized_kv_cache(self.kv_cache_dtype) \
and not flash_attn_supports_fp8():
raise NotImplementedError(
"FlashAttention does not support fp8 kv-cache on this device.")

def forward(
self,
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,14 +196,14 @@
AttentionMetadata,
MLAAttentionImpl)
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
from vllm.fa_utils import get_flash_attn_version
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, RowParallelLinear,
UnquantizedLinearMethod)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform
from vllm.utils import cdiv, round_down
from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version

try:
from vllm.vllm_flash_attn import flash_attn_varlen_func
Expand Down
6 changes: 6 additions & 0 deletions vllm/fa_utils.py → vllm/vllm_flash_attn/fa_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,9 @@ def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]:
return fa_version
except (ImportError, AssertionError):
return None


def flash_attn_supports_fp8() -> bool:
from vllm.platforms import current_platform
return get_flash_attn_version() == 3 and \
current_platform.get_device_capability().major == 9