-
-
Notifications
You must be signed in to change notification settings - Fork 11k
[core] [3/N] multi-step args and sequence.py #7452
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -8,6 +8,7 @@ | |||||
| from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple, | ||||||
| Union, cast) | ||||||
|
|
||||||
| import numpy | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit:
Suggested change
|
||||||
| import torch | ||||||
|
|
||||||
| from vllm.inputs.parse import is_valid_encoder_decoder_llm_inputs | ||||||
|
|
@@ -489,6 +490,19 @@ def __repr__(self) -> str: | |||||
| f"num_blocks={self.n_blocks}, ") | ||||||
|
|
||||||
|
|
||||||
| @dataclass | ||||||
| class SequenceGroupState: | ||||||
| """Mutable state tied to a specific sequence group""" | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if this should be a part of SequenceData? All the states are now stored in SequenceData now actually There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also IIUC There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. perhaps eventually we will support different num_steps in a batch and so we track this per SequenceGroup. There used to be a SequenceGroupState used for seed generator but has since been removed. #6698 |
||||||
|
|
||||||
| # for multi-step decoding | ||||||
| num_steps: int = 1 | ||||||
| current_step: int = 0 | ||||||
|
|
||||||
| @property | ||||||
| def remaining_steps(self) -> int: | ||||||
| return self.num_steps - self.current_step | ||||||
|
|
||||||
|
|
||||||
| class SequenceGroup: | ||||||
| """A group of sequences that are generated from the same prompt. | ||||||
|
|
||||||
|
|
@@ -534,6 +548,7 @@ def __init__( | |||||
| time_in_queue=None) | ||||||
| self.lora_request = lora_request | ||||||
| self.prompt_logprobs: Optional[PromptLogprobs] = None | ||||||
| self.state = SequenceGroupState() | ||||||
| self.embeddings = embeddings | ||||||
| self.pooling_params = pooling_params | ||||||
| self.prompt_adapter_request = prompt_adapter_request | ||||||
|
|
@@ -588,6 +603,10 @@ def prompt_adapter_num_virtual_tokens(self) -> int: | |||||
| return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\ | ||||||
| if self.prompt_adapter_request else 0 | ||||||
|
|
||||||
| def init_multi_step(self, num_scheduler_steps: int) -> None: | ||||||
| self.state.num_steps = num_scheduler_steps | ||||||
| self.state.current_step = 0 | ||||||
|
|
||||||
| def get_last_latency(self, now: float) -> Optional[float]: | ||||||
| """Sets the last token time for Request level timings.""" | ||||||
| # If still in prefill phase, raise Error. | ||||||
|
|
@@ -756,6 +775,7 @@ class SequenceGroupMetadata: | |||||
| lora_request: LoRA request. | ||||||
| computed_block_nums: The block numbers that are already computed, | ||||||
| used in prefix caching. | ||||||
| state: Internal state tied to this sequence group. | ||||||
| multi_modal_data: Multi modal data. | ||||||
| encoder_seq_data: Optional sequence data for encoder prompt | ||||||
| (SequenceGroup.encoder_seq). Should be None | ||||||
|
|
@@ -781,6 +801,7 @@ def __init__( | |||||
| token_chunk_size: Optional[int] = None, | ||||||
| lora_request: Optional[LoRARequest] = None, | ||||||
| computed_block_nums: Optional[List[int]] = None, | ||||||
| state: Optional[SequenceGroupState] = None, | ||||||
| multi_modal_data: Optional["MultiModalDataDict"] = None, | ||||||
| encoder_seq_data: Optional[SequenceData] = None, | ||||||
| cross_block_table: Optional[List[int]] = None, | ||||||
|
|
@@ -796,6 +817,7 @@ def __init__( | |||||
| self.prompt_adapter_request = prompt_adapter_request | ||||||
| self.computed_block_nums = computed_block_nums | ||||||
| self.multi_modal_data = multi_modal_data | ||||||
| self.state = SequenceGroupState() if state is None else state | ||||||
| self.encoder_seq_data = encoder_seq_data | ||||||
| self.cross_block_table = cross_block_table | ||||||
| self._token_chunk_size = token_chunk_size | ||||||
|
|
@@ -834,6 +856,10 @@ def token_chunk_size(self) -> int: | |||||
| assert self._token_chunk_size is not None | ||||||
| return self._token_chunk_size | ||||||
|
|
||||||
| def finish_step(self) -> None: | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not used in this PR. Will it be used in the next PR with a scheduler? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes it will be used in AsyncLLMEngine |
||||||
| assert self.state.current_step < self.state.num_steps | ||||||
| self.state.current_step += 1 | ||||||
|
|
||||||
|
|
||||||
| class SequenceOutput: | ||||||
| """The model output associated with a sequence. | ||||||
|
|
@@ -971,6 +997,7 @@ class SamplerOutput: | |||||
|
|
||||||
| # On-device tensor containing the sampled token ids. | ||||||
| sampled_token_ids: Optional[torch.Tensor] = None | ||||||
| sampled_token_ids_numpy: Optional[numpy.ndarray] = None | ||||||
|
|
||||||
| # Spec decode metrics populated by workers. | ||||||
| spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None | ||||||
|
|
@@ -1112,6 +1139,33 @@ class ExecuteModelRequest: | |||||
| num_steps: int = 1 | ||||||
| # Finished request ids since last step. | ||||||
| finished_requests_ids: List[str] = field(default_factory=list) | ||||||
| # The last sampled token ids for multi step decoding. | ||||||
| last_sampled_token_ids: Optional[torch.Tensor] = None | ||||||
|
|
||||||
| @property | ||||||
| def is_first_multi_step(self) -> bool: | ||||||
| # TODO(will) make this be able to handle batches with variable number of | ||||||
| # steps | ||||||
| assert len(self.seq_group_metadata_list) > 0 | ||||||
| first_seq_group = self.seq_group_metadata_list[0] | ||||||
| return first_seq_group.state.current_step == 0 | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so this means that the state of multi step is decided by the state of the first sequence group in a batch? E.g., if the first seq group has only 1 step left, we only run 1 step although num_steps > 1? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, eventually we may change this behavior and allow sequences in a batch to concurrently be on different steps There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this really a necessary feature? I personally think it should be ok to assume we run all the requests in the batch for the same number of steps. |
||||||
|
|
||||||
| @property | ||||||
| def is_last_step(self) -> bool: | ||||||
| # TODO(will) make this be able to handle batches with variable number of | ||||||
| # steps | ||||||
| assert len(self.seq_group_metadata_list) > 0 | ||||||
| first_seq_group = self.seq_group_metadata_list[0] | ||||||
| num_steps = first_seq_group.state.num_steps | ||||||
| current_step = first_seq_group.state.current_step | ||||||
| return num_steps - current_step == 1 | ||||||
|
|
||||||
| @property | ||||||
| def current_step(self) -> int: | ||||||
| # TODO(will) make this be able to handle batches with variable number of | ||||||
| # steps | ||||||
| assert len(self.seq_group_metadata_list) > 0 | ||||||
| return self.seq_group_metadata_list[0].state.current_step | ||||||
|
|
||||||
| def clone( | ||||||
| self, seq_group_metadata_list: List[SequenceGroupMetadata] | ||||||
|
|
@@ -1127,4 +1181,6 @@ def clone( | |||||
| running_queue_size=self.running_queue_size, | ||||||
| previous_hidden_states=self.previous_hidden_states, | ||||||
| num_steps=self.num_steps, | ||||||
| finished_requests_ids=self.finished_requests_ids) | ||||||
| finished_requests_ids=self.finished_requests_ids, | ||||||
| last_sampled_token_ids=self.last_sampled_token_ids.clone() | ||||||
| if self.last_sampled_token_ids is not None else None) | ||||||
Uh oh!
There was an error while loading. Please reload this page.