Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,7 +732,6 @@ def forward(
prefill_output = output[:num_prefill_query_tokens]
assert query.shape[0] == num_prefill_query_tokens
assert decode_query.shape[0] == num_decode_query_tokens

if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
if (kv_cache.numel() == 0 or prefill_meta.block_tables is None
Expand Down
7 changes: 7 additions & 0 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,13 @@
if arrival_time is None:
arrival_time = time.time()

if isinstance(prompt, dict) and prompt.get("prompt_embeds",
None) is not None:
if not prompt.get("prompt_token_ids", None):

Check failure on line 491 in vllm/engine/async_llm_engine.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (SIM102)

vllm/engine/async_llm_engine.py:489:9: SIM102 Use a single `if` statement instead of nested `if` statements
prompt["prompt_token_ids"] = [
0
] * prompt["prompt_embeds"].shape[0]

if self.tokenizer is not None:
tokenizer = await self.get_tokenizer_async(lora_request)
self._validate_token_prompt(prompt, tokenizer=tokenizer)
Expand Down
4 changes: 4 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,10 @@
if arrival_time is None:
arrival_time = time.time()

if isinstance(prompt, dict) and prompt.get("prompt_embeds", None) is not None:
if not prompt.get("prompt_token_ids", None):
prompt["prompt_token_ids"] = [0] * prompt["prompt_embeds"].shape[0]

Check failure on line 758 in vllm/engine/llm_engine.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (SIM102)

vllm/engine/llm_engine.py:756:9: SIM102 Use a single `if` statement instead of nested `if` statements

if self.tokenizer is not None:
self._validate_token_prompt(
prompt,
Expand Down
147 changes: 14 additions & 133 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch.nn as nn
from tqdm import tqdm
from typing_extensions import TypeVar, deprecated

import torch
from vllm import envs
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
BeamSearchSequence, get_beam_search_score)
Expand Down Expand Up @@ -368,7 +368,7 @@
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None,
) -> list[RequestOutput]:
...

Check failure on line 371 in vllm/entrypoints/llm.py

View workflow job for this annotation

GitHub Actions / pre-commit

Overloaded function implementation does not accept all possible arguments of signature 2 [misc]

Check failure on line 371 in vllm/entrypoints/llm.py

View workflow job for this annotation

GitHub Actions / pre-commit

Overloaded function implementation does not accept all possible arguments of signature 3 [misc]

Check failure on line 371 in vllm/entrypoints/llm.py

View workflow job for this annotation

GitHub Actions / pre-commit

Overloaded function implementation does not accept all possible arguments of signature 6 [misc]

@deprecate_kwargs(
"prompt_token_ids",
Expand All @@ -382,6 +382,7 @@
sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None,
prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
prompt_embeds: Optional[torch.Tensor] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
Expand All @@ -401,13 +402,18 @@
for more details about the format of each prompts.
sampling_params: The sampling parameters for text generation. If
None, we use the default sampling parameters.
When it is a single value, it is applied to every prompt.

Check failure on line 405 in vllm/entrypoints/llm.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/entrypoints/llm.py:405:81: E501 Line too long (81 > 80)
When it is a list, the list must have the same length as the
prompts and it is paired one by one with the prompt.

Check failure on line 407 in vllm/entrypoints/llm.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/entrypoints/llm.py:407:81: E501 Line too long (81 > 80)
prompt_token_ids: DEPRECATED. Token IDs for the prompts. If provided,
the `prompts` will be ignored.
prompt_embeds: Optional tensor of prompt embeddings to use instead of
text prompts.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
guided_options_request: Options for guided decoding, if any.
priority: The priority of the requests, if any.
Only applicable when priority scheduling policy is enabled.

Expand Down Expand Up @@ -444,8 +450,13 @@
)
else:
parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
prompts)

Check failure on line 453 in vllm/entrypoints/llm.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/entrypoints/llm.py:453:81: E501 Line too long (99 > 80)

# Handle prompt_embeds separately
# This is a simplified approach - you may need to adjust based on how prompt_embeds is used
if prompt_embeds is not None:
parsed_prompts.prompt_embeds = prompt_embeds

if isinstance(guided_options_request, dict):
if len(guided_options_request) > 1:
raise ValueError(
Expand All @@ -456,7 +467,7 @@

if sampling_params is None:
# Use default sampling params.
sampling_params = self.get_default_sampling_params()

Check failure on line 470 in vllm/entrypoints/llm.py

View workflow job for this annotation

GitHub Actions / pre-commit

"LLM" has no attribute "_validate_and_add_requests" [attr-defined]

self._validate_and_add_requests(
prompts=parsed_prompts,
Expand Down Expand Up @@ -1227,6 +1238,7 @@
self,
prompts: Optional[Union[str, list[str]]],
prompt_token_ids: Optional[Union[list[int], list[list[int]]]],
prompt_embeds: Optional[torch.Tensor] = None,
):
# skip_tokenizer_init is now checked in engine

Expand Down Expand Up @@ -1261,140 +1273,9 @@
item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
else:
raise AssertionError

Check failure on line 1276 in vllm/entrypoints/llm.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/entrypoints/llm.py:1276:81: E501 Line too long (94 > 80)
parsed_prompts.append(item)

# We don't need to handle prompt_embeds here since it's handled in the generate method
return parsed_prompts

def _validate_and_add_requests(
self,
prompts: Union[PromptType, Sequence[PromptType]],
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
Sequence[PoolingParams]],
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
prompt_adapter_request: Optional[PromptAdapterRequest],
guided_options: Optional[GuidedDecodingRequest] = None,
priority: Optional[list[int]] = None,
) -> None:
if guided_options is not None:
warnings.warn(
"guided_options_request is deprecated, use "
"SamplingParams.guided_decoding instead",
DeprecationWarning,
stacklevel=2,
)

