Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
70cf633
feat: :sparkles: do not fall back to v0 engine when enabling prompt e…
qthequartermasterman Aug 7, 2025
3f4de3f
feat: :pipe: propagate prompt_embeds from the requests to the model e…
qthequartermasterman Aug 8, 2025
88329c5
refactor: :recycle: :truck: move length_from_prompt_tokens_or_embeds …
qthequartermasterman Aug 14, 2025
3df28f4
fix: keep track of prompt embeds within input batch
qthequartermasterman Aug 14, 2025
addcd8e
fix: use placeholder prompt_token_ids in RequestOutput when using pro…
qthequartermasterman Aug 14, 2025
7179da1
feat: pass prompt embeds from InputBatch to GPU model runner
qthequartermasterman Aug 14, 2025
61b962d
fix: correctly place the correct indices
qthequartermasterman Aug 14, 2025
d429d21
Remove accidental file
qthequartermasterman Aug 14, 2025
9c90c87
fix: do not copy token_ids->embeds into a temporary tensor when batched
qthequartermasterman Aug 28, 2025
ab3f02d
disable prefix caching in v1 when prompt embeds is enabled.
qthequartermasterman Aug 28, 2025
994b600
disable prefix caching in v0 when prompt embeds is enabled.
qthequartermasterman Aug 28, 2025
4f1480f
Merge branch 'main' into enable-prompt-embeds-in-v1
qthequartermasterman Aug 28, 2025
d6a376b
use cpu/gpu buffer classes after merge
qthequartermasterman Aug 28, 2025
0ef06db
test: fix missing argument
qthequartermasterman Aug 28, 2025
7eeafac
test: add test cases for v1 engine + prompt embeds
qthequartermasterman Sep 3, 2025
b683d49
Merge branch 'main' into enable-prompt-embeds-in-v1
qthequartermasterman Sep 3, 2025
0c0c53a
fix: cudagraph compilation when prompt embeds are enabled.
qthequartermasterman Sep 4, 2025
ce908c8
style: use old style union syntax
qthequartermasterman Sep 4, 2025
f499067
fix: always cast tensors to CPU before sending through msgpack
qthequartermasterman Sep 4, 2025
b3ae070
style: reorder prompt_embeds in OpenAI spec so that they appear in ex…
qthequartermasterman Sep 4, 2025
4fc6454
style: remove unnecessary TODO comment
qthequartermasterman Sep 4, 2025
6d94a14
style: remove unnecessary TODO comment
qthequartermasterman Sep 4, 2025
3d9d400
style: remove unnecessary TODO comment
qthequartermasterman Sep 4, 2025
a23bc88
Merge branch 'main' into enable-prompt-embeds-in-v1
qthequartermasterman Sep 4, 2025
e236802
fix: avoid slow NCCL initialization by using unitialized CPU prompt e…
qthequartermasterman Sep 5, 2025
6574014
test: add prompt_embeds + tensor_parallel correctness tests
qthequartermasterman Sep 5, 2025
ce04ea0
Merge branch 'main' into enable-prompt-embeds-in-v1
qthequartermasterman Sep 5, 2025
6891ca6
Update vllm/v1/core/sched/output.py
qthequartermasterman Sep 5, 2025
418cbcf
Update vllm/v1/core/sched/output.py
qthequartermasterman Sep 5, 2025
28b32cd
refactor: :rename length_from_prompt_token_ids_or_prompt_embeds to le…
qthequartermasterman Sep 5, 2025
020a152
style: undo accidental indentation change
qthequartermasterman Sep 5, 2025
edf5f40
refactor: use is_token_ids in input batch instead of is_prompt_embeds
qthequartermasterman Sep 5, 2025
425a21d
refactor: use is_token_ids in model_runner instead of is_prompt_embeds
qthequartermasterman Sep 5, 2025
75d9f89
refactor: do not cast token_indices to tensor multiple times)
qthequartermasterman Sep 5, 2025
9bfc0f8
refactor: rename CpuGpuBuffer args to size
qthequartermasterman Sep 5, 2025
ca4abfc
refactor: coalesce CpuGpuBuffer back into one class with a conditiona…
qthequartermasterman Sep 5, 2025
b870980
Merge branch 'main' into enable-prompt-embeds-in-v1
qthequartermasterman Sep 8, 2025
a037619
refactor: refactor copying from cpu to gpu into _prepare_input_ids to…
qthequartermasterman Sep 8, 2025
fcec0c3
fix: remove unnecessary GPU copy
qthequartermasterman Sep 8, 2025
f055dda
refactor: avoid instantiating giant tensor on the CPU and OOMing
qthequartermasterman Sep 8, 2025
5559267
Update vllm/v1/worker/gpu_input_batch.py
qthequartermasterman Sep 9, 2025
17980f0
chore: forbid prompt logprobs and prompt embeds in a request
qthequartermasterman Sep 9, 2025
cd40fe2
feat: do not allow prompt logprobs when using prompt embeds
qthequartermasterman Sep 9, 2025
912963c
Merge branch 'main' into enable-prompt-embeds-in-v1
qthequartermasterman Sep 9, 2025
67f3fee
style: remove unnecessary TODO
qthequartermasterman Sep 9, 2025
3f98ed6
Update vllm/v1/worker/gpu_model_runner.py
qthequartermasterman Sep 10, 2025
c7b2549
style: remove unnecessary comment
qthequartermasterman Sep 10, 2025
fcb0b64
test: add prompt_embeds to data structures that expect them
qthequartermasterman Sep 11, 2025
9e7b78c
Merge branch 'main' into enable-prompt-embeds-in-v1
qthequartermasterman Sep 11, 2025
3844fc6
test: fix missing arguments in signature
qthequartermasterman Sep 11, 2025
77dd4cd
Merge branch 'main' into enable-prompt-embeds-in-v1
qthequartermasterman Sep 11, 2025
46fea2c
test: add missing arguments
qthequartermasterman Sep 11, 2025
9dca231
Merge branch 'main' into enable-prompt-embeds-in-v1
qthequartermasterman Sep 12, 2025
5e70006
Merge branch 'main' into enable-prompt-embeds-in-v1
qthequartermasterman Sep 12, 2025
fe66418
Merge branch 'main' into enable-prompt-embeds-in-v1
qthequartermasterman Sep 12, 2025
5a7ebb8
test: require new process for model tests because tensor parallel can…
qthequartermasterman Sep 12, 2025
609c256
refactor: remove unnecessary tensor parallel tests. #suppress-api-com…
qthequartermasterman Sep 12, 2025
680973d
Merge branch 'main' into enable-prompt-embeds-in-v1
qthequartermasterman Sep 13, 2025
a98e105
test: use kwargs for clarity
qthequartermasterman Sep 13, 2025
2f7f5b2
fix: remove vestigial inputbatch attributes
qthequartermasterman Sep 15, 2025
20117fa
fix: remove vestigial inputbatch attributes
qthequartermasterman Sep 15, 2025
5f74fd3
Merge branch 'main' into enable-prompt-embeds-in-v1
qthequartermasterman Sep 15, 2025
b50c414
test: remove now-unnecessary flag in prompt embeds tests
qthequartermasterman Sep 15, 2025
f94fc4a
refactor: make prompt_embeds have a default value of None to avoid a …
qthequartermasterman Sep 16, 2025
4733d9d
chore: remove now vestigial type narrowing assert
qthequartermasterman Sep 16, 2025
0e8be37
Merge branch 'main' into enable-prompt-embeds-in-v1
qthequartermasterman Sep 16, 2025
8960f3a
refactor: make prompt_embeds have a default value of None to avoid a …
qthequartermasterman Sep 16, 2025
9fe8244
Merge branch 'main' into enable-prompt-embeds-in-v1
qthequartermasterman Sep 16, 2025
403dad2
Merge branch 'main' into enable-prompt-embeds-in-v1
qthequartermasterman Sep 16, 2025
56b136c
Merge branch 'main' into enable-prompt-embeds-in-v1
qthequartermasterman Sep 16, 2025
78357ef
test: revert changes to match main and reduce diff
qthequartermasterman Sep 16, 2025
c8cc1ff
Merge branch 'main' into enable-prompt-embeds-in-v1
qthequartermasterman Sep 17, 2025
aca498d
revert changes independently included in #25077
qthequartermasterman Sep 17, 2025
ecbb5b0
Merge branch 'main' into enable-prompt-embeds-in-v1
qthequartermasterman Sep 17, 2025
6dad7a8
Merge branch 'main' into enable-prompt-embeds-in-v1
qthequartermasterman Sep 18, 2025
1604a57
Merge branch 'main' into enable-prompt-embeds-in-v1
qthequartermasterman Sep 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions tests/basic_correctness/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,6 @@ def test_models(
model_executor: str,
enable_prompt_embeds: bool,
) -> None:

