From 86c0f384df94624920c80d4cb79fb5c317afe443 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Sat, 18 Oct 2025 15:47:25 +0800 Subject: [PATCH 01/18] Support chunked prefill with ALL pooling Signed-off-by: wang.yuqi --- .../test_all_pooling_plus_chunked_prefill.py | 57 +++++++++++++++++ vllm/model_executor/layers/pooler.py | 61 ++++++++++++++++--- vllm/pooling_params.py | 3 + vllm/v1/outputs.py | 2 +- vllm/v1/pool/metadata.py | 16 +++-- vllm/v1/worker/gpu_model_runner.py | 47 ++++---------- 6 files changed, 138 insertions(+), 48 deletions(-) create mode 100644 tests/models/language/pooling/test_all_pooling_plus_chunked_prefill.py diff --git a/tests/models/language/pooling/test_all_pooling_plus_chunked_prefill.py b/tests/models/language/pooling/test_all_pooling_plus_chunked_prefill.py new file mode 100644 index 000000000000..9bc4912441a4 --- /dev/null +++ b/tests/models/language/pooling/test_all_pooling_plus_chunked_prefill.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch +from transformers import AutoModel + +from tests.models.utils import check_embeddings_close +from vllm import TokensPrompt + + +@pytest.mark.parametrize( + "model", + ["Qwen/Qwen3-Embedding-0.6B"], +) +@pytest.mark.parametrize("dtype", ["half"]) +@torch.inference_mode +def test_embed_models(hf_runner, vllm_runner, model: str, dtype: str): + chunk_size = 10 + n_prompt_tokens = [55, 56, 57] + token_prompts = [[1024 + i for i in range(n)] for n in n_prompt_tokens] + + with vllm_runner( + model, + runner="pooling", + long_prefill_token_threshold=chunk_size, + max_model_len=128, + enforce_eager=True, + enable_chunked_prefill=True, + # If enable_prefix_caching is enabled, + # the output of all pooling will be less than n_prompt_tokens, + # we need a method to disable prefix_caching at the request level. + enable_prefix_caching=False, + max_num_batched_tokens=chunk_size, + ) as vllm_model: + vllm_outputs = vllm_model.token_embed( + [TokensPrompt(prompt_token_ids=t) for t in token_prompts], + ) + + with hf_runner( + model, + auto_cls=AutoModel, + ) as hf_model: + hf_outputs = [] + for token_prompt in token_prompts: + inputs = hf_model.wrap_device({"input_ids": torch.tensor([token_prompt])}) + input_ids = inputs["input_ids"] + output = hf_model.model(input_ids) + hf_outputs.append(output.last_hidden_state.cpu().float()[0]) + + for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): + check_embeddings_close( + embeddings_0_lst=hf_output, + embeddings_1_lst=vllm_output, + name_0="hf", + name_1="vllm", + tol=1e-2, + ) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index a8c66315684e..647ac9460629 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -204,22 +204,57 @@ def forward_all( class AllPool(PoolingMethod): + def __init__(self): + super().__init__() + + vllm_config = get_current_vllm_config() + self.enable_chunked_prefill = ( + vllm_config.scheduler_config.enable_chunked_prefill + ) + def get_supported_tasks(self) -> Set[PoolingTask]: return {"token_embed", "token_classify"} def forward_all( + self, hidden_states: torch.Tensor, pooling_cursor: PoolingCursor + ) -> list[torch.Tensor] | torch.Tensor: + pass + + def forward( self, hidden_states: torch.Tensor, - pooling_cursor: PoolingCursor, + pooling_metadata: PoolingMetadata, ) -> list[torch.Tensor] | torch.Tensor: - assert not pooling_cursor.is_partial_prefill(), ( - "partial prefill not supported with ALL pooling" - ) + pooling_cursor = pooling_metadata.pooling_cursor + pooling_params = get_pooling_params(pooling_metadata) + is_finished = pooling_cursor.is_finished() hidden_states_lst = list( hidden_states.split(pooling_cursor.num_scheduled_tokens_cpu.tolist()) ) - return [hidden_states_lst[i] for i in pooling_cursor.index] + hidden_states_lst = [hidden_states_lst[i] for i in pooling_cursor.index] + + if not self.enable_chunked_prefill: + return hidden_states_lst + + # If chunked_prefill is enabled + # 1. first store the chunked hidden_states in pooling_param.hidden_states_cache + for pooling_param, hidden_states in zip(pooling_params, hidden_states_lst): + pooling_param.hidden_states_cache.append(hidden_states) + + # 2. Once prefill is finished, send hidden_states_cache to PoolerHead + hidden_states = [] + for pooling_param, finished in zip(pooling_params, is_finished): + if finished: + hidden_states_cache = pooling_param.hidden_states_cache + if len(hidden_states_cache) == 1: + hidden_states.append(hidden_states_cache[0]) + else: + hidden_states.append(torch.concat(hidden_states_cache, dim=0)) + else: + hidden_states.append(None) + + return hidden_states class MeanPool(PoolingMethod): @@ -610,8 +645,12 @@ def forward( class TokenEmbeddingPoolerHead(EmbeddingPoolerHead): def forward( - self, pooled_data: torch.Tensor, pooling_param: PoolingParams - ) -> torch.Tensor: + self, pooled_data: torch.Tensor | None, pooling_param: PoolingParams + ) -> PoolerOutput: + # for unfinished chunked prefill + if pooled_data is None: + return None + pooled_data = pooled_data.to(self.head_dtype) # pooled_data shape: [n_tokens, hidden_dimension] @@ -654,9 +693,13 @@ def get_supported_tasks(self) -> Set[PoolingTask]: def forward( self, - hidden_states: torch.Tensor, + hidden_states: torch.Tensor | None, pooling_param: PoolingParams, - ) -> torch.Tensor: + ) -> PoolerOutput: + # for unfinished chunked prefill + if hidden_states is None: + return None + hidden_states = hidden_states.to(self.head_dtype) # hidden_states shape: [n_token, hidden_size] diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py index c6dff6e01c1d..57fac7990592 100644 --- a/vllm/pooling_params.py +++ b/vllm/pooling_params.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Annotated, Any, Optional import msgspec +import torch from vllm.sampling_params import RequestOutputKind from vllm.tasks import PoolingTask @@ -55,6 +56,8 @@ class PoolingParams( task: PoolingTask | None = None requires_token_ids: bool = False extra_kwargs: dict[str, Any] | None = None + # use in AllPool + hidden_states_cache: list[torch.Tensor] = [] output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY @property diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index c224555da6ca..bbcb5f204bf7 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -66,7 +66,7 @@ def empty_cpu( # [num_reqs, ] # The shape of each element depends on the pooler used -PoolerOutput = torch.Tensor | list[torch.Tensor] +PoolerOutput = torch.Tensor | list[torch.Tensor] | None @dataclass diff --git a/vllm/v1/pool/metadata.py b/vllm/v1/pool/metadata.py index 2fb320dd2aaf..fe10ee1d8318 100644 --- a/vllm/v1/pool/metadata.py +++ b/vllm/v1/pool/metadata.py @@ -16,6 +16,7 @@ class PoolingCursor: first_token_indices_gpu: torch.Tensor last_token_indices_gpu: torch.Tensor prompt_lens_cpu: torch.Tensor + seq_lens_cpu: torch.Tensor num_scheduled_tokens_cpu: torch.Tensor def __getitem__(self, indices: slice): @@ -24,11 +25,15 @@ def __getitem__(self, indices: slice): first_token_indices_gpu=self.first_token_indices_gpu[indices], last_token_indices_gpu=self.last_token_indices_gpu[indices], prompt_lens_cpu=self.prompt_lens_cpu[indices], + seq_lens_cpu=self.seq_lens_cpu[indices], num_scheduled_tokens_cpu=self.num_scheduled_tokens_cpu[indices], ) def is_partial_prefill(self): - return not torch.all(self.prompt_lens_cpu == self.num_scheduled_tokens_cpu) + return self.prompt_lens_cpu == self.num_scheduled_tokens_cpu + + def is_finished(self): + return self.prompt_lens_cpu == self.seq_lens_cpu @dataclass @@ -53,15 +58,17 @@ def __getitem__(self, indices: slice): ) def build_pooling_cursor( - self, num_scheduled_tokens: list[int], device: torch.device + self, num_scheduled_tokens: list[int], seq_lens_cpu: torch.Tensor, device: torch.device ): self.pooling_cursor = build_pooling_cursor( - num_scheduled_tokens, self.prompt_lens, device + num_scheduled_tokens, + seq_lens_cpu, + self.prompt_lens, device ) def build_pooling_cursor( - num_scheduled_tokens: list[int], prompt_lens: torch.Tensor, device: torch.device + num_scheduled_tokens: list[int], seq_lens_cpu: torch.Tensor, prompt_lens: torch.Tensor, device: torch.device ): assert len(prompt_lens) == len(num_scheduled_tokens) @@ -78,5 +85,6 @@ def build_pooling_cursor( first_token_indices_gpu=cumsum[:n_seq], last_token_indices_gpu=cumsum[1:] - 1, prompt_lens_cpu=prompt_lens, + seq_lens_cpu=seq_lens_cpu, num_scheduled_tokens_cpu=num_scheduled_tokens, ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 7e72ce937be4..16678274c4a4 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1923,20 +1923,6 @@ def get_supported_pooling_tasks(self) -> list[PoolingTask]: supported_tasks = list(model.pooler.get_supported_tasks()) - if self.scheduler_config.chunked_prefill_enabled: - if "token_embed" in supported_tasks: - supported_tasks.remove("token_embed") - if "token_classify" in supported_tasks: - supported_tasks.remove("token_classify") - - logger.debug_once( - "Chunked prefill is not supported with " - "token_embed and token_classify tasks " - "which using ALL pooling. " - "Please turn off chunked prefill by " - "`--no-enable-chunked-prefill` before using it." - ) - if "score" in supported_tasks: num_labels = getattr(self.model_config.hf_config, "num_labels", 0) if num_labels != 1: @@ -2027,11 +2013,12 @@ def _pool( ) hidden_states = hidden_states[:num_scheduled_tokens] + seq_lens_cpu = self.seq_lens.cpu[: self.input_batch.num_reqs] + pooling_metadata = self.input_batch.get_pooling_metadata() pooling_metadata.build_pooling_cursor( - num_scheduled_tokens_np.tolist(), device=hidden_states.device + num_scheduled_tokens_np.tolist(), seq_lens_cpu, device=hidden_states.device ) - seq_lens_cpu = self.seq_lens.cpu[: self.input_batch.num_reqs] model = cast(VllmModelForPooling, self.model) raw_pooler_output: PoolerOutput = model.pooler( @@ -2039,7 +2026,7 @@ def _pool( pooling_metadata=pooling_metadata, ) raw_pooler_output = json_map_leaves( - lambda x: x.to("cpu", non_blocking=True), + lambda x: x.to("cpu", non_blocking=True) if x is not None else x, raw_pooler_output, ) self._sync_device() @@ -3592,7 +3579,9 @@ def _dummy_pooler_run_task( ) dummy_metadata.build_pooling_cursor( - num_scheduled_tokens_list, device=hidden_states.device + num_scheduled_tokens_list, + seq_lens_cpu=dummy_prompt_lens, + device=hidden_states.device, ) try: @@ -3619,22 +3608,12 @@ def _dummy_pooler_run( supported_pooling_tasks = self.get_supported_pooling_tasks() if not supported_pooling_tasks: - if self.scheduler_config.chunked_prefill_enabled: - raise RuntimeError( - f"Model {self.model_config.model} does not support " - "any pooling tasks with chunked prefill enabled. " - "Please add --no-enable-chunked-prefill to your " - "config or CLI args. See " - "https://docs.vllm.ai/en/latest/models/pooling_models.html " - "to learn more." - ) - else: - raise RuntimeError( - f"Model {self.model_config.model} does not support " - "any pooling tasks. See " - "https://docs.vllm.ai/en/latest/models/pooling_models.html " - "to learn more." - ) + raise RuntimeError( + f"Model {self.model_config.model} does not support " + "any pooling tasks. See " + "https://docs.vllm.ai/en/latest/models/pooling_models.html " + "to learn more." + ) output_size = dict[PoolingTask, float]() for task in supported_pooling_tasks: From 6bd49f271c02082d2e4c5b1030e5fecb45b0b553 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Sat, 18 Oct 2025 15:49:57 +0800 Subject: [PATCH 02/18] fix Signed-off-by: wang.yuqi --- vllm/v1/pool/metadata.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/vllm/v1/pool/metadata.py b/vllm/v1/pool/metadata.py index fe10ee1d8318..006b2b9219ae 100644 --- a/vllm/v1/pool/metadata.py +++ b/vllm/v1/pool/metadata.py @@ -58,17 +58,21 @@ def __getitem__(self, indices: slice): ) def build_pooling_cursor( - self, num_scheduled_tokens: list[int], seq_lens_cpu: torch.Tensor, device: torch.device + self, + num_scheduled_tokens: list[int], + seq_lens_cpu: torch.Tensor, + device: torch.device, ): self.pooling_cursor = build_pooling_cursor( - num_scheduled_tokens, - seq_lens_cpu, - self.prompt_lens, device + num_scheduled_tokens, seq_lens_cpu, self.prompt_lens, device ) def build_pooling_cursor( - num_scheduled_tokens: list[int], seq_lens_cpu: torch.Tensor, prompt_lens: torch.Tensor, device: torch.device + num_scheduled_tokens: list[int], + seq_lens_cpu: torch.Tensor, + prompt_lens: torch.Tensor, + device: torch.device, ): assert len(prompt_lens) == len(num_scheduled_tokens) From 44c6ee1379383fb76db4f3c51319952b1e7f2faf Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Sat, 18 Oct 2025 15:53:19 +0800 Subject: [PATCH 03/18] fix Signed-off-by: wang.yuqi --- vllm/v1/pool/metadata.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/pool/metadata.py b/vllm/v1/pool/metadata.py index 006b2b9219ae..6e7ba174cb1c 100644 --- a/vllm/v1/pool/metadata.py +++ b/vllm/v1/pool/metadata.py @@ -30,7 +30,7 @@ def __getitem__(self, indices: slice): ) def is_partial_prefill(self): - return self.prompt_lens_cpu == self.num_scheduled_tokens_cpu + return not torch.all(self.prompt_lens_cpu == self.num_scheduled_tokens_cpu) def is_finished(self): return self.prompt_lens_cpu == self.seq_lens_cpu From 86f0868e0df35b17450ec4b86982d8d65b8a5753 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Sat, 18 Oct 2025 15:55:46 +0800 Subject: [PATCH 04/18] fix Signed-off-by: wang.yuqi --- vllm/pooling_params.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py index 57fac7990592..e723d4b356cf 100644 --- a/vllm/pooling_params.py +++ b/vllm/pooling_params.py @@ -57,7 +57,7 @@ class PoolingParams( requires_token_ids: bool = False extra_kwargs: dict[str, Any] | None = None # use in AllPool - hidden_states_cache: list[torch.Tensor] = [] + hidden_states_cache: list[torch.Tensor] = msgspec.field(default_factory=list) output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY @property From 7c1d68d4f3a0f51efb7c0ff0d924ab503b4e72c7 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Sat, 18 Oct 2025 15:58:23 +0800 Subject: [PATCH 05/18] Update vllm/model_executor/layers/pooler.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: wang.yuqi --- vllm/model_executor/layers/pooler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 647ac9460629..28a126b95127 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -218,7 +218,8 @@ def get_supported_tasks(self) -> Set[PoolingTask]: def forward_all( self, hidden_states: torch.Tensor, pooling_cursor: PoolingCursor ) -> list[torch.Tensor] | torch.Tensor: - pass + raise NotImplementedError( + "forward_all is not implemented for AllPool. Use forward instead.") def forward( self, From f9034156ba8f177410e9e39b487afe502bb63187 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Sat, 18 Oct 2025 15:58:47 +0800 Subject: [PATCH 06/18] Update vllm/model_executor/layers/pooler.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: wang.yuqi --- vllm/model_executor/layers/pooler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 28a126b95127..a4febb346e96 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -252,6 +252,7 @@ def forward( hidden_states.append(hidden_states_cache[0]) else: hidden_states.append(torch.concat(hidden_states_cache, dim=0)) + pooling_param.hidden_states_cache.clear() else: hidden_states.append(None) From 72df85d80b956265bc73d84a1c6187c913a1d385 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Sat, 18 Oct 2025 16:02:05 +0800 Subject: [PATCH 07/18] Update vllm/model_executor/layers/pooler.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: wang.yuqi --- vllm/model_executor/layers/pooler.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index a4febb346e96..186eb87fa41d 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -240,23 +240,23 @@ def forward( # If chunked_prefill is enabled # 1. first store the chunked hidden_states in pooling_param.hidden_states_cache - for pooling_param, hidden_states in zip(pooling_params, hidden_states_lst): - pooling_param.hidden_states_cache.append(hidden_states) + for pooling_param, hs_chunk in zip(pooling_params, hidden_states_lst): + pooling_param.hidden_states_cache.append(hs_chunk) # 2. Once prefill is finished, send hidden_states_cache to PoolerHead - hidden_states = [] + output_list = [] for pooling_param, finished in zip(pooling_params, is_finished): if finished: hidden_states_cache = pooling_param.hidden_states_cache if len(hidden_states_cache) == 1: - hidden_states.append(hidden_states_cache[0]) + output_list.append(hidden_states_cache[0]) else: - hidden_states.append(torch.concat(hidden_states_cache, dim=0)) + output_list.append(torch.concat(hidden_states_cache, dim=0)) pooling_param.hidden_states_cache.clear() else: - hidden_states.append(None) + output_list.append(None) - return hidden_states + return output_list class MeanPool(PoolingMethod): From 6b6e7a8b349d6d8f3cd128f1bf713314b98ca506 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Sat, 18 Oct 2025 16:26:43 +0800 Subject: [PATCH 08/18] fix deep copy Signed-off-by: wang.yuqi --- vllm/v1/engine/core.py | 3 +++ vllm/v1/worker/gpu_model_runner.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 2773dc61ff3d..104cd20cace7 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -272,6 +272,9 @@ def add_request(self, request: Request, request_wave: int = 0): f"Supported tasks: {supported_pooling_tasks}" ) + # Ensure that no multiple requests share the same pooling_params + request.pooling_params = request.pooling_params.clone() + if request.kv_transfer_params is not None and ( not self.scheduler.get_kv_connector() ): diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 16678274c4a4..2dde5db9d4bb 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3575,7 +3575,7 @@ def _dummy_pooler_run_task( dummy_metadata = PoolingMetadata( prompt_lens=dummy_prompt_lens, prompt_token_ids=dummy_token_ids, - pooling_params=[dummy_pooling_params] * num_reqs, + pooling_params=[dummy_pooling_params.clone() for i in range(num_reqs)], ) dummy_metadata.build_pooling_cursor( From 9aef3544daa2795736f4b633895f35d79732be4c Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Sat, 18 Oct 2025 16:35:54 +0800 Subject: [PATCH 09/18] fix Signed-off-by: wang.yuqi --- vllm/model_executor/layers/pooler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 186eb87fa41d..29539228837a 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -219,7 +219,8 @@ def forward_all( self, hidden_states: torch.Tensor, pooling_cursor: PoolingCursor ) -> list[torch.Tensor] | torch.Tensor: raise NotImplementedError( - "forward_all is not implemented for AllPool. Use forward instead.") + "forward_all is not implemented for AllPool. Use forward instead." + ) def forward( self, From d574b6c35ebfea3cbadddac4294d20b904df5517 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Sat, 18 Oct 2025 19:46:23 +0800 Subject: [PATCH 10/18] + tests Signed-off-by: wang.yuqi --- .../entrypoints/pooling/llm/test_classify.py | 7 +--- .../entrypoints/pooling/llm/test_embedding.py | 2 +- .../pooling/openai/test_classification.py | 13 +++++-- .../test_all_pooling_plus_chunked_prefill.py | 7 ++-- .../pooling/test_auto_prefix_cache_support.py | 28 ++++++++++++-- .../pooling/test_extract_hidden_states.py | 38 +++++++++++++++++++ vllm/outputs.py | 13 ++++++- vllm/v1/engine/output_processor.py | 1 + 8 files changed, 90 insertions(+), 19 deletions(-) create mode 100644 tests/models/language/pooling/test_extract_hidden_states.py diff --git a/tests/entrypoints/pooling/llm/test_classify.py b/tests/entrypoints/pooling/llm/test_classify.py index 96f634ee0a8c..206f4df55a09 100644 --- a/tests/entrypoints/pooling/llm/test_classify.py +++ b/tests/entrypoints/pooling/llm/test_classify.py @@ -59,11 +59,8 @@ def get_outputs(activation): @pytest.mark.skip_global_cleanup -def test_encode_api(llm: LLM): - # chunked prefill does not support all pooling - err_msg = "pooling_task must be one of.+" - with pytest.raises(ValueError, match=err_msg): - llm.encode(prompts, pooling_task="token_classify", use_tqdm=False) +def test_token_classify(llm: LLM): + llm.encode(prompts, pooling_task="token_classify", use_tqdm=False) def test_score_api(llm: LLM): diff --git a/tests/entrypoints/pooling/llm/test_embedding.py b/tests/entrypoints/pooling/llm/test_embedding.py index 5455b5f91fc0..e478ed3ac0b4 100644 --- a/tests/entrypoints/pooling/llm/test_embedding.py +++ b/tests/entrypoints/pooling/llm/test_embedding.py @@ -36,7 +36,7 @@ def llm(): @pytest.mark.skip_global_cleanup -def test_encode_api(llm: LLM): +def test_token_embed(llm: LLM): outputs = llm.encode(prompts, pooling_task="token_embed", use_tqdm=False) multi_vector = outputs[0].outputs.data assert multi_vector.shape == (11, 384) diff --git a/tests/entrypoints/pooling/openai/test_classification.py b/tests/entrypoints/pooling/openai/test_classification.py index 92d40efad21c..dc6f37cbae6d 100644 --- a/tests/entrypoints/pooling/openai/test_classification.py +++ b/tests/entrypoints/pooling/openai/test_classification.py @@ -7,7 +7,7 @@ import torch.nn.functional as F from tests.utils import RemoteOpenAIServer -from vllm.entrypoints.openai.protocol import ClassificationResponse +from vllm.entrypoints.openai.protocol import ClassificationResponse, PoolingResponse MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach" DTYPE = "float32" # Use float32 to avoid NaN issue @@ -192,12 +192,17 @@ async def get_outputs(activation): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) def test_pooling(server: RemoteOpenAIServer, model_name: str): - # pooling api uses ALL pooling, which does not support chunked prefill. + input_text = ["This product was excellent and exceeded my expectations"] response = requests.post( server.url_for("pooling"), - json={"model": model_name, "input": "test", "encoding_format": "float"}, + json={"model": model_name, "input": input_text, "encoding_format": "float"}, ) - assert response.json()["error"]["type"] == "BadRequestError" + poolings = PoolingResponse.model_validate(response.json()) + + # token_classify + assert len(poolings.data) == 1 + assert len(poolings.data[0].data) == 8 + assert len(poolings.data[0].data[0]) == 2 @pytest.mark.asyncio diff --git a/tests/models/language/pooling/test_all_pooling_plus_chunked_prefill.py b/tests/models/language/pooling/test_all_pooling_plus_chunked_prefill.py index 9bc4912441a4..6634945f3b97 100644 --- a/tests/models/language/pooling/test_all_pooling_plus_chunked_prefill.py +++ b/tests/models/language/pooling/test_all_pooling_plus_chunked_prefill.py @@ -12,9 +12,8 @@ "model", ["Qwen/Qwen3-Embedding-0.6B"], ) -@pytest.mark.parametrize("dtype", ["half"]) @torch.inference_mode -def test_embed_models(hf_runner, vllm_runner, model: str, dtype: str): +def test_embed_models(hf_runner, vllm_runner, model: str): chunk_size = 10 n_prompt_tokens = [55, 56, 57] token_prompts = [[1024 + i for i in range(n)] for n in n_prompt_tokens] @@ -22,15 +21,15 @@ def test_embed_models(hf_runner, vllm_runner, model: str, dtype: str): with vllm_runner( model, runner="pooling", - long_prefill_token_threshold=chunk_size, max_model_len=128, + max_num_batched_tokens=chunk_size, enforce_eager=True, + # `enable_chunked_prefill`: Set to `False` instead of `None` in VllmRunner enable_chunked_prefill=True, # If enable_prefix_caching is enabled, # the output of all pooling will be less than n_prompt_tokens, # we need a method to disable prefix_caching at the request level. enable_prefix_caching=False, - max_num_batched_tokens=chunk_size, ) as vllm_model: vllm_outputs = vllm_model.token_embed( [TokensPrompt(prompt_token_ids=t) for t in token_prompts], diff --git a/tests/models/language/pooling/test_auto_prefix_cache_support.py b/tests/models/language/pooling/test_auto_prefix_cache_support.py index e95119df95c7..0904c7e877ef 100644 --- a/tests/models/language/pooling/test_auto_prefix_cache_support.py +++ b/tests/models/language/pooling/test_auto_prefix_cache_support.py @@ -19,14 +19,25 @@ def test_classify_models( model: str, dtype: str, ) -> None: - example_prompts = example_prompts * 2 + # example_prompts is too short for testing prefix_caching + example_prompts = [s * 10 for s in example_prompts] with vllm_runner( model, max_model_len=512, dtype=dtype, enable_prefix_caching=True ) as vllm_model: cache_config = vllm_model.llm.llm_engine.cache_config assert cache_config.enable_prefix_caching - vllm_outputs = vllm_model.classify(example_prompts) + + # First Run + vllm_model.classify(example_prompts) + + # assert prefix_caching works + pooling_outputs = vllm_model.llm.encode( + example_prompts, pooling_task="classify" + ) + for output in pooling_outputs: + assert output.num_cached_tokens > 0 + vllm_outputs = [req_output.outputs.data for req_output in pooling_outputs] with hf_runner( model, dtype=dtype, auto_cls=AutoModelForSequenceClassification @@ -54,7 +65,8 @@ def test_embed_models( model: str, dtype: str, ): - example_prompts = [str(s).strip() for s in example_prompts] * 2 + # example_prompts is too short for testing prefix_caching + example_prompts = [str(s).strip() * 10 for s in example_prompts] with vllm_runner( model, @@ -64,7 +76,15 @@ def test_embed_models( ) as vllm_model: cache_config = vllm_model.llm.llm_engine.cache_config assert cache_config.enable_prefix_caching - vllm_outputs = vllm_model.embed(example_prompts) + + # First Run + vllm_model.embed(example_prompts) + + # assert prefix_caching works + pooling_outputs = vllm_model.llm.encode(example_prompts, pooling_task="embed") + for output in pooling_outputs: + assert output.num_cached_tokens > 0 + vllm_outputs = [req_output.outputs.data for req_output in pooling_outputs] with hf_runner( model, diff --git a/tests/models/language/pooling/test_extract_hidden_states.py b/tests/models/language/pooling/test_extract_hidden_states.py new file mode 100644 index 000000000000..4964eedcac6e --- /dev/null +++ b/tests/models/language/pooling/test_extract_hidden_states.py @@ -0,0 +1,38 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch + +from vllm import TokensPrompt + + +@pytest.mark.parametrize( + "model", + ["Qwen/Qwen3-0.6B"], +) +@torch.inference_mode +def test_embed_models(hf_runner, vllm_runner, model: str): + chunk_size = 10 + n_prompt_tokens = [55, 56, 57] + token_prompts = [[1024 + i for i in range(n)] for n in n_prompt_tokens] + + with vllm_runner( + model, + max_model_len=128, + max_num_batched_tokens=chunk_size, + enforce_eager=True, + runner="pooling", + # `enable_chunked_prefill`: Set to `False` instead of `None` in VllmRunner + enable_chunked_prefill=True, + enable_prefix_caching=True, + ) as vllm_model: + pooling_outputs = vllm_model.llm.encode( + [TokensPrompt(prompt_token_ids=t) for t in token_prompts], + pooling_task="token_embed", + ) + + for n, output in zip(n_prompt_tokens, pooling_outputs): + assert len(output.prompt_token_ids) == n + # We should ensure that all pooling task output.num_cached_tokens == 0 + # even if prefix caching is enabled + assert output.num_cached_tokens >= 0 diff --git a/vllm/outputs.py b/vllm/outputs.py index 114c1c5dc4b0..cdfe06f1c7fa 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -201,14 +201,21 @@ class PoolingRequestOutput(Generic[_O]): request_id (str): A unique identifier for the pooling request. outputs (PoolingOutput): The pooling results for the given input. prompt_token_ids (list[int]): A list of token IDs used in the prompt. + num_cached_tokens: The number of tokens with prefix cache hit. finished (bool): A flag indicating whether the pooling is completed. """ def __init__( - self, request_id: str, outputs: _O, prompt_token_ids: list[int], finished: bool + self, + request_id: str, + outputs: _O, + prompt_token_ids: list[int], + num_cached_tokens: int, + finished: bool, ): self.request_id = request_id self.prompt_token_ids = prompt_token_ids + self.num_cached_tokens = num_cached_tokens self.finished = finished self.outputs = outputs @@ -217,6 +224,7 @@ def __repr__(self): f"{type(self).__name__}(request_id={self.request_id!r}, " f"outputs={self.outputs!r}, " f"prompt_token_ids={self.prompt_token_ids}, " + f"num_cached_tokens={self.num_cached_tokens}, " f"finished={self.finished})" ) @@ -255,6 +263,7 @@ def from_base(request_output: PoolingRequestOutput): request_id=request_output.request_id, outputs=EmbeddingOutput.from_base(request_output.outputs), prompt_token_ids=request_output.prompt_token_ids, + num_cached_tokens=request_output.num_cached_tokens, finished=request_output.finished, ) @@ -294,6 +303,7 @@ def from_base(request_output: PoolingRequestOutput): request_id=request_output.request_id, outputs=ClassificationOutput.from_base(request_output.outputs), prompt_token_ids=request_output.prompt_token_ids, + num_cached_tokens=request_output.num_cached_tokens, finished=request_output.finished, ) @@ -330,5 +340,6 @@ def from_base(request_output: PoolingRequestOutput): request_id=request_output.request_id, outputs=ScoringOutput.from_base(request_output.outputs), prompt_token_ids=request_output.prompt_token_ids, + num_cached_tokens=request_output.num_cached_tokens, finished=request_output.finished, ) diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 2bc1542187c9..e013023e26fa 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -231,6 +231,7 @@ def _new_request_output( request_id=request_id, outputs=first_output, prompt_token_ids=self.prompt_token_ids, + num_cached_tokens=self.num_cached_tokens, finished=finished, ) assert self.logprobs_processor is not None From 5c4b13c9fa3bf7fc28be68ced4555e6fb9955dc6 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Sat, 18 Oct 2025 21:35:14 +0800 Subject: [PATCH 11/18] fix Signed-off-by: wang.yuqi --- vllm/entrypoints/llm.py | 3 +++ vllm/entrypoints/openai/serving_embedding.py | 1 + vllm/entrypoints/score_utils.py | 1 + 3 files changed, 5 insertions(+) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 30bcb59437d9..89a84d32ce5f 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1068,6 +1068,9 @@ def encode( PoolingRequestOutput[Any]( request_id="", outputs=processed_outputs, + num_cached_tokens=getattr( + processed_outputs, "num_cached_tokens", 0 + ), prompt_token_ids=[], finished=True, ) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 2e3129cbeb8e..ceb9597a0758 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -548,6 +548,7 @@ async def _collect_batch( request_id=aggregator["request_id"], prompt_token_ids=original_token_ids, outputs=pooling_output_data, + num_cached_tokens=0, finished=True, ) diff --git a/vllm/entrypoints/score_utils.py b/vllm/entrypoints/score_utils.py index cd62cfe5448c..309a4c996392 100644 --- a/vllm/entrypoints/score_utils.py +++ b/vllm/entrypoints/score_utils.py @@ -66,6 +66,7 @@ def _cosine_similarity( request_id=f"{emb_1.request_id}_{emb_2.request_id}", outputs=pair_score, prompt_token_ids=tokens, + num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens, finished=True, ) ) From 178ccd26e52044ea9d80f3d19772d4b421117954 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Sat, 18 Oct 2025 21:58:52 +0800 Subject: [PATCH 12/18] fix StepPooler Signed-off-by: wang.yuqi --- vllm/model_executor/layers/pooler.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 29539228837a..d5a39d58b4f6 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -759,14 +759,17 @@ def extract_states( ) -> torch.Tensor | list[torch.Tensor]: pooled_data_lst = self.pooling(hidden_states, pooling_metadata) prompt_token_ids = get_prompt_token_ids(pooling_metadata) - - pooled_data = list[torch.Tensor]() - pooling_params = get_pooling_params(pooling_metadata) + pooled_data = [] for data, token_id, pooling_param in zip( pooled_data_lst, prompt_token_ids, pooling_params ): + # for unfinished chunked prefill + if data is None: + pooled_data.append(data) + continue + step_tag_id = pooling_param.step_tag_id returned_token_ids = pooling_param.returned_token_ids From 43291dbd4ebfd5e7347cea2fa68534341029f26a Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Sat, 18 Oct 2025 22:06:56 +0800 Subject: [PATCH 13/18] fix StepPooler Signed-off-by: wang.yuqi --- vllm/model_executor/layers/pooler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index d5a39d58b4f6..3039bc302def 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -756,12 +756,12 @@ def extract_states( self, hidden_states: torch.Tensor | list[torch.Tensor], pooling_metadata: PoolingMetadata, - ) -> torch.Tensor | list[torch.Tensor]: + ) -> PoolerOutput: pooled_data_lst = self.pooling(hidden_states, pooling_metadata) prompt_token_ids = get_prompt_token_ids(pooling_metadata) pooling_params = get_pooling_params(pooling_metadata) - pooled_data = [] + pooled_data: list[torch.Tensor | None] = [] for data, token_id, pooling_param in zip( pooled_data_lst, prompt_token_ids, pooling_params ): From bb9a4ad7c7396bf40f76a614265e50c40b3fa4ef Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Tue, 21 Oct 2025 09:35:30 +0800 Subject: [PATCH 14/18] + preempted_req Signed-off-by: wang.yuqi --- vllm/v1/core/sched/scheduler.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 08368b7d99ef..abc10b4bf759 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -283,6 +283,14 @@ def schedule(self) -> SchedulerOutput: self.kv_cache_manager.free(preempted_req) self.encoder_cache_manager.free(preempted_req) + + # The hidden_states_cache is used in requests that + # use all pooling + chunked prefill. + # If the request is preempted, the hidden_states_cache + # needs to be cleared and recalculated. + if preempted_req.pooling_params is not None: + preempted_req.pooling_params.hidden_states_cache.clear() + preempted_req.status = RequestStatus.PREEMPTED preempted_req.num_computed_tokens = 0 preempted_req.num_preemptions += 1 From eea5f6c0ecb79cf5bcd0610fb7415e304fd52d8d Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Tue, 28 Oct 2025 13:49:04 +0800 Subject: [PATCH 15/18] update Signed-off-by: wang.yuqi --- tests/models/language/pooling/test_extract_hidden_states.py | 5 ----- vllm/pooling_params.py | 5 +++-- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/tests/models/language/pooling/test_extract_hidden_states.py b/tests/models/language/pooling/test_extract_hidden_states.py index 0747ef6a237f..4964eedcac6e 100644 --- a/tests/models/language/pooling/test_extract_hidden_states.py +++ b/tests/models/language/pooling/test_extract_hidden_states.py @@ -25,10 +25,6 @@ def test_embed_models(hf_runner, vllm_runner, model: str): # `enable_chunked_prefill`: Set to `False` instead of `None` in VllmRunner enable_chunked_prefill=True, enable_prefix_caching=True, - enforce_eager=True, - runner="pooling", - enable_chunked_prefill=False, - enable_prefix_caching=False, ) as vllm_model: pooling_outputs = vllm_model.llm.encode( [TokensPrompt(prompt_token_ids=t) for t in token_prompts], @@ -40,4 +36,3 @@ def test_embed_models(hf_runner, vllm_runner, model: str): # We should ensure that all pooling task output.num_cached_tokens == 0 # even if prefix caching is enabled assert output.num_cached_tokens >= 0 - assert output.num_cached_tokens == 0 diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py index 22e14587cf6c..a8402c735ea8 100644 --- a/vllm/pooling_params.py +++ b/vllm/pooling_params.py @@ -56,10 +56,11 @@ class PoolingParams( task: PoolingTask | None = None requires_token_ids: bool = False extra_kwargs: dict[str, Any] | None = None - # use in AllPool - hidden_states_cache: list[torch.Tensor] = msgspec.field(default_factory=list) output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY + # for chunked prefill with ALL pooling + hidden_states_cache: list[torch.Tensor] = msgspec.field(default_factory=list) + @property def all_parameters(self) -> list[str]: return ["dimensions", "normalize", "activation"] From 41ff486d17650286b4fe2e0d372d149db95e32e4 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Tue, 28 Oct 2025 14:01:06 +0800 Subject: [PATCH 16/18] update Signed-off-by: wang.yuqi --- vllm/v1/engine/output_processor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 5c8cfcf333dd..44e4eadce42a 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -232,7 +232,6 @@ def _new_request_output( outputs=first_output, num_cached_tokens=self.num_cached_tokens, prompt_token_ids=self.prompt_token_ids, - num_cached_tokens=self.num_cached_tokens, finished=finished, ) assert self.logprobs_processor is not None From e8f222e3cd65b590a48fe93c8550028eae6cb5e1 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Mon, 17 Nov 2025 11:39:01 +0800 Subject: [PATCH 17/18] update Signed-off-by: wang.yuqi --- .../pooling/openai/test_classification.py | 3 +- .../test_all_pooling_plus_chunked_prefill.py | 5 +--- .../pooling/test_extract_hidden_states.py | 6 ---- vllm/v1/core/sched/scheduler.py | 8 +++++ vllm/v1/worker/gpu_model_runner.py | 30 ------------------- 5 files changed, 10 insertions(+), 42 deletions(-) diff --git a/tests/entrypoints/pooling/openai/test_classification.py b/tests/entrypoints/pooling/openai/test_classification.py index 3d0c7da0ee43..f1d481e1bb9c 100644 --- a/tests/entrypoints/pooling/openai/test_classification.py +++ b/tests/entrypoints/pooling/openai/test_classification.py @@ -205,7 +205,7 @@ async def get_outputs(use_activation): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -def test_pooling(server: RemoteOpenAIServer, model_name: str): +async def test_pooling(server: RemoteOpenAIServer, model_name: str): input_text = ["This product was excellent and exceeded my expectations"] response = requests.post( server.url_for("pooling"), @@ -221,7 +221,6 @@ def test_pooling(server: RemoteOpenAIServer, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -def test_score(server: RemoteOpenAIServer, model_name: str): async def test_score(server: RemoteOpenAIServer, model_name: str): # score api is only enabled for num_labels == 1. response = requests.post( diff --git a/tests/models/language/pooling/test_all_pooling_plus_chunked_prefill.py b/tests/models/language/pooling/test_all_pooling_plus_chunked_prefill.py index 6634945f3b97..c259c532220b 100644 --- a/tests/models/language/pooling/test_all_pooling_plus_chunked_prefill.py +++ b/tests/models/language/pooling/test_all_pooling_plus_chunked_prefill.py @@ -26,10 +26,7 @@ def test_embed_models(hf_runner, vllm_runner, model: str): enforce_eager=True, # `enable_chunked_prefill`: Set to `False` instead of `None` in VllmRunner enable_chunked_prefill=True, - # If enable_prefix_caching is enabled, - # the output of all pooling will be less than n_prompt_tokens, - # we need a method to disable prefix_caching at the request level. - enable_prefix_caching=False, + enable_prefix_caching=True, ) as vllm_model: vllm_outputs = vllm_model.token_embed( [TokensPrompt(prompt_token_ids=t) for t in token_prompts], diff --git a/tests/models/language/pooling/test_extract_hidden_states.py b/tests/models/language/pooling/test_extract_hidden_states.py index d24f9b670788..2a02bf0aedef 100644 --- a/tests/models/language/pooling/test_extract_hidden_states.py +++ b/tests/models/language/pooling/test_extract_hidden_states.py @@ -11,8 +11,6 @@ ["Qwen/Qwen3-0.6B"], ) @torch.inference_mode -def test_embed_models(hf_runner, vllm_runner, model: str): - chunk_size = 10 def test_extract_hidden_states(hf_runner, vllm_runner, model: str): n_prompt_tokens = [55, 56, 57] token_prompts = [[1024 + i for i in range(n)] for n in n_prompt_tokens] @@ -20,12 +18,8 @@ def test_extract_hidden_states(hf_runner, vllm_runner, model: str): with vllm_runner( model, max_model_len=128, - max_num_batched_tokens=chunk_size, enforce_eager=True, runner="pooling", - # `enable_chunked_prefill`: Set to `False` instead of `None` in VllmRunner - enable_chunked_prefill=True, - enable_chunked_prefill=False, enable_prefix_caching=True, ) as vllm_model: pooling_outputs = vllm_model.llm.encode( diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 8e62542337a7..8bc9ee7e84f8 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -320,6 +320,14 @@ def schedule(self) -> SchedulerOutput: self.kv_cache_manager.free(preempted_req) self.encoder_cache_manager.free(preempted_req) + + # The hidden_states_cache is used in requests that + # use all pooling + chunked prefill. + # If the request is preempted, the hidden_states_cache + # needs to be cleared and recalculated. + if preempted_req.pooling_params is not None: + preempted_req.pooling_params.hidden_states_cache.clear() + preempted_req.status = RequestStatus.PREEMPTED preempted_req.num_computed_tokens = 0 preempted_req.num_preemptions += 1 diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a846abfe3472..c5307e3925d6 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2023,20 +2023,6 @@ def get_supported_pooling_tasks(self) -> list[PoolingTask]: supported_tasks = list(model.pooler.get_supported_tasks()) - if self.scheduler_config.enable_chunked_prefill: - if "token_embed" in supported_tasks: - supported_tasks.remove("token_embed") - if "token_classify" in supported_tasks: - supported_tasks.remove("token_classify") - - logger.debug_once( - "Chunked prefill is not supported with " - "token_embed and token_classify tasks " - "which using ALL pooling. " - "Please turn off chunked prefill by " - "`--no-enable-chunked-prefill` before using it." - ) - if "score" in supported_tasks: num_labels = getattr(self.model_config.hf_config, "num_labels", 0) if num_labels != 1: @@ -3834,22 +3820,6 @@ def _dummy_pooler_run( "https://docs.vllm.ai/en/latest/models/pooling_models.html " "to learn more." ) - if self.scheduler_config.enable_chunked_prefill: - raise RuntimeError( - f"Model {self.model_config.model} does not support " - "any pooling tasks with chunked prefill enabled. " - "Please add --no-enable-chunked-prefill to your " - "config or CLI args. See " - "https://docs.vllm.ai/en/latest/models/pooling_models.html " - "to learn more." - ) - else: - raise RuntimeError( - f"Model {self.model_config.model} does not support " - "any pooling tasks. See " - "https://docs.vllm.ai/en/latest/models/pooling_models.html " - "to learn more." - ) output_size = dict[PoolingTask, float]() for task in supported_pooling_tasks: From fb8197be47f0887ccad41a6958903488232a9ee1 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Mon, 17 Nov 2025 14:01:36 +0800 Subject: [PATCH 18/18] update Signed-off-by: wang.yuqi --- tests/models/language/pooling/test_extract_hidden_states.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/models/language/pooling/test_extract_hidden_states.py b/tests/models/language/pooling/test_extract_hidden_states.py index 2a02bf0aedef..488b27e2da0f 100644 --- a/tests/models/language/pooling/test_extract_hidden_states.py +++ b/tests/models/language/pooling/test_extract_hidden_states.py @@ -29,9 +29,6 @@ def test_extract_hidden_states(hf_runner, vllm_runner, model: str): for n, output in zip(n_prompt_tokens, pooling_outputs): assert len(output.prompt_token_ids) == n - # We should ensure that all pooling task output.num_cached_tokens == 0 - # even if prefix caching is enabled - assert output.num_cached_tokens >= 0 assert len(output.outputs.data) == n assert output.num_cached_tokens == 0