diff --git a/tests/entrypoints/llm/test_classify.py b/tests/entrypoints/llm/test_classify.py new file mode 100644 index 000000000000..abdce8935ea5 --- /dev/null +++ b/tests/entrypoints/llm/test_classify.py @@ -0,0 +1,67 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import weakref + +import pytest +import torch + +from vllm import LLM, PoolingParams +from vllm.distributed import cleanup_dist_env_and_memory + +from ...models.utils import softmax + +MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach" + +prompts = ["The chef prepared a delicious meal."] + + +@pytest.fixture(autouse=True) +def v1(run_with_both_engines): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + +@pytest.fixture(scope="module") +def llm(): + # pytest caches the fixture so we use weakref.proxy to + # enable garbage collection + llm = LLM(model=MODEL_NAME, + max_num_batched_tokens=32768, + tensor_parallel_size=1, + gpu_memory_utilization=0.75, + enforce_eager=True, + seed=0) + + with llm.deprecate_legacy_api(): + yield weakref.proxy(llm) + + del llm + + cleanup_dist_env_and_memory() + + +@pytest.mark.skip_global_cleanup +def test_pooling_params(llm: LLM): + + def get_outputs(activation): + outputs = llm.classify( + prompts, + pooling_params=PoolingParams(activation=activation), + use_tqdm=False) + return torch.tensor([x.outputs.probs for x in outputs]) + + default = get_outputs(activation=None) + w_activation = get_outputs(activation=True) + wo_activation = get_outputs(activation=False) + + assert torch.allclose(default, w_activation, + atol=1e-2), "Default should use activation." + assert not torch.allclose( + w_activation, wo_activation, + atol=1e-2), "wo_activation should not use activation." + assert torch.allclose( + softmax(wo_activation), w_activation, atol=1e-2 + ), "w_activation should be close to activation(wo_activation)." diff --git a/tests/entrypoints/llm/test_embedding.py b/tests/entrypoints/llm/test_embedding.py new file mode 100644 index 000000000000..ba20d7b9548e --- /dev/null +++ b/tests/entrypoints/llm/test_embedding.py @@ -0,0 +1,56 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import weakref + +import pytest +import torch +import torch.nn.functional as F + +from vllm import LLM, PoolingParams +from vllm.distributed import cleanup_dist_env_and_memory + +MODEL_NAME = "intfloat/multilingual-e5-small" + +prompts = ["The chef prepared a delicious meal."] + + +@pytest.fixture(scope="module") +def llm(): + # pytest caches the fixture so we use weakref.proxy to + # enable garbage collection + llm = LLM(model=MODEL_NAME, + max_num_batched_tokens=32768, + tensor_parallel_size=1, + gpu_memory_utilization=0.75, + enforce_eager=True, + seed=0) + + with llm.deprecate_legacy_api(): + yield weakref.proxy(llm) + + del llm + + cleanup_dist_env_and_memory() + + +@pytest.mark.skip_global_cleanup +def test_pooling_params(llm: LLM): + + def get_outputs(normalize): + outputs = llm.embed(prompts, + pooling_params=PoolingParams(normalize=normalize), + use_tqdm=False) + return torch.tensor([x.outputs.embedding for x in outputs]) + + default = get_outputs(normalize=None) + w_normal = get_outputs(normalize=True) + wo_normal = get_outputs(normalize=False) + + assert torch.allclose(default, w_normal, + atol=1e-2), "Default should use normal." + assert not torch.allclose(w_normal, wo_normal, + atol=1e-2), "wo_normal should not use normal." + assert torch.allclose( + w_normal, F.normalize(wo_normal, p=2, dim=-1), + atol=1e-2), "w_normal should be close to normal(wo_normal)." diff --git a/tests/entrypoints/llm/test_reward.py b/tests/entrypoints/llm/test_reward.py new file mode 100644 index 000000000000..361e2d0e1047 --- /dev/null +++ b/tests/entrypoints/llm/test_reward.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import weakref + +import pytest +import torch + +from vllm import LLM, PoolingParams +from vllm.distributed import cleanup_dist_env_and_memory + +from ...models.utils import softmax + +MODEL_NAME = "internlm/internlm2-1_8b-reward" + +prompts = ["The chef prepared a delicious meal."] + + +@pytest.fixture(autouse=True) +def v1(run_with_both_engines): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + +@pytest.fixture(scope="module") +def llm(): + # pytest caches the fixture so we use weakref.proxy to + # enable garbage collection + llm = LLM(model=MODEL_NAME, + max_num_batched_tokens=32768, + tensor_parallel_size=1, + gpu_memory_utilization=0.75, + enforce_eager=True, + trust_remote_code=True, + seed=0) + + with llm.deprecate_legacy_api(): + yield weakref.proxy(llm) + + del llm + + cleanup_dist_env_and_memory() + + +@pytest.mark.skip_global_cleanup +def test_pooling_params(llm: LLM): + + def get_outputs(softmax): + outputs = llm.reward(prompts, + pooling_params=PoolingParams(softmax=softmax), + use_tqdm=False) + return torch.cat([x.outputs.data for x in outputs]) + + default = get_outputs(softmax=None) + w_softmax = get_outputs(softmax=True) + wo_softmax = get_outputs(softmax=False) + + assert torch.allclose(default, w_softmax, + atol=1e-2), "Default should use softmax." + assert not torch.allclose(w_softmax, wo_softmax, + atol=1e-2), "wo_softmax should not use softmax." + assert torch.allclose( + softmax(wo_softmax), w_softmax, + atol=1e-2), "w_softmax should be close to softmax(wo_softmax)." diff --git a/tests/entrypoints/llm/test_score.py b/tests/entrypoints/llm/test_score.py new file mode 100644 index 000000000000..dd4eae0ccc06 --- /dev/null +++ b/tests/entrypoints/llm/test_score.py @@ -0,0 +1,69 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import weakref + +import pytest +import torch + +from vllm import LLM, PoolingParams +from vllm.distributed import cleanup_dist_env_and_memory + +from ...models.utils import softmax + +MODEL_NAME = "tomaarsen/Qwen3-Reranker-0.6B-seq-cls" + + +@pytest.fixture(autouse=True) +def v1(run_with_both_engines): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + +@pytest.fixture(scope="module") +def llm(): + # pytest caches the fixture so we use weakref.proxy to + # enable garbage collection + llm = LLM(model=MODEL_NAME, + max_num_batched_tokens=32768, + tensor_parallel_size=1, + gpu_memory_utilization=0.75, + enforce_eager=True, + seed=0) + + with llm.deprecate_legacy_api(): + yield weakref.proxy(llm) + + del llm + + cleanup_dist_env_and_memory() + + +@pytest.mark.skip_global_cleanup +def test_pooling_params(llm: LLM): + + def get_outputs(activation): + text_1 = "What is the capital of France?" + text_2 = "The capital of France is Paris." + + outputs = llm.score( + text_1, + text_2, + pooling_params=PoolingParams(activation=activation), + use_tqdm=False) + return torch.tensor([x.outputs.score for x in outputs]) + + default = get_outputs(activation=None) + w_activation = get_outputs(activation=True) + wo_activation = get_outputs(activation=False) + + assert torch.allclose(default, w_activation, + atol=1e-2), "Default should use activation." + assert not torch.allclose( + w_activation, wo_activation, + atol=1e-2), "wo_activation should not use activation." + assert torch.allclose( + softmax(wo_activation), w_activation, atol=1e-2 + ), "w_activation should be close to activation(wo_activation)." diff --git a/tests/entrypoints/openai/test_classification.py b/tests/entrypoints/openai/test_classification.py index b2472658ca81..bcf127307f73 100644 --- a/tests/entrypoints/openai/test_classification.py +++ b/tests/entrypoints/openai/test_classification.py @@ -3,6 +3,8 @@ import pytest import requests +import torch +import torch.nn.functional as F from vllm.entrypoints.openai.protocol import ClassificationResponse @@ -181,3 +183,32 @@ async def test_invocations(server: RemoteOpenAIServer): assert classification_data.keys() == invocation_data.keys() assert classification_data["probs"] == pytest.approx( invocation_data["probs"], rel=0.01) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_activation(server: RemoteOpenAIServer, model_name: str): + input_text = ["This product was excellent and exceeded my expectations"] + + async def get_outputs(activation): + response = requests.post(server.url_for("classify"), + json={ + "model": model_name, + "input": input_text, + "activation": activation + }) + outputs = response.json() + return torch.tensor([x['probs'] for x in outputs["data"]]) + + default = await get_outputs(activation=None) + w_activation = await get_outputs(activation=True) + wo_activation = await get_outputs(activation=False) + + assert torch.allclose(default, w_activation, + atol=1e-2), "Default should use activation." + assert not torch.allclose( + w_activation, wo_activation, + atol=1e-2), "wo_activation should not use activation." + assert torch.allclose( + F.softmax(wo_activation, dim=-1), w_activation, atol=1e-2 + ), "w_activation should be close to activation(wo_activation)." diff --git a/tests/entrypoints/openai/test_embedding.py b/tests/entrypoints/openai/test_embedding.py index a7203befcc40..cf2442a56938 100644 --- a/tests/entrypoints/openai/test_embedding.py +++ b/tests/entrypoints/openai/test_embedding.py @@ -8,6 +8,8 @@ import pytest import pytest_asyncio import requests +import torch +import torch.nn.functional as F from vllm.entrypoints.openai.protocol import EmbeddingResponse from vllm.transformers_utils.tokenizer import get_tokenizer @@ -369,3 +371,35 @@ async def test_invocations_conversation(server: RemoteOpenAIServer): embeddings_1_lst=[invocation_data["embedding"]], name_0="chat", name_1="invocation") + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_normalize(server: RemoteOpenAIServer, model_name: str): + input_text = ["The chef prepared a delicious meal."] + + async def get_outputs(normalize): + request_args = { + "model": MODEL_NAME, + "input": input_text, + "encoding_format": "float", + "normalize": normalize + } + + response = requests.post(server.url_for("v1/embeddings"), + json=request_args) + outputs = response.json() + + return torch.tensor([x['embedding'] for x in outputs["data"]]) + + default = await get_outputs(normalize=None) + w_normal = await get_outputs(normalize=True) + wo_normal = await get_outputs(normalize=False) + + assert torch.allclose(default, w_normal, + atol=1e-2), "Default should use normal." + assert not torch.allclose(w_normal, wo_normal, + atol=1e-2), "wo_normal should not use normal." + assert torch.allclose( + w_normal, F.normalize(wo_normal, p=2, dim=-1), + atol=1e-2), "w_normal should be close to normal(wo_normal)." diff --git a/tests/entrypoints/openai/test_rerank.py b/tests/entrypoints/openai/test_rerank.py index 4da97fe13691..f121693e329f 100644 --- a/tests/entrypoints/openai/test_rerank.py +++ b/tests/entrypoints/openai/test_rerank.py @@ -3,6 +3,8 @@ import pytest import requests +import torch +import torch.nn.functional as F from vllm.entrypoints.openai.protocol import RerankResponse @@ -125,3 +127,39 @@ def test_invocations(server: RemoteOpenAIServer): assert rerank_result.keys() == invocations_result.keys() assert rerank_result["relevance_score"] == pytest.approx( invocations_result["relevance_score"], rel=0.01) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_activation(server: RemoteOpenAIServer, model_name: str): + + async def get_outputs(activation): + query = "What is the capital of France?" + documents = [ + "The capital of Brazil is Brasilia.", + "The capital of France is Paris." + ] + + response = requests.post(server.url_for("rerank"), + json={ + "model": model_name, + "query": query, + "documents": documents, + "activation": activation + }) + outputs = response.json() + + return torch.tensor([x['relevance_score'] for x in outputs["results"]]) + + default = await get_outputs(activation=None) + w_activation = await get_outputs(activation=True) + wo_activation = await get_outputs(activation=False) + + assert torch.allclose(default, w_activation, + atol=1e-2), "Default should use activation." + assert not torch.allclose( + w_activation, wo_activation, + atol=1e-2), "wo_activation should not use activation." + assert torch.allclose( + F.sigmoid(wo_activation), w_activation, atol=1e-2 + ), "w_activation should be close to activation(wo_activation)." diff --git a/tests/entrypoints/openai/test_score.py b/tests/entrypoints/openai/test_score.py index 187542b7bafc..1a5df1d2dbd2 100644 --- a/tests/entrypoints/openai/test_score.py +++ b/tests/entrypoints/openai/test_score.py @@ -4,6 +4,7 @@ import pytest import requests +import torch import torch.nn.functional as F from torch import tensor @@ -220,3 +221,43 @@ def test_invocations(self, server: RemoteOpenAIServer, model: dict[str, assert score_data.keys() == invocation_data.keys() assert score_data["score"] == pytest.approx( invocation_data["score"], rel=0.01) + + def test_activation(self, server: RemoteOpenAIServer, model: dict[str, + Any]): + + def get_outputs(activation): + text_1 = "What is the capital of France?" + text_2 = "The capital of France is Paris." + response = requests.post(server.url_for("score"), + json={ + "model": model["name"], + "text_1": text_1, + "text_2": text_2, + "activation": activation + }) + if response.status_code != 200: + return response + + outputs = response.json() + return torch.tensor([x['score'] for x in outputs["data"]]) + + if model["is_cross_encoder"]: + + default = get_outputs(activation=None) + w_activation = get_outputs(activation=True) + wo_activation = get_outputs(activation=False) + + assert torch.allclose(default, w_activation, + atol=1e-2), "Default should use activation." + assert not torch.allclose( + w_activation, wo_activation, + atol=1e-2), "wo_activation should not use activation." + assert torch.allclose( + F.sigmoid(wo_activation), w_activation, atol=1e-2 + ), "w_activation should be close to activation(wo_activation)." + else: + get_outputs(activation=None) + + # The activation parameter only works for the is_cross_encoder model + response = get_outputs(activation=True) + assert response.status_code == 400 diff --git a/tests/models/language/pooling/test_override_pooler_config.py b/tests/models/language/pooling/test_override_pooler_config.py new file mode 100644 index 000000000000..2b1c74652e76 --- /dev/null +++ b/tests/models/language/pooling/test_override_pooler_config.py @@ -0,0 +1,127 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch +import torch.nn.functional as F + +from tests.models.utils import softmax +from vllm.config import PoolerConfig + + +@pytest.mark.parametrize( + "model", + [ + "jason9693/Qwen2.5-1.5B-apeach", + "papluca/xlm-roberta-base-language-detection" + ], +) +@pytest.mark.parametrize("dtype", ["half"]) +def test_classify_models_using_activation( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, +) -> None: + + with vllm_runner(model, + max_model_len=512, + dtype=dtype, + override_pooler_config=PoolerConfig( + activation=False)) as vllm_model: + wo_activation_out = vllm_model.classify(example_prompts) + + with vllm_runner(model, + max_model_len=512, + dtype=dtype, + override_pooler_config=PoolerConfig( + activation=True)) as vllm_model: + w_activation_out = vllm_model.classify(example_prompts) + + for wo_activation, w_activation in zip(wo_activation_out, + w_activation_out): + wo_activation = torch.tensor(wo_activation) + w_activation = torch.tensor(w_activation) + + assert not torch.allclose( + wo_activation, w_activation, + atol=1e-2), "override_pooler_config is not working" + assert torch.allclose(softmax(wo_activation), w_activation, + 1e-3 if dtype == "float" else 1e-2) + + +@pytest.mark.parametrize( + "model", + [ + "intfloat/multilingual-e5-small", + ], +) +@pytest.mark.parametrize("dtype", ["half"]) +def test_embed_models_using_normalize( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, +) -> None: + + with vllm_runner(model, + max_model_len=512, + dtype=dtype, + override_pooler_config=PoolerConfig( + normalize=False)) as vllm_model: + wo_normalize = torch.tensor(vllm_model.embed(example_prompts)) + + with vllm_runner( + model, + max_model_len=512, + dtype=dtype, + override_pooler_config=PoolerConfig(normalize=True)) as vllm_model: + w_normalize = torch.tensor(vllm_model.embed(example_prompts)) + + assert not torch.allclose( + wo_normalize, w_normalize, + atol=1e-2), "override_pooler_config normalize is not working" + assert torch.allclose( + F.normalize(wo_normalize, p=2, dim=-1), w_normalize, + atol=1e-2), "w_normal should be close to normal(wo_normal)." + + +@pytest.mark.parametrize( + "model", + [ + "internlm/internlm2-1_8b-reward", + ], +) +@pytest.mark.parametrize("dtype", ["half"]) +def test_reward_models_using_softmax( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, +) -> None: + + with vllm_runner( + model, + max_model_len=1024, + dtype=dtype, + override_pooler_config=PoolerConfig(softmax=False)) as vllm_model: + wo_softmax = vllm_model.encode(example_prompts) + + with vllm_runner( + model, + max_model_len=1024, + dtype=dtype, + override_pooler_config=PoolerConfig(softmax=True)) as vllm_model: + w_softmax = vllm_model.encode(example_prompts) + + for wo, w in zip(wo_softmax, w_softmax): + wo = torch.tensor(wo) + w = torch.tensor(w) + + assert not torch.allclose( + wo, w, atol=1e-2), "override_pooler_config softmax is not working" + assert torch.allclose( + softmax(wo), w, + atol=1e-2), "w_softmax should be close to softmax(wo_softmax)." diff --git a/tests/models/language/pooling/test_reward.py b/tests/models/language/pooling/test_reward.py index a5f7dca76d82..7add1d975c63 100644 --- a/tests/models/language/pooling/test_reward.py +++ b/tests/models/language/pooling/test_reward.py @@ -103,7 +103,7 @@ def test_prm_models( # check logits difference for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): - hf_output = torch.tensor(hf_output) - vllm_output = torch.tensor(vllm_output) + hf_output = torch.tensor(hf_output).float() + vllm_output = torch.tensor(vllm_output).float() assert torch.allclose(hf_output, vllm_output, 1.5e-2) diff --git a/tests/models/utils.py b/tests/models/utils.py index 3cd0721be1b6..bda7ea3e3ad5 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -330,6 +330,13 @@ def matryoshka_fy(tensor: torch.Tensor, dimensions: int): return tensor +def softmax(data): + if data.shape[-1] == 1: + return F.sigmoid(data) + else: + return F.softmax(data, dim=-1) + + class EmbedModelInfo(NamedTuple): name: str is_matryoshka: bool = False diff --git a/tests/test_pooling_params.py b/tests/test_pooling_params.py new file mode 100644 index 000000000000..52c03015483c --- /dev/null +++ b/tests/test_pooling_params.py @@ -0,0 +1,106 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest + +from tests.models.utils import EmbedModelInfo +from vllm import PoolingParams +from vllm.config import ModelConfig + +EMBEDDING_MODELS = [ + EmbedModelInfo("intfloat/multilingual-e5-small", is_matryoshka=False), + EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v1.5", + is_matryoshka=True, + matryoshka_dimensions=[256]), +] + + +def test_task(): + pooling_params = PoolingParams() + pooling_params.verify(task="score") + + pooling_params = PoolingParams(task="score") + pooling_params.verify(task="score") + + with pytest.raises(ValueError): + pooling_params.verify(task="encode") + + +def test_embed(): + task = "embed" + pooling_params = PoolingParams(normalize=None) + pooling_params.verify(task=task) + + pooling_params = PoolingParams(normalize=True) + pooling_params.verify(task=task) + + pooling_params = PoolingParams(normalize=False) + pooling_params.verify(task=task) + + invalid_parameters = ["activation", "softmax"] + for p in invalid_parameters: + with pytest.raises(ValueError): + pooling_params = PoolingParams(**{p: True}) + pooling_params.verify(task=task) + + +@pytest.mark.parametrize("model_info", EMBEDDING_MODELS) +def test_embed_dimensions(model_info: EmbedModelInfo): + task = "embed" + model_config = ModelConfig( + model_info.name, + task="auto", + tokenizer=model_info.name, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="float16", + ) + + pooling_params = PoolingParams(dimensions=None) + pooling_params.verify(task=task, model_config=model_config) + + with pytest.raises(ValueError): + pooling_params = PoolingParams(dimensions=1) + pooling_params.verify(task=task, model_config=model_config) + + if model_info.is_matryoshka: + assert model_info.matryoshka_dimensions is not None + pooling_params = PoolingParams( + dimensions=model_info.matryoshka_dimensions[0]) + pooling_params.verify(task=task, model_config=model_config) + + +@pytest.mark.parametrize("task", ["score", "classify"]) +def test_classify(task): + pooling_params = PoolingParams(activation=None) + pooling_params.verify(task=task) + + pooling_params = PoolingParams(activation=True) + pooling_params.verify(task=task) + + pooling_params = PoolingParams(activation=False) + pooling_params.verify(task=task) + + invalid_parameters = ["dimensions", "normalize", "softmax"] + for p in invalid_parameters: + with pytest.raises(ValueError): + pooling_params = PoolingParams(**{p: True}) + pooling_params.verify(task=task) + + +def test_encode(): + task = "encode" + pooling_params = PoolingParams(softmax=None) + pooling_params.verify(task=task) + + pooling_params = PoolingParams(softmax=True) + pooling_params.verify(task=task) + + pooling_params = PoolingParams(softmax=False) + pooling_params.verify(task=task) + + invalid_parameters = ["dimensions", "normalize", "activation"] + for p in invalid_parameters: + with pytest.raises(ValueError): + pooling_params = PoolingParams(**{p: True}) + pooling_params.verify(task=task) diff --git a/vllm/config.py b/vllm/config.py index 1100e1077401..6c1e85bb15c8 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -913,15 +913,6 @@ def _init_pooler_config(self) -> Optional["PoolerConfig"]: if getattr(pooler_config, k) is None: setattr(pooler_config, k, v) - if self.is_matryoshka: - if pooler_config.normalize is None: - pooler_config.normalize = True - elif not pooler_config.normalize: - raise ValueError( - "`normalize` must be enabled (set to True) " - "for models that are compatible with " - "Matryoshka Representation.") - return pooler_config return None @@ -3438,25 +3429,34 @@ class PoolerConfig: [`vllm.model_executor.layers.pooler.PoolingType`][]. """ + ## for embeddings models normalize: Optional[bool] = None """ - Whether to normalize the pooled outputs. Usually, this should be set to - ``True`` for embedding outputs. + Whether to normalize the embeddings outputs. + """ + dimensions: Optional[int] = None + """ + Reduce the dimensions of embeddings if model + support matryoshka representation. """ - softmax: Optional[bool] = None + ## for classification models + activation: Optional[bool] = None """ - Whether to apply softmax to the pooled outputs. Usually, this should be set - to ``True`` for classification outputs. + Whether to apply activation function to the classification outputs. """ + ## for reward models + softmax: Optional[bool] = None + """ + Whether to apply softmax to the reward outputs. + """ step_tag_id: Optional[int] = None """ If set, only the score corresponding to the ``step_tag_id`` in the generated sentence should be returned. Otherwise, the scores for all tokens are returned. """ - returned_token_ids: Optional[list[int]] = None """ A list of indices for the vocabulary dimensions to be extracted, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 842a22ccebaa..ca24b0c32b73 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1189,6 +1189,8 @@ def classify( /, *, use_tqdm: Union[bool, Callable[..., tqdm]] = True, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, ) -> list[ClassificationRequestOutput]: """ @@ -1207,7 +1209,8 @@ def classify( it is used to create the progress bar. If `False`, no progress bar is created. lora_request: LoRA request to use for generation, if any. - + pooling_params: The pooling parameters for pooling. If None, we + use the default pooling parameters. Returns: A list of `ClassificationRequestOutput` objects containing the embedding vectors in the same order as the input prompts. @@ -1220,6 +1223,7 @@ def classify( items = self.encode( prompts, use_tqdm=use_tqdm, + pooling_params=pooling_params, lora_request=lora_request, pooling_task="classify", ) @@ -1272,6 +1276,7 @@ def _embedding_score( text_2: list[Union[str, TextPrompt, TokensPrompt]], truncate_prompt_tokens: Optional[int] = None, use_tqdm: Union[bool, Callable[..., tqdm]] = True, + pooling_params: Optional[PoolingParams] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, ) -> list[ScoringRequestOutput]: @@ -1280,6 +1285,7 @@ def _embedding_score( truncate_prompt_tokens=truncate_prompt_tokens, use_tqdm=use_tqdm, lora_request=lora_request, + pooling_params=pooling_params, pooling_task="embed", ) @@ -1306,6 +1312,7 @@ def _cross_encoding_score( data_2: Union[list[str], list[ScoreContentPartParam]], truncate_prompt_tokens: Optional[int] = None, use_tqdm: Union[bool, Callable[..., tqdm]] = True, + pooling_params: Optional[PoolingParams] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, ) -> list[ScoringRequestOutput]: model_config = self.llm_engine.model_config @@ -1317,7 +1324,12 @@ def _cross_encoding_score( if len(data_1) == 1: data_1 = data_1 * len(data_2) - pooling_params = PoolingParams(task="score") + if pooling_params is None: + pooling_params = PoolingParams(task="score") + + model_config = self.llm_engine.model_config + pooling_params.verify("score", model_config) + tokenization_kwargs: dict[str, Any] = {} _validate_truncation_size(model_config.max_model_len, @@ -1379,6 +1391,7 @@ def score( *, truncate_prompt_tokens: Optional[int] = None, use_tqdm: Union[bool, Callable[..., tqdm]] = True, + pooling_params: Optional[PoolingParams] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, ) -> list[ScoringRequestOutput]: """Generate similarity scores for all pairs `` or @@ -1410,7 +1423,8 @@ def score( it is used to create the progress bar. If `False`, no progress bar is created. lora_request: LoRA request to use for generation, if any. - + pooling_params: The pooling parameters for pooling. If None, we + use the default pooling parameters. Returns: A list of `ScoringRequestOutput` objects containing the generated scores in the same order as the input prompts. @@ -1494,6 +1508,7 @@ def ensure_str(prompt: SingletonPrompt): data_2, # type: ignore[arg-type] truncate_prompt_tokens, use_tqdm, + pooling_params, lora_request) else: return self._embedding_score( @@ -1502,6 +1517,7 @@ def ensure_str(prompt: SingletonPrompt): data_2, # type: ignore[arg-type] truncate_prompt_tokens, use_tqdm, + pooling_params, lora_request) def start_profile(self) -> None: diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index d77aee345843..64f2beb14021 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1274,11 +1274,13 @@ class EmbeddingCompletionRequest(OpenAIBaseModel): "not set it, a random_uuid will be generated. This id is used " "through out the inference process and return in response."), ) + normalize: Optional[bool] = None # --8<-- [end:embedding-extra-params] def to_pooling_params(self): - return PoolingParams(dimensions=self.dimensions) + return PoolingParams(dimensions=self.dimensions, + normalize=self.normalize) class EmbeddingChatRequest(OpenAIBaseModel): @@ -1332,6 +1334,7 @@ class EmbeddingChatRequest(OpenAIBaseModel): "not set it, a random_uuid will be generated. This id is used " "through out the inference process and return in response."), ) + normalize: Optional[bool] = None # --8<-- [end:chat-embedding-extra-params] @model_validator(mode="before") @@ -1344,7 +1347,8 @@ def check_generation_prompt(cls, data): return data def to_pooling_params(self): - return PoolingParams(dimensions=self.dimensions) + return PoolingParams(dimensions=self.dimensions, + normalize=self.normalize) EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest] @@ -1375,10 +1379,12 @@ class ScoreRequest(OpenAIBaseModel): "if the served model does not use priority scheduling."), ) + activation: Optional[bool] = None + # --8<-- [end:score-extra-params] def to_pooling_params(self): - return PoolingParams() + return PoolingParams(activation=self.activation) class RerankRequest(OpenAIBaseModel): @@ -1403,10 +1409,12 @@ class RerankRequest(OpenAIBaseModel): "if the served model does not use priority scheduling."), ) + activation: Optional[bool] = None + # --8<-- [end:rerank-extra-params] def to_pooling_params(self): - return PoolingParams() + return PoolingParams(activation=self.activation) class RerankDocument(BaseModel): @@ -1553,10 +1561,12 @@ class ClassificationRequest(OpenAIBaseModel): "if the served model does not use priority scheduling."), ) + activation: Optional[bool] = None + # --8<-- [end:classification-extra-params] def to_pooling_params(self): - return PoolingParams() + return PoolingParams(activation=self.activation) class ClassificationData(OpenAIBaseModel): diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 5bfd4aaccc17..0f2e58eb9b5d 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -41,35 +41,18 @@ class PoolingType(IntEnum): @dataclass(frozen=True) class ResolvedPoolingConfig: pooling_type: PoolingType - - normalize: bool - softmax: bool - step_tag_id: Optional[int] - returned_token_ids: Optional[list[int]] + task: PoolingTask @classmethod def from_config_with_defaults( cls, + task: PoolingTask, pooler_config: PoolerConfig, pooling_type: PoolingType, - normalize: bool, - softmax: bool, - step_tag_id: Optional[int] = None, - returned_token_ids: Optional[list[int]] = None, ) -> "ResolvedPoolingConfig": - return cls( - pooling_type=PoolingType[pooler_config.pooling_type] - if pooler_config.pooling_type is not None else pooling_type, - normalize=pooler_config.normalize - if pooler_config.normalize is not None else normalize, - softmax=pooler_config.softmax - if pooler_config.softmax is not None else softmax, - step_tag_id=pooler_config.step_tag_id - if pooler_config.step_tag_id is not None else step_tag_id, - returned_token_ids=pooler_config.returned_token_ids - if pooler_config.returned_token_ids is not None else - returned_token_ids, - ) + return cls(task=task, + pooling_type=PoolingType[pooler_config.pooling_type] + if pooler_config.pooling_type is not None else pooling_type) @dataclass(frozen=True) @@ -89,22 +72,15 @@ def for_encode( pooler_config: PoolerConfig, *, default_pooling_type: PoolingType = PoolingType.ALL, - default_normalize: bool = False, - default_softmax: bool = False, - default_step_tag_id: Optional[int] = None, - default_returned_token_ids: Optional[list[int]] = None, ): resolved_config = ResolvedPoolingConfig.from_config_with_defaults( + task="encode", pooler_config=pooler_config, pooling_type=default_pooling_type, - normalize=default_normalize, - softmax=default_softmax, - step_tag_id=default_step_tag_id, - returned_token_ids=default_returned_token_ids, ) if resolved_config.pooling_type == PoolingType.STEP: - return StepPooler.from_config(resolved_config) + return StepPooler() return SimplePooler.from_config(resolved_config) @@ -113,14 +89,11 @@ def for_embed( pooler_config: PoolerConfig, *, default_pooling_type: PoolingType = PoolingType.LAST, - default_normalize: bool = True, - default_softmax: bool = False, ): resolved_config = ResolvedPoolingConfig.from_config_with_defaults( + task="embed", pooler_config=pooler_config, pooling_type=default_pooling_type, - normalize=default_normalize, - softmax=default_softmax, ) return SimplePooler.from_config(resolved_config) @@ -131,23 +104,18 @@ def for_classify( classifier: Optional[ClassifierFn], *, default_pooling_type: PoolingType = PoolingType.LAST, - default_normalize: bool = False, - default_softmax: bool = True, ): resolved_config = ResolvedPoolingConfig.from_config_with_defaults( + task="classify", pooler_config=pooler_config, pooling_type=default_pooling_type, - normalize=default_normalize, - softmax=default_softmax, ) - base_pooler = SimplePooler.from_config(resolved_config) - if classifier is None: - return base_pooler + + pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type) return ClassifierPooler( - pooling=base_pooler.pooling, + pooling=pooling, classifier=classifier, - act_fn=base_pooler.head.activation, ) @abstractmethod @@ -198,11 +166,17 @@ def get_prompt_token_ids( ] -def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]: +def get_pooling_params( + pooling_metadata: PoolingMetadata) -> list[PoolingParams]: if isinstance(pooling_metadata, V0PoolingMetadata): pooling_params = [p for _, p in pooling_metadata.seq_groups] else: pooling_params = pooling_metadata.pooling_params + return pooling_params + + +def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]: + pooling_params = get_pooling_params(pooling_metadata) tasks: list[PoolingTask] = [ task for pooling_param in pooling_params @@ -484,49 +458,30 @@ def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: class PoolerHead(nn.Module): - @classmethod - def from_config(cls, pooler_config: ResolvedPoolingConfig) -> "PoolerHead": - if pooler_config.normalize and pooler_config.softmax: - raise ValueError("`normalize=True` and `softmax=True` should not " - "be set together") - - activation: PoolerActivation - if pooler_config.normalize: - activation = PoolerNormalize() - elif pooler_config.softmax: - activation = PoolerClassify() - else: - activation = PoolerIdentity() - - return cls(activation) - def __init__(self, activation: PoolerActivation) -> None: super().__init__() - self.activation = activation def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], pooling_metadata: PoolingMetadata): - # Using float32 in PoolerHead - if isinstance(pooled_data, list): - for i in range(len(pooled_data)): - pooled_data[i] = pooled_data[i].to(torch.float32) - else: - pooled_data = pooled_data.to(torch.float32) + return self.activation(pooled_data) + + +class EmbeddingPoolerHead(PoolerHead): + + def __init__(self) -> None: + super().__init__(activation=PoolerNormalize()) + + def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], + pooling_metadata: PoolingMetadata): + + pooling_params = get_pooling_params(pooling_metadata) # for matryoshka representation - if isinstance(pooling_metadata, V0PoolingMetadata): - dimensions_list = [ - pooling_param.dimensions - for _, pooling_param in pooling_metadata.seq_groups - ] - else: - assert isinstance(pooled_data, list) - dimensions_list = [ - pooling_param.dimensions - for pooling_param in pooling_metadata.pooling_params - ] + dimensions_list = [ + pooling_param.dimensions for pooling_param in pooling_params + ] if any(d is not None for d in dimensions_list): # change the output dimension assert len(pooled_data) == len(dimensions_list) @@ -541,7 +496,41 @@ def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], for vecs, d in zip(pooled_data, dimensions_list) ] - return self.activation(pooled_data) + # for normalize + flags = [p.normalize for p in pooling_params] + if len(set(flags)) == 1: + if flags[0]: + pooled_data = self.activation(pooled_data) + else: + pooled_data = [ + self.activation(vecs) if f else vecs + for vecs, f in zip(pooled_data, flags) + ] + + return pooled_data + + +class RewardPoolerHead(PoolerHead): + + def __init__(self) -> None: + super().__init__(activation=PoolerClassify()) + + def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], + pooling_metadata: PoolingMetadata): + pooling_params = get_pooling_params(pooling_metadata) + + # for softmax + flags = [p.softmax for p in pooling_params] + if len(set(flags)) == 1: + if flags[0]: + pooled_data = self.activation(pooled_data) + else: + pooled_data = [ + self.activation(vecs) if f else vecs + for vecs, f in zip(pooled_data, flags) + ] + + return pooled_data class SimplePooler(Pooler): @@ -559,8 +548,12 @@ def from_config( pooler_config: ResolvedPoolingConfig, ) -> "SimplePooler": pooling = PoolingMethod.from_pooling_type(pooler_config.pooling_type) - head = PoolerHead.from_config(pooler_config) - + if pooler_config.task == "embed": + head = EmbeddingPoolerHead() + elif pooler_config.task == "encode": + head = RewardPoolerHead() + else: + raise NotImplementedError(f"Unknown task: {pooler_config.task}") return cls(pooling, head) def __init__(self, pooling: PoolingMethod, head: PoolerHead) -> None: @@ -587,29 +580,11 @@ def forward( class StepPooler(Pooler): - @classmethod - def from_config(cls, pooler_config: ResolvedPoolingConfig) -> "StepPooler": - assert pooler_config.pooling_type == PoolingType.STEP - - return cls( - PoolerHead.from_config(pooler_config), - step_tag_id=pooler_config.step_tag_id, - returned_token_ids=pooler_config.returned_token_ids, - ) - - def __init__( - self, - head: PoolerHead, - *, - step_tag_id: Optional[int] = None, - returned_token_ids: Optional[list[int]] = None, - ) -> None: + def __init__(self, ) -> None: super().__init__() self.pooling = AllPool() - self.head = head - self.step_tag_id = step_tag_id - self.returned_token_ids = returned_token_ids + self.head = RewardPoolerHead() def extract_states( self, @@ -620,10 +595,15 @@ def extract_states( prompt_token_ids = get_prompt_token_ids(pooling_metadata) pooled_data = list[torch.Tensor]() - returned_token_ids = self.returned_token_ids - step_tag_id = self.step_tag_id - for data, token_id in zip(pooled_data_lst, prompt_token_ids): + pooling_params = get_pooling_params(pooling_metadata) + + for data, token_id, pooling_param in zip(pooled_data_lst, + prompt_token_ids, + pooling_params): + step_tag_id = pooling_param.step_tag_id + returned_token_ids = pooling_param.returned_token_ids + if returned_token_ids is not None and len(returned_token_ids) > 0: data = data[:, returned_token_ids] @@ -669,14 +649,14 @@ def act_fn_for_cross_encoder(config: ModelConfig): def __init__( self, pooling: PoolingFn, - classifier: ClassifierFn, - act_fn: PoolerActivation, + classifier: Optional[ClassifierFn], + act_fn: Optional[PoolerActivation] = None, ) -> None: super().__init__() self.pooling = pooling self.classifier = classifier - self.act_fn = act_fn + self.act_fn = act_fn or PoolerClassify() def get_supported_tasks(self) -> Set[PoolingTask]: return {"classify", "score"} @@ -688,15 +668,25 @@ def forward( ) -> PoolerOutput: pooled_data = self.pooling(hidden_states, pooling_metadata) - # apply classifier once on the full batch if possible - if isinstance(pooled_data, torch.Tensor): - pooled_output = self.classifier(pooled_data) - elif len({data.shape for data in pooled_data}) <= 1: - pooled_output = self.classifier(torch.stack(pooled_data)) - else: - pooled_output = [self.classifier(data) for data in pooled_data] + if self.classifier is not None: + # apply classifier once on the full batch if possible + if isinstance(pooled_data, torch.Tensor): + pooled_data = self.classifier(pooled_data) + elif len({data.shape for data in pooled_data}) <= 1: + pooled_data = self.classifier(torch.stack(pooled_data)) + else: + pooled_data = [self.classifier(data) for data in pooled_data] - scores = self.act_fn(pooled_output) + pooling_params = get_pooling_params(pooling_metadata) + flags = [p.activation for p in pooling_params] + + if len(set(flags)) == 1: + scores = self.act_fn(pooled_data) if flags[0] else pooled_data + else: + scores = [ + self.act_fn(vecs) if f else vecs + for vecs, f in zip(pooled_data, flags) + ] return build_output(scores) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 9030ff307bee..6f09be7a5941 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -44,6 +44,15 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None: } +class JambaForSequenceClassificationConfig(VerifyAndUpdateConfig): + + @staticmethod + def verify_and_update_config(vllm_config: "VllmConfig") -> None: + pooler_config = vllm_config.model_config.pooler_config + if pooler_config.activation is None: + pooler_config.activation = False + + class JinaRobertaModelConfig(VerifyAndUpdateConfig): @staticmethod @@ -155,6 +164,26 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None: vllm_config.recalculate_max_model_len(max_model_len) +class Qwen2ForProcessRewardModelConfig(VerifyAndUpdateConfig): + + @staticmethod + def verify_and_update_config(vllm_config: "VllmConfig") -> None: + pooler_config = vllm_config.model_config.pooler_config + + if pooler_config.step_tag_id is None: + pooler_config.step_tag_id = 151651 + + +class Qwen2ForRewardModelConfig(VerifyAndUpdateConfig): + + @staticmethod + def verify_and_update_config(vllm_config: "VllmConfig") -> None: + pooler_config = vllm_config.model_config.pooler_config + + if pooler_config.softmax is None: + pooler_config.softmax = False + + class Qwen3ForSequenceClassificationConfig(VerifyAndUpdateConfig): @staticmethod @@ -309,8 +338,11 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: "GteModel": SnowflakeGteNewModelConfig, "GteNewModel": GteNewModelConfig, "NomicBertModel": NomicBertModelConfig, + "Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig, + "Qwen2ForRewardModel": Qwen2ForRewardModelConfig, "Qwen3ForSequenceClassification": Qwen3ForSequenceClassificationConfig, "XLMRobertaModel": JinaRobertaModelConfig, "JinaVLForRanking": JinaVLForSequenceClassificationConfig, + "JambaForSequenceClassification": JambaForSequenceClassificationConfig, "GraniteMoeHybridForCausalLM": GraniteMoeHybridModelConfig, } diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 263f4c8379cf..ab21b7ce2c5f 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -593,7 +593,5 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): pooler_config, classifier=self.score, default_pooling_type=PoolingType.LAST, - default_normalize=False, - default_softmax=False, ), }) diff --git a/vllm/model_executor/models/jina_vl.py b/vllm/model_executor/models/jina_vl.py index 0c4284f7daaa..8c64f636c6a0 100644 --- a/vllm/model_executor/models/jina_vl.py +++ b/vllm/model_executor/models/jina_vl.py @@ -90,15 +90,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=maybe_prefix(prefix, "qwen2_vl")) config = vllm_config.model_config.hf_config pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None # logit bias for sigmoid normalization self.LOGIT_BIAS = 2.65 self.score = JinaVLScorer(config) - - pooler_config = vllm_config.model_config.pooler_config - assert pooler_config is not None - self.pooler = DispatchPooler({ "encode": Pooler.for_encode(pooler_config), diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py index f12e9a041a94..9b6b70c75c34 100644 --- a/vllm/model_executor/models/qwen2_rm.py +++ b/vllm/model_executor/models/qwen2_rm.py @@ -117,8 +117,5 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): Pooler.for_encode( pooler_config, default_pooling_type=PoolingType.STEP, - default_normalize=False, - default_softmax=True, - default_step_tag_id=151651, ) }) diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py index 23eb775f2dc6..7077f68353fc 100644 --- a/vllm/pooling_params.py +++ b/vllm/pooling_params.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from copy import deepcopy from typing import TYPE_CHECKING, Optional import msgspec @@ -19,13 +20,25 @@ class PoolingParams( """API parameters for pooling models. Attributes: + normalize: Whether to normalize the embeddings outputs. dimensions: Reduce the dimensions of embeddings if model support matryoshka representation. + activation: Whether to apply activation function to + the classification outputs. + softmax: Whether to apply softmax to the reward outputs. """ + ## for embeddings models dimensions: Optional[int] = None + normalize: Optional[bool] = None - output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY + ## for classification models + activation: Optional[bool] = None + + ## for reward models + softmax: Optional[bool] = None + step_tag_id: Optional[int] = None + returned_token_ids: Optional[list[int]] = None task: Optional[PoolingTask] = None """Internal use only.""" @@ -33,15 +46,32 @@ class PoolingParams( requires_token_ids: bool = False """Internal use only.""" + output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY + + @property + def all_parameters(self) -> list[str]: + return [ + "dimensions", "normalize", "activation", "softmax", "step_tag_id", + "returned_token_ids" + ] + + @property + def valid_parameters(self): + return { + "embed": ["dimensions", "normalize"], + "classify": ["activation"], + "score": ["activation"], + "encode": ["softmax", "step_tag_id", "returned_token_ids"], + } + def clone(self) -> "PoolingParams": """Returns a deep copy of the PoolingParams instance.""" - return PoolingParams( - dimensions=self.dimensions, - task=self.task, - requires_token_ids=self.requires_token_ids, - ) + return deepcopy(self) + + def verify(self, + task: PoolingTask, + model_config: Optional["ModelConfig"] = None) -> None: - def verify(self, task: PoolingTask, model_config: "ModelConfig") -> None: if self.task is None: self.task = task elif self.task != task: @@ -52,28 +82,91 @@ def verify(self, task: PoolingTask, model_config: "ModelConfig") -> None: # which is not available in model config. So, it's not included # in this method - if self.dimensions is not None: - if not model_config.is_matryoshka: - raise ValueError( - f'Model "{model_config.served_model_name}" does not ' - f'support matryoshka representation, ' - f'changing output dimensions will lead to poor results.') + self._merge_default_parameters(model_config) + self._set_default_parameters(model_config) + self._verify_valid_parameters() + + def _merge_default_parameters(self, + model_config: Optional["ModelConfig"] = None + ) -> None: + + if model_config is None: + return - mds = model_config.matryoshka_dimensions - if mds is not None: - if self.dimensions not in mds: + pooler_config = model_config.pooler_config + if pooler_config is None: + return + + assert self.task is not None, "task must be set" + valid_parameters = self.valid_parameters[self.task] + + for k in valid_parameters: + if getattr(pooler_config, k, None) is None: + continue + + if getattr(self, k, None) is None: + setattr(self, k, getattr(pooler_config, k)) + + def _set_default_parameters(self, model_config: Optional["ModelConfig"]): + if self.task == "embed": + if self.normalize is None: + self.normalize = True + + if self.dimensions is not None and model_config is not None: + if not model_config.is_matryoshka: raise ValueError( - f'Model "{model_config.served_model_name}" ' - f'only supports {str(mds)} matryoshka dimensions, ' - f'use other output dimensions will ' - f'lead to poor results.') - elif self.dimensions < 1: - raise ValueError("Dimensions must be greater than 0") + f'Model "{model_config.served_model_name}" does not ' + f'support matryoshka representation, ' + f'changing output dimensions will lead to poor results.' + ) + + mds = model_config.matryoshka_dimensions + if mds is not None: + if self.dimensions not in mds: + raise ValueError( + f'Model "{model_config.served_model_name}" ' + f'only supports {str(mds)} matryoshka dimensions, ' + f'use other output dimensions will ' + f'lead to poor results.') + elif self.dimensions < 1: + raise ValueError("Dimensions must be greater than 0") + + elif self.task in ["classify", "score"]: + if self.activation is None: + self.activation = True + + elif self.task == "encode": + if self.softmax is None: + self.softmax = True + else: + raise ValueError(f"Unknown pooling task: {self.task}") + + def _verify_valid_parameters(self): + assert self.task is not None, "task must be set" + valid_parameters = self.valid_parameters[self.task] + invalid_parameters = [] + for k in self.all_parameters: + if k in valid_parameters: + continue + + if getattr(self, k, None) is not None: + invalid_parameters.append(k) + + if invalid_parameters: + raise ValueError( + f"Task {self.task} only supports {valid_parameters} " + f"parameters, does not support " + f"{invalid_parameters} parameters") def __repr__(self) -> str: return (f"PoolingParams(" - f"dimensions={self.dimensions}, " f"task={self.task}, " + f"normalize={self.normalize}, " + f"dimensions={self.dimensions}, " + f"activation={self.activation}, " + f"softmax={self.softmax}, " + f"step_tag_id={self.step_tag_id}, " + f"returned_token_ids={self.returned_token_ids}, " f"requires_token_ids={self.requires_token_ids})") def __post_init__(self) -> None: