diff --git a/tests/conftest.py b/tests/conftest.py index f50e611a471b..feb52e26300a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1027,13 +1027,13 @@ def classify(self, prompts: list[str]) -> list[list[float]]: req_outputs = self.model.classify(prompts) return [req_output.outputs.probs for req_output in req_outputs] - def encode(self, - prompts: list[str], - images: Optional[PromptImageInput] = None, - videos: Optional[PromptVideoInput] = None, - audios: Optional[PromptAudioInput] = None, - *args, - **kwargs) -> list[list[float]]: + def embed(self, + prompts: list[str], + images: Optional[PromptImageInput] = None, + videos: Optional[PromptVideoInput] = None, + audios: Optional[PromptAudioInput] = None, + *args, + **kwargs) -> list[list[float]]: inputs = self.get_inputs(prompts, images=images, videos=videos, @@ -1042,6 +1042,10 @@ def encode(self, req_outputs = self.model.embed(inputs, *args, **kwargs) return [req_output.outputs.embedding for req_output in req_outputs] + def encode(self, prompts: list[str]) -> list[list[float]]: + req_outputs = self.model.encode(prompts) + return [req_output.outputs.data for req_output in req_outputs] + def score( self, text_1: Union[str, list[str]], diff --git a/tests/model_executor/test_model_load_with_params.py b/tests/model_executor/test_model_load_with_params.py index 94a14bd24bcb..4bdb651e5170 100644 --- a/tests/model_executor/test_model_load_with_params.py +++ b/tests/model_executor/test_model_load_with_params.py @@ -29,8 +29,8 @@ def test_model_loading_with_params(vllm_runner): revision=REVISION, dtype="float16", max_model_len=MAX_MODEL_LEN) as vllm_model: - output = vllm_model.encode("Write a short story about a robot that" - " dreams for the first time.\n") + output = vllm_model.embed("Write a short story about a robot that" + " dreams for the first time.\n") model_config = vllm_model.model.llm_engine.model_config model_tokenizer = vllm_model.model.llm_engine.tokenizer @@ -67,8 +67,8 @@ def test_roberta_model_loading_with_params(vllm_runner): revision=REVISION_ROBERTA, dtype="float16", max_model_len=MAX_MODEL_LEN) as vllm_model: - output = vllm_model.encode("Write a short story about a robot that" - " dreams for the first time.\n") + output = vllm_model.embed("Write a short story about a robot that" + " dreams for the first time.\n") model_config = vllm_model.model.llm_engine.model_config model_tokenizer = vllm_model.model.llm_engine.tokenizer @@ -105,8 +105,8 @@ def test_facebook_roberta_model_loading_with_params(vllm_runner): with vllm_runner(model_name=model_name, dtype="float16", max_model_len=MAX_MODEL_LEN) as vllm_model: - output = vllm_model.encode("Write a short story about a robot that" - " dreams for the first time.\n") + output = vllm_model.embed("Write a short story about a robot that" + " dreams for the first time.\n") model_tokenizer = vllm_model.model.llm_engine.tokenizer assert model_tokenizer.tokenizer_id == model_name diff --git a/tests/models/language/pooling/embed_utils.py b/tests/models/language/pooling/embed_utils.py index dabd7bee7f39..a663679a9c7c 100644 --- a/tests/models/language/pooling/embed_utils.py +++ b/tests/models/language/pooling/embed_utils.py @@ -55,7 +55,7 @@ def correctness_test_embed_models(hf_runner, task="embed", max_model_len=None, **vllm_extra_kwargs) as vllm_model: - vllm_outputs = vllm_model.encode(example_prompts) + vllm_outputs = vllm_model.embed(example_prompts) with hf_runner( model_info.name, diff --git a/tests/models/language/pooling/test_embedding.py b/tests/models/language/pooling/test_embedding.py index 5ef9f768c574..b8b17524cf07 100644 --- a/tests/models/language/pooling/test_embedding.py +++ b/tests/models/language/pooling/test_embedding.py @@ -89,7 +89,7 @@ def test_models( task="embed", max_model_len=512, **vllm_extra_kwargs) as vllm_model: - vllm_outputs = vllm_model.encode(example_prompts) + vllm_outputs = vllm_model.embed(example_prompts) check_embeddings_close( embeddings_0_lst=hf_outputs, diff --git a/tests/models/language/pooling/test_jina.py b/tests/models/language/pooling/test_jina.py index 0c44683e7486..0bc189d82b8a 100644 --- a/tests/models/language/pooling/test_jina.py +++ b/tests/models/language/pooling/test_jina.py @@ -98,11 +98,11 @@ def test_matryoshka( if dimensions not in matryoshka_dimensions: with pytest.raises(ValueError): - vllm_model.encode( + vllm_model.embed( example_prompts, pooling_params=PoolingParams(dimensions=dimensions)) else: - vllm_outputs = vllm_model.encode( + vllm_outputs = vllm_model.embed( example_prompts, pooling_params=PoolingParams(dimensions=dimensions)) diff --git a/tests/models/language/pooling/test_reward.py b/tests/models/language/pooling/test_reward.py new file mode 100644 index 000000000000..085cdca9f1f3 --- /dev/null +++ b/tests/models/language/pooling/test_reward.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch +import torch.nn.functional as F +from transformers import AutoModel + +from vllm.platforms import current_platform + +from ....conftest import HfRunner + + +@pytest.fixture(autouse=True) +def v1(run_with_both_engines): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + +@pytest.fixture +def math_step_prompts(): + # ruff: noqa: E501 + data = { + "system": + "Please reason step by step, and put your final answer within \\boxed{}. ", + "query": + "Sue lives in a fun neighborhood. One weekend, the neighbors decided to play a prank on Sue. On Friday morning, the neighbors placed 18 pink plastic flamingos out on Sue's front yard. On Saturday morning, the neighbors took back one third of the flamingos, painted them white, and put these newly painted white flamingos back out on Sue's front yard. Then, on Sunday morning, they added another 18 pink plastic flamingos to the collection. At noon on Sunday, how many more pink plastic flamingos were out than white plastic flamingos?", + "response": [ + "To find out how many more pink plastic flamingos were out than white plastic flamingos at noon on Sunday, we can break down the problem into steps. First, on Friday, the neighbors start with 18 pink plastic flamingos.", + "On Saturday, they take back one third of the flamingos. Since there were 18 flamingos, (1/3 \\times 18 = 6) flamingos are taken back. So, they have (18 - 6 = 12) flamingos left in their possession. Then, they paint these 6 flamingos white and put them back out on Sue's front yard. Now, Sue has the original 12 pink flamingos plus the 6 new white ones. Thus, by the end of Saturday, Sue has (12 + 6 = 18) pink flamingos and 6 white flamingos.", + "On Sunday, the neighbors add another 18 pink plastic flamingos to Sue's front yard. By the end of Sunday morning, Sue has (18 + 18 = 36) pink flamingos and still 6 white flamingos.", + "To find the difference, subtract the number of white flamingos from the number of pink flamingos: (36 - 6 = 30). Therefore, at noon on Sunday, there were 30 more pink plastic flamingos out than white plastic flamingos. The answer is (\\boxed{30}).", + ], + } + answer = "".join(data['response']) + "" + prompt = f"system\n{data['system']}\nuser\n{data['query']}\nassistant\n{answer}<|endoftext|>" + return [prompt] + + +def step_reward_patch_hf_model(hf_model: HfRunner): + + # Patch the hf_runner to use the step reward function + def make_step_rewards(logits: torch.Tensor, + token_masks: torch.Tensor) -> list[list[float]]: + probabilities = F.softmax(logits, dim=-1) + probabilities = probabilities * token_masks.unsqueeze(-1) + + all_scores_res: list[list[float]] = [] + for i in range(probabilities.size(0)): + sample = probabilities[i] # seq_len, num_labels + positive_probs = sample[sample != 0].view(-1, 2) + non_zero_elements_list = positive_probs.cpu().tolist() + all_scores_res.append(non_zero_elements_list) + return all_scores_res + + def reward(prompts: list[str]) -> list[list[float]]: + input_ids = hf_model.tokenizer(prompts, return_tensors="pt").input_ids + input_ids = hf_model.wrap_device(input_ids) + outputs = hf_model.model(input_ids=input_ids) + + step_sep_id = hf_model.tokenizer.encode("")[0] + token_masks = (input_ids == step_sep_id) + return make_step_rewards(outputs[0], token_masks) + + hf_model.reward = reward # type: ignore[attr-defined] + + return hf_model + + +@pytest.mark.parametrize( + "model", + [ + pytest.param("Qwen/Qwen2.5-Math-PRM-7B", + marks=[pytest.mark.core_model, pytest.mark.cpu_model]), + ], +) +@pytest.mark.parametrize("dtype", ["half"]) +def test_prm_models( + hf_runner, + vllm_runner, + math_step_prompts, + model: str, + dtype: str, + monkeypatch, +) -> None: + if current_platform.is_rocm(): + # ROCm Triton FA does not currently support sliding window attention + # switch to use ROCm CK FA backend + monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False") + + with vllm_runner(model, max_model_len=1024, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.encode(math_step_prompts) + + with hf_runner(model, dtype=dtype, auto_cls=AutoModel) as hf_model: + hf_model = step_reward_patch_hf_model(hf_model) + hf_outputs = hf_model.reward(math_step_prompts) + + # check logits difference + for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): + hf_output = torch.tensor(hf_output) + vllm_output = torch.tensor(vllm_output) + + assert torch.allclose(hf_output, vllm_output, 1e-2) diff --git a/tests/models/multimodal/pooling/test_dse_qwen2_vl.py b/tests/models/multimodal/pooling/test_dse_qwen2_vl.py index 3734d87b7962..f889eea5e839 100644 --- a/tests/models/multimodal/pooling/test_dse_qwen2_vl.py +++ b/tests/models/multimodal/pooling/test_dse_qwen2_vl.py @@ -98,7 +98,7 @@ def _run_test( max_model_len=8192) as vllm_model: tokenizer = vllm_model.model.get_tokenizer() texts = [ - # this is necessary because vllm_model.encode will not apply any + # this is necessary because vllm_model.embed will not apply any # templating to the prompt, and therefore lacks an image_pad # token unless one is inserted beforehand (the (28,28) image # above is converted to an image pad token by the chat template). @@ -109,7 +109,7 @@ def _run_test( # vllm will replace the pad token with the actual image, # which may be a placeholder image, later. ] - vllm_outputs = vllm_model.encode(texts, images=input_images) + vllm_outputs = vllm_model.embed(texts, images=input_images) hf_outputs = [] with hf_runner(model, diff --git a/tests/models/multimodal/pooling/test_llava_next.py b/tests/models/multimodal/pooling/test_llava_next.py index b6d90d2b0abe..4a8f5cafbe48 100644 --- a/tests/models/multimodal/pooling/test_llava_next.py +++ b/tests/models/multimodal/pooling/test_llava_next.py @@ -68,7 +68,7 @@ def _run_test( dtype=dtype, max_model_len=4096, enforce_eager=True) as vllm_model: - vllm_outputs = vllm_model.encode(input_texts, images=input_images) + vllm_outputs = vllm_model.embed(input_texts, images=input_images) with hf_runner(model, dtype=dtype, auto_cls=AutoModelForImageTextToText) as hf_model: diff --git a/tests/models/multimodal/pooling/test_phi3v.py b/tests/models/multimodal/pooling/test_phi3v.py index b42ac6fb21ed..9a4b6d3ff8a8 100644 --- a/tests/models/multimodal/pooling/test_phi3v.py +++ b/tests/models/multimodal/pooling/test_phi3v.py @@ -46,7 +46,7 @@ def _run_test( # will hurt multiprocessing backend with fork method (the default method). with vllm_runner(model, task="embed", dtype=dtype, enforce_eager=True) as vllm_model: - vllm_outputs = vllm_model.encode(input_texts, images=input_images) + vllm_outputs = vllm_model.embed(input_texts, images=input_images) # use eager mode for hf runner, since phi3_v didn't work with flash_attn hf_model_kwargs = {"_attn_implementation": "eager"} diff --git a/tests/quantization/test_bitsandbytes.py b/tests/quantization/test_bitsandbytes.py index 8e39ed2fff87..363daa6d27ef 100644 --- a/tests/quantization/test_bitsandbytes.py +++ b/tests/quantization/test_bitsandbytes.py @@ -161,7 +161,7 @@ def test_4bit_bnb_embedding_model( dtype=dtype, gpu_memory_utilization=0.5, quantization="bitsandbytes") as vllm_model: - vllm_outputs = vllm_model.encode(example_prompts) + vllm_outputs = vllm_model.embed(example_prompts) check_embeddings_close( embeddings_0_lst=hf_outputs, embeddings_1_lst=vllm_outputs, diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index eb2148d76452..8a33cd6be405 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -239,25 +239,24 @@ def extract_states( prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata) prompt_token_ids = self.get_prompt_token_ids(pooling_metadata) - pooled_data: list[torch.Tensor] = [] - + pooled_data_lst = list[torch.Tensor]() if isinstance(hidden_states, list): for req_state, prompt_len in zip(hidden_states, prompt_lens): assert prompt_len == req_state.shape[0], \ - "partial prefill not supported with mean pooling" - pooled_data = hidden_states + "partial prefill not supported with step pooling" + pooled_data_lst = hidden_states else: offset = 0 for prompt_len in prompt_lens: pooled_data_i = hidden_states[offset:offset + prompt_len] offset += prompt_len - pooled_data.append(pooled_data_i) + pooled_data_lst.append(pooled_data_i) - pooled_data = [] + pooled_data = list[torch.Tensor]() returned_token_ids = self.returned_token_ids step_tag_id = self.step_tag_id - for data, token_id in zip(pooled_data, prompt_token_ids): + for data, token_id in zip(pooled_data_lst, prompt_token_ids): if returned_token_ids is not None and len(returned_token_ids) > 0: data = data[:, returned_token_ids] diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 0e7e4e73eca9..f759f8f1f273 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -489,6 +489,12 @@ def supports_cross_encoding( return is_pooling_model(model) and _supports_cross_encoding(model) +def has_step_pooler(model: Union[type[object], object]) -> bool: + """Check if the model uses step pooler.""" + return is_pooling_model(model) and any( + type(module).__name__ == "StepPool" for module in model.modules()) + + class SupportsQuant: """The interface required for all models that support quantization.""" diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 3a2c9ef7dfac..ca2bfe831746 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -59,14 +59,15 @@ def get_token_id(self, idx: int) -> int: class InputBatch: def __init__( - self, - max_num_reqs: int, - max_model_len: int, - max_num_batched_tokens: int, - device: torch.device, - pin_memory: bool, - vocab_size: int, - block_sizes: list[int], # The block_size of each kv cache group + self, + max_num_reqs: int, + max_model_len: int, + max_num_batched_tokens: int, + device: torch.device, + pin_memory: bool, + vocab_size: int, + block_sizes: list[int], # The block_size of each kv cache group + logits_processing_needs_token_ids: bool = False, ): self.max_num_reqs = max_num_reqs self.max_model_len = max_model_len @@ -74,6 +75,8 @@ def __init__( self.device = device self.pin_memory = pin_memory self.vocab_size = vocab_size + self.logits_processing_needs_token_ids = ( + logits_processing_needs_token_ids) self._req_ids: list[Optional[str]] = [] self.req_id_to_index: dict[str, int] = {} @@ -579,9 +582,14 @@ def _make_sampling_metadata(self) -> SamplingMetadata: copy_slice(self.repetition_penalties_cpu_tensor, self.repetition_penalties, num_reqs) - # The prompt tokens are used only for applying penalties during - # the sampling process. Hence copy these tensors only when - # there are requests which need penalties to be applied. + needs_prompt_token_ids = (not self.no_penalties or + (self.num_reqs > 0 + and self.logits_processing_needs_token_ids)) + if needs_prompt_token_ids: + # The prompt tokens are used only for applying penalties or + # step pooling during the sampling/pooling process. + # Hence copy these tensors only when there are requests which + # need penalties/step_pooler to be applied. prompt_token_ids = self._make_prompt_token_ids_tensor() else: prompt_token_ids = None diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 330366006118..520d8fb186f4 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -33,6 +33,7 @@ from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader +from vllm.model_executor.models.interfaces import has_step_pooler from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.utils import group_mm_inputs_by_modality @@ -1708,6 +1709,8 @@ def load_model(self) -> None: ) model_loader.load_weights(self.model, model_config=self.model_config) + if has_step_pooler(self.model): + self.input_batch.logits_processing_needs_token_ids = True if self.lora_config: self.model = self.load_lora_model(self.model, self.model_config,