if enable_prompt_embeds and envs.is_set(
"VLLM_USE_V1") and envs.VLLM_USE_V1:
pytest.skip("enable_prompt_embeds is not supported in v1.")

if not envs.VLLM_USE_V1:
if async_scheduling:
pytest.skip("async_scheduling only supported in v1.")
Expand Down Expand Up @@ -164,11 +159,6 @@ def test_models_distributed(
extra_env: dict[str, str],
enable_prompt_embeds: bool,
) -> None:

if enable_prompt_embeds and envs.is_set(
"VLLM_USE_V1") and envs.VLLM_USE_V1:
pytest.skip("enable_prompt_embeds is not supported in v1.")

if test_suite != TARGET_TEST_SUITE:
pytest.skip(f"Skip test for {test_suite}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def default_server_args() -> list[str]:
"--enforce-eager",
# Prompt Embeds server args
"--enable-prompt-embeds",
"--no-enable-chunked-prefill",
]


Expand Down
6 changes: 0 additions & 6 deletions tests/models/language/generation/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,6 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
# in parts of the operators
pytest.skip(f"Skipping '{model}' model test with AITER kernel.")

# Note: can be removed when
# https://github.com/vllm-project/vllm/pull/24278 finished
if current_platform.is_cpu() and use_prompt_embeds:
pytest.skip("Skipping use_prompt_embeds=True with "
"V1-only CPU backend.")

