Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ def run_to_completion(profile_dir: Optional[str] = None):
if args.profile:
profile_dir = args.profile_result_dir
if not profile_dir:
profile_dir = Path(".") / "vllm_benchmark_result" / f"latency_result_{time.time()}"
profile_dir = Path(
"."
) / "vllm_benchmark_result" / f"latency_result_{time.time()}"
print(f"Profiling (results will be saved to '{profile_dir}')...")
run_to_completion(profile_dir=args.profile_result_dir)
return
Expand Down Expand Up @@ -123,9 +125,7 @@ def run_to_completion(profile_dir: Optional[str] = None):
'--profile-result-dir',
type=str,
default=None,
help=(
'path to save the pytorch profiler output. Can be visualized '
'with ui.perfetto.dev or Tensorboard.'
))
help=('path to save the pytorch profiler output. Can be visualized '
'with ui.perfetto.dev or Tensorboard.'))
args = parser.parse_args()
main(args)
14 changes: 14 additions & 0 deletions bin/build.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#!/bin/sh
function pkg_deploy() {
tar_src=`find dist -type f | grep "tar.gz" | grep -v cli`
pip install -q oss2 -i https://mirrors.aliyun.com/pypi/simple
python $ROOT/bin/osscli $tar_src alps/`basename $tar_src`
}
cd ..
rm -rf dist
python setup.py sdist

ROOT=$(pwd)
echo $ROOT

pkg_deploy
26 changes: 26 additions & 0 deletions bin/osscli
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# -*- coding: utf-8 -*-
import oss2
import sys

def decode(key):
import base64
return base64.standard_b64decode(key).decode()

accessid=decode("TFRBSUN1S0hFSzFFNnNXNw==")
accesskey=decode("UFJ5YnNnVU1taHVKSUNXMUp3U0FKcDlwS1NjbXlR")
host="oss-cn-hangzhou-zmf.aliyuncs.com"
bucket="alps-common"
src=sys.argv[1]
target=sys.argv[2]
if target.endswith("/"):
import os
name = os.path.basename(src)
target = os.join(target, src)
oss_bucket = oss2.Bucket(oss2.Auth(accessid, accesskey), host, bucket, connect_timeout=300)
out = oss_bucket.put_object_from_file(target, src)
addr="http://%s.oss-cn-hangzhou-zmf.aliyuncs.com/%s" % (bucket, target)

if out.status == 200:
print("Upload to %s " % addr)
else:
raise ValueError("Upload to %s " % addr)
1 change: 0 additions & 1 deletion examples/openai_chatcompletion_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,5 @@
model=model,
)


print("Chat completion results:")
print(chat_completion)
3 changes: 1 addition & 2 deletions examples/openai_completion_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
echo=False,
n=2,
stream=stream,
logprobs=3
)
logprobs=3)

print("Completion results:")
if stream:
Expand Down
57 changes: 55 additions & 2 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Deque, Dict, Iterable, List, Optional, Tuple, Union

from vllm.config import CacheConfig, SchedulerConfig
from vllm.core.block_manager import AllocStatus, BlockSpaceManager
from vllm.core.block_manager import AllocStatus, BlockSpaceManager, BlockTable
from vllm.core.policy import PolicyFactory
from vllm.logger import init_logger
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
Expand Down Expand Up @@ -82,11 +82,31 @@ def __init__(
self.running: Deque[SequenceGroup] = deque()
# Sequence groups in the SWAPPED state.
self.swapped: Deque[SequenceGroup] = deque()
self.prefix_cache: Dict[str, SequenceGroup] = {}

def add_seq_group(self, seq_group: SequenceGroup) -> None:
# Add sequence groups to the waiting queue.
if seq_group.prefix_name in self.prefix_cache:
prefix_seq = self.prefix_cache[seq_group.prefix_name].get_seqs()[0]
if prefix_seq.is_finished():
seq_group.set_prefix_seq(prefix_seq)
else:
prefix = prefix_seq.prompt
prefix_token_ids = prefix_seq.data.prompt_token_ids
seqs = seq_group.get_seqs()
for seq in seqs:
seq_group.seqs_dict[seq.seq_id] = Sequence(
seq.seq_id, prefix + seq.prompt,
prefix_token_ids + seq.data.prompt_token_ids,
seq.block_size)
self.waiting.append(seq_group)

def add_prefix_seq_groups(self,
seq_group_map: Dict[str, SequenceGroup]) -> None:
self.prefix_cache.update(seq_group_map)
for seq_group in seq_group_map.values():
self.waiting.append(seq_group)

def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
if isinstance(request_id, str):
request_id = (request_id, )
Expand Down Expand Up @@ -294,6 +314,11 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
seq_id = seq.seq_id
seq_data[seq_id] = seq.data
block_tables[seq_id] = self.block_manager.get_block_table(seq)
self._reuse_prefix_cache(
seq_group=seq_group,
block_table=block_tables[seq_id],
blocks_to_copy=scheduler_outputs.blocks_to_copy
if scheduler_outputs.prompt_run else None)

seq_group_metadata = SequenceGroupMetadata(
request_id=seq_group.request_id,
Expand All @@ -309,7 +334,9 @@ def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
self.block_manager.fork(parent_seq, child_seq)

def free_seq(self, seq: Sequence) -> None:
self.block_manager.free(seq)
# don't free prefix_table which has block_table.
if not seq.block_table:
self.block_manager.free(seq)

def free_finished_seq_groups(self) -> None:
self.running = [
Expand Down Expand Up @@ -411,3 +438,29 @@ def _swap_out(
blocks_to_swap_out.update(mapping)
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
seq.status = SequenceStatus.SWAPPED

def _reuse_prefix_cache(self, seq_group: SequenceGroup,
block_table: BlockTable,
blocks_to_copy: Dict[int, List[int]]) -> None:
if seq_group.is_prefix:
assert len(seq_group.get_seqs()
) == 1, "prefix_seq_group only has one sequence."
seq = seq_group.get_seqs()[0]
seq.block_table = block_table # cache prefix_block_table in prefix_seq
elif seq_group.prefix_seq:
# reuse prefix_block_table frome prefix_seq
prefix_seq = seq_group.prefix_seq
need_copy_last_block = prefix_seq.get_prompt_len(
) % prefix_seq.block_size > 0
prefix_block_num = len(
prefix_seq.block_table) - need_copy_last_block
block_table[:
prefix_block_num] = prefix_seq.block_table[:
prefix_block_num]
if need_copy_last_block and blocks_to_copy is not None:
src_block, dst_block = prefix_seq.block_table[-1], block_table[
prefix_block_num]
if src_block in blocks_to_copy:
blocks_to_copy[src_block].append(dst_block)
else:
blocks_to_copy[src_block] = [dst_block]
20 changes: 12 additions & 8 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ async def add_request(
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
prefix_name: Optional[str] = None,
) -> AsyncStream:
if self.log_requests:
shortened_prompt = prompt
Expand Down Expand Up @@ -400,17 +401,18 @@ async def add_request(
prompt=prompt,
sampling_params=sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time)
arrival_time=arrival_time,
prefix_name=prefix_name)

return stream

async def generate(
self,
prompt: Optional[str],
sampling_params: SamplingParams,
request_id: str,
prompt_token_ids: Optional[List[int]] = None
) -> AsyncIterator[RequestOutput]:
self,
prompt: Optional[str],
sampling_params: SamplingParams,
request_id: str,
prompt_token_ids: Optional[List[int]] = None,
prefix_name: Optional[str] = None) -> AsyncIterator[RequestOutput]:
"""Generate outputs for a request.

Generate outputs for a request. This method is a coroutine. It adds the
Expand All @@ -424,6 +426,7 @@ async def generate(
request_id: The unique id of the request.
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
prefix_name: the prefix_name which has cached by prompt cache.

Yields:
The output `RequestOutput` objects from the LLMEngine for the
Expand All @@ -438,7 +441,8 @@ async def generate(
prompt,
sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time)
arrival_time=arrival_time,
prefix_name=prefix_name)

async for request_output in stream:
yield request_output
Expand Down
51 changes: 50 additions & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ def add_request(
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
prefix_name: Optional[str] = None,
) -> None:
"""Add a request to the engine's request pool.

Expand All @@ -334,12 +335,22 @@ def add_request(
use the tokenizer to convert the prompts to token IDs.
arrival_time: The arrival time of the request. If None, we use
the current monotonic time.
prefix_name: the key of prefix_cache to reuse prefix_cache
to speedup prompt computing.
"""
if arrival_time is None:
arrival_time = time.monotonic()
if prompt_token_ids is None:
assert prompt is not None
prompt_token_ids = self.tokenizer.encode(prompt)
if prefix_name:
special_tokens = self.tokenizer.build_inputs_with_special_tokens(
[], [])
if special_tokens:
for idex, token_id in enumerate(prompt_token_ids):
if token_id not in special_tokens:
prompt_token_ids = prompt_token_ids[idex:]
break

# Create the sequences.
block_size = self.cache_config.block_size
Expand All @@ -348,11 +359,47 @@ def add_request(

# Create the sequence group.
seq_group = SequenceGroup(request_id, [seq], sampling_params,
arrival_time)
arrival_time, prefix_name)

# Add the sequence group to the scheduler.
self.scheduler.add_seq_group(seq_group)

def add_prefix_template(self, prefix_map: Dict[str, str]) -> None:
"""
prefix_map = {prefix_name: prefix_content or (prefix_content, prefix_token_ids)}
Generate the kv cache for the corresponding prefix_content, stored as promptcache.
Ensure that when a normal request is added, the required prefix_cache has already been created.
The current strategy is to concatenate the corresponding prefix_context if the promptcache has not been generated.
"""
greedy_sampling_params = SamplingParams(temperature=0.0,
use_beam_search=False,
max_tokens=1)
seq_group_map = {}
for prefix_name, con in prefix_map.items():
if isinstance(con, tuple):
prefix_content, prefix_token_ids = con
else:
prefix_content = con
assert prefix_content is not None
prefix_token_ids = self.tokenizer.encode(prefix_content)
special_tokens = self.tokenizer.build_inputs_with_special_tokens(
[], [])
if special_tokens:
for idex in range(len(prefix_token_ids)):
if prefix_token_ids[-idex] not in special_tokens:
prefix_token_ids = prefix_token_ids[:len(
prefix_token_ids) - idex + 1]
break
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
seq = Sequence(seq_id, prefix_content, prefix_token_ids,
block_size)
# Create the sequence group.
seq_group_map[prefix_name] = SequenceGroup(None, [seq],
greedy_sampling_params,
None)
self.scheduler.add_prefix_seq_groups(seq_group_map)

def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
"""Aborts a request(s) with the given ID.

Expand Down Expand Up @@ -602,6 +649,8 @@ def _process_model_outputs(
# Create the outputs.
request_outputs: List[RequestOutput] = []
for seq_group in scheduled_seq_groups:
if seq_group.is_prefix:
continue
request_output = RequestOutput.from_seq_group(seq_group)
request_outputs.append(request_output)
for seq_group in scheduler_outputs.ignored_seq_groups:
Expand Down
6 changes: 5 additions & 1 deletion vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Union
from typing import List, Optional, Union, Dict

from tqdm import tqdm
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
Expand Down Expand Up @@ -174,6 +174,10 @@ def _add_request(
self.llm_engine.add_request(request_id, prompt, sampling_params,
prompt_token_ids)

def add_prefix_template(self, prefix_map: Dict[str, str]) -> None:
self.llm_engine.add_prefix_template(prefix_map)
self._run_engine(True)

def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
# Initialize tqdm.
if use_tqdm:
Expand Down
6 changes: 5 additions & 1 deletion vllm/model_executor/input_metadata.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, List

import torch

Expand All @@ -22,13 +22,17 @@ def __init__(
context_lens: Optional[torch.Tensor],
block_tables: Optional[torch.Tensor],
use_cuda_graph: bool,
prefix_len_list: Optional[List[int]] = None,
prefix_slot_mapping: Optional[torch.Tensor] = None,
) -> None:
self.is_prompt = is_prompt
self.max_context_len = max_context_len
self.slot_mapping = slot_mapping
self.context_lens = context_lens
self.block_tables = block_tables
self.use_cuda_graph = use_cuda_graph
self.prefix_len_list = prefix_len_list
self.prefix_slot_mapping = prefix_slot_mapping

# Set during the execution of the first attention op.
# FIXME(woosuk): This is a hack.
Expand Down
Loading