Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions vllm/v1/sample/logits_processor/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ def __init__(
self.added = added or []
self._is_removed_sorted = False

# Used to track changes in the pooling case
# where we don't populate the added list.
self.batch_changed = False

def _ensure_removed_sorted(self) -> None:
"""Sort removed request indices in
descending order.
Expand Down Expand Up @@ -80,6 +84,7 @@ def removed_append(self, index: int) -> None:
raise RuntimeError("Cannot register new removed request after"
" self.removed has been read.")
self._removed.append(index)
self.batch_changed = True

def has_removed(self) -> bool:
return bool(self._removed)
Expand All @@ -98,9 +103,15 @@ def pop_removed(self) -> Optional[int]:
return self._removed.pop()
return None

def _is_update(self) -> bool:
"""True if there is a batch state change"""
return any((self._removed, self.moved, self.added))
def reset(self) -> bool:
"""Returns True if there were any changes to the batch."""
self._is_removed_sorted = False
self._removed.clear()
self.moved.clear()
self.added.clear()
batch_changed = self.batch_changed
self.batch_changed = False
return batch_changed

def get_and_reset(self, batch_size: int) -> Optional[BatchUpdate]:
"""Generate a logitsprocs batch update data structure and reset
Expand All @@ -114,7 +125,8 @@ def get_and_reset(self, batch_size: int) -> Optional[BatchUpdate]:
"""
# Reset removal-sorting logic
self._is_removed_sorted = False
if not self._is_update():
self.batch_changed = False
if not any((self._removed, self.moved, self.added)):
# No update; short-circuit
return None
# Build batch state update
Expand Down
146 changes: 76 additions & 70 deletions vllm/v1/worker/gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@ def mm_inputs(self) -> list[MultiModalKwargsItems]:
def get_token_id(self, idx: int) -> int:
if idx < self.num_prompt_tokens:
return self.prompt_token_ids[idx]
else:
return self.output_token_ids[idx - self.num_prompt_tokens]
return self.output_token_ids[idx - self.num_prompt_tokens]


class InputBatch:
Expand Down Expand Up @@ -261,30 +260,27 @@ def _register_add_request(self, request: "CachedRequestState") -> int:
Not applicable to pooling models.
"""

# Detailed added request metadata is only required for non-pooling
# models, to support logitsprocs
assert request.sampling_params

# Fill the next empty index if there is one.
if (new_req_index := self.batch_update_builder.pop_removed()) is None:
# Append to end otherwise.
new_req_index = self.num_reqs

assert new_req_index < self.max_num_reqs
self.batch_update_builder.added.append(
(new_req_index, request.sampling_params, request.prompt_token_ids,
request.output_token_ids))
self.batch_update_builder.batch_changed = True
if request.sampling_params:
# Detailed added request metadata is only required for non-pooling
# models, to support logitsprocs.
self.batch_update_builder.added.append(
(new_req_index, request.sampling_params,
request.prompt_token_ids, request.output_token_ids))

return new_req_index

def add_request(
self,
request: "CachedRequestState",
) -> int:
if not self.is_pooling_model:
# New request index bookkeeping for autoregressive models.
req_index = self._register_add_request(request)
else:
req_index = self.num_reqs
req_index = self._register_add_request(request)

req_id = request.req_id
if req_index == len(self._req_ids):
Expand Down Expand Up @@ -389,7 +385,7 @@ def add_request(
self.logits_processing_needs_token_ids[req_index] = (
pooling_params.requires_token_ids)
else:
raise NotImplementedError(request)
raise NotImplementedError("Unrecognized request type")

# Add request lora ID
if request.lora_request:
Expand Down Expand Up @@ -419,13 +415,25 @@ def remove_request(self, req_id: str) -> Optional[int]:
req_index = self.req_id_to_index.pop(req_id, None)
if req_index is None:
return None
if not self.is_pooling_model:
# Autoregressive models require bookkeeping of removed requests to
# support logitsprocs.
self.batch_update_builder.removed_append(req_index)

self.batch_update_builder.removed_append(req_index)
self._req_ids[req_index] = None
self.req_output_token_ids[req_index] = None

# LoRA
lora_id = self.request_lora_mapping[req_index]
if lora_id != 0:
lora_req_ids = self.lora_id_to_request_ids[lora_id]
lora_req_ids.discard(req_id)
if not lora_req_ids:
del self.lora_id_to_request_ids[lora_id]
del self.lora_id_to_lora_request[lora_id]
self.request_lora_mapping[req_index] = 0

if self.is_pooling_model:
self.pooling_params.pop(req_id, None)
return req_index

self.greedy_reqs.discard(req_id)
self.random_reqs.discard(req_id)
self.top_p_reqs.discard(req_id)
Expand All @@ -439,29 +447,14 @@ def remove_request(self, req_id: str) -> Optional[int]:
self.num_prompt_logprobs.pop(req_id, None)
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)

# LoRA
lora_id = self.request_lora_mapping[req_index]
if lora_id != 0:
lora_req_ids = self.lora_id_to_request_ids[lora_id]
lora_req_ids.discard(req_id)
if not lora_req_ids:
del self.lora_id_to_request_ids[lora_id]
del self.lora_id_to_lora_request[lora_id]
self.request_lora_mapping[req_index] = 0

self.has_allowed_token_ids.discard(req_id)
if self.allowed_token_ids_mask_cpu_tensor is not None:
# False means we don't fill with -inf.
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
self.bad_words_token_ids.pop(req_index, None)
self.pooling_params.pop(req_id, None)
return req_index

def swap_states(self, i1: int, i2: int) -> None:
# For autoregressive models, track detailed request reordering info
# to support logitsprocs
self.batch_update_builder.moved.append(
(i1, i2, MoveDirectionality.SWAP))
old_id_i1 = self._req_ids[i1]
old_id_i2 = self._req_ids[i2]
self._req_ids[i1], self._req_ids[i2] =\
Expand All @@ -479,18 +472,6 @@ def swap_states(self, i1: int, i2: int) -> None:
self.num_prompt_tokens[i2], self.num_prompt_tokens[i1]
self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\
self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1]
self.temperature_cpu[i1], self.temperature_cpu[i2] =\
self.temperature_cpu[i2], self.temperature_cpu[i1]
self.top_p_cpu[i1], self.top_p_cpu[i2] =\
self.top_p_cpu[i2], self.top_p_cpu[i1]
self.top_k_cpu[i1], self.top_k_cpu[i2] =\
self.top_k_cpu[i2], self.top_k_cpu[i1]
self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] =\
self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1]
self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] =\
self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\
self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]

# NOTE: the following is unsafe
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
Expand All @@ -501,18 +482,41 @@ def swap_states(self, i1: int, i2: int) -> None:
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
self.token_ids_cpu[i2, ...] = tmp

swap_dict_values(self.generators, i1, i2)
swap_dict_values(self.bad_words_token_ids, i1, i2)
self.block_table.swap_row(i1, i2)

self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\
self.request_lora_mapping[i1], self.request_lora_mapping[i2] = \
self.request_lora_mapping[i2], self.request_lora_mapping[i1]

if self.is_pooling_model:
# Sampling and logits parameters don't apply to pooling models.
return

# For autoregressive models, track detailed request reordering info
# to support logitsprocs.
self.batch_update_builder.moved.append(
(i1, i2, MoveDirectionality.SWAP))

self.temperature_cpu[i1], self.temperature_cpu[i2] = \
self.temperature_cpu[i2], self.temperature_cpu[i1]
self.top_p_cpu[i1], self.top_p_cpu[i2] = \
self.top_p_cpu[i2], self.top_p_cpu[i1]
self.top_k_cpu[i1], self.top_k_cpu[i2] = \
self.top_k_cpu[i2], self.top_k_cpu[i1]
self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] = \
self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1]
self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] = \
self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = \
self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]

swap_dict_values(self.generators, i1, i2)
swap_dict_values(self.bad_words_token_ids, i1, i2)

if self.allowed_token_ids_mask_cpu_tensor is not None:
self.allowed_token_ids_mask_cpu_tensor[i1], \
self.allowed_token_ids_mask_cpu_tensor[i2] =\
self.allowed_token_ids_mask_cpu_tensor[i2], \
self.allowed_token_ids_mask_cpu_tensor[i1]
self.block_table.swap_row(i1, i2)

def condense(self) -> None:
"""Slide non-empty requests down into lower, empty indices.
Expand All @@ -529,12 +533,6 @@ def condense(self) -> None:
"""
num_reqs = self.num_reqs

