Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions tests/v1/sample/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,3 +412,32 @@ def test_sampler_logit_bias(device: str, batch_size: int, bias_value: float):
1e-2)
else:
assert logits_for_req[token_id] == pytest.approx(1e-2)


@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("bias_value", [-0.1, 1.2])
def test_sampler_allowed_token_ids(device: str, batch_size: int,
bias_value: float):
"""
Test to verify that when the repetition penalty is enabled, tokens
are penalized based on their presence in the prompt or the existing
output.
"""
torch.set_default_device(device)
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
sampling_metadata = _create_default_sampling_metadata(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
allowed_token_ids = set([0, 1, 2])
sampling_metadata.allowed_token_ids = list(allowed_token_ids)
# https://github.com/vllm-project/vllm/blob/38094584566b89210a6f72a408eba1fae43c3d81/tests/entrypoints/openai/test_completion.py#L620
sampler = Sampler()
logits = sampler.apply_allowed_token_ids(fake_logits, sampling_metadata)
logits = logits.cpu()
for batch_idx in range(batch_size):
logits_for_req = logits[batch_idx]
for token_id in range(VOCAB_SIZE):
if token_id not in allowed_token_ids:
assert logits_for_req[token_id] == float("-inf")
else:
assert logits_for_req[token_id] > float("-inf")
1 change: 1 addition & 0 deletions vllm/v1/sample/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,4 @@ class SamplingMetadata:
stop_token_ids: List[Set[int]]

logit_bias: List[Optional[Dict[int, float]]]
allowed_token_ids: Optional[List[int]] = None
22 changes: 22 additions & 0 deletions vllm/v1/sample/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def forward(
logits = logits.to(torch.float32)
# Apply logits bias.
logits = self.apply_logits_bias(logits, sampling_metadata)
# Apply allowed token ids.
logits = self.apply_allowed_token_ids(logits, sampling_metadata)
# Apply penalties (e.g., min_tokens, freq_penalties).
logits = self.apply_penalties(logits, sampling_metadata)
# Sample the next token.
Expand Down Expand Up @@ -227,3 +229,23 @@ def apply_logits_bias(
for token_id, bias in logit_bias.items():
logits[i, token_id] += bias
return logits

def apply_allowed_token_ids(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
if not sampling_metadata.allowed_token_ids:
return logits
vocab_size = logits.size(dim=1)
if not all(0 <= tid < vocab_size
for tid in sampling_metadata.allowed_token_ids):
raise ValueError("allowed_token_ids contains "
"out-of-vocab token id")
allowed_ids = list(sampling_metadata.allowed_token_ids)
mask = torch.ones((logits.shape[-1], ),
dtype=torch.bool,
device=logits.device)
mask[allowed_ids] = False
logits.masked_fill_(mask, float("-inf"))
return logits