Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,37 @@ def test_sampler_mixed(seed: int):
continue
for nth_output in sequence_output.samples:
assert nth_output.output_token in expected_tokens


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_sampler_logits_processors(seed: int):
set_random_seed(seed)
batch_size = random.randint(1, 256)
input_tensor, _, sampler, worker = _prepare_test(batch_size)

# This sample logits processor gives infinite score to the i-th token,
# where i is the length of the input sequence.
# We therefore expect the output token sequence to be [0, 1, 2, ...]
def pick_ith(token_ids, logits):
logits[len(token_ids)] = float("inf")
return logits

seq_group_metadata_list = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=SamplingParams(temperature=0,
logits_processors=[pick_ith]),
block_tables={0: [1]},
))

_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
input_metadata=input_metadata)
for i, sequence_output in enumerate(sampler_output):
for idx, nth_output in enumerate(sequence_output.samples):
assert nth_output.output_token == idx
24 changes: 24 additions & 0 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def forward(
logits = _get_logits(hidden_states, embedding, embedding_bias,
self.vocab_size)

# Apply logits processors (if any).
logits = _apply_logits_processors(logits, input_metadata)
# Apply presence and frequency penalties.
output_tokens = _get_output_tokens(input_metadata)
assert len(output_tokens) == logits.shape[0]
Expand Down Expand Up @@ -155,6 +157,28 @@ def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
return output_tokens


def _apply_logits_processors(logits: torch.Tensor,
input_metadata: InputMetadata) -> torch.Tensor:
logits_row_idx = 0
found_logits_processors = False
for seq_ids, sampling_params in input_metadata.seq_groups:
logits_processors = sampling_params.logits_processors
if logits_processors:
found_logits_processors = True
for seq_id in seq_ids:
logits_row = logits[logits_row_idx]
token_ids = input_metadata.seq_data[seq_id].output_token_ids
for logits_processor in logits_processors:
logits_row = logits_processor(token_ids, logits_row)
logits[logits_row_idx] = logits_row
logits_row_idx += 1
else:
logits_row_idx += len(seq_ids)
if found_logits_processors:
assert logits_row_idx == logits.shape[0]
return logits


def _apply_penalties(
logits: torch.Tensor,
output_tokens: List[List[int]],
Expand Down
14 changes: 12 additions & 2 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Sampling parameters for text generation."""
from enum import IntEnum
from functools import cached_property
from typing import List, Optional, Union
from typing import Callable, List, Optional, Union
import torch

_SAMPLING_EPS = 1e-5

Expand All @@ -12,6 +13,12 @@ class SamplingType(IntEnum):
BEAM = 2


LogitsProcessor = Callable[[List[int], torch.Tensor], torch.Tensor]
"""LogitsProcessor is a function that takes a list of previously generated
tokens and a tensor of the logits for the next token, and returns a modified
tensor of logits to sample from."""


class SamplingParams:
"""Sampling parameters for text generation.

Expand Down Expand Up @@ -73,6 +80,8 @@ class SamplingParams:
skip_special_tokens: Whether to skip special tokens in the output.
spaces_between_special_tokens: Whether to add spaces between special
tokens in the output. Defaults to True.
logits_processors: List of functions that modify logits based on
previously generated tokens.
"""

def __init__(
Expand All @@ -96,6 +105,7 @@ def __init__(
prompt_logprobs: Optional[int] = None,
skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True,
logits_processors: Optional[List[LogitsProcessor]] = None,
) -> None:
self.n = n
self.best_of = best_of if best_of is not None else n
Expand Down Expand Up @@ -124,7 +134,7 @@ def __init__(
self.prompt_logprobs = prompt_logprobs
self.skip_special_tokens = skip_special_tokens
self.spaces_between_special_tokens = spaces_between_special_tokens

self.logits_processors = logits_processors
self._verify_args()
if self.use_beam_search:
self._verify_beam_search()
Expand Down