diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 6a82127acdf7..971fe411695c 100755 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -14,19 +14,16 @@ AttentionMetadataBuilder, AttentionType) from vllm.attention.backends.utils import ( - PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping, - compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens, - get_seq_len_block_table_args, is_all_cross_attn_metadata_set, - is_all_encoder_attn_metadata_set, is_block_tables_empty) -from vllm.envs import VLLM_FLASH_ATTN_VERSION + PAD_SLOT_ID, VLLM_FLASH_ATTN_VERSION, CommonAttentionState, + compute_slot_mapping, compute_slot_mapping_start_idx, + get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args, + is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set, + is_block_tables_empty) from vllm.logger import init_logger from vllm.multimodal import MultiModalPlaceholderMap -from vllm.platforms import current_platform from vllm.utils import async_tensor_h2d, make_tensor_with_pad -from vllm.vllm_flash_attn import (fa_version_unsupported_reason, - flash_attn_varlen_func, - flash_attn_with_kvcache, - is_fa_version_supported) +from vllm.vllm_flash_attn import (flash_attn_varlen_func, + flash_attn_with_kvcache) if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, @@ -644,25 +641,6 @@ def __init__( f"Supported head sizes are: {support_head_sizes}.") self.attn_type = attn_type - # if hopper default to FA3, otherwise stick to FA2 for now - # TODO(lucas): profile FA3 on ampere to see if it makes sense to - # use FA3 as default for both - if current_platform.get_device_capability()[0] >= 9: - self.fa_version = 3 if is_fa_version_supported(3) else 2 - else: - self.fa_version = 2 - - if VLLM_FLASH_ATTN_VERSION is not None: - assert VLLM_FLASH_ATTN_VERSION in [2, 3] - self.fa_version = VLLM_FLASH_ATTN_VERSION - - if not is_fa_version_supported(self.fa_version): - logger.error("Cannot use FA version %d is not supported due to %s", - self.fa_version, - fa_version_unsupported_reason(self.fa_version)) - - assert is_fa_version_supported(self.fa_version) - def forward( self, layer: AttentionLayer, @@ -781,7 +759,7 @@ def forward( alibi_slopes=alibi_slopes, softcap=logits_soft_cap, out=prefill_output, - fa_version=self.fa_version, + fa_version=VLLM_FLASH_ATTN_VERSION, ) else: # prefix-enabled attention @@ -804,7 +782,7 @@ def forward( block_table=prefill_meta.block_tables, softcap=logits_soft_cap, out=prefill_output, - fa_version=self.fa_version, + fa_version=VLLM_FLASH_ATTN_VERSION, ) if decode_meta := attn_metadata.decode_metadata: @@ -833,7 +811,7 @@ def forward( softcap=logits_soft_cap, block_table=decode_meta.block_tables, out=decode_output, - fa_version=self.fa_version, + fa_version=VLLM_FLASH_ATTN_VERSION, ) else: # Use flash_attn_with_kvcache for normal decoding. @@ -854,7 +832,7 @@ def forward( alibi_slopes=alibi_slopes, softcap=logits_soft_cap, out=decode_output.unsqueeze(1), - fa_version=self.fa_version, + fa_version=VLLM_FLASH_ATTN_VERSION, ) return output diff --git a/vllm/attention/backends/mla/utils.py b/vllm/attention/backends/mla/utils.py index cd8c08e5ab47..e1285d1fad3c 100644 --- a/vllm/attention/backends/mla/utils.py +++ b/vllm/attention/backends/mla/utils.py @@ -12,6 +12,7 @@ from vllm.attention.backends.abstract import (AttentionLayer, AttentionMetadata, MLAAttentionImpl, T) +from vllm.attention.backends.utils import VLLM_FLASH_ATTN_VERSION from vllm.distributed import (get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -533,6 +534,7 @@ def _forward_prefill_flash( max_seqlen_k=max_prefill_seq_len, softmax_scale=self.scale, causal=True, + fa_version=VLLM_FLASH_ATTN_VERSION, ) attn_output = attn_output\ .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index ad53e4e70b0f..3c5028a66d58 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -8,12 +8,17 @@ import numpy as np import torch +from vllm import envs from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder, AttentionState) from vllm.attention.backends.abstract import AttentionType +from vllm.logger import logging from vllm.multimodal import MultiModalPlaceholderMap +from vllm.platforms import current_platform from vllm.utils import async_tensor_h2d, make_tensor_with_pad +logger = logging.getLogger(__name__) + if TYPE_CHECKING: from vllm.worker.model_runner_base import ModelRunnerBase @@ -580,3 +585,32 @@ def get_num_prefill_decode_query_kv_tokens( return (num_prefill_query_tokens, num_prefill_kv_tokens, num_decode_query_tokens) + + +try: + from vllm.vllm_flash_attn.flash_attn_interface import ( + fa_version_unsupported_reason, is_fa_version_supported) + + def flash_attn_version(): + # if hopper default to FA3, otherwise stick to FA2 for now + # TODO(lucas): profile FA3 on ampere to see if it makes sense to + # use FA3 as default for both + if current_platform.get_device_capability()[0] >= 9: + fa_version = 3 if is_fa_version_supported(3) else 2 + else: + fa_version = 2 + + if envs.VLLM_FLASH_ATTN_VERSION is not None: + assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3] + fa_version = envs.VLLM_FLASH_ATTN_VERSION + + if not is_fa_version_supported(fa_version): + logger.error("Cannot use FA version %d is not supported due to %s", + fa_version, fa_version_unsupported_reason(fa_version)) + + assert is_fa_version_supported(fa_version) + return fa_version + + VLLM_FLASH_ATTN_VERSION = flash_attn_version() +except ImportError: + VLLM_FLASH_ATTN_VERSION = None diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 837d7faf4370..204afc9f4025 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -10,13 +10,10 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) -from vllm.envs import VLLM_FLASH_ATTN_VERSION +from vllm.attention.backends.utils import VLLM_FLASH_ATTN_VERSION from vllm.logger import init_logger -from vllm.platforms import current_platform from vllm.utils import cdiv -from vllm.vllm_flash_attn import (fa_version_unsupported_reason, - flash_attn_varlen_func, - is_fa_version_supported) +from vllm.vllm_flash_attn import flash_attn_varlen_func logger = init_logger(__name__) @@ -136,25 +133,6 @@ def __init__( "are not implemented for " "FlashAttentionImpl") - # if hopper default to FA3, otherwise stick to FA2 for now - # TODO(lucas): profile FA3 on ampere to see if it makes sense to - # use FA3 as default for both - if current_platform.get_device_capability()[0] >= 9: - self.fa_version = 3 if is_fa_version_supported(3) else 2 - else: - self.fa_version = 2 - - if VLLM_FLASH_ATTN_VERSION is not None: - assert VLLM_FLASH_ATTN_VERSION in [2, 3] - self.fa_version = VLLM_FLASH_ATTN_VERSION - - if not is_fa_version_supported(self.fa_version): - logger.error("Cannot use FA version %d is not supported due to %s", - self.fa_version, - fa_version_unsupported_reason(self.fa_version)) - - assert is_fa_version_supported(self.fa_version) - def forward( self, layer: torch.nn.Module, @@ -227,7 +205,7 @@ def forward( window_size=self.sliding_window, block_table=attn_metadata.block_table, softcap=self.logits_soft_cap, - fa_version=self.fa_version, + fa_version=VLLM_FLASH_ATTN_VERSION, ) return output @@ -249,7 +227,7 @@ def forward( logits_soft_cap=self.logits_soft_cap, block_table=attn_metadata.block_table, common_prefix_len=attn_metadata.common_prefix_len, - fa_version=self.fa_version, + fa_version=VLLM_FLASH_ATTN_VERSION, ) return output