@@ -65,7 +65,7 @@ def check_xformers_availability():
6565 return USE_XFORMERS_OPS
6666
6767
68- def check_upstream_fa_availability (dtype : torch .dtype ):
68+ def check_upstream_fa_availability (dtype : torch .dtype ) -> bool :
6969 if (
7070 dtype in (torch .float16 , torch .bfloat16 )
7171 and current_platform .is_cuda ()
@@ -80,26 +80,31 @@ def check_upstream_fa_availability(dtype: torch.dtype):
8080 return find_spec ("flash_attn" ) is not None
8181 return False
8282
83+ def is_fa_backend (backend : _Backend ) -> bool :
84+ return backend in {_Backend .FLASH_ATTN , _Backend .ROCM_AITER_FA }
8385
8486def maybe_get_vit_flash_attn_backend (
85- attn_backend : _Backend , use_upstream_fa : bool
86- ) -> tuple [_Backend , Callable ]:
87- if (
88- attn_backend != _Backend .FLASH_ATTN
89- and attn_backend != _Backend .ROCM_AITER_FA
90- and check_upstream_fa_availability (torch .get_default_dtype ())
91- ):
92- attn_backend = _Backend .FLASH_ATTN
93- use_upstream_fa = True
87+ attn_backend : _Backend ,
88+ try_switch_to_fa : bool = False ,
89+ try_use_upstream_fa : bool = False ) -> tuple [_Backend , Callable ]:
9490
95- if current_platform .is_rocm () and attn_backend == _Backend .FLASH_ATTN :
96- use_upstream_fa = True
97-
98- if attn_backend in {_Backend .FLASH_ATTN , _Backend .ROCM_AITER_FA }:
91+ if try_switch_to_fa and not is_fa_backend (attn_backend ):
92+ attn_backend = _Backend .FLASH_ATTN
93+
94+ if current_platform .is_rocm () and \
95+ attn_backend == _Backend .FLASH_ATTN :
96+ # Always try upstream on ROCM.
97+ logger .info_once ("maybe_get_vit_flash_attn_backend: forcing upstream FlashAttn on ROCM." )
98+ try_use_upstream_fa = True
99+
100+ if is_fa_backend (attn_backend ):
99101 if attn_backend == _Backend .ROCM_AITER_FA :
100102 from aiter import flash_attn_varlen_func
101103 else :
102- if use_upstream_fa :
104+ if try_use_upstream_fa :
105+ assert check_upstream_fa_availability ( \
106+ torch .get_default_dtype ()), \
107+ "Upstream FlashAttn is not available."
103108 from flash_attn import flash_attn_varlen_func
104109 else :
105110 from vllm .vllm_flash_attn import flash_attn_varlen_func
@@ -108,7 +113,6 @@ def maybe_get_vit_flash_attn_backend(
108113
109114 return attn_backend , flash_attn_varlen_func
110115
111-
112116class Attention (nn .Module , AttentionLayerBase ):
113117 """Attention layer.
114118
@@ -428,11 +432,6 @@ def __init__(
428432 # Determine the attention backend
429433 backend = get_vit_attn_backend (head_size = head_size , dtype = dtype )
430434
431- # Some auto-selected backends can be upgraded
432- # to upstream flash attention if available.
433- # If vllm native fa is selected, we use it directly.
434- use_upstream_fa = False
435-
436435 if current_platform .is_xpu ():
437436 # currently, only torch_sdpa is supported on xpu
438437 self .attn_backend = _Backend .TORCH_SDPA
@@ -450,30 +449,20 @@ def __init__(
450449 else _Backend .TORCH_SDPA
451450 )
452451
453- self .attn_backend , self ._flash_attn_varlen_func = (
454- maybe_get_vit_flash_attn_backend (
452+ self .attn_backend , self ._flash_attn_varlen_func \
453+ = maybe_get_vit_flash_attn_backend (
455454 self .attn_backend ,
456- use_upstream_fa ,
455+ try_switch_to_fa = False ,
456+ try_use_upstream_fa = False ,
457457 )
458- )
459458
460459 if self .attn_backend == _Backend .XFORMERS and not check_xformers_availability ():
461460 self .attn_backend = _Backend .TORCH_SDPA
462461
463- self .is_flash_attn_backend = self .attn_backend in {
464- _Backend .FLASH_ATTN ,
465- _Backend .ROCM_AITER_FA ,
466- }
467-
468- # this condition is just to make sure that the
469- # use_upstream_fa in the log is correct
470- if current_platform .is_rocm () and self .attn_backend == _Backend .FLASH_ATTN :
471- use_upstream_fa = True
462+ self .is_flash_attn_backend = is_fa_backend (self .attn_backend )
472463
473464 logger .info_once (
474- f"MultiHeadAttention attn_backend: { self .attn_backend } , "
475- f"use_upstream_fa: { use_upstream_fa } "
476- )
465+ f"MultiHeadAttention attn_backend: { self .attn_backend } " )
477466
478467 def forward (
479468 self ,
0 commit comments