Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/v1/sample/test_rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata:
output_token_ids=[],
min_tokens={},
logit_bias=[None] * batch_size,
allowed_token_ids_mask=None,
)


Expand Down
94 changes: 79 additions & 15 deletions tests/v1/sample/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,26 @@ def _create_logit_bias(
return res


def _create_allowed_token_ids(
batch_size: int,
vocab_size: int,
num_allowed_token_ids: int,
device: torch.device,
) -> Optional[torch.Tensor]:
mask: Optional[torch.Tensor] = None
for i in range(batch_size):
if i % 2 == 1:
continue
if mask is None:
mask = torch.zeros((batch_size, vocab_size),
dtype=torch.bool,
device=device)
start = min(i, vocab_size - 1)
end = min(i + num_allowed_token_ids, vocab_size - 1)
mask[i, start:end] = True
return mask


def _create_default_sampling_metadata(
num_output_tokens: int,
batch_size: int,
Expand Down Expand Up @@ -92,6 +112,7 @@ def _create_default_sampling_metadata(
no_penalties=True,
min_tokens={},
logit_bias=[None] * batch_size,
allowed_token_ids_mask=None,
)
return fake_sampling_metadata

Expand Down Expand Up @@ -253,7 +274,10 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
sampling_metadata.frequency_penalties = _create_penalty_tensor(
batch_size, frequency_penalty, torch.device(device))
output_token_ids, sorted_token_ids_in_output = \
_create_weighted_output_token_list(batch_size, VOCAB_SIZE)
_create_weighted_output_token_list(
batch_size,
VOCAB_SIZE,
)
sampling_metadata.output_token_ids = output_token_ids
sampling_metadata.no_penalties = False
sampler = Sampler()
Expand All @@ -262,8 +286,8 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
for batch_idx in range(batch_size):
non_penalized_token_id = logits[batch_idx].argmax().item()
penalized_token_id = logits[batch_idx].argmin().item()
distinct_sorted_token_ids_in_output = \
sorted_token_ids_in_output[batch_idx]
distinct_sorted_token_ids_in_output = sorted_token_ids_in_output[
batch_idx]
most_frequent_token_id = distinct_sorted_token_ids_in_output[
len(distinct_sorted_token_ids_in_output) - 1]
if frequency_penalty > 0:
Expand All @@ -272,8 +296,8 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
# non-penalized token ID is not present in the output, while the
# most penalized token is the one that occurs most frequently in
# the output.
assert non_penalized_token_id \
not in distinct_sorted_token_ids_in_output
assert (non_penalized_token_id
not in distinct_sorted_token_ids_in_output)
assert penalized_token_id == most_frequent_token_id
elif frequency_penalty < 0:
# If `frequency_penalty` is set to < 0, it indicates
Expand All @@ -282,8 +306,7 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
# in the output, while the penalized token ID is one that has not
# yet appeared.
assert non_penalized_token_id == most_frequent_token_id
assert penalized_token_id \
not in distinct_sorted_token_ids_in_output
assert penalized_token_id not in distinct_sorted_token_ids_in_output


@pytest.mark.parametrize("device", CUDA_DEVICES)
Expand Down Expand Up @@ -318,18 +341,18 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
# If `repetition_penalty` > 1.0, verify that the non-penalized
# token ID has not been seen before, while the penalized token ID
# exists either in the prompt or the output.
assert (non_penalized_token_id not in prompt_tokens and \
non_penalized_token_id not in output_tokens)
assert (penalized_token_id in prompt_tokens or \
penalized_token_id in output_tokens)
assert (non_penalized_token_id not in prompt_tokens
and non_penalized_token_id not in output_tokens)
assert (penalized_token_id in prompt_tokens
or penalized_token_id in output_tokens)
elif repetition_penalty < 1.0:
# If `repetition_penalty` < 1.0, verify that the penalized
# token ID has not been seen before, while the non-penalized
# token ID exists either in the prompt or the output.
assert (penalized_token_id not in prompt_tokens and \
penalized_token_id not in output_tokens)
assert (non_penalized_token_id in prompt_tokens or \
non_penalized_token_id in output_tokens)
assert (penalized_token_id not in prompt_tokens
and penalized_token_id not in output_tokens)
assert (non_penalized_token_id in prompt_tokens
or non_penalized_token_id in output_tokens)


@pytest.mark.parametrize("device", CUDA_DEVICES)
Expand Down Expand Up @@ -404,3 +427,44 @@ def test_sampler_logit_bias(device: str, batch_size: int, bias_value: float):
1e-2)
else:
assert logits_for_req[token_id] == pytest.approx(1e-2)


@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("num_allowed_token_ids", [0, 1, 2])
def test_sampler_allowed_token_ids(device: str, batch_size: int,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for adding this test!!

num_allowed_token_ids: int):
"""
Test to verify that when the repetition penalty is enabled, tokens
are penalized based on their presence in the prompt or the existing
output.
"""
torch.set_default_device(device)
# Create fake logits where each token is assigned the same
# logit value.
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
sampling_metadata = _create_default_sampling_metadata(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
mask = _create_allowed_token_ids(
batch_size=batch_size,
vocab_size=VOCAB_SIZE,
num_allowed_token_ids=num_allowed_token_ids,
device=device,
)
sampling_metadata.allowed_token_ids_mask = mask
sampler = Sampler()
logits = sampler.apply_allowed_token_ids(fake_logits, sampling_metadata)
logits = logits.cpu()
for batch_idx in range(batch_size):
logits_for_req = logits[batch_idx]
if batch_idx % 2 == 1:
assert torch.all(logits_for_req != -float("inf"))
continue
for token_id in range(VOCAB_SIZE):
start = min(batch_idx, VOCAB_SIZE - 1)
end = min(batch_idx + num_allowed_token_ids, VOCAB_SIZE - 1)
if token_id >= start and token_id < end:
assert logits_for_req[token_id] == -float(
"inf"), f"{batch_idx}, {token_id}"
else:
assert logits_for_req[token_id] != -float("inf")
13 changes: 13 additions & 0 deletions tests/v1/worker/test_gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ def _construct_expected_sampling_metadata(
temperature = [0.0 for _ in range(num_reqs)]
min_tokens = {}
logit_bias = [None] * num_reqs
allowed_token_ids_mask = torch.zeros(num_reqs,
VOCAB_SIZE,
dtype=torch.bool,
device=device)
for req in reqs:
if req.req_id not in req_ids_retained:
continue
Expand All @@ -86,6 +90,10 @@ def _construct_expected_sampling_metadata(
req.sampling_params.min_tokens,
req.sampling_params.all_stop_token_ids)
logit_bias[index_in_input_batch] = req.sampling_params.logit_bias
if req.sampling_params.allowed_token_ids:
allowed_token_ids_mask[index_in_input_batch][
req.sampling_params.allowed_token_ids] = True

return SamplingMetadata(
temperature=torch.tensor(temperature, dtype=torch.float,
device=device),
Expand Down Expand Up @@ -121,6 +129,7 @@ def _construct_expected_sampling_metadata(
and all(x == 0 for x in frequency_penalties)
and all(x == 1 for x in repetition_penalties)),
logit_bias=logit_bias,
allowed_token_ids_mask=allowed_token_ids_mask,
)


Expand Down Expand Up @@ -242,3 +251,7 @@ def same(t1: Optional[torch.Tensor], t2: Optional[torch.Tensor]) -> bool:
assert expected_sampling_metadata.no_penalties == \
sampling_metadata.no_penalties
assert expected_sampling_metadata.logit_bias == sampling_metadata.logit_bias
if sampling_metadata.allowed_token_ids_mask:
assert torch.allclose(
expected_sampling_metadata.allowed_token_ids_mask,
sampling_metadata.allowed_token_ids_mask)
14 changes: 14 additions & 0 deletions vllm/v1/engine/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,19 @@ def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")

def _validate_allowed_token_ids(
self,
params: Union[SamplingParams, PoolingParams],
) -> None:
if not isinstance(params, SamplingParams):
return
if params.allowed_token_ids is None:
return
if not all(0 <= tid < self.model_config.vocab_size
for tid in params.allowed_token_ids):
raise ValueError(
"allowed_token_ids contains out-of-vocab token id")

def process_inputs(
self,
request_id: str,
Expand All @@ -100,6 +113,7 @@ def process_inputs(

self._validate_logprobs(params)
self._validate_lora(lora_request)
self._validate_allowed_token_ids(params)

if arrival_time is None:
arrival_time = time.time()
Expand Down
4 changes: 4 additions & 0 deletions vllm/v1/sample/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,7 @@ class SamplingMetadata:
min_tokens: Dict[int, Tuple[int, Set[int]]]

logit_bias: List[Optional[Dict[int, float]]]

# `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size,
# vocab size).
allowed_token_ids_mask: Optional[torch.Tensor]
18 changes: 16 additions & 2 deletions vllm/v1/sample/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def forward(

# Use float32 for the logits.
logits = logits.to(torch.float32)
# Apply allowed token ids.
logits = self.apply_allowed_token_ids(logits, sampling_metadata)
# Apply logits bias.
logits = self.apply_logits_bias(logits, sampling_metadata)
# Apply penalties (e.g., min_tokens, freq_penalties).
Expand Down Expand Up @@ -184,11 +186,13 @@ def apply_penalties(
if not sampling_metadata.no_penalties:
assert sampling_metadata.prompt_token_ids is not None
logits = apply_all_penalties(
logits, sampling_metadata.prompt_token_ids,
logits,
sampling_metadata.prompt_token_ids,
sampling_metadata.presence_penalties,
sampling_metadata.frequency_penalties,
sampling_metadata.repetition_penalties,
sampling_metadata.output_token_ids)
sampling_metadata.output_token_ids,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same general comment, it's best not to include unrelated formatting changes, looks like there are a couple above and some in the other file.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, it's done by my VSCode, sorry.

I will be more cautious next time.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, confirmed, actually it's done by the yapf in the pre-commit hooks. It forced me to do such change...

return logits

def apply_min_p(
Expand Down Expand Up @@ -226,3 +230,13 @@ def apply_logits_bias(
for token_id, bias in logit_bias.items():
logits[i, token_id] += bias
return logits

def apply_allowed_token_ids(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
if sampling_metadata.allowed_token_ids_mask is not None:
logits.masked_fill_(sampling_metadata.allowed_token_ids_mask,
float("-inf"))
return logits
43 changes: 41 additions & 2 deletions vllm/v1/worker/gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def __init__(
device="cpu",
pin_memory=pin_memory)
self.frequency_penalties_cpu = \
self.frequency_penalties_cpu_tensor.numpy()
self.frequency_penalties_cpu_tensor.numpy()
self.frequency_penalties_reqs: Set[str] = set()

# Presence penalty related data structures
Expand All @@ -168,7 +168,7 @@ def __init__(
device="cpu",
pin_memory=pin_memory)
self.repetition_penalties_cpu = \
self.repetition_penalties_cpu_tensor.numpy()
self.repetition_penalties_cpu_tensor.numpy()
self.repetition_penalties_reqs: Set[str] = set()

# req_index -> (min_tokens, stop_token_ids)
Expand All @@ -192,6 +192,9 @@ def __init__(

self.logit_bias: List[Optional[Dict[int,
float]]] = [None] * max_num_reqs
self.has_allowed_token_ids: Set[str] = set()
self.allowed_token_ids_mask: Optional[torch.Tensor] = None
self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None

self.req_output_token_ids: List[Optional[List[int]]] = []

Expand Down Expand Up @@ -287,6 +290,22 @@ def add_request(
if sampling_params.logit_bias is not None:
self.logit_bias[req_index] = sampling_params.logit_bias

if sampling_params.allowed_token_ids:
self.has_allowed_token_ids.add(req_id)
if self.allowed_token_ids_mask_cpu_tensor is None:
# Lazy allocation for this tensor, which can be large.
self.allowed_token_ids_mask = torch.zeros(self.max_num_reqs,
self.vocab_size,
dtype=torch.bool,
device=self.device)
self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
self.max_num_reqs,
self.vocab_size,
dtype=torch.bool,
device="cpu")
self.allowed_token_ids_mask_cpu_tensor[req_index][
sampling_params.allowed_token_ids] = True

# Add request lora ID
if request.lora_request:
lora_id = request.lora_request.lora_int_id
Expand Down Expand Up @@ -332,6 +351,9 @@ def remove_request(self, req_id: str) -> Optional[int]:
self.request_lora_mapping[req_index] = 0

self.logit_bias[req_index] = None
self.has_allowed_token_ids.discard(req_id)
if self.allowed_token_ids_mask_cpu_tensor is not None:
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
return req_index

def condense(self, empty_req_indices: List[int]) -> None:
Expand Down Expand Up @@ -400,6 +422,11 @@ def condense(self, empty_req_indices: List[int]) -> None:

self.logit_bias[empty_index] = self.logit_bias[last_req_index]

if self.allowed_token_ids_mask_cpu_tensor is not None:
self.allowed_token_ids_mask_cpu_tensor[
empty_index] = self.allowed_token_ids_mask_cpu_tensor[
last_req_index]

# Decrement last_req_index since it is now empty.
last_req_index -= 1

Expand Down Expand Up @@ -442,6 +469,13 @@ def _make_sampling_metadata(self) -> SamplingMetadata:
else:
prompt_token_ids = None

allowed_token_ids_mask: Optional[torch.Tensor] = None
if not self.no_allowed_token_ids:
assert self.allowed_token_ids_mask is not None
copy_slice(self.allowed_token_ids_mask_cpu_tensor,
self.allowed_token_ids_mask, num_reqs)
allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]

return SamplingMetadata(
temperature=temperature,
all_greedy=self.all_greedy,
Expand All @@ -460,6 +494,7 @@ def _make_sampling_metadata(self) -> SamplingMetadata:
min_tokens=self.min_tokens,
no_penalties=self.no_penalties,
logit_bias=self.logit_bias[:num_reqs],
allowed_token_ids_mask=allowed_token_ids_mask,
)

def get_sampling_metadata(
Expand Down Expand Up @@ -550,3 +585,7 @@ def max_num_logprobs(self) -> Optional[int]:
@property
def no_prompt_logprob(self) -> bool:
return not self.num_prompt_logprobs

@property
def no_allowed_token_ids(self) -> bool:
return len(self.has_allowed_token_ids) == 0