From 1d0b4db5e8048b44139af9ba06361404739028a0 Mon Sep 17 00:00:00 2001 From: Noam Gat Date: Wed, 25 Oct 2023 12:09:16 +0300 Subject: [PATCH 1/5] Added logits processor API to sampling params --- tests/samplers/test_sampler.py | 34 +++++++++++++++++++++++++++ vllm/model_executor/layers/sampler.py | 22 +++++++++++++++++ vllm/sampling_params.py | 13 +++++++++- 3 files changed, 68 insertions(+), 1 deletion(-) diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index c4d33711cc9a..eec0d9ff7972 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -183,3 +183,37 @@ def test_sampler_mixed(seed: int): continue for nth_output in sequence_output.samples: assert nth_output.output_token in expected_tokens + + +@pytest.mark.parametrize("seed", RANDOM_SEEDS) +def test_sampler_logits_processors(seed: int): + set_random_seed(seed) + batch_size = random.randint(1, 256) + input_tensor, _, sampler, worker = _prepare_test(batch_size) + + # This sample logits processor gives infinite score to the i-th token, + # where i is the length of the input sequence. + # We therefore expect the output token sequence to be [0, 1, 2, ...] + def pick_ith(token_ids, logits): + logits[len(token_ids)] = float("inf") + return logits + + seq_group_metadata_list = [] + for i in range(batch_size): + seq_group_metadata_list.append( + SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=True, + seq_data={0: SequenceData([1, 2, 3])}, + sampling_params=SamplingParams(temperature=0, + logits_processors=[pick_ith]), + block_tables={0: [1]}, + )) + + _, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list) + sampler_output = sampler(embedding=None, + hidden_states=input_tensor, + input_metadata=input_metadata) + for i, sequence_output in enumerate(sampler_output): + for idx, nth_output in enumerate(sequence_output.samples): + assert nth_output.output_token == idx diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index a12c82a21f46..b44ebe78c594 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -47,6 +47,8 @@ def forward( logits = _get_logits(hidden_states, embedding, embedding_bias, self.vocab_size) + # Apply logits processors (if any). + logits = _apply_logits_processors(logits, input_metadata) # Apply presence and frequency penalties. output_tokens = _get_output_tokens(input_metadata) assert len(output_tokens) == logits.shape[0] @@ -170,6 +172,26 @@ def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]: return output_tokens +def _apply_logits_processors(logits: torch.Tensor, + input_metadata: InputMetadata) -> torch.Tensor: + logits_row_idx = 0 + found_logits_processors = False + for seq_ids, sampling_params in input_metadata.seq_groups: + logits_processors = sampling_params.logits_processors + for seq_id in seq_ids: + if logits_processors: + found_logits_processors = True + logits_row = logits[logits_row_idx] + token_ids = input_metadata.seq_data[seq_id].output_token_ids + for logits_processor in logits_processors: + logits_row = logits_processor(token_ids, logits_row) + logits[logits_row_idx] = logits_row + logits_row_idx += 1 + if found_logits_processors: + assert logits_row_idx == logits.shape[0] + return logits + + def _apply_penalties( logits: torch.Tensor, output_tokens: List[List[int]], diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 10e97d1fcb19..e507b9ac9c00 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -1,7 +1,8 @@ """Sampling parameters for text generation.""" from enum import IntEnum from functools import cached_property -from typing import List, Optional, Union +from typing import Callable, List, Optional, Union +import torch _SAMPLING_EPS = 1e-5 @@ -12,6 +13,12 @@ class SamplingType(IntEnum): BEAM = 2 +LogitsProcessor = Callable[[List[int], torch.Tensor], torch.Tensor] +"""LogitsProcessor is a function that takes a list of previously generated +tokens and a tensor of the logits for the next token, and returns a modified +tensor of logits to sample from.""" + + class SamplingParams: """Sampling parameters for text generation. @@ -67,6 +74,8 @@ class SamplingParams: `logprobs+1` elements in the response. prompt_logprobs: Number of log probabilities to return per prompt token. skip_special_tokens: Whether to skip special tokens in the output. + logits_processors: List of functions that modify logits based on + previously generated tokens. """ def __init__( @@ -88,6 +97,7 @@ def __init__( logprobs: Optional[int] = None, prompt_logprobs: Optional[int] = None, skip_special_tokens: bool = True, + logits_processors: Optional[List[LogitsProcessor]] = None, ) -> None: self.n = n self.best_of = best_of if best_of is not None else n @@ -114,6 +124,7 @@ def __init__( self.logprobs = logprobs self.prompt_logprobs = prompt_logprobs self.skip_special_tokens = skip_special_tokens + self.logits_processors = logits_processors self._verify_args() if self.use_beam_search: From 58e528d51266504bc7d27c337d3e7e05895955c2 Mon Sep 17 00:00:00 2001 From: Noam Gat Date: Tue, 31 Oct 2023 07:34:20 +0200 Subject: [PATCH 2/5] Lint fix --- vllm/sampling_params.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index d48284c2c1ea..f8ef9be7b6a6 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -135,7 +135,6 @@ def __init__( self.skip_special_tokens = skip_special_tokens self.spaces_between_special_tokens = spaces_between_special_tokens self.logits_processors = logits_processors - self._verify_args() if self.use_beam_search: self._verify_beam_search() From 57ba07aa0729934ff1a007b6f5963aa14a573f7a Mon Sep 17 00:00:00 2001 From: Noam Gat Date: Thu, 2 Nov 2023 00:00:31 +0200 Subject: [PATCH 3/5] Reduced runtime footprint to zero if there are no logits processors --- vllm/model_executor/layers/sampler.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 85d12ab40938..d24295ffe685 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -163,8 +163,8 @@ def _apply_logits_processors(logits: torch.Tensor, found_logits_processors = False for seq_ids, sampling_params in input_metadata.seq_groups: logits_processors = sampling_params.logits_processors - for seq_id in seq_ids: - if logits_processors: + if logits_processors: + for seq_id in seq_ids: found_logits_processors = True logits_row = logits[logits_row_idx] token_ids = input_metadata.seq_data[seq_id].output_token_ids @@ -172,6 +172,8 @@ def _apply_logits_processors(logits: torch.Tensor, logits_row = logits_processor(token_ids, logits_row) logits[logits_row_idx] = logits_row logits_row_idx += 1 + else: + logits_row_idx += len(seq_ids) if found_logits_processors: assert logits_row_idx == logits.shape[0] return logits From f78c78613c9eb33f93ca6ae94d51d61432a66cc2 Mon Sep 17 00:00:00 2001 From: Noam Gat Date: Thu, 2 Nov 2023 00:02:41 +0200 Subject: [PATCH 4/5] Code cleanup --- vllm/model_executor/layers/sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index d24295ffe685..929896b38d4a 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -164,8 +164,8 @@ def _apply_logits_processors(logits: torch.Tensor, for seq_ids, sampling_params in input_metadata.seq_groups: logits_processors = sampling_params.logits_processors if logits_processors: + found_logits_processors = True for seq_id in seq_ids: - found_logits_processors = True logits_row = logits[logits_row_idx] token_ids = input_metadata.seq_data[seq_id].output_token_ids for logits_processor in logits_processors: From 417fc0033885602fcbbe19046b921585f90b33a6 Mon Sep 17 00:00:00 2001 From: Noam Gat Date: Fri, 3 Nov 2023 23:04:27 +0200 Subject: [PATCH 5/5] Update vllm/model_executor/layers/sampler.py Co-authored-by: Simon Mo --- vllm/model_executor/layers/sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 929896b38d4a..e0ec42081179 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -171,7 +171,7 @@ def _apply_logits_processors(logits: torch.Tensor, for logits_processor in logits_processors: logits_row = logits_processor(token_ids, logits_row) logits[logits_row_idx] = logits_row - logits_row_idx += 1 + logits_row_idx += 1 else: logits_row_idx += len(seq_ids) if found_logits_processors: