Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
0c872b3
Make encoder-decoder inputs a composed structure
DarkLight1337 Oct 23, 2024
9287a1b
Merge branch 'main' into refactor-preprocessing
DarkLight1337 Oct 23, 2024
fa5ad17
Rename
DarkLight1337 Oct 23, 2024
44fd058
Fix type error
DarkLight1337 Oct 23, 2024
b73a345
Fix test
DarkLight1337 Oct 23, 2024
fa968b5
Fix llama-3.2
DarkLight1337 Oct 23, 2024
906ee1e
Remove force_bos
DarkLight1337 Oct 24, 2024
005ad95
Add example output
DarkLight1337 Oct 24, 2024
a5f0c16
format
DarkLight1337 Oct 24, 2024
6ab44e4
Fix
DarkLight1337 Oct 24, 2024
21be11f
Merge branch 'main' into refactor-preprocessing
DarkLight1337 Oct 29, 2024
1f927d2
Merge branch 'main' into refactor-preprocessing
DarkLight1337 Oct 31, 2024
760db05
Fix merge
DarkLight1337 Oct 31, 2024
acb8e6f
Update mllama processing
DarkLight1337 Oct 31, 2024
3bed519
Fix line
DarkLight1337 Oct 31, 2024
ea861e0
format
DarkLight1337 Oct 31, 2024
f654421
Avoid repeated lookups
DarkLight1337 Oct 31, 2024
594794e
Remove unused import
DarkLight1337 Oct 31, 2024
08ea824
Fix mypy
DarkLight1337 Oct 31, 2024
b622f41
Merge branch 'main' into refactor-preprocessing
DarkLight1337 Oct 31, 2024
800960d
Merge branch 'main' into refactor-preprocessing
DarkLight1337 Oct 31, 2024
283bc2c
Fix merge
DarkLight1337 Oct 31, 2024
e8169ea
Merge branch 'main' into refactor-preprocessing
DarkLight1337 Nov 2, 2024
61bf1d1
Merge branch 'main' into refactor-preprocessing
DarkLight1337 Nov 3, 2024
b45cdc9
Fix missing import
DarkLight1337 Nov 3, 2024
4d33b1e
Improve error message
DarkLight1337 Nov 3, 2024
0a549e5
Add missing export
DarkLight1337 Nov 3, 2024
f741a75
Improve error message.
DarkLight1337 Nov 3, 2024
cd231fa
Format
DarkLight1337 Nov 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 26 additions & 31 deletions tests/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Tuple

from vllm import SamplingParams
from vllm.inputs import EncoderDecoderInputs, token_inputs
from vllm.lora.request import LoRARequest
from vllm.sequence import Logprob, Sequence, SequenceGroup

Expand All @@ -27,10 +28,7 @@ def create_dummy_prompt(
prompt_tokens = list(range(prompt_length))
prompt_str = " ".join([str(t) for t in prompt_tokens])
prompt = Sequence(int(request_id),
inputs={
"prompt": prompt_str,
"prompt_token_ids": prompt_tokens,
},
inputs=token_inputs(prompt_tokens, prompt=prompt_str),
block_size=block_size)
seq_group = SequenceGroup(request_id=request_id,
seqs=[prompt],
Expand Down Expand Up @@ -63,23 +61,21 @@ def create_dummy_prompt_encoder_decoder(
encoder_prompt_tokens = list(reversed(list(range(encoder_prompt_length))))
encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens])

inputs = {
"prompt": decoder_prompt_str,
"prompt_token_ids": decoder_prompt_tokens,
"encoder_prompt": encoder_prompt_str,
"encoder_prompt_token_ids": encoder_prompt_tokens,
"multi_modal_data": None,
inputs: EncoderDecoderInputs = {
"decoder": token_inputs(decoder_prompt_tokens,
prompt=decoder_prompt_str),
"encoder": token_inputs(encoder_prompt_tokens,
prompt=encoder_prompt_str),
}

decoder_prompt = Sequence(int(request_id),
inputs=inputs,
block_size=block_size,
from_decoder_prompt=True)
inputs=inputs["decoder"],
block_size=block_size)

encoder_prompt = Sequence(int(request_id),
inputs=inputs,
block_size=block_size,
from_decoder_prompt=False)
inputs=inputs["encoder"],
block_size=block_size)

seq_group = SequenceGroup(request_id=request_id,
seqs=[decoder_prompt],
sampling_params=SamplingParams(best_of=best_of),
Expand Down Expand Up @@ -108,7 +104,7 @@ def create_seq_group(
for seq_id_offset, output_len in enumerate(seq_output_lens):
seq = Sequence(
seq_id=seq_id_start + seq_id_offset,
inputs={"prompt_token_ids": prompt_token_ids},
inputs=token_inputs(prompt_token_ids),
block_size=16,
)

Expand Down Expand Up @@ -143,21 +139,19 @@ def create_seq_group_encoder_decoder(

prompt_token_ids = [0] * seq_prompt_len

inputs = {
"prompt": "",
"prompt_token_ids": prompt_token_ids,
"encoder_prompt": "",
"encoder_prompt_token_ids": prompt_token_ids,
"multi_modal_data": None,
inputs: EncoderDecoderInputs = {
"decoder": token_inputs(prompt_token_ids),
"encoder": token_inputs(prompt_token_ids),
}

seqs = []
for seq_id_offset, output_len in enumerate(seq_output_lens):
# Construct decoder input sequences
seq = Sequence(seq_id=seq_id_start + seq_id_offset,
inputs=inputs,
block_size=16,
from_decoder_prompt=True)
seq = Sequence(
seq_id=seq_id_start + seq_id_offset,
inputs=inputs["decoder"],
block_size=16,
)

for i in range(output_len):
seq.append_token_id(
Expand All @@ -167,10 +161,11 @@ def create_seq_group_encoder_decoder(
seqs.append(seq)

# Encoder input sequence
encoder_seq = Sequence(seq_id=seq_id_start + len(seq_output_lens),
inputs=inputs,
block_size=16,
from_decoder_prompt=False)
encoder_seq = Sequence(
seq_id=seq_id_start + len(seq_output_lens),
inputs=inputs["encoder"],
block_size=16,
)

return SequenceGroup(request_id=request_id,
seqs=seqs,
Expand Down
3 changes: 2 additions & 1 deletion tests/engine/output_processor/test_stop_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from transformers import PreTrainedTokenizer

from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.inputs import token_inputs
from vllm.sampling_params import SamplingParams
from vllm.sequence import Logprob, Sequence, SequenceStatus

Expand All @@ -15,7 +16,7 @@ def sequence_with_eos(text: str, eos_token: str,
"""
seq = Sequence(
seq_id=0,
inputs={"prompt_token_ids": []},
inputs=token_inputs([]),
block_size=16,
eos_token_id=eos_token_id,
)
Expand Down
7 changes: 3 additions & 4 deletions tests/test_cache_block_hashing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import pytest

from vllm.inputs import token_inputs
from vllm.lora.request import LoRARequest
from vllm.sequence import Sequence
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
Expand Down Expand Up @@ -70,10 +71,8 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int,
hashes[-1].append([])
prompt_token_ids = tokenizer.encode(prompt)
seq = Sequence(seq_id,
inputs={
"prompt": prompt,
"prompt_token_ids": prompt_token_ids,
},
inputs=token_inputs(prompt_token_ids,
prompt=prompt),
block_size=block_size,
eos_token_id=tokenizer.tokenizer.eos_token_id,
lora_request=lora_request)
Expand Down
6 changes: 2 additions & 4 deletions tests/tokenization/test_detokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
from transformers import AutoTokenizer

from vllm.inputs import token_inputs
from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup
from vllm.transformers_utils.detokenizer import (Detokenizer,
detokenize_incrementally)
Expand Down Expand Up @@ -169,10 +170,7 @@ def create_sequence(prompt_token_ids=None):
prompt_token_ids = prompt_token_ids or [1]
return Sequence(
seq_id=0,
inputs={
"prompt": "<s>",
"prompt_token_ids": prompt_token_ids,
},
inputs=token_inputs(prompt_token_ids, prompt="<s>"),
block_size=16,
)

Expand Down
51 changes: 24 additions & 27 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Set, Type, Union, cast, overload

import torch
from typing_extensions import TypeIs, TypeVar
from typing_extensions import TypeVar

import vllm.envs as envs
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
Expand All @@ -29,9 +29,9 @@
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.gpu_executor import GPUExecutor
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
EncoderDecoderInputs, InputRegistry, PromptType,
TokensPrompt)
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
PromptType)
from vllm.inputs.parse import is_encoder_decoder_inputs, is_token_prompt
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.logits_process import get_bad_words_logits_processors
Expand Down Expand Up @@ -638,7 +638,7 @@ def _verify_args(self) -> None:
def _add_processed_request(
self,
request_id: str,
processed_inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs],
processed_inputs: ProcessorInputs,
params: Union[SamplingParams, PoolingParams],
arrival_time: float,
lora_request: Optional[LoRARequest],
Expand Down Expand Up @@ -669,18 +669,19 @@ def _add_processed_request(
seq_id = next(self.seq_counter)
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)

seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
if is_encoder_decoder_inputs(processed_inputs):
decoder_inputs = processed_inputs["decoder"]
encoder_inputs = processed_inputs["encoder"]
else:
decoder_inputs = processed_inputs
encoder_inputs = None

seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
lora_request, prompt_adapter_request)

encoder_seq = None
if 'encoder_prompt_token_ids' in processed_inputs:
encoder_seq = Sequence(seq_id,
processed_inputs,
block_size,
eos_token_id,
lora_request,
prompt_adapter_request,
from_decoder_prompt=False)
encoder_seq = (None if encoder_inputs is None else Sequence(
seq_id, encoder_inputs, block_size, eos_token_id, lora_request,
prompt_adapter_request))

# Create a SequenceGroup based on SamplingParams or PoolingParams
if isinstance(params, SamplingParams):
Expand Down Expand Up @@ -874,7 +875,7 @@ def _validate_token_prompt(self, prompt: PromptType,
# This needs to happen before multimodal input pre-processing, which
# may add dummy <image> tokens that aren't part of the tokenizer's
# vocabulary.
if self._is_token_prompt(prompt):
if is_token_prompt(prompt):
prompt_ids = prompt["prompt_token_ids"]
if len(prompt_ids) == 0:
# Empty prompt check is handled later
Expand All @@ -884,10 +885,6 @@ def _validate_token_prompt(self, prompt: PromptType,
raise ValueError(
"Token id {} is out of vocabulary".format(max_input_id))

@staticmethod
def _is_token_prompt(prompt: PromptType) -> TypeIs[TokensPrompt]:
return isinstance(prompt, dict) and "prompt_token_ids" in prompt

def _create_sequence_group_with_sampling(
self,
request_id: str,
Expand Down Expand Up @@ -1974,17 +1971,17 @@ def create_trace_span(self, seq_group: SequenceGroup) -> None:
def is_encoder_decoder_model(self):
return self.input_preprocessor.is_encoder_decoder_model()

def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs,
EncoderDecoderInputs],
def _validate_model_inputs(self, inputs: ProcessorInputs,
lora_request: Optional[LoRARequest]):
if self.model_config.is_multimodal_model:
if is_encoder_decoder_inputs(inputs):
# For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length
prompt_ids = inputs.get("prompt_token_ids")
elif self.is_encoder_decoder_model():
prompt_ids = inputs.get("encoder_prompt_token_ids")
prompt_inputs = inputs["decoder" if self.model_config.
is_multimodal_model else "encoder"]
else:
prompt_ids = inputs.get("prompt_token_ids")
prompt_inputs = inputs

prompt_ids = prompt_inputs.get("prompt_token_ids")

if prompt_ids is None or len(prompt_ids) == 0:
raise ValueError("Prompt cannot be empty")
Expand Down
23 changes: 16 additions & 7 deletions vllm/engine/protocol.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import asyncio
from abc import ABC, abstractmethod
from typing import AsyncGenerator, List, Mapping, Optional, Union
from typing import AsyncGenerator, List, Mapping, Optional

from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.inputs.data import PromptType, TokensPrompt
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
Expand Down Expand Up @@ -60,7 +61,7 @@ def generate(

async def beam_search(
self,
prompt: Union[PromptType, List[int]],
prompt: PromptType,
model_config: ModelConfig,
request_id: str,
params: BeamSearchParams,
Expand All @@ -76,11 +77,19 @@ async def beam_search(
tokenizer = await self.get_tokenizer()
input_preprocessor = InputPreprocessor(model_config, tokenizer)

(prompt_text, prompt_token_ids, multi_modal_data,
mm_processor_kwargs) = input_preprocessor._extract_prompt_components(
prompt,
request_id=request_id,
)
if is_explicit_encoder_decoder_prompt(prompt):
raise NotImplementedError
else:
processed_inputs = input_preprocessor._prompt_to_llm_inputs(
prompt,
request_id=request_id,
)

prompt_token_ids = processed_inputs["prompt_token_ids"]
prompt_text = processed_inputs.get("prompt")
multi_modal_data = processed_inputs.get("multi_modal_data")
mm_processor_kwargs = processed_inputs.get("mm_processor_kwargs")

tokenized_length = len(prompt_token_ids)

sort_beams_key = create_sort_beams_key_function(
Expand Down
11 changes: 6 additions & 5 deletions vllm/inputs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from .data import (DecoderOnlyInputs, EncoderDecoderInputs,
ExplicitEncoderDecoderPrompt, PromptType, SingletonInputs,
SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt,
build_explicit_enc_dec_prompt, to_enc_dec_tuple_list,
token_inputs, zip_enc_dec_prompts)
ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType,
SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs,
TokensPrompt, build_explicit_enc_dec_prompt,
to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts)
from .registry import DummyData, InputContext, InputRegistry

INPUT_REGISTRY = InputRegistry()
Expand All @@ -22,9 +22,10 @@
"ExplicitEncoderDecoderPrompt",
"TokenInputs",
"token_inputs",
"SingletonInputs",
"DecoderOnlyInputs",
"EncoderDecoderInputs",
"ProcessorInputs",
"SingletonInputs",
"build_explicit_enc_dec_prompt",
"to_enc_dec_tuple_list",
"zip_enc_dec_prompts",
Expand Down
Loading