diff --git a/tests/v1/sample/test_sampler.py b/tests/v1/sample/test_sampler.py index a4bd651f8224..d7cc2b8b7740 100644 --- a/tests/v1/sample/test_sampler.py +++ b/tests/v1/sample/test_sampler.py @@ -412,3 +412,32 @@ def test_sampler_logit_bias(device: str, batch_size: int, bias_value: float): 1e-2) else: assert logits_for_req[token_id] == pytest.approx(1e-2) + + +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("batch_size", [1, 2, 32]) +@pytest.mark.parametrize("bias_value", [-0.1, 1.2]) +def test_sampler_allowed_token_ids(device: str, batch_size: int, + bias_value: float): + """ + Test to verify that when the repetition penalty is enabled, tokens + are penalized based on their presence in the prompt or the existing + output. + """ + torch.set_default_device(device) + fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) + sampling_metadata = _create_default_sampling_metadata( + NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) + allowed_token_ids = set([0, 1, 2]) + sampling_metadata.allowed_token_ids = list(allowed_token_ids) + # https://github.com/vllm-project/vllm/blob/38094584566b89210a6f72a408eba1fae43c3d81/tests/entrypoints/openai/test_completion.py#L620 + sampler = Sampler() + logits = sampler.apply_allowed_token_ids(fake_logits, sampling_metadata) + logits = logits.cpu() + for batch_idx in range(batch_size): + logits_for_req = logits[batch_idx] + for token_id in range(VOCAB_SIZE): + if token_id not in allowed_token_ids: + assert logits_for_req[token_id] == float("-inf") + else: + assert logits_for_req[token_id] > float("-inf") diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index ea64181c0aeb..1998bec7b47a 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -38,3 +38,4 @@ class SamplingMetadata: stop_token_ids: List[Set[int]] logit_bias: List[Optional[Dict[int, float]]] + allowed_token_ids: Optional[List[int]] = None diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index ec6374d12b17..d69b391f47ee 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -49,6 +49,8 @@ def forward( logits = logits.to(torch.float32) # Apply logits bias. logits = self.apply_logits_bias(logits, sampling_metadata) + # Apply allowed token ids. + logits = self.apply_allowed_token_ids(logits, sampling_metadata) # Apply penalties (e.g., min_tokens, freq_penalties). logits = self.apply_penalties(logits, sampling_metadata) # Sample the next token. @@ -227,3 +229,23 @@ def apply_logits_bias( for token_id, bias in logit_bias.items(): logits[i, token_id] += bias return logits + + def apply_allowed_token_ids( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> torch.Tensor: + if not sampling_metadata.allowed_token_ids: + return logits + vocab_size = logits.size(dim=1) + if not all(0 <= tid < vocab_size + for tid in sampling_metadata.allowed_token_ids): + raise ValueError("allowed_token_ids contains " + "out-of-vocab token id") + allowed_ids = list(sampling_metadata.allowed_token_ids) + mask = torch.ones((logits.shape[-1], ), + dtype=torch.bool, + device=logits.device) + mask[allowed_ids] = False + logits.masked_fill_(mask, float("-inf")) + return logits