Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 54 additions & 6 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2211,6 +2211,20 @@ def _bookkeeping_sync(
if i not in invalid_req_indices_set
}

# Collect updates in the for loop and apply a batch update at the end
# to vectorize updates to tensors and numpy arrays.
start_indices = self.input_batch.num_tokens_no_spec.tolist()
# Indices and values to update for num_tokens and num_tokens_no_spec
num_tokens_indices: list[int] = []
num_tokens_values: list[int] = []
# Flatten indices and values to update for token_ids (2D numpy array)
token_ids_cpu_flatten_indices: list[int] = []
token_ids_cpu_values: list[int] = []
token_ids_cpu_column_cnt = self.input_batch.token_ids_cpu.shape[1]
# Flatten is_token_ids indices for update
is_token_flatten_indices: list[int] = []
is_token_column_cnt = self.input_batch.is_token_ids.shape[1]

# Cache the sampled tokens in the model runner, so that the scheduler
# doesn't need to send them back.
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
Expand All @@ -2226,23 +2240,57 @@ def _bookkeeping_sync(
if not sampled_ids:
continue

start_idx = self.input_batch.num_tokens_no_spec[req_idx]
start_idx = start_indices[req_idx]
end_idx = start_idx + len(sampled_ids)
assert end_idx <= self.max_model_len, (
"Sampled token IDs exceed the max model length. "
f"Total number of tokens: {end_idx} > max_model_len: "
f"{self.max_model_len}")

self.input_batch.token_ids_cpu[req_idx,
start_idx:end_idx] = sampled_ids
self.input_batch.is_token_ids[req_idx, start_idx:end_idx] = True
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
self.input_batch.num_tokens[req_idx] = end_idx
# Collect flattened indices and updated values a bulk token_ids_cpu
# update, which is equivilent to
# - self.input_batch.token_ids_cpu[req_idx,
# start_idx:end_idx] = sampled_ids
base_idx = req_idx * token_ids_cpu_column_cnt
token_ids_cpu_flatten_indices.extend(
base_idx + idx for idx in range(start_idx, end_idx))
token_ids_cpu_values.extend(sampled_ids)
# Collect flattened indices for a bulk is_token_ids update, which
# is equivilent to
# - self.input_batch.is_token_ids[req_idx,
# start_idx:end_idx] = True
base_idx = req_idx * is_token_column_cnt
is_token_flatten_indices.extend(
base_idx + idx for idx in range(start_idx, end_idx))
# Collect updates to num_tokens and num_tokens_no_spec,
# which is equivilent to
# - self.input_batch.num_tokens_no_spec[req_idx] = end_idx
# - self.input_batch.num_tokens[req_idx] = end_idx
num_tokens_indices.append(req_idx)
num_tokens_values.append(end_idx)

req_id = req_ids[req_idx]
req_state = self.requests[req_id]
req_state.output_token_ids.extend(sampled_ids)

# Apply all tensor / numpy array updates in batch
if num_tokens_indices:
# Batch update num_tokens arrays
self.input_batch.num_tokens[num_tokens_indices] = num_tokens_values
self.input_batch.num_tokens_no_spec[
num_tokens_indices] = num_tokens_values
if token_ids_cpu_flatten_indices:
token_ids_cpu_view = self.input_batch.token_ids_cpu.ravel()
# Ensure ravel returned a view of the original numpy array
assert token_ids_cpu_view.base is self.input_batch.token_ids_cpu
token_ids_cpu_view[
token_ids_cpu_flatten_indices] = token_ids_cpu_values
if is_token_flatten_indices:
is_token_ids_view = self.input_batch.is_token_ids.ravel()
# Ensure ravel returned a view of the original tensor
assert is_token_ids_view._base is self.input_batch.is_token_ids
is_token_ids_view[is_token_flatten_indices] = True

return (
num_nans_in_logits,
logprobs_lists,
Expand Down