Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,7 +845,8 @@ def __init__(self,
delay_factor: float = 0.0,
enable_chunked_prefill: bool = False,
embedding_mode: Optional[bool] = False,
preemption_mode: Optional[str] = None) -> None:
preemption_mode: Optional[str] = None,
num_scheduler_steps: int = 1) -> None:
if max_num_batched_tokens is not None:
self.max_num_batched_tokens = max_num_batched_tokens
else:
Expand Down Expand Up @@ -874,6 +875,7 @@ def __init__(self,
self.chunked_prefill_enabled = enable_chunked_prefill
self.embedding_mode = embedding_mode
self.preemption_mode = preemption_mode
self.num_scheduler_steps = num_scheduler_steps
self._verify_args()

def _verify_args(self) -> None:
Expand All @@ -899,6 +901,16 @@ def _verify_args(self) -> None:
f"({self.num_lookahead_slots}) must be greater than or "
"equal to 0.")

if self.num_scheduler_steps < 1:
raise ValueError(
"num_scheduler_steps "
f"({self.num_scheduler_steps}) must be greater than or "
"equal to 1.")

@property
def is_multi_step(self) -> bool:
return self.num_scheduler_steps > 1


class DeviceConfig:
device: Optional[torch.device]
Expand Down
5 changes: 5 additions & 0 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,6 +805,9 @@ def _schedule_prefills(
curr_loras.add(lora_int_id)
waiting_queue.popleft()
self._allocate_and_set_running(seq_group)
seq_group.init_multi_step(
num_scheduler_steps=self._get_num_lookahead_slots(
is_prefill=True) + 1)
seq_groups.append(
ScheduledSequenceGroup(seq_group=seq_group,
token_chunk_size=num_new_tokens))
Expand Down Expand Up @@ -1108,6 +1111,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
computed_block_nums=common_computed_block_nums,
encoder_seq_data=encoder_seq_data,
cross_block_table=cross_block_table,
state=seq_group.state,
# `multi_modal_data` will only be present for the 1st comm
# between engine and worker.
# the subsequent comms can still use delta, but
Expand Down Expand Up @@ -1184,6 +1188,7 @@ def _append_slots(
slots.
"""
num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False)
seq_group.init_multi_step(num_scheduler_steps=num_lookahead_slots + 1)

for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
cows = self.block_manager.append_slots(seq, num_lookahead_slots)
Expand Down
28 changes: 25 additions & 3 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class EngineArgs:
lora_dtype: str = 'auto'
max_cpu_loras: Optional[int] = None
device: str = 'auto'
num_scheduler_steps: int = 1
ray_workers_use_nsight: bool = False
num_gpu_blocks_override: Optional[int] = None
num_lookahead_slots: int = 0
Expand Down Expand Up @@ -506,6 +507,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"tpu", "xpu"
],
help='Device type for vLLM execution.')
parser.add_argument('--num-scheduler-steps',
type=int,
default=1,
help=('Maximum number of forward steps per '
'scheduler call.'))

parser.add_argument(
'--scheduler-delay-factor',
Expand Down Expand Up @@ -820,18 +826,34 @@ def create_engine_config(self, ) -> EngineConfig:
disable_logprobs=self.disable_logprobs_during_spec_decoding,
)

if self.num_scheduler_steps > 1:
raise NotImplementedError("Multi-step is not yet supported.")
if speculative_config is not None:
raise ValueError("Speculative decoding is not supported with "
"multi-step (--num-scheduler-steps > 1)")
if self.enable_chunked_prefill:
raise ValueError("Chunked prefill is not supported with "
"multi-step (--num-scheduler-steps > 1)")

# make sure num_lookahead_slots is set the higher value depending on
# if we are using speculative decoding or multi-step
num_lookahead_slots = max(self.num_lookahead_slots,
self.num_scheduler_steps - 1)
num_lookahead_slots = num_lookahead_slots \
if speculative_config is None \
else speculative_config.num_lookahead_slots

scheduler_config = SchedulerConfig(
max_num_batched_tokens=self.max_num_batched_tokens,
max_num_seqs=self.max_num_seqs,
max_model_len=model_config.max_model_len,
use_v2_block_manager=self.use_v2_block_manager,
num_lookahead_slots=(self.num_lookahead_slots
if speculative_config is None else
speculative_config.num_lookahead_slots),
num_lookahead_slots=num_lookahead_slots,
delay_factor=self.scheduler_delay_factor,
enable_chunked_prefill=self.enable_chunked_prefill,
embedding_mode=model_config.embedding_mode,
preemption_mode=self.preemption_mode,
num_scheduler_steps=self.num_scheduler_steps,
)
lora_config = LoRAConfig(
max_lora_rank=self.max_lora_rank,
Expand Down
58 changes: 57 additions & 1 deletion vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple,
Union, cast)

import numpy
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:

Suggested change
import numpy
import numpy as np

import torch

from vllm.inputs.parse import is_valid_encoder_decoder_llm_inputs
Expand Down Expand Up @@ -489,6 +490,19 @@ def __repr__(self) -> str:
f"num_blocks={self.n_blocks}, ")


@dataclass
class SequenceGroupState:
"""Mutable state tied to a specific sequence group"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this should be a part of SequenceData? All the states are now stored in SequenceData now actually

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also IIUC num_steps is not state, and it may not belong here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps eventually we will support different num_steps in a batch and so we track this per SequenceGroup. There used to be a SequenceGroupState used for seed generator but has since been removed. #6698


# for multi-step decoding
num_steps: int = 1
current_step: int = 0

@property
def remaining_steps(self) -> int:
return self.num_steps - self.current_step


class SequenceGroup:
"""A group of sequences that are generated from the same prompt.

Expand Down Expand Up @@ -534,6 +548,7 @@ def __init__(
time_in_queue=None)
self.lora_request = lora_request
self.prompt_logprobs: Optional[PromptLogprobs] = None
self.state = SequenceGroupState()
self.embeddings = embeddings
self.pooling_params = pooling_params
self.prompt_adapter_request = prompt_adapter_request
Expand Down Expand Up @@ -588,6 +603,10 @@ def prompt_adapter_num_virtual_tokens(self) -> int:
return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\
if self.prompt_adapter_request else 0

def init_multi_step(self, num_scheduler_steps: int) -> None:
self.state.num_steps = num_scheduler_steps
self.state.current_step = 0

def get_last_latency(self, now: float) -> Optional[float]:
"""Sets the last token time for Request level timings."""
# If still in prefill phase, raise Error.
Expand Down Expand Up @@ -756,6 +775,7 @@ class SequenceGroupMetadata:
lora_request: LoRA request.
computed_block_nums: The block numbers that are already computed,
used in prefix caching.
state: Internal state tied to this sequence group.
multi_modal_data: Multi modal data.
encoder_seq_data: Optional sequence data for encoder prompt
(SequenceGroup.encoder_seq). Should be None
Expand All @@ -781,6 +801,7 @@ def __init__(
token_chunk_size: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,
computed_block_nums: Optional[List[int]] = None,
state: Optional[SequenceGroupState] = None,
multi_modal_data: Optional["MultiModalDataDict"] = None,
encoder_seq_data: Optional[SequenceData] = None,
cross_block_table: Optional[List[int]] = None,
Expand All @@ -796,6 +817,7 @@ def __init__(
self.prompt_adapter_request = prompt_adapter_request
self.computed_block_nums = computed_block_nums
self.multi_modal_data = multi_modal_data
self.state = SequenceGroupState() if state is None else state
self.encoder_seq_data = encoder_seq_data
self.cross_block_table = cross_block_table
self._token_chunk_size = token_chunk_size
Expand Down Expand Up @@ -834,6 +856,10 @@ def token_chunk_size(self) -> int:
assert self._token_chunk_size is not None
return self._token_chunk_size

def finish_step(self) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not used in this PR. Will it be used in the next PR with a scheduler?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes it will be used in AsyncLLMEngine

assert self.state.current_step < self.state.num_steps
self.state.current_step += 1


class SequenceOutput:
"""The model output associated with a sequence.
Expand Down Expand Up @@ -971,6 +997,7 @@ class SamplerOutput:

# On-device tensor containing the sampled token ids.
sampled_token_ids: Optional[torch.Tensor] = None
sampled_token_ids_numpy: Optional[numpy.ndarray] = None

# Spec decode metrics populated by workers.
spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
Expand Down Expand Up @@ -1112,6 +1139,33 @@ class ExecuteModelRequest:
num_steps: int = 1
# Finished request ids since last step.
finished_requests_ids: List[str] = field(default_factory=list)
# The last sampled token ids for multi step decoding.
last_sampled_token_ids: Optional[torch.Tensor] = None

@property
def is_first_multi_step(self) -> bool:
# TODO(will) make this be able to handle batches with variable number of
# steps
assert len(self.seq_group_metadata_list) > 0
first_seq_group = self.seq_group_metadata_list[0]
return first_seq_group.state.current_step == 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this means that the state of multi step is decided by the state of the first sequence group in a batch? E.g., if the first seq group has only 1 step left, we only run 1 step although num_steps > 1?

Copy link
Contributor Author

@SolitaryThinker SolitaryThinker Aug 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, eventually we may change this behavior and allow sequences in a batch to concurrently be on different steps

Copy link
Member

@zhuohan123 zhuohan123 Aug 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this really a necessary feature? I personally think it should be ok to assume we run all the requests in the batch for the same number of steps.


@property
def is_last_step(self) -> bool:
# TODO(will) make this be able to handle batches with variable number of
# steps
assert len(self.seq_group_metadata_list) > 0
first_seq_group = self.seq_group_metadata_list[0]
num_steps = first_seq_group.state.num_steps
current_step = first_seq_group.state.current_step
return num_steps - current_step == 1

@property
def current_step(self) -> int:
# TODO(will) make this be able to handle batches with variable number of
# steps
assert len(self.seq_group_metadata_list) > 0
return self.seq_group_metadata_list[0].state.current_step

def clone(
self, seq_group_metadata_list: List[SequenceGroupMetadata]
Expand All @@ -1127,4 +1181,6 @@ def clone(
running_queue_size=self.running_queue_size,
previous_hidden_states=self.previous_hidden_states,
num_steps=self.num_steps,
finished_requests_ids=self.finished_requests_ids)
finished_requests_ids=self.finished_requests_ids,
last_sampled_token_ids=self.last_sampled_token_ids.clone()
if self.last_sampled_token_ids is not None else None)