Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions tpu_commons/runner/tpu_torchax_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,21 +447,24 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
req_ids_to_add.append(req_id)

# Update the states of the running/resumed requests.
for req_data in scheduler_output.scheduled_cached_reqs:
req_id = req_data.req_id
req_data = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(req_data.req_ids):
req_state = self.requests[req_id]
num_computed_tokens = req_data.num_computed_tokens[i]
new_block_ids = req_data.new_block_ids[i]
resumed_from_preemption = req_data.resumed_from_preemption[i]

# Update the cached states.
req_state.num_computed_tokens = req_data.num_computed_tokens
if not req_data.resumed_from_preemption:
req_state.num_computed_tokens = num_computed_tokens
if not resumed_from_preemption:
# Append the new blocks to the existing block IDs.
for block_ids, new_block_ids in zip(req_state.block_ids,
req_data.new_block_ids):
block_ids.extend(new_block_ids)
for block_ids, new_ids in zip(req_state.block_ids,
new_block_ids):
block_ids.extend(new_ids)
else:
# The request is resumed from preemption.
# Replace the existing block IDs with the new ones.
req_state.block_ids = req_data.new_block_ids
req_state.block_ids = new_block_ids

req_index = self.input_batch.req_id_to_index.get(req_id)
if req_index is None:
Expand All @@ -473,9 +476,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:

# Update the persistent batch.
self.input_batch.num_computed_tokens_cpu[req_index] = (
req_data.num_computed_tokens)
self.input_batch.block_table.append_row(req_data.new_block_ids,
req_index)
num_computed_tokens)
self.input_batch.block_table.append_row(new_block_ids, req_index)

# Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first.
Expand Down
Loading