diff --git a/tests/v1/sample/test_rejection_sampler.py b/tests/v1/sample/test_rejection_sampler.py index 3e810e525e1c..956d91c6daf7 100644 --- a/tests/v1/sample/test_rejection_sampler.py +++ b/tests/v1/sample/test_rejection_sampler.py @@ -43,6 +43,7 @@ def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata: output_token_ids=[], min_tokens={}, logit_bias=[None] * batch_size, + allowed_token_ids_mask=None, ) diff --git a/tests/v1/sample/test_sampler.py b/tests/v1/sample/test_sampler.py index 3f6301c54267..34fba5a9f6d7 100644 --- a/tests/v1/sample/test_sampler.py +++ b/tests/v1/sample/test_sampler.py @@ -57,6 +57,26 @@ def _create_logit_bias( return res +def _create_allowed_token_ids( + batch_size: int, + vocab_size: int, + num_allowed_token_ids: int, + device: torch.device, +) -> Optional[torch.Tensor]: + mask: Optional[torch.Tensor] = None + for i in range(batch_size): + if i % 2 == 1: + continue + if mask is None: + mask = torch.zeros((batch_size, vocab_size), + dtype=torch.bool, + device=device) + start = min(i, vocab_size - 1) + end = min(i + num_allowed_token_ids, vocab_size - 1) + mask[i, start:end] = True + return mask + + def _create_default_sampling_metadata( num_output_tokens: int, batch_size: int, @@ -92,6 +112,7 @@ def _create_default_sampling_metadata( no_penalties=True, min_tokens={}, logit_bias=[None] * batch_size, + allowed_token_ids_mask=None, ) return fake_sampling_metadata @@ -253,7 +274,10 @@ def test_sampler_frequency_penalty(device: str, batch_size: int, sampling_metadata.frequency_penalties = _create_penalty_tensor( batch_size, frequency_penalty, torch.device(device)) output_token_ids, sorted_token_ids_in_output = \ - _create_weighted_output_token_list(batch_size, VOCAB_SIZE) + _create_weighted_output_token_list( + batch_size, + VOCAB_SIZE, + ) sampling_metadata.output_token_ids = output_token_ids sampling_metadata.no_penalties = False sampler = Sampler() @@ -262,8 +286,8 @@ def test_sampler_frequency_penalty(device: str, batch_size: int, for batch_idx in range(batch_size): non_penalized_token_id = logits[batch_idx].argmax().item() penalized_token_id = logits[batch_idx].argmin().item() - distinct_sorted_token_ids_in_output = \ - sorted_token_ids_in_output[batch_idx] + distinct_sorted_token_ids_in_output = sorted_token_ids_in_output[ + batch_idx] most_frequent_token_id = distinct_sorted_token_ids_in_output[ len(distinct_sorted_token_ids_in_output) - 1] if frequency_penalty > 0: @@ -272,8 +296,8 @@ def test_sampler_frequency_penalty(device: str, batch_size: int, # non-penalized token ID is not present in the output, while the # most penalized token is the one that occurs most frequently in # the output. - assert non_penalized_token_id \ - not in distinct_sorted_token_ids_in_output + assert (non_penalized_token_id + not in distinct_sorted_token_ids_in_output) assert penalized_token_id == most_frequent_token_id elif frequency_penalty < 0: # If `frequency_penalty` is set to < 0, it indicates @@ -282,8 +306,7 @@ def test_sampler_frequency_penalty(device: str, batch_size: int, # in the output, while the penalized token ID is one that has not # yet appeared. assert non_penalized_token_id == most_frequent_token_id - assert penalized_token_id \ - not in distinct_sorted_token_ids_in_output + assert penalized_token_id not in distinct_sorted_token_ids_in_output @pytest.mark.parametrize("device", CUDA_DEVICES) @@ -318,18 +341,18 @@ def test_sampler_repetition_penalty(device: str, batch_size: int, # If `repetition_penalty` > 1.0, verify that the non-penalized # token ID has not been seen before, while the penalized token ID # exists either in the prompt or the output. - assert (non_penalized_token_id not in prompt_tokens and \ - non_penalized_token_id not in output_tokens) - assert (penalized_token_id in prompt_tokens or \ - penalized_token_id in output_tokens) + assert (non_penalized_token_id not in prompt_tokens + and non_penalized_token_id not in output_tokens) + assert (penalized_token_id in prompt_tokens + or penalized_token_id in output_tokens) elif repetition_penalty < 1.0: # If `repetition_penalty` < 1.0, verify that the penalized # token ID has not been seen before, while the non-penalized # token ID exists either in the prompt or the output. - assert (penalized_token_id not in prompt_tokens and \ - penalized_token_id not in output_tokens) - assert (non_penalized_token_id in prompt_tokens or \ - non_penalized_token_id in output_tokens) + assert (penalized_token_id not in prompt_tokens + and penalized_token_id not in output_tokens) + assert (non_penalized_token_id in prompt_tokens + or non_penalized_token_id in output_tokens) @pytest.mark.parametrize("device", CUDA_DEVICES) @@ -404,3 +427,44 @@ def test_sampler_logit_bias(device: str, batch_size: int, bias_value: float): 1e-2) else: assert logits_for_req[token_id] == pytest.approx(1e-2) + + +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("batch_size", [1, 2, 32]) +@pytest.mark.parametrize("num_allowed_token_ids", [0, 1, 2]) +def test_sampler_allowed_token_ids(device: str, batch_size: int, + num_allowed_token_ids: int): + """ + Test to verify that when the repetition penalty is enabled, tokens + are penalized based on their presence in the prompt or the existing + output. + """ + torch.set_default_device(device) + # Create fake logits where each token is assigned the same + # logit value. + fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) + sampling_metadata = _create_default_sampling_metadata( + NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) + mask = _create_allowed_token_ids( + batch_size=batch_size, + vocab_size=VOCAB_SIZE, + num_allowed_token_ids=num_allowed_token_ids, + device=device, + ) + sampling_metadata.allowed_token_ids_mask = mask + sampler = Sampler() + logits = sampler.apply_allowed_token_ids(fake_logits, sampling_metadata) + logits = logits.cpu() + for batch_idx in range(batch_size): + logits_for_req = logits[batch_idx] + if batch_idx % 2 == 1: + assert torch.all(logits_for_req != -float("inf")) + continue + for token_id in range(VOCAB_SIZE): + start = min(batch_idx, VOCAB_SIZE - 1) + end = min(batch_idx + num_allowed_token_ids, VOCAB_SIZE - 1) + if token_id >= start and token_id < end: + assert logits_for_req[token_id] == -float( + "inf"), f"{batch_idx}, {token_id}" + else: + assert logits_for_req[token_id] != -float("inf") diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index cb3b3d21fbb3..0aee266264ac 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -66,6 +66,10 @@ def _construct_expected_sampling_metadata( temperature = [0.0 for _ in range(num_reqs)] min_tokens = {} logit_bias = [None] * num_reqs + allowed_token_ids_mask = torch.zeros(num_reqs, + VOCAB_SIZE, + dtype=torch.bool, + device=device) for req in reqs: if req.req_id not in req_ids_retained: continue @@ -86,6 +90,10 @@ def _construct_expected_sampling_metadata( req.sampling_params.min_tokens, req.sampling_params.all_stop_token_ids) logit_bias[index_in_input_batch] = req.sampling_params.logit_bias + if req.sampling_params.allowed_token_ids: + allowed_token_ids_mask[index_in_input_batch][ + req.sampling_params.allowed_token_ids] = True + return SamplingMetadata( temperature=torch.tensor(temperature, dtype=torch.float, device=device), @@ -121,6 +129,7 @@ def _construct_expected_sampling_metadata( and all(x == 0 for x in frequency_penalties) and all(x == 1 for x in repetition_penalties)), logit_bias=logit_bias, + allowed_token_ids_mask=allowed_token_ids_mask, ) @@ -242,3 +251,7 @@ def same(t1: Optional[torch.Tensor], t2: Optional[torch.Tensor]) -> bool: assert expected_sampling_metadata.no_penalties == \ sampling_metadata.no_penalties assert expected_sampling_metadata.logit_bias == sampling_metadata.logit_bias + if sampling_metadata.allowed_token_ids_mask: + assert torch.allclose( + expected_sampling_metadata.allowed_token_ids_mask, + sampling_metadata.allowed_token_ids_mask) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index b7eee5a39972..2547cebaede7 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -83,6 +83,19 @@ def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None: raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") + def _validate_allowed_token_ids( + self, + params: Union[SamplingParams, PoolingParams], + ) -> None: + if not isinstance(params, SamplingParams): + return + if params.allowed_token_ids is None: + return + if not all(0 <= tid < self.model_config.vocab_size + for tid in params.allowed_token_ids): + raise ValueError( + "allowed_token_ids contains out-of-vocab token id") + def process_inputs( self, request_id: str, @@ -100,6 +113,7 @@ def process_inputs( self._validate_logprobs(params) self._validate_lora(lora_request) + self._validate_allowed_token_ids(params) if arrival_time is None: arrival_time = time.time() diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index 6d82d3a79c8e..9f7770bbd078 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -37,3 +37,7 @@ class SamplingMetadata: min_tokens: Dict[int, Tuple[int, Set[int]]] logit_bias: List[Optional[Dict[int, float]]] + + # `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size, + # vocab size). + allowed_token_ids_mask: Optional[torch.Tensor] diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index ff978b3b6c41..47ec26d42024 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -47,6 +47,8 @@ def forward( # Use float32 for the logits. logits = logits.to(torch.float32) + # Apply allowed token ids. + logits = self.apply_allowed_token_ids(logits, sampling_metadata) # Apply logits bias. logits = self.apply_logits_bias(logits, sampling_metadata) # Apply penalties (e.g., min_tokens, freq_penalties). @@ -184,11 +186,13 @@ def apply_penalties( if not sampling_metadata.no_penalties: assert sampling_metadata.prompt_token_ids is not None logits = apply_all_penalties( - logits, sampling_metadata.prompt_token_ids, + logits, + sampling_metadata.prompt_token_ids, sampling_metadata.presence_penalties, sampling_metadata.frequency_penalties, sampling_metadata.repetition_penalties, - sampling_metadata.output_token_ids) + sampling_metadata.output_token_ids, + ) return logits def apply_min_p( @@ -226,3 +230,13 @@ def apply_logits_bias( for token_id, bias in logit_bias.items(): logits[i, token_id] += bias return logits + + def apply_allowed_token_ids( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> torch.Tensor: + if sampling_metadata.allowed_token_ids_mask is not None: + logits.masked_fill_(sampling_metadata.allowed_token_ids_mask, + float("-inf")) + return logits diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index bd1c369acb30..d9fc53490c07 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -143,7 +143,7 @@ def __init__( device="cpu", pin_memory=pin_memory) self.frequency_penalties_cpu = \ - self.frequency_penalties_cpu_tensor.numpy() + self.frequency_penalties_cpu_tensor.numpy() self.frequency_penalties_reqs: Set[str] = set() # Presence penalty related data structures @@ -168,7 +168,7 @@ def __init__( device="cpu", pin_memory=pin_memory) self.repetition_penalties_cpu = \ - self.repetition_penalties_cpu_tensor.numpy() + self.repetition_penalties_cpu_tensor.numpy() self.repetition_penalties_reqs: Set[str] = set() # req_index -> (min_tokens, stop_token_ids) @@ -192,6 +192,9 @@ def __init__( self.logit_bias: List[Optional[Dict[int, float]]] = [None] * max_num_reqs + self.has_allowed_token_ids: Set[str] = set() + self.allowed_token_ids_mask: Optional[torch.Tensor] = None + self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None self.req_output_token_ids: List[Optional[List[int]]] = [] @@ -287,6 +290,22 @@ def add_request( if sampling_params.logit_bias is not None: self.logit_bias[req_index] = sampling_params.logit_bias + if sampling_params.allowed_token_ids: + self.has_allowed_token_ids.add(req_id) + if self.allowed_token_ids_mask_cpu_tensor is None: + # Lazy allocation for this tensor, which can be large. + self.allowed_token_ids_mask = torch.zeros(self.max_num_reqs, + self.vocab_size, + dtype=torch.bool, + device=self.device) + self.allowed_token_ids_mask_cpu_tensor = torch.zeros( + self.max_num_reqs, + self.vocab_size, + dtype=torch.bool, + device="cpu") + self.allowed_token_ids_mask_cpu_tensor[req_index][ + sampling_params.allowed_token_ids] = True + # Add request lora ID if request.lora_request: lora_id = request.lora_request.lora_int_id @@ -332,6 +351,9 @@ def remove_request(self, req_id: str) -> Optional[int]: self.request_lora_mapping[req_index] = 0 self.logit_bias[req_index] = None + self.has_allowed_token_ids.discard(req_id) + if self.allowed_token_ids_mask_cpu_tensor is not None: + self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False) return req_index def condense(self, empty_req_indices: List[int]) -> None: @@ -400,6 +422,11 @@ def condense(self, empty_req_indices: List[int]) -> None: self.logit_bias[empty_index] = self.logit_bias[last_req_index] + if self.allowed_token_ids_mask_cpu_tensor is not None: + self.allowed_token_ids_mask_cpu_tensor[ + empty_index] = self.allowed_token_ids_mask_cpu_tensor[ + last_req_index] + # Decrement last_req_index since it is now empty. last_req_index -= 1 @@ -442,6 +469,13 @@ def _make_sampling_metadata(self) -> SamplingMetadata: else: prompt_token_ids = None + allowed_token_ids_mask: Optional[torch.Tensor] = None + if not self.no_allowed_token_ids: + assert self.allowed_token_ids_mask is not None + copy_slice(self.allowed_token_ids_mask_cpu_tensor, + self.allowed_token_ids_mask, num_reqs) + allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs] + return SamplingMetadata( temperature=temperature, all_greedy=self.all_greedy, @@ -460,6 +494,7 @@ def _make_sampling_metadata(self) -> SamplingMetadata: min_tokens=self.min_tokens, no_penalties=self.no_penalties, logit_bias=self.logit_bias[:num_reqs], + allowed_token_ids_mask=allowed_token_ids_mask, ) def get_sampling_metadata( @@ -550,3 +585,7 @@ def max_num_logprobs(self) -> Optional[int]: @property def no_prompt_logprob(self) -> bool: return not self.num_prompt_logprobs + + @property + def no_allowed_token_ids(self) -> bool: + return len(self.has_allowed_token_ids) == 0