Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import PromptType
from vllm.sampling_params import BeamSearchParams
from vllm.utils import FlexibleArgumentParser


Expand Down Expand Up @@ -40,6 +41,20 @@ def main(args: argparse.Namespace):
"prompt_token_ids": batch
} for batch in dummy_prompt_token_ids.tolist()]

def llm_generate():
if not args.use_beam_search:
llm.generate(dummy_prompts,
sampling_params=sampling_params,
use_tqdm=False)
else:
llm.beam_search(
dummy_prompts,
BeamSearchParams(
beam_width=args.n,
max_tokens=args.output_len,
ignore_eos=True,
))

def run_to_completion(profile_dir: Optional[str] = None):
if profile_dir:
with torch.profiler.profile(
Expand All @@ -49,15 +64,11 @@ def run_to_completion(profile_dir: Optional[str] = None):
],
on_trace_ready=torch.profiler.tensorboard_trace_handler(
str(profile_dir))) as p:
llm.generate(dummy_prompts,
sampling_params=sampling_params,
use_tqdm=False)
llm_generate()
print(p.key_averages().table(sort_by="self_cuda_time_total"))
else:
start_time = time.perf_counter()
llm.generate(dummy_prompts,
sampling_params=sampling_params,
use_tqdm=False)
llm_generate()
end_time = time.perf_counter()
latency = end_time - start_time
return latency
Expand Down
10 changes: 6 additions & 4 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
parse_chat_messages,
resolve_chat_template_content_format)
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding.guided_fields import (
Expand Down Expand Up @@ -457,7 +457,7 @@ def generate(

def beam_search(
self,
prompts: List[Union[str, List[int]]],
prompts: List[Union[TokensPrompt, TextPrompt]],
params: BeamSearchParams,
) -> List[BeamSearchOutput]:
"""
Expand Down Expand Up @@ -493,8 +493,10 @@ def sort_beams_key(x: BeamSearchSequence) -> float:
instances: List[BeamSearchInstance] = []

for prompt in prompts:
prompt_tokens = prompt if isinstance(
prompt, list) else tokenizer.encode(prompt)
if is_token_prompt(prompt):
prompt_tokens = prompt["prompt_token_ids"]
else:
prompt_tokens = tokenizer.encode(prompt["prompt"])
instances.append(BeamSearchInstance(prompt_tokens))

for _ in range(max_tokens):
Expand Down
Loading