Skip to content

[Performance]: vllm 0.5.4 with enable_chunked_prefill =True, throughput is slightly lower than 0.5.3~0.5.0. #7592

@noooop

Description

@noooop

Your current environment

environment

Hardware & Nvidia driver & OS

GPU models and configuration: GPU 0: NVIDIA GeForce RTX 4090

Nvidia driver version: 550.107.02

OS: Ubuntu 22.04.4 LTS (x86_64)

conda env create -f environment_linux.yml

name: vllm_xxx
channels:
  - conda-forge
  - pytorch
  - nvidia
  - defaults
dependencies:
  - python=3.11
  - anaconda
  - pip
  - pip:
      - easydict
  
pip install vllm==v0.5.4 / v0.5.3.post1 / v0.5.2 / v0.5.1

🐛 Describe the bug

Throughput and latency vary with max_num_seqs。

vllm 0.5.4 with enable_chunked_prefill =True, throughput is slightly lower than 0.5.3~0.5.1. very strange.

chunked_prefill

max_num_seqs requests/s disable 0.5.4 requests/s disable 0.5.3 requests/s disable 0.5.2 requests/s disable 0.5.1 requests/s enable 0.5.4 requests/s enable 0.5.3 requests/s enable 0.5.2 requests/s enable 0.5.1
1024 2.5115 36.1 2.4747 36.69 2.4878 36.82 2.4972 36.33 3.1797 39.72 3.1341 40.15 3.1696 39.67 3.1745 39.61
768 3.2838 42.6 3.2242 43.48 3.2608 42.93 3.2648 42.88 4.1047 43.4 3.9708 44.42 4.0413 43.6 4.0439 43.57
512 4.1063 51.93 4.0102 53.22 4.0966 52.07 4.0998 51.97 4.6486 46.79 4.6377 46.25 4.7419 45.15 4.747 45.13
384 4.1705 54.83 4.0749 56.35 4.1538 55.07 4.1587 55 4.798 49.66 4.8605 45.93 4.9834 44.72 5.0019 44.61
256 4.3613 59.37 4.2659 60.94 4.3586 59.44 4.3632 59.36 4.9096 48.27 5.4424 42.23 5.5876 41.1 5.5949 41.04
128 4.7441 42.75 4.6511 43.76 4.7564 42.62 4.7583 42.58 4.2047 32.77 4.0718 29.7 4.1528 29.11 4.1544 29.1
64 3.7161 29.56 3.6446 30.2 3.6972 29.71 3.7044 29.66 2.435 26.88 2.3879 25.74 2.4141 25.45 2.4175 25.41
32 2.6923 20.76 2.6465 21.16 2.6702 20.94 2.6735 20.92 1.6103 19.83 1.5846 19.55 1.593 19.44 1.5942 19.43
import os
import random
import numpy as np
import time


def benchmark(args):
    random.seed(args.seed)

    os.environ["VLLM_LOGGING_LEVEL"] = "ERROR"
    os.environ["VLLM_NO_USAGE_STATS"] = "True"

    import vllm
    from vllm import LLMEngine, EngineArgs, SamplingParams, TextPrompt

    print(vllm.__version__)

    engine_args = EngineArgs(
        model=args.model,
        tokenizer=args.tokenizer,
        quantization=args.quantization,
        tensor_parallel_size=args.tensor_parallel_size,
        seed=args.seed,
        trust_remote_code=args.trust_remote_code,
        dtype=args.dtype,
        max_model_len=args.max_model_len,
        gpu_memory_utilization=args.gpu_memory_utilization,
        enforce_eager=args.enforce_eager,
        kv_cache_dtype=args.kv_cache_dtype,
        quantization_param_path=args.quantization_param_path,
        device=args.device,
        enable_prefix_caching=args.enable_prefix_caching,
        download_dir=args.download_dir,
        enable_chunked_prefill=args.enable_chunked_prefill,
        max_num_batched_tokens=args.max_num_batched_tokens,
        max_num_seqs=args.max_num_seqs,
        distributed_executor_backend=args.distributed_executor_backend,
        disable_log_stats=True
    )
    engine = LLMEngine.from_engine_args(engine_args)

    prompt = "hi" * (args.input_len - 1)
    requests = [(prompt, args.input_len, args.output_len)
                for _ in range(args.num_prompts)]

    start = time.perf_counter()
    for request_id, (prompt, _, output_len) in enumerate(requests):
        inputs = TextPrompt(prompt=prompt)
        sampling_params = SamplingParams(
            n=args.n,
            temperature=0.0 if args.use_beam_search else 1.0,
            top_p=1.0,
            use_beam_search=args.use_beam_search,
            ignore_eos=True,
            max_tokens=output_len,
        )
        engine.add_request(str(request_id), inputs, sampling_params)

    out = []
    while engine.has_unfinished_requests():
        request_outputs = engine.step()
        out.append((time.perf_counter(), request_outputs))
    end = time.perf_counter()

    timestamp = {}
    for t, rs in out:
        for r in rs:
            request_id = r.request_id
            if request_id not in timestamp:
                timestamp[request_id] = []
            timestamp[request_id].append(t)

    tpot = []
    for v in timestamp.values():
        dd = [v[i]-v[i-1] for i in range(1, len(v))]
        tpot.extend(dd)

    tpot = np.mean(tpot)
    elapsed_time = end - start

    total_num_tokens = sum(prompt_len + output_len
                           for _, prompt_len, output_len in requests)

    print(f"Throughput: {len(requests) / elapsed_time:.4f} requests/s, "
          f"{total_num_tokens / elapsed_time:.4f} tokens/s, "
          f"Delay {tpot*1000:0.2f} ms")


if __name__ == '__main__':
    from easydict import EasyDict as edict
    args = edict()

    args.dataset = None
    args.input_len = 512
    args.output_len = 512

    args.model = "Qwen/Qwen2-7B-Instruct"
    args.trust_remote_code = False
    args.tokenizer = args.model
    args.quantization = None
    args.quantization_param_path = None
    args.tensor_parallel_size = 1
    args.seed = 0
    args.n = 1
    args.use_beam_search = False
    args.num_prompts = 1000
    args.dtype = 'auto'
    args.max_model_len = 10000
    args.enforce_eager = True
    args.kv_cache_dtype = "auto"
    args.device = "cuda"
    args.enable_prefix_caching = False
    args.gpu_memory_utilization = 0.9
    args.output_json = None
    args.distributed_executor_backend = None
    args.download_dir = None

    import sys
    from concurrent.futures import ProcessPoolExecutor

    def run(args):
        with ProcessPoolExecutor(1) as executor:
            f = executor.submit(benchmark, args)
            f.result()

    max_num_seqs_list = [1024, 768, 512, 384, 256, 128, 64, 32]

    print()
    print("enable_chunked_prefill = False")
    for max_num_seqs in max_num_seqs_list:
        print("max_num_seqs", max_num_seqs)
        args.enable_chunked_prefill = False
        args.max_num_batched_tokens = None
        args.max_num_seqs = max_num_seqs
        run(args)

    print()
    print("enable_chunked_prefill = True")
    for max_num_seqs in max_num_seqs_list:
        print("max_num_seqs", max_num_seqs)
        args.enable_chunked_prefill = True
        args.max_num_seqs = max_num_seqs
        args.max_num_batched_tokens = args.max_num_seqs
        run(args)

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions