diff --git a/.buildkite/nightly-benchmarks/nightly-annotation.md b/.buildkite/nightly-benchmarks/nightly-annotation.md index e43ea765f155..ef11c040057c 100644 --- a/.buildkite/nightly-benchmarks/nightly-annotation.md +++ b/.buildkite/nightly-benchmarks/nightly-annotation.md @@ -16,7 +16,7 @@ Please download the visualization scripts in the post - Download `nightly-benchmarks.zip`. - In the same folder, run the following code: - ```console + ```bash export HF_TOKEN= apt update apt install -y git diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml index 16b5ad0297fe..55678b8936e0 100644 --- a/.buildkite/release-pipeline.yaml +++ b/.buildkite/release-pipeline.yaml @@ -102,6 +102,7 @@ steps: commands: - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest --progress plain --target vllm-openai -f docker/Dockerfile.cpu ." + - "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest" - "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version)" env: DOCKER_BUILDKIT: "1" @@ -117,6 +118,7 @@ steps: commands: - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:latest --progress plain -f docker/Dockerfile.neuron ." + - "docker push public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:latest" - "docker push public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:$(buildkite-agent meta-data get release-version)" env: DOCKER_BUILDKIT: "1" diff --git a/.buildkite/scripts/tpu/config_v6e_1.env b/.buildkite/scripts/tpu/config_v6e_1.env index 441758647347..03ec116f698d 100644 --- a/.buildkite/scripts/tpu/config_v6e_1.env +++ b/.buildkite/scripts/tpu/config_v6e_1.env @@ -4,8 +4,8 @@ CONTAINER_NAME=vllm-tpu # vllm config MODEL=meta-llama/Llama-3.1-8B-Instruct -MAX_NUM_SEQS=512 -MAX_NUM_BATCHED_TOKENS=512 +MAX_NUM_SEQS=256 +MAX_NUM_BATCHED_TOKENS=1024 TENSOR_PARALLEL_SIZE=1 MAX_MODEL_LEN=2048 DOWNLOAD_DIR=/mnt/disks/persist diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index fe775bb370f2..d6c9ee680abf 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -615,13 +615,16 @@ steps: - vllm/executor/ - vllm/model_executor/models/ - tests/distributed/ + - tests/examples/offline_inference/data_parallel.py commands: - # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up) - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' + - python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=0 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code - VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py - VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py - # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up) - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' + - python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=1 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code - label: Distributed Tests (2 GPUs) # 40min mirror_hardwares: [amdexperimental] diff --git a/benchmarks/README.md b/benchmarks/README.md index 6f9fbb91cbd9..2714b8b49821 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -269,6 +269,21 @@ python3 vllm/benchmarks/benchmark_serving.py \ --num-prompts 10 ``` +### Running With Ramp-Up Request Rate + +The benchmark tool also supports ramping up the request rate over the +duration of the benchmark run. This can be useful for stress testing the +server or finding the maximum throughput that it can handle, given some latency budget. + +Two ramp-up strategies are supported: +- `linear`: Increases the request rate linearly from a start value to an end value. +- `exponential`: Increases the request rate exponentially. + +The following arguments can be used to control the ramp-up: +- `--ramp-up-strategy`: The ramp-up strategy to use (`linear` or `exponential`). +- `--ramp-up-start-rps`: The request rate at the beginning of the benchmark. +- `--ramp-up-end-rps`: The request rate at the end of the benchmark. + --- ## Example - Offline Throughput Benchmark @@ -387,3 +402,178 @@ python3 vllm/benchmarks/benchmark_throughput.py \ --enable-lora \ --lora-path yard1/llama-2-7b-sql-lora-test ``` + +--- +## Example - Structured Output Benchmark + +Benchmark the performance of structured output generation (JSON, grammar, regex). + +### Server Setup + +```bash +vllm serve NousResearch/Hermes-3-Llama-3.1-8B --disable-log-requests +``` + +### JSON Schema Benchmark + +```bash +python3 benchmarks/benchmark_serving_structured_output.py \ + --backend vllm \ + --model NousResearch/Hermes-3-Llama-3.1-8B \ + --dataset json \ + --structured-output-ratio 1.0 \ + --request-rate 10 \ + --num-prompts 1000 +``` + +### Grammar-based Generation Benchmark + +```bash +python3 benchmarks/benchmark_serving_structured_output.py \ + --backend vllm \ + --model NousResearch/Hermes-3-Llama-3.1-8B \ + --dataset grammar \ + --structure-type grammar \ + --request-rate 10 \ + --num-prompts 1000 +``` + +### Regex-based Generation Benchmark + +```bash +python3 benchmarks/benchmark_serving_structured_output.py \ + --backend vllm \ + --model NousResearch/Hermes-3-Llama-3.1-8B \ + --dataset regex \ + --request-rate 10 \ + --num-prompts 1000 +``` + +### Choice-based Generation Benchmark + +```bash +python3 benchmarks/benchmark_serving_structured_output.py \ + --backend vllm \ + --model NousResearch/Hermes-3-Llama-3.1-8B \ + --dataset choice \ + --request-rate 10 \ + --num-prompts 1000 +``` + +### XGrammar Benchmark Dataset + +```bash +python3 benchmarks/benchmark_serving_structured_output.py \ + --backend vllm \ + --model NousResearch/Hermes-3-Llama-3.1-8B \ + --dataset xgrammar_bench \ + --request-rate 10 \ + --num-prompts 1000 +``` + +--- +## Example - Long Document QA Throughput Benchmark + +Benchmark the performance of long document question-answering with prefix caching. + +### Basic Long Document QA Test + +```bash +python3 benchmarks/benchmark_long_document_qa_throughput.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --enable-prefix-caching \ + --num-documents 16 \ + --document-length 2000 \ + --output-len 50 \ + --repeat-count 5 +``` + +### Different Repeat Modes + +```bash +# Random mode (default) - shuffle prompts randomly +python3 benchmarks/benchmark_long_document_qa_throughput.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --enable-prefix-caching \ + --num-documents 8 \ + --document-length 3000 \ + --repeat-count 3 \ + --repeat-mode random + +# Tile mode - repeat entire prompt list in sequence +python3 benchmarks/benchmark_long_document_qa_throughput.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --enable-prefix-caching \ + --num-documents 8 \ + --document-length 3000 \ + --repeat-count 3 \ + --repeat-mode tile + +# Interleave mode - repeat each prompt consecutively +python3 benchmarks/benchmark_long_document_qa_throughput.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --enable-prefix-caching \ + --num-documents 8 \ + --document-length 3000 \ + --repeat-count 3 \ + --repeat-mode interleave +``` + +--- +## Example - Prefix Caching Benchmark + +Benchmark the efficiency of automatic prefix caching. + +### Fixed Prompt with Prefix Caching + +```bash +python3 benchmarks/benchmark_prefix_caching.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --enable-prefix-caching \ + --num-prompts 1 \ + --repeat-count 100 \ + --input-length-range 128:256 +``` + +### ShareGPT Dataset with Prefix Caching + +```bash +# download dataset +# wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json + +python3 benchmarks/benchmark_prefix_caching.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --dataset-path /path/ShareGPT_V3_unfiltered_cleaned_split.json \ + --enable-prefix-caching \ + --num-prompts 20 \ + --repeat-count 5 \ + --input-length-range 128:256 +``` + +--- +## Example - Request Prioritization Benchmark + +Benchmark the performance of request prioritization in vLLM. + +### Basic Prioritization Test + +```bash +python3 benchmarks/benchmark_prioritization.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --input-len 128 \ + --output-len 64 \ + --num-prompts 100 \ + --scheduling-policy priority +``` + +### Multiple Sequences per Prompt + +```bash +python3 benchmarks/benchmark_prioritization.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --input-len 128 \ + --output-len 64 \ + --num-prompts 100 \ + --scheduling-policy priority \ + --n 2 +``` diff --git a/benchmarks/auto_tune.sh b/benchmarks/auto_tune.sh index 1b01bbd61b62..b257b57ce06f 100644 --- a/benchmarks/auto_tune.sh +++ b/benchmarks/auto_tune.sh @@ -10,6 +10,7 @@ # 3. Set variables (ALL REQUIRED) # BASE: your directory for vllm repo # MODEL: the model served by vllm +# SYSTEM: the hardware, choice TPU or GPU, for other systems, "get best profile" might not support. # TP: ways of tensor parallelism # DOWNLOAD_DIR: directory to download and load model weights. # INPUT_LEN: request input len @@ -34,6 +35,7 @@ TAG=$(date +"%Y_%m_%d_%H_%M") BASE="" MODEL="meta-llama/Llama-3.1-8B-Instruct" +SYSTEM="TPU" TP=1 DOWNLOAD_DIR="" INPUT_LEN=4000 @@ -45,12 +47,15 @@ NUM_BATCHED_TOKENS_LIST="512 1024 2048 4096" LOG_FOLDER="$BASE/auto-benchmark/$TAG" RESULT="$LOG_FOLDER/result.txt" +PROFILE_PATH="$LOG_FOLDER/profile" echo "result file: $RESULT" echo "model: $MODEL" rm -rf $LOG_FOLDER +rm -rf $PROFILE_PATH mkdir -p $LOG_FOLDER +mkdir -p $PROFILE_PATH cd "$BASE/vllm" @@ -70,10 +75,11 @@ start_server() { local max_num_seqs=$2 local max_num_batched_tokens=$3 local vllm_log=$4 + local profile_dir=$5 pkill -f vllm - VLLM_USE_V1=1 VLLM_SERVER_DEV_MODE=1 vllm serve $MODEL \ + VLLM_USE_V1=1 VLLM_SERVER_DEV_MODE=1 VLLM_TORCH_PROFILER_DIR=$profile_dir vllm serve $MODEL \ --disable-log-requests \ --port 8004 \ --gpu-memory-utilization $gpu_memory_utilization \ @@ -105,19 +111,37 @@ start_server() { fi } +update_best_profile() { + local profile_dir=$1 + local profile_index=$2 + sorted_paths=($(find "$profile_dir" -maxdepth 1 -not -path "$profile_dir" | sort)) + selected_profile_file= + if [[ "$SYSTEM" == "TPU" ]]; then + selected_profile_file="${sorted_paths[$profile_index]}/*.xplane.pb" + fi + if [[ "$SYSTEM" == "GPU" ]]; then + selected_profile_file="${sorted_paths[$profile_index]}" + fi + rm -f $PROFILE_PATH/* + cp $selected_profile_file $PROFILE_PATH +} + run_benchmark() { local max_num_seqs=$1 local max_num_batched_tokens=$2 local gpu_memory_utilization=$3 echo "max_num_seq: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens" local vllm_log="$LOG_FOLDER/vllm_log_${max_num_seqs}_${max_num_batched_tokens}.txt" + local profile_dir="$LOG_FOLDER/profile_${max_num_seqs}_${max_num_batched_tokens}" echo "vllm_log: $vllm_log" echo rm -f $vllm_log + mkdir -p $profile_dir pkill -f vllm + local profile_index=0 echo "starting server..." - start_server $gpu_memory_utilization $max_num_seqs $max_num_batched_tokens $vllm_log + start_server $gpu_memory_utilization $max_num_seqs $max_num_batched_tokens $vllm_log $profile_dir result=$? if [[ "$result" -eq 1 ]]; then echo "server failed to start. gpu_memory_utilization:$gpu_memory_utilization, max_num_seqs:$max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens" @@ -144,7 +168,8 @@ run_benchmark() { --goodput e2el:$MAX_LATENCY_ALLOWED_MS \ --num-prompts 1000 \ --random-prefix-len $prefix_len \ - --port 8004 &> "$bm_log" + --port 8004 \ + --profile &> "$bm_log" throughput=$(grep "Request throughput (req/s):" "$bm_log" | sed 's/[^0-9.]//g') e2el=$(grep "P99 E2EL (ms):" "$bm_log" | awk '{print $NF}') goodput=$(grep "Request goodput (req/s):" "$bm_log" | sed 's/[^0-9.]//g') @@ -158,6 +183,7 @@ run_benchmark() { # start from request-rate as int(throughput) + 1 request_rate=$((${throughput%.*} + 1)) while ((request_rate > 0)); do + profile_index=$((profile_index+1)) # clear prefix cache curl -X POST http://0.0.0.0:8004/reset_prefix_cache sleep 5 @@ -195,6 +221,12 @@ run_benchmark() { best_max_num_seqs=$max_num_seqs best_num_batched_tokens=$max_num_batched_tokens best_goodput=$goodput + if [[ "$SYSTEM" == "TPU" ]]; then + update_best_profile "$profile_dir/plugins/profile" $profile_index + fi + if [[ "$SYSTEM" == "GPU" ]]; then + update_best_profile "$profile_dir" $profile_index + fi fi else echo "max_num_seqs: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens does not meet latency requirement ${MAX_LATENCY_ALLOWED_MS}" @@ -239,6 +271,6 @@ for num_seqs in "${num_seqs_list[@]}"; do done done echo "finish permutations" -echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput" -echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput" >> "$RESULT" +echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput, profile saved in: $PROFILE_PATH" +echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput, profile saved in: $PROFILE_PATH" >> "$RESULT" diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index f38e45b26113..886a51e1cbd9 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -33,7 +33,7 @@ from collections.abc import AsyncGenerator, Iterable from dataclasses import dataclass from datetime import datetime -from typing import Any, Optional +from typing import Any, Literal, Optional import numpy as np from tqdm.asyncio import tqdm @@ -107,14 +107,42 @@ class BenchmarkMetrics: percentiles_e2el_ms: list[tuple[float, float]] +def _get_current_request_rate( + ramp_up_strategy: Optional[Literal["linear", "exponential"]], + ramp_up_start_rps: Optional[int], + ramp_up_end_rps: Optional[int], + request_index: int, + total_requests: int, + request_rate: float, +) -> float: + if ( + ramp_up_strategy + and ramp_up_start_rps is not None + and ramp_up_end_rps is not None + ): + progress = request_index / max(total_requests - 1, 1) + if ramp_up_strategy == "linear": + increase = (ramp_up_end_rps - ramp_up_start_rps) * progress + return ramp_up_start_rps + increase + elif ramp_up_strategy == "exponential": + ratio = ramp_up_end_rps / ramp_up_start_rps + return ramp_up_start_rps * (ratio**progress) + else: + raise ValueError(f"Unknown ramp-up strategy: {ramp_up_strategy}") + return request_rate + + async def get_request( input_requests: list[SampleRequest], request_rate: float, burstiness: float = 1.0, -) -> AsyncGenerator[SampleRequest, None]: + ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None, + ramp_up_start_rps: Optional[int] = None, + ramp_up_end_rps: Optional[int] = None, +) -> AsyncGenerator[tuple[SampleRequest, float], None]: """ Asynchronously generates requests at a specified rate - with OPTIONAL burstiness. + with OPTIONAL burstiness and OPTIONAL ramp-up strategy. Args: input_requests: @@ -129,22 +157,44 @@ async def get_request( A lower burstiness value (0 < burstiness < 1) results in more bursty requests, while a higher burstiness value (burstiness > 1) results in a more uniform arrival of requests. + ramp_up_strategy (optional): + The ramp-up strategy. Can be "linear" or "exponential". + If None, uses constant request rate (specified by request_rate). + ramp_up_start_rps (optional): + The starting request rate for ramp-up. + ramp_up_end_rps (optional): + The ending request rate for ramp-up. """ - input_requests: Iterable[SampleRequest] = iter(input_requests) - - # Calculate scale parameter theta to maintain the desired request_rate. assert burstiness > 0, ( f"A positive burstiness factor is expected, but given {burstiness}." ) - theta = 1.0 / (request_rate * burstiness) + # Convert to list to get length for ramp-up calculations + if isinstance(input_requests, Iterable) and not isinstance(input_requests, list): + input_requests = list(input_requests) + + total_requests = len(input_requests) + request_index = 0 for request in input_requests: - yield request + current_request_rate = _get_current_request_rate( + ramp_up_strategy, + ramp_up_start_rps, + ramp_up_end_rps, + request_index, + total_requests, + request_rate, + ) + + yield request, current_request_rate - if request_rate == float("inf"): + request_index += 1 + + if current_request_rate == float("inf"): # If the request rate is infinity, then we don't need to wait. continue + theta = 1.0 / (current_request_rate * burstiness) + # Sample the request interval from the gamma distribution. # If burstiness is 1, it follows exponential distribution. interval = np.random.gamma(shape=burstiness, scale=theta) @@ -290,6 +340,9 @@ async def benchmark( max_concurrency: Optional[int], lora_modules: Optional[Iterable[str]], extra_body: Optional[dict], + ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None, + ramp_up_start_rps: Optional[int] = None, + ramp_up_end_rps: Optional[int] = None, ): if backend in ASYNC_REQUEST_FUNCS: request_func = ASYNC_REQUEST_FUNCS[backend] @@ -353,7 +406,15 @@ async def benchmark( distribution = "Poisson process" if burstiness == 1.0 else "Gamma distribution" - print(f"Traffic request rate: {request_rate}") + if ramp_up_strategy is not None: + print( + f"Traffic ramp-up strategy: {ramp_up_strategy}. Will increase " + f"RPS from {ramp_up_start_rps} to {ramp_up_end_rps} RPS over " + "the duration of the benchmark." + ) + else: + print(f"Traffic request rate: {request_rate} RPS.") + print(f"Burstiness factor: {burstiness} ({distribution})") print(f"Maximum request concurrency: {max_concurrency}") @@ -373,7 +434,34 @@ async def limited_request_func(request_func_input, pbar): benchmark_start_time = time.perf_counter() tasks: list[asyncio.Task] = [] - async for request in get_request(input_requests, request_rate, burstiness): + + rps_change_events = [] + last_int_rps = -1 + if ramp_up_strategy is not None and ramp_up_start_rps is not None: + last_int_rps = ramp_up_start_rps + rps_change_events.append( + { + "rps": last_int_rps, + "timestamp": datetime.now().isoformat(), + } + ) + + async for request, current_request_rate in get_request( + input_requests, + request_rate, + burstiness, + ramp_up_strategy, + ramp_up_start_rps, + ramp_up_end_rps, + ): + if ramp_up_strategy is not None: + current_int_rps = int(current_request_rate) + if current_int_rps > last_int_rps: + timestamp = datetime.now().isoformat() + for rps_val in range(last_int_rps + 1, current_int_rps + 1): + rps_change_events.append({"rps": rps_val, "timestamp": timestamp}) + last_int_rps = current_int_rps + prompt, prompt_len, output_len, mm_content = ( request.prompt, request.prompt_len, @@ -397,11 +485,8 @@ async def limited_request_func(request_func_input, pbar): ignore_eos=ignore_eos, extra_body=extra_body, ) - tasks.append( - asyncio.create_task( - limited_request_func(request_func_input=request_func_input, pbar=pbar) - ) - ) + task = limited_request_func(request_func_input=request_func_input, pbar=pbar) + tasks.append(asyncio.create_task(task)) outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks) if profile: @@ -477,6 +562,9 @@ async def limited_request_func(request_func_input, pbar): "errors": [output.error for output in outputs], } + if rps_change_events: + result["rps_change_events"] = rps_change_events + def process_one_metric( # E.g., "ttft" metric_attribute_name: str, @@ -610,6 +698,26 @@ def main(args: argparse.Namespace): tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model tokenizer_mode = args.tokenizer_mode + # Validate ramp-up arguments + if args.ramp_up_strategy is not None: + if args.request_rate != float("inf"): + raise ValueError( + "When using ramp-up, do not specify --request-rate. " + "The request rate will be controlled by ramp-up parameters. " + "Please remove the --request-rate argument." + ) + if args.ramp_up_start_rps is None or args.ramp_up_end_rps is None: + raise ValueError( + "When using --ramp-up-strategy, both --ramp-up-start-rps and " + "--ramp-up-end-rps must be specified" + ) + if args.ramp_up_start_rps < 0 or args.ramp_up_end_rps < 0: + raise ValueError("Ramp-up start and end RPS must be non-negative") + if args.ramp_up_start_rps > args.ramp_up_end_rps: + raise ValueError("Ramp-up start RPS must be less than end RPS") + if args.ramp_up_strategy == "exponential" and args.ramp_up_start_rps == 0: + raise ValueError("For exponential ramp-up, the start RPS cannot be 0.") + if args.base_url is not None: api_url = f"{args.base_url}{args.endpoint}" base_url = f"{args.base_url}" @@ -802,6 +910,9 @@ def main(args: argparse.Namespace): max_concurrency=args.max_concurrency, lora_modules=args.lora_modules, extra_body=sampling_params, + ramp_up_strategy=args.ramp_up_strategy, + ramp_up_start_rps=args.ramp_up_start_rps, + ramp_up_end_rps=args.ramp_up_end_rps, ) ) @@ -834,6 +945,11 @@ def main(args: argparse.Namespace): result_json["burstiness"] = args.burstiness result_json["max_concurrency"] = args.max_concurrency + if args.ramp_up_strategy is not None: + result_json["ramp_up_strategy"] = args.ramp_up_strategy + result_json["ramp_up_start_rps"] = args.ramp_up_start_rps + result_json["ramp_up_end_rps"] = args.ramp_up_end_rps + # Merge with benchmark result result_json = {**result_json, **benchmark_result} @@ -859,7 +975,10 @@ def main(args: argparse.Namespace): if args.max_concurrency is not None else "" ) - file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa + if args.ramp_up_strategy is not None: + file_name = f"{backend}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa + else: + file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa if args.result_filename: file_name = args.result_filename if args.result_dir: @@ -1225,6 +1344,31 @@ def create_argument_parser(): "script chooses a LoRA module at random.", ) + parser.add_argument( + "--ramp-up-strategy", + type=str, + default=None, + choices=["linear", "exponential"], + help="The ramp-up strategy. This would be used to " + "ramp up the request rate from initial RPS to final " + "RPS rate (specified by --ramp-up-start-rps and --ramp-up-end-rps). " + "over the duration of the benchmark.", + ) + parser.add_argument( + "--ramp-up-start-rps", + type=int, + default=None, + help="The starting request rate for ramp-up (RPS). " + "Needs to be specified when --ramp-up-strategy is used.", + ) + parser.add_argument( + "--ramp-up-end-rps", + type=int, + default=None, + help="The ending request rate for ramp-up (RPS). " + "Needs to be specified when --ramp-up-strategy is used.", + ) + return parser diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py index cec422e8d597..a5a5b52f6039 100644 --- a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py @@ -19,7 +19,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( w8a8_block_fp8_matmul, ) -from vllm.utils import FlexibleArgumentParser +from vllm.utils import FlexibleArgumentParser, cdiv DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] @@ -117,14 +117,9 @@ def bench_fp8( scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) - def ceil_div(x: int, y: int) -> int: - return (x + y - 1) // y - - block_scale_a = torch.rand( - (m, ceil_div(k, 128)), device="cuda", dtype=torch.float32 - ) + block_scale_a = torch.rand((m, cdiv(k, 128)), device="cuda", dtype=torch.float32) block_scale_b = torch.rand( - ceil_div(k, 128), ceil_div(n, 128), device="cuda", dtype=torch.float32 + cdiv(k, 128), cdiv(n, 128), device="cuda", dtype=torch.float32 ) block_scale_a_M_major = block_scale_a.t().contiguous().t() block_scale_b_K_major = block_scale_b.t().contiguous().t() diff --git a/benchmarks/kernels/benchmark_marlin.py b/benchmarks/kernels/benchmark_marlin.py index 9ea1fddae2a3..34cc45e94d76 100644 --- a/benchmarks/kernels/benchmark_marlin.py +++ b/benchmarks/kernels/benchmark_marlin.py @@ -22,8 +22,16 @@ MARLIN_SUPPORTED_GROUP_SIZES, query_marlin_supported_quant_types, ) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + FP4_MARLIN_SUPPORTED_GROUP_SIZES, + rand_marlin_weight_fp4_like, +) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( + marlin_quant_fp8_torch, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( MarlinWorkspace, + awq_marlin_quantize, marlin_quantize, ) from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import ( @@ -35,7 +43,7 @@ quantize_weights, sort_weights, ) -from vllm.scalar_type import ScalarType +from vllm.scalar_type import ScalarType, scalar_types from vllm.utils import FlexibleArgumentParser DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"] @@ -57,80 +65,144 @@ def bench_run( size_n: int, ): label = "Quant Matmul" - sub_label = "{}, act={} k_full={}, q={}, g={}, MKN=({}x{}x{})".format( model, act_order, is_k_full, str(quant_type), group_size, size_m, size_k, size_n ) - print(f"Testing: {sub_label}") a = torch.randn(size_m, size_k).to(torch.half).cuda() b = torch.rand(size_k, size_n).to(torch.half).cuda() + has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8] + if act_order and (group_size == -1 or group_size == size_k or has_zp): + return + if size_k % group_size != 0: + return - a_tmp = torch.zeros(size_m, size_k).to(torch.half).cuda() + marlin_24_supported = ( + quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES + and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES + ) + repack_supported = ( + quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES + and group_size in MARLIN_SUPPORTED_GROUP_SIZES + ) + allspark_supported = ( + quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES + and group_size == -1 + and not act_order + and is_k_full + ) + + def gen_marlin_params(): + # Marlin quant + marlin_g_idx = marlin_sort_indices = marlin_zp = marlin_s2 = None + if quant_type == scalar_types.float4_e2m1f: + if group_size != 16 or act_order: + return + marlin_w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_fp4_like( + b.T, group_size + ) + elif quant_type == scalar_types.float8_e4m3fn: + if group_size not in [-1, 128] or act_order: + return + marlin_w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(b.T, group_size) + elif group_size == 16: + return + elif has_zp: + marlin_w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize( + b, quant_type, group_size + ) + else: + marlin_w_ref, marlin_q_w, marlin_s, marlin_g_idx, marlin_sort_indices, _ = ( + marlin_quantize(b, quant_type, group_size, act_order) + ) + return ( + marlin_w_ref, + marlin_q_w, + marlin_s, + marlin_s2, + marlin_zp, + marlin_g_idx, + marlin_sort_indices, + ) + + def gen_marlin_24_params(): + marlin_24_w_ref = marlin_24_q_w_comp = marlin_24_meta = marlin_24_s = None + if marlin_24_supported: + (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = ( + marlin_24_quantize(b, quant_type, group_size) + ) + return (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) + + def gen_repack_params(): + q_w_gptq = None + repack_sort_indices = None + if repack_supported: + (w_ref, q_w, s, g_idx, rand_perm) = gptq_quantize_weights( + b, quant_type, group_size, act_order + ) + q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n) + + # For act_order, sort the "weights" and "g_idx" + # so that group ids are increasing + repack_sort_indices = torch.empty(0, dtype=torch.int, device=b.device) + if act_order: + (q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx) + return q_w_gptq, repack_sort_indices + + def gen_allspark_params(): + qw_reorder = s_reorder = zp_reorder = sm_count = sm_version = ( + CUBLAS_M_THRESHOLD + ) = None + nonlocal allspark_supported + if allspark_supported: + properties = torch.cuda.get_device_properties(b.device.index) + sm_count = properties.multi_processor_count + sm_version = properties.major * 10 + properties.minor + + supported_arch = sm_version >= 80 and sm_version < 90 + allspark_supported = allspark_supported and supported_arch + if supported_arch: + w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size, has_zp) + qw = qw.to(torch.uint8) + + qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight( + qw, s, zp, has_zp + ) + CUBLAS_M_THRESHOLD = ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD + return ( + qw_reorder, + s_reorder, + zp_reorder, + sm_count, + sm_version, + CUBLAS_M_THRESHOLD, + ) - # Marlin quant ( marlin_w_ref, marlin_q_w, marlin_s, + marlin_s2, + marlin_zp, marlin_g_idx, marlin_sort_indices, - marlin_rand_perm, - ) = marlin_quantize(b, quant_type, group_size, act_order) - - # Marlin_24 quant - (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = ( - marlin_24_quantize(b, quant_type, group_size) + ) = gen_marlin_params() + marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s = ( + gen_marlin_24_params() ) - - marlin_zp = torch.empty(0, dtype=torch.int, device=b.device) - - # GPTQ quant - (w_ref, q_w, s, g_idx, rand_perm) = gptq_quantize_weights( - b, quant_type, group_size, act_order + q_w_gptq, repack_sort_indices = gen_repack_params() + qw_reorder, s_reorder, zp_reorder, sm_count, sm_version, CUBLAS_M_THRESHOLD = ( + gen_allspark_params() ) - q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n) - - # For act_order, sort the "weights" and "g_idx" - # so that group ids are increasing - repack_sort_indices = torch.empty(0, dtype=torch.int, device=b.device) - if act_order: - (q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx) # Prepare marlin_workspace = MarlinWorkspace( size_n, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL ) - marlin_24_workspace = MarlinWorkspace( size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL ) - marlin_zp = torch.zeros_like(marlin_s, dtype=torch.int) - - # AllSpark W8A16 quant - as_supported_case = ( - quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES - and group_size == -1 - and not act_order - and is_k_full - ) - if as_supported_case: - properties = torch.cuda.get_device_properties(b.device.index) - sm_count = properties.multi_processor_count - sm_version = properties.major * 10 + properties.minor - - supported_arch = sm_version >= 80 and sm_version < 90 - as_supported_case = as_supported_case and supported_arch - if supported_arch: - has_zp = False - w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size, has_zp) - qw = qw.to(torch.uint8) - - qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight( - qw, s, zp, has_zp - ) - CUBLAS_M_THRESHOLD = ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD globals = { # Gen params @@ -140,15 +212,14 @@ def bench_run( "size_n": size_n, "size_k": size_k, "a": a, - "a_tmp": a_tmp, # Marlin params "marlin_w_ref": marlin_w_ref, "marlin_q_w": marlin_q_w, "marlin_s": marlin_s, + "marlin_s2": marlin_s2, "marlin_zp": marlin_zp, "marlin_g_idx": marlin_g_idx, "marlin_sort_indices": marlin_sort_indices, - "marlin_rand_perm": marlin_rand_perm, "marlin_workspace": marlin_workspace, "is_k_full": is_k_full, # Marlin_24 params @@ -161,12 +232,12 @@ def bench_run( "q_w_gptq": q_w_gptq, "repack_sort_indices": repack_sort_indices, # AllSpark W8A16 params - "qw_reorder": qw_reorder if as_supported_case else None, - "s_reorder": s_reorder if as_supported_case else None, - "zp_reorder": zp_reorder if as_supported_case else None, - "sm_count": sm_count if as_supported_case else None, - "sm_version": sm_version if as_supported_case else None, - "CUBLAS_M_THRESHOLD": CUBLAS_M_THRESHOLD if as_supported_case else None, + "qw_reorder": qw_reorder, + "s_reorder": s_reorder, + "zp_reorder": zp_reorder, + "sm_count": sm_count, + "sm_version": sm_version, + "CUBLAS_M_THRESHOLD": CUBLAS_M_THRESHOLD, # Kernels "gptq_marlin_gemm": ops.gptq_marlin_gemm, "gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm, @@ -177,7 +248,7 @@ def bench_run( min_run_time = 1 # Warmup pytorch - for i in range(5): + for _ in range(5): torch.matmul(a, marlin_w_ref) results.append( @@ -192,17 +263,17 @@ def bench_run( results.append( benchmark.Timer( - stmt="output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501 + stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, - description="gptq_marlin_gemm_fp16", + description="gptq_marlin_gemm", ).blocked_autorange(min_run_time=min_run_time) ) results.append( benchmark.Timer( - stmt="output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501 + stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, @@ -210,10 +281,7 @@ def bench_run( ).blocked_autorange(min_run_time=min_run_time) ) - if ( - quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES - and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES - ): + if marlin_24_supported: results.append( benchmark.Timer( stmt="output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)", # noqa: E501 @@ -224,17 +292,18 @@ def bench_run( ).blocked_autorange(min_run_time=min_run_time) ) - results.append( - benchmark.Timer( - stmt="q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)", # noqa: E501 - globals=globals, - label=label, - sub_label=sub_label, - description="gptq_marlin_repack", - ).blocked_autorange(min_run_time=min_run_time) - ) + if repack_supported: + results.append( + benchmark.Timer( + stmt="q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="gptq_marlin_repack", + ).blocked_autorange(min_run_time=min_run_time) + ) - if as_supported_case: + if allspark_supported: results.append( benchmark.Timer( stmt="output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)", # noqa: E501 @@ -250,7 +319,6 @@ def main(args): print("Benchmarking models:") for i, model in enumerate(args.models): print(f"[{i}] {model}") - results: list[benchmark.Measurement] = [] for model in args.models: @@ -278,14 +346,17 @@ def main(args): ): continue - for quant_type in query_marlin_supported_quant_types(False): + for quant_type in query_marlin_supported_quant_types(): if ( len(args.limit_num_bits) > 0 and quant_type.size_bits not in args.limit_num_bits ): continue - for group_size in MARLIN_SUPPORTED_GROUP_SIZES: + for group_size in ( + MARLIN_SUPPORTED_GROUP_SIZES + + FP4_MARLIN_SUPPORTED_GROUP_SIZES + ): if ( len(args.limit_group_size) > 0 and group_size not in args.limit_group_size diff --git a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py index e67ce0545318..43c54d56ca8c 100644 --- a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py +++ b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py @@ -85,12 +85,6 @@ def benchmark_shape(m: int, # === DeepGEMM Implementation === def deepgemm_gemm(): - # A quantization is inside the loop as it depends on activations - # A_deepgemm, A_scale_deepgemm = per_token_cast_to_fp8(A) - # A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8( - # A, block_size[1]) - # A_scale_aligned = get_col_major_tma_aligned_tensor(A_scale_deepgemm) - # C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) deep_gemm.gemm_fp8_fp8_bf16_nt((A_deepgemm, A_scale_deepgemm), (B_deepgemm, B_scale_deepgemm), C_deepgemm) @@ -98,8 +92,6 @@ def deepgemm_gemm(): # === vLLM Triton Implementation === def vllm_triton_gemm(): - # A quantization is inside the loop as it depends on activations - # A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1]) return w8a8_block_fp8_matmul(A_vllm, B_vllm, A_scale_vllm, @@ -109,9 +101,6 @@ def vllm_triton_gemm(): # === vLLM CUTLASS Implementation === def vllm_cutlass_gemm(): - # A quantization is inside the loop as it depends on activations - # A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8( - # A, block_size[1], column_major_scales=True) return ops.cutlass_scaled_mm(A_vllm_cutlass, B_vllm.T, scale_a=A_scale_vllm_cutlass, diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake index dba5baa362b8..7b17018f65ab 100644 --- a/cmake/external_projects/vllm_flash_attn.cmake +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -38,7 +38,7 @@ else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG 763ad155a1c826f71ff318f41edb1e4e5e376ddb + GIT_TAG 5f3644181c7a15345ce20bfc65af117d3601b524 GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index 39997030751b..3bddd12cad07 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -1598,7 +1598,6 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( const int warpid = threadIdx.x / WARP_SIZE; const int laneid = threadIdx.x % WARP_SIZE; const int lane2id = laneid % 2; - const int lane4id = laneid % 4; const int lane16id = laneid % 16; const int rowid = laneid / 16; @@ -1745,7 +1744,6 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( const cache_t* k_ptr2 = k_ptr + kblock_number * kv_block_stride; const int klocal_token_idx = TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id; - const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx; const int kphysical_block_offset = klocal_token_idx % BLOCK_SIZE; const cache_t* k_ptr3 = k_ptr2 + kphysical_block_offset * KX; @@ -2368,7 +2366,6 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( const int warpid = threadIdx.x / WARP_SIZE; const int laneid = threadIdx.x % WARP_SIZE; const int lane2id = laneid % 2; - const int lane4id = laneid % 4; const int lane16id = laneid % 16; const int rowid = laneid / 16; @@ -2514,7 +2511,6 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( const cache_t* k_ptr2 = k_ptr + kblock_number * kv_block_stride; const int klocal_token_idx = TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id; - const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx; const int kphysical_block_offset = klocal_token_idx % BLOCK_SIZE; const cache_t* k_ptr3 = k_ptr2 + kphysical_block_offset * KX; diff --git a/docs/contributing/README.md b/docs/contributing/README.md index e977ec3d2f71..c0c338b42695 100644 --- a/docs/contributing/README.md +++ b/docs/contributing/README.md @@ -29,6 +29,8 @@ See . Depending on the kind of development you'd like to do (e.g. Python, CUDA), you can choose to build vLLM with or without compilation. Check out the [building from source][build-from-source] documentation for details. +For an optimized workflow when iterating on C++/CUDA kernels, see the [Incremental Compilation Workflow](./incremental_build.md) for recommendations. + ### Building the docs with MkDocs #### Introduction to MkDocs @@ -188,6 +190,7 @@ The PR needs to meet the following code quality standards: ### Adding or Changing Kernels +When actively developing or modifying kernels, using the [Incremental Compilation Workflow](./incremental_build.md) is highly recommended for faster build times. Each custom kernel needs a schema and one or more implementations to be registered with PyTorch. - Make sure custom ops are registered following PyTorch guidelines: diff --git a/docs/contributing/incremental_build.md b/docs/contributing/incremental_build.md new file mode 100644 index 000000000000..8efa34825eca --- /dev/null +++ b/docs/contributing/incremental_build.md @@ -0,0 +1,138 @@ +# Incremental Compilation Workflow for vLLM Development + +When working on vLLM's C++/CUDA kernels located in the `csrc/` directory, recompiling the entire project with `uv pip install -e .` for every change can be time-consuming. An incremental compilation workflow using CMake allows for faster iteration by only recompiling the necessary components after an initial setup. This guide details how to set up and use such a workflow, which complements your editable Python installation. + +## Prerequisites + +Before setting up the incremental build: + +1. **vLLM Editable Install:** Ensure you have vLLM installed from source in an editable mode. Using pre-compiled wheels for the initial editable setup can be faster, as the CMake workflow will handle subsequent kernel recompilations. + + ```console + uv venv --python 3.12 --seed + source .venv/bin/activate + VLLM_USE_PRECOMPILED=1 uv pip install -U -e . --torch-backend=auto + ``` + +2. **CUDA Toolkit:** Verify that the NVIDIA CUDA Toolkit is correctly installed and `nvcc` is accessible in your `PATH`. CMake relies on `nvcc` to compile CUDA code. You can typically find `nvcc` in `$CUDA_HOME/bin/nvcc` or by running `which nvcc`. If you encounter issues, refer to the [official CUDA Toolkit installation guides](https://developer.nvidia.com/cuda-toolkit-archive) and vLLM's main [GPU installation documentation](../getting_started/installation/gpu/cuda.inc.md#troubleshooting) for troubleshooting. The `CMAKE_CUDA_COMPILER` variable in your `CMakeUserPresets.json` should also point to your `nvcc` binary. + +3. **Build Tools:** It is highly recommended to install `ccache` for fast rebuilds by caching compilation results (e.g., `sudo apt install ccache` or `conda install ccache`). Also, ensure the core build dependencies like `cmake` and `ninja` are installed. These are installable through `requirements/build.txt` or your system's package manager. + + ```console + uv pip install -r requirements/build.txt --torch-backend=auto + ``` + +## Setting up the CMake Build Environment + +The incremental build process is managed through CMake. You can configure your build settings using a `CMakeUserPresets.json` file at the root of the vLLM repository. + +### Generate `CMakeUserPresets.json` using the helper script + +To simplify the setup, vLLM provides a helper script that attempts to auto-detect your system's configuration (like CUDA path, Python environment, and CPU cores) and generates the `CMakeUserPresets.json` file for you. + +**Run the script:** + +Navigate to the root of your vLLM clone and execute the following command: + +```console +python tools/generate_cmake_presets.py +``` + +The script will prompt you if it cannot automatically determine certain paths (e.g., `nvcc` or a specific Python executable for your vLLM development environment). Follow the on-screen prompts. If an existing `CMakeUserPresets.json` is found, the script will ask for confirmation before overwriting it. + +After running the script, a `CMakeUserPresets.json` file will be created in the root of your vLLM repository. + +### Example `CMakeUserPresets.json` + +Below is an example of what the generated `CMakeUserPresets.json` might look like. The script will tailor these values based on your system and any input you provide. + +```json +{ + "version": 6, + "cmakeMinimumRequired": { + "major": 3, + "minor": 26, + "patch": 1 + }, + "configurePresets": [ + { + "name": "release", + "generator": "Ninja", + "binaryDir": "${sourceDir}/cmake-build-release", + "cacheVariables": { + "CMAKE_CUDA_COMPILER": "/usr/local/cuda/bin/nvcc", + "CMAKE_C_COMPILER_LAUNCHER": "ccache", + "CMAKE_CXX_COMPILER_LAUNCHER": "ccache", + "CMAKE_CUDA_COMPILER_LAUNCHER": "ccache", + "CMAKE_BUILD_TYPE": "Release", + "VLLM_PYTHON_EXECUTABLE": "/home/user/venvs/vllm/bin/python", + "CMAKE_INSTALL_PREFIX": "${sourceDir}", + "CMAKE_CUDA_FLAGS": "", + "NVCC_THREADS": "4", + "CMAKE_JOB_POOLS": "compile=32" + } + } + ], + "buildPresets": [ + { + "name": "release", + "configurePreset": "release", + "jobs": 32 + } + ] +} +``` + +**What do the various configurations mean?** +- `CMAKE_CUDA_COMPILER`: Path to your `nvcc` binary. The script attempts to find this automatically. +- `CMAKE_C_COMPILER_LAUNCHER`, `CMAKE_CXX_COMPILER_LAUNCHER`, `CMAKE_CUDA_COMPILER_LAUNCHER`: Setting these to `ccache` (or `sccache`) significantly speeds up rebuilds by caching compilation results. Ensure `ccache` is installed (e.g., `sudo apt install ccache` or `conda install ccache`). The script sets these by default. +- `VLLM_PYTHON_EXECUTABLE`: Path to the Python executable in your vLLM development environment. The script will prompt for this, defaulting to the current Python environment if suitable. +- `CMAKE_INSTALL_PREFIX: "${sourceDir}"`: Specifies that the compiled components should be installed back into your vLLM source directory. This is crucial for the editable install, as it makes the newly built kernels immediately available to your Python environment. +- `CMAKE_JOB_POOLS` and `jobs` in build presets: Control the parallelism of the build. The script sets these based on the number of CPU cores detected on your system. +- `binaryDir`: Specifies where the build artifacts will be stored (e.g., `cmake-build-release`). + +## Building and Installing with CMake + +Once your `CMakeUserPresets.json` is configured: + +1. **Initialize the CMake build environment:** + This step configures the build system according to your chosen preset (e.g., `release`) and creates the build directory at `binaryDir` + + ```console + cmake --preset release + ``` + +2. **Build and install the vLLM components:** + This command compiles the code and installs the resulting binaries into your vLLM source directory, making them available to your editable Python installation. + + ```console + cmake --build --preset release --target install + ``` + +3. **Make changes and repeat!** + Now you start using your editable install of vLLM, testing and making changes as needed. If you need to build again to update based on changes, simply run the CMake command again to build only the affected files. + + ```console + cmake --build --preset release --target install + ``` + +## Verifying the Build + +After a successful build, you will find a populated build directory (e.g., `cmake-build-release/` if you used the `release` preset and the example configuration). + +```console +> ls cmake-build-release/ +bin cmake_install.cmake _deps machete_generation.log +build.ninja CPackConfig.cmake detect_cuda_compute_capabilities.cu marlin_generation.log +_C.abi3.so CPackSourceConfig.cmake detect_cuda_version.cc _moe_C.abi3.so +CMakeCache.txt ctest _flashmla_C.abi3.so moe_marlin_generation.log +CMakeFiles cumem_allocator.abi3.so install_local_manifest.txt vllm-flash-attn +``` + +The `cmake --build ... --target install` command copies the compiled shared libraries (like `_C.abi3.so`, `_moe_C.abi3.so`, etc.) into the appropriate `vllm` package directory within your source tree. This updates your editable installation with the newly compiled kernels. + +## Additional Tips + +- **Adjust Parallelism:** Fine-tune the `CMAKE_JOB_POOLS` in `configurePresets` and `jobs` in `buildPresets` in your `CMakeUserPresets.json`. Too many jobs can overload systems with limited RAM or CPU cores, leading to slower builds or system instability. Too few won't fully utilize available resources. +- **Clean Builds When Necessary:** If you encounter persistent or strange build errors, especially after significant changes or switching branches, consider removing the CMake build directory (e.g., `rm -rf cmake-build-release`) and re-running the `cmake --preset` and `cmake --build` commands. +- **Specific Target Builds:** For even faster iterations when working on a specific module, you can sometimes build a specific target instead of the full `install` target, though `install` ensures all necessary components are updated in your Python environment. Refer to CMake documentation for more advanced target management. diff --git a/docs/deployment/docker.md b/docs/deployment/docker.md index eb84db7871e4..5f6a22c28c28 100644 --- a/docs/deployment/docker.md +++ b/docs/deployment/docker.md @@ -10,7 +10,7 @@ title: Using Docker vLLM offers an official Docker image for deployment. The image can be used to run OpenAI compatible server and is available on Docker Hub as [vllm/vllm-openai](https://hub.docker.com/r/vllm/vllm-openai/tags). -```console +```bash docker run --runtime nvidia --gpus all \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HUGGING_FACE_HUB_TOKEN=" \ @@ -22,7 +22,7 @@ docker run --runtime nvidia --gpus all \ This image can also be used with other container engines such as [Podman](https://podman.io/). -```console +```bash podman run --gpus all \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \ @@ -71,7 +71,7 @@ You can add any other [engine-args][engine-args] you need after the image tag (` You can build and run vLLM from source via the provided . To build vLLM: -```console +```bash # optionally specifies: --build-arg max_jobs=8 --build-arg nvcc_threads=2 DOCKER_BUILDKIT=1 docker build . \ --target vllm-openai \ @@ -99,7 +99,7 @@ of PyTorch Nightly and should be considered **experimental**. Using the flag `-- ??? Command - ```console + ```bash # Example of building on Nvidia GH200 server. (Memory usage: ~15GB, Build time: ~1475s / ~25 min, Image size: 6.93GB) python3 use_existing_torch.py DOCKER_BUILDKIT=1 docker build . \ @@ -118,7 +118,7 @@ of PyTorch Nightly and should be considered **experimental**. Using the flag `-- Run the following command on your host machine to register QEMU user static handlers: - ```console + ```bash docker run --rm --privileged multiarch/qemu-user-static --reset -p yes ``` @@ -128,7 +128,7 @@ of PyTorch Nightly and should be considered **experimental**. Using the flag `-- To run vLLM with the custom-built Docker image: -```console +```bash docker run --runtime nvidia --gpus all \ -v ~/.cache/huggingface:/root/.cache/huggingface \ -p 8000:8000 \ diff --git a/docs/deployment/frameworks/anything-llm.md b/docs/deployment/frameworks/anything-llm.md index a89e633c086e..4633c2946cde 100644 --- a/docs/deployment/frameworks/anything-llm.md +++ b/docs/deployment/frameworks/anything-llm.md @@ -15,7 +15,7 @@ It allows you to deploy a large language model (LLM) server with vLLM as the bac - Start the vLLM server with the supported chat completion model, e.g. -```console +```bash vllm serve Qwen/Qwen1.5-32B-Chat-AWQ --max-model-len 4096 ``` diff --git a/docs/deployment/frameworks/autogen.md b/docs/deployment/frameworks/autogen.md index 295664daeadb..13930e67ab2f 100644 --- a/docs/deployment/frameworks/autogen.md +++ b/docs/deployment/frameworks/autogen.md @@ -11,7 +11,7 @@ title: AutoGen - Setup [AutoGen](https://microsoft.github.io/autogen/0.2/docs/installation/) environment -```console +```bash pip install vllm # Install AgentChat and OpenAI client from Extensions @@ -23,7 +23,7 @@ pip install -U "autogen-agentchat" "autogen-ext[openai]" - Start the vLLM server with the supported chat completion model, e.g. -```console +```bash python -m vllm.entrypoints.openai.api_server \ --model mistralai/Mistral-7B-Instruct-v0.2 ``` diff --git a/docs/deployment/frameworks/cerebrium.md b/docs/deployment/frameworks/cerebrium.md index 8e096f26db71..5c5f2f48d50b 100644 --- a/docs/deployment/frameworks/cerebrium.md +++ b/docs/deployment/frameworks/cerebrium.md @@ -11,14 +11,14 @@ vLLM can be run on a cloud based GPU machine with [Cerebrium](https://www.cerebr To install the Cerebrium client, run: -```console +```bash pip install cerebrium cerebrium login ``` Next, create your Cerebrium project, run: -```console +```bash cerebrium init vllm-project ``` @@ -58,7 +58,7 @@ Next, let us add our code to handle inference for the LLM of your choice (`mistr Then, run the following code to deploy it to the cloud: -```console +```bash cerebrium deploy ``` diff --git a/docs/deployment/frameworks/chatbox.md b/docs/deployment/frameworks/chatbox.md index 10da2fc71002..b1b50b55146c 100644 --- a/docs/deployment/frameworks/chatbox.md +++ b/docs/deployment/frameworks/chatbox.md @@ -15,7 +15,7 @@ It allows you to deploy a large language model (LLM) server with vLLM as the bac - Start the vLLM server with the supported chat completion model, e.g. -```console +```bash vllm serve qwen/Qwen1.5-0.5B-Chat ``` diff --git a/docs/deployment/frameworks/dify.md b/docs/deployment/frameworks/dify.md index 886484b54347..a0e40784f0ea 100644 --- a/docs/deployment/frameworks/dify.md +++ b/docs/deployment/frameworks/dify.md @@ -18,13 +18,13 @@ This guide walks you through deploying Dify using a vLLM backend. - Start the vLLM server with the supported chat completion model, e.g. -```console +```bash vllm serve Qwen/Qwen1.5-7B-Chat ``` - Start the Dify server with docker compose ([details](https://github.com/langgenius/dify?tab=readme-ov-file#quick-start)): -```console +```bash git clone https://github.com/langgenius/dify.git cd dify cd docker diff --git a/docs/deployment/frameworks/dstack.md b/docs/deployment/frameworks/dstack.md index 0b91fc88ce3f..8b4bc459683b 100644 --- a/docs/deployment/frameworks/dstack.md +++ b/docs/deployment/frameworks/dstack.md @@ -11,14 +11,14 @@ vLLM can be run on a cloud based GPU machine with [dstack](https://dstack.ai/), To install dstack client, run: -```console +```bash pip install "dstack[all] dstack server ``` Next, to configure your dstack project, run: -```console +```bash mkdir -p vllm-dstack cd vllm-dstack dstack init diff --git a/docs/deployment/frameworks/haystack.md b/docs/deployment/frameworks/haystack.md index 04d9eba3065c..7a4cab4c2ee3 100644 --- a/docs/deployment/frameworks/haystack.md +++ b/docs/deployment/frameworks/haystack.md @@ -13,7 +13,7 @@ It allows you to deploy a large language model (LLM) server with vLLM as the bac - Setup vLLM and Haystack environment -```console +```bash pip install vllm haystack-ai ``` @@ -21,7 +21,7 @@ pip install vllm haystack-ai - Start the vLLM server with the supported chat completion model, e.g. -```console +```bash vllm serve mistralai/Mistral-7B-Instruct-v0.1 ``` diff --git a/docs/deployment/frameworks/helm.md b/docs/deployment/frameworks/helm.md index 192b90438acf..cff8af2c09d2 100644 --- a/docs/deployment/frameworks/helm.md +++ b/docs/deployment/frameworks/helm.md @@ -22,7 +22,7 @@ Before you begin, ensure that you have the following: To install the chart with the release name `test-vllm`: -```console +```bash helm upgrade --install --create-namespace --namespace=ns-vllm test-vllm . -f values.yaml --set secrets.s3endpoint=$ACCESS_POINT --set secrets.s3bucketname=$BUCKET --set secrets.s3accesskeyid=$ACCESS_KEY --set secrets.s3accesskey=$SECRET_KEY ``` @@ -30,7 +30,7 @@ helm upgrade --install --create-namespace --namespace=ns-vllm test-vllm . -f val To uninstall the `test-vllm` deployment: -```console +```bash helm uninstall test-vllm --namespace=ns-vllm ``` diff --git a/docs/deployment/frameworks/litellm.md b/docs/deployment/frameworks/litellm.md index 8498feaa2972..8279613b1a27 100644 --- a/docs/deployment/frameworks/litellm.md +++ b/docs/deployment/frameworks/litellm.md @@ -18,7 +18,7 @@ And LiteLLM supports all models on VLLM. - Setup vLLM and litellm environment -```console +```bash pip install vllm litellm ``` @@ -28,7 +28,7 @@ pip install vllm litellm - Start the vLLM server with the supported chat completion model, e.g. -```console +```bash vllm serve qwen/Qwen1.5-0.5B-Chat ``` @@ -56,7 +56,7 @@ vllm serve qwen/Qwen1.5-0.5B-Chat - Start the vLLM server with the supported embedding model, e.g. -```console +```bash vllm serve BAAI/bge-base-en-v1.5 ``` diff --git a/docs/deployment/frameworks/open-webui.md b/docs/deployment/frameworks/open-webui.md index 1ab1931068fa..676a0f58b54f 100644 --- a/docs/deployment/frameworks/open-webui.md +++ b/docs/deployment/frameworks/open-webui.md @@ -7,13 +7,13 @@ title: Open WebUI 2. Start the vLLM server with the supported chat completion model, e.g. -```console +```bash vllm serve qwen/Qwen1.5-0.5B-Chat ``` 1. Start the [Open WebUI](https://github.com/open-webui/open-webui) docker container (replace the vllm serve host and vllm serve port): -```console +```bash docker run -d -p 3000:8080 \ --name open-webui \ -v open-webui:/app/backend/data \ diff --git a/docs/deployment/frameworks/retrieval_augmented_generation.md b/docs/deployment/frameworks/retrieval_augmented_generation.md index cb26c8378dee..851c31db32f2 100644 --- a/docs/deployment/frameworks/retrieval_augmented_generation.md +++ b/docs/deployment/frameworks/retrieval_augmented_generation.md @@ -15,7 +15,7 @@ Here are the integrations: - Setup vLLM and langchain environment -```console +```bash pip install -U vllm \ langchain_milvus langchain_openai \ langchain_community beautifulsoup4 \ @@ -26,14 +26,14 @@ pip install -U vllm \ - Start the vLLM server with the supported embedding model, e.g. -```console +```bash # Start embedding service (port 8000) vllm serve ssmits/Qwen2-7B-Instruct-embed-base ``` - Start the vLLM server with the supported chat completion model, e.g. -```console +```bash # Start chat service (port 8001) vllm serve qwen/Qwen1.5-0.5B-Chat --port 8001 ``` @@ -52,7 +52,7 @@ python retrieval_augmented_generation_with_langchain.py - Setup vLLM and llamaindex environment -```console +```bash pip install vllm \ llama-index llama-index-readers-web \ llama-index-llms-openai-like \ @@ -64,14 +64,14 @@ pip install vllm \ - Start the vLLM server with the supported embedding model, e.g. -```console +```bash # Start embedding service (port 8000) vllm serve ssmits/Qwen2-7B-Instruct-embed-base ``` - Start the vLLM server with the supported chat completion model, e.g. -```console +```bash # Start chat service (port 8001) vllm serve qwen/Qwen1.5-0.5B-Chat --port 8001 ``` diff --git a/docs/deployment/frameworks/skypilot.md b/docs/deployment/frameworks/skypilot.md index b649312971b5..ecf987539ced 100644 --- a/docs/deployment/frameworks/skypilot.md +++ b/docs/deployment/frameworks/skypilot.md @@ -15,7 +15,7 @@ vLLM can be **run and scaled to multiple service replicas on clouds and Kubernet - Check that you have installed SkyPilot ([docs](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html)). - Check that `sky check` shows clouds or Kubernetes are enabled. -```console +```bash pip install skypilot-nightly sky check ``` @@ -71,7 +71,7 @@ See the vLLM SkyPilot YAML for serving, [serving.yaml](https://github.com/skypil Start the serving the Llama-3 8B model on any of the candidate GPUs listed (L4, A10g, ...): -```console +```bash HF_TOKEN="your-huggingface-token" sky launch serving.yaml --env HF_TOKEN ``` @@ -83,7 +83,7 @@ Check the output of the command. There will be a shareable gradio link (like the **Optional**: Serve the 70B model instead of the default 8B and use more GPU: -```console +```bash HF_TOKEN="your-huggingface-token" \ sky launch serving.yaml \ --gpus A100:8 \ @@ -159,7 +159,7 @@ SkyPilot can scale up the service to multiple service replicas with built-in aut Start the serving the Llama-3 8B model on multiple replicas: -```console +```bash HF_TOKEN="your-huggingface-token" \ sky serve up -n vllm serving.yaml \ --env HF_TOKEN @@ -167,7 +167,7 @@ HF_TOKEN="your-huggingface-token" \ Wait until the service is ready: -```console +```bash watch -n10 sky serve status vllm ``` @@ -271,13 +271,13 @@ This will scale the service up to when the QPS exceeds 2 for each replica. To update the service with the new config: -```console +```bash HF_TOKEN="your-huggingface-token" sky serve update vllm serving.yaml --env HF_TOKEN ``` To stop the service: -```console +```bash sky serve down vllm ``` @@ -317,7 +317,7 @@ It is also possible to access the Llama-3 service with a separate GUI frontend, 1. Start the chat web UI: - ```console + ```bash sky launch \ -c gui ./gui.yaml \ --env ENDPOINT=$(sky serve status --endpoint vllm) diff --git a/docs/deployment/frameworks/streamlit.md b/docs/deployment/frameworks/streamlit.md index 33ed8c5f5b54..5e998e3cca6e 100644 --- a/docs/deployment/frameworks/streamlit.md +++ b/docs/deployment/frameworks/streamlit.md @@ -15,13 +15,13 @@ It can be quickly integrated with vLLM as a backend API server, enabling powerfu - Start the vLLM server with the supported chat completion model, e.g. -```console +```bash vllm serve qwen/Qwen1.5-0.5B-Chat ``` - Install streamlit and openai: -```console +```bash pip install streamlit openai ``` @@ -29,7 +29,7 @@ pip install streamlit openai - Start the streamlit web UI and start to chat: -```console +```bash streamlit run streamlit_openai_chatbot_webserver.py # or specify the VLLM_API_BASE or VLLM_API_KEY diff --git a/docs/deployment/integrations/llamastack.md b/docs/deployment/integrations/llamastack.md index 2ae600a423ff..9bbc6b5b296c 100644 --- a/docs/deployment/integrations/llamastack.md +++ b/docs/deployment/integrations/llamastack.md @@ -7,7 +7,7 @@ vLLM is also available via [Llama Stack](https://github.com/meta-llama/llama-sta To install Llama Stack, run -```console +```bash pip install llama-stack -q ``` diff --git a/docs/deployment/k8s.md b/docs/deployment/k8s.md index 13225ba208fd..f01e3d2fae0e 100644 --- a/docs/deployment/k8s.md +++ b/docs/deployment/k8s.md @@ -115,7 +115,7 @@ Next, start the vLLM server as a Kubernetes Deployment and Service: We can verify that the vLLM server has started successfully via the logs (this might take a couple of minutes to download the model): -```console +```bash kubectl logs -l app.kubernetes.io/name=vllm ... INFO: Started server process [1] @@ -358,14 +358,14 @@ INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit) Apply the deployment and service configurations using `kubectl apply -f `: - ```console + ```bash kubectl apply -f deployment.yaml kubectl apply -f service.yaml ``` To test the deployment, run the following `curl` command: - ```console + ```bash curl http://mistral-7b.default.svc.cluster.local/v1/completions \ -H "Content-Type: application/json" \ -d '{ diff --git a/docs/deployment/nginx.md b/docs/deployment/nginx.md index 752be76b3864..7f09453be0c4 100644 --- a/docs/deployment/nginx.md +++ b/docs/deployment/nginx.md @@ -11,13 +11,13 @@ This document shows how to launch multiple vLLM serving containers and use Nginx This guide assumes that you have just cloned the vLLM project and you're currently in the vllm root directory. -```console +```bash export vllm_root=`pwd` ``` Create a file named `Dockerfile.nginx`: -```console +```dockerfile FROM nginx:latest RUN rm /etc/nginx/conf.d/default.conf EXPOSE 80 @@ -26,7 +26,7 @@ CMD ["nginx", "-g", "daemon off;"] Build the container: -```console +```bash docker build . -f Dockerfile.nginx --tag nginx-lb ``` @@ -60,14 +60,14 @@ Create a file named `nginx_conf/nginx.conf`. Note that you can add as many serve ## Build vLLM Container -```console +```bash cd $vllm_root docker build -f docker/Dockerfile . --tag vllm ``` If you are behind proxy, you can pass the proxy settings to the docker build command as shown below: -```console +```bash cd $vllm_root docker build \ -f docker/Dockerfile . \ @@ -80,7 +80,7 @@ docker build \ ## Create Docker Network -```console +```bash docker network create vllm_nginx ``` @@ -129,7 +129,7 @@ Notes: ## Launch Nginx -```console +```bash docker run \ -itd \ -p 8000:80 \ @@ -142,7 +142,7 @@ docker run \ ## Verify That vLLM Servers Are Ready -```console +```bash docker logs vllm0 | grep Uvicorn docker logs vllm1 | grep Uvicorn ``` diff --git a/docs/features/multimodal_inputs.md b/docs/features/multimodal_inputs.md index d4465beb8593..e3a77afb02f1 100644 --- a/docs/features/multimodal_inputs.md +++ b/docs/features/multimodal_inputs.md @@ -307,7 +307,7 @@ Full example: ``` @@ -370,7 +370,7 @@ Full example: ``` @@ -476,7 +476,7 @@ Full example: ``` diff --git a/docs/features/quantization/auto_awq.md b/docs/features/quantization/auto_awq.md index 8362672f40b3..9f97ea406e25 100644 --- a/docs/features/quantization/auto_awq.md +++ b/docs/features/quantization/auto_awq.md @@ -9,7 +9,7 @@ The main benefits are lower latency and memory usage. You can quantize your own models by installing AutoAWQ or picking one of the [6500+ models on Huggingface](https://huggingface.co/models?search=awq). -```console +```bash pip install autoawq ``` @@ -43,7 +43,7 @@ After installing AutoAWQ, you are ready to quantize a model. Please refer to the To run an AWQ model with vLLM, you can use [TheBloke/Llama-2-7b-Chat-AWQ](https://huggingface.co/TheBloke/Llama-2-7b-Chat-AWQ) with the following command: -```console +```bash python examples/offline_inference/llm_engine_example.py \ --model TheBloke/Llama-2-7b-Chat-AWQ \ --quantization awq diff --git a/docs/features/quantization/bitblas.md b/docs/features/quantization/bitblas.md index 3f8ae7a959cd..c8f874ff8414 100644 --- a/docs/features/quantization/bitblas.md +++ b/docs/features/quantization/bitblas.md @@ -12,7 +12,7 @@ vLLM now supports [BitBLAS](https://github.com/microsoft/BitBLAS) for more effic Below are the steps to utilize BitBLAS with vLLM. -```console +```bash pip install bitblas>=0.1.0 ``` diff --git a/docs/features/quantization/bnb.md b/docs/features/quantization/bnb.md index a8dc2476f30a..5756fdb28837 100644 --- a/docs/features/quantization/bnb.md +++ b/docs/features/quantization/bnb.md @@ -9,7 +9,7 @@ Compared to other quantization methods, BitsAndBytes eliminates the need for cal Below are the steps to utilize BitsAndBytes with vLLM. -```console +```bash pip install bitsandbytes>=0.45.3 ``` @@ -54,6 +54,6 @@ llm = LLM( Append the following to your model arguments for 4bit inflight quantization: -```console +```bash --quantization bitsandbytes ``` diff --git a/docs/features/quantization/fp8.md b/docs/features/quantization/fp8.md index ec7639af805b..b9ed668b2ef3 100644 --- a/docs/features/quantization/fp8.md +++ b/docs/features/quantization/fp8.md @@ -23,7 +23,7 @@ The FP8 types typically supported in hardware have two distinct representations, To produce performant FP8 quantized models with vLLM, you'll need to install the [llm-compressor](https://github.com/vllm-project/llm-compressor/) library: -```console +```bash pip install llmcompressor ``` @@ -81,7 +81,7 @@ Since simple RTN does not require data for weight quantization and the activatio Install `vllm` and `lm-evaluation-harness` for evaluation: -```console +```bash pip install vllm lm-eval==0.4.4 ``` @@ -99,9 +99,9 @@ Evaluate accuracy with `lm_eval` (for example on 250 samples of `gsm8k`): !!! note Quantized models can be sensitive to the presence of the `bos` token. `lm_eval` does not add a `bos` token by default, so make sure to include the `add_bos_token=True` argument when running your evaluations. -```console -$ MODEL=$PWD/Meta-Llama-3-8B-Instruct-FP8-Dynamic -$ lm_eval \ +```bash +MODEL=$PWD/Meta-Llama-3-8B-Instruct-FP8-Dynamic +lm_eval \ --model vllm \ --model_args pretrained=$MODEL,add_bos_token=True \ --tasks gsm8k --num_fewshot 5 --batch_size auto --limit 250 diff --git a/docs/features/quantization/gguf.md b/docs/features/quantization/gguf.md index 014b513eeda7..102a3ee1cccc 100644 --- a/docs/features/quantization/gguf.md +++ b/docs/features/quantization/gguf.md @@ -11,7 +11,7 @@ title: GGUF To run a GGUF model with vLLM, you can download and use the local GGUF model from [TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF](https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF) with the following command: -```console +```bash wget https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF/resolve/main/tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf # We recommend using the tokenizer from base model to avoid long-time and buggy tokenizer conversion. vllm serve ./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf \ @@ -20,7 +20,7 @@ vllm serve ./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf \ You can also add `--tensor-parallel-size 2` to enable tensor parallelism inference with 2 GPUs: -```console +```bash # We recommend using the tokenizer from base model to avoid long-time and buggy tokenizer conversion. vllm serve ./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf \ --tokenizer TinyLlama/TinyLlama-1.1B-Chat-v1.0 \ @@ -32,7 +32,7 @@ vllm serve ./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf \ GGUF assumes that huggingface can convert the metadata to a config file. In case huggingface doesn't support your model you can manually create a config and pass it as hf-config-path -```console +```bash # If you model is not supported by huggingface you can manually provide a huggingface compatible config path vllm serve ./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf \ --tokenizer TinyLlama/TinyLlama-1.1B-Chat-v1.0 \ diff --git a/docs/features/quantization/gptqmodel.md b/docs/features/quantization/gptqmodel.md index 2f088f474f19..37bb02d4fb5b 100644 --- a/docs/features/quantization/gptqmodel.md +++ b/docs/features/quantization/gptqmodel.md @@ -21,7 +21,7 @@ for more details on this and other advanced features. You can quantize your own models by installing [GPTQModel](https://github.com/ModelCloud/GPTQModel) or picking one of the [5000+ models on Huggingface](https://huggingface.co/models?search=gptq). -```console +```bash pip install -U gptqmodel --no-build-isolation -v ``` @@ -60,7 +60,7 @@ Here is an example of how to quantize `meta-llama/Llama-3.2-1B-Instruct`: To run an GPTQModel quantized model with vLLM, you can use [DeepSeek-R1-Distill-Qwen-7B-gptqmodel-4bit-vortex-v2](https://huggingface.co/ModelCloud/DeepSeek-R1-Distill-Qwen-7B-gptqmodel-4bit-vortex-v2) with the following command: -```console +```bash python examples/offline_inference/llm_engine_example.py \ --model ModelCloud/DeepSeek-R1-Distill-Qwen-7B-gptqmodel-4bit-vortex-v2 ``` diff --git a/docs/features/quantization/int4.md b/docs/features/quantization/int4.md index 185e13649f48..2008bef5c8a2 100644 --- a/docs/features/quantization/int4.md +++ b/docs/features/quantization/int4.md @@ -14,13 +14,13 @@ Please visit the HF collection of [quantized INT4 checkpoints of popular LLMs re To use INT4 quantization with vLLM, you'll need to install the [llm-compressor](https://github.com/vllm-project/llm-compressor/) library: -```console +```bash pip install llmcompressor ``` Additionally, install `vllm` and `lm-evaluation-harness` for evaluation: -```console +```bash pip install vllm lm-eval==0.4.4 ``` @@ -116,8 +116,8 @@ model = LLM("./Meta-Llama-3-8B-Instruct-W4A16-G128") To evaluate accuracy, you can use `lm_eval`: -```console -$ lm_eval --model vllm \ +```bash +lm_eval --model vllm \ --model_args pretrained="./Meta-Llama-3-8B-Instruct-W4A16-G128",add_bos_token=true \ --tasks gsm8k \ --num_fewshot 5 \ diff --git a/docs/features/quantization/int8.md b/docs/features/quantization/int8.md index de5ae5c04401..3a8f855aa057 100644 --- a/docs/features/quantization/int8.md +++ b/docs/features/quantization/int8.md @@ -15,13 +15,13 @@ Please visit the HF collection of [quantized INT8 checkpoints of popular LLMs re To use INT8 quantization with vLLM, you'll need to install the [llm-compressor](https://github.com/vllm-project/llm-compressor/) library: -```console +```bash pip install llmcompressor ``` Additionally, install `vllm` and `lm-evaluation-harness` for evaluation: -```console +```bash pip install vllm lm-eval==0.4.4 ``` @@ -122,8 +122,8 @@ model = LLM("./Meta-Llama-3-8B-Instruct-W8A8-Dynamic-Per-Token") To evaluate accuracy, you can use `lm_eval`: -```console -$ lm_eval --model vllm \ +```bash +lm_eval --model vllm \ --model_args pretrained="./Meta-Llama-3-8B-Instruct-W8A8-Dynamic-Per-Token",add_bos_token=true \ --tasks gsm8k \ --num_fewshot 5 \ diff --git a/docs/features/quantization/modelopt.md b/docs/features/quantization/modelopt.md index 0bb6003832ba..39f2a78e705f 100644 --- a/docs/features/quantization/modelopt.md +++ b/docs/features/quantization/modelopt.md @@ -4,7 +4,7 @@ The [NVIDIA TensorRT Model Optimizer](https://github.com/NVIDIA/TensorRT-Model-O We recommend installing the library with: -```console +```bash pip install nvidia-modelopt ``` diff --git a/docs/features/quantization/quantized_kvcache.md b/docs/features/quantization/quantized_kvcache.md index 52b8d38ace1d..323dcb7d052d 100644 --- a/docs/features/quantization/quantized_kvcache.md +++ b/docs/features/quantization/quantized_kvcache.md @@ -65,7 +65,7 @@ For optimal model quality when using FP8 KV Cache, we recommend using calibrated First, install the required dependencies: -```console +```bash pip install llmcompressor ``` diff --git a/docs/features/quantization/quark.md b/docs/features/quantization/quark.md index 6e77584da232..77e383495406 100644 --- a/docs/features/quantization/quark.md +++ b/docs/features/quantization/quark.md @@ -13,7 +13,7 @@ AWQ, GPTQ, Rotation and SmoothQuant. Before quantizing models, you need to install Quark. The latest release of Quark can be installed with pip: -```console +```bash pip install amd-quark ``` @@ -22,13 +22,13 @@ for more installation details. Additionally, install `vllm` and `lm-evaluation-harness` for evaluation: -```console +```bash pip install vllm lm-eval==0.4.4 ``` ## Quantization Process -After installing Quark, we will use an example to illustrate how to use Quark. +After installing Quark, we will use an example to illustrate how to use Quark. The Quark quantization process can be listed for 5 steps as below: 1. Load the model @@ -209,8 +209,8 @@ Now, you can load and run the Quark quantized model directly through the LLM ent Or, you can use `lm_eval` to evaluate accuracy: -```console -$ lm_eval --model vllm \ +```bash +lm_eval --model vllm \ --model_args pretrained=Llama-2-70b-chat-hf-w-fp8-a-fp8-kvcache-fp8-pertensor-autosmoothquant,kv_cache_dtype='fp8',quantization='quark' \ --tasks gsm8k ``` @@ -222,7 +222,7 @@ to quantize large language models more conveniently. It supports quantizing mode of different quantization schemes and optimization algorithms. It can export the quantized model and run evaluation tasks on the fly. With the script, the example above can be: -```console +```bash python3 quantize_quark.py --model_dir meta-llama/Llama-2-70b-chat-hf \ --output_dir /path/to/output \ --quant_scheme w_fp8_a_fp8 \ diff --git a/docs/features/quantization/torchao.md b/docs/features/quantization/torchao.md index c45979a36117..f8df3c4b0809 100644 --- a/docs/features/quantization/torchao.md +++ b/docs/features/quantization/torchao.md @@ -4,7 +4,7 @@ TorchAO is an architecture optimization library for PyTorch, it provides high pe We recommend installing the latest torchao nightly with -```console +```bash # Install the latest TorchAO nightly build # Choose the CUDA version that matches your system (cu126, cu128, etc.) pip install \ diff --git a/docs/features/tool_calling.md b/docs/features/tool_calling.md index 9fb878777a48..41a024ba632e 100644 --- a/docs/features/tool_calling.md +++ b/docs/features/tool_calling.md @@ -351,7 +351,7 @@ Here is a summary of a plugin file: Then you can use this plugin in the command line like this. -```console +```bash --enable-auto-tool-choice \ --tool-parser-plugin --tool-call-parser example \ diff --git a/docs/getting_started/installation/aws_neuron.md b/docs/getting_started/installation/aws_neuron.md index 6b2efd85f06b..b8bd76bd5bcb 100644 --- a/docs/getting_started/installation/aws_neuron.md +++ b/docs/getting_started/installation/aws_neuron.md @@ -26,7 +26,7 @@ The easiest way to launch a Trainium or Inferentia instance with pre-installed N - After launching the instance, follow the instructions in [Connect to your instance](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/AccessingInstancesLinux.html) to connect to the instance - Once inside your instance, activate the pre-installed virtual environment for inference by running -```console +```bash source /opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/bin/activate ``` @@ -47,7 +47,7 @@ Currently, there are no pre-built Neuron wheels. To build and install vLLM from source, run: -```console +```bash git clone https://github.com/vllm-project/vllm.git cd vllm pip install -U -r requirements/neuron.txt @@ -66,7 +66,7 @@ Refer to [vLLM User Guide for NxD Inference](https://awsdocs-neuron.readthedocs- To install the AWS Neuron fork, run the following: -```console +```bash git clone -b neuron-2.23-vllm-v0.7.2 https://github.com/aws-neuron/upstreaming-to-vllm.git cd upstreaming-to-vllm pip install -r requirements/neuron.txt @@ -100,7 +100,7 @@ to perform most of the heavy lifting which includes PyTorch model initialization To configure NxD Inference features through the vLLM entrypoint, use the `override_neuron_config` setting. Provide the configs you want to override as a dictionary (or JSON object when starting vLLM from the CLI). For example, to disable auto bucketing, include -```console +```python override_neuron_config={ "enable_bucketing":False, } @@ -108,7 +108,7 @@ override_neuron_config={ or when launching vLLM from the CLI, pass -```console +```bash --override-neuron-config "{\"enable_bucketing\":false}" ``` diff --git a/docs/getting_started/installation/cpu.md b/docs/getting_started/installation/cpu.md index 5d7019e5a867..370b854def0f 100644 --- a/docs/getting_started/installation/cpu.md +++ b/docs/getting_started/installation/cpu.md @@ -78,13 +78,13 @@ Currently, there are no pre-built CPU wheels. ??? Commands - ```console - $ docker build -f docker/Dockerfile.cpu \ + ```bash + docker build -f docker/Dockerfile.cpu \ --tag vllm-cpu-env \ --target vllm-openai . - # Launching OpenAI server - $ docker run --rm \ + # Launching OpenAI server + docker run --rm \ --privileged=true \ --shm-size=4g \ -p 8000:8000 \ @@ -123,7 +123,7 @@ vLLM CPU backend supports the following vLLM features: - We highly recommend to use TCMalloc for high performance memory allocation and better cache locality. For example, on Ubuntu 22.4, you can run: -```console +```bash sudo apt-get install libtcmalloc-minimal4 # install TCMalloc library find / -name *libtcmalloc* # find the dynamic link library path export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:$LD_PRELOAD # prepend the library to LD_PRELOAD @@ -132,7 +132,7 @@ python examples/offline_inference/basic/basic.py # run vLLM - When using the online serving, it is recommended to reserve 1-2 CPU cores for the serving framework to avoid CPU oversubscription. For example, on a platform with 32 physical CPU cores, reserving CPU 30 and 31 for the framework and using CPU 0-29 for OpenMP: -```console +```bash export VLLM_CPU_KVCACHE_SPACE=40 export VLLM_CPU_OMP_THREADS_BIND=0-29 vllm serve facebook/opt-125m @@ -140,7 +140,7 @@ vllm serve facebook/opt-125m or using default auto thread binding: -```console +```bash export VLLM_CPU_KVCACHE_SPACE=40 export VLLM_CPU_NUM_OF_RESERVED_CPU=2 vllm serve facebook/opt-125m @@ -189,7 +189,7 @@ vllm serve facebook/opt-125m - Tensor Parallel is supported for serving and offline inferencing. In general each NUMA node is treated as one GPU card. Below is the example script to enable Tensor Parallel = 2 for serving: - ```console + ```bash VLLM_CPU_KVCACHE_SPACE=40 VLLM_CPU_OMP_THREADS_BIND="0-31|32-63" \ vllm serve meta-llama/Llama-2-7b-chat-hf \ -tp=2 \ @@ -198,7 +198,7 @@ vllm serve facebook/opt-125m or using default auto thread binding: - ```console + ```bash VLLM_CPU_KVCACHE_SPACE=40 \ vllm serve meta-llama/Llama-2-7b-chat-hf \ -tp=2 \ diff --git a/docs/getting_started/installation/cpu/apple.inc.md b/docs/getting_started/installation/cpu/apple.inc.md index 7a91e3ce5e5b..1771213f5591 100644 --- a/docs/getting_started/installation/cpu/apple.inc.md +++ b/docs/getting_started/installation/cpu/apple.inc.md @@ -25,11 +25,11 @@ Currently the CPU implementation for macOS supports FP32 and FP16 datatypes. After installation of XCode and the Command Line Tools, which include Apple Clang, execute the following commands to build and install vLLM from the source. -```console +```bash git clone https://github.com/vllm-project/vllm.git cd vllm pip install -r requirements/cpu.txt -pip install -e . +pip install -e . ``` !!! note diff --git a/docs/getting_started/installation/cpu/arm.inc.md b/docs/getting_started/installation/cpu/arm.inc.md index 59b71dcaf911..6c05900cf45c 100644 --- a/docs/getting_started/installation/cpu/arm.inc.md +++ b/docs/getting_started/installation/cpu/arm.inc.md @@ -23,7 +23,7 @@ ARM CPU backend currently supports Float32, FP16 and BFloat16 datatypes. # --8<-- [end:pre-built-wheels] # --8<-- [start:build-wheel-from-source] ---8<-- "docs/getting_started/installation/cpu/cpu/build.inc.md" +--8<-- "docs/getting_started/installation/cpu/build.inc.md" Testing has been conducted on AWS Graviton3 instances for compatibility. diff --git a/docs/getting_started/installation/cpu/build.inc.md b/docs/getting_started/installation/cpu/build.inc.md index 7ddadccb1b4f..d9ca04edee02 100644 --- a/docs/getting_started/installation/cpu/build.inc.md +++ b/docs/getting_started/installation/cpu/build.inc.md @@ -1,6 +1,6 @@ First, install recommended compiler. We recommend to use `gcc/g++ >= 12.3.0` as the default compiler to avoid potential problems. For example, on Ubuntu 22.4, you can run: -```console +```bash sudo apt-get update -y sudo apt-get install -y gcc-12 g++-12 libnuma-dev python3-dev sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 @@ -8,14 +8,14 @@ sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave / Second, clone vLLM project: -```console +```bash git clone https://github.com/vllm-project/vllm.git vllm_source cd vllm_source ``` Third, install Python packages for vLLM CPU backend building: -```console +```bash pip install --upgrade pip pip install "cmake>=3.26.1" wheel packaging ninja "setuptools-scm>=8" numpy pip install -v -r requirements/cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu @@ -23,13 +23,13 @@ pip install -v -r requirements/cpu.txt --extra-index-url https://download.pytorc Finally, build and install vLLM CPU backend: -```console +```bash VLLM_TARGET_DEVICE=cpu python setup.py install ``` If you want to develop vllm, install it in editable mode instead. -```console +```bash VLLM_TARGET_DEVICE=cpu python setup.py develop ``` diff --git a/docs/getting_started/installation/cpu/s390x.inc.md b/docs/getting_started/installation/cpu/s390x.inc.md index 670485feefb6..6c6c40baecec 100644 --- a/docs/getting_started/installation/cpu/s390x.inc.md +++ b/docs/getting_started/installation/cpu/s390x.inc.md @@ -26,7 +26,7 @@ Currently the CPU implementation for s390x architecture supports FP32 datatype o Install the following packages from the package manager before building the vLLM. For example on RHEL 9.4: -```console +```bash dnf install -y \ which procps findutils tar vim git gcc g++ make patch make cython zlib-devel \ libjpeg-turbo-devel libtiff-devel libpng-devel libwebp-devel freetype-devel harfbuzz-devel \ @@ -35,7 +35,7 @@ dnf install -y \ Install rust>=1.80 which is needed for `outlines-core` and `uvloop` python packages installation. -```console +```bash curl https://sh.rustup.rs -sSf | sh -s -- -y && \ . "$HOME/.cargo/env" ``` @@ -45,7 +45,7 @@ Execute the following commands to build and install vLLM from the source. !!! tip Please build the following dependencies, `torchvision`, `pyarrow` from the source before building vLLM. -```console +```bash sed -i '/^torch/d' requirements-build.txt # remove torch from requirements-build.txt since we use nightly builds pip install -v \ --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ diff --git a/docs/getting_started/installation/cpu/x86.inc.md b/docs/getting_started/installation/cpu/x86.inc.md index 9434eeea8b4a..0412d4ccef00 100644 --- a/docs/getting_started/installation/cpu/x86.inc.md +++ b/docs/getting_started/installation/cpu/x86.inc.md @@ -24,7 +24,7 @@ vLLM initially supports basic model inferencing and serving on x86 CPU platform, # --8<-- [end:pre-built-wheels] # --8<-- [start:build-wheel-from-source] ---8<-- "docs/getting_started/installation/cpu/cpu/build.inc.md" +--8<-- "docs/getting_started/installation/cpu/build.inc.md" !!! note - AVX512_BF16 is an extension ISA provides native BF16 data type conversion and vector product instructions, which brings some performance improvement compared with pure AVX512. The CPU backend build script will check the host CPU flags to determine whether to enable AVX512_BF16. diff --git a/docs/getting_started/installation/google_tpu.md b/docs/getting_started/installation/google_tpu.md index 0cb10b8de835..a81a19df38b0 100644 --- a/docs/getting_started/installation/google_tpu.md +++ b/docs/getting_started/installation/google_tpu.md @@ -68,7 +68,7 @@ For more information about using TPUs with GKE, see: Create a TPU v5e with 4 TPU chips: -```console +```bash gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \ --node-id TPU_NAME \ --project PROJECT_ID \ @@ -156,13 +156,13 @@ See [deployment-docker-pre-built-image][deployment-docker-pre-built-image] for i You can use to build a Docker image with TPU support. -```console +```bash docker build -f docker/Dockerfile.tpu -t vllm-tpu . ``` Run the Docker image with the following command: -```console +```bash # Make sure to add `--privileged --net host --shm-size=16G`. docker run --privileged --net host --shm-size=16G -it vllm-tpu ``` @@ -185,6 +185,6 @@ docker run --privileged --net host --shm-size=16G -it vllm-tpu Install OpenBLAS with the following command: - ```console + ```bash sudo apt-get install --no-install-recommends --yes libopenblas-base libopenmpi-dev libomp-dev ``` diff --git a/docs/getting_started/installation/gpu/cuda.inc.md b/docs/getting_started/installation/gpu/cuda.inc.md index 4503bb443188..0417a25f85ad 100644 --- a/docs/getting_started/installation/gpu/cuda.inc.md +++ b/docs/getting_started/installation/gpu/cuda.inc.md @@ -22,7 +22,7 @@ Therefore, it is recommended to install vLLM with a **fresh new** environment. I You can install vLLM using either `pip` or `uv pip`: -```console +```bash # Install vLLM with CUDA 12.8. # If you are using pip. pip install vllm --extra-index-url https://download.pytorch.org/whl/cu128 @@ -37,7 +37,7 @@ We recommend leveraging `uv` to [automatically select the appropriate PyTorch in As of now, vLLM's binaries are compiled with CUDA 12.8 and public PyTorch release versions by default. We also provide vLLM binaries compiled with CUDA 12.6, 11.8, and public PyTorch release versions: -```console +```bash # Install vLLM with CUDA 11.8. export VLLM_VERSION=0.6.1.post1 export PYTHON_VERSION=312 @@ -52,7 +52,7 @@ LLM inference is a fast-evolving field, and the latest code may contain bug fixe ##### Install the latest code using `pip` -```console +```bash pip install -U vllm \ --pre \ --extra-index-url https://wheels.vllm.ai/nightly @@ -62,7 +62,7 @@ pip install -U vllm \ Another way to install the latest code is to use `uv`: -```console +```bash uv pip install -U vllm \ --torch-backend=auto \ --extra-index-url https://wheels.vllm.ai/nightly @@ -72,7 +72,7 @@ uv pip install -U vllm \ If you want to access the wheels for previous commits (e.g. to bisect the behavior change, performance regression), due to the limitation of `pip`, you have to specify the full URL of the wheel file by embedding the commit hash in the URL: -```console +```bash export VLLM_COMMIT=33f460b17a54acb3b6cc0b03f4a17876cff5eafd # use full commit hash from the main branch pip install https://wheels.vllm.ai/${VLLM_COMMIT}/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl ``` @@ -83,7 +83,7 @@ Note that the wheels are built with Python 3.8 ABI (see [PEP 425](https://peps.p If you want to access the wheels for previous commits (e.g. to bisect the behavior change, performance regression), you can specify the commit hash in the URL: -```console +```bash export VLLM_COMMIT=72d9c316d3f6ede485146fe5aabd4e61dbc59069 # use full commit hash from the main branch uv pip install vllm \ --torch-backend=auto \ @@ -99,7 +99,7 @@ The `uv` approach works for vLLM `v0.6.6` and later and offers an easy-to-rememb If you only need to change Python code, you can build and install vLLM without compilation. Using `pip`'s [`--editable` flag](https://pip.pypa.io/en/stable/topics/local-project-installs/#editable-installs), changes you make to the code will be reflected when you run vLLM: -```console +```bash git clone https://github.com/vllm-project/vllm.git cd vllm VLLM_USE_PRECOMPILED=1 pip install --editable . @@ -118,7 +118,7 @@ This command will do the following: In case you see an error about wheel not found when running the above command, it might be because the commit you based on in the main branch was just merged and the wheel is being built. In this case, you can wait for around an hour to try again, or manually assign the previous commit in the installation using the `VLLM_PRECOMPILED_WHEEL_LOCATION` environment variable. -```console +```bash export VLLM_COMMIT=72d9c316d3f6ede485146fe5aabd4e61dbc59069 # use full commit hash from the main branch export VLLM_PRECOMPILED_WHEEL_LOCATION=https://wheels.vllm.ai/${VLLM_COMMIT}/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl pip install --editable . @@ -134,7 +134,7 @@ You can find more information about vLLM's wheels in [install-the-latest-code][i If you want to modify C++ or CUDA code, you'll need to build vLLM from source. This can take several minutes: -```console +```bash git clone https://github.com/vllm-project/vllm.git cd vllm pip install -e . @@ -151,6 +151,9 @@ pip install -e . [sccache](https://github.com/mozilla/sccache) works similarly to `ccache`, but has the capability to utilize caching in remote storage environments. The following environment variables can be set to configure the vLLM `sccache` remote: `SCCACHE_BUCKET=vllm-build-sccache SCCACHE_REGION=us-west-2 SCCACHE_S3_NO_CREDENTIALS=1`. We also recommend setting `SCCACHE_IDLE_TIMEOUT=0`. +!!! note "Faster Kernel Development" + For frequent C++/CUDA kernel changes, after the initial `pip install -e .` setup, consider using the [Incremental Compilation Workflow](../../contributing/incremental_build.md) for significantly faster rebuilds of only the modified kernel code. + ##### Use an existing PyTorch installation There are scenarios where the PyTorch dependency cannot be easily installed via pip, e.g.: @@ -160,7 +163,7 @@ There are scenarios where the PyTorch dependency cannot be easily installed via To build vLLM using an existing PyTorch installation: -```console +```bash git clone https://github.com/vllm-project/vllm.git cd vllm python use_existing_torch.py @@ -173,7 +176,7 @@ pip install --no-build-isolation -e . Currently, before starting the build process, vLLM fetches cutlass code from GitHub. However, there may be scenarios where you want to use a local version of cutlass instead. To achieve this, you can set the environment variable VLLM_CUTLASS_SRC_DIR to point to your local cutlass directory. -```console +```bash git clone https://github.com/vllm-project/vllm.git cd vllm VLLM_CUTLASS_SRC_DIR=/path/to/cutlass pip install -e . @@ -184,7 +187,7 @@ VLLM_CUTLASS_SRC_DIR=/path/to/cutlass pip install -e . To avoid your system being overloaded, you can limit the number of compilation jobs to be run simultaneously, via the environment variable `MAX_JOBS`. For example: -```console +```bash export MAX_JOBS=6 pip install -e . ``` @@ -194,7 +197,7 @@ A side effect is a much slower build process. Additionally, if you have trouble building vLLM, we recommend using the NVIDIA PyTorch Docker image. -```console +```bash # Use `--ipc=host` to make sure the shared memory is large enough. docker run \ --gpus all \ @@ -205,14 +208,14 @@ docker run \ If you don't want to use docker, it is recommended to have a full installation of CUDA Toolkit. You can download and install it from [the official website](https://developer.nvidia.com/cuda-toolkit-archive). After installation, set the environment variable `CUDA_HOME` to the installation path of CUDA Toolkit, and make sure that the `nvcc` compiler is in your `PATH`, e.g.: -```console +```bash export CUDA_HOME=/usr/local/cuda export PATH="${CUDA_HOME}/bin:$PATH" ``` Here is a sanity check to verify that the CUDA Toolkit is correctly installed: -```console +```bash nvcc --version # verify that nvcc is in your PATH ${CUDA_HOME}/bin/nvcc --version # verify that nvcc is in your CUDA_HOME ``` @@ -223,7 +226,7 @@ vLLM can fully run only on Linux but for development purposes, you can still bui Simply disable the `VLLM_TARGET_DEVICE` environment variable before installing: -```console +```bash export VLLM_TARGET_DEVICE=empty pip install -e . ``` @@ -238,7 +241,7 @@ See [deployment-docker-pre-built-image][deployment-docker-pre-built-image] for i Another way to access the latest code is to use the docker images: -```console +```bash export VLLM_COMMIT=33f460b17a54acb3b6cc0b03f4a17876cff5eafd # use full commit hash from the main branch docker pull public.ecr.aws/q9t5s3a7/vllm-ci-postmerge-repo:${VLLM_COMMIT} ``` diff --git a/docs/getting_started/installation/gpu/rocm.inc.md b/docs/getting_started/installation/gpu/rocm.inc.md index 6bc714fe6e8b..aa4cacaf1aed 100644 --- a/docs/getting_started/installation/gpu/rocm.inc.md +++ b/docs/getting_started/installation/gpu/rocm.inc.md @@ -31,17 +31,17 @@ Currently, there are no pre-built ROCm wheels. Alternatively, you can install PyTorch using PyTorch wheels. You can check PyTorch installation guide in PyTorch [Getting Started](https://pytorch.org/get-started/locally/). Example: - ```console + ```bash # Install PyTorch - $ pip uninstall torch -y - $ pip install --no-cache-dir --pre torch --index-url https://download.pytorch.org/whl/nightly/rocm6.3 + pip uninstall torch -y + pip install --no-cache-dir --pre torch --index-url https://download.pytorch.org/whl/nightly/rocm6.3 ``` 1. Install [Triton flash attention for ROCm](https://github.com/ROCm/triton) Install ROCm's Triton flash attention (the default triton-mlir branch) following the instructions from [ROCm/triton](https://github.com/ROCm/triton/blob/triton-mlir/README.md) - ```console + ```bash python3 -m pip install ninja cmake wheel pybind11 pip uninstall -y triton git clone https://github.com/OpenAI/triton.git @@ -62,7 +62,7 @@ Currently, there are no pre-built ROCm wheels. For example, for ROCm 6.3, suppose your gfx arch is `gfx90a`. To get your gfx architecture, run `rocminfo |grep gfx`. - ```console + ```bash git clone https://github.com/ROCm/flash-attention.git cd flash-attention git checkout b7d29fb @@ -76,7 +76,7 @@ Currently, there are no pre-built ROCm wheels. 3. If you choose to build AITER yourself to use a certain branch or commit, you can build AITER using the following steps: - ```console + ```bash python3 -m pip uninstall -y aiter git clone --recursive https://github.com/ROCm/aiter.git cd aiter @@ -148,7 +148,7 @@ If you choose to build this rocm_base image yourself, the steps are as follows. It is important that the user kicks off the docker build using buildkit. Either the user put DOCKER_BUILDKIT=1 as environment variable when calling docker build command, or the user needs to setup buildkit in the docker daemon configuration /etc/docker/daemon.json as follows and restart the daemon: -```console +```json { "features": { "buildkit": true @@ -158,7 +158,7 @@ It is important that the user kicks off the docker build using buildkit. Either To build vllm on ROCm 6.3 for MI200 and MI300 series, you can use the default: -```console +```bash DOCKER_BUILDKIT=1 docker build \ -f docker/Dockerfile.rocm_base \ -t rocm/vllm-dev:base . @@ -169,7 +169,7 @@ DOCKER_BUILDKIT=1 docker build \ First, build a docker image from and launch a docker container from the image. It is important that the user kicks off the docker build using buildkit. Either the user put `DOCKER_BUILDKIT=1` as environment variable when calling docker build command, or the user needs to setup buildkit in the docker daemon configuration /etc/docker/daemon.json as follows and restart the daemon: -```console +```bash { "features": { "buildkit": true @@ -187,13 +187,13 @@ Their values can be passed in when running `docker build` with `--build-arg` opt To build vllm on ROCm 6.3 for MI200 and MI300 series, you can use the default: -```console +```bash DOCKER_BUILDKIT=1 docker build -f docker/Dockerfile.rocm -t vllm-rocm . ``` To build vllm on ROCm 6.3 for Radeon RX7900 series (gfx1100), you should pick the alternative base image: -```console +```bash DOCKER_BUILDKIT=1 docker build \ --build-arg BASE_IMAGE="rocm/vllm-dev:navi_base" \ -f docker/Dockerfile.rocm \ @@ -205,7 +205,7 @@ To run the above docker image `vllm-rocm`, use the below command: ??? Command - ```console + ```bash docker run -it \ --network=host \ --group-add=video \ diff --git a/docs/getting_started/installation/gpu/xpu.inc.md b/docs/getting_started/installation/gpu/xpu.inc.md index 128fff164c3a..4469be36c007 100644 --- a/docs/getting_started/installation/gpu/xpu.inc.md +++ b/docs/getting_started/installation/gpu/xpu.inc.md @@ -22,10 +22,10 @@ Currently, there are no pre-built XPU wheels. # --8<-- [end:pre-built-wheels] # --8<-- [start:build-wheel-from-source] -- First, install required driver and Intel OneAPI 2025.0 or later. +- First, install required [driver](https://dgpu-docs.intel.com/driver/installation.html#installing-gpu-drivers) and [Intel OneAPI](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html) 2025.0 or later. - Second, install Python packages for vLLM XPU backend building: -```console +```bash git clone https://github.com/vllm-project/vllm.git cd vllm pip install --upgrade pip @@ -34,7 +34,7 @@ pip install -v -r requirements/xpu.txt - Then, build and install vLLM XPU backend: -```console +```bash VLLM_TARGET_DEVICE=xpu python setup.py install ``` @@ -53,9 +53,9 @@ Currently, there are no pre-built XPU images. # --8<-- [end:pre-built-images] # --8<-- [start:build-image-from-source] -```console -$ docker build -f docker/Dockerfile.xpu -t vllm-xpu-env --shm-size=4g . -$ docker run -it \ +```bash +docker build -f docker/Dockerfile.xpu -t vllm-xpu-env --shm-size=4g . +docker run -it \ --rm \ --network=host \ --device /dev/dri \ @@ -68,7 +68,7 @@ $ docker run -it \ XPU platform supports **tensor parallel** inference/serving and also supports **pipeline parallel** as a beta feature for online serving. We require Ray as the distributed runtime backend. For example, a reference execution like following: -```console +```bash python -m vllm.entrypoints.openai.api_server \ --model=facebook/opt-13b \ --dtype=bfloat16 \ diff --git a/docs/getting_started/installation/intel_gaudi.md b/docs/getting_started/installation/intel_gaudi.md index 056caa708147..a4f13dca4bf4 100644 --- a/docs/getting_started/installation/intel_gaudi.md +++ b/docs/getting_started/installation/intel_gaudi.md @@ -24,7 +24,7 @@ please follow the methods outlined in the To verify that the Intel Gaudi software was correctly installed, run: -```console +```bash hl-smi # verify that hl-smi is in your PATH and each Gaudi accelerator is visible apt list --installed | grep habana # verify that habanalabs-firmware-tools, habanalabs-graph, habanalabs-rdma-core, habanalabs-thunk and habanalabs-container-runtime are installed pip list | grep habana # verify that habana-torch-plugin, habana-torch-dataloader, habana-pyhlml and habana-media-loader are installed @@ -42,7 +42,7 @@ for more details. Use the following commands to run a Docker image: -```console +```bash docker pull vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest docker run \ -it \ @@ -65,7 +65,7 @@ Currently, there are no pre-built Intel Gaudi wheels. To build and install vLLM from source, run: -```console +```bash git clone https://github.com/vllm-project/vllm.git cd vllm pip install -r requirements/hpu.txt @@ -74,7 +74,7 @@ python setup.py develop Currently, the latest features and performance optimizations are developed in Gaudi's [vLLM-fork](https://github.com/HabanaAI/vllm-fork) and we periodically upstream them to vLLM main repo. To install latest [HabanaAI/vLLM-fork](https://github.com/HabanaAI/vllm-fork), run the following: -```console +```bash git clone https://github.com/HabanaAI/vllm-fork.git cd vllm-fork git checkout habana_main @@ -90,7 +90,7 @@ Currently, there are no pre-built Intel Gaudi images. ### Build image from source -```console +```bash docker build -f docker/Dockerfile.hpu -t vllm-hpu-env . docker run \ -it \ diff --git a/docs/getting_started/installation/python_env_setup.inc.md b/docs/getting_started/installation/python_env_setup.inc.md index 911301d68335..423bf9b00d07 100644 --- a/docs/getting_started/installation/python_env_setup.inc.md +++ b/docs/getting_started/installation/python_env_setup.inc.md @@ -1,6 +1,6 @@ It's recommended to use [uv](https://docs.astral.sh/uv/), a very fast Python environment manager, to create and manage Python environments. Please follow the [documentation](https://docs.astral.sh/uv/#getting-started) to install `uv`. After installing `uv`, you can create a new Python environment and install vLLM using the following commands: -```console +```bash uv venv --python 3.12 --seed source .venv/bin/activate ``` diff --git a/docs/getting_started/quickstart.md b/docs/getting_started/quickstart.md index d02cb18bcb94..39100e4ca540 100644 --- a/docs/getting_started/quickstart.md +++ b/docs/getting_started/quickstart.md @@ -19,7 +19,7 @@ If you are using NVIDIA GPUs, you can install vLLM using [pip](https://pypi.org/ It's recommended to use [uv](https://docs.astral.sh/uv/), a very fast Python environment manager, to create and manage Python environments. Please follow the [documentation](https://docs.astral.sh/uv/#getting-started) to install `uv`. After installing `uv`, you can create a new Python environment and install vLLM using the following commands: -```console +```bash uv venv --python 3.12 --seed source .venv/bin/activate uv pip install vllm --torch-backend=auto @@ -29,13 +29,13 @@ uv pip install vllm --torch-backend=auto Another delightful way is to use `uv run` with `--with [dependency]` option, which allows you to run commands such as `vllm serve` without creating any permanent environment: -```console +```bash uv run --with vllm vllm --help ``` You can also use [conda](https://docs.conda.io/projects/conda/en/latest/user-guide/getting-started.html) to create and manage Python environments. You can install `uv` to the conda environment through `pip` if you want to manage it within the environment. -```console +```bash conda create -n myenv python=3.12 -y conda activate myenv pip install --upgrade uv @@ -110,7 +110,7 @@ By default, it starts the server at `http://localhost:8000`. You can specify the Run the following command to start the vLLM server with the [Qwen2.5-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2.5-1.5B-Instruct) model: -```console +```bash vllm serve Qwen/Qwen2.5-1.5B-Instruct ``` @@ -124,7 +124,7 @@ vllm serve Qwen/Qwen2.5-1.5B-Instruct This server can be queried in the same format as OpenAI API. For example, to list the models: -```console +```bash curl http://localhost:8000/v1/models ``` @@ -134,7 +134,7 @@ You can pass in the argument `--api-key` or environment variable `VLLM_API_KEY` Once your server is started, you can query the model with input prompts: -```console +```bash curl http://localhost:8000/v1/completions \ -H "Content-Type: application/json" \ -d '{ @@ -172,7 +172,7 @@ vLLM is designed to also support the OpenAI Chat Completions API. The chat inter You can use the [create chat completion](https://platform.openai.com/docs/api-reference/chat/completions/create) endpoint to interact with the model: -```console +```bash curl http://localhost:8000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ diff --git a/docs/models/extensions/runai_model_streamer.md b/docs/models/extensions/runai_model_streamer.md index 6755b574ea67..60b43d21d9f6 100644 --- a/docs/models/extensions/runai_model_streamer.md +++ b/docs/models/extensions/runai_model_streamer.md @@ -9,27 +9,27 @@ Further reading can be found in [Run:ai Model Streamer Documentation](https://gi vLLM supports loading weights in Safetensors format using the Run:ai Model Streamer. You first need to install vLLM RunAI optional dependency: -```console +```bash pip3 install vllm[runai] ``` To run it as an OpenAI-compatible server, add the `--load-format runai_streamer` flag: -```console +```bash vllm serve /home/meta-llama/Llama-3.2-3B-Instruct \ --load-format runai_streamer ``` To run model from AWS S3 object store run: -```console +```bash vllm serve s3://core-llm/Llama-3-8b \ --load-format runai_streamer ``` To run model from a S3 compatible object store run: -```console +```bash RUNAI_STREAMER_S3_USE_VIRTUAL_ADDRESSING=0 \ AWS_EC2_METADATA_DISABLED=true \ AWS_ENDPOINT_URL=https://storage.googleapis.com \ @@ -44,7 +44,7 @@ You can tune parameters using `--model-loader-extra-config`: You can tune `concurrency` that controls the level of concurrency and number of OS threads reading tensors from the file to the CPU buffer. For reading from S3, it will be the number of client instances the host is opening to the S3 server. -```console +```bash vllm serve /home/meta-llama/Llama-3.2-3B-Instruct \ --load-format runai_streamer \ --model-loader-extra-config '{"concurrency":16}' @@ -53,7 +53,7 @@ vllm serve /home/meta-llama/Llama-3.2-3B-Instruct \ You can control the size of the CPU Memory buffer to which tensors are read from the file, and limit this size. You can read further about CPU buffer memory limiting [here](https://github.com/run-ai/runai-model-streamer/blob/master/docs/src/env-vars.md#runai_streamer_memory_limit). -```console +```bash vllm serve /home/meta-llama/Llama-3.2-3B-Instruct \ --load-format runai_streamer \ --model-loader-extra-config '{"memory_limit":5368709120}' @@ -66,13 +66,13 @@ vllm serve /home/meta-llama/Llama-3.2-3B-Instruct \ vLLM also supports loading sharded models using Run:ai Model Streamer. This is particularly useful for large models that are split across multiple files. To use this feature, use the `--load-format runai_streamer_sharded` flag: -```console +```bash vllm serve /path/to/sharded/model --load-format runai_streamer_sharded ``` The sharded loader expects model files to follow the same naming pattern as the regular sharded state loader: `model-rank-{rank}-part-{part}.safetensors`. You can customize this pattern using the `pattern` parameter in `--model-loader-extra-config`: -```console +```bash vllm serve /path/to/sharded/model \ --load-format runai_streamer_sharded \ --model-loader-extra-config '{"pattern":"custom-model-rank-{rank}-part-{part}.safetensors"}' @@ -82,7 +82,7 @@ To create sharded model files, you can use the script provided in + +#### Extra Parameters + +The following [sampling parameters][sampling-params] are supported. + +```python +--8<-- "vllm/entrypoints/openai/protocol.py:translation-sampling-params" +``` + +The following extra parameters are supported: + +```python +--8<-- "vllm/entrypoints/openai/protocol.py:translation-extra-params" +``` [](){ #tokenizer-api } diff --git a/docs/usage/metrics.md b/docs/usage/metrics.md index 988b9a551725..4350ab5025f5 100644 --- a/docs/usage/metrics.md +++ b/docs/usage/metrics.md @@ -6,7 +6,7 @@ OpenAI compatible API server. You can start the server using Python, or using [Docker][deployment-docker]: -```console +```bash vllm serve unsloth/Llama-3.2-1B-Instruct ``` diff --git a/docs/usage/troubleshooting.md b/docs/usage/troubleshooting.md index 631c8c40cfec..82957d33b19e 100644 --- a/docs/usage/troubleshooting.md +++ b/docs/usage/troubleshooting.md @@ -127,13 +127,13 @@ If GPU/CPU communication cannot be established, you can use the following Python If you are testing with a single node, adjust `--nproc-per-node` to the number of GPUs you want to use: -```console +```bash NCCL_DEBUG=TRACE torchrun --nproc-per-node= test.py ``` If you are testing with multi-nodes, adjust `--nproc-per-node` and `--nnodes` according to your setup and set `MASTER_ADDR` to the correct IP address of the master node, reachable from all nodes. Then, run: -```console +```bash NCCL_DEBUG=TRACE torchrun --nnodes 2 \ --nproc-per-node=2 \ --rdzv_backend=c10d \ diff --git a/examples/offline_inference/openai_batch/README.md b/examples/offline_inference/openai_batch/README.md index ce7529782122..631fde91fcd0 100644 --- a/examples/offline_inference/openai_batch/README.md +++ b/examples/offline_inference/openai_batch/README.md @@ -29,14 +29,14 @@ We currently support `/v1/chat/completions`, `/v1/embeddings`, and `/v1/score` e To follow along with this example, you can download the example batch, or create your own batch file in your working directory. -```console +```bash wget https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai_batch/openai_example_batch.jsonl ``` Once you've created your batch file it should look like this -```console -$ cat offline_inference/openai_batch/openai_example_batch.jsonl +```bash +cat offline_inference/openai_batch/openai_example_batch.jsonl {"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_completion_tokens": 1000}} {"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_completion_tokens": 1000}} ``` @@ -47,7 +47,7 @@ The batch running tool is designed to be used from the command line. You can run the batch with the following command, which will write its results to a file called `results.jsonl` -```console +```bash python -m vllm.entrypoints.openai.run_batch \ -i offline_inference/openai_batch/openai_example_batch.jsonl \ -o results.jsonl \ @@ -56,7 +56,7 @@ python -m vllm.entrypoints.openai.run_batch \ or use command-line: -```console +```bash vllm run-batch \ -i offline_inference/openai_batch/openai_example_batch.jsonl \ -o results.jsonl \ @@ -67,8 +67,8 @@ vllm run-batch \ You should now have your results at `results.jsonl`. You can check your results by running `cat results.jsonl` -```console -$ cat results.jsonl +```bash +cat results.jsonl {"id":"vllm-383d1c59835645aeb2e07d004d62a826","custom_id":"request-1","response":{"id":"cmpl-61c020e54b964d5a98fa7527bfcdd378","object":"chat.completion","created":1715633336,"model":"meta-llama/Meta-Llama-3-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"Hello! It's great to meet you! I'm here to help with any questions or tasks you may have. What's on your mind today?"},"logprobs":null,"finish_reason":"stop","stop_reason":null}],"usage":{"prompt_tokens":25,"total_tokens":56,"completion_tokens":31}},"error":null} {"id":"vllm-42e3d09b14b04568afa3f1797751a267","custom_id":"request-2","response":{"id":"cmpl-f44d049f6b3a42d4b2d7850bb1e31bcc","object":"chat.completion","created":1715633336,"model":"meta-llama/Meta-Llama-3-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"*silence*"},"logprobs":null,"finish_reason":"stop","stop_reason":null}],"usage":{"prompt_tokens":27,"total_tokens":32,"completion_tokens":5}},"error":null} ``` @@ -79,7 +79,7 @@ The batch runner supports remote input and output urls that are accessible via h For example, to run against our example input file located at `https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai_batch/openai_example_batch.jsonl`, you can run -```console +```bash python -m vllm.entrypoints.openai.run_batch \ -i https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai_batch/openai_example_batch.jsonl \ -o results.jsonl \ @@ -88,7 +88,7 @@ python -m vllm.entrypoints.openai.run_batch \ or use command-line: -```console +```bash vllm run-batch \ -i https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai_batch/openai_example_batch.jsonl \ -o results.jsonl \ @@ -112,21 +112,21 @@ To integrate with cloud blob storage, we recommend using presigned urls. To follow along with this example, you can download the example batch, or create your own batch file in your working directory. -```console +```bash wget https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai_batch/openai_example_batch.jsonl ``` Once you've created your batch file it should look like this -```console -$ cat offline_inference/openai_batch/openai_example_batch.jsonl +```bash +cat offline_inference/openai_batch/openai_example_batch.jsonl {"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_completion_tokens": 1000}} {"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_completion_tokens": 1000}} ``` Now upload your batch file to your S3 bucket. -```console +```bash aws s3 cp offline_inference/openai_batch/openai_example_batch.jsonl s3://MY_BUCKET/MY_INPUT_FILE.jsonl ``` @@ -181,7 +181,7 @@ output_url='https://s3.us-west-2.amazonaws.com/MY_BUCKET/MY_OUTPUT_FILE.jsonl?AW You can now run the batch runner, using the urls generated in the previous section. -```console +```bash python -m vllm.entrypoints.openai.run_batch \ -i "https://s3.us-west-2.amazonaws.com/MY_BUCKET/MY_INPUT_FILE.jsonl?AWSAccessKeyId=ABCDEFGHIJKLMNOPQRST&Signature=abcdefghijklmnopqrstuvwxyz12345&Expires=1715800091" \ -o "https://s3.us-west-2.amazonaws.com/MY_BUCKET/MY_OUTPUT_FILE.jsonl?AWSAccessKeyId=ABCDEFGHIJKLMNOPQRST&Signature=abcdefghijklmnopqrstuvwxyz12345&Expires=1715800091" \ @@ -190,7 +190,7 @@ python -m vllm.entrypoints.openai.run_batch \ or use command-line: -```console +```bash vllm run-batch \ -i "https://s3.us-west-2.amazonaws.com/MY_BUCKET/MY_INPUT_FILE.jsonl?AWSAccessKeyId=ABCDEFGHIJKLMNOPQRST&Signature=abcdefghijklmnopqrstuvwxyz12345&Expires=1715800091" \ -o "https://s3.us-west-2.amazonaws.com/MY_BUCKET/MY_OUTPUT_FILE.jsonl?AWSAccessKeyId=ABCDEFGHIJKLMNOPQRST&Signature=abcdefghijklmnopqrstuvwxyz12345&Expires=1715800091" \ @@ -201,7 +201,7 @@ vllm run-batch \ Your results are now on S3. You can view them in your terminal by running -```console +```bash aws s3 cp s3://MY_BUCKET/MY_OUTPUT_FILE.jsonl - ``` @@ -230,8 +230,8 @@ You can run the batch using the same command as in earlier examples. You can check your results by running `cat results.jsonl` -```console -$ cat results.jsonl +```bash +cat results.jsonl {"id":"vllm-db0f71f7dec244e6bce530e0b4ef908b","custom_id":"request-1","response":{"status_code":200,"request_id":"vllm-batch-3580bf4d4ae54d52b67eee266a6eab20","body":{"id":"embd-33ac2efa7996430184461f2e38529746","object":"list","created":444647,"model":"intfloat/e5-mistral-7b-instruct","data":[{"index":0,"object":"embedding","embedding":[0.016204833984375,0.0092010498046875,0.0018358230590820312,-0.0028228759765625,0.001422882080078125,-0.0031147003173828125,...]}],"usage":{"prompt_tokens":8,"total_tokens":8,"completion_tokens":0}}},"error":null} ... ``` @@ -261,8 +261,8 @@ You can run the batch using the same command as in earlier examples. You can check your results by running `cat results.jsonl` -```console -$ cat results.jsonl +```bash +cat results.jsonl {"id":"vllm-f87c5c4539184f618e555744a2965987","custom_id":"request-1","response":{"status_code":200,"request_id":"vllm-batch-806ab64512e44071b37d3f7ccd291413","body":{"id":"score-4ee45236897b4d29907d49b01298cdb1","object":"list","created":1737847944,"model":"BAAI/bge-reranker-v2-m3","data":[{"index":0,"object":"score","score":0.0010900497436523438},{"index":1,"object":"score","score":1.0}],"usage":{"prompt_tokens":37,"total_tokens":37,"completion_tokens":0,"prompt_tokens_details":null}}},"error":null} {"id":"vllm-41990c51a26d4fac8419077f12871099","custom_id":"request-2","response":{"status_code":200,"request_id":"vllm-batch-73ce66379026482699f81974e14e1e99","body":{"id":"score-13f2ffe6ba40460fbf9f7f00ad667d75","object":"list","created":1737847944,"model":"BAAI/bge-reranker-v2-m3","data":[{"index":0,"object":"score","score":0.001094818115234375},{"index":1,"object":"score","score":1.0}],"usage":{"prompt_tokens":37,"total_tokens":37,"completion_tokens":0,"prompt_tokens_details":null}}},"error":null} ``` diff --git a/examples/offline_inference/qwen3_reranker.py b/examples/offline_inference/qwen3_reranker.py index 27c4071bf094..fe3cebc348f1 100644 --- a/examples/offline_inference/qwen3_reranker.py +++ b/examples/offline_inference/qwen3_reranker.py @@ -22,15 +22,19 @@ # If you want to load the official original version, the init parameters are # as follows. -model = LLM( - model=model_name, - task="score", - hf_overrides={ - "architectures": ["Qwen3ForSequenceClassification"], - "classifier_from_token": ["no", "yes"], - "is_original_qwen3_reranker": True, - }, -) + +def get_model() -> LLM: + """Initializes and returns the LLM model for Qwen3-Reranker.""" + return LLM( + model=model_name, + task="score", + hf_overrides={ + "architectures": ["Qwen3ForSequenceClassification"], + "classifier_from_token": ["no", "yes"], + "is_original_qwen3_reranker": True, + }, + ) + # Why do we need hf_overrides for the official original version: # vllm converts it to Qwen3ForSequenceClassification when loaded for @@ -51,7 +55,8 @@ query_template = "{prefix}: {instruction}\n: {query}\n" document_template = ": {doc}{suffix}" -if __name__ == "__main__": + +def main() -> None: instruction = ( "Given a web search query, retrieve relevant passages that answer the query" ) @@ -72,6 +77,13 @@ ] documents = [document_template.format(doc=doc, suffix=suffix) for doc in documents] + model = get_model() outputs = model.score(queries, documents) + print("-" * 30) print([output.outputs.score for output in outputs]) + print("-" * 30) + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/openai_transcription_client.py b/examples/online_serving/openai_transcription_client.py index ae43cb5da790..755038a76139 100644 --- a/examples/online_serving/openai_transcription_client.py +++ b/examples/online_serving/openai_transcription_client.py @@ -26,23 +26,12 @@ from vllm.assets.audio import AudioAsset -mary_had_lamb = AudioAsset("mary_had_lamb").get_local_path() -winning_call = AudioAsset("winning_call").get_local_path() -# Modify OpenAI's API key and API base to use vLLM's API server. -openai_api_key = "EMPTY" -openai_api_base = "http://localhost:8000/v1" -client = OpenAI( - api_key=openai_api_key, - base_url=openai_api_base, -) - - -def sync_openai(): +def sync_openai(audio_path: str, client: OpenAI): """ Perform synchronous transcription using OpenAI-compatible API. """ - with open(str(mary_had_lamb), "rb") as f: + with open(audio_path, "rb") as f: transcription = client.audio.transcriptions.create( file=f, model="openai/whisper-large-v3", @@ -58,8 +47,7 @@ def sync_openai(): print("transcription result:", transcription.text) -# OpenAI Transcription API client does not support streaming. -async def stream_openai_response(): +async def stream_openai_response(audio_path: str, base_url: str, api_key: str): """ Perform streaming transcription using vLLM's raw HTTP streaming API. """ @@ -68,11 +56,12 @@ async def stream_openai_response(): "stream": True, "model": "openai/whisper-large-v3", } - url = openai_api_base + "/audio/transcriptions" - headers = {"Authorization": f"Bearer {openai_api_key}"} + url = base_url + "/audio/transcriptions" + headers = {"Authorization": f"Bearer {api_key}"} print("transcription result:", end=" ") + # OpenAI Transcription API client does not support streaming. async with httpx.AsyncClient() as client: - with open(str(winning_call), "rb") as f: + with open(audio_path, "rb") as f: async with client.stream( "POST", url, files={"file": f}, data=data, headers=headers ) as response: @@ -93,10 +82,20 @@ async def stream_openai_response(): def main(): - sync_openai() - + mary_had_lamb = str(AudioAsset("mary_had_lamb").get_local_path()) + winning_call = str(AudioAsset("winning_call").get_local_path()) + + # Modify OpenAI's API key and API base to use vLLM's API server. + openai_api_key = "EMPTY" + openai_api_base = "http://localhost:8000/v1" + client = OpenAI( + api_key=openai_api_key, + base_url=openai_api_base, + ) + + sync_openai(mary_had_lamb, client) # Run the asynchronous function - asyncio.run(stream_openai_response()) + asyncio.run(stream_openai_response(winning_call, openai_api_base, openai_api_key)) if __name__ == "__main__": diff --git a/examples/online_serving/openai_translation_client.py b/examples/online_serving/openai_translation_client.py new file mode 100644 index 000000000000..6f7253e2a789 --- /dev/null +++ b/examples/online_serving/openai_translation_client.py @@ -0,0 +1,75 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import asyncio +import json + +import httpx +from openai import OpenAI + +from vllm.assets.audio import AudioAsset + + +def sync_openai(audio_path: str, client: OpenAI): + with open(audio_path, "rb") as f: + translation = client.audio.translations.create( + file=f, + model="openai/whisper-large-v3", + response_format="json", + temperature=0.0, + # Additional params not provided by OpenAI API. + extra_body=dict( + language="it", + seed=4419, + repetition_penalty=1.3, + ), + ) + print("translation result:", translation.text) + + +async def stream_openai_response(audio_path: str, base_url: str, api_key: str): + data = { + "language": "it", + "stream": True, + "model": "openai/whisper-large-v3", + } + url = base_url + "/audio/translations" + headers = {"Authorization": f"Bearer {api_key}"} + print("translation result:", end=" ") + # OpenAI translation API client does not support streaming. + async with httpx.AsyncClient() as client: + with open(audio_path, "rb") as f: + async with client.stream( + "POST", url, files={"file": f}, data=data, headers=headers + ) as response: + async for line in response.aiter_lines(): + # Each line is a JSON object prefixed with 'data: ' + if line: + if line.startswith("data: "): + line = line[len("data: ") :] + # Last chunk, stream ends + if line.strip() == "[DONE]": + break + # Parse the JSON response + chunk = json.loads(line) + # Extract and print the content + content = chunk["choices"][0].get("delta", {}).get("content") + print(content, end="") + + +def main(): + foscolo = str(AudioAsset("azacinto_foscolo").get_local_path()) + + # Modify OpenAI's API key and API base to use vLLM's API server. + openai_api_key = "EMPTY" + openai_api_base = "http://localhost:8000/v1" + client = OpenAI( + api_key=openai_api_key, + base_url=openai_api_base, + ) + sync_openai(foscolo, client) + # Run the asynchronous function + asyncio.run(stream_openai_response(foscolo, openai_api_base, openai_api_key)) + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/opentelemetry/README.md b/examples/online_serving/opentelemetry/README.md index af0034007974..ae5d84d8ef19 100644 --- a/examples/online_serving/opentelemetry/README.md +++ b/examples/online_serving/opentelemetry/README.md @@ -2,7 +2,7 @@ 1. Install OpenTelemetry packages: - ```console + ```bash pip install \ 'opentelemetry-sdk>=1.26.0,<1.27.0' \ 'opentelemetry-api>=1.26.0,<1.27.0' \ @@ -12,7 +12,7 @@ 1. Start Jaeger in a docker container: - ```console + ```bash # From: https://www.jaegertracing.io/docs/1.57/getting-started/ docker run --rm --name jaeger \ -e COLLECTOR_ZIPKIN_HOST_PORT=:9411 \ @@ -31,14 +31,14 @@ 1. In a new shell, export Jaeger IP: - ```console + ```bash export JAEGER_IP=$(docker inspect --format '{{ .NetworkSettings.IPAddress }}' jaeger) export OTEL_EXPORTER_OTLP_TRACES_ENDPOINT=grpc://$JAEGER_IP:4317 ``` Then set vLLM's service name for OpenTelemetry, enable insecure connections to Jaeger and run vLLM: - ```console + ```bash export OTEL_SERVICE_NAME="vllm-server" export OTEL_EXPORTER_OTLP_TRACES_INSECURE=true vllm serve facebook/opt-125m --otlp-traces-endpoint="$OTEL_EXPORTER_OTLP_TRACES_ENDPOINT" @@ -46,7 +46,7 @@ 1. In a new shell, send requests with trace context from a dummy client - ```console + ```bash export JAEGER_IP=$(docker inspect --format '{{ .NetworkSettings.IPAddress }}' jaeger) export OTEL_EXPORTER_OTLP_TRACES_ENDPOINT=grpc://$JAEGER_IP:4317 export OTEL_EXPORTER_OTLP_TRACES_INSECURE=true @@ -67,7 +67,7 @@ OpenTelemetry supports either `grpc` or `http/protobuf` as the transport protocol for trace data in the exporter. By default, `grpc` is used. To set `http/protobuf` as the protocol, configure the `OTEL_EXPORTER_OTLP_TRACES_PROTOCOL` environment variable as follows: -```console +```bash export OTEL_EXPORTER_OTLP_TRACES_PROTOCOL=http/protobuf export OTEL_EXPORTER_OTLP_TRACES_ENDPOINT=http://$JAEGER_IP:4318/v1/traces vllm serve facebook/opt-125m --otlp-traces-endpoint="$OTEL_EXPORTER_OTLP_TRACES_ENDPOINT" @@ -79,13 +79,13 @@ OpenTelemetry allows automatic instrumentation of FastAPI. 1. Install the instrumentation library - ```console + ```bash pip install opentelemetry-instrumentation-fastapi ``` 1. Run vLLM with `opentelemetry-instrument` - ```console + ```bash opentelemetry-instrument vllm serve facebook/opt-125m ``` diff --git a/requirements/common.txt b/requirements/common.txt index 639abe511017..9a9ae1d93896 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -44,3 +44,4 @@ watchfiles # required for http server to monitor the updates of TLS files python-json-logger # Used by logging as per examples/others/logging_configuration.md scipy # Required for phi-4-multimodal-instruct ninja # Required for xgrammar, rocm, tpu, xpu +pybase64 # fast base64 implementation diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index 37d8ae0c08bf..8679d5c3019b 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -5,6 +5,15 @@ import vllm from vllm.compilation.counter import compilation_counter from vllm.config import VllmConfig +from vllm.utils import _is_torch_equal_or_newer + + +def test_version(): + assert _is_torch_equal_or_newer('2.8.0.dev20250624+cu128', '2.8.0.dev') + assert _is_torch_equal_or_newer('2.8.0a0+gitc82a174', '2.8.0.dev') + assert _is_torch_equal_or_newer('2.8.0', '2.8.0.dev') + assert _is_torch_equal_or_newer('2.8.1', '2.8.0.dev') + assert not _is_torch_equal_or_newer('2.7.1', '2.8.0.dev') def test_use_cudagraphs_dynamic(monkeypatch): diff --git a/tests/compile/test_sequence_parallelism.py b/tests/compile/test_sequence_parallelism.py index c689befdf2da..b56edfc90612 100644 --- a/tests/compile/test_sequence_parallelism.py +++ b/tests/compile/test_sequence_parallelism.py @@ -6,7 +6,9 @@ import vllm.envs as envs from vllm.compilation.fix_functionalization import FixFunctionalizationPass +from vllm.compilation.fusion import FusionPass from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func +from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.sequence_parallelism import SequenceParallelismPass from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig, PassConfig, VllmConfig) @@ -14,12 +16,15 @@ from vllm.distributed.parallel_state import (init_distributed_environment, initialize_model_parallel) from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + Fp8LinearOp) from vllm.platforms import current_platform from vllm.utils import update_environment_variables from ..utils import multi_gpu_test from .backend import TestBackend +FP8_DTYPE = current_platform.fp8_dtype() prompts = [ "Hello, my name is", "The president of the United States is", @@ -30,13 +35,16 @@ class TestModel(torch.nn.Module): - def __init__(self, hidden_size=16, intermediate_size=32): + def __init__(self, + hidden_size=16, + intermediate_size=32, + vllm_config: VllmConfig = None): super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.gate_proj = torch.nn.Parameter( torch.empty((intermediate_size, hidden_size))) - self.norm = RMSNorm(hidden_size, 1e-05) + self.norm = RMSNorm(intermediate_size, 1e-05) # Initialize weights torch.nn.init.normal_(self.gate_proj, std=0.02) @@ -79,32 +87,138 @@ def ops_in_model(self): return [torch.ops._C.fused_add_rms_norm.default] +class TestQuantModel(torch.nn.Module): + + def __init__(self, + hidden_size=16, + intermediate_size=32, + vllm_config: VllmConfig = None): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.vllm_config = vllm_config + self.gate_proj = torch.nn.Parameter(torch.empty( + (intermediate_size, hidden_size)), + requires_grad=False) + self.norm = RMSNorm(intermediate_size, 1e-05) + # Initialize weights + torch.nn.init.normal_(self.gate_proj, std=0.02) + + self.fp8_linear = Fp8LinearOp(cutlass_fp8_supported=True, + use_per_token_if_dynamic=False) + + self.scale = torch.rand(1, dtype=torch.float32) + # Create a weight that is compatible with torch._scaled_mm, + # which expects a column-major layout. + self.w = torch.rand(hidden_size, + intermediate_size).to(dtype=FP8_DTYPE).t() + self.wscale = torch.rand(1, dtype=torch.float32) + + def forward(self, hidden_states, residual): + """ + Forward pass implementing the operations in the FX graph + + Args: + hidden_states: Input tensor + residual: Residual tensor from previous layer + + Returns: + Tuple containing the output tensor + """ + # Reshape input + view = hidden_states.reshape(-1, self.hidden_size) + + #matrix multiplication + permute = self.gate_proj.permute(1, 0) + mm = torch.mm(view, permute) + + # Tensor parallel all-reduce + all_reduce = tensor_model_parallel_all_reduce(mm) + + # layer normalization + norm_output, residual_output = self.norm(all_reduce, residual) + + # for static input quantization + # self.fp8_linear is initialized with use_per_token_if_dynamic=False + fp8_linear_result = self.fp8_linear.apply(norm_output, + self.w, + self.wscale, + input_scale=self.scale.to( + norm_output.device)) + + return fp8_linear_result, residual_output + + def ops_in_model_before(self): + ops_to_remove = [torch.ops.vllm.all_reduce.default + ] # Always removed by SP + # The following are only removed if fusion happens + if self.vllm_config and self.vllm_config.compilation_config \ + .pass_config.enable_fusion: + ops_to_remove.extend([ + torch.ops._C.fused_add_rms_norm.default, + torch.ops._C.static_scaled_fp8_quant.default, + ]) + return ops_to_remove + + def ops_in_model_after(self): + ops_to_add = [ + torch.ops.vllm.reduce_scatter.default, + torch.ops.vllm.all_gather.default + ] + # The following is only added if fusion happens + if self.vllm_config and self.vllm_config.compilation_config \ + .pass_config.enable_fusion: + ops_to_add.append( + torch.ops._C.fused_add_rms_norm_static_fp8_quant.default) + return ops_to_add + + def ops_in_model(self): + if self.vllm_config and self.vllm_config.compilation_config \ + .pass_config.enable_fusion: + # If fusion happens, the fused op is the one + # we check for (de)functionalization + return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default + ] # noqa: E501 + else: + # If no fusion, the original ops are checked + return [ + torch.ops._C.fused_add_rms_norm.default, + # TODO functionalization pass does not handle this yet + # torch.ops._C.static_scaled_fp8_quant.default, + ] + + @multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize("test_model_cls", [TestModel, TestQuantModel]) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("seq_len", [16]) @pytest.mark.parametrize("hidden_size", [16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("enable_fusion", [True, False]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") -def test_sequence_parallelism_pass(batch_size: int, seq_len: int, - hidden_size: int, dtype: torch.dtype): +def test_sequence_parallelism_pass(test_model_cls: type[torch.nn.Module], + batch_size: int, seq_len: int, + hidden_size: int, dtype: torch.dtype, + enable_fusion: bool): num_processes = 2 def run_torch_spawn(fn, nprocs): # need to use torch.mp.spawn otherwise will have problems with # torch.distributed and cuda torch.multiprocessing.spawn(fn, - args=(num_processes, batch_size, seq_len, - hidden_size, dtype), + args=(num_processes, test_model_cls, + batch_size, seq_len, hidden_size, + dtype, enable_fusion), nprocs=nprocs) run_torch_spawn(sequence_parallelism_pass_on_test_model, num_processes) -def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int, - batch_size: int, seq_len: int, - hidden_size: int, - dtype: torch.dtype): +def sequence_parallelism_pass_on_test_model( + local_rank: int, world_size: int, + test_model_cls: type[torch.nn.Module], batch_size: int, seq_len: int, + hidden_size: int, dtype: torch.dtype, enable_fusion: bool): current_platform.seed_everything(0) device = torch.device(f"cuda:{local_rank}") @@ -127,26 +241,39 @@ def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int, # configure vllm config for SequenceParallelismPass vllm_config = VllmConfig() vllm_config.compilation_config = CompilationConfig(pass_config=PassConfig( - enable_sequence_parallelism=True)) + enable_sequence_parallelism=True, + enable_fusion=enable_fusion, + enable_noop=True)) # NoOp needed for fusion vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) # this is a fake model name to construct the model config # in the vllm_config, it's not really used. - model = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e" - vllm_config.model_config = ModelConfig(model=model, + model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e" + vllm_config.model_config = ModelConfig(model=model_name, task="auto", - tokenizer=model, + tokenizer=model_name, tokenizer_mode="auto", trust_remote_code=True, dtype=dtype, seed=42) sequence_parallelism_pass = SequenceParallelismPass(vllm_config) - backend_no_func = TestBackend(sequence_parallelism_pass) + noop_pass = NoOpEliminationPass(vllm_config) func_pass = FixFunctionalizationPass(vllm_config) - backend_func = TestBackend(sequence_parallelism_pass, func_pass) - model = TestModel(hidden_size, hidden_size * 2) + passes_for_backend = [noop_pass, sequence_parallelism_pass] + + if enable_fusion: + fusion_pass = FusionPass.instance(vllm_config) + passes_for_backend.append(fusion_pass) + + backend_no_func = TestBackend(*passes_for_backend) + backend_func = TestBackend(*passes_for_backend, func_pass) + + model = test_model_cls(hidden_size, + hidden_size * 2, + vllm_config=vllm_config) + hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) diff --git a/tests/conftest.py b/tests/conftest.py index f50e611a471b..feb52e26300a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1027,13 +1027,13 @@ def classify(self, prompts: list[str]) -> list[list[float]]: req_outputs = self.model.classify(prompts) return [req_output.outputs.probs for req_output in req_outputs] - def encode(self, - prompts: list[str], - images: Optional[PromptImageInput] = None, - videos: Optional[PromptVideoInput] = None, - audios: Optional[PromptAudioInput] = None, - *args, - **kwargs) -> list[list[float]]: + def embed(self, + prompts: list[str], + images: Optional[PromptImageInput] = None, + videos: Optional[PromptVideoInput] = None, + audios: Optional[PromptAudioInput] = None, + *args, + **kwargs) -> list[list[float]]: inputs = self.get_inputs(prompts, images=images, videos=videos, @@ -1042,6 +1042,10 @@ def encode(self, req_outputs = self.model.embed(inputs, *args, **kwargs) return [req_output.outputs.embedding for req_output in req_outputs] + def encode(self, prompts: list[str]) -> list[list[float]]: + req_outputs = self.model.encode(prompts) + return [req_output.outputs.data for req_output in req_outputs] + def score( self, text_1: Union[str, list[str]], diff --git a/tests/distributed/test_sequence_parallel.py b/tests/distributed/test_sequence_parallel.py index 91a594eac5c4..b2f6a8ab9dd3 100644 --- a/tests/distributed/test_sequence_parallel.py +++ b/tests/distributed/test_sequence_parallel.py @@ -28,7 +28,7 @@ class ParallelSetup(NamedTuple): tp_size: int pp_size: int - sp_enabled: bool + enable_fusion: bool eager_mode: bool chunked_prefill: bool @@ -67,49 +67,18 @@ def detailed( task: TaskOption = "auto", load_format: Optional[str] = None, ): + parallel_setups = [] + for eager_mode_val in [False, True]: + for pp_multiplier in [1, 2]: + for chunked_prefill_val in [False, True]: + parallel_setups.append( + ParallelSetup(tp_size=tp_base, + pp_size=pp_multiplier * pp_base, + enable_fusion=False, + eager_mode=eager_mode_val, + chunked_prefill=chunked_prefill_val)) return SPTestSettings( - parallel_setups=[ - ParallelSetup(tp_size=tp_base, - pp_size=pp_base, - sp_enabled=True, - eager_mode=False, - chunked_prefill=False), - ParallelSetup(tp_size=tp_base, - pp_size=pp_base, - sp_enabled=True, - eager_mode=False, - chunked_prefill=True), - ParallelSetup(tp_size=tp_base, - pp_size=pp_base, - sp_enabled=True, - eager_mode=True, - chunked_prefill=False), - ParallelSetup(tp_size=tp_base, - pp_size=pp_base, - sp_enabled=True, - eager_mode=True, - chunked_prefill=True), - ParallelSetup(tp_size=tp_base, - pp_size=2 * pp_base, - sp_enabled=True, - eager_mode=False, - chunked_prefill=False), - ParallelSetup(tp_size=tp_base, - pp_size=2 * pp_base, - sp_enabled=True, - eager_mode=False, - chunked_prefill=True), - ParallelSetup(tp_size=tp_base, - pp_size=2 * pp_base, - sp_enabled=True, - eager_mode=True, - chunked_prefill=False), - ParallelSetup(tp_size=tp_base, - pp_size=2 * pp_base, - sp_enabled=True, - eager_mode=True, - chunked_prefill=True) - ], + parallel_setups=parallel_setups, distributed_backends=["mp", "ray"], vllm_major_versions=["1", "1"], task=task, @@ -126,19 +95,44 @@ def fast( multi_node_only: bool = False, load_format: Optional[str] = None, ): + parallel_setups = [] + for eager_mode_val in [False, True]: + for pp_multiplier in [1, 2]: + for chunked_prefill_val in [False, True]: + parallel_setups.append( + ParallelSetup(tp_size=tp_base, + pp_size=pp_multiplier * pp_base, + enable_fusion=False, + eager_mode=eager_mode_val, + chunked_prefill=chunked_prefill_val)) return SPTestSettings( - parallel_setups=[ + parallel_setups=parallel_setups, + distributed_backends=["mp", "ray"], + vllm_major_versions=["1", "1"], + task=task, + test_options=SPTestOptions(multi_node_only=multi_node_only, + load_format=load_format), + ) + + @staticmethod + def fp8_quant( + *, + tp_base: int = 2, + pp_base: int = 1, + task: TaskOption = "auto", + multi_node_only: bool = False, + load_format: Optional[str] = None, + ): + parallel_setups = [] + for fusion_val in [False, True]: + parallel_setups.append( ParallelSetup(tp_size=tp_base, pp_size=pp_base, - sp_enabled=True, - eager_mode=False, - chunked_prefill=False), - ParallelSetup(tp_size=tp_base, - pp_size=2 * pp_base, - sp_enabled=True, - eager_mode=False, - chunked_prefill=False), - ], + enable_fusion=fusion_val, + eager_mode=True, + chunked_prefill=False)) + return SPTestSettings( + parallel_setups=parallel_setups, distributed_backends=["mp", "ray"], vllm_major_versions=["1", "1"], task=task, @@ -171,7 +165,7 @@ def _compare_sp( ( tp_size, pp_size, - sp_enabled, + enable_fusion, eager_mode, chunked_prefill, ) = parallel_setup @@ -240,9 +234,9 @@ def _compare_sp( 'compile_sizes': [4, 8], 'splitting_ops': [], 'pass_config': { - 'enable_sequence_parallelism': sp_enabled, + 'enable_sequence_parallelism': True, + 'enable_fusion': enable_fusion, 'enable_noop': True, - 'enable_fusion': True, }, } @@ -291,12 +285,14 @@ def _compare_sp( SP_TEXT_GENERATION_MODELS = { # [Decoder-only] "meta-llama/Llama-3.2-1B-Instruct": SPTestSettings.fast(), + "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8": SPTestSettings.fp8_quant(), } SP_TEST_MODELS = [ # TODO support other models # [LANGUAGE GENERATION] "meta-llama/Llama-3.2-1B-Instruct", + "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8" ] diff --git a/tests/entrypoints/openai/test_transcription_validation.py b/tests/entrypoints/openai/test_transcription_validation.py index 8117e774951e..dab14f1d7d03 100644 --- a/tests/entrypoints/openai/test_transcription_validation.py +++ b/tests/entrypoints/openai/test_transcription_validation.py @@ -82,6 +82,8 @@ async def test_long_audio_request(mary_had_lamb): mary_had_lamb.seek(0) audio, sr = librosa.load(mary_had_lamb) + # Add small silence after each audio for repeatability in the split process + audio = np.pad(audio, (0, 1600)) repeated_audio = np.tile(audio, 10) # Repeated audio to buffer buffer = io.BytesIO() diff --git a/tests/entrypoints/openai/test_translation_validation.py b/tests/entrypoints/openai/test_translation_validation.py new file mode 100644 index 000000000000..0c2cb367f330 --- /dev/null +++ b/tests/entrypoints/openai/test_translation_validation.py @@ -0,0 +1,172 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import io +# imports for guided decoding tests +import json +from unittest.mock import patch + +import librosa +import numpy as np +import pytest +import soundfile as sf +from openai._base_client import AsyncAPIClient + +from vllm.assets.audio import AudioAsset + +from ...utils import RemoteOpenAIServer + + +@pytest.fixture +def foscolo(): + # Test translation it->en + path = AudioAsset('azacinto_foscolo').get_local_path() + with open(str(path), "rb") as f: + yield f + + +# NOTE: (NickLucche) the large-v3-turbo model was not trained on translation! +@pytest.mark.asyncio +async def test_basic_audio(foscolo): + model_name = "openai/whisper-small" + server_args = ["--enforce-eager"] + with RemoteOpenAIServer(model_name, server_args) as remote_server: + client = remote_server.get_async_client() + translation = await client.audio.translations.create( + model=model_name, + file=foscolo, + response_format="text", + # TODO remove once language detection is implemented + extra_body=dict(language="it"), + temperature=0.0) + out = json.loads(translation)['text'].strip() + assert "Nor will I ever touch the sacred" in out + + +@pytest.mark.asyncio +async def test_audio_prompt(foscolo): + model_name = "openai/whisper-small" + server_args = ["--enforce-eager"] + # Condition whisper on starting text + prompt = "Nor have I ever" + with RemoteOpenAIServer(model_name, server_args) as remote_server: + client = remote_server.get_async_client() + transcription = await client.audio.translations.create( + model=model_name, + file=foscolo, + prompt=prompt, + extra_body=dict(language="it"), + response_format="text", + temperature=0.0) + out = json.loads(transcription)['text'] + assert "Nor will I ever touch the sacred" not in out + assert prompt not in out + + +@pytest.mark.asyncio +async def test_non_asr_model(foscolo): + # text to text model + model_name = "JackFram/llama-68m" + server_args = ["--enforce-eager"] + with RemoteOpenAIServer(model_name, server_args) as remote_server: + client = remote_server.get_async_client() + res = await client.audio.translations.create(model=model_name, + file=foscolo, + temperature=0.0) + assert res.code == 400 and not res.text + assert res.message == "The model does not support Translations API" + + +@pytest.mark.asyncio +async def test_streaming_response(foscolo): + model_name = "openai/whisper-small" + server_args = ["--enforce-eager"] + translation = "" + with RemoteOpenAIServer(model_name, server_args) as remote_server: + client = remote_server.get_async_client() + res_no_stream = await client.audio.translations.create( + model=model_name, + file=foscolo, + response_format="json", + extra_body=dict(language="it"), + temperature=0.0) + # Unfortunately this only works when the openai client is patched + # to use streaming mode, not exposed in the translation api. + original_post = AsyncAPIClient.post + + async def post_with_stream(*args, **kwargs): + kwargs['stream'] = True + return await original_post(*args, **kwargs) + + with patch.object(AsyncAPIClient, "post", new=post_with_stream): + client = remote_server.get_async_client() + res = await client.audio.translations.create(model=model_name, + file=foscolo, + temperature=0.0, + extra_body=dict( + stream=True, + language="it")) + # Reconstruct from chunks and validate + async for chunk in res: + # just a chunk + text = chunk.choices[0]['delta']['content'] + translation += text + + assert translation == res_no_stream.text + + +@pytest.mark.asyncio +async def test_stream_options(foscolo): + model_name = "openai/whisper-small" + server_args = ["--enforce-eager"] + with RemoteOpenAIServer(model_name, server_args) as remote_server: + original_post = AsyncAPIClient.post + + async def post_with_stream(*args, **kwargs): + kwargs['stream'] = True + return await original_post(*args, **kwargs) + + with patch.object(AsyncAPIClient, "post", new=post_with_stream): + client = remote_server.get_async_client() + res = await client.audio.translations.create( + model=model_name, + file=foscolo, + temperature=0.0, + extra_body=dict(language="it", + stream=True, + stream_include_usage=True, + stream_continuous_usage_stats=True)) + final = False + continuous = True + async for chunk in res: + if not len(chunk.choices): + # final usage sent + final = True + else: + continuous = continuous and hasattr(chunk, 'usage') + assert final and continuous + + +@pytest.mark.asyncio +async def test_long_audio_request(foscolo): + model_name = "openai/whisper-small" + server_args = ["--enforce-eager"] + + foscolo.seek(0) + audio, sr = librosa.load(foscolo) + repeated_audio = np.tile(audio, 2) + # Repeated audio to buffer + buffer = io.BytesIO() + sf.write(buffer, repeated_audio, sr, format='WAV') + buffer.seek(0) + with RemoteOpenAIServer(model_name, server_args) as remote_server: + client = remote_server.get_async_client() + translation = await client.audio.translations.create( + model=model_name, + file=buffer, + extra_body=dict(language="it"), + response_format="text", + temperature=0.0) + out = json.loads(translation)['text'].strip().lower() + # TODO investigate higher model uncertainty in for longer translations. + assert out.count("nor will i ever") == 2 diff --git a/tests/kernels/attention/test_mla_decode_cpu.py b/tests/kernels/attention/test_mla_decode_cpu.py index 5a7480a6beae..f8b307c595de 100644 --- a/tests/kernels/attention/test_mla_decode_cpu.py +++ b/tests/kernels/attention/test_mla_decode_cpu.py @@ -7,10 +7,7 @@ import vllm._custom_ops as ops from vllm.platforms import current_platform - - -def cdiv(a, b): - return (a + b - 1) // b +from vllm.utils import cdiv def ref_mla( diff --git a/tests/kernels/attention/test_triton_decode_attention.py b/tests/kernels/attention/test_triton_decode_attention.py index 358b374ea75b..2dca720fe330 100644 --- a/tests/kernels/attention/test_triton_decode_attention.py +++ b/tests/kernels/attention/test_triton_decode_attention.py @@ -5,10 +5,7 @@ import torch from vllm.attention.ops.triton_decode_attention import decode_attention_fwd - - -def cdiv(a, b): - return (a + b - 1) // b +from vllm.utils import cdiv @pytest.mark.parametrize("B", [3, 5]) diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py index ce420901e317..158100a09879 100644 --- a/tests/kernels/moe/test_cutlass_moe.py +++ b/tests/kernels/moe/test_cutlass_moe.py @@ -29,7 +29,10 @@ (224, 1024, 1536), (224, 3072, 1024), (224, 3072, 1536), - (1024 * 128, 1024, 1024), + (32768, 1024, 1024), + # These sizes trigger wrong answers. + #(7232, 2048, 5120), + #(40000, 2048, 5120), ] vllm_config = VllmConfig(parallel_config=ParallelConfig( @@ -232,8 +235,10 @@ def test_cutlass_moe_8_bit_no_graph( topk: int, per_act_token: bool, per_out_ch: bool, + monkeypatch, ): current_platform.seed_everything(7) + monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") with set_current_vllm_config(vllm_config): mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_ch) @@ -274,8 +279,10 @@ def test_cutlass_moe_8_bit_cuda_graph( topk: int, per_act_token: bool, per_out_ch: bool, + monkeypatch, ): current_platform.seed_everything(7) + monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") with set_current_vllm_config(vllm_config): dtype = torch.half @@ -329,8 +336,10 @@ def test_cutlass_moe_8_bit_EP( per_act_token: bool, per_out_channel: bool, ep_size: int, + monkeypatch, ): current_platform.seed_everything(7) + monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") with set_current_vllm_config(vllm_config): mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_channel) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index bed374cf4d56..0c31168566e2 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -4,6 +4,9 @@ Run `pytest tests/kernels/test_moe.py`. """ +import functools +from typing import Callable, Optional, Union + import pytest import torch from torch.nn import Parameter @@ -14,6 +17,7 @@ import vllm.model_executor.layers.fused_moe # noqa from tests.kernels.utils import opcheck, stack_and_dev, torch_moe from vllm.config import VllmConfig, set_current_vllm_config +from vllm.forward_context import set_forward_context from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, modular_triton_fused_moe) @@ -40,7 +44,76 @@ vllm_config.scheduler_config.max_model_len = 8192 -@pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128]) +def run_moe_test( + baseline: Union[Callable, torch.Tensor], + moe_fn: Callable, + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + score: torch.Tensor, + topk: int, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + padding: bool = False, + use_compile: bool = False, + use_cudagraph: bool = False, + atol: float = 2e-2, + rtol: float = 0, +) -> torch.Tensor: + if isinstance(baseline, torch.Tensor): + baseline_output = baseline + else: + baseline_output = baseline(a, + w1, + w2, + score, + topk, + global_num_experts=global_num_experts, + expert_map=expert_map) + + # Pad the weight if moe padding is enabled + if padding: + w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128] + w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128] + + if use_compile: + moe_fn = torch.compile(moe_fn, backend="inductor", fullgraph=True) + torch._dynamo.mark_dynamic(a, 0) + torch._dynamo.mark_dynamic(score, 0) + + test_output = moe_fn(a, + w1, + w2, + score, + topk, + global_num_experts=global_num_experts, + expert_map=expert_map) + + if use_cudagraph: + test_output.fill_(0) + stream = torch.cuda.Stream() + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=stream): + test_output = moe_fn(a, + w1, + w2, + score, + topk, + global_num_experts=global_num_experts, + expert_map=expert_map) + torch.cuda.synchronize() + graph.replay() + torch.cuda.synchronize() + + torch.testing.assert_close(test_output, + baseline_output, + atol=atol, + rtol=rtol) + + return baseline_output + + +@pytest.mark.parametrize("m", [1, 33, 64, 222, 32768, 40000]) @pytest.mark.parametrize("n", [128, 1024, 2048]) @pytest.mark.parametrize("k", [128, 511, 1024]) @pytest.mark.parametrize("e", NUM_EXPERTS) @@ -48,6 +121,7 @@ @pytest.mark.parametrize("ep_size", EP_SIZE) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("padding", [True, False]) +@pytest.mark.parametrize("chunk_size", [8192]) def test_fused_moe( m: int, n: int, @@ -57,7 +131,17 @@ def test_fused_moe( ep_size: int, dtype: torch.dtype, padding: bool, + chunk_size: int, + monkeypatch, ): + current_platform.seed_everything(7) + + monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size)) + + # + # Setup test data + # + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 @@ -77,58 +161,70 @@ def test_fused_moe( else: e_map = None - m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=False, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - per_channel_quant=False, - block_shape=None) - - with set_current_vllm_config(vllm_config): - torch_output = torch_moe(a, w1, w2, score, topk, e_map) - iterative_output = iterative_moe(a, - w1, - w2, - score, - topk, - global_num_experts=e, - expert_map=e_map, - renormalize=False) - - # Pad the weight if moe padding is enabled - if padding: - w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128] - torch.cuda.empty_cache() - w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128] - torch.cuda.empty_cache() + # + # Setup test functions + # + + m_fused_moe_fn = modular_triton_fused_moe(use_fp8_w8a8=False, + use_int8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + per_channel_quant=False, + block_shape=None) + + def m_fused_moe( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + score: torch.Tensor, + topk: int, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) + return m_fused_moe_fn(a, + w1, + w2, + topk_weights, + topk_ids, + global_num_experts=global_num_experts, + expert_map=expert_map) + + fused_moe_fn = functools.partial(fused_moe, renormalize=False) + + # + # Run tests + # + runner = functools.partial( + run_moe_test, + a=a, + w1=w1, + w2=w2, + score=score, + topk=topk, + global_num_experts=e, + expert_map=e_map, + padding=padding, + ) - triton_output = fused_moe(a, - w1, - w2, - score, - topk, - global_num_experts=e, - expert_map=e_map, - renormalize=False) + # Note: for now use_compile will error out if the problem size is + # large enough to trigger chunking. I'm leaving the flag and + # setup code in case we are able to revisit this later. + use_compile = False - topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) - m_triton_output = m_fused_moe(a, - w1, - w2, - topk_weights, - topk_ids, - global_num_experts=e, - expert_map=e_map) + use_cudagraph = (n >= 1024 and k >= 1024 + and current_platform.is_cuda_alike()) - torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) - torch.testing.assert_close(m_triton_output, - torch_output, - atol=2e-2, - rtol=0) - torch.testing.assert_close(iterative_output, - torch_output, - atol=2e-2, - rtol=0) + with set_current_vllm_config(vllm_config): + baseline_output = runner(torch_moe, iterative_moe) + runner(baseline_output, + fused_moe_fn, + use_compile=use_compile, + use_cudagraph=use_cudagraph) + runner(baseline_output, + m_fused_moe, + use_compile=use_compile, + use_cudagraph=use_cudagraph) @pytest.mark.parametrize("m", [1, 32, 222]) @@ -238,7 +334,12 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, w1_zp=w1_qzeros if has_zp else None, w2_zp=w2_qzeros if has_zp else None, block_shape=[0, group_size]) - torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, e_map) + torch_output = torch_moe(a, + w1_ref, + w2_ref, + score, + topk, + expert_map=e_map) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) @@ -265,45 +366,51 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, pytest.skip("AITER ROCm test skip for float32") # Instantiate our and huggingface's MoE blocks - config = MixtralConfig() - hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda") - vllm_moe = MixtralMoE( - num_experts=config.num_local_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - params_dtype=dtype, - tp_size=1, - dp_size=1, - ).cuda() - - # Load the weights - vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data - for i in range(config.num_local_experts): - weights = (hf_moe.experts[i].w1.weight.data, - hf_moe.experts[i].w3.weight.data) - vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0) - vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data - - # Generate input batch of dimensions [batch_size, seq_len, hidden_dim] - hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda") - # vLLM uses 1D query [num_tokens, hidden_dim] - vllm_inputs = hf_inputs.flatten(0, 1) + vllm_config.compilation_config.static_forward_context = dict() + with (set_current_vllm_config(vllm_config), + set_forward_context(None, vllm_config)): + config = MixtralConfig() + hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda") + vllm_moe = MixtralMoE( + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + params_dtype=dtype, + tp_size=1, + dp_size=1, + ).cuda() + + # Load the weights + vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data + for i in range(config.num_local_experts): + weights = (hf_moe.experts[i].w1.weight.data, + hf_moe.experts[i].w3.weight.data) + vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0) + vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data + + # Generate input batch of dimensions [batch_size, seq_len, hidden_dim] + hf_inputs = torch.randn( + (1, 64, config.hidden_size)).to(dtype).to("cuda") + # vLLM uses 1D query [num_tokens, hidden_dim] + vllm_inputs = hf_inputs.flatten(0, 1) - # Pad the weight if moe padding is enabled - if padding: - vllm_moe.experts.w13_weight = Parameter(F.pad( - vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[..., 0:-128], - requires_grad=False) - torch.cuda.empty_cache() - vllm_moe.experts.w2_weight = Parameter(F.pad( - vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[..., 0:-128], - requires_grad=False) - torch.cuda.empty_cache() - - # Run forward passes for both MoE blocks - hf_states, _ = hf_moe.forward(hf_inputs) - vllm_states = vllm_moe.forward(vllm_inputs) + # Pad the weight if moe padding is enabled + if padding: + vllm_moe.experts.w13_weight = Parameter(F.pad( + vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[..., + 0:-128], + requires_grad=False) + torch.cuda.empty_cache() + vllm_moe.experts.w2_weight = Parameter(F.pad( + vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[..., + 0:-128], + requires_grad=False) + torch.cuda.empty_cache() + + # Run forward passes for both MoE blocks + hf_states, _ = hf_moe.forward(hf_inputs) + vllm_states = vllm_moe.forward(vllm_inputs) mixtral_moe_tol = { torch.float32: 1e-3, @@ -546,7 +653,12 @@ def test_fused_marlin_moe( topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) with set_current_vllm_config(vllm_config): - torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map) + torch_output = torch_moe(a, + w_ref1, + w_ref2, + score, + topk, + expert_map=e_map) marlin_output = torch.ops.vllm.fused_marlin_moe( a, diff --git a/tests/kernels/moe/test_nvfp4_moe.py b/tests/kernels/moe/test_nvfp4_moe.py index 22482d9ca85a..76b560e1bb41 100644 --- a/tests/kernels/moe/test_nvfp4_moe.py +++ b/tests/kernels/moe/test_nvfp4_moe.py @@ -136,7 +136,7 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, device=w2.device, block_size=quant_blocksize) - torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk, None) + torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk) torch.testing.assert_close(torch_output, cutlass_output, diff --git a/tests/kernels/moe/test_pplx_cutlass_moe.py b/tests/kernels/moe/test_pplx_cutlass_moe.py index d90202dfcb3b..0caf14f040bb 100644 --- a/tests/kernels/moe/test_pplx_cutlass_moe.py +++ b/tests/kernels/moe/test_pplx_cutlass_moe.py @@ -6,9 +6,9 @@ import pytest import torch +from tests.kernels.utils import torch_experts from vllm import _custom_ops as ops from vllm.config import VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.modular_kernel import ( @@ -164,22 +164,6 @@ def pplx_cutlass_moe( vllm_config.scheduler_config.max_model_len = 8192 -def torch_moe2(a, w1, w2, topk_weight, topk_ids): - M, K = a.shape - topk = topk_ids.shape[1] - a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) - out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device) - num_experts = w1.shape[0] - for i in range(num_experts): - mask = (topk_ids == i).view(-1) - if mask.sum(): - out[mask] = SiluAndMul()( - a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) - - return (out.view(M, -1, w2.shape[1]) * - topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) - - def _pplx_moe( pgi: ProcessGroupInfo, dp_size: int, @@ -210,8 +194,8 @@ def _pplx_moe( group_name = cpu_group.group_name with set_current_vllm_config(vllm_config): - torch_output = torch_moe2(a_full, w1_full, w2_full, topk_weights, - topk_ids) + torch_output = torch_experts(a_full, w1_full, w2_full, topk_weights, + topk_ids) pplx_output = pplx_cutlass_moe(pgi, dp_size, a, w1, w2, w1_scale, w2_scale, topk_weights, topk_ids, a1_scale, out_dtype, per_act_token, diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 2d6a8f39cec5..c4ad3af6802d 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -18,8 +18,8 @@ except ImportError: has_pplx = False +from tests.kernels.utils import torch_experts from vllm.config import VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import override_config from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( BatchedExperts, BatchedPrepareAndFinalize, BatchedTritonExperts) @@ -163,29 +163,6 @@ def batched_moe( return fused_experts(a, w1, w2, topk_weight, topk_ids, num_experts) -# Note: same as torch_moe but with fused_topk factored out. -def torch_moe2( - a: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weight: torch.Tensor, - topk_ids: torch.Tensor, -) -> torch.Tensor: - M, K = a.shape - topk = topk_ids.shape[1] - a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) - out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device) - num_experts = w1.shape[0] - for i in range(num_experts): - mask = (topk_ids == i).view(-1) - if mask.sum(): - out[mask] = SiluAndMul()( - a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) - - return (out.view(M, -1, w2.shape[1]) * - topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) - - @pytest.mark.parametrize("m", [1, 33, 64, 222]) @pytest.mark.parametrize("n", [128, 1024, 2048]) @pytest.mark.parametrize("k", [128, 512, 1024]) @@ -209,7 +186,7 @@ def test_fused_moe_batched_experts( with set_current_vllm_config(vllm_config): topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) - baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) + baseline_output = torch_experts(a, w1, w2, topk_weight, topk_ids) torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids) batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids) @@ -409,7 +386,7 @@ def pplx_moe( w2: torch.Tensor, topk_weight: torch.Tensor, topk_ids: torch.Tensor, - use_compile: bool = True, + use_compile: bool = False, use_cudagraphs: bool = True, ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( @@ -470,10 +447,16 @@ def pplx_moe( w1_chunk = chunk_by_rank(w1, rank, world_size).to(device) w2_chunk = chunk_by_rank(w2, rank, world_size).to(device) + # Note: for now use_compile will error out if the problem size is + # large enough to trigger chunking. I'm leaving the flag and + # setup code in case we are able to revisit this later. if use_compile: _fused_experts = torch.compile(fused_experts, backend='inductor', fullgraph=True) + torch._dynamo.mark_dynamic(a_chunk, 0) + torch._dynamo.mark_dynamic(chunk_topk_weight, 0) + torch._dynamo.mark_dynamic(chunk_topk_ids, 0) else: _fused_experts = fused_experts @@ -576,7 +559,7 @@ def _pplx_moe( with set_current_vllm_config(vllm_config), override_config(moe_config): topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) - torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) + torch_output = torch_experts(a, w1, w2, topk_weight, topk_ids) pplx_output = pplx_moe(group_name, pgi.rank, pgi.world_size, dp_size, a, w1, w2, topk_weight, topk_ids) # TODO (bnell): fix + re-enable diff --git a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py new file mode 100644 index 000000000000..673a0aa36794 --- /dev/null +++ b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py @@ -0,0 +1,83 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( + silu_mul_fp8_quant_deep_gemm) +from vllm.platforms import current_platform + +# (E, T, H, group_size, seed) +CASES = [ + (1, 1, 128, 64, 0), + (1, 4, 128, 128, 0), + (2, 4, 256, 128, 0), + (32, 64, 256, 128, 0), + (17, 31, 768, 128, 0), +] + + +@pytest.mark.parametrize("E,T,H,group_size,seed", CASES) +@torch.inference_mode() +def test_silu_mul_fp8_quant_deep_gemm(E, T, H, group_size, seed): + current_platform.seed_everything(seed) + + # Input tensor of shape (E, T, 2*H) + y = torch.randn((E, T, 2 * H), dtype=torch.float32, device="cuda") + tokens_per_expert = torch.randint( + low=0, + high=T, + size=(E, ), + dtype=torch.int32, + device="cuda", + ) + + # Run the Triton kernel + y_q, y_s = silu_mul_fp8_quant_deep_gemm(y, + tokens_per_expert, + group_size=group_size, + eps=1e-10) + + # Reference implementation + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max = fp8_info.max + fp8_min = fp8_info.min + eps = 1e-10 + + # Compute silu activation and elementwise multiplication + y1 = y[..., :H] + y2 = y[..., H:] + silu_x = y1 * torch.sigmoid(y1) + merged = silu_x * y2 + + # Compute reference scales and quantized output, skipping padded tokens + for e in range(E): + nt = tokens_per_expert[e].item() + ref_s = torch.empty((T, H // group_size), + dtype=torch.float32, + device="cuda") + ref_q = torch.empty((T, H), dtype=torch.float8_e4m3fn, device="cuda") + for t in range(nt): + data = merged[e, t] + data_grp = data.view(H // group_size, group_size) + amax = data_grp.abs().amax(dim=1).clamp(min=eps) + scale = amax / fp8_max + + scaled = data / scale.repeat_interleave(group_size) + clamped = scaled.clamp(fp8_min, fp8_max) + q = clamped.to(torch.float8_e4m3fn) + + ref_s[t] = scale + ref_q[t] = q + + y_se = y_s[e] + y_qe = y_q[e] + + torch.testing.assert_close(y_se[:nt], ref_s[:nt]) + torch.testing.assert_close( + y_qe[:nt].to(torch.float32), + ref_q[:nt].to(torch.float32), + atol=2, + rtol=2e-1, + ) diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index eec59573792d..1ca0a80ab9a9 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -403,19 +403,24 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, SEEDS)) @pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") @torch.inference_mode() -def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): - - block_m = deep_gemm.get_m_alignment_for_contiguous_layout() - block_size = [block_m, block_m] - dtype = torch.bfloat16 - +def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, + monkeypatch): if topk > E: pytest.skip(f"Skipping test: topk={topk} > E={E}") if not _valid_deep_gemm_shape(M, N, K): pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}") + chunk_size = 1024 + torch.manual_seed(seed) + + monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size)) + + block_m = deep_gemm.get_m_alignment_for_contiguous_layout() + block_size = [block_m, block_m] + dtype = torch.bfloat16 + fp8_info = torch.finfo(torch.float8_e4m3fn) fp8_max, fp8_min = fp8_info.max, fp8_info.min @@ -451,6 +456,14 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) + # Note: for now use_compile will error out if the problem size is + # large enough to trigger chunking. I'm leaving the flag and + # setup code in case we are able to revisit this later. + use_compile = False + + use_cudagraph = (chunk_size < M and N >= 1024 and K >= 1024 + and current_platform.is_cuda_alike()) + # Set the context to avoid lots of warning spam. with set_current_vllm_config(vllm_config): if M >= 128: @@ -463,7 +476,29 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): topk_weights, topk_ids, token_expert_indices = fused_topk( a, score.float(), topk, False) - out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids) + if use_compile: + deep_gemm_moe_fp8_fn = torch.compile(deep_gemm_moe_fp8, + backend="inductor", + fullgraph=True) + torch._dynamo.mark_dynamic(a, 0) + torch._dynamo.mark_dynamic(topk_weights, 0) + torch._dynamo.mark_dynamic(topk_ids, 0) + else: + deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8 + + out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, + topk_ids) + + if use_cudagraph: + out.fill_(0) + stream = torch.cuda.Stream() + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=stream): + out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, + topk_ids) + torch.cuda.synchronize() + graph.replay() + torch.cuda.synchronize() #print(f"{out.sum()=}") #print(f"{ref_out.sum()=}") diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index d1db6a8eb1ba..dcda8e479b29 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -1054,12 +1054,21 @@ def compute_max_diff(output, output_ref): torch.abs(output_ref)) -def torch_moe(a, w1, w2, score, topk, expert_map): +def torch_experts(a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None) -> torch.Tensor: + assert (global_num_experts == -1 + or (global_num_experts == w1.shape[0] and expert_map is None) + or (expert_map is not None + and global_num_experts == expert_map.shape[0])) + topk = topk_ids.shape[1] B, D = a.shape a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) - score = torch.softmax(score, dim=-1, dtype=torch.float32) - topk_weight, topk_ids = torch.topk(score, topk) topk_weight = topk_weight.view(-1) topk_ids = topk_ids.view(-1) if expert_map is not None: @@ -1073,6 +1082,19 @@ def torch_moe(a, w1, w2, score, topk, expert_map): topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) +def torch_moe(a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + score: torch.Tensor, + topk: int, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None) -> torch.Tensor: + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + return torch_experts(a, w1, w2, topk_weight, topk_ids, global_num_experts, + expert_map) + + def torch_moe_single(a, w, score, topk): B, D = a.shape a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) diff --git a/tests/model_executor/test_model_load_with_params.py b/tests/model_executor/test_model_load_with_params.py index 94a14bd24bcb..4bdb651e5170 100644 --- a/tests/model_executor/test_model_load_with_params.py +++ b/tests/model_executor/test_model_load_with_params.py @@ -29,8 +29,8 @@ def test_model_loading_with_params(vllm_runner): revision=REVISION, dtype="float16", max_model_len=MAX_MODEL_LEN) as vllm_model: - output = vllm_model.encode("Write a short story about a robot that" - " dreams for the first time.\n") + output = vllm_model.embed("Write a short story about a robot that" + " dreams for the first time.\n") model_config = vllm_model.model.llm_engine.model_config model_tokenizer = vllm_model.model.llm_engine.tokenizer @@ -67,8 +67,8 @@ def test_roberta_model_loading_with_params(vllm_runner): revision=REVISION_ROBERTA, dtype="float16", max_model_len=MAX_MODEL_LEN) as vllm_model: - output = vllm_model.encode("Write a short story about a robot that" - " dreams for the first time.\n") + output = vllm_model.embed("Write a short story about a robot that" + " dreams for the first time.\n") model_config = vllm_model.model.llm_engine.model_config model_tokenizer = vllm_model.model.llm_engine.tokenizer @@ -105,8 +105,8 @@ def test_facebook_roberta_model_loading_with_params(vllm_runner): with vllm_runner(model_name=model_name, dtype="float16", max_model_len=MAX_MODEL_LEN) as vllm_model: - output = vllm_model.encode("Write a short story about a robot that" - " dreams for the first time.\n") + output = vllm_model.embed("Write a short story about a robot that" + " dreams for the first time.\n") model_tokenizer = vllm_model.model.llm_engine.tokenizer assert model_tokenizer.tokenizer_id == model_name diff --git a/tests/models/language/pooling/embed_utils.py b/tests/models/language/pooling/embed_utils.py index dabd7bee7f39..a663679a9c7c 100644 --- a/tests/models/language/pooling/embed_utils.py +++ b/tests/models/language/pooling/embed_utils.py @@ -55,7 +55,7 @@ def correctness_test_embed_models(hf_runner, task="embed", max_model_len=None, **vllm_extra_kwargs) as vllm_model: - vllm_outputs = vllm_model.encode(example_prompts) + vllm_outputs = vllm_model.embed(example_prompts) with hf_runner( model_info.name, diff --git a/tests/models/language/pooling/test_embedding.py b/tests/models/language/pooling/test_embedding.py index 5ef9f768c574..b8b17524cf07 100644 --- a/tests/models/language/pooling/test_embedding.py +++ b/tests/models/language/pooling/test_embedding.py @@ -89,7 +89,7 @@ def test_models( task="embed", max_model_len=512, **vllm_extra_kwargs) as vllm_model: - vllm_outputs = vllm_model.encode(example_prompts) + vllm_outputs = vllm_model.embed(example_prompts) check_embeddings_close( embeddings_0_lst=hf_outputs, diff --git a/tests/models/language/pooling/test_jina.py b/tests/models/language/pooling/test_jina.py index 0c44683e7486..0bc189d82b8a 100644 --- a/tests/models/language/pooling/test_jina.py +++ b/tests/models/language/pooling/test_jina.py @@ -98,11 +98,11 @@ def test_matryoshka( if dimensions not in matryoshka_dimensions: with pytest.raises(ValueError): - vllm_model.encode( + vllm_model.embed( example_prompts, pooling_params=PoolingParams(dimensions=dimensions)) else: - vllm_outputs = vllm_model.encode( + vllm_outputs = vllm_model.embed( example_prompts, pooling_params=PoolingParams(dimensions=dimensions)) diff --git a/tests/models/language/pooling/test_reward.py b/tests/models/language/pooling/test_reward.py new file mode 100644 index 000000000000..ec3d25ee22a9 --- /dev/null +++ b/tests/models/language/pooling/test_reward.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch +import torch.nn.functional as F +from transformers import AutoModel + +from vllm.platforms import current_platform + +from ....conftest import HfRunner + + +@pytest.fixture(autouse=True) +def v1(run_with_both_engines): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + +@pytest.fixture +def math_step_prompts(): + # ruff: noqa: E501 + data = { + "system": + "Please reason step by step, and put your final answer within \\boxed{}. ", + "query": + "Sue lives in a fun neighborhood. One weekend, the neighbors decided to play a prank on Sue. On Friday morning, the neighbors placed 18 pink plastic flamingos out on Sue's front yard. On Saturday morning, the neighbors took back one third of the flamingos, painted them white, and put these newly painted white flamingos back out on Sue's front yard. Then, on Sunday morning, they added another 18 pink plastic flamingos to the collection. At noon on Sunday, how many more pink plastic flamingos were out than white plastic flamingos?", + "response": [ + "To find out how many more pink plastic flamingos were out than white plastic flamingos at noon on Sunday, we can break down the problem into steps. First, on Friday, the neighbors start with 18 pink plastic flamingos.", + "On Saturday, they take back one third of the flamingos. Since there were 18 flamingos, (1/3 \\times 18 = 6) flamingos are taken back. So, they have (18 - 6 = 12) flamingos left in their possession. Then, they paint these 6 flamingos white and put them back out on Sue's front yard. Now, Sue has the original 12 pink flamingos plus the 6 new white ones. Thus, by the end of Saturday, Sue has (12 + 6 = 18) pink flamingos and 6 white flamingos.", + "On Sunday, the neighbors add another 18 pink plastic flamingos to Sue's front yard. By the end of Sunday morning, Sue has (18 + 18 = 36) pink flamingos and still 6 white flamingos.", + "To find the difference, subtract the number of white flamingos from the number of pink flamingos: (36 - 6 = 30). Therefore, at noon on Sunday, there were 30 more pink plastic flamingos out than white plastic flamingos. The answer is (\\boxed{30}).", + ], + } + answer = "".join(data['response']) + "" + prompt = f"system\n{data['system']}\nuser\n{data['query']}\nassistant\n{answer}<|endoftext|>" + return [prompt] + + +def step_reward_patch_hf_model(hf_model: HfRunner): + + # Patch the hf_runner to use the step reward function + def make_step_rewards(logits: torch.Tensor, + token_masks: torch.Tensor) -> list[list[float]]: + probabilities = F.softmax(logits, dim=-1) + probabilities = probabilities * token_masks.unsqueeze(-1) + + all_scores_res: list[list[float]] = [] + for i in range(probabilities.size(0)): + sample = probabilities[i] # seq_len, num_labels + positive_probs = sample[sample != 0].view(-1, 2) + non_zero_elements_list = positive_probs.cpu().tolist() + all_scores_res.append(non_zero_elements_list) + return all_scores_res + + def reward(prompts: list[str]) -> list[list[float]]: + input_ids = hf_model.tokenizer(prompts, return_tensors="pt").input_ids + input_ids = hf_model.wrap_device(input_ids) + outputs = hf_model.model(input_ids=input_ids) + + step_sep_id = hf_model.tokenizer.encode("")[0] + token_masks = (input_ids == step_sep_id) + return make_step_rewards(outputs[0], token_masks) + + hf_model.reward = reward # type: ignore[attr-defined] + + return hf_model + + +@pytest.mark.parametrize( + "model", + [ + pytest.param("Qwen/Qwen2.5-Math-PRM-7B", + marks=[pytest.mark.core_model, pytest.mark.cpu_model]), + ], +) +@pytest.mark.parametrize("dtype", ["half"]) +def test_prm_models( + hf_runner, + vllm_runner, + math_step_prompts, + model: str, + dtype: str, + monkeypatch, +) -> None: + if current_platform.is_rocm(): + # ROCm Triton FA does not currently support sliding window attention + # switch to use ROCm CK FA backend + monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False") + + with vllm_runner(model, max_model_len=1024, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.encode(math_step_prompts) + + with hf_runner(model, dtype=dtype, auto_cls=AutoModel) as hf_model: + hf_model = step_reward_patch_hf_model(hf_model) + hf_outputs = hf_model.reward(math_step_prompts) + + # check logits difference + for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): + hf_output = torch.tensor(hf_output) + vllm_output = torch.tensor(vllm_output) + + assert torch.allclose(hf_output, vllm_output, 1.5e-2) diff --git a/tests/models/multimodal/pooling/test_dse_qwen2_vl.py b/tests/models/multimodal/pooling/test_dse_qwen2_vl.py index 3734d87b7962..f889eea5e839 100644 --- a/tests/models/multimodal/pooling/test_dse_qwen2_vl.py +++ b/tests/models/multimodal/pooling/test_dse_qwen2_vl.py @@ -98,7 +98,7 @@ def _run_test( max_model_len=8192) as vllm_model: tokenizer = vllm_model.model.get_tokenizer() texts = [ - # this is necessary because vllm_model.encode will not apply any + # this is necessary because vllm_model.embed will not apply any # templating to the prompt, and therefore lacks an image_pad # token unless one is inserted beforehand (the (28,28) image # above is converted to an image pad token by the chat template). @@ -109,7 +109,7 @@ def _run_test( # vllm will replace the pad token with the actual image, # which may be a placeholder image, later. ] - vllm_outputs = vllm_model.encode(texts, images=input_images) + vllm_outputs = vllm_model.embed(texts, images=input_images) hf_outputs = [] with hf_runner(model, diff --git a/tests/models/multimodal/pooling/test_llava_next.py b/tests/models/multimodal/pooling/test_llava_next.py index b6d90d2b0abe..4a8f5cafbe48 100644 --- a/tests/models/multimodal/pooling/test_llava_next.py +++ b/tests/models/multimodal/pooling/test_llava_next.py @@ -68,7 +68,7 @@ def _run_test( dtype=dtype, max_model_len=4096, enforce_eager=True) as vllm_model: - vllm_outputs = vllm_model.encode(input_texts, images=input_images) + vllm_outputs = vllm_model.embed(input_texts, images=input_images) with hf_runner(model, dtype=dtype, auto_cls=AutoModelForImageTextToText) as hf_model: diff --git a/tests/models/multimodal/pooling/test_phi3v.py b/tests/models/multimodal/pooling/test_phi3v.py index b42ac6fb21ed..9a4b6d3ff8a8 100644 --- a/tests/models/multimodal/pooling/test_phi3v.py +++ b/tests/models/multimodal/pooling/test_phi3v.py @@ -46,7 +46,7 @@ def _run_test( # will hurt multiprocessing backend with fork method (the default method). with vllm_runner(model, task="embed", dtype=dtype, enforce_eager=True) as vllm_model: - vllm_outputs = vllm_model.encode(input_texts, images=input_images) + vllm_outputs = vllm_model.embed(input_texts, images=input_images) # use eager mode for hf runner, since phi3_v didn't work with flash_attn hf_model_kwargs = {"_attn_implementation": "eager"} diff --git a/tests/models/registry.py b/tests/models/registry.py index 49510af880cf..4a587e39ad4c 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -193,7 +193,8 @@ def check_available_online( extras={"tiny": "ai21labs/Jamba-tiny-dev"}), # noqa: E501 "LlamaForCausalLM": _HfExamplesInfo("meta-llama/Llama-3.2-1B-Instruct", extras={"guard": "meta-llama/Llama-Guard-3-1B", # noqa: E501 - "hermes": "NousResearch/Hermes-3-Llama-3.1-8B"}), # noqa: E501 + "hermes": "NousResearch/Hermes-3-Llama-3.1-8B", # noqa: E501 + "fp8": "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"}), # noqa: E501 "LLaMAForCausalLM": _HfExamplesInfo("decapoda-research/llama-7b-hf", is_available_online=False), "MambaForCausalLM": _HfExamplesInfo("state-spaces/mamba-130m-hf"), diff --git a/tests/neuron/1_core/test_prefix_prefill.py b/tests/neuron/1_core/test_prefix_prefill.py index 8b9a5f6e4a6a..abf7febc2955 100644 --- a/tests/neuron/1_core/test_prefix_prefill.py +++ b/tests/neuron/1_core/test_prefix_prefill.py @@ -7,6 +7,8 @@ import torch import torch.nn.functional as F +from vllm.utils import cdiv + class BlockDiagonalCausalFromBottomRightMask: @@ -398,11 +400,8 @@ def test_contexted_kv_attention( assert (large_tile_size >= B_P_SIZE ), f"Expect {large_tile_size=} to be larger than {B_P_SIZE=}" - def ceil_div(a, b): - return (a + b - 1) // b - def pad_to_multiple(a, b): - return ceil_div(a, b) * b + return cdiv(a, b) * b def pad_to_next_power_of_2(a): assert a > 0 @@ -411,7 +410,7 @@ def pad_to_next_power_of_2(a): # calculate input shapes max_num_queries = pad_to_next_power_of_2(sum(query_lens)) context_lens = torch.tensor(seq_lens) - torch.tensor(query_lens) - num_active_blocks = ceil_div(context_lens, block_size).sum().item() + num_active_blocks = cdiv(context_lens, block_size).sum().item() num_active_blocks = pad_to_multiple(num_active_blocks, large_tile_size // block_size) context_kv_len = num_active_blocks * block_size diff --git a/tests/quantization/test_bitsandbytes.py b/tests/quantization/test_bitsandbytes.py index 8e39ed2fff87..363daa6d27ef 100644 --- a/tests/quantization/test_bitsandbytes.py +++ b/tests/quantization/test_bitsandbytes.py @@ -161,7 +161,7 @@ def test_4bit_bnb_embedding_model( dtype=dtype, gpu_memory_utilization=0.5, quantization="bitsandbytes") as vllm_model: - vllm_outputs = vllm_model.encode(example_prompts) + vllm_outputs = vllm_model.embed(example_prompts) check_embeddings_close( embeddings_0_lst=hf_outputs, embeddings_1_lst=vllm_outputs, diff --git a/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py b/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py index 3d720fe0cafe..c58cb0286f13 100644 --- a/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py +++ b/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py @@ -196,8 +196,7 @@ async def stream_service_response(client_info: dict, endpoint: str, yield chunk -@app.post("/v1/completions") -async def handle_completions(request: Request): +async def _handle_completions(api: str, request: Request): try: req_data = await request.json() request_id = str(uuid.uuid4()) @@ -206,9 +205,8 @@ async def handle_completions(request: Request): prefill_client_info = get_next_client(request.app, 'prefill') # Send request to prefill service - response = await send_request_to_service(prefill_client_info, - "/completions", req_data, - request_id) + response = await send_request_to_service(prefill_client_info, api, + req_data, request_id) # Extract the needed fields response_json = response.json() @@ -224,7 +222,7 @@ async def handle_completions(request: Request): # Stream response from decode service async def generate_stream(): async for chunk in stream_service_response(decode_client_info, - "/completions", + api, req_data, request_id=request_id): yield chunk @@ -237,12 +235,22 @@ async def generate_stream(): import traceback exc_info = sys.exc_info() print("Error occurred in disagg prefill proxy server" - " - completions endpoint") + f" - {api} endpoint") print(e) print("".join(traceback.format_exception(*exc_info))) raise +@app.post("/v1/completions") +async def handle_completions(request: Request): + return await _handle_completions("/completions", request) + + +@app.post("/v1/chat/completions") +async def handle_chat_completions(request: Request): + return await _handle_completions("/chat/completions", request) + + @app.get("/healthcheck") async def healthcheck(): """Simple endpoint to check if the server is running.""" diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index b00be7b83e12..ab9729aae2e9 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -7,13 +7,6 @@ from typing import Optional from unittest.mock import patch -import pytest - -try: - from nixl._api import nixl_agent as NixlWrapper -except ImportError: - NixlWrapper = None - from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata, NixlConnectorWorker) @@ -92,7 +85,8 @@ def test_prompt_less_than_block_size(): class FakeNixlWrapper: """Mock implementation of NixlWrapper for testing. - We don't inherit from NixlWrapper because NixlWrapper could be None. + We don't inherit from nixl._api.nixl_agent because nixl may not be + installed. """ AGENT_METADATA = b"fake_agent_metadata" @@ -167,7 +161,7 @@ def __init__(self, *args, hand_shake_latency: float = 1.8, **kwargs): super().__init__(*args, **kwargs) self._hand_shake_latency = hand_shake_latency - def _nixl_handshake(self, host: str, port: int): + def _nixl_handshake(self, host: str, port: int) -> dict[int, str]: # Mimic slow _nixl_handshake, as well as bypass zmq communication. time.sleep(self._hand_shake_latency) # These should've been done in register_kv_caches(), called by @@ -177,7 +171,7 @@ def _nixl_handshake(self, host: str, port: int): self.num_blocks = 1 self.dst_num_blocks[self.engine_id] = self.num_blocks - self.add_remote_agent( + remote_agent_name = self.add_remote_agent( NixlAgentMetadata( engine_id=self.REMOTE_ENGINE_ID, agent_metadata=FakeNixlWrapper.AGENT_METADATA, @@ -187,40 +181,101 @@ def _nixl_handshake(self, host: str, port: int): block_len=self.block_len, attn_backend_name=self.backend_name, )) - - -@pytest.mark.skipif(NixlWrapper is None, reason="nixl not installed") -@patch( - "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", - FakeNixlWrapper) -def test_multi_xfer_one_engine( - # dist_init is a fixture that initializes the distributed environment. - dist_init): - """Test case where multiple xfers are initiated to the same engine. - - This test triggers the connector to load remote KV for the same - `request_id`. The transfer is not done immediately due to - `set_cycles_before_xfer_done`, so there is a state where there are multiple - transfer states for the same `request_id`, and `get_finished` should handle - it correctly (wait for all transfers to be done). - """ - vllm_config = create_vllm_config() - - request_id = "req_id" - - # Test worker role in decode server. - connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) - connector.connector_worker = FakeNixlConnectorWorker(vllm_config, - connector.engine_id, - hand_shake_latency=0) - assert isinstance(connector.connector_worker.nixl_wrapper, FakeNixlWrapper) - connector.connector_worker.nixl_wrapper.set_cycles_before_xfer_done(3) - for i in range(4): + return {0: remote_agent_name} + + +class TestNixlHandshake: + + @patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper) + def test_multi_xfer_one_engine( + self, + # dist_init is a fixture that initializes the distributed environment. + dist_init): + """Test case where multiple xfers are initiated to the same engine. + + This test triggers the connector to load remote KV for the same + `request_id`. The transfer is not done immediately due to + `set_cycles_before_xfer_done`, so there is a state where there are + multiple transfer states for the same `request_id`, and `get_finished` + should handle it correctly (wait for all transfers to be done). + """ + vllm_config = create_vllm_config() + + request_id = "req_id" + + # Test worker role in decode server. + connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + connector.connector_worker = FakeNixlConnectorWorker( + vllm_config, connector.engine_id, hand_shake_latency=0) + assert isinstance(connector.connector_worker.nixl_wrapper, + FakeNixlWrapper) + connector.connector_worker.nixl_wrapper.set_cycles_before_xfer_done(3) + num_xfers = 4 + while True: + # For the same request_id, initiate multiple xfers across different + # round of `execute_model` calls. + metadata = NixlConnectorMetadata() + if num_xfers > 0: + num_xfers -= 1 + metadata.add_new_req( + request_id=request_id, + local_block_ids=[ + num_xfers + 1, num_xfers + 2, num_xfers + 3 + ], + kv_transfer_params={ + "remote_block_ids": + [num_xfers + 4, num_xfers + 5, num_xfers + 6], + "remote_engine_id": + FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + "remote_host": + "localhost", + "remote_port": + 1234, + }) + connector.bind_connector_metadata(metadata) + + # Mimic maybe_setup_kv_connector in gpu_model_runner. + dummy_ctx = ForwardContext( + no_compile_layers={}, + attn_metadata={}, + virtual_engine=0, + ) + _before_load = time.perf_counter() + connector.start_load_kv(dummy_ctx) + _after_load = time.perf_counter() + assert _after_load - _before_load < 0.1, "start_load_kv took " \ + f"{_after_load - _before_load} seconds" + + # Mimic get_finished_kv_transfers in gpu_model_runner. + _, done_recving = connector.get_finished(finished_req_ids=set()) + if len(done_recving) > 0: + assert request_id in done_recving + break + + connector.clear_connector_metadata() + + @patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper) + def test_async_load_kv( + self, + # dist_init is a fixture that initializes the distributed environment. + dist_init): + """Test that NixlConnector's start_load_kv should be non-blocking.""" + + vllm_config = create_vllm_config() + + # Test worker role in decode server. + connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + connector.connector_worker = FakeNixlConnectorWorker( + vllm_config, connector.engine_id) metadata = NixlConnectorMetadata() - metadata.add_new_req(request_id=request_id, - local_block_ids=[i + 1, i + 2, i + 3], + metadata.add_new_req(request_id="id", + local_block_ids=[1, 2, 3], kv_transfer_params={ - "remote_block_ids": [i + 4, i + 5, i + 6], + "remote_block_ids": [4, 5, 6], "remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID, "remote_host": "localhost", @@ -228,19 +283,74 @@ def test_multi_xfer_one_engine( }) connector.bind_connector_metadata(metadata) - dummy_ctx = ForwardContext( - no_compile_layers={}, - attn_metadata={}, - virtual_engine=0, - ) - _before_load = time.perf_counter() - connector.start_load_kv(dummy_ctx) - _after_load = time.perf_counter() - assert _after_load - _before_load < 0.1, "start_load_kv took " \ - f"{_after_load - _before_load} seconds" - - while True: - _, done_recving = connector.get_finished(finished_req_ids=set()) - if len(done_recving) > 0: - assert request_id in done_recving - break + timeout = 2.5 + start = time.perf_counter() + while time.perf_counter() - start < timeout: + dummy_ctx = ForwardContext( + no_compile_layers={}, + attn_metadata={}, + virtual_engine=0, + ) + _before_load = time.perf_counter() + connector.start_load_kv(dummy_ctx) + _after_load = time.perf_counter() + assert _after_load - _before_load < 0.1, "start_load_kv took " \ + f"{_after_load - _before_load} seconds" + time.sleep(0.5) # backoff for the async handshake to complete. + connector.bind_connector_metadata(NixlConnectorMetadata()) + _, done_recving = connector.get_finished(finished_req_ids=set()) + if len(done_recving) > 0: + return + raise TimeoutError("Took too long to complete async handshake.") + + @patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper) + def test_concurrent_load_kv( + self, + # dist_init is a fixture that initializes the distributed environment. + dist_init): + """Test that multiple start_load_kv calls should occur concurrently.""" + + vllm_config = create_vllm_config() + + # Test worker role in decode server. + connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + connector.connector_worker = FakeNixlConnectorWorker( + vllm_config, connector.engine_id) + metadata = NixlConnectorMetadata() + total_reqs = 5 + for i in range(total_reqs): + metadata.add_new_req(request_id=f"id_{i}", + local_block_ids=[1, 2, 3], + kv_transfer_params={ + "remote_block_ids": [4, 5, 6], + "remote_engine_id": + FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + "remote_host": "localhost", + "remote_port": 1234, + }) + connector.bind_connector_metadata(metadata) + + timeout = 2.5 * total_reqs + cnt_finished_reqs = 0 + start = time.perf_counter() + while time.perf_counter() - start < timeout: + dummy_ctx = ForwardContext( + no_compile_layers={}, + attn_metadata={}, + virtual_engine=0, + ) + _before_load = time.perf_counter() + connector.start_load_kv(dummy_ctx) + _after_load = time.perf_counter() + assert _after_load - _before_load < 0.1, "start_load_kv took " \ + f"{_after_load - _before_load} seconds" + time.sleep(0.5) # backoff for the async handshake to complete. + connector.bind_connector_metadata(NixlConnectorMetadata()) + _, done_recving = connector.get_finished(finished_req_ids=set()) + if len(done_recving) > 0: + cnt_finished_reqs += len(done_recving) + if cnt_finished_reqs == total_reqs: + return + raise TimeoutError("Took too long to complete async handshake.") diff --git a/tests/v1/test_oracle.py b/tests/v1/test_oracle.py index 1787b9a0b469..d640d7dc49d1 100644 --- a/tests/v1/test_oracle.py +++ b/tests/v1/test_oracle.py @@ -74,12 +74,6 @@ def test_unsupported_configs(monkeypatch): disable_async_output_proc=True, ).create_engine_config() - with pytest.raises(NotImplementedError): - AsyncEngineArgs( - model=MODEL, - scheduling_policy="priority", - ).create_engine_config() - with pytest.raises(NotImplementedError): AsyncEngineArgs( model=MODEL, diff --git a/tests/v1/tpu/worker/test_tpu_model_runner.py b/tests/v1/tpu/worker/test_tpu_model_runner.py index 0e7d305fef9e..d22ddf5c7e58 100644 --- a/tests/v1/tpu/worker/test_tpu_model_runner.py +++ b/tests/v1/tpu/worker/test_tpu_model_runner.py @@ -6,6 +6,7 @@ from vllm.attention.layer import Attention from vllm.config import (CacheConfig, ModelConfig, SchedulerConfig, VllmConfig, set_current_vllm_config) +from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.utils import GiB_bytes from vllm.v1.core.kv_cache_utils import (estimate_max_model_len, @@ -71,6 +72,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: mm_hashes=[], mm_positions=[], sampling_params=SamplingParams(), + pooling_params=PoolingParams(), block_ids=([0], ), # block_ids should be tuple[list[int]] num_computed_tokens=0, lora_request=None, diff --git a/tools/generate_cmake_presets.py b/tools/generate_cmake_presets.py new file mode 100644 index 000000000000..5f92f2f5848f --- /dev/null +++ b/tools/generate_cmake_presets.py @@ -0,0 +1,169 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json +import multiprocessing +import os +import sys +from shutil import which + +try: + # Try to get CUDA_HOME from PyTorch installation, which is the + # most reliable source of truth for vLLM's build. + from torch.utils.cpp_extension import CUDA_HOME +except ImportError: + print("Warning: PyTorch not found. " + "Falling back to CUDA_HOME environment variable.") + CUDA_HOME = os.environ.get("CUDA_HOME") + + +def get_python_executable(): + """Get the current Python executable, which is used to run this script.""" + return sys.executable + + +def get_cpu_cores(): + """Get the number of CPU cores.""" + return multiprocessing.cpu_count() + + +def generate_presets(output_path="CMakeUserPresets.json"): + """Generates the CMakeUserPresets.json file.""" + + print("Attempting to detect your system configuration...") + + # Detect NVCC + nvcc_path = None + if CUDA_HOME: + prospective_path = os.path.join(CUDA_HOME, "bin", "nvcc") + if os.path.exists(prospective_path): + nvcc_path = prospective_path + print("Found nvcc via torch.utils.cpp_extension.CUDA_HOME: " + f"{nvcc_path}") + + if not nvcc_path: + nvcc_path = which("nvcc") + if nvcc_path: + print(f"Found nvcc in PATH: {nvcc_path}") + + if not nvcc_path: + nvcc_path_input = input( + "Could not automatically find 'nvcc'. Please provide the full " + "path to nvcc (e.g., /usr/local/cuda/bin/nvcc): ") + nvcc_path = nvcc_path_input.strip() + print(f"Using NVCC path: {nvcc_path}") + + # Detect Python executable + python_executable = get_python_executable() + if python_executable: + print(f"Found Python via sys.executable: {python_executable}") + else: + python_executable_prompt = ( + "Could not automatically find Python executable. Please provide " + "the full path to your Python executable for vLLM development " + "(typically from your virtual environment, e.g., " + "/home/user/venvs/vllm/bin/python): ") + python_executable = input(python_executable_prompt).strip() + if not python_executable: + raise ValueError( + "Could not determine Python executable. Please provide it " + "manually.") + + print(f"Using Python executable: {python_executable}") + + # Get CPU cores + cpu_cores = get_cpu_cores() + nvcc_threads = min(4, cpu_cores) + cmake_jobs = max(1, cpu_cores // nvcc_threads) + print(f"Detected {cpu_cores} CPU cores. " + f"Setting NVCC_THREADS={nvcc_threads} and CMake jobs={cmake_jobs}.") + + # Get vLLM project root (assuming this script is in vllm/tools/) + project_root = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..")) + print(f"VLLM project root detected as: {project_root}") + + # Ensure python_executable path is absolute or resolvable + if not os.path.isabs(python_executable) and which(python_executable): + python_executable = os.path.abspath(which(python_executable)) + elif not os.path.isabs(python_executable): + print(f"Warning: Python executable '{python_executable}' is not an " + "absolute path and not found in PATH. CMake might not find it.") + + cache_variables = { + "CMAKE_CUDA_COMPILER": nvcc_path, + "CMAKE_BUILD_TYPE": "Release", + "VLLM_PYTHON_EXECUTABLE": python_executable, + "CMAKE_INSTALL_PREFIX": "${sourceDir}", + "CMAKE_CUDA_FLAGS": "", + "NVCC_THREADS": str(nvcc_threads), + } + + # Detect compiler cache + if which("sccache"): + print("Using sccache for compiler caching.") + for launcher in ("C", "CXX", "CUDA", "HIP"): + cache_variables[f"CMAKE_{launcher}_COMPILER_LAUNCHER"] = "sccache" + elif which("ccache"): + print("Using ccache for compiler caching.") + for launcher in ("C", "CXX", "CUDA", "HIP"): + cache_variables[f"CMAKE_{launcher}_COMPILER_LAUNCHER"] = "ccache" + else: + print("No compiler cache ('ccache' or 'sccache') found.") + + configure_preset = { + "name": "release", + "binaryDir": "${sourceDir}/cmake-build-release", + "cacheVariables": cache_variables, + } + if which("ninja"): + print("Using Ninja generator.") + configure_preset["generator"] = "Ninja" + cache_variables["CMAKE_JOB_POOLS"] = f"compile={cmake_jobs}" + else: + print("Ninja not found, using default generator. " + "Build may be slower.") + + presets = { + "version": + 6, + # Keep in sync with CMakeLists.txt and requirements/build.txt + "cmakeMinimumRequired": { + "major": 3, + "minor": 26, + "patch": 1 + }, + "configurePresets": [configure_preset], + "buildPresets": [{ + "name": "release", + "configurePreset": "release", + "jobs": cmake_jobs, + }], + } + + output_file_path = os.path.join(project_root, output_path) + + if os.path.exists(output_file_path): + overwrite = input( + f"'{output_file_path}' already exists. Overwrite? (y/N): ").strip( + ).lower() + if overwrite != 'y': + print("Generation cancelled.") + return + + try: + with open(output_file_path, "w") as f: + json.dump(presets, f, indent=4) + print(f"Successfully generated '{output_file_path}'") + print("\nTo use this preset:") + print( + f"1. Ensure you are in the vLLM root directory: cd {project_root}") + print("2. Initialize CMake: cmake --preset release") + print("3. Build+install: cmake --build --preset release " + "--target install") + + except OSError as e: + print(f"Error writing file: {e}") + + +if __name__ == "__main__": + generate_presets() diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 247fddbfe261..761704463782 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1289,7 +1289,7 @@ def scaled_fp8_quant( torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) else: # num_token_padding not implemented for this case - assert (scale.numel() == 1 or num_token_padding is None) + assert (scale.numel() == 1 and num_token_padding is None) torch.ops._C.static_scaled_fp8_quant(output, input, scale) return output, scale diff --git a/vllm/attention/ops/nki_flash_attn.py b/vllm/attention/ops/nki_flash_attn.py index e28ff7e8b4ed..29fa43201761 100644 --- a/vllm/attention/ops/nki_flash_attn.py +++ b/vllm/attention/ops/nki_flash_attn.py @@ -8,9 +8,7 @@ from neuronxcc import nki from neuronxcc.nki.language import par_dim - -def ceil_div(a, b): - return (a + b - 1) // b +from vllm.utils import cdiv def is_power_of_2(x): @@ -35,11 +33,10 @@ def load_block_tables(block_tables_hbm, num_tiles, num_blocks_per_tile): (num_tiles, num_blocks_per_tile)) block_tables_sbuf = nl.zeros( - (ceil_div(num_tiles, - B_P_SIZE), par_dim(B_P_SIZE), num_blocks_per_tile), + (cdiv(num_tiles, B_P_SIZE), par_dim(B_P_SIZE), num_blocks_per_tile), dtype=nl.int32, ) - for i in nl.affine_range(ceil_div(num_tiles, B_P_SIZE)): + for i in nl.affine_range(cdiv(num_tiles, B_P_SIZE)): i_p = nl.arange(B_P_SIZE)[:, None] i_f = nl.arange(num_blocks_per_tile)[None, :] block_tables_sbuf[i, i_p, i_f] = nl.load( @@ -83,7 +80,7 @@ def transform_block_tables_for_indirect_load( assert is_power_of_2( num_blocks_per_tile), f"{num_blocks_per_tile=} is not power of 2" - num_loads = ceil_div(num_blocks_per_tile, B_P_SIZE) + num_loads = cdiv(num_blocks_per_tile, B_P_SIZE) block_tables_transposed = nl.ndarray( ( num_loads, @@ -165,7 +162,7 @@ def load_kv_tile_from_cache( equivalent to (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE) """ # load key cache - num_loads = ceil_div(num_blocks_per_large_tile, B_P_SIZE) + num_loads = cdiv(num_blocks_per_large_tile, B_P_SIZE) for load_idx in nl.affine_range(num_loads): i_p = nl.arange(B_P_SIZE)[:, None] i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :] @@ -605,7 +602,7 @@ def flash_paged_attention( ) for large_k_tile_idx in nl.sequential_range(0, num_large_k_tile): - num_loads = ceil_div(num_blocks_per_large_tile, B_P_SIZE) + num_loads = cdiv(num_blocks_per_large_tile, B_P_SIZE) cur_k_tile = nl.ndarray( (par_dim(B_D_SIZE), LARGE_TILE_SZ), dtype=kernel_dtype, diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py index 4487d2d6841a..302f655f424a 100644 --- a/vllm/benchmarks/serve.py +++ b/vllm/benchmarks/serve.py @@ -26,7 +26,7 @@ from collections.abc import AsyncGenerator, Iterable from dataclasses import dataclass from datetime import datetime -from typing import Any, Optional +from typing import Any, Literal, Optional import numpy as np from tqdm.asyncio import tqdm @@ -75,14 +75,39 @@ class BenchmarkMetrics: percentiles_e2el_ms: list[tuple[float, float]] +def _get_current_request_rate( + ramp_up_strategy: Optional[Literal["linear", "exponential"]], + ramp_up_start_rps: Optional[int], + ramp_up_end_rps: Optional[int], + request_index: int, + total_requests: int, + request_rate: float, +) -> float: + if (ramp_up_strategy and ramp_up_start_rps is not None + and ramp_up_end_rps is not None): + progress = request_index / max(total_requests - 1, 1) + if ramp_up_strategy == "linear": + increase = (ramp_up_end_rps - ramp_up_start_rps) * progress + return ramp_up_start_rps + increase + elif ramp_up_strategy == "exponential": + ratio = ramp_up_end_rps / ramp_up_start_rps + return ramp_up_start_rps * (ratio**progress) + else: + raise ValueError(f"Unknown ramp-up strategy: {ramp_up_strategy}") + return request_rate + + async def get_request( input_requests: list[SampleRequest], request_rate: float, burstiness: float = 1.0, -) -> AsyncGenerator[SampleRequest, None]: + ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None, + ramp_up_start_rps: Optional[int] = None, + ramp_up_end_rps: Optional[int] = None, +) -> AsyncGenerator[tuple[SampleRequest, float], None]: """ Asynchronously generates requests at a specified rate - with OPTIONAL burstiness. + with OPTIONAL burstiness and OPTIONAL ramp-up strategy. Args: input_requests: @@ -97,21 +122,42 @@ async def get_request( A lower burstiness value (0 < burstiness < 1) results in more bursty requests, while a higher burstiness value (burstiness > 1) results in a more uniform arrival of requests. + ramp_up_strategy (optional): + The ramp-up strategy. Can be "linear" or "exponential". + If None, uses constant request rate (specified by request_rate). + ramp_up_start_rps (optional): + The starting request rate for ramp-up. + ramp_up_end_rps (optional): + The ending request rate for ramp-up. """ - input_requests: Iterable[SampleRequest] = iter(input_requests) - - # Calculate scale parameter theta to maintain the desired request_rate. assert burstiness > 0, ( f"A positive burstiness factor is expected, but given {burstiness}.") - theta = 1.0 / (request_rate * burstiness) + # Convert to list to get length for ramp-up calculations + if isinstance(input_requests, Iterable) and not isinstance( + input_requests, list): + input_requests = list(input_requests) - for request in input_requests: - yield request + total_requests = len(input_requests) + request_index = 0 - if request_rate == float("inf"): + for request in input_requests: + current_request_rate = _get_current_request_rate(ramp_up_strategy, + ramp_up_start_rps, + ramp_up_end_rps, + request_index, + total_requests, + request_rate) + + yield request, current_request_rate + + request_index += 1 + + if current_request_rate == float("inf"): # If the request rate is infinity, then we don't need to wait. continue + theta = 1.0 / (current_request_rate * burstiness) + # Sample the request interval from the gamma distribution. # If burstiness is 1, it follows exponential distribution. interval = np.random.gamma(shape=burstiness, scale=theta) @@ -259,6 +305,9 @@ async def benchmark( max_concurrency: Optional[int], lora_modules: Optional[Iterable[str]], extra_body: Optional[dict], + ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None, + ramp_up_start_rps: Optional[int] = None, + ramp_up_end_rps: Optional[int] = None, ): if endpoint_type in ASYNC_REQUEST_FUNCS: request_func = ASYNC_REQUEST_FUNCS[endpoint_type] @@ -316,12 +365,16 @@ async def benchmark( if profile_output.success: print("Profiler started") - if burstiness == 1.0: - distribution = "Poisson process" + distribution = ("Poisson process" if burstiness == 1.0 + else "Gamma distribution") + + if ramp_up_strategy is not None: + print(f"Traffic ramp-up strategy: {ramp_up_strategy}.") + print(f"Will increase RPS from {ramp_up_start_rps} to " + f"{ramp_up_end_rps} RPS over the duration of the benchmark.") else: - distribution = "Gamma distribution" + print(f"Traffic request rate: {request_rate}") - print(f"Traffic request rate: {request_rate}") print(f"Burstiness factor: {burstiness} ({distribution})") print(f"Maximum request concurrency: {max_concurrency}") @@ -344,7 +397,29 @@ async def limited_request_func(request_func_input, pbar): benchmark_start_time = time.perf_counter() tasks: list[asyncio.Task] = [] - async for request in get_request(input_requests, request_rate, burstiness): + + rps_change_events = [] + last_int_rps = -1 + if ramp_up_strategy is not None and ramp_up_start_rps is not None: + last_int_rps = ramp_up_start_rps + rps_change_events.append({ + "rps": last_int_rps, + "timestamp": datetime.now().isoformat(), + }) + + async for request, current_request_rate in get_request( + input_requests, request_rate, burstiness, ramp_up_strategy, + ramp_up_start_rps, ramp_up_end_rps): + if ramp_up_strategy is not None: + current_int_rps = int(current_request_rate) + if current_int_rps > last_int_rps: + timestamp = datetime.now().isoformat() + for rps_val in range(last_int_rps + 1, current_int_rps + 1): + rps_change_events.append({ + "rps": rps_val, + "timestamp": timestamp + }) + last_int_rps = current_int_rps prompt, prompt_len, output_len, mm_content = ( request.prompt, request.prompt_len, @@ -435,6 +510,9 @@ async def limited_request_func(request_func_input, pbar): "errors": [output.error for output in outputs], } + if rps_change_events: + result["rps_change_events"] = rps_change_events + def process_one_metric( # E.g., "ttft" metric_attribute_name: str, @@ -771,12 +849,60 @@ def add_cli_args(parser: argparse.ArgumentParser): "launching the server. For each request, the " "script chooses a LoRA module at random.") + parser.add_argument( + "--ramp-up-strategy", + type=str, + default=None, + choices=["linear", "exponential"], + help="The ramp-up strategy. This would be used to " + "ramp up the request rate from initial RPS to final " + "RPS rate (specified by --ramp-up-start-rps and " + "--ramp-up-end-rps.) over the duration of the benchmark." + ) + parser.add_argument( + "--ramp-up-start-rps", + type=int, + default=None, + help="The starting request rate for ramp-up (RPS). " + "Needs to be specified when --ramp-up-strategy is used.", + ) + parser.add_argument( + "--ramp-up-end-rps", + type=int, + default=None, + help="The ending request rate for ramp-up (RPS). " + "Needs to be specified when --ramp-up-strategy is used.", + ) + def main(args: argparse.Namespace): print(args) random.seed(args.seed) np.random.seed(args.seed) + # Validate ramp-up arguments + if args.ramp_up_strategy is not None: + if args.request_rate != float("inf"): + raise ValueError( + "When using ramp-up, do not specify --request-rate. " + "The request rate will be controlled by ramp-up parameters. " + "Please remove the --request-rate argument." + ) + if args.ramp_up_start_rps is None or args.ramp_up_end_rps is None: + raise ValueError( + "When using --ramp-up-strategy, both --ramp-up-start-rps and " + "--ramp-up-end-rps must be specified" + ) + if args.ramp_up_start_rps < 0 or args.ramp_up_end_rps < 0: + raise ValueError("Ramp-up start and end RPS must be non-negative") + if args.ramp_up_start_rps > args.ramp_up_end_rps: + raise ValueError("Ramp-up start RPS must be less than end RPS") + if (args.ramp_up_strategy == "exponential" + and args.ramp_up_start_rps == 0): + raise ValueError( + "For exponential ramp-up, the start RPS cannot be 0.") + + endpoint_type = args.endpoint_type label = args.label model_id = args.model model_name = args.served_model_name @@ -849,6 +975,9 @@ def main(args: argparse.Namespace): max_concurrency=args.max_concurrency, lora_modules=args.lora_modules, extra_body=sampling_params, + ramp_up_strategy=args.ramp_up_strategy, + ramp_up_start_rps=args.ramp_up_start_rps, + ramp_up_end_rps=args.ramp_up_end_rps, )) # Save config and results to json @@ -881,6 +1010,11 @@ def main(args: argparse.Namespace): result_json["burstiness"] = args.burstiness result_json["max_concurrency"] = args.max_concurrency + if args.ramp_up_strategy is not None: + result_json["ramp_up_strategy"] = args.ramp_up_strategy + result_json["ramp_up_start_rps"] = args.ramp_up_start_rps + result_json["ramp_up_end_rps"] = args.ramp_up_end_rps + # Merge with benchmark result result_json = {**result_json, **benchmark_result} @@ -903,8 +1037,11 @@ def main(args: argparse.Namespace): base_model_id = model_id.split("/")[-1] max_concurrency_str = (f"-concurrency{args.max_concurrency}" if args.max_concurrency is not None else "") - label = label or args.endpoint_type - file_name = f"{label}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" #noqa + label = label or endpoint_type + if args.ramp_up_strategy is not None: + file_name = f"{label}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa + else: + file_name = f"{label}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa if args.result_filename: file_name = args.result_filename if args.result_dir: diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 8bb8c3a2a2e4..a2bb053cec4a 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -32,7 +32,7 @@ def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface: if compilation_config.use_inductor: if envs.VLLM_USE_STANDALONE_COMPILE and is_torch_equal_or_newer( - "2.8.0a"): + "2.8.0.dev"): logger.debug("Using InductorStandaloneAdaptor") return InductorStandaloneAdaptor() else: diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 9d908fcae3df..951a2861e3a4 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -345,8 +345,8 @@ def process(self): # 0 is always None fused_return_mapping = {1: (quant_node, 1), 2: (rms_node, 2)} self.insert_fused_node(fused_return_mapping, - epsilon=rms_node.kwargs["epsilon"], - **kwargs) + **kwargs, + epsilon=rms_node.kwargs["epsilon"]) class RMSNormDynamicQuantPattern(RMSNormQuantPattern): diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 28a59905ecf8..3ce00e3610c5 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -51,15 +51,15 @@ def configure(self, config: VllmConfig): if self.pass_config.enable_noop: self.passes += [NoOpEliminationPass(config)] - if self.pass_config.enable_fusion: - self.passes += [FusionPass.instance(config)] - self.passes += [ActivationQuantFusionPass(config)] - if self.pass_config.enable_sequence_parallelism: self.passes += [SequenceParallelismPass(config)] if self.pass_config.enable_async_tp: self.passes += [AsyncTPPass(config)] + if self.pass_config.enable_fusion: + self.passes += [FusionPass.instance(config)] + self.passes += [ActivationQuantFusionPass(config)] + if self.pass_config.enable_attn_fusion: self.passes += [AttnFusionPass(config)] diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py index d41093903480..6107046e40dc 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/sequence_parallelism.py @@ -12,91 +12,142 @@ from vllm.distributed.parallel_state import ( get_tensor_model_parallel_world_size) from vllm.logger import init_logger +from vllm.platforms import current_platform from .vllm_inductor_pass import VllmInductorPass logger = init_logger(__name__) -class AllReduceRMSNormPattern: +class _RMSNormAndQuantOpHelper: + """Base helper for RMSNorm and RMSNorm + Quantization functionalization.""" - def __init__(self, epsilon: float, dtype: torch.dtype, device: str): + def __init__(self, + epsilon: float, + dtype: torch.dtype, + device: str, + quant_op: Optional[torch._ops.OpOverload] = None, + **kwargs): self.epsilon = epsilon self.dtype = dtype self.device = device - - -class EmbeddingAllReduceRMSNormPattern(AllReduceRMSNormPattern): + self.quant_op = quant_op + + def _functional_rmsnorm(self, result_buffer, input_tensor, weight_tensor): + return torch.ops.higher_order.auto_functionalized( + torch.ops._C.rms_norm.default, + result=result_buffer, + input=input_tensor, + weight=weight_tensor, + epsilon=self.epsilon) + + def _functional_fused_add_rmsnorm(self, input_tensor, residual_tensor, + weight_tensor): + return torch.ops.higher_order.auto_functionalized( + torch.ops._C.fused_add_rms_norm.default, + input=input_tensor, + residual=residual_tensor, + weight=weight_tensor, + epsilon=self.epsilon) + + def _functional_rmsnorm_then_quant(self, rmsnorm_result_buffer, + quant_result_buffer, input_tensor, + weight_tensor, scale_tensor): + if self.quant_op is None: + raise RuntimeError( + "_RMSNormAndQuantOpHelper was not initialized with a quant_op." + ) + rmsnorm_out_tuple = self._functional_rmsnorm(rmsnorm_result_buffer, + input_tensor, + weight_tensor) + quant_out_tuple = torch.ops.higher_order.auto_functionalized( + self.quant_op, + result=quant_result_buffer, + input=rmsnorm_out_tuple[1], + scale=scale_tensor) + return quant_out_tuple + + def _functional_fused_add_rmsnorm_then_quant(self, quant_result_buffer, + input_tensor, residual_tensor, + weight_tensor, scale_tensor): + if self.quant_op is None: + raise RuntimeError( + "_RMSNormAndQuantOpHelper was not initialized with a quant_op." + ) + fused_add_rmsnorm_out_tuple = self._functional_fused_add_rmsnorm( + input_tensor, residual_tensor, weight_tensor) + quant_out_tuple = torch.ops.higher_order.auto_functionalized( + self.quant_op, + result=quant_result_buffer, + input=fused_add_rmsnorm_out_tuple[1], + scale=scale_tensor) + return quant_out_tuple, fused_add_rmsnorm_out_tuple[2] + + +class _SequenceParallelPatternHelper(_RMSNormAndQuantOpHelper): + """Helper for sequence parallelism patterns.""" + + def __init__(self, + epsilon: float, + dtype: torch.dtype, + device: str, + quant_op: Optional[torch._ops.OpOverload] = None, + **kwargs): + super().__init__(epsilon, dtype, device, quant_op=quant_op, **kwargs) + self.tp_group = get_tp_group() + self.tp_size = get_tensor_model_parallel_world_size() + + def _all_reduce(self, x: torch.Tensor) -> torch.Tensor: + return tensor_model_parallel_all_reduce(x) + + def _reduce_scatter(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.vllm.reduce_scatter.default( + x, + dim=0, + world_size=self.tp_size, + group_name=self.tp_group.unique_name) + + def _all_gather(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.vllm.all_gather.default( + x, + dim=0, + world_size=self.tp_size, + group_name=self.tp_group.unique_name) + + +class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper): def get_inputs(self): - arg2_1 = torch.empty([16, 4], device=self.device, dtype=self.dtype) - mul_6 = torch.tensor([[3, 7, 1, 4, 9, 2, 5, 0]], - device=self.device, - dtype=torch.long) - unsqueeze = torch.rand([1, 8, 1], device=self.device, \ - dtype=self.dtype) > 0.5 - full_default = torch.zeros([1, 8, 4], device=self.device, \ - dtype=self.dtype) + input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) permute = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype) - return [arg2_1, mul_6, unsqueeze, full_default, permute, arg3_1] + return [input, permute, arg3_1] def register(self, pm_pass: PatternMatcherPass): def pattern( - arg2_1: torch.Tensor, - mul_6: torch.Tensor, - unsqueeze: torch.Tensor, - full_default: torch.Tensor, + input: torch.Tensor, permute: torch.Tensor, arg3_1: torch.Tensor, ): - embedding = torch.ops.aten.embedding.default(arg2_1, mul_6) - where = torch.ops.aten.where.self(unsqueeze, full_default, - embedding) - all_reduce = tensor_model_parallel_all_reduce(where) - rmsnorm = torch.ops.higher_order.auto_functionalized( - torch.ops._C.rms_norm.default, - result=permute, - input=all_reduce, - weight=arg3_1, - epsilon=self.epsilon, - ) + all_reduce = self._all_reduce(input) + rmsnorm = self._functional_rmsnorm(permute, all_reduce, arg3_1) return rmsnorm[1], all_reduce def replacement( - arg2_1: torch.Tensor, - mul_6: torch.Tensor, - unsqueeze: torch.Tensor, - full_default: torch.Tensor, + input: torch.Tensor, permute: torch.Tensor, arg3_1: torch.Tensor, ): - embedding = torch.ops.aten.embedding.default(arg2_1, mul_6) - where = torch.ops.aten.where.self(unsqueeze, full_default, - embedding) - - tp = get_tp_group() - tp_size = get_tensor_model_parallel_world_size() - reduce_scatter = torch.ops.vllm.reduce_scatter.default( - where, dim=0, world_size=tp_size, group_name=tp.unique_name) + reduce_scatter = self._reduce_scatter(input) rmsnorm_result = torch.empty_like(reduce_scatter) - rmsnorm = torch.ops.higher_order.auto_functionalized( - torch.ops._C.rms_norm.default, - result=rmsnorm_result, - input=reduce_scatter, - weight=arg3_1, - epsilon=self.epsilon, - ) + rmsnorm = self._functional_rmsnorm(rmsnorm_result, reduce_scatter, + arg3_1) - all_gather = torch.ops.vllm.all_gather.default( - rmsnorm[1], - dim=0, - world_size=tp_size, - group_name=tp.unique_name) + all_gather = self._all_gather(rmsnorm[1]) return all_gather, reduce_scatter @@ -104,7 +155,7 @@ def replacement( pm.fwd_only, pm_pass) -class MiddleAllReduceRMSNormPattern(AllReduceRMSNormPattern): +class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper): def get_inputs(self): mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) @@ -127,16 +178,9 @@ def pattern( mm_1: torch.Tensor, rms_norm_weights: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - all_reduce = tensor_model_parallel_all_reduce(mm_1) - - rmsnorm = torch.ops.higher_order.auto_functionalized( - torch.ops._C.fused_add_rms_norm.default, - input=all_reduce, - residual=residual, - weight=rms_norm_weights, - epsilon=self.epsilon, - ) - + all_reduce = self._all_reduce(mm_1) + rmsnorm = self._functional_fused_add_rmsnorm( + all_reduce, residual, rms_norm_weights) return rmsnorm[1], rmsnorm[2] def replacement( @@ -144,32 +188,17 @@ def replacement( mm_1: torch.Tensor, rms_norm_weights: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - tp = get_tp_group() - tp_size = get_tensor_model_parallel_world_size() - reduce_scatter = torch.ops.vllm.reduce_scatter.default( - mm_1, dim=0, world_size=tp_size, group_name=tp.unique_name) - - # TODO is it possible to extract epsilon from somewhere - rmsnorm = torch.ops.higher_order.auto_functionalized( - torch.ops._C.fused_add_rms_norm.default, - input=reduce_scatter, - residual=residual, - weight=rms_norm_weights, - epsilon=self.epsilon, - ) - - all_gather = torch.ops.vllm.all_gather.default( - rmsnorm[1], - dim=0, - world_size=tp_size, - group_name=tp.unique_name) + reduce_scatter = self._reduce_scatter(mm_1) + rmsnorm = self._functional_fused_add_rmsnorm( + reduce_scatter, residual, rms_norm_weights) + all_gather = self._all_gather(rmsnorm[1]) return all_gather, rmsnorm[2] pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass) -class LastAllReduceRMSNormPattern(AllReduceRMSNormPattern): +class LastAllReduceRMSNormPattern(_SequenceParallelPatternHelper): def get_inputs(self): mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) @@ -192,16 +221,9 @@ def pattern( mm_1: torch.Tensor, rms_norm_weights: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - all_reduce = tensor_model_parallel_all_reduce(mm_1) - - rmsnorm = torch.ops.higher_order.auto_functionalized( - torch.ops._C.fused_add_rms_norm.default, - input=all_reduce, - residual=residual, - weight=rms_norm_weights, - epsilon=self.epsilon, - ) - + all_reduce = self._all_reduce(mm_1) + rmsnorm = self._functional_fused_add_rmsnorm( + all_reduce, residual, rms_norm_weights) return rmsnorm[1] def replacement( @@ -209,26 +231,185 @@ def replacement( mm_1: torch.Tensor, rms_norm_weights: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - tp = get_tp_group() - tp_size = get_tensor_model_parallel_world_size() - reduce_scatter = torch.ops.vllm.reduce_scatter.default( - mm_1, dim=0, world_size=tp_size, group_name=tp.unique_name) - - # TODO is it possible to extract epsilon from somewhere - rmsnorm = torch.ops.higher_order.auto_functionalized( - torch.ops._C.fused_add_rms_norm.default, - input=reduce_scatter, - residual=residual, - weight=rms_norm_weights, - epsilon=self.epsilon, - ) + reduce_scatter = self._reduce_scatter(mm_1) + rmsnorm = self._functional_fused_add_rmsnorm( + reduce_scatter, residual, rms_norm_weights) + normalized = self._all_gather(rmsnorm[1]) + return normalized + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +FP8_DTYPE = current_platform.fp8_dtype() + + +class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): + + def __init__(self, epsilon: float, dtype: torch.dtype, device: str, + op: torch._ops.OpOverload): + super().__init__(epsilon, dtype, device, quant_op=op) + + def get_inputs(self): + input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype) + rmsnorm_result = torch.empty([1, 8, 4], + device=self.device, + dtype=self.dtype) + quant_result = torch.empty([1, 8, 4], + device=self.device, + dtype=FP8_DTYPE) + weight = torch.empty([4], device=self.device, dtype=self.dtype) + scale = torch.tensor(1.0, device=self.device, dtype=torch.float32) + return [input, rmsnorm_result, quant_result, weight, scale] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + input: torch.Tensor, + rmsnorm_result: torch.Tensor, + quant_result: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): + all_reduce = self._all_reduce(input) + static_fp8 = self._functional_rmsnorm_then_quant( + rmsnorm_result, quant_result, all_reduce, weight, scale) + return static_fp8[1], all_reduce + + def replacement( + input: torch.Tensor, + rmsnorm_result: torch.Tensor, + quant_result: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): + reduce_scatter = self._reduce_scatter(input) + + rmsnorm_result = torch.empty_like(reduce_scatter, + dtype=rmsnorm_result.dtype) + quant_result = torch.empty_like( + rmsnorm_result, # Output of RMSNorm + dtype=quant_result.dtype) + static_fp8 = self._functional_rmsnorm_then_quant( + rmsnorm_result, quant_result, reduce_scatter, weight, scale) + all_gather = self._all_gather(static_fp8[1]) + + return all_gather, reduce_scatter + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): + + def __init__(self, epsilon: float, dtype: torch.dtype, device: str, + op: torch._ops.OpOverload): + super().__init__(epsilon, dtype, device, quant_op=op) + + def get_inputs(self): + mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) + + residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) + rms_norm_weights = torch.empty([4, 4], + device=self.device, + dtype=self.dtype) + result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE) + scale = torch.empty([1, 1], device=self.device, dtype=torch.float32) + + return [ + result, + residual, + mm_1, + rms_norm_weights, + scale, + ] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + result: torch.Tensor, + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + scale: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + all_reduce = self._all_reduce(mm_1) + static_fp8, rmsnorm_residual_out = self._functional_fused_add_rmsnorm_then_quant( # noqa: E501 + result, all_reduce, residual, rms_norm_weights, scale) + return static_fp8[1], rmsnorm_residual_out + + def replacement( + result: torch.Tensor, + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + scale: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + reduce_scatter = self._reduce_scatter(mm_1) + quant_result_buf = torch.empty_like(reduce_scatter, + dtype=result.dtype) + static_fp8, rmsnorm_residual_out = self._functional_fused_add_rmsnorm_then_quant( # noqa: E501 + quant_result_buf, reduce_scatter, residual, rms_norm_weights, + scale) + all_gather = self._all_gather(static_fp8[1]) + return all_gather, rmsnorm_residual_out + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class LastAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): + + def __init__(self, epsilon: float, dtype: torch.dtype, device: str, + op: torch._ops.OpOverload): + super().__init__(epsilon, dtype, device, quant_op=op) + + def get_inputs(self): + mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) + + residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) + rms_norm_weights = torch.empty([4, 4], + device=self.device, + dtype=self.dtype) + result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE) + scale = torch.empty([1, 1], device=self.device, dtype=torch.float32) + + return [ + result, + residual, + mm_1, + rms_norm_weights, + scale, + ] + + def register(self, pm_pass: PatternMatcherPass): - normalized = torch.ops.vllm.all_gather.default( - rmsnorm[1], - dim=0, - world_size=tp_size, - group_name=tp.unique_name) + def pattern( + result: torch.Tensor, + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + scale: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + all_reduce = self._all_reduce(mm_1) + static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant( + result, all_reduce, residual, rms_norm_weights, scale) + return static_fp8[1] + def replacement( + result: torch.Tensor, + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + scale: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + reduce_scatter = self._reduce_scatter(mm_1) + quant_result_buf = torch.empty_like(reduce_scatter, + dtype=result.dtype) + static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant( + quant_result_buf, reduce_scatter, residual, rms_norm_weights, + scale) + normalized = self._all_gather(static_fp8[1]) return normalized pm.register_replacement(pattern, replacement, self.get_inputs(), @@ -236,21 +417,54 @@ def replacement( class SequenceParallelismPass(VllmInductorPass): + """ + This pass enables sequence parallelism for models. + It identifies patterns where an AllReduce operation is followed by + an RMSNorm (or RMSNorm and then Quantization) operation. + These patterns are replaced with a ReduceScatter operation, followed by + a local RMSNorm/Quantization, and then an AllGather operation. + + The general transformation is: + Input -> AllReduce -> RMSNorm -> Output + becomes + Input -> ReduceScatter -> RMSNorm -> AllGather -> Output + + While this pass itself does not directly yield performance improvements, + it lays the groundwork for subsequent fusion passes, such as + GEMM + ReduceScatter and AllGather + GEMM fusions. These fusions can + significantly reduce communication overhead and improve overall model + performance. + """ def __init__(self, config: VllmConfig): super().__init__(config) self.patterns: PatternMatcherPass = PatternMatcherPass( pass_name="sequence_parallelism_pass") + for epsilon in [1e-5, 1e-6]: - EmbeddingAllReduceRMSNormPattern( - epsilon, self.model_dtype, self.device).register(self.patterns) + # RMSNorm + Static FP8 quantization patterns + fp8_quant_op = torch.ops._C.static_scaled_fp8_quant.default + FirstAllReduceRMSNormStaticFP8Pattern( + epsilon, self.model_dtype, self.device, + fp8_quant_op).register(self.patterns) + MiddleAllReduceRMSNormStaticFP8Pattern( + epsilon, self.model_dtype, self.device, + fp8_quant_op).register(self.patterns) + LastAllReduceRMSNormStaticFP8Pattern( + epsilon, self.model_dtype, self.device, + fp8_quant_op).register(self.patterns) + + # Normal RMSNorm patterns + FirstAllReduceRMSNormPattern(epsilon, self.model_dtype, + self.device).register(self.patterns) MiddleAllReduceRMSNormPattern(epsilon, self.model_dtype, self.device).register(self.patterns) LastAllReduceRMSNormPattern(epsilon, self.model_dtype, self.device).register(self.patterns) + # WARNING: This is a hack to clear the pattern matcher cache # and allow multiple values of epsilon. torch._inductor.pattern_matcher._seen_patterns.clear() diff --git a/vllm/config.py b/vllm/config.py index 684d81e0a267..6883ec29a184 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -27,19 +27,13 @@ from pydantic.dataclasses import dataclass from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE from torch.distributed import ProcessGroup, ReduceOp -from transformers import PretrainedConfig from typing_extensions import Self, deprecated, runtime_checkable import vllm.envs as envs from vllm import version from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass from vllm.logger import init_logger -from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS, - QuantizationMethods, - get_quantization_config) -from vllm.model_executor.models import ModelRegistry from vllm.platforms import current_platform -from vllm.tracing import is_otel_available, otel_import_error_traceback from vllm.transformers_utils.config import ( ConfigFormat, get_config, get_hf_image_processor_config, get_hf_text_config, get_pooling_config, @@ -48,32 +42,49 @@ try_get_tokenizer_config, uses_mrope) from vllm.transformers_utils.s3_utils import S3Model from vllm.transformers_utils.utils import is_s3, maybe_model_redirect +# yapf conflicts with isort for this block +# yapf: disable from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS, MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, POOLING_MODEL_MAX_NUM_BATCHED_TOKENS, GiB_bytes, - LayerBlockType, common_broadcastable_dtype, + LayerBlockType, LazyLoader, common_broadcastable_dtype, cuda_device_count_stateless, get_cpu_memory, get_open_port, is_torch_equal_or_newer, random_uuid, resolve_obj_by_qualname) +# yapf: enable + if TYPE_CHECKING: from _typeshed import DataclassInstance from ray.util.placement_group import PlacementGroup + from transformers.configuration_utils import PretrainedConfig + import vllm.model_executor.layers.quantization as me_quant + import vllm.model_executor.models as me_models from vllm.executor.executor_base import ExecutorBase + from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.model_loader import BaseModelLoader from vllm.model_executor.model_loader.tensorizer import TensorizerConfig ConfigType = type[DataclassInstance] + HfOverrides = Union[dict, Callable[[type], type]] else: PlacementGroup = Any + PretrainedConfig = Any ExecutorBase = Any QuantizationConfig = Any + QuantizationMethods = Any BaseModelLoader = Any TensorizerConfig = Any ConfigType = type + HfOverrides = Union[dict[str, Any], Callable[[type], type]] + + me_quant = LazyLoader("model_executor", globals(), + "vllm.model_executor.layers.quantization") + me_models = LazyLoader("model_executor", globals(), + "vllm.model_executor.models") logger = init_logger(__name__) @@ -100,9 +111,6 @@ for task in tasks } -HfOverrides = Union[dict[str, Any], Callable[[PretrainedConfig], - PretrainedConfig]] - @runtime_checkable class SupportsHash(Protocol): @@ -538,10 +546,10 @@ def __post_init__(self) -> None: self.code_revision, self.config_format) if hf_overrides_kw: - logger.info("Overriding HF config with %s", hf_overrides_kw) + logger.debug("Overriding HF config with %s", hf_overrides_kw) hf_config.update(hf_overrides_kw) if hf_overrides_fn: - logger.info("Overriding HF config with %s", hf_overrides_fn) + logger.debug("Overriding HF config with %s", hf_overrides_fn) hf_config = hf_overrides_fn(hf_config) self.hf_config = hf_config @@ -648,7 +656,7 @@ def validate_model_config_after(self: "ModelConfig") -> "ModelConfig": @property def registry(self): - return ModelRegistry + return me_models.ModelRegistry @property def architectures(self) -> list[str]: @@ -859,14 +867,15 @@ def _parse_quant_hf_config(self): return quant_cfg def _verify_quantization(self) -> None: - supported_quantization = QUANTIZATION_METHODS + supported_quantization = me_quant.QUANTIZATION_METHODS optimized_quantization_methods = [ "fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin", "awq_marlin", "fbgemm_fp8", "compressed-tensors", "experts_int8", "quark", "modelopt_fp4", "bitblas", "gptq_bitblas" ] if self.quantization is not None: - self.quantization = cast(QuantizationMethods, self.quantization) + self.quantization = cast(me_quant.QuantizationMethods, + self.quantization) # Parse quantization method from the HF model config, if available. quant_cfg = self._parse_quant_hf_config() @@ -900,14 +909,14 @@ def _verify_quantization(self) -> None: # Detect which checkpoint is it for name in quantization_methods: - method = get_quantization_config(name) + method = me_quant.get_quantization_config(name) quantization_override = method.override_quantization_method( quant_cfg, self.quantization) if quantization_override is not None: # Raise error if the override is not custom (custom would # be in QUANTIZATION_METHODS but not QuantizationMethods) # and hasn't been added to the overrides list. - if (name in get_args(QuantizationMethods) + if (name in get_args(me_quant.QuantizationMethods) and name not in overrides): raise ValueError( f"Quantization method {name} is an override but " @@ -1417,7 +1426,7 @@ def runner_type(self) -> RunnerType: @property def is_v1_compatible(self) -> bool: architectures = getattr(self.hf_config, "architectures", []) - return ModelRegistry.is_v1_compatible(architectures) + return me_models.ModelRegistry.is_v1_compatible(architectures) @property def is_matryoshka(self) -> bool: @@ -1938,8 +1947,8 @@ def __post_init__(self) -> None: if get_current_placement_group(): backend = "ray" self.distributed_executor_backend = backend - logger.info("Defaulting to use %s for distributed inference", - backend) + logger.debug("Defaulting to use %s for distributed inference", + backend) if self.distributed_executor_backend is None and self.world_size == 1: self.distributed_executor_backend = "uni" @@ -2376,7 +2385,7 @@ class SpeculativeConfig: according to the log probability settings in SamplingParams.""" # Draft model configuration - quantization: Optional[QuantizationMethods] = None + quantization: Optional[me_quant.QuantizationMethods] = None """Quantization method that was used to quantize the draft model weights. If `None`, we assume the model weights are not quantized. Note that it only takes effect when using the draft model-based speculative method.""" @@ -3624,6 +3633,7 @@ def __post_init__(self): and "," in self.collect_detailed_traces[0]): self._parse_collect_detailed_traces() + from vllm.tracing import is_otel_available, otel_import_error_traceback if not is_otel_available() and self.otlp_traces_endpoint is not None: raise ValueError( "OpenTelemetry is not available. Unable to configure " @@ -3802,11 +3812,11 @@ class PassConfig: its own stages (before, after, maybe in-between).""" dump_graph_dir: Path = Path(".") """Directory to dump the graphs.""" - enable_fusion: bool = True + enable_fusion: bool = field(default_factory=lambda: not envs.VLLM_USE_V1) """Whether to enable the custom fusion (RMSNorm/SiluMul+quant) pass.""" enable_attn_fusion: bool = False """Whether to enable the custom attention+quant fusion pass.""" - enable_noop: bool = True + enable_noop: bool = field(default_factory=lambda: not envs.VLLM_USE_V1) """Whether to enable the custom no-op elimination pass.""" enable_sequence_parallelism: bool = False """Whether to enable sequence parallelism.""" diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 2d80cbf2b24f..a962a9241d73 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -2,11 +2,13 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import contextlib import math +import queue import threading import time import uuid from collections import defaultdict from collections.abc import Iterator +from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional @@ -23,6 +25,7 @@ get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tp_group) from vllm.distributed.utils import divide +from vllm.forward_context import ForwardContext from vllm.logger import init_logger from vllm.platforms import _Backend from vllm.utils import make_zmq_path, make_zmq_socket, round_down @@ -31,11 +34,12 @@ if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata - from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.request import Request Transfer = tuple[int, float] # (xfer_handle, start_time) +EngineId = str +ReqId = str GET_META_MSG = b"get_meta_msg" logger = init_logger(__name__) @@ -69,17 +73,17 @@ class ReqMeta: remote_block_ids: list[int] remote_host: str remote_port: int - remote_engine_id: str + remote_engine_id: EngineId class NixlConnectorMetadata(KVConnectorMetadata): def __init__(self): - self.requests: dict[str, ReqMeta] = {} + self.requests: dict[ReqId, ReqMeta] = {} def add_new_req( self, - request_id: str, + request_id: ReqId, local_block_ids: list[int], kv_transfer_params: dict[str, Any], ): @@ -96,16 +100,17 @@ class NixlConnector(KVConnectorBase_V1): def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): assert vllm_config.kv_transfer_config is not None - self.engine_id = vllm_config.kv_transfer_config.engine_id + assert vllm_config.kv_transfer_config.engine_id is not None + self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id if role == KVConnectorRole.SCHEDULER: - self.connector_scheduler : Optional[NixlConnectorScheduler] = \ - NixlConnectorScheduler(vllm_config, str(self.engine_id)) + self.connector_scheduler: Optional[NixlConnectorScheduler] = \ + NixlConnectorScheduler(vllm_config, self.engine_id) self.connector_worker: Optional[NixlConnectorWorker] = None elif role == KVConnectorRole.WORKER: self.connector_scheduler = None self.connector_worker = NixlConnectorWorker( - vllm_config, str(self.engine_id)) + vllm_config, self.engine_id) ############################################################ # Scheduler Side Methods @@ -179,18 +184,18 @@ class NixlConnectorScheduler: def __init__(self, vllm_config: VllmConfig, engine_id: str): self.vllm_config = vllm_config self.block_size = vllm_config.cache_config.block_size - self.engine_id = engine_id + self.engine_id: EngineId = engine_id self.side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST self.side_channel_port = ( envs.VLLM_NIXL_SIDE_CHANNEL_PORT + - vllm_config.parallel_config.data_parallel_rank_local * + vllm_config.parallel_config.data_parallel_rank * vllm_config.parallel_config.tensor_parallel_size) logger.info("Initializing NIXL Scheduler %s", engine_id) # Requests that need to start recv. # New requests are added by update_state_after_alloc in # the scheduler. Used to make metadata passed to Worker. - self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {} + self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {} def get_num_new_matched_tokens( self, request: "Request", @@ -332,19 +337,19 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # Agent. self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None) # Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}. - self._remote_agents: dict[str, dict[int, str]] = defaultdict(dict) + self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict) # NIXL handshake port. # NOTE(rob): Within a DP group, each DP rank gets its own # base port (which is sent in the KVTransferParams). # Each TP rank listens/queries on the base_port + tp_rank. - self.side_channel_port = ( + self.side_channel_port: int = ( envs.VLLM_NIXL_SIDE_CHANNEL_PORT + - vllm_config.parallel_config.data_parallel_rank_local * + vllm_config.parallel_config.data_parallel_rank * vllm_config.parallel_config.tensor_parallel_size) # Metadata. - self.engine_id = engine_id + self.engine_id: EngineId = engine_id self.tp_rank = get_tensor_model_parallel_rank() self.world_size = get_tensor_model_parallel_world_size() self.tp_group = get_tp_group() @@ -354,7 +359,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # Map of engine_id -> kv_caches_base_addr. For TP case, each local # rank will still only pull from a single remote TP worker. - self.kv_caches_base_addr: dict[str, list[int]] = {} + self.kv_caches_base_addr: dict[EngineId, list[int]] = {} # Number of NIXL regions. Currently one region per cache # (so 1 per layer for MLA, otherwise 2 per layer) @@ -364,27 +369,36 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # nixl_prepped_dlist_handle. self.src_xfer_side_handle: int = 0 # Map of engine_id -> nixl_prepped_dlist_handle (int)]. - self.dst_xfer_side_handles: dict[str, int] = {} + self.dst_xfer_side_handles: dict[EngineId, int] = {} # Map of engine_id -> num_blocks. All ranks in the same deployment will # have the same number of blocks. - self.dst_num_blocks: dict[str, int] = {} + self.dst_num_blocks: dict[EngineId, int] = {} self._registered_descs: list[Any] = [] # In progress transfers. # [req_id -> list[handle]] - self._recving_transfers = defaultdict[str, list[Transfer]](list) + self._recving_transfers = defaultdict[ReqId, list[Transfer]](list) # Complete transfer tracker. Used by the rank 0 to track finished # transactions on ranks 1 to N-1. # [req_id -> count] - self._done_recving_count: defaultdict[str, + self._done_recving_count: defaultdict[ReqId, int] = defaultdict(lambda: 0) - self._done_sending_count: defaultdict[str, + self._done_sending_count: defaultdict[ReqId, int] = defaultdict(lambda: 0) - # Background thread for establishing new connections. + # Background thread for handling new handshake requests. self._nixl_handshake_listener_t: Optional[threading.Thread] = None + # Background thread for initializing new NIXL handshakes. + self._handshake_initiation_executor = ThreadPoolExecutor( + # NIXL is not guaranteed to be thread-safe, limit 1 worker. + max_workers=1, + thread_name_prefix="vllm-nixl-handshake-initiator") + self._ready_requests = queue.Queue[tuple[ReqId, ReqMeta]]() + self._handshake_futures: dict[EngineId, Future[dict[int, str]]] = {} + # Protects _handshake_futures and _remote_agents. + self._handshake_lock = threading.RLock() self.vllm_config = vllm_config self.block_size = vllm_config.cache_config.block_size @@ -408,10 +422,16 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self._use_flashinfer = attn_backend == _Backend.FLASHINFER_VLLM_V1 logger.debug("Detected attention backend %s", self.backend_name) - self._tp_size: dict[str, int] = {self.engine_id: self.world_size} + self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size} # With heterogeneous TP, P must wait for all assigned D TP workers to # finish reading before safely freeing the blocks. - self.consumer_notification_counts_by_req = defaultdict[str, int](int) + self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int) + + def __del__(self): + """Cleanup background threads on destruction.""" + self._handshake_initiation_executor.shutdown(wait=False) + if self._nixl_handshake_listener_t: + self._nixl_handshake_listener_t.join(timeout=0) @staticmethod def _nixl_handshake_listener(metadata: NixlAgentMetadata, @@ -440,7 +460,7 @@ def _nixl_handshake_listener(metadata: NixlAgentMetadata, "Connection listener got unexpected message %s", msg) sock.send_multipart((identity, b"", encoded_data)) - def _nixl_handshake(self, host: str, port: int): + def _nixl_handshake(self, host: str, port: int) -> dict[int, str]: """Do a NIXL handshake with a remote instance.""" start_time = time.perf_counter() @@ -449,7 +469,7 @@ def _nixl_handshake(self, host: str, port: int): # a hack to keep us moving. We will switch when moving to etcd # or where we have a single ZMQ socket in the scheduler. - def handshake(path: str, rank: int) -> NixlAgentMetadata: + def handshake(path: str, rank: int) -> tuple[NixlAgentMetadata, str]: # Send query for the request. with zmq_ctx(zmq.REQ, path) as sock: sock.send(GET_META_MSG) @@ -459,19 +479,20 @@ def handshake(path: str, rank: int) -> NixlAgentMetadata: got_metadata_time = time.perf_counter() # Register Remote agent. - self.add_remote_agent(metadata, rank) + remote_agent_name = self.add_remote_agent(metadata, rank) setup_agent_time = time.perf_counter() logger.debug("NIXL handshake: get metadata took: %s", got_metadata_time - start_time) logger.debug("NIXL handshake: add agent took: %s", setup_agent_time - got_metadata_time) - return metadata + return metadata, remote_agent_name # Handshake with remote agent-rank0 first to get the tp_size of remote path = make_zmq_path("tcp", host, port) logger.debug("Querying master rank metadata on path: %s", path) - metadata = handshake(path, 0) + rank_to_agent_name: dict[int, str] = {} + metadata, rank_to_agent_name[0] = handshake(path, 0) # Handshake only with the other TP remote the current local rank will # pull from. With homogeneous TP it happens to be the same rank_i. @@ -481,7 +502,10 @@ def handshake(path: str, rank: int) -> NixlAgentMetadata: path = make_zmq_path("tcp", host, port + p_remote_rank) logger.debug("Querying metadata on path: %s at remote rank %s", path, p_remote_rank) - _ = handshake(path, p_remote_rank) + _, rank_to_agent_name[p_remote_rank] = handshake( + path, p_remote_rank) + + return rank_to_agent_name def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data in nixl.""" @@ -618,11 +642,11 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): daemon=True, name="nixl_handshake_listener") self._nixl_handshake_listener_t.start() - ready_event.wait() + ready_event.wait() # Wait for listener ZMQ socket to be ready. def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, - remote_tp_rank: int = 0): + remote_tp_rank: int = 0) -> str: """ Add the remote NIXL agent and prepare the descriptors for reading cache blocks from remote. @@ -663,8 +687,8 @@ def add_remote_agent(self, """ # noqa: E501 engine_id = nixl_agent_meta.engine_id # TODO re-evaluate refreshing for scaling/recovery - if remote_tp_rank in self._remote_agents.get(engine_id, ()): - return + if remote_tp_rank in self._remote_agents.get(engine_id, {}): + return self._remote_agents[engine_id][remote_tp_rank] if engine_id in self._tp_size: assert self._tp_size[engine_id] == nixl_agent_meta.tp_size @@ -674,9 +698,8 @@ def add_remote_agent(self, # layout and close outputs. assert nixl_agent_meta.attn_backend_name == self.backend_name - self._remote_agents[engine_id][ - remote_tp_rank] = self.nixl_wrapper.add_remote_agent( - nixl_agent_meta.agent_metadata) + remote_agent_name = self.nixl_wrapper.add_remote_agent( + nixl_agent_meta.agent_metadata) # Number of D TP workers reading from a single P TP worker. This is # 1 when P and D `--tensor-parallel-size` match. @@ -705,8 +728,9 @@ def add_remote_agent(self, "local_kv_heads*tp_ratio, block_size, head_dim] and same dtype." ) - assert self.block_size == remote_block_size, "Remote P worker with " \ - "different block size is not supported" + assert self.block_size == remote_block_size, ( + "Remote P worker with different block size is not supported " + f"{self.block_size=} {remote_block_size=}") # Create dst descs and xfer side handles. TP workers have same #blocks. if engine_id in self.dst_num_blocks: @@ -745,7 +769,9 @@ def add_remote_agent(self, descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") self.dst_xfer_side_handles[ engine_id] = self.nixl_wrapper.prep_xfer_dlist( - self._remote_agents[engine_id][remote_tp_rank], descs) + remote_agent_name, descs) + + return remote_agent_name def get_finished(self) -> tuple[set[str], set[str]]: """ @@ -863,33 +889,68 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): We check for these trnxs to complete in each step(). """ for req_id, meta in metadata.requests.items(): + remote_engine_id = meta.remote_engine_id logger.debug( "start_load_kv for request %s from remote engine %s. " "Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id, - meta.remote_engine_id, len(meta.local_block_ids), + remote_engine_id, len(meta.local_block_ids), len(meta.remote_block_ids)) - self._read_blocks( - request_id=req_id, - dst_engine_id=meta.remote_engine_id, - local_block_ids=meta.local_block_ids, - remote_block_ids=meta.remote_block_ids, - remote_host=meta.remote_host, - remote_port=meta.remote_port, - ) + if remote_engine_id not in self._remote_agents: + # Being optimistic to assume engine is usually ready, apply + # lock only when the optimistic check fails. + with self._handshake_lock: + if remote_engine_id not in self._remote_agents: + fut = self._handshake_futures.get(remote_engine_id) + if fut is None: + fut = self._handshake_initiation_executor.submit( + self._nixl_handshake, meta.remote_host, + meta.remote_port) + self._handshake_futures[remote_engine_id] = fut + + def done_callback(f: Future[dict[int, str]], + eid=remote_engine_id): + with self._handshake_lock: + del self._handshake_futures[eid] + try: + self._remote_agents[eid] = f.result() + except Exception: + logger.exception( + "Handshake with %s failed", eid) + + fut.add_done_callback(done_callback) + + # TODO: handle failure state of future in the + # callback, we want to fail the request in this case. + def request_ready(_f: Future[Any], + entry=(req_id, meta)): + self._ready_requests.put(entry) + + fut.add_done_callback(request_ready) + continue + self._read_blocks_for_req(req_id, meta) + + # Start transfers for requests whose handshakes have now finished. + while not self._ready_requests.empty(): + self._read_blocks_for_req(*self._ready_requests.get_nowait()) + + def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): + logger.debug( + "Remote agent %s available, calling _read_blocks for req %s", + meta.remote_engine_id, req_id) + self._read_blocks( + request_id=req_id, + dst_engine_id=meta.remote_engine_id, + local_block_ids=meta.local_block_ids, + remote_block_ids=meta.remote_block_ids, + ) def _read_blocks( self, local_block_ids: list[int], remote_block_ids: list[int], - remote_host: str, - remote_port: int, dst_engine_id: str, request_id: str, ): - # NOTE(rob): this takes ~2s. We need to get this off the hotpath. - if dst_engine_id not in self._remote_agents: - self._nixl_handshake(remote_host, remote_port) - # NOTE(rob): having the staging blocks be on the READER side is # not going to work well (since we will have to call rearrange tensors). # after we detect the txn is complete (which means we cannot make the diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index dd09f514906d..9d1008b6b350 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1289,11 +1289,6 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: recommend_to_remove=True) return False - if self.scheduling_policy != SchedulerConfig.policy: - _raise_or_fallback(feature_name="--scheduling-policy", - recommend_to_remove=False) - return False - if self.num_scheduler_steps != SchedulerConfig.num_scheduler_steps: _raise_or_fallback(feature_name="--num-scheduler-steps", recommend_to_remove=True) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 05e0be61adad..63967e4d2d4b 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1568,6 +1568,8 @@ def _run_engine( pbar.update(n) else: pbar.update(1) + if pbar.n == num_requests: + pbar.refresh() if use_tqdm: pbar.close() diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 62f1c6a7c12b..681633a2aff7 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -73,6 +73,8 @@ TokenizeResponse, TranscriptionRequest, TranscriptionResponse, + TranslationRequest, + TranslationResponse, UnloadLoRAAdapterRequest) # yapf: enable from vllm.entrypoints.openai.serving_chat import OpenAIServingChat @@ -88,7 +90,7 @@ from vllm.entrypoints.openai.serving_tokenization import ( OpenAIServingTokenization) from vllm.entrypoints.openai.serving_transcription import ( - OpenAIServingTranscription) + OpenAIServingTranscription, OpenAIServingTranslation) from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.entrypoints.utils import (cli_env_setup, load_aware_call, with_cancellation) @@ -401,6 +403,10 @@ def transcription(request: Request) -> OpenAIServingTranscription: return request.app.state.openai_serving_transcription +def translation(request: Request) -> OpenAIServingTranslation: + return request.app.state.openai_serving_translation + + def engine_client(request: Request) -> EngineClient: return request.app.state.engine_client @@ -774,6 +780,47 @@ async def create_transcriptions(raw_request: Request, return StreamingResponse(content=generator, media_type="text/event-stream") +@router.post("/v1/audio/translations", + responses={ + HTTPStatus.OK.value: { + "content": { + "text/event-stream": {} + } + }, + HTTPStatus.BAD_REQUEST.value: { + "model": ErrorResponse + }, + HTTPStatus.UNPROCESSABLE_ENTITY.value: { + "model": ErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": ErrorResponse + }, + }) +@with_cancellation +@load_aware_call +async def create_translations(request: Annotated[TranslationRequest, + Form()], + raw_request: Request): + handler = translation(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Translations API") + + audio_data = await request.file.read() + generator = await handler.create_translation(audio_data, request, + raw_request) + + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + + elif isinstance(generator, TranslationResponse): + return JSONResponse(content=generator.model_dump()) + + return StreamingResponse(content=generator, media_type="text/event-stream") + + @router.post("/rerank", dependencies=[Depends(validate_json_request)], responses={ @@ -1190,6 +1237,7 @@ async def init_app_state( tool_parser=args.tool_call_parser, reasoning_parser=args.reasoning_parser, enable_prompt_tokens_details=args.enable_prompt_tokens_details, + enable_force_include_usage=args.enable_force_include_usage, ) if model_config.runner_type == "generate" else None state.openai_serving_completion = OpenAIServingCompletion( engine_client, @@ -1197,6 +1245,7 @@ async def init_app_state( state.openai_serving_models, request_logger=request_logger, return_tokens_as_token_ids=args.return_tokens_as_token_ids, + enable_force_include_usage=args.enable_force_include_usage, ) if model_config.runner_type == "generate" else None state.openai_serving_pooling = OpenAIServingPooling( engine_client, @@ -1246,6 +1295,12 @@ async def init_app_state( state.openai_serving_models, request_logger=request_logger, ) if model_config.runner_type == "transcription" else None + state.openai_serving_translation = OpenAIServingTranslation( + engine_client, + model_config, + state.openai_serving_models, + request_logger=request_logger, + ) if model_config.runner_type == "transcription" else None state.task = model_config.task state.enable_server_load_tracking = args.enable_server_load_tracking diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index ca70e78df326..dd4bd53046a3 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -272,6 +272,11 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: action='store_true', default=False, help="If set to True, enable prompt_tokens_details in usage.") + parser.add_argument( + "--enable-force-include-usage", + action='store_true', + default=False, + help="If set to True, including usage on every request.") parser.add_argument( "--enable-server-load-tracking", action='store_true', diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index b278d0d00586..3b5281962b2d 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1947,3 +1947,190 @@ class TranscriptionResponseVerbose(OpenAIBaseModel): words: Optional[list[TranscriptionWord]] = None """Extracted words and their corresponding timestamps.""" + + +class TranslationResponseStreamChoice(OpenAIBaseModel): + delta: DeltaMessage + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = None + + +class TranslationStreamResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"trsl-{random_uuid()}") + object: Literal["translation.chunk"] = "translation.chunk" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: list[TranslationResponseStreamChoice] + usage: Optional[UsageInfo] = Field(default=None) + + +class TranslationRequest(OpenAIBaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/audio/createTranslation + + file: UploadFile + """ + The audio file object (not file name) to translate, in one of these + formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm. + """ + + model: Optional[str] = None + """ID of the model to use. + """ + + prompt: str = Field(default="") + """An optional text to guide the model's style or continue a previous audio + segment. + + The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting) + should match the audio language. + """ + + response_format: AudioResponseFormat = Field(default="json") + """ + The format of the output, in one of these options: `json`, `text`, `srt`, + `verbose_json`, or `vtt`. + """ + + # TODO support additional sampling parameters + # --8<-- [start:translation-sampling-params] + temperature: float = Field(default=0.0) + """The sampling temperature, between 0 and 1. + + Higher values like 0.8 will make the output more random, while lower values + like 0.2 will make it more focused / deterministic. If set to 0, the model + will use [log probability](https://en.wikipedia.org/wiki/Log_probability) + to automatically increase the temperature until certain thresholds are hit. + """ + # --8<-- [end:translation-sampling-params] + + # --8<-- [start:translation-extra-params] + language: Optional[str] = None + """The language of the input audio we translate from. + + Supplying the input language in + [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format + will improve accuracy. + """ + + stream: Optional[bool] = False + """Custom field not present in the original OpenAI definition. When set, + it will enable output to be streamed in a similar fashion as the Chat + Completion endpoint. + """ + # Flattened stream option to simplify form data. + stream_include_usage: Optional[bool] = False + stream_continuous_usage_stats: Optional[bool] = False + # --8<-- [end:translation-extra-params] + + # Default sampling parameters for translation requests. + _DEFAULT_SAMPLING_PARAMS: dict = { + "temperature": 0, + } + + def to_sampling_params( + self, + default_max_tokens: int, + default_sampling_params: Optional[dict] = None) -> SamplingParams: + # TODO(#9845): remove max_tokens when field is removed from OpenAI API + max_tokens = default_max_tokens + + if default_sampling_params is None: + default_sampling_params = {} + # Default parameters + if (temperature := self.temperature) is None: + temperature = default_sampling_params.get( + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) + + return SamplingParams.from_optional(temperature=temperature, + max_tokens=max_tokens, + output_kind=RequestOutputKind.DELTA + if self.stream \ + else RequestOutputKind.FINAL_ONLY) + + @model_validator(mode="before") + @classmethod + def validate_stream_options(cls, data): + stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"] + stream = data.get("stream", False) + if any(bool(data.get(so, False)) for so in stream_opts) and not stream: + raise ValueError( + "Stream options can only be defined when `stream=True`.") + + return data + + +# Translation response objects +class TranslationResponse(OpenAIBaseModel): + text: str + """The translated text.""" + + +class TranslationWord(OpenAIBaseModel): + end: float + """End time of the word in seconds.""" + + start: float + """Start time of the word in seconds.""" + + word: str + """The text content of the word.""" + + +class TranslationSegment(OpenAIBaseModel): + id: int + """Unique identifier of the segment.""" + + avg_logprob: float + """Average logprob of the segment. + + If the value is lower than -1, consider the logprobs failed. + """ + + compression_ratio: float + """Compression ratio of the segment. + + If the value is greater than 2.4, consider the compression failed. + """ + + end: float + """End time of the segment in seconds.""" + + no_speech_prob: float + """Probability of no speech in the segment. + + If the value is higher than 1.0 and the `avg_logprob` is below -1, consider + this segment silent. + """ + + seek: int + """Seek offset of the segment.""" + + start: float + """Start time of the segment in seconds.""" + + temperature: float + """Temperature parameter used for generating the segment.""" + + text: str + """Text content of the segment.""" + + tokens: list[int] + """Array of token IDs for the text content.""" + + +class TranslationResponseVerbose(OpenAIBaseModel): + duration: str + """The duration of the input audio.""" + + language: str + """The language of the input audio.""" + + text: str + """The translated text.""" + + segments: Optional[list[TranslationSegment]] = None + """Segments of the translated text and their corresponding details.""" + + words: Optional[list[TranslationWord]] = None + """Extracted words and their corresponding timestamps.""" diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 2a0d4cd74a28..10aced83b60b 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -64,12 +64,14 @@ def __init__( enable_auto_tools: bool = False, tool_parser: Optional[str] = None, enable_prompt_tokens_details: bool = False, + enable_force_include_usage: bool = False, ) -> None: super().__init__(engine_client=engine_client, model_config=model_config, models=models, request_logger=request_logger, - return_tokens_as_token_ids=return_tokens_as_token_ids) + return_tokens_as_token_ids=return_tokens_as_token_ids, + enable_force_include_usage=enable_force_include_usage) self.response_role = response_role self.chat_template = chat_template @@ -110,6 +112,7 @@ def __init__( "been registered") from e self.enable_prompt_tokens_details = enable_prompt_tokens_details + self.enable_force_include_usage = enable_force_include_usage self.default_sampling_params = ( self.model_config.get_diff_sampling_param()) if self.default_sampling_params: @@ -261,8 +264,14 @@ async def create_chat_completion( # Streaming response if request.stream: return self.chat_completion_stream_generator( - request, result_generator, request_id, model_name, - conversation, tokenizer, request_metadata) + request, + result_generator, + request_id, + model_name, + conversation, + tokenizer, + request_metadata, + enable_force_include_usage=self.enable_force_include_usage) try: return await self.chat_completion_full_generator( @@ -405,6 +414,7 @@ async def chat_completion_stream_generator( conversation: list[ConversationMessage], tokenizer: AnyTokenizer, request_metadata: RequestResponseMetadata, + enable_force_include_usage: bool, ) -> AsyncGenerator[str, None]: created_time = int(time.time()) chunk_object_type: Final = "chat.completion.chunk" @@ -471,7 +481,8 @@ async def chat_completion_stream_generator( stream_options = request.stream_options if stream_options: - include_usage = stream_options.include_usage + include_usage = stream_options.include_usage \ + or enable_force_include_usage include_continuous_usage = include_usage and \ stream_options.continuous_usage_stats else: diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index ce5eca855028..a19fde8d70a8 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -52,12 +52,14 @@ def __init__( *, request_logger: Optional[RequestLogger], return_tokens_as_token_ids: bool = False, + enable_force_include_usage: bool = False, ): super().__init__(engine_client=engine_client, model_config=model_config, models=models, request_logger=request_logger, - return_tokens_as_token_ids=return_tokens_as_token_ids) + return_tokens_as_token_ids=return_tokens_as_token_ids, + enable_force_include_usage=enable_force_include_usage) self.default_sampling_params = ( self.model_config.get_diff_sampling_param()) if self.default_sampling_params: @@ -227,7 +229,8 @@ async def create_completion( model_name, num_prompts=num_prompts, tokenizer=tokenizer, - request_metadata=request_metadata) + request_metadata=request_metadata, + enable_force_include_usage=self.enable_force_include_usage) # Non-streaming response final_res_batch: list[Optional[RequestOutput]] = [None] * num_prompts @@ -289,6 +292,7 @@ async def completion_stream_generator( num_prompts: int, tokenizer: AnyTokenizer, request_metadata: RequestResponseMetadata, + enable_force_include_usage: bool, ) -> AsyncGenerator[str, None]: num_choices = 1 if request.n is None else request.n previous_text_lens = [0] * num_choices * num_prompts @@ -298,7 +302,8 @@ async def completion_stream_generator( stream_options = request.stream_options if stream_options: - include_usage = stream_options.include_usage + include_usage = stream_options.include_usage or \ + enable_force_include_usage include_continuous_usage = include_usage and \ stream_options.continuous_usage_stats else: diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index ac3883bdeb33..cf2b738ba55e 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -58,7 +58,8 @@ TokenizeCompletionRequest, TokenizeResponse, TranscriptionRequest, - TranscriptionResponse) + TranscriptionResponse, + TranslationRequest) from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.tool_parsers import ToolParser # yapf: enable @@ -89,9 +90,8 @@ ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest, TokenizeChatRequest] - -AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest, - TranscriptionRequest] +SpeechToTextRequest = Union[TranscriptionRequest, TranslationRequest] +AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest, SpeechToTextRequest] AnyResponse = Union[ CompletionResponse, @@ -132,7 +132,7 @@ def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]: class RequestProcessingMixin(BaseModel): """ - Mixin for request processing, + Mixin for request processing, handling prompt preparation and engine input. """ request_prompts: Optional[Sequence[RequestPrompt]] = [] @@ -144,7 +144,7 @@ class RequestProcessingMixin(BaseModel): class ResponseGenerationMixin(BaseModel): """ - Mixin for response generation, + Mixin for response generation, managing result generators and final batch results. """ result_generator: Optional[AsyncGenerator[tuple[int, Union[ @@ -208,6 +208,7 @@ def __init__( *, request_logger: Optional[RequestLogger], return_tokens_as_token_ids: bool = False, + enable_force_include_usage: bool = False, ): super().__init__() @@ -219,6 +220,7 @@ def __init__( self.request_logger = request_logger self.return_tokens_as_token_ids = return_tokens_as_token_ids + self.enable_force_include_usage = enable_force_include_usage self._tokenizer_executor = ThreadPoolExecutor(max_workers=1) diff --git a/vllm/entrypoints/openai/serving_transcription.py b/vllm/entrypoints/openai/serving_transcription.py index 60d66434ea5a..0d6989fe91bf 100644 --- a/vllm/entrypoints/openai/serving_transcription.py +++ b/vllm/entrypoints/openai/serving_transcription.py @@ -1,155 +1,28 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import asyncio -import io -import math -import time from collections.abc import AsyncGenerator -from math import ceil -from typing import Final, Optional, Union, cast +from typing import Optional, Union -import numpy as np from fastapi import Request from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import ( - DeltaMessage, ErrorResponse, RequestResponseMetadata, TranscriptionRequest, + ErrorResponse, RequestResponseMetadata, TranscriptionRequest, TranscriptionResponse, TranscriptionResponseStreamChoice, - TranscriptionStreamResponse, UsageInfo) -from vllm.entrypoints.openai.serving_engine import OpenAIServing + TranscriptionStreamResponse, TranslationRequest, TranslationResponse, + TranslationResponseStreamChoice, TranslationStreamResponse) from vllm.entrypoints.openai.serving_models import OpenAIServingModels -from vllm.inputs.data import PromptType +from vllm.entrypoints.openai.speech_to_text import OpenAISpeechToText from vllm.logger import init_logger from vllm.outputs import RequestOutput -from vllm.transformers_utils.processor import cached_get_processor -from vllm.utils import PlaceholderModule - -try: - import librosa -except ImportError: - librosa = PlaceholderModule("librosa") # type: ignore[assignment] logger = init_logger(__name__) -# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages#supported-languages -# TODO these configs should live somewhere with the model so we can support -# additional ones - -ISO639_1_SUPPORTED_LANGS = { - "af": "Afrikaans", - "ar": "Arabic", - "hy": "Armenian", - "az": "Azerbaijani", - "be": "Belarusian", - "bs": "Bosnian", - "bg": "Bulgarian", - "ca": "Catalan", - "zh": "Chinese", - "hr": "Croatian", - "cs": "Czech", - "da": "Danish", - "nl": "Dutch", - "en": "English", - "et": "Estonian", - "fi": "Finnish", - "fr": "French", - "gl": "Galician", - "de": "German", - "el": "Greek", - "he": "Hebrew", - "hi": "Hindi", - "hu": "Hungarian", - "is": "Icelandic", - "id": "Indonesian", - "it": "Italian", - "ja": "Japanese", - "kn": "Kannada", - "kk": "Kazakh", - "ko": "Korean", - "lv": "Latvian", - "lt": "Lithuanian", - "mk": "Macedonian", - "ms": "Malay", - "mr": "Marathi", - "mi": "Maori", - "ne": "Nepali", - "no": "Norwegian", - "fa": "Persian", - "pl": "Polish", - "pt": "Portuguese", - "ro": "Romanian", - "ru": "Russian", - "sr": "Serbian", - "sk": "Slovak", - "sl": "Slovenian", - "es": "Spanish", - "sw": "Swahili", - "sv": "Swedish", - "tl": "Tagalog", - "ta": "Tamil", - "th": "Thai", - "tr": "Turkish", - "uk": "Ukrainian", - "ur": "Urdu", - "vi": "Vietnamese", - "cy": "Welsh" -} -ISO639_1_OTHER_LANGS = { - "lo": "Lao", - "jw": "Javanese", - "tk": "Turkmen", - "yi": "Yiddish", - "so": "Somali", - "bn": "Bengali", - "nn": "Norwegian Nynorsk", - "si": "Sinhala", - "yo": "Yoruba", - "sa": "Sanskrit", - "mi": "Māori", - "fo": "Faroese", # codespell:ignore - "mt": "Maltese", - "tg": "Tajik", - "mg": "Malagasy", - "haw": "Hawaiian", - "km": "Khmer", - "br": "Breton", - "ps": "Pashto", - "ln": "Lingala", - "la": "Latin", - "ml": "Malayalam", - "sq": "Albanian", - "su": "Sundanese", - "eu": "Basque", - "ka": "Georgian", - "uz": "Uzbek", - "sn": "Shona", - "ht": "Haitian", - "as": "Assamese", - "mn": "Mongolian", - "te": "Telugu", - "pa": "Panjabi", - "tt": "Tatar", - "gu": "Gujarati", - "oc": "Occitan", - "ha": "Hausa", - "ba": "Bashkir", - "my": "Burmese", - "sd": "Sindhi", - "am": "Amharic", - "lb": "Luxembourgish", - "bo": "Tibetan" -} - -# As per https://platform.openai.com/docs/guides/speech-to-text#overview. -# TODO configurable -MAX_AUDIO_CLIP_FILESIZE_MB = 25 -OVERLAP_CHUNK_SECOND = 1 -MIN_ENERGY_WINDOW_SIZE = 1600 # 1600 ~ 100ms for 16000 Hz audio - -class OpenAIServingTranscription(OpenAIServing): +class OpenAIServingTranscription(OpenAISpeechToText): + """Handles transcription requests.""" def __init__( self, @@ -164,70 +37,9 @@ def __init__( model_config=model_config, models=models, request_logger=request_logger, - return_tokens_as_token_ids=return_tokens_as_token_ids) - - self.default_sampling_params = ( - self.model_config.get_diff_sampling_param()) - processor = cached_get_processor(model_config.model) - self.max_audio_clip_s = processor.feature_extractor.chunk_length - self.model_sr = processor.feature_extractor.sampling_rate - self.hop_length = processor.feature_extractor.hop_length - - if self.default_sampling_params: - logger.info( - "Overwriting default completion sampling param with: %s", - self.default_sampling_params) - - async def _preprocess_transcription( - self, - request: TranscriptionRequest, - audio_data: bytes, - ) -> tuple[list[PromptType], float]: - # Validate request - # TODO language should be optional and can be guessed. - # For now we default to en. See - # https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520 - lang_token = f"<|{request.language}|>" if request.language else "<|en|>" - if request.language: - if request.language in ISO639_1_SUPPORTED_LANGS: - pass - elif request.language in ISO639_1_OTHER_LANGS: - logger.warning( - "The selected language %s has limited accuracy with" - " reported WER>=0.5. Results may be less accurate " - "for this choice.", request.language) - else: - raise ValueError( - f"Unsupported language: {request.language}." - "Language should be one of:" + - f" {list(ISO639_1_SUPPORTED_LANGS.values())}" + - f"or {list(ISO639_1_OTHER_LANGS.values())}") - - if len(audio_data) / 1024**2 > MAX_AUDIO_CLIP_FILESIZE_MB: - raise ValueError("Maximum file size exceeded.") - - with io.BytesIO(audio_data) as bytes_: - y, sr = librosa.load(bytes_) - - duration = librosa.get_duration(y=y, sr=sr) - chunks = [y] if duration < 30 else self._split_audio(y, sr) - prompts = [] - for i, chunk in enumerate(chunks): - prompt = { - "encoder_prompt": { - "prompt": "", - "multi_modal_data": { - "audio": (chunk, sr), - }, - }, - "decoder_prompt": - f"<|startoftranscript|>{lang_token}<|transcribe|><|notimestamps|>{request.prompt}" - if i == 0 else "" - } - prompts.append(cast(PromptType, prompt)) - return prompts, duration + return_tokens_as_token_ids=return_tokens_as_token_ids, + task_type="transcribe") - # TODO (varun) : Make verbose response work ! async def create_transcription( self, audio_data: bytes, request: TranscriptionRequest, raw_request: Request @@ -238,250 +50,83 @@ async def create_transcription( See https://platform.openai.com/docs/api-reference/audio/createTranscription for the API specification. This API mimics the OpenAI transcription API. """ - error_check_ret = await self._check_model(request) - if error_check_ret is not None: - return error_check_ret - - # If the engine is dead, raise the engine's DEAD_ERROR. - # This is required for the streaming case, where we return a - # success status before we actually start generating text :). - if self.engine_client.errored: - raise self.engine_client.dead_error - - if request.response_format not in ['text', 'json']: - return self.create_error_response( - "Currently only support response_format `text` or `json`") - - request_id = f"trsc-{self._base_request_id(raw_request)}" - - request_metadata = RequestResponseMetadata(request_id=request_id) - if raw_request: - raw_request.state.request_metadata = request_metadata - - try: - ( - lora_request, - prompt_adapter_request, - ) = self._maybe_get_adapters(request) - - if lora_request: - return self.create_error_response( - "Currently do not support LoRA for Transcription.") - if prompt_adapter_request: - return self.create_error_response( - "Currently do not support PromptAdapter for Transcription." - ) - - prompts, duration_s = await self._preprocess_transcription( - request=request, - audio_data=audio_data, - ) - - except ValueError as e: - logger.exception("Error in preprocessing prompt inputs") - return self.create_error_response(str(e)) - - list_result_generator: Optional[list[AsyncGenerator[RequestOutput, - None]]] = None - try: - # Unlike most decoder-only models, whisper generation length is not - # constrained by the size of the input audio, which is mapped to a - # fixed-size log-mel-spectogram. - default_max_tokens = self.model_config.max_model_len - sampling_params = request.to_sampling_params( - default_max_tokens, self.default_sampling_params) - - self._log_inputs( - request_id, - prompts[0]['decoder_prompt'], # type: ignore - params=sampling_params, - lora_request=None, - prompt_adapter_request=None) - - list_result_generator = [ - self.engine_client.generate( - prompt, - sampling_params, - request_id, - ) for prompt in prompts - ] - except ValueError as e: - # TODO: Use a vllm-specific Validation Error - return self.create_error_response(str(e)) - - if request.stream: - return self.transcription_stream_generator(request, - list_result_generator, - request_id, - request_metadata, - duration_s) - # Non-streaming response. - try: - assert list_result_generator is not None - text = "" - for result_generator in list_result_generator: - async for op in result_generator: - text += op.outputs[0].text - return TranscriptionResponse(text=text) - except asyncio.CancelledError: - return self.create_error_response("Client disconnected") - except ValueError as e: - # TODO: Use a vllm-specific Validation Error - return self.create_error_response(str(e)) + return await self._create_speech_to_text( + audio_data=audio_data, + request=request, + raw_request=raw_request, + response_class=TranscriptionResponse, + stream_generator_method=self.transcription_stream_generator, + ) async def transcription_stream_generator( self, request: TranscriptionRequest, - list_result_generator: list[AsyncGenerator[RequestOutput, None]], + result_generator: list[AsyncGenerator[RequestOutput, None]], request_id: str, request_metadata: RequestResponseMetadata, audio_duration_s: float) -> AsyncGenerator[str, None]: - created_time = int(time.time()) - model_name = request.model - chunk_object_type: Final = "transcription.chunk" - - completion_tokens = 0 - num_prompt_tokens = 0 - - include_usage = request.stream_include_usage \ - if request.stream_include_usage else False - include_continuous_usage = request.stream_continuous_usage_stats\ - if include_usage and request.stream_continuous_usage_stats\ - else False - - try: - for result_generator in list_result_generator: - async for res in result_generator: - # On first result. - if res.prompt_token_ids is not None: - # Do not account the 4-tokens `<|startoftranscript|>..` - # Could be negative when language token - # is not specified. - num_prompt_tokens = max( - len(res.prompt_token_ids) - 4, 0) - # NOTE(NickLucche) user can't pass encoder - # prompts directly at least not to Whisper. - # One indicator of the encoder amount of processing - # is the log-mel spectogram length. - num_prompt_tokens += ceil( - audio_duration_s * self.model_sr / self.hop_length) - - # We need to do it here, because if there are exceptions in - # the result_generator, it needs to be sent as the FIRST - # response (by the try...catch). - - # Just one output (n=1) supported. - assert len(res.outputs) == 1 - output = res.outputs[0] + generator = self._speech_to_text_stream_generator( + request=request, + list_result_generator=result_generator, + request_id=request_id, + request_metadata=request_metadata, + audio_duration_s=audio_duration_s, + chunk_object_type="transcription.chunk", + response_stream_choice_class=TranscriptionResponseStreamChoice, + stream_response_class=TranscriptionStreamResponse, + ) + async for chunk in generator: + yield chunk + + +class OpenAIServingTranslation(OpenAISpeechToText): + """Handles translation requests.""" - delta_message = DeltaMessage(content=output.text) - completion_tokens += len(output.token_ids) - - if output.finish_reason is None: - # Still generating, send delta update. - choice_data = TranscriptionResponseStreamChoice( - delta=delta_message) - else: - # Model is finished generating. - choice_data = TranscriptionResponseStreamChoice( - delta=delta_message, - finish_reason=output.finish_reason, - stop_reason=output.stop_reason) - - chunk = TranscriptionStreamResponse( - id=request_id, - object=chunk_object_type, - created=created_time, - choices=[choice_data], - model=model_name) - - # handle usage stats if requested & if continuous - if include_continuous_usage: - chunk.usage = UsageInfo( - prompt_tokens=num_prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=num_prompt_tokens + completion_tokens, - ) - - data = chunk.model_dump_json(exclude_unset=True) - yield f"data: {data}\n\n" - - # Once the final token is handled, if stream_options.include_usage - # is sent, send the usage. - if include_usage: - final_usage = UsageInfo(prompt_tokens=num_prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=num_prompt_tokens + - completion_tokens) - - final_usage_chunk = TranscriptionStreamResponse( - id=request_id, - object=chunk_object_type, - created=created_time, - choices=[], - model=model_name, - usage=final_usage) - final_usage_data = (final_usage_chunk.model_dump_json( - exclude_unset=True, exclude_none=True)) - yield f"data: {final_usage_data}\n\n" - - # report to FastAPI middleware aggregate usage across all choices - request_metadata.final_usage_info = UsageInfo( - prompt_tokens=num_prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=num_prompt_tokens + completion_tokens) - - except Exception as e: - # TODO: Use a vllm-specific Validation Error - logger.exception("Error in chat completion stream generator.") - data = self.create_streaming_error_response(str(e)) - yield f"data: {data}\n\n" - # Send the final done message after all response.n are finished - yield "data: [DONE]\n\n" - - def _split_audio(self, audio_data: np.ndarray, - sample_rate: int) -> list[np.ndarray]: - chunk_size = sample_rate * self.max_audio_clip_s - overlap_size = sample_rate * OVERLAP_CHUNK_SECOND - chunks = [] - i = 0 - while i < audio_data.shape[-1]: - if i + chunk_size >= audio_data.shape[-1]: - # handle last chunk - chunks.append(audio_data[..., i:]) - break - - # Find the best split point in the overlap region - search_start = i + chunk_size - overlap_size - search_end = min(i + chunk_size, audio_data.shape[-1]) - split_point = self._find_split_point(audio_data, search_start, - search_end) + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + *, + request_logger: Optional[RequestLogger], + return_tokens_as_token_ids: bool = False, + ): + super().__init__(engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger, + return_tokens_as_token_ids=return_tokens_as_token_ids, + task_type="translate") - # Extract chunk up to the split point - chunks.append(audio_data[..., i:split_point]) - i = split_point - return chunks + async def create_translation( + self, audio_data: bytes, request: TranslationRequest, + raw_request: Request + ) -> Union[TranslationResponse, AsyncGenerator[str, None], ErrorResponse]: + """Translation API similar to OpenAI's API. - def _find_split_point(self, wav: np.ndarray, start_idx: int, - end_idx: int) -> int: - """Find the best point to split audio by - looking for silence or low amplitude. - Args: - wav: Audio tensor [1, T] - start_idx: Start index of search region - end_idx: End index of search region - Returns: - Index of best splitting point + See https://platform.openai.com/docs/api-reference/audio/createTranslation + for the API specification. This API mimics the OpenAI translation API. """ - segment = wav[start_idx:end_idx] - - # Calculate RMS energy in small windows - min_energy = math.inf - quietest_idx = 0 - for i in range(0, - len(segment) - MIN_ENERGY_WINDOW_SIZE, - MIN_ENERGY_WINDOW_SIZE): - window = segment[i:i + MIN_ENERGY_WINDOW_SIZE] - energy = (window**2).mean()**0.5 - if energy < min_energy: - quietest_idx = i + start_idx - min_energy = energy - return quietest_idx + return await self._create_speech_to_text( + audio_data=audio_data, + request=request, + raw_request=raw_request, + response_class=TranslationResponse, + stream_generator_method=self.translation_stream_generator, + ) + + async def translation_stream_generator( + self, request: TranslationRequest, + result_generator: list[AsyncGenerator[RequestOutput, None]], + request_id: str, request_metadata: RequestResponseMetadata, + audio_duration_s: float) -> AsyncGenerator[str, None]: + generator = self._speech_to_text_stream_generator( + request=request, + list_result_generator=result_generator, + request_id=request_id, + request_metadata=request_metadata, + audio_duration_s=audio_duration_s, + chunk_object_type="translation.chunk", + response_stream_choice_class=TranslationResponseStreamChoice, + stream_response_class=TranslationStreamResponse, + ) + async for chunk in generator: + yield chunk diff --git a/vllm/entrypoints/openai/speech_to_text.py b/vllm/entrypoints/openai/speech_to_text.py new file mode 100644 index 000000000000..b23cf6cab097 --- /dev/null +++ b/vllm/entrypoints/openai/speech_to_text.py @@ -0,0 +1,503 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import asyncio +import io +import math +import time +from collections.abc import AsyncGenerator +from math import ceil +from typing import Callable, Literal, Optional, TypeVar, Union, cast + +import numpy as np +from fastapi import Request + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.protocol import ( + DeltaMessage, ErrorResponse, RequestResponseMetadata, + TranscriptionResponse, TranscriptionResponseStreamChoice, + TranscriptionStreamResponse, TranslationResponse, + TranslationResponseStreamChoice, TranslationStreamResponse, UsageInfo) +from vllm.entrypoints.openai.serving_engine import (OpenAIServing, + SpeechToTextRequest) +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.inputs.data import PromptType +from vllm.logger import init_logger +from vllm.outputs import RequestOutput +from vllm.transformers_utils.processor import cached_get_processor +from vllm.utils import PlaceholderModule + +try: + import librosa +except ImportError: + librosa = PlaceholderModule("librosa") # type: ignore[assignment] + +SpeechToTextResponse = Union[TranscriptionResponse, TranslationResponse] +T = TypeVar("T", bound=SpeechToTextResponse) + +logger = init_logger(__name__) + +# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages +# TODO these configs should live somewhere with the model so we can support +# additional ones + +ISO639_1_SUPPORTED_LANGS = { + "af": "Afrikaans", + "ar": "Arabic", + "hy": "Armenian", + "az": "Azerbaijani", + "be": "Belarusian", + "bs": "Bosnian", + "bg": "Bulgarian", + "ca": "Catalan", + "zh": "Chinese", + "hr": "Croatian", + "cs": "Czech", + "da": "Danish", + "nl": "Dutch", + "en": "English", + "et": "Estonian", + "fi": "Finnish", + "fr": "French", + "gl": "Galician", + "de": "German", + "el": "Greek", + "he": "Hebrew", + "hi": "Hindi", + "hu": "Hungarian", + "is": "Icelandic", + "id": "Indonesian", + "it": "Italian", + "ja": "Japanese", + "kn": "Kannada", + "kk": "Kazakh", + "ko": "Korean", + "lv": "Latvian", + "lt": "Lithuanian", + "mk": "Macedonian", + "ms": "Malay", + "mr": "Marathi", + "mi": "Maori", + "ne": "Nepali", + "no": "Norwegian", + "fa": "Persian", + "pl": "Polish", + "pt": "Portuguese", + "ro": "Romanian", + "ru": "Russian", + "sr": "Serbian", + "sk": "Slovak", + "sl": "Slovenian", + "es": "Spanish", + "sw": "Swahili", + "sv": "Swedish", + "tl": "Tagalog", + "ta": "Tamil", + "th": "Thai", + "tr": "Turkish", + "uk": "Ukrainian", + "ur": "Urdu", + "vi": "Vietnamese", + "cy": "Welsh" +} +ISO639_1_OTHER_LANGS = { + "lo": "Lao", + "jw": "Javanese", + "tk": "Turkmen", + "yi": "Yiddish", + "so": "Somali", + "bn": "Bengali", + "nn": "Norwegian Nynorsk", + "si": "Sinhala", + "yo": "Yoruba", + "sa": "Sanskrit", + "mi": "Māori", + "fo": "Faroese", # codespell:ignore + "mt": "Maltese", + "tg": "Tajik", + "mg": "Malagasy", + "haw": "Hawaiian", + "km": "Khmer", + "br": "Breton", + "ps": "Pashto", + "ln": "Lingala", + "la": "Latin", + "ml": "Malayalam", + "sq": "Albanian", + "su": "Sundanese", + "eu": "Basque", + "ka": "Georgian", + "uz": "Uzbek", + "sn": "Shona", + "ht": "Haitian", + "as": "Assamese", + "mn": "Mongolian", + "te": "Telugu", + "pa": "Panjabi", + "tt": "Tatar", + "gu": "Gujarati", + "oc": "Occitan", + "ha": "Hausa", + "ba": "Bashkir", + "my": "Burmese", + "sd": "Sindhi", + "am": "Amharic", + "lb": "Luxembourgish", + "bo": "Tibetan" +} + +# As per https://platform.openai.com/docs/guides/speech-to-text#overview. +# TODO configurable +MAX_AUDIO_CLIP_FILESIZE_MB = 25 +OVERLAP_CHUNK_SECOND = 1 +MIN_ENERGY_WINDOW_SIZE = 1600 # 1600 ~ 100ms for 16000 Hz audio + + +class OpenAISpeechToText(OpenAIServing): + """Base class for speech-to-text operations like transcription and + translation.""" + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + *, + request_logger: Optional[RequestLogger], + return_tokens_as_token_ids: bool = False, + task_type: Literal["transcribe", "translate"] = "transcribe", + ): + super().__init__(engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger, + return_tokens_as_token_ids=return_tokens_as_token_ids) + + self.default_sampling_params = ( + self.model_config.get_diff_sampling_param()) + processor = cached_get_processor(model_config.model) + self.max_audio_clip_s = processor.feature_extractor.chunk_length + self.model_sr = processor.feature_extractor.sampling_rate + self.hop_length = processor.feature_extractor.hop_length + self.task_type = task_type + + if self.default_sampling_params: + logger.info( + "Overwriting default completion sampling param with: %s", + self.default_sampling_params) + + async def _preprocess_speech_to_text( + self, + request: SpeechToTextRequest, + audio_data: bytes, + ) -> tuple[list[PromptType], float]: + # Validate request + # TODO language should be optional and can be guessed. + # For now we default to en. See + # https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520 + lang_token = f"<|{request.language}|>" if request.language else "<|en|>" + if request.language: + if request.language in ISO639_1_SUPPORTED_LANGS: + pass + elif request.language in ISO639_1_OTHER_LANGS: + logger.warning( + "The selected language %s has limited accuracy with" + " reported WER>=0.5. Results may be less accurate " + "for this choice.", request.language) + else: + raise ValueError( + f"Unsupported language: {request.language}." + "Language should be one of:" + + f" {list(ISO639_1_SUPPORTED_LANGS.values())}" + + f"or {list(ISO639_1_OTHER_LANGS.values())}") + + if len(audio_data) / 1024**2 > MAX_AUDIO_CLIP_FILESIZE_MB: + raise ValueError("Maximum file size exceeded.") + + with io.BytesIO(audio_data) as bytes_: + # NOTE resample to model SR here for efficiency. This is also a + # pre-requisite for chunking, as it assumes Whisper SR. + y, sr = librosa.load(bytes_, sr=self.model_sr) + + duration = librosa.get_duration(y=y, sr=sr) + chunks = [y] if duration < 30 else self._split_audio(y, int(sr)) + prompts = [] + for chunk in chunks: + prompt = { + "encoder_prompt": { + "prompt": "", + "multi_modal_data": { + "audio": (chunk, sr), + }, + }, + "decoder_prompt": + (f"<|startoftranscript|>{lang_token}" + f"<|{self.task_type}|><|notimestamps|>{request.prompt}") + } + prompts.append(cast(PromptType, prompt)) + return prompts, duration + + async def _create_speech_to_text( + self, + audio_data: bytes, + request: SpeechToTextRequest, + raw_request: Request, + response_class: type[T], + stream_generator_method: Callable[..., AsyncGenerator[str, None]], + ) -> Union[T, AsyncGenerator[str, None], ErrorResponse]: + """Base method for speech-to-text operations like transcription and + translation.""" + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + # If the engine is dead, raise the engine's DEAD_ERROR. + # This is required for the streaming case, where we return a + # success status before we actually start generating text :). + if self.engine_client.errored: + raise self.engine_client.dead_error + + if request.response_format not in ['text', 'json']: + return self.create_error_response( + "Currently only support response_format `text` or `json`") + + request_id = f"{self.task_type}-{self._base_request_id(raw_request)}" + + request_metadata = RequestResponseMetadata(request_id=request_id) + if raw_request: + raw_request.state.request_metadata = request_metadata + + try: + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + + if lora_request: + return self.create_error_response( + "Currently do not support LoRA for " + f"{self.task_type.title()}.") + if prompt_adapter_request: + return self.create_error_response( + f"Currently do not support PromptAdapter for " + f"{self.task_type.title()}.") + + prompts, duration_s = await self._preprocess_speech_to_text( + request=request, + audio_data=audio_data, + ) + + except ValueError as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(str(e)) + + list_result_generator: Optional[list[AsyncGenerator[RequestOutput, + None]]] = None + try: + # Unlike most decoder-only models, whisper generation length is not + # constrained by the size of the input audio, which is mapped to a + # fixed-size log-mel-spectogram. + default_max_tokens = self.model_config.max_model_len + sampling_params = request.to_sampling_params( + default_max_tokens, self.default_sampling_params) + + self._log_inputs( + request_id, + prompts[0]['decoder_prompt'], # type: ignore + params=sampling_params, + lora_request=None, + prompt_adapter_request=None) + + list_result_generator = [ + self.engine_client.generate( + prompt, + sampling_params, + request_id, + ) for prompt in prompts + ] + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + if request.stream: + return stream_generator_method(request, list_result_generator, + request_id, request_metadata, + duration_s) + # Non-streaming response. + try: + assert list_result_generator is not None + text = "" + for result_generator in list_result_generator: + async for op in result_generator: + text += op.outputs[0].text + return cast(T, response_class(text=text)) + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + async def _speech_to_text_stream_generator( + self, + request: SpeechToTextRequest, + list_result_generator: list[AsyncGenerator[RequestOutput, None]], + request_id: str, + request_metadata: RequestResponseMetadata, + audio_duration_s: float, + chunk_object_type: Literal["translation.chunk", "transcription.chunk"], + response_stream_choice_class: Union[ + type[TranscriptionResponseStreamChoice], + type[TranslationResponseStreamChoice]], + stream_response_class: Union[type[TranscriptionStreamResponse], + type[TranslationStreamResponse]], + ) -> AsyncGenerator[str, None]: + created_time = int(time.time()) + model_name = request.model + + completion_tokens = 0 + num_prompt_tokens = 0 + + include_usage = request.stream_include_usage \ + if request.stream_include_usage else False + include_continuous_usage = request.stream_continuous_usage_stats\ + if include_usage and request.stream_continuous_usage_stats\ + else False + + try: + for result_generator in list_result_generator: + async for res in result_generator: + # On first result. + if res.prompt_token_ids is not None: + # Do not account the 4-tokens `<|startoftranscript|>..` + # Could be negative when language token + # is not specified. + num_prompt_tokens = max( + len(res.prompt_token_ids) - 4, 0) + # NOTE(NickLucche) user can't pass encoder + # prompts directly at least not to Whisper. + # One indicator of the encoder amount of processing + # is the log-mel spectogram length. + num_prompt_tokens += ceil( + audio_duration_s * self.model_sr / self.hop_length) + + # We need to do it here, because if there are exceptions in + # the result_generator, it needs to be sent as the FIRST + # response (by the try...catch). + + # Just one output (n=1) supported. + assert len(res.outputs) == 1 + output = res.outputs[0] + + delta_message = DeltaMessage(content=output.text) + completion_tokens += len(output.token_ids) + + if output.finish_reason is None: + # Still generating, send delta update. + choice_data = response_stream_choice_class( + delta=delta_message) + else: + # Model is finished generating. + choice_data = response_stream_choice_class( + delta=delta_message, + finish_reason=output.finish_reason, + stop_reason=output.stop_reason) + + chunk = stream_response_class(id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + + # handle usage stats if requested & if continuous + if include_continuous_usage: + chunk.usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + completion_tokens, + ) + + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + + # Once the final token is handled, if stream_options.include_usage + # is sent, send the usage. + if include_usage: + final_usage = UsageInfo(prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + + completion_tokens) + + final_usage_chunk = stream_response_class( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[], + model=model_name, + usage=final_usage) + final_usage_data = (final_usage_chunk.model_dump_json( + exclude_unset=True, exclude_none=True)) + yield f"data: {final_usage_data}\n\n" + + # report to FastAPI middleware aggregate usage across all choices + request_metadata.final_usage_info = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + completion_tokens) + + except Exception as e: + # TODO: Use a vllm-specific Validation Error + logger.exception("Error in %s stream generator.", self.task_type) + data = self.create_streaming_error_response(str(e)) + yield f"data: {data}\n\n" + # Send the final done message after all response.n are finished + yield "data: [DONE]\n\n" + + def _split_audio(self, audio_data: np.ndarray, + sample_rate: int) -> list[np.ndarray]: + chunk_size = sample_rate * self.max_audio_clip_s + overlap_size = sample_rate * OVERLAP_CHUNK_SECOND + chunks = [] + i = 0 + while i < audio_data.shape[-1]: + if i + chunk_size >= audio_data.shape[-1]: + # handle last chunk + chunks.append(audio_data[..., i:]) + break + + # Find the best split point in the overlap region + search_start = i + chunk_size - overlap_size + search_end = min(i + chunk_size, audio_data.shape[-1]) + split_point = self._find_split_point(audio_data, search_start, + search_end) + + # Extract chunk up to the split point + chunks.append(audio_data[..., i:split_point]) + i = split_point + return chunks + + def _find_split_point(self, wav: np.ndarray, start_idx: int, + end_idx: int) -> int: + """Find the best point to split audio by + looking for silence or low amplitude. + Args: + wav: Audio tensor [1, T] + start_idx: Start index of search region + end_idx: End index of search region + Returns: + Index of best splitting point + """ + segment = wav[start_idx:end_idx] + + # Calculate RMS energy in small windows + min_energy = math.inf + quietest_idx = 0 + for i in range(0, + len(segment) - MIN_ENERGY_WINDOW_SIZE, + MIN_ENERGY_WINDOW_SIZE): + window = segment[i:i + MIN_ENERGY_WINDOW_SIZE] + energy = (window**2).mean()**0.5 + if energy < min_energy: + quietest_idx = i + start_idx + min_energy = energy + return quietest_idx diff --git a/vllm/envs.py b/vllm/envs.py index ee7a615bfba8..745ca626cda1 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -133,6 +133,7 @@ VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS: int = 300 VLLM_KV_CACHE_LAYOUT: Optional[str] = None VLLM_COMPUTE_NANS_IN_LOGITS: bool = False + VLLM_USE_NVFP4_CT_EMULATIONS: bool = False def get_default_cache_root(): @@ -918,6 +919,12 @@ def get_vllm_port() -> Optional[int]: # or bad hardware but it may add compute overhead. "VLLM_COMPUTE_NANS_IN_LOGITS": lambda: bool(int(os.getenv("VLLM_COMPUTE_NANS_IN_LOGITS", "0"))), + + # Controls whether or not emulations are used for NVFP4 + # generations on machines < 100 for compressed-tensors + # models + "VLLM_USE_NVFP4_CT_EMULATIONS": + lambda: bool(int(os.getenv("VLLM_USE_NVFP4_CT_EMULATIONS", "0"))) } # --8<-- [end:env-vars-definition] @@ -981,6 +988,7 @@ def factorize(name: str): "VLLM_DP_RANK", "VLLM_DP_SIZE", "VLLM_USE_STANDALONE_COMPILE", + "VLLM_FUSED_MOE_CHUNK_SIZE", ] for key in environment_variables_to_hash: if key in environment_variables: diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 5492399efdf8..70836879d17c 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -6,14 +6,179 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.utils import ( - _resize_cache, per_token_group_quant_fp8) +from vllm.model_executor.layers.fused_moe.utils import _resize_cache +from vllm.triton_utils import tl, triton logger = init_logger(__name__) has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None +@triton.jit +def _silu_mul_fp8_quant_deep_gemm( + # Pointers ------------------------------------------------------------ + input_ptr, # 16-bit activations (E, T, 2*H) + y_q_ptr, # fp8 quantized activations (E, T, H) + y_s_ptr, # 16-bit scales (E, T, G) + counts_ptr, # int32 num tokens per expert (E) + + # Sizes --------------------------------------------------------------- + H: tl.constexpr, # hidden dimension (per output) + GROUP_SIZE: tl.constexpr, # elements per group (usually 128) + + # Strides for input (elements) --------------------------------------- + stride_i_e, + stride_i_t, + stride_i_h, + + # Strides for y_q (elements) ----------------------------------------- + stride_yq_e, + stride_yq_t, + stride_yq_h, + + # Strides for y_s (elements) ----------------------------------------- + stride_ys_e, + stride_ys_t, + stride_ys_g, + + # Stride for counts (elements) + stride_counts_e, + + # Numeric params ------------------------------------------------------ + eps: tl.constexpr, + fp8_min: tl.constexpr, + fp8_max: tl.constexpr, + + # Meta --------------------------------------------------------------- + BLOCK: tl.constexpr, +): + G = H // GROUP_SIZE + + # map program id -> (e, g) + pid = tl.program_id(0) + e = pid // G + g = pid % G + + e = e.to(tl.int64) + g = g.to(tl.int64) + + # number of valid tokens for this expert + n_tokens = tl.load(counts_ptr + e * stride_counts_e).to(tl.int64) + + cols = tl.arange(0, BLOCK) + cols = cols.to(tl.int64) + mask_h = cols < BLOCK + + t = tl.zeros([], tl.int64) + while t < n_tokens: + base_i_offset = (e * stride_i_e + t * stride_i_t + + g * GROUP_SIZE * stride_i_h) + base_yq_offset = (e * stride_yq_e + t * stride_yq_t + + g * GROUP_SIZE * stride_yq_h) + base_ys_offset = e * stride_ys_e + t * stride_ys_t + g * stride_ys_g + + mask = mask_h + x = tl.load(input_ptr + base_i_offset + cols * stride_i_h, + mask=mask, + other=0.0).to(tl.float32) + y2 = tl.load(input_ptr + base_i_offset + H * stride_i_h + + cols * stride_i_h, + mask=mask, + other=0.0).to(tl.float32) + + x = x * (1.0 / (1.0 + tl.exp(-x))) + y = x * y2 + + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + base_yq_offset + cols * stride_yq_h, y_q, mask=mask) + tl.store(y_s_ptr + base_ys_offset, y_s) + + t += 1 + + +def silu_mul_fp8_quant_deep_gemm( + y: torch.Tensor, # (E, T, 2*H) float32 + tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert + group_size: int = 128, + eps: float = 1e-10, +): + """Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales + + y has shape (E, T, 2*H). The first half of the last dimension is + silu-activated, multiplied by the second half, then quantized into FP8. + + Returns `(y_q, y_s)` where + * `y_q` is the FP8 tensor of shape `(E, T, H)`, same layout as `y[..., :H]`. + * `y_s` has shape `(E, T, H // group_size)` and strides `(T*G, 1, T)` + """ + assert y.ndim == 3, "y must be (E, T, 2*H)" + E, T, H2 = y.shape + assert H2 % 2 == 0, "last dim of y must be even (2*H)" + H = H2 // 2 + G = H // group_size + assert H % group_size == 0, "H must be divisible by group_size" + assert tokens_per_expert.ndim == 1 and tokens_per_expert.shape[0] == E, \ + "tokens_per_expert must be shape (E,)" + tokens_per_expert = tokens_per_expert.to(device=y.device, + dtype=torch.int32) + + # allocate outputs + fp8_dtype = torch.float8_e4m3fn + y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device) + + # strides (elements) + stride_i_e, stride_i_t, stride_i_h = y.stride() + stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride() + + # desired scale strides (elements): (T*G, 1, T) + stride_ys_e = T * G + stride_ys_t = 1 + stride_ys_g = T + y_s = torch.empty_strided((E, T, G), + (stride_ys_e, stride_ys_t, stride_ys_g), + dtype=torch.float32, + device=y.device) + + stride_cnt_e = tokens_per_expert.stride()[0] + + # static grid over experts and H-groups. + # A loop inside the kernel handles the token dim + grid = (E * G, ) + + f_info = torch.finfo(fp8_dtype) + fp8_max = f_info.max + fp8_min = f_info.min + + _silu_mul_fp8_quant_deep_gemm[grid]( + y, + y_q, + y_s, + tokens_per_expert, + H, + group_size, + stride_i_e, + stride_i_t, + stride_i_h, + stride_yq_e, + stride_yq_t, + stride_yq_h, + stride_ys_e, + stride_ys_t, + stride_ys_g, + stride_cnt_e, + eps, + fp8_min, + fp8_max, + BLOCK=group_size, + num_warps=4, + ) + + return y_q, y_s + + class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): # The Deep Gemm kernels only support block size of 128 @@ -96,7 +261,6 @@ def apply( hidden_states, w1, w2, topk_ids) workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N)) - workspace2 = _resize_cache(workspace2, (E, max_num_tokens, N // 2)) # (from deepgemm docs) : A value hint (which is a value on CPU) # for the M expectation of each batch, correctly setting this value @@ -109,19 +273,9 @@ def apply( masked_m=expert_num_tokens, expected_m=expected_m) - # TODO (varun) [Optimization]: Use a batched version of activation. - # Similarly for the quant below. - self.activation(activation, workspace2, workspace1.view(-1, N)) - - w2_hidden_size = workspace2.size(-1) - workspace2 = workspace2.view(-1, w2_hidden_size) - - a2q_scale: Optional[torch.Tensor] = None - a2q, a2q_scale = per_token_group_quant_fp8(workspace2, - self.block_shape[1], - column_major_scales=False) - a2q = a2q.view(E, max_num_tokens, -1) - a2q_scale = a2q_scale.view(E, max_num_tokens, -1) + assert expert_num_tokens is not None + a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1, + expert_num_tokens) dg.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a2q, a2q_scale), (w2, w2_scale), diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 3f9ceac8b6e3..73d169a84808 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -41,24 +41,24 @@ def run_cutlass_moe_fp8( assert w1.dtype == torch.float8_e4m3fn assert w2.dtype == torch.float8_e4m3fn if expert_num_tokens is None: - assert a1q.shape[1] == w1.shape[2], "Hidden size mismatch w1" + assert a1q.size(1) == w1.size(2), "Hidden size mismatch w1" else: - assert a1q.shape[2] == w1.shape[2], "Hidden size mismatch w1" - assert w1.shape[1] == w2.shape[2] * 2, "Hidden size mismatch w2" - assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[ - 1] == w1.shape[1], "W1 scale shape mismatch" - assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[ - 1] == w2.shape[1], "W2 scale shape mismatch" - assert w1.shape[0] == w2.shape[0], "Expert number mismatch" - assert a1q_scale is None or a1q_scale.dim( - ) == 0 or a1q_scale.shape[0] == 1 or a1q_scale.shape[0] == a1q.shape[ - 0], "Input scale shape mismatch" - assert w1.shape[0] == w2.shape[0], "Weights expert number mismatch" - assert w1.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch" - assert w1.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch" - assert a2_scale is None or a2_scale.dim( - ) == 0 or a2_scale.shape[0] == 1 or a2_scale.shape[0] == a1q.shape[ - 0], "Intermediate scale shape mismatch" + assert a1q.size(2) == w1.size(2), "Hidden size mismatch w1" + assert w1.size(1) == w2.size(2) * 2, "Hidden size mismatch w2" + assert w1_scale.dim() == 1 or w1_scale.size( + 1) == 1 or w1_scale.shape[1] == w1.size(1), "W1 scale shape mismatch" + assert w2_scale.dim() == 1 or w2_scale.size( + 1) == 1 or w2_scale.shape[1] == w2.size(1), "W2 scale shape mismatch" + assert w1.size(0) == w2.size(0), "Expert number mismatch" + assert a1q_scale is None or a1q_scale.dim() == 0 or a1q_scale.size( + 0) == 1 or a1q_scale.size( + 0) == a1q.shape[0], "Input scale shape mismatch" + assert w1.size(0) == w2.size(0), "Weights expert number mismatch" + assert w1.size(0) == w1_scale.size(0), "w1 scales expert number mismatch" + assert w1.size(0) == w2_scale.size(0), "w2 scales expert number mismatch" + assert a2_scale is None or a2_scale.dim() == 0 or a2_scale.size( + 0) == 1 or a2_scale.size( + 0) == a1q.shape[0], "Intermediate scale shape mismatch" assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype" if expert_map is not None: assert expert_num_tokens is None @@ -75,12 +75,12 @@ def run_cutlass_moe_fp8( # their tokens are already contiguous for each expert as a result of # the dispatch function. - M = a1q.shape[0] # non batched expert M - padded_M = a1q.shape[1] # batched expert M + M = a1q.size(0) # non batched expert M + padded_M = a1q.size(1) # batched expert M _, K, N = w2.shape device = a1q.device - assert w1.shape[2] == K + assert w1.size(2) == K assert global_num_experts != -1 assert a1q_scale is not None @@ -91,8 +91,8 @@ def run_cutlass_moe_fp8( else: local_topk_ids = topk_ids - topk = local_topk_ids.shape[1] - local_E = w1.shape[0] + topk = local_topk_ids.size(1) + local_E = w1.size(0) if use_batched_format: assert expert_num_tokens is not None @@ -111,10 +111,10 @@ def run_cutlass_moe_fp8( problem_sizes2, expert_num_tokens, local_E, padded_M, N, K) - w1_scale = w1_scale.reshape(w1_scale.shape[0], -1) - w2_scale = w2_scale.reshape(w2_scale.shape[0], -1) - a1q = a1q.reshape(-1, a1q.shape[2]) - a1q_scale = a1q_scale.reshape(-1, a1q_scale.shape[2]).contiguous() + w1_scale = w1_scale.reshape(w1_scale.size(0), -1) + w2_scale = w2_scale.reshape(w2_scale.size(0), -1) + a1q = a1q.reshape(-1, a1q.size(2)) + a1q_scale = a1q_scale.reshape(-1, a1q_scale.size(2)).contiguous() else: expert_offsets = torch.empty((global_num_experts + 1), @@ -151,19 +151,19 @@ def run_cutlass_moe_fp8( a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale expert_offsets = expert_offsets[:-1] - ab_strides1 = torch.full((w1.shape[0], ), + ab_strides1 = torch.full((w1.size(0), ), K, device=device, dtype=torch.int64) - c_strides1 = torch.full((w1.shape[0], ), + c_strides1 = torch.full((w1.size(0), ), 2 * N, device=device, dtype=torch.int64) - ab_strides2 = torch.full((w1.shape[0], ), + ab_strides2 = torch.full((w1.size(0), ), N, device=device, dtype=torch.int64) - c_strides2 = torch.full((w1.shape[0], ), + c_strides2 = torch.full((w1.size(0), ), K, device=device, dtype=torch.int64) @@ -237,7 +237,7 @@ def workspace_shapes( workspace2: tuple[int, ...] = () output: tuple[int, ...] = () if self.use_batched_format: - padded_M = aq.shape[1] + padded_M = aq.size(1) workspace1 = (self.max_experts_per_worker, padded_M, max(N, K)) workspace2 = (self.max_experts_per_worker, padded_M, (N // 2)) output = (self.max_experts_per_worker, padded_M, K) @@ -332,7 +332,7 @@ def cutlass_moe_fp8( """ per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( a2_scale.numel() != 1 if a2_scale is not None else False) - per_out_ch = w1_scale.numel() != w1_q.shape[0] + per_out_ch = w1_scale.numel() != w1_q.size(0) out_dtype = a.dtype @@ -425,11 +425,11 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor, assert (m == m_a), "input shape mismatch" assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1" assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype" - assert (topk_weights.shape[0] == m and topk_ids.shape[0] + assert (topk_weights.size(0) == m and topk_ids.size(0) == m), ("topk must be provided for each row of a") out_dtype = a.dtype - num_topk = topk_ids.shape[1] + num_topk = topk_ids.size(1) expert_offsets = torch.empty((e + 1), dtype=torch.int32, device=device) blockscale_offsets = torch.empty((e + 1), dtype=torch.int32, device=device) @@ -463,7 +463,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor, out_dtype, device) del rep_a_fp4, rep_a_blockscale # hidden size dimension is split to one halfpytho sized tensor. - intermediate = torch.empty((m * num_topk, w1_fp4.shape[1] // 2), + intermediate = torch.empty((m * num_topk, w1_fp4.size(1) // 2), device=device, dtype=out_dtype) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index b4473b907381..050d9520ca01 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -48,7 +48,7 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor, M = hidden_states.size(0) _, K, N = w2.size() if not _valid_deep_gemm_shape(M, N, K): - logger.debug("DeepGemm disabled: unalinged problem size.") + logger.debug("DeepGemm disabled: unaligned problem size.") return False if (w1.dtype != torch.float8_e4m3fn or w2.dtype != torch.float8_e4m3fn): @@ -143,6 +143,7 @@ def apply( quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, N // 2)) mm2_out = _resize_cache(workspace2, (M_sum, K)) + # import pdb; pdb.set_trace() dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (a1q, a1q_scale), (w1, w1_scale), mm1_out, expert_ids) diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index 3484a7a8a496..5a8accd80463 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -25,7 +25,7 @@ def dequant_fp8(expert_x_fp8: torch.Tensor, expert_x_fp32 = expert_x_fp8.to(torch.float32).view( num_experts, -1, DEEPEP_QUANT_BLOCK_SIZE) expert_x_scales = expert_x_scales.view(num_experts, -1, 1) - return (expert_x_fp32 * expert_x_scales).view(expert_x_fp8.shape) + return (expert_x_fp32 * expert_x_scales).view(expert_x_fp8.size()) class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 437e80696ac6..f22884b8a1a5 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -488,10 +488,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, if use_fp8_w8a8 or use_int8_w8a8: assert B_scale is not None - assert (block_shape is None or triton.cdiv(B.shape[-2], block_shape[0]) - == B_scale.shape[-2]) - assert (block_shape is None or triton.cdiv(B.shape[-1], block_shape[1]) - == B_scale.shape[-1]) + assert (block_shape is None + or triton.cdiv(B.size(-2), block_shape[0]) == B_scale.size(-2)) + assert (block_shape is None + or triton.cdiv(B.size(-1), block_shape[1]) == B_scale.size(-1)) elif use_int8_w8a16 or use_int4_w4a16: assert B_scale is not None @@ -500,19 +500,19 @@ def invoke_fused_moe_kernel(A: torch.Tensor, assert A_scale is None assert B_scale is None - M = A.shape[0] + M = A.size(0) num_tokens = M * top_k - EM = sorted_token_ids.shape[0] - if A.shape[0] < config["BLOCK_SIZE_M"]: + EM = sorted_token_ids.size(0) + if A.size(0) < config["BLOCK_SIZE_M"]: # optimize for small batch_size. # We assume that top_ids of each token is unique, so # so num_valid_experts <= batch_size <= BLOCK_SIZE_M, # and we can skip some invalid blocks. - EM = min(sorted_token_ids.shape[0], - A.shape[0] * top_k * config['BLOCK_SIZE_M']) + EM = min(sorted_token_ids.size(0), + A.size(0) * top_k * config['BLOCK_SIZE_M']) grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv( - B.shape[1], META['BLOCK_SIZE_N']), ) + B.size(1), META['BLOCK_SIZE_N']), ) if (use_int8_w8a16 or use_int4_w4a16) and \ block_shape is not None and block_shape[1] > 0: @@ -522,16 +522,16 @@ def invoke_fused_moe_kernel(A: torch.Tensor, use_moe_wna16_cuda = should_moe_wna16_use_cuda( num_valid_tokens=num_tokens, group_size=block_shape[1], - num_experts=B.shape[0], + num_experts=B.size(0), bit=4 if use_int4_w4a16 else 8) config = config.copy() config.update( get_moe_wna16_block_config(config=config, use_moe_wna16_cuda=use_moe_wna16_cuda, num_valid_tokens=num_tokens, - size_k=A.shape[1], - size_n=B.shape[1], - num_experts=B.shape[1], + size_k=A.size(1), + size_n=B.size(1), + num_experts=B.size(1), group_size=block_shape[1], real_top_k=top_k, block_size_m=config["BLOCK_SIZE_M"])) @@ -556,8 +556,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, sorted_token_ids, expert_ids, num_tokens_post_padded, - B.shape[1], - A.shape[1], + B.size(1), + A.size(1), EM, num_tokens, A.stride(0), @@ -573,7 +573,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B_zp.stride(0) if B_zp is not None else 0, B_zp.stride(2) if B_zp is not None else 0, B_zp.stride(1) if B_zp is not None else 0, - block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0, + block_k_diviable=A.size(1) % config["BLOCK_SIZE_K"] == 0, group_size=block_shape[1], MUL_ROUTED_WEIGHT=mul_routed_weight, top_k=top_k, @@ -599,8 +599,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, sorted_token_ids, expert_ids, num_tokens_post_padded, - B.shape[1], - B.shape[2], + B.size(1), + B.size(2), EM, num_tokens, A.stride(0), @@ -818,7 +818,7 @@ def try_get_optimal_moe_config( M: int, is_marlin: bool = False, block_shape: Optional[list[int]] = None, -): +) -> dict[str, int]: from vllm.model_executor.layers.fused_moe import get_config override_config = get_config() if override_config: @@ -873,10 +873,10 @@ def fused_topk( renormalize: bool, indices_type: Optional[torch.dtype] = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - assert hidden_states.shape[0] == gating_output.shape[0], ( + assert hidden_states.size(0) == gating_output.size(0), ( "Number of tokens mismatch") - M, _ = hidden_states.shape + M, _ = hidden_states.size() topk_weights = torch.empty(M, topk, @@ -915,7 +915,7 @@ def grouped_topk( e_score_correction_bias: Optional[torch.Tensor] = None ) -> tuple[torch.Tensor, torch.Tensor]: - assert hidden_states.shape[0] == gating_output.shape[0], ( + assert hidden_states.size(0) == gating_output.size(0), ( "Number of tokens mismatch") if scoring_func == "softmax": @@ -925,7 +925,7 @@ def grouped_topk( else: raise ValueError(f"Unsupported scoring function: {scoring_func}") - num_token = scores.shape[0] + num_token = scores.size(0) if e_score_correction_bias is not None: # Store original scores before applying correction bias. We use biased # scores for expert selection but original scores for routing weights @@ -942,7 +942,7 @@ def grouped_topk( group_mask.scatter_(1, group_idx, 1) # [n, n_group] score_mask = group_mask.unsqueeze(-1).expand( num_token, num_expert_group, - scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e] + scores.size(-1) // num_expert_group).reshape(num_token, -1) # [n, e] tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e] @@ -1162,7 +1162,7 @@ def fused_experts(hidden_states: torch.Tensor, allow_deep_gemm: bool = False) -> torch.Tensor: # For now, disable DeepGemm for small N (<= 512) until better # permute/unpermute ops are available. - N = w1.shape[1] + N = w1.size(1) if (allow_deep_gemm and use_fp8_w8a8 and N > 512 and _valid_deep_gemm(hidden_states, w1, w2)): assert apply_router_weight_on_input is False @@ -1233,13 +1233,13 @@ def fused_experts_impl( ) -> torch.Tensor: # Check constraints. if use_int4_w4a16: - assert hidden_states.shape[1] // 2 == w1.shape[ - 2], "Hidden size mismatch" + assert hidden_states.size(1) // 2 == w1.size(2), ( + "Hidden size mismatch") else: - assert hidden_states.shape[1] == w1.shape[2], ( - f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[2]}") + assert hidden_states.size(1) == w1.size(2), ( + f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}") - assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert topk_weights.size() == topk_ids.size(), "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.stride(-1) == 1, "Stride of last dimension must be 1" assert w2.stride(-1) == 1, "Stride of last dimension must be 1" @@ -1247,12 +1247,12 @@ def fused_experts_impl( torch.float32, torch.float16, torch.bfloat16 ] - num_tokens = hidden_states.shape[0] - E, N, _ = w1.shape - K = w2.shape[1] + num_tokens = hidden_states.size(0) + E, N, _ = w1.size() + K = w2.size(1) if global_num_experts == -1: global_num_experts = E - top_k_num = topk_ids.shape[1] + top_k_num = topk_ids.size(1) # We execute the fused_moe kernel in chunks to circumvent this issue: # https://github.com/vllm-project/vllm/issues/5938 CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE @@ -1269,8 +1269,8 @@ def fused_experts_impl( get_config_func = functools.partial( try_get_optimal_moe_config, - w1.shape, - w2.shape, + w1.size(), + w2.size(), top_k_num, config_dtype, block_shape=block_shape, @@ -1310,7 +1310,7 @@ def fused_experts_impl( min((chunk + 1) * CHUNK_SIZE, num_tokens)) curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] - tokens_in_chunk, _ = curr_hidden_states.shape + tokens_in_chunk, _ = curr_hidden_states.size() if tokens_in_chunk == 0: break @@ -1322,7 +1322,7 @@ def fused_experts_impl( # do not need to be adjusted. intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] intermediate_cache2 = intermediate_cache2[:tokens_in_chunk * - topk_ids.shape[1]] + topk_ids.size(1)] intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] config = get_config_func(tokens_in_chunk) @@ -1398,7 +1398,7 @@ def fused_experts_impl( per_channel_quant=per_channel_quant, block_shape=block_shape) - ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), + ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()), out_hidden_states[begin_chunk_idx:end_chunk_idx]) return out_hidden_states @@ -1611,8 +1611,8 @@ def apply( dtype=hidden_states.dtype) config = try_get_optimal_moe_config( - w1.shape, - w2.shape, + w1.size(), + w2.size(), top_k_num, config_dtype, num_tokens, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 1fd8f2175886..c1bae033c2b4 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -45,7 +45,8 @@ from .pplx_prepare_finalize import PplxPrepareAndFinalize if has_deepep: from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize - from .deepep_ll_prepare_finalize import DeepEPLLPrepareAndFinalize + from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SIZE, + DeepEPLLPrepareAndFinalize) else: fused_experts = None # type: ignore FusedMoEPermuteExpertsUnpermute = None # type: ignore @@ -377,6 +378,13 @@ def init_prepare_finalize(self, moe: MoEConfig, all2all_manager.world_size) handle = all2all_manager.get_handle(all_to_all_args) + # Note : We may want to use FP8 dispatch even otherwise just to + # reduce datamovement + assert act_quant_block_size is not None + use_fp8_dispatch = (quant_dtype == current_platform.fp8_dtype() + and act_quant_block_size[1] + == DEEPEP_QUANT_BLOCK_SIZE) + # Note (varun): Whether to use FP8 dispatch or not needs some # profiling. Turning it off for now. prepare_finalize = DeepEPLLPrepareAndFinalize( @@ -386,7 +394,7 @@ def init_prepare_finalize(self, moe: MoEConfig, max_tokens_per_rank=moe.max_num_tokens, quant_dtype=quant_dtype, block_shape=act_quant_block_size, - use_fp8_dispatch=False, + use_fp8_dispatch=use_fp8_dispatch, ) self.topk_indices_dtype = None @@ -853,13 +861,11 @@ def __init__( self.global_num_experts = num_experts # For smuggling this layer into the fused moe custom op - self.use_direct_call = self.dp_size == 1 - if not self.use_direct_call: - compilation_config = vllm_config.compilation_config - if prefix in compilation_config.static_forward_context: - raise ValueError("Duplicate layer name: {}".format(prefix)) - compilation_config.static_forward_context[prefix] = self - self.layer_name = prefix + compilation_config = vllm_config.compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError("Duplicate layer name: {}".format(prefix)) + compilation_config.static_forward_context[prefix] = self + self.layer_name = prefix # Determine expert maps if self.use_ep: @@ -1353,11 +1359,8 @@ def maybe_all_reduce_tensor_model_parallel( def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): - if self.use_direct_call: - return self.forward_impl(hidden_states, router_logits) - else: - return torch.ops.vllm.moe_forward(hidden_states, router_logits, - self.layer_name) + return torch.ops.vllm.moe_forward(hidden_states, router_logits, + self.layer_name) def forward_impl_chunked(self, full_hidden_states: torch.Tensor, full_router_logits: torch.Tensor): diff --git a/vllm/model_executor/layers/fused_moe/moe_align_block_size.py b/vllm/model_executor/layers/fused_moe/moe_align_block_size.py index f9451ca2fde4..ceb96add0fde 100644 --- a/vllm/model_executor/layers/fused_moe/moe_align_block_size.py +++ b/vllm/model_executor/layers/fused_moe/moe_align_block_size.py @@ -6,11 +6,7 @@ from vllm import _custom_ops as ops from vllm.triton_utils import tl, triton -from vllm.utils import round_up - - -def ceil_div(a, b): - return (a + b - 1) // b +from vllm.utils import cdiv, round_up @triton.jit @@ -115,7 +111,7 @@ def moe_align_block_size_triton( cumsum = torch.zeros((num_experts + 1, ), dtype=torch.int32, device=topk_ids.device) - tokens_per_thread = ceil_div(numel, num_experts) + tokens_per_thread = cdiv(numel, num_experts) moe_align_block_size_stage1[grid]( topk_ids, diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 5bc01dbf2025..2ff8ef99b2ec 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -69,7 +69,7 @@ def prepare( a1 = a1 * rank_topk_weights.to(a1.dtype) repeat_cols = 4 - repeat_rows = 1 if self.per_act_token else a1.shape[0] + repeat_rows = 1 if self.per_act_token else a1.size(0) a1q, a1q_scale = moe_kernel_quantize_input( a1, (None if self.per_act_token else a1_scale), self.quant_dtype, self.per_act_token, self.block_shape) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index eb2148d76452..8a33cd6be405 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -239,25 +239,24 @@ def extract_states( prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata) prompt_token_ids = self.get_prompt_token_ids(pooling_metadata) - pooled_data: list[torch.Tensor] = [] - + pooled_data_lst = list[torch.Tensor]() if isinstance(hidden_states, list): for req_state, prompt_len in zip(hidden_states, prompt_lens): assert prompt_len == req_state.shape[0], \ - "partial prefill not supported with mean pooling" - pooled_data = hidden_states + "partial prefill not supported with step pooling" + pooled_data_lst = hidden_states else: offset = 0 for prompt_len in prompt_lens: pooled_data_i = hidden_states[offset:offset + prompt_len] offset += prompt_len - pooled_data.append(pooled_data_i) + pooled_data_lst.append(pooled_data_i) - pooled_data = [] + pooled_data = list[torch.Tensor]() returned_token_ids = self.returned_token_ids step_tag_id = self.step_tag_id - for data, token_id in zip(pooled_data, prompt_token_ids): + for data, token_id in zip(pooled_data_lst, prompt_token_ids): if returned_token_ids is not None and len(returned_token_ids) > 0: data = data[:, returned_token_ids] diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index e5702c871cc9..d21abb2741a2 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -13,6 +13,7 @@ QuantizationType) from pydantic import BaseModel +import vllm.envs as envs from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, @@ -374,7 +375,8 @@ def _get_scheme_from_parts( if is_activation_quantization_format(self.quant_format): if self._is_fp4a4_nvfp4(weight_quant, input_quant): - if CompressedTensorsW4A4Fp4.cutlass_fp4_supported(): + if CompressedTensorsW4A4Fp4.cutlass_fp4_supported( + ) or envs.VLLM_USE_NVFP4_CT_EMULATIONS: return CompressedTensorsW4A4Fp4() else: logger.warning_once( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py index 32718972a627..ec1d4a6c0efa 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -4,11 +4,14 @@ import torch from torch.nn.parameter import Parameter +import vllm.envs as envs from vllm._custom_ops import (cutlass_scaled_fp4_mm, cutlass_scaled_mm_supports_fp4, scaled_fp4_quant) from vllm.logger import init_logger from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501 + run_nvfp4_emulations) from vllm.model_executor.parameter import (GroupQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter) @@ -26,6 +29,8 @@ def __init__(self): @classmethod def get_min_capability(cls) -> int: + if envs.VLLM_USE_NVFP4_CT_EMULATIONS: + return 80 return 100 @classmethod @@ -129,6 +134,17 @@ def apply_weights(self, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: + if envs.VLLM_USE_NVFP4_CT_EMULATIONS: + out = run_nvfp4_emulations( + x=x, + input_global_scale=layer.input_global_scale, + weight=layer.weight, + weight_scale_swizzled=layer.weight_scale_swizzled, + weight_global_scale=layer.weight_global_scale) + if bias is not None: + out = out + bias + return out + output_dtype = x.dtype output_shape = [x.shape[0], layer.weight.shape[0]] diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index f92ebdea986d..e9b8dc3266b4 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from copy import deepcopy from typing import Any, Callable, Optional, Union import torch @@ -9,7 +10,8 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) + FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported, + UnquantizedFusedMoEMethod) from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization import QuantizationMethods @@ -19,7 +21,7 @@ MPLinearLayerConfig, choose_mp_linear_kernel) from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.gptq_utils import ( - get_linear_quant_method) + get_dynamic_override, get_linear_quant_method, override_config) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( check_marlin_supported, check_moe_marlin_supports_layer, marlin_make_workspace_new, marlin_moe_permute_scales, @@ -35,6 +37,29 @@ logger = init_logger(__name__) +def get_moe_quant_method( + config: QuantizationConfig, + layer: torch.nn.Module, + prefix: str, + moe_method_cls: type, +): + cloned_config = deepcopy(config) + + if isinstance(layer, FusedMoE): + # False = skip module, None = no override, else = Positive match + if get_dynamic_override( # noqa: E712 + cloned_config, # noqa: E712 + layer_name=prefix) == False: # noqa: E712 + return UnquantizedFusedMoEMethod(layer.moe_config) + + if prefix: + # Dynamic per module/layer rules may override base config + override_config(cloned_config, prefix=prefix) + + return moe_method_cls(cloned_config) + return None + + class GPTQMarlinConfig(QuantizationConfig): """Config class for GPTQ Marlin""" @@ -163,7 +188,8 @@ def get_quant_method(self, layer: torch.nn.Module, "Falling back to Moe WNA16 kernels.") return MoeWNA16Config.from_config( self.full_config).get_quant_method(layer, prefix) - return GPTQMarlinMoEMethod(self) + return get_moe_quant_method(self, layer, prefix, + GPTQMarlinMoEMethod) return get_linear_quant_method(self, layer, prefix, GPTQMarlinLinearMethod) diff --git a/vllm/model_executor/layers/quantization/torchao.py b/vllm/model_executor/layers/quantization/torchao.py index 9c909a3a430c..a4e0356c0268 100644 --- a/vllm/model_executor/layers/quantization/torchao.py +++ b/vllm/model_executor/layers/quantization/torchao.py @@ -44,14 +44,14 @@ def __init__(self, """ # TorchAO quantization relies on tensor subclasses. In order, # to enable proper caching this needs standalone compile - if is_torch_equal_or_newer("2.8.0a"): + if is_torch_equal_or_newer("2.8.0.dev"): os.environ["VLLM_TEST_STANDALONE_COMPILE"] = "1" logger.info( "Using TorchAO: Setting VLLM_TEST_STANDALONE_COMPILE=1") # TODO: remove after the torch dependency is updated to 2.8 if is_torch_equal_or_newer( - "2.7.0") and not is_torch_equal_or_newer("2.8.0a"): + "2.7.0") and not is_torch_equal_or_newer("2.8.0.dev"): os.environ["VLLM_DISABLE_COMPILE_CACHE"] = "1" logger.info("Using TorchAO: Setting VLLM_DISABLE_COMPILE_CACHE=1") """ diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 754650ebeffb..3a0fb83d627a 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -19,7 +19,7 @@ CUTLASS_BLOCK_FP8_SUPPORTED) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton -from vllm.utils import direct_register_custom_op +from vllm.utils import cdiv, direct_register_custom_op logger = init_logger(__name__) has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None @@ -158,12 +158,9 @@ def apply_w8a8_block_fp8_linear( if current_platform.is_cuda(): if current_platform.has_device_capability(100): - def ceil_div(x: int, y: int) -> int: - return (x + y - 1) // y - use_cutlass = cutlass_block_fp8_supported and ( - ceil_div(weight.shape[0], 128) == weight_scale.shape[0] - and ceil_div(weight.shape[1], 128) == weight_scale.shape[1]) + cdiv(weight.shape[0], 128) == weight_scale.shape[0] + and cdiv(weight.shape[1], 128) == weight_scale.shape[1]) else: # TODO: update this after switching to public sm90 block scale gemm # as it also supports weight.shape % 128 != 0 diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 9de2338968a1..b7bb2affc4fa 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -26,6 +26,7 @@ import math from typing import Any, Optional, Union +import numpy as np import torch import torch.nn as nn from transformers import PretrainedConfig @@ -1458,15 +1459,14 @@ def get_next_input_positions( ] @staticmethod - def get_next_input_positions_tensor( - mrope_position_delta: int, - context_len: int, - seq_len: int, - ) -> torch.Tensor: - return torch.arange( - mrope_position_delta + context_len, - mrope_position_delta + seq_len, - ).expand(3, -1) + def get_next_input_positions_tensor(out: np.ndarray, out_offset: int, + mrope_position_delta: int, + context_len: int, num_new_tokens: int): + + values = np.arange(mrope_position_delta + context_len, + mrope_position_delta + context_len + num_new_tokens, + dtype=out.dtype) + out[:, out_offset:out_offset + num_new_tokens] = values @classmethod def omni_get_updates_use_audio_in_video( diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index 6e6e74b0d1d9..911f0036c2dd 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -52,11 +52,6 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - ) - self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.eh_proj = nn.Linear(config.hidden_size * 2, @@ -74,8 +69,6 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, spec_step_index: int = 0, ) -> torch.Tensor: - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) assert inputs_embeds is not None # masking inputs at position 0, as not needed by MTP inputs_embeds[positions == 0] = 0 @@ -112,7 +105,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): for idx in range(self.mtp_start_layer_idx, self.mtp_start_layer_idx + self.num_mtp_layers) }) - + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) self.logits_processor = LogitsProcessor(config.vocab_size) def forward( @@ -123,6 +119,8 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, spec_step_idx: int = 0, ) -> torch.Tensor: + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) current_step_idx = (spec_step_idx % self.num_mtp_layers) return self.layers[str(self.mtp_start_layer_idx + current_step_idx)]( input_ids, @@ -242,6 +240,12 @@ def load_weights(self, weights: Iterable[tuple[str, if name.endswith(".bias") and name not in params_dict: continue + # According to DeepSeek-V3 Technical Report, MTP modules + # shares embedding layer. We only load the first weights. + if (spec_layer != self.model.mtp_start_layer_idx + and ".layers" not in name): + continue + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) @@ -253,17 +257,25 @@ def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: """ Rewrite the weight name to match the format of the original model. Add .mtp_block for modules in transformer layer block for spec layer + and rename shared layer weights to be top level. """ spec_layer_weight_names = [ "embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head" ] + shared_weight_names = ["embed_tokens"] spec_layer_weight = False + shared_weight = False for weight_name in spec_layer_weight_names: if weight_name in name: spec_layer_weight = True + if weight_name in shared_weight_names: + shared_weight = True break if not spec_layer_weight: # treat rest weights as weights for transformer layer block name = name.replace(f"model.layers.{spec_layer}.", f"model.layers.{spec_layer}.mtp_block.") + elif shared_weight: + # treat shared weights as top level weights + name = name.replace(f"model.layers.{spec_layer}.", "model.") return name diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index 619d2aa67491..3a1c14978b45 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -479,7 +479,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, "model.vision_tower.": "vision_tower.", "model.multi_modal_projector.": "multi_modal_projector.", "lm_head.": "language_model.lm_head.", - "vision_tower.vision_model.": "vision_model.", }) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 0e7e4e73eca9..f759f8f1f273 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -489,6 +489,12 @@ def supports_cross_encoding( return is_pooling_model(model) and _supports_cross_encoding(model) +def has_step_pooler(model: Union[type[object], object]) -> bool: + """Check if the model uses step pooler.""" + return is_pooling_model(model) and any( + type(module).__name__ == "StepPool" for module in model.modules()) + + class SupportsQuant: """The interface required for all models that support quantization.""" diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index 9fb73261cd89..0c9baab1f2e4 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -148,9 +148,8 @@ def __init__(self, self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 - # TODO: attn_temperature_tuning should be a bool in huggingface self.attn_temperature_tuning = self.nope and \ - config.attn_temperature_tuning > 0 + config.attn_temperature_tuning self.floor_scale = getattr(config, "floor_scale", 8192.0) self.attn_scale = getattr(config, "attn_scale", 0.1) diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index 7c1f889e8f38..9d619b38d38d 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -258,6 +258,7 @@ def __init__(self, config: ModernBertConfig): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias) + self.pooling_type = config.classifier_pooling self.act = nn.GELU() self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, @@ -265,7 +266,13 @@ def __init__(self, config: ModernBertConfig): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: pooled_output = hidden_states - pooled_output = pooled_output.mean(dim=0, keepdim=False) + if self.pooling_type == "mean": + pooled_output = pooled_output.mean(dim=0, keepdim=False) + elif self.pooling_type == "cls": + pooled_output = pooled_output[0, :] + else: + raise ValueError("Pooling type should be either `cls` or `mean`, " + f"but got {self.pooling_type}") pooled_output = self.norm(self.act(self.dense(pooled_output))) return pooled_output diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index e673632d4366..dce4c4c1cadb 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -1,10 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import base64 from io import BytesIO from pathlib import Path +import pybase64 import torch from PIL import Image @@ -55,7 +55,7 @@ def load_bytes(self, data: bytes) -> Image.Image: return convert_image_mode(image, self.image_mode) def load_base64(self, media_type: str, data: str) -> Image.Image: - return self.load_bytes(base64.b64decode(data)) + return self.load_bytes(pybase64.b64decode(data, validate=True)) def load_file(self, filepath: Path) -> Image.Image: image = Image.open(filepath) @@ -75,7 +75,7 @@ def encode_base64( image.save(buffer, image_format) data = buffer.getvalue() - return base64.b64encode(data).decode('utf-8') + return pybase64.b64encode(data).decode('utf-8') class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]): @@ -88,10 +88,10 @@ def load_bytes(self, data: bytes) -> torch.Tensor: return torch.load(buffer, weights_only=True) def load_base64(self, media_type: str, data: str) -> torch.Tensor: - return self.load_bytes(base64.b64decode(data)) + return self.load_bytes(pybase64.b64decode(data, validate=True)) def load_file(self, filepath: Path) -> torch.Tensor: return torch.load(filepath, weights_only=True) def encode_base64(self, media: torch.Tensor) -> str: - return base64.b64encode(media.numpy()).decode('utf-8') + return pybase64.b64encode(media.numpy()).decode('utf-8') diff --git a/vllm/utils.py b/vllm/utils.py index 34be4d52c483..fdefda901c4d 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -2919,8 +2919,13 @@ def is_torch_equal_or_newer(target: str) -> bool: Whether the condition meets. """ try: - torch_version = version.parse(str(torch.__version__)) - return torch_version >= version.parse(target) + return _is_torch_equal_or_newer(str(torch.__version__), target) except Exception: # Fallback to PKG-INFO to load the package info, needed by the doc gen. return Version(importlib.metadata.version('torch')) >= Version(target) + + +# Helper function used in testing. +def _is_torch_equal_or_newer(torch_version: str, target: str) -> bool: + torch_version = version.parse(torch_version) + return torch_version >= version.parse(target) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 4ad7178374b1..ef65d2ea36e4 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -158,12 +158,13 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, self.aot_schedule = (get_flash_attn_version() == 3) self.use_full_cuda_graph = compilation_config.full_cuda_graph - if self.use_full_cuda_graph and not self.aot_schedule: - raise ValueError("Full CUDA graph mode requires AOT scheduling, " - "which requires FlashAttention 3.") - self.scheduler_metadata = torch.zeros(self.runner.max_num_reqs + 1, - dtype=torch.int32, - device=self.runner.device) + if self.use_full_cuda_graph: + # NOTE(lucas): AOT scheduling not supported in full cuda graph mode + # yet. This is because the scheduler and kernel need to always use + # the same num_splits (which acts as an upper bound with the + # dynamic split scheduler) which is currently heuristically decided + # by the kernel launching code. + self.aot_schedule = False # Sliding window size to be used with the AOT scheduler will be # populated on first build() call. @@ -299,18 +300,6 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len=max_seq_len, causal=True) - if self.use_full_cuda_graph: - assert scheduler_metadata is not None - n = scheduler_metadata.shape[0] - self.scheduler_metadata[:n].copy_(scheduler_metadata, - non_blocking=True) - # NOTE(woosuk): We should zero out the rest of the scheduler - # metadata to guarantee the correctness. Otherwise, some thread - # blocks may use the invalid scheduler metadata and overwrite the - # output buffer. - self.scheduler_metadata[n:] = 0 - scheduler_metadata = self.scheduler_metadata[:n] - attn_metadata = FlashAttentionMetadata( num_actual_tokens=num_actual_tokens, max_query_len=max_query_len, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index da65550354d0..453ed364dc81 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -877,12 +877,16 @@ def run_busy_loop(self): local_unfinished_reqs) if not self.engines_running: - if self.dp_rank == 0: + if self.dp_rank == 0 or not self.has_coordinator: # Notify client that we are pausing the loop. logger.debug("Wave %d finished, pausing engine loop.", self.current_wave) + # In the coordinator case, dp rank 0 sends updates to the + # coordinator. Otherwise (offline spmd case), each rank + # sends the update to its colocated front-end process. + client_index = -1 if self.has_coordinator else 0 self.output_queue.put_nowait( - (-1, + (client_index, EngineCoreOutputs(wave_complete=self.current_wave))) self.current_wave += 1 diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 8058cd3127df..856310df5888 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -155,6 +155,11 @@ def collective_rpc(self, kwargs: Optional[dict[str, Any]] = None) -> list[_R]: raise NotImplementedError + def dp_engines_running(self) -> bool: + """Returns True id data parallel engines are collectively in a + running state.""" + raise NotImplementedError + async def get_output_async(self) -> EngineCoreOutputs: raise NotImplementedError @@ -282,6 +287,9 @@ def collective_rpc(self, kwargs: Optional[dict[str, Any]] = None) -> list[_R]: return self.engine_core.collective_rpc(method, timeout, args, kwargs) + def dp_engines_running(self) -> bool: + return False + @dataclass class BackgroundResources: @@ -384,6 +392,9 @@ def __init__( dp_size = parallel_config.data_parallel_size dp_rank = parallel_config.data_parallel_rank + # State used for data parallel. + self.engines_running = False + # SPMD mode is where there is an LLM instance per DP rank and # one core engine per LLM, see # examples/offline_inference/data_parallel.py. @@ -539,6 +550,9 @@ def free_pending_messages(self): while self.pending_messages and self.pending_messages[-1][0].done: self.pending_messages.pop() + def dp_engines_running(self) -> bool: + return self.engines_running + def _process_utility_output(output: UtilityOutput, utility_results: dict[int, AnyFuture]): @@ -562,6 +576,7 @@ def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], log_stats=log_stats, ) + self.is_dp = self.vllm_config.parallel_config.data_parallel_size > 1 self.outputs_queue = queue.Queue[Union[EngineCoreOutputs, Exception]]() # Ensure that the outputs socket processing thread does not have @@ -623,6 +638,8 @@ def get_output(self) -> EngineCoreOutputs: outputs = self.outputs_queue.get() if isinstance(outputs, Exception): raise self._format_exception(outputs) from None + if outputs.wave_complete is not None: + self.engines_running = False return outputs def _send_input(self, request_type: EngineCoreRequestType, request: Any): @@ -650,6 +667,8 @@ def call_utility(self, method: str, *args) -> Any: return future.result() def add_request(self, request: EngineCoreRequest) -> None: + if self.is_dp: + self.engines_running = True self._send_input(EngineCoreRequestType.ADD, request) def abort_requests(self, request_ids: list[str]) -> None: @@ -911,7 +930,6 @@ def __init__(self, client_addresses: Optional[dict[str, str]] = None, client_index: int = 0): self.current_wave = 0 - self.engines_running = False # To route aborts to the correct engine. self.reqs_in_flight: dict[str, CoreEngine] = {} diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 1932cd10bb1b..25fab2713114 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -160,7 +160,7 @@ def get_num_unfinished_requests(self) -> int: def has_unfinished_requests(self) -> bool: has_unfinished = self.output_processor.has_unfinished_requests() if self.dp_group is None: - return has_unfinished + return has_unfinished or self.engine_core.dp_engines_running() return self.has_unfinished_requests_dp(has_unfinished) def has_unfinished_requests_dp(self, has_unfinished: bool) -> bool: diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 153b67fe5714..156f5764e8dc 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -148,7 +148,7 @@ def propose( assert self.runner is not None # FIXME: need to consider multiple kv_cache_groups - attn_metadata = self.runner.attn_metadata_builder.build( + attn_metadata = self.runner.attn_metadata_builders[0].build( common_prefix_len=0, common_attn_metadata=common_attn_metadata, ) diff --git a/vllm/v1/worker/cpu_model_runner.py b/vllm/v1/worker/cpu_model_runner.py index 6631c9636eac..370de9f11599 100644 --- a/vllm/v1/worker/cpu_model_runner.py +++ b/vllm/v1/worker/cpu_model_runner.py @@ -7,6 +7,7 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model +from vllm.model_executor.models.interfaces import has_step_pooler from vllm.v1.worker.gpu_model_runner import GPUModelRunner logger = init_logger(__name__) @@ -52,6 +53,9 @@ def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) self.model = get_model(vllm_config=self.vllm_config) + if has_step_pooler(self.model): + self.input_batch.logits_processing_needs_token_ids = True + if self.lora_config: self.model = self.load_lora_model(self.model, self.model_config, self.scheduler_config, diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 3a2c9ef7dfac..ca2bfe831746 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -59,14 +59,15 @@ def get_token_id(self, idx: int) -> int: class InputBatch: def __init__( - self, - max_num_reqs: int, - max_model_len: int, - max_num_batched_tokens: int, - device: torch.device, - pin_memory: bool, - vocab_size: int, - block_sizes: list[int], # The block_size of each kv cache group + self, + max_num_reqs: int, + max_model_len: int, + max_num_batched_tokens: int, + device: torch.device, + pin_memory: bool, + vocab_size: int, + block_sizes: list[int], # The block_size of each kv cache group + logits_processing_needs_token_ids: bool = False, ): self.max_num_reqs = max_num_reqs self.max_model_len = max_model_len @@ -74,6 +75,8 @@ def __init__( self.device = device self.pin_memory = pin_memory self.vocab_size = vocab_size + self.logits_processing_needs_token_ids = ( + logits_processing_needs_token_ids) self._req_ids: list[Optional[str]] = [] self.req_id_to_index: dict[str, int] = {} @@ -579,9 +582,14 @@ def _make_sampling_metadata(self) -> SamplingMetadata: copy_slice(self.repetition_penalties_cpu_tensor, self.repetition_penalties, num_reqs) - # The prompt tokens are used only for applying penalties during - # the sampling process. Hence copy these tensors only when - # there are requests which need penalties to be applied. + needs_prompt_token_ids = (not self.no_penalties or + (self.num_reqs > 0 + and self.logits_processing_needs_token_ids)) + if needs_prompt_token_ids: + # The prompt tokens are used only for applying penalties or + # step pooling during the sampling/pooling process. + # Hence copy these tensors only when there are requests which + # need penalties/step_pooler to be applied. prompt_token_ids = self._make_prompt_token_ids_tensor() else: prompt_token_ids = None diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 330366006118..40639fdf2433 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -33,6 +33,7 @@ from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader +from vllm.model_executor.models.interfaces import has_step_pooler from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.utils import group_mm_inputs_by_modality @@ -261,6 +262,7 @@ def __init__( dtype=torch.int64, device="cpu", pin_memory=self.pin_memory) + self.mrope_positions_np = self.mrope_positions_cpu.numpy() # Only relevant for models using ALiBi (e.g, MPT) self.use_alibi = check_use_alibi(model_config) @@ -888,15 +890,13 @@ def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): dst_start = mrope_pos_ptr dst_end = mrope_pos_ptr + completion_part_len - self.mrope_positions_cpu[:, dst_start:dst_end] = \ - MRotaryEmbedding.get_next_input_positions_tensor( - req.mrope_position_delta, - context_len=num_computed_tokens + - prompt_part_len, - seq_len=num_computed_tokens + - prompt_part_len + - completion_part_len, - ) + MRotaryEmbedding.get_next_input_positions_tensor( + out=self.mrope_positions_np, + out_offset=dst_start, + mrope_position_delta=req.mrope_position_delta, + context_len=num_computed_tokens + prompt_part_len, + num_new_tokens=completion_part_len, + ) mrope_pos_ptr += completion_part_len @@ -1708,6 +1708,8 @@ def load_model(self) -> None: ) model_loader.load_weights(self.model, model_config=self.model_config) + if has_step_pooler(self.model): + self.input_batch.logits_processing_needs_token_ids = True if self.lora_config: self.model = self.load_lora_model(self.model, self.model_config,