diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0960fe3a25fb..3da136dd30ed 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2211,6 +2211,20 @@ def _bookkeeping_sync( if i not in invalid_req_indices_set } + # Collect updates in the for loop and apply a batch update at the end + # to vectorize updates to tensors and numpy arrays. + start_indices = self.input_batch.num_tokens_no_spec.tolist() + # Indices and values to update for num_tokens and num_tokens_no_spec + num_tokens_indices: list[int] = [] + num_tokens_values: list[int] = [] + # Flatten indices and values to update for token_ids (2D numpy array) + token_ids_cpu_flatten_indices: list[int] = [] + token_ids_cpu_values: list[int] = [] + token_ids_cpu_column_cnt = self.input_batch.token_ids_cpu.shape[1] + # Flatten is_token_ids indices for update + is_token_flatten_indices: list[int] = [] + is_token_column_cnt = self.input_batch.is_token_ids.shape[1] + # Cache the sampled tokens in the model runner, so that the scheduler # doesn't need to send them back. # NOTE(woosuk): As an exception, when using PP, the scheduler sends @@ -2226,23 +2240,57 @@ def _bookkeeping_sync( if not sampled_ids: continue - start_idx = self.input_batch.num_tokens_no_spec[req_idx] + start_idx = start_indices[req_idx] end_idx = start_idx + len(sampled_ids) assert end_idx <= self.max_model_len, ( "Sampled token IDs exceed the max model length. " f"Total number of tokens: {end_idx} > max_model_len: " f"{self.max_model_len}") - self.input_batch.token_ids_cpu[req_idx, - start_idx:end_idx] = sampled_ids - self.input_batch.is_token_ids[req_idx, start_idx:end_idx] = True - self.input_batch.num_tokens_no_spec[req_idx] = end_idx - self.input_batch.num_tokens[req_idx] = end_idx + # Collect flattened indices and updated values a bulk token_ids_cpu + # update, which is equivilent to + # - self.input_batch.token_ids_cpu[req_idx, + # start_idx:end_idx] = sampled_ids + base_idx = req_idx * token_ids_cpu_column_cnt + token_ids_cpu_flatten_indices.extend( + base_idx + idx for idx in range(start_idx, end_idx)) + token_ids_cpu_values.extend(sampled_ids) + # Collect flattened indices for a bulk is_token_ids update, which + # is equivilent to + # - self.input_batch.is_token_ids[req_idx, + # start_idx:end_idx] = True + base_idx = req_idx * is_token_column_cnt + is_token_flatten_indices.extend( + base_idx + idx for idx in range(start_idx, end_idx)) + # Collect updates to num_tokens and num_tokens_no_spec, + # which is equivilent to + # - self.input_batch.num_tokens_no_spec[req_idx] = end_idx + # - self.input_batch.num_tokens[req_idx] = end_idx + num_tokens_indices.append(req_idx) + num_tokens_values.append(end_idx) req_id = req_ids[req_idx] req_state = self.requests[req_id] req_state.output_token_ids.extend(sampled_ids) + # Apply all tensor / numpy array updates in batch + if num_tokens_indices: + # Batch update num_tokens arrays + self.input_batch.num_tokens[num_tokens_indices] = num_tokens_values + self.input_batch.num_tokens_no_spec[ + num_tokens_indices] = num_tokens_values + if token_ids_cpu_flatten_indices: + token_ids_cpu_view = self.input_batch.token_ids_cpu.ravel() + # Ensure ravel returned a view of the original numpy array + assert token_ids_cpu_view.base is self.input_batch.token_ids_cpu + token_ids_cpu_view[ + token_ids_cpu_flatten_indices] = token_ids_cpu_values + if is_token_flatten_indices: + is_token_ids_view = self.input_batch.is_token_ids.ravel() + # Ensure ravel returned a view of the original tensor + assert is_token_ids_view._base is self.input_batch.is_token_ids + is_token_ids_view[is_token_flatten_indices] = True + return ( num_nans_in_logits, logprobs_lists,