From 42352ba3e64cc85f1f014f0da63455b19b9b7546 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Mon, 7 Jul 2025 11:09:27 +0800 Subject: [PATCH 01/30] +test Signed-off-by: wang.yuqi --- .../test_classification_pooler_config.py | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 tests/models/language/pooling/test_classification_pooler_config.py diff --git a/tests/models/language/pooling/test_classification_pooler_config.py b/tests/models/language/pooling/test_classification_pooler_config.py new file mode 100644 index 000000000000..b591a82f54a9 --- /dev/null +++ b/tests/models/language/pooling/test_classification_pooler_config.py @@ -0,0 +1,48 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch +import torch.nn.functional as F + +from vllm.config import PoolerConfig + + +@pytest.mark.parametrize( + "model", + [ + "jason9693/Qwen2.5-1.5B-apeach", + "papluca/xlm-roberta-base-language-detection" + ], +) +@pytest.mark.parametrize("dtype", ["half"]) +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, +) -> None: + + with vllm_runner( + model, + max_model_len=512, + dtype=dtype, + override_pooler_config=PoolerConfig(softmax=False)) as vllm_model: + wo_softmax_out = vllm_model.classify(example_prompts) + + with vllm_runner( + model, + max_model_len=512, + dtype=dtype, + override_pooler_config=PoolerConfig(softmax=True)) as vllm_model: + w_softmax_out = vllm_model.classify(example_prompts) + + for wo_softmax, w_softmax in zip(wo_softmax_out, w_softmax_out): + wo_softmax = torch.tensor(wo_softmax) + w_softmax = torch.tensor(w_softmax) + + assert not torch.allclose( + wo_softmax, w_softmax, + atol=1e-2), "override_pooler_config is not working" + assert torch.allclose(F.softmax(wo_softmax, dim=-1), w_softmax, + 1e-3 if dtype == "float" else 1e-2) From f27031c091b73b022b3c6d9f8eb55609501e298f Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Wed, 23 Jul 2025 15:10:29 +0800 Subject: [PATCH 02/30] + using_normalize Signed-off-by: wang.yuqi --- tests/entrypoints/llm/test_embedding.py | 62 +++++++++++++++++ .../entrypoints/openai/test_classification.py | 30 +++++++++ tests/entrypoints/openai/test_embedding.py | 33 ++++++++++ ...nfig.py => test_override_pooler_config.py} | 39 ++++++++++- vllm/entrypoints/openai/protocol.py | 17 +++-- vllm/model_executor/layers/pooler.py | 66 ++++++++++++++----- vllm/pooling_params.py | 16 +++++ 7 files changed, 239 insertions(+), 24 deletions(-) create mode 100644 tests/entrypoints/llm/test_embedding.py rename tests/models/language/pooling/{test_classification_pooler_config.py => test_override_pooler_config.py} (55%) diff --git a/tests/entrypoints/llm/test_embedding.py b/tests/entrypoints/llm/test_embedding.py new file mode 100644 index 000000000000..f5f34b383152 --- /dev/null +++ b/tests/entrypoints/llm/test_embedding.py @@ -0,0 +1,62 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import weakref + +import pytest +import torch +import torch.nn.functional as F + +from vllm import LLM, PoolingParams +from vllm.distributed import cleanup_dist_env_and_memory + +MODEL_NAME = "intfloat/multilingual-e5-small" + +prompts = ["The chef prepared a delicious meal."] + + +@pytest.fixture(autouse=True) +def v1(run_with_both_engines): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + +@pytest.fixture(scope="module") +def llm(): + # pytest caches the fixture so we use weakref.proxy to + # enable garbage collection + llm = LLM(model=MODEL_NAME, + max_num_batched_tokens=32768, + tensor_parallel_size=1, + gpu_memory_utilization=0.75, + enforce_eager=True, + seed=0) + + with llm.deprecate_legacy_api(): + yield weakref.proxy(llm) + + del llm + + cleanup_dist_env_and_memory() + + +@pytest.mark.skip_global_cleanup +def test_normalize(llm: LLM): + + def get_outputs(normalize): + outputs = llm.embed(prompts, + pooling_params=PoolingParams(normalize=normalize)) + return torch.tensor([x.outputs.embedding for x in outputs]) + + default = get_outputs(normalize=None) + w_normal = get_outputs(normalize=True) + wo_normal = get_outputs(normalize=False) + + assert torch.allclose(default, w_normal), "Default should use normal." + assert not torch.allclose(w_normal, + wo_normal), "wo_normal should not use normal." + assert torch.allclose(w_normal, F.normalize( + wo_normal, p=2, + dim=-1)), "w_normal should be close to normal(wo_normal)." diff --git a/tests/entrypoints/openai/test_classification.py b/tests/entrypoints/openai/test_classification.py index b2472658ca81..8c4b937f4848 100644 --- a/tests/entrypoints/openai/test_classification.py +++ b/tests/entrypoints/openai/test_classification.py @@ -3,6 +3,7 @@ import pytest import requests +import torch from vllm.entrypoints.openai.protocol import ClassificationResponse @@ -181,3 +182,32 @@ async def test_invocations(server: RemoteOpenAIServer): assert classification_data.keys() == invocation_data.keys() assert classification_data["probs"] == pytest.approx( invocation_data["probs"], rel=0.01) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_softmax(server: RemoteOpenAIServer, model_name: str): + input_text = ["This product was excellent and exceeded my expectations"] + + async def get_outputs(softmax): + + response = requests.post(server.url_for("classify"), + json={ + "model": model_name, + "input": input_text, + "softmax": softmax + }) + + outputs = response.json() + return torch.tensor([x['probs'] for x in outputs["data"]]) + + default = await get_outputs(softmax=None) + w_softmax = await get_outputs(softmax=True) + wo_softmax = await get_outputs(softmax=False) + + assert torch.allclose(default, w_softmax), "Default should use softmax." + assert not torch.allclose(w_softmax, + wo_softmax), "wo_softmax should not use softmax." + assert torch.allclose( + w_softmax, + wo_softmax), "w_softmax should be close to softmax(wo_softmax)." diff --git a/tests/entrypoints/openai/test_embedding.py b/tests/entrypoints/openai/test_embedding.py index f03c96b12179..e39c97398676 100644 --- a/tests/entrypoints/openai/test_embedding.py +++ b/tests/entrypoints/openai/test_embedding.py @@ -8,6 +8,8 @@ import pytest import pytest_asyncio import requests +import torch +import torch.nn.functional as F from vllm.entrypoints.openai.protocol import EmbeddingResponse from vllm.transformers_utils.tokenizer import get_tokenizer @@ -369,3 +371,34 @@ async def test_invocations_conversation(server: RemoteOpenAIServer): embeddings_1_lst=[invocation_data["embedding"]], name_0="chat", name_1="invocation") + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_normalize(server: RemoteOpenAIServer, model_name: str): + input_text = ["The chef prepared a delicious meal."] + + async def get_outputs(normalize): + request_args = { + "model": MODEL_NAME, + "input": input_text, + "encoding_format": "float", + "normalize": normalize + } + + response = requests.post(server.url_for("v1/embeddings"), + json=request_args) + outputs = response.json() + + return torch.tensor([x['embedding'] for x in outputs["data"]]) + + default = await get_outputs(normalize=None) + w_normal = await get_outputs(normalize=True) + wo_normal = await get_outputs(normalize=False) + + assert torch.allclose(default, w_normal), "Default should use normal." + assert not torch.allclose(w_normal, + wo_normal), "wo_normal should not use normal." + assert torch.allclose(w_normal, F.normalize( + wo_normal, p=2, + dim=-1)), "w_normal should be close to normal(wo_normal)." diff --git a/tests/models/language/pooling/test_classification_pooler_config.py b/tests/models/language/pooling/test_override_pooler_config.py similarity index 55% rename from tests/models/language/pooling/test_classification_pooler_config.py rename to tests/models/language/pooling/test_override_pooler_config.py index b591a82f54a9..925b383aed5d 100644 --- a/tests/models/language/pooling/test_classification_pooler_config.py +++ b/tests/models/language/pooling/test_override_pooler_config.py @@ -15,7 +15,7 @@ ], ) @pytest.mark.parametrize("dtype", ["half"]) -def test_models( +def test_classify_models_using_softmax( hf_runner, vllm_runner, example_prompts, @@ -46,3 +46,40 @@ def test_models( atol=1e-2), "override_pooler_config is not working" assert torch.allclose(F.softmax(wo_softmax, dim=-1), w_softmax, 1e-3 if dtype == "float" else 1e-2) + + +@pytest.mark.parametrize( + "model", + [ + "intfloat/multilingual-e5-small", + ], +) +@pytest.mark.parametrize("dtype", ["half"]) +def test_embed_models_using_normalize( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, +) -> None: + + with vllm_runner(model, + max_model_len=512, + dtype=dtype, + override_pooler_config=PoolerConfig( + normalize=False)) as vllm_model: + wo_normalize = torch.tensor(vllm_model.embed(example_prompts)) + + with vllm_runner( + model, + max_model_len=512, + dtype=dtype, + override_pooler_config=PoolerConfig(normalize=True)) as vllm_model: + w_normalize = torch.tensor(vllm_model.embed(example_prompts)) + + assert not torch.allclose( + wo_normalize, + w_normalize), "override_pooler_config normalize is not working" + assert torch.allclose( + F.normalize(wo_normalize, p=2, dim=-1), + w_normalize), "w_normal should be close to normal(wo_normal)." diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 95e5bcd3bae1..ffb8e992afd8 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1251,11 +1251,13 @@ class EmbeddingCompletionRequest(OpenAIBaseModel): "default: 0). Any priority other than 0 will raise an error " "if the served model does not use priority scheduling."), ) + normalize: Optional[bool] = None # --8<-- [end:embedding-extra-params] def to_pooling_params(self): - return PoolingParams(dimensions=self.dimensions) + return PoolingParams(dimensions=self.dimensions, + normalize=self.normalize) class EmbeddingChatRequest(OpenAIBaseModel): @@ -1302,6 +1304,7 @@ class EmbeddingChatRequest(OpenAIBaseModel): "default: 0). Any priority other than 0 will raise an error " "if the served model does not use priority scheduling."), ) + normalize: Optional[bool] = None # --8<-- [end:chat-embedding-extra-params] @model_validator(mode="before") @@ -1314,7 +1317,8 @@ def check_generation_prompt(cls, data): return data def to_pooling_params(self): - return PoolingParams(dimensions=self.dimensions) + return PoolingParams(dimensions=self.dimensions, + normalize=self.normalize) EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest] @@ -1329,6 +1333,7 @@ class ScoreRequest(OpenAIBaseModel): text_1: Union[list[str], str, ScoreMultiModalParam] text_2: Union[list[str], str, ScoreMultiModalParam] truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None + softmax: Optional[bool] = None # --8<-- [start:score-extra-params] @@ -1348,7 +1353,7 @@ class ScoreRequest(OpenAIBaseModel): # --8<-- [end:score-extra-params] def to_pooling_params(self): - return PoolingParams() + return PoolingParams(softmax=self.softmax) class RerankRequest(OpenAIBaseModel): @@ -1357,6 +1362,7 @@ class RerankRequest(OpenAIBaseModel): documents: Union[list[str], ScoreMultiModalParam] top_n: int = Field(default_factory=lambda: 0) truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None + softmax: Optional[bool] = None # --8<-- [start:rerank-extra-params] @@ -1376,7 +1382,7 @@ class RerankRequest(OpenAIBaseModel): # --8<-- [end:rerank-extra-params] def to_pooling_params(self): - return PoolingParams() + return PoolingParams(softmax=self.softmax) class RerankDocument(BaseModel): @@ -1513,6 +1519,7 @@ class ClassificationRequest(OpenAIBaseModel): input: Union[list[str], str] truncate_prompt_tokens: Optional[int] = None user: Optional[str] = None + softmax: Optional[bool] = None # --8<-- [start:classification-extra-params] priority: int = Field( @@ -1526,7 +1533,7 @@ class ClassificationRequest(OpenAIBaseModel): # --8<-- [end:classification-extra-params] def to_pooling_params(self): - return PoolingParams() + return PoolingParams(softmax=self.softmax) class ClassificationData(OpenAIBaseModel): diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index c06cca080227..498de0197371 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -482,6 +482,8 @@ def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: class PoolerHead(nn.Module): + # embed use this class + # Classify & Score seems not to use this class @classmethod def from_config(cls, pooler_config: ResolvedPoolingConfig) -> "PoolerHead": @@ -489,24 +491,17 @@ def from_config(cls, pooler_config: ResolvedPoolingConfig) -> "PoolerHead": raise ValueError("`normalize=True` and `softmax=True` should not " "be set together") - activation: PoolerActivation - if pooler_config.normalize: - activation = PoolerNormalize() - elif pooler_config.softmax: - activation = PoolerClassify() - else: - activation = PoolerIdentity() - - return cls(activation) + return cls(pooler_config) - def __init__(self, activation: PoolerActivation) -> None: + def __init__(self, pooler_config: ResolvedPoolingConfig) -> None: super().__init__() - self.activation = activation + self.pooler_config = pooler_config + self.normalize = PoolerNormalize() + self.softmax = PoolerClassify() def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], pooling_metadata: PoolingMetadata): - # Using float32 in PoolerHead if isinstance(pooled_data, list): for i in range(len(pooled_data)): @@ -514,18 +509,22 @@ def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], else: pooled_data = pooled_data.to(torch.float32) - # for matryoshka representation if isinstance(pooling_metadata, V0PoolingMetadata): - dimensions_list = [ - pooling_param.dimensions + pooling_params = [ + pooling_param for _, pooling_param in pooling_metadata.seq_groups ] else: assert isinstance(pooled_data, list) - dimensions_list = [ - pooling_param.dimensions + pooling_params = [ + pooling_param for pooling_param in pooling_metadata.pooling_params ] + + # for matryoshka representation + dimensions_list = [ + pooling_param.dimensions for pooling_param in pooling_params + ] if any(d is not None for d in dimensions_list): # change the output dimension assert len(pooled_data) == len(dimensions_list) @@ -540,7 +539,38 @@ def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], for vecs, d in zip(pooled_data, dimensions_list) ] - return self.activation(pooled_data) + # for normalize + normalize_list = [ + pooling_param.normalize or + (pooling_param.normalize is None and self.pooler_config.normalize) + for pooling_param in pooling_params + ] + + if len(set(normalize_list)) == 1: + if normalize_list[0]: + pooled_data = self.normalize(pooled_data) + else: + pooled_data = [ + self.normalize(vecs) if f else vecs + for vecs, f in zip(pooled_data, normalize_list) + ] + + # for softmax + softmax_list = [ + pooling_param.softmax + or (pooling_param.softmax is None and self.pooler_config.softmax) + for pooling_param in pooling_params + ] + + if len(set(softmax_list)) == 1: + if softmax_list[0]: + pooled_data = self.softmax(pooled_data) + else: + pooled_data = [ + self.softmax(vecs) if f else vecs + for vecs, f in zip(pooled_data, softmax_list) + ] + return pooled_data class SimplePooler(Pooler): diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py index 868facbe2557..fd723a5fa517 100644 --- a/vllm/pooling_params.py +++ b/vllm/pooling_params.py @@ -22,10 +22,18 @@ class PoolingParams( Attributes: dimensions: Reduce the dimensions of embeddings if model support matryoshka representation. + softmax: Whether to using softmax, + None means using the model's default + normalize: normalize: Whether to using softmax, + None means using the model's default """ dimensions: Optional[int] = None + softmax: Optional[bool] = None + + normalize: Optional[bool] = None + output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY task: Optional[PoolingTask] = None @@ -38,6 +46,8 @@ def clone(self) -> "PoolingParams": """Returns a deep copy of the PoolingParams instance.""" return PoolingParams( dimensions=self.dimensions, + softmax=self.softmax, + normalize=self.normalize, task=self.task, requires_token_ids=self.requires_token_ids, ) @@ -71,10 +81,16 @@ def verify(self, task: PoolingTask, model_config: "ModelConfig") -> None: elif self.dimensions < 1: raise ValueError("Dimensions must be greater than 0") + if self.normalize and self.softmax: + raise ValueError("`normalize=True` and `softmax=True` should not " + "be set together") + def __repr__(self) -> str: return (f"PoolingParams(" f"dimensions={self.dimensions}, " f"task={self.task}, " + f"softmax={self.softmax}, " + f"normalize={self.normalize}, " f"requires_token_ids={self.requires_token_ids})") def __post_init__(self) -> None: From 593df3e8c3adecb3a79dbbb6b50ffd6588b8fba2 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Thu, 24 Jul 2025 16:54:28 +0800 Subject: [PATCH 03/30] + using_activation Signed-off-by: wang.yuqi --- tests/entrypoints/llm/test_classification.py | 64 ++++++++++++++++++ tests/entrypoints/llm/test_embedding.py | 13 ++-- tests/entrypoints/llm/test_score.py | 67 +++++++++++++++++++ .../entrypoints/openai/test_classification.py | 27 ++++---- tests/entrypoints/openai/test_embedding.py | 13 ++-- tests/entrypoints/openai/test_rerank.py | 38 +++++++++++ tests/entrypoints/openai/test_score.py | 36 ++++++++++ .../pooling/test_override_pooler_config.py | 45 +++++++------ vllm/config.py | 5 ++ vllm/entrypoints/llm.py | 14 +++- vllm/entrypoints/openai/protocol.py | 12 ++-- vllm/model_executor/layers/pooler.py | 48 +++++++++---- vllm/model_executor/models/adapters.py | 2 + vllm/model_executor/models/bert.py | 2 + vllm/model_executor/models/modernbert.py | 2 + vllm/model_executor/models/roberta.py | 2 + vllm/pooling_params.py | 7 +- 17 files changed, 328 insertions(+), 69 deletions(-) create mode 100644 tests/entrypoints/llm/test_classification.py create mode 100644 tests/entrypoints/llm/test_score.py diff --git a/tests/entrypoints/llm/test_classification.py b/tests/entrypoints/llm/test_classification.py new file mode 100644 index 000000000000..03ac583cfc4a --- /dev/null +++ b/tests/entrypoints/llm/test_classification.py @@ -0,0 +1,64 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import weakref + +import pytest +import torch +import torch.nn.functional as F + +from vllm import LLM, PoolingParams +from vllm.distributed import cleanup_dist_env_and_memory + +MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach" + +prompts = ["The chef prepared a delicious meal."] + + +@pytest.fixture(autouse=True) +def v1(run_with_both_engines): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + +@pytest.fixture(scope="module") +def llm(): + # pytest caches the fixture so we use weakref.proxy to + # enable garbage collection + llm = LLM(model=MODEL_NAME, + max_num_batched_tokens=32768, + tensor_parallel_size=1, + gpu_memory_utilization=0.75, + enforce_eager=True, + seed=0) + + with llm.deprecate_legacy_api(): + yield weakref.proxy(llm) + + del llm + + cleanup_dist_env_and_memory() + + +@pytest.mark.skip_global_cleanup +def test_activation(llm: LLM): + + def get_outputs(activation): + outputs = llm.classify(prompts, + pooling_params=PoolingParams(activation=activation)) + return torch.tensor([x.outputs.probs for x in outputs]) + + default = get_outputs(activation=None) + w_activation = get_outputs(activation=True) + wo_activation = get_outputs(activation=False) + + assert torch.allclose(default, w_activation, + atol=1e-2), "Default should use activation." + assert not torch.allclose( + w_activation, wo_activation, + atol=1e-2), "wo_activation should not use activation." + assert torch.allclose( + F.softmax(wo_activation, dim=-1), w_activation, atol=1e-2 + ), "w_activation should be close to activation(wo_activation)." diff --git a/tests/entrypoints/llm/test_embedding.py b/tests/entrypoints/llm/test_embedding.py index f5f34b383152..2a372c2f93d6 100644 --- a/tests/entrypoints/llm/test_embedding.py +++ b/tests/entrypoints/llm/test_embedding.py @@ -54,9 +54,10 @@ def get_outputs(normalize): w_normal = get_outputs(normalize=True) wo_normal = get_outputs(normalize=False) - assert torch.allclose(default, w_normal), "Default should use normal." - assert not torch.allclose(w_normal, - wo_normal), "wo_normal should not use normal." - assert torch.allclose(w_normal, F.normalize( - wo_normal, p=2, - dim=-1)), "w_normal should be close to normal(wo_normal)." + assert torch.allclose(default, w_normal, + atol=1e-2), "Default should use normal." + assert not torch.allclose(w_normal, wo_normal, + atol=1e-2), "wo_normal should not use normal." + assert torch.allclose( + w_normal, F.normalize(wo_normal, p=2, dim=-1), + atol=1e-2), "w_normal should be close to normal(wo_normal)." diff --git a/tests/entrypoints/llm/test_score.py b/tests/entrypoints/llm/test_score.py new file mode 100644 index 000000000000..435a68b16717 --- /dev/null +++ b/tests/entrypoints/llm/test_score.py @@ -0,0 +1,67 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import weakref + +import pytest +import torch +import torch.nn.functional as F + +from vllm import LLM, PoolingParams +from vllm.distributed import cleanup_dist_env_and_memory + +MODEL_NAME = "BAAI/bge-reranker-v2-m3" + + +@pytest.fixture(autouse=True) +def v1(run_with_both_engines): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + +@pytest.fixture(scope="module") +def llm(): + # pytest caches the fixture so we use weakref.proxy to + # enable garbage collection + llm = LLM(model=MODEL_NAME, + max_num_batched_tokens=32768, + tensor_parallel_size=1, + gpu_memory_utilization=0.75, + enforce_eager=True, + seed=0) + + with llm.deprecate_legacy_api(): + yield weakref.proxy(llm) + + del llm + + cleanup_dist_env_and_memory() + + +@pytest.mark.skip_global_cleanup +def test_activation(llm: LLM): + + def get_outputs(activation): + text_1 = "What is the capital of France?" + text_2 = "The capital of France is Paris." + + outputs = llm.score( + text_1, + text_2, + pooling_params=PoolingParams(activation=activation)) + return torch.tensor([x.outputs.score for x in outputs]) + + default = get_outputs(activation=None) + w_activation = get_outputs(activation=True) + wo_activation = get_outputs(activation=False) + + assert torch.allclose(default, w_activation, + atol=1e-2), "Default should use activation." + assert not torch.allclose( + w_activation, wo_activation, + atol=1e-2), "wo_activation should not use activation." + assert torch.allclose( + F.sigmoid(wo_activation), w_activation, atol=1e-2 + ), "w_activation should be close to activation(wo_activation)." diff --git a/tests/entrypoints/openai/test_classification.py b/tests/entrypoints/openai/test_classification.py index 8c4b937f4848..bcf127307f73 100644 --- a/tests/entrypoints/openai/test_classification.py +++ b/tests/entrypoints/openai/test_classification.py @@ -4,6 +4,7 @@ import pytest import requests import torch +import torch.nn.functional as F from vllm.entrypoints.openai.protocol import ClassificationResponse @@ -186,28 +187,28 @@ async def test_invocations(server: RemoteOpenAIServer): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_softmax(server: RemoteOpenAIServer, model_name: str): +async def test_activation(server: RemoteOpenAIServer, model_name: str): input_text = ["This product was excellent and exceeded my expectations"] - async def get_outputs(softmax): - + async def get_outputs(activation): response = requests.post(server.url_for("classify"), json={ "model": model_name, "input": input_text, - "softmax": softmax + "activation": activation }) - outputs = response.json() return torch.tensor([x['probs'] for x in outputs["data"]]) - default = await get_outputs(softmax=None) - w_softmax = await get_outputs(softmax=True) - wo_softmax = await get_outputs(softmax=False) + default = await get_outputs(activation=None) + w_activation = await get_outputs(activation=True) + wo_activation = await get_outputs(activation=False) - assert torch.allclose(default, w_softmax), "Default should use softmax." - assert not torch.allclose(w_softmax, - wo_softmax), "wo_softmax should not use softmax." + assert torch.allclose(default, w_activation, + atol=1e-2), "Default should use activation." + assert not torch.allclose( + w_activation, wo_activation, + atol=1e-2), "wo_activation should not use activation." assert torch.allclose( - w_softmax, - wo_softmax), "w_softmax should be close to softmax(wo_softmax)." + F.softmax(wo_activation, dim=-1), w_activation, atol=1e-2 + ), "w_activation should be close to activation(wo_activation)." diff --git a/tests/entrypoints/openai/test_embedding.py b/tests/entrypoints/openai/test_embedding.py index e39c97398676..29e2c52a64dd 100644 --- a/tests/entrypoints/openai/test_embedding.py +++ b/tests/entrypoints/openai/test_embedding.py @@ -396,9 +396,10 @@ async def get_outputs(normalize): w_normal = await get_outputs(normalize=True) wo_normal = await get_outputs(normalize=False) - assert torch.allclose(default, w_normal), "Default should use normal." - assert not torch.allclose(w_normal, - wo_normal), "wo_normal should not use normal." - assert torch.allclose(w_normal, F.normalize( - wo_normal, p=2, - dim=-1)), "w_normal should be close to normal(wo_normal)." + assert torch.allclose(default, w_normal, + atol=1e-2), "Default should use normal." + assert not torch.allclose(w_normal, wo_normal, + atol=1e-2), "wo_normal should not use normal." + assert torch.allclose( + w_normal, F.normalize(wo_normal, p=2, dim=-1), + atol=1e-2), "w_normal should be close to normal(wo_normal)." diff --git a/tests/entrypoints/openai/test_rerank.py b/tests/entrypoints/openai/test_rerank.py index 4da97fe13691..f121693e329f 100644 --- a/tests/entrypoints/openai/test_rerank.py +++ b/tests/entrypoints/openai/test_rerank.py @@ -3,6 +3,8 @@ import pytest import requests +import torch +import torch.nn.functional as F from vllm.entrypoints.openai.protocol import RerankResponse @@ -125,3 +127,39 @@ def test_invocations(server: RemoteOpenAIServer): assert rerank_result.keys() == invocations_result.keys() assert rerank_result["relevance_score"] == pytest.approx( invocations_result["relevance_score"], rel=0.01) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_activation(server: RemoteOpenAIServer, model_name: str): + + async def get_outputs(activation): + query = "What is the capital of France?" + documents = [ + "The capital of Brazil is Brasilia.", + "The capital of France is Paris." + ] + + response = requests.post(server.url_for("rerank"), + json={ + "model": model_name, + "query": query, + "documents": documents, + "activation": activation + }) + outputs = response.json() + + return torch.tensor([x['relevance_score'] for x in outputs["results"]]) + + default = await get_outputs(activation=None) + w_activation = await get_outputs(activation=True) + wo_activation = await get_outputs(activation=False) + + assert torch.allclose(default, w_activation, + atol=1e-2), "Default should use activation." + assert not torch.allclose( + w_activation, wo_activation, + atol=1e-2), "wo_activation should not use activation." + assert torch.allclose( + F.sigmoid(wo_activation), w_activation, atol=1e-2 + ), "w_activation should be close to activation(wo_activation)." diff --git a/tests/entrypoints/openai/test_score.py b/tests/entrypoints/openai/test_score.py index 187542b7bafc..dfe010e5845c 100644 --- a/tests/entrypoints/openai/test_score.py +++ b/tests/entrypoints/openai/test_score.py @@ -4,6 +4,7 @@ import pytest import requests +import torch import torch.nn.functional as F from torch import tensor @@ -220,3 +221,38 @@ def test_invocations(self, server: RemoteOpenAIServer, model: dict[str, assert score_data.keys() == invocation_data.keys() assert score_data["score"] == pytest.approx( invocation_data["score"], rel=0.01) + + def test_activation(self, server: RemoteOpenAIServer, model: dict[str, + Any]): + + def get_outputs(activation): + text_1 = "What is the capital of France?" + text_2 = "The capital of France is Paris." + response = requests.post(server.url_for("score"), + json={ + "model": model["name"], + "text_1": text_1, + "text_2": text_2, + "activation": activation + }) + outputs = response.json() + + return torch.tensor([x['score'] for x in outputs["data"]]) + + default = get_outputs(activation=None) + w_activation = get_outputs(activation=True) + wo_activation = get_outputs(activation=False) + + if model["is_cross_encoder"]: + assert torch.allclose(default, w_activation, + atol=1e-2), "Default should use activation." + assert not torch.allclose( + w_activation, wo_activation, + atol=1e-2), "wo_activation should not use activation." + assert torch.allclose( + F.sigmoid(wo_activation), w_activation, atol=1e-2 + ), "w_activation should be close to activation(wo_activation)." + else: + # The activation parameter only works for the is_cross_encoder model + assert torch.allclose(default, w_activation, atol=1e-2) + assert torch.allclose(wo_activation, w_activation, atol=1e-2) diff --git a/tests/models/language/pooling/test_override_pooler_config.py b/tests/models/language/pooling/test_override_pooler_config.py index 925b383aed5d..63da5aa39469 100644 --- a/tests/models/language/pooling/test_override_pooler_config.py +++ b/tests/models/language/pooling/test_override_pooler_config.py @@ -15,7 +15,7 @@ ], ) @pytest.mark.parametrize("dtype", ["half"]) -def test_classify_models_using_softmax( +def test_classify_models_using_activation( hf_runner, vllm_runner, example_prompts, @@ -23,28 +23,29 @@ def test_classify_models_using_softmax( dtype: str, ) -> None: - with vllm_runner( - model, - max_model_len=512, - dtype=dtype, - override_pooler_config=PoolerConfig(softmax=False)) as vllm_model: - wo_softmax_out = vllm_model.classify(example_prompts) + with vllm_runner(model, + max_model_len=512, + dtype=dtype, + override_pooler_config=PoolerConfig( + activation=False)) as vllm_model: + wo_activation_out = vllm_model.classify(example_prompts) - with vllm_runner( - model, - max_model_len=512, - dtype=dtype, - override_pooler_config=PoolerConfig(softmax=True)) as vllm_model: - w_softmax_out = vllm_model.classify(example_prompts) + with vllm_runner(model, + max_model_len=512, + dtype=dtype, + override_pooler_config=PoolerConfig( + activation=True)) as vllm_model: + w_activation_out = vllm_model.classify(example_prompts) - for wo_softmax, w_softmax in zip(wo_softmax_out, w_softmax_out): - wo_softmax = torch.tensor(wo_softmax) - w_softmax = torch.tensor(w_softmax) + for wo_activation, w_activation in zip(wo_activation_out, + w_activation_out): + wo_activation = torch.tensor(wo_activation) + w_activation = torch.tensor(w_activation) assert not torch.allclose( - wo_softmax, w_softmax, + wo_activation, w_activation, atol=1e-2), "override_pooler_config is not working" - assert torch.allclose(F.softmax(wo_softmax, dim=-1), w_softmax, + assert torch.allclose(F.softmax(wo_activation, dim=-1), w_activation, 1e-3 if dtype == "float" else 1e-2) @@ -78,8 +79,8 @@ def test_embed_models_using_normalize( w_normalize = torch.tensor(vllm_model.embed(example_prompts)) assert not torch.allclose( - wo_normalize, - w_normalize), "override_pooler_config normalize is not working" + wo_normalize, w_normalize, + atol=1e-2), "override_pooler_config normalize is not working" assert torch.allclose( - F.normalize(wo_normalize, p=2, dim=-1), - w_normalize), "w_normal should be close to normal(wo_normal)." + F.normalize(wo_normalize, p=2, dim=-1), w_normalize, + atol=1e-2), "w_normal should be close to normal(wo_normal)." diff --git a/vllm/config.py b/vllm/config.py index 6623a48f839a..565571a4b1a5 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3281,6 +3281,11 @@ class PoolerConfig: to ``True`` for classification outputs. """ + activation: Optional[bool] = True + """ + Whether to apply activation function to the pooled outputs. + """ + step_tag_id: Optional[int] = None """ If set, only the score corresponding to the ``step_tag_id`` in the diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index c4f1b3b86619..880417d4660c 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1217,6 +1217,8 @@ def classify( /, *, use_tqdm: Union[bool, Callable[..., tqdm]] = True, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> list[ClassificationRequestOutput]: @@ -1252,6 +1254,7 @@ def classify( items = self.encode( prompts, use_tqdm=use_tqdm, + pooling_params=pooling_params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, pooling_task="classify", @@ -1266,6 +1269,7 @@ def _embedding_score( text_2: list[Union[str, TextPrompt, TokensPrompt]], truncate_prompt_tokens: Optional[int] = None, use_tqdm: Union[bool, Callable[..., tqdm]] = True, + pooling_params: Optional[PoolingParams] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> list[ScoringRequestOutput]: @@ -1302,6 +1306,7 @@ def _cross_encoding_score( data_2: Union[list[str], list[ScoreContentPartParam]], truncate_prompt_tokens: Optional[int] = None, use_tqdm: Union[bool, Callable[..., tqdm]] = True, + pooling_params: Optional[PoolingParams] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> list[ScoringRequestOutput]: @@ -1313,7 +1318,11 @@ def _cross_encoding_score( if len(data_1) == 1: data_1 = data_1 * len(data_2) - pooling_params = PoolingParams(task="score") + if pooling_params is None: + pooling_params = PoolingParams(task="score") + else: + pooling_params.task = "score" + tokenization_kwargs: dict[str, Any] = {} _validate_truncation_size(self.llm_engine.model_config.max_model_len, truncate_prompt_tokens, tokenization_kwargs) @@ -1380,6 +1389,7 @@ def score( *, truncate_prompt_tokens: Optional[int] = None, use_tqdm: Union[bool, Callable[..., tqdm]] = True, + pooling_params: Optional[PoolingParams] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> list[ScoringRequestOutput]: @@ -1504,6 +1514,7 @@ def ensure_str(prompt: SingletonPrompt): data_2, # type: ignore[arg-type] truncate_prompt_tokens, use_tqdm, + pooling_params, lora_request, prompt_adapter_request) else: @@ -1513,6 +1524,7 @@ def ensure_str(prompt: SingletonPrompt): data_2, # type: ignore[arg-type] truncate_prompt_tokens, use_tqdm, + pooling_params, lora_request, prompt_adapter_request) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index ffb8e992afd8..501fb8e9e441 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1333,7 +1333,7 @@ class ScoreRequest(OpenAIBaseModel): text_1: Union[list[str], str, ScoreMultiModalParam] text_2: Union[list[str], str, ScoreMultiModalParam] truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None - softmax: Optional[bool] = None + activation: Optional[bool] = None # --8<-- [start:score-extra-params] @@ -1353,7 +1353,7 @@ class ScoreRequest(OpenAIBaseModel): # --8<-- [end:score-extra-params] def to_pooling_params(self): - return PoolingParams(softmax=self.softmax) + return PoolingParams(activation=self.activation) class RerankRequest(OpenAIBaseModel): @@ -1362,7 +1362,7 @@ class RerankRequest(OpenAIBaseModel): documents: Union[list[str], ScoreMultiModalParam] top_n: int = Field(default_factory=lambda: 0) truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None - softmax: Optional[bool] = None + activation: Optional[bool] = None # --8<-- [start:rerank-extra-params] @@ -1382,7 +1382,7 @@ class RerankRequest(OpenAIBaseModel): # --8<-- [end:rerank-extra-params] def to_pooling_params(self): - return PoolingParams(softmax=self.softmax) + return PoolingParams(activation=self.activation) class RerankDocument(BaseModel): @@ -1519,7 +1519,7 @@ class ClassificationRequest(OpenAIBaseModel): input: Union[list[str], str] truncate_prompt_tokens: Optional[int] = None user: Optional[str] = None - softmax: Optional[bool] = None + activation: Optional[bool] = None # --8<-- [start:classification-extra-params] priority: int = Field( @@ -1533,7 +1533,7 @@ class ClassificationRequest(OpenAIBaseModel): # --8<-- [end:classification-extra-params] def to_pooling_params(self): - return PoolingParams(softmax=self.softmax) + return PoolingParams(activation=self.activation) class ClassificationData(OpenAIBaseModel): diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 498de0197371..f3a188188c05 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -147,6 +147,7 @@ def for_classify( pooling=base_pooler.pooling, classifier=classifier, act_fn=base_pooler.head.activation, + activation=pooler_config.activation, ) @abstractmethod @@ -482,8 +483,6 @@ def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: class PoolerHead(nn.Module): - # embed use this class - # Classify & Score seems not to use this class @classmethod def from_config(cls, pooler_config: ResolvedPoolingConfig) -> "PoolerHead": @@ -509,17 +508,7 @@ def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], else: pooled_data = pooled_data.to(torch.float32) - if isinstance(pooling_metadata, V0PoolingMetadata): - pooling_params = [ - pooling_param - for _, pooling_param in pooling_metadata.seq_groups - ] - else: - assert isinstance(pooled_data, list) - pooling_params = [ - pooling_param - for pooling_param in pooling_metadata.pooling_params - ] + pooling_params = _get_pooling_params(pooling_metadata) # for matryoshka representation dimensions_list = [ @@ -700,12 +689,14 @@ def __init__( pooling: PoolingFn, classifier: ClassifierFn, act_fn: PoolerActivation, + activation: bool = True, ) -> None: super().__init__() self.pooling = pooling self.classifier = classifier self.act_fn = act_fn + self.activation = activation def get_supported_tasks(self) -> Set[PoolingTask]: return {"classify", "score"} @@ -725,7 +716,24 @@ def forward( else: pooled_output = [self.classifier(data) for data in pooled_data] - scores = self.act_fn(pooled_output) + pooling_params = _get_pooling_params(pooling_metadata) + + activation_list = [ + pooling_param.activation + or (pooling_param.activation is None and self.activation) + for pooling_param in pooling_params + ] + + if len(set(activation_list)) == 1: + if activation_list[0]: + scores = self.act_fn(pooled_output) + else: + scores = pooled_output + else: + scores = [ + self.act_fn(vecs) if f else vecs + for vecs, f in zip(pooled_output, activation_list) + ] return build_output(scores) @@ -781,3 +789,15 @@ def forward( offset += num_items return PoolerOutput(outputs) + + +def _get_pooling_params(pooling_metadata: PoolingMetadata): + if isinstance(pooling_metadata, V0PoolingMetadata): + pooling_params = [ + pooling_param for _, pooling_param in pooling_metadata.seq_groups + ] + else: + pooling_params = [ + pooling_param for pooling_param in pooling_metadata.pooling_params + ] + return pooling_params diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index 867de2c68b4c..8bdfc03c6631 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -194,6 +194,7 @@ def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): classifier=self._classifier, act_fn=ClassifierPooler.act_fn_for_seq_cls( vllm_config.model_config), + activation=pooler_config.activation, ), "score": ClassifierPooler( @@ -201,6 +202,7 @@ def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): classifier=self._classifier, act_fn=ClassifierPooler.act_fn_for_cross_encoder( vllm_config.model_config), + activation=pooler_config.activation, ), }) diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 9dc6115f850e..760979983b04 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -548,6 +548,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): classifier=self.classifier, act_fn=ClassifierPooler.act_fn_for_seq_cls( vllm_config.model_config), + activation=pooler_config.activation, ), "score": ClassifierPooler( @@ -555,6 +556,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): classifier=self.classifier, act_fn=ClassifierPooler.act_fn_for_cross_encoder( vllm_config.model_config), + activation=pooler_config.activation, ), }) diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index be1c3438d9db..feaccb2a45d6 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -321,6 +321,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): classifier=self.classifier, act_fn=ClassifierPooler.act_fn_for_seq_cls( vllm_config.model_config), + activation=pooler_config.activation, ), "score": ClassifierPooler( @@ -328,6 +329,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): classifier=self.classifier, act_fn=ClassifierPooler.act_fn_for_cross_encoder( vllm_config.model_config), + activation=pooler_config.activation, ), }) diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index c6b411644034..e139f4e7a95b 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -194,6 +194,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): classifier=self.classifier, act_fn=ClassifierPooler.act_fn_for_seq_cls( vllm_config.model_config), + activation=pooler_config.activation, ), "score": ClassifierPooler( @@ -201,6 +202,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): classifier=self.classifier, act_fn=ClassifierPooler.act_fn_for_cross_encoder( vllm_config.model_config), + activation=pooler_config.activation, ), }) diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py index fd723a5fa517..00857a82ffc1 100644 --- a/vllm/pooling_params.py +++ b/vllm/pooling_params.py @@ -24,8 +24,9 @@ class PoolingParams( if model support matryoshka representation. softmax: Whether to using softmax, None means using the model's default - normalize: normalize: Whether to using softmax, + normalize: Whether to using normalize, None means using the model's default + activation: Whether to using activation function """ dimensions: Optional[int] = None @@ -34,6 +35,8 @@ class PoolingParams( normalize: Optional[bool] = None + activation: Optional[bool] = None + output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY task: Optional[PoolingTask] = None @@ -48,6 +51,7 @@ def clone(self) -> "PoolingParams": dimensions=self.dimensions, softmax=self.softmax, normalize=self.normalize, + activation=self.activation, task=self.task, requires_token_ids=self.requires_token_ids, ) @@ -91,6 +95,7 @@ def __repr__(self) -> str: f"task={self.task}, " f"softmax={self.softmax}, " f"normalize={self.normalize}, " + f"activation={self.activation}, " f"requires_token_ids={self.requires_token_ids})") def __post_init__(self) -> None: From 3473b059acf82a8f377a57858457ec4b1cf7cad3 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Thu, 24 Jul 2025 17:04:23 +0800 Subject: [PATCH 04/30] conflicts Signed-off-by: wang.yuqi --- vllm/entrypoints/llm.py | 23 +++++++---------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 880417d4660c..fbebbcee633c 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1269,7 +1269,6 @@ def _embedding_score( text_2: list[Union[str, TextPrompt, TokensPrompt]], truncate_prompt_tokens: Optional[int] = None, use_tqdm: Union[bool, Callable[..., tqdm]] = True, - pooling_params: Optional[PoolingParams] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> list[ScoringRequestOutput]: @@ -1306,7 +1305,6 @@ def _cross_encoding_score( data_2: Union[list[str], list[ScoreContentPartParam]], truncate_prompt_tokens: Optional[int] = None, use_tqdm: Union[bool, Callable[..., tqdm]] = True, - pooling_params: Optional[PoolingParams] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> list[ScoringRequestOutput]: @@ -1318,11 +1316,7 @@ def _cross_encoding_score( if len(data_1) == 1: data_1 = data_1 * len(data_2) - if pooling_params is None: - pooling_params = PoolingParams(task="score") - else: - pooling_params.task = "score" - + pooling_params = PoolingParams(task="score") tokenization_kwargs: dict[str, Any] = {} _validate_truncation_size(self.llm_engine.model_config.max_model_len, truncate_prompt_tokens, tokenization_kwargs) @@ -1389,7 +1383,6 @@ def score( *, truncate_prompt_tokens: Optional[int] = None, use_tqdm: Union[bool, Callable[..., tqdm]] = True, - pooling_params: Optional[PoolingParams] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> list[ScoringRequestOutput]: @@ -1405,17 +1398,17 @@ def score( of your inputs into a single list and pass it to this method. Supports both text and multi-modal data (images, etc.) when used with - appropriate multi-modal models. For multi-modal inputs, ensure the + appropriate multi-modal models. For multi-modal inputs, ensure the prompt structure matches the model's expected input format. Args: - data_1: Can be a single prompt, a list of prompts or - `ScoreMultiModalParam`, which can contain either text or - multi-modal data. When a list, it must have the same length as + data_1: Can be a single prompt, a list of prompts or + `ScoreMultiModalParam`, which can contain either text or + multi-modal data. When a list, it must have the same length as the `data_2` list. - data_2: The data to pair with the query to form the input to + data_2: The data to pair with the query to form the input to the LLM. Can be text or multi-modal data. See [PromptType] - [vllm.inputs.PromptType] for more details about the format of + [vllm.inputs.PromptType] for more details about the format of each prompt. use_tqdm: If `True`, shows a tqdm progress bar. If a callable (e.g., `functools.partial(tqdm, leave=False)`), @@ -1514,7 +1507,6 @@ def ensure_str(prompt: SingletonPrompt): data_2, # type: ignore[arg-type] truncate_prompt_tokens, use_tqdm, - pooling_params, lora_request, prompt_adapter_request) else: @@ -1524,7 +1516,6 @@ def ensure_str(prompt: SingletonPrompt): data_2, # type: ignore[arg-type] truncate_prompt_tokens, use_tqdm, - pooling_params, lora_request, prompt_adapter_request) From 8c5f74471e7f1182a07110a7be606d7657cba3c8 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Thu, 24 Jul 2025 17:18:05 +0800 Subject: [PATCH 05/30] + pooling_params Signed-off-by: wang.yuqi --- vllm/entrypoints/llm.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 699b35109fe0..5a25f2a7f63d 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1193,7 +1193,7 @@ def classify( *, use_tqdm: Union[bool, Callable[..., tqdm]] = True, pooling_params: Optional[Union[PoolingParams, - Sequence[PoolingParams]]] = None, + Sequence[PoolingParams]]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, ) -> list[ClassificationRequestOutput]: """ @@ -1212,7 +1212,8 @@ def classify( it is used to create the progress bar. If `False`, no progress bar is created. lora_request: LoRA request to use for generation, if any. - + pooling_params: The pooling parameters for pooling. If None, we + use the default pooling parameters. Returns: A list of `ClassificationRequestOutput` objects containing the embedding vectors in the same order as the input prompts. @@ -1240,6 +1241,7 @@ def _embedding_score( text_2: list[Union[str, TextPrompt, TokensPrompt]], truncate_prompt_tokens: Optional[int] = None, use_tqdm: Union[bool, Callable[..., tqdm]] = True, + pooling_params: Optional[PoolingParams] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, ) -> list[ScoringRequestOutput]: @@ -1248,6 +1250,7 @@ def _embedding_score( truncate_prompt_tokens=truncate_prompt_tokens, use_tqdm=use_tqdm, lora_request=lora_request, + pooling_params=pooling_params, pooling_task="embed", ) @@ -1274,6 +1277,7 @@ def _cross_encoding_score( data_2: Union[list[str], list[ScoreContentPartParam]], truncate_prompt_tokens: Optional[int] = None, use_tqdm: Union[bool, Callable[..., tqdm]] = True, + pooling_params: Optional[PoolingParams] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, ) -> list[ScoringRequestOutput]: @@ -1284,7 +1288,11 @@ def _cross_encoding_score( if len(data_1) == 1: data_1 = data_1 * len(data_2) - pooling_params = PoolingParams(task="score") + if pooling_params is None: + pooling_params = PoolingParams(task="score") + else: + pooling_params.task = "score" + tokenization_kwargs: dict[str, Any] = {} _validate_truncation_size(self.llm_engine.model_config.max_model_len, truncate_prompt_tokens, tokenization_kwargs) @@ -1350,6 +1358,7 @@ def score( *, truncate_prompt_tokens: Optional[int] = None, use_tqdm: Union[bool, Callable[..., tqdm]] = True, + pooling_params: Optional[PoolingParams] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, ) -> list[ScoringRequestOutput]: """Generate similarity scores for all pairs `` or @@ -1381,7 +1390,8 @@ def score( it is used to create the progress bar. If `False`, no progress bar is created. lora_request: LoRA request to use for generation, if any. - + pooling_params: The pooling parameters for pooling. If None, we + use the default pooling parameters. Returns: A list of `ScoringRequestOutput` objects containing the generated scores in the same order as the input prompts. @@ -1471,6 +1481,7 @@ def ensure_str(prompt: SingletonPrompt): data_2, # type: ignore[arg-type] truncate_prompt_tokens, use_tqdm, + pooling_params, lora_request) else: return self._embedding_score( @@ -1479,6 +1490,7 @@ def ensure_str(prompt: SingletonPrompt): data_2, # type: ignore[arg-type] truncate_prompt_tokens, use_tqdm, + pooling_params, lora_request) def start_profile(self) -> None: From e3bc35ab678ec2733b6b12cba73d50fee5824394 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Thu, 24 Jul 2025 17:28:47 +0800 Subject: [PATCH 06/30] fix Signed-off-by: wang.yuqi --- tests/entrypoints/llm/test_classification.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/entrypoints/llm/test_classification.py b/tests/entrypoints/llm/test_classification.py index 03ac583cfc4a..7429b582fe9e 100644 --- a/tests/entrypoints/llm/test_classification.py +++ b/tests/entrypoints/llm/test_classification.py @@ -46,8 +46,8 @@ def llm(): def test_activation(llm: LLM): def get_outputs(activation): - outputs = llm.classify(prompts, - pooling_params=PoolingParams(activation=activation)) + outputs = llm.classify( + prompts, pooling_params=PoolingParams(activation=activation)) return torch.tensor([x.outputs.probs for x in outputs]) default = get_outputs(activation=None) From 1a370b55a3253561d5d9bfac533cf724c677e58b Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Fri, 25 Jul 2025 11:48:40 +0800 Subject: [PATCH 07/30] fix Signed-off-by: wang.yuqi --- tests/entrypoints/llm/test_embedding.py | 8 -------- tests/entrypoints/llm/test_encode.py | 8 -------- tests/entrypoints/llm/test_score.py | 8 -------- 3 files changed, 24 deletions(-) diff --git a/tests/entrypoints/llm/test_embedding.py b/tests/entrypoints/llm/test_embedding.py index 2a372c2f93d6..db43398b4685 100644 --- a/tests/entrypoints/llm/test_embedding.py +++ b/tests/entrypoints/llm/test_embedding.py @@ -15,14 +15,6 @@ prompts = ["The chef prepared a delicious meal."] -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass - - @pytest.fixture(scope="module") def llm(): # pytest caches the fixture so we use weakref.proxy to diff --git a/tests/entrypoints/llm/test_encode.py b/tests/entrypoints/llm/test_encode.py index b930f05bebd0..c0adee8e5f04 100644 --- a/tests/entrypoints/llm/test_encode.py +++ b/tests/entrypoints/llm/test_encode.py @@ -29,14 +29,6 @@ ] -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass - - @pytest.fixture(scope="module") def llm(): # pytest caches the fixture so we use weakref.proxy to diff --git a/tests/entrypoints/llm/test_score.py b/tests/entrypoints/llm/test_score.py index 435a68b16717..f06edf3b2adc 100644 --- a/tests/entrypoints/llm/test_score.py +++ b/tests/entrypoints/llm/test_score.py @@ -13,14 +13,6 @@ MODEL_NAME = "BAAI/bge-reranker-v2-m3" -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass - - @pytest.fixture(scope="module") def llm(): # pytest caches the fixture so we use weakref.proxy to From 9d756288fad7b0dfaaa4d0e406afd3dceb7c5327 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Fri, 25 Jul 2025 18:37:36 +0800 Subject: [PATCH 08/30] + test_pooling_params.py Signed-off-by: wang.yuqi --- tests/test_pooling_params.py | 105 +++++++++++++++++++++++++++++++++++ vllm/config.py | 20 ++++--- vllm/pooling_params.py | 84 +++++++++++++++++----------- 3 files changed, 170 insertions(+), 39 deletions(-) create mode 100644 tests/test_pooling_params.py diff --git a/tests/test_pooling_params.py b/tests/test_pooling_params.py new file mode 100644 index 000000000000..eb25f191ccbe --- /dev/null +++ b/tests/test_pooling_params.py @@ -0,0 +1,105 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest + +from tests.models.utils import EmbedModelInfo +from vllm import PoolingParams +from vllm.config import ModelConfig + +EMBEDDING_MODELS = [ + EmbedModelInfo("intfloat/multilingual-e5-small", is_matryoshka=False), + EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v1.5", + is_matryoshka=True, + matryoshka_dimensions=[256]), +] + + +def test_task(): + pooling_params = PoolingParams() + pooling_params.verify(task="score") + + pooling_params = PoolingParams(task="score") + pooling_params.verify(task="score") + + with pytest.raises(ValueError): + pooling_params.verify(task="encode") + + +def test_embed(): + task = "embed" + pooling_params = PoolingParams(normalize=None) + pooling_params.verify(task=task) + + pooling_params = PoolingParams(normalize=True) + pooling_params.verify(task=task) + + pooling_params = PoolingParams(normalize=False) + pooling_params.verify(task=task) + + invalid_parameters = ["activation", "softmax"] + for p in invalid_parameters: + with pytest.raises(ValueError): + pooling_params = PoolingParams(**{p: True}) + pooling_params.verify(task=task) + + +@pytest.mark.parametrize("model_info", EMBEDDING_MODELS) +def test_embed_dimensions(model_info: EmbedModelInfo): + task = "embed" + model_config = ModelConfig( + model_info.name, + task="auto", + tokenizer=model_info.name, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="float16", + ) + + pooling_params = PoolingParams(dimensions=None) + pooling_params.verify(task=task, model_config=model_config) + + with pytest.raises(ValueError): + pooling_params = PoolingParams(dimensions=1) + pooling_params.verify(task=task, model_config=model_config) + + if model_info.is_matryoshka: + pooling_params = PoolingParams( + dimensions=model_info.matryoshka_dimensions[0]) + pooling_params.verify(task=task, model_config=model_config) + + +@pytest.mark.parametrize("task", ["score", "classify"]) +def test_classify(task): + pooling_params = PoolingParams(activation=None) + pooling_params.verify(task=task) + + pooling_params = PoolingParams(activation=True) + pooling_params.verify(task=task) + + pooling_params = PoolingParams(activation=False) + pooling_params.verify(task=task) + + invalid_parameters = ["dimensions", "normalize", "softmax"] + for p in invalid_parameters: + with pytest.raises(ValueError): + pooling_params = PoolingParams(**{p: True}) + pooling_params.verify(task=task) + + +def test_encode(): + task = "encode" + pooling_params = PoolingParams(softmax=None) + pooling_params.verify(task=task) + + pooling_params = PoolingParams(softmax=True) + pooling_params.verify(task=task) + + pooling_params = PoolingParams(softmax=False) + pooling_params.verify(task=task) + + invalid_parameters = ["dimensions", "normalize", "activation"] + for p in invalid_parameters: + with pytest.raises(ValueError): + pooling_params = PoolingParams(**{p: True}) + pooling_params.verify(task=task) diff --git a/vllm/config.py b/vllm/config.py index a72362ff2b76..4d686a78201b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3217,30 +3217,34 @@ class PoolerConfig: [`vllm.model_executor.layers.pooler.PoolingType`][]. """ + ## for embeddings models normalize: Optional[bool] = None """ - Whether to normalize the pooled outputs. Usually, this should be set to - ``True`` for embedding outputs. + Whether to normalize the embeddings outputs. """ - - softmax: Optional[bool] = None + dimensions: Optional[int] = None """ - Whether to apply softmax to the pooled outputs. Usually, this should be set - to ``True`` for classification outputs. + Reduce the dimensions of embeddings if model + support matryoshka representation. """ + ## for classification models activation: Optional[bool] = True """ - Whether to apply activation function to the pooled outputs. + Whether to apply activation function to the classification outputs. """ + ## for reward models + softmax: Optional[bool] = None + """ + Whether to apply softmax to the reward outputs. + """ step_tag_id: Optional[int] = None """ If set, only the score corresponding to the ``step_tag_id`` in the generated sentence should be returned. Otherwise, the scores for all tokens are returned. """ - returned_token_ids: Optional[list[int]] = None """ A list of indices for the vocabulary dimensions to be extracted, diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py index 00857a82ffc1..462109787508 100644 --- a/vllm/pooling_params.py +++ b/vllm/pooling_params.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING, Literal, Optional +from typing import TYPE_CHECKING, Literal, Optional, assert_never import msgspec @@ -20,22 +20,21 @@ class PoolingParams( """API parameters for pooling models. Attributes: + normalize: Whether to normalize the embeddings outputs. dimensions: Reduce the dimensions of embeddings if model support matryoshka representation. - softmax: Whether to using softmax, - None means using the model's default - normalize: Whether to using normalize, - None means using the model's default - activation: Whether to using activation function + activation: Whether to apply activation function to + the classification outputs. + softmax: Whether to apply softmax to the reward outputs. """ + ## for embeddings models dimensions: Optional[int] = None - - softmax: Optional[bool] = None - normalize: Optional[bool] = None - + ## for classification models activation: Optional[bool] = None + ## for reward models + softmax: Optional[bool] = None output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY @@ -49,14 +48,16 @@ def clone(self) -> "PoolingParams": """Returns a deep copy of the PoolingParams instance.""" return PoolingParams( dimensions=self.dimensions, - softmax=self.softmax, normalize=self.normalize, activation=self.activation, + softmax=self.softmax, task=self.task, requires_token_ids=self.requires_token_ids, ) - def verify(self, task: PoolingTask, model_config: "ModelConfig") -> None: + def verify(self, + task: PoolingTask, + model_config: Optional["ModelConfig"] = None) -> None: if self.task is None: self.task = task elif self.task != task: @@ -67,27 +68,48 @@ def verify(self, task: PoolingTask, model_config: "ModelConfig") -> None: # which is not available in model config. So, it's not included # in this method - if self.dimensions is not None: - if not model_config.is_matryoshka: - raise ValueError( - f'Model "{model_config.served_model_name}" does not ' - f'support matryoshka representation, ' - f'changing output dimensions will lead to poor results.') + all_parameters = ["dimensions", "normalize", "activation", "softmax"] + + if self.task == "embed": + legal_parameters = ["dimensions", "normalize"] - mds = model_config.matryoshka_dimensions - if mds is not None: - if self.dimensions not in mds: + if self.dimensions is not None and model_config is not None: + if not model_config.is_matryoshka: raise ValueError( - f'Model "{model_config.served_model_name}" ' - f'only supports {str(mds)} matryoshka dimensions, ' - f'use other output dimensions will ' - f'lead to poor results.') - elif self.dimensions < 1: - raise ValueError("Dimensions must be greater than 0") - - if self.normalize and self.softmax: - raise ValueError("`normalize=True` and `softmax=True` should not " - "be set together") + f'Model "{model_config.served_model_name}" does not ' + f'support matryoshka representation, ' + f'changing output dimensions will lead to poor results.' + ) + + mds = model_config.matryoshka_dimensions + if mds is not None: + if self.dimensions not in mds: + raise ValueError( + f'Model "{model_config.served_model_name}" ' + f'only supports {str(mds)} matryoshka dimensions, ' + f'use other output dimensions will ' + f'lead to poor results.') + elif self.dimensions < 1: + raise ValueError("Dimensions must be greater than 0") + elif self.task in ["classify", "score"]: + legal_parameters = ["activation"] + elif self.task == "encode": + legal_parameters = ["softmax"] + else: + assert_never(self.task) + + invalid_parameters = [] + for k in all_parameters: + if k in legal_parameters: + continue + + if getattr(self, k, None) is not None: + invalid_parameters.append(k) + + if invalid_parameters: + raise ValueError( + f"{self.task} only supports {legal_parameters} parameters, " + f"does not support {invalid_parameters} parameters") def __repr__(self) -> str: return (f"PoolingParams(" From 5a131d1243d0bf86c62f41d08c46cb8365e630b2 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Fri, 25 Jul 2025 19:47:46 +0800 Subject: [PATCH 09/30] + merge_default_parameters Signed-off-by: wang.yuqi --- vllm/config.py | 2 +- vllm/entrypoints/llm.py | 11 +- .../openai/serving_classification.py | 1 + vllm/entrypoints/openai/serving_embedding.py | 1 + vllm/entrypoints/openai/serving_pooling.py | 1 + vllm/entrypoints/openai/serving_score.py | 2 + vllm/model_executor/layers/pooler.py | 209 ++++++++---------- vllm/model_executor/models/adapters.py | 6 +- vllm/model_executor/models/bert.py | 6 +- vllm/model_executor/models/modernbert.py | 6 +- vllm/model_executor/models/roberta.py | 6 +- vllm/pooling_params.py | 24 +- 12 files changed, 133 insertions(+), 142 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 4d686a78201b..36c42ce6e038 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3229,7 +3229,7 @@ class PoolerConfig: """ ## for classification models - activation: Optional[bool] = True + activation: Optional[bool] = None """ Whether to apply activation function to the classification outputs. """ diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 5a25f2a7f63d..a215039c50ea 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1113,9 +1113,12 @@ def encode( pooling_params = PoolingParams() if isinstance(pooling_params, PoolingParams): + pooling_params.merge_default_parameters(model_config.pooler_config) pooling_params.verify(pooling_task, model_config) else: for pooling_param in pooling_params: + pooling_params.merge_default_parameters( + model_config.pooler_config) pooling_param.verify(pooling_task, model_config) if tokenization_kwargs is None: @@ -1290,8 +1293,12 @@ def _cross_encoding_score( if pooling_params is None: pooling_params = PoolingParams(task="score") - else: - pooling_params.task = "score" + + model_config = self.llm_engine.model_config + pooling_task = "score" + + pooling_params.merge_default_parameters(model_config.pooler_config) + pooling_params.verify(pooling_task, model_config) tokenization_kwargs: dict[str, Any] = {} _validate_truncation_size(self.llm_engine.model_config.max_model_len, diff --git a/vllm/entrypoints/openai/serving_classification.py b/vllm/entrypoints/openai/serving_classification.py index 377f7f684717..e10ba3c616f6 100644 --- a/vllm/entrypoints/openai/serving_classification.py +++ b/vllm/entrypoints/openai/serving_classification.py @@ -177,6 +177,7 @@ def _create_pooling_params( return pooling_params try: + pooling_params.merge_default_parameters(self.model_config.pooler_config) pooling_params.verify("classify", self.model_config) except ValueError as e: return self.create_error_response(str(e)) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 697f43c018b2..7e9c1c408cc8 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -199,6 +199,7 @@ def _create_pooling_params( return pooling_params try: + pooling_params.merge_default_parameters(self.model_config.pooler_config) pooling_params.verify("embed", self.model_config) except ValueError as e: return self.create_error_response(str(e)) diff --git a/vllm/entrypoints/openai/serving_pooling.py b/vllm/entrypoints/openai/serving_pooling.py index 12334cdac365..b17207042050 100644 --- a/vllm/entrypoints/openai/serving_pooling.py +++ b/vllm/entrypoints/openai/serving_pooling.py @@ -136,6 +136,7 @@ async def create_pooling( pooling_params = request.to_pooling_params() try: + pooling_params.merge_default_parameters(self.model_config.pooler_config) pooling_params.verify("encode", self.model_config) except ValueError as e: return self.create_error_response(str(e)) diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index 4da2094147ce..4a5a9476572f 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -86,6 +86,7 @@ async def _embedding_score( pooling_params = request.to_pooling_params() try: + pooling_params.merge_default_parameters(self.model_config.pooler_config) pooling_params.verify("embed", self.model_config) except ValueError as e: return self.create_error_response(str(e)) @@ -245,6 +246,7 @@ async def _cross_encoding_score( pooling_params = request.to_pooling_params() try: + pooling_params.merge_default_parameters(self.model_config.pooler_config) pooling_params.verify("score", self.model_config) except ValueError as e: return self.create_error_response(str(e)) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index f3a188188c05..eb27b31d542d 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -143,12 +143,9 @@ def for_classify( if classifier is None: return base_pooler - return ClassifierPooler( - pooling=base_pooler.pooling, - classifier=classifier, - act_fn=base_pooler.head.activation, - activation=pooler_config.activation, - ) + return ClassifierPooler(pooling=base_pooler.pooling, + classifier=classifier, + act_fn=base_pooler.head.activation) @abstractmethod def get_supported_tasks(self) -> Set[PoolingTask]: @@ -170,78 +167,6 @@ def forward( raise NotImplementedError -def get_prompt_lens( - hidden_states: Union[torch.Tensor, list[torch.Tensor]], - pooling_metadata: PoolingMetadata, -) -> torch.Tensor: - if isinstance(pooling_metadata, V1PoolingMetadata): - return pooling_metadata.prompt_lens - - return PoolingTensors.from_pooling_metadata( - pooling_metadata, hidden_states[0].device).prompt_lens - - -def get_prompt_token_ids( - pooling_metadata: PoolingMetadata) -> list[torch.Tensor]: - if isinstance(pooling_metadata, V1PoolingMetadata): - assert pooling_metadata.prompt_token_ids is not None, ( - "Please set `requires_token_ids=True` in `get_pooling_updates`") - - return [ - pooling_metadata.prompt_token_ids[i, :num] - for i, num in enumerate(pooling_metadata.prompt_lens) - ] - - return [ - torch.tensor(seq_data_i.prompt_token_ids) - for seq_data_i in pooling_metadata.seq_data.values() - ] - - -def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]: - if isinstance(pooling_metadata, V0PoolingMetadata): - pooling_params = [p for _, p in pooling_metadata.seq_groups] - else: - pooling_params = pooling_metadata.pooling_params - - tasks: list[PoolingTask] = [ - task for pooling_param in pooling_params - if (task := pooling_param.task) is not None - ] - assert len(pooling_params) == len(tasks) - - return tasks - - -def get_classification_activation_function(config: PretrainedConfig): - return PoolerClassify() - - -def get_cross_encoder_activation_function(config: PretrainedConfig): - function_name: Optional[str] = None - if (hasattr(config, "sentence_transformers") - and "activation_fn" in config.sentence_transformers): - function_name = config.sentence_transformers["activation_fn"] - elif (hasattr(config, "sbert_ce_default_activation_function") - and config.sbert_ce_default_activation_function is not None): - function_name = config.sbert_ce_default_activation_function - - if function_name is not None: - assert function_name.startswith("torch.nn.modules."), ( - "Loading of activation functions is restricted to " - "torch.nn.modules for security reasons") - fn = resolve_obj_by_qualname(function_name)() - return PoolerActivation.wraps(fn) - - return PoolerScore() - - -def build_output( - all_data: Union[torch.Tensor, list[torch.Tensor]], ) -> PoolerOutput: - all_outputs = [PoolingSequenceGroupOutput(data) for data in all_data] - return PoolerOutput(outputs=all_outputs) - - class PoolingMethod(nn.Module, ABC): @staticmethod @@ -508,7 +433,7 @@ def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], else: pooled_data = pooled_data.to(torch.float32) - pooling_params = _get_pooling_params(pooling_metadata) + pooling_params = get_pooling_params(pooling_metadata) # for matryoshka representation dimensions_list = [ @@ -529,35 +454,25 @@ def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], ] # for normalize - normalize_list = [ - pooling_param.normalize or - (pooling_param.normalize is None and self.pooler_config.normalize) - for pooling_param in pooling_params - ] - - if len(set(normalize_list)) == 1: - if normalize_list[0]: + flags = [p.normalize for p in pooling_params] + if len(set(flags)) == 1: + if flags[0]: pooled_data = self.normalize(pooled_data) else: pooled_data = [ self.normalize(vecs) if f else vecs - for vecs, f in zip(pooled_data, normalize_list) + for vecs, f in zip(pooled_data, flags) ] # for softmax - softmax_list = [ - pooling_param.softmax - or (pooling_param.softmax is None and self.pooler_config.softmax) - for pooling_param in pooling_params - ] - - if len(set(softmax_list)) == 1: - if softmax_list[0]: + flags = [p.softmax for p in pooling_params] + if len(set(flags)) == 1: + if flags[0]: pooled_data = self.softmax(pooled_data) else: pooled_data = [ self.softmax(vecs) if f else vecs - for vecs, f in zip(pooled_data, softmax_list) + for vecs, f in zip(pooled_data, flags) ] return pooled_data @@ -689,14 +604,12 @@ def __init__( pooling: PoolingFn, classifier: ClassifierFn, act_fn: PoolerActivation, - activation: bool = True, ) -> None: super().__init__() self.pooling = pooling self.classifier = classifier self.act_fn = act_fn - self.activation = activation def get_supported_tasks(self) -> Set[PoolingTask]: return {"classify", "score"} @@ -716,23 +629,15 @@ def forward( else: pooled_output = [self.classifier(data) for data in pooled_data] - pooling_params = _get_pooling_params(pooling_metadata) - - activation_list = [ - pooling_param.activation - or (pooling_param.activation is None and self.activation) - for pooling_param in pooling_params - ] + pooling_params = get_pooling_params(pooling_metadata) + flags = [p.activation for p in pooling_params] - if len(set(activation_list)) == 1: - if activation_list[0]: - scores = self.act_fn(pooled_output) - else: - scores = pooled_output + if len(set(flags)) == 1: + scores = self.act_fn(pooled_output) if flags[0] else pooled_output else: scores = [ self.act_fn(vecs) if f else vecs - for vecs, f in zip(pooled_output, activation_list) + for vecs, f in zip(pooled_output, flags) ] return build_output(scores) @@ -791,13 +696,79 @@ def forward( return PoolerOutput(outputs) -def _get_pooling_params(pooling_metadata: PoolingMetadata): - if isinstance(pooling_metadata, V0PoolingMetadata): - pooling_params = [ - pooling_param for _, pooling_param in pooling_metadata.seq_groups +def get_prompt_lens( + hidden_states: Union[torch.Tensor, list[torch.Tensor]], + pooling_metadata: PoolingMetadata, +) -> torch.Tensor: + if isinstance(pooling_metadata, V1PoolingMetadata): + return pooling_metadata.prompt_lens + + return PoolingTensors.from_pooling_metadata( + pooling_metadata, hidden_states[0].device).prompt_lens + + +def get_prompt_token_ids( + pooling_metadata: PoolingMetadata) -> list[torch.Tensor]: + if isinstance(pooling_metadata, V1PoolingMetadata): + assert pooling_metadata.prompt_token_ids is not None, ( + "Please set `requires_token_ids=True` in `get_pooling_updates`") + + return [ + pooling_metadata.prompt_token_ids[i, :num] + for i, num in enumerate(pooling_metadata.prompt_lens) ] + + return [ + torch.tensor(seq_data_i.prompt_token_ids) + for seq_data_i in pooling_metadata.seq_data.values() + ] + + +def get_pooling_params( + pooling_metadata: PoolingMetadata) -> list[PoolingParams]: + if isinstance(pooling_metadata, V0PoolingMetadata): + pooling_params = [p for _, p in pooling_metadata.seq_groups] else: - pooling_params = [ - pooling_param for pooling_param in pooling_metadata.pooling_params - ] + pooling_params = pooling_metadata.pooling_params return pooling_params + + +def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]: + pooling_params = get_pooling_params(pooling_metadata) + + tasks: list[PoolingTask] = [ + task for pooling_param in pooling_params + if (task := pooling_param.task) is not None + ] + assert len(pooling_params) == len(tasks) + + return tasks + + +def get_classification_activation_function(config: PretrainedConfig): + return PoolerClassify() + + +def get_cross_encoder_activation_function(config: PretrainedConfig): + function_name: Optional[str] = None + if (hasattr(config, "sentence_transformers") + and "activation_fn" in config.sentence_transformers): + function_name = config.sentence_transformers["activation_fn"] + elif (hasattr(config, "sbert_ce_default_activation_function") + and config.sbert_ce_default_activation_function is not None): + function_name = config.sbert_ce_default_activation_function + + if function_name is not None: + assert function_name.startswith("torch.nn.modules."), ( + "Loading of activation functions is restricted to " + "torch.nn.modules for security reasons") + fn = resolve_obj_by_qualname(function_name)() + return PoolerActivation.wraps(fn) + + return PoolerScore() + + +def build_output( + all_data: Union[torch.Tensor, list[torch.Tensor]], ) -> PoolerOutput: + all_outputs = [PoolingSequenceGroupOutput(data) for data in all_data] + return PoolerOutput(outputs=all_outputs) diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index 8bdfc03c6631..aa608a836297 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -193,16 +193,14 @@ def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): pooling=PoolingMethod.from_pooling_type(pooling_type), classifier=self._classifier, act_fn=ClassifierPooler.act_fn_for_seq_cls( - vllm_config.model_config), - activation=pooler_config.activation, + vllm_config.model_config) ), "score": ClassifierPooler( pooling=PoolingMethod.from_pooling_type(pooling_type), classifier=self._classifier, act_fn=ClassifierPooler.act_fn_for_cross_encoder( - vllm_config.model_config), - activation=pooler_config.activation, + vllm_config.model_config) ), }) diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 760979983b04..6477bfc3e855 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -547,16 +547,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): pooling=self.bert.pooler, classifier=self.classifier, act_fn=ClassifierPooler.act_fn_for_seq_cls( - vllm_config.model_config), - activation=pooler_config.activation, + vllm_config.model_config) ), "score": ClassifierPooler( pooling=self.bert.pooler, classifier=self.classifier, act_fn=ClassifierPooler.act_fn_for_cross_encoder( - vllm_config.model_config), - activation=pooler_config.activation, + vllm_config.model_config) ), }) diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index feaccb2a45d6..87cfe17bff75 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -320,16 +320,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): pooling=ModernBertPooler(config), classifier=self.classifier, act_fn=ClassifierPooler.act_fn_for_seq_cls( - vllm_config.model_config), - activation=pooler_config.activation, + vllm_config.model_config) ), "score": ClassifierPooler( pooling=ModernBertPooler(config), classifier=self.classifier, act_fn=ClassifierPooler.act_fn_for_cross_encoder( - vllm_config.model_config), - activation=pooler_config.activation, + vllm_config.model_config) ), }) diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index e139f4e7a95b..bf42540e4a4c 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -193,16 +193,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): pooling=CLSPool(), classifier=self.classifier, act_fn=ClassifierPooler.act_fn_for_seq_cls( - vllm_config.model_config), - activation=pooler_config.activation, + vllm_config.model_config) ), "score": ClassifierPooler( pooling=CLSPool(), classifier=self.classifier, act_fn=ClassifierPooler.act_fn_for_cross_encoder( - vllm_config.model_config), - activation=pooler_config.activation, + vllm_config.model_config) ), }) diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py index 462109787508..0275aa5f89eb 100644 --- a/vllm/pooling_params.py +++ b/vllm/pooling_params.py @@ -8,7 +8,7 @@ from vllm.sampling_params import RequestOutputKind if TYPE_CHECKING: - from vllm.config import ModelConfig + from vllm.config import ModelConfig, PoolerConfig PoolingTask = Literal["encode", "embed", "classify", "score"] @@ -44,6 +44,10 @@ class PoolingParams( requires_token_ids: bool = False """Internal use only.""" + @property + def all_parameters(self) -> list[str]: + return ["dimensions", "normalize", "activation", "softmax"] + def clone(self) -> "PoolingParams": """Returns a deep copy of the PoolingParams instance.""" return PoolingParams( @@ -68,8 +72,6 @@ def verify(self, # which is not available in model config. So, it's not included # in this method - all_parameters = ["dimensions", "normalize", "activation", "softmax"] - if self.task == "embed": legal_parameters = ["dimensions", "normalize"] @@ -91,15 +93,24 @@ def verify(self, f'lead to poor results.') elif self.dimensions < 1: raise ValueError("Dimensions must be greater than 0") + + if self.normalize is None: + self.normalize = True + elif self.task in ["classify", "score"]: legal_parameters = ["activation"] + if self.activation is None: + self.activation = True + elif self.task == "encode": legal_parameters = ["softmax"] + if self.softmax is None: + self.softmax = True else: assert_never(self.task) invalid_parameters = [] - for k in all_parameters: + for k in self.all_parameters: if k in legal_parameters: continue @@ -111,6 +122,11 @@ def verify(self, f"{self.task} only supports {legal_parameters} parameters, " f"does not support {invalid_parameters} parameters") + def merge_default_parameters(self, pooler_config: "PoolerConfig") -> None: + for k in self.all_parameters: + if getattr(self, k, None) is None: + setattr(self, k, getattr(pooler_config, k)) + def __repr__(self) -> str: return (f"PoolingParams(" f"dimensions={self.dimensions}, " From 5385b764eb65652ebe2a02da34f670d61c909055 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Fri, 25 Jul 2025 19:57:57 +0800 Subject: [PATCH 10/30] Remove unnecessary changes Signed-off-by: wang.yuqi --- vllm/model_executor/layers/pooler.py | 156 +++++++++++------------ vllm/model_executor/models/adapters.py | 4 +- vllm/model_executor/models/bert.py | 4 +- vllm/model_executor/models/modernbert.py | 4 +- vllm/model_executor/models/roberta.py | 4 +- 5 files changed, 86 insertions(+), 86 deletions(-) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index eb27b31d542d..36822e50459e 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -167,6 +167,84 @@ def forward( raise NotImplementedError +def get_prompt_lens( + hidden_states: Union[torch.Tensor, list[torch.Tensor]], + pooling_metadata: PoolingMetadata, +) -> torch.Tensor: + if isinstance(pooling_metadata, V1PoolingMetadata): + return pooling_metadata.prompt_lens + + return PoolingTensors.from_pooling_metadata( + pooling_metadata, hidden_states[0].device).prompt_lens + + +def get_prompt_token_ids( + pooling_metadata: PoolingMetadata) -> list[torch.Tensor]: + if isinstance(pooling_metadata, V1PoolingMetadata): + assert pooling_metadata.prompt_token_ids is not None, ( + "Please set `requires_token_ids=True` in `get_pooling_updates`") + + return [ + pooling_metadata.prompt_token_ids[i, :num] + for i, num in enumerate(pooling_metadata.prompt_lens) + ] + + return [ + torch.tensor(seq_data_i.prompt_token_ids) + for seq_data_i in pooling_metadata.seq_data.values() + ] + + +def get_pooling_params( + pooling_metadata: PoolingMetadata) -> list[PoolingParams]: + if isinstance(pooling_metadata, V0PoolingMetadata): + pooling_params = [p for _, p in pooling_metadata.seq_groups] + else: + pooling_params = pooling_metadata.pooling_params + return pooling_params + + +def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]: + pooling_params = get_pooling_params(pooling_metadata) + + tasks: list[PoolingTask] = [ + task for pooling_param in pooling_params + if (task := pooling_param.task) is not None + ] + assert len(pooling_params) == len(tasks) + + return tasks + + +def get_classification_activation_function(config: PretrainedConfig): + return PoolerClassify() + + +def get_cross_encoder_activation_function(config: PretrainedConfig): + function_name: Optional[str] = None + if (hasattr(config, "sentence_transformers") + and "activation_fn" in config.sentence_transformers): + function_name = config.sentence_transformers["activation_fn"] + elif (hasattr(config, "sbert_ce_default_activation_function") + and config.sbert_ce_default_activation_function is not None): + function_name = config.sbert_ce_default_activation_function + + if function_name is not None: + assert function_name.startswith("torch.nn.modules."), ( + "Loading of activation functions is restricted to " + "torch.nn.modules for security reasons") + fn = resolve_obj_by_qualname(function_name)() + return PoolerActivation.wraps(fn) + + return PoolerScore() + + +def build_output( + all_data: Union[torch.Tensor, list[torch.Tensor]], ) -> PoolerOutput: + all_outputs = [PoolingSequenceGroupOutput(data) for data in all_data] + return PoolerOutput(outputs=all_outputs) + + class PoolingMethod(nn.Module, ABC): @staticmethod @@ -694,81 +772,3 @@ def forward( offset += num_items return PoolerOutput(outputs) - - -def get_prompt_lens( - hidden_states: Union[torch.Tensor, list[torch.Tensor]], - pooling_metadata: PoolingMetadata, -) -> torch.Tensor: - if isinstance(pooling_metadata, V1PoolingMetadata): - return pooling_metadata.prompt_lens - - return PoolingTensors.from_pooling_metadata( - pooling_metadata, hidden_states[0].device).prompt_lens - - -def get_prompt_token_ids( - pooling_metadata: PoolingMetadata) -> list[torch.Tensor]: - if isinstance(pooling_metadata, V1PoolingMetadata): - assert pooling_metadata.prompt_token_ids is not None, ( - "Please set `requires_token_ids=True` in `get_pooling_updates`") - - return [ - pooling_metadata.prompt_token_ids[i, :num] - for i, num in enumerate(pooling_metadata.prompt_lens) - ] - - return [ - torch.tensor(seq_data_i.prompt_token_ids) - for seq_data_i in pooling_metadata.seq_data.values() - ] - - -def get_pooling_params( - pooling_metadata: PoolingMetadata) -> list[PoolingParams]: - if isinstance(pooling_metadata, V0PoolingMetadata): - pooling_params = [p for _, p in pooling_metadata.seq_groups] - else: - pooling_params = pooling_metadata.pooling_params - return pooling_params - - -def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]: - pooling_params = get_pooling_params(pooling_metadata) - - tasks: list[PoolingTask] = [ - task for pooling_param in pooling_params - if (task := pooling_param.task) is not None - ] - assert len(pooling_params) == len(tasks) - - return tasks - - -def get_classification_activation_function(config: PretrainedConfig): - return PoolerClassify() - - -def get_cross_encoder_activation_function(config: PretrainedConfig): - function_name: Optional[str] = None - if (hasattr(config, "sentence_transformers") - and "activation_fn" in config.sentence_transformers): - function_name = config.sentence_transformers["activation_fn"] - elif (hasattr(config, "sbert_ce_default_activation_function") - and config.sbert_ce_default_activation_function is not None): - function_name = config.sbert_ce_default_activation_function - - if function_name is not None: - assert function_name.startswith("torch.nn.modules."), ( - "Loading of activation functions is restricted to " - "torch.nn.modules for security reasons") - fn = resolve_obj_by_qualname(function_name)() - return PoolerActivation.wraps(fn) - - return PoolerScore() - - -def build_output( - all_data: Union[torch.Tensor, list[torch.Tensor]], ) -> PoolerOutput: - all_outputs = [PoolingSequenceGroupOutput(data) for data in all_data] - return PoolerOutput(outputs=all_outputs) diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index aa608a836297..867de2c68b4c 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -193,14 +193,14 @@ def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): pooling=PoolingMethod.from_pooling_type(pooling_type), classifier=self._classifier, act_fn=ClassifierPooler.act_fn_for_seq_cls( - vllm_config.model_config) + vllm_config.model_config), ), "score": ClassifierPooler( pooling=PoolingMethod.from_pooling_type(pooling_type), classifier=self._classifier, act_fn=ClassifierPooler.act_fn_for_cross_encoder( - vllm_config.model_config) + vllm_config.model_config), ), }) diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 6477bfc3e855..9dc6115f850e 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -547,14 +547,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): pooling=self.bert.pooler, classifier=self.classifier, act_fn=ClassifierPooler.act_fn_for_seq_cls( - vllm_config.model_config) + vllm_config.model_config), ), "score": ClassifierPooler( pooling=self.bert.pooler, classifier=self.classifier, act_fn=ClassifierPooler.act_fn_for_cross_encoder( - vllm_config.model_config) + vllm_config.model_config), ), }) diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index 87cfe17bff75..be1c3438d9db 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -320,14 +320,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): pooling=ModernBertPooler(config), classifier=self.classifier, act_fn=ClassifierPooler.act_fn_for_seq_cls( - vllm_config.model_config) + vllm_config.model_config), ), "score": ClassifierPooler( pooling=ModernBertPooler(config), classifier=self.classifier, act_fn=ClassifierPooler.act_fn_for_cross_encoder( - vllm_config.model_config) + vllm_config.model_config), ), }) diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index bf42540e4a4c..c6b411644034 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -193,14 +193,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): pooling=CLSPool(), classifier=self.classifier, act_fn=ClassifierPooler.act_fn_for_seq_cls( - vllm_config.model_config) + vllm_config.model_config), ), "score": ClassifierPooler( pooling=CLSPool(), classifier=self.classifier, act_fn=ClassifierPooler.act_fn_for_cross_encoder( - vllm_config.model_config) + vllm_config.model_config), ), }) From 194659658502c60e080b1b4c7f523ac5f0ad553b Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Fri, 25 Jul 2025 20:02:46 +0800 Subject: [PATCH 11/30] Remove unnecessary changes Signed-off-by: wang.yuqi --- tests/entrypoints/llm/test_encode.py | 8 ++++++++ vllm/model_executor/layers/pooler.py | 8 +++++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/tests/entrypoints/llm/test_encode.py b/tests/entrypoints/llm/test_encode.py index c0adee8e5f04..b930f05bebd0 100644 --- a/tests/entrypoints/llm/test_encode.py +++ b/tests/entrypoints/llm/test_encode.py @@ -29,6 +29,14 @@ ] +@pytest.fixture(autouse=True) +def v1(run_with_both_engines): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + @pytest.fixture(scope="module") def llm(): # pytest caches the fixture so we use weakref.proxy to diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 36822e50459e..260acd1d83f9 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -143,9 +143,11 @@ def for_classify( if classifier is None: return base_pooler - return ClassifierPooler(pooling=base_pooler.pooling, - classifier=classifier, - act_fn=base_pooler.head.activation) + return ClassifierPooler( + pooling=base_pooler.pooling, + classifier=classifier, + act_fn=base_pooler.head.activation, + ) @abstractmethod def get_supported_tasks(self) -> Set[PoolingTask]: From beead6b05aa59e499d61f4600cb529bd14928dee Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Sat, 26 Jul 2025 14:16:01 +0800 Subject: [PATCH 12/30] fix Signed-off-by: wang.yuqi --- tests/entrypoints/openai/test_score.py | 19 ++-- .../pooling/test_truncation_control.py | 6 +- vllm/entrypoints/openai/protocol.py | 17 +--- vllm/pooling_params.py | 95 ++++--------------- 4 files changed, 38 insertions(+), 99 deletions(-) diff --git a/tests/entrypoints/openai/test_score.py b/tests/entrypoints/openai/test_score.py index dfe010e5845c..1a5df1d2dbd2 100644 --- a/tests/entrypoints/openai/test_score.py +++ b/tests/entrypoints/openai/test_score.py @@ -235,15 +235,18 @@ def get_outputs(activation): "text_2": text_2, "activation": activation }) - outputs = response.json() + if response.status_code != 200: + return response + outputs = response.json() return torch.tensor([x['score'] for x in outputs["data"]]) - default = get_outputs(activation=None) - w_activation = get_outputs(activation=True) - wo_activation = get_outputs(activation=False) - if model["is_cross_encoder"]: + + default = get_outputs(activation=None) + w_activation = get_outputs(activation=True) + wo_activation = get_outputs(activation=False) + assert torch.allclose(default, w_activation, atol=1e-2), "Default should use activation." assert not torch.allclose( @@ -253,6 +256,8 @@ def get_outputs(activation): F.sigmoid(wo_activation), w_activation, atol=1e-2 ), "w_activation should be close to activation(wo_activation)." else: + get_outputs(activation=None) + # The activation parameter only works for the is_cross_encoder model - assert torch.allclose(default, w_activation, atol=1e-2) - assert torch.allclose(wo_activation, w_activation, atol=1e-2) + response = get_outputs(activation=True) + assert response.status_code == 400 diff --git a/tests/models/language/pooling/test_truncation_control.py b/tests/models/language/pooling/test_truncation_control.py index c7399e01c735..c68ff078044a 100644 --- a/tests/models/language/pooling/test_truncation_control.py +++ b/tests/models/language/pooling/test_truncation_control.py @@ -28,7 +28,7 @@ def test_smaller_truncation_size(vllm_runner, with vllm_runner(model_name, task="embed", max_model_len=max_model_len) as vllm_model: - vllm_output = vllm_model.llm.encode( + vllm_output = vllm_model.llm.embed( input_str, truncate_prompt_tokens=truncate_prompt_tokens) prompt_tokens = vllm_output[0].prompt_token_ids @@ -43,7 +43,7 @@ def test_max_truncation_size(vllm_runner, with vllm_runner(model_name, task="embed", max_model_len=max_model_len) as vllm_model: - vllm_output = vllm_model.llm.encode( + vllm_output = vllm_model.llm.embed( input_str, truncate_prompt_tokens=truncate_prompt_tokens) prompt_tokens = vllm_output[0].prompt_token_ids @@ -61,7 +61,7 @@ def test_bigger_truncation_size(vllm_runner, model_name, task="embed", max_model_len=max_model_len) as vllm_model: - llm_output = vllm_model.llm.encode( + llm_output = vllm_model.llm.embed( input_str, truncate_prompt_tokens=truncate_prompt_tokens) assert llm_output == f"""truncate_prompt_tokens value diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 56e7472915c3..6c6ec207a3ca 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1251,13 +1251,11 @@ class EmbeddingCompletionRequest(OpenAIBaseModel): "default: 0). Any priority other than 0 will raise an error " "if the served model does not use priority scheduling."), ) - normalize: Optional[bool] = None # --8<-- [end:embedding-extra-params] def to_pooling_params(self): - return PoolingParams(dimensions=self.dimensions, - normalize=self.normalize) + return PoolingParams(dimensions=self.dimensions) class EmbeddingChatRequest(OpenAIBaseModel): @@ -1304,7 +1302,6 @@ class EmbeddingChatRequest(OpenAIBaseModel): "default: 0). Any priority other than 0 will raise an error " "if the served model does not use priority scheduling."), ) - normalize: Optional[bool] = None # --8<-- [end:chat-embedding-extra-params] @model_validator(mode="before") @@ -1317,8 +1314,7 @@ def check_generation_prompt(cls, data): return data def to_pooling_params(self): - return PoolingParams(dimensions=self.dimensions, - normalize=self.normalize) + return PoolingParams(dimensions=self.dimensions) EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest] @@ -1333,7 +1329,6 @@ class ScoreRequest(OpenAIBaseModel): text_1: Union[list[str], str, ScoreMultiModalParam] text_2: Union[list[str], str, ScoreMultiModalParam] truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None - activation: Optional[bool] = None # --8<-- [start:score-extra-params] @@ -1353,7 +1348,7 @@ class ScoreRequest(OpenAIBaseModel): # --8<-- [end:score-extra-params] def to_pooling_params(self): - return PoolingParams(activation=self.activation) + return PoolingParams() class RerankRequest(OpenAIBaseModel): @@ -1362,7 +1357,6 @@ class RerankRequest(OpenAIBaseModel): documents: Union[list[str], ScoreMultiModalParam] top_n: int = Field(default_factory=lambda: 0) truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None - activation: Optional[bool] = None # --8<-- [start:rerank-extra-params] @@ -1382,7 +1376,7 @@ class RerankRequest(OpenAIBaseModel): # --8<-- [end:rerank-extra-params] def to_pooling_params(self): - return PoolingParams(activation=self.activation) + return PoolingParams() class RerankDocument(BaseModel): @@ -1519,7 +1513,6 @@ class ClassificationRequest(OpenAIBaseModel): input: Union[list[str], str] truncate_prompt_tokens: Optional[int] = None user: Optional[str] = None - activation: Optional[bool] = None # --8<-- [start:classification-extra-params] priority: int = Field( @@ -1533,7 +1526,7 @@ class ClassificationRequest(OpenAIBaseModel): # --8<-- [end:classification-extra-params] def to_pooling_params(self): - return PoolingParams(activation=self.activation) + return PoolingParams() class ClassificationData(OpenAIBaseModel): diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py index 0275aa5f89eb..868facbe2557 100644 --- a/vllm/pooling_params.py +++ b/vllm/pooling_params.py @@ -1,14 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING, Literal, Optional, assert_never +from typing import TYPE_CHECKING, Literal, Optional import msgspec from vllm.sampling_params import RequestOutputKind if TYPE_CHECKING: - from vllm.config import ModelConfig, PoolerConfig + from vllm.config import ModelConfig PoolingTask = Literal["encode", "embed", "classify", "score"] @@ -20,21 +20,11 @@ class PoolingParams( """API parameters for pooling models. Attributes: - normalize: Whether to normalize the embeddings outputs. dimensions: Reduce the dimensions of embeddings if model support matryoshka representation. - activation: Whether to apply activation function to - the classification outputs. - softmax: Whether to apply softmax to the reward outputs. """ - ## for embeddings models dimensions: Optional[int] = None - normalize: Optional[bool] = None - ## for classification models - activation: Optional[bool] = None - ## for reward models - softmax: Optional[bool] = None output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY @@ -44,24 +34,15 @@ class PoolingParams( requires_token_ids: bool = False """Internal use only.""" - @property - def all_parameters(self) -> list[str]: - return ["dimensions", "normalize", "activation", "softmax"] - def clone(self) -> "PoolingParams": """Returns a deep copy of the PoolingParams instance.""" return PoolingParams( dimensions=self.dimensions, - normalize=self.normalize, - activation=self.activation, - softmax=self.softmax, task=self.task, requires_token_ids=self.requires_token_ids, ) - def verify(self, - task: PoolingTask, - model_config: Optional["ModelConfig"] = None) -> None: + def verify(self, task: PoolingTask, model_config: "ModelConfig") -> None: if self.task is None: self.task = task elif self.task != task: @@ -72,68 +53,28 @@ def verify(self, # which is not available in model config. So, it's not included # in this method - if self.task == "embed": - legal_parameters = ["dimensions", "normalize"] + if self.dimensions is not None: + if not model_config.is_matryoshka: + raise ValueError( + f'Model "{model_config.served_model_name}" does not ' + f'support matryoshka representation, ' + f'changing output dimensions will lead to poor results.') - if self.dimensions is not None and model_config is not None: - if not model_config.is_matryoshka: + mds = model_config.matryoshka_dimensions + if mds is not None: + if self.dimensions not in mds: raise ValueError( - f'Model "{model_config.served_model_name}" does not ' - f'support matryoshka representation, ' - f'changing output dimensions will lead to poor results.' - ) - - mds = model_config.matryoshka_dimensions - if mds is not None: - if self.dimensions not in mds: - raise ValueError( - f'Model "{model_config.served_model_name}" ' - f'only supports {str(mds)} matryoshka dimensions, ' - f'use other output dimensions will ' - f'lead to poor results.') - elif self.dimensions < 1: - raise ValueError("Dimensions must be greater than 0") - - if self.normalize is None: - self.normalize = True - - elif self.task in ["classify", "score"]: - legal_parameters = ["activation"] - if self.activation is None: - self.activation = True - - elif self.task == "encode": - legal_parameters = ["softmax"] - if self.softmax is None: - self.softmax = True - else: - assert_never(self.task) - - invalid_parameters = [] - for k in self.all_parameters: - if k in legal_parameters: - continue - - if getattr(self, k, None) is not None: - invalid_parameters.append(k) - - if invalid_parameters: - raise ValueError( - f"{self.task} only supports {legal_parameters} parameters, " - f"does not support {invalid_parameters} parameters") - - def merge_default_parameters(self, pooler_config: "PoolerConfig") -> None: - for k in self.all_parameters: - if getattr(self, k, None) is None: - setattr(self, k, getattr(pooler_config, k)) + f'Model "{model_config.served_model_name}" ' + f'only supports {str(mds)} matryoshka dimensions, ' + f'use other output dimensions will ' + f'lead to poor results.') + elif self.dimensions < 1: + raise ValueError("Dimensions must be greater than 0") def __repr__(self) -> str: return (f"PoolingParams(" f"dimensions={self.dimensions}, " f"task={self.task}, " - f"softmax={self.softmax}, " - f"normalize={self.normalize}, " - f"activation={self.activation}, " f"requires_token_ids={self.requires_token_ids})") def __post_init__(self) -> None: From 889570ad9f75fb5bebac925693c402d9394f27d3 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Sat, 26 Jul 2025 14:22:21 +0800 Subject: [PATCH 13/30] fix Signed-off-by: wang.yuqi --- vllm/entrypoints/openai/protocol.py | 20 ++++-- vllm/pooling_params.py | 95 +++++++++++++++++++++++------ 2 files changed, 92 insertions(+), 23 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index b6b3bf3f530e..f884e9766fef 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1265,11 +1265,13 @@ class EmbeddingCompletionRequest(OpenAIBaseModel): "not set it, a random_uuid will be generated. This id is used " "through out the inference process and return in response."), ) + normalize: Optional[bool] = None # --8<-- [end:embedding-extra-params] def to_pooling_params(self): - return PoolingParams(dimensions=self.dimensions) + return PoolingParams(dimensions=self.dimensions, + normalize=self.normalize) class EmbeddingChatRequest(OpenAIBaseModel): @@ -1323,6 +1325,7 @@ class EmbeddingChatRequest(OpenAIBaseModel): "not set it, a random_uuid will be generated. This id is used " "through out the inference process and return in response."), ) + normalize: Optional[bool] = None # --8<-- [end:chat-embedding-extra-params] @model_validator(mode="before") @@ -1335,7 +1338,8 @@ def check_generation_prompt(cls, data): return data def to_pooling_params(self): - return PoolingParams(dimensions=self.dimensions) + return PoolingParams(dimensions=self.dimensions, + normalize=self.normalize) EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest] @@ -1366,10 +1370,12 @@ class ScoreRequest(OpenAIBaseModel): "if the served model does not use priority scheduling."), ) + activation: Optional[bool] = None + # --8<-- [end:score-extra-params] def to_pooling_params(self): - return PoolingParams() + return PoolingParams(activation=self.activation) class RerankRequest(OpenAIBaseModel): @@ -1394,10 +1400,12 @@ class RerankRequest(OpenAIBaseModel): "if the served model does not use priority scheduling."), ) + activation: Optional[bool] = None + # --8<-- [end:rerank-extra-params] def to_pooling_params(self): - return PoolingParams() + return PoolingParams(activation=self.activation) class RerankDocument(BaseModel): @@ -1544,10 +1552,12 @@ class ClassificationRequest(OpenAIBaseModel): "if the served model does not use priority scheduling."), ) + activation: Optional[bool] = None + # --8<-- [end:classification-extra-params] def to_pooling_params(self): - return PoolingParams() + return PoolingParams(activation=self.activation) class ClassificationData(OpenAIBaseModel): diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py index 23eb775f2dc6..27e0bd7bec0d 100644 --- a/vllm/pooling_params.py +++ b/vllm/pooling_params.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, assert_never import msgspec @@ -9,7 +9,7 @@ from vllm.tasks import PoolingTask if TYPE_CHECKING: - from vllm.config import ModelConfig + from vllm.config import ModelConfig, PoolerConfig class PoolingParams( @@ -19,11 +19,21 @@ class PoolingParams( """API parameters for pooling models. Attributes: + normalize: Whether to normalize the embeddings outputs. dimensions: Reduce the dimensions of embeddings if model support matryoshka representation. + activation: Whether to apply activation function to + the classification outputs. + softmax: Whether to apply softmax to the reward outputs. """ + ## for embeddings models dimensions: Optional[int] = None + normalize: Optional[bool] = None + ## for classification models + activation: Optional[bool] = None + ## for reward models + softmax: Optional[bool] = None output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY @@ -33,15 +43,24 @@ class PoolingParams( requires_token_ids: bool = False """Internal use only.""" + @property + def all_parameters(self) -> list[str]: + return ["dimensions", "normalize", "activation", "softmax"] + def clone(self) -> "PoolingParams": """Returns a deep copy of the PoolingParams instance.""" return PoolingParams( dimensions=self.dimensions, + normalize=self.normalize, + activation=self.activation, + softmax=self.softmax, task=self.task, requires_token_ids=self.requires_token_ids, ) - def verify(self, task: PoolingTask, model_config: "ModelConfig") -> None: + def verify(self, + task: PoolingTask, + model_config: Optional["ModelConfig"] = None) -> None: if self.task is None: self.task = task elif self.task != task: @@ -52,28 +71,68 @@ def verify(self, task: PoolingTask, model_config: "ModelConfig") -> None: # which is not available in model config. So, it's not included # in this method - if self.dimensions is not None: - if not model_config.is_matryoshka: - raise ValueError( - f'Model "{model_config.served_model_name}" does not ' - f'support matryoshka representation, ' - f'changing output dimensions will lead to poor results.') + if self.task == "embed": + legal_parameters = ["dimensions", "normalize"] - mds = model_config.matryoshka_dimensions - if mds is not None: - if self.dimensions not in mds: + if self.dimensions is not None and model_config is not None: + if not model_config.is_matryoshka: raise ValueError( - f'Model "{model_config.served_model_name}" ' - f'only supports {str(mds)} matryoshka dimensions, ' - f'use other output dimensions will ' - f'lead to poor results.') - elif self.dimensions < 1: - raise ValueError("Dimensions must be greater than 0") + f'Model "{model_config.served_model_name}" does not ' + f'support matryoshka representation, ' + f'changing output dimensions will lead to poor results.' + ) + + mds = model_config.matryoshka_dimensions + if mds is not None: + if self.dimensions not in mds: + raise ValueError( + f'Model "{model_config.served_model_name}" ' + f'only supports {str(mds)} matryoshka dimensions, ' + f'use other output dimensions will ' + f'lead to poor results.') + elif self.dimensions < 1: + raise ValueError("Dimensions must be greater than 0") + + if self.normalize is None: + self.normalize = True + + elif self.task in ["classify", "score"]: + legal_parameters = ["activation"] + if self.activation is None: + self.activation = True + + elif self.task == "encode": + legal_parameters = ["softmax"] + if self.softmax is None: + self.softmax = True + else: + assert_never(self.task) + + invalid_parameters = [] + for k in self.all_parameters: + if k in legal_parameters: + continue + + if getattr(self, k, None) is not None: + invalid_parameters.append(k) + + if invalid_parameters: + raise ValueError( + f"{self.task} only supports {legal_parameters} parameters, " + f"does not support {invalid_parameters} parameters") + + def merge_default_parameters(self, pooler_config: "PoolerConfig") -> None: + for k in self.all_parameters: + if getattr(self, k, None) is None: + setattr(self, k, getattr(pooler_config, k)) def __repr__(self) -> str: return (f"PoolingParams(" f"dimensions={self.dimensions}, " f"task={self.task}, " + f"softmax={self.softmax}, " + f"normalize={self.normalize}, " + f"activation={self.activation}, " f"requires_token_ids={self.requires_token_ids})") def __post_init__(self) -> None: From fa1367ee2cf9aa4cc942d78abfde289b2d2a8b2f Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Sat, 26 Jul 2025 15:07:11 +0800 Subject: [PATCH 14/30] mypy Signed-off-by: wang.yuqi --- vllm/entrypoints/openai/serving_classification.py | 3 ++- vllm/entrypoints/openai/serving_embedding.py | 3 ++- vllm/entrypoints/openai/serving_pooling.py | 3 ++- vllm/entrypoints/openai/serving_score.py | 6 ++++-- vllm/pooling_params.py | 15 ++++++++++----- 5 files changed, 20 insertions(+), 10 deletions(-) diff --git a/vllm/entrypoints/openai/serving_classification.py b/vllm/entrypoints/openai/serving_classification.py index e10ba3c616f6..db6309fc0872 100644 --- a/vllm/entrypoints/openai/serving_classification.py +++ b/vllm/entrypoints/openai/serving_classification.py @@ -177,7 +177,8 @@ def _create_pooling_params( return pooling_params try: - pooling_params.merge_default_parameters(self.model_config.pooler_config) + pooling_params.merge_default_parameters( + self.model_config.pooler_config) pooling_params.verify("classify", self.model_config) except ValueError as e: return self.create_error_response(str(e)) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 9ad99556bccf..0cb1401fca78 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -200,7 +200,8 @@ def _create_pooling_params( return pooling_params try: - pooling_params.merge_default_parameters(self.model_config.pooler_config) + pooling_params.merge_default_parameters( + self.model_config.pooler_config) pooling_params.verify("embed", self.model_config) except ValueError as e: return self.create_error_response(str(e)) diff --git a/vllm/entrypoints/openai/serving_pooling.py b/vllm/entrypoints/openai/serving_pooling.py index 4c8a6fe01d46..c58428ee132a 100644 --- a/vllm/entrypoints/openai/serving_pooling.py +++ b/vllm/entrypoints/openai/serving_pooling.py @@ -140,7 +140,8 @@ async def create_pooling( pooling_params = request.to_pooling_params() try: - pooling_params.merge_default_parameters(self.model_config.pooler_config) + pooling_params.merge_default_parameters( + self.model_config.pooler_config) pooling_params.verify("encode", self.model_config) except ValueError as e: return self.create_error_response(str(e)) diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index 4a5a9476572f..eb3e313cc3f4 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -86,7 +86,8 @@ async def _embedding_score( pooling_params = request.to_pooling_params() try: - pooling_params.merge_default_parameters(self.model_config.pooler_config) + pooling_params.merge_default_parameters( + self.model_config.pooler_config) pooling_params.verify("embed", self.model_config) except ValueError as e: return self.create_error_response(str(e)) @@ -246,7 +247,8 @@ async def _cross_encoding_score( pooling_params = request.to_pooling_params() try: - pooling_params.merge_default_parameters(self.model_config.pooler_config) + pooling_params.merge_default_parameters( + self.model_config.pooler_config) pooling_params.verify("score", self.model_config) except ValueError as e: return self.create_error_response(str(e)) diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py index 27e0bd7bec0d..5c9946f13144 100644 --- a/vllm/pooling_params.py +++ b/vllm/pooling_params.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING, Optional, assert_never +from typing import TYPE_CHECKING, Optional import msgspec @@ -106,7 +106,7 @@ def verify(self, if self.softmax is None: self.softmax = True else: - assert_never(self.task) + raise ValueError(f"Unknown pooling task: {self.task}") invalid_parameters = [] for k in self.all_parameters: @@ -118,10 +118,15 @@ def verify(self, if invalid_parameters: raise ValueError( - f"{self.task} only supports {legal_parameters} parameters, " - f"does not support {invalid_parameters} parameters") + f"Task {self.task} only supports {legal_parameters} " + f"parameters, does not support " + f"{invalid_parameters} parameters") + + def merge_default_parameters( + self, pooler_config: Optional["PoolerConfig"]) -> None: + if pooler_config is None: + return - def merge_default_parameters(self, pooler_config: "PoolerConfig") -> None: for k in self.all_parameters: if getattr(self, k, None) is None: setattr(self, k, getattr(pooler_config, k)) From efcf72e1a872928b244e43c0d2c536edeea8230d Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Sat, 26 Jul 2025 15:16:31 +0800 Subject: [PATCH 15/30] fix Signed-off-by: wang.yuqi --- vllm/entrypoints/llm.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index cc4c0b565b4c..e4b0517442e0 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1129,7 +1129,7 @@ def encode( pooling_params.verify(pooling_task, model_config) else: for pooling_param in pooling_params: - pooling_params.merge_default_parameters( + pooling_param.merge_default_parameters( model_config.pooler_config) pooling_param.verify(pooling_task, model_config) @@ -1305,10 +1305,8 @@ def _cross_encoding_score( pooling_params = PoolingParams(task="score") model_config = self.llm_engine.model_config - pooling_task = "score" - pooling_params.merge_default_parameters(model_config.pooler_config) - pooling_params.verify(pooling_task, model_config) + pooling_params.verify("score", model_config) tokenization_kwargs: dict[str, Any] = {} _validate_truncation_size(self.llm_engine.model_config.max_model_len, From 8526e2a1fe5e5d1b79b41b6ae3850e0770f0b40a Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Sat, 26 Jul 2025 15:43:53 +0800 Subject: [PATCH 16/30] + test_reward_models_using_softmax Signed-off-by: wang.yuqi --- .../pooling/test_override_pooler_config.py | 40 +++++++++++++++++++ tests/test_pooling_params.py | 1 + 2 files changed, 41 insertions(+) diff --git a/tests/models/language/pooling/test_override_pooler_config.py b/tests/models/language/pooling/test_override_pooler_config.py index 63da5aa39469..f7af56e2cac3 100644 --- a/tests/models/language/pooling/test_override_pooler_config.py +++ b/tests/models/language/pooling/test_override_pooler_config.py @@ -84,3 +84,43 @@ def test_embed_models_using_normalize( assert torch.allclose( F.normalize(wo_normalize, p=2, dim=-1), w_normalize, atol=1e-2), "w_normal should be close to normal(wo_normal)." + + +@pytest.mark.parametrize( + "model", + [ + "Qwen/Qwen2.5-Math-PRM-7B", + ], +) +@pytest.mark.parametrize("dtype", ["half"]) +def test_reward_models_using_softmax( + hf_runner, + vllm_runner, + math_step_prompts, + model: str, + dtype: str, +) -> None: + + with vllm_runner( + model, + max_model_len=1024, + dtype=dtype, + override_pooler_config=PoolerConfig(softmax=False)) as vllm_model: + wo_softmax = vllm_model.encode(math_step_prompts) + + with vllm_runner( + model, + max_model_len=1024, + dtype=dtype, + override_pooler_config=PoolerConfig(softmax=True)) as vllm_model: + w_softmax = vllm_model.encode(math_step_prompts) + + for wo, w in zip(wo_softmax, w_softmax): + wo = torch.tensor(wo) + w = torch.tensor(w) + + assert not torch.allclose( + wo, w, atol=1e-2), "override_pooler_config softmax is not working" + assert torch.allclose( + F.softmax(wo, dim=-1), w, + atol=1e-2), "w_softmax should be close to softmax(wo_softmax)." diff --git a/tests/test_pooling_params.py b/tests/test_pooling_params.py index eb25f191ccbe..52c03015483c 100644 --- a/tests/test_pooling_params.py +++ b/tests/test_pooling_params.py @@ -64,6 +64,7 @@ def test_embed_dimensions(model_info: EmbedModelInfo): pooling_params.verify(task=task, model_config=model_config) if model_info.is_matryoshka: + assert model_info.matryoshka_dimensions is not None pooling_params = PoolingParams( dimensions=model_info.matryoshka_dimensions[0]) pooling_params.verify(task=task, model_config=model_config) From 684a2d98805d39a933e2e0d2488a8b632d01e115 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Wed, 30 Jul 2025 20:18:20 +0800 Subject: [PATCH 17/30] + test_reward Signed-off-by: wang.yuqi --- ...est_classification.py => test_classify.py} | 5 +- tests/entrypoints/llm/test_embedding.py | 5 +- tests/entrypoints/llm/test_reward.py | 66 +++++++++++ tests/entrypoints/llm/test_score.py | 5 +- vllm/model_executor/layers/pooler.py | 112 +++++++++--------- vllm/pooling_params.py | 39 ++++-- 6 files changed, 160 insertions(+), 72 deletions(-) rename tests/entrypoints/llm/{test_classification.py => test_classify.py} (96%) create mode 100644 tests/entrypoints/llm/test_reward.py diff --git a/tests/entrypoints/llm/test_classification.py b/tests/entrypoints/llm/test_classify.py similarity index 96% rename from tests/entrypoints/llm/test_classification.py rename to tests/entrypoints/llm/test_classify.py index 7429b582fe9e..2dd30f8d798c 100644 --- a/tests/entrypoints/llm/test_classification.py +++ b/tests/entrypoints/llm/test_classify.py @@ -43,11 +43,12 @@ def llm(): @pytest.mark.skip_global_cleanup -def test_activation(llm: LLM): +def test_pooling_params(llm: LLM): def get_outputs(activation): outputs = llm.classify( - prompts, pooling_params=PoolingParams(activation=activation)) + prompts, pooling_params=PoolingParams(activation=activation), + use_tqdm=False) return torch.tensor([x.outputs.probs for x in outputs]) default = get_outputs(activation=None) diff --git a/tests/entrypoints/llm/test_embedding.py b/tests/entrypoints/llm/test_embedding.py index db43398b4685..ba20d7b9548e 100644 --- a/tests/entrypoints/llm/test_embedding.py +++ b/tests/entrypoints/llm/test_embedding.py @@ -35,11 +35,12 @@ def llm(): @pytest.mark.skip_global_cleanup -def test_normalize(llm: LLM): +def test_pooling_params(llm: LLM): def get_outputs(normalize): outputs = llm.embed(prompts, - pooling_params=PoolingParams(normalize=normalize)) + pooling_params=PoolingParams(normalize=normalize), + use_tqdm=False) return torch.tensor([x.outputs.embedding for x in outputs]) default = get_outputs(normalize=None) diff --git a/tests/entrypoints/llm/test_reward.py b/tests/entrypoints/llm/test_reward.py new file mode 100644 index 000000000000..62097eb942ba --- /dev/null +++ b/tests/entrypoints/llm/test_reward.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import weakref + +import pytest +import torch +import torch.nn.functional as F + +from vllm import LLM, PoolingParams +from vllm.distributed import cleanup_dist_env_and_memory + +MODEL_NAME = "internlm/internlm2-1_8b-reward" + +prompts = ["The chef prepared a delicious meal."] + + +@pytest.fixture(autouse=True) +def v1(run_with_both_engines): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + +@pytest.fixture(scope="module") +def llm(): + # pytest caches the fixture so we use weakref.proxy to + # enable garbage collection + llm = LLM(model=MODEL_NAME, + max_num_batched_tokens=32768, + tensor_parallel_size=1, + gpu_memory_utilization=0.75, + enforce_eager=True, + trust_remote_code=True, + seed=0) + + with llm.deprecate_legacy_api(): + yield weakref.proxy(llm) + + del llm + + cleanup_dist_env_and_memory() + + +@pytest.mark.skip_global_cleanup +def test_pooling_params(llm: LLM): + + def get_outputs(softmax): + outputs = llm.reward( + prompts, pooling_params=PoolingParams(softmax=softmax), + use_tqdm=False) + return torch.tensor([x.outputs.data for x in outputs]) + + default = get_outputs(softmax=None) + w_softmax = get_outputs(softmax=True) + wo_softmax = get_outputs(softmax=False) + + assert torch.allclose(default, w_softmax, + atol=1e-2), "Default should use softmax." + assert not torch.allclose( + w_softmax, wo_softmax, + atol=1e-2), "wo_softmax should not use softmax." + assert torch.allclose( + F.softmax(wo_softmax, dim=-1), w_softmax, atol=1e-2 + ), "w_softmax should be close to softmax(wo_softmax)." diff --git a/tests/entrypoints/llm/test_score.py b/tests/entrypoints/llm/test_score.py index f06edf3b2adc..315c41537bd3 100644 --- a/tests/entrypoints/llm/test_score.py +++ b/tests/entrypoints/llm/test_score.py @@ -33,7 +33,7 @@ def llm(): @pytest.mark.skip_global_cleanup -def test_activation(llm: LLM): +def test_pooling_params(llm: LLM): def get_outputs(activation): text_1 = "What is the capital of France?" @@ -42,7 +42,8 @@ def get_outputs(activation): outputs = llm.score( text_1, text_2, - pooling_params=PoolingParams(activation=activation)) + pooling_params=PoolingParams(activation=activation), + use_tqdm=False) return torch.tensor([x.outputs.score for x in outputs]) default = get_outputs(activation=None) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 21c4907c973a..f4f05ce46afd 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -41,35 +41,18 @@ class PoolingType(IntEnum): @dataclass(frozen=True) class ResolvedPoolingConfig: pooling_type: PoolingType - - normalize: bool - softmax: bool - step_tag_id: Optional[int] - returned_token_ids: Optional[list[int]] + task: PoolingTask @classmethod def from_config_with_defaults( cls, + task: PoolingTask, pooler_config: PoolerConfig, pooling_type: PoolingType, - normalize: bool, - softmax: bool, - step_tag_id: Optional[int] = None, - returned_token_ids: Optional[list[int]] = None, ) -> "ResolvedPoolingConfig": - return cls( - pooling_type=PoolingType[pooler_config.pooling_type] - if pooler_config.pooling_type is not None else pooling_type, - normalize=pooler_config.normalize - if pooler_config.normalize is not None else normalize, - softmax=pooler_config.softmax - if pooler_config.softmax is not None else softmax, - step_tag_id=pooler_config.step_tag_id - if pooler_config.step_tag_id is not None else step_tag_id, - returned_token_ids=pooler_config.returned_token_ids - if pooler_config.returned_token_ids is not None else - returned_token_ids, - ) + return cls(task=task, + pooling_type=PoolingType[pooler_config.pooling_type] + if pooler_config.pooling_type is not None else pooling_type) @dataclass(frozen=True) @@ -89,18 +72,11 @@ def for_encode( pooler_config: PoolerConfig, *, default_pooling_type: PoolingType = PoolingType.ALL, - default_normalize: bool = False, - default_softmax: bool = False, - default_step_tag_id: Optional[int] = None, - default_returned_token_ids: Optional[list[int]] = None, ): resolved_config = ResolvedPoolingConfig.from_config_with_defaults( + task="encode", pooler_config=pooler_config, pooling_type=default_pooling_type, - normalize=default_normalize, - softmax=default_softmax, - step_tag_id=default_step_tag_id, - returned_token_ids=default_returned_token_ids, ) if resolved_config.pooling_type == PoolingType.STEP: @@ -113,14 +89,11 @@ def for_embed( pooler_config: PoolerConfig, *, default_pooling_type: PoolingType = PoolingType.LAST, - default_normalize: bool = True, - default_softmax: bool = False, ): resolved_config = ResolvedPoolingConfig.from_config_with_defaults( + task="embed", pooler_config=pooler_config, pooling_type=default_pooling_type, - normalize=default_normalize, - softmax=default_softmax, ) return SimplePooler.from_config(resolved_config) @@ -131,14 +104,11 @@ def for_classify( classifier: Optional[ClassifierFn], *, default_pooling_type: PoolingType = PoolingType.LAST, - default_normalize: bool = False, - default_softmax: bool = True, ): resolved_config = ResolvedPoolingConfig.from_config_with_defaults( + task="classify", pooler_config=pooler_config, pooling_type=default_pooling_type, - normalize=default_normalize, - softmax=default_softmax, ) base_pooler = SimplePooler.from_config(resolved_config) if classifier is None: @@ -490,23 +460,31 @@ def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: class PoolerHead(nn.Module): - @classmethod - def from_config(cls, pooler_config: ResolvedPoolingConfig) -> "PoolerHead": - if pooler_config.normalize and pooler_config.softmax: - raise ValueError("`normalize=True` and `softmax=True` should not " - "be set together") + def __init__(self, activation: PoolerActivation) -> None: + super().__init__() + self.activation = activation - return cls(pooler_config) + def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], + pooling_metadata: PoolingMetadata): + + # Using float32 in PoolerHead + if isinstance(pooled_data, list): + for i in range(len(pooled_data)): + pooled_data[i] = pooled_data[i].to(torch.float32) + else: + pooled_data = pooled_data.to(torch.float32) + + return self.activation(pooled_data) - def __init__(self, pooler_config: ResolvedPoolingConfig) -> None: - super().__init__() - self.pooler_config = pooler_config - self.normalize = PoolerNormalize() - self.softmax = PoolerClassify() +class EmbeddingPoolerHead(PoolerHead): + + def __init__(self) -> None: + super().__init__(activation=PoolerNormalize()) def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], pooling_metadata: PoolingMetadata): + # Using float32 in PoolerHead if isinstance(pooled_data, list): for i in range(len(pooled_data)): @@ -538,23 +516,43 @@ def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], flags = [p.normalize for p in pooling_params] if len(set(flags)) == 1: if flags[0]: - pooled_data = self.normalize(pooled_data) + pooled_data = self.activation(pooled_data) else: pooled_data = [ - self.normalize(vecs) if f else vecs + self.activation(vecs) if f else vecs for vecs, f in zip(pooled_data, flags) ] + return pooled_data + + +class RewardPoolerHead(PoolerHead): + + def __init__(self) -> None: + super().__init__(activation=PoolerNormalize()) + + def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], + pooling_metadata: PoolingMetadata): + # Using float32 in PoolerHead + if isinstance(pooled_data, list): + for i in range(len(pooled_data)): + pooled_data[i] = pooled_data[i].to(torch.float32) + else: + pooled_data = pooled_data.to(torch.float32) + + pooling_params = get_pooling_params(pooling_metadata) + # for softmax flags = [p.softmax for p in pooling_params] if len(set(flags)) == 1: if flags[0]: - pooled_data = self.softmax(pooled_data) + pooled_data = self.activation(pooled_data) else: pooled_data = [ - self.softmax(vecs) if f else vecs + self.activation(vecs) if f else vecs for vecs, f in zip(pooled_data, flags) ] + return pooled_data @@ -573,8 +571,12 @@ def from_config( pooler_config: ResolvedPoolingConfig, ) -> "SimplePooler": pooling = PoolingMethod.from_pooling_type(pooler_config.pooling_type) - head = PoolerHead.from_config(pooler_config) - + if pooler_config.task == "embed": + head = EmbeddingPoolerHead() + elif pooler_config.task == "encode": + head = RewardPoolerHead() + else: + raise NotImplementedError(f"Unknown task: {pooler_config.task}") return cls(pooling, head) def __init__(self, pooling: PoolingMethod, head: PoolerHead) -> None: diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py index 5c9946f13144..ebc5beef1446 100644 --- a/vllm/pooling_params.py +++ b/vllm/pooling_params.py @@ -30,12 +30,14 @@ class PoolingParams( ## for embeddings models dimensions: Optional[int] = None normalize: Optional[bool] = None + ## for classification models activation: Optional[bool] = None + ## for reward models softmax: Optional[bool] = None - - output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY + step_tag_id: Optional[int] = None + returned_token_ids: Optional[list[int]] = None task: Optional[PoolingTask] = None """Internal use only.""" @@ -43,9 +45,23 @@ class PoolingParams( requires_token_ids: bool = False """Internal use only.""" + output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY + @property def all_parameters(self) -> list[str]: - return ["dimensions", "normalize", "activation", "softmax"] + return [ + "dimensions", "normalize", "activation", "softmax", "step_tag_id", + "returned_token_ids" + ] + + @property + def legal_parameters(self): + return { + "embed": ["dimensions", "normalize"], + "classify": ["activation"], + "score": ["activation"], + "encode": ["softmax", "step_tag_id", "returned_token_ids"], + } def clone(self) -> "PoolingParams": """Returns a deep copy of the PoolingParams instance.""" @@ -72,8 +88,6 @@ def verify(self, # in this method if self.task == "embed": - legal_parameters = ["dimensions", "normalize"] - if self.dimensions is not None and model_config is not None: if not model_config.is_matryoshka: raise ValueError( @@ -92,22 +106,20 @@ def verify(self, f'lead to poor results.') elif self.dimensions < 1: raise ValueError("Dimensions must be greater than 0") - if self.normalize is None: self.normalize = True - elif self.task in ["classify", "score"]: - legal_parameters = ["activation"] if self.activation is None: self.activation = True elif self.task == "encode": - legal_parameters = ["softmax"] if self.softmax is None: self.softmax = True else: raise ValueError(f"Unknown pooling task: {self.task}") + assert self.task is not None, "task must be set" + legal_parameters = self.legal_parameters[self.task] invalid_parameters = [] for k in self.all_parameters: if k in legal_parameters: @@ -128,16 +140,21 @@ def merge_default_parameters( return for k in self.all_parameters: + if getattr(pooler_config, k, None) is None: + continue + if getattr(self, k, None) is None: setattr(self, k, getattr(pooler_config, k)) def __repr__(self) -> str: return (f"PoolingParams(" - f"dimensions={self.dimensions}, " f"task={self.task}, " - f"softmax={self.softmax}, " f"normalize={self.normalize}, " + f"dimensions={self.dimensions}, " f"activation={self.activation}, " + f"softmax={self.softmax}, " + f"step_tag_id={self.step_tag_id}, " + f"returned_token_ids={self.returned_token_ids}, " f"requires_token_ids={self.requires_token_ids})") def __post_init__(self) -> None: From d5e30e60ab9cd048563d9694fe972a7d27a0f918 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Fri, 1 Aug 2025 17:19:21 +0800 Subject: [PATCH 18/30] - default_normalize & default_softmax Signed-off-by: wang.yuqi --- tests/entrypoints/llm/test_classify.py | 10 +++-- tests/entrypoints/llm/test_reward.py | 20 +++++----- tests/entrypoints/llm/test_score.py | 5 ++- tests/models/utils.py | 7 ++++ vllm/config.py | 9 ----- vllm/entrypoints/llm.py | 8 ++-- .../openai/serving_classification.py | 2 +- vllm/entrypoints/openai/serving_embedding.py | 2 +- vllm/entrypoints/openai/serving_pooling.py | 2 +- vllm/entrypoints/openai/serving_score.py | 4 +- vllm/model_executor/layers/pooler.py | 37 ++++++------------ vllm/model_executor/models/config.py | 18 +++++++++ vllm/model_executor/models/jamba.py | 2 - vllm/model_executor/models/qwen2_rm.py | 3 -- vllm/pooling_params.py | 38 +++++++++++-------- 15 files changed, 89 insertions(+), 78 deletions(-) diff --git a/tests/entrypoints/llm/test_classify.py b/tests/entrypoints/llm/test_classify.py index 2dd30f8d798c..abdce8935ea5 100644 --- a/tests/entrypoints/llm/test_classify.py +++ b/tests/entrypoints/llm/test_classify.py @@ -5,11 +5,12 @@ import pytest import torch -import torch.nn.functional as F from vllm import LLM, PoolingParams from vllm.distributed import cleanup_dist_env_and_memory +from ...models.utils import softmax + MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach" prompts = ["The chef prepared a delicious meal."] @@ -47,8 +48,9 @@ def test_pooling_params(llm: LLM): def get_outputs(activation): outputs = llm.classify( - prompts, pooling_params=PoolingParams(activation=activation), - use_tqdm=False) + prompts, + pooling_params=PoolingParams(activation=activation), + use_tqdm=False) return torch.tensor([x.outputs.probs for x in outputs]) default = get_outputs(activation=None) @@ -61,5 +63,5 @@ def get_outputs(activation): w_activation, wo_activation, atol=1e-2), "wo_activation should not use activation." assert torch.allclose( - F.softmax(wo_activation, dim=-1), w_activation, atol=1e-2 + softmax(wo_activation), w_activation, atol=1e-2 ), "w_activation should be close to activation(wo_activation)." diff --git a/tests/entrypoints/llm/test_reward.py b/tests/entrypoints/llm/test_reward.py index 62097eb942ba..6fca2f64ffbc 100644 --- a/tests/entrypoints/llm/test_reward.py +++ b/tests/entrypoints/llm/test_reward.py @@ -5,11 +5,12 @@ import pytest import torch -import torch.nn.functional as F from vllm import LLM, PoolingParams from vllm.distributed import cleanup_dist_env_and_memory +from ...models.utils import softmax + MODEL_NAME = "internlm/internlm2-1_8b-reward" prompts = ["The chef prepared a delicious meal."] @@ -47,10 +48,10 @@ def llm(): def test_pooling_params(llm: LLM): def get_outputs(softmax): - outputs = llm.reward( - prompts, pooling_params=PoolingParams(softmax=softmax), - use_tqdm=False) - return torch.tensor([x.outputs.data for x in outputs]) + outputs = llm.reward(prompts, + pooling_params=PoolingParams(softmax=softmax), + use_tqdm=False) + return torch.cat([x.outputs.data for x in outputs]) default = get_outputs(softmax=None) w_softmax = get_outputs(softmax=True) @@ -58,9 +59,8 @@ def get_outputs(softmax): assert torch.allclose(default, w_softmax, atol=1e-2), "Default should use softmax." - assert not torch.allclose( - w_softmax, wo_softmax, - atol=1e-2), "wo_softmax should not use softmax." + assert not torch.allclose(w_softmax, wo_softmax, + atol=1e-2), "wo_softmax should not use softmax." assert torch.allclose( - F.softmax(wo_softmax, dim=-1), w_softmax, atol=1e-2 - ), "w_softmax should be close to softmax(wo_softmax)." + softmax(wo_softmax), w_softmax, + atol=1e-2), "w_softmax should be close to softmax(wo_softmax)." \ No newline at end of file diff --git a/tests/entrypoints/llm/test_score.py b/tests/entrypoints/llm/test_score.py index 315c41537bd3..100e4b58f615 100644 --- a/tests/entrypoints/llm/test_score.py +++ b/tests/entrypoints/llm/test_score.py @@ -5,11 +5,12 @@ import pytest import torch -import torch.nn.functional as F from vllm import LLM, PoolingParams from vllm.distributed import cleanup_dist_env_and_memory +from ...models.utils import softmax + MODEL_NAME = "BAAI/bge-reranker-v2-m3" @@ -56,5 +57,5 @@ def get_outputs(activation): w_activation, wo_activation, atol=1e-2), "wo_activation should not use activation." assert torch.allclose( - F.sigmoid(wo_activation), w_activation, atol=1e-2 + softmax(wo_activation), w_activation, atol=1e-2 ), "w_activation should be close to activation(wo_activation)." diff --git a/tests/models/utils.py b/tests/models/utils.py index 3cd0721be1b6..0a1846637d4e 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -330,6 +330,13 @@ def matryoshka_fy(tensor: torch.Tensor, dimensions: int): return tensor +def softmax(data): + if data.shape[-1] ==1: + return F.sigmoid(data) + else: + return F.softmax(data, dim=-1) + + class EmbedModelInfo(NamedTuple): name: str is_matryoshka: bool = False diff --git a/vllm/config.py b/vllm/config.py index 9d1c685fc060..daa4b0f070e5 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -899,15 +899,6 @@ def _init_pooler_config(self) -> Optional["PoolerConfig"]: if getattr(pooler_config, k) is None: setattr(pooler_config, k, v) - if self.is_matryoshka: - if pooler_config.normalize is None: - pooler_config.normalize = True - elif not pooler_config.normalize: - raise ValueError( - "`normalize` must be enabled (set to True) " - "for models that are compatible with " - "Matryoshka Representation.") - return pooler_config return None diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 1746d3e78de6..22720453d931 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1110,12 +1110,13 @@ def encode( pooling_params = PoolingParams() if isinstance(pooling_params, PoolingParams): - pooling_params.merge_default_parameters(model_config.pooler_config) + pooling_params.merge_default_parameters(pooling_task, + model_config.pooler_config) pooling_params.verify(pooling_task, model_config) else: for pooling_param in pooling_params: pooling_param.merge_default_parameters( - model_config.pooler_config) + pooling_task, model_config.pooler_config) pooling_param.verify(pooling_task, model_config) if tokenization_kwargs is None: @@ -1331,7 +1332,8 @@ def _cross_encoding_score( pooling_params = PoolingParams(task="score") model_config = self.llm_engine.model_config - pooling_params.merge_default_parameters(model_config.pooler_config) + pooling_params.merge_default_parameters("score", + model_config.pooler_config) pooling_params.verify("score", model_config) tokenization_kwargs: dict[str, Any] = {} diff --git a/vllm/entrypoints/openai/serving_classification.py b/vllm/entrypoints/openai/serving_classification.py index db6309fc0872..a6e0744d92fe 100644 --- a/vllm/entrypoints/openai/serving_classification.py +++ b/vllm/entrypoints/openai/serving_classification.py @@ -178,7 +178,7 @@ def _create_pooling_params( try: pooling_params.merge_default_parameters( - self.model_config.pooler_config) + "classify", self.model_config.pooler_config) pooling_params.verify("classify", self.model_config) except ValueError as e: return self.create_error_response(str(e)) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 0cb1401fca78..f429f884b13c 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -201,7 +201,7 @@ def _create_pooling_params( try: pooling_params.merge_default_parameters( - self.model_config.pooler_config) + "embed", self.model_config.pooler_config) pooling_params.verify("embed", self.model_config) except ValueError as e: return self.create_error_response(str(e)) diff --git a/vllm/entrypoints/openai/serving_pooling.py b/vllm/entrypoints/openai/serving_pooling.py index c58428ee132a..e9aa9373d52f 100644 --- a/vllm/entrypoints/openai/serving_pooling.py +++ b/vllm/entrypoints/openai/serving_pooling.py @@ -141,7 +141,7 @@ async def create_pooling( try: pooling_params.merge_default_parameters( - self.model_config.pooler_config) + "encode", self.model_config.pooler_config) pooling_params.verify("encode", self.model_config) except ValueError as e: return self.create_error_response(str(e)) diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index eb3e313cc3f4..cfb49d8feffb 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -87,7 +87,7 @@ async def _embedding_score( try: pooling_params.merge_default_parameters( - self.model_config.pooler_config) + "embed", self.model_config.pooler_config) pooling_params.verify("embed", self.model_config) except ValueError as e: return self.create_error_response(str(e)) @@ -248,7 +248,7 @@ async def _cross_encoding_score( try: pooling_params.merge_default_parameters( - self.model_config.pooler_config) + "score", self.model_config.pooler_config) pooling_params.verify("score", self.model_config) except ValueError as e: return self.create_error_response(str(e)) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index f4f05ce46afd..3887cd79aa84 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -80,7 +80,7 @@ def for_encode( ) if resolved_config.pooling_type == PoolingType.STEP: - return StepPooler.from_config(resolved_config) + return StepPooler() return SimplePooler.from_config(resolved_config) @@ -529,7 +529,7 @@ def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], class RewardPoolerHead(PoolerHead): def __init__(self) -> None: - super().__init__(activation=PoolerNormalize()) + super().__init__(activation=PoolerClassify()) def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], pooling_metadata: PoolingMetadata): @@ -603,29 +603,11 @@ def forward( class StepPooler(Pooler): - @classmethod - def from_config(cls, pooler_config: ResolvedPoolingConfig) -> "StepPooler": - assert pooler_config.pooling_type == PoolingType.STEP - - return cls( - PoolerHead.from_config(pooler_config), - step_tag_id=pooler_config.step_tag_id, - returned_token_ids=pooler_config.returned_token_ids, - ) - - def __init__( - self, - head: PoolerHead, - *, - step_tag_id: Optional[int] = None, - returned_token_ids: Optional[list[int]] = None, - ) -> None: + def __init__(self, ) -> None: super().__init__() self.pooling = AllPool() - self.head = head - self.step_tag_id = step_tag_id - self.returned_token_ids = returned_token_ids + self.head = RewardPoolerHead() def extract_states( self, @@ -636,10 +618,15 @@ def extract_states( prompt_token_ids = get_prompt_token_ids(pooling_metadata) pooled_data = list[torch.Tensor]() - returned_token_ids = self.returned_token_ids - step_tag_id = self.step_tag_id - for data, token_id in zip(pooled_data_lst, prompt_token_ids): + pooling_params = get_pooling_params(pooling_metadata) + + for data, token_id, pooling_param in zip(pooled_data_lst, + prompt_token_ids, + pooling_params): + step_tag_id = pooling_param.step_tag_id + returned_token_ids = pooling_param.returned_token_ids + if returned_token_ids is not None and len(returned_token_ids) > 0: data = data[:, returned_token_ids] diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 6f50b1753098..52f2c45abe1e 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -44,6 +44,14 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None: } +class JambaForSequenceClassificationConfig(VerifyAndUpdateConfig): + @staticmethod + def verify_and_update_config(vllm_config: "VllmConfig") -> None: + pooler_config = vllm_config.model_config.pooler_config + if pooler_config.activation is None: + pooler_config.activation = False + + class JinaRobertaModelConfig(VerifyAndUpdateConfig): @staticmethod @@ -155,6 +163,15 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None: vllm_config.recalculate_max_model_len(max_model_len) +class Qwen2ForProcessRewardModelConfig(VerifyAndUpdateConfig): + @staticmethod + def verify_and_update_config(vllm_config: "VllmConfig") -> None: + pooler_config = vllm_config.model_config.pooler_config + + if pooler_config.step_tag_id is None: + pooler_config.step_tag_id = 151651 + + class Qwen3ForSequenceClassificationConfig(VerifyAndUpdateConfig): @staticmethod @@ -309,6 +326,7 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: "GteModel": SnowflakeGteNewModelConfig, "GteNewModel": GteNewModelConfig, "NomicBertModel": NomicBertModelConfig, + "Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig, "Qwen3ForSequenceClassification": Qwen3ForSequenceClassificationConfig, "XLMRobertaModel": JinaRobertaModelConfig, "JinaVLForRanking": JinaVLForSequenceClassificationConfig, diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 263f4c8379cf..ab21b7ce2c5f 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -593,7 +593,5 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): pooler_config, classifier=self.score, default_pooling_type=PoolingType.LAST, - default_normalize=False, - default_softmax=False, ), }) diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py index f12e9a041a94..9b6b70c75c34 100644 --- a/vllm/model_executor/models/qwen2_rm.py +++ b/vllm/model_executor/models/qwen2_rm.py @@ -117,8 +117,5 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): Pooler.for_encode( pooler_config, default_pooling_type=PoolingType.STEP, - default_normalize=False, - default_softmax=True, - default_step_tag_id=151651, ) }) diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py index ebc5beef1446..e30108d8d60f 100644 --- a/vllm/pooling_params.py +++ b/vllm/pooling_params.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from copy import deepcopy from typing import TYPE_CHECKING, Optional import msgspec @@ -65,18 +66,12 @@ def legal_parameters(self): def clone(self) -> "PoolingParams": """Returns a deep copy of the PoolingParams instance.""" - return PoolingParams( - dimensions=self.dimensions, - normalize=self.normalize, - activation=self.activation, - softmax=self.softmax, - task=self.task, - requires_token_ids=self.requires_token_ids, - ) + return deepcopy(self) - def verify(self, - task: PoolingTask, - model_config: Optional["ModelConfig"] = None) -> None: + def verify_task( + self, + task: PoolingTask, + ): if self.task is None: self.task = task elif self.task != task: @@ -87,7 +82,15 @@ def verify(self, # which is not available in model config. So, it's not included # in this method + def verify(self, + task: PoolingTask, + model_config: Optional["ModelConfig"] = None) -> None: + self.verify_task(task) + if self.task == "embed": + if self.normalize is None: + self.normalize = True + if self.dimensions is not None and model_config is not None: if not model_config.is_matryoshka: raise ValueError( @@ -106,8 +109,7 @@ def verify(self, f'lead to poor results.') elif self.dimensions < 1: raise ValueError("Dimensions must be greater than 0") - if self.normalize is None: - self.normalize = True + elif self.task in ["classify", "score"]: if self.activation is None: self.activation = True @@ -135,11 +137,17 @@ def verify(self, f"{invalid_parameters} parameters") def merge_default_parameters( - self, pooler_config: Optional["PoolerConfig"]) -> None: + self, task: PoolingTask, + pooler_config: Optional["PoolerConfig"]) -> None: if pooler_config is None: return - for k in self.all_parameters: + self.verify_task(task) + + assert self.task is not None, "task must be set" + legal_parameters = self.legal_parameters[self.task] + + for k in legal_parameters: if getattr(pooler_config, k, None) is None: continue From cfa1a3d33ce48ab06f99053e02b2e6017f2af33f Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Fri, 1 Aug 2025 17:26:19 +0800 Subject: [PATCH 19/30] + JambaForSequenceClassificationConfig Signed-off-by: wang.yuqi --- vllm/model_executor/models/config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 52f2c45abe1e..22c9cbd0c2ce 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -330,5 +330,6 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: "Qwen3ForSequenceClassification": Qwen3ForSequenceClassificationConfig, "XLMRobertaModel": JinaRobertaModelConfig, "JinaVLForRanking": JinaVLForSequenceClassificationConfig, + "JambaForSequenceClassification": JambaForSequenceClassificationConfig, "GraniteMoeHybridForCausalLM": GraniteMoeHybridModelConfig, } From d0488e75a1bc3f8729154369fd6e6c7b3bc1d4e6 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Fri, 1 Aug 2025 17:52:12 +0800 Subject: [PATCH 20/30] fix Signed-off-by: wang.yuqi --- tests/models/utils.py | 2 +- vllm/model_executor/models/config.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/models/utils.py b/tests/models/utils.py index 0a1846637d4e..bda7ea3e3ad5 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -331,7 +331,7 @@ def matryoshka_fy(tensor: torch.Tensor, dimensions: int): def softmax(data): - if data.shape[-1] ==1: + if data.shape[-1] == 1: return F.sigmoid(data) else: return F.softmax(data, dim=-1) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 22c9cbd0c2ce..7950e8ab0fca 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -45,6 +45,7 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None: class JambaForSequenceClassificationConfig(VerifyAndUpdateConfig): + @staticmethod def verify_and_update_config(vllm_config: "VllmConfig") -> None: pooler_config = vllm_config.model_config.pooler_config @@ -164,6 +165,7 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None: class Qwen2ForProcessRewardModelConfig(VerifyAndUpdateConfig): + @staticmethod def verify_and_update_config(vllm_config: "VllmConfig") -> None: pooler_config = vllm_config.model_config.pooler_config From 5274e2f3900febcd73fc5797f67e6ba9cdeb9002 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Fri, 1 Aug 2025 19:11:03 +0800 Subject: [PATCH 21/30] fix Signed-off-by: wang.yuqi --- .../language/pooling/test_override_pooler_config.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/models/language/pooling/test_override_pooler_config.py b/tests/models/language/pooling/test_override_pooler_config.py index f7af56e2cac3..2b1c74652e76 100644 --- a/tests/models/language/pooling/test_override_pooler_config.py +++ b/tests/models/language/pooling/test_override_pooler_config.py @@ -4,6 +4,7 @@ import torch import torch.nn.functional as F +from tests.models.utils import softmax from vllm.config import PoolerConfig @@ -45,7 +46,7 @@ def test_classify_models_using_activation( assert not torch.allclose( wo_activation, w_activation, atol=1e-2), "override_pooler_config is not working" - assert torch.allclose(F.softmax(wo_activation, dim=-1), w_activation, + assert torch.allclose(softmax(wo_activation), w_activation, 1e-3 if dtype == "float" else 1e-2) @@ -89,14 +90,14 @@ def test_embed_models_using_normalize( @pytest.mark.parametrize( "model", [ - "Qwen/Qwen2.5-Math-PRM-7B", + "internlm/internlm2-1_8b-reward", ], ) @pytest.mark.parametrize("dtype", ["half"]) def test_reward_models_using_softmax( hf_runner, vllm_runner, - math_step_prompts, + example_prompts, model: str, dtype: str, ) -> None: @@ -106,14 +107,14 @@ def test_reward_models_using_softmax( max_model_len=1024, dtype=dtype, override_pooler_config=PoolerConfig(softmax=False)) as vllm_model: - wo_softmax = vllm_model.encode(math_step_prompts) + wo_softmax = vllm_model.encode(example_prompts) with vllm_runner( model, max_model_len=1024, dtype=dtype, override_pooler_config=PoolerConfig(softmax=True)) as vllm_model: - w_softmax = vllm_model.encode(math_step_prompts) + w_softmax = vllm_model.encode(example_prompts) for wo, w in zip(wo_softmax, w_softmax): wo = torch.tensor(wo) @@ -122,5 +123,5 @@ def test_reward_models_using_softmax( assert not torch.allclose( wo, w, atol=1e-2), "override_pooler_config softmax is not working" assert torch.allclose( - F.softmax(wo, dim=-1), w, + softmax(wo), w, atol=1e-2), "w_softmax should be close to softmax(wo_softmax)." From 2ffa8342f6866cc1d422007eb19ebe834140d2b8 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Mon, 4 Aug 2025 09:46:38 +0800 Subject: [PATCH 22/30] - merge_default_parameters Signed-off-by: wang.yuqi --- tests/entrypoints/llm/test_score.py | 9 +++ vllm/entrypoints/llm.py | 12 +--- .../openai/serving_classification.py | 4 +- vllm/entrypoints/openai/serving_embedding.py | 2 - vllm/entrypoints/openai/serving_pooling.py | 2 - vllm/entrypoints/openai/serving_score.py | 4 -- vllm/model_executor/layers/pooler.py | 21 ------ vllm/model_executor/models/config.py | 9 +++ vllm/pooling_params.py | 66 ++++++++++--------- 9 files changed, 57 insertions(+), 72 deletions(-) diff --git a/tests/entrypoints/llm/test_score.py b/tests/entrypoints/llm/test_score.py index 100e4b58f615..ef4186c00d9f 100644 --- a/tests/entrypoints/llm/test_score.py +++ b/tests/entrypoints/llm/test_score.py @@ -14,6 +14,15 @@ MODEL_NAME = "BAAI/bge-reranker-v2-m3" +@pytest.fixture(autouse=True) +def v1(run_with_both_engines): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + +@pytest.mark.skip_v1 @pytest.fixture(scope="module") def llm(): # pytest caches the fixture so we use weakref.proxy to diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 22720453d931..8f504b4af03a 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1110,14 +1110,10 @@ def encode( pooling_params = PoolingParams() if isinstance(pooling_params, PoolingParams): - pooling_params.merge_default_parameters(pooling_task, - model_config.pooler_config) - pooling_params.verify(pooling_task, model_config) + pooling_params.verify(pooling_task, model_config=model_config) else: for pooling_param in pooling_params: - pooling_param.merge_default_parameters( - pooling_task, model_config.pooler_config) - pooling_param.verify(pooling_task, model_config) + pooling_param.verify(pooling_task, model_config=model_config) if tokenization_kwargs is None: tokenization_kwargs = dict[str, Any]() @@ -1332,9 +1328,7 @@ def _cross_encoding_score( pooling_params = PoolingParams(task="score") model_config = self.llm_engine.model_config - pooling_params.merge_default_parameters("score", - model_config.pooler_config) - pooling_params.verify("score", model_config) + pooling_params.verify("score", model_config=model_config) tokenization_kwargs: dict[str, Any] = {} diff --git a/vllm/entrypoints/openai/serving_classification.py b/vllm/entrypoints/openai/serving_classification.py index a6e0744d92fe..84b5e29008e6 100644 --- a/vllm/entrypoints/openai/serving_classification.py +++ b/vllm/entrypoints/openai/serving_classification.py @@ -177,9 +177,7 @@ def _create_pooling_params( return pooling_params try: - pooling_params.merge_default_parameters( - "classify", self.model_config.pooler_config) - pooling_params.verify("classify", self.model_config) + pooling_params.verify("classify", model_config=self.model_config) except ValueError as e: return self.create_error_response(str(e)) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index f429f884b13c..84ba00873103 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -200,8 +200,6 @@ def _create_pooling_params( return pooling_params try: - pooling_params.merge_default_parameters( - "embed", self.model_config.pooler_config) pooling_params.verify("embed", self.model_config) except ValueError as e: return self.create_error_response(str(e)) diff --git a/vllm/entrypoints/openai/serving_pooling.py b/vllm/entrypoints/openai/serving_pooling.py index e9aa9373d52f..38745d001ade 100644 --- a/vllm/entrypoints/openai/serving_pooling.py +++ b/vllm/entrypoints/openai/serving_pooling.py @@ -140,8 +140,6 @@ async def create_pooling( pooling_params = request.to_pooling_params() try: - pooling_params.merge_default_parameters( - "encode", self.model_config.pooler_config) pooling_params.verify("encode", self.model_config) except ValueError as e: return self.create_error_response(str(e)) diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index cfb49d8feffb..4da2094147ce 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -86,8 +86,6 @@ async def _embedding_score( pooling_params = request.to_pooling_params() try: - pooling_params.merge_default_parameters( - "embed", self.model_config.pooler_config) pooling_params.verify("embed", self.model_config) except ValueError as e: return self.create_error_response(str(e)) @@ -247,8 +245,6 @@ async def _cross_encoding_score( pooling_params = request.to_pooling_params() try: - pooling_params.merge_default_parameters( - "score", self.model_config.pooler_config) pooling_params.verify("score", self.model_config) except ValueError as e: return self.create_error_response(str(e)) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 3887cd79aa84..d40dffe29539 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -467,13 +467,6 @@ def __init__(self, activation: PoolerActivation) -> None: def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], pooling_metadata: PoolingMetadata): - # Using float32 in PoolerHead - if isinstance(pooled_data, list): - for i in range(len(pooled_data)): - pooled_data[i] = pooled_data[i].to(torch.float32) - else: - pooled_data = pooled_data.to(torch.float32) - return self.activation(pooled_data) @@ -485,13 +478,6 @@ def __init__(self) -> None: def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], pooling_metadata: PoolingMetadata): - # Using float32 in PoolerHead - if isinstance(pooled_data, list): - for i in range(len(pooled_data)): - pooled_data[i] = pooled_data[i].to(torch.float32) - else: - pooled_data = pooled_data.to(torch.float32) - pooling_params = get_pooling_params(pooling_metadata) # for matryoshka representation @@ -533,13 +519,6 @@ def __init__(self) -> None: def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], pooling_metadata: PoolingMetadata): - # Using float32 in PoolerHead - if isinstance(pooled_data, list): - for i in range(len(pooled_data)): - pooled_data[i] = pooled_data[i].to(torch.float32) - else: - pooled_data = pooled_data.to(torch.float32) - pooling_params = get_pooling_params(pooling_metadata) # for softmax diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 18eda22bcf12..477ee0fd4344 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -173,6 +173,15 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None: if pooler_config.step_tag_id is None: pooler_config.step_tag_id = 151651 +class Qwen2ForRewardModelConfig(VerifyAndUpdateConfig): + + @staticmethod + def verify_and_update_config(vllm_config: "VllmConfig") -> None: + pooler_config = vllm_config.model_config.pooler_config + + if pooler_config.softmax is None: + pooler_config.softmax = False + class Qwen3ForSequenceClassificationConfig(VerifyAndUpdateConfig): diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py index e30108d8d60f..7077f68353fc 100644 --- a/vllm/pooling_params.py +++ b/vllm/pooling_params.py @@ -10,7 +10,7 @@ from vllm.tasks import PoolingTask if TYPE_CHECKING: - from vllm.config import ModelConfig, PoolerConfig + from vllm.config import ModelConfig class PoolingParams( @@ -56,7 +56,7 @@ def all_parameters(self) -> list[str]: ] @property - def legal_parameters(self): + def valid_parameters(self): return { "embed": ["dimensions", "normalize"], "classify": ["activation"], @@ -68,10 +68,10 @@ def clone(self) -> "PoolingParams": """Returns a deep copy of the PoolingParams instance.""" return deepcopy(self) - def verify_task( - self, - task: PoolingTask, - ): + def verify(self, + task: PoolingTask, + model_config: Optional["ModelConfig"] = None) -> None: + if self.task is None: self.task = task elif self.task != task: @@ -82,11 +82,32 @@ def verify_task( # which is not available in model config. So, it's not included # in this method - def verify(self, - task: PoolingTask, - model_config: Optional["ModelConfig"] = None) -> None: - self.verify_task(task) + self._merge_default_parameters(model_config) + self._set_default_parameters(model_config) + self._verify_valid_parameters() + def _merge_default_parameters(self, + model_config: Optional["ModelConfig"] = None + ) -> None: + + if model_config is None: + return + + pooler_config = model_config.pooler_config + if pooler_config is None: + return + + assert self.task is not None, "task must be set" + valid_parameters = self.valid_parameters[self.task] + + for k in valid_parameters: + if getattr(pooler_config, k, None) is None: + continue + + if getattr(self, k, None) is None: + setattr(self, k, getattr(pooler_config, k)) + + def _set_default_parameters(self, model_config: Optional["ModelConfig"]): if self.task == "embed": if self.normalize is None: self.normalize = True @@ -120,11 +141,12 @@ def verify(self, else: raise ValueError(f"Unknown pooling task: {self.task}") + def _verify_valid_parameters(self): assert self.task is not None, "task must be set" - legal_parameters = self.legal_parameters[self.task] + valid_parameters = self.valid_parameters[self.task] invalid_parameters = [] for k in self.all_parameters: - if k in legal_parameters: + if k in valid_parameters: continue if getattr(self, k, None) is not None: @@ -132,28 +154,10 @@ def verify(self, if invalid_parameters: raise ValueError( - f"Task {self.task} only supports {legal_parameters} " + f"Task {self.task} only supports {valid_parameters} " f"parameters, does not support " f"{invalid_parameters} parameters") - def merge_default_parameters( - self, task: PoolingTask, - pooler_config: Optional["PoolerConfig"]) -> None: - if pooler_config is None: - return - - self.verify_task(task) - - assert self.task is not None, "task must be set" - legal_parameters = self.legal_parameters[self.task] - - for k in legal_parameters: - if getattr(pooler_config, k, None) is None: - continue - - if getattr(self, k, None) is None: - setattr(self, k, getattr(pooler_config, k)) - def __repr__(self) -> str: return (f"PoolingParams(" f"task={self.task}, " From 9e692228f6a8406f8c31b6e38ea4ddbe6e00d348 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Mon, 4 Aug 2025 09:48:53 +0800 Subject: [PATCH 23/30] fix Signed-off-by: wang.yuqi --- tests/entrypoints/llm/test_reward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/entrypoints/llm/test_reward.py b/tests/entrypoints/llm/test_reward.py index 6fca2f64ffbc..361e2d0e1047 100644 --- a/tests/entrypoints/llm/test_reward.py +++ b/tests/entrypoints/llm/test_reward.py @@ -63,4 +63,4 @@ def get_outputs(softmax): atol=1e-2), "wo_softmax should not use softmax." assert torch.allclose( softmax(wo_softmax), w_softmax, - atol=1e-2), "w_softmax should be close to softmax(wo_softmax)." \ No newline at end of file + atol=1e-2), "w_softmax should be close to softmax(wo_softmax)." From bd83adab2bda573d9cfa4faf6d5e2e4c35f39cfb Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Mon, 4 Aug 2025 09:51:50 +0800 Subject: [PATCH 24/30] fix Signed-off-by: wang.yuqi --- vllm/entrypoints/llm.py | 6 +++--- vllm/entrypoints/openai/serving_classification.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 8f504b4af03a..ca24b0c32b73 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1110,10 +1110,10 @@ def encode( pooling_params = PoolingParams() if isinstance(pooling_params, PoolingParams): - pooling_params.verify(pooling_task, model_config=model_config) + pooling_params.verify(pooling_task, model_config) else: for pooling_param in pooling_params: - pooling_param.verify(pooling_task, model_config=model_config) + pooling_param.verify(pooling_task, model_config) if tokenization_kwargs is None: tokenization_kwargs = dict[str, Any]() @@ -1328,7 +1328,7 @@ def _cross_encoding_score( pooling_params = PoolingParams(task="score") model_config = self.llm_engine.model_config - pooling_params.verify("score", model_config=model_config) + pooling_params.verify("score", model_config) tokenization_kwargs: dict[str, Any] = {} diff --git a/vllm/entrypoints/openai/serving_classification.py b/vllm/entrypoints/openai/serving_classification.py index 84b5e29008e6..377f7f684717 100644 --- a/vllm/entrypoints/openai/serving_classification.py +++ b/vllm/entrypoints/openai/serving_classification.py @@ -177,7 +177,7 @@ def _create_pooling_params( return pooling_params try: - pooling_params.verify("classify", model_config=self.model_config) + pooling_params.verify("classify", self.model_config) except ValueError as e: return self.create_error_response(str(e)) From dab5b55ff01a3bb483fe785e1a77961435bf2897 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Mon, 4 Aug 2025 09:53:11 +0800 Subject: [PATCH 25/30] fix Signed-off-by: wang.yuqi --- vllm/model_executor/models/config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 477ee0fd4344..9ba9d4c1d4a0 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -338,6 +338,7 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: "GteNewModel": GteNewModelConfig, "NomicBertModel": NomicBertModelConfig, "Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig, + "Qwen2ForRewardModel": Qwen2ForRewardModelConfig, "Qwen3ForSequenceClassification": Qwen3ForSequenceClassificationConfig, "XLMRobertaModel": JinaRobertaModelConfig, "JinaVLForRanking": JinaVLForSequenceClassificationConfig, From 9129093d7f4a88c8c8141bfb717f906abcf7b2c5 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Mon, 4 Aug 2025 10:01:13 +0800 Subject: [PATCH 26/30] fix Signed-off-by: wang.yuqi --- vllm/model_executor/models/config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 9ba9d4c1d4a0..6f09be7a5941 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -173,6 +173,7 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None: if pooler_config.step_tag_id is None: pooler_config.step_tag_id = 151651 + class Qwen2ForRewardModelConfig(VerifyAndUpdateConfig): @staticmethod From 0973e6b3277b0a1562d5dff00d1b47023adef9a3 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Mon, 4 Aug 2025 10:20:53 +0800 Subject: [PATCH 27/30] fix Signed-off-by: wang.yuqi --- tests/models/language/pooling/test_reward.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/language/pooling/test_reward.py b/tests/models/language/pooling/test_reward.py index a5f7dca76d82..7add1d975c63 100644 --- a/tests/models/language/pooling/test_reward.py +++ b/tests/models/language/pooling/test_reward.py @@ -103,7 +103,7 @@ def test_prm_models( # check logits difference for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): - hf_output = torch.tensor(hf_output) - vllm_output = torch.tensor(vllm_output) + hf_output = torch.tensor(hf_output).float() + vllm_output = torch.tensor(vllm_output).float() assert torch.allclose(hf_output, vllm_output, 1.5e-2) From f0d6190fc6561d4cb502f9aea91fea5b11731d19 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Mon, 4 Aug 2025 13:16:58 +0800 Subject: [PATCH 28/30] using tomaarsen/Qwen3-Reranker-0.6B-seq-cls Signed-off-by: wang.yuqi --- tests/entrypoints/llm/test_score.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/entrypoints/llm/test_score.py b/tests/entrypoints/llm/test_score.py index ef4186c00d9f..dd4eae0ccc06 100644 --- a/tests/entrypoints/llm/test_score.py +++ b/tests/entrypoints/llm/test_score.py @@ -11,7 +11,7 @@ from ...models.utils import softmax -MODEL_NAME = "BAAI/bge-reranker-v2-m3" +MODEL_NAME = "tomaarsen/Qwen3-Reranker-0.6B-seq-cls" @pytest.fixture(autouse=True) @@ -22,7 +22,6 @@ def v1(run_with_both_engines): pass -@pytest.mark.skip_v1 @pytest.fixture(scope="module") def llm(): # pytest caches the fixture so we use weakref.proxy to From e55a3424219c9f6c1b025fb92d87a0202b06cee5 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Mon, 4 Aug 2025 15:21:18 +0800 Subject: [PATCH 29/30] fix Signed-off-by: wang.yuqi --- vllm/model_executor/layers/pooler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index d40dffe29539..edd04fc5ecc7 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -555,7 +555,7 @@ def from_config( elif pooler_config.task == "encode": head = RewardPoolerHead() else: - raise NotImplementedError(f"Unknown task: {pooler_config.task}") + head = PoolerHead(PoolerIdentity()) return cls(pooling, head) def __init__(self, pooling: PoolingMethod, head: PoolerHead) -> None: From b3624e11b047a9fa23268feb1e61fcca3c86f3b8 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Mon, 4 Aug 2025 22:06:27 +0800 Subject: [PATCH 30/30] ci bug ? Signed-off-by: wang.yuqi --- vllm/model_executor/layers/pooler.py | 35 +++++++++++++-------------- vllm/model_executor/models/jina_vl.py | 5 +--- 2 files changed, 18 insertions(+), 22 deletions(-) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index edd04fc5ecc7..0f2e58eb9b5d 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -110,14 +110,12 @@ def for_classify( pooler_config=pooler_config, pooling_type=default_pooling_type, ) - base_pooler = SimplePooler.from_config(resolved_config) - if classifier is None: - return base_pooler + + pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type) return ClassifierPooler( - pooling=base_pooler.pooling, + pooling=pooling, classifier=classifier, - act_fn=base_pooler.head.activation, ) @abstractmethod @@ -555,7 +553,7 @@ def from_config( elif pooler_config.task == "encode": head = RewardPoolerHead() else: - head = PoolerHead(PoolerIdentity()) + raise NotImplementedError(f"Unknown task: {pooler_config.task}") return cls(pooling, head) def __init__(self, pooling: PoolingMethod, head: PoolerHead) -> None: @@ -651,14 +649,14 @@ def act_fn_for_cross_encoder(config: ModelConfig): def __init__( self, pooling: PoolingFn, - classifier: ClassifierFn, - act_fn: PoolerActivation, + classifier: Optional[ClassifierFn], + act_fn: Optional[PoolerActivation] = None, ) -> None: super().__init__() self.pooling = pooling self.classifier = classifier - self.act_fn = act_fn + self.act_fn = act_fn or PoolerClassify() def get_supported_tasks(self) -> Set[PoolingTask]: return {"classify", "score"} @@ -670,23 +668,24 @@ def forward( ) -> PoolerOutput: pooled_data = self.pooling(hidden_states, pooling_metadata) - # apply classifier once on the full batch if possible - if isinstance(pooled_data, torch.Tensor): - pooled_output = self.classifier(pooled_data) - elif len({data.shape for data in pooled_data}) <= 1: - pooled_output = self.classifier(torch.stack(pooled_data)) - else: - pooled_output = [self.classifier(data) for data in pooled_data] + if self.classifier is not None: + # apply classifier once on the full batch if possible + if isinstance(pooled_data, torch.Tensor): + pooled_data = self.classifier(pooled_data) + elif len({data.shape for data in pooled_data}) <= 1: + pooled_data = self.classifier(torch.stack(pooled_data)) + else: + pooled_data = [self.classifier(data) for data in pooled_data] pooling_params = get_pooling_params(pooling_metadata) flags = [p.activation for p in pooling_params] if len(set(flags)) == 1: - scores = self.act_fn(pooled_output) if flags[0] else pooled_output + scores = self.act_fn(pooled_data) if flags[0] else pooled_data else: scores = [ self.act_fn(vecs) if f else vecs - for vecs, f in zip(pooled_output, flags) + for vecs, f in zip(pooled_data, flags) ] return build_output(scores) diff --git a/vllm/model_executor/models/jina_vl.py b/vllm/model_executor/models/jina_vl.py index 0c4284f7daaa..8c64f636c6a0 100644 --- a/vllm/model_executor/models/jina_vl.py +++ b/vllm/model_executor/models/jina_vl.py @@ -90,15 +90,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=maybe_prefix(prefix, "qwen2_vl")) config = vllm_config.model_config.hf_config pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None # logit bias for sigmoid normalization self.LOGIT_BIAS = 2.65 self.score = JinaVLScorer(config) - - pooler_config = vllm_config.model_config.pooler_config - assert pooler_config is not None - self.pooler = DispatchPooler({ "encode": Pooler.for_encode(pooler_config),