Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
AttentionMetadataBuilder,
AttentionType)
from vllm.attention.backends.utils import (
PAD_SLOT_ID, VLLM_FLASH_ATTN_VERSION, CommonAttentionState,
compute_slot_mapping, compute_slot_mapping_start_idx,
PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping,
compute_slot_mapping_start_idx, get_flash_attn_version,
get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args,
is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set,
is_block_tables_empty)
Expand Down Expand Up @@ -640,6 +640,7 @@ def __init__(
f"Head size {head_size} is not supported by FlashAttention. "
f"Supported head sizes are: {support_head_sizes}.")
self.attn_type = attn_type
self.vllm_flash_attn_version = get_flash_attn_version()

def forward(
self,
Expand Down Expand Up @@ -759,7 +760,7 @@ def forward(
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
out=prefill_output,
fa_version=VLLM_FLASH_ATTN_VERSION,
fa_version=self.vllm_flash_attn_version,
)
else:
# prefix-enabled attention
Expand All @@ -782,7 +783,7 @@ def forward(
block_table=prefill_meta.block_tables,
softcap=logits_soft_cap,
out=prefill_output,
fa_version=VLLM_FLASH_ATTN_VERSION,
fa_version=self.vllm_flash_attn_version,
)

if decode_meta := attn_metadata.decode_metadata:
Expand Down Expand Up @@ -811,7 +812,7 @@ def forward(
softcap=logits_soft_cap,
block_table=decode_meta.block_tables,
out=decode_output,
fa_version=VLLM_FLASH_ATTN_VERSION,
fa_version=self.vllm_flash_attn_version,
)
else:
# Use flash_attn_with_kvcache for normal decoding.
Expand All @@ -832,7 +833,7 @@ def forward(
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
out=decode_output.unsqueeze(1),
fa_version=VLLM_FLASH_ATTN_VERSION,
fa_version=self.vllm_flash_attn_version,
)
return output

Expand Down
5 changes: 3 additions & 2 deletions vllm/attention/backends/mla/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from vllm.attention.backends.abstract import (AttentionLayer,
AttentionMetadata,
MLAAttentionImpl, T)
from vllm.attention.backends.utils import VLLM_FLASH_ATTN_VERSION
from vllm.attention.backends.utils import get_flash_attn_version
from vllm.distributed import (get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
Expand Down Expand Up @@ -181,6 +181,7 @@ def __init__(
self.q_proj = q_proj
self.kv_b_proj = kv_b_proj
self.o_proj = o_proj
self.vllm_flash_attn_version = get_flash_attn_version()

def _v_up_proj_and_o_proj(self, x):
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
Expand Down Expand Up @@ -515,7 +516,7 @@ def _forward_prefill_flash(
max_seqlen_k=max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
fa_version=VLLM_FLASH_ATTN_VERSION,
fa_version=self.vllm_flash_attn_version,
)
attn_output = attn_output\
.view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
Expand Down
14 changes: 6 additions & 8 deletions vllm/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,11 +587,11 @@ def get_num_prefill_decode_query_kv_tokens(
num_decode_query_tokens)


try:
from vllm.vllm_flash_attn.flash_attn_interface import (
fa_version_unsupported_reason, is_fa_version_supported)
def get_flash_attn_version():
try:
from vllm.vllm_flash_attn.flash_attn_interface import (
fa_version_unsupported_reason, is_fa_version_supported)

def flash_attn_version():
# if hopper default to FA3, otherwise stick to FA2 for now
# TODO(lucas): profile FA3 on ampere to see if it makes sense to
# use FA3 as default for both
Expand All @@ -610,7 +610,5 @@ def flash_attn_version():

assert is_fa_version_supported(fa_version)
return fa_version

VLLM_FLASH_ATTN_VERSION = flash_attn_version()
except (ImportError, AssertionError):
VLLM_FLASH_ATTN_VERSION = None
except (ImportError, AssertionError):
return None
7 changes: 4 additions & 3 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import VLLM_FLASH_ATTN_VERSION
from vllm.attention.backends.utils import get_flash_attn_version
from vllm.logger import init_logger
from vllm.utils import cdiv
from vllm.vllm_flash_attn import flash_attn_varlen_func
Expand Down Expand Up @@ -132,6 +132,7 @@ def __init__(
"encoder/decoder cross-attention "
"are not implemented for "
"FlashAttentionImpl")
self.vllm_flash_attn_version = get_flash_attn_version()

def forward(
self,
Expand Down Expand Up @@ -205,7 +206,7 @@ def forward(
window_size=self.sliding_window,
block_table=attn_metadata.block_table,
softcap=self.logits_soft_cap,
fa_version=VLLM_FLASH_ATTN_VERSION,
fa_version=self.vllm_flash_attn_version,
)
return output

Expand All @@ -227,7 +228,7 @@ def forward(
logits_soft_cap=self.logits_soft_cap,
block_table=attn_metadata.block_table,
common_prefix_len=attn_metadata.common_prefix_len,
fa_version=VLLM_FLASH_ATTN_VERSION,
fa_version=self.vllm_flash_attn_version,
)
return output

Expand Down