|
32 | 32 | get_scheduler_metadata, |
33 | 33 | reshape_and_cache_flash, |
34 | 34 | ) |
35 | | -from vllm.config import VllmConfig, get_layers_from_vllm_config |
| 35 | +from vllm.config import VllmConfig, get_current_vllm_config, get_layers_from_vllm_config |
36 | 36 | from vllm.config.cache import CacheDType |
37 | 37 | from vllm.distributed.parallel_state import get_dcp_group |
38 | 38 | from vllm.logger import init_logger |
|
56 | 56 | class FlashAttentionBackend(AttentionBackend): |
57 | 57 | accept_output_buffer: bool = True |
58 | 58 | supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] |
59 | | - # NOTE(tdoublep): while in principle, FA supports |
60 | | - # MultipleOf(16), these are the block sizes that do not |
61 | | - # suffer from the NaN propagation problem described here: |
62 | | - # https://github.com/Dao-AILab/flash-attention/issues/1974 |
63 | | - supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [16, 32, 64] |
| 59 | + |
| 60 | + @staticmethod |
| 61 | + def get_supported_kernel_block_size() -> list[int | MultipleOf]: |
| 62 | + vllm_config = get_current_vllm_config() |
| 63 | + model_config = vllm_config.model_config |
| 64 | + if model_config and model_config.is_hybrid: |
| 65 | + # NOTE(tdoublep): while in principle, FA supports |
| 66 | + # MultipleOf(16), these are the block sizes that do not |
| 67 | + # suffer from the NaN propagation problem described here: |
| 68 | + # https://github.com/Dao-AILab/flash-attention/issues/1974 |
| 69 | + return [16, 32, 64] |
| 70 | + return [MultipleOf(16)] |
64 | 71 |
|
65 | 72 | @staticmethod |
66 | 73 | def get_name() -> str: |
|
0 commit comments