-
-
Notifications
You must be signed in to change notification settings - Fork 10.9k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Your current environment
environment
Hardware & Nvidia driver & OS
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 4090
Nvidia driver version: 550.107.02
OS: Ubuntu 22.04.4 LTS (x86_64)
conda env create -f environment_linux.yml
name: vllm_xxx
channels:
- conda-forge
- pytorch
- nvidia
- defaults
dependencies:
- python=3.11
- anaconda
- pip
- pip:
- easydict
pip install vllm==v0.5.4 / v0.5.3.post1 / v0.5.2 / v0.5.1
🐛 Describe the bug
Throughput and latency vary with max_num_seqs。
vllm 0.5.4 with enable_chunked_prefill =True, throughput is slightly lower than 0.5.3~0.5.1. very strange.
| max_num_seqs | requests/s | disable 0.5.4 | requests/s | disable 0.5.3 | requests/s | disable 0.5.2 | requests/s | disable 0.5.1 | requests/s | enable 0.5.4 | requests/s | enable 0.5.3 | requests/s | enable 0.5.2 | requests/s | enable 0.5.1 |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 1024 | 2.5115 | 36.1 | 2.4747 | 36.69 | 2.4878 | 36.82 | 2.4972 | 36.33 | 3.1797 | 39.72 | 3.1341 | 40.15 | 3.1696 | 39.67 | 3.1745 | 39.61 |
| 768 | 3.2838 | 42.6 | 3.2242 | 43.48 | 3.2608 | 42.93 | 3.2648 | 42.88 | 4.1047 | 43.4 | 3.9708 | 44.42 | 4.0413 | 43.6 | 4.0439 | 43.57 |
| 512 | 4.1063 | 51.93 | 4.0102 | 53.22 | 4.0966 | 52.07 | 4.0998 | 51.97 | 4.6486 | 46.79 | 4.6377 | 46.25 | 4.7419 | 45.15 | 4.747 | 45.13 |
| 384 | 4.1705 | 54.83 | 4.0749 | 56.35 | 4.1538 | 55.07 | 4.1587 | 55 | 4.798 | 49.66 | 4.8605 | 45.93 | 4.9834 | 44.72 | 5.0019 | 44.61 |
| 256 | 4.3613 | 59.37 | 4.2659 | 60.94 | 4.3586 | 59.44 | 4.3632 | 59.36 | 4.9096 | 48.27 | 5.4424 | 42.23 | 5.5876 | 41.1 | 5.5949 | 41.04 |
| 128 | 4.7441 | 42.75 | 4.6511 | 43.76 | 4.7564 | 42.62 | 4.7583 | 42.58 | 4.2047 | 32.77 | 4.0718 | 29.7 | 4.1528 | 29.11 | 4.1544 | 29.1 |
| 64 | 3.7161 | 29.56 | 3.6446 | 30.2 | 3.6972 | 29.71 | 3.7044 | 29.66 | 2.435 | 26.88 | 2.3879 | 25.74 | 2.4141 | 25.45 | 2.4175 | 25.41 |
| 32 | 2.6923 | 20.76 | 2.6465 | 21.16 | 2.6702 | 20.94 | 2.6735 | 20.92 | 1.6103 | 19.83 | 1.5846 | 19.55 | 1.593 | 19.44 | 1.5942 | 19.43 |
import os
import random
import numpy as np
import time
def benchmark(args):
random.seed(args.seed)
os.environ["VLLM_LOGGING_LEVEL"] = "ERROR"
os.environ["VLLM_NO_USAGE_STATS"] = "True"
import vllm
from vllm import LLMEngine, EngineArgs, SamplingParams, TextPrompt
print(vllm.__version__)
engine_args = EngineArgs(
model=args.model,
tokenizer=args.tokenizer,
quantization=args.quantization,
tensor_parallel_size=args.tensor_parallel_size,
seed=args.seed,
trust_remote_code=args.trust_remote_code,
dtype=args.dtype,
max_model_len=args.max_model_len,
gpu_memory_utilization=args.gpu_memory_utilization,
enforce_eager=args.enforce_eager,
kv_cache_dtype=args.kv_cache_dtype,
quantization_param_path=args.quantization_param_path,
device=args.device,
enable_prefix_caching=args.enable_prefix_caching,
download_dir=args.download_dir,
enable_chunked_prefill=args.enable_chunked_prefill,
max_num_batched_tokens=args.max_num_batched_tokens,
max_num_seqs=args.max_num_seqs,
distributed_executor_backend=args.distributed_executor_backend,
disable_log_stats=True
)
engine = LLMEngine.from_engine_args(engine_args)
prompt = "hi" * (args.input_len - 1)
requests = [(prompt, args.input_len, args.output_len)
for _ in range(args.num_prompts)]
start = time.perf_counter()
for request_id, (prompt, _, output_len) in enumerate(requests):
inputs = TextPrompt(prompt=prompt)
sampling_params = SamplingParams(
n=args.n,
temperature=0.0 if args.use_beam_search else 1.0,
top_p=1.0,
use_beam_search=args.use_beam_search,
ignore_eos=True,
max_tokens=output_len,
)
engine.add_request(str(request_id), inputs, sampling_params)
out = []
while engine.has_unfinished_requests():
request_outputs = engine.step()
out.append((time.perf_counter(), request_outputs))
end = time.perf_counter()
timestamp = {}
for t, rs in out:
for r in rs:
request_id = r.request_id
if request_id not in timestamp:
timestamp[request_id] = []
timestamp[request_id].append(t)
tpot = []
for v in timestamp.values():
dd = [v[i]-v[i-1] for i in range(1, len(v))]
tpot.extend(dd)
tpot = np.mean(tpot)
elapsed_time = end - start
total_num_tokens = sum(prompt_len + output_len
for _, prompt_len, output_len in requests)
print(f"Throughput: {len(requests) / elapsed_time:.4f} requests/s, "
f"{total_num_tokens / elapsed_time:.4f} tokens/s, "
f"Delay {tpot*1000:0.2f} ms")
if __name__ == '__main__':
from easydict import EasyDict as edict
args = edict()
args.dataset = None
args.input_len = 512
args.output_len = 512
args.model = "Qwen/Qwen2-7B-Instruct"
args.trust_remote_code = False
args.tokenizer = args.model
args.quantization = None
args.quantization_param_path = None
args.tensor_parallel_size = 1
args.seed = 0
args.n = 1
args.use_beam_search = False
args.num_prompts = 1000
args.dtype = 'auto'
args.max_model_len = 10000
args.enforce_eager = True
args.kv_cache_dtype = "auto"
args.device = "cuda"
args.enable_prefix_caching = False
args.gpu_memory_utilization = 0.9
args.output_json = None
args.distributed_executor_backend = None
args.download_dir = None
import sys
from concurrent.futures import ProcessPoolExecutor
def run(args):
with ProcessPoolExecutor(1) as executor:
f = executor.submit(benchmark, args)
f.result()
max_num_seqs_list = [1024, 768, 512, 384, 256, 128, 64, 32]
print()
print("enable_chunked_prefill = False")
for max_num_seqs in max_num_seqs_list:
print("max_num_seqs", max_num_seqs)
args.enable_chunked_prefill = False
args.max_num_batched_tokens = None
args.max_num_seqs = max_num_seqs
run(args)
print()
print("enable_chunked_prefill = True")
for max_num_seqs in max_num_seqs_list:
print("max_num_seqs", max_num_seqs)
args.enable_chunked_prefill = True
args.max_num_seqs = max_num_seqs
args.max_num_batched_tokens = args.max_num_seqs
run(args)
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working
