Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 11 additions & 33 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,16 @@
AttentionMetadataBuilder,
AttentionType)
from vllm.attention.backends.utils import (
PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping,
compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens,
get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
is_all_encoder_attn_metadata_set, is_block_tables_empty)
from vllm.envs import VLLM_FLASH_ATTN_VERSION
PAD_SLOT_ID, VLLM_FLASH_ATTN_VERSION, CommonAttentionState,
compute_slot_mapping, compute_slot_mapping_start_idx,
get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args,
is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set,
is_block_tables_empty)
from vllm.logger import init_logger
from vllm.multimodal import MultiModalPlaceholderMap
from vllm.platforms import current_platform
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
flash_attn_varlen_func,
flash_attn_with_kvcache,
is_fa_version_supported)
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
flash_attn_with_kvcache)

if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
Expand Down Expand Up @@ -644,25 +641,6 @@ def __init__(
f"Supported head sizes are: {support_head_sizes}.")
self.attn_type = attn_type

# if hopper default to FA3, otherwise stick to FA2 for now
# TODO(lucas): profile FA3 on ampere to see if it makes sense to
# use FA3 as default for both
if current_platform.get_device_capability()[0] >= 9:
self.fa_version = 3 if is_fa_version_supported(3) else 2
else:
self.fa_version = 2

if VLLM_FLASH_ATTN_VERSION is not None:
assert VLLM_FLASH_ATTN_VERSION in [2, 3]
self.fa_version = VLLM_FLASH_ATTN_VERSION

if not is_fa_version_supported(self.fa_version):
logger.error("Cannot use FA version %d is not supported due to %s",
self.fa_version,
fa_version_unsupported_reason(self.fa_version))

assert is_fa_version_supported(self.fa_version)

def forward(
self,
layer: AttentionLayer,
Expand Down Expand Up @@ -781,7 +759,7 @@ def forward(
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
out=prefill_output,
fa_version=self.fa_version,
fa_version=VLLM_FLASH_ATTN_VERSION,
)
else:
# prefix-enabled attention
Expand All @@ -804,7 +782,7 @@ def forward(
block_table=prefill_meta.block_tables,
softcap=logits_soft_cap,
out=prefill_output,
fa_version=self.fa_version,
fa_version=VLLM_FLASH_ATTN_VERSION,
)

if decode_meta := attn_metadata.decode_metadata:
Expand Down Expand Up @@ -833,7 +811,7 @@ def forward(
softcap=logits_soft_cap,
block_table=decode_meta.block_tables,
out=decode_output,
fa_version=self.fa_version,
fa_version=VLLM_FLASH_ATTN_VERSION,
)
else:
# Use flash_attn_with_kvcache for normal decoding.
Expand All @@ -854,7 +832,7 @@ def forward(
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
out=decode_output.unsqueeze(1),
fa_version=self.fa_version,
fa_version=VLLM_FLASH_ATTN_VERSION,
)
return output

Expand Down
2 changes: 2 additions & 0 deletions vllm/attention/backends/mla/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vllm.attention.backends.abstract import (AttentionLayer,
AttentionMetadata,
MLAAttentionImpl, T)
from vllm.attention.backends.utils import VLLM_FLASH_ATTN_VERSION
from vllm.distributed import (get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
Expand Down Expand Up @@ -533,6 +534,7 @@ def _forward_prefill_flash(
max_seqlen_k=max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
fa_version=VLLM_FLASH_ATTN_VERSION,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This breaks in case vllm.vllm_flash_attn is not found and instead
from flash_attn import flash_attn_varlen_func
happens. That flash_attn_varlen_func doesn't have fa_version

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know if we fixed it? I had this problem during the latest commit.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gshtras @gaocegege how did you guys fix?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used the latest commit, and it worked.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks :) Will try again

)
attn_output = attn_output\
.view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
Expand Down
34 changes: 34 additions & 0 deletions vllm/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,17 @@
import numpy as np
import torch

from vllm import envs
from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
AttentionState)
from vllm.attention.backends.abstract import AttentionType
from vllm.logger import logging
from vllm.multimodal import MultiModalPlaceholderMap
from vllm.platforms import current_platform
from vllm.utils import async_tensor_h2d, make_tensor_with_pad

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from vllm.worker.model_runner_base import ModelRunnerBase

Expand Down Expand Up @@ -580,3 +585,32 @@ def get_num_prefill_decode_query_kv_tokens(

return (num_prefill_query_tokens, num_prefill_kv_tokens,
num_decode_query_tokens)


try:
from vllm.vllm_flash_attn.flash_attn_interface import (
fa_version_unsupported_reason, is_fa_version_supported)

def flash_attn_version():
# if hopper default to FA3, otherwise stick to FA2 for now
# TODO(lucas): profile FA3 on ampere to see if it makes sense to
# use FA3 as default for both
if current_platform.get_device_capability()[0] >= 9:
fa_version = 3 if is_fa_version_supported(3) else 2
else:
fa_version = 2

if envs.VLLM_FLASH_ATTN_VERSION is not None:
assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3]
fa_version = envs.VLLM_FLASH_ATTN_VERSION

if not is_fa_version_supported(fa_version):
logger.error("Cannot use FA version %d is not supported due to %s",
fa_version, fa_version_unsupported_reason(fa_version))

assert is_fa_version_supported(fa_version)
return fa_version

VLLM_FLASH_ATTN_VERSION = flash_attn_version()
except ImportError:
VLLM_FLASH_ATTN_VERSION = None
30 changes: 4 additions & 26 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,10 @@

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.envs import VLLM_FLASH_ATTN_VERSION
from vllm.attention.backends.utils import VLLM_FLASH_ATTN_VERSION
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cdiv
from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
flash_attn_varlen_func,
is_fa_version_supported)
from vllm.vllm_flash_attn import flash_attn_varlen_func

logger = init_logger(__name__)

Expand Down Expand Up @@ -136,25 +133,6 @@ def __init__(
"are not implemented for "
"FlashAttentionImpl")

# if hopper default to FA3, otherwise stick to FA2 for now
# TODO(lucas): profile FA3 on ampere to see if it makes sense to
# use FA3 as default for both
if current_platform.get_device_capability()[0] >= 9:
self.fa_version = 3 if is_fa_version_supported(3) else 2
else:
self.fa_version = 2

if VLLM_FLASH_ATTN_VERSION is not None:
assert VLLM_FLASH_ATTN_VERSION in [2, 3]
self.fa_version = VLLM_FLASH_ATTN_VERSION

if not is_fa_version_supported(self.fa_version):
logger.error("Cannot use FA version %d is not supported due to %s",
self.fa_version,
fa_version_unsupported_reason(self.fa_version))

assert is_fa_version_supported(self.fa_version)

def forward(
self,
layer: torch.nn.Module,
Expand Down Expand Up @@ -227,7 +205,7 @@ def forward(
window_size=self.sliding_window,
block_table=attn_metadata.block_table,
softcap=self.logits_soft_cap,
fa_version=self.fa_version,
fa_version=VLLM_FLASH_ATTN_VERSION,
)
return output

Expand All @@ -249,7 +227,7 @@ def forward(
logits_soft_cap=self.logits_soft_cap,
block_table=attn_metadata.block_table,
common_prefix_len=attn_metadata.common_prefix_len,
fa_version=self.fa_version,
fa_version=VLLM_FLASH_ATTN_VERSION,
)
return output

Expand Down