with hf_runner(model) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)
Expand Down
24 changes: 18 additions & 6 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1513,12 +1513,6 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
recommend_to_remove=False)
return False

# No text embedding inputs so far.
if self.enable_prompt_embeds:
_raise_or_fallback(feature_name="--enable-prompt-embeds",
recommend_to_remove=False)
return False

# No Mamba or Encoder-Decoder so far.
if not model_config.is_v1_compatible:
_raise_or_fallback(feature_name=model_config.architectures,
Expand Down Expand Up @@ -1651,6 +1645,13 @@ def _set_default_args_v0(self, model_config: ModelConfig) -> None:
"models in V0 and has been disabled.")
self.enable_prefix_caching = False

if self.enable_prompt_embeds:
logger.warning(
"--enable-prompt-embeds and --enable-prefix-caching "
"are not supported together in V0. Prefix caching has "
"been disabled.")
self.enable_prefix_caching = False

# Set max_num_seqs to 256 for VLLM_V0.
if self.max_num_seqs is None:
self.max_num_seqs = 256
Expand All @@ -1664,6 +1665,17 @@ def _set_default_args_v1(self, usage_context: UsageContext,
# For pooling tasks the default is False
if model_config.runner_type != "pooling":
self.enable_chunked_prefill = True

# TODO: When prefix caching supports prompt embeds inputs, this
# check can be removed.
if (self.enable_prompt_embeds
and self.enable_prefix_caching is not False):
logger.warning(
"--enable-prompt-embeds and --enable-prefix-caching "
"are not supported together in V1. Prefix caching has "
"been disabled.")
self.enable_prefix_caching = False

if self.enable_prefix_caching is None:
self.enable_prefix_caching = True
else:
Expand Down
2 changes: 1 addition & 1 deletion vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,7 +973,6 @@ class CompletionRequest(OpenAIBaseModel):
# https://platform.openai.com/docs/api-reference/completions/create
model: Optional[str] = None
prompt: Optional[Union[list[int], list[list[int]], str, list[str]]] = None
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None
best_of: Optional[int] = None
echo: Optional[bool] = False
frequency_penalty: Optional[float] = 0.0
Expand Down Expand Up @@ -1009,6 +1008,7 @@ class CompletionRequest(OpenAIBaseModel):
# --8<-- [end:completion-sampling-params]

# --8<-- [start:completion-extra-params]
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None
add_special_tokens: bool = Field(
default=True,
description=(
Expand Down
27 changes: 27 additions & 0 deletions vllm/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3443,3 +3443,30 @@ def decorate_logs(process_name: Optional[str] = None) -> None:
pid = os.getpid()
_add_prefix(sys.stdout, process_name, pid)
_add_prefix(sys.stderr, process_name, pid)


def length_from_prompt_token_ids_or_embeds(
prompt_token_ids: Optional[list[int]],
prompt_embeds: Optional[torch.Tensor],
) -> int:
"""Calculate the request length (in number of tokens) give either
prompt_token_ids or prompt_embeds.
"""
prompt_token_len = None if prompt_token_ids is None else len(
prompt_token_ids)
prompt_embeds_len = \
None if prompt_embeds is None else len(prompt_embeds)

if prompt_token_len is None:
if prompt_embeds_len is None:
raise ValueError(
"Neither prompt_token_ids nor prompt_embeds were defined.")
return prompt_embeds_len
else:
if (prompt_embeds_len is not None
and prompt_embeds_len != prompt_token_len):
raise ValueError(
"Prompt token ids and prompt embeds had different lengths"
f" prompt_token_ids={prompt_token_len}"
f" prompt_embeds={prompt_embeds_len}")
return prompt_token_len
24 changes: 18 additions & 6 deletions vllm/v1/core/sched/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
if TYPE_CHECKING:
import numpy as np
import numpy.typing as npt
import torch

from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorMetadata)
Expand All @@ -26,13 +27,14 @@
class NewRequestData:

req_id: str
prompt_token_ids: list[int]
prompt_token_ids: Optional[list[int]]

Check warning on line 30 in vllm/v1/core/sched/output.py

View workflow job for this annotation

GitHub Actions / bc_lint

Function NewRequestData: prompt_token_ids changed from list[int] to Optional[list[int]]
mm_features: list[MultiModalFeatureSpec]
sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams]
block_ids: tuple[list[int], ...]
num_computed_tokens: int
lora_request: Optional[LoRARequest]
prompt_embeds: Optional[torch.Tensor] = None

@classmethod
def from_request(
Expand All @@ -49,29 +51,39 @@
block_ids=block_ids,
num_computed_tokens=request.num_computed_tokens,
lora_request=request.lora_request,
prompt_embeds=request.prompt_embeds,
)

def __repr__(self):
def __repr__(self) -> str:
prompt_embeds_shape = (self.prompt_embeds.shape
if self.prompt_embeds else None)
return (f"NewRequestData("
f"req_id={self.req_id},"
f"prompt_token_ids={self.prompt_token_ids},"
f"mm_features={self.mm_features},"
f"sampling_params={self.sampling_params},"
f"block_ids={self.block_ids},"
f"num_computed_tokens={self.num_computed_tokens},"
f"lora_request={self.lora_request}"
f"lora_request={self.lora_request},"
f"prompt_embeds_shape={prompt_embeds_shape}"
")")

# Version of __repr__ with the prompt data obfuscated
def anon_repr(self):
def anon_repr(self) -> str:
prompt_token_ids_len = len(
self.prompt_token_ids
) if self.prompt_token_ids is not None else None
prompt_embeds_shape = (self.prompt_embeds.shape
if self.prompt_embeds else None)
return (f"NewRequestData("
f"req_id={self.req_id},"
f"prompt_token_ids_len={len(self.prompt_token_ids)},"
f"prompt_token_ids_len={prompt_token_ids_len},"
f"mm_features={self.mm_features},"
f"sampling_params={self.sampling_params},"
f"block_ids={self.block_ids},"
f"num_computed_tokens={self.num_computed_tokens},"
f"lora_request={self.lora_request}"
f"lora_request={self.lora_request},"
f"prompt_embeds_shape={prompt_embeds_shape}"
")")


Expand Down
3 changes: 2 additions & 1 deletion vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class EngineCoreRequest(
gc=False): # type: ignore[call-arg]

request_id: str
prompt_token_ids: list[int]
prompt_token_ids: Optional[list[int]]
mm_features: Optional[list[MultiModalFeatureSpec]]
sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams]
Expand All @@ -56,6 +56,7 @@ class EngineCoreRequest(
lora_request: Optional[LoRARequest]
cache_salt: Optional[str]
data_parallel_rank: Optional[int]
prompt_embeds: Optional[torch.Tensor] = None

# Index of the client, used to ensure outputs are sent back to the same
# client for this request when scaling out the front-end.
Expand Down
31 changes: 21 additions & 10 deletions vllm/v1/engine/detokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from vllm.logger import init_logger
from vllm.transformers_utils.detokenizer_utils import (
AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally)
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.v1.engine import EngineCoreRequest

logger = init_logger(__name__)
Expand Down Expand Up @@ -179,11 +180,12 @@ def __init__(self, tokenizer: PreTrainedTokenizerFast,
self.tokenizer: Tokenizer = tokenizer._tokenizer

# Find a safe place to start.
prompt_suffix = request.prompt_token_ids
prompt_token_ids = request.prompt_token_ids or []
prompt_suffix = prompt_token_ids
prompt_len = len(prompt_suffix)
if prompt_len > 4:
for i in range(4, min(prompt_len + 1, 24)):
suffix = request.prompt_token_ids[-i:]
suffix = prompt_token_ids[-i:]
if '�' not in self.tokenizer.decode(suffix):
prompt_suffix = suffix
break
Expand Down Expand Up @@ -260,16 +262,25 @@ def __init__(self, tokenizer: AnyTokenizer, request: EngineCoreRequest):
params = request.sampling_params
assert params is not None

self.prompt_len = length_from_prompt_token_ids_or_embeds(
request.prompt_token_ids, request.prompt_embeds)

# Metadata for incremental detokenization.
self.tokens, self.prefix_offset, self.read_offset = (
convert_prompt_ids_to_tokens(
tokenizer=tokenizer,
prompt_ids=request.prompt_token_ids,
skip_special_tokens=params.skip_special_tokens,
))
if request.prompt_token_ids is not None:
self.tokens, self.prefix_offset, self.read_offset = (
convert_prompt_ids_to_tokens(
tokenizer=tokenizer,
prompt_ids=request.prompt_token_ids,
skip_special_tokens=params.skip_special_tokens,
))
else:
# Prompt embedding requests cannot be detokenized, in general.
self.tokens = [""] * self.prompt_len
self.prefix_offset = 0
self.read_offest = 0

self.token_ids.extend(request.prompt_token_ids)
self.prompt_len = len(request.prompt_token_ids)
self.token_ids.extend(request.prompt_token_ids
or [0] * self.prompt_len)

self.skip_special_tokens = params.skip_special_tokens
self.spaces_between_special_tokens = (
Expand Down
25 changes: 20 additions & 5 deletions vllm/v1/engine/output_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from vllm.tracing import (SpanAttributes, SpanKind, Tracer,
extract_trace_context)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
from vllm.v1.engine.logprobs import LogprobsProcessor
Expand Down Expand Up @@ -86,7 +87,8 @@ def __init__(
lora_name: Optional[str],
output_kind: RequestOutputKind,
prompt: Optional[str],
prompt_token_ids: list[int],
prompt_token_ids: Optional[list[int]],
prompt_embeds: Optional[torch.Tensor],
logprobs_processor: Optional[LogprobsProcessor],
detokenizer: Optional[IncrementalDetokenizer],
max_tokens_param: Optional[int],
Expand All @@ -104,7 +106,9 @@ def __init__(
self.output_kind = output_kind
self.prompt = prompt
self.prompt_token_ids = prompt_token_ids
self.prompt_len = len(prompt_token_ids)
self.prompt_embeds = prompt_embeds
self.prompt_len = length_from_prompt_token_ids_or_embeds(
self.prompt_token_ids, self.prompt_embeds)
self.logprobs_processor = logprobs_processor
self.detokenizer = detokenizer
self.max_tokens_param = max_tokens_param
Expand Down Expand Up @@ -165,6 +169,7 @@ def from_new_request(
output_kind=output_kind,
prompt=prompt,
prompt_token_ids=request.prompt_token_ids,
prompt_embeds=request.prompt_embeds,
logprobs_processor=logprobs_processor,
detokenizer=detokenizer,
max_tokens_param=max_tokens_param,
Expand Down Expand Up @@ -223,6 +228,8 @@ def _new_request_output(
first_output = outputs[0]
if isinstance(first_output, PoolingOutput):
assert len(outputs) == 1
# Prompt embeddings are currently not supported by pooling requests.
assert self.prompt_token_ids is not None
return PoolingRequestOutput(
request_id=request_id,
outputs=first_output,
Expand All @@ -236,10 +243,15 @@ def _new_request_output(
else:
prompt_logprobs = self.logprobs_processor.prompt_logprobs

# If prompt embeds were used, put placeholder prompt token ids
prompt_token_ids = self.prompt_token_ids
if prompt_token_ids is None and self.prompt_embeds is not None:
prompt_token_ids = [0] * len(self.prompt_embeds)

return RequestOutput(
request_id=request_id,
prompt=self.prompt,
prompt_token_ids=self.prompt_token_ids,
prompt_token_ids=prompt_token_ids,
prompt_logprobs=prompt_logprobs,
outputs=cast(list[CompletionOutput], outputs),
finished=finished,
Expand Down Expand Up @@ -469,6 +481,8 @@ def do_tracing(self, engine_core_output: EngineCoreOutput,

arrival_time_nano_seconds = int(req_state.stats.arrival_time * 1e9)
trace_context = extract_trace_context(engine_core_output.trace_headers)
prompt_length = length_from_prompt_token_ids_or_embeds(
req_state.prompt_token_ids, req_state.prompt_embeds)
with (self.tracer.start_as_current_span(
"llm_request",
kind=SpanKind.SERVER,
Expand All @@ -488,7 +502,7 @@ def do_tracing(self, engine_core_output: EngineCoreOutput,
span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE,
queued_time)
span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS,
len(req_state.prompt_token_ids))
prompt_length)
span.set_attribute(SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
metrics.num_generation_tokens)
span.set_attribute(
Expand Down Expand Up @@ -544,7 +558,8 @@ def _update_stats_from_finished(self, req_state: RequestState,
assert req_state.stats is not None
iteration_stats.update_from_finished_request(
finish_reason=finish_reason,
num_prompt_tokens=len(req_state.prompt_token_ids),
num_prompt_tokens=length_from_prompt_token_ids_or_embeds(
req_state.prompt_token_ids, req_state.prompt_embeds),
max_tokens_param=req_state.max_tokens_param,
req_stats=req_state.stats)
self.lora_states.finish_request(req_state)
Expand Down
Loading
Loading