Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions tests/entrypoints/llm/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from vllm import LLM
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.sampling_params import SamplingParams

from ..openai.test_vision import TEST_IMAGE_ASSETS

Expand All @@ -23,6 +24,29 @@ def text_llm():
cleanup_dist_env_and_memory()


@pytest.fixture(scope="function")
def llm_for_failure_test():
"""
Fixture for testing issue #26081.
Uses a small max_model_len to easily trigger length errors.
"""
# pytest caches the fixture so we use weakref.proxy to
# enable garbage collection
llm = LLM(
model="meta-llama/Llama-3.2-1B-Instruct",
enforce_eager=True,
seed=0,
max_model_len=128,
disable_log_stats=True,
)

yield weakref.proxy(llm)

del llm

cleanup_dist_env_and_memory()


def test_chat(text_llm):
prompt1 = "Explain the concept of entropy."
messages = [
Expand Down Expand Up @@ -157,3 +181,32 @@ def test_chat_extra_kwargs(thinking_llm, enable_thinking):
else:
# The chat template includes dummy thinking process
assert think_id in prompt_token_ids


def test_chat_batch_failure_cleanup(llm_for_failure_test):
"""
Tests that if a batch call to llm.chat() fails mid-way
(e.g., due to one invalid prompt), the requests that
were already enqueued are properly aborted and do not
pollute the queue for subsequent calls.
(Fixes Issue #26081)
"""
llm = llm_for_failure_test
valid_msg = [{"role": "user", "content": "Hello"}]
long_text = "This is a very long text to test the error " * 50
invalid_msg = [{"role": "user", "content": long_text}]
batch_1 = [
valid_msg,
valid_msg,
invalid_msg,
]
batch_2 = [
valid_msg,
valid_msg,
]
sampling_params = SamplingParams(temperature=0, max_tokens=10)
with pytest.raises(ValueError, match="longer than the maximum model length"):
llm.chat(batch_1, sampling_params=sampling_params)
outputs_2 = llm.chat(batch_2, sampling_params=sampling_params)
assert len(outputs_2) == len(batch_2)
assert llm.llm_engine.get_num_unfinished_requests() == 0
36 changes: 22 additions & 14 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1560,20 +1560,27 @@ def _validate_and_add_requests(
tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
it = tqdm_func(it, desc="Adding requests")

for i, prompt in enumerate(it):
if isinstance(prompt, dict):
self._validate_mm_data_and_uuids(
prompt.get("multi_modal_data"), prompt.get("multi_modal_uuids")
)
added_request_ids: list[str] = []

self._add_request(
prompt,
params[i] if isinstance(params, Sequence) else params,
lora_request=lora_request[i]
if isinstance(lora_request, Sequence)
else lora_request,
priority=priority[i] if priority else 0,
)
try:
for i, prompt in enumerate(it):
if isinstance(prompt, dict):
self._validate_mm_data_and_uuids(
prompt.get("multi_modal_data"), prompt.get("multi_modal_uuids")
)
request_id = self._add_request(
prompt,
params[i] if isinstance(params, Sequence) else params,
lora_request=lora_request[i]
if isinstance(lora_request, Sequence)
else lora_request,
priority=priority[i] if priority else 0,
)
added_request_ids.append(request_id)
except Exception as e:
if added_request_ids:
self.llm_engine.abort_request(added_request_ids)
raise e

def _validate_mm_data_and_uuids(
self,
Expand Down Expand Up @@ -1656,7 +1663,7 @@ def _add_request(
params: SamplingParams | PoolingParams,
lora_request: LoRARequest | None = None,
priority: int = 0,
) -> None:
) -> str:
prompt_text, _, _ = get_prompt_components(prompt)
request_id = str(next(self.request_counter))

Expand All @@ -1677,6 +1684,7 @@ def _add_request(
priority=priority,
prompt_text=prompt_text,
)
return request_id

def _run_engine(
self, *, use_tqdm: bool | Callable[..., tqdm] = True
Expand Down