diff --git a/tests/entrypoints/llm/test_chat.py b/tests/entrypoints/llm/test_chat.py index b2a958a992a6..a9698632b82e 100644 --- a/tests/entrypoints/llm/test_chat.py +++ b/tests/entrypoints/llm/test_chat.py @@ -6,6 +6,7 @@ from vllm import LLM from vllm.distributed import cleanup_dist_env_and_memory +from vllm.sampling_params import SamplingParams from ..openai.test_vision import TEST_IMAGE_ASSETS @@ -23,6 +24,29 @@ def text_llm(): cleanup_dist_env_and_memory() +@pytest.fixture(scope="function") +def llm_for_failure_test(): + """ + Fixture for testing issue #26081. + Uses a small max_model_len to easily trigger length errors. + """ + # pytest caches the fixture so we use weakref.proxy to + # enable garbage collection + llm = LLM( + model="meta-llama/Llama-3.2-1B-Instruct", + enforce_eager=True, + seed=0, + max_model_len=128, + disable_log_stats=True, + ) + + yield weakref.proxy(llm) + + del llm + + cleanup_dist_env_and_memory() + + def test_chat(text_llm): prompt1 = "Explain the concept of entropy." messages = [ @@ -157,3 +181,32 @@ def test_chat_extra_kwargs(thinking_llm, enable_thinking): else: # The chat template includes dummy thinking process assert think_id in prompt_token_ids + + +def test_chat_batch_failure_cleanup(llm_for_failure_test): + """ + Tests that if a batch call to llm.chat() fails mid-way + (e.g., due to one invalid prompt), the requests that + were already enqueued are properly aborted and do not + pollute the queue for subsequent calls. + (Fixes Issue #26081) + """ + llm = llm_for_failure_test + valid_msg = [{"role": "user", "content": "Hello"}] + long_text = "This is a very long text to test the error " * 50 + invalid_msg = [{"role": "user", "content": long_text}] + batch_1 = [ + valid_msg, + valid_msg, + invalid_msg, + ] + batch_2 = [ + valid_msg, + valid_msg, + ] + sampling_params = SamplingParams(temperature=0, max_tokens=10) + with pytest.raises(ValueError, match="longer than the maximum model length"): + llm.chat(batch_1, sampling_params=sampling_params) + outputs_2 = llm.chat(batch_2, sampling_params=sampling_params) + assert len(outputs_2) == len(batch_2) + assert llm.llm_engine.get_num_unfinished_requests() == 0 diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 869861afff03..d6a36373aba2 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1560,20 +1560,27 @@ def _validate_and_add_requests( tqdm_func = use_tqdm if callable(use_tqdm) else tqdm it = tqdm_func(it, desc="Adding requests") - for i, prompt in enumerate(it): - if isinstance(prompt, dict): - self._validate_mm_data_and_uuids( - prompt.get("multi_modal_data"), prompt.get("multi_modal_uuids") - ) + added_request_ids: list[str] = [] - self._add_request( - prompt, - params[i] if isinstance(params, Sequence) else params, - lora_request=lora_request[i] - if isinstance(lora_request, Sequence) - else lora_request, - priority=priority[i] if priority else 0, - ) + try: + for i, prompt in enumerate(it): + if isinstance(prompt, dict): + self._validate_mm_data_and_uuids( + prompt.get("multi_modal_data"), prompt.get("multi_modal_uuids") + ) + request_id = self._add_request( + prompt, + params[i] if isinstance(params, Sequence) else params, + lora_request=lora_request[i] + if isinstance(lora_request, Sequence) + else lora_request, + priority=priority[i] if priority else 0, + ) + added_request_ids.append(request_id) + except Exception as e: + if added_request_ids: + self.llm_engine.abort_request(added_request_ids) + raise e def _validate_mm_data_and_uuids( self, @@ -1656,7 +1663,7 @@ def _add_request( params: SamplingParams | PoolingParams, lora_request: LoRARequest | None = None, priority: int = 0, - ) -> None: + ) -> str: prompt_text, _, _ = get_prompt_components(prompt) request_id = str(next(self.request_counter)) @@ -1677,6 +1684,7 @@ def _add_request( priority=priority, prompt_text=prompt_text, ) + return request_id def _run_engine( self, *, use_tqdm: bool | Callable[..., tqdm] = True