if isinstance(prompts, (str, dict)):
# Convert a single prompt to a list.
prompts = [prompts]

num_requests = len(prompts)
if isinstance(params, list) and len(params) != num_requests:
raise ValueError("The lengths of prompts and params "
"must be the same.")
if isinstance(lora_request,
list) and len(lora_request) != num_requests:
raise ValueError("The lengths of prompts and lora_request "
"must be the same.")

for sp in params if isinstance(params, list) else (params, ):
if isinstance(sp, SamplingParams):
self._add_guided_params(sp, guided_options)

# We only care about the final output
sp.output_kind = RequestOutputKind.FINAL_ONLY

# Add requests to the engine.
for i, prompt in enumerate(prompts):
self._add_request(
prompt,
params[i] if isinstance(params, Sequence) else params,
lora_request=lora_request[i] if isinstance(
lora_request, Sequence) else lora_request,
prompt_adapter_request=prompt_adapter_request,
priority=priority[i] if priority else 0,
)

def _add_request(
self,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
request_id = str(next(self.request_counter))
self.llm_engine.add_request(
request_id,
prompt,
params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
)

def _add_guided_params(
self,
params: SamplingParams,
guided_options: Optional[GuidedDecodingRequest] = None):
if guided_options is None:
return params

if params.guided_decoding is not None:
raise ValueError("Cannot set both guided_options_request and "
"params.guided_decoding.")

params.guided_decoding = GuidedDecodingParams(
json=guided_options.guided_json,
regex=guided_options.guided_regex,
choice=guided_options.guided_choice,
grammar=guided_options.guided_grammar,
json_object=guided_options.guided_json_object,
backend=guided_options.guided_decoding_backend,
whitespace_pattern=guided_options.guided_whitespace_pattern)
return params

def _run_engine(
self, *, use_tqdm: bool
) -> list[Union[RequestOutput, PoolingRequestOutput]]:
# Initialize tqdm.
if use_tqdm:
num_requests = self.llm_engine.get_num_unfinished_requests()
pbar = tqdm(
total=num_requests,
desc="Processed prompts",
dynamic_ncols=True,
postfix=(f"est. speed input: {0:.2f} toks/s, "
f"output: {0:.2f} toks/s"),
)

# Run the engine.
outputs: list[Union[RequestOutput, PoolingRequestOutput]] = []
total_in_toks = 0
total_out_toks = 0
while self.llm_engine.has_unfinished_requests():
step_outputs = self.llm_engine.step()
for output in step_outputs:
if output.finished:
outputs.append(output)
if use_tqdm:
if isinstance(output, RequestOutput):
# Calculate tokens only for RequestOutput
assert output.prompt_token_ids is not None
total_in_toks += len(output.prompt_token_ids)
in_spd = total_in_toks / pbar.format_dict["elapsed"]
total_out_toks += sum(
len(stp.token_ids) for stp in output.outputs)
out_spd = (total_out_toks /
pbar.format_dict["elapsed"])
pbar.postfix = (
f"est. speed input: {in_spd:.2f} toks/s, "
f"output: {out_spd:.2f} toks/s")
pbar.update(1)

if use_tqdm:
pbar.close()
# Sort the outputs by request ID.
# This is necessary because some requests may be finished earlier than
# its previous requests.
return sorted(outputs, key=lambda x: int(x.request_id))
14 changes: 13 additions & 1 deletion vllm/inputs/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ class TextPrompt(TypedDict):
prompt: str
"""The input text to be tokenized before passing to the model."""

prompt_embeds: NotRequired[torch.Tensor]
"""The embeddings of the prompt, if available."""

multi_modal_data: NotRequired["MultiModalDataDict"]
"""
Optional multi-modal data to pass to the model,
Expand All @@ -41,6 +44,9 @@ class TokensPrompt(TypedDict):
prompt_token_ids: List[int]
"""A list of token IDs to pass to the model."""

prompt_embeds: NotRequired[torch.Tensor]
"""The embeddings of the prompt, if available."""

token_type_ids: NotRequired[List[int]]
"""A list of token type IDs to pass to the cross encoder model."""

Expand Down Expand Up @@ -147,6 +153,9 @@ class TokenInputs(TypedDict):
The original prompt text corresponding to the token IDs, if available.
"""

prompt_embeds: NotRequired[torch.Tensor]
"""The embeddings of the prompt, if available."""

multi_modal_data: NotRequired["MultiModalDataDict"]
"""
Optional multi-modal data to pass to the model,
Expand Down Expand Up @@ -182,6 +191,7 @@ def token_inputs(
prompt_token_ids: List[int],
token_type_ids: Optional[List[int]] = None,
prompt: Optional[str] = None,
prompt_embeds: Optional[torch.Tensor] = None,
multi_modal_data: Optional["MultiModalDataDict"] = None,
multi_modal_inputs: Optional["MultiModalKwargs"] = None,
multi_modal_hashes: Optional[List[str]] = None,
Expand All @@ -195,6 +205,8 @@ def token_inputs(
inputs["prompt"] = prompt
if token_type_ids is not None:
inputs["token_type_ids"] = token_type_ids
if prompt_embeds is not None:
inputs["prompt_embeds"] = prompt_embeds
if multi_modal_data is not None:
inputs["multi_modal_data"] = multi_modal_data
if multi_modal_inputs is not None:
Expand Down Expand Up @@ -277,7 +289,7 @@ def prompt_embeds(self) -> Optional[torch.Tensor]:
inputs = self.inputs

if inputs["type"] == "token" or inputs["type"] == "multimodal":
return None
return inputs.get("prompt_embeds")

assert_never(inputs) # type: ignore[arg-type]

Expand Down
3 changes: 3 additions & 0 deletions vllm/inputs/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ def parse_singleton_prompt(
elif "prompt" in prompt:
return ParsedTextPrompt(type="text", content=prompt)

elif "prompt_embeds" in prompt:
return ParsedTokensPrompt(type="tokens", content=prompt)

raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt")


Expand Down
4 changes: 4 additions & 0 deletions vllm/inputs/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ def _prompt_to_llm_inputs(

return token_inputs(
prompt_token_ids=prompt_token_ids,
prompt_embeds=tokens_content.get("prompt_embeds"),
token_type_ids=token_type_ids,
multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs,
Expand Down Expand Up @@ -389,6 +390,7 @@ def _prompt_to_llm_inputs(
return token_inputs(
prompt=prompt_text,
prompt_token_ids=prompt_token_ids,
prompt_embeds=text_content.get("prompt_embeds"),
multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs,
)
Expand Down Expand Up @@ -434,6 +436,7 @@ async def _prompt_to_llm_inputs_async(

return token_inputs(
prompt_token_ids=prompt_token_ids,
prompt_embeds=tokens_content.get("prompt_embeds"),
multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs,
)
Expand Down Expand Up @@ -462,6 +465,7 @@ async def _prompt_to_llm_inputs_async(
return token_inputs(
prompt=prompt_text,
prompt_token_ids=prompt_token_ids,
prompt_embeds=tokens_content.get("prompt_embeds"),
multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs,
)
Expand Down
5 changes: 3 additions & 2 deletions vllm/model_executor/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
True,
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "lm_head"))
Expand All @@ -459,8 +460,8 @@ def forward(
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)

hidden_states = self.model(input_ids, positions,intermediate_tensors, inputs_embeds, self.lm_head.bias)
return hidden_states

def compute_logits(
Expand Down
14 changes: 13 additions & 1 deletion vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ class SequenceData(msgspec.Struct,
_output_token_ids: array = msgspec.field(
default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, []))

_prompt_embeds: Optional[torch.Tensor] = None

### The below fields should not be passed as an argument ###
_cumulative_logprob: float = 0.0
_prompt_token_ids_tuple: tuple[int,
Expand Down Expand Up @@ -254,7 +256,7 @@ def prompt_token_ids_array(self) -> array:
@property
def output_token_ids(self) -> tuple[int, ...]:
return tuple(self._output_token_ids)

@output_token_ids.setter
def output_token_ids(self,
new_output_token_ids: GenericSequence[int]) -> None:
Expand All @@ -271,6 +273,14 @@ def output_token_ids_array(self) -> array:
"""
assert isinstance(self._output_token_ids, array)
return self._output_token_ids

@property
def prompt_embeds(self) -> Optional[torch.Tensor]:
return self._prompt_embeds

@prompt_embeds.setter
def prompt_embeds(self, prompt_embeds: torch.Tensor) -> None:
self._prompt_embeds = prompt_embeds

@property
def mrope_position_delta(self) -> Optional[int]:
Expand Down Expand Up @@ -379,6 +389,7 @@ def stage(self) -> SequenceStage:
def __repr__(self) -> str:
return (f"SequenceData("
f"prompt_token_ids={self._prompt_token_ids}, "
f"prompt_embeds={getattr(self._prompt_embeds, 'shape', None)}, "
f"output_token_ids={self.output_token_ids}, "
f"cumulative_logprob={self.cumulative_logprob}, "
f"get_num_computed_tokens={self.get_num_computed_tokens()})")
Expand Down Expand Up @@ -418,6 +429,7 @@ def __init__(
self.prompt_adapter_request = prompt_adapter_request

self.data = SequenceData.from_seqs(self.prompt_token_ids)
self.data.prompt_embeds = self.inputs.prompt_embeds
self.output_logprobs: SampleLogprobs = []
self.output_text = ""

Expand Down
Loading