Skip to content

Commit 3f942ac

Browse files
authored
Fix latency benchmark script (#118)
1 parent 19d2899 commit 3f942ac

File tree

2 files changed

+43
-31
lines changed

2 files changed

+43
-31
lines changed

benchmark/benchmark_latency.py

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,71 +1,75 @@
11
import argparse
22
import time
3-
from typing import List
43

5-
from tqdm import tqdm
64
import numpy as np
75
import torch
6+
from tqdm import tqdm
87

9-
from cacheflow.core.server import (
10-
add_server_arguments, process_server_arguments,
11-
init_local_server_and_frontend_with_arguments)
12-
from cacheflow.sampling_params import SamplingParams
8+
from cacheflow import LLM, SamplingParams
139

1410

1511
def main(args: argparse.Namespace):
16-
server, frontend = init_local_server_and_frontend_with_arguments(args)
12+
print(args)
13+
14+
# Process all the requests in a single batch if possible.
15+
# NOTE(woosuk): If the request cannot be processed in a single batch,
16+
# the server will automatically process the request in multiple batches.
17+
llm = LLM(
18+
model=args.model,
19+
tensor_parallel_size=args.tensor_parallel_size,
20+
max_num_seqs=args.batch_size,
21+
max_num_batched_tokens=args.batch_size * args.input_len,
22+
)
1723

1824
sampling_params = SamplingParams(
1925
n=args.n,
2026
temperature=0.0 if args.use_beam_search else 1.0,
2127
top_p=1.0,
2228
use_beam_search=args.use_beam_search,
23-
stop_token_ids=set(),
29+
ignore_eos=True,
2430
max_tokens=args.output_len,
2531
)
2632
print(sampling_params)
27-
input_token_ids = [0] * args.input_len
33+
dummy_prompts = [""] * args.batch_size
34+
dummy_prompt_token_ids = [[0] * args.input_len] * args.batch_size
2835

29-
def profile_step(profile=False):
36+
def run_to_completion(profile: bool = False):
3037
if profile:
3138
torch.cuda.cudart().cudaProfilerStart()
32-
for _ in range(args.batch_size):
33-
dummy_prompt = ""
34-
frontend._add_query(dummy_prompt, input_token_ids, sampling_params)
35-
server.add_sequence_groups(frontend.get_inputs())
3639
start_time = time.time()
37-
while True:
38-
server.step()
39-
if not server.has_unfinished_requests():
40-
break
40+
41+
llm.generate(dummy_prompts, sampling_params, dummy_prompt_token_ids,
42+
use_tqdm=False)
43+
4144
end_time = time.time()
4245
latency = end_time - start_time
4346
if profile:
4447
torch.cuda.cudart().cudaProfilerStop()
4548
return latency
4649

47-
print("Warm up step")
48-
profile_step()
50+
print("Warming up...")
51+
run_to_completion(profile=False)
4952

5053
# Benchmark.
5154
latencies = []
52-
for _ in tqdm(range(3), desc="Profile step"):
53-
latencies.append(profile_step())
55+
for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
56+
latencies.append(run_to_completion(profile=False))
5457
print(f'Avg latency: {np.mean(latencies)} seconds')
5558

5659

5760
if __name__ == '__main__':
5861
parser = argparse.ArgumentParser(
59-
description='Benchmark the latency of decoding a single sentence.')
60-
parser = add_server_arguments(parser)
62+
description='Benchmark the latency of processing a single batch of '
63+
'requests till completion.')
64+
parser.add_argument('--model', type=str, default='facebook/opt-125m')
65+
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
6166
parser.add_argument('--input-len', type=int, default=32)
6267
parser.add_argument('--output-len', type=int, default=128)
6368
parser.add_argument('--batch-size', type=int, default=8)
64-
parser.add_argument('--n', type=int, default=1)
69+
parser.add_argument('--n', type=int, default=1,
70+
help='Number of generated sequences per prompt.')
6571
parser.add_argument('--use-beam-search', action='store_true')
72+
parser.add_argument('--num-iters', type=int, default=3,
73+
help='Number of iterations to run.')
6674
args = parser.parse_args()
67-
args = process_server_arguments(args)
68-
args.max_num_batched_tokens = max(
69-
args.max_num_batched_tokens, args.batch_size * args.input_len)
70-
print(args)
7175
main(args)

cacheflow/entrypoints/llm.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,18 +35,26 @@ def generate(
3535
self,
3636
prompts: List[str],
3737
sampling_params: Optional[SamplingParams] = None,
38+
prompt_token_ids: Optional[List[List[int]]] = None,
3839
use_tqdm: bool = True,
3940
) -> List[RequestOutput]:
4041
if sampling_params is None:
42+
# Use default sampling params.
4143
sampling_params = SamplingParams()
4244
# Initialize tqdm.
4345
if use_tqdm:
4446
pbar = tqdm(total=len(prompts), desc="Processed prompts")
4547

4648
# Add requests to the server.
47-
for prompt in prompts:
49+
for i in range(len(prompts)):
50+
prompt = prompts[i]
51+
if prompt_token_ids is None:
52+
token_ids = None
53+
else:
54+
token_ids = prompt_token_ids[i]
4855
request_id = str(next(self.request_counter))
49-
self.llm_server.add_request(request_id, prompt, sampling_params)
56+
self.llm_server.add_request(request_id, prompt, sampling_params,
57+
token_ids)
5058

5159
# Run the server.
5260
outputs: List[RequestOutput] = []

0 commit comments

Comments
 (0)