diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 04d7cdc3d885..815ff06a94e3 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -253,6 +253,7 @@ steps: - pytest -v -s v1/engine - pytest -v -s v1/entrypoints - pytest -v -s v1/sample + - pytest -v -s v1/logits_processors - pytest -v -s v1/worker - pytest -v -s v1/structured_output - pytest -v -s v1/spec_decode diff --git a/examples/offline_inference/logits_processor.py b/examples/offline_inference/logits_processor.py new file mode 100644 index 000000000000..7ef20efa7d28 --- /dev/null +++ b/examples/offline_inference/logits_processor.py @@ -0,0 +1,147 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""This example demonstrates instantiating vLLM with a custom logits processor +class object. + +For a basic example of implementing a custom logits processor, see +the `DummyLogitsProcessor` implementation in `vllm/test_utils.py`. + +For testing purposes, a dummy logits processor is employed which, if +`target_token` is passed as a keyword argument to `SamplingParams.extra_args`, +will mask out all tokens except `target_token`. + +A batch is constructed with `temperature=0.0` and 50% of requests specifying +`target_token`, and for these requests - and *only* these requests - we +expect the `target_token` to be decoded in each step, yielding an output +similar to that shown below: + +Generated Outputs: +------------------------------------------------------------ +Prompt: 'Hello, my name is' +Output: " ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '" +------------------------------------------------------------ +Prompt: 'The president of the United States is' +Output: " not a racist. He is a racist.\nHe's a racist because he" +------------------------------------------------------------ +Prompt: 'The capital of France is' +Output: ' also also also also also also also also also also also also also + also also also' +------------------------------------------------------------ +Prompt: 'The future of AI is' +Output: ' in the hands of the people.\n\nThe future of AI is in the' +------------------------------------------------------------ +""" + +from typing import Optional + +import torch + +from vllm import LLM, SamplingParams +from vllm.config import VllmConfig +from vllm.v1.sample.logits_processor import ( + BatchUpdate, + LogitsProcessor, + MoveDirectionality, +) + + +# Hypothetical custom logits processor +class DummyLogitsProcessor(LogitsProcessor): + """Fake logit processor to support unit testing and examples""" + + def __init__( + self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool + ): + self.req_info: dict[int, SamplingParams] = {} + + def is_argmax_invariant(self) -> bool: + """Never impacts greedy sampling""" + return False + + def update_state(self, batch_update: Optional[BatchUpdate]): + if not batch_update: + return + + # Process added requests. + for index, params, _, _ in batch_update.added: + assert params is not None + if params.extra_args and ( + target_token := params.extra_args.get("target_token") + ): + self.req_info[index] = target_token + + if self.req_info: + # Process removed requests. + for index in batch_update.removed: + self.req_info.pop(index, None) + + # Process moved requests, unidirectional move (a->b) and swap + # (a<->b) + for adx, bdx, direct in batch_update.moved: + a_val = self.req_info.pop(adx, None) + b_val = self.req_info.pop(bdx, None) + if a_val is not None: + self.req_info[bdx] = a_val + if direct == MoveDirectionality.SWAP and b_val is not None: + self.req_info[adx] = b_val + + def apply(self, logits: torch.Tensor) -> torch.Tensor: + if not self.req_info: + return logits + + # Save target values before modification + rows_list = list(self.req_info.keys()) + cols = torch.tensor( + [self.req_info[i] for i in rows_list], + dtype=torch.long, + device=logits.device, + ) + rows = torch.tensor(rows_list, dtype=torch.long, device=logits.device) + values_to_keep = logits[rows, cols].clone() + + # Mask all but target tokens + logits[rows] = float("-inf") + logits[rows, cols] = values_to_keep + + return logits + + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a mixture of requests which do and don't utilize the dummy logitproc +sampling_params_list = [ + SamplingParams(temperature=0.0, extra_args={"target_token": 128}), + SamplingParams(temperature=0.0), + SamplingParams(temperature=0.0, extra_args={"target_token": 67}), + SamplingParams(temperature=0.0), +] + + +def main(): + # Create an LLM. + llm = LLM( + model="facebook/opt-125m", + logits_processors=[DummyLogitsProcessor], + ) + # Generate texts from the prompts. + # The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params_list) + # Print the outputs. + print("\nGenerated Outputs:\n" + "-" * 60) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}") + print(f"Output: {generated_text!r}") + print("-" * 60) + + +if __name__ == "__main__": + main() diff --git a/tests/utils.py b/tests/utils.py index 18fcde949160..e98707fb4447 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -13,6 +13,7 @@ import time import warnings from contextlib import contextmanager, suppress +from multiprocessing import Process from pathlib import Path from typing import Any, Callable, Literal, Optional, Union @@ -76,6 +77,23 @@ def _nvml(): class RemoteOpenAIServer: DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key + def _start_server(self, model: str, vllm_serve_args: list[str], + env_dict: Optional[dict[str, str]]) -> None: + """Subclasses override this method to customize server process launch + """ + env = os.environ.copy() + # the current process might initialize cuda, + # to be safe, we should use spawn method + env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' + if env_dict is not None: + env.update(env_dict) + self.proc: subprocess.Popen = subprocess.Popen( + ["vllm", "serve", model, *vllm_serve_args], + env=env, + stdout=sys.stdout, + stderr=sys.stderr, + ) + def __init__(self, model: str, vllm_serve_args: list[str], @@ -128,18 +146,7 @@ def __init__(self, model_loader = get_model_loader(load_config) model_loader.download_model(model_config) - env = os.environ.copy() - # the current process might initialize cuda, - # to be safe, we should use spawn method - env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' - if env_dict is not None: - env.update(env_dict) - self.proc = subprocess.Popen( - ["vllm", "serve", model, *vllm_serve_args], - env=env, - stdout=sys.stdout, - stderr=sys.stderr, - ) + self._start_server(model, vllm_serve_args, env_dict) max_wait_seconds = max_wait_seconds or 240 self._wait_for_server(url=self.url_for("health"), timeout=max_wait_seconds) @@ -155,6 +162,10 @@ def __exit__(self, exc_type, exc_value, traceback): # force kill if needed self.proc.kill() + def _poll(self) -> Optional[int]: + """Subclasses override this method to customize process polling""" + return self.proc.poll() + def _wait_for_server(self, *, url: str, timeout: float): # run health check start = time.time() @@ -169,7 +180,7 @@ def _wait_for_server(self, *, url: str, timeout: float): # which means the server is not ready yet. # the stack trace is not useful, so we suppress it # by using `raise from None`. - result = self.proc.poll() + result = self._poll() if result is not None and result != 0: raise RuntimeError("Server exited unexpectedly.") from None @@ -205,6 +216,48 @@ def get_async_client(self, **kwargs): **kwargs) +class RemoteOpenAIServerCustom(RemoteOpenAIServer): + """Launch test server with custom child process""" + + def _start_server(self, model: str, vllm_serve_args: list[str], + env_dict: Optional[dict[str, str]]) -> None: + self.proc: Process = Process( + target=self.child_process_fxn, + args=(env_dict, model, + vllm_serve_args)) # type: ignore[assignment] + self.proc.start() + + def __init__(self, + model: str, + vllm_serve_args: list[str], + child_process_fxn: Callable[ + [Optional[dict[str, str]], str, list[str]], None], + *, + env_dict: Optional[dict[str, str]] = None, + seed: Optional[int] = 0, + auto_port: bool = True, + max_wait_seconds: Optional[float] = None) -> None: + """Store custom child process function then invoke superclass + constructor which will indirectly launch it.""" + self.child_process_fxn = child_process_fxn + super().__init__(model=model, + vllm_serve_args=vllm_serve_args, + env_dict=env_dict, + seed=seed, + auto_port=auto_port, + max_wait_seconds=max_wait_seconds) + + def _poll(self) -> Optional[int]: + return self.proc.exitcode + + def __exit__(self, exc_type, exc_value, traceback): + self.proc.terminate() + self.proc.join(8) + if self.proc.is_alive(): + # force kill if needed + self.proc.kill() + + def _test_completion( client: openai.OpenAI, model: str, diff --git a/tests/v1/logits_processors/__init__.py b/tests/v1/logits_processors/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/v1/sample/test_logits_processors.py b/tests/v1/logits_processors/test_correctness.py similarity index 97% rename from tests/v1/sample/test_logits_processors.py rename to tests/v1/logits_processors/test_correctness.py index 84ee3b0392b4..43caef79b02f 100644 --- a/tests/v1/sample/test_logits_processors.py +++ b/tests/v1/logits_processors/test_correctness.py @@ -9,11 +9,13 @@ import pytest import torch +from tests.utils import create_new_process_for_each_test from tests.v1.sample.utils import (LogitsprocsTestFakes, create_fake_logits, create_penalty_tensor, create_prompt_tokens_tensor, fake_apply_logitsprocs, fake_update_logitsprocs_state) +from vllm.config import VllmConfig from vllm.platforms import current_platform from vllm.sampling_params import SamplingParams from vllm.utils import is_pin_memory_available @@ -24,7 +26,7 @@ MinPLogitsProcessor, MinTokensLogitsProcessor, MoveDirectionality, - init_builtin_logitsprocs) + build_logitsprocs) # yapf: enable from vllm.v1.sample.metadata import SamplingMetadata @@ -53,6 +55,7 @@ class LogitsProcsRequestParams: workload_index: int logitproc_type: LogitprocType # Logitproc enabled, specified by str id out_tokens: list[int] # Output tokens required for min tokens test + prompt_tokens: list[int] # Dummy prompt tokens placeholder params: SamplingParams # Settings customized for logitproc def __init__(self, workload_index: int, logitproc_type: LogitprocType): @@ -63,6 +66,7 @@ def __init__(self, workload_index: int, logitproc_type: LogitprocType): # don't matter *for these tests* so use 0 as a dummy value self.out_tokens = ([0] * (MIN_TOKENS_LEN_THRESHOLD * random.randint(0, 2))) + self.prompt_tokens = [] self.params = _sampling_params_from_logitproc(logitproc_type) def __str__(self): @@ -88,11 +92,12 @@ def _generate_fake_sampling_metadata( vocab_size, size=np.random.randint( 1, MAX_NUM_PROMPT_TOKENS)).tolist()) - logitsprocs = init_builtin_logitsprocs( - pin_memory_available=PIN_MEMORY_AVAILABLE, - max_num_reqs=MAX_NUM_REQS + 1, - device=device) - + logitsprocs = build_logitsprocs( + vllm_config=VllmConfig(), + device=device, + is_pin_memory=PIN_MEMORY_AVAILABLE, + is_pooling_model=False, + ) fake_sampling_metadata = SamplingMetadata( temperature=torch.full((batch_size, ), 0.0), all_greedy=True, @@ -462,7 +467,8 @@ def _generate_fake_step_update( # Replace as many removed requests as possible with added requests add_remove_idx = batch_update_builder.pop_removed() batch_update_builder.added.append( - (add_remove_idx, add_req_params.params, add_req_params.out_tokens)) + (add_remove_idx, add_req_params.params, + add_req_params.prompt_tokens, add_req_params.out_tokens)) persistent_batch[add_remove_idx] = add_req_params # Append remaining added requests to end of batch @@ -470,7 +476,8 @@ def _generate_fake_step_update( num_step_add_replace):(wdx + num_step_add)] batch_update_builder.added.extend([ - (adx + batch_size, add_req_params.params, add_req_params.out_tokens) + (adx + batch_size, add_req_params.params, add_req_params.prompt_tokens, + add_req_params.out_tokens) for adx, add_req_params in enumerate(add_reqs_append) ]) persistent_batch.extend(add_reqs_append) @@ -561,6 +568,7 @@ def _assert_valid( step_idx=step_idx) +@create_new_process_for_each_test() @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("reqs_per_logitproc", [REQS_PER_LOGITPROC]) @pytest.mark.parametrize("logitsprocs_under_test", _get_test_cases()) diff --git a/tests/v1/logits_processors/test_custom_offline.py b/tests/v1/logits_processors/test_custom_offline.py new file mode 100644 index 000000000000..a7fde1990f7e --- /dev/null +++ b/tests/v1/logits_processors/test_custom_offline.py @@ -0,0 +1,237 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import random +import sys +from typing import Union + +import pytest + +from tests.utils import create_new_process_for_each_test +# yapf: disable +from tests.v1.logits_processors.utils import (DUMMY_LOGITPROC_ARG, + DUMMY_LOGITPROC_FQCN, + DUMMY_LOGITPROC_MODULE, + MAX_TOKENS, MODEL_NAME, + POOLING_MODEL_NAME, TEMP_GREEDY, + CustomLogitprocSource, + DummyLogitsProcessor, + dummy_module) +from tests.v1.logits_processors.utils import entry_points as fake_entry_points +from tests.v1.logits_processors.utils import prompts +# yapf: enable +from vllm import LLM, SamplingParams +from vllm.v1.sample.logits_processor import (STR_POOLING_REJECTS_LOGITSPROCS, + LogitsProcessor) + +# Create a mixture of requests which do and don't utilize the dummy logitproc +sampling_params_list = [ + SamplingParams(temperature=TEMP_GREEDY, + max_tokens=MAX_TOKENS, + extra_args={DUMMY_LOGITPROC_ARG: 128}), + SamplingParams(temperature=TEMP_GREEDY, max_tokens=MAX_TOKENS), + SamplingParams(temperature=TEMP_GREEDY, + max_tokens=MAX_TOKENS, + extra_args={DUMMY_LOGITPROC_ARG: 67}), + SamplingParams(temperature=TEMP_GREEDY, max_tokens=MAX_TOKENS), +] + + +def _run_test(kwargs: dict, logitproc_loaded: bool) -> None: + """Compare `LLM` instance initialized with specified `kwargs` against + reference `LLM` instance. + + Two scenarios: + 1. Server has loaded dummy logitproc; test that requests which specify + dummy logitproc arg value behave as if logitproc is operating (output + token value should repeat), while requests that don't specify dummy + logitproc arg value should match reference `LLM` output. + 2. Server has *not* loaded dummy logitproc; test that all requests + behave as if logitproc is *not* operating (output matches reference + `LLM` output.) + + Args: + kwargs: `LLM` constructor kwargs + logitproc_loaded: server has loaded dummy logitproc if True + """ + + # Create a vLLM instance and load custom logitproc + llm_logitproc = LLM( + model=MODEL_NAME, + gpu_memory_utilization=0.1, + **kwargs, + ) + + # Create a reference vLLM instance without custom logitproc + llm_ref = LLM(model=MODEL_NAME, gpu_memory_utilization=0.1) + + # Run inference with logitproc loaded + outputs_logitproc = llm_logitproc.generate(prompts, sampling_params_list) + + # Reference run + outputs_ref = llm_ref.generate(prompts, sampling_params_list) + + # Validate outputs + for bdx, (out_lp, out_ref, params) in enumerate( + zip(outputs_logitproc, outputs_ref, sampling_params_list)): + lp_toks = out_lp.outputs[0].token_ids + if logitproc_loaded and params.extra_args: + # This request exercises custom logitproc; validate that logitproc + # forces `target_token` to be decoded in each step + target_token = params.extra_args[DUMMY_LOGITPROC_ARG] + if not all(x == target_token for x in lp_toks): + raise AssertionError( + f"Request {bdx} generated {lp_toks}, shoud all be " + f"{target_token}") + else: + # This request does not exercise custom logitproc (or custom + # logitproc is not enabled on this server); validate against + # reference result + ref_toks = out_ref.outputs[0].token_ids + if lp_toks != ref_toks: + raise AssertionError( + f"Request {bdx} generated {lp_toks}, should match " + f"{ref_toks}") + + +@create_new_process_for_each_test() +@pytest.mark.parametrize("logitproc_source", list(CustomLogitprocSource)) +def test_custom_logitsprocs(monkeypatch, + logitproc_source: CustomLogitprocSource): + """Test offline Python interface for passing custom logitsprocs + + Construct an `LLM` instance which loads a custom logitproc that has a + well-defined behavior (mask out all tokens except one `target_token`) + + Construct a reference `LLM` instance with no custom logitproc + + Pass in a batch of requests, 50% of which pass a `target_token` value + in through `SamplingParams.extra_args`, 50% of which do not. + + Validate that + * Requests which do not activate the custom logitproc, yield the same + results for both `LLM` instances + * Requests which activate the custom logitproc, only output `target_token` + + Test four scenarios, corresponding to `logitproc_source` value + * No logitsprocs loaded - test that generated tokens match reference `LLM` + instance output + * Logitproc passed in via {entrypoint, class object, fully-qualified class + name (FQCN)} - test that dummy logitproc is utilized correctly when + provided via any of these three possible sources + + Args: + monkeypatch: for setting env vars + logitproc_source: what source (entrypoint, fully-qualified class name + (FQCN), class object, or None) the user pulls the + logitproc from + """ + + # Test that logitproc info is passed to workers + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1") + random.seed(40) + + # Choose LLM args based on logitproc source + if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_NONE: + # Scenario: the server does not load any custom logitproc + # Every other scenario is a different way of loading a custom logitproc + _run_test({}, logitproc_loaded=False) + return + + if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_ENTRYPOINT: + # Scenario: vLLM loads a logitproc from a preconfigured entrypoint + # To that end, mock a dummy logitproc entrypoint + import importlib.metadata + importlib.metadata.entry_points = fake_entry_points # type: ignore + + # fork is required for workers to see entrypoint patch + monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "fork") + _run_test({}, logitproc_loaded=True) + return + + kwargs: dict[str, list[Union[str, type[LogitsProcessor]]]] = {} + if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_FQCN: + # Scenario: load logitproc based on fully-qualified class name (FQCN) + # Inject dummy module which defines logitproc + sys.modules[DUMMY_LOGITPROC_MODULE] = dummy_module + kwargs["logits_processors"] = [DUMMY_LOGITPROC_FQCN] + elif logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_CLASS: + # Scenario: load logitproc from provided class object + kwargs["logits_processors"] = [DummyLogitsProcessor] + + _run_test(kwargs, logitproc_loaded=True) + + +@create_new_process_for_each_test() +@pytest.mark.parametrize("logitproc_source", [ + CustomLogitprocSource.LOGITPROC_SOURCE_ENTRYPOINT, + CustomLogitprocSource.LOGITPROC_SOURCE_FQCN, + CustomLogitprocSource.LOGITPROC_SOURCE_CLASS, +]) +def test_pooling_rejects_custom_logitsprocs( + monkeypatch, logitproc_source: CustomLogitprocSource): + """Validate that vLLM engine initialization properly rejects custom + logitsprocs when the model is a pooling model. + + Use `LLM` entrypoint. We expect `LLM` initialization to fail before the + logitproc is actually loaded. + + Scenario 1: + * Mock a logitproc entrypoint + * Validate that `LLM` does not load the logitproc + + Scenario 2: + * Pass custom logitproc to `LLM` constructor + * Scenario 2a: via FQCN + * Scenario 2b: via class object + * Validate that initialization fails with appropriate exception + + Args: + monkeypatch: used to set environment variables + logitproc_source: what source (entrypoint, fully-qualified class name + (FQCN), or class object) the user pulls the + logitproc from + """ + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + random.seed(40) + + if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_ENTRYPOINT: + # Scenario: vLLM loads a pooling model and ignores a logitproc that is + # available at a preconfigured entrypoint + + # Patch in dummy logitproc entrypoint + import importlib.metadata + importlib.metadata.entry_points = fake_entry_points # type: ignore + + # fork is required for entrypoint patch to be visible to workers, + # although they should ignore the entrypoint patch anyway + monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "fork") + + llm = LLM( + runner="pooling", + model=POOLING_MODEL_NAME, + gpu_memory_utilization=0.1, + ) + # Require that no logitsprocs have been loaded + assert sum([ + 1 for _ in llm.llm_engine.model_executor.driver_worker.worker. + model_runner.input_batch.logitsprocs.all + ]) == 0 + return + + kwargs: dict[str, list[Union[str, type[LogitsProcessor]]]] = {} + if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_FQCN: + # Scenario: load logitproc based on fully-qualified class name (FQCN) + kwargs["logits_processors"] = [DUMMY_LOGITPROC_FQCN] + elif logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_CLASS: + # Scenario: load logitproc from provided class object + kwargs["logits_processors"] = [DummyLogitsProcessor] + + with pytest.raises(ValueError, match=STR_POOLING_REJECTS_LOGITSPROCS): + # Require that loading a pooling model alongside the logitproc raises + # the appropriate exception. + LLM( + runner="pooling", + model=POOLING_MODEL_NAME, + gpu_memory_utilization=0.1, + **kwargs, + ) diff --git a/tests/v1/logits_processors/test_custom_online.py b/tests/v1/logits_processors/test_custom_online.py new file mode 100644 index 000000000000..a01a479e5b24 --- /dev/null +++ b/tests/v1/logits_processors/test_custom_online.py @@ -0,0 +1,180 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +import random +import sys +from typing import Any, Optional + +import openai +import pytest +import pytest_asyncio + +from tests.utils import (RemoteOpenAIServerCustom, + create_new_process_for_each_test) +# yapf: disable +from tests.v1.logits_processors.utils import (DUMMY_LOGITPROC_ARG, + DUMMY_LOGITPROC_FQCN, + DUMMY_LOGITPROC_MODULE, + MAX_TOKENS, MODEL_NAME, + TEMP_GREEDY, dummy_module) +from tests.v1.logits_processors.utils import entry_points as fake_entry_points +from tests.v1.logits_processors.utils import prompts + +# yapf: enable + + +def _server_with_logitproc_entrypoint( + env_dict: Optional[dict[str, str]], + model: str, + vllm_serve_args: list[str], +) -> None: + """Start vLLM server, inject dummy logitproc entrypoint""" + + # Patch `entry_points` to inject logitproc entrypoint + import importlib.metadata + importlib.metadata.entry_points = fake_entry_points # type: ignore + from vllm.entrypoints.cli import main + + # fork is required for workers to see entrypoint patch + os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = "fork" + if env_dict is not None: + os.environ.update(env_dict) + + # Emulate `vllm serve ` + sys.argv = ["vllm", "serve", model] + vllm_serve_args + main.main() + + +def _server_with_logitproc_module( + env_dict: Optional[dict[str, str]], + model: str, + vllm_serve_args: list[str], +) -> None: + """Start vLLM server, inject module with dummy logitproc""" + + # Patch `modules` to inject dummy logitproc module + from vllm.entrypoints.cli import main + sys.modules[DUMMY_LOGITPROC_MODULE] = dummy_module + + # fork is required for workers to see entrypoint patch + os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = "fork" + if env_dict is not None: + os.environ.update(env_dict) + + # Emulate `vllm serve ` + sys.argv = ["vllm", "serve", model] + vllm_serve_args + main.main() + + +@pytest.fixture(scope="module") +def default_server_args(): + return [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--max-num-seqs", + "128", + ] + + +@pytest.fixture(scope="function", + params=[[], ["--logits-processors", DUMMY_LOGITPROC_FQCN]]) +def server(default_server_args, request, monkeypatch): + """Consider two server configurations: + (1) --logits-processors cli arg specifies dummy logits processor via fully- + qualified class name (FQCN); patch in a dummy logits processor module + (2) No --logits-processors cli arg; patch in a dummy logits processor + entrypoint + """ + + # Test that logitproc info is passed to workers + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1") + + if request.param: + # Launch server, append FQCN argument, inject dummy logitproc module + args = default_server_args + request.param + _server_fxn = _server_with_logitproc_module + else: + # Launch server, inject dummy logitproc entrypoint + args = default_server_args + _server_fxn = _server_with_logitproc_entrypoint + + with RemoteOpenAIServerCustom(MODEL_NAME, args, + _server_fxn) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client + + +# General request argument values for these tests +api_keyword_args = { + # Greedy sampling ensures that requests which receive the `target_token` + # arg will decode it in every step + "temperature": TEMP_GREEDY, + # Since EOS will never be decoded (unless `target_token` is EOS) + "max_tokens": MAX_TOKENS, + # Return decoded token logprobs (as a way of getting token id) + "logprobs": 0, +} + + +@create_new_process_for_each_test() +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_custom_logitsprocs(client: openai.AsyncOpenAI, model_name: str): + """Test custom logitsprocs when starting OpenAI server from CLI + + Launch vLLM OpenAI-compatible server, configured to load a custom logitproc + that has a well-defined behavior (mask out all tokens except one + `target_token`). + + Pass in requests, 50% of which pass a `target_token` value + in through `extra_body["vllm_xargs"]`, 50% of which do not. + + Validate that requests which activate the custom logitproc, repeat the same + token + """ + + use_dummy_logitproc = True + for prompt in prompts: + # Build request arguments + request_keyword_args: dict[str, Any] = { + **api_keyword_args, + } + if use_dummy_logitproc: + # 50% of requests pass target_token custom arg + target_token = random.choice([128, 67]) + # For requests which activate the dummy logitproc, choose one of + # two `target_token` values which are known not to be EOS tokens + request_keyword_args["extra_body"] = { + "vllm_xargs": { + DUMMY_LOGITPROC_ARG: target_token + } + } + batch = await client.completions.create( + model=model_name, + prompt=prompt, + **request_keyword_args, + ) + + if use_dummy_logitproc: + # Only for requests which activate dummy logitproc - validate that + # output token is repeated + choices: openai.types.CompletionChoice = batch.choices + toks = choices[0].logprobs.tokens + if not all([x == toks[0] for x in toks]): + raise AssertionError( + f"Generated {toks} should all be {toks[0]}") + + # Alternate whether to activate dummy logitproc for each request + use_dummy_logitproc = not use_dummy_logitproc diff --git a/tests/v1/logits_processors/utils.py b/tests/v1/logits_processors/utils.py new file mode 100644 index 000000000000..c0bfc1a18fec --- /dev/null +++ b/tests/v1/logits_processors/utils.py @@ -0,0 +1,127 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import types +from enum import Enum, auto +from typing import Optional + +import torch + +from vllm.config import VllmConfig +from vllm.sampling_params import SamplingParams +from vllm.v1.sample.logits_processor import (LOGITSPROCS_GROUP, BatchUpdate, + LogitsProcessor, + MoveDirectionality) + +MODEL_NAME = "facebook/opt-125m" +POOLING_MODEL_NAME = "BAAI/bge-base-en-v1.5" +DUMMY_LOGITPROC_ARG = "target_token" +TEMP_GREEDY = 0.0 +MAX_TOKENS = 20 +DUMMY_LOGITPROC_ENTRYPOINT = "dummy_logitproc" +DUMMY_LOGITPROC_MODULE = "DummyModule" +DUMMY_LOGITPROC_FQCN = f"{DUMMY_LOGITPROC_MODULE}:DummyLogitsProcessor" + + +class CustomLogitprocSource(Enum): + """How to source a logitproc for testing purposes""" + LOGITPROC_SOURCE_NONE = auto() # No custom logitproc + LOGITPROC_SOURCE_ENTRYPOINT = auto() # Via entrypoint + LOGITPROC_SOURCE_FQCN = auto() # Via fully-qualified class name (FQCN) + LOGITPROC_SOURCE_CLASS = auto() # Via provided class object + + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + + +class DummyLogitsProcessor(LogitsProcessor): + """Fake logit processor to support unit testing and examples""" + + def __init__(self, vllm_config: "VllmConfig", device: torch.device, + is_pin_memory: bool): + self.req_info: dict[int, SamplingParams] = {} + + def is_argmax_invariant(self) -> bool: + """Never impacts greedy sampling""" + return False + + def update_state(self, batch_update: Optional[BatchUpdate]): + if not batch_update: + return + + # Process added requests. + for index, params, _, _ in batch_update.added: + assert params is not None + if params.extra_args and (target_token := + params.extra_args.get("target_token")): + self.req_info[index] = target_token + + if self.req_info: + # Process removed requests. + for index in batch_update.removed: + self.req_info.pop(index, None) + + # Process moved requests, unidirectional move (a->b) and swap + # (a<->b) + for adx, bdx, direct in batch_update.moved: + a_val = self.req_info.pop(adx, None) + b_val = self.req_info.pop(bdx, None) + if a_val is not None: + self.req_info[bdx] = a_val + if direct == MoveDirectionality.SWAP and b_val is not None: + self.req_info[adx] = b_val + + def apply(self, logits: torch.Tensor) -> torch.Tensor: + if not self.req_info: + return logits + + # Save target values before modification + rows_list = list(self.req_info.keys()) + cols = torch.tensor([self.req_info[i] for i in rows_list], + dtype=torch.long, + device=logits.device) + rows = torch.tensor(rows_list, dtype=torch.long, device=logits.device) + values_to_keep = logits[rows, cols].clone() + + # Mask all but target tokens + logits[rows] = float('-inf') + logits[rows, cols] = values_to_keep + + return logits + + +"""Dummy module with dummy logitproc class""" +dummy_module = types.ModuleType(DUMMY_LOGITPROC_MODULE) +dummy_module.DummyLogitsProcessor = DummyLogitsProcessor # type: ignore + + +class EntryPoint: + """Dummy entrypoint class for logitsprocs testing""" + + def __init__(self): + self.name = DUMMY_LOGITPROC_ENTRYPOINT + self.value = DUMMY_LOGITPROC_FQCN + + def load(self): + return DummyLogitsProcessor + + +class EntryPoints(list): + """Dummy EntryPoints class for logitsprocs testing""" + + def __init__(self, group: str): + # Emulate list-like functionality + eps = [EntryPoint()] if group == LOGITSPROCS_GROUP else [] + super().__init__(eps) + # Extra attributes + self.names = [ep.name for ep in eps] + + +"""Fake version of importlib.metadata.entry_points""" +entry_points = lambda group: EntryPoints(group) diff --git a/tests/v1/sample/test_rejection_sampler.py b/tests/v1/sample/test_rejection_sampler.py index 3a4d48afc9d7..4e912f98f376 100644 --- a/tests/v1/sample/test_rejection_sampler.py +++ b/tests/v1/sample/test_rejection_sampler.py @@ -7,7 +7,7 @@ import torch.nn.functional as F from vllm.platforms import current_platform -from vllm.v1.sample.logits_processor import LogitsProcessorManager +from vllm.v1.sample.logits_processor import LogitsProcessors from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID, RejectionSampler) @@ -69,7 +69,7 @@ def create_sampling_metadata( output_token_ids=[], allowed_token_ids_mask=None, bad_words_token_ids={}, - logitsprocs=LogitsProcessorManager(), + logitsprocs=LogitsProcessors(), ) diff --git a/tests/v1/sample/test_sampler.py b/tests/v1/sample/test_sampler.py index 31c6c881d7b8..53215f88bb27 100644 --- a/tests/v1/sample/test_sampler.py +++ b/tests/v1/sample/test_sampler.py @@ -9,7 +9,7 @@ from vllm.platforms import current_platform from vllm.utils import is_pin_memory_available, make_tensor_with_pad -from vllm.v1.sample.logits_processor import LogitsProcessorManager +from vllm.v1.sample.logits_processor import LogitsProcessors from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.sampler import Sampler @@ -173,7 +173,7 @@ def _create_default_sampling_metadata( no_penalties=True, allowed_token_ids_mask=None, bad_words_token_ids={}, - logitsprocs=LogitsProcessorManager(), + logitsprocs=LogitsProcessors(), ) return fake_sampling_metadata diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index 74ab19a3ce32..d7b4746562be 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -13,7 +13,7 @@ from vllm.sampling_params import SamplingParams from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.v1.pool.metadata import PoolingMetadata -from vllm.v1.sample.logits_processor import LogitsProcessorManager +from vllm.v1.sample.logits_processor import LogitsProcessors from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch @@ -169,7 +169,7 @@ def _construct_expected_sampling_metadata( and all(x == 1 for x in repetition_penalties)), allowed_token_ids_mask=allowed_token_ids_mask, bad_words_token_ids=bad_words_token_ids, - logitsprocs=LogitsProcessorManager(), + logitsprocs=LogitsProcessors(), ) diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 280ae60c91ff..3145ecb96a9c 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -62,6 +62,7 @@ QuantizationConfig) from vllm.model_executor.model_loader import LoadFormats from vllm.model_executor.model_loader.tensorizer import TensorizerConfig + from vllm.v1.sample.logits_processor import LogitsProcessor HfOverrides = Union[dict, Callable[[type], type]] else: @@ -72,6 +73,7 @@ BaseModelLoader = Any LoadFormats = Any TensorizerConfig = Any + LogitsProcessor = Any HfOverrides = Union[dict[str, Any], Callable[[type], type]] me_quant = LazyLoader("model_executor", globals(), @@ -465,6 +467,9 @@ class ModelConfig: - "transformers" will use the Transformers model implementation.""" override_attention_dtype: Optional[str] = None """Override dtype for attention""" + logits_processors: Optional[list[Union[str, type[LogitsProcessor]]]] = None + """One or more logits processors' fully-qualified class names or class + definitions""" def compute_hash(self) -> str: """ diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f8af6d36e0c0..0ae63e7726ef 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -43,6 +43,7 @@ from vllm.transformers_utils.utils import check_gguf_file from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser, GiB_bytes, get_ip, is_in_ray_actor) +from vllm.v1.sample.logits_processor import LogitsProcessor # yapf: enable @@ -435,6 +436,10 @@ class EngineArgs: enable_multimodal_encoder_data_parallel: bool = \ ParallelConfig.enable_multimodal_encoder_data_parallel + logits_processors: Optional[list[Union[ + str, type[LogitsProcessor]]]] = ModelConfig.logits_processors + """Custom logitproc types""" + async_scheduling: bool = SchedulerConfig.async_scheduling # DEPRECATED enable_prompt_adapter: bool = False @@ -549,6 +554,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: **model_kwargs["model_impl"]) model_group.add_argument("--override-attention-dtype", **model_kwargs["override_attention_dtype"]) + model_group.add_argument("--logits-processors", + **model_kwargs["logits_processors"]) # Model loading arguments load_kwargs = get_kwargs(LoadConfig) @@ -940,6 +947,7 @@ def create_model_config(self) -> ModelConfig: enable_sleep_mode=self.enable_sleep_mode, model_impl=self.model_impl, override_attention_dtype=self.override_attention_dtype, + logits_processors=self.logits_processors, ) def validate_tensorizer_args(self): diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 915f14a29b90..b002f234c043 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -55,6 +55,7 @@ get_cached_tokenizer) from vllm.usage.usage_lib import UsageContext from vllm.utils import Counter, Device, deprecate_kwargs, is_list_of +from vllm.v1.sample.logits_processor import LogitsProcessor if TYPE_CHECKING: from vllm.v1.metrics.reader import Metric @@ -198,6 +199,8 @@ def __init__( override_pooler_config: Optional[PoolerConfig] = None, compilation_config: Optional[Union[int, dict[str, Any], CompilationConfig]] = None, + logits_processors: Optional[list[Union[str, + type[LogitsProcessor]]]] = None, **kwargs, ) -> None: """LLM constructor.""" @@ -272,6 +275,7 @@ def __init__( mm_processor_kwargs=mm_processor_kwargs, override_pooler_config=override_pooler_config, compilation_config=compilation_config_instance, + logits_processors=logits_processors, **kwargs, ) diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index a1f8ad164762..a82e17995fe2 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -2553,7 +2553,7 @@ def direct_register_custom_op( def resolve_obj_by_qualname(qualname: str) -> Any: """ - Resolve an object by its fully qualified name. + Resolve an object by its fully-qualified class name. """ module_name, obj_name = qualname.rsplit(".", 1) module = importlib.import_module(module_name) diff --git a/vllm/v1/sample/logits_processor/__init__.py b/vllm/v1/sample/logits_processor/__init__.py new file mode 100644 index 000000000000..822026916295 --- /dev/null +++ b/vllm/v1/sample/logits_processor/__init__.py @@ -0,0 +1,185 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import importlib +import itertools +from collections.abc import Sequence +from typing import TYPE_CHECKING, Optional, Union + +import torch + +from vllm.logger import init_logger +from vllm.v1.sample.logits_processor.builtin import (LogitBiasLogitsProcessor, + MinPLogitsProcessor, + MinTokensLogitsProcessor) +from vllm.v1.sample.logits_processor.interface import (BatchUpdate, + LogitsProcessor, + MoveDirectionality) +from vllm.v1.sample.logits_processor.state import (BatchUpdateBuilder, + LogitsProcessors) + +if TYPE_CHECKING: + from vllm.config import VllmConfig + +logger = init_logger(__name__) + +# Error message when the user tries to initialize vLLM with a pooling model +# and custom logitsproces +STR_POOLING_REJECTS_LOGITSPROCS = ("Pooling models do not support custom" + " logits processors.") + +LOGITSPROCS_GROUP = 'vllm.logits_processors' + +BUILTIN_LOGITS_PROCESSORS: list[type[LogitsProcessor]] = [ + MinTokensLogitsProcessor, + LogitBiasLogitsProcessor, + MinPLogitsProcessor, +] + + +def _load_logitsprocs_plugins() -> list[type[LogitsProcessor]]: + """Load all installed logit processor plugins""" + + import sys + + if sys.version_info < (3, 10): + from importlib_metadata import entry_points + else: + from importlib.metadata import entry_points + + installed_logitsprocs_plugins = entry_points(group=LOGITSPROCS_GROUP) + if len(installed_logitsprocs_plugins) == 0: + logger.debug("No logitsprocs plugins installed (group %s).", + LOGITSPROCS_GROUP) + return [] + + # Load logitsprocs plugins + logger.debug("Loading installed logitsprocs plugins (group %s):", + LOGITSPROCS_GROUP) + classes: list[type[LogitsProcessor]] = [] + for entrypoint in installed_logitsprocs_plugins: + try: + logger.debug("- Loading logitproc plugin entrypoint=%s target=%s", + entrypoint.name, entrypoint.value) + classes.append(entrypoint.load()) + except Exception as e: + raise RuntimeError( + f"Failed to load LogitsProcessor plugin {entrypoint}") from e + return classes + + +def _load_logitsprocs_by_fqcns( + logits_processors: Optional[Sequence[Union[str, type[LogitsProcessor]]]] +) -> list[type[LogitsProcessor]]: + """Load logit processor types, identifying them by fully-qualified class + names (FQCNs). + + Effectively, a mixed list of logitproc types and FQCN strings is converted + into a list of entirely logitproc types, by loading from the FQCNs. + + FQCN syntax is : i.e. x.y.z:CustomLogitProc + + Already-loaded logitproc types must be subclasses of LogitsProcessor + + Args: + logits_processors: Potentially mixed list of logitsprocs types and FQCN + strings for logitproc types + + Returns: + List of logitproc types + + """ + if not logits_processors: + return [] + + logger.debug( + "%s additional custom logits processors specified, checking whether " + "they need to be loaded.", len(logits_processors)) + + classes: list[type[LogitsProcessor]] = [] + for ldx, logitproc in enumerate(logits_processors): + if isinstance(logitproc, type): + logger.debug(" - Already-loaded logit processor: %s", + logitproc.__name__) + if not issubclass(logitproc, LogitsProcessor): + raise ValueError( + f"{logitproc.__name__} is not a subclass of LogitsProcessor" + ) + classes.append(logitproc) + continue + + logger.debug("- Loading logits processor %s", logitproc) + module_path, qualname = logitproc.split(":") + + try: + # Load module + module = importlib.import_module(module_path) + except Exception as e: + raise RuntimeError( + f"Failed to load {ldx}th LogitsProcessor plugin {logitproc}" + ) from e + + # Walk down dotted name to get logitproc class + obj = module + for attr in qualname.split("."): + obj = getattr(obj, attr) + if not isinstance(obj, type): + raise ValueError("Loaded logit processor must be a type.") + if not issubclass(obj, LogitsProcessor): + raise ValueError( + f"{obj.__name__} must be a subclass of LogitsProcessor") + classes.append(obj) + + return classes + + +def _load_custom_logitsprocs( + logits_processors: Optional[Sequence[Union[str, type[LogitsProcessor]]]], +) -> list[type[LogitsProcessor]]: + """Load all custom logits processors. + + * First load all installed logitproc plugins + * Second load custom logitsprocs pass by the user at initialization time + + Args: + logits_processors: potentially mixed list of logitproc types and + logitproc type fully-qualified names (FQCNs) + which need to be loaded + + Returns: + A list of all loaded logitproc types + """ + from vllm.platforms import current_platform + if current_platform.is_tpu(): + # No logitsprocs specified by caller + # TODO(andy) - vLLM V1 on TPU does not support custom logitsprocs + return [] + + return (_load_logitsprocs_plugins() + + _load_logitsprocs_by_fqcns(logits_processors)) + + +def build_logitsprocs( + vllm_config: "VllmConfig", + device: torch.device, + is_pin_memory: bool, + is_pooling_model: bool, + custom_logitsprocs: Sequence[Union[str, type[LogitsProcessor]]] = (), +) -> LogitsProcessors: + if is_pooling_model: + if custom_logitsprocs: + raise ValueError(STR_POOLING_REJECTS_LOGITSPROCS) + logger.debug("Skipping logits processor loading because pooling models" + " do not support logits processors.") + return LogitsProcessors() + custom_logitsprocs_classes = _load_custom_logitsprocs(custom_logitsprocs) + return LogitsProcessors( + ctor(vllm_config, device, is_pin_memory) for ctor in itertools.chain( + BUILTIN_LOGITS_PROCESSORS, custom_logitsprocs_classes)) + + +__all__ = [ + "LogitsProcessor", "LogitBiasLogitsProcessor", "MinPLogitsProcessor", + "MinTokensLogitsProcessor", "BatchUpdate", "BatchUpdateBuilder", + "MoveDirectionality", "LogitsProcessors", "build_logitsprocs", + "STR_POOLING_REJECTS_LOGITSPROCS", "LOGITSPROCS_GROUP" +] diff --git a/vllm/v1/sample/logits_processor.py b/vllm/v1/sample/logits_processor/builtin.py similarity index 54% rename from vllm/v1/sample/logits_processor.py rename to vllm/v1/sample/logits_processor/builtin.py index 3a06e71057cd..24387ab79390 100644 --- a/vllm/v1/sample/logits_processor.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -1,241 +1,32 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import dataclasses -from abc import ABC, abstractmethod -from collections.abc import Iterator, Sequence -from dataclasses import dataclass, field -from enum import Enum -from itertools import chain -from typing import Optional, Union +from collections.abc import Sequence +from typing import TYPE_CHECKING, Optional import torch -from torch._prims_common import DeviceLikeType - -from vllm import PoolingParams, SamplingParams -from vllm.logger import init_logger - -logger = init_logger(__name__) - - -class MoveDirectionality(Enum): - # One-way i1->i2 req move within batch - UNIDIRECTIONAL = 0 - # Two-way i1<->i2 req swap within batch - SWAP = 1 - - -# (index, params, output_tok_ids) tuples for new -# requests added to the batch. -AddedRequest = tuple[int, Union[SamplingParams, PoolingParams], list[int]] -# (index 1, index 2, directionality) tuples representing -# one-way moves or two-way swaps of requests in batch -MovedRequest = tuple[int, int, MoveDirectionality] -# Batch indices of any removed requests. -RemovedRequest = int - - -@dataclasses.dataclass(frozen=True) -class BatchUpdate: - """Persistent batch state change info for logitsprocs""" - batch_size: int # Current num reqs in batch - - # Metadata for requests added to, removed from, and moved - # within the persistent batch. - # - # Note: each added request is represented as - # (index, params, output_tok_ids) - # Key assumption: output_tok_ids is a reference to the - # request's running output tokens list; in this way - # the logits processors always see the latest list of - # generated tokens - removed: Sequence[RemovedRequest] - moved: Sequence[MovedRequest] - added: Sequence[AddedRequest] - - -class BatchUpdateBuilder: - """Helps track persistent batch state changes and build - a batch update data structure for logitsprocs - - Assumptions: - * All information about requests removed from persistent batch - during a step is aggregated in self._removed through calls to - self.removed_append() at the beginning of a step. This must happen - before the first time that self.removed, self.pop_removed() - or self.peek_removed() are invoked in a given step - * After the first time that self.removed, self.pop_removed() - or self.peek_removed() are read in a step, no new removals - are registered using self.removed_append() - * Elements of self._removed are never directly modified, added or - removed (i.e. modification is only via self.removed_append() and - self.pop_removed()) - - Guarantees under above assumptions: - * self.removed is always sorted in descending order - * self.pop_removed() and self.peek_removed() both return - the lowest removed request index in the current step - """ - - _removed: list[RemovedRequest] - _is_removed_sorted: bool - moved: list[MovedRequest] - added: list[AddedRequest] - - def __init__( - self, - removed: Optional[list[RemovedRequest]] = None, - moved: Optional[list[MovedRequest]] = None, - added: Optional[list[AddedRequest]] = None, - ) -> None: - self._removed = removed or [] - self.moved = moved or [] - self.added = added or [] - self._is_removed_sorted = False - - def _ensure_removed_sorted(self) -> None: - """Sort removed request indices in - descending order. - - Idempotent after first call in a - given step, until reset. - """ - if not self._is_removed_sorted: - self._removed.sort(reverse=True) - self._is_removed_sorted = True - - @property - def removed(self) -> list[RemovedRequest]: - """Removed request indices sorted in - descending order""" - self._ensure_removed_sorted() - return self._removed - - def removed_append(self, index: int) -> None: - """Register the removal of a request from - the persistent batch. - - Must not be called after the first time - self.removed, self.pop_removed() or - self.peek_removed() are invoked. - - Args: - index: request index - """ - if self._is_removed_sorted: - raise RuntimeError("Cannot register new removed request after" - " self.removed has been read.") - self._removed.append(index) - - def has_removed(self) -> bool: - return bool(self._removed) - - def peek_removed(self) -> Optional[int]: - """Return lowest removed request index""" - if self.has_removed(): - self._ensure_removed_sorted() - return self._removed[-1] - return None - - def pop_removed(self) -> Optional[int]: - """Pop lowest removed request index""" - if self.has_removed(): - self._ensure_removed_sorted() - return self._removed.pop() - return None - - def get_and_reset(self, batch_size: int) -> Optional[BatchUpdate]: - """Generate a logitsprocs batch update data structure - and reset internal batch update builder state. - - Args: - batch_size: current persistent batch size - - Returns: - Frozen logitsprocs batch update instance; `None` if no updates - """ - # Reset removal-sorting logic - self._is_removed_sorted = False - if not any((self._removed, self.moved, self.added)): - # No update; short-circuit - return None - # Build batch state update - batch_update = BatchUpdate( - batch_size=batch_size, - removed=self._removed, - moved=self.moved, - added=self.added, - ) - # Reset removed/moved/added update lists - self._removed = [] - self.moved = [] - self.added = [] - return batch_update - - -class LogitsProcessor(ABC): - - @abstractmethod - def apply(self, logits: torch.Tensor) -> torch.Tensor: - raise NotImplementedError - @abstractmethod - def is_argmax_invariant(self) -> bool: - """True if logits processor has no impact on the - argmax computation in greedy sampling. - NOTE: may or may not have the same value for all - instances of a given LogitsProcessor subclass, - depending on subclass implementation. - TODO(andy): won't be utilized until logits - processors are user-extensible - """ - raise NotImplementedError - - @abstractmethod - def update_state( - self, - batch_update: Optional[BatchUpdate], - ) -> None: - """Called when there are new output tokens, prior - to each forward pass. - - Args: - batch_update is non-None iff there have been - changes to the batch makeup. - """ - raise NotImplementedError - - -@dataclass -class LogitsProcessorManager: - """Encapsulates initialized logitsproc objects.""" - argmax_invariant: list[LogitsProcessor] = field( - default_factory=list) # argmax-invariant logitsprocs - non_argmax_invariant: list[LogitsProcessor] = field( - default_factory=list) # non-argmax-invariant logitsprocs - - @property - def all(self) -> Iterator[LogitsProcessor]: - """Iterator over all logits processors.""" - return chain(self.argmax_invariant, self.non_argmax_invariant) - - -###### ----- Built-in LogitsProcessor impls below here +from vllm.v1.sample.logits_processor.interface import (BatchUpdate, + LogitsProcessor, + MoveDirectionality) + +if TYPE_CHECKING: + from vllm.config import VllmConfig class MinPLogitsProcessor(LogitsProcessor): - def __init__(self, max_num_reqs: int, pin_memory: bool, - device: DeviceLikeType): - super().__init__() + def __init__(self, vllm_config: "VllmConfig", device: torch.device, + is_pin_memory: bool): + max_num_reqs = vllm_config.scheduler_config.max_num_seqs self.min_p_count: int = 0 self.min_p_cpu_tensor = torch.zeros((max_num_reqs, ), dtype=torch.float32, device="cpu", - pin_memory=pin_memory) + pin_memory=is_pin_memory) self.min_p_cpu = self.min_p_cpu_tensor.numpy() - self.use_double_tensor = torch.device("cpu") != torch.device(device) + self.use_double_tensor = torch.device(device).type != "cpu" if self.use_double_tensor: # Pre-allocated device tensor @@ -260,8 +51,8 @@ def update_state(self, batch_update: Optional[BatchUpdate]): needs_update = False # Process added requests. - for index, params, _ in batch_update.added: - min_p = params.min_p if isinstance(params, SamplingParams) else 0.0 + for index, params, _, _ in batch_update.added: + min_p = params.min_p if self.min_p_cpu[index] != min_p: needs_update = True self.min_p_cpu[index] = min_p @@ -316,11 +107,10 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: class LogitBiasLogitsProcessor(LogitsProcessor): - def __init__(self, pin_memory: bool, device: torch.device): - super().__init__() - self.biases: dict[int, dict[int, float]] = {} + def __init__(self, _, device: torch.device, is_pin_memory: bool): self.device = device - self.pin_memory = pin_memory + self.pin_memory = is_pin_memory + self.biases: dict[int, dict[int, float]] = {} self.bias_tensor: torch.Tensor = torch.tensor(()) self.logits_slice = (self._device_tensor([], torch.int32), @@ -337,9 +127,8 @@ def update_state(self, batch_update: Optional[BatchUpdate]): needs_update: bool = False # Process added requests. - for index, params, _ in batch_update.added: - if isinstance(params, SamplingParams) and (lb := - params.logit_bias): + for index, params, _, _ in batch_update.added: + if lb := params.logit_bias: self.biases[index] = lb needs_update = True else: @@ -400,12 +189,12 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: class MinTokensLogitsProcessor(LogitsProcessor): - def __init__(self, pin_memory: bool, device: torch.device): + def __init__(self, vllm_config: "VllmConfig", device: torch.device, + is_pin_memory: bool): # index -> (min_toks, output_token_ids, stop_token_ids) - super().__init__() - self.min_toks: dict[int, tuple[int, Sequence[int], set[int]]] = {} self.device = device - self.pin_memory = pin_memory + self.pin_memory = is_pin_memory + self.min_toks: dict[int, tuple[int, Sequence[int], set[int]]] = {} # (req_idx_tensor,eos_tok_id_tensor) self.logits_slice: tuple[torch.Tensor, @@ -424,9 +213,8 @@ def update_state(self, batch_update: Optional[BatchUpdate]): if batch_update: # Process added requests. - for index, params, output_tok_ids in batch_update.added: - if (isinstance(params, SamplingParams) - and (min_tokens := params.min_tokens) + for index, params, _, output_tok_ids in batch_update.added: + if ((min_tokens := params.min_tokens) and len(output_tok_ids) < min_tokens): # Replace request metadata at batch index self.min_toks[index] = (min_tokens, output_tok_ids, @@ -499,35 +287,3 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: # Inhibit EOS token for requests which have not reached min length logits[self.logits_slice] = -float("inf") return logits - - -def init_builtin_logitsprocs(pin_memory_available: bool, max_num_reqs: int, - device: torch.device) -> LogitsProcessorManager: - """Construct 'builtin' vLLM logitsprocs which the engine - loads by default. - - Args: - pin_memory_available: pinned memory is available for use - for use by logitsproc - max_num_reqs: ceiling on request count in persistent batch - device: inference device - - Returns: - Data structure encapsulating loaded logitsprocs - """ - min_tokens_logitproc = MinTokensLogitsProcessor( - pin_memory=pin_memory_available, device=device) - logit_bias_logitproc = LogitBiasLogitsProcessor( - pin_memory=pin_memory_available, device=device) - min_p_logitproc = MinPLogitsProcessor( - pin_memory=pin_memory_available, - device=device, - # +1 for temporary swap space - max_num_reqs=max_num_reqs + 1) - return LogitsProcessorManager( - non_argmax_invariant=[ - min_tokens_logitproc, - logit_bias_logitproc, - ], - argmax_invariant=[min_p_logitproc], - ) diff --git a/vllm/v1/sample/logits_processor/interface.py b/vllm/v1/sample/logits_processor/interface.py new file mode 100644 index 000000000000..12b4db24bff8 --- /dev/null +++ b/vllm/v1/sample/logits_processor/interface.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod +from collections.abc import Sequence +from dataclasses import dataclass +from enum import Enum, auto +from typing import TYPE_CHECKING, Optional + +import torch + +from vllm import SamplingParams + +if TYPE_CHECKING: + from vllm.config import VllmConfig + + +class MoveDirectionality(Enum): + # One-way i1->i2 req move within batch + UNIDIRECTIONAL = auto() + # Two-way i1<->i2 req swap within batch + SWAP = auto() + + +# (index, params, prompt_tok_ids, output_tok_ids) tuples for new +# requests added to the batch. +AddedRequest = tuple[int, SamplingParams, list[int], list[int]] + +# (index 1, index 2, directionality) tuples representing +# one-way moves or two-way swaps of requests in batch +MovedRequest = tuple[int, int, MoveDirectionality] + +# Batch indices of any removed requests. +RemovedRequest = int + + +@dataclass(frozen=True) +class BatchUpdate: + """Persistent batch state change info for logitsprocs""" + batch_size: int # Current num reqs in batch + + # Metadata for requests added to, removed from, and moved + # within the persistent batch. + # + # Key assumption: the `output_tok_ids` list (which is an element of each + # tuple in `added`) is a reference to the request's running output tokens + # list; via this reference, the logits processors always see the latest + # list of generated output tokens + removed: Sequence[RemovedRequest] + moved: Sequence[MovedRequest] + added: Sequence[AddedRequest] + + +class LogitsProcessor(ABC): + + @abstractmethod + def __init__(self, vllm_config: "VllmConfig", device: torch.device, + is_pin_memory: bool) -> None: + raise NotImplementedError + + @abstractmethod + def apply(self, logits: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + @abstractmethod + def is_argmax_invariant(self) -> bool: + """True if logits processor has no impact on the + argmax computation in greedy sampling. + NOTE: may or may not have the same value for all + instances of a given LogitsProcessor subclass, + depending on subclass implementation. + """ + raise NotImplementedError + + @abstractmethod + def update_state( + self, + batch_update: Optional["BatchUpdate"], + ) -> None: + """Called when there are new output tokens, prior + to each forward pass. + + Args: + batch_update is non-None iff there have been + changes to the batch makeup. + """ + raise NotImplementedError diff --git a/vllm/v1/sample/logits_processor/state.py b/vllm/v1/sample/logits_processor/state.py new file mode 100644 index 000000000000..0f58b5249695 --- /dev/null +++ b/vllm/v1/sample/logits_processor/state.py @@ -0,0 +1,149 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterator +from itertools import chain +from typing import TYPE_CHECKING, Optional + +from vllm.v1.sample.logits_processor.interface import (AddedRequest, + BatchUpdate, + MovedRequest, + RemovedRequest) + +if TYPE_CHECKING: + from vllm.v1.sample.logits_processor.interface import LogitsProcessor + + +class BatchUpdateBuilder: + """Helps track persistent batch state changes and build + a batch update data structure for logitsprocs + Assumptions: + * All information about requests removed from persistent batch + during a step is aggregated in self._removed through calls to + self.removed_append() at the beginning of a step. This must happen + before the first time that self.removed, self.pop_removed() + or self.peek_removed() are invoked in a given step + * After the first time that self.removed, self.pop_removed() + or self.peek_removed() are read in a step, no new removals + are registered using self.removed_append() + * Elements of self._removed are never directly modified, added or + removed (i.e. modification is only via self.removed_append() and + self.pop_removed()) + Guarantees under above assumptions: + * self.removed is always sorted in descending order + * self.pop_removed() and self.peek_removed() both return + the lowest removed request index in the current step + """ + + _removed: list[RemovedRequest] + _is_removed_sorted: bool + moved: list[MovedRequest] + added: list[AddedRequest] + + def __init__( + self, + removed: Optional[list[RemovedRequest]] = None, + moved: Optional[list[MovedRequest]] = None, + added: Optional[list[AddedRequest]] = None, + ) -> None: + self._removed = removed or [] + self.moved = moved or [] + self.added = added or [] + self._is_removed_sorted = False + + def _ensure_removed_sorted(self) -> None: + """Sort removed request indices in + descending order. + Idempotent after first call in a + given step, until reset. + """ + if not self._is_removed_sorted: + self._removed.sort(reverse=True) + self._is_removed_sorted = True + + @property + def removed(self) -> list[RemovedRequest]: + """Removed request indices sorted in + descending order""" + self._ensure_removed_sorted() + return self._removed + + def removed_append(self, index: int) -> None: + """Register the removal of a request from the persistent batch. + + Must not be called after the first time self.removed, + self.pop_removed() or self.peek_removed() are invoked. + + Args: + index: request index + """ + if self._is_removed_sorted: + raise RuntimeError("Cannot register new removed request after" + " self.removed has been read.") + self._removed.append(index) + + def has_removed(self) -> bool: + return bool(self._removed) + + def peek_removed(self) -> Optional[int]: + """Return lowest removed request index""" + if self.has_removed(): + self._ensure_removed_sorted() + return self._removed[-1] + return None + + def pop_removed(self) -> Optional[int]: + """Pop lowest removed request index""" + if self.has_removed(): + self._ensure_removed_sorted() + return self._removed.pop() + return None + + def _is_update(self) -> bool: + """True if there is a batch state change""" + return any((self._removed, self.moved, self.added)) + + def get_and_reset(self, batch_size: int) -> Optional[BatchUpdate]: + """Generate a logitsprocs batch update data structure and reset + internal batch update builder state. + + Args: + batch_size: current persistent batch size + + Returns: + Frozen logitsprocs batch update instance; `None` if no updates + """ + # Reset removal-sorting logic + self._is_removed_sorted = False + if not self._is_update(): + # No update; short-circuit + return None + # Build batch state update + batch_update = BatchUpdate( + batch_size=batch_size, + removed=self._removed, + moved=self.moved, + added=self.added, + ) + self._removed = [] + self.moved = [] + self.added = [] + return batch_update + + +class LogitsProcessors: + """Encapsulates initialized logitsproc objects.""" + + def __init__( + self, + logitsprocs: Optional[Iterator["LogitsProcessor"]] = None) -> None: + self.argmax_invariant: list[LogitsProcessor] = [] + self.non_argmax_invariant: list[LogitsProcessor] = [] + if logitsprocs: + for logitproc in logitsprocs: + (self.argmax_invariant if logitproc.is_argmax_invariant() else + self.non_argmax_invariant).append(logitproc) + + @property + def all(self) -> Iterator["LogitsProcessor"]: + """Iterator over all logits processors.""" + return chain(self.argmax_invariant, self.non_argmax_invariant) diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index 1189b12f3077..9d6a87cea3d0 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -6,7 +6,7 @@ import torch -from vllm.v1.sample.logits_processor import LogitsProcessorManager +from vllm.v1.sample.logits_processor import LogitsProcessors @dataclass @@ -40,4 +40,4 @@ class SamplingMetadata: bad_words_token_ids: dict[int, list[list[int]]] # Loaded logits processors - logitsprocs: LogitsProcessorManager + logitsprocs: LogitsProcessors diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 2469e09f8249..e718d9d5e0fb 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -18,8 +18,8 @@ from vllm.v1.outputs import LogprobsTensors from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.logits_processor import (BatchUpdateBuilder, - MoveDirectionality, - init_builtin_logitsprocs) + LogitsProcessors, + MoveDirectionality) from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.utils import is_spec_decode_unsupported from vllm.v1.utils import copy_slice @@ -78,8 +78,11 @@ def __init__( pin_memory: bool, vocab_size: int, block_sizes: list[int], # The block_size of each kv cache group + logitsprocs: Optional[LogitsProcessors] = None, is_spec_decode: bool = False, + is_pooling_model: bool = False, ): + self.is_pooling_model = is_pooling_model self.is_spec_decode = is_spec_decode self.max_num_reqs = max_num_reqs self.max_model_len = max_model_len @@ -221,14 +224,6 @@ def __init__( # updates. Should reset each step. self.batch_update_builder = BatchUpdateBuilder() - # Define logits processors. - # TODO(andy): logits processor list should be extensible via engine - # constructor argument; for now the list is fixed. - self.logitsprocs = init_builtin_logitsprocs( - pin_memory_available=pin_memory, - max_num_reqs=max_num_reqs + 1, - device=device) - # TODO convert this to LogitsProcessor self.has_allowed_token_ids: set[str] = set() # NOTE(lufang): In the mask tensor, if the corresponding token allowed, @@ -244,6 +239,10 @@ def __init__( self.req_output_token_ids: list[Optional[list[int]]] = [] + # Store provided logitsprocs. If none are provided, initialize empty + # data structure + self.logitsprocs = logitsprocs or LogitsProcessors() + # This is updated each time the batch constituents change. self.sampling_metadata = self._make_sampling_metadata() @@ -255,28 +254,35 @@ def req_ids(self) -> list[str]: # while performing state updates to the batch. return cast(list[str], self._req_ids) - def _get_next_add_index(self) -> int: - if (req_index := self.batch_update_builder.pop_removed()) is not None: - # Fill the empty index. - return req_index - # Append to end - return self.num_reqs - def _register_add_request(self, request: "CachedRequestState") -> int: - """Track add-request operations""" - req_index = self._get_next_add_index() - assert req_index < self.max_num_reqs - params = (request.sampling_params - if request.sampling_params else request.pooling_params) + """Track add-request operations for logits processors. + Not applicable to pooling models. + """ + + # Detailed added request metadata is only required for non-pooling + # models, to support logitsprocs + assert request.sampling_params + + # Fill the next empty index if there is one. + if (new_req_index := self.batch_update_builder.pop_removed()) is None: + # Append to end otherwise. + new_req_index = self.num_reqs + + assert new_req_index < self.max_num_reqs self.batch_update_builder.added.append( - (req_index, params, request.output_token_ids)) - return req_index + (new_req_index, request.sampling_params, request.prompt_token_ids, + request.output_token_ids)) + return new_req_index def add_request( self, request: "CachedRequestState", ) -> int: - req_index = self._register_add_request(request) + if not self.is_pooling_model: + # New request index bookkeeping for autoregressive models. + req_index = self._register_add_request(request) + else: + req_index = self.num_reqs req_id = request.req_id if req_index == len(self._req_ids): @@ -411,7 +417,10 @@ def remove_request(self, req_id: str) -> Optional[int]: req_index = self.req_id_to_index.pop(req_id, None) if req_index is None: return None - self.batch_update_builder.removed_append(req_index) + if not self.is_pooling_model: + # Autoregressive models require bookkeeping of removed requests to + # support logitsprocs. + self.batch_update_builder.removed_append(req_index) self._req_ids[req_index] = None self.req_output_token_ids[req_index] = None @@ -446,6 +455,8 @@ def remove_request(self, req_id: str) -> Optional[int]: return req_index def swap_states(self, i1: int, i2: int) -> None: + # For autoregressive models, track detailed request reordering info + # to support logitsprocs self.batch_update_builder.moved.append( (i1, i2, MoveDirectionality.SWAP)) old_id_i1 = self._req_ids[i1] @@ -513,11 +524,18 @@ def condense(self) -> None: swaps: list of (from,to) swap tuples for moved requests empty_req_indices: indices not filled by condensation """ + num_reqs = self.num_reqs + + if self.is_pooling_model: + # Will be contiguous in pooling case, just trim the lists. + del self._req_ids[num_reqs:] + del self.req_output_token_ids[num_reqs:] + return + if not (empty_req_indices := self.batch_update_builder.removed): # All removed requests were replaced by added requests, or else no # requests were removed at all. No condense() needed return - num_reqs = self.num_reqs if num_reqs == 0: # The batched states are empty. self._req_ids.clear() @@ -541,6 +559,8 @@ def condense(self) -> None: # Move active request down into empty request # index. self.batch_update_builder.pop_removed() + # Autoregressive models require detailed tracking of condense + # operations to support logitsprocs self.batch_update_builder.moved.append( (last_req_index, empty_index, MoveDirectionality.UNIDIRECTIONAL)) @@ -596,15 +616,20 @@ def condense(self) -> None: last_req_index -= 1 # Trim lists to the batch size. - del self._req_ids[self.num_reqs:] - del self.req_output_token_ids[self.num_reqs:] + del self._req_ids[num_reqs:] + del self.req_output_token_ids[num_reqs:] def refresh_metadata(self): - """Apply batch updates, reset input batch at end of step + """Apply any batch updates to sampling metadata.""" - * Apply batch add/remove/permute to logits procs' states - * If batch state is modified, update sampling metadata - """ + if self.is_pooling_model: + # Batch changes every step for pooling models. + self.sampling_metadata = self._make_sampling_metadata() + return + + # For non-pooling models - generate and apply logitsprocs update; + # reset batch update tracking. + # Update sampling metadata if batch state is changed. batch_update = self.batch_update_builder.get_and_reset(self.num_reqs) for logit_proc in self.logitsprocs.all: logit_proc.update_state(batch_update) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9460d91c5832..19d13dfef6d5 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -68,6 +68,7 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, ModelRunnerOutput) from vllm.v1.pool.metadata import PoolingMetadata +from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.sample.sampler import Sampler @@ -80,7 +81,6 @@ KVConnectorModelRunnerMixin, KVConnectorOutput) from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin -from ..sample.logits_processor import LogitsProcessorManager from .utils import (AttentionGroup, MultiModalBudget, bind_kv_cache, gather_mm_placeholders, initialize_kv_cache_for_kv_sharing, sanity_check_mm_encoder_outputs, scatter_mm_placeholders) @@ -221,6 +221,11 @@ def __init__( vocab_size=self.model_config.get_vocab_size(), block_sizes=[self.cache_config.block_size], is_spec_decode=bool(self.vllm_config.speculative_config), + logitsprocs=build_logitsprocs( + self.vllm_config, self.device, self.pin_memory, + self.is_pooling_model, + self.vllm_config.model_config.logits_processors), + is_pooling_model=self.is_pooling_model, ) # TODO(woosuk): Provide an option to tune the max cudagraph batch size. @@ -2446,7 +2451,7 @@ def _dummy_sampler_run( output_token_ids=[[] for _ in range(num_reqs)], allowed_token_ids_mask=None, bad_words_token_ids={}, - logitsprocs=LogitsProcessorManager(), + logitsprocs=LogitsProcessors(), ) try: sampler_output = self.sampler(logits=logits, @@ -2967,6 +2972,8 @@ def may_reinitialize_input_batch(self, vocab_size=self.model_config.get_vocab_size(), block_sizes=block_sizes, is_spec_decode=bool(self.vllm_config.speculative_config), + logitsprocs=self.input_batch.logitsprocs, + is_pooling_model=self.is_pooling_model, ) def _allocate_kv_cache_tensors(