Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
cef6894
(vllm) add input embedding
Jan 2, 2025
c51d8fb
improve embedding input
Bryce1010 Jan 6, 2025
9564b40
(vllm) fix import error
Bryce1010 Mar 6, 2025
c60298a
(vllm) fix pre commit error
Bryce1010 Mar 6, 2025
0c24a82
apply ruff and isort fixes
qthequartermasterman Mar 25, 2025
403a165
apply ruff and isort fixes
qthequartermasterman Mar 25, 2025
b1ac072
styling
qthequartermasterman Mar 25, 2025
0390c33
fix missing imports from rebase
qthequartermasterman Mar 25, 2025
0ca4dae
typing fixes
qthequartermasterman Mar 25, 2025
35320fe
type fix
qthequartermasterman Mar 25, 2025
0a77630
type fix
qthequartermasterman Mar 25, 2025
11b6c02
remove unnecessary changes
qthequartermasterman Mar 25, 2025
cb92a3d
remove unnecessary changes
qthequartermasterman Mar 25, 2025
375bd5b
re-add deleted whitespace
qthequartermasterman Mar 25, 2025
c9d8024
Include unit tests from #6869.
qthequartermasterman Mar 25, 2025
a64e627
remove unrelated qwen2 changes
qthequartermasterman Mar 26, 2025
6ab349e
guard clause around fully consumed prompt embeds to avoid returning e…
qthequartermasterman Mar 27, 2025
26c8784
use v0 for prompt embeds model runner tests
qthequartermasterman Mar 27, 2025
b71a13c
fix batching of input embeddings
qthequartermasterman Apr 2, 2025
4aa9ade
style formatting
qthequartermasterman Apr 2, 2025
e2c4c26
remove incorrect overload
qthequartermasterman Apr 3, 2025
26d108a
remove incorrect overload
qthequartermasterman Apr 3, 2025
af20435
Update representations
qthequartermasterman Apr 4, 2025
25aaf3f
remove unrelated changes to docs
qthequartermasterman Apr 4, 2025
bc05860
remove unrelated typing change
qthequartermasterman Apr 4, 2025
b55800d
fix missing syntax
qthequartermasterman Apr 4, 2025
be42a17
do not schedule prompt embeds and non-prompt embeds in the same batch
qthequartermasterman Apr 4, 2025
c8fcfe4
fix style linelength
qthequartermasterman Apr 4, 2025
b21688f
Merge branch 'main' into feature/vllm/add-input-embedding
qthequartermasterman Apr 7, 2025
1e359ae
propogate embeddings for sampled output tokens for decoding
qthequartermasterman Apr 11, 2025
59fbe70
fix type check
qthequartermasterman Apr 11, 2025
c152a3a
do not schedule decode sequence groups with batches containing both p…
qthequartermasterman Apr 11, 2025
42ad800
Merge branch 'main' into feature/vllm/add-input-embedding
qthequartermasterman Apr 11, 2025
e7ab2a2
fix type check
qthequartermasterman Apr 11, 2025
911adbe
add default value to optional parameter
qthequartermasterman Apr 11, 2025
82d923d
remove unused comments
qthequartermasterman Apr 14, 2025
c951479
properly pass in placeholder token ids when testing prompt embeds
qthequartermasterman Apr 15, 2025
01e1a6e
do not test mixed token_ids/prompt_embeds batches in the model_runner
qthequartermasterman Apr 15, 2025
193ad5c
refactor cuda_prepare_decode test
qthequartermasterman Apr 15, 2025
74bd9f4
use correct expected input embeds length for prepare_decode_cuda_grap…
qthequartermasterman Apr 15, 2025
d949f1b
add scheduler test to ensure prompt embeds and prompt tokens are not …
qthequartermasterman Apr 15, 2025
62bbc88
support inputs_embeds in compiled mode
qthequartermasterman Apr 16, 2025
1d1ae4b
fix typing in test
qthequartermasterman Apr 16, 2025
1914676
use corrector operator precedence for handling empty strings
qthequartermasterman Apr 16, 2025
70198f6
only test decoder models with input embeds in v0 backend
qthequartermasterman Apr 16, 2025
934ceae
Merge branch 'vllm-project:main' into feature/vllm/add-input-embedding
qthequartermasterman Apr 16, 2025
5595b45
adjust type hints for modelinputforgpubuilder.build
qthequartermasterman Apr 18, 2025
3343d3e
simplify conditional logic
qthequartermasterman Apr 18, 2025
5010ea0
simplify compilation conditional logic
qthequartermasterman Apr 18, 2025
2075e53
refactor decoder only language model tests to reduce number of times …
qthequartermasterman Apr 18, 2025
9a4fb3c
break up multiple assignments for readability
qthequartermasterman Apr 18, 2025
8ad4091
update type hints in scheduler
qthequartermasterman Apr 18, 2025
9055daf
clear existing lists instead of instantiating new ones
qthequartermasterman Apr 18, 2025
9a57aca
preprocess tensors to handle batched/misshaped prompt embeds to avoid…
qthequartermasterman Apr 18, 2025
bbfb0f0
use seperate Embedsprompt class for preprocessing inputs embeddings
qthequartermasterman Apr 18, 2025
933e567
fix typing
qthequartermasterman Apr 18, 2025
4e0d12f
fix type errors
qthequartermasterman Apr 19, 2025
164aeb5
Merge branch 'vllm-project:main' into feature/vllm/add-input-embedding
qthequartermasterman Apr 19, 2025
9e6909e
fix mistaken type change
qthequartermasterman Apr 19, 2025
90b950a
add missing type hint
qthequartermasterman Apr 19, 2025
01d83f4
add spaces for style
qthequartermasterman Apr 20, 2025
6985452
seperate EmbedsInputs from TokenInputs and embeds_inputs from token_i…
qthequartermasterman Apr 20, 2025
e916551
fix docstrings for EmbedsInputs
qthequartermasterman Apr 20, 2025
69f8725
fix typing for token_type_ids
qthequartermasterman Apr 20, 2025
9c2c89f
fix typing for embeds_tokens in InputRegistry and InputsAdapter
qthequartermasterman Apr 20, 2025
499dc6a
remove prompts and prompt_token_ids from EmbedsPrompts
qthequartermasterman Apr 21, 2025
20668ca
Merge branch 'main' into feature/vllm/add-input-embedding
qthequartermasterman Apr 28, 2025
6712ba6
fight mypy to get correct typing for not embeds prompts
qthequartermasterman Apr 28, 2025
740b290
remove incorrect call to embeds_inputs
qthequartermasterman Apr 28, 2025
8f9bd51
wrestle with mypy and typeddict type narrowing
qthequartermasterman Apr 29, 2025
b8d36c6
wrestle with mypy and typeddict type narrowing
qthequartermasterman Apr 29, 2025
b764c19
support indexing graph runners that with inputs_embeds
qthequartermasterman Apr 29, 2025
cb6ff22
Merge branch 'main' into feature/vllm/add-input-embedding
qthequartermasterman May 1, 2025
85642d0
support encoder decoder models with inputs_embeds
qthequartermasterman May 1, 2025
b226fd6
simplify redundant ternary statement
qthequartermasterman May 1, 2025
b738d3f
explicitly remove support for inputs embeds with speculative decoding…
qthequartermasterman May 1, 2025
2340119
fix occasional device mismatch errors when appending output tokens to…
qthequartermasterman May 1, 2025
92b3264
Merge branch 'main' into feature/vllm/add-input-embedding
qthequartermasterman May 2, 2025
b9271c1
Merge branch 'main' into feature/vllm/add-input-embedding
qthequartermasterman May 2, 2025
28b0983
Fix a typo
DarkLight1337 May 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,7 @@ def __init__(

def get_inputs(
self,
prompts: list[str],
prompts: Union[list[str], list[torch.Tensor]],
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
Expand All @@ -801,16 +801,18 @@ def get_inputs(
if audios is not None and (audio := audios[i]) is not None:
multi_modal_data["audio"] = audio

inputs.append(
TextPrompt(prompt=prompt,
multi_modal_data=multi_modal_data
if multi_modal_data else None))
text_prompt_kwargs = {
("prompt" if isinstance(prompt, str) else "prompt_embeds"):
prompt,
"multi_modal_data": multi_modal_data or None
}
inputs.append(TextPrompt(**text_prompt_kwargs))

return inputs

def generate(
self,
prompts: list[str],
prompts: Union[list[str], list[torch.Tensor]],
sampling_params: SamplingParams,
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
Expand All @@ -836,7 +838,7 @@ def generate(
output_str = sample.text
output_ids = list(sample.token_ids)
req_sample_output_ids.append(prompt_ids + output_ids)
req_sample_output_strs.append(prompt_str + output_str)
req_sample_output_strs.append((prompt_str or "") + output_str)
outputs.append((req_sample_output_ids, req_sample_output_strs))
return outputs

Expand Down Expand Up @@ -903,7 +905,7 @@ def generate_encoder_decoder_w_logprobs(

def generate_greedy(
self,
prompts: list[str],
prompts: Union[list[str], list[torch.Tensor]],
max_tokens: int,
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
Expand Down
74 changes: 73 additions & 1 deletion tests/core/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@

import time
from collections import deque
from typing import Optional
from unittest.mock import MagicMock

import pytest # noqa
import torch
from torch import Use # noqa

from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.core.interfaces import AllocStatus
from vllm.core.scheduler import Scheduler, SchedulingBudget
from vllm.lora.request import LoRARequest
from vllm.sequence import SequenceGroup
from vllm.sequence import SequenceGroup, SequenceStatus

from .utils import (append_new_token, append_new_token_seq,
append_new_token_seq_group, create_dummy_prompt,
Expand Down Expand Up @@ -968,3 +970,73 @@ def test_no_multiple_partial_prefills_with_chunked_prefill_and_prefix_caching(
), "A partial prefix of C (4 tokens) should be prefilled, with the "
"remaining tokens fit into 3 token budget (4-1 from the seqA). It will "
"then be rounded down to 2 tokens on block size, thus 6 tokens in total."


def test_no_batches_mixed_with_prompt_tokens_and_prompt_embeds():
"""
Test that the scheduler does not schedule batches with prompt tokens and
prompt embeddings co-mingled.
"""
block_size = 2
max_seq_group = 3
scheduler = initialize_scheduler(
block_size=block_size,
num_cpu_blocks=16,
num_gpu_blocks=16,
max_num_seqs=max_seq_group,
max_model_len=100,
enable_prefix_caching=True,
)

# the odd indexed inputs should be passed in via embeddings,
# evens via token_ids
seq_length = 7
embedding_size = 5
num_seqs = 11
seq_tokens: list[list[int]] = []
seq_embeds: list[Optional[torch.Tensor]] = []
for i in range(num_seqs):
if i % 2:
seq_tokens.append(list(range(seq_length)))
seq_embeds.append(None)
else:
seq_tokens.append([0] * seq_length)
seq_embeds.append(torch.rand(embedding_size))

seq_and_seq_groups = [
create_dummy_prompt(f"{i}",
prompt_tokens=seq_tokens[i],
prompt_embeds=seq_embeds[i],
block_size=block_size)
for i in range(len(seq_tokens))
]

for _, seq_group in seq_and_seq_groups:
scheduler.add_seq_group(seq_group)

while not all(seq.is_finished() for seq, _ in seq_and_seq_groups):
unfinished_seq_groups = [
seq_group for _, seq_group in seq_and_seq_groups
if not seq_group.is_finished()
]
_, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) > 0
batch_is_prompt_embeds = out.scheduled_seq_groups[
0].seq_group.uses_prompt_embeds()
expected_scheduled_seq_groups = [
seq_group for seq_group in unfinished_seq_groups
if seq_group.uses_prompt_embeds() == batch_is_prompt_embeds
]

# We should have as many scheduled groups as possible, without mixing
assert len(out.scheduled_seq_groups) == min(
max_seq_group, len(expected_scheduled_seq_groups))
assert all(scheduled_seq_group.seq_group.uses_prompt_embeds() ==
batch_is_prompt_embeds
for scheduled_seq_group in out.scheduled_seq_groups)

# Finish the scheduled groups
for scheduled_seq_group in out.scheduled_seq_groups:
for seq in scheduled_seq_group.seq_group.seqs:
seq.status = SequenceStatus.FINISHED_STOPPED
scheduler.free_finished_seq_groups()
11 changes: 9 additions & 2 deletions tests/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
from collections.abc import Sequence as GenericSequence
from typing import Any, Optional

import torch

from vllm import SamplingParams
from vllm.core.scheduler import Scheduler, SchedulerOutputs
from vllm.inputs import EncoderDecoderInputs, token_inputs
from vllm.inputs import EncoderDecoderInputs, embeds_inputs, token_inputs
from vllm.lora.request import LoRARequest
from vllm.sequence import (Logprob, Sequence, SequenceGroup,
SequenceGroupMetadata)
Expand All @@ -19,6 +21,7 @@ def create_dummy_prompt(
block_size: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,
prompt_tokens: Optional[list[int]] = None,
prompt_embeds: Optional[torch.Tensor] = None,
min_tokens: int = 0,
max_tokens: int = 16,
) -> tuple[Sequence, SequenceGroup]:
Expand All @@ -31,9 +34,13 @@ def create_dummy_prompt(
prompt_tokens = list(range(prompt_length))

prompt_str = " ".join([str(t) for t in prompt_tokens])
inputs = token_inputs(
prompt_token_ids=prompt_tokens,
prompt=prompt_str) if prompt_embeds is None else embeds_inputs(
prompt_embeds=prompt_embeds)
prompt = Sequence(
int(request_id),
inputs=token_inputs(prompt_tokens, prompt=prompt_str),
inputs=inputs,
block_size=block_size,
)
seq_group = SequenceGroup(
Expand Down
26 changes: 26 additions & 0 deletions tests/models/language/generation/test_common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
import os
from typing import Optional

import pytest
import torch

Expand Down Expand Up @@ -110,6 +113,18 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)

prompt_embeds: Optional[list[torch.Tensor]] = [] if os.getenv(
"VLLM_USE_V1") == "0" else None
prompt_token_ids = []
for prompt in example_prompts:
token_ids = hf_model.tokenizer(prompt,
return_tensors="pt").input_ids.to(
hf_model.model.device)
prompt_token_ids.append(token_ids)
if prompt_embeds is not None:
prompt_embeds.append(hf_model.model.get_input_embeddings()(
token_ids).squeeze(0))

with vllm_runner(
model,
tokenizer_name=model_info.tokenizer or model,
Expand All @@ -119,13 +134,24 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
if prompt_embeds is not None:
vllm_outputs_from_embeds = vllm_model.generate_greedy_logprobs(
prompt_embeds, max_tokens, num_logprobs)

check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
if prompt_embeds is not None:
check_logprobs_close(
outputs_0_lst=vllm_outputs,
outputs_1_lst=vllm_outputs_from_embeds,
name_0="vllm",
name_1="vllm_from_embeds",
)

if use_rocm_aiter:
# this is to ensure that vllm engine
# has deallocated the memory before running the next
Expand Down
Loading