From 7a870d9849210cee8687af2301bcfd3f3a0b919d Mon Sep 17 00:00:00 2001 From: courage17340 Date: Thu, 6 Mar 2025 03:53:22 +0000 Subject: [PATCH] [Bugfix][Core] fix abort_seq_group and memory leak when n>1 Signed-off-by: courage17340 --- vllm/core/scheduler.py | 33 ++++++++++++++++++++++++--------- vllm/engine/llm_engine.py | 8 +++++++- 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 3cdad496e843..e93143c83d9f 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -16,8 +16,9 @@ from vllm.lora.request import LoRARequest from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import (Sequence, SequenceData, SequenceGroup, - SequenceGroupMetadata, SequenceGroupMetadataDelta, - SequenceStage, SequenceStatus) + SequenceGroupBase, SequenceGroupMetadata, + SequenceGroupMetadataDelta, SequenceStage, + SequenceStatus) from vllm.utils import Device, PyObjectCache logger = init_logger(__name__) @@ -561,7 +562,11 @@ def _add_seq_group_to_swapped(self, seq_group: SequenceGroup) -> None: # Only for testing purposes. self.swapped.append(seq_group) - def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: + def abort_seq_group( + self, + request_id: Union[str, Iterable[str]], + seq_id_to_seq_group: Optional[Dict[str, SequenceGroupBase]] = None, + ) -> None: """Aborts a sequence group with the given ID. Check if the sequence group with the given ID @@ -573,21 +578,29 @@ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: Args: request_id: The ID(s) of the sequence group to abort. + seq_id_to_seq_group: helper for groups with n>1 """ if isinstance(request_id, str): request_id = (request_id, ) request_ids = set(request_id) + seq_id_to_seq_group = seq_id_to_seq_group or {} for state_queue in [self.waiting, self.running, self.swapped]: aborted_groups: List[SequenceGroup] = [] for seq_group in state_queue: - if not request_ids: - # Using 'break' here may add two extra iterations, - # but is acceptable to reduce complexity. - break - if seq_group.request_id in request_ids: + # When n>1, seq_group.request_id looks like + # foo_parallel_sample_0, while request_ids is just foo, and we + # should resolve it as real_request_id to match. + if seq_group.request_id in seq_id_to_seq_group: + real_request_id = seq_id_to_seq_group[ + seq_group.request_id].group_id + else: + real_request_id = seq_group.request_id + if real_request_id in request_ids: # Appending aborted group into pending list. aborted_groups.append(seq_group) - request_ids.remove(seq_group.request_id) + # We can't remove real_request_id in request_ids here, + # because there may be other seq groups sharing the same + # real_request_id for aborted_group in aborted_groups: # Remove the sequence group from the state queue. state_queue.remove(aborted_group) @@ -598,6 +611,8 @@ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: continue seq.status = SequenceStatus.FINISHED_ABORTED self.free_seq(seq) + if aborted_group.request_id in seq_id_to_seq_group: + del seq_id_to_seq_group[aborted_group.request_id] self._free_seq_group_cross_attn_blocks(aborted_group) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index f055438d1feb..783275ab41d2 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -887,7 +887,8 @@ def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: >>> engine.abort_request(request_id) """ for scheduler in self.scheduler: - scheduler.abort_seq_group(request_id) + scheduler.abort_seq_group( + request_id, seq_id_to_seq_group=self.seq_id_to_seq_group) def get_model_config(self) -> ModelConfig: """Gets the model configuration.""" @@ -1354,6 +1355,11 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: finished_requests_ids = self.scheduler[ virtual_engine].get_and_reset_finished_requests_ids() + # When n>1, elements in self.seq_id_to_seq_group should be deleted + # here, otherwise memory leaks. + for finished_request_id in finished_requests_ids: + if finished_request_id in self.seq_id_to_seq_group: + del self.seq_id_to_seq_group[finished_request_id] # Maybe switch from async mode to sync mode if not allow_async_output_proc and len(ctx.output_queue) > 0: