diff --git a/.gitignore b/.gitignore index e40752f4dea0..6f5cbd0733da 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,8 @@ /vllm/_version.py # vllm-flash-attn built from source -vllm/vllm_flash_attn/ +vllm/vllm_flash_attn/* +!vllm/vllm_flash_attn/fa_utils.py # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 4cb0b916739a..27bd292b51f2 100755 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -22,12 +22,13 @@ compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args, is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set, is_block_tables_empty) -from vllm.fa_utils import get_flash_attn_version from vllm.logger import init_logger from vllm.multimodal import MultiModalPlaceholderMap from vllm.utils import async_tensor_h2d, make_tensor_with_pad from vllm.vllm_flash_attn import (flash_attn_varlen_func, flash_attn_with_kvcache) +from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8, + get_flash_attn_version) if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, @@ -632,10 +633,13 @@ def __init__( self.kv_cache_dtype = kv_cache_dtype self.vllm_flash_attn_version = get_flash_attn_version( requires_alibi=self.alibi_slopes is not None) - if (is_quantized_kv_cache(self.kv_cache_dtype) - and self.vllm_flash_attn_version != 3): + if is_quantized_kv_cache(self.kv_cache_dtype) and ( + not self.kv_cache_dtype.startswith("fp8") + or not flash_attn_supports_fp8()): raise NotImplementedError( - "Only FlashAttention3 supports FP8 KV cache") + f"FlashAttention does not support {self.kv_cache_dtype} " + "kv-cache on this device " + f"(FA supports fp8 = {flash_attn_supports_fp8()}).") if logits_soft_cap is None: # In flash-attn, setting logits_soft_cap as 0 means no soft cap. logits_soft_cap = 0 @@ -704,6 +708,10 @@ def forward( logits_soft_cap: Optional[float] = self.logits_soft_cap fp8_attention = kv_cache_dtype.startswith("fp8") + if fp8_attention and not flash_attn_supports_fp8(): + raise NotImplementedError( + "FlashAttention does not support FP8 kv-cache on this device.") + if kv_cache.numel() > 0: key_cache = kv_cache[0] value_cache = kv_cache[1] diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 258090d3e80e..1b1ab314c01e 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -205,7 +205,6 @@ compute_slot_mapping_start_idx, is_block_tables_empty) from vllm.attention.ops.triton_merge_attn_states import merge_attn_states -from vllm.fa_utils import get_flash_attn_version from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearBase, RowParallelLinear, UnquantizedLinearMethod) @@ -214,6 +213,7 @@ from vllm.multimodal import MultiModalPlaceholderMap from vllm.platforms import current_platform from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down +from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version try: from vllm.vllm_flash_attn import flash_attn_varlen_func diff --git a/vllm/config.py b/vllm/config.py index 1f7147f7cfd4..3f1307f9c6e0 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1150,10 +1150,6 @@ def _verify_cache_dtype(self) -> None: if self.cache_dtype == "auto": pass elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"): - if envs.VLLM_USE_V1: - raise NotImplementedError( - "V1 does not yet support fp8 KV cache. " - "Set VLLM_USE_V1=0 to enable fp8 kv cache.") logger.info( "Using fp8 data type to store kv cache. It reduces the GPU " "memory footprint and boosts the performance. " diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index edfa748b82d7..29f61189187f 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1500,9 +1500,20 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: # No Fp8 KV cache so far. if self.kv_cache_dtype != "auto": - _raise_or_fallback(feature_name="--kv-cache-dtype", - recommend_to_remove=False) - return False + fp8_attention = self.kv_cache_dtype.startswith("fp8") + will_use_fa = ( + current_platform.is_cuda() + and not envs.is_set("VLLM_ATTENTION_BACKEND") + ) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1" + supported = False + if fp8_attention and will_use_fa: + from vllm.vllm_flash_attn.fa_utils import ( + flash_attn_supports_fp8) + supported = flash_attn_supports_fp8() + if not supported: + _raise_or_fallback(feature_name="--kv-cache-dtype", + recommend_to_remove=False) + return False # No Prompt Adapter so far. if self.enable_prompt_adapter: diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 38d8fffd63c0..bb77318092fc 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -14,7 +14,6 @@ # import custom ops, trigger op registration import vllm._C # noqa import vllm.envs as envs -from vllm.fa_utils import get_flash_attn_version from vllm.logger import init_logger from vllm.utils import import_pynvml @@ -258,7 +257,7 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, try: import vllm.vllm_flash_attn # noqa: F401 from vllm.attention.backends.flash_attn import ( # noqa: F401 - FlashAttentionBackend) + FlashAttentionBackend, flash_attn_supports_fp8) supported_sizes = \ FlashAttentionBackend.get_supported_head_sizes() @@ -269,10 +268,9 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, target_backend = _Backend.XFORMERS fp8_kv_cache = (kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8")) - if (fp8_kv_cache and get_flash_attn_version() != 3): + if (fp8_kv_cache and not flash_attn_supports_fp8()): logger.info( - "Cannot use FlashAttention-2 backend for FP8 KV cache." - ) + "Cannot use FlashAttention backend for FP8 KV cache.") logger.warning( "Please use FlashInfer backend with FP8 KV Cache for " "better performance by setting environment variable " diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 27b3aabbc350..92e4ffd0371a 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -11,10 +11,11 @@ AttentionMetadata, AttentionType, is_quantized_kv_cache) from vllm.attention.ops.triton_merge_attn_states import merge_attn_states -from vllm.fa_utils import get_flash_attn_version from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import cdiv +from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8, + get_flash_attn_version) if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -182,9 +183,6 @@ def __init__( else: self.sliding_window = (sliding_window - 1, 0) self.kv_cache_dtype = kv_cache_dtype - if is_quantized_kv_cache(self.kv_cache_dtype): - raise NotImplementedError( - "FlashAttention V1 with FP8 KV cache not yet supported") if logits_soft_cap is None: # In flash-attn, setting logits_soft_cap as 0 means no soft cap. logits_soft_cap = 0 @@ -206,6 +204,10 @@ def __init__( "are not implemented for " "FlashAttentionImpl") self.vllm_flash_attn_version = get_flash_attn_version() + if is_quantized_kv_cache(self.kv_cache_dtype) \ + and not flash_attn_supports_fp8(): + raise NotImplementedError( + "FlashAttention does not support fp8 kv-cache on this device.") def forward( self, diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 31244443108b..1437db7e9d48 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -196,7 +196,6 @@ AttentionMetadata, MLAAttentionImpl) from vllm.attention.ops.triton_merge_attn_states import merge_attn_states -from vllm.fa_utils import get_flash_attn_version from vllm.logger import init_logger from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearBase, RowParallelLinear, @@ -204,6 +203,7 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.platforms import current_platform from vllm.utils import cdiv, round_down +from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version try: from vllm.vllm_flash_attn import flash_attn_varlen_func diff --git a/vllm/fa_utils.py b/vllm/vllm_flash_attn/fa_utils.py similarity index 90% rename from vllm/fa_utils.py rename to vllm/vllm_flash_attn/fa_utils.py index 417653490158..ca88549f3f72 100644 --- a/vllm/fa_utils.py +++ b/vllm/vllm_flash_attn/fa_utils.py @@ -46,3 +46,9 @@ def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]: return fa_version except (ImportError, AssertionError): return None + + +def flash_attn_supports_fp8() -> bool: + from vllm.platforms import current_platform + return get_flash_attn_version() == 3 and \ + current_platform.get_device_capability().major == 9