Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions vllm/config/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@
structured outputs, speculative decoding, and pipeline parallelism.
"""

split_prefill_from_chunk: bool = False
"""Whether to split the prefill request into pure prefill and chunked prefill in a single

Check failure on line 163 in vllm/config/scheduler.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/config/scheduler.py:163:81: E501 Line too long (93 > 80)
batch."""

def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
Expand Down
19 changes: 9 additions & 10 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,16 +239,11 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
if selected_backend is None or selected_backend == _Backend.FLASH_ATTN:
selected_backend = _Backend.ROCM_FLASH

if envs.VLLM_USE_V1:
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA \
and on_gfx9():
logger.info("Using Flash Attention backend on V1 engine.")
return ("vllm.v1.attention.backends."
"rocm_aiter_fa.AiterFlashAttentionBackend")
else:
logger.info("Using Triton Attention backend on V1 engine.")
return ("vllm.v1.attention.backends."
"triton_attn.TritonAttentionBackend")
if envs.VLLM_USE_V1:
from vllm.v1.attention.backends.rocm_mha_backend_helper import get_rocm_mha_backend_selection
backend_class_path, _ = get_rocm_mha_backend_selection()
if backend_class_path:
return backend_class_path
if selected_backend == _Backend.ROCM_FLASH:
if not cls.has_device_capability(90):
# not Instinct series GPUs.
Expand Down Expand Up @@ -346,6 +341,10 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
else:
parallel_config.worker_cls = "vllm.worker.worker.Worker"

if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA:
# enable the request reorder if we are using AITER MHA for calculation
vllm_config.scheduler_config.split_prefill_from_chunk = True

@classmethod
def verify_model_arch(cls, model_arch: str) -> None:
if model_arch in _ROCM_UNSUPPORTED_MODELS:
Expand Down
Loading
Loading