if self.is_pooling_model:
# Will be contiguous in pooling case, just trim the lists.
del self._req_ids[num_reqs:]
del self.req_output_token_ids[num_reqs:]
return

if not (empty_req_indices := self.batch_update_builder.removed):
# All removed requests were replaced by added requests, or else no
# requests were removed at all. No condense() needed
Expand Down Expand Up @@ -562,11 +560,6 @@ def condense(self) -> None:
# Move active request down into empty request
# index.
self.batch_update_builder.pop_removed()
# Autoregressive models require detailed tracking of condense
# operations to support logitsprocs
self.batch_update_builder.moved.append(
(last_req_index, empty_index,
MoveDirectionality.UNIDIRECTIONAL))
req_id = self._req_ids[last_req_index]
output_token_ids = self.req_output_token_ids[last_req_index]
assert req_id is not None
Expand All @@ -587,6 +580,21 @@ def condense(self) -> None:
self.num_computed_tokens_cpu[
empty_index] = self.num_computed_tokens_cpu[last_req_index]
self.block_table.move_row(last_req_index, empty_index)

self.request_lora_mapping[empty_index] = self.request_lora_mapping[
last_req_index]

if self.is_pooling_model:
last_req_index -= 1
# Samping state not used by pooling models.
continue

# Autoregressive models require detailed tracking of condense
# operations to support logitsprocs
self.batch_update_builder.moved.append(
(last_req_index, empty_index,
MoveDirectionality.UNIDIRECTIONAL))

self.temperature_cpu[empty_index] = self.temperature_cpu[
last_req_index]
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
Expand All @@ -601,9 +609,6 @@ def condense(self) -> None:
if generator is not None:
self.generators[empty_index] = generator

self.request_lora_mapping[empty_index] = self.request_lora_mapping[
last_req_index]

# TODO convert these to LogitsProcessors
if self.allowed_token_ids_mask_cpu_tensor is not None:
self.allowed_token_ids_mask_cpu_tensor[
Expand All @@ -626,8 +631,9 @@ def refresh_metadata(self):
"""Apply any batch updates to sampling metadata."""

if self.is_pooling_model:
# Batch changes every step for pooling models.
self.sampling_metadata = self._make_sampling_metadata()
batch_changed = self.batch_update_builder.reset()
if batch_changed:
self.sampling_metadata = self._make_sampling_metadata()
return

# For non-pooling models - generate and apply logitsprocs update;
Expand Down Expand Up @@ -720,19 +726,19 @@ def pooling_metadata(self) -> PoolingMetadata:
)

def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
num_reqs = self.num_reqs
max_prompt_len = self.num_prompt_tokens[:num_reqs].max()
prompt_token_ids_cpu_tensor = torch.empty(
(self.num_reqs, max_prompt_len),
device="cpu",
dtype=torch.int64,
pin_memory=self.pin_memory,
)
prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
prompt_token_ids[:] = self.token_ids_cpu[:self.
num_reqs, :max_prompt_len]
prompt_token_ids[:] = self.token_ids_cpu[:num_reqs, :max_prompt_len]
# Use the value of vocab_size as a pad since we don't have a
# token_id of this value.
for i in range(self.num_reqs):
for i in range(num_reqs):
prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size
return prompt_token_ids_cpu_tensor.to(device=self.device,
non_blocking=True)
Expand Down
8 changes: 3 additions & 5 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1482,10 +1482,8 @@ def _pool(
for raw_output, seq_len, prompt_len in zip(
raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens):

if seq_len == prompt_len:
pooler_output.append(raw_output.data)
else:
pooler_output.append(None)
output = raw_output.data if seq_len == prompt_len else None
pooler_output.append(output)

return ModelRunnerOutput(
req_ids=self.input_batch.req_ids,
Expand Down Expand Up @@ -1515,7 +1513,7 @@ def execute_model(
# Prepare the decoder inputs.
(attn_metadata, logits_indices, spec_decode_metadata,
num_scheduled_tokens_np, spec_decode_common_attn_metadata,
max_query_len) = (self._prepare_inputs(scheduler_output))
max_query_len) = self._prepare_inputs(scheduler_output)

num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
Expand Down