@@ -506,3 +506,231 @@ def restore_torch_device_after_vllm_init():
506506 current_device = torch .cuda .current_device ()
507507 if origin_device != current_device :
508508 torch .cuda .set_device (origin_device )
509+
510+
511+ def patch_vllm_memory_leak ():
512+ import vllm
513+ if version .parse (vllm .__version__ ) != version .parse ('0.7.3' ):
514+ return
515+
516+ def patch_vllm_abort_seq_group ():
517+ from vllm .core .scheduler import Scheduler
518+ from typing import Iterable , Dict
519+ from vllm .sequence import SequenceGroupBase , SequenceGroup , SequenceStatus
520+
521+ def new_abort_seq_group (
522+ self ,
523+ request_id : Union [str , Iterable [str ]],
524+ seq_id_to_seq_group : Optional [Dict [str , SequenceGroupBase ]] = None ,
525+ ) -> None :
526+ if isinstance (request_id , str ):
527+ request_id = (request_id , )
528+ request_ids = set (request_id )
529+ seq_id_to_seq_group = seq_id_to_seq_group or {}
530+ for state_queue in [self .waiting , self .running , self .swapped ]:
531+ aborted_groups : List [SequenceGroup ] = []
532+ for seq_group in state_queue :
533+ # When n>1, seq_group.request_id looks like
534+ # foo_parallel_sample_0, while request_ids is just foo, and we
535+ # should resolve it as real_request_id to match.
536+ if seq_group .request_id in seq_id_to_seq_group :
537+ real_request_id = seq_id_to_seq_group [seq_group .request_id ].group_id
538+ else :
539+ real_request_id = seq_group .request_id
540+ if real_request_id in request_ids :
541+ # Appending aborted group into pending list.
542+ aborted_groups .append (seq_group )
543+ # We can't remove real_request_id in request_ids here,
544+ # because there may be other seq groups sharing the same
545+ # real_request_id
546+ for aborted_group in aborted_groups :
547+ # Remove the sequence group from the state queue.
548+ state_queue .remove (aborted_group )
549+ # Remove the aborted request from the Mamba cache.
550+ self ._finished_requests_ids .append (aborted_group .request_id )
551+ for seq in aborted_group .get_seqs ():
552+ if seq .is_finished ():
553+ continue
554+ seq .status = SequenceStatus .FINISHED_ABORTED
555+ self .free_seq (seq )
556+ if aborted_group .request_id in seq_id_to_seq_group :
557+ del seq_id_to_seq_group [aborted_group .request_id ]
558+
559+ self ._free_seq_group_cross_attn_blocks (aborted_group )
560+
561+ origin_method = Scheduler .abort_seq_group
562+ Scheduler ._old_abort_seq_group = origin_method
563+ Scheduler .abort_seq_group = new_abort_seq_group
564+
565+ def patch_vllm_engine ():
566+ from vllm .engine .llm_engine import LLMEngine , SchedulerOutputState
567+ from vllm .outputs import PoolingRequestOutput , RequestOutput
568+ from vllm .sequence import ExecuteModelRequest
569+
570+ def new_abort_request (self , request_id ) -> None :
571+ for scheduler in self .scheduler :
572+ scheduler .abort_seq_group (request_id , seq_id_to_seq_group = self .seq_id_to_seq_group )
573+
574+ origin_method = LLMEngine .abort_request
575+ LLMEngine ._old_abort_request = origin_method
576+ LLMEngine .abort_request = new_abort_request
577+
578+ def new_step (self ) -> List [Union [RequestOutput , PoolingRequestOutput ]]:
579+ if self .parallel_config .pipeline_parallel_size > 1 :
580+ raise NotImplementedError ('Pipeline parallelism is only supported through AsyncLLMEngine '
581+ 'as performance will be severely degraded otherwise.' )
582+
583+ # For llm_engine, there is no pipeline parallel support, so the engine
584+ # used is always 0.
585+ virtual_engine = 0
586+
587+ # These are cached outputs from previous iterations. None if on first
588+ # iteration
589+ cached_outputs = self .cached_scheduler_outputs [virtual_engine ]
590+ seq_group_metadata_list = cached_outputs .seq_group_metadata_list
591+ scheduler_outputs = cached_outputs .scheduler_outputs
592+ allow_async_output_proc = cached_outputs .allow_async_output_proc
593+
594+ ctx = self .scheduler_contexts [virtual_engine ]
595+
596+ # Clear outputs for each new scheduler iteration
597+ ctx .request_outputs .clear ()
598+
599+ # Skip the scheduler if there are any remaining steps in the seq groups.
600+ # This ensures that the scheduler is only called again when the current
601+ # batch has completed.
602+ # The scheduler is also skipped if a single request caused the last
603+ # engine step to fail, and the previous schedule needs to be rerun.
604+ if not self ._has_remaining_steps (seq_group_metadata_list ):
605+ # Schedule iteration
606+ (seq_group_metadata_list , scheduler_outputs ,
607+ allow_async_output_proc ) = self .scheduler [virtual_engine ].schedule ()
608+
609+ ctx .seq_group_metadata_list = seq_group_metadata_list
610+ ctx .scheduler_outputs = scheduler_outputs
611+
612+ finished_requests_ids = self .scheduler [virtual_engine ].get_and_reset_finished_requests_ids ()
613+ # When n>1, elements in self.seq_id_to_seq_group should be deleted
614+ # here, otherwise memory leaks.
615+ for finished_request_id in finished_requests_ids :
616+ if finished_request_id in self .seq_id_to_seq_group :
617+ del self .seq_id_to_seq_group [finished_request_id ]
618+
619+ # Maybe switch from async mode to sync mode
620+ if not allow_async_output_proc and len (ctx .output_queue ) > 0 :
621+ self ._process_model_outputs (ctx = ctx )
622+
623+ if (self .scheduler_config .is_multi_step and scheduler_outputs .num_lookahead_slots > 0 ):
624+ # cache the scheduler outputs for the next iteration if we have
625+ # lookahead slots
626+ self ._cache_scheduler_outputs_for_multi_step (virtual_engine , seq_group_metadata_list ,
627+ scheduler_outputs , allow_async_output_proc )
628+ else :
629+ finished_requests_ids = list ()
630+
631+ assert seq_group_metadata_list is not None
632+ assert scheduler_outputs is not None
633+
634+ if not scheduler_outputs .is_empty ():
635+
636+ # Check if we have a cached last_output from the previous iteration.
637+ # For supporting PP this is probably the best way to pass the
638+ # sampled_token_ids, as a separate broadcast over all the PP stages
639+ # will cause one virtual engine's microbatch to block the pipeline.
640+ last_sampled_token_ids = \
641+ self ._get_last_sampled_token_ids (virtual_engine )
642+
643+ execute_model_req = ExecuteModelRequest (
644+ seq_group_metadata_list = seq_group_metadata_list ,
645+ blocks_to_swap_in = scheduler_outputs .blocks_to_swap_in ,
646+ blocks_to_swap_out = scheduler_outputs .blocks_to_swap_out ,
647+ blocks_to_copy = scheduler_outputs .blocks_to_copy ,
648+ num_lookahead_slots = scheduler_outputs .num_lookahead_slots ,
649+ running_queue_size = scheduler_outputs .running_queue_size ,
650+ finished_requests_ids = finished_requests_ids ,
651+ # We use ExecuteModelRequest to pass the last sampled_token_ids
652+ # to each of the non-last PP stages for in-place prepare_input.
653+ last_sampled_token_ids = last_sampled_token_ids )
654+
655+ if allow_async_output_proc :
656+ execute_model_req .async_callback = self .async_callbacks [virtual_engine ]
657+
658+ outputs = self .model_executor .execute_model (execute_model_req = execute_model_req )
659+
660+ # We need to do this here so that last step's sampled_token_ids can
661+ # be passed to the next iteration for PP.
662+ if self .scheduler_config .is_multi_step :
663+ self ._update_cached_scheduler_output (virtual_engine , outputs )
664+ else :
665+ # Nothing scheduled => If there is pending async postprocessor,
666+ # then finish it here.
667+ if len (ctx .output_queue ) > 0 :
668+ self ._process_model_outputs (ctx = ctx )
669+ # No outputs in this case
670+ outputs = []
671+
672+ # Finish the current step for all the sequence groups.
673+ if self .scheduler_config .is_multi_step :
674+ for seq_group in seq_group_metadata_list :
675+ seq_group .finish_step ()
676+
677+ if not self ._has_remaining_steps (seq_group_metadata_list ):
678+ # clear the cache if we have finished all the steps.
679+ if self .scheduler_config .is_multi_step :
680+ self .cached_scheduler_outputs [0 ] = SchedulerOutputState ()
681+
682+ # is_first_step_output is True only when the num_steps of all
683+ # the sequences are 1. When the num_steps > 1,
684+ # multi_step_model_runner does the first-step output append.
685+ is_first_step_output : bool = False if not seq_group_metadata_list \
686+ else seq_group_metadata_list [0 ].state .num_steps == 1
687+
688+ # Add results to the output_queue
689+ ctx .append_output (
690+ outputs = outputs ,
691+ seq_group_metadata_list = seq_group_metadata_list ,
692+ scheduler_outputs = scheduler_outputs ,
693+ is_async = allow_async_output_proc ,
694+ is_last_step = True ,
695+ is_first_step_output = is_first_step_output )
696+
697+ if outputs and allow_async_output_proc :
698+ assert len (outputs ) == 1 , ('Async postprocessor expects only a single output set' )
699+
700+ self ._advance_to_next_step (outputs [0 ], seq_group_metadata_list ,
701+ scheduler_outputs .scheduled_seq_groups )
702+
703+ # Check if need to run the usual non-async path
704+ if not allow_async_output_proc :
705+ self ._process_model_outputs (ctx = ctx )
706+
707+ # Log stats.
708+ self .do_log_stats (scheduler_outputs , outputs )
709+
710+ # Tracing
711+ self .do_tracing (scheduler_outputs )
712+ else :
713+ # Multi-step case
714+ return ctx .request_outputs
715+
716+ if not self .has_unfinished_requests ():
717+ # Drain async postprocessor (if exists)
718+ if len (ctx .output_queue ) > 0 :
719+ self ._process_model_outputs (ctx = ctx )
720+ assert len (ctx .output_queue ) == 0
721+
722+ # Stop the execute model loop in parallel workers until there are
723+ # more requests to process. This avoids waiting indefinitely in
724+ # torch.distributed ops which may otherwise timeout, and unblocks
725+ # the RPC thread in the workers so that they can process any other
726+ # queued control plane messages, such as add/remove lora adapters.
727+ self .model_executor .stop_remote_worker_execution_loop ()
728+
729+ return ctx .request_outputs
730+
731+ origin_method = LLMEngine .step
732+ LLMEngine ._old_step = origin_method
733+ LLMEngine .step = new_step
734+
735+ patch_vllm_abort_seq_group ()
736+ patch_vllm_engine ()
0 commit comments