Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions vllm/model_executor/layers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def get_token_bin_counts_and_mask(
vocab_size: int,
num_seqs: int,
) -> tuple[torch.Tensor, torch.Tensor]:
if tokens is None:
return None, None
# Compute the bin counts for the tokens.
# vocab_size + 1 for padding.
bin_counts = torch.zeros((num_seqs, vocab_size + 1),
Expand Down Expand Up @@ -73,15 +75,27 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
output_bin_counts, output_mask = get_token_bin_counts_and_mask(
output_tokens_tensor, vocab_size, num_seqs)

# Apply repetition penalties as a custom op
from vllm._custom_ops import apply_repetition_penalties
apply_repetition_penalties(logits, prompt_mask, output_mask,
repetition_penalties)
if prompt_mask is not None or output_mask is not None:
from vllm._custom_ops import apply_repetition_penalties

if prompt_mask is None:
prompt_mask = torch.zeros((num_seqs, vocab_size),
dtype=torch.bool,
device=logits.device)
if output_mask is None:
output_mask = torch.zeros((num_seqs, vocab_size),
dtype=torch.bool,
device=logits.device)

apply_repetition_penalties(logits, prompt_mask, output_mask,
repetition_penalties)

if output_bin_counts is not None:
logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts

if output_mask is not None:
logits -= presence_penalties.unsqueeze(dim=1) * output_mask

# We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
logits -= presence_penalties.unsqueeze(dim=1) * output_mask
return logits


Expand Down
25 changes: 17 additions & 8 deletions vllm/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1175,7 +1175,7 @@ def make_ndarray_with_pad(
"""
if max_len is None:
# Unlike for most functions, map is faster than a genexpr over `len`
max_len = max(map(len, x), default=0)
max_len = max(map(len, x)) if x else 0

padded_x = np.full((len(x), max_len), pad, dtype=dtype)
for ind, blocktb in enumerate(x):
Expand All @@ -1196,18 +1196,27 @@ def make_tensor_with_pad(
) -> torch.Tensor:
"""
Make a padded tensor from 2D inputs.

The padding is applied to the end of each inner list until it reaches
`max_len`.
"""
np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype]
padded_x = make_ndarray_with_pad(x, pad, np_dtype, max_len=max_len)
if max_len is None:
max_len = max(len(row) for row in x) if x else 0

padded_tensor = torch.full((len(x), max_len),
fill_value=pad,
dtype=dtype,
device=device)

for i, row in enumerate(x):
row_len = len(row)
if row_len > 0:
row_tensor = torch.as_tensor(row, dtype=dtype, device=device)
padded_tensor[i, :row_len] = row_tensor

tensor = torch.from_numpy(padded_x).to(device)
if pin_memory:
tensor = tensor.pin_memory()
if pin_memory and padded_tensor.device.type == 'cpu':
return padded_tensor.pin_memory()

return tensor
return padded_tensor


def async_tensor_h2d(
Expand Down
1 change: 0 additions & 1 deletion vllm/v1/sample/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,6 @@ def apply_penalties(
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
if not sampling_metadata.no_penalties:
assert sampling_metadata.prompt_token_ids is not None
logits = apply_all_penalties(
logits,
sampling_metadata.prompt_token_ids,
Expand Down