diff --git a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh
index a2a5c2a02cbb..90cad506ab1e 100755
--- a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh
+++ b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh
@@ -159,6 +159,8 @@ run_and_track_test 14 "test_tpu_qkv_linear.py" \
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_tpu_qkv_linear.py"
run_and_track_test 15 "test_spmd_model_weight_loading.py" \
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_spmd_model_weight_loading.py"
+run_and_track_test 16 "test_kv_cache_update_kernel.py" \
+ "python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_kv_cache_update_kernel.py"
# After all tests have been attempted, exit with the overall status.
if [ "$overall_script_exit_code" -ne 0 ]; then
diff --git a/.buildkite/scripts/hardware_ci/run-xpu-test.sh b/.buildkite/scripts/hardware_ci/run-xpu-test.sh
index f54010c4231f..827649bfcf54 100644
--- a/.buildkite/scripts/hardware_ci/run-xpu-test.sh
+++ b/.buildkite/scripts/hardware_ci/run-xpu-test.sh
@@ -28,4 +28,5 @@ docker run \
sh -c '
VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m
VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m -tp 2
+ VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
'
diff --git a/.buildkite/scripts/tpu/docker_run_bm.sh b/.buildkite/scripts/tpu/docker_run_bm.sh
index 6705da03e3d7..715afce5f71a 100755
--- a/.buildkite/scripts/tpu/docker_run_bm.sh
+++ b/.buildkite/scripts/tpu/docker_run_bm.sh
@@ -68,7 +68,7 @@ docker run \
echo "run script..."
echo
-docker exec "$CONTAINER_NAME" /bin/bash -c ".buildkite/scripts/hardware_ci/run_bm.sh"
+docker exec "$CONTAINER_NAME" /bin/bash -c ".buildkite/scripts/tpu/run_bm.sh"
echo "copy result back..."
VLLM_LOG="$LOG_ROOT/$TEST_NAME"_vllm_log.txt
diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml
index d6c9ee680abf..a13e2cb78218 100644
--- a/.buildkite/test-pipeline.yaml
+++ b/.buildkite/test-pipeline.yaml
@@ -41,6 +41,16 @@ steps:
# TODO: add `--strict` once warnings in docstrings are fixed
- mkdocs build
+- label: Pytorch Nightly Dependency Override Check # 2min
+ # if this test fails, it means the nightly torch version is not compatible with some
+ # of the dependencies. Please check the error message and add the package to whitelist
+ # in /vllm/tools/generate_nightly_torch_test.py
+ soft_fail: true
+ source_file_dependencies:
+ - requirements/nightly_torch_test.txt
+ commands:
+ - bash standalone_tests/pytorch_nightly_dependency.sh
+
- label: Async Engine, Inputs, Utils, Worker Test # 24min
mirror_hardwares: [amdexperimental]
source_file_dependencies:
@@ -168,6 +178,23 @@ steps:
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
- popd
+- label: EPLB Algorithm Test
+ working_dir: "/vllm-workspace/tests"
+ source_file_dependencies:
+ - vllm/distributed/eplb
+ - tests/distributed/test_eplb_algo.py
+ commands:
+ - pytest -v -s distributed/test_eplb_algo.py
+
+- label: EPLB Execution Test # 5min
+ working_dir: "/vllm-workspace/tests"
+ num_gpus: 4
+ source_file_dependencies:
+ - vllm/distributed/eplb
+ - tests/distributed/test_eplb_execute.py
+ commands:
+ - pytest -v -s distributed/test_eplb_execute.py
+
- label: Metrics, Tracing Test # 10min
mirror_hardwares: [amdexperimental, amdproduction]
num_gpus: 2
@@ -509,6 +536,17 @@ steps:
- pip freeze | grep -E 'torch'
- pytest -v -s models/language -m core_model
+- label: Language Models Test (Hybrid) # 35 min
+ mirror_hardwares: [amdexperimental]
+ torch_nightly: true
+ source_file_dependencies:
+ - vllm/
+ - tests/models/language/generation
+ commands:
+ # Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile.
+ - pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8'
+ - pytest -v -s models/language/generation -m hybrid_model
+
- label: Language Models Test (Extended Generation) # 1hr20min
mirror_hardwares: [amdexperimental]
optional: true
@@ -518,7 +556,7 @@ steps:
commands:
# Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile.
- pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8'
- - pytest -v -s models/language/generation -m 'not core_model'
+ - pytest -v -s models/language/generation -m '(not core_model) and (not hybrid_model)'
- label: Language Models Test (Extended Pooling) # 36min
mirror_hardwares: [amdexperimental]
@@ -619,11 +657,13 @@ steps:
commands:
- # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up)
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
+ - NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed'
- python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=0 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py
- # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up)
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
+ - NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed'
- python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=1 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code
- label: Distributed Tests (2 GPUs) # 40min
@@ -748,7 +788,7 @@ steps:
- bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models.txt
- label: Weight Loading Multiple GPU Test - Large Models # optional
- mirror_hardwares: [amdexperimental]
+ mirror_hardwares: [amdexperimental]
working_dir: "/vllm-workspace/tests"
num_gpus: 2
gpu: a100
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index e62b623b4e11..15ef5defff69 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -53,6 +53,11 @@ repos:
files: ^requirements/test\.(in|txt)$
- repo: local
hooks:
+ - id: format-torch-nightly-test
+ name: reformat nightly_torch_test.txt to be in sync with test.in
+ language: python
+ entry: python tools/generate_nightly_torch_test.py
+ files: ^requirements/test\.(in|txt)$
- id: mypy-local
name: Run mypy for local Python installation
entry: tools/mypy.sh 0 "local"
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 19f3f7542a7e..927500fe5365 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -513,6 +513,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
CUDA_ARCHS "${FP4_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4=1")
+ list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
else()
message(STATUS "Not building NVFP4 as no compatible archs were found.")
@@ -547,8 +548,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# if it's possible to compile MoE kernels that use its output.
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
- set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu"
- "csrc/quantization/cutlass_w8a8/moe/moe_data.cu")
+ set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
@@ -562,7 +562,27 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"if you intend on running FP8 quantized MoE models on Hopper.")
else()
message(STATUS "Not building grouped_mm_c3x as no compatible archs found "
- "in CUDA target architectures")
+ "in CUDA target architectures.")
+ endif()
+ endif()
+
+ # moe_data.cu is used by all CUTLASS MoE kernels.
+ cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}")
+ if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
+ set(SRCS "csrc/quantization/cutlass_w8a8/moe/moe_data.cu")
+ set_gencode_flags_for_srcs(
+ SRCS "${SRCS}"
+ CUDA_ARCHS "${CUTLASS_MOE_DATA_ARCHS}")
+ list(APPEND VLLM_EXT_SRC "${SRCS}")
+ message(STATUS "Building moe_data for archs: ${CUTLASS_MOE_DATA_ARCHS}")
+ else()
+ if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
+ message(STATUS "Not building moe_data as CUDA Compiler version is "
+ "not >= 12.3, we recommend upgrading to CUDA 12.3 or later "
+ "if you intend on running FP8 quantized MoE models on Hopper or Blackwell.")
+ else()
+ message(STATUS "Not building moe_data as no compatible archs found "
+ "in CUDA target architectures.")
endif()
endif()
@@ -638,6 +658,14 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# if CUDA endif
endif()
+if (VLLM_GPU_LANG STREQUAL "HIP")
+ # Add QuickReduce kernels
+ list(APPEND VLLM_EXT_SRC
+ "csrc/custom_quickreduce.cu"
+ )
+# if ROCM endif
+endif()
+
message(STATUS "Enabling C extension.")
define_gpu_extension_target(
_C
diff --git a/benchmarks/README.md b/benchmarks/README.md
index 2714b8b49821..fb8690d42db9 100644
--- a/benchmarks/README.md
+++ b/benchmarks/README.md
@@ -4,7 +4,7 @@ This README guides you through running benchmark tests with the extensive
datasets supported on vLLM. Itโs a living document, updated as new features and datasets
become available.
-## Dataset Overview
+**Dataset Overview**
@@ -82,7 +82,10 @@ become available.
**Note**: HuggingFace dataset's `dataset-name` should be set to `hf`
---
-## Example - Online Benchmark
+
+๐ Example - Online Benchmark
+
+
First start serving your model
@@ -130,7 +133,8 @@ P99 ITL (ms): 8.39
==================================================
```
-### Custom Dataset
+**Custom Dataset**
+
If the dataset you want to benchmark is not supported yet in vLLM, even then you can benchmark on it using `CustomDataset`. Your data needs to be in `.jsonl` format and needs to have "prompt" field per entry, e.g., data.jsonl
```
@@ -162,7 +166,7 @@ python3 benchmarks/benchmark_serving.py --port 9001 --save-result --save-detaile
You can skip applying chat template if your data already has it by using `--custom-skip-chat-template`.
-### VisionArena Benchmark for Vision Language Models
+**VisionArena Benchmark for Vision Language Models**
```bash
# need a model with vision capability here
@@ -180,7 +184,7 @@ python3 vllm/benchmarks/benchmark_serving.py \
--num-prompts 1000
```
-### InstructCoder Benchmark with Speculative Decoding
+**InstructCoder Benchmark with Speculative Decoding**
``` bash
VLLM_USE_V1=1 vllm serve meta-llama/Meta-Llama-3-8B-Instruct \
@@ -197,7 +201,7 @@ python3 benchmarks/benchmark_serving.py \
--num-prompts 2048
```
-### Other HuggingFaceDataset Examples
+**Other HuggingFaceDataset Examples**
```bash
vllm serve Qwen/Qwen2-VL-7B-Instruct --disable-log-requests
@@ -251,7 +255,7 @@ python3 vllm/benchmarks/benchmark_serving.py \
--num-prompts 80
```
-### Running With Sampling Parameters
+**Running With Sampling Parameters**
When using OpenAI-compatible backends such as `vllm`, optional sampling
parameters can be specified. Example client command:
@@ -269,7 +273,7 @@ python3 vllm/benchmarks/benchmark_serving.py \
--num-prompts 10
```
-### Running With Ramp-Up Request Rate
+**Running With Ramp-Up Request Rate**
The benchmark tool also supports ramping up the request rate over the
duration of the benchmark run. This can be useful for stress testing the
@@ -284,8 +288,12 @@ The following arguments can be used to control the ramp-up:
- `--ramp-up-start-rps`: The request rate at the beginning of the benchmark.
- `--ramp-up-end-rps`: The request rate at the end of the benchmark.
----
-## Example - Offline Throughput Benchmark
+
+
+
+๐ Example - Offline Throughput Benchmark
+
+
```bash
python3 vllm/benchmarks/benchmark_throughput.py \
@@ -303,7 +311,7 @@ Total num prompt tokens: 5014
Total num output tokens: 1500
```
-### VisionArena Benchmark for Vision Language Models
+**VisionArena Benchmark for Vision Language Models**
``` bash
python3 vllm/benchmarks/benchmark_throughput.py \
@@ -323,7 +331,7 @@ Total num prompt tokens: 14527
Total num output tokens: 1280
```
-### InstructCoder Benchmark with Speculative Decoding
+**InstructCoder Benchmark with Speculative Decoding**
``` bash
VLLM_WORKER_MULTIPROC_METHOD=spawn \
@@ -347,7 +355,7 @@ Total num prompt tokens: 261136
Total num output tokens: 204800
```
-### Other HuggingFaceDataset Examples
+**Other HuggingFaceDataset Examples**
**`lmms-lab/LLaVA-OneVision-Data`**
@@ -386,7 +394,7 @@ python3 benchmarks/benchmark_throughput.py \
--num-prompts 10
```
-### Benchmark with LoRA Adapters
+**Benchmark with LoRA Adapters**
``` bash
# download dataset
@@ -403,18 +411,22 @@ python3 vllm/benchmarks/benchmark_throughput.py \
--lora-path yard1/llama-2-7b-sql-lora-test
```
----
-## Example - Structured Output Benchmark
+
+
+
+๐ ๏ธ Example - Structured Output Benchmark
+
+
Benchmark the performance of structured output generation (JSON, grammar, regex).
-### Server Setup
+**Server Setup**
```bash
vllm serve NousResearch/Hermes-3-Llama-3.1-8B --disable-log-requests
```
-### JSON Schema Benchmark
+**JSON Schema Benchmark**
```bash
python3 benchmarks/benchmark_serving_structured_output.py \
@@ -426,7 +438,7 @@ python3 benchmarks/benchmark_serving_structured_output.py \
--num-prompts 1000
```
-### Grammar-based Generation Benchmark
+**Grammar-based Generation Benchmark**
```bash
python3 benchmarks/benchmark_serving_structured_output.py \
@@ -438,7 +450,7 @@ python3 benchmarks/benchmark_serving_structured_output.py \
--num-prompts 1000
```
-### Regex-based Generation Benchmark
+**Regex-based Generation Benchmark**
```bash
python3 benchmarks/benchmark_serving_structured_output.py \
@@ -449,7 +461,7 @@ python3 benchmarks/benchmark_serving_structured_output.py \
--num-prompts 1000
```
-### Choice-based Generation Benchmark
+**Choice-based Generation Benchmark**
```bash
python3 benchmarks/benchmark_serving_structured_output.py \
@@ -460,7 +472,7 @@ python3 benchmarks/benchmark_serving_structured_output.py \
--num-prompts 1000
```
-### XGrammar Benchmark Dataset
+**XGrammar Benchmark Dataset**
```bash
python3 benchmarks/benchmark_serving_structured_output.py \
@@ -471,12 +483,16 @@ python3 benchmarks/benchmark_serving_structured_output.py \
--num-prompts 1000
```
----
-## Example - Long Document QA Throughput Benchmark
+
+
+
+๐ Example - Long Document QA Benchmark
+
+
Benchmark the performance of long document question-answering with prefix caching.
-### Basic Long Document QA Test
+**Basic Long Document QA Test**
```bash
python3 benchmarks/benchmark_long_document_qa_throughput.py \
@@ -488,7 +504,7 @@ python3 benchmarks/benchmark_long_document_qa_throughput.py \
--repeat-count 5
```
-### Different Repeat Modes
+**Different Repeat Modes**
```bash
# Random mode (default) - shuffle prompts randomly
@@ -519,12 +535,16 @@ python3 benchmarks/benchmark_long_document_qa_throughput.py \
--repeat-mode interleave
```
----
-## Example - Prefix Caching Benchmark
+
+
+
+๐๏ธ Example - Prefix Caching Benchmark
+
+
Benchmark the efficiency of automatic prefix caching.
-### Fixed Prompt with Prefix Caching
+**Fixed Prompt with Prefix Caching**
```bash
python3 benchmarks/benchmark_prefix_caching.py \
@@ -535,7 +555,7 @@ python3 benchmarks/benchmark_prefix_caching.py \
--input-length-range 128:256
```
-### ShareGPT Dataset with Prefix Caching
+**ShareGPT Dataset with Prefix Caching**
```bash
# download dataset
@@ -550,12 +570,16 @@ python3 benchmarks/benchmark_prefix_caching.py \
--input-length-range 128:256
```
----
-## Example - Request Prioritization Benchmark
+
+
+
+โก Example - Request Prioritization Benchmark
+
+
Benchmark the performance of request prioritization in vLLM.
-### Basic Prioritization Test
+**Basic Prioritization Test**
```bash
python3 benchmarks/benchmark_prioritization.py \
@@ -566,7 +590,7 @@ python3 benchmarks/benchmark_prioritization.py \
--scheduling-policy priority
```
-### Multiple Sequences per Prompt
+**Multiple Sequences per Prompt**
```bash
python3 benchmarks/benchmark_prioritization.py \
@@ -577,3 +601,5 @@ python3 benchmarks/benchmark_prioritization.py \
--scheduling-policy priority \
--n 2
```
+
+
diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py
index 8671719bce72..55c0cf851264 100644
--- a/benchmarks/benchmark_dataset.py
+++ b/benchmarks/benchmark_dataset.py
@@ -349,8 +349,9 @@ def sample(
# [1650, 939, 486] -> ['ฤ call', 'sh', 'ere']
# To avoid uncontrolled change of the prompt length,
# the encoded sequence is truncated before being decode again.
+ total_input_len = prefix_len + int(input_lens[i])
re_encoded_sequence = tokenizer.encode(prompt, add_special_tokens=False)[
- : input_lens[i]
+ :total_input_len
]
prompt = tokenizer.decode(re_encoded_sequence)
total_input_len = len(re_encoded_sequence)
diff --git a/cmake/utils.cmake b/cmake/utils.cmake
index 59c78950a109..621179a70169 100644
--- a/cmake/utils.cmake
+++ b/cmake/utils.cmake
@@ -265,8 +265,8 @@ macro(set_gencode_flags_for_srcs)
endmacro()
#
-# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form
-# `.[letter]` compute the "loose intersection" with the
+# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form
+# `.[letter]` compute the "loose intersection" with the
# `TGT_CUDA_ARCHS` list of gencodes. We also support the `+PTX` suffix in
# `SRC_CUDA_ARCHS` which indicates that the PTX code should be built when there
# is a CUDA_ARCH in `TGT_CUDA_ARCHS` that is equal to or larger than the
@@ -278,7 +278,7 @@ endmacro()
# in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`.
# We have special handling for x.0a, if x.0a is in `SRC_CUDA_ARCHS` and x.0 is
# in `TGT_CUDA_ARCHS` then we should remove x.0a from `SRC_CUDA_ARCHS` and add
-# x.0a to the result (and remove x.0 from TGT_CUDA_ARCHS).
+# x.0a to the result (and remove x.0 from TGT_CUDA_ARCHS).
# The result is stored in `OUT_CUDA_ARCHS`.
#
# Example:
@@ -313,21 +313,16 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
# if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should
# remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS
set(_CUDA_ARCHS)
- if ("9.0a" IN_LIST _SRC_CUDA_ARCHS)
- list(REMOVE_ITEM _SRC_CUDA_ARCHS "9.0a")
- if ("9.0" IN_LIST TGT_CUDA_ARCHS)
- list(REMOVE_ITEM _TGT_CUDA_ARCHS "9.0")
- set(_CUDA_ARCHS "9.0a")
- endif()
- endif()
-
- if ("10.0a" IN_LIST _SRC_CUDA_ARCHS)
- list(REMOVE_ITEM _SRC_CUDA_ARCHS "10.0a")
- if ("10.0" IN_LIST TGT_CUDA_ARCHS)
- list(REMOVE_ITEM _TGT_CUDA_ARCHS "10.0")
- set(_CUDA_ARCHS "10.0a")
+ foreach(_arch ${_SRC_CUDA_ARCHS})
+ if(_arch MATCHES "\\a$")
+ list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}")
+ string(REPLACE "a" "" _base "${_arch}")
+ if ("${_base}" IN_LIST TGT_CUDA_ARCHS)
+ list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_base}")
+ list(APPEND _CUDA_ARCHS "${_arch}")
+ endif()
endif()
- endif()
+ endforeach()
list(SORT _SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
@@ -359,7 +354,7 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
endforeach()
list(REMOVE_DUPLICATES _CUDA_ARCHS)
-
+
# reapply +PTX suffix to architectures that requested PTX
set(_FINAL_ARCHS)
foreach(_arch ${_CUDA_ARCHS})
@@ -370,7 +365,7 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
endif()
endforeach()
set(_CUDA_ARCHS ${_FINAL_ARCHS})
-
+
set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)
endfunction()
diff --git a/csrc/attention/mla/cutlass_mla_kernels.cu b/csrc/attention/mla/cutlass_mla_kernels.cu
index f4b6b19f4b23..9d05d910dd81 100644
--- a/csrc/attention/mla/cutlass_mla_kernels.cu
+++ b/csrc/attention/mla/cutlass_mla_kernels.cu
@@ -207,7 +207,7 @@ void cutlass_mla_decode_sm100a(torch::Tensor const& out,
"page_table must be a 32-bit integer tensor");
auto in_dtype = q_nope.dtype();
- at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()};
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(q_nope));
const cudaStream_t stream =
at::cuda::getCurrentCUDAStream(q_nope.get_device());
if (in_dtype == at::ScalarType::Half) {
diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp
index 447e826bc1c0..60304d229a8f 100644
--- a/csrc/cpu/torch_bindings.cpp
+++ b/csrc/cpu/torch_bindings.cpp
@@ -131,16 +131,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Quantization
#ifdef __AVX512F__
+ at::Tag stride_tag = at::Tag::needs_fixed_stride_order;
// Compute int8 quantized tensor for given scaling factor.
ops.def(
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
- "Tensor? azp) -> ()");
+ "Tensor? azp) -> ()",
+ {stride_tag});
ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant);
// Compute int8 quantized tensor and scaling factor
ops.def(
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, "
- "Tensor!? azp) -> ()");
+ "Tensor!? azp) -> ()",
+ {stride_tag});
ops.impl("dynamic_scaled_int8_quant", torch::kCPU,
&dynamic_scaled_int8_quant);
// W8A8 GEMM, supporting symmetric per-tensor or per-row/column
@@ -148,7 +151,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def(
"cutlass_scaled_mm(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales,"
- " Tensor b_scales, Tensor? bias) -> ()");
+ " Tensor b_scales, Tensor? bias) -> ()",
+ {stride_tag});
ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm);
// w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
// quantization.
@@ -156,7 +160,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor azp_adj,"
- " Tensor? azp, Tensor? bias) -> ()");
+ " Tensor? azp, Tensor? bias) -> ()",
+ {stride_tag});
ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp);
#elif defined(__powerpc64__)
// Compute int8 quantized tensor for given scaling factor.
diff --git a/csrc/custom_quickreduce.cu b/csrc/custom_quickreduce.cu
new file mode 100644
index 000000000000..33d0d4a7226e
--- /dev/null
+++ b/csrc/custom_quickreduce.cu
@@ -0,0 +1,114 @@
+#include
+#include
+#include
+#include
+
+#ifdef USE_ROCM
+
+ #include "quickreduce/quick_reduce.h"
+
+quickreduce::fptr_t init_custom_qr(int64_t rank, int64_t world_size,
+ std::optional qr_max_size) {
+ if (world_size > 8)
+ throw std::invalid_argument("world size > 8 is not supported");
+ if (world_size == 6)
+ throw std::invalid_argument("world size == 6 is not supported");
+ if (world_size % 2 != 0)
+ throw std::invalid_argument("Odd num gpus is not supported for now");
+ if (rank < 0 || rank >= world_size)
+ throw std::invalid_argument("invalid rank passed in");
+ quickreduce::DeviceComms* fptr = new quickreduce::DeviceComms();
+ fptr->init(world_size, rank, qr_max_size);
+ return (quickreduce::fptr_t)fptr;
+}
+
+void qr_destroy(quickreduce::fptr_t _fa) {
+ if (_fa) {
+ auto fa = reinterpret_cast(_fa);
+ fa->destroy();
+ delete fa;
+ }
+}
+
+torch::Tensor qr_get_handle(quickreduce::fptr_t _fa) {
+ auto fa = reinterpret_cast(_fa);
+ hipIpcMemHandle_t handle = fa->get_handle();
+ auto options =
+ torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
+ auto data_handle =
+ torch::empty({static_cast(sizeof(hipIpcMemHandle_t))}, options);
+ std::memcpy(data_handle.data_ptr(), &handle, sizeof(hipIpcMemHandle_t));
+ return data_handle;
+}
+
+void qr_open_handles(quickreduce::fptr_t _fa,
+ const std::vector& handles) {
+ auto fa = reinterpret_cast(_fa);
+ std::vector ipc_handles;
+ ipc_handles.reserve(handles.size());
+ for (auto& handle : handles) {
+ // Ensure the tensor is on the same device as the current device.
+ hipIpcMemHandle_t ipc_handle;
+ std::memcpy(&ipc_handle, handle.data_ptr(), sizeof(hipIpcMemHandle_t));
+ ipc_handles.push_back(ipc_handle);
+ }
+ fa->open_ipc_handles(ipc_handles);
+}
+
+void qr_all_reduce(quickreduce::fptr_t _fa, torch::Tensor& inp,
+ torch::Tensor& out, int64_t quant_level, bool cast_bf2half) {
+ auto fa = reinterpret_cast(_fa);
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
+ auto stream = at::cuda::getCurrentHIPStreamMasqueradingAsCUDA();
+
+ TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
+ TORCH_CHECK_EQ(inp.numel(), out.numel());
+ TORCH_CHECK_LE(out.numel(), fa->kMaxProblemSize);
+ if (out.scalar_type() == at::ScalarType::Half) {
+ fa->allreduce(reinterpret_cast(inp.data_ptr()),
+ reinterpret_cast(out.data_ptr()),
+ out.numel(), quant_level, stream);
+ } else if (out.scalar_type() == at::ScalarType::BFloat16) {
+ if (cast_bf2half) {
+ fa->allreduce(reinterpret_cast(inp.data_ptr()),
+ reinterpret_cast(out.data_ptr()),
+ out.numel(), quant_level, stream);
+ } else {
+ fa->allreduce(
+ reinterpret_cast(inp.data_ptr()),
+ reinterpret_cast(out.data_ptr()),
+ out.numel(), quant_level, stream);
+ }
+ } else {
+ throw std::runtime_error(
+ "quick allreduce only supports float16 and bfloat16");
+ }
+}
+
+int64_t qr_max_size() {
+ // The default is 2GB (2,147,483,648 bytes)
+ return static_cast(std::numeric_limits::max()) + 1;
+}
+
+ #define INSTANTIATE_FOR_WORLDSIZE(T, Codec, cast_bf2half) \
+ template struct quickreduce::AllReduceTwoshot, \
+ cast_bf2half>; \
+ template struct quickreduce::AllReduceTwoshot, \
+ cast_bf2half>; \
+ template struct quickreduce::AllReduceTwoshot, cast_bf2half>;
+
+INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecFP, false)
+INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ4, false)
+INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ6, false)
+INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ8, false)
+INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecFP, true)
+INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ4, true)
+INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ6, true)
+INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ8, true)
+
+INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecFP, false)
+INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ4, false)
+INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ6, false)
+INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ8, false)
+
+#endif // USE_ROCM
\ No newline at end of file
diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu
index f62d08c17c6d..c83d72751a55 100644
--- a/csrc/mamba/causal_conv1d/causal_conv1d.cu
+++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu
@@ -185,9 +185,7 @@ void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
params.conv_states_ptr = nullptr;
}
- // Otherwise the kernel will be launched from cuda:0 device
- // Cast to char to avoid compiler warning about narrowing
- at::cuda::CUDAGuard device_guard{(char)x.get_device()};
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
auto stream = at::cuda::getCurrentCUDAStream().stream();
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] {
causal_conv1d_fwd_cuda(params, stream);
@@ -278,9 +276,7 @@ void causal_conv1d_update(const at::Tensor &x,
params.conv_state_indices_ptr = nullptr;
}
- // Otherwise the kernel will be launched from cuda:0 device
- // Cast to char to avoid compiler warning about narrowing
- at::cuda::CUDAGuard device_guard{(char)x.get_device()};
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
auto stream = at::cuda::getCurrentCUDAStream().stream();
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] {
causal_conv1d_update_cuda(params, stream);
diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu
index 0c9df925bdbf..785d316025ec 100644
--- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu
+++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu
@@ -647,9 +647,7 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
);
- // Otherwise the kernel will be launched from cuda:0 device
- // Cast to char to avoid compiler warning about narrowing
- at::cuda::CUDAGuard device_guard{(char)u.get_device()};
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(u));
auto stream = at::cuda::getCurrentCUDAStream().stream();
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] {
selective_scan_fwd_cuda(params, stream);
diff --git a/csrc/ops.h b/csrc/ops.h
index b123da92f512..c54fa1574820 100644
--- a/csrc/ops.h
+++ b/csrc/ops.h
@@ -363,3 +363,14 @@ std::tuple allocate_shared_buffer_and_handle(
int64_t size);
int64_t open_mem_handle(torch::Tensor& mem_handle);
void free_shared_buffer(int64_t buffer);
+
+#ifdef USE_ROCM
+fptr_t init_custom_qr(int64_t rank, int64_t world_size,
+ std::optional qr_max_size = std::nullopt);
+void qr_destroy(fptr_t _fa);
+torch::Tensor qr_get_handle(fptr_t _fa);
+void qr_open_handles(fptr_t _fa, const std::vector& handles);
+void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
+ int64_t quant_level, bool cast_bf2half = false);
+int64_t qr_max_size();
+#endif
\ No newline at end of file
diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh
index 1549ed96aa2b..24564efbd21b 100644
--- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh
+++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh
@@ -29,12 +29,12 @@ struct sm100_fp8_config_default {
template typename Epilogue>
struct sm100_fp8_config_M256 {
- // M in (128, 256]
+ // M in (64, 256]
static_assert(std::is_same());
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
using TileShape = Shape<_128, _128, _128>;
- using ClusterShape = Shape<_2, _2, _1>;
+ using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm_sm100;
@@ -42,13 +42,13 @@ struct sm100_fp8_config_M256 {
template typename Epilogue>
-struct sm100_fp8_config_M128 {
- // M in (64, 128]
+struct sm100_fp8_config_M64 {
+ // M in (16, 64]
static_assert(std::is_same());
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
- using TileShape = Shape<_128, _128, _256>;
- using ClusterShape = Shape<_2, _4, _1>;
+ using TileShape = Shape<_64, _64, _128>;
+ using ClusterShape = Shape<_1, _1, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm_sm100;
@@ -56,13 +56,13 @@ struct sm100_fp8_config_M128 {
template typename Epilogue>
-struct sm100_fp8_config_M64 {
- // M in [1, 64]
+struct sm100_fp8_config_M16 {
+ // M in [1, 16]
static_assert(std::is_same());
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
- using TileShape = Shape<_64, _64, _256>;
- using ClusterShape = Shape<_1, _8, _1>;
+ using TileShape = Shape<_64, _64, _128>;
+ using ClusterShape = Shape<_1, _4, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm_sm100;
@@ -82,27 +82,27 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
using Cutlass3xGemmDefault =
typename sm100_fp8_config_default::Cutlass3xGemm;
+ using Cutlass3xGemmM16 =
+ typename sm100_fp8_config_M16::Cutlass3xGemm;
using Cutlass3xGemmM64 =
typename sm100_fp8_config_M64::Cutlass3xGemm;
- using Cutlass3xGemmM128 =
- typename sm100_fp8_config_M128::Cutlass3xGemm;
using Cutlass3xGemmM256 =
typename sm100_fp8_config_M256::Cutlass3xGemm;
uint32_t const m = a.size(0);
uint32_t const mp2 =
- std::max(static_cast(64), next_pow_2(m)); // next power of 2
+ std::max(static_cast(16), next_pow_2(m)); // next power of 2
- if (mp2 <= 64) {
- // m in [1, 64]
- return cutlass_gemm_caller(
+ if (mp2 <= 16) {
+ // m in [1, 16]
+ return cutlass_gemm_caller(
out, a, b, std::forward(args)...);
- } else if (mp2 <= 128) {
- // m in (64, 128]
- return cutlass_gemm_caller(
+ } else if (mp2 <= 64) {
+ // m in (16, 64]
+ return cutlass_gemm_caller(
out, a, b, std::forward(args)...);
} else if (mp2 <= 256) {
- // m in (128, 256]
+ // m in (64, 256]
return cutlass_gemm_caller(
out, a, b, std::forward(args)...);
} else {
diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
index 348525810810..a2080c300119 100644
--- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
+++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
@@ -241,7 +241,7 @@ void get_cutlass_moe_mm_data(
// mm to run it for.
int32_t version_num = get_sm_version_num();
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
- (defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM90)
+ (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100)
get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1,
problem_sizes2, input_permutation,
output_permutation, num_experts, n, k,
@@ -252,7 +252,7 @@ void get_cutlass_moe_mm_data(
false,
"No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for "
"CUDA device capability: ",
- version_num, ". Required capability: 90");
+ version_num, ". Required capability: 90 or 100");
}
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
@@ -265,7 +265,8 @@ void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
// This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for.
int32_t version_num = get_sm_version_num();
-#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
+#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
+ (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100)
get_cutlass_pplx_moe_mm_data_caller(expert_offsets, problem_sizes1,
problem_sizes2, expert_num_tokens,
num_local_experts, padded_m, n, k);
@@ -275,7 +276,7 @@ void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
false,
"No compiled get_cutlass_pplx_moe_mm_data: no cutlass_scaled_mm kernel "
"for CUDA device capability: ",
- version_num, ". Required capability: 90");
+ version_num, ". Required capability: 90 or 100");
}
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
diff --git a/csrc/quantization/fp4/nvfp4_experts_quant.cu b/csrc/quantization/fp4/nvfp4_experts_quant.cu
index b51033c9b72c..190d66f318a8 100644
--- a/csrc/quantization/fp4/nvfp4_experts_quant.cu
+++ b/csrc/quantization/fp4/nvfp4_experts_quant.cu
@@ -561,7 +561,7 @@ void scaled_fp4_experts_quant_sm100a(
TORCH_CHECK(output_scale.size(1) * 4 == padded_k);
auto in_dtype = input.dtype();
- at::cuda::CUDAGuard device_guard{(char)input.get_device()};
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream =
at::cuda::getCurrentCUDAStream(input.get_device());
if (in_dtype == at::ScalarType::Half) {
@@ -579,4 +579,4 @@ void scaled_fp4_experts_quant_sm100a(
} else {
TORCH_CHECK(false, "Expected input data type to be half or bfloat16");
}
-}
\ No newline at end of file
+}
diff --git a/csrc/quantization/fp4/nvfp4_quant_kernels.cu b/csrc/quantization/fp4/nvfp4_quant_kernels.cu
index fef74111624f..d32911357a95 100644
--- a/csrc/quantization/fp4/nvfp4_quant_kernels.cu
+++ b/csrc/quantization/fp4/nvfp4_quant_kernels.cu
@@ -347,7 +347,7 @@ void scaled_fp4_quant_sm100a(torch::Tensor const& output,
auto input_sf_ptr = static_cast(input_sf.data_ptr());
auto sf_out = static_cast(output_sf.data_ptr());
auto output_ptr = static_cast(output.data_ptr());
- at::cuda::CUDAGuard device_guard{(char)input.get_device()};
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
// We don't support e8m0 scales at this moment.
diff --git a/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu b/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu
index 97c0e0da7b1f..7572a7eb3122 100644
--- a/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu
+++ b/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu
@@ -267,7 +267,7 @@ void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A,
B_sf.sizes()[1], ")");
auto out_dtype = D.dtype();
- at::cuda::CUDAGuard device_guard{(char)A.get_device()};
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device());
if (out_dtype == at::ScalarType::Half) {
diff --git a/csrc/quickreduce/base.h b/csrc/quickreduce/base.h
new file mode 100644
index 000000000000..a2170e483207
--- /dev/null
+++ b/csrc/quickreduce/base.h
@@ -0,0 +1,338 @@
+#pragma once
+
+#include
+#include
+#include
+#include
+
+#define __quickreduce_device_inline__ __device__ __forceinline__
+#define __quickreduce_launch_bounds_two_shot__ __launch_bounds__(256, 4)
+#define __quickreduce_launch_bounds_one_shot__ __launch_bounds__(512, 4)
+
+namespace quickreduce {
+
+typedef __hip_bfloat16 nv_bfloat16;
+typedef __hip_bfloat162 nv_bfloat162;
+
+using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
+using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
+
+// Setup acquire-release semantics for vector memory reads (mubuf instruction)
+// as per architecture.
+#if defined(__gfx942__)
+// CDNA3: Scope bits sc0, sc1
+ #define MUBUF_ACQUIRE 16
+ #define MUBUF_RELEASE 16
+#elif (defined(__gfx908__) || defined(__gfx90a__))
+// CDNA1 and CDNA2 - glc bit
+ #define MUBUF_ACQUIRE 1
+ #define MUBUF_RELEASE 0
+#endif
+
+static constexpr int kNegOne = 0xBC00BC00; // {-1, -1}, fp16x2_t
+
+// Number of atoms (4xf16x2_t) processed by a single thread
+static constexpr int kAtoms = 8;
+
+// We use a workgroup of 256 threads
+static constexpr int kBlockSize = 256;
+static constexpr int kAtomStride = kBlockSize;
+
+// Size and atom stride of source/destination data that the block will
+// process.
+// Workgroup scope = Tile = (256 threads x 8 atoms x 16B)
+static constexpr int kTileSize = kBlockSize * kAtoms * sizeof(int32x4_t);
+
+// Max number of blocks. 304 CUs on MI300
+static constexpr int kMaxNumBlocks = 304 * 4;
+
+// Standard CDNA wavefront size.
+static constexpr int kWavefront = 64;
+
+// 256 thread, 4 wavefronts.
+static dim3 constexpr kBlockTwoShot = {kWavefront, kBlockSize / kWavefront, 1};
+
+// Number of threads in a group for quantization
+// It corresponds to 32 F16 elements in quantization block
+static constexpr int kThreadGroupSize = 8;
+
+// Methods
+__quickreduce_device_inline__ __host__ unsigned long divceil(unsigned long x,
+ unsigned long y) {
+ return ((x + y - 1) / y);
+}
+
+union BufferResource {
+ __quickreduce_device_inline__ constexpr BufferResource()
+ : config(0x00020000U) {}
+
+ __quickreduce_device_inline__ constexpr BufferResource(void* buffer_address,
+ uint32_t buffer_size)
+ : address(buffer_address), range(buffer_size), config(0x00020000U) {}
+
+ int32x4_t descriptor;
+ struct {
+ void* address; // 8B, out of which first 48b is address, and 16b is stride
+ // (unused)
+ uint32_t range; // Byte range for the buffer resource
+ uint32_t config; // Constant, DFMT=32b
+ };
+};
+
+__quickreduce_device_inline__ static int32x4_t buffer_load_dwordx4(
+ int32x4_t srsrc, int32_t voffset, int32_t soffset,
+ int32_t aux) __asm("llvm.amdgcn.raw.buffer.load.v4i32");
+
+__quickreduce_device_inline__ static void buffer_store_dwordx4(
+ int32x4_t data, int32x4_t srsrc, int32_t voffset, int32_t soffset,
+ int32_t aux) __asm("llvm.amdgcn.raw.buffer.store.v4i32");
+
+__quickreduce_device_inline__ static void set_fp16_ovfl(bool const value) {
+#if defined(__gfx942__)
+ if (value) {
+ asm volatile("s_setreg_imm32_b32 0xdc1, 1;" ::);
+ } else {
+ asm volatile("s_setreg_imm32_b32 0xdc1, 0;" ::);
+ }
+#endif
+}
+union bf162_int_union {
+ int i;
+ nv_bfloat162 bf2;
+};
+
+template
+__quickreduce_device_inline__ void packed_assign_add(int32x4_t* A,
+ int32x4_t* B);
+
+template <>
+__quickreduce_device_inline__ void packed_assign_add(int32x4_t* A,
+ int32x4_t* B) {
+ int32x4_t& tR_fragment = A[0];
+ int32x4_t& tA_fragment = B[0];
+
+ asm volatile("v_pk_add_f16 %0, %1, %2"
+ : "=v"(tR_fragment[0])
+ : "v"(tR_fragment[0]), "v"(tA_fragment[0]));
+ asm volatile("v_pk_add_f16 %0, %1, %2"
+ : "=v"(tR_fragment[1])
+ : "v"(tR_fragment[1]), "v"(tA_fragment[1]));
+ asm volatile("v_pk_add_f16 %0, %1, %2"
+ : "=v"(tR_fragment[2])
+ : "v"(tR_fragment[2]), "v"(tA_fragment[2]));
+ asm volatile("v_pk_add_f16 %0, %1, %2"
+ : "=v"(tR_fragment[3])
+ : "v"(tR_fragment[3]), "v"(tA_fragment[3]));
+}
+
+template <>
+__quickreduce_device_inline__ void packed_assign_add(
+ int32x4_t* A, int32x4_t* B) {
+ nv_bfloat162* tA = reinterpret_cast(A);
+ nv_bfloat162* tB = reinterpret_cast(B);
+#pragma unroll
+ for (int i = 0; i < 4; i++) {
+ tA[i] = __hadd2(tA[i], tB[i]);
+ }
+}
+
+template
+__quickreduce_device_inline__ int packed_max(int a, int b);
+
+template <>
+__quickreduce_device_inline__ int packed_max(int a, int b) {
+ int result;
+ asm volatile("v_pk_max_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b));
+ return result;
+}
+
+template <>
+__quickreduce_device_inline__ int packed_max(int a, int b) {
+ bf162_int_union A, B, R;
+ A.i = a;
+ B.i = b;
+ R.bf2 = __hmax2(A.bf2, B.bf2);
+ return R.i;
+}
+
+template
+__quickreduce_device_inline__ int packed_min(int a, int b);
+
+template <>
+__quickreduce_device_inline__ int packed_min(int a, int b) {
+ int result;
+ asm volatile("v_pk_min_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b));
+ return result;
+}
+
+template <>
+__quickreduce_device_inline__ int packed_min(int a, int b) {
+ bf162_int_union A, B, R;
+ A.i = a;
+ B.i = b;
+ R.bf2 = __hmin2(A.bf2, B.bf2);
+ return R.i;
+}
+
+template
+__quickreduce_device_inline__ int packed_abs_max(int a, int b);
+
+template <>
+__quickreduce_device_inline__ int packed_abs_max(int a, int b) {
+ half2 wmaxh2 = __builtin_bit_cast(half2, a);
+ half2 wminh2 = __builtin_bit_cast(half2, b);
+ half2 wblockmaxh2;
+
+ wblockmaxh2.x =
+ __hgt(__habs(wmaxh2.x), __habs(wminh2.x)) ? wmaxh2.x : wminh2.x;
+ wblockmaxh2.y =
+ __hgt(__habs(wmaxh2.y), __habs(wminh2.y)) ? wmaxh2.y : wminh2.y;
+ return __builtin_bit_cast(int, wblockmaxh2);
+}
+
+template <>
+__quickreduce_device_inline__ int packed_abs_max(int a, int b) {
+ bf162_int_union A, B, R;
+ A.i = a;
+ B.i = b;
+ R.bf2.x = __hgt(__habs(A.bf2.x), __habs(B.bf2.x)) ? A.bf2.x : B.bf2.x;
+ R.bf2.y = __hgt(__habs(A.bf2.y), __habs(B.bf2.y)) ? A.bf2.y : B.bf2.y;
+ return R.i;
+}
+
+template
+__quickreduce_device_inline__ int packed_add(int a, int b);
+
+template <>
+__quickreduce_device_inline__ int packed_add(int a, int b) {
+ int result;
+ asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b));
+ return result;
+}
+
+template <>
+__quickreduce_device_inline__ int packed_add(int a, int b) {
+ bf162_int_union A, B, R;
+ A.i = a;
+ B.i = b;
+ R.bf2 = __hadd2(A.bf2, B.bf2);
+ return R.i;
+}
+
+template <>
+__quickreduce_device_inline__ int packed_add(int a, int b) {
+ int result;
+ asm volatile("v_pk_add_i16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b));
+ return result;
+}
+
+template
+__quickreduce_device_inline__ int packed_sub(int a, int b);
+
+template <>
+__quickreduce_device_inline__ int packed_sub(int a, int b) {
+ int result;
+
+ // MI300 lacks packed fp16 sub instruction. So we do -1 * min + max
+ asm volatile("v_pk_fma_f16 %0, %1, %2 %3"
+ : "=v"(result)
+ : "v"(kNegOne), "v"(b), "v"(a));
+ return result;
+}
+
+template <>
+__quickreduce_device_inline__ int packed_sub(int a, int b) {
+ bf162_int_union A, B, R;
+ A.i = a;
+ B.i = b;
+ R.bf2 = __hsub2(A.bf2, B.bf2);
+ return R.i;
+}
+
+template
+__quickreduce_device_inline__ int packed_mul(int a, int b);
+
+template <>
+__quickreduce_device_inline__ int packed_mul(int a, int b) {
+ int result;
+ asm volatile("v_pk_mul_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b));
+ return result;
+}
+
+template <>
+__quickreduce_device_inline__ int packed_mul(int a, int b) {
+ nv_bfloat162* tA = reinterpret_cast(&a);
+ nv_bfloat162* tB = reinterpret_cast(&b);
+ nv_bfloat162 tR = __hmul2(*tA, *tB);
+ return *(reinterpret_cast(&tR));
+}
+
+template
+__quickreduce_device_inline__ int packed_rcp(int a);
+
+template <>
+__quickreduce_device_inline__ int packed_rcp(int a) {
+ return __builtin_bit_cast(int, h2rcp(__builtin_bit_cast(half2, a)));
+}
+
+template <>
+__quickreduce_device_inline__ int packed_rcp(int a) {
+ bf162_int_union A, R;
+ A.i = a;
+ R.bf2 = h2rcp(A.bf2);
+ return R.i;
+}
+
+// changes dtype
+__quickreduce_device_inline__ float T2float_cast(half a) {
+ return __half2float(a);
+}
+
+__quickreduce_device_inline__ float T2float_cast(nv_bfloat16 a) {
+ return __bfloat162float(a);
+}
+
+template
+__quickreduce_device_inline__ int group_abs_max(int32x4_t atom) {
+ const int group_leader = (threadIdx.x / kThreadGroupSize) * kThreadGroupSize;
+
+ int wmax, wmin, wblockmax;
+ int a, b;
+ a = packed_max(atom[0], atom[1]);
+ b = packed_max(atom[2], atom[3]);
+
+ wmax = packed_max(a, b);
+
+ a = packed_min(atom[0], atom[1]);
+ b = packed_min(atom[2], atom[3]);
+
+ wmin = packed_min(a, b);
+
+ // Reduce the max among a group of threads
+ // Note: This is basically 2 blocks of values setup as the
+ // upper/lower halves of the f16x2_t
+ for (int i = 1; i < kThreadGroupSize; i <<= 1) {
+ int x = __shfl_down(wmax, i);
+ wmax = packed_max(wmax, x);
+
+ int y = __shfl_down(wmin, i);
+ wmin = packed_min(wmin, y);
+ }
+ wblockmax = packed_abs_max(wmax, wmin);
+ // Share with the cohort
+ wblockmax = __shfl(wblockmax, group_leader);
+ return wblockmax;
+}
+
+__quickreduce_device_inline__ void set_sync_flag(uint32_t* flag_ptr,
+ uint32_t flag) {
+ __atomic_store_n(flag_ptr, flag, __ATOMIC_RELEASE);
+}
+
+__quickreduce_device_inline__ void wait_sync_flag(uint32_t* flag_ptr,
+ uint32_t flag) {
+ while (__atomic_load_n(flag_ptr, __ATOMIC_RELAXED) != flag) {
+ }
+}
+
+} // namespace quickreduce
\ No newline at end of file
diff --git a/csrc/quickreduce/quick_reduce.h b/csrc/quickreduce/quick_reduce.h
new file mode 100644
index 000000000000..4fe4c44be7eb
--- /dev/null
+++ b/csrc/quickreduce/quick_reduce.h
@@ -0,0 +1,196 @@
+#pragma once
+
+#include
+#include
+#include "quick_reduce_impl.cuh"
+
+#define HIP_CHECK(err) \
+ do { \
+ hipError_t err_ = (err); \
+ if (err_ != hipSuccess) { \
+ std::printf("HIP error %d at %s:%d. %s\n", err_, __FILE__, __LINE__, \
+ hipGetErrorString(err_)); \
+ throw std::runtime_error("HIP error"); \
+ } \
+ } while (0)
+
+namespace quickreduce {
+using fptr_t = int64_t;
+static_assert(sizeof(void*) == sizeof(fptr_t));
+
+template
+__global__ __quickreduce_launch_bounds_two_shot__ static void
+allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, uint32_t num_blocks,
+ int rank, uint8_t** dbuffer_list,
+ uint32_t data_offset, uint32_t flag_color) {
+ int block = blockIdx.x;
+ int grid = gridDim.x;
+
+ while (block < num_blocks) {
+ AllReduceKernel::run(A, B, N, block, rank, dbuffer_list, data_offset,
+ flag_color);
+ block += grid;
+ flag_color++;
+ }
+}
+
+#define TWOSHOT_DISPATCH(__codec) \
+ if (world_size == 2) { \
+ using LineCodec = __codec; \
+ using AllReduceKernel = AllReduceTwoshot; \
+ hipLaunchKernelGGL((allreduce_prototype_twoshot), \
+ dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \
+ num_blocks, rank, dbuffer_list, data_offset, \
+ flag_color); \
+ } else if (world_size == 4) { \
+ using LineCodec = __codec; \
+ using AllReduceKernel = AllReduceTwoshot; \
+ hipLaunchKernelGGL((allreduce_prototype_twoshot), \
+ dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \
+ num_blocks, rank, dbuffer_list, data_offset, \
+ flag_color); \
+ } else if (world_size == 8) { \
+ using LineCodec = __codec; \
+ using AllReduceKernel = AllReduceTwoshot; \
+ hipLaunchKernelGGL((allreduce_prototype_twoshot), \
+ dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \
+ num_blocks, rank, dbuffer_list, data_offset, \
+ flag_color); \
+ }
+
+enum QuickReduceQuantLevel {
+ F16 = 0,
+ INT8 = 1,
+ INT6 = 2,
+ INT4 = 3,
+};
+
+struct DeviceComms {
+ // Max problem size is 2GB (in bytes) or half of uint32_t max value.
+ int64_t kMaxProblemSize =
+ static_cast(std::numeric_limits::max()) + 1;
+
+ // Max TP-8
+ static int constexpr kMaxWorldSize = 8;
+
+ bool initialized = false;
+ uint32_t flag_color = 1;
+ int world_size;
+ int rank;
+
+ uint8_t* dbuffer;
+ uint8_t** dbuffer_list;
+ hipIpcMemHandle_t buffer_ipc_handle;
+ std::vector all_buffer_ipc_handles;
+ std::vector buffer_list;
+ uint32_t data_offset;
+
+ DeviceComms() : initialized(false), world_size(1), rank(0) {}
+ ~DeviceComms() { destroy(); }
+
+ void init(int world_size, int rank,
+ std::optional max_problem_size = std::nullopt) {
+ destroy();
+ this->world_size = world_size;
+ this->rank = rank;
+ if (max_problem_size.has_value() && max_problem_size.value() > 0) {
+ this->kMaxProblemSize = max_problem_size.value();
+ }
+ // Allocate buffer size for worst case: F16 2-stage buffer.
+ uint32_t flags_buffer_size =
+ 2 * world_size * kMaxNumBlocks * sizeof(uint32_t);
+ static int64_t data_buffer_size = 2 * this->kMaxProblemSize;
+ int64_t total_buffer_size = flags_buffer_size + data_buffer_size;
+ data_offset = flags_buffer_size;
+ HIP_CHECK(hipExtMallocWithFlags((void**)&dbuffer, total_buffer_size,
+ hipDeviceMallocUncached));
+
+ // Clear the flags buffer.
+ HIP_CHECK(hipMemset(dbuffer, 0, flags_buffer_size));
+
+ // Device-side list of IPC buffers.
+ buffer_list.resize(world_size);
+ HIP_CHECK(hipMalloc(&dbuffer_list, world_size * sizeof(uint8_t*)));
+
+ // Create IPC handles for rank's communication buffer.
+ all_buffer_ipc_handles.resize(world_size);
+ HIP_CHECK(hipIpcGetMemHandle(&buffer_ipc_handle, dbuffer));
+
+ initialized = true;
+ }
+ int get_world_size() { return world_size; }
+ int get_rank() { return rank; }
+ bool status() { return initialized; }
+ hipIpcMemHandle_t const get_handle() { return buffer_ipc_handle; }
+
+ void destroy() {
+ if (initialized) {
+ for (int i = 0; i < world_size; i++) {
+ if (i != rank) {
+ HIP_CHECK(hipIpcCloseMemHandle(dbuffer_list[i]));
+ }
+ }
+
+ HIP_CHECK(hipFree(dbuffer));
+ HIP_CHECK(hipFree(dbuffer_list));
+
+ initialized = false;
+ }
+ }
+
+ void open_ipc_handles(std::vector const& ipc_handles) {
+ assert(ipc_handles.size() == all_buffer_ipc_handles.size());
+ for (int i = 0; i < world_size; i++) {
+ all_buffer_ipc_handles[i] = ipc_handles[i];
+ }
+
+ // Open device memory access to the IPC communication buffers.
+ // Note: For our own rank, we do not need to open a handle.
+ for (int i = 0; i < world_size; i++) {
+ if (i != rank) {
+ HIP_CHECK(hipIpcOpenMemHandle((void**)&buffer_list[i],
+ all_buffer_ipc_handles[i],
+ hipIpcMemLazyEnablePeerAccess));
+ } else {
+ buffer_list[i] = dbuffer;
+ }
+ }
+
+ HIP_CHECK(hipMemcpy(dbuffer_list, buffer_list.data(),
+ world_size * sizeof(uint8_t*), hipMemcpyHostToDevice));
+ }
+
+ template
+ void allreduce(T const* A, T* B, uint32_t N, int quant_level,
+ hipStream_t stream) {
+ if (world_size != 2 && world_size != 4 && world_size != 8) {
+ throw std::runtime_error("All Reduce not supported for world_size = " +
+ std::to_string(world_size));
+ }
+
+ // Configuration.
+ uint32_t msg_size = N * sizeof(T);
+ uint32_t num_blocks = divceil(msg_size, kTileSize);
+ uint32_t grid = min(kMaxNumBlocks, num_blocks);
+ auto quant_level_ = static_cast(quant_level);
+ switch (quant_level_) {
+ case QuickReduceQuantLevel::INT8:
+ TWOSHOT_DISPATCH(CodecQ8)
+ break;
+ case QuickReduceQuantLevel::INT6:
+ TWOSHOT_DISPATCH(CodecQ6)
+ break;
+ case QuickReduceQuantLevel::INT4:
+ TWOSHOT_DISPATCH(CodecQ4)
+ break;
+ default:
+ TWOSHOT_DISPATCH(CodecFP)
+ break;
+ }
+ HIP_CHECK(cudaGetLastError());
+ // Rotate the flag color.
+ flag_color += divceil(N, grid);
+ }
+};
+
+} // namespace quickreduce
\ No newline at end of file
diff --git a/csrc/quickreduce/quick_reduce_impl.cuh b/csrc/quickreduce/quick_reduce_impl.cuh
new file mode 100644
index 000000000000..17816c552d25
--- /dev/null
+++ b/csrc/quickreduce/quick_reduce_impl.cuh
@@ -0,0 +1,698 @@
+#pragma once
+
+#include
+#include "base.h"
+
+namespace quickreduce {
+
+struct CodecBase {
+ const int thread;
+ const int rank;
+ const int group_leader;
+ __quickreduce_device_inline__ CodecBase(int thread, int rank)
+ : thread(thread),
+ rank(rank),
+ group_leader((threadIdx.x / kThreadGroupSize) * kThreadGroupSize) {
+ set_fp16_ovfl(true);
+ }
+};
+
+// Default full precision codec.
+template
+struct CodecFP : public CodecBase {
+ static constexpr int kWorldSize = world_size;
+ static constexpr int kRankAtoms = kAtoms / kWorldSize;
+
+ // Codec tile size process by this workgroup.
+ // Each thread processes atoms of f16x8_t (16B).
+ static constexpr int kRankTransmittedTileSize =
+ kBlockSize * kRankAtoms * sizeof(int32x4_t);
+ static_assert(kRankTransmittedTileSize % 16 == 0,
+ "kRankTransmittedTileSize must be 16B aligned.");
+
+ // Total tile size for the collective communication.
+ static constexpr int kTransmittedTileSize =
+ kRankTransmittedTileSize * kWorldSize;
+
+ __quickreduce_device_inline__ CodecFP(int thread, int rank)
+ : CodecBase(thread, rank) {}
+
+ __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer,
+ const int32x4_t* __restrict__ data) {
+ for (int i = 0; i < kRankAtoms; i++) {
+ __builtin_nontemporal_store(data[i], send_buffer + thread);
+ send_buffer += kAtomStride;
+ }
+ }
+
+ __quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer,
+ int32x4_t* __restrict__ data) {
+ for (int i = 0; i < kRankAtoms; i++) {
+ data[i] = __builtin_nontemporal_load(*recv_buffer + thread);
+ *recv_buffer += kAtomStride;
+ }
+ }
+};
+
+// Int4 symmetric quantization codec.
+// We quantize the FP16 data to block-scaled Int4 in blocks of 4 *
+// kThreadGroupSize.
+template
+struct CodecQ4 : public CodecBase {
+ static constexpr int kWorldSize = world_size;
+
+ // Codec tile size process by this workgroup.
+ // Each threads processes a fragment of fp16x8_t (16B),
+ // into a int4x8_t (4B) and a fp16 scale shared among 32 values.
+ static constexpr int kRankAtoms = kAtoms / kWorldSize;
+ static constexpr int kRankTileStride = 1152;
+ static constexpr int kRankTileScaleOffset = 1024;
+ static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms;
+ static_assert(kRankTransmittedTileSize % 16 == 0,
+ "kRankTransmittedTileSize must be 16B aligned.");
+
+ static constexpr int kRankBufferTileStride =
+ kRankTileStride / sizeof(int32x4_t);
+
+ // Total tile size for the collective communication.
+ static constexpr int kTransmittedTileSize =
+ kRankTransmittedTileSize * kWorldSize;
+
+ // Constants configuration
+
+ // {-1/8.0h, -1/8.0h}, f16x2_t
+ static constexpr int kScaleFactor =
+ std::is_same::value ? 0xB000B000 : 0xBE00BE00;
+
+ // {1e-7, 1e-7}, f16x2_t
+ static constexpr int kScaleEpsilon =
+ std::is_same::value ? 0x00010001 : 0x33D733D7;
+
+ // {-8, -8}, f16x2_t
+ static constexpr int kRangeMin =
+ std::is_same::value ? 0xC800C800 : 0xC100C100;
+
+ // {+7, +7}, f16x2_t
+ static constexpr int kRangeMax =
+ std::is_same::value ? 0x47004700 : 0x40E040E0;
+
+ // {+8, +8}, int16x2_t
+ static constexpr int kRangeBias = 0x00080008;
+
+ __quickreduce_device_inline__ CodecQ4(int thread, int rank)
+ : CodecBase(thread, rank) {}
+
+ __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer,
+ const int32x4_t* __restrict__ data) {
+ for (int k = 0; k < kRankAtoms; k++) {
+ int32x4_t const atom = data[k];
+
+ // Compute the absolute maximum of the atom in the thread group
+ // In 2 blocks of values, upper/lower halves of the f16x2_t
+ int wblockmax = group_abs_max(atom);
+
+ // Derive scales
+ int decoding_scale;
+ int encoding_scale;
+ decoding_scale = packed_mul(wblockmax, kScaleFactor);
+ encoding_scale = packed_add(decoding_scale, kScaleEpsilon);
+ encoding_scale = packed_rcp(encoding_scale);
+
+ // Apply scales to get quantized values
+ int32x4_t w;
+ for (int i = 0; i < 4; i++) {
+ w[i] = packed_mul(atom[i], encoding_scale);
+ w[i] = packed_max(w[i], kRangeMin);
+ w[i] = packed_min(w[i], kRangeMax);
+ }
+
+ // Convert from f16x2_t to uint16x2_t
+ int32x4_t q;
+ {
+ int16_t* qi = reinterpret_cast(&q);
+ T* wh = reinterpret_cast(&w);
+ for (int i = 0; i < 8; i++) qi[i] = (int16_t)rintf(T2float_cast(wh[i]));
+
+ for (int i = 0; i < 4; i++) {
+ q[i] = packed_add(q[i], kRangeBias);
+ }
+ }
+
+ // Pack 8 x q4 into int32_t
+ int qw = q[0] | (q[1] << 4) | (q[2] << 8) | (q[3] << 12);
+
+ // Write quantized atom to send_buffer
+ // note: only the group leader stores the scale
+ uint8_t* atom_ptr =
+ reinterpret_cast(send_buffer + k * kRankBufferTileStride);
+ int32_t* qw_ptr = reinterpret_cast(atom_ptr) + thread;
+ int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) +
+ (thread / 8);
+
+ __builtin_nontemporal_store(qw, qw_ptr);
+ if (threadIdx.x == group_leader) {
+ __builtin_nontemporal_store(decoding_scale, qs_ptr);
+ }
+ }
+ }
+
+ __quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer,
+ int32x4_t* __restrict__ data) {
+ for (int k = 0; k < kRankAtoms; k++) {
+ // Directly read quantized atom from recv_buffer
+ uint8_t* atom_ptr = reinterpret_cast(*recv_buffer);
+ int32_t* qw_ptr = reinterpret_cast(atom_ptr) + thread;
+ int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) +
+ (thread / 8);
+
+ int32_t qw = __builtin_nontemporal_load(qw_ptr);
+ int qs = __builtin_nontemporal_load(qs_ptr);
+
+ *recv_buffer += kRankBufferTileStride;
+
+ // Unpack q4 into f16x8_t
+ int32x4_t w;
+ {
+ static constexpr uint kMask000F = 0x000F000F;
+ static constexpr uint kHalf2_1024 =
+ 0x64006400; // {1024.0, 1024.0}, fp16x2_t
+ static uint constexpr kHalf2_1032 =
+ 0xE408E408; // {-1032.0, -1032.0}, fp16x2_t
+
+ for (int i = 0; i < 4; i++) {
+ if constexpr (std::is_same::value) {
+ int32_t q4 = ((qw >> (i * 4)) & kMask000F) | kHalf2_1024;
+ w[i] = packed_add(q4, kHalf2_1032);
+ } else {
+ int32_t int16_2 = (qw >> (i * 4)) & kMask000F;
+ int16_t low = static_cast(int16_2 & 0xFFFF);
+ int16_t high = static_cast((int16_2 >> 16) & 0xFFFF);
+ nv_bfloat16 bf_low = __float2bfloat16(static_cast(low));
+ nv_bfloat16 bf_high = __float2bfloat16(static_cast(high));
+ nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high);
+ int32_t packed_bf16 = *reinterpret_cast(&bf2);
+ w[i] = packed_add(packed_bf16, kRangeMin);
+ }
+ }
+ }
+
+ // Apply decoding scales
+ for (int i = 0; i < 4; i++) {
+ w[i] = packed_mul(w[i], qs);
+ }
+
+ data[k] = w;
+ }
+ }
+};
+
+// Int6 symmetric quantization codec.
+// We quantize the FP16 data to block-scaled Int6 in blocks of 4 *
+// kThreadGroupSize.
+template
+struct CodecQ6 : public CodecBase {
+ static constexpr int kWorldSize = world_size;
+
+ // Codec tile size process by this workgroup.
+ // Each threads processes a fragment of fp16x8_t (16B),
+ // into a int6x8_t (4B + 2B) and a fp16 scale shared among 32 values.
+ static constexpr int kRankAtoms = kAtoms / kWorldSize;
+ static constexpr int kRankTileStride = 1664;
+ static constexpr int kRankTileQ2Offset = 1024;
+ static constexpr int kRankTileScaleOffset = 1536;
+ static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms;
+ static_assert(kRankTransmittedTileSize % 16 == 0,
+ "kRankTransmittedTileSize must be 16B aligned.");
+
+ static constexpr int kRankBufferTileStride =
+ kRankTileStride / sizeof(int32x4_t);
+
+ // Total tile size for the collective communication.
+ static constexpr int kTransmittedTileSize =
+ kRankTransmittedTileSize * kWorldSize;
+
+ // Constants configuration
+
+ // {-1/32.0h, -1/32.0h}, fp16x2_t
+ static constexpr int kScaleFactor =
+ std::is_same::value ? 0xA800A800 : 0xBD00BD00;
+
+ // {1e-7, 1e-7}, fp16x2_t
+ static constexpr int kScaleEpsilon =
+ std::is_same::value ? 0x00010001 : 0x33D733D7;
+
+ // {-32, -32}, fp16x2_t
+ static constexpr int kRangeMin =
+ std::is_same::value ? 0xD000D000 : 0xC200C200;
+
+ // {+31, +31}, fp16x2_t
+ static constexpr int kRangeMax =
+ std::is_same::value ? 0x4FC04FC0 : 0x41F841F8;
+
+ // {+32, +32}, int16x2_t
+ static constexpr int kRangeBias = 0x00200020;
+
+ __quickreduce_device_inline__ CodecQ6(int thread, int rank)
+ : CodecBase(thread, rank) {}
+
+ __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer,
+ const int32x4_t* __restrict__ data) {
+ for (int k = 0; k < kRankAtoms; k++) {
+ int32x4_t const atom = data[k];
+
+ // Compute the absolute maximum of the atom in the thread group
+ // In 2 blocks of values, upper/lower halves of the f16x2_t
+ int wblockmax = group_abs_max(atom);
+
+ // Derive scales
+ int decoding_scale;
+ int encoding_scale;
+ decoding_scale = packed_mul(wblockmax, kScaleFactor);
+ encoding_scale = packed_add(decoding_scale, kScaleEpsilon);
+ encoding_scale = packed_rcp(encoding_scale);
+
+ // Apply scales to get quantized values
+ int32x4_t w;
+ for (int i = 0; i < 4; i++) {
+ w[i] = packed_mul(atom[i], encoding_scale);
+ w[i] = packed_max(w[i], kRangeMin);
+ w[i] = packed_min(w[i], kRangeMax);
+ }
+
+ // Convert from f16x2_t to uint16x2_t
+ int32x4_t q;
+ {
+ int16_t* qi = reinterpret_cast(&q);
+ T* wh = reinterpret_cast(&w);
+ for (int i = 0; i < 8; i++) qi[i] = (int16_t)rintf(T2float_cast(wh[i]));
+
+ for (int i = 0; i < 4; i++) {
+ q[i] = packed_add(q[i], kRangeBias);
+ }
+ }
+
+ // Pack 8 x q6 into int32_t + int16_t
+ uint32_t q4w;
+ uint16_t q2w = 0;
+ q4w = (q[0] & 0x000F000F) | ((q[1] & 0x000F000F) << 4) |
+ ((q[2] & 0x000F000F) << 8) | ((q[3] & 0x000F000F) << 12);
+ {
+ int16_t* tw = reinterpret_cast(&q);
+#pragma unroll
+ for (int i = 0; i < 8; i++) {
+ q2w |= (tw[i] >> 4) << (i * 2);
+ }
+ }
+ // Write quantized atom to send_buffer
+ // note: only the group leader stores the scale
+ uint8_t* atom_ptr =
+ reinterpret_cast(send_buffer + k * kRankBufferTileStride);
+ uint32_t* q4w_ptr = reinterpret_cast(atom_ptr) + thread;
+ uint16_t* q2w_ptr =
+ reinterpret_cast(atom_ptr + kRankTileQ2Offset) + thread;
+ int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) +
+ (thread / 8);
+
+ __builtin_nontemporal_store(q4w, q4w_ptr);
+ __builtin_nontemporal_store(q2w, q2w_ptr);
+ if (threadIdx.x == group_leader) {
+ __builtin_nontemporal_store(decoding_scale, qs_ptr);
+ }
+ }
+ }
+
+ __quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer,
+ int32x4_t* __restrict__ data) {
+ for (int k = 0; k < kRankAtoms; k++) {
+ // Directly read quantized atom from recv_buffer
+ uint8_t* atom_ptr = reinterpret_cast(*recv_buffer);
+ uint32_t* q4w_ptr = reinterpret_cast(atom_ptr) + thread;
+ uint16_t* q2w_ptr =
+ reinterpret_cast(atom_ptr + kRankTileQ2Offset) + thread;
+ int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) +
+ (thread / 8);
+
+ uint32_t q4w = __builtin_nontemporal_load(q4w_ptr);
+ uint16_t q2w = __builtin_nontemporal_load(q2w_ptr);
+ int qs = __builtin_nontemporal_load(qs_ptr);
+
+ *recv_buffer += kRankBufferTileStride;
+
+ // Unpack q6 into fp16x8_t
+ int32x4_t w;
+ {
+ static uint constexpr kMask000F = 0x000F000F;
+ static uint constexpr kHalf2_1024 =
+ 0x64006400; // {1024.0, 1024.0}, fp16x2_t
+ static uint constexpr kHalf2_1056 =
+ 0xE420E420; // {-1056.0, -1056.0}, fp16x2_t
+
+#pragma unroll
+ for (int i = 0; i < 4; i++) {
+ int32_t q4 = q4w & kMask000F;
+ int32_t q2 = (q2w & 0x3) | ((q2w & 0xC) << 14);
+ q4w >>= 4;
+ q2w >>= 4;
+ if constexpr (std::is_same::value) {
+ int32_t q6 = q4 | (q2 << 4) | kHalf2_1024;
+ asm volatile("v_pk_add_f16 %0, %1, %2"
+ : "=v"(w[i])
+ : "v"(q6), "v"(kHalf2_1056));
+ } else {
+ int32_t int16_2 = q4 | (q2 << 4);
+ int16_t low = static_cast(int16_2 & 0xFFFF);
+ int16_t high = static_cast((int16_2 >> 16) & 0xFFFF);
+
+ nv_bfloat16 bf_low = __float2bfloat16(static_cast(low));
+ nv_bfloat16 bf_high = __float2bfloat16(static_cast(high));
+ nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high);
+ int32_t packed_bf16 = *reinterpret_cast(&bf2);
+ w[i] = packed_add(packed_bf16, kRangeMin);
+ }
+ }
+ }
+
+ // Apply decoding scales
+ for (int i = 0; i < 4; i++) {
+ w[i] = packed_mul(w[i], qs);
+ }
+
+ // That's pretty much it...
+ data[k] = w;
+ }
+ }
+};
+
+// Int8 symmetric quantization codec.
+// We quantize the FP16 data to block-scaled Int8 in blocks of 4 *
+// kThreadGroupSize.
+template
+struct CodecQ8 : public CodecBase {
+ static constexpr int kWorldSize = world_size;
+
+ // Codec tile size process by this workgroup.
+ // Each threads processes a fragment of f16x8_t (16B),
+ // into a int8x8_t (8B) and a f16 scale shared among 32 values.
+ static constexpr int kRankAtoms = kAtoms / kWorldSize;
+ static constexpr int kRankTileStride = 2176;
+ static constexpr int kRankTileScaleOffset = 2048;
+ static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms;
+ static_assert(kRankTransmittedTileSize % 16 == 0,
+ "kRankTileSize must be 16B aligned.");
+
+ static constexpr int kRankBufferTileStride =
+ kRankTileStride / sizeof(int32x4_t);
+
+ // Total tile size for the collective communication.
+ static constexpr int kTransmittedTileSize =
+ kRankTransmittedTileSize * kWorldSize;
+
+ // Constants configuration
+
+ // {-1/128.0h, -1/128.0h}, f16x2_t
+ static constexpr int kScaleFactor =
+ std::is_same::value ? 0xA000A000 : 0xBC00BC00;
+
+ // {1e-7, 1e-7}, f16x2_t
+ static constexpr int kScaleEpsilon =
+ std::is_same::value ? 0x00010001 : 0x33D733D7;
+
+ // {-128, -128}, f16x2_t
+ static constexpr int kRangeMin =
+ std::is_same::value ? 0xD800D800 : 0xC300C300;
+ // {+127, +127}, f16x2_t
+ static constexpr int kRangeMax =
+ std::is_same::value ? 0x57F057F0 : 0x42FE42FE;
+
+ // {+128, +128}, int16x2_t
+ static constexpr int kRangeBias = 0x00800080;
+
+ __quickreduce_device_inline__ CodecQ8(int thread, int rank)
+ : CodecBase(thread, rank) {}
+
+ __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer,
+ int32x4_t const* __restrict__ data) {
+ for (int k = 0; k < kRankAtoms; k++) {
+ int32x4_t const atom = data[k];
+ // Compute the absolute maximum of the atom in the thread group
+ // In 2 blocks of values, upper/lower halves of the f16x2_t
+ int wblockmax = group_abs_max(atom);
+
+ // Derive scales
+ int decoding_scale;
+ int encoding_scale;
+ decoding_scale = packed_mul(wblockmax, kScaleFactor);
+ encoding_scale = packed_add(decoding_scale, kScaleEpsilon);
+ encoding_scale = packed_rcp(encoding_scale);
+
+ // Apply scales to get quantized values
+ int32x4_t w;
+ for (int i = 0; i < 4; i++) {
+ w[i] = packed_mul(atom[i], encoding_scale);
+ w[i] = packed_max(w[i], kRangeMin);
+ w[i] = packed_min(w[i], kRangeMax);
+ }
+
+ // Convert from f16x2_t to uint16x2_t
+ int32x4_t q;
+ {
+ int16_t* qi = reinterpret_cast(&q);
+ T* wh = reinterpret_cast(&w);
+ for (int i = 0; i < 8; i++) qi[i] = (int16_t)rintf(T2float_cast(wh[i]));
+
+ for (int i = 0; i < 4; i++) {
+ q[i] = packed_add(q[i], kRangeBias);
+ }
+ }
+
+ // Pack 8 x q8 into int32x2_t
+ int32x2_t qw;
+ qw[0] = q[0] | (q[1] << 8);
+ qw[1] = q[2] | (q[3] << 8);
+
+ // Write quantized atom to send_buffer
+ // note: only the group leader stores the scale
+ uint8_t* atom_ptr =
+ reinterpret_cast(send_buffer + k * kRankBufferTileStride);
+ int32x2_t* qw_ptr = reinterpret_cast(atom_ptr) + thread;
+ int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) +
+ (thread / 8);
+
+ __builtin_nontemporal_store(qw, qw_ptr);
+ if (threadIdx.x == group_leader) {
+ __builtin_nontemporal_store(decoding_scale, qs_ptr);
+ }
+ }
+ }
+
+ __quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer,
+ int32x4_t* __restrict__ data) {
+ for (int k = 0; k < kRankAtoms; k++) {
+ // Directly read quantized atom from recv_buffer
+ uint8_t* atom_ptr = reinterpret_cast(*recv_buffer);
+ int32x2_t* qw_ptr = reinterpret_cast(atom_ptr) + thread;
+ int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) +
+ (thread / 8);
+
+ int32x2_t qw = __builtin_nontemporal_load(qw_ptr);
+ int qs = __builtin_nontemporal_load(qs_ptr);
+
+ *recv_buffer += kRankBufferTileStride;
+
+ // Unpack q8 into fp16x8_t
+ int32x4_t w;
+ {
+ static uint constexpr kMask00FF = 0x00FF00FF;
+
+ // {1024.0, 1024.0}, fp16x2_t
+ static uint constexpr kHalf2_1024 = 0x64006400;
+
+ // {-1152.0, -1152.0}, fp16x2_t
+ static uint constexpr kHalf2_1152 = 0xE480E480;
+
+#pragma unroll
+ for (int i = 0; i < 4; i++) {
+ if constexpr (std::is_same::value) {
+ int32_t q8 =
+ ((qw[i / 2] >> ((i % 2) * 8)) & kMask00FF) | kHalf2_1024;
+ w[i] = packed_add(q8, kHalf2_1152);
+ } else {
+ int32_t int16_2 = (qw[i / 2] >> ((i % 2) * 8)) & kMask00FF;
+ int16_t low = static_cast(int16_2 & 0xFFFF);
+ int16_t high = static_cast((int16_2 >> 16) & 0xFFFF);
+ nv_bfloat16 bf_low = __float2bfloat16(static_cast(low));
+ nv_bfloat16 bf_high = __float2bfloat16(static_cast(high));
+ nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high);
+ int32_t packed_bf16 = *reinterpret_cast(&bf2);
+ w[i] = packed_add(packed_bf16, kRangeMin);
+ }
+ }
+ }
+
+ // Apply decoding scales
+ for (int i = 0; i < 4; i++) {
+ w[i] = packed_mul(w[i], qs);
+ }
+
+ data[k] = w;
+ }
+ }
+};
+
+// Twoshot All Reduce
+template
+struct AllReduceTwoshot {
+ static_assert(sizeof(T) == 2);
+
+ static constexpr int kWorldSize = Codec::kWorldSize;
+
+ __device__ static void run(
+ T const* __restrict__ input, T* __restrict__ output,
+ uint32_t const N, // number of elements
+ int const block, // block index
+ int const rank, // rank index
+ uint8_t** __restrict__ buffer_list, // communication buffers
+ uint32_t const data_offset, // offset to start of the data buffer
+ uint32_t flag_color) {
+ // Topology
+ int thread = threadIdx.x + threadIdx.y * kWavefront;
+ uint8_t* rank_buffer = buffer_list[rank];
+ Codec codec(thread, rank);
+ int block_id = blockIdx.x;
+ int grid_size = gridDim.x;
+ // --------------------------------------------------------
+ // Read input into registers
+ int32x4_t tA[kAtoms];
+
+ BufferResource src_buffer(const_cast(input), N * sizeof(T));
+ uint32_t src_offset = block * kTileSize + thread * sizeof(int32x4_t);
+
+ for (int i = 0; i < kAtoms; i++) {
+ tA[i] = buffer_load_dwordx4(src_buffer.descriptor, src_offset, 0, 0);
+ src_offset += kAtomStride * sizeof(int32x4_t);
+ if constexpr (cast_bf2half) {
+ const nv_bfloat162* bf_buf =
+ reinterpret_cast(&tA[i]);
+ half2 half_buf[4];
+#pragma unroll
+ for (int j = 0; j < 4; ++j) {
+ float2 f = __bfloat1622float2(bf_buf[j]);
+ half_buf[j] = __float22half2_rn(f);
+ }
+ tA[i] = *reinterpret_cast(half_buf);
+ }
+ }
+
+ // --------------------------------------------------------
+ // Phase-1A: Write segment data into the communication buffer of the target
+ // rank responsible for this segment.
+ uint32_t comm_data0_offset =
+ data_offset + block_id * Codec::kTransmittedTileSize;
+ uint32_t comm_data1_offset =
+ grid_size * Codec::kTransmittedTileSize + comm_data0_offset;
+
+ uint32_t comm_flags0_offset = block_id * (kWorldSize * sizeof(uint32_t));
+ uint32_t comm_flags1_offset =
+ grid_size * (kWorldSize * sizeof(uint32_t)) + comm_flags0_offset;
+
+ for (int r = 0; r < kWorldSize; r++) {
+ int32x4_t* send_buffer =
+ reinterpret_cast(buffer_list[r] + comm_data0_offset +
+ rank * Codec::kRankTransmittedTileSize);
+ codec.send(send_buffer, &tA[r * Codec::kRankAtoms]);
+ }
+
+ __syncthreads();
+ if (thread < kWorldSize) {
+ int r = thread;
+ uint32_t* flag_ptr = reinterpret_cast(
+ buffer_list[r] + comm_flags0_offset + rank * sizeof(uint32_t));
+ set_sync_flag(flag_ptr, flag_color);
+ }
+ // --------------------------------------------------------
+ // Phase-1B: Reduce the segment data from the communication buffers.
+ int32x4_t tR[Codec::kRankAtoms] = {};
+ {
+ // Read the data from the communication buffer.
+ int32x4_t* recv_buffer =
+ reinterpret_cast(rank_buffer + comm_data0_offset);
+ uint32_t* flag_ptr =
+ reinterpret_cast(rank_buffer + comm_flags0_offset);
+
+ for (int r = 0; r < kWorldSize; r++) {
+ // Wait for the flags to be set.
+ if (thread == 0) {
+ wait_sync_flag(&flag_ptr[r], flag_color);
+ }
+ __syncthreads();
+
+ // note: we reuse tA as temp buffer here
+ codec.recv(&recv_buffer, tA);
+
+ for (int i = 0; i < Codec::kRankAtoms; i++) {
+ packed_assign_add(&tR[i], &tA[i]);
+ }
+ }
+ }
+
+ // Phase-2: Write the reduced segment to every other rank
+ for (int r = 0; r < kWorldSize; r++) {
+ int32x4_t* send_buffer =
+ reinterpret_cast(buffer_list[r] + comm_data1_offset +
+ rank * Codec::kRankTransmittedTileSize);
+ codec.send(send_buffer, tR);
+ }
+
+ __syncthreads();
+ if (thread < kWorldSize) {
+ int r = thread;
+ uint32_t* flag_ptr = reinterpret_cast(
+ buffer_list[r] + comm_flags1_offset + rank * sizeof(uint32_t));
+ set_sync_flag(flag_ptr, flag_color);
+ }
+
+ // Phase-2: Read the gather segments from the rank's communication buffer.
+ {
+ // Read the data from the communication buffer.
+ int32x4_t* recv_buffer =
+ reinterpret_cast(rank_buffer + comm_data1_offset);
+ uint32_t* flag_ptr =
+ reinterpret_cast(rank_buffer + comm_flags1_offset);
+
+ for (int r = 0; r < kWorldSize; r++) {
+ // Wait for the flags to be set.
+ if (thread == 0) {
+ wait_sync_flag(&flag_ptr[r], flag_color);
+ }
+ __syncthreads();
+
+ // Gather all reduced and final rank segments into tA.
+ codec.recv(&recv_buffer, &tA[r * Codec::kRankAtoms]);
+ }
+ }
+
+ // --------------------------------------------------------
+ // Write the result to output.
+ BufferResource dst_buffer(output, N * sizeof(T));
+ uint32_t dst_offset = block * kTileSize + thread * sizeof(int32x4_t);
+
+ for (int i = 0; i < kAtoms; i++) {
+ if constexpr (cast_bf2half) {
+ const half2* half_buf = reinterpret_cast(&tA[i]);
+ nv_bfloat162 bf16_buf[4];
+#pragma unroll
+ for (int j = 0; j < 4; ++j) {
+ float2 f = __half22float2(half_buf[j]);
+ bf16_buf[j] = __float22bfloat162_rn(f);
+ }
+ buffer_store_dwordx4(*reinterpret_cast(bf16_buf),
+ dst_buffer.descriptor, dst_offset, 0, 0);
+ } else {
+ buffer_store_dwordx4(tA[i], dst_buffer.descriptor, dst_offset, 0, 0);
+ }
+ dst_offset += kAtomStride * sizeof(int32x4_t);
+ }
+ }
+};
+
+} // namespace quickreduce
\ No newline at end of file
diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp
index 40a14678ad87..d2932d4fa019 100644
--- a/csrc/torch_bindings.cpp
+++ b/csrc/torch_bindings.cpp
@@ -729,6 +729,24 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
custom_ar.impl("open_mem_handle", torch::kCPU, &open_mem_handle);
custom_ar.def("free_shared_buffer", &free_shared_buffer);
+#ifdef USE_ROCM
+ // Quick Reduce all-reduce kernels
+ custom_ar.def(
+ "qr_all_reduce(int fa, Tensor inp, Tensor out, int quant_level, bool "
+ "cast_bf2half) -> ()");
+ custom_ar.impl("qr_all_reduce", torch::kCUDA, &qr_all_reduce);
+
+ custom_ar.def("init_custom_qr", &init_custom_qr);
+ custom_ar.def("qr_destroy", &qr_destroy);
+
+ custom_ar.def("qr_get_handle", &qr_get_handle);
+
+ custom_ar.def("qr_open_handles(int _fa, Tensor[](b!) handles) -> ()");
+ custom_ar.impl("qr_open_handles", torch::kCPU, &qr_open_handles);
+
+ // Max input size in bytes
+ custom_ar.def("qr_max_size", &qr_max_size);
+#endif
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
diff --git a/docker/Dockerfile b/docker/Dockerfile
index cf9c245a9517..a71b052bfca2 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -6,30 +6,106 @@
# docs/assets/contributing/dockerfile-stages-dependency.png
ARG CUDA_VERSION=12.8.1
+ARG PYTHON_VERSION=3.12
+
+# By parameterizing the base images, we allow third-party to use their own
+# base images. One use case is hermetic builds with base images stored in
+# private registries that use a different repository naming conventions.
+#
+# Example:
+# docker build --build-arg BUILD_BASE_IMAGE=registry.acme.org/mirror/nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04
+ARG BUILD_BASE_IMAGE=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04
+ARG FINAL_BASE_IMAGE=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04
+
+# By parameterizing the Deadsnakes repository URL, we allow third-party to use
+# their own mirror. When doing so, we don't benefit from the transparent
+# installation of the GPG key of the PPA, as done by add-apt-repository, so we
+# also need a URL for the GPG key.
+ARG DEADSNAKES_MIRROR_URL
+ARG DEADSNAKES_GPGKEY_URL
+
+# The PyPA get-pip.py script is a self contained script+zip file, that provides
+# both the installer script and the pip base85-encoded zip archive. This allows
+# bootstrapping pip in environment where a dsitribution package does not exist.
+#
+# By parameterizing the URL for get-pip.py installation script, we allow
+# third-party to use their own copy of the script stored in a private mirror.
+# We set the default value to the PyPA owned get-pip.py script.
+#
+# Reference: https://pip.pypa.io/en/stable/installation/#get-pip-py
+ARG GET_PIP_URL="https://bootstrap.pypa.io/get-pip.py"
+
+# PIP supports fetching the packages from custom indexes, allowing third-party
+# to host the packages in private mirrors. The PIP_INDEX_URL and
+# PIP_EXTRA_INDEX_URL are standard PIP environment variables to override the
+# default indexes. By letting them empty by default, PIP will use its default
+# indexes if the build process doesn't override the indexes.
+#
+# Uv uses different variables. We set them by default to the same values as
+# PIP, but they can be overridden.
+ARG PIP_INDEX_URL
+ARG PIP_EXTRA_INDEX_URL
+ARG UV_INDEX_URL=${PIP_INDEX_URL}
+ARG UV_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL}
+
+# PyTorch provides its own indexes for standard and nightly builds
+ARG PYTORCH_CUDA_INDEX_BASE_URL=https://download.pytorch.org/whl
+ARG PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL=https://download.pytorch.org/whl/nightly
+
+# PIP supports multiple authentication schemes, including keyring
+# By parameterizing the PIP_KEYRING_PROVIDER variable and setting it to
+# disabled by default, we allow third-party to use keyring authentication for
+# their private Python indexes, while not changing the default behavior which
+# is no authentication.
+#
+# Reference: https://pip.pypa.io/en/stable/topics/authentication/#keyring-support
+ARG PIP_KEYRING_PROVIDER=disabled
+ARG UV_KEYRING_PROVIDER=${PIP_KEYRING_PROVIDER}
+
#################### BASE BUILD IMAGE ####################
# prepare basic build environment
-FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 AS base
-ARG CUDA_VERSION=12.8.1
-ARG PYTHON_VERSION=3.12
+FROM ${BUILD_BASE_IMAGE} AS base
+ARG CUDA_VERSION
+ARG PYTHON_VERSION
ARG TARGETPLATFORM
ENV DEBIAN_FRONTEND=noninteractive
+ARG DEADSNAKES_MIRROR_URL
+ARG DEADSNAKES_GPGKEY_URL
+ARG GET_PIP_URL
+
# Install Python and other dependencies
RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
&& echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
&& apt-get update -y \
&& apt-get install -y ccache software-properties-common git curl sudo \
- && for i in 1 2 3; do \
- add-apt-repository -y ppa:deadsnakes/ppa && break || \
- { echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \
- done \
+ && if [ ! -z ${DEADSNAKES_MIRROR_URL} ] ; then \
+ if [ ! -z "${DEADSNAKES_GPGKEY_URL}" ] ; then \
+ mkdir -p -m 0755 /etc/apt/keyrings ; \
+ curl -L ${DEADSNAKES_GPGKEY_URL} | gpg --dearmor > /etc/apt/keyrings/deadsnakes.gpg ; \
+ sudo chmod 644 /etc/apt/keyrings/deadsnakes.gpg ; \
+ echo "deb [signed-by=/etc/apt/keyrings/deadsnakes.gpg] ${DEADSNAKES_MIRROR_URL} $(lsb_release -cs) main" > /etc/apt/sources.list.d/deadsnakes.list ; \
+ fi ; \
+ else \
+ for i in 1 2 3; do \
+ add-apt-repository -y ppa:deadsnakes/ppa && break || \
+ { echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \
+ done ; \
+ fi \
&& apt-get update -y \
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \
&& update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \
&& ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \
- && curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \
+ && curl -sS ${GET_PIP_URL} | python${PYTHON_VERSION} \
&& python3 --version && python3 -m pip --version
+
+ARG PIP_INDEX_URL UV_INDEX_URL
+ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL
+ARG PYTORCH_CUDA_INDEX_BASE_URL
+ARG PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL
+ARG PIP_KEYRING_PROVIDER UV_KEYRING_PROVIDER
+
# Install uv for faster pip installs
RUN --mount=type=cache,target=/root/.cache/uv \
python3 -m pip install uv
@@ -63,21 +139,25 @@ WORKDIR /workspace
# after this step
RUN --mount=type=cache,target=/root/.cache/uv \
if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
- uv pip install --system --index-url https://download.pytorch.org/whl/nightly/cu128 "torch==2.8.0.dev20250318+cu128" "torchvision==0.22.0.dev20250319"; \
- uv pip install --system --index-url https://download.pytorch.org/whl/nightly/cu128 --pre pytorch_triton==3.3.0+gitab727c40; \
+ uv pip install --system \
+ --index-url ${PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \
+ "torch==2.8.0.dev20250318+cu128" "torchvision==0.22.0.dev20250319"; \
+ uv pip install --system \
+ --index-url ${PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \
+ --pre pytorch_triton==3.3.0+gitab727c40; \
fi
COPY requirements/common.txt requirements/common.txt
COPY requirements/cuda.txt requirements/cuda.txt
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system -r requirements/cuda.txt \
- --extra-index-url https://download.pytorch.org/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
+ --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
# cuda arch list used by torch
# can be useful for both `dev` and `test`
# explicitly set the list to avoid issues with torch 2.2
# see https://github.com/pytorch/pytorch/pull/123243
-ARG torch_cuda_arch_list='7.0 7.5 8.0 8.9 9.0 10.0+PTX'
+ARG torch_cuda_arch_list='7.0 7.5 8.0 8.9 9.0 10.0 12.0'
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
# Override the arch list for flash-attn to reduce the binary size
ARG vllm_fa_cmake_gpu_arches='80-real;90-real'
@@ -88,6 +168,10 @@ ENV VLLM_FA_CMAKE_GPU_ARCHES=${vllm_fa_cmake_gpu_arches}
FROM base AS build
ARG TARGETPLATFORM
+ARG PIP_INDEX_URL UV_INDEX_URL
+ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL
+ARG PYTORCH_CUDA_INDEX_BASE_URL
+
# install build dependencies
COPY requirements/build.txt requirements/build.txt
@@ -98,7 +182,7 @@ ENV UV_INDEX_STRATEGY="unsafe-best-match"
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system -r requirements/build.txt \
- --extra-index-url https://download.pytorch.org/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
+ --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
COPY . .
ARG GIT_REPO_CHECK=0
@@ -113,6 +197,8 @@ ARG nvcc_threads=8
ENV NVCC_THREADS=$nvcc_threads
ARG USE_SCCACHE
+ARG SCCACHE_DOWNLOAD_URL=https://github.com/mozilla/sccache/releases/download/v0.8.1/sccache-v0.8.1-x86_64-unknown-linux-musl.tar.gz
+ARG SCCACHE_ENDPOINT
ARG SCCACHE_BUCKET_NAME=vllm-build-sccache
ARG SCCACHE_REGION_NAME=us-west-2
ARG SCCACHE_S3_NO_CREDENTIALS=0
@@ -121,10 +207,11 @@ RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,source=.git,target=.git \
if [ "$USE_SCCACHE" = "1" ]; then \
echo "Installing sccache..." \
- && curl -L -o sccache.tar.gz https://github.com/mozilla/sccache/releases/download/v0.8.1/sccache-v0.8.1-x86_64-unknown-linux-musl.tar.gz \
+ && curl -L -o sccache.tar.gz ${SCCACHE_DOWNLOAD_URL} \
&& tar -xzf sccache.tar.gz \
&& sudo mv sccache-v0.8.1-x86_64-unknown-linux-musl/sccache /usr/bin/sccache \
&& rm -rf sccache.tar.gz sccache-v0.8.1-x86_64-unknown-linux-musl \
+ && if [ ! -z ${SCCACHE_ENDPOINT} ] ; then export SCCACHE_ENDPOINT=${SCCACHE_ENDPOINT} ; fi \
&& export SCCACHE_BUCKET=${SCCACHE_BUCKET_NAME} \
&& export SCCACHE_REGION=${SCCACHE_REGION_NAME} \
&& export SCCACHE_S3_NO_CREDENTIALS=${SCCACHE_S3_NO_CREDENTIALS} \
@@ -162,6 +249,10 @@ RUN if [ "$RUN_WHEEL_CHECK" = "true" ]; then \
#################### DEV IMAGE ####################
FROM base as dev
+ARG PIP_INDEX_URL UV_INDEX_URL
+ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL
+ARG PYTORCH_CUDA_INDEX_BASE_URL
+
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
# Reference: https://github.com/astral-sh/uv/pull/1694
ENV UV_HTTP_TIMEOUT=500
@@ -176,21 +267,25 @@ COPY requirements/test.txt requirements/test.txt
COPY requirements/dev.txt requirements/dev.txt
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system -r requirements/dev.txt \
- --extra-index-url https://download.pytorch.org/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
+ --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
#################### DEV IMAGE ####################
#################### vLLM installation IMAGE ####################
# image with vLLM installed
# TODO: Restore to base image after FlashInfer AOT wheel fixed
-FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 AS vllm-base
-ARG CUDA_VERSION=12.8.1
-ARG PYTHON_VERSION=3.12
+FROM ${FINAL_BASE_IMAGE} AS vllm-base
+ARG CUDA_VERSION
+ARG PYTHON_VERSION
WORKDIR /vllm-workspace
ENV DEBIAN_FRONTEND=noninteractive
ARG TARGETPLATFORM
SHELL ["/bin/bash", "-c"]
+ARG DEADSNAKES_MIRROR_URL
+ARG DEADSNAKES_GPGKEY_URL
+ARG GET_PIP_URL
+
RUN PYTHON_VERSION_STR=$(echo ${PYTHON_VERSION} | sed 's/\.//g') && \
echo "export PYTHON_VERSION_STR=${PYTHON_VERSION_STR}" >> /etc/environment
@@ -200,17 +295,33 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
&& apt-get update -y \
&& apt-get install -y ccache software-properties-common git curl wget sudo vim python3-pip \
&& apt-get install -y ffmpeg libsm6 libxext6 libgl1 \
- && for i in 1 2 3; do \
- add-apt-repository -y ppa:deadsnakes/ppa && break || \
- { echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \
- done \
+ && if [ ! -z ${DEADSNAKES_MIRROR_URL} ] ; then \
+ if [ ! -z "${DEADSNAKES_GPGKEY_URL}" ] ; then \
+ mkdir -p -m 0755 /etc/apt/keyrings ; \
+ curl -L ${DEADSNAKES_GPGKEY_URL} | gpg --dearmor > /etc/apt/keyrings/deadsnakes.gpg ; \
+ sudo chmod 644 /etc/apt/keyrings/deadsnakes.gpg ; \
+ echo "deb [signed-by=/etc/apt/keyrings/deadsnakes.gpg] ${DEADSNAKES_MIRROR_URL} $(lsb_release -cs) main" > /etc/apt/sources.list.d/deadsnakes.list ; \
+ fi ; \
+ else \
+ for i in 1 2 3; do \
+ add-apt-repository -y ppa:deadsnakes/ppa && break || \
+ { echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \
+ done ; \
+ fi \
&& apt-get update -y \
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv libibverbs-dev \
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \
&& update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \
&& ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \
- && curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \
+ && curl -sS ${GET_PIP_URL} | python${PYTHON_VERSION} \
&& python3 --version && python3 -m pip --version
+
+ARG PIP_INDEX_URL UV_INDEX_URL
+ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL
+ARG PYTORCH_CUDA_INDEX_BASE_URL
+ARG PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL
+ARG PIP_KEYRING_PROVIDER UV_KEYRING_PROVIDER
+
# Install uv for faster pip installs
RUN --mount=type=cache,target=/root/.cache/uv \
python3 -m pip install uv
@@ -232,19 +343,23 @@ RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/
# after this step
RUN --mount=type=cache,target=/root/.cache/uv \
if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
- uv pip install --system --index-url https://download.pytorch.org/whl/nightly/cu128 "torch==2.8.0.dev20250318+cu128" "torchvision==0.22.0.dev20250319"; \
- uv pip install --system --index-url https://download.pytorch.org/whl/nightly/cu128 --pre pytorch_triton==3.3.0+gitab727c40; \
+ uv pip install --system \
+ --index-url ${PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \
+ "torch==2.8.0.dev20250318+cu128" "torchvision==0.22.0.dev20250319" ; \
+ uv pip install --system \
+ --index-url ${PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \
+ --pre pytorch_triton==3.3.0+gitab727c40 ; \
fi
# Install vllm wheel first, so that torch etc will be installed.
RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \
--mount=type=cache,target=/root/.cache/uv \
uv pip install --system dist/*.whl --verbose \
- --extra-index-url https://download.pytorch.org/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
+ --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
# If we need to build FlashInfer wheel before its release:
# $ # Note we remove 7.0 from the arch list compared to the list below, since FlashInfer only supports sm75+
-# $ export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0a 10.0a'
+# $ export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0a 10.0a 12.0'
# $ git clone https://github.com/flashinfer-ai/flashinfer.git --recursive
# $ cd flashinfer
# $ git checkout v0.2.6.post1
@@ -254,15 +369,20 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
# -rw-rw-r-- 1 mgoin mgoin 205M Jun 9 18:03 flashinfer_python-0.2.6.post1-cp39-abi3-linux_x86_64.whl
# $ # upload the wheel to a public location, e.g. https://wheels.vllm.ai/flashinfer/v0.2.6.post1/flashinfer_python-0.2.6.post1-cp39-abi3-linux_x86_64.whl
+# Allow specifying a version, Git revision or local .whl file
+ARG FLASHINFER_CUDA128_INDEX_URL="https://download.pytorch.org/whl/cu128/flashinfer"
+ARG FLASHINFER_CUDA128_WHEEL="flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl"
+ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
+ARG FLASHINFER_GIT_REF="v0.2.6.post1"
RUN --mount=type=cache,target=/root/.cache/uv \
. /etc/environment && \
if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \
# FlashInfer already has a wheel for PyTorch 2.7.0 and CUDA 12.8. This is enough for CI use
if [[ "$CUDA_VERSION" == 12.8* ]]; then \
- uv pip install --system https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl; \
+ uv pip install --system ${FLASHINFER_CUDA128_INDEX_URL}/${FLASHINFER_CUDA128_WHEEL} ; \
else \
- export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0a 10.0a' && \
- git clone https://github.com/flashinfer-ai/flashinfer.git --single-branch --branch v0.2.6.post1 --recursive && \
+ export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0a 10.0a 12.0' && \
+ git clone ${FLASHINFER_GIT_REPO} --single-branch --branch ${FLASHINFER_GIT_REF} --recursive && \
# Needed to build AOT kernels
(cd flashinfer && \
python3 -m flashinfer.aot && \
@@ -286,7 +406,7 @@ uv pip list
COPY requirements/build.txt requirements/build.txt
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system -r requirements/build.txt \
- --extra-index-url https://download.pytorch.org/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
+ --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
#################### vLLM installation IMAGE ####################
@@ -297,6 +417,11 @@ FROM vllm-base AS test
ADD . /vllm-workspace/
+ARG PYTHON_VERSION
+
+ARG PIP_INDEX_URL UV_INDEX_URL
+ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL
+
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
# Reference: https://github.com/astral-sh/uv/pull/1694
ENV UV_HTTP_TIMEOUT=500
@@ -307,7 +432,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system --no-build-isolation "git+https://github.com/state-spaces/mamba@v2.2.4"
# install development dependencies (for testing)
-RUN --mount=type=cache,target=/root/.cache/uv \
+RUN --mount=type=cache,target=/root/.cache/uv \
CUDA_MAJOR="${CUDA_VERSION%%.*}"; \
if [ "$CUDA_MAJOR" -ge 12 ]; then \
uv pip install --system -r requirements/dev.txt; \
@@ -323,7 +448,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
ENV HF_HUB_ENABLE_HF_TRANSFER 1
# Copy in the v1 package for testing (it isn't distributed yet)
-COPY vllm/v1 /usr/local/lib/python3.12/dist-packages/vllm/v1
+COPY vllm/v1 /usr/local/lib/python${PYTHON_VERSION}/dist-packages/vllm/v1
# doc requires source code
# we hide them inside `test_docs/` , so that this source code
@@ -340,6 +465,9 @@ RUN mv mkdocs.yaml test_docs/
FROM vllm-base AS vllm-openai-base
ARG TARGETPLATFORM
+ARG PIP_INDEX_URL UV_INDEX_URL
+ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL
+
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
# Reference: https://github.com/astral-sh/uv/pull/1694
ENV UV_HTTP_TIMEOUT=500
diff --git a/docker/Dockerfile.cpu b/docker/Dockerfile.cpu
index 3e9fa0e7af2d..13bd03c5696a 100644
--- a/docker/Dockerfile.cpu
+++ b/docker/Dockerfile.cpu
@@ -66,7 +66,7 @@ ENV VLLM_CPU_DISABLE_AVX512=${VLLM_CPU_DISABLE_AVX512}
WORKDIR /workspace/vllm
RUN --mount=type=cache,target=/root/.cache/uv \
- --mount=type=bind,src=requirements/build.txt,target=requirements/build.txt \
+ --mount=type=bind,src=requirements/cpu-build.txt,target=requirements/build.txt \
uv pip install -r requirements/build.txt
COPY . .
@@ -79,6 +79,22 @@ RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,source=.git,target=.git \
VLLM_TARGET_DEVICE=cpu python3 setup.py bdist_wheel
+######################### TEST DEPS #########################
+FROM base AS vllm-test-deps
+
+WORKDIR /workspace/vllm
+
+RUN --mount=type=bind,src=requirements/test.in,target=requirements/test.in \
+ cp requirements/test.in requirements/cpu-test.in && \
+ sed -i '/mamba_ssm/d' requirements/cpu-test.in && \
+ sed -i 's/torch==.*/torch==2.6.0/g' requirements/cpu-test.in && \
+ sed -i 's/torchaudio.*/torchaudio/g' requirements/cpu-test.in && \
+ sed -i 's/torchvision.*/torchvision/g' requirements/cpu-test.in && \
+ uv pip compile requirements/cpu-test.in -o requirements/cpu-test.txt --index-strategy unsafe-best-match --torch-backend cpu
+
+RUN --mount=type=cache,target=/root/.cache/uv \
+ uv pip install -r requirements/cpu-test.txt
+
######################### DEV IMAGE #########################
FROM vllm-build AS vllm-dev
@@ -97,28 +113,19 @@ RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,source=.git,target=.git \
VLLM_TARGET_DEVICE=cpu python3 setup.py develop
+COPY --from=vllm-test-deps /workspace/vllm/requirements/cpu-test.txt requirements/test.txt
+
RUN --mount=type=cache,target=/root/.cache/uv \
- --mount=type=bind,src=requirements/test.in,target=requirements/test.in \
- cp requirements/test.in requirements/test-cpu.in && \
- sed -i '/mamba_ssm/d' requirements/test-cpu.in && \
- uv pip compile requirements/test-cpu.in -o requirements/test.txt && \
uv pip install -r requirements/dev.txt && \
pre-commit install --hook-type pre-commit --hook-type commit-msg
ENTRYPOINT ["bash"]
######################### TEST IMAGE #########################
-FROM base AS vllm-test
+FROM vllm-test-deps AS vllm-test
WORKDIR /workspace/
-RUN --mount=type=cache,target=/root/.cache/uv \
- --mount=type=bind,src=requirements/test.in,target=requirements/test.in \
- cp requirements/test.in requirements/test-cpu.in && \
- sed -i '/mamba_ssm/d' requirements/test-cpu.in && \
- uv pip compile requirements/test-cpu.in -o requirements/cpu-test.txt && \
- uv pip install -r requirements/cpu-test.txt
-
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,from=vllm-build,src=/workspace/vllm/dist,target=dist \
uv pip install dist/*.whl
diff --git a/docker/Dockerfile.xpu b/docker/Dockerfile.xpu
index 681102b9d18b..466ba9833363 100644
--- a/docker/Dockerfile.xpu
+++ b/docker/Dockerfile.xpu
@@ -35,6 +35,7 @@ RUN --mount=type=bind,source=.git,target=.git \
if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh; fi
ENV VLLM_TARGET_DEVICE=xpu
+ENV VLLM_WORKER_MULTIPROC_METHOD=spawn
RUN --mount=type=cache,target=/root/.cache/pip \
--mount=type=bind,source=.git,target=.git \
diff --git a/docs/.nav.yml b/docs/.nav.yml
index a9c594c29177..e679807f7534 100644
--- a/docs/.nav.yml
+++ b/docs/.nav.yml
@@ -48,7 +48,12 @@ nav:
- General:
- glob: contributing/*
flatten_single_child_sections: true
- - Model Implementation: contributing/model
+ - Model Implementation:
+ - contributing/model/README.md
+ - contributing/model/basic.md
+ - contributing/model/registration.md
+ - contributing/model/tests.md
+ - contributing/model/multimodal.md
- Design Documents:
- V0: design
- V1: design/v1
diff --git a/docs/README.md b/docs/README.md
index 0c6aff5fa07c..9fb3137b3192 100644
--- a/docs/README.md
+++ b/docs/README.md
@@ -40,7 +40,7 @@ vLLM is flexible and easy to use with:
- OpenAI-compatible API server
- Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs, Gaudiยฎ accelerators and GPUs, IBM Power CPUs, TPU, and AWS Trainium and Inferentia Accelerators.
- Prefix caching support
-- Multi-lora support
+- Multi-LoRA support
For more information, check out the following:
diff --git a/docs/contributing/README.md b/docs/contributing/README.md
index c0c338b42695..83525436be13 100644
--- a/docs/contributing/README.md
+++ b/docs/contributing/README.md
@@ -151,6 +151,14 @@ the terms of the DCO.
Using `-s` with `git commit` will automatically add this header.
+!!! tip
+ You can enable automatic sign-off via your IDE:
+
+ - **PyCharm**: Click on the `Show Commit Options` icon to the right of the `Commit and Push...` button in the `Commit` window.
+ It will bring up a `git` window where you can modify the `Author` and enable `Sign-off commit`.
+ - **VSCode**: Open the [Settings editor](https://code.visualstudio.com/docs/configure/settings)
+ and enable the `Git: Always Sign Off` (`git.alwaysSignOff`) field.
+
### PR Title and Classification
Only specific types of PRs will be reviewed. The PR title is prefixed
diff --git a/docs/contributing/incremental_build.md b/docs/contributing/incremental_build.md
index 8efa34825eca..14c3aaead51e 100644
--- a/docs/contributing/incremental_build.md
+++ b/docs/contributing/incremental_build.md
@@ -1,4 +1,4 @@
-# Incremental Compilation Workflow for vLLM Development
+# Incremental Compilation Workflow
When working on vLLM's C++/CUDA kernels located in the `csrc/` directory, recompiling the entire project with `uv pip install -e .` for every change can be time-consuming. An incremental compilation workflow using CMake allows for faster iteration by only recompiling the necessary components after an initial setup. This guide details how to set up and use such a workflow, which complements your editable Python installation.
diff --git a/docs/contributing/model/README.md b/docs/contributing/model/README.md
index b7727f02c11b..63abb7991050 100644
--- a/docs/contributing/model/README.md
+++ b/docs/contributing/model/README.md
@@ -1,21 +1,23 @@
---
-title: Adding a New Model
+title: Summary
---
[](){ #new-model }
-This section provides more information on how to integrate a [PyTorch](https://pytorch.org/) model into vLLM.
+!!! important
+ Many decoder language models can now be automatically loaded using the [Transformers backend][transformers-backend] without having to implement them in vLLM. See if `vllm serve ` works first!
-Contents:
+vLLM models are specialized [PyTorch](https://pytorch.org/) models that take advantage of various [features][compatibility-matrix] to optimize their performance.
-- [Basic](basic.md)
-- [Registration](registration.md)
-- [Tests](tests.md)
-- [Multimodal](multimodal.md)
+The complexity of integrating a model into vLLM depends heavily on the model's architecture.
+The process is considerably straightforward if the model shares a similar architecture with an existing model in vLLM.
+However, this can be more complex for models that include new operators (e.g., a new attention mechanism).
-!!! note
- The complexity of adding a new model depends heavily on the model's architecture.
- The process is considerably straightforward if the model shares a similar architecture with an existing model in vLLM.
- However, for models that include new operators (e.g., a new attention mechanism), the process can be a bit more complex.
+Read through these pages for a step-by-step guide:
+
+- [Basic Model](basic.md)
+- [Registering a Model](registration.md)
+- [Unit Testing](tests.md)
+- [Multi-Modal Support](multimodal.md)
!!! tip
If you are encountering issues while integrating your model into vLLM, feel free to open a [GitHub issue](https://github.com/vllm-project/vllm/issues)
diff --git a/docs/contributing/model/basic.md b/docs/contributing/model/basic.md
index 644d21482ef6..d552cd06be20 100644
--- a/docs/contributing/model/basic.md
+++ b/docs/contributing/model/basic.md
@@ -1,5 +1,5 @@
---
-title: Implementing a Basic Model
+title: Basic Model
---
[](){ #new-model-basic }
diff --git a/docs/contributing/model/registration.md b/docs/contributing/model/registration.md
index a6dc1e32dfb9..758caa72cd4a 100644
--- a/docs/contributing/model/registration.md
+++ b/docs/contributing/model/registration.md
@@ -1,5 +1,5 @@
---
-title: Registering a Model to vLLM
+title: Registering a Model
---
[](){ #new-model-registration }
diff --git a/docs/contributing/model/tests.md b/docs/contributing/model/tests.md
index a8cb457453b9..c7bcc02a8b80 100644
--- a/docs/contributing/model/tests.md
+++ b/docs/contributing/model/tests.md
@@ -1,5 +1,5 @@
---
-title: Writing Unit Tests
+title: Unit Testing
---
[](){ #new-model-tests }
diff --git a/docs/deployment/frameworks/helm.md b/docs/deployment/frameworks/helm.md
index cff8af2c09d2..d929665e8a3d 100644
--- a/docs/deployment/frameworks/helm.md
+++ b/docs/deployment/frameworks/helm.md
@@ -5,9 +5,9 @@ title: Helm
A Helm chart to deploy vLLM for Kubernetes
-Helm is a package manager for Kubernetes. It will help you to deploy vLLM on k8s and automate the deployment of vLLM Kubernetes applications. With Helm, you can deploy the same framework architecture with different configurations to multiple namespaces by overriding variable values.
+Helm is a package manager for Kubernetes. It helps automate the deployment of vLLM applications on Kubernetes. With Helm, you can deploy the same framework architecture with different configurations to multiple namespaces by overriding variable values.
-This guide will walk you through the process of deploying vLLM with Helm, including the necessary prerequisites, steps for helm installation and documentation on architecture and values file.
+This guide will walk you through the process of deploying vLLM with Helm, including the necessary prerequisites, steps for Helm installation and documentation on architecture and values file.
## Prerequisites
@@ -16,17 +16,23 @@ Before you begin, ensure that you have the following:
- A running Kubernetes cluster
- NVIDIA Kubernetes Device Plugin (`k8s-device-plugin`): This can be found at [https://github.com/NVIDIA/k8s-device-plugin](https://github.com/NVIDIA/k8s-device-plugin)
- Available GPU resources in your cluster
-- S3 with the model which will be deployed
+- An S3 with the model which will be deployed
## Installing the chart
To install the chart with the release name `test-vllm`:
```bash
-helm upgrade --install --create-namespace --namespace=ns-vllm test-vllm . -f values.yaml --set secrets.s3endpoint=$ACCESS_POINT --set secrets.s3bucketname=$BUCKET --set secrets.s3accesskeyid=$ACCESS_KEY --set secrets.s3accesskey=$SECRET_KEY
+helm upgrade --install --create-namespace \
+ --namespace=ns-vllm test-vllm . \
+ -f values.yaml \
+ --set secrets.s3endpoint=$ACCESS_POINT \
+ --set secrets.s3bucketname=$BUCKET \
+ --set secrets.s3accesskeyid=$ACCESS_KEY \
+ --set secrets.s3accesskey=$SECRET_KEY
```
-## Uninstalling the Chart
+## Uninstalling the chart
To uninstall the `test-vllm` deployment:
@@ -39,57 +45,59 @@ chart **including persistent volumes** and deletes the release.
## Architecture
-
+
## Values
-| Key | Type | Default | Description |
-|--------------------------------------------|---------|----------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------|
-| autoscaling | object | {"enabled":false,"maxReplicas":100,"minReplicas":1,"targetCPUUtilizationPercentage":80} | Autoscaling configuration |
-| autoscaling.enabled | bool | false | Enable autoscaling |
-| autoscaling.maxReplicas | int | 100 | Maximum replicas |
-| autoscaling.minReplicas | int | 1 | Minimum replicas |
-| autoscaling.targetCPUUtilizationPercentage | int | 80 | Target CPU utilization for autoscaling |
-| configs | object | {} | Configmap |
-| containerPort | int | 8000 | Container port |
-| customObjects | list | [] | Custom Objects configuration |
-| deploymentStrategy | object | {} | Deployment strategy configuration |
-| externalConfigs | list | [] | External configuration |
-| extraContainers | list | [] | Additional containers configuration |
-| extraInit | object | {"pvcStorage":"1Gi","s3modelpath":"relative_s3_model_path/opt-125m", "awsEc2MetadataDisabled": true} | Additional configuration for the init container |
-| extraInit.pvcStorage | string | "50Gi" | Storage size of the s3 |
-| extraInit.s3modelpath | string | "relative_s3_model_path/opt-125m" | Path of the model on the s3 which hosts model weights and config files |
-| extraInit.awsEc2MetadataDisabled | boolean | true | Disables the use of the Amazon EC2 instance metadata service |
-| extraPorts | list | [] | Additional ports configuration |
-| gpuModels | list | ["TYPE_GPU_USED"] | Type of gpu used |
-| image | object | {"command":["vllm","serve","/data/","--served-model-name","opt-125m","--host","0.0.0.0","--port","8000"],"repository":"vllm/vllm-openai","tag":"latest"} | Image configuration |
-| image.command | list | ["vllm","serve","/data/","--served-model-name","opt-125m","--host","0.0.0.0","--port","8000"] | Container launch command |
-| image.repository | string | "vllm/vllm-openai" | Image repository |
-| image.tag | string | "latest" | Image tag |
-| livenessProbe | object | {"failureThreshold":3,"httpGet":{"path":"/health","port":8000},"initialDelaySeconds":15,"periodSeconds":10} | Liveness probe configuration |
-| livenessProbe.failureThreshold | int | 3 | Number of times after which if a probe fails in a row, Kubernetes considers that the overall check has failed: the container is not alive |
-| livenessProbe.httpGet | object | {"path":"/health","port":8000} | Configuration of the Kubelet http request on the server |
-| livenessProbe.httpGet.path | string | "/health" | Path to access on the HTTP server |
-| livenessProbe.httpGet.port | int | 8000 | Name or number of the port to access on the container, on which the server is listening |
-| livenessProbe.initialDelaySeconds | int | 15 | Number of seconds after the container has started before liveness probe is initiated |
-| livenessProbe.periodSeconds | int | 10 | How often (in seconds) to perform the liveness probe |
-| maxUnavailablePodDisruptionBudget | string | "" | Disruption Budget Configuration |
-| readinessProbe | object | {"failureThreshold":3,"httpGet":{"path":"/health","port":8000},"initialDelaySeconds":5,"periodSeconds":5} | Readiness probe configuration |
-| readinessProbe.failureThreshold | int | 3 | Number of times after which if a probe fails in a row, Kubernetes considers that the overall check has failed: the container is not ready |
-| readinessProbe.httpGet | object | {"path":"/health","port":8000} | Configuration of the Kubelet http request on the server |
-| readinessProbe.httpGet.path | string | "/health" | Path to access on the HTTP server |
-| readinessProbe.httpGet.port | int | 8000 | Name or number of the port to access on the container, on which the server is listening |
-| readinessProbe.initialDelaySeconds | int | 5 | Number of seconds after the container has started before readiness probe is initiated |
-| readinessProbe.periodSeconds | int | 5 | How often (in seconds) to perform the readiness probe |
-| replicaCount | int | 1 | Number of replicas |
-| resources | object | {"limits":{"cpu":4,"memory":"16Gi","nvidia.com/gpu":1},"requests":{"cpu":4,"memory":"16Gi","nvidia.com/gpu":1}} | Resource configuration |
-| resources.limits."nvidia.com/gpu" | int | 1 | Number of gpus used |
-| resources.limits.cpu | int | 4 | Number of CPUs |
-| resources.limits.memory | string | "16Gi" | CPU memory configuration |
-| resources.requests."nvidia.com/gpu" | int | 1 | Number of gpus used |
-| resources.requests.cpu | int | 4 | Number of CPUs |
-| resources.requests.memory | string | "16Gi" | CPU memory configuration |
-| secrets | object | {} | Secrets configuration |
-| serviceName | string | Service name | |
-| servicePort | int | 80 | Service port |
-| labels.environment | string | test | Environment name |
+The following table describes configurable parameters of the chart in `values.yaml`:
+
+| Key | Type | Default | Description |
+|-----|------|---------|-------------|
+| autoscaling | object | {"enabled":false,"maxReplicas":100,"minReplicas":1,"targetCPUUtilizationPercentage":80} | Autoscaling configuration |
+| autoscaling.enabled | bool | false | Enable autoscaling |
+| autoscaling.maxReplicas | int | 100 | Maximum replicas |
+| autoscaling.minReplicas | int | 1 | Minimum replicas |
+| autoscaling.targetCPUUtilizationPercentage | int | 80 | Target CPU utilization for autoscaling |
+| configs | object | {} | Configmap |
+| containerPort | int | 8000 | Container port |
+| customObjects | list | [] | Custom Objects configuration |
+| deploymentStrategy | object | {} | Deployment strategy configuration |
+| externalConfigs | list | [] | External configuration |
+| extraContainers | list | [] | Additional containers configuration |
+| extraInit | object | {"pvcStorage":"1Gi","s3modelpath":"relative_s3_model_path/opt-125m", "awsEc2MetadataDisabled": true} | Additional configuration for the init container |
+| extraInit.pvcStorage | string | "1Gi" | Storage size of the s3 |
+| extraInit.s3modelpath | string | "relative_s3_model_path/opt-125m" | Path of the model on the s3 which hosts model weights and config files |
+| extraInit.awsEc2MetadataDisabled | boolean | true | Disables the use of the Amazon EC2 instance metadata service |
+| extraPorts | list | [] | Additional ports configuration |
+| gpuModels | list | ["TYPE_GPU_USED"] | Type of gpu used |
+| image | object | {"command":["vllm","serve","/data/","--served-model-name","opt-125m","--host","0.0.0.0","--port","8000"],"repository":"vllm/vllm-openai","tag":"latest"} | Image configuration |
+| image.command | list | ["vllm","serve","/data/","--served-model-name","opt-125m","--host","0.0.0.0","--port","8000"] | Container launch command |
+| image.repository | string | "vllm/vllm-openai" | Image repository |
+| image.tag | string | "latest" | Image tag |
+| livenessProbe | object | {"failureThreshold":3,"httpGet":{"path":"/health","port":8000},"initialDelaySeconds":15,"periodSeconds":10} | Liveness probe configuration |
+| livenessProbe.failureThreshold | int | 3 | Number of times after which if a probe fails in a row, Kubernetes considers that the overall check has failed: the container is not alive |
+| livenessProbe.httpGet | object | {"path":"/health","port":8000} | Configuration of the kubelet http request on the server |
+| livenessProbe.httpGet.path | string | "/health" | Path to access on the HTTP server |
+| livenessProbe.httpGet.port | int | 8000 | Name or number of the port to access on the container, on which the server is listening |
+| livenessProbe.initialDelaySeconds | int | 15 | Number of seconds after the container has started before liveness probe is initiated |
+| livenessProbe.periodSeconds | int | 10 | How often (in seconds) to perform the liveness probe |
+| maxUnavailablePodDisruptionBudget | string | "" | Disruption Budget Configuration |
+| readinessProbe | object | {"failureThreshold":3,"httpGet":{"path":"/health","port":8000},"initialDelaySeconds":5,"periodSeconds":5} | Readiness probe configuration |
+| readinessProbe.failureThreshold | int | 3 | Number of times after which if a probe fails in a row, Kubernetes considers that the overall check has failed: the container is not ready |
+| readinessProbe.httpGet | object | {"path":"/health","port":8000} | Configuration of the kubelet http request on the server |
+| readinessProbe.httpGet.path | string | "/health" | Path to access on the HTTP server |
+| readinessProbe.httpGet.port | int | 8000 | Name or number of the port to access on the container, on which the server is listening |
+| readinessProbe.initialDelaySeconds | int | 5 | Number of seconds after the container has started before readiness probe is initiated |
+| readinessProbe.periodSeconds | int | 5 | How often (in seconds) to perform the readiness probe |
+| replicaCount | int | 1 | Number of replicas |
+| resources | object | {"limits":{"cpu":4,"memory":"16Gi","nvidia.com/gpu":1},"requests":{"cpu":4,"memory":"16Gi","nvidia.com/gpu":1}} | Resource configuration |
+| resources.limits."nvidia.com/gpu" | int | 1 | Number of GPUs used |
+| resources.limits.cpu | int | 4 | Number of CPUs |
+| resources.limits.memory | string | "16Gi" | CPU memory configuration |
+| resources.requests."nvidia.com/gpu" | int | 1 | Number of GPUs used |
+| resources.requests.cpu | int | 4 | Number of CPUs |
+| resources.requests.memory | string | "16Gi" | CPU memory configuration |
+| secrets | object | {} | Secrets configuration |
+| serviceName | string | "" | Service name |
+| servicePort | int | 80 | Service port |
+| labels.environment | string | test | Environment name |
diff --git a/docs/design/v1/prefix_caching.md b/docs/design/v1/prefix_caching.md
index e87e4c6a48b7..2d3c8412894a 100644
--- a/docs/design/v1/prefix_caching.md
+++ b/docs/design/v1/prefix_caching.md
@@ -117,8 +117,8 @@ There are two design points to highlight:
1. We allocate all KVCacheBlock when initializing the KV cache manager to be a block pool. This avoids Python object creation overheads and can easily track all blocks all the time.
2. We introduce doubly linked list pointers directly in the KVCacheBlock, so that we could construct a free queue directly. This gives us two benefits:
- 1. We could have O(1) complexity moving elements in the middle to the tail.
- 2. We could avoid introducing another Python queue (e.g., `deque`) which has a wrapper to the elements.
+ 1. We could have O(1) complexity moving elements in the middle to the tail.
+ 2. We could avoid introducing another Python queue (e.g., `deque`) which has a wrapper to the elements.
As a result, we will have the following components when the KV cache manager is initialized:
@@ -135,19 +135,19 @@ As a result, we will have the following components when the KV cache manager is
**New request:** Workflow for the scheduler to schedule a new request with KV cache block allocation:
-1. The scheduler calls `kv_cache_manager.get_computed_blocks()` to get a sequence of blocks that have already been computed. This is done by hashing the prompt tokens in the request and looking up Cache Blocks.
+1. The scheduler calls `kv_cache_manager.get_computed_blocks()` to get a sequence of blocks that have already been computed. This is done by hashing the prompt tokens in the request and looking up cache blocks.
2. The scheduler calls `kv_cache_manager.allocate_slots()`. It does the following steps:
- 1. Compute the number of new required blocks, and return if there are no sufficient blocks to allocate.
- 2. โTouchโ the computed blocks. It increases the reference count of the computed block by one, and removes the block from the free queue if the block wasnโt used by other requests. This is to avoid these computed blocks being evicted. See the example in the next section for illustration.
- 3. Allocate new blocks by popping the heads of the free queue. If the head block is a cached block, this also โevictsโ the block so that no other requests can reuse it anymore from now on.
- 4. If an allocated block is already full of tokens, we immediately add it to the Cache Block, so that the block can be reused by other requests in the same batch.
+ 1. Compute the number of new required blocks, and return if there are no sufficient blocks to allocate.
+ 2. โTouchโ the computed blocks. It increases the reference count of the computed block by one, and removes the block from the free queue if the block wasnโt used by other requests. This is to avoid these computed blocks being evicted. See the example in the next section for illustration.
+ 3. Allocate new blocks by popping the heads of the free queue. If the head block is a cached block, this also โevictsโ the block so that no other requests can reuse it anymore from now on.
+ 4. If an allocated block is already full of tokens, we immediately add it to the cache block, so that the block can be reused by other requests in the same batch.
**Running request:** Workflow for the scheduler to schedule a running request with KV cache block allocation:
1. The scheduler calls `kv_cache_manager.allocate_slots()`. It does the following steps:
- 1. Compute the number of new required blocks, and return if there are no sufficient blocks to allocate.
- 2. Allocate new blocks by popping the heads of the free queue. If the head block is a cached block, this also โevictsโ the block so that no other requests can reuse it anymore from now on.
- 3. Append token IDs to the slots in existing blocks as well as the new blocks. If a block is full, we add it to the Cache Block to cache it.
+ 1. Compute the number of new required blocks, and return if there are no sufficient blocks to allocate.
+ 2. Allocate new blocks by popping the heads of the free queue. If the head block is a cached block, this also โevictsโ the block so that no other requests can reuse it anymore from now on.
+ 3. Append token IDs to the slots in existing blocks as well as the new blocks. If a block is full, we add it to the cache block to cache it.
**Duplicated blocks**
Assuming block size is 4 and you send a request (Request 1\) with prompt ABCDEF and decoding length 3:
@@ -199,7 +199,7 @@ When a request is finished, we free all its blocks if no other requests are usin
When the head block (least recently used block) of the free queue is cached, we have to evict the block to prevent it from being used by other requests. Specifically, eviction involves the following steps:
1. Pop the block from the head of the free queue. This is the LRU block to be evicted.
-2. Remove the block ID from the Cache Block.
+2. Remove the block ID from the cache block.
3. Remove the block hash.
## Example
diff --git a/docs/mkdocs/javascript/slack_and_forum.js b/docs/mkdocs/javascript/slack_and_forum.js
new file mode 100644
index 000000000000..9a9233223836
--- /dev/null
+++ b/docs/mkdocs/javascript/slack_and_forum.js
@@ -0,0 +1,56 @@
+/**
+ * slack_and_forum.js
+ *
+ * Adds a custom Slack and Forum button to the MkDocs Material header.
+ *
+ */
+
+window.addEventListener('DOMContentLoaded', () => {
+ const headerInner = document.querySelector('.md-header__inner');
+
+ if (headerInner) {
+ const slackButton = document.createElement('button');
+ slackButton.className = 'slack-button';
+ slackButton.title = 'Join us on Slack';
+ slackButton.style.border = 'none';
+ slackButton.style.background = 'transparent';
+ slackButton.style.cursor = 'pointer';
+
+ slackButton.innerHTML = `
+
+ `;
+
+ slackButton.addEventListener('click', () => {
+ window.open('https://slack.vllm.ai', '_blank', 'noopener');
+ });
+
+ const forumButton = document.createElement('button');
+ forumButton.className = 'forum-button';
+ forumButton.title = 'Join the Forum';
+ forumButton.style.border = 'none';
+ forumButton.style.background = 'transparent';
+ forumButton.style.cursor = 'pointer';
+
+ forumButton.innerHTML = `
+
+ `;
+
+ forumButton.addEventListener('click', () => {
+ window.open('https://discuss.vllm.ai/', '_blank', 'noopener');
+ });
+
+ const githubSource = document.querySelector('.md-header__source');
+ if (githubSource) {
+ githubSource.parentNode.insertBefore(slackButton, githubSource.nextSibling);
+ githubSource.parentNode.insertBefore(forumButton, slackButton.nextSibling);
+ }
+ }
+});
diff --git a/docs/mkdocs/stylesheets/extra.css b/docs/mkdocs/stylesheets/extra.css
index 220657f83d5f..248711f491b9 100644
--- a/docs/mkdocs/stylesheets/extra.css
+++ b/docs/mkdocs/stylesheets/extra.css
@@ -108,3 +108,29 @@ body[data-md-color-scheme="slate"] .md-nav__item--section > label.md-nav__link .
.md-content__button-wrapper a:hover {
color: var(--md-accent-fg-color);
}
+
+/* Slack and Forum css */
+.slack-button,
+.forum-button {
+ display: inline-flex;
+ align-items: center;
+ justify-content: center;
+ margin-left: 0.4rem;
+ height: 24px;
+}
+
+.slack-button img {
+ height: 18px;
+ filter: none !important;
+}
+
+.slack-button:hover,
+.forum-button:hover {
+ opacity: 0.7;
+}
+
+.forum-button svg {
+ height: 28px;
+ opacity: 0.9;
+ transform: translateY(2px);
+}
diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md
index a435c59a3042..0248700292ae 100644
--- a/docs/models/supported_models.md
+++ b/docs/models/supported_models.md
@@ -329,6 +329,7 @@ Specified using `--task generate`.
| `DeepseekForCausalLM` | DeepSeek | `deepseek-ai/deepseek-llm-67b-base`, `deepseek-ai/deepseek-llm-7b-chat` etc. | | โ
๏ธ | โ
๏ธ |
| `DeepseekV2ForCausalLM` | DeepSeek-V2 | `deepseek-ai/DeepSeek-V2`, `deepseek-ai/DeepSeek-V2-Chat` etc. | | โ
๏ธ | โ
๏ธ |
| `DeepseekV3ForCausalLM` | DeepSeek-V3 | `deepseek-ai/DeepSeek-V3-Base`, `deepseek-ai/DeepSeek-V3` etc. | | โ
๏ธ | โ
๏ธ |
+| `Dots1ForCausalLM` | dots.llm1 | `rednote-hilab/dots.llm1.base`, `rednote-hilab/dots.llm1.inst` etc. | | โ
๏ธ | โ
๏ธ |
| `ExaoneForCausalLM` | EXAONE-3 | `LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc. | โ
๏ธ | โ
๏ธ | โ
๏ธ |
| `FalconForCausalLM` | Falcon | `tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc. | | โ
๏ธ | โ
๏ธ |
| `FalconMambaForCausalLM` | FalconMamba | `tiiuae/falcon-mamba-7b`, `tiiuae/falcon-mamba-7b-instruct`, etc. | | โ
๏ธ | โ
๏ธ |
@@ -336,6 +337,7 @@ Specified using `--task generate`.
| `GemmaForCausalLM` | Gemma | `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc. | โ
๏ธ | โ
๏ธ | โ
๏ธ |
| `Gemma2ForCausalLM` | Gemma 2 | `google/gemma-2-9b`, `google/gemma-2-27b`, etc. | โ
๏ธ | โ
๏ธ | โ
๏ธ |
| `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it`, etc. | โ
๏ธ | โ
๏ธ | โ
๏ธ |
+| `Gemma3nForConditionalGeneration` | Gemma 3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | โ
๏ธ |
| `GlmForCausalLM` | GLM-4 | `THUDM/glm-4-9b-chat-hf`, etc. | โ
๏ธ | โ
๏ธ | โ
๏ธ |
| `Glm4ForCausalLM` | GLM-4-0414 | `THUDM/GLM-4-32B-0414`, etc. | โ
๏ธ | โ
๏ธ | โ
๏ธ |
| `GPT2LMHeadModel` | GPT-2 | `gpt2`, `gpt2-xl`, etc. | | โ
๏ธ | โ
๏ธ |
@@ -392,6 +394,9 @@ Specified using `--task generate`.
!!! note
Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096.
+!!! note
+ Only text inputs are currently supported for `Gemma3nForConditionalGeneration`. To use this model, please upgrade Hugging Face Transformers to version 4.53.0.
+
### Pooling Models
See [this page](./pooling_models.md) for more information on how to use pooling models.
@@ -427,7 +432,7 @@ Specified using `--task embed`.
See [relevant issue on HF Transformers](https://github.com/huggingface/transformers/issues/34882).
!!! note
- `jinaai/jina-embeddings-v3` supports multiple tasks through lora, while vllm temporarily only supports text-matching tasks by merging lora weights.
+ `jinaai/jina-embeddings-v3` supports multiple tasks through LoRA, while vllm temporarily only supports text-matching tasks by merging LoRA weights.
!!! note
The second-generation GTE model (mGTE-TRM) is named `NewModel`. The name `NewModel` is too generic, you should set `--hf-overrides '{"architectures": ["GteNewModel"]}'` to specify the use of the `GteNewModel` architecture.
diff --git a/docs/serving/openai_compatible_server.md b/docs/serving/openai_compatible_server.md
index 00756e719992..a3f1ef9fd8b6 100644
--- a/docs/serving/openai_compatible_server.md
+++ b/docs/serving/openai_compatible_server.md
@@ -146,11 +146,6 @@ completion = client.chat.completions.create(
Only `X-Request-Id` HTTP request header is supported for now. It can be enabled
with `--enable-request-id-headers`.
-> Note that enablement of the headers can impact performance significantly at high QPS
-> rates. We recommend implementing HTTP headers at the router level (e.g. via Istio),
-> rather than within the vLLM layer for this reason.
-> See [this PR](https://github.com/vllm-project/vllm/pull/11529) for more details.
-
??? Code
```python
diff --git a/examples/offline_inference/data_parallel.py b/examples/offline_inference/data_parallel.py
index 3eccb4e11ab6..dbf8ed58cc47 100644
--- a/examples/offline_inference/data_parallel.py
+++ b/examples/offline_inference/data_parallel.py
@@ -64,6 +64,18 @@ def parse_args():
parser.add_argument(
"--trust-remote-code", action="store_true", help="Trust remote code."
)
+ parser.add_argument(
+ "--max-num-seqs",
+ type=int,
+ default=64,
+ help=("Maximum number of sequences to be processed in a single iteration."),
+ )
+ parser.add_argument(
+ "--gpu-memory-utilization",
+ type=float,
+ default=0.8,
+ help=("Fraction of GPU memory vLLM is allowed to allocate (0.0, 1.0]."),
+ )
return parser.parse_args()
@@ -77,6 +89,8 @@ def main(
GPUs_per_dp_rank,
enforce_eager,
trust_remote_code,
+ max_num_seqs,
+ gpu_memory_utilization,
):
os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
@@ -127,6 +141,8 @@ def start(rank):
enforce_eager=enforce_eager,
enable_expert_parallel=True,
trust_remote_code=trust_remote_code,
+ max_num_seqs=max_num_seqs,
+ gpu_memory_utilization=gpu_memory_utilization,
)
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
@@ -181,6 +197,8 @@ def start(rank):
tp_size,
args.enforce_eager,
args.trust_remote_code,
+ args.max_num_seqs,
+ args.gpu_memory_utilization,
),
)
proc.start()
diff --git a/examples/offline_inference/eagle.py b/examples/offline_inference/eagle.py
deleted file mode 100644
index f4193fdb8bd3..000000000000
--- a/examples/offline_inference/eagle.py
+++ /dev/null
@@ -1,144 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-import argparse
-import json
-import os
-
-from transformers import AutoTokenizer
-
-from vllm import LLM, SamplingParams
-from vllm.v1.metrics.reader import Counter, Vector
-
-
-def load_prompts(dataset_path, num_prompts):
- if os.path.exists(dataset_path):
- prompts = []
- try:
- with open(dataset_path) as f:
- for line in f:
- data = json.loads(line)
- prompts.append(data["turns"][0])
- except Exception as e:
- print(f"Error reading dataset: {e}")
- return []
- else:
- prompts = ["The future of AI is", "The president of the United States is"]
-
- return prompts[:num_prompts]
-
-
-def parse_args():
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--dataset",
- type=str,
- default="./examples/data/gsm8k.jsonl",
- help="downloaded from the eagle repo "
- "https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/",
- )
- parser.add_argument(
- "--method", type=str, default="eagle", choices=["eagle", "eagle3"]
- )
- parser.add_argument("--max_num_seqs", type=int, default=8)
- parser.add_argument("--num_prompts", type=int, default=80)
- parser.add_argument("--num_spec_tokens", type=int, default=2)
- parser.add_argument("--tp", type=int, default=1)
- parser.add_argument("--draft_tp", type=int, default=1)
- parser.add_argument("--enforce_eager", action="store_true")
- parser.add_argument("--enable_chunked_prefill", action="store_true")
- parser.add_argument("--max_num_batched_tokens", type=int, default=2048)
- parser.add_argument("--temp", type=float, default=0)
- return parser.parse_args()
-
-
-def main():
- args = parse_args()
-
- model_dir = "meta-llama/Llama-3.1-8B-Instruct"
-
- if args.method == "eagle":
- eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
- elif args.method == "eagle3":
- eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
- else:
- raise ValueError(f"unknown method: {args.method}")
-
- max_model_len = 2048
-
- tokenizer = AutoTokenizer.from_pretrained(model_dir)
-
- prompts = load_prompts(args.dataset, args.num_prompts)
-
- prompt_ids = [
- tokenizer.apply_chat_template(
- [{"role": "user", "content": prompt}], add_generation_prompt=True
- )
- for prompt in prompts
- ]
-
- llm = LLM(
- model=model_dir,
- trust_remote_code=True,
- tensor_parallel_size=args.tp,
- enable_chunked_prefill=args.enable_chunked_prefill,
- max_num_batched_tokens=args.max_num_batched_tokens,
- enforce_eager=args.enforce_eager,
- max_model_len=max_model_len,
- max_num_seqs=args.max_num_seqs,
- gpu_memory_utilization=0.8,
- speculative_config={
- "method": args.method,
- "model": eagle_dir,
- "num_speculative_tokens": args.num_spec_tokens,
- "draft_tensor_parallel_size": args.draft_tp,
- "max_model_len": max_model_len,
- },
- disable_log_stats=False,
- )
-
- sampling_params = SamplingParams(temperature=args.temp, max_tokens=256)
-
- outputs = llm.generate(prompt_token_ids=prompt_ids, sampling_params=sampling_params)
-
- # print the generated text
- for output in outputs:
- print("-" * 50)
- print(f"prompt: {output.prompt}")
- print(f"generated text: {output.outputs[0].text}")
- print("-" * 50)
-
- try:
- metrics = llm.get_metrics()
- except AssertionError:
- print("Metrics are not supported in the V0 engine.")
- return
-
- num_drafts = num_accepted = 0
- acceptance_counts = [0] * args.num_spec_tokens
- for metric in metrics:
- if metric.name == "vllm:spec_decode_num_drafts":
- assert isinstance(metric, Counter)
- num_drafts += metric.value
- elif metric.name == "vllm:spec_decode_num_accepted_tokens":
- assert isinstance(metric, Counter)
- num_accepted += metric.value
- elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos":
- assert isinstance(metric, Vector)
- for pos in range(len(metric.values)):
- acceptance_counts[pos] += metric.values[pos]
-
- print("-" * 50)
- print(f"mean acceptance length: {1 + (num_accepted / num_drafts):.2f}")
- print("-" * 50)
-
- # print acceptance at each token position
- for i in range(len(acceptance_counts)):
- print(f"acceptance at token {i}:{acceptance_counts[i] / num_drafts:.2f}")
-
-
-if __name__ == "__main__":
- print(
- "[WARNING] Use examples/offline_inference/spec_decode.py"
- " instead of this script."
- )
- main()
diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py
index eece8beced51..90d103e5cb05 100644
--- a/examples/offline_inference/spec_decode.py
+++ b/examples/offline_inference/spec_decode.py
@@ -16,29 +16,24 @@ def parse_args():
parser = FlexibleArgumentParser()
add_dataset_parser(parser)
parser.add_argument(
- "--dataset",
+ "--method",
type=str,
- default="./examples/data/gsm8k.jsonl",
- help="downloaded from the eagle repo "
- "https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/",
+ default="eagle",
+ choices=["ngram", "eagle", "eagle3", "mtp"],
)
- parser.add_argument(
- "--method", type=str, default="eagle", choices=["ngram", "eagle", "eagle3"]
- )
- parser.add_argument("--max-num-seqs", type=int, default=8)
parser.add_argument("--num-spec-tokens", type=int, default=2)
parser.add_argument("--prompt-lookup-max", type=int, default=5)
parser.add_argument("--prompt-lookup-min", type=int, default=2)
parser.add_argument("--tp", type=int, default=1)
- parser.add_argument("--draft-tp", type=int, default=1)
parser.add_argument("--enforce-eager", action="store_true")
parser.add_argument("--enable-chunked-prefill", action="store_true")
- parser.add_argument("--max-num-batched-tokens", type=int, default=2048)
parser.add_argument("--temp", type=float, default=0)
parser.add_argument("--top-p", type=float, default=1.0)
parser.add_argument("--top-k", type=int, default=-1)
parser.add_argument("--print-output", action="store_true")
parser.add_argument("--output-len", type=int, default=256)
+ parser.add_argument("--model-dir", type=str, default=None)
+ parser.add_argument("--eagle-dir", type=str, default=None)
return parser.parse_args()
@@ -46,9 +41,10 @@ def main():
args = parse_args()
args.endpoint_type = "openai-chat"
- model_dir = "meta-llama/Llama-3.1-8B-Instruct"
+ model_dir = args.model_dir
+ if args.model_dir is None:
+ model_dir = "meta-llama/Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_dir)
- max_model_len = 2048
prompts = get_samples(args, tokenizer)
# add_special_tokens is False to avoid adding bos twice when using chat templates
@@ -57,16 +53,16 @@ def main():
]
if args.method == "eagle" or args.method == "eagle3":
- if args.method == "eagle":
+ eagle_dir = args.eagle_dir
+ if args.method == "eagle" and eagle_dir is None:
eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
- elif args.method == "eagle3":
+
+ elif args.method == "eagle3" and eagle_dir is None:
eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
speculative_config = {
"method": args.method,
"model": eagle_dir,
"num_speculative_tokens": args.num_spec_tokens,
- "draft_tensor_parallel_size": args.draft_tp,
- "max_model_len": max_model_len,
}
elif args.method == "ngram":
speculative_config = {
@@ -74,7 +70,6 @@ def main():
"num_speculative_tokens": args.num_spec_tokens,
"prompt_lookup_max": args.prompt_lookup_max,
"prompt_lookup_min": args.prompt_lookup_min,
- "max_model_len": max_model_len,
}
else:
raise ValueError(f"unknown method: {args.method}")
@@ -86,7 +81,6 @@ def main():
enable_chunked_prefill=args.enable_chunked_prefill,
max_num_batched_tokens=args.max_num_batched_tokens,
enforce_eager=args.enforce_eager,
- max_model_len=max_model_len,
max_num_seqs=args.max_num_seqs,
gpu_memory_utilization=0.8,
speculative_config=speculative_config,
@@ -110,27 +104,41 @@ def main():
print("Metrics are not supported in the V0 engine.")
return
- num_drafts = num_accepted = 0
+ total_num_output_tokens = sum(
+ len(output.outputs[0].token_ids) for output in outputs
+ )
+ num_drafts = 0
+ num_draft_tokens = 0
+ num_accepted_tokens = 0
acceptance_counts = [0] * args.num_spec_tokens
for metric in metrics:
if metric.name == "vllm:spec_decode_num_drafts":
assert isinstance(metric, Counter)
num_drafts += metric.value
+ elif metric.name == "vllm:spec_decode_num_draft_tokens":
+ assert isinstance(metric, Counter)
+ num_draft_tokens += metric.value
elif metric.name == "vllm:spec_decode_num_accepted_tokens":
assert isinstance(metric, Counter)
- num_accepted += metric.value
+ num_accepted_tokens += metric.value
elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos":
assert isinstance(metric, Vector)
for pos in range(len(metric.values)):
acceptance_counts[pos] += metric.values[pos]
print("-" * 50)
- print(f"mean acceptance length: {1 + (num_accepted / num_drafts):.2f}")
+ print(f"total_num_output_tokens: {total_num_output_tokens}")
+ print(f"num_drafts: {num_drafts}")
+ print(f"num_draft_tokens: {num_draft_tokens}")
+ print(f"num_accepted_tokens: {num_accepted_tokens}")
+ acceptance_length = 1 + (num_accepted_tokens / num_drafts) if num_drafts > 0 else 1
+ print(f"mean acceptance length: {acceptance_length:.2f}")
print("-" * 50)
# print acceptance at each token position
for i in range(len(acceptance_counts)):
- print(f"acceptance at token {i}:{acceptance_counts[i] / num_drafts:.2f}")
+ acceptance_rate = acceptance_counts[i] / num_drafts if num_drafts > 0 else 0
+ print(f"acceptance at token {i}: {acceptance_rate:.2f}")
if __name__ == "__main__":
diff --git a/mkdocs.yaml b/mkdocs.yaml
index 9fb3fed8b8ac..45b6ffadbeb7 100644
--- a/mkdocs.yaml
+++ b/mkdocs.yaml
@@ -127,6 +127,7 @@ extra_javascript:
- mkdocs/javascript/run_llm_widget.js
- https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS_HTML
- mkdocs/javascript/edit_and_feedback.js
+ - mkdocs/javascript/slack_and_forum.js
# Makes the url format end in .html rather than act as a dir
# So index.md generates as index.html and is available under URL /index.html
diff --git a/pyproject.toml b/pyproject.toml
index e8c2403af064..fb45572d265b 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -150,6 +150,7 @@ skip_gitignore = true
markers = [
"skip_global_cleanup",
"core_model: enable this model test in each PR instead of only nightly",
+ "hybrid_model: models that contain mamba layers (including pure SSM and hybrid architectures)",
"cpu_model: enable this model test in CPU tests",
"split: run this test as part of a split",
"distributed: run this test only in distributed GPU tests",
diff --git a/requirements/common.txt b/requirements/common.txt
index 9a9ae1d93896..6cc304e5b1f6 100644
--- a/requirements/common.txt
+++ b/requirements/common.txt
@@ -37,7 +37,7 @@ pyyaml
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
setuptools>=77.0.3,<80; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12
einops # Required for Qwen2-VL.
-compressed-tensors == 0.10.1 # required for compressed-tensors
+compressed-tensors == 0.10.2 # required for compressed-tensors
depyf==0.18.0 # required for profiling and debugging with compilation config
cloudpickle # allows pickling lambda functions in model_executor/models/registry.py
watchfiles # required for http server to monitor the updates of TLS files
diff --git a/requirements/cpu-build.txt b/requirements/cpu-build.txt
new file mode 100644
index 000000000000..37f072202bd7
--- /dev/null
+++ b/requirements/cpu-build.txt
@@ -0,0 +1,12 @@
+# Temporarily used for x86 CPU backend to avoid performance regression of torch>2.6.0+cpu,
+# see https://github.com/pytorch/pytorch/pull/151218
+cmake>=3.26.1
+ninja
+packaging>=24.2
+setuptools>=77.0.3,<80.0.0
+setuptools-scm>=8
+--extra-index-url https://download.pytorch.org/whl/cpu
+torch==2.6.0+cpu
+wheel
+jinja2>=3.1.6
+regex
diff --git a/requirements/cpu.txt b/requirements/cpu.txt
index 8742898cff00..df3a3393563a 100644
--- a/requirements/cpu.txt
+++ b/requirements/cpu.txt
@@ -8,7 +8,7 @@ numba == 0.61.2; python_version > '3.9'
packaging>=24.2
setuptools>=77.0.3,<80.0.0
--extra-index-url https://download.pytorch.org/whl/cpu
-torch==2.7.0+cpu; platform_machine == "x86_64"
+torch==2.6.0+cpu; platform_machine == "x86_64" # torch>2.6.0+cpu has performance regression on x86 platform, see https://github.com/pytorch/pytorch/pull/151218
torch==2.7.0; platform_system == "Darwin"
torch==2.7.0; platform_machine == "ppc64le" or platform_machine == "aarch64"
@@ -23,6 +23,7 @@ datasets # for benchmark scripts
# Intel Extension for PyTorch, only for x86_64 CPUs
intel-openmp==2024.2.1; platform_machine == "x86_64"
-intel_extension_for_pytorch==2.7.0; platform_machine == "x86_64"
+intel_extension_for_pytorch==2.6.0; platform_machine == "x86_64" # torch>2.6.0+cpu has performance regression on x86 platform, see https://github.com/pytorch/pytorch/pull/151218
py-libnuma; platform_system != "Darwin"
psutil; platform_system != "Darwin"
+triton==3.2.0; platform_machine == "x86_64" # Triton is required for torch 2.6+cpu, as it is imported in torch.compile.
diff --git a/requirements/nightly_torch_test.txt b/requirements/nightly_torch_test.txt
index 00acda366260..fd0b0fac12a9 100644
--- a/requirements/nightly_torch_test.txt
+++ b/requirements/nightly_torch_test.txt
@@ -1,47 +1,50 @@
-# Dependency that able to run entrypoints test
-# pytest and its extensions
+# testing
pytest
-pytest-asyncio
+tensorizer>=2.9.0
pytest-forked
-pytest-mock
+pytest-asyncio
pytest-rerunfailures
pytest-shard
pytest-timeout
-librosa # required by audio tests in entrypoints/openai
-sentence-transformers # required for embedding tests
-transformers==4.52.4
-transformers_stream_generator # required for qwen-vl test
-numba == 0.61.2; python_version > '3.9'
# testing utils
-boto3
-botocore
-datasets
-ray >= 2.10.0
+backoff # required for phi4mm test
+blobfile # required for kimi-vl test
+einops # required for MPT, qwen-vl and Mamba
+httpx
+librosa # required for audio tests
+vocos # required for minicpmo_26 test
peft
-runai-model-streamer==0.11.0
-runai-model-streamer-s3==0.11.0
-tensorizer>=2.9.0
-lm-eval==0.4.8
-buildkite-test-collector==0.1.9
+pqdm
+ray[cgraph,default]>=2.43.0, !=2.44.* # Ray Compiled Graph, required by pipeline parallelism tests
+sentence-transformers # required for embedding tests
+soundfile # required for audio tests
+jiwer # required for audio tests
+timm # required for internvl test
+transformers_stream_generator # required for qwen-vl test
+matplotlib # required for qwen-vl test
+mistral_common[opencv] >= 1.6.2 # required for pixtral test
+num2words # required for smolvlm test
+opencv-python-headless >= 4.11.0 # required for video test
+datamodel_code_generator # required for minicpm3 test
lm-eval[api]==0.4.8 # required for model evaluation test
-
-# required for quantization test
+mteb>=1.38.11, <2 # required for mteb test
+transformers==4.52.4
+tokenizers==0.21.1
+huggingface-hub[hf_xet]>=0.30.0 # Required for Xet downloads.
+schemathesis>=3.39.15 # Required for openai schema test.
+# quantization
bitsandbytes>=0.45.3
+buildkite-test-collector==0.1.9
-# required for minicpmo_26 test
-vector_quantize_pytorch
-vocos
-
-# required for Basic Models Test
-blobfile # required for kimi-vl test
-matplotlib # required for qwen-vl test
-# required for Multi-Modal Models Test (Standard)
-num2words # required for smolvlm test
-pqdm
-timm # required for internvl test
-mistral-common==1.6.2
+genai_perf==0.0.8
+tritonclient==2.51.0
-schemathesis==3.39.15 # Required for openai schema test.
-mteb>=1.38.11, <2 # required for mteb test
+numba == 0.60.0; python_version == '3.9' # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding
+numba == 0.61.2; python_version > '3.9'
+numpy
+runai-model-streamer==0.11.0
+runai-model-streamer-s3==0.11.0
+fastsafetensors>=0.1.10
+pydantic>=2.10 # 2.9 leads to error on python 3.10
diff --git a/requirements/test.in b/requirements/test.in
index e8f44059fcf8..85c96df8e8f4 100644
--- a/requirements/test.in
+++ b/requirements/test.in
@@ -42,6 +42,7 @@ schemathesis>=3.39.15 # Required for openai schema test.
bitsandbytes>=0.45.3
buildkite-test-collector==0.1.9
+
genai_perf==0.0.8
tritonclient==2.51.0
@@ -51,4 +52,4 @@ numpy
runai-model-streamer==0.11.0
runai-model-streamer-s3==0.11.0
fastsafetensors>=0.1.10
-pydantic>=2.10 # 2.9 leads to error on python 3.10
\ No newline at end of file
+pydantic>=2.10 # 2.9 leads to error on python 3.10
diff --git a/requirements/xpu.txt b/requirements/xpu.txt
index 3cb6a4a8adda..0d95dc57152d 100644
--- a/requirements/xpu.txt
+++ b/requirements/xpu.txt
@@ -9,6 +9,7 @@ setuptools>=77.0.3,<80.0.0
wheel
jinja2>=3.1.6
datasets # for benchmark scripts
+numba == 0.60.0 # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding
torch==2.7.0+xpu
torchaudio
diff --git a/tests/config/test_config_generation.py b/tests/config/test_config_generation.py
new file mode 100644
index 000000000000..024e81fccc5f
--- /dev/null
+++ b/tests/config/test_config_generation.py
@@ -0,0 +1,38 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import pytest
+
+from vllm.engine.arg_utils import EngineArgs
+from vllm.model_executor.layers.quantization.quark.utils import deep_compare
+
+
+def test_cuda_empty_vs_unset_configs(monkeypatch: pytest.MonkeyPatch):
+ """Test that configs created with normal (untouched) CUDA_VISIBLE_DEVICES
+ and CUDA_VISIBLE_DEVICES="" are equivalent. This ensures consistent
+ behavior regardless of whether GPU visibility is disabled via empty string
+ or left in its normal state.
+ """
+
+ def create_config():
+ engine_args = EngineArgs(model="deepseek-ai/DeepSeek-V2-Lite",
+ trust_remote_code=True)
+ return engine_args.create_engine_config()
+
+ # Create config with CUDA_VISIBLE_DEVICES set normally
+ normal_config = create_config()
+
+ # Create config with CUDA_VISIBLE_DEVICES=""
+ with monkeypatch.context() as m:
+ m.setenv("CUDA_VISIBLE_DEVICES", "")
+ empty_config = create_config()
+
+ normal_config_dict = vars(normal_config)
+ empty_config_dict = vars(empty_config)
+
+ # Remove instance_id before comparison as it's expected to be different
+ normal_config_dict.pop("instance_id", None)
+ empty_config_dict.pop("instance_id", None)
+
+ assert deep_compare(normal_config_dict, empty_config_dict), (
+ "Configs with normal CUDA_VISIBLE_DEVICES and CUDA_VISIBLE_DEVICES=\"\""
+ " should be equivalent")
diff --git a/tests/distributed/test_eplb_algo.py b/tests/distributed/test_eplb_algo.py
new file mode 100644
index 000000000000..e47ccba99c81
--- /dev/null
+++ b/tests/distributed/test_eplb_algo.py
@@ -0,0 +1,292 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import pytest
+import torch
+
+from vllm.distributed.eplb.rebalance_algo import rebalance_experts
+
+
+def test_basic_rebalance():
+ """Test basic rebalancing functionality"""
+ # Example from https://github.com/deepseek-ai/eplb
+ weight = torch.tensor([
+ [90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86],
+ [20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27],
+ ])
+
+ num_layers = weight.shape[0]
+ num_replicas = 16
+ num_groups = 4
+ num_nodes = 2
+ num_gpus = 8
+
+ phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
+ num_groups, num_nodes,
+ num_gpus)
+
+ # Verify output shapes
+ assert phy2log.shape == (
+ 2,
+ 16,
+ ), f"Expected `phy2log` shape (2, 16), got {phy2log.shape}"
+ assert (log2phy.shape[0] == 2
+ ), f"Expected `log2phy` first dimension 2, got {log2phy.shape[0]}"
+ assert (
+ log2phy.shape[1] == 12
+ ), f"Expected `log2phy` second dimension 12, got {log2phy.shape[1]}"
+ assert logcnt.shape == (
+ 2,
+ 12,
+ ), f"Expected `logcnt` shape (2, 12), got {logcnt.shape}"
+
+ # Verify physical to logical expert mapping range is correct
+ assert torch.all(phy2log >= 0) and torch.all(
+ phy2log < 12), "Physical to logical mapping should be in range [0, 12)"
+
+ # Verify expert count reasonableness
+ assert torch.all(
+ logcnt >= 1), "Each logical expert should have at least 1 replica"
+ assert (
+ torch.sum(logcnt, dim=1).sum() == num_replicas *
+ num_layers), f"Total replicas should be {num_replicas * num_layers}"
+
+ # Verify expected output
+ expected_phy2log = torch.tensor([
+ [5, 6, 5, 7, 8, 4, 3, 4, 10, 9, 10, 2, 0, 1, 11, 1],
+ [7, 10, 6, 8, 6, 11, 8, 9, 2, 4, 5, 1, 5, 0, 3, 1],
+ ])
+ assert torch.all(phy2log == expected_phy2log)
+
+ expected_logcnt = torch.tensor([[1, 2, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1],
+ [1, 2, 1, 1, 1, 2, 2, 1, 2, 1, 1, 1]])
+ assert torch.all(logcnt == expected_logcnt)
+
+
+def test_single_gpu_case():
+ """Test single GPU case"""
+ weight = torch.tensor([[10, 20, 30, 40]])
+ num_replicas = 4
+ num_groups = 1
+ num_nodes = 1
+ num_gpus = 1
+
+ phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
+ num_groups, num_nodes,
+ num_gpus)
+
+ # Verify shapes
+ assert phy2log.shape == (1, 4)
+ assert log2phy.shape[0] == 1
+ assert log2phy.shape[1] == 4
+ assert logcnt.shape == (1, 4)
+
+ # Verify all logical experts are mapped
+ assert set(phy2log[0].tolist()) == {0, 1, 2, 3}
+
+
+def test_equal_weights():
+ """Test case with equal weights"""
+ weight = torch.tensor([[50, 50, 50, 50, 50, 50, 50, 50]])
+ num_replicas = 8
+ num_groups = 2
+ num_nodes = 2
+ num_gpus = 4
+
+ phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
+ num_groups, num_nodes,
+ num_gpus)
+
+ # Verify shapes
+ assert phy2log.shape == (1, 8)
+ assert logcnt.shape == (1, 8)
+
+ # With equal weights, each expert should have exactly one replica
+ assert torch.all(
+ logcnt == 1
+ ), "With equal weights and no replication, " \
+ "each expert should have exactly 1 replica"
+
+
+def test_extreme_weight_imbalance():
+ """Test extreme weight imbalance case"""
+ weight = torch.tensor([[1000, 1, 1, 1, 1, 1, 1, 1]])
+ num_replicas = 12
+ num_groups = 2
+ num_nodes = 2
+ num_gpus = 4
+
+ phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
+ num_groups, num_nodes,
+ num_gpus)
+
+ # Verify shapes
+ assert phy2log.shape == (1, 12)
+ assert logcnt.shape == (1, 8)
+
+ # Expert with highest weight (index 0) should have more replicas
+ assert (
+ logcnt[0, 0]
+ > logcnt[0, 1]), "Expert with highest weight should have more replicas"
+
+
+def test_multiple_layers():
+ """Test multiple layers case"""
+ weight = torch.tensor([
+ [10, 20, 30, 40, 50, 60], # First layer
+ [60, 50, 40, 30, 20, 10], # Second layer (opposite weight pattern)
+ [25, 25, 25, 25, 25, 25], # Third layer (equal weights)
+ ])
+ num_replicas = 8
+ num_groups = 2
+ num_nodes = 2
+ num_gpus = 4
+
+ phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
+ num_groups, num_nodes,
+ num_gpus)
+
+ # Verify shapes
+ assert phy2log.shape == (3, 8)
+ assert logcnt.shape == (3, 6)
+
+ # Verify expert allocation is reasonable for each layer
+ for layer in range(3):
+ assert torch.all(phy2log[layer] >= 0) and torch.all(
+ phy2log[layer] < 6
+ ), f"Layer {layer} physical to logical mapping" \
+ "should be in range [0, 6)"
+ assert (torch.sum(logcnt[layer]) == num_replicas
+ ), f"Layer {layer} total replicas should be {num_replicas}"
+
+
+def test_parameter_validation():
+ """Test parameter validation"""
+ weight = torch.tensor([[10, 20, 30, 40]])
+
+ # Test non-divisible case - this should handle normally without throwing
+ # errors because the function will fall back to global load balancing
+ # strategy
+ phy2log, log2phy, logcnt = rebalance_experts(weight, 8, 3, 2, 4)
+ assert phy2log.shape == (1, 8)
+ assert logcnt.shape == (1, 4)
+
+ # Test cases that will actually cause errors:
+ # num_physical_experts not divisible by num_gpus
+ with pytest.raises(AssertionError):
+ rebalance_experts(weight, 7, 2, 2, 4) # 7 not divisible by 4
+
+
+def test_small_scale_hierarchical():
+ """Test small-scale hierarchical load balancing"""
+ weight = torch.tensor([
+ [100, 50, 200, 75, 150, 25, 300, 80], # 8 experts
+ ])
+ num_replicas = 12
+ num_groups = 4 # 4 groups, 2 experts each
+ num_nodes = 2 # 2 nodes
+ num_gpus = 4 # 4 GPUs
+
+ phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
+ num_groups, num_nodes,
+ num_gpus)
+
+ # Verify basic constraints
+ assert phy2log.shape == (1, 12)
+ assert logcnt.shape == (1, 8)
+ assert torch.sum(logcnt) == num_replicas
+ assert torch.all(logcnt >= 1)
+
+ # Expert with highest weight should have more replicas
+ max_weight_expert = torch.argmax(weight[0])
+ assert (logcnt[0, max_weight_expert]
+ >= 2), "Highest weight expert should have multiple replicas"
+
+
+def test_global_load_balance_fallback():
+ """Test global load balancing fallback case"""
+ # When num_groups % num_nodes != 0, should fall back to global load
+ # balancing
+ weight = torch.tensor([[10, 20, 30, 40, 50, 60]])
+ num_replicas = 8
+ num_groups = 3 # Cannot be divided evenly by num_nodes=2
+ num_nodes = 2
+ num_gpus = 4
+
+ phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
+ num_groups, num_nodes,
+ num_gpus)
+
+ # Should work normally, just using global load balancing strategy
+ assert phy2log.shape == (1, 8)
+ assert logcnt.shape == (1, 6)
+ assert torch.sum(logcnt) == num_replicas
+
+
+@pytest.mark.parametrize("device", ["cpu", "cuda"])
+def test_device_compatibility(device):
+ """Test device compatibility"""
+ if device == "cuda" and not torch.cuda.is_available():
+ pytest.skip("CUDA not available")
+
+ weight = torch.tensor([[10, 20, 30, 40]], device=device)
+ num_replicas = 6
+ num_groups = 2
+ num_nodes = 1
+ num_gpus = 2
+
+ phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
+ num_groups, num_nodes,
+ num_gpus)
+
+ # Function will convert to CPU internally, but should handle different
+ # device inputs normally
+ assert phy2log.shape == (1, 6)
+ assert logcnt.shape == (1, 4)
+
+
+def test_additional_cases():
+ """Test more edge cases and different parameter combinations"""
+
+ # Test case 1: Large-scale distributed setup
+ weight1 = torch.tensor(
+ [[50, 100, 75, 120, 90, 60, 80, 110, 40, 70, 95, 85, 65, 55, 45, 35]])
+ phy2log1, log2phy1, logcnt1 = rebalance_experts(weight1, 24, 8, 4, 8)
+
+ assert phy2log1.shape == (1, 24)
+ assert logcnt1.shape == (1, 16)
+ assert torch.sum(logcnt1) == 24
+
+ # Test case 2: Different weight distributions
+ weight2 = torch.tensor([
+ [200, 150, 100, 50, 25, 12], # Decreasing weights
+ [12, 25, 50, 100, 150, 200], # Increasing weights
+ ])
+ phy2log2, log2phy2, logcnt2 = rebalance_experts(weight2, 10, 3, 1, 2)
+
+ assert phy2log2.shape == (2, 10)
+ assert logcnt2.shape == (2, 6)
+
+ # Verify high-weight experts have more replicas
+ for layer in range(2):
+ max_weight_idx = torch.argmax(weight2[layer])
+ assert logcnt2[layer, max_weight_idx] >= 2
+
+
+if __name__ == "__main__":
+ weight = torch.tensor([
+ [90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86],
+ [20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27],
+ ])
+
+ num_replicas = 16
+ num_groups = 4
+ num_nodes = 2
+ num_gpus = 8
+
+ phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas,
+ num_groups, num_nodes,
+ num_gpus)
+ print(phy2log)
+
+ test_basic_rebalance()
diff --git a/tests/distributed/test_eplb_execute.py b/tests/distributed/test_eplb_execute.py
new file mode 100644
index 000000000000..de9ed1eabbac
--- /dev/null
+++ b/tests/distributed/test_eplb_execute.py
@@ -0,0 +1,504 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import multiprocessing
+import os
+import random
+
+import pytest
+import torch
+import torch.distributed
+
+from vllm.distributed.eplb.rebalance_execute import (
+ rearrange_expert_weights_inplace)
+from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
+ get_tp_group,
+ init_distributed_environment)
+from vllm.utils import update_environment_variables
+
+
+def distributed_run(fn, world_size):
+ number_of_processes = world_size
+ processes: list[multiprocessing.Process] = []
+ for i in range(number_of_processes):
+ env: dict[str, str] = {}
+ env['RANK'] = str(i)
+ env['LOCAL_RANK'] = str(i)
+ env['WORLD_SIZE'] = str(number_of_processes)
+ env['LOCAL_WORLD_SIZE'] = str(number_of_processes)
+ env['MASTER_ADDR'] = 'localhost'
+ env['MASTER_PORT'] = '12345'
+ p = multiprocessing.Process(target=fn, args=(env, ))
+ processes.append(p)
+ p.start()
+
+ for p in processes:
+ p.join()
+
+ for p in processes:
+ assert p.exitcode == 0
+
+
+def worker_fn_wrapper(fn):
+ # `multiprocessing.Process` cannot accept environment variables directly
+ # so we need to pass the environment variables as arguments
+ # and update the environment variables in the function
+ def wrapped_fn(env):
+ update_environment_variables(env)
+ local_rank = os.environ['LOCAL_RANK']
+ device = torch.device(f"cuda:{local_rank}")
+ torch.cuda.set_device(device)
+ init_distributed_environment()
+
+ # Ensure each worker process has the same random seed
+ random.seed(42)
+ torch.manual_seed(42)
+
+ fn()
+
+ return wrapped_fn
+
+
+def create_expert_indices_with_redundancy(
+ num_layers: int,
+ num_logical_experts: int,
+ total_physical_experts: int,
+ redundancy_config: list[int], # redundancy for each logical expert
+) -> torch.Tensor:
+ """
+ Create expert indices with redundancy.
+
+ Args:
+ num_layers: number of layers
+ num_logical_experts: number of logical experts
+ total_physical_experts: total number of physical experts
+ redundancy_config: redundancy for each logical expert
+
+ Returns:
+ indices: Shape (num_layers, total_physical_experts)
+ """
+ assert sum(redundancy_config) == total_physical_experts
+ assert len(redundancy_config) == num_logical_experts
+
+ indices = torch.zeros(num_layers, total_physical_experts, dtype=torch.long)
+
+ for layer in range(num_layers):
+ physical_pos = 0
+ for logical_expert_id, redundancy in enumerate(redundancy_config):
+ for _ in range(redundancy):
+ indices[layer, physical_pos] = logical_expert_id
+ physical_pos += 1
+
+ # Shuffle the indices at dim 1
+ for layer in range(num_layers):
+ indices[layer] = indices[layer][torch.randperm(indices.shape[1])]
+
+ return indices
+
+
+def create_expert_weights(
+ num_layers: int,
+ num_local_experts: int,
+ hidden_sizes: list[int],
+ rank: int,
+ device: torch.device,
+ physical_to_logical_mapping: torch.Tensor,
+) -> list[list[torch.Tensor]]:
+ """
+ Create fake expert weights tensor for testing.
+
+ Use `arange` to generate predictable weights values, based on logical
+ expert ID.
+ All replicas of the same logical expert should have the same weights.
+
+ Args:
+ physical_to_logical_mapping: Shape (num_layers, num_local_experts)
+ mapping[layer, physical_pos] = logical_expert_id
+ """
+ expert_weights = []
+
+ for layer in range(num_layers):
+ layer_weights = []
+ for weight_idx, hidden_size in enumerate(hidden_sizes):
+ weight_tensor = torch.zeros(num_local_experts,
+ hidden_size,
+ device=device,
+ dtype=torch.float32)
+
+ for local_expert in range(num_local_experts):
+ # Get the logical expert ID for this physical expert
+ global_pos = rank * num_local_experts + local_expert
+ logical_expert_id = physical_to_logical_mapping[
+ layer, global_pos].item()
+
+ # Generate weights based on logical expert ID
+ # (so that all replicas of the same logical expert have the
+ # same weights)
+ base_value = (logical_expert_id * 1000 + layer * 100 +
+ weight_idx * 10)
+ weight_tensor[local_expert] = torch.arange(base_value,
+ base_value +
+ hidden_size,
+ device=device,
+ dtype=torch.float32)
+
+ layer_weights.append(weight_tensor)
+ expert_weights.append(layer_weights)
+
+ return expert_weights
+
+
+def create_redundancy_config(
+ num_logical_experts: int,
+ num_physical_experts: int,
+) -> list[int]:
+ """Create a redundancy configuration."""
+ redundancy_config = [1] * num_logical_experts
+ remaining = num_physical_experts - num_logical_experts
+ # Randomly assign the remaining physical experts to the logical experts
+ for _ in range(remaining):
+ redundancy_config[random.choice(range(num_logical_experts))] += 1
+ return redundancy_config
+
+
+def verify_expert_weights_after_shuffle(
+ expert_weights: list[list[torch.Tensor]],
+ new_indices: torch.Tensor,
+ hidden_sizes: list[int],
+ ep_rank: int,
+ num_local_experts: int,
+):
+ """Verify the weights after shuffling are correct."""
+ num_layers = len(expert_weights)
+
+ for layer in range(num_layers):
+ for weight_idx, hidden_size in enumerate(hidden_sizes):
+ weight_tensor = expert_weights[layer][weight_idx]
+
+ for local_expert in range(num_local_experts):
+ # Calculate the global expert ID for this local expert
+ global_pos = ep_rank * num_local_experts + local_expert
+ expected_logical_expert = new_indices[layer, global_pos].item()
+
+ # Check if the weights are correct
+ actual_weights = weight_tensor[local_expert]
+ expected_base = (expected_logical_expert * 1000 + layer * 100 +
+ weight_idx * 10)
+ expected_weights = torch.arange(expected_base,
+ expected_base + hidden_size,
+ device=actual_weights.device,
+ dtype=actual_weights.dtype)
+
+ torch.testing.assert_close(
+ actual_weights,
+ expected_weights,
+ msg=f"Layer {layer}, weight {weight_idx},"
+ f"local expert {local_expert}: "
+ f"weights do not match. "
+ f"Expected logical expert {expected_logical_expert}")
+
+
+def verify_redundant_experts_have_same_weights(
+ expert_weights: list[list[torch.Tensor]],
+ indices: torch.Tensor,
+ hidden_sizes: list[int],
+ world_size: int,
+ num_local_experts: int,
+):
+ """
+ Verify that all replicas of the same logical expert have the same weights.
+ """
+ num_layers = len(expert_weights)
+ total_physical_experts = world_size * num_local_experts
+
+ for layer in range(num_layers):
+ # Collect weights for all physical experts for each weight matrix
+ all_weights: list[torch.Tensor] = []
+
+ for weight_idx, hidden_size in enumerate(hidden_sizes):
+ # Create tensor to store all expert weights
+ # Shape: [total_physical_experts, hidden_size]
+ gathered_weights = torch.zeros(
+ total_physical_experts,
+ hidden_size,
+ device=expert_weights[layer][weight_idx].device,
+ dtype=expert_weights[layer][weight_idx].dtype)
+
+ # Use all_gather to collect expert weights from current node
+ # expert_weights[layer][weight_idx] shape:
+ # [num_local_experts, hidden_size]
+ local_weights = expert_weights[layer][
+ weight_idx] # [num_local_experts, hidden_size]
+
+ # Split tensor along dim 0 into a list for all_gather
+ gathered_weights_list = torch.chunk(gathered_weights,
+ world_size,
+ dim=0)
+
+ torch.distributed.all_gather(
+ # Output list: each element corresponds to one rank's weights
+ list(gathered_weights_list),
+ local_weights # Input: current rank's local weights
+ )
+
+ all_weights.append(gathered_weights)
+
+ # Verify that all replicas of the same logical expert have the same
+ # weights
+ logical_expert_weights: dict[int, dict[int, torch.Tensor]] = {}
+
+ for physical_pos in range(total_physical_experts):
+ logical_expert_id = int(indices[layer, physical_pos].item())
+
+ if logical_expert_id not in logical_expert_weights:
+ # First time encountering this logical expert, save its weights
+ logical_expert_weights[logical_expert_id] = {
+ weight_idx: all_weights[weight_idx][physical_pos]
+ for weight_idx in range(len(hidden_sizes))
+ }
+ else:
+ # Verify that current physical expert's weights match the
+ # previously saved logical expert weights
+ for weight_idx in range(len(hidden_sizes)):
+ torch.testing.assert_close(
+ all_weights[weight_idx][physical_pos],
+ logical_expert_weights[logical_expert_id][weight_idx],
+ msg=f"Layer {layer}, weight {weight_idx},"
+ f"logical expert {logical_expert_id}: "
+ f"Physical expert {physical_pos} has different weights"
+ f"than expected")
+
+
+@pytest.mark.parametrize(
+ "world_size,num_layers,num_local_experts,num_logical_experts",
+ [
+ # 2 GPU, 2 experts per GPU
+ # 3 logical experts, 4 physical experts, 1 redundant experts
+ (2, 1, 2, 3),
+ # 2 GPU, 3 experts per GPU
+ # 4 logical experts, 6 physical experts, 2 redundant experts
+ (2, 2, 3, 4),
+ # 2 GPU, 8 experts per GPU
+ # 16 logical experts, 16 physical experts, 0 redundant experts
+ (2, 4, 8, 16),
+ # 4 GPU, 2 experts per GPU
+ # 6 logical experts, 8 physical experts, 2 redundant experts
+ (4, 1, 2, 6),
+ # 4 GPU, 2 experts per GPU
+ # 5 logical experts, 8 physical experts, 3 redundant experts
+ (4, 2, 2, 5),
+ # 4 GPU, 8 experts per GPU
+ # 16 logical experts, 32 physical experts, 16 redundant experts
+ (4, 8, 8, 16),
+ ])
+def test_rearrange_expert_weights_with_redundancy(world_size, num_layers,
+ num_local_experts,
+ num_logical_experts):
+ """Test the functionality of rearranging expert weights with redundancy."""
+
+ if torch.cuda.device_count() < world_size:
+ pytest.skip(f"Need at least {world_size} GPUs to run the test")
+
+ @worker_fn_wrapper
+ def worker_fn():
+ # Initialize model parallel (using tensor parallel as an entrypoint
+ # to expert parallel)
+ ensure_model_parallel_initialized(
+ tensor_model_parallel_size=world_size,
+ pipeline_model_parallel_size=1)
+
+ ep_group = get_tp_group().cpu_group
+ ep_rank = torch.distributed.get_rank()
+ device = torch.device(f"cuda:{ep_rank}")
+
+ # Test parameters
+ total_physical_experts = world_size * num_local_experts
+ hidden_sizes = [32, 64] # Two different weight matrices
+
+ # Create old expert indices (with redundancy)
+ redundancy_config = create_redundancy_config(num_logical_experts,
+ total_physical_experts)
+
+ old_indices = create_expert_indices_with_redundancy(
+ num_layers,
+ num_logical_experts,
+ total_physical_experts,
+ redundancy_config,
+ )
+
+ # Create new expert indices (with redundancy)
+ new_redundancy_config = create_redundancy_config(
+ num_logical_experts, total_physical_experts)
+ new_indices = create_expert_indices_with_redundancy(
+ num_layers,
+ num_logical_experts,
+ total_physical_experts,
+ new_redundancy_config,
+ )
+
+ # Create expert weights
+ expert_weights = create_expert_weights(num_layers, num_local_experts,
+ hidden_sizes, ep_rank, device,
+ old_indices)
+
+ # Execute weight rearrangement
+ rearrange_expert_weights_inplace(
+ old_indices,
+ new_indices,
+ expert_weights,
+ ep_group,
+ is_profile=False,
+ )
+
+ # Verify the rearrangement result
+ verify_expert_weights_after_shuffle(
+ expert_weights,
+ new_indices,
+ hidden_sizes,
+ ep_rank,
+ num_local_experts,
+ )
+
+ verify_redundant_experts_have_same_weights(
+ expert_weights,
+ new_indices,
+ hidden_sizes,
+ world_size,
+ num_local_experts,
+ )
+
+ distributed_run(worker_fn, world_size)
+
+
+@pytest.mark.parametrize("world_size", [2, 4])
+def test_rearrange_expert_weights_no_change(world_size):
+ """
+ Test that when the indices do not change, the weights should remain
+ unchanged.
+ """
+
+ if torch.cuda.device_count() < world_size:
+ pytest.skip(f"Need at least {world_size} GPUs to run the test")
+
+ @worker_fn_wrapper
+ def worker_fn():
+ ensure_model_parallel_initialized(
+ tensor_model_parallel_size=world_size,
+ pipeline_model_parallel_size=1)
+
+ ep_group = get_tp_group().cpu_group
+ ep_rank = torch.distributed.get_rank()
+ device = torch.device(f"cuda:{ep_rank}")
+
+ num_layers = 2
+ num_local_experts = 2
+ total_physical_experts = world_size * num_local_experts
+ num_logical_experts = total_physical_experts // 2 # Some redundancy
+ hidden_sizes = [32, 64]
+
+ # Create redundancy configuration
+ redundancy_config = [2] * num_logical_experts
+
+ # Same indices - no change
+ indices = create_expert_indices_with_redundancy(
+ num_layers, num_logical_experts, total_physical_experts,
+ redundancy_config)
+
+ expert_weights = create_expert_weights(num_layers, num_local_experts,
+ hidden_sizes, ep_rank, device,
+ indices)
+
+ # Save original weights
+ original_weights = []
+ for layer_weights in expert_weights:
+ layer_copy = []
+ for weight in layer_weights:
+ layer_copy.append(weight.clone())
+ original_weights.append(layer_copy)
+
+ # Execute rearrangement (should be no change)
+ rearrange_expert_weights_inplace(
+ indices,
+ indices, # Same indices
+ expert_weights,
+ ep_group,
+ is_profile=False)
+
+ # Verify that the weights have not changed
+ for layer in range(num_layers):
+ for weight_idx in range(len(hidden_sizes)):
+ torch.testing.assert_close(
+ expert_weights[layer][weight_idx],
+ original_weights[layer][weight_idx],
+ msg=f"Layer {layer}, weight {weight_idx} should remain "
+ f"unchanged")
+
+ distributed_run(worker_fn, world_size)
+
+
+@pytest.mark.parametrize("world_size", [2, 4])
+def test_rearrange_expert_weights_profile_mode(world_size):
+ """Test profile mode (should not copy actual weights)"""
+
+ if torch.cuda.device_count() < world_size:
+ pytest.skip(f"Need at least {world_size} GPUs to run the test")
+
+ @worker_fn_wrapper
+ def worker_fn():
+ ensure_model_parallel_initialized(
+ tensor_model_parallel_size=world_size,
+ pipeline_model_parallel_size=1)
+
+ ep_group = get_tp_group().cpu_group
+ ep_rank = torch.distributed.get_rank()
+ device = torch.device(f"cuda:{ep_rank}")
+
+ num_layers = 1
+ num_local_experts = 2
+ total_physical_experts = world_size * num_local_experts
+ num_logical_experts = total_physical_experts // 2
+ hidden_sizes = [32]
+
+ # Create different index distributions
+ old_redundancy = create_redundancy_config(num_logical_experts,
+ total_physical_experts)
+ new_redundancy = create_redundancy_config(num_logical_experts,
+ total_physical_experts)
+
+ old_indices = create_expert_indices_with_redundancy(
+ num_layers, num_logical_experts, total_physical_experts,
+ old_redundancy)
+ new_indices = create_expert_indices_with_redundancy(
+ num_layers, num_logical_experts, total_physical_experts,
+ new_redundancy)
+
+ expert_weights = create_expert_weights(num_layers, num_local_experts,
+ hidden_sizes, ep_rank, device,
+ old_indices)
+
+ # Save original weights
+ original_weights = []
+ for layer_weights in expert_weights:
+ layer_copy = []
+ for weight in layer_weights:
+ layer_copy.append(weight.clone())
+ original_weights.append(layer_copy)
+
+ # Execute profile mode rearrangement
+ rearrange_expert_weights_inplace(
+ old_indices,
+ new_indices,
+ expert_weights,
+ ep_group,
+ is_profile=True # Profile mode
+ )
+
+ # In profile mode, the weights should remain unchanged
+ for layer in range(num_layers):
+ for weight_idx in range(len(hidden_sizes)):
+ torch.testing.assert_close(
+ expert_weights[layer][weight_idx],
+ original_weights[layer][weight_idx],
+ msg="In profile mode, the weights should remain unchanged")
+
+ distributed_run(worker_fn, world_size)
diff --git a/tests/distributed/test_node_count.py b/tests/distributed/test_node_count.py
new file mode 100644
index 000000000000..e3c36ef5ef37
--- /dev/null
+++ b/tests/distributed/test_node_count.py
@@ -0,0 +1,43 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import os
+
+import torch.distributed as dist
+
+from vllm.distributed.parallel_state import _node_count
+from vllm.distributed.utils import StatelessProcessGroup
+from vllm.utils import get_ip, get_open_port
+
+if __name__ == "__main__":
+ dist.init_process_group(backend="gloo")
+
+ rank = dist.get_rank()
+ world_size = dist.get_world_size()
+
+ if rank == 0:
+ port = get_open_port()
+ ip = get_ip()
+ dist.broadcast_object_list([ip, port], src=0)
+ else:
+ recv = [None, None]
+ dist.broadcast_object_list(recv, src=0)
+ ip, port = recv
+
+ stateless_pg = StatelessProcessGroup.create(ip, port, rank, world_size)
+
+ for pg in [dist.group.WORLD, stateless_pg]:
+ test_result = _node_count(pg)
+
+ # Expected node count based on environment variable)
+ expected = int(os.environ.get("NUM_NODES", "1"))
+
+ assert test_result == expected, \
+ f"Expected {expected} nodes, got {test_result}"
+
+ if pg == dist.group.WORLD:
+ print(f"Node count test passed! Got {test_result} nodes "
+ f"when using torch distributed!")
+ else:
+ print(f"Node count test passed! Got {test_result} nodes "
+ f"when using StatelessProcessGroup!")
diff --git a/tests/distributed/test_quick_all_reduce.py b/tests/distributed/test_quick_all_reduce.py
new file mode 100644
index 000000000000..a4added29144
--- /dev/null
+++ b/tests/distributed/test_quick_all_reduce.py
@@ -0,0 +1,138 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import random
+
+import pytest
+import ray
+import torch
+import torch.distributed as dist
+
+from vllm.distributed.communication_op import ( # noqa
+ tensor_model_parallel_all_reduce)
+from vllm.distributed.parallel_state import (get_tensor_model_parallel_group,
+ get_tp_group, graph_capture)
+from vllm.platforms import current_platform
+
+from ..utils import (ensure_model_parallel_initialized,
+ init_test_distributed_environment, multi_process_parallel)
+
+torch.manual_seed(42)
+random.seed(44)
+# Size over 8MB is sufficient for custom quick allreduce.
+test_sizes = [
+ random.randint(8 * 1024 * 1024, 10 * 1024 * 1024) for _ in range(8)
+]
+for i, v in enumerate(test_sizes):
+ test_sizes[i] -= v % 8
+
+
+@ray.remote(num_gpus=1, max_calls=1)
+def graph_quickreduce(
+ monkeypatch: pytest.MonkeyPatch,
+ tp_size,
+ pp_size,
+ rank,
+ distributed_init_port,
+):
+ with monkeypatch.context() as m:
+ m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
+ device = torch.device(f"cuda:{rank}")
+ torch.cuda.set_device(device)
+ init_test_distributed_environment(tp_size, pp_size, rank,
+ distributed_init_port)
+ ensure_model_parallel_initialized(tp_size, pp_size)
+ group = get_tensor_model_parallel_group().device_group
+
+ # A small all_reduce for warmup.
+ # this is needed because device communicators might be created lazily
+ # (e.g. NCCL). This will ensure that the communicator is initialized
+ # before any communication happens, so that this group can be used for
+ # graph capture immediately.
+ data = torch.zeros(1)
+ data = data.to(device=device)
+ torch.distributed.all_reduce(data, group=group)
+ torch.cuda.synchronize()
+ del data
+
+ # we use the first group to communicate once
+ # and the second group to communicate twice
+ # and so on
+ # this is used to demonstrate that each group can
+ # communicate independently
+ num_communication = rank // tp_size + 1
+
+ for sz in test_sizes:
+ for dtype in [torch.float16, torch.bfloat16]:
+ with graph_capture(device=device) as graph_capture_context:
+ inp1 = torch.randint(1,
+ 23, (sz, ),
+ dtype=dtype,
+ device=torch.cuda.current_device())
+ inp2 = torch.randint(-23,
+ 1, (sz, ),
+ dtype=dtype,
+ device=torch.cuda.current_device())
+ torch.cuda.synchronize()
+ graph = torch.cuda.CUDAGraph()
+ with torch.cuda.graph(graph,
+ stream=graph_capture_context.stream):
+ for _ in range(num_communication):
+ out1 = tensor_model_parallel_all_reduce(inp1)
+ dist.all_reduce(inp1, group=group)
+ out2 = tensor_model_parallel_all_reduce(inp2)
+ dist.all_reduce(inp2, group=group)
+ graph.replay()
+ torch.testing.assert_close(out1, inp1, atol=2.5, rtol=0.1)
+ torch.testing.assert_close(out2, inp2, atol=2.5, rtol=0.1)
+
+
+@ray.remote(num_gpus=1, max_calls=1)
+def eager_quickreduce(
+ monkeypatch: pytest.MonkeyPatch,
+ tp_size,
+ pp_size,
+ rank,
+ distributed_init_port,
+):
+ with monkeypatch.context() as m:
+ m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
+ device = torch.device(f"cuda:{rank}")
+ torch.cuda.set_device(device)
+
+ init_test_distributed_environment(tp_size, pp_size, rank,
+ distributed_init_port)
+
+ # Size over 8MB is sufficient for custom quick allreduce.
+ sz = 16 * 1024 * 1024
+ fa = get_tp_group().device_communicator.qr_comm
+ inp = torch.tensor([1.0 * ((i) % 23) for i in range(sz)],
+ dtype=torch.float16,
+ device=device)
+ out = fa.quick_all_reduce(inp)
+ torch.testing.assert_close(out, inp * tp_size, atol=2.5, rtol=0.1)
+
+ inp = torch.tensor([1.0 * ((i) % 23) for i in range(sz)],
+ dtype=torch.bfloat16,
+ device=device)
+ out = fa.quick_all_reduce(inp)
+ torch.testing.assert_close(out, inp * tp_size, atol=2.5, rtol=0.1)
+
+
+@pytest.mark.skipif(not current_platform.is_rocm(),
+ reason="only test quick allreduce for rocm")
+@pytest.mark.parametrize("quant_mode", ["FP", "INT8", "INT6", "INT4"])
+@pytest.mark.parametrize("tp_size", [2])
+@pytest.mark.parametrize("pipeline_parallel_size", [1, 2])
+@pytest.mark.parametrize("test_target", [graph_quickreduce, eager_quickreduce])
+def test_custom_quick_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size,
+ pipeline_parallel_size, test_target,
+ quant_mode):
+ world_size = tp_size * pipeline_parallel_size
+ if world_size > torch.cuda.device_count():
+ pytest.skip("Not enough GPUs to run the test.")
+
+ monkeypatch.setenv("VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", quant_mode)
+
+ multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size,
+ test_target)
diff --git a/tests/entrypoints/openai/test_optional_middleware.py b/tests/entrypoints/openai/test_optional_middleware.py
new file mode 100644
index 000000000000..882fa0886ce3
--- /dev/null
+++ b/tests/entrypoints/openai/test_optional_middleware.py
@@ -0,0 +1,116 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+Tests for middleware that's off by default and can be toggled through
+server arguments, mainly --api-key and --enable-request-id-headers.
+"""
+
+from http import HTTPStatus
+
+import pytest
+import requests
+
+from ...utils import RemoteOpenAIServer
+
+# Use a small embeddings model for faster startup and smaller memory footprint.
+# Since we are not testing any chat functionality,
+# using a chat capable model is overkill.
+MODEL_NAME = "intfloat/multilingual-e5-small"
+
+
+@pytest.fixture(scope="module")
+def server(request: pytest.FixtureRequest):
+ passed_params = []
+ if hasattr(request, "param"):
+ passed_params = request.param
+ if isinstance(passed_params, str):
+ passed_params = [passed_params]
+
+ args = [
+ "--task",
+ "embed",
+ # use half precision for speed and memory savings in CI environment
+ "--dtype",
+ "float16",
+ "--max-model-len",
+ "512",
+ "--enforce-eager",
+ "--max-num-seqs",
+ "2",
+ *passed_params
+ ]
+ with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
+ yield remote_server
+
+
+@pytest.mark.asyncio
+async def test_no_api_token(server: RemoteOpenAIServer):
+ response = requests.get(server.url_for("v1/models"))
+ assert response.status_code == HTTPStatus.OK
+
+
+@pytest.mark.asyncio
+async def test_no_request_id_header(server: RemoteOpenAIServer):
+ response = requests.get(server.url_for("health"))
+ assert "X-Request-Id" not in response.headers
+
+
+@pytest.mark.parametrize(
+ "server",
+ [["--api-key", "test"]],
+ indirect=True,
+)
+@pytest.mark.asyncio
+async def test_missing_api_token(server: RemoteOpenAIServer):
+ response = requests.get(server.url_for("v1/models"))
+ assert response.status_code == HTTPStatus.UNAUTHORIZED
+
+
+@pytest.mark.parametrize(
+ "server",
+ [["--api-key", "test"]],
+ indirect=True,
+)
+@pytest.mark.asyncio
+async def test_passed_api_token(server: RemoteOpenAIServer):
+ response = requests.get(server.url_for("v1/models"),
+ headers={"Authorization": "Bearer test"})
+ assert response.status_code == HTTPStatus.OK
+
+
+@pytest.mark.parametrize(
+ "server",
+ [["--api-key", "test"]],
+ indirect=True,
+)
+@pytest.mark.asyncio
+async def test_not_v1_api_token(server: RemoteOpenAIServer):
+ # Authorization check is skipped for any paths that
+ # don't start with /v1 (e.g. /v1/chat/completions).
+ response = requests.get(server.url_for("health"))
+ assert response.status_code == HTTPStatus.OK
+
+
+@pytest.mark.parametrize(
+ "server",
+ ["--enable-request-id-headers"],
+ indirect=True,
+)
+@pytest.mark.asyncio
+async def test_enable_request_id_header(server: RemoteOpenAIServer):
+ response = requests.get(server.url_for("health"))
+ assert "X-Request-Id" in response.headers
+ assert len(response.headers.get("X-Request-Id", "")) == 32
+
+
+@pytest.mark.parametrize(
+ "server",
+ ["--enable-request-id-headers"],
+ indirect=True,
+)
+@pytest.mark.asyncio
+async def test_custom_request_id_header(server: RemoteOpenAIServer):
+ response = requests.get(server.url_for("health"),
+ headers={"X-Request-Id": "Custom"})
+ assert "X-Request-Id" in response.headers
+ assert response.headers.get("X-Request-Id") == "Custom"
diff --git a/tests/kernels/moe/test_deepep_deepgemm_moe.py b/tests/kernels/moe/test_deepep_deepgemm_moe.py
index 2d7cf39a8cca..475427f43928 100644
--- a/tests/kernels/moe/test_deepep_deepgemm_moe.py
+++ b/tests/kernels/moe/test_deepep_deepgemm_moe.py
@@ -6,7 +6,6 @@
"""
import dataclasses
-import importlib
from typing import Optional
import pytest
@@ -21,18 +20,11 @@
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.platforms import current_platform
+from vllm.utils import has_deep_ep, has_deep_gemm
-from .deepep_utils import ProcessGroupInfo, parallel_launch
+from .utils import ProcessGroupInfo, parallel_launch
-has_deep_ep = importlib.util.find_spec("deep_ep") is not None
-
-try:
- import deep_gemm
- has_deep_gemm = True
-except ImportError:
- has_deep_gemm = False
-
-if has_deep_ep:
+if has_deep_ep():
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
@@ -40,19 +32,21 @@
from .deepep_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
-if has_deep_gemm:
+if has_deep_gemm():
+ import deep_gemm
+
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
DeepGemmExperts)
requires_deep_ep = pytest.mark.skipif(
- not has_deep_ep,
+ not has_deep_ep(),
reason="Requires deep_ep kernels",
)
requires_deep_gemm = pytest.mark.skipif(
- not has_deep_gemm,
+ not has_deep_gemm(),
reason="Requires deep_gemm kernels",
)
diff --git a/tests/kernels/moe/test_deepep_moe.py b/tests/kernels/moe/test_deepep_moe.py
index 7e029ea95055..80a36dc39712 100644
--- a/tests/kernels/moe/test_deepep_moe.py
+++ b/tests/kernels/moe/test_deepep_moe.py
@@ -4,7 +4,6 @@
"""
import dataclasses
-import importlib
from typing import Optional, Union
import pytest
@@ -22,12 +21,11 @@
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.platforms import current_platform
+from vllm.utils import has_deep_ep
-from .deepep_utils import ProcessGroupInfo, parallel_launch
+from .utils import ProcessGroupInfo, parallel_launch
-has_deep_ep = importlib.util.find_spec("deep_ep") is not None
-
-if has_deep_ep:
+if has_deep_ep():
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
@@ -36,7 +34,7 @@
from .deepep_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
requires_deep_ep = pytest.mark.skipif(
- not has_deep_ep,
+ not has_deep_ep(),
reason="Requires deep_ep kernels",
)
diff --git a/tests/kernels/moe/test_deepgemm.py b/tests/kernels/moe/test_deepgemm.py
new file mode 100644
index 000000000000..5d2690904cea
--- /dev/null
+++ b/tests/kernels/moe/test_deepgemm.py
@@ -0,0 +1,225 @@
+# SPDX-License-Identifier: Apache-2.0
+"""
+Unit-test DeepGEMM FP8 kernels (no DeepEP).
+Compare DeepGEMM path against the Triton fallback inside vLLM's fused_experts.
+"""
+
+import importlib
+import math
+
+import pytest
+import torch
+
+# vLLM fused-expert reference (Triton fallback + DeepGEMM option)
+from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
+from vllm.model_executor.layers.quantization.utils.fp8_utils import (
+ per_token_group_quant_fp8)
+from vllm.utils import cdiv
+
+has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
+
+if has_deep_gemm:
+ import deep_gemm
+ BLOCK_M = deep_gemm.get_m_alignment_for_contiguous_layout()
+ BLOCK_SIZE = [BLOCK_M, BLOCK_M]
+
+requires_deep_gemm = pytest.mark.skipif(
+ not has_deep_gemm,
+ reason="Requires deep_gemm kernels",
+)
+
+
+def calc_diff(x: torch.Tensor, y: torch.Tensor):
+ x, y = x.double(), y.double()
+ denominator = (x * x + y * y).sum()
+ sim = 2 * (x * y).sum() / denominator
+ return 1 - sim
+
+
+def per_block_cast_to_fp8(
+ x: torch.Tensor,
+ block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
+ assert x.dim() == 2
+ m, n = x.shape
+ x_padded = torch.zeros(
+ (cdiv(m, 128) * 128, cdiv(n, block_size_n) * block_size_n),
+ dtype=x.dtype,
+ device=x.device)
+ x_padded[:m, :n] = x
+ x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n)
+ x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
+ x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
+ x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
+ scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
+ return x_scaled_sub, scales
+
+
+def make_block_quant_fp8_weights(
+ e: int,
+ n: int,
+ k: int,
+ block_size: list[int],
+):
+ """
+ Generate (w1, w2) expert weights and their per-block scale tensors
+ in FP8 block-quantized format.
+
+ w1 shape: (E, 2N, K)
+ w2 shape: (E, K, N)
+ """
+ dtype = torch.bfloat16
+ fp8_max, fp8_min = torch.finfo(torch.float8_e4m3fn).max, torch.finfo(
+ torch.float8_e4m3fn).min
+
+ # bf16 reference weights
+ w1_bf16 = torch.randn(e, 2 * n, k, device="cuda", dtype=dtype) / 10
+ w2_bf16 = torch.randn(e, k, n, device="cuda", dtype=dtype) / 10
+ w1_bf16.clamp_(fp8_min, fp8_max)
+ w2_bf16.clamp_(fp8_min, fp8_max)
+
+ block_n, block_k = block_size
+ n_tiles_w1 = math.ceil((2 * n) / block_n)
+ k_tiles_w1 = math.ceil(k / block_k)
+ n_tiles_w2 = math.ceil(k / block_n)
+ k_tiles_w2 = math.ceil(n / block_k)
+
+ w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn)
+ w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn)
+ w1_s = torch.empty(e,
+ n_tiles_w1,
+ k_tiles_w1,
+ device="cuda",
+ dtype=torch.float32)
+ w2_s = torch.empty(e,
+ n_tiles_w2,
+ k_tiles_w2,
+ device="cuda",
+ dtype=torch.float32)
+
+ for i in range(e):
+ w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i])
+ w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i])
+
+ return w1, w2, w1_s, w2_s
+
+
+def run_single_case(m, n, k, topk, num_experts, block_size):
+ """
+ Run one (M,N,K) configuration on a single GPU and assert DeepGEMM ==
+ Triton baseline within tolerance.
+ """
+ tokens_bf16 = torch.randn(
+ m, k, device="cuda", dtype=torch.bfloat16).clamp_min_(-1).clamp_max_(1)
+ _, a1_scale = per_token_group_quant_fp8(tokens_bf16, block_size[1])
+
+ # expert weight tensors
+ w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(num_experts, n, k,
+ block_size)
+
+ router_logits = torch.randn(m,
+ num_experts,
+ device="cuda",
+ dtype=torch.float32)
+ topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1)
+ topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1)
+
+ # triton referrence
+ out_triton = fused_experts(
+ hidden_states=tokens_bf16,
+ w1=w1,
+ w2=w2,
+ topk_weights=topk_weights,
+ topk_ids=topk_ids,
+ inplace=False,
+ use_fp8_w8a8=True,
+ w1_scale=w1_s,
+ w2_scale=w2_s,
+ a1_scale=a1_scale,
+ block_shape=block_size,
+ allow_deep_gemm=False,
+ )
+
+ # DeepGemm
+ out_deepgemm = fused_experts(
+ hidden_states=tokens_bf16,
+ w1=w1,
+ w2=w2,
+ topk_weights=topk_weights,
+ topk_ids=topk_ids,
+ inplace=False,
+ use_fp8_w8a8=True,
+ w1_scale=w1_s,
+ w2_scale=w2_s,
+ a1_scale=a1_scale,
+ block_shape=block_size,
+ allow_deep_gemm=True,
+ )
+
+ base = out_triton.abs().mean()
+ atol = 0.1 * base.clamp(min=1e-2) # 10% of mean, but not lower than 1e-3
+ rtol = 0.05
+ # ----- Compare -----
+ torch.testing.assert_close(
+ out_deepgemm.to(torch.float32),
+ out_triton.to(torch.float32),
+ rtol=rtol,
+ atol=float(atol),
+ )
+
+
+# Note: W1 has shape (E, 2N, K), so N = 512
+# can trigger the deepgemm path.
+MNKs = [
+ (1024, 512, 128),
+ (1024, 512, 512),
+ (2048, 512, 512),
+ (512, 1024, 1024),
+ (512, 2048, 2048),
+ (4096, 4096, 1024),
+]
+
+TOPKS = [2, 6]
+NUM_EXPERTS = [32]
+
+
+@pytest.mark.parametrize("mnk", MNKs)
+@pytest.mark.parametrize("topk", TOPKS)
+@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
+@requires_deep_gemm
+def test_deepgemm_vs_triton(mnk, topk, num_experts, monkeypatch):
+
+ with monkeypatch.context() as m:
+ m.setenv("VLLM_USE_DEEP_GEMM", "1")
+
+ _fused_moe_mod = importlib.import_module(
+ "vllm.model_executor.layers.fused_moe.fused_moe")
+
+ call_counter = {"cnt": 0}
+
+ orig_fn = _fused_moe_mod.deep_gemm_moe_fp8
+
+ def _spy_deep_gemm_moe_fp8(*args, **kwargs):
+ call_counter["cnt"] += 1
+ return orig_fn(*args, **kwargs)
+
+ monkeypatch.setattr(_fused_moe_mod, "deep_gemm_moe_fp8",
+ _spy_deep_gemm_moe_fp8)
+
+ m, n, k = mnk
+
+ if topk > num_experts:
+ pytest.skip(f"topk={topk} > num_experts={num_experts}")
+
+ run_single_case(
+ m=m,
+ n=n,
+ k=k,
+ topk=topk,
+ num_experts=num_experts,
+ block_size=BLOCK_SIZE,
+ )
+
+ # ensure that the DeepGEMM path was indeed taken.
+ assert call_counter["cnt"] == 1, \
+ f"DeepGEMM path was not executed during the test. " \
+ f"Call counter: {call_counter['cnt']}"
diff --git a/tests/kernels/moe/test_pplx_cutlass_moe.py b/tests/kernels/moe/test_pplx_cutlass_moe.py
index 0caf14f040bb..ee2bdc838b0d 100644
--- a/tests/kernels/moe/test_pplx_cutlass_moe.py
+++ b/tests/kernels/moe/test_pplx_cutlass_moe.py
@@ -15,7 +15,7 @@
FusedMoEModularKernel)
from vllm.platforms import current_platform
-from .deepep_utils import ProcessGroupInfo, parallel_launch
+from .utils import ProcessGroupInfo, parallel_launch
try:
from pplx_kernels import AllToAll
diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py
index c4ad3af6802d..1da14eddff31 100644
--- a/tests/kernels/moe/test_pplx_moe.py
+++ b/tests/kernels/moe/test_pplx_moe.py
@@ -29,7 +29,7 @@
FusedMoEModularKernel)
from vllm.platforms import current_platform
-from .deepep_utils import ProcessGroupInfo, parallel_launch
+from .utils import ProcessGroupInfo, parallel_launch
requires_pplx = pytest.mark.skipif(
not has_pplx,
diff --git a/tests/kernels/moe/deepep_utils.py b/tests/kernels/moe/utils.py
similarity index 97%
rename from tests/kernels/moe/deepep_utils.py
rename to tests/kernels/moe/utils.py
index 117f1babdf62..e4cd8386e102 100644
--- a/tests/kernels/moe/deepep_utils.py
+++ b/tests/kernels/moe/utils.py
@@ -4,6 +4,7 @@
"""
import dataclasses
import importlib
+import os
import traceback
from typing import Callable, Optional
@@ -13,6 +14,8 @@
spawn) # pyright: ignore[reportPrivateImportUsage]
from typing_extensions import Concatenate, ParamSpec
+from vllm.model_executor.layers.fused_moe.utils import find_free_port
+
has_deep_ep = importlib.util.find_spec("deep_ep") is not None
if has_deep_ep:
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
@@ -92,7 +95,7 @@ def parallel_launch(
world_size,
world_size,
0,
- "tcp://localhost:29500",
+ f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{find_free_port()}",
worker,
) + args,
nprocs=world_size,
diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py
index a94215ee397b..140f00294765 100644
--- a/tests/model_executor/test_enabled_custom_ops.py
+++ b/tests/model_executor/test_enabled_custom_ops.py
@@ -28,42 +28,49 @@ class Relu3(ReLUSquaredActivation):
@pytest.mark.parametrize(
- "env, torch_level, ops_enabled, default_on",
+ "env, torch_level, use_inductor, ops_enabled, default_on",
[
# Default values based on compile level
- ("", 0, [True] * 4, True),
- ("", 1, [True] * 4, True),
- ("", 2, [True] * 4, True), # All by default
- ("", 3, [False] * 4, False),
- ("", 4, [False] * 4, False), # None by default
+ # - All by default (no Inductor compilation)
+ ("", 0, False, [True] * 4, True),
+ ("", 1, True, [True] * 4, True),
+ ("", 2, False, [True] * 4, True),
+ # - None by default (with Inductor)
+ ("", 3, True, [False] * 4, False),
+ ("", 4, True, [False] * 4, False),
+ # - All by default (without Inductor)
+ ("", 3, False, [True] * 4, True),
+ ("", 4, False, [True] * 4, True),
# Explicitly enabling/disabling
#
# Default: all
#
# All but SiluAndMul
- ("+rms_norm,-silu_and_mul", 0, [1, 0, 1, 1], True),
+ ("+rms_norm,-silu_and_mul", 0, True, [1, 0, 1, 1], True),
# Only ReLU3
- ("none,-rms_norm,+relu3", 0, [0, 0, 0, 1], False),
+ ("none,-rms_norm,+relu3", 1, False, [0, 0, 0, 1], False),
# All but SiluAndMul
- ("all,-silu_and_mul", 1, [1, 0, 1, 1], True),
+ ("all,-silu_and_mul", 2, True, [1, 0, 1, 1], True),
# All but ReLU3 (even if ReLU2 is on)
- ("-relu3,relu2", 1, [1, 1, 1, 0], True),
- # GeluAndMul and SiluAndMul
- ("none,-relu3,+gelu_and_mul,+silu_and_mul", 2, [0, 1, 1, 0], False),
+ ("-relu3,relu2", 3, False, [1, 1, 1, 0], True),
+ # RMSNorm and SiluAndMul
+ ("none,-relu3,+rms_norm,+silu_and_mul", 4, False, [1, 1, 0, 0], False),
# All but RMSNorm
- ("-rms_norm", 2, [0, 1, 1, 1], True),
+ ("-rms_norm", 3, False, [0, 1, 1, 1], True),
#
# Default: none
#
# Only ReLU3
- ("-silu_and_mul,+relu3", 3, [0, 0, 0, 1], False),
+ ("-silu_and_mul,+relu3", 3, True, [0, 0, 0, 1], False),
# All but RMSNorm
- ("all,-rms_norm", 4, [0, 1, 1, 1], True),
+ ("all,-rms_norm", 4, True, [0, 1, 1, 1], True),
])
-def test_enabled_ops(env: str, torch_level: int, ops_enabled: list[int],
- default_on: bool):
- vllm_config = VllmConfig(compilation_config=CompilationConfig(
- level=torch_level, custom_ops=env.split(",")))
+def test_enabled_ops(env: str, torch_level: int, use_inductor: bool,
+ ops_enabled: list[int], default_on: bool):
+ vllm_config = VllmConfig(
+ compilation_config=CompilationConfig(use_inductor=bool(use_inductor),
+ level=torch_level,
+ custom_ops=env.split(",")))
with set_current_vllm_config(vllm_config):
assert CustomOp.default_on() == default_on
diff --git a/tests/models/language/generation/test_gemma.py b/tests/models/language/generation/test_gemma.py
index ed0f0c19a041..5be4ae874e61 100644
--- a/tests/models/language/generation/test_gemma.py
+++ b/tests/models/language/generation/test_gemma.py
@@ -7,14 +7,21 @@
@pytest.mark.parametrize("model", MODELS)
-def test_dummy_loader(vllm_runner, model: str) -> None:
- with vllm_runner(
- model,
- load_format="dummy",
- ) as llm:
- normalizers = llm.collective_rpc(lambda self: self.worker.model_runner.
- model.model.normalizer.cpu().item())
- assert np.allclose(
- normalizers,
- llm.llm_engine.model_config.hf_config.hidden_size**0.5,
- rtol=1e-3)
+def test_dummy_loader(vllm_runner, monkeypatch, model: str) -> None:
+ with monkeypatch.context() as m:
+ m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
+ with vllm_runner(
+ model,
+ load_format="dummy",
+ ) as llm:
+ if model == "google/gemma-3-4b-it":
+ normalizers = llm.model.collective_rpc(
+ lambda self: self.model_runner.model.language_model.model.
+ normalizer.cpu().item())
+ config = llm.model.llm_engine.model_config.hf_config.text_config
+ else:
+ normalizers = llm.model.collective_rpc(
+ lambda self: self.model_runner.model.model.normalizer.cpu(
+ ).item())
+ config = llm.model.llm_engine.model_config.hf_config
+ assert np.allclose(normalizers, config.hidden_size**0.5, rtol=2e-3)
diff --git a/tests/models/language/generation/test_granitemoehybrid.py b/tests/models/language/generation/test_granitemoehybrid.py
deleted file mode 100644
index 952449f28415..000000000000
--- a/tests/models/language/generation/test_granitemoehybrid.py
+++ /dev/null
@@ -1,42 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-import pytest
-
-from ...utils import check_logprobs_close
-
-# Path of the checkpoints
-MODELS = [
- "ibm-granite/granite-4.0-tiny-preview",
-]
-
-
-@pytest.mark.skip(
- reason="Granite 4.0 is not yet available in huggingface transformers")
-@pytest.mark.parametrize("model", MODELS)
-@pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
-@pytest.mark.parametrize("max_tokens", [64])
-@pytest.mark.parametrize("num_logprobs", [5])
-def test_model_equivalence_to_hf_greedy(
- hf_runner,
- vllm_runner,
- example_prompts,
- model: str,
- dtype: str,
- max_tokens: int,
- num_logprobs: int,
-):
- with vllm_runner(model, dtype=dtype) as vllm_model:
- vllm_outputs = vllm_model.generate_greedy_logprobs(
- example_prompts, max_tokens, num_logprobs)
-
- with hf_runner(model, dtype=dtype) as hf_model:
- hf_outputs = hf_model.generate_greedy_logprobs_limit(
- example_prompts, max_tokens, num_logprobs)
-
- check_logprobs_close(
- outputs_0_lst=hf_outputs,
- outputs_1_lst=vllm_outputs,
- name_0="hf",
- name_1="vllm",
- )
diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py
index 90c4cd968e7a..e6dd6c35e64d 100644
--- a/tests/models/language/generation/test_hybrid.py
+++ b/tests/models/language/generation/test_hybrid.py
@@ -9,6 +9,9 @@
from ...utils import check_logprobs_close, check_outputs_equal
+# Mark all tests as hybrid
+pytestmark = pytest.mark.hybrid_model
+
# NOTE: The first model in each list is taken as the primary model,
# meaning that it will be used in all tests in this file
# The rest of the models will only be tested by test_models
@@ -25,8 +28,9 @@
HYBRID_MODELS = [
"ai21labs/Jamba-tiny-dev",
- # NOTE: ibm-granite/granite-4.0-tiny-preview are skipped currently as
- # it is not yet available in huggingface transformers
+ # NOTE: Currently the test failes due to HF transformers issue fixed in:
+ # https://github.com/huggingface/transformers/pull/39033
+ # We will enable vLLM test for Granite after next HF transformers release.
# "ibm-granite/granite-4.0-tiny-preview",
# NOTE: Running Plamo2 in transformers implementation requires to install
# causal-conv1d package, which is not listed as a test dependency as it's
diff --git a/tests/models/language/generation/test_mistral.py b/tests/models/language/generation/test_mistral.py
index bdd857ff5062..c70698ede37a 100644
--- a/tests/models/language/generation/test_mistral.py
+++ b/tests/models/language/generation/test_mistral.py
@@ -10,6 +10,7 @@
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
MistralToolCall, MistralToolParser)
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
+from vllm.transformers_utils.tokenizer import MistralTokenizer
from ...utils import check_logprobs_close
@@ -318,3 +319,53 @@ def test_mistral_guided_decoding(
schema=SAMPLE_JSON_SCHEMA)
except jsonschema.exceptions.ValidationError:
pytest.fail("Generated response is not valid with JSON schema")
+
+
+def test_mistral_function_call_nested_json():
+ """Ensure that the function-name regex captures the entire outer-most
+ JSON block, including nested braces."""
+
+ # Create a minimal stub tokenizer that provides the few attributes the
+ # parser accesses (`version` and `get_vocab`).
+ class _StubMistralTokenizer(MistralTokenizer):
+ version = 11 # Satisfy the version check
+
+ def __init__(self):
+ pass
+
+ @staticmethod
+ def get_vocab():
+ # Provide the special TOOL_CALLS token expected by the parser.
+ return {"[TOOL_CALLS]": 0}
+
+ tokenizer = _StubMistralTokenizer()
+ parser = MistralToolParser(tokenizer)
+
+ # Craft a model output featuring nested JSON inside the arguments.
+ args_dict = {
+ "city": "Dallas",
+ "state": "TX",
+ "unit": "fahrenheit",
+ "sub_dict": {
+ "foo": "bar",
+ "inner": {
+ "x": 1,
+ "y": 2
+ }
+ },
+ }
+
+ model_output = (
+ f"{parser.bot_token}get_current_weather{json.dumps(args_dict)}")
+
+ parsed = parser.extract_tool_calls(model_output, None)
+
+ # Assertions: the tool call is detected and the full nested JSON is parsed
+ # without truncation.
+ assert parsed.tools_called
+
+ assert MistralToolCall.is_valid_id(parsed.tool_calls[0].id)
+ assert parsed.tool_calls[0].function.name == "get_current_weather"
+ assert json.loads(parsed.tool_calls[0].function.arguments) == args_dict
+ # No additional content outside the tool call should be returned.
+ assert parsed.content is None
diff --git a/tests/models/language/pooling/mteb_utils.py b/tests/models/language/pooling/mteb_utils.py
index 21d55c418c36..0284e69f3f0e 100644
--- a/tests/models/language/pooling/mteb_utils.py
+++ b/tests/models/language/pooling/mteb_utils.py
@@ -43,7 +43,7 @@ def encode(
# issues by randomizing the order.
r = self.rng.permutation(len(sentences))
sentences = [sentences[i] for i in r]
- outputs = self.model.encode(sentences, use_tqdm=False)
+ outputs = self.model.embed(sentences, use_tqdm=False)
embeds = np.array(outputs)
embeds = embeds[np.argsort(r)]
return embeds
@@ -250,16 +250,19 @@ def mteb_test_rerank_models(hf_runner,
with vllm_runner(model_info.name,
task="score",
max_model_len=None,
+ max_num_seqs=8,
**vllm_extra_kwargs) as vllm_model:
+ model_config = vllm_model.model.llm_engine.model_config
+
if model_info.architecture:
- assert (model_info.architecture
- in vllm_model.model.llm_engine.model_config.architectures)
+ assert (model_info.architecture in model_config.architectures)
+ assert model_config.hf_config.num_labels == 1
vllm_main_score = run_mteb_rerank(VllmMtebEncoder(vllm_model),
tasks=MTEB_RERANK_TASKS,
languages=MTEB_RERANK_LANGS)
- vllm_dtype = vllm_model.model.llm_engine.model_config.dtype
+ vllm_dtype = model_config.dtype
with hf_runner(model_info.name, is_cross_encoder=True,
dtype="float32") as hf_model:
diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py
index 496850b19af4..9d63339737ce 100644
--- a/tests/models/multimodal/generation/test_common.py
+++ b/tests/models/multimodal/generation/test_common.py
@@ -107,6 +107,8 @@
),
limit_mm_per_prompt={"image": 4},
)],
+ # TODO: Revert to "auto" when CPU backend can use torch > 2.6
+ dtype="bfloat16" if current_platform.is_cpu() else "auto",
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
),
"paligemma": VLMTestInfo(
diff --git a/tests/models/multimodal/generation/vlm_utils/builders.py b/tests/models/multimodal/generation/vlm_utils/builders.py
index 7d20dd66089b..03c08240d6a8 100644
--- a/tests/models/multimodal/generation/vlm_utils/builders.py
+++ b/tests/models/multimodal/generation/vlm_utils/builders.py
@@ -203,6 +203,9 @@ def build_embedding_inputs_from_test_info(
images = [asset.pil_image for asset in image_assets]
embeds = test_info.convert_assets_to_embeddings(image_assets)
+ if test_info.dtype != "auto":
+ dtype = getattr(torch, test_info.dtype) # type: ignore
+ embeds = [e.to(dtype=dtype) for e in embeds]
assert len(images) == len(model_prompts)
inputs = build_single_image_inputs(images, model_prompts, size_wrapper)
diff --git a/tests/models/registry.py b/tests/models/registry.py
index 4a587e39ad4c..e56dd19bec67 100644
--- a/tests/models/registry.py
+++ b/tests/models/registry.py
@@ -70,6 +70,12 @@ class _HfExamplesInfo:
length that is too large to fit into memory in CI.
"""
+ revision: Optional[str] = None
+ """
+ The specific revision (commit hash, tag, or branch) to use for the model.
+ If not specified, the default revision will be used.
+ """
+
def check_transformers_version(
self,
*,
@@ -164,6 +170,8 @@ def check_available_online(
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"),
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
"Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"),
+ "Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it", # noqa: E501
+ min_transformers_version="4.53"),
"GlmForCausalLM": _HfExamplesInfo("THUDM/glm-4-9b-chat-hf"),
"Glm4ForCausalLM": _HfExamplesInfo("THUDM/GLM-4-9B-0414"),
"GPT2LMHeadModel": _HfExamplesInfo("openai-community/gpt2",
@@ -205,7 +213,8 @@ def check_available_online(
"MiniCPM3ForCausalLM": _HfExamplesInfo("openbmb/MiniCPM3-4B",
trust_remote_code=True),
"MiniMaxText01ForCausalLM": _HfExamplesInfo("MiniMaxAI/MiniMax-Text-01",
- trust_remote_code=True),
+ trust_remote_code=True,
+ revision="a59aa9cbc53b9fb8742ca4e9e1531b9802b6fdc3"), # noqa: E501
"MiniMaxM1ForCausalLM": _HfExamplesInfo("MiniMaxAI/MiniMax-M1-40k",
trust_remote_code=True),
"MistralForCausalLM": _HfExamplesInfo("mistralai/Mistral-7B-Instruct-v0.1"),
@@ -259,6 +268,8 @@ def check_available_online(
"Zamba2ForCausalLM": _HfExamplesInfo("Zyphra/Zamba2-7B-instruct"),
"MiMoForCausalLM": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL",
trust_remote_code=True),
+ "Dots1ForCausalLM": _HfExamplesInfo("rednote-hilab/dots.llm1.inst",
+ min_transformers_version="4.53"),
# [Encoder-decoder]
"BartModel": _HfExamplesInfo("facebook/bart-base"),
"BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"),
diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py
index 54e8cd597bfc..df72607767fd 100644
--- a/tests/models/test_initialization.py
+++ b/tests/models/test_initialization.py
@@ -31,12 +31,20 @@ def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig:
text_config = hf_config.get_text_config()
+ # Ensure at least 2 expert per group
+ # Since `grouped_topk` assums top-2
+ num_experts = getattr(text_config, 'n_group', 1) * 2
+
text_config.update({
"num_layers": 1,
"num_hidden_layers": 1,
- "num_experts": 2,
+ "num_experts": num_experts,
"num_experts_per_tok": 2,
- "num_local_experts": 2,
+ "num_local_experts": num_experts,
+ # Otherwise there will not be any expert layers
+ "first_k_dense_replace": 0,
+ # To avoid OOM on DeepSeek-V3
+ "n_routed_experts": num_experts,
})
if hasattr(hf_config, "vision_config"):
@@ -80,6 +88,7 @@ def _initialize_kv_caches_v1(self, vllm_config):
model_info.default,
tokenizer=model_info.tokenizer,
tokenizer_mode=model_info.tokenizer_mode,
+ revision=model_info.revision,
speculative_config={
"model": model_info.speculative_model,
"num_speculative_tokens": 1,
diff --git a/tests/models/test_oot_registration.py b/tests/models/test_oot_registration.py
index ef0ad613d525..59de35644c12 100644
--- a/tests/models/test_oot_registration.py
+++ b/tests/models/test_oot_registration.py
@@ -53,7 +53,9 @@ def test_oot_registration_embedding(
with monkeypatch.context() as m:
m.setenv("VLLM_PLUGINS", "register_dummy_model")
prompts = ["Hello, my name is", "The text does not matter"]
- llm = LLM(model=dummy_gemma2_embedding_path, load_format="dummy")
+ llm = LLM(model=dummy_gemma2_embedding_path,
+ load_format="dummy",
+ max_model_len=2048)
outputs = llm.embed(prompts)
for output in outputs:
diff --git a/tests/mq_llm_engine/test_error_handling.py b/tests/mq_llm_engine/test_error_handling.py
index 49b02279d61b..3feee01dadf7 100644
--- a/tests/mq_llm_engine/test_error_handling.py
+++ b/tests/mq_llm_engine/test_error_handling.py
@@ -66,7 +66,7 @@ async def test_evil_forward(tmp_socket):
with pytest.raises(MQEngineDeadError):
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
- request_id=uuid.uuid4()):
+ request_id=str(uuid.uuid4())):
pass
assert client.errored
@@ -115,7 +115,7 @@ async def test_failed_health_check(tmp_socket):
with pytest.raises(MQEngineDeadError):
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
- request_id=uuid.uuid4()):
+ request_id=str(uuid.uuid4())):
pass
client.close()
@@ -157,7 +157,7 @@ async def test_failed_abort(tmp_socket):
async for _ in client.generate(
prompt="Hello my name is",
sampling_params=SamplingParams(max_tokens=10),
- request_id=uuid.uuid4()):
+ request_id=str(uuid.uuid4())):
pass
assert "KeyError" in repr(execinfo.value)
assert client.errored
@@ -189,7 +189,7 @@ async def do_generate(client):
params = SamplingParams(min_tokens=2048, max_tokens=2048)
async for _ in client.generate(prompt="Hello my name is",
sampling_params=params,
- request_id=uuid.uuid4()):
+ request_id=str(uuid.uuid4())):
pass
tasks = [asyncio.create_task(do_generate(client)) for _ in range(10)]
@@ -289,7 +289,7 @@ async def test_engine_process_death(tmp_socket):
with pytest.raises(MQEngineDeadError):
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
- request_id=uuid.uuid4()):
+ request_id=str(uuid.uuid4())):
pass
# And the health check should show the engine is dead
diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py
index 516bf4513816..3646ad6c481b 100644
--- a/tests/quantization/test_compressed_tensors.py
+++ b/tests/quantization/test_compressed_tensors.py
@@ -17,7 +17,7 @@
CompressedTensorsW4A4Fp4, CompressedTensorsW4A16Fp4,
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
- CompressedTensorsWNA16)
+ CompressedTensorsWNA16, cutlass_fp4_supported)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
sparse_cutlass_supported)
from vllm.platforms import current_platform
@@ -668,8 +668,8 @@ def check_model(model):
assert isinstance(qkv_proj.quant_method,
CompressedTensorsLinearMethod)
if isinstance(qkv_proj.scheme, scheme) or isinstance(
- qkv_proj.scheme, CompressedTensorsW4A16Fp4
- ) and not CompressedTensorsW4A4Fp4.cutlass_fp4_supported():
+ qkv_proj.scheme,
+ CompressedTensorsW4A16Fp4) and not cutlass_fp4_supported():
assert True
else:
raise AssertionError("FP4 Scheme Mismatch")
diff --git a/tests/spec_decode/e2e/test_eagle_correctness.py b/tests/spec_decode/e2e/test_eagle_correctness.py
index fd838285aba7..7c369feec415 100644
--- a/tests/spec_decode/e2e/test_eagle_correctness.py
+++ b/tests/spec_decode/e2e/test_eagle_correctness.py
@@ -370,6 +370,10 @@ def test_llama2_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
+ # 2 for small prompt, 256//16 for generated.
+ "num_gpu_blocks_override": 2 + 256 // 16,
+ "max_model_len": (2 + 256 // 16) * 16,
+
# Skip cuda graph recording for fast test.
"enforce_eager": True,
@@ -420,6 +424,10 @@ def test_llama3_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
+ # 2 for small prompt, 256//16 for generated.
+ "num_gpu_blocks_override": 2 + 256 // 16,
+ "max_model_len": (2 + 256 // 16) * 16,
+
# Skip cuda graph recording for fast test.
"enforce_eager": True,
diff --git a/tests/standalone_tests/pytorch_nightly_dependency.sh b/tests/standalone_tests/pytorch_nightly_dependency.sh
new file mode 100644
index 000000000000..cb531e13ecb8
--- /dev/null
+++ b/tests/standalone_tests/pytorch_nightly_dependency.sh
@@ -0,0 +1,42 @@
+#!/bin/sh
+# This script tests if the nightly torch packages are not overridden by the dependencies
+
+set -e
+set -x
+
+cd /vllm-workspace/
+
+rm -rf .venv
+
+uv venv .venv
+
+source .venv/bin/activate
+
+# check the environment
+uv pip freeze
+
+echo ">>> Installing nightly torch packages"
+uv pip install --quiet torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu128
+
+echo ">>> Capturing torch-related versions before requirements install"
+uv pip freeze | grep -E '^torch|^torchvision|^torchaudio' | sort > before.txt
+echo "Before:"
+cat before.txt
+
+echo ">>> Installing requirements/nightly_torch_test.txt"
+uv pip install --quiet -r requirements/nightly_torch_test.txt
+
+echo ">>> Capturing torch-related versions after requirements install"
+uv pip freeze | grep -E '^torch|^torchvision|^torchaudio' | sort > after.txt
+echo "After:"
+cat after.txt
+
+echo ">>> Comparing versions"
+if diff before.txt after.txt; then
+ echo "torch version not overridden."
+else
+ echo "torch version overridden by nightly_torch_test.txt, \
+ if the dependency is not triggered by the pytroch nightly test,\
+ please add the dependency to the list 'white_list' in tools/generate_nightly_torch_test.py"
+ exit 1
+fi
diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py
index 8994816a3017..652a556659fe 100644
--- a/tests/v1/core/test_scheduler.py
+++ b/tests/v1/core/test_scheduler.py
@@ -10,7 +10,7 @@
SchedulerConfig, SpeculativeConfig, VllmConfig)
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
-from vllm.v1.core.sched.output import SchedulerOutput
+from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec)
@@ -198,7 +198,7 @@ def test_schedule(enable_prefix_caching: Optional[bool],
# Test initial scheduling
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == len(requests)
- assert len(output.scheduled_cached_reqs) == 0
+ assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0
# Verify all requests are scheduled.
for req_id, num_tokens in output.num_scheduled_tokens.items():
@@ -225,7 +225,7 @@ def test_schedule_multimodal_requests():
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == len(requests)
- assert len(output.scheduled_cached_reqs) == 0
+ assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0
for req_id, num_tokens in output.num_scheduled_tokens.items():
assert num_tokens == len(requests[int(req_id)].prompt_token_ids)
@@ -259,7 +259,7 @@ def test_schedule_partial_requests():
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 3
- assert len(output.scheduled_cached_reqs) == 0
+ assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0
assert scheduler.max_num_encoder_input_tokens == 1024
@@ -295,7 +295,7 @@ def test_schedule_partial_requests():
output = scheduler.schedule()
assert len(scheduler.running) == 3
assert len(output.scheduled_new_reqs) == 0
- assert len(output.scheduled_cached_reqs) == 2
+ assert output.scheduled_cached_reqs.num_reqs == 2
assert len(output.finished_req_ids) == 0
assert output.num_scheduled_tokens[requests[0].request_id] == 1
assert output.num_scheduled_tokens[requests[1].request_id] == 700
@@ -319,7 +319,7 @@ def test_no_mm_input_chunking():
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 1
- assert len(output.scheduled_cached_reqs) == 0
+ assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0
# We want to only see the 400 text tokens at the start scheduled
assert output.num_scheduled_tokens[requests[0].request_id] == 400
@@ -342,7 +342,7 @@ def test_no_mm_input_chunking():
output = scheduler.schedule()
assert len(scheduler.running) == 1
assert len(output.scheduled_new_reqs) == 0
- assert len(output.scheduled_cached_reqs) == 1
+ assert output.scheduled_cached_reqs.num_reqs == 1
assert len(output.finished_req_ids) == 0
assert output.num_scheduled_tokens[requests[0].request_id] == 800
@@ -379,7 +379,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 3
- assert len(output.scheduled_cached_reqs) == 0
+ assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0
# The first request is scheduled partially - 400.
@@ -408,7 +408,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
output1 = scheduler.schedule()
assert len(scheduler.running) == 3
assert len(output1.scheduled_new_reqs) == 0
- assert len(output1.scheduled_cached_reqs) == 3
+ assert output1.scheduled_cached_reqs.num_reqs == 3
assert len(output1.finished_req_ids) == 0
assert output1.num_scheduled_tokens[requests[0].request_id] == 400
assert output1.num_scheduled_tokens[requests[1].request_id] == 400
@@ -430,7 +430,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
output2 = scheduler.schedule()
assert len(scheduler.running) == 3
assert len(output2.scheduled_new_reqs) == 0
- assert len(output2.scheduled_cached_reqs) == 3
+ assert output2.scheduled_cached_reqs.num_reqs == 3
assert len(output2.finished_req_ids) == 0
assert output2.num_scheduled_tokens[requests[0].request_id] == 1
assert output2.num_scheduled_tokens[requests[1].request_id] == 1
@@ -449,23 +449,24 @@ def test_stop_via_update_from_output():
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
- scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
- scheduled_cached_reqs=[],
- num_scheduled_tokens={
- requests[0].request_id: 1,
- requests[1].request_id: 2
- },
- total_num_scheduled_tokens=3,
- scheduled_encoder_inputs={},
- scheduled_spec_decode_tokens={
- requests[0].request_id: [],
- requests[1].request_id: [10]
- },
- num_common_prefix_blocks=0,
- finished_req_ids=set(),
- free_encoder_input_ids=[],
- structured_output_request_ids={},
- grammar_bitmask=None)
+ scheduler_output = SchedulerOutput(
+ scheduled_new_reqs=[],
+ scheduled_cached_reqs=CachedRequestData.make_empty(),
+ num_scheduled_tokens={
+ requests[0].request_id: 1,
+ requests[1].request_id: 2
+ },
+ total_num_scheduled_tokens=3,
+ scheduled_encoder_inputs={},
+ scheduled_spec_decode_tokens={
+ requests[0].request_id: [],
+ requests[1].request_id: [10]
+ },
+ num_common_prefix_blocks=0,
+ finished_req_ids=set(),
+ free_encoder_input_ids=[],
+ structured_output_request_ids={},
+ grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
@@ -501,23 +502,25 @@ def test_stop_via_update_from_output():
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
- scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
- scheduled_cached_reqs=[],
- num_scheduled_tokens={
- requests[0].request_id: 3,
- requests[1].request_id: 2
- },
- total_num_scheduled_tokens=5,
- scheduled_encoder_inputs={},
- scheduled_spec_decode_tokens={
- requests[0].request_id: [10, 42],
- requests[1].request_id: [13]
- },
- num_common_prefix_blocks=0,
- finished_req_ids=set(),
- free_encoder_input_ids=[],
- structured_output_request_ids={},
- grammar_bitmask=None)
+ scheduler_output = SchedulerOutput(
+ scheduled_new_reqs=[],
+ scheduled_cached_reqs=CachedRequestData.make_empty(),
+ num_scheduled_tokens={
+ requests[0].request_id: 3,
+ requests[1].request_id: 2
+ },
+ total_num_scheduled_tokens=5,
+ scheduled_encoder_inputs={},
+ scheduled_spec_decode_tokens={
+ requests[0].request_id: [10, 42],
+ requests[1].request_id: [13]
+ },
+ num_common_prefix_blocks=0,
+ finished_req_ids=set(),
+ free_encoder_input_ids=[],
+ structured_output_request_ids={},
+ grammar_bitmask=None,
+ )
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
@@ -551,23 +554,25 @@ def test_stop_via_update_from_output():
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
- scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
- scheduled_cached_reqs=[],
- num_scheduled_tokens={
- requests[0].request_id: 3,
- requests[1].request_id: 1
- },
- total_num_scheduled_tokens=4,
- scheduled_encoder_inputs={},
- scheduled_spec_decode_tokens={
- requests[0].request_id: [10, 11],
- requests[1].request_id: []
- },
- num_common_prefix_blocks=0,
- finished_req_ids=set(),
- free_encoder_input_ids=[],
- structured_output_request_ids={},
- grammar_bitmask=None)
+ scheduler_output = SchedulerOutput(
+ scheduled_new_reqs=[],
+ scheduled_cached_reqs=CachedRequestData.make_empty(),
+ num_scheduled_tokens={
+ requests[0].request_id: 3,
+ requests[1].request_id: 1
+ },
+ total_num_scheduled_tokens=4,
+ scheduled_encoder_inputs={},
+ scheduled_spec_decode_tokens={
+ requests[0].request_id: [10, 11],
+ requests[1].request_id: []
+ },
+ num_common_prefix_blocks=0,
+ finished_req_ids=set(),
+ free_encoder_input_ids=[],
+ structured_output_request_ids={},
+ grammar_bitmask=None,
+ )
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
@@ -603,7 +608,7 @@ def test_stop_via_update_from_output():
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
- scheduled_cached_reqs=[],
+ scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={requests[0].request_id: 3},
total_num_scheduled_tokens=3,
scheduled_encoder_inputs={},
@@ -1208,7 +1213,6 @@ def assert_scheduler_empty(scheduler: Scheduler):
assert len(scheduler.waiting) == 0
assert len(scheduler.running) == 0
assert len(scheduler.finished_req_ids) == 0
- assert len(scheduler._cached_reqs_data) == 0
# EncoderCacheManager.
assert len(scheduler.encoder_cache_manager.freed) == 0
diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py
index 16c36cd5c6b9..d5ff78c1449a 100644
--- a/tests/v1/engine/test_engine_core_client.py
+++ b/tests/v1/engine/test_engine_core_client.py
@@ -8,8 +8,10 @@
import uuid
from threading import Thread
from typing import Optional
+from unittest.mock import MagicMock
import pytest
+import torch
from transformers import AutoTokenizer
from tests.utils import multi_gpu_test
@@ -517,3 +519,72 @@ def kill_first_child():
)
assert "Engine core initialization failed" in str(e_info.value)
+
+
+@create_new_process_for_each_test()
+def test_engine_core_proc_instantiation_cuda_empty(
+ monkeypatch: pytest.MonkeyPatch):
+ """
+ Test that EngineCoreProc can be instantiated when CUDA_VISIBLE_DEVICES
+ is empty. This ensures the engine frontend does not need access to GPUs.
+ """
+
+ from vllm.v1.engine.core import EngineCoreProc
+ from vllm.v1.executor.abstract import Executor
+
+ # Create a simple mock executor instead of a complex custom class
+ mock_executor_class = MagicMock(spec=Executor)
+
+ def create_mock_executor(vllm_config):
+ mock_executor = MagicMock()
+
+ # Only implement the methods that are actually called during init
+ from vllm.v1.kv_cache_interface import FullAttentionSpec
+ mock_spec = FullAttentionSpec(block_size=16,
+ num_kv_heads=1,
+ head_size=64,
+ dtype=torch.float16,
+ use_mla=False)
+
+ mock_executor.get_kv_cache_specs.return_value = [{
+ "default": mock_spec
+ }]
+ mock_executor.determine_available_memory.return_value = [
+ 1024 * 1024 * 1024
+ ]
+ mock_executor.initialize_from_config.return_value = None
+ mock_executor.max_concurrent_batches = 1
+
+ return mock_executor
+
+ mock_executor_class.side_effect = create_mock_executor
+
+ with monkeypatch.context() as m:
+ m.setenv("VLLM_USE_V1", "1")
+ m.setenv("CUDA_VISIBLE_DEVICES", "") # No CUDA devices
+
+ from vllm.v1.utils import EngineZmqAddresses
+
+ def mock_startup_handshake(self, handshake_socket, on_head_node,
+ parallel_config):
+ return EngineZmqAddresses(inputs=["tcp://127.0.0.1:5555"],
+ outputs=["tcp://127.0.0.1:5556"],
+ coordinator_input=None,
+ coordinator_output=None)
+
+ # Background processes are not important here
+ m.setattr(EngineCoreProc, "startup_handshake", mock_startup_handshake)
+
+ vllm_config = EngineArgs(
+ model="deepseek-ai/DeepSeek-V2-Lite",
+ trust_remote_code=True).create_engine_config()
+ engine_core_proc = EngineCoreProc(
+ vllm_config=vllm_config,
+ on_head_node=True,
+ handshake_address="tcp://127.0.0.1:12345",
+ executor_class=mock_executor_class,
+ log_stats=False,
+ engine_index=0,
+ )
+
+ engine_core_proc.shutdown()
diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py
index ab9729aae2e9..e30a250449aa 100644
--- a/tests/v1/kv_connector/unit/test_nixl_connector.py
+++ b/tests/v1/kv_connector/unit/test_nixl_connector.py
@@ -7,6 +7,8 @@
from typing import Optional
from unittest.mock import patch
+import pytest
+
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata,
NixlConnectorWorker)
@@ -161,7 +163,8 @@ def __init__(self, *args, hand_shake_latency: float = 1.8, **kwargs):
super().__init__(*args, **kwargs)
self._hand_shake_latency = hand_shake_latency
- def _nixl_handshake(self, host: str, port: int) -> dict[int, str]:
+ def _nixl_handshake(self, host: str, port: int,
+ remote_tp_size: int) -> dict[int, str]:
# Mimic slow _nixl_handshake, as well as bypass zmq communication.
time.sleep(self._hand_shake_latency)
# These should've been done in register_kv_caches(), called by
@@ -177,10 +180,10 @@ def _nixl_handshake(self, host: str, port: int) -> dict[int, str]:
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
num_blocks=1,
- tp_size=1,
block_len=self.block_len,
attn_backend_name=self.backend_name,
- ))
+ ),
+ remote_tp_size=remote_tp_size)
return {0: remote_agent_name}
@@ -233,6 +236,8 @@ def test_multi_xfer_one_engine(
"localhost",
"remote_port":
1234,
+ "remote_tp_size":
+ 1,
})
connector.bind_connector_metadata(metadata)
@@ -259,13 +264,23 @@ def test_multi_xfer_one_engine(
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper)
+ @pytest.mark.parametrize("decode_tp_size, prefill_tp_size", [
+ (1, 1),
+ (2, 1),
+ (4, 2),
+ (4, 4),
+ ])
def test_async_load_kv(
- self,
- # dist_init is a fixture that initializes the distributed environment.
- dist_init):
+ self,
+ # Fixture that initializes the distributed environment.
+ dist_init,
+ # Simulate consumer-producer TP sizes.
+ decode_tp_size,
+ prefill_tp_size):
"""Test that NixlConnector's start_load_kv should be non-blocking."""
vllm_config = create_vllm_config()
+ vllm_config.parallel_config.tensor_parallel_size = decode_tp_size
# Test worker role in decode server.
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
@@ -280,6 +295,7 @@ def test_async_load_kv(
FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_host": "localhost",
"remote_port": 1234,
+ "remote_tp_size": prefill_tp_size,
})
connector.bind_connector_metadata(metadata)
@@ -329,6 +345,7 @@ def test_concurrent_load_kv(
FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_host": "localhost",
"remote_port": 1234,
+ "remote_tp_size": 1,
})
connector.bind_connector_metadata(metadata)
diff --git a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
index ff36a281c413..12a71d97e8d2 100644
--- a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
+++ b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
@@ -66,7 +66,7 @@ def test_basic_lifecycle():
assert len(scheduler_output.finished_req_ids) == 1
assert request_id in scheduler_output.finished_req_ids
assert len(scheduler_output.scheduled_new_reqs) == 0
- assert len(scheduler_output.scheduled_cached_reqs) == 0
+ assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
assert len(scheduler.finished_req_ids) == 0
# (2b): execute_model()
@@ -81,7 +81,7 @@ def test_basic_lifecycle():
assert len(scheduler.running) == 0
assert len(scheduler_output.finished_req_ids) == 0
assert len(scheduler_output.scheduled_new_reqs) == 0
- assert len(scheduler_output.scheduled_cached_reqs) == 0
+ assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
assert len(scheduler.finished_req_ids) == 0
# (3b): execute_model()
diff --git a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py
index a1156306dc4b..f89970bf2c80 100644
--- a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py
+++ b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py
@@ -36,7 +36,7 @@ def test_basic_lifecycle():
# Nothing running and empty scheduler output.
assert len(scheduler.running) == 0
assert len(scheduler_output.scheduled_new_reqs) == 0
- assert len(scheduler_output.scheduled_cached_reqs) == 0
+ assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
assert len(scheduler_output.num_scheduled_tokens) == 0
assert scheduler_output.total_num_scheduled_tokens == 0
@@ -158,7 +158,7 @@ def test_interleaved_lifecycle():
assert len(scheduler.running) == 2
assert len(scheduler.waiting) == 1
assert len(scheduler_output.scheduled_new_reqs) == 1
- assert len(scheduler_output.scheduled_cached_reqs) == 1
+ assert scheduler_output.scheduled_cached_reqs.num_reqs == 1
model_runner_output = create_model_runner_output(
[request_local_a, request_local_b])
@@ -169,7 +169,7 @@ def test_interleaved_lifecycle():
assert len(scheduler.running) == 2
assert len(scheduler.waiting) == 1
assert len(scheduler_output.scheduled_new_reqs) == 0
- assert len(scheduler_output.scheduled_cached_reqs) == 2
+ assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
model_runner_output = create_model_runner_output(
reqs=[request_local_a, request_local_b])
@@ -177,14 +177,14 @@ def test_interleaved_lifecycle():
assert len(scheduler.running) == 2
assert len(scheduler.waiting) == 1
assert len(scheduler_output.scheduled_new_reqs) == 0
- assert len(scheduler_output.scheduled_cached_reqs) == 2
+ assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
# STEP 4: KVs arrive.
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 2
assert len(scheduler.waiting) == 1
assert len(scheduler_output.scheduled_new_reqs) == 0
- assert len(scheduler_output.scheduled_cached_reqs) == 2
+ assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
model_runner_output = create_model_runner_output(
[request_local_a, request_local_b],
@@ -196,7 +196,7 @@ def test_interleaved_lifecycle():
assert len(scheduler.running) == 3
assert len(scheduler.waiting) == 0
assert len(scheduler_output.scheduled_new_reqs) == 1
- assert len(scheduler_output.scheduled_cached_reqs) == 2
+ assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
model_runner_output = create_model_runner_output(
[request_local_a, request_local_b, request_remote])
diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py
index 61f59f35f75b..983d900606fc 100644
--- a/tests/v1/kv_connector/unit/utils.py
+++ b/tests/v1/kv_connector/unit/utils.py
@@ -25,7 +25,6 @@ def assert_scheduler_empty(scheduler: Scheduler):
assert len(scheduler.running) == 0
assert len(scheduler.finished_req_ids) == 0
assert len(scheduler.finished_recving_kv_req_ids) == 0
- assert len(scheduler._cached_reqs_data) == 0
# EncoderCacheManager.
assert len(scheduler.encoder_cache_manager.freed) == 0
diff --git a/tests/v1/tpu/test_kv_cache_update_kernel.py b/tests/v1/tpu/test_kv_cache_update_kernel.py
new file mode 100644
index 000000000000..63a1f6777e4d
--- /dev/null
+++ b/tests/v1/tpu/test_kv_cache_update_kernel.py
@@ -0,0 +1,71 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import numpy as np
+import pytest
+import torch
+import torch_xla
+
+import vllm.v1.attention.backends.pallas # noqa: F401
+from vllm.platforms import current_platform
+
+
+@pytest.mark.skipif(not current_platform.is_tpu(),
+ reason="This is a test for TPU only")
+@pytest.mark.parametrize("page_size", [32, 33])
+@pytest.mark.parametrize("combined_kv_head_num", [2, 16])
+@pytest.mark.parametrize("head_dim", [128, 256])
+@pytest.mark.parametrize("num_slices_per_block", [4, 8])
+def test_kv_cache_update_kernel(page_size: int, combined_kv_head_num: int,
+ head_dim: int, num_slices_per_block: int):
+ page_num = 1000
+ padded_num_tokens = 128
+ kv_cache_cpu = torch.zeros(
+ (page_num * page_size, combined_kv_head_num, head_dim),
+ dtype=torch.bfloat16,
+ device="cpu")
+ kv_cache_xla = kv_cache_cpu.to(torch_xla.device())
+ new_kv_cpu = torch.randn(
+ (padded_num_tokens, combined_kv_head_num, head_dim),
+ dtype=torch.bfloat16,
+ device="cpu")
+ new_kv_xla = new_kv_cpu.to(torch_xla.device())
+ slice_lens = np.array([7, page_size, page_size, 1, 1, 1, 9],
+ dtype=np.int32)
+ kv_cache_start_indices = np.array([
+ page_size * 2 - 7, page_size * 2, page_size * 3, page_size * 4 + 6,
+ page_size * 5 + 7, page_size * 6 + 8, page_size * 15 + 3
+ ],
+ dtype=np.int32)
+ new_kv_cache_indices = np.concatenate(
+ [np.array([0], dtype=np.int32),
+ np.cumsum(slice_lens[:-1])])
+ slot_mapping = np.stack(
+ [kv_cache_start_indices, new_kv_cache_indices, slice_lens], axis=1)
+ padded_size = (slot_mapping.shape[0] + num_slices_per_block -
+ 1) // num_slices_per_block * num_slices_per_block
+ slot_mapping = np.pad(slot_mapping,
+ [[0, padded_size - slot_mapping.shape[0]], [0, 0]],
+ constant_values=0)
+ slot_mapping = np.transpose(slot_mapping)
+ slot_mapping_cpu = torch.tensor(slot_mapping,
+ device="cpu",
+ dtype=torch.int32)
+ slot_mapping_xla = slot_mapping_cpu.to(torch_xla.device())
+ torch_xla.sync()
+
+ torch.ops.xla.dynamo_set_buffer_donor_(kv_cache_xla, True)
+ new_kv_cache_xla = torch.ops.xla.kv_cache_update_op(
+ new_kv_xla, slot_mapping_xla, kv_cache_xla, page_size,
+ num_slices_per_block)
+ kv_cache_xla.copy_(new_kv_cache_xla)
+ torch_xla.sync()
+
+ for ni, ci, sl in zip(new_kv_cache_indices, kv_cache_start_indices,
+ slice_lens):
+ kv_cache_cpu[ci:ci + sl, :, :] = new_kv_cpu[ni:ni + sl, :, :]
+
+ assert torch.allclose(kv_cache_xla.cpu(),
+ kv_cache_cpu,
+ atol=1e-4,
+ rtol=1e-4)
diff --git a/tests/v1/tpu/test_pallas.py b/tests/v1/tpu/test_pallas.py
index 3a9d80847a16..e279edfffbc7 100644
--- a/tests/v1/tpu/test_pallas.py
+++ b/tests/v1/tpu/test_pallas.py
@@ -47,7 +47,7 @@ class FakeAttentionLayer:
key = torch.zeros(num_tokens, num_kv_heads * head_size)
value = torch.zeros(num_tokens, num_kv_heads * head_size)
kv_cache = torch.zeros(num_blocks, block_size, num_kv_heads * 2, head_size)
- slot_mapping = torch.zeros(num_tokens, dtype=torch.int64)
+ slot_mapping = torch.zeros((3, num_tokens), dtype=torch.int64)
max_num_reqs = 8
max_num_blocks_per_req = 8
block_tables = torch.zeros((max_num_reqs, max_num_blocks_per_req),
@@ -65,6 +65,7 @@ class FakeAttentionLayer:
context_lens=context_lens,
query_start_loc=query_start_loc,
num_seqs=num_seqs,
+ num_slices_per_kv_cache_update_block=8,
)
with patch("torch.ops.xla.ragged_paged_attention"
diff --git a/tests/v1/tpu/worker/test_tpu_model_runner.py b/tests/v1/tpu/worker/test_tpu_model_runner.py
index d22ddf5c7e58..40db0b2afe0d 100644
--- a/tests/v1/tpu/worker/test_tpu_model_runner.py
+++ b/tests/v1/tpu/worker/test_tpu_model_runner.py
@@ -82,7 +82,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
return SchedulerOutput(
scheduled_new_reqs=new_reqs,
- scheduled_cached_reqs=[],
+ scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_spec_decode_tokens={},
@@ -161,7 +161,7 @@ def test_update_states_request_finished(model_runner):
# finish req
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
- scheduled_cached_reqs=[],
+ scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={},
total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={},
@@ -191,7 +191,7 @@ def test_update_states_request_resumed(model_runner):
# unschedule req
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
- scheduled_cached_reqs=[],
+ scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={},
total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={},
@@ -209,16 +209,16 @@ def test_update_states_request_resumed(model_runner):
# resume req
cached_req_data = CachedRequestData(
- req_id=req_id,
- resumed_from_preemption=False,
- new_token_ids=[],
- new_block_ids=([], ),
- num_computed_tokens=0,
+ req_ids=[req_id],
+ resumed_from_preemption=[False],
+ new_token_ids=[[]],
+ new_block_ids=[([], )],
+ num_computed_tokens=[0],
)
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
- scheduled_cached_reqs=[cached_req_data],
+ scheduled_cached_reqs=cached_req_data,
num_scheduled_tokens={req_id: 1},
total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={},
@@ -249,7 +249,7 @@ def test_update_states_no_changes(model_runner):
# schedule req
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
- scheduled_cached_reqs=[],
+ scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={req_id: 1},
total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={},
@@ -284,7 +284,7 @@ def test_update_states_request_unscheduled(model_runner):
# unschedule req_1
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
- scheduled_cached_reqs=[],
+ scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={req_ids[0]: 1},
total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={},
@@ -587,3 +587,17 @@ def test_init_kv_cache_with_kv_sharing_valid():
assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2
assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
+
+
+def test_most_model_len(monkeypatch: pytest.MonkeyPatch):
+ monkeypatch.setenv("VLLM_TPU_MOST_MODEL_LEN", "2048")
+ vllm_config = get_vllm_config()
+ vllm_config.model_config.max_model_len = 32000
+ vllm_config.scheduler_config.max_num_seqs = 1200
+ model_runner = get_model_runner(vllm_config)
+
+ # verify model runner will adjust num_reqs to avoid SMEM OOM.
+ assert model_runner.num_reqs_most_model_len == 1200
+ # num_page_per_req = 32k // 128
+ # num_reqs = 1024 ** 2 // 2 // num_page_per_req // 4 = 524
+ assert model_runner.num_reqs_max_model_len == 524
diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py
index 583a88d8e6ec..c739b23b90dc 100644
--- a/tests/v1/worker/test_gpu_model_runner.py
+++ b/tests/v1/worker/test_gpu_model_runner.py
@@ -133,7 +133,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
return SchedulerOutput(
scheduled_new_reqs=new_reqs,
- scheduled_cached_reqs=[],
+ scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_spec_decode_tokens={},
@@ -199,7 +199,7 @@ def test_update_states_request_finished(model_runner):
# finish req
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
- scheduled_cached_reqs=[],
+ scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={},
total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={},
@@ -231,7 +231,7 @@ def test_update_states_request_resumed(model_runner):
# unschedule req
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
- scheduled_cached_reqs=[],
+ scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={},
total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={},
@@ -249,16 +249,16 @@ def test_update_states_request_resumed(model_runner):
# resume req
cached_req_data = CachedRequestData(
- req_id=req_id,
- resumed_from_preemption=False,
- new_token_ids=[],
- new_block_ids=([], ),
- num_computed_tokens=0,
+ req_ids=[req_id],
+ resumed_from_preemption=[False],
+ new_token_ids=[[]],
+ new_block_ids=([[0]], ),
+ num_computed_tokens=[0],
)
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
- scheduled_cached_reqs=[cached_req_data],
+ scheduled_cached_reqs=cached_req_data,
num_scheduled_tokens={req_id: 1},
total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={},
@@ -339,7 +339,7 @@ def test_update_states_no_changes(model_runner):
# schedule req
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
- scheduled_cached_reqs=[],
+ scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={req_id: 1},
total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={},
@@ -376,7 +376,7 @@ def test_update_states_request_unscheduled(model_runner):
# unschedule req_1
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
- scheduled_cached_reqs=[],
+ scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={req_ids[0]: 1},
total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={},
diff --git a/tools/generate_nightly_torch_test.py b/tools/generate_nightly_torch_test.py
new file mode 100644
index 000000000000..a3d7f7a609ba
--- /dev/null
+++ b/tools/generate_nightly_torch_test.py
@@ -0,0 +1,34 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+Generates specialized requirements files for nightly PyTorch testing.
+
+This script reads the main test requirements input file (`requirements/test.in`)
+and splits its content into two files:
+1. `requirements/nightly_torch_test.txt`: Contains dependencies
+except PyTorch-related.
+2. `torch_nightly_test.txt`: Contains only PyTorch-related packages.
+"""
+
+input_file = "requirements/test.in"
+output_file = "requirements/nightly_torch_test.txt"
+
+# white list of packages that are not compatible with PyTorch nightly directly
+# with pip install. Please add your package to this list if it is not compatible
+# or make the dependency test fails.
+white_list = ["torch", "torchaudio", "torchvision", "mamba_ssm"]
+
+with open(input_file) as f:
+ lines = f.readlines()
+
+skip_next = False
+
+for line in lines:
+ if skip_next:
+ if line.startswith((" ", "\t")) or line.strip() == "":
+ continue
+ skip_next = False
+
+ if any(k in line.lower() for k in white_list):
+ skip_next = True
+ continue
diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py
index 761704463782..89b6603fdf89 100644
--- a/vllm/_custom_ops.py
+++ b/vllm/_custom_ops.py
@@ -5,7 +5,6 @@
from typing import TYPE_CHECKING, Optional, Union
import torch
-import torch.library
import vllm.envs as envs
from vllm.logger import init_logger
@@ -1288,8 +1287,7 @@ def scaled_fp8_quant(
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
else:
- # num_token_padding not implemented for this case
- assert (scale.numel() == 1 and num_token_padding is None)
+ assert scale.numel() == 1
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
return output, scale
@@ -1767,6 +1765,38 @@ def LLMM_Silu(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor,
torch.ops._rocm_C.LLMM_Silu(a, b, out, rows_per_block)
+# quick all reduce
+def init_custom_qr(rank: int,
+ world_size: int,
+ qr_max_size: Optional[int] = None) -> int:
+ return torch.ops._C_custom_ar.init_custom_qr(rank, world_size, qr_max_size)
+
+
+def qr_destroy(fa: int) -> None:
+ torch.ops._C_custom_ar.qr_destroy(fa)
+
+
+def qr_all_reduce(fa: int,
+ inp: torch.Tensor,
+ out: torch.Tensor,
+ quant_level: int,
+ cast_bf2half: bool = False) -> None:
+ torch.ops._C_custom_ar.qr_all_reduce(fa, inp, out, quant_level,
+ cast_bf2half)
+
+
+def qr_get_handle(fa: int) -> torch.Tensor:
+ return torch.ops._C_custom_ar.qr_get_handle(fa)
+
+
+def qr_open_handles(fa: int, handles: list[torch.Tensor]) -> None:
+ return torch.ops._C_custom_ar.qr_open_handles(fa, handles)
+
+
+def qr_max_size() -> int:
+ return torch.ops._C_custom_ar.qr_max_size()
+
+
def get_flash_mla_metadata(
cache_seqlens: torch.Tensor,
num_heads_per_head_k: int,
diff --git a/vllm/_ipex_ops.py b/vllm/_ipex_ops.py
index ae63e06030dd..2be02411ec05 100644
--- a/vllm/_ipex_ops.py
+++ b/vllm/_ipex_ops.py
@@ -228,6 +228,111 @@ def reshape_and_cache(
ipex.llm.modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache, slot_mapping)
+ @staticmethod
+ def reshape_and_cache_flash(
+ key: torch.Tensor,
+ value: torch.Tensor,
+ key_cache: torch.Tensor,
+ value_cache: torch.Tensor,
+ slot_mapping: torch.Tensor,
+ kv_cache_dtype: str,
+ k_scale: Optional[torch.Tensor] = None,
+ v_scale: Optional[torch.Tensor] = None,
+ k_scale_float: float = 1.0,
+ v_scale_float: float = 1.0,
+ ) -> None:
+ assert kv_cache_dtype == "auto"
+ # TODO: support FP8 kv cache.
+ ipex.llm.modules.PagedAttention.reshape_and_cache_flash(
+ key, value, key_cache, value_cache, slot_mapping)
+
+ @staticmethod
+ def flash_attn_varlen_func(
+ out: torch.Tensor,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ cu_seqlens_q: torch.Tensor,
+ seqused_k: torch.Tensor, # we don't support this in ipex kernel
+ max_seqlen_q: int,
+ max_seqlen_k: int,
+ softmax_scale: float,
+ causal: bool,
+ block_table: torch.Tensor,
+ alibi_slopes: Optional[torch.Tensor],
+ window_size: Optional[list[int]] = None,
+ softcap: Optional[float] = 0.0,
+ cu_seqlens_k: Optional[torch.Tensor] = None,
+ # The following parameters are not used in ipex kernel currently,
+ # we keep API compatible to CUDA's.
+ scheduler_metadata=None,
+ fa_version: int = 2,
+ q_descale=None,
+ k_descale=None,
+ v_descale=None,
+ ):
+ if cu_seqlens_k is None:
+ # cu_seqlens_k is not used in ipex kernel.
+ cu_seqlens_k = torch.cumsum(seqused_k, dim=0)
+ cu_seqlens_k = torch.cat([
+ torch.tensor([0], device=seqused_k.device, dtype=torch.int32),
+ cu_seqlens_k
+ ]).to(torch.int32)
+
+ real_window_size: tuple[int, int]
+ if window_size is None:
+ real_window_size = (-1, -1)
+ else:
+ assert len(window_size) == 2
+ real_window_size = (window_size[0], window_size[1])
+ return ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
+ out,
+ q.contiguous(),
+ k,
+ v,
+ cu_seqlens_q,
+ cu_seqlens_k,
+ max_seqlen_q,
+ max_seqlen_k,
+ softmax_scale,
+ causal,
+ block_table,
+ alibi_slopes,
+ softcap=softcap,
+ window_size_left=real_window_size[0],
+ window_size_right=real_window_size[1],
+ k_scale=1.0,
+ v_scale=1.0,
+ )
+
+ @staticmethod
+ def get_scheduler_metadata(
+ batch_size,
+ max_seqlen_q,
+ max_seqlen_k,
+ num_heads_q,
+ num_heads_kv,
+ headdim,
+ cache_seqlens: torch.Tensor,
+ qkv_dtype=torch.bfloat16,
+ headdim_v=None,
+ cu_seqlens_q: Optional[torch.Tensor] = None,
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
+ cache_leftpad: Optional[torch.Tensor] = None,
+ page_size: Optional[int] = None,
+ max_seqlen_k_new=0,
+ causal=False,
+ window_size=(-1, -1), # -1 means infinite context window
+ has_softcap=False,
+ num_splits=0, # Can be tuned for speed
+ pack_gqa=None, # Can be tuned for speed
+ sm_margin=0, # Can be tuned if some SMs are used for communication
+ ) -> None:
+ logger.warning_once(
+ "get_scheduler_metadata is not implemented for ipex_ops, "
+ "returning None.")
+ return None
+
@staticmethod
def copy_blocks(key_caches: list[torch.Tensor],
value_caches: list[torch.Tensor],
diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py
index f7d230c5d7d6..0c79aaf13551 100644
--- a/vllm/attention/layer.py
+++ b/vllm/attention/layer.py
@@ -306,12 +306,16 @@ def __init__(
block_size=16,
is_attention_free=False)
backend = backend_name_to_enum(attn_backend.get_name())
- if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
- backend = _Backend.XFORMERS
+ if current_platform.is_rocm():
+ # currently, only torch_sdpa is supported on rocm
+ self.attn_backend = _Backend.TORCH_SDPA
+ else:
+ if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
+ backend = _Backend.XFORMERS
- self.attn_backend = backend if backend in {
- _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1
- } else _Backend.TORCH_SDPA
+ self.attn_backend = backend if backend in {
+ _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1
+ } else _Backend.TORCH_SDPA
def forward(
self,
diff --git a/vllm/attention/ops/pallas_kv_cache_update.py b/vllm/attention/ops/pallas_kv_cache_update.py
new file mode 100644
index 000000000000..1a92b10e4f9c
--- /dev/null
+++ b/vllm/attention/ops/pallas_kv_cache_update.py
@@ -0,0 +1,117 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import functools
+
+import jax
+from jax.experimental import pallas as pl
+from jax.experimental.pallas import tpu as pltpu
+
+
+def _kv_cache_update_kernel(
+ # Prefetch
+ slices_ref, # [3, num_slices], list of (kv_cache_start, new_kv_start,
+ # slice_len)
+ # Input
+ new_kv_hbm_ref, # [num_tokens, num_combined_kv_heads, head_dim]
+ kv_cache_hbm_ref, # [total_num_pages * page_size, num_combined_kv_heads,
+ # head_dim]
+ # Output
+ _, # [total_num_pages * page_size, num_combined_kv_heads, head_dim]
+ # Scratch
+ scratch, # [num_slices_per_block, page_size, num_combined_kv_heads,
+ # head_dim]
+ sem,
+):
+ async_copies = []
+ block_idx = pl.program_id(0)
+ num_slices_per_block = scratch.shape[0]
+
+ # Copy from new_kv_hbm_ref to scratch
+ for i in range(num_slices_per_block):
+ offset_i = i + block_idx * num_slices_per_block
+ new_kv_start = slices_ref[1, offset_i]
+ length = slices_ref[2, offset_i]
+ async_copy = pltpu.make_async_copy(
+ new_kv_hbm_ref.at[pl.ds(new_kv_start, length), ...],
+ scratch.at[i, pl.ds(0, length), ...],
+ sem,
+ )
+ async_copy.start()
+ async_copies.append(async_copy)
+
+ for async_copy in async_copies:
+ async_copy.wait()
+
+ # Copy from scratch to kv_cache_hbm_ref
+ async_copies.clear()
+ for i in range(num_slices_per_block):
+ offset_i = i + block_idx * num_slices_per_block
+ kv_cache_start = slices_ref[0, offset_i]
+ length = slices_ref[2, offset_i]
+ async_copy = pltpu.make_async_copy(
+ scratch.at[i, pl.ds(0, length), ...],
+ kv_cache_hbm_ref.at[pl.ds(kv_cache_start, length), ...],
+ sem,
+ )
+ async_copy.start()
+ async_copies.append(async_copy)
+ for async_copy in async_copies:
+ async_copy.wait()
+
+
+@functools.partial(
+ jax.jit,
+ static_argnames=["page_size", "num_slices_per_block"],
+)
+def kv_cache_update(
+ new_kv: jax.Array, # [total_num_token, num_combined_kv_heads, head_dim]
+ slices: jax.
+ Array, # [3, slices], list of (kv_cache_start, new_kv_start, slice_len)
+ kv_cache: jax.
+ Array, # [total_num_pages * page_size, num_combined_kv_heads, head_dim]
+ *,
+ page_size: int = 32,
+ num_slices_per_block: int = 8,
+):
+ assert slices.shape[1] % num_slices_per_block == 0
+ _, num_combined_kv_heads, head_dim = new_kv.shape
+ assert kv_cache.shape[1] == num_combined_kv_heads
+ assert kv_cache.shape[2] == head_dim
+ assert head_dim % 128 == 0
+ # TODO: Add dynamic check to make sure that the all the slice lengths are
+ # smaller or equal to page_size
+
+ in_specs = [
+ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
+ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
+ ]
+
+ out_specs = [pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)]
+ out_shape = [jax.ShapeDtypeStruct(kv_cache.shape, dtype=kv_cache.dtype)]
+
+ scalar_prefetches = [slices]
+ scratch = pltpu.VMEM(
+ (num_slices_per_block, page_size, num_combined_kv_heads, head_dim),
+ new_kv.dtype,
+ )
+
+ scratch_shapes = [
+ scratch,
+ pltpu.SemaphoreType.DMA,
+ ]
+
+ kernel = pl.pallas_call(
+ _kv_cache_update_kernel,
+ grid_spec=pltpu.PrefetchScalarGridSpec(
+ num_scalar_prefetch=len(scalar_prefetches),
+ in_specs=in_specs,
+ out_specs=out_specs,
+ grid=(slices.shape[1] // num_slices_per_block, ),
+ scratch_shapes=scratch_shapes,
+ ),
+ out_shape=out_shape,
+ input_output_aliases={len(scalar_prefetches) + 1: 0},
+ )
+
+ return kernel(*scalar_prefetches, new_kv, kv_cache)[0]
diff --git a/vllm/attention/utils/fa_utils.py b/vllm/attention/utils/fa_utils.py
index 69cde06fd72e..f8b00565f051 100644
--- a/vllm/attention/utils/fa_utils.py
+++ b/vllm/attention/utils/fa_utils.py
@@ -4,13 +4,27 @@
from vllm import envs
from vllm.logger import init_logger
+from vllm.platforms import current_platform
logger = init_logger(__name__)
+if current_platform.is_cuda():
+ from vllm import _custom_ops as ops
+ reshape_and_cache_flash = ops.reshape_and_cache_flash
+ from vllm.vllm_flash_attn import (flash_attn_varlen_func,
+ get_scheduler_metadata)
+elif current_platform.is_xpu():
+ from vllm._ipex_ops import ipex_ops as ops
+ reshape_and_cache_flash = ops.reshape_and_cache_flash
+ flash_attn_varlen_func = ops.flash_attn_varlen_func
+ get_scheduler_metadata = ops.get_scheduler_metadata
+
def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]:
# import here to avoid circular dependencies
from vllm.platforms import current_platform
+ if current_platform.is_xpu():
+ return 2
try:
from vllm.vllm_flash_attn.flash_attn_interface import (
fa_version_unsupported_reason, is_fa_version_supported)
@@ -50,6 +64,9 @@ def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]:
def flash_attn_supports_fp8() -> bool:
- from vllm.platforms import current_platform
return get_flash_attn_version() == 3 and \
current_platform.get_device_capability().major == 9
+
+
+def is_flash_attn_varlen_func_available() -> bool:
+ return current_platform.is_cuda() or current_platform.is_xpu()
diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py
index 3efbe5695711..b3688d2340e4 100644
--- a/vllm/benchmarks/datasets.py
+++ b/vllm/benchmarks/datasets.py
@@ -320,6 +320,8 @@ def __init__(
**kwargs,
) -> None:
super().__init__(**kwargs)
+ random.seed(self.random_seed)
+ np.random.seed(self.random_seed)
def sample(
self,
@@ -376,10 +378,11 @@ def sample(
# [1650, 939, 486] -> ['ฤ call', 'sh', 'ere']
# To avoid uncontrolled change of the prompt length,
# the encoded sequence is truncated before being decode again.
+ total_input_len = prefix_len + int(input_lens[i])
re_encoded_sequence = tokenizer.encode(
- prompt, add_special_tokens=False)[:input_lens[i]]
+ prompt, add_special_tokens=False)[:total_input_len]
prompt = tokenizer.decode(re_encoded_sequence)
- total_input_len = prefix_len + int(input_lens[i])
+ total_input_len = len(re_encoded_sequence)
requests.append(
SampleRequest(
prompt=prompt,
@@ -692,7 +695,8 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
dataset_path=args.dataset_path).
sample(tokenizer=tokenizer, num_requests=args.num_prompts),
"random":
- lambda: RandomDataset(dataset_path=args.dataset_path).sample(
+ lambda: RandomDataset(random_seed=args.seed,
+ dataset_path=args.dataset_path).sample(
tokenizer=tokenizer,
num_requests=args.num_prompts,
prefix_len=args.random_prefix_len,
diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py
index 302f655f424a..419284cca042 100644
--- a/vllm/benchmarks/serve.py
+++ b/vllm/benchmarks/serve.py
@@ -631,6 +631,12 @@ def add_cli_args(parser: argparse.ArgumentParser):
help="The label (prefix) of the benchmark results. If not specified, "
"the endpoint type will be used as the label.",
)
+ parser.add_argument(
+ "--backend",
+ type=str,
+ default="vllm",
+ choices=list(ASYNC_REQUEST_FUNCS.keys()),
+ )
parser.add_argument(
"--base-url",
type=str,
diff --git a/vllm/config.py b/vllm/config.py
index 6883ec29a184..63eac9bd8903 100644
--- a/vllm/config.py
+++ b/vllm/config.py
@@ -569,6 +569,10 @@ def __post_init__(self) -> None:
else:
self.truncation_side = "right"
+ model_info, arch = self.registry.inspect_model_cls(self.architectures)
+ self._model_info = model_info
+ self._architecture = arch
+
self.pooler_config = self._init_pooler_config()
self.dtype = _get_and_verify_dtype(
@@ -660,8 +664,18 @@ def registry(self):
@property
def architectures(self) -> list[str]:
+ # architectures in the model config.
return getattr(self.hf_config, "architectures", [])
+ @property
+ def architecture(self) -> str:
+ # The architecture vllm actually used.
+ return self._architecture
+
+ @property
+ def model_info(self) -> dict[str, Any]:
+ return self._model_info
+
def maybe_pull_model_tokenizer_for_s3(self, model: str,
tokenizer: str) -> None:
"""Pull model/tokenizer from S3 to temporary directory when needed.
@@ -1470,7 +1484,7 @@ class CacheConfig:
sizes up to 32 are supported. On HPU devices, block size defaults to 128.
This config has no static default. If left unspecified by the user, it will
- be set in `Platform.check_and_update_configs()` based on the current
+ be set in `Platform.check_and_update_config()` based on the current
platform."""
gpu_memory_utilization: float = 0.9
"""The fraction of GPU memory to be used for the model executor, which can
@@ -1775,6 +1789,25 @@ class ParallelConfig:
"""Backend to use for data parallel, either "mp" or "ray"."""
enable_expert_parallel: bool = False
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
+ enable_eplb: bool = False
+ """Enable expert parallelism load balancing for MoE layers."""
+ num_redundant_experts: int = 0
+ """Number of redundant experts to use for expert parallelism."""
+ eplb_window_size: int = 1000
+ """Window size for expert load recording."""
+ eplb_step_interval: int = 3000
+ """
+ Interval for rearranging experts in expert parallelism.
+
+ Note that if this is greater than the EPLB window size, only the metrics
+ of the last `eplb_window_size` steps will be used for rearranging experts.
+ """
+ eplb_log_balancedness: bool = False
+ """
+ Log the balancedness each step of expert parallelism.
+ This is turned off by default since it will cause communication overhead.
+ """
+
max_parallel_loading_workers: Optional[int] = None
"""Maximum number of parallel loading workers when loading model
sequentially in multiple batches. To avoid RAM OOM when using tensor
@@ -1845,18 +1878,41 @@ def get_next_dp_init_port(self) -> int:
return answer
def stateless_init_dp_group(self) -> "ProcessGroup":
+ # NOTE: In high-concurrency scenarios multiple processes
+ # can pick the same (currently free) port through a race
+ # condition when calling `get_open_port()`. When the first
+ # process binds the port the others will subsequently fail
+ # with `torch.distributed.DistNetworkError: EADDRINUSE`.
+ # To make the initialization more robust we retry a few times
+ # with a fresh port whenever this specific error is observed.
+ from torch.distributed import DistNetworkError
+
from vllm.distributed.utils import (
stateless_init_torch_distributed_process_group)
- # use gloo since the engine process might not have cuda device
- dp_group = stateless_init_torch_distributed_process_group(
- self.data_parallel_master_ip,
- self.get_next_dp_init_port(),
- self.data_parallel_rank,
- self.data_parallel_size,
- backend="gloo")
+ max_retries = 5
+ last_exc: Optional[Exception] = None
+ for _ in range(max_retries):
+ try:
+ # use gloo since the engine process might not have cuda device
+ return stateless_init_torch_distributed_process_group(
+ self.data_parallel_master_ip,
+ self.get_next_dp_init_port(),
+ self.data_parallel_rank,
+ self.data_parallel_size,
+ backend="gloo")
+ except DistNetworkError as e:
+ # We only want to retry when the root cause is EADDRINUSE.
+ if "EADDRINUSE" in str(e):
+ logger.warning(
+ "Address already in use. Retrying with a new port.")
+ last_exc = e
+ continue # try again with a new port
+ raise e
- return dp_group
+ # If we get here all retries have failed.
+ assert last_exc is not None
+ raise last_exc
@staticmethod
def has_unfinished_dp(dp_group: "ProcessGroup",
@@ -1913,6 +1969,20 @@ def __post_init__(self) -> None:
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
logger.info("Disabling V1 multiprocessing for external launcher.")
+ if self.enable_eplb:
+ if not current_platform.is_cuda():
+ raise ValueError(
+ "Expert parallelism load balancing is only supported on "
+ "CUDA devices now.")
+ if self.num_redundant_experts < 0:
+ raise ValueError(
+ "num_redundant_experts must be non-negative, but got "
+ f"{self.num_redundant_experts}.")
+ else:
+ if self.num_redundant_experts != 0:
+ raise ValueError(
+ "num_redundant_experts should be used with EPLB."
+ f"{self.num_redundant_experts}.")
if self.distributed_executor_backend is None and self.world_size > 1:
# We use multiprocessing by default if world_size fits on the
# current node and we aren't in a ray placement group.
@@ -3924,7 +3994,8 @@ class CompilationConfig:
- 'none,+op1,+op2' to enable only op1 and op2
By default, all custom ops are enabled when running without Inductor and
- disabled when running with Inductor (compile_level >= Inductor)."""
+ disabled when running with Inductor: level>=PIECEWISE and use_inductor=True.
+ Inductor generates (fused) Triton kernels for disabled custom ops."""
splitting_ops: list[str] = field(default_factory=list)
"""A list of ops to split the full graph into subgraphs, used in piecewise
compilation."""
@@ -3933,10 +4004,13 @@ class CompilationConfig:
use_inductor: bool = True
"""Whether to use inductor compilation:
- - False: inductor compilation is not used. graph runs in eager.
- - True: inductor compilation is used. one graph for symbolic shape
- is compiled. In addition, compile for compile_sizes,
- using configurations in inductor_compile_config."""
+ - False: inductor compilation is not used. graph runs in eager
+ (custom_ops enabled by default).
+ - True: inductor compilation is used (custom_ops disabled by default).
+ One graph for symbolic shape and one graph per size in compile_sizes
+ are compiled using configurations in inductor_compile_config.
+
+ This setting is ignored if level 0 and \
@@ -4668,11 +4732,21 @@ def _set_cudagraph_sizes(self):
batch_size_capture_list)
def recalculate_max_model_len(self, max_model_len: int):
+ # Can only be called in try_verify_and_update_config
model_config = self.model_config
max_model_len = model_config.get_and_verify_max_len(max_model_len)
self.model_config.max_model_len = max_model_len
self.scheduler_config.max_model_len = max_model_len
- self.compute_hash()
+
+ def try_verify_and_update_config(self):
+ architecture = getattr(self.model_config, "architecture", None)
+ if architecture is None:
+ return
+
+ from vllm.model_executor.models.config import MODELS_CONFIG_MAP
+ cls = MODELS_CONFIG_MAP.get(architecture, None)
+ if cls is not None:
+ cls.verify_and_update_config(self)
def __str__(self):
return (
diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py
index 35f2fd0ba9e2..85f87cb21edc 100644
--- a/vllm/distributed/device_communicators/all2all.py
+++ b/vllm/distributed/device_communicators/all2all.py
@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-import importlib.util
from typing import TYPE_CHECKING, Any
import torch
@@ -8,6 +7,7 @@
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
+from vllm.utils import has_deep_ep, has_pplx
from .base_device_communicator import All2AllManagerBase, Cache
@@ -80,8 +80,8 @@ class PPLXAll2AllManager(All2AllManagerBase):
"""
def __init__(self, cpu_group):
- has_pplx = importlib.util.find_spec("pplx_kernels") is not None
- assert has_pplx, "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa
+ assert has_pplx(
+ ), "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa
super().__init__(cpu_group)
if self.internode:
@@ -133,8 +133,8 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
"""
def __init__(self, cpu_group):
- has_deepep = importlib.util.find_spec("deep_ep") is not None
- assert has_deepep, "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels." # noqa
+ assert has_deep_ep(
+ ), "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels." # noqa
super().__init__(cpu_group)
self.handle_cache = Cache()
diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py
index 055d91690e67..3958d566b174 100644
--- a/vllm/distributed/device_communicators/cuda_communicator.py
+++ b/vllm/distributed/device_communicators/cuda_communicator.py
@@ -8,6 +8,7 @@
import vllm.envs as envs
from vllm.logger import init_logger
+from vllm.platforms import current_platform
from .base_device_communicator import DeviceCommunicatorBase
@@ -41,6 +42,8 @@ def __init__(self,
CustomAllreduce)
from vllm.distributed.device_communicators.pynccl import (
PyNcclCommunicator)
+ from vllm.distributed.device_communicators.quick_all_reduce import (
+ QuickAllReduce)
self.pynccl_comm: Optional[PyNcclCommunicator] = None
if use_pynccl and self.world_size > 1:
@@ -50,6 +53,7 @@ def __init__(self,
)
self.ca_comm: Optional[CustomAllreduce] = None
+ self.qr_comm: Optional[QuickAllReduce] = None
if use_custom_allreduce and self.world_size > 1:
# Initialize a custom fast all-reduce implementation.
self.ca_comm = CustomAllreduce(
@@ -57,6 +61,14 @@ def __init__(self,
device=self.device,
)
+ if current_platform.is_rocm():
+ # Initialize a custom quick all-reduce implementation for AMD.
+ # Quick reduce is designed as a complement to custom allreduce.
+ # Based on quickreduce (https://github.com/mk1-project/quickreduce).
+ # If it's a rocm, 'use_custom_allreduce==True' means it must
+ # currently be an MI300 series.
+ self.qr_comm = QuickAllReduce(group=self.cpu_group,
+ device=self.device)
if self.use_all2all:
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
if all2all_backend == "naive":
@@ -79,8 +91,14 @@ def __init__(self,
raise ValueError(f"Unknown all2all backend: {all2all_backend}")
def all_reduce(self, input_):
- # always try custom allreduce first,
- # and then pynccl.
+ # always try quick reduce first, then custom allreduce,
+ # and then pynccl. (quick reduce just for ROCM MI3*)
+ qr_comm = self.qr_comm
+ if qr_comm is not None and not qr_comm.disabled and \
+ qr_comm.should_quick_allreduce(input_):
+ out = qr_comm.quick_all_reduce(input_)
+ assert out is not None
+ return out
ca_comm = self.ca_comm
if ca_comm is not None and not ca_comm.disabled and \
ca_comm.should_custom_ar(input_):
diff --git a/vllm/distributed/device_communicators/quick_all_reduce.py b/vllm/distributed/device_communicators/quick_all_reduce.py
new file mode 100644
index 000000000000..c61231e2d33f
--- /dev/null
+++ b/vllm/distributed/device_communicators/quick_all_reduce.py
@@ -0,0 +1,278 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from enum import Enum
+from typing import Union
+
+import torch
+import torch.distributed as dist
+from torch.distributed import ProcessGroup
+
+import vllm.envs as envs
+from vllm import _custom_ops as ops
+from vllm.config import get_current_vllm_config
+from vllm.distributed.parallel_state import in_the_same_node_as
+from vllm.logger import init_logger
+from vllm.platforms import current_platform
+from vllm.utils import cuda_device_count_stateless
+
+logger = init_logger(__name__)
+
+try:
+ ops.qr_max_size()
+ quick_ar = True
+except Exception:
+ # For CPUs and CUDA
+ quick_ar = False
+
+
+def is_weak_contiguous(inp: torch.Tensor):
+ return inp.is_contiguous() or (inp.storage().nbytes() -
+ inp.storage_offset() * inp.element_size()
+ == inp.numel() * inp.element_size())
+
+
+class QuickReduceRegime(Enum):
+ FP = 0
+ INT8 = 1
+ INT6 = 2
+ INT4 = 3
+ NONE = 4
+
+
+MB = 1024 * 1024
+
+
+class QuickAllReduce:
+
+ _SUPPORTED_WORLD_SIZES = [2, 4, 8]
+ _SUPPORTED_DTYPES = [torch.float16, torch.bfloat16]
+ # The following data is based on kernel tests.
+ # In this order [FP, INT8, INT6, INT4].
+ _QR_MIN_SIZE = {
+ (torch.float16, 2): [1 * MB, 2 * MB, 2 * MB, 1 * MB],
+ (torch.float16, 4): [1 * MB, 16 * MB, 4 * MB, 2 * MB],
+ (torch.float16, 8): [16 * MB, 4 * MB, 4 * MB, 2 * MB],
+ (torch.bfloat16, 2): [2 * MB, 8 * MB, 8 * MB, 8 * MB],
+ (torch.bfloat16, 4): [8 * MB, 64 * MB, 64 * MB, 16 * MB],
+ (torch.bfloat16, 8): [16 * MB, 2048 * MB, 2048 * MB, 2048 * MB],
+ }
+
+ def __init__(self, group: ProcessGroup,
+ device: Union[int, str, torch.device]) -> None:
+ """
+ Custom allreduce provides non-destructive acceleration and is
+ available for CUDA and ROCm MI300 series.
+
+ Custom quick allreduce leverages quantization for further
+ acceleration on ROCm. It currently supports Q8, Q6, and Q4
+ quantization formats and FP(float16, bfloat16).
+
+ Quick allreduce is designed as a complement to custom allreduce.
+ Its initialization requires even stricter conditions.
+
+ Only the ROCm MI300 series is supported for quick allreduce at
+ this time.
+
+ Args:
+ group: the process group to work on. If None, it will use the
+ default process group.
+ device: the device to bind the CustomAllreduce to. If None,
+ it will be bind to f"cuda:{local_rank}".
+ It is the caller's responsibility to make sure each communicator
+ is bind to a unique device, and all communicators in this group
+ are in the same node.
+ """
+ self.disabled = True
+ if not self._rocm_arch_available():
+ logger.debug(
+ "Custom quick allreduce is only supported on ROCm MI300 series."
+ )
+ return
+
+ if not quick_ar:
+ # disable because of missing quick reduce library
+ # e.g. in a cuda environment
+ logger.info("Custom quick allreduce is disabled because "
+ "of missing custom quick allreduce library")
+ return
+
+ self.group = group
+ assert dist.get_backend(group) != dist.Backend.NCCL, (
+ "Custom quick allreduce should be attached to a non-NCCL group.")
+ if not all(in_the_same_node_as(group, source_rank=0)):
+ # No need to initialize custom quick allreduce for
+ # multi-node case.
+ logger.warning("Custom quick allreduce is disabled because this "
+ "process group spans across nodes.")
+ return
+ rank = dist.get_rank(group=self.group)
+ world_size = dist.get_world_size(group=self.group)
+ self.rank = rank
+ self.world_size = world_size
+ if world_size == 1:
+ # No need to initialize QuickReduce for single GPU case.
+ return
+
+ if world_size not in QuickAllReduce._SUPPORTED_WORLD_SIZES:
+ logger.warning(
+ "Custom quick allreduce is disabled due to an "
+ "unsupported world size: %d. Supported world sizes: %s.",
+ world_size, str(QuickAllReduce._SUPPORTED_WORLD_SIZES))
+ return
+
+ if isinstance(device, int):
+ device = torch.device(f"cuda:{device}")
+ elif isinstance(device, str):
+ device = torch.device(device)
+ assert isinstance(device, torch.device)
+ self.device = device
+
+ cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
+ if cuda_visible_devices:
+ device_ids = list(map(int, cuda_visible_devices.split(",")))
+ else:
+ device_ids = list(range(cuda_device_count_stateless()))
+ physical_device_id = device_ids[device.index]
+ tensor = torch.tensor([physical_device_id],
+ dtype=torch.int,
+ device="cpu")
+ gather_list = [
+ torch.tensor([0], dtype=torch.int, device="cpu")
+ for _ in range(self.world_size)
+ ]
+ dist.all_gather(gather_list, tensor, group=self.group)
+ physical_device_ids = [t.item() for t in gather_list]
+
+ # test nvlink first, this will filter out most of the cases
+ # where custom quick allreduce is not supported
+ # this checks hardware and driver support for NVLink
+ assert current_platform.is_cuda_alike()
+ self.fully_connected = current_platform.is_fully_connected(
+ physical_device_ids)
+ if self.world_size > 2 and not self.fully_connected:
+ logger.debug(
+ "Custom quick allreduce is disabled because it's not supported "
+ "on more than two PCIe-only GPUs. ")
+ return
+
+ self.init_quick_all_reduce()
+
+ def init_quick_all_reduce(self):
+ # On RocM, bfloat16 kernels are slower than fp16
+ # due to slower match operations
+ # If environment variable is set to 1, we convert input to fp16
+ self.use_fp16_kernels = envs.VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16
+ regime_str = envs.VLLM_ROCM_QUICK_REDUCE_QUANTIZATION
+ if regime_str not in QuickReduceRegime.__members__:
+ logger.warning(
+ "Custom quick allreduce:",
+ f"Invalid quantization level: {regime_str}. "
+ "Supported levels: "
+ f"{list(QuickReduceRegime.__members__.keys())}")
+ return
+
+ if regime_str == "NONE":
+ logger.debug("Custom quick allreduce is disabled based "
+ "on env variable "
+ "VLLM_ROCM_QUICK_REDUCE_QUANTIZATION='NONE'")
+ return
+ self.qr_quant_level = QuickReduceRegime[regime_str]
+ vllm_config = get_current_vllm_config()
+ if vllm_config is not None and \
+ hasattr(vllm_config, "model_config") and \
+ hasattr(vllm_config.model_config, "dtype"):
+ dtype = vllm_config.model_config.dtype
+ if dtype not in [torch.float16, torch.bfloat16]:
+ logger.debug(
+ "Custom quick allreduce disabled: only supports "
+ "float16 and float16, but get %s.", dtype)
+ return
+
+ if dtype == torch.bfloat16 and self.use_fp16_kernels:
+ logger.info(
+ "Custom quick allreduce: BF16 inputs will be converted "
+ "to FP16 to improve performance. set "
+ "envs.VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16=0 "
+ "to turn off.")
+
+ # VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB is specified in MB
+ qr_max_size = envs.VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB
+ if qr_max_size is not None:
+ if qr_max_size < 1:
+ logger.info(
+ "You should not set a max_size smaller than 1MB, which can "
+ "lead to error or degradation to custom allreduce or rccl."
+ )
+ qr_max_size = qr_max_size * MB
+ self._ptr = ops.init_custom_qr(self.rank, self.world_size, qr_max_size)
+ self.qr_max_size = qr_max_size if qr_max_size is not None \
+ else ops.qr_max_size()
+ self.create_shared_buffer()
+ self.disabled = False
+
+ def _rocm_arch_available(self):
+ if not current_platform.is_rocm():
+ return False
+ try:
+ props = torch.cuda.get_device_properties(0)
+ gcn_arch = getattr(props, "gcnArchName", "")
+ supported_archs = ['gfx94', 'gfx95']
+ return any(gfx in gcn_arch for gfx in supported_archs)
+ except Exception as e:
+ logger.warning("Failed to determine ROCm for quick allreduce: %s",
+ e)
+ return False
+
+ def create_shared_buffer(self):
+ """
+ Creates a shared buffer for quickreduce.
+ Has to be called after init_custom_qr
+ """
+ handle = ops.qr_get_handle(self._ptr)
+ world_size = dist.get_world_size(group=self.group)
+ handles = [None] * world_size
+ dist.all_gather_object(handles, handle, group=self.group)
+ ops.qr_open_handles(self._ptr, handles)
+
+ def should_quick_allreduce(self, inp: torch.Tensor):
+ """
+ Check if quickreduce is available
+ """
+ if self.disabled:
+ return False
+ if inp.dtype not in self._SUPPORTED_DTYPES:
+ return False
+ inp_size = inp.numel() * inp.element_size()
+ # custom quick allreduce requires input byte size to be
+ # multiples of 16
+ if inp_size % 16 != 0:
+ return False
+ if not is_weak_contiguous(inp):
+ return False
+ dtype = inp.dtype
+ if self.use_fp16_kernels:
+ dtype = torch.float16
+ return inp_size <= self.qr_max_size and \
+ inp_size >= self._QR_MIN_SIZE[(dtype, self.world_size)]\
+ [self.qr_quant_level.value]
+
+ def quick_all_reduce(self, inp: torch.Tensor, *, out: torch.Tensor = None):
+ """Performs an out-of-place custom quick all reduce."""
+ # quick allreduce doesn't require a separate graph mode,
+ # as QR uses static IPC buffer.
+ if out is None:
+ out = torch.empty_like(inp)
+ ops.qr_all_reduce(self._ptr, inp, out, self.qr_quant_level.value,
+ self.use_fp16_kernels)
+ return out
+
+ def close(self):
+ if not self.disabled and getattr(self, "_ptr", None):
+ if ops is not None:
+ ops.qr_destroy(self._ptr)
+ self._ptr = 0
+ self.disabled = True
+
+ def __del__(self):
+ self.close()
diff --git a/vllm/distributed/eplb/__init__.py b/vllm/distributed/eplb/__init__.py
new file mode 100644
index 000000000000..c87b039afd73
--- /dev/null
+++ b/vllm/distributed/eplb/__init__.py
@@ -0,0 +1,7 @@
+# SPDX-License-Identifier: Apache-2.0
+'''
+Expert parallelism load balancer (EPLB).
+'''
+
+from .eplb_state import *
+from .rebalance_algo import *
diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py
new file mode 100644
index 000000000000..2185df865c1f
--- /dev/null
+++ b/vllm/distributed/eplb/eplb_state.py
@@ -0,0 +1,431 @@
+# SPDX-License-Identifier: Apache-2.0
+"""
+Expert parallelism load balancer (EPLB) metrics and states.
+
+# Glossary
+
+- **Logical Expert**: An expert that is part of the model's logical structure.
+ It holds a set of weights and is replicated across multiple physical
+ experts.
+- **Redundant Expert**: To achieve load balancing, for some popular logical
+ experts, we create additional copies of the expert weights. During inference,
+ each of these copies can be routed to by the same set of tokens.
+- **Physical Expert**: An expert that is instantiated on a specific device.
+ It is a replica of a logical expert and can be rearranged across devices.
+ I.e., one logical expert may have multiple sets of weights initialized on
+ different devices, and each of these sets is a physical expert.
+- **Local Physical Expert**: A physical expert that is instantiated on the
+ current device.
+
+For example: DeepSeek-R1 has 256 logical experts, so each MoE layer
+has 256 sets of linear layer weights in the model parameters. If we add 32
+redundant experts, DeepSeek-R1 will have 256 + 32 = 288 physical experts in
+total. And when deploying, we'll have 288 sets of linear layer weights for each
+MoE layer. If we have 32 EP ranks, then each GPU will hold 288 / 32 = 9 local
+physical experts.
+"""
+
+import time
+from collections.abc import Sequence
+from dataclasses import dataclass
+
+import torch
+from torch.distributed import all_gather, all_reduce
+
+from vllm.config import ParallelConfig
+from vllm.distributed.parallel_state import get_ep_group, get_node_count
+from vllm.logger import init_logger
+from vllm.model_executor.models.interfaces import MixtureOfExperts
+
+from .rebalance_algo import rebalance_experts
+from .rebalance_execute import rearrange_expert_weights_inplace
+
+logger = init_logger(__name__)
+
+
+@dataclass
+class EplbState:
+ """EPLB metrics."""
+
+ physical_to_logical_map: torch.Tensor
+ """
+ Mapping from physical experts to logical experts.
+
+ Shape: (num_moe_layers, num_physical_experts)
+
+ # Example
+
+ For a 2-layer MoE model with 6 physical experts and 4 logical experts on 3
+ EP ranks, the mapping could look like this:
+
+ ```
+ [[0, 1, 2, 3, 0, 1],
+ [0, 2, 0, 1, 0, 3]]
+ ```
+ """
+ logical_to_physical_map: torch.Tensor
+ """
+ Mapping from logical experts to physical experts.
+
+ This is a sparse matrix, where -1 indicates no mapping.
+
+ Shape: (num_moe_layers, num_logical_experts, num_redundant_experts + 1)
+
+ # Example
+
+ For a 2-layer MoE model with 6 physical experts and 4 logical experts on 3
+ EP ranks, the mapping could look like this:
+
+ ```
+ [[[0, 4, -1],
+ [1, 5, -1],
+ [2, -1, -1],
+ [3, -1, -1]],
+ [[0, 2, 4],
+ [3, -1, -1],
+ [1, -1, -1],
+ [5, -1, -1]]]
+ ```
+ """
+ logical_replica_count: torch.Tensor
+ """
+ Number of replicas for each logical expert.
+ This is exactly the non-`-1` count in the `logical_to_physical_map`.
+
+ Shape: (num_moe_layers, num_logical_experts)
+
+ # Example
+ For a 2-layer MoE model with 6 physical experts and 4 logical experts on 3
+ EP ranks, the count could look like this:
+
+ ```
+ [[2, 2, 1, 1],
+ [3, 1, 1, 1]]
+ """
+
+ expert_load_pass: torch.Tensor
+ """
+ Expert load during this forward pass.
+ We use the token count each expert processes as the load.
+
+ Shape: (num_moe_layers, num_local_physical_experts)
+ """
+ expert_load_window: torch.Tensor
+ """
+ A sliding window of expert load.
+
+ Shape: (window_size, num_moe_layers, num_local_physical_experts)
+ """
+ expert_load_window_step: int = 0
+ """
+ Current step in the sliding window.
+
+ Different from `expert_rearrangement_step`, each EP rank may have its own
+ `expert_load_window_step`.
+ """
+ expert_load_window_size: int = 0
+ """
+ Size of the expert load sliding window.
+ This is a constant and is taken from the config.
+ """
+
+ expert_rearrangement_step: int = 0
+ """
+ Steps after last rearrangement.
+ Will trigger a rearrangement if it exceeds the threshold.
+
+ NOTE: Keep in mind that all EP ranks need to have the same
+ `expert_rearrangement_step` value to ensure synchronization.
+ Otherwise, the rearrangement will hang at collective
+ communication calls.
+ """
+ expert_rearrangement_step_interval: int = 0
+ """
+ Interval for expert rearrangement steps.
+ This is a constant and is taken from the config.
+ """
+
+ @staticmethod
+ def build_initial_global_physical_to_logical_map(
+ num_routed_experts: int,
+ num_redundant_experts: int,
+ ) -> Sequence[int]:
+ """
+ Build an initial expert arrangement using the following structure:
+ [original routed experts, redundant experts]
+
+ Returns:
+ physical_to_logical_map (Sequence[int]): A list of integers,
+ where each integer is the index of the logical expert
+ that the corresponding physical expert maps to.
+ """
+ global_physical_to_logical_map = list(range(num_routed_experts))
+ global_physical_to_logical_map += [
+ i % num_routed_experts for i in range(num_redundant_experts)
+ ]
+ return global_physical_to_logical_map
+
+ @classmethod
+ def build(
+ cls,
+ model: MixtureOfExperts,
+ device: torch.device,
+ parallel_config: ParallelConfig,
+ ) -> "EplbState":
+ """
+ Build the initial EPLB state.
+ """
+ physical_to_logical_map_list = (
+ cls.build_initial_global_physical_to_logical_map(
+ model.num_routed_experts,
+ model.num_redundant_experts,
+ ))
+ physical_to_logical_map = torch.tensor(
+ physical_to_logical_map_list,
+ device=device,
+ )
+ logical_to_physical_map = torch.full(
+ (model.num_logical_experts, model.num_redundant_experts + 1),
+ -1,
+ device=device,
+ )
+ logical_replica_count = torch.zeros(
+ (model.num_logical_experts, ),
+ device=device,
+ dtype=torch.long,
+ )
+
+ for i in range(model.num_physical_experts):
+ logical_idx = physical_to_logical_map[i]
+ logical_to_physical_map[logical_idx,
+ logical_replica_count[logical_idx]] = i
+ logical_replica_count[logical_idx] += 1
+
+ # Duplicate initial mapping for all layers
+ physical_to_logical_map = physical_to_logical_map.unsqueeze(0).expand(
+ model.num_moe_layers,
+ -1,
+ ).contiguous()
+ logical_to_physical_map = logical_to_physical_map.unsqueeze(0).expand(
+ model.num_moe_layers,
+ -1,
+ -1,
+ ).contiguous()
+ logical_replica_count = logical_replica_count.unsqueeze(0).expand(
+ model.num_moe_layers,
+ -1,
+ ).contiguous()
+
+ expert_load_pass = torch.zeros(
+ (model.num_moe_layers, model.num_local_physical_experts),
+ dtype=torch.int32,
+ device=device,
+ )
+ expert_load_window_size = parallel_config.eplb_window_size
+ expert_load_window = torch.zeros(
+ (expert_load_window_size, model.num_moe_layers,
+ model.num_local_physical_experts),
+ dtype=torch.int32,
+ device=device,
+ )
+
+ # Set the initial progress of rearrangement to 3/4
+ eplb_step_interval = parallel_config.eplb_step_interval
+ expert_rearrangement_step = max(
+ 0, eplb_step_interval - eplb_step_interval // 4)
+
+ model.set_eplb_state(
+ expert_load_pass,
+ logical_to_physical_map,
+ logical_replica_count,
+ )
+
+ return cls(
+ physical_to_logical_map,
+ logical_to_physical_map,
+ logical_replica_count,
+ expert_load_pass,
+ expert_load_window,
+ expert_load_window_size=expert_load_window_size,
+ expert_rearrangement_step=expert_rearrangement_step,
+ expert_rearrangement_step_interval=eplb_step_interval,
+ )
+
+ def step(self,
+ model: MixtureOfExperts,
+ is_dummy: bool = False,
+ is_profile: bool = False,
+ log_stats: bool = False) -> None:
+ """
+ Step the EPLB state.
+
+ Args:
+ model (MixtureOfExperts): The MoE model.
+ is_dummy (bool): If `True`, this is a dummy step and the load
+ metrics recorded in this forward pass will not count. Defaults
+ to `False`.
+ is_profile (bool): If `True`, perform a dummy rearrangement
+ with maximum communication cost. This is used in `profile_run`
+ to reserve enough memory for the communication buffer.
+ log_stats (bool): If `True`, log the expert load metrics.
+
+ # Stats
+ The metrics are all summed up across layers.
+ - `avg_tokens`: The average load across ranks.
+ - `max_tokens`: The maximum load across ranks.
+ - `balancedness`: The ratio of average load to maximum load.
+ """
+
+ if is_profile:
+ self.rearrange(model, is_profile=True)
+ return
+
+ if is_dummy:
+ # Do not record load metrics for dummy steps
+ self.expert_load_pass.zero_()
+
+ if log_stats:
+ # `num_tokens`: (num_moe_layers,)
+ num_tokens = self.expert_load_pass.sum(dim=-1)
+
+ # Collect load metrics from all ranks
+ ep_group = get_ep_group().device_group
+ num_tokens_list = [
+ torch.empty_like(num_tokens) for _ in range(ep_group.size())
+ ]
+ all_gather(num_tokens_list, num_tokens, group=ep_group)
+ # Stack to get (num_ranks, num_moe_layers)
+ num_tokens_per_rank = torch.stack(num_tokens_list).float()
+
+ # Compute balancedness ratio:
+ # for each layer:
+ # (mean load across ranks) / (max load across ranks)
+ avg_tokens_tensor = num_tokens_per_rank.mean(dim=0).sum(dim=0)
+ max_tokens_tensor = num_tokens_per_rank.max(dim=0).values.sum(
+ dim=0)
+
+ # Just to make type checker happy
+ tokens_tensors: list[float] = torch.stack(
+ [avg_tokens_tensor, max_tokens_tensor]).tolist()
+ avg_tokens, max_tokens = tokens_tensors
+ balancedness = avg_tokens / max_tokens if max_tokens > 0 else 0.0
+
+ if ep_group.rank() == 0:
+ logger.info(
+ "EPLB step: avg_tokens=%.2f, max_tokens=%d, "
+ "balancedness=%.4f", avg_tokens, max_tokens, balancedness)
+
+ # Update the expert load sliding window
+ if not is_dummy:
+ self.expert_load_window[self.expert_load_window_step] = (
+ self.expert_load_pass.clone())
+ self.expert_load_window_step += 1
+ if self.expert_load_window_step >= self.expert_load_window_size:
+ self.expert_load_window_step = 0
+ self.expert_load_pass.zero_()
+
+ # Step the expert rearrangement step
+ # Note that even if this is a dummy step, we still increment the
+ # rearrangement step and perform rearrangement to ensure all ranks are
+ # performing collective communication.
+ self.expert_rearrangement_step += 1
+ if (self.expert_rearrangement_step
+ >= self.expert_rearrangement_step_interval):
+ self.expert_rearrangement_step = 0
+ self.rearrange(model)
+
+ def rearrange(self,
+ model: MixtureOfExperts,
+ is_profile: bool = False) -> None:
+ """
+ Rearrange the experts according to the current load.
+ """
+
+ ep_group = get_ep_group().device_group
+ ep_rank = ep_group.rank()
+
+ time_start = None
+ is_main_rank = ep_rank == 0
+ if is_main_rank:
+ torch.cuda.synchronize()
+ time_start = time.perf_counter()
+ logger.info("Rearranging experts %s...",
+ "(profile)" if is_profile else "")
+
+ # This mapping is only used here, so we do not store it in the state
+ physical_expert_start = ep_rank * model.num_local_physical_experts
+ physical_expert_end = (physical_expert_start +
+ model.num_local_physical_experts)
+ # (num_moe_layers, num_local_physical_experts)
+ local_physical_to_logical_map = self.physical_to_logical_map[
+ :,
+ physical_expert_start:physical_expert_end,
+ ]
+
+ # Map the local physical expert load to global logical experts
+ logical_expert_load_window = torch.zeros(
+ self.expert_load_window_size,
+ model.num_moe_layers,
+ model.num_logical_experts,
+ dtype=self.expert_load_window.dtype,
+ device=self.expert_load_window.device,
+ )
+ logical_expert_load_window.scatter_add_(
+ dim=-1,
+ index=local_physical_to_logical_map.unsqueeze(0).expand_as(
+ self.expert_load_window).long(),
+ src=self.expert_load_window,
+ )
+
+ # Perform all-reduce to get the expert load across all ranks
+ global_expert_load_window = logical_expert_load_window.sum(dim=0)
+ all_reduce(global_expert_load_window, group=ep_group)
+
+ # TODO(bowen): Treat differently for prefill and decode nodes
+ num_replicas = model.num_physical_experts
+ num_groups = model.num_expert_groups
+ num_nodes = get_node_count()
+ num_gpus = ep_group.size()
+
+ if num_gpus % num_nodes != 0:
+ logger.warning_once(
+ f"num_gpus % num_nodes != 0, "
+ "not using hierarchical rearrangement algorithm.\n"
+ f"{num_gpus=}, {num_nodes=}")
+
+ # Get new expert mappings
+ (
+ new_physical_to_logical_map,
+ new_logical_to_physical_map,
+ new_logical_replica_count,
+ ) = (rebalance_experts(
+ global_expert_load_window,
+ num_replicas,
+ num_groups,
+ num_nodes,
+ num_gpus,
+ ))
+
+ # Update expert weights
+ rearrange_expert_weights_inplace(
+ self.physical_to_logical_map,
+ new_physical_to_logical_map,
+ model.expert_weights,
+ ep_group,
+ is_profile,
+ )
+
+ if not is_profile:
+ self.physical_to_logical_map.copy_(new_physical_to_logical_map)
+ self.logical_to_physical_map.copy_(new_logical_to_physical_map)
+ self.logical_replica_count.copy_(new_logical_replica_count)
+
+ if is_main_rank:
+ assert time_start is not None
+ torch.cuda.synchronize()
+ time_end = time.perf_counter()
+ logger.info(
+ "Rearranged experts%sin %.2f seconds.",
+ " (profile) " if is_profile else " ",
+ time_end - time_start,
+ )
diff --git a/vllm/distributed/eplb/rebalance_algo.py b/vllm/distributed/eplb/rebalance_algo.py
new file mode 100644
index 000000000000..7ad6d566b55b
--- /dev/null
+++ b/vllm/distributed/eplb/rebalance_algo.py
@@ -0,0 +1,233 @@
+# SPDX-License-Identifier: Apache-2.0
+"""
+Expert parallelism load balancer (EPLB) for vLLM.
+
+This module implements the core rearrangement algorithm.
+
+The rearrangement algorithm is adapted from
+[DeepSeek EPLB](https://github.com/deepseek-ai/eplb).
+
+Please find at [#12](https://github.com/deepseek-ai/EPLB/issues/12) an example
+on how the EPLB algorithm works.
+"""
+
+import torch
+
+
+def balanced_packing(weight: torch.Tensor,
+ num_packs: int) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Pack n weighted objects to m packs, such that each bin contains exactly
+ n/m objects and the weights of all packs are as balanced as possible.
+
+ Parameters:
+ weight: [X, n], the weight of each item
+ num_packs: number of packs
+
+ Returns:
+ pack_index: [X, n], the pack index of each item
+ rank_in_pack: [X, n], the rank of the item in the pack
+ """
+ num_layers, num_groups = weight.shape
+ assert num_groups % num_packs == 0
+ groups_per_pack = num_groups // num_packs
+
+ if groups_per_pack == 1:
+ pack_index = torch.arange(weight.size(-1),
+ dtype=torch.int64,
+ device=weight.device).expand(weight.shape)
+ rank_in_pack = torch.zeros_like(weight, dtype=torch.int64)
+ return pack_index, rank_in_pack
+
+ indices = weight.float().sort(-1, descending=True).indices.cpu()
+ pack_index = torch.full_like(weight,
+ fill_value=-1,
+ dtype=torch.int64,
+ device="cpu")
+ rank_in_pack = torch.full_like(pack_index, fill_value=-1)
+ for i in range(num_layers):
+ pack_weights = [0] * num_packs
+ pack_items = [0] * num_packs
+ for group in indices[i]:
+ pack = min(
+ (i
+ for i in range(num_packs) if pack_items[i] < groups_per_pack),
+ key=pack_weights.__getitem__,
+ )
+ assert pack_items[pack] < groups_per_pack
+ pack_index[i, group] = pack
+ rank_in_pack[i, group] = pack_items[pack]
+ pack_weights[pack] += weight[i, group]
+ pack_items[pack] += 1
+ return pack_index, rank_in_pack
+
+
+def replicate_experts(
+ weight: torch.Tensor,
+ num_phy: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Replicate `num_log` experts to `num_phy` replicas, such that the maximum
+ load of all replicas is minimized.
+
+ Parameters:
+ weight: [X, num_log]
+ num_phy: total number of experts after replication
+
+ Returns:
+ phy2log: [X, num_phy], logical expert id of each physical expert
+ rank: [X, num_phy], the replica rank
+ logcnt: [X, num_log], number of replicas for each logical expert
+ """
+ n, num_log = weight.shape
+ num_redundant = num_phy - num_log
+ assert num_redundant >= 0
+ device = weight.device
+ phy2log = torch.arange(num_phy, dtype=torch.int64,
+ device=device).repeat(n, 1)
+ rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device)
+ logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device)
+ arangen = torch.arange(n, dtype=torch.int64, device=device)
+ for i in range(num_log, num_phy):
+ redundant_indices = (weight / logcnt).max(dim=-1).indices
+ phy2log[:, i] = redundant_indices
+ rank[:, i] = logcnt[arangen, redundant_indices]
+ logcnt[arangen, redundant_indices] += 1
+ return phy2log, rank, logcnt
+
+
+def rebalance_experts_hierarchical(
+ weight: torch.Tensor,
+ num_physical_experts: int,
+ num_groups: int,
+ num_nodes: int,
+ num_gpus: int,
+):
+ """
+ Parameters:
+ weight: [num_moe_layers, num_logical_experts]
+ num_physical_experts: number of physical experts after replication
+ num_groups: number of expert groups
+ num_nodes: number of server nodes, where the intra-node network
+ (e.g, NVLink) is faster
+ num_gpus: number of GPUs, must be a multiple of `num_nodes`
+
+ Returns:
+ physical_to_logical_map: [num_moe_layers, num_physical_experts]
+ logical_to_physical_map: [num_moe_layers, num_logical_experts, X]
+ logical_count: [num_moe_layers, num_logical_experts]
+ """
+ num_layers, num_logical_experts = weight.shape
+ assert num_logical_experts % num_groups == 0
+ group_size = num_logical_experts // num_groups
+ assert num_groups % num_nodes == 0
+ groups_per_node = num_groups // num_nodes
+ assert num_gpus % num_nodes == 0
+ assert num_physical_experts % num_gpus == 0
+ phy_experts_per_gpu = num_physical_experts // num_gpus
+
+ def inverse(perm: torch.Tensor) -> torch.Tensor:
+ inv = torch.empty_like(perm)
+ inv.scatter_(
+ 1,
+ perm,
+ torch.arange(perm.size(1), dtype=torch.int64,
+ device=perm.device).expand(perm.shape),
+ )
+ return inv
+
+ # Step 1: pack groups to nodes
+ tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1)
+ group_pack_index, group_rank_in_pack = balanced_packing(
+ tokens_per_group, num_nodes)
+ log2mlog = (((group_pack_index * groups_per_node + group_rank_in_pack) *
+ group_size).unsqueeze(-1) +
+ torch.arange(group_size,
+ dtype=torch.int64,
+ device=group_pack_index.device)).flatten(-2)
+ mlog2log = inverse(log2mlog)
+
+ # Step 2: construct redundant experts within nodes
+ # [num_layers * num_nodes, num_logical_experts // num_nodes]
+ tokens_per_mlog = weight.gather(-1, mlog2log).view(
+ -1, num_logical_experts // num_nodes)
+ phy2mlog, phyrank, mlogcnt = replicate_experts(
+ tokens_per_mlog, num_physical_experts // num_nodes)
+
+ # Step 3: pack physical_experts to GPUs
+ # [num_layers * num_nodes, num_physical_experts // num_nodes]
+ tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog)
+ pack_index, rank_in_pack = balanced_packing(tokens_per_phy,
+ num_gpus // num_nodes)
+ phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack
+ pphy2phy = inverse(phy2pphy)
+
+ pphy2mlog = phy2mlog.gather(
+ -1, pphy2phy) # [num_layers * num_nodes, num_log_per_nodes]
+ pphy2mlog = (pphy2mlog.view(num_layers, num_nodes, -1) + torch.arange(
+ 0,
+ num_logical_experts,
+ num_logical_experts // num_nodes,
+ device=group_pack_index.device,
+ ).view(1, -1, 1)).flatten(-2)
+ pphy2log = mlog2log.gather(-1, pphy2mlog)
+ pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1)
+ logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog)
+ return pphy2log, pphyrank, logcnt
+
+
+def rebalance_experts(
+ weight: torch.Tensor,
+ num_replicas: int,
+ num_groups: int,
+ num_nodes: int,
+ num_gpus: int,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Entry point for expert-parallelism load balancer.
+
+ Parameters:
+ weight: [layers, num_logical_experts], the load statistics for all
+ logical experts
+ num_replicas: number of physical experts, must be a multiple of
+ `num_gpus`
+ num_groups: number of expert groups
+ num_nodes: number of server nodes, where the intra-node network
+ (e.g, NVLink) is faster
+ num_gpus: number of GPUs, must be a multiple of `num_nodes`
+
+ Returns:
+ physical_to_logical_map: [layers, num_replicas], the expert index of
+ each replica
+ logical_to_physical_map: [layers, num_logical_experts, X], the replica
+ indices for each expert
+ expert_count: [layers, num_logical_experts], number of physical
+ replicas for each logical expert
+ """
+ num_layers, num_logical_experts = weight.shape
+ weight = weight.float().cpu()
+ if num_groups % num_nodes == 0:
+ # use hierarchical load-balance policy
+ phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
+ weight, num_replicas, num_groups, num_nodes, num_gpus)
+ else:
+ # use global load-balance policy
+ phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
+ weight, num_replicas, 1, 1, num_gpus)
+ num_redundant_experts = num_replicas - num_logical_experts
+ maxlogcnt = num_redundant_experts + 1
+ log2phy: torch.Tensor = torch.full(
+ (num_layers, num_logical_experts, maxlogcnt),
+ -1,
+ dtype=torch.int64,
+ device=logcnt.device,
+ )
+ log2phy.view(num_layers, -1).scatter_(
+ -1,
+ phy2log * maxlogcnt + phyrank,
+ torch.arange(num_replicas, dtype=torch.int64,
+ device=log2phy.device).expand(num_layers, -1),
+ )
+ return phy2log, log2phy, logcnt
+
+
+__all__ = ["rebalance_experts"]
diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py
new file mode 100644
index 000000000000..cf173c734afd
--- /dev/null
+++ b/vllm/distributed/eplb/rebalance_execute.py
@@ -0,0 +1,306 @@
+# SPDX-License-Identifier: Apache-2.0
+"""
+The actual execution of the rearrangement.
+
+This involves the exchange of expert weights between GPUs.
+"""
+
+from collections.abc import Iterable, MutableSequence, Sequence
+from functools import partial
+
+import torch
+from torch.distributed import (P2POp, ProcessGroup, all_gather,
+ batch_isend_irecv, get_global_rank)
+
+
+def idx_local_to_global(
+ local_idx: int,
+ local_cnt: int,
+ ep_rank: int,
+) -> int:
+ """
+ Convert a local expert index to a global expert index.
+ """
+ return ep_rank * local_cnt + local_idx
+
+
+def idx_global_to_local(
+ global_idx: int,
+ local_cnt: int,
+ ep_rank: int,
+) -> int:
+ """
+ Convert a global expert index to a local expert index.
+ """
+ return global_idx - ep_rank * local_cnt
+
+
+def global_idx_to_rank(
+ global_idx: int,
+ local_cnt: int,
+) -> int:
+ """
+ Convert a global expert index to a rank index.
+ """
+ return global_idx // local_cnt
+
+
+def get_ep_ranks_with_expert(
+ idx: int,
+ num_local_experts: int,
+ old_indices: Sequence[int],
+ new_indices: Sequence[int],
+) -> tuple[MutableSequence[int], MutableSequence[int]]:
+ """
+ Get the ranks of the experts that need to be exchanged.
+
+ Args:
+ idx: The index of the expert.
+ num_local_experts: The number of local experts.
+ old_indices: The old indices of the experts.
+ new_indices: The new indices of the experts.
+
+ Returns:
+ A tuple of two lists:
+ - The ranks of the experts that need to be sent.
+ - The ranks of the experts that need to be received.
+ """
+ global2rank = partial(
+ global_idx_to_rank,
+ local_cnt=num_local_experts,
+ )
+
+ ranks_to_send: list[int] = []
+ ranks_to_recv: list[int] = []
+
+ for i, e in enumerate(old_indices):
+ if e == idx:
+ rank = global2rank(i)
+ if not ranks_to_send or ranks_to_send[-1] != rank:
+ ranks_to_send.append(rank)
+
+ for i, e in enumerate(new_indices):
+ if e == idx:
+ rank = global2rank(i)
+ if not ranks_to_recv or ranks_to_recv[-1] != rank:
+ ranks_to_recv.append(rank)
+
+ # Remove those ranks that can get this expert locally.
+ ranks_to_send_set = set(ranks_to_send)
+ ranks_to_recv_actual = [
+ rank for rank in ranks_to_recv if rank not in ranks_to_send_set
+ ]
+
+ return ranks_to_send, ranks_to_recv_actual
+
+
+def shuffle_layer(
+ num_local_experts: int,
+ ep_rank: int,
+ old_indices: Sequence[int],
+ new_indices: Sequence[int],
+ expert_weights: Iterable[torch.Tensor],
+ expert_weights_buffer: Sequence[torch.Tensor],
+ ep_group: ProcessGroup,
+) -> None:
+ """
+ Perform expert weights rearrangement of one layer.
+ """
+ local2global = partial(
+ idx_local_to_global,
+ local_cnt=num_local_experts,
+ ep_rank=ep_rank,
+ )
+
+ # 0. Do nothing for experts that did not change.
+ is_unchanged = [
+ old_indices[local2global(i)] == new_indices[local2global(i)]
+ for i in range(num_local_experts)
+ ]
+
+ # 1. Perform weight copy inside the local rank.
+ is_received_locally = is_unchanged[:]
+ for src in range(num_local_experts):
+ src_global = local2global(src)
+ for dst in range(num_local_experts):
+ dst_global = local2global(dst)
+ if is_received_locally[dst]:
+ continue
+ if old_indices[src_global] == new_indices[dst_global]:
+ is_received_locally[dst] = True
+ for weight, buffer in zip(expert_weights,
+ expert_weights_buffer):
+ buffer[dst].copy_(weight[src])
+
+ p2p_ops: list[P2POp] = []
+
+ # 2. Initiate sending of weights.
+ experts_send_loc: dict[int, int] = {}
+ for src in range(num_local_experts):
+ expert = old_indices[local2global(src)]
+ if expert in experts_send_loc:
+ continue
+ experts_send_loc[expert] = src
+
+ # We need to sort here to match send/recv
+ for expert, src in sorted(experts_send_loc.items()):
+ ranks_to_send, ranks_to_recv = get_ep_ranks_with_expert(
+ expert,
+ num_local_experts,
+ old_indices,
+ new_indices,
+ )
+
+ # Calculate the ranks to send by this rank
+ num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send)
+ sender_pos = ranks_to_send.index(ep_rank)
+ recv_begin = sender_pos * num_dst_per_sender
+ recv_end = recv_begin + num_dst_per_sender
+ recv_ranks = ranks_to_recv[recv_begin:recv_end]
+
+ # Tackle remainders
+ remainder_start = len(ranks_to_send) * num_dst_per_sender
+ recver_pos = remainder_start + sender_pos
+ if recver_pos < len(ranks_to_recv):
+ recv_ranks.append(ranks_to_recv[recver_pos])
+
+ for dst in recv_ranks:
+ dst_global = get_global_rank(ep_group, dst)
+ p2p_ops += [
+ P2POp(
+ torch.distributed.isend,
+ weight[src],
+ dst_global,
+ ) for weight in expert_weights
+ ]
+
+ # 3. Initiate receiving of weights.
+ experts_recv_loc: dict[int, int] = {}
+ for dst in range(num_local_experts):
+ if is_received_locally[dst]:
+ continue
+ expert = new_indices[local2global(dst)]
+ if expert in experts_recv_loc:
+ continue
+ experts_recv_loc[expert] = dst
+
+ # We need to sort here to match send/recv
+ for expert, dst in sorted(experts_recv_loc.items()):
+ ranks_to_send, ranks_to_recv = get_ep_ranks_with_expert(
+ expert,
+ num_local_experts,
+ old_indices,
+ new_indices,
+ )
+
+ # Calculate the rank to recv by this rank
+ num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send)
+ recver_pos = ranks_to_recv.index(ep_rank)
+ remainder_start = len(ranks_to_send) * num_dst_per_sender
+ if recver_pos < remainder_start:
+ src = ranks_to_send[recver_pos // num_dst_per_sender]
+ else:
+ src = ranks_to_send[recver_pos - remainder_start]
+
+ src_global = get_global_rank(ep_group, src)
+ p2p_ops += [
+ P2POp(
+ torch.distributed.irecv,
+ weight[dst],
+ src_global,
+ ) for weight in expert_weights_buffer
+ ]
+
+ # 4. Execute the P2P operations. The real communication happens here.
+ if p2p_ops:
+ reqs = batch_isend_irecv(p2p_ops)
+ for req in reqs:
+ req.wait()
+
+ # 5. Copy the weights from the buffer back to the original weights.
+ for dst in range(num_local_experts):
+ if is_unchanged[dst]:
+ continue
+ if is_received_locally[dst]:
+ for weight, buffer in zip(expert_weights, expert_weights_buffer):
+ weight[dst].copy_(buffer[dst])
+ else:
+ expert = new_indices[local2global(dst)]
+ src = experts_recv_loc[expert]
+ for weight, buffer in zip(expert_weights, expert_weights_buffer):
+ weight[dst].copy_(buffer[src])
+
+
+def rearrange_expert_weights_inplace(
+ old_global_expert_indices: torch.Tensor,
+ new_global_expert_indices: torch.Tensor,
+ expert_weights: Sequence[Iterable[torch.Tensor]],
+ ep_group: ProcessGroup,
+ is_profile: bool = False,
+) -> None:
+ """
+ Rearranges the expert weights in place according to the new expert indices.
+
+ The value of the indices arguments are logical indices of the experts,
+ while keys are physical.
+
+ Args:
+ old_global_expert_indices: Shape (num_moe_layers, num_physical_experts).
+ new_global_expert_indices: Shape (num_moe_layers, num_physical_experts).
+ expert_weights: A sequence of shape (num_moe_layers)(weight_count)
+ of tensors of shape (num_local_physical_experts, hidden_size_i).
+ For example, a linear layer may have up and down projection,
+ so weight_count = 2. Each weight's hidden size can be different.
+ ep_group: The device process group for expert parallelism.
+ is_profile (bool): If `True`, do not perform any actual weight copy.
+ This is used during profile run, where we only perform dummy
+ communications to reserve enough memory for the buffers.
+ """
+ num_moe_layers, num_physical_experts = old_global_expert_indices.shape
+ assert len(expert_weights) == num_moe_layers
+
+ num_local_physical_experts = next(iter(expert_weights[0])).shape[0]
+ assert new_global_expert_indices.shape == (num_moe_layers,
+ num_physical_experts)
+
+ ep_rank = ep_group.rank()
+ ep_size = ep_group.size()
+ assert num_physical_experts == ep_size * num_local_physical_experts
+
+ # A buffer to hold the expert weights in one layer during the exchange.
+ # NOTE: Currently we assume the same weights across different layers
+ # have the same shape.
+ expert_weights_buffer = [torch.empty_like(w) for w in expert_weights[0]]
+
+ if is_profile:
+ # Maximum send size is to send all local experts to all ranks,
+ # So we use a dummy `all_gather` to reserve enough communication buffer
+ for weight, buffer in zip(expert_weights[0], expert_weights_buffer):
+ # A `/dev/null`-like buffer to avoid real memory allocation
+ dummy_recv_buffer = [buffer for _ in range(ep_size)]
+ # NOTE(bowen): Needed this barrier to avoid OOM during actual
+ # execution. I'm not very sure why this is needed
+ torch.distributed.barrier()
+ all_gather(
+ dummy_recv_buffer,
+ weight,
+ group=ep_group,
+ )
+ return
+
+ for layer in range(num_moe_layers):
+ # NOTE(bowen): We need this synchronize to run, but I don't know why.
+ # If you figure out the reason, please let me know -- thank you!
+ torch.cuda.synchronize()
+ shuffle_layer(
+ num_local_physical_experts,
+ ep_rank,
+ old_global_expert_indices[layer].tolist(),
+ new_global_expert_indices[layer].tolist(),
+ expert_weights[layer],
+ expert_weights_buffer,
+ ep_group,
+ )
+
+
+__all__ = ["rearrange_expert_weights_inplace"]
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
index a962a9241d73..7a077dce7706 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
@@ -62,7 +62,6 @@ class NixlAgentMetadata(
agent_metadata: bytes
kv_caches_base_addr: list[int]
num_blocks: int
- tp_size: int
block_len: int
attn_backend_name: str
@@ -73,7 +72,8 @@ class ReqMeta:
remote_block_ids: list[int]
remote_host: str
remote_port: int
- remote_engine_id: EngineId
+ remote_engine_id: str
+ tp_size: int
class NixlConnectorMetadata(KVConnectorMetadata):
@@ -93,6 +93,8 @@ def add_new_req(
remote_engine_id=kv_transfer_params["remote_engine_id"],
remote_host=kv_transfer_params["remote_host"],
remote_port=kv_transfer_params["remote_port"],
+ # P workers don't need to receive tp_size from proxy here.
+ tp_size=kv_transfer_params.get("tp_size", 1),
)
@@ -298,8 +300,21 @@ def request_finished(
logger.debug(
"NIXLConnector request_finished, request_status=%s, "
"kv_transfer_params=%s", request.status, params)
+ if not params:
+ return False, None
+
+ if params.get("do_remote_prefill"):
+ # If do_remote_prefill is still True when the request is finished,
+ # update_state_after_alloc must not have been called (the request
+ # must have been aborted before it was scheduled).
+ # To avoid stranding the prefill blocks in the prefill instance,
+ # we must add empty block_ids to _reqs_need_recv so that our
+ # worker side will notify and free blocks in the prefill instance.
+ self._reqs_need_recv[request.request_id] = (request, [])
+ params["do_remote_prefill"] = False
+ return False, None
- if (params is None or not params.get("do_remote_decode")
+ if (not params.get("do_remote_decode")
or request.status != RequestStatus.FINISHED_LENGTH_CAPPED):
return False, None
@@ -317,7 +332,7 @@ def request_finished(
remote_engine_id=self.engine_id,
remote_host=self.side_channel_host,
remote_port=self.side_channel_port,
- )
+ tp_size=self.vllm_config.parallel_config.tensor_parallel_size)
class NixlConnectorWorker:
@@ -460,7 +475,8 @@ def _nixl_handshake_listener(metadata: NixlAgentMetadata,
"Connection listener got unexpected message %s", msg)
sock.send_multipart((identity, b"", encoded_data))
- def _nixl_handshake(self, host: str, port: int) -> dict[int, str]:
+ def _nixl_handshake(self, host: str, port: int,
+ remote_tp_size: int) -> dict[int, str]:
"""Do a NIXL handshake with a remote instance."""
start_time = time.perf_counter()
@@ -469,7 +485,7 @@ def _nixl_handshake(self, host: str, port: int) -> dict[int, str]:
# a hack to keep us moving. We will switch when moving to etcd
# or where we have a single ZMQ socket in the scheduler.
- def handshake(path: str, rank: int) -> tuple[NixlAgentMetadata, str]:
+ def handshake(path: str, rank: int) -> str:
# Send query for the request.
with zmq_ctx(zmq.REQ, path) as sock:
sock.send(GET_META_MSG)
@@ -479,33 +495,25 @@ def handshake(path: str, rank: int) -> tuple[NixlAgentMetadata, str]:
got_metadata_time = time.perf_counter()
# Register Remote agent.
- remote_agent_name = self.add_remote_agent(metadata, rank)
+ remote_agent_name = self.add_remote_agent(
+ metadata, rank, remote_tp_size)
setup_agent_time = time.perf_counter()
logger.debug("NIXL handshake: get metadata took: %s",
got_metadata_time - start_time)
logger.debug("NIXL handshake: add agent took: %s",
setup_agent_time - got_metadata_time)
- return metadata, remote_agent_name
-
- # Handshake with remote agent-rank0 first to get the tp_size of remote
- path = make_zmq_path("tcp", host, port)
- logger.debug("Querying master rank metadata on path: %s", path)
- rank_to_agent_name: dict[int, str] = {}
- metadata, rank_to_agent_name[0] = handshake(path, 0)
+ return remote_agent_name
- # Handshake only with the other TP remote the current local rank will
+ # Handshake only with the remote TP rank that current local rank will
# pull from. With homogeneous TP it happens to be the same rank_i.
- tp_ratio = self._tp_size[self.engine_id] // metadata.tp_size
+ tp_ratio = self._tp_size[self.engine_id] // remote_tp_size
p_remote_rank = self.tp_rank // tp_ratio
- if p_remote_rank > 0:
- path = make_zmq_path("tcp", host, port + p_remote_rank)
- logger.debug("Querying metadata on path: %s at remote rank %s",
- path, p_remote_rank)
- _, rank_to_agent_name[p_remote_rank] = handshake(
- path, p_remote_rank)
-
- return rank_to_agent_name
+ path = make_zmq_path("tcp", host, port + p_remote_rank)
+ logger.debug("Querying metadata on path: %s at remote rank %s", path,
+ p_remote_rank)
+ # Remote rank -> agent name.
+ return {p_remote_rank: handshake(path, p_remote_rank)}
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Register the KV Cache data in nixl."""
@@ -632,7 +640,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
num_blocks=self.num_blocks,
- tp_size=self.world_size,
block_len=self.block_len,
attn_backend_name=self.backend_name)
ready_event = threading.Event()
@@ -646,7 +653,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
def add_remote_agent(self,
nixl_agent_meta: NixlAgentMetadata,
- remote_tp_rank: int = 0) -> str:
+ remote_tp_rank: int = 0,
+ remote_tp_size: int = 1) -> str:
"""
Add the remote NIXL agent and prepare the descriptors for reading cache
blocks from remote.
@@ -691,9 +699,9 @@ def add_remote_agent(self,
return self._remote_agents[engine_id][remote_tp_rank]
if engine_id in self._tp_size:
- assert self._tp_size[engine_id] == nixl_agent_meta.tp_size
+ assert self._tp_size[engine_id] == remote_tp_size
else:
- self._tp_size[engine_id] = nixl_agent_meta.tp_size
+ self._tp_size[engine_id] = remote_tp_size
# We may eventually enable this after asserting equality in cache
# layout and close outputs.
assert nixl_agent_meta.attn_backend_name == self.backend_name
@@ -743,33 +751,31 @@ def add_remote_agent(self,
# rank. With heterogeneous TP, prepare the descriptors by splitting the
# P KV cache along kv_head dim, of D worker's kv_head size (D>P).
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
- p_remote_tp_rank = self.tp_rank // tp_ratio
# Only register the remote's descriptors if current rank pulls from it.
- if p_remote_tp_rank == remote_tp_rank:
- self.kv_caches_base_addr[
- engine_id] = nixl_agent_meta.kv_caches_base_addr
- rank_offset = self.tp_rank % tp_ratio * self.block_len \
- if not (self.use_mla or is_kv_replicated) else 0
- # Register all remote blocks, but only the corresponding kv heads.
- for base_addr in nixl_agent_meta.kv_caches_base_addr:
- for block_id in range(nixl_agent_meta.num_blocks):
- block_offset = block_id * nixl_agent_meta.block_len
- # For each block, grab the heads chunk belonging to rank_i
- # of size remote_nheads // tp_ratio, which correspond to
- # self.block_len == remote_block_len//tp_ratio bytes.
- addr = base_addr + block_offset + rank_offset
- # (addr, len, device id)
- blocks_data.append((addr, self.block_len, remote_tp_rank))
- logger.debug(
- "Created %s blocks for dst engine %s with remote rank %s and "
- "local rank %s", len(blocks_data), engine_id, remote_tp_rank,
- self.tp_rank)
+ self.kv_caches_base_addr[
+ engine_id] = nixl_agent_meta.kv_caches_base_addr
+ rank_offset = self.tp_rank % tp_ratio * self.block_len \
+ if not (self.use_mla or is_kv_replicated) else 0
+ # Register all remote blocks, but only the corresponding kv heads.
+ for base_addr in nixl_agent_meta.kv_caches_base_addr:
+ for block_id in range(nixl_agent_meta.num_blocks):
+ block_offset = block_id * nixl_agent_meta.block_len
+ # For each block, grab the heads chunk belonging to rank_i
+ # of size remote_nheads // tp_ratio, which correspond to
+ # self.block_len == remote_block_len//tp_ratio bytes.
+ addr = base_addr + block_offset + rank_offset
+ # (addr, len, device id)
+ blocks_data.append((addr, self.block_len, remote_tp_rank))
+ logger.debug(
+ "Created %s blocks for dst engine %s with remote rank %s and "
+ "local rank %s", len(blocks_data), engine_id, remote_tp_rank,
+ self.tp_rank)
- # Register with NIXL.
- descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
- self.dst_xfer_side_handles[
- engine_id] = self.nixl_wrapper.prep_xfer_dlist(
- remote_agent_name, descs)
+ # Register with NIXL.
+ descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
+ self.dst_xfer_side_handles[
+ engine_id] = self.nixl_wrapper.prep_xfer_dlist(
+ remote_agent_name, descs)
return remote_agent_name
@@ -904,7 +910,7 @@ def start_load_kv(self, metadata: NixlConnectorMetadata):
if fut is None:
fut = self._handshake_initiation_executor.submit(
self._nixl_handshake, meta.remote_host,
- meta.remote_port)
+ meta.remote_port, meta.tp_size)
self._handshake_futures[remote_engine_id] = fut
def done_callback(f: Future[dict[int, str]],
@@ -944,13 +950,9 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
remote_block_ids=meta.remote_block_ids,
)
- def _read_blocks(
- self,
- local_block_ids: list[int],
- remote_block_ids: list[int],
- dst_engine_id: str,
- request_id: str,
- ):
+ def _read_blocks(self, local_block_ids: list[int],
+ remote_block_ids: list[int], dst_engine_id: str,
+ request_id: str):
# NOTE(rob): having the staging blocks be on the READER side is
# not going to work well (since we will have to call rearrange tensors).
# after we detect the txn is complete (which means we cannot make the
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
index a47deaf91272..2f870971ded7 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
@@ -371,45 +371,48 @@ def build_connector_meta(
block_size=self._block_size)
self._requests_need_load.pop(new_req.req_id)
- for cached_req in scheduler_output.scheduled_cached_reqs:
+ cached_reqs = scheduler_output.scheduled_cached_reqs
+ for i, req_id in enumerate(cached_reqs.req_ids):
+ num_computed_tokens = cached_reqs.num_computed_tokens[i]
+ new_block_ids = cached_reqs.new_block_ids[i]
+ resumed_from_preemption = cached_reqs.resumed_from_preemption[i]
+
if self.is_producer:
num_scheduled_tokens = (
- scheduler_output.num_scheduled_tokens)[cached_req.req_id]
- num_tokens = (num_scheduled_tokens +
- cached_req.num_computed_tokens)
- assert cached_req.req_id in self.chunked_prefill
- block_ids = cached_req.new_block_ids[0]
- if not cached_req.resumed_from_preemption:
- block_ids = (self.chunked_prefill[cached_req.req_id][0] +
- block_ids)
- prompt_token_ids = self.chunked_prefill[cached_req.req_id][1]
+ scheduler_output.num_scheduled_tokens)[req_id]
+ num_tokens = (num_scheduled_tokens + num_computed_tokens)
+ assert req_id in self.chunked_prefill
+ block_ids = new_block_ids[0]
+ if not resumed_from_preemption:
+ block_ids = (self.chunked_prefill[req_id][0] + block_ids)
+ prompt_token_ids = self.chunked_prefill[req_id][1]
# the request's prompt is chunked prefill again
if num_tokens < len(prompt_token_ids):
- self.chunked_prefill[cached_req.req_id] = (
- block_ids, prompt_token_ids)
+ self.chunked_prefill[req_id] = (block_ids,
+ prompt_token_ids)
continue
# the request's prompt is all prefilled finally
- meta.add_request(request_id=cached_req.req_id,
+ meta.add_request(request_id=req_id,
token_ids=prompt_token_ids,
block_ids=block_ids,
block_size=self._block_size)
- self.chunked_prefill.pop(cached_req.req_id, None)
+ self.chunked_prefill.pop(req_id, None)
continue
# NOTE(rob): here we rely on the resumed requests being
# the first N requests in the list scheduled_cache_reqs.
- if not cached_req.resumed_from_preemption:
+ if not resumed_from_preemption:
break
- if cached_req.req_id in self._requests_need_load:
- request, _ = self._requests_need_load.pop(cached_req.req_id)
- total_tokens = cached_req.num_computed_tokens + 1
+ if req_id in self._requests_need_load:
+ request, _ = self._requests_need_load.pop(req_id)
+ total_tokens = num_computed_tokens + 1
token_ids = request.all_token_ids[:total_tokens]
# NOTE(rob): For resumed req, new_block_ids is all
# of the block_ids for the request.
- block_ids = cached_req.new_block_ids[0]
+ block_ids = new_block_ids[0]
- meta.add_request(request_id=cached_req.req_id,
+ meta.add_request(request_id=req_id,
token_ids=token_ids,
block_ids=block_ids,
block_size=self._block_size)
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
index f86b92692a0e..0bceee19f873 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
@@ -304,23 +304,28 @@ def build_connector_meta(
block_size=self._block_size,
is_store=True)
- for cached_req in scheduler_output.scheduled_cached_reqs:
+ cached_reqs = scheduler_output.scheduled_cached_reqs
+ for i, req_id in enumerate(cached_reqs.req_ids):
+ num_computed_tokens = cached_reqs.num_computed_tokens[i]
+ new_token_ids = cached_reqs.new_token_ids[i]
+ new_block_ids = cached_reqs.new_block_ids[i]
+ resumed_from_preemption = cached_reqs.resumed_from_preemption[i]
+
# NOTE(rob): here we rely on the resumed requests being
# the first N requests in the list scheduled_cache_reqs.
- if not cached_req.resumed_from_preemption:
+ if not resumed_from_preemption:
break
- if cached_req.req_id in self._requests_need_load:
+ if req_id in self._requests_need_load:
# NOTE(rob): cached_req_data does not have the full
# list of token ids (only new tokens). So we look it
# up in the actual request object.
- request = self._requests_need_load[cached_req.req_id]
- total_tokens = (len(cached_req.new_token_ids) +
- cached_req.num_computed_tokens)
+ request = self._requests_need_load[req_id]
+ total_tokens = (len(new_token_ids) + num_computed_tokens)
token_ids = request.all_token_ids[:total_tokens]
# NOTE(rob): For resumed req, new_block_ids is all
# of the block_ids for the request.
- block_ids = cached_req.new_block_ids[0]
+ block_ids = new_block_ids[0]
meta.add_request(token_ids=token_ids,
block_ids=block_ids,
diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py
index 126160b09553..50dbbf50e9fc 100644
--- a/vllm/distributed/parallel_state.py
+++ b/vllm/distributed/parallel_state.py
@@ -802,6 +802,7 @@ def combine(self, hidden_states) -> torch.Tensor:
_WORLD: Optional[GroupCoordinator] = None
+_NODE_COUNT: Optional[int] = None
def get_world_group() -> GroupCoordinator:
@@ -961,10 +962,13 @@ def init_distributed_environment(
local_rank = envs.LOCAL_RANK
else:
local_rank = rank
- global _WORLD
+ global _WORLD, _NODE_COUNT
if _WORLD is None:
ranks = list(range(torch.distributed.get_world_size()))
_WORLD = init_world_group(ranks, local_rank, backend)
+ _NODE_COUNT = _node_count(_WORLD.cpu_group)
+ logger.debug("Detected %d nodes in the distributed environment",
+ _NODE_COUNT)
else:
assert _WORLD.world_size == torch.distributed.get_world_size(), (
"world group already initialized with a different world size")
@@ -1164,6 +1168,13 @@ def get_tensor_model_parallel_rank():
return get_tp_group().rank_in_group
+def get_node_count() -> int:
+ """Return the total number of nodes in the distributed environment. """
+ assert _NODE_COUNT is not None, (
+ "distributed environment is not initialized")
+ return _NODE_COUNT
+
+
def destroy_model_parallel():
"""Set the groups to none and destroy them."""
global _TP
@@ -1189,10 +1200,11 @@ def destroy_model_parallel():
def destroy_distributed_environment():
- global _WORLD
+ global _WORLD, _NODE_COUNT
if _WORLD:
_WORLD.destroy()
_WORLD = None
+ _NODE_COUNT = None
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
@@ -1301,3 +1313,42 @@ def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],
aggregated_data += rank_data
return [x == 1 for x in aggregated_data.tolist()]
+
+
+def _node_count(pg: Union[ProcessGroup, StatelessProcessGroup]) -> int:
+ """
+ Returns the total number of nodes in the process group.
+
+ Args:
+ pg: The process group to analyze
+
+ Returns:
+ int: The total number of nodes
+ """
+ if isinstance(pg, ProcessGroup):
+ world_size = torch.distributed.get_world_size(group=pg)
+ else:
+ world_size = pg.world_size
+
+ if world_size == 1:
+ return 1
+
+ # Build node assignment map
+ node_assignment = [0] * world_size # rank -> node_id
+ next_node_id = 0
+
+ for current_rank in range(world_size):
+ if node_assignment[current_rank] != 0:
+ continue # Already assigned to a node
+
+ # Assign current rank to a new node
+ next_node_id += 1
+ node_assignment[current_rank] = next_node_id
+
+ # Find all ranks on the same node as current_rank
+ same_node_flags = in_the_same_node_as(pg, current_rank)
+ for other_rank, is_same_node in enumerate(same_node_flags):
+ if is_same_node and node_assignment[other_rank] == 0:
+ node_assignment[other_rank] = next_node_id
+
+ return next_node_id
diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py
index 9d1008b6b350..6c908f88b9a9 100644
--- a/vllm/engine/arg_utils.py
+++ b/vllm/engine/arg_utils.py
@@ -320,6 +320,11 @@ class EngineArgs:
data_parallel_rpc_port: Optional[int] = None
data_parallel_backend: str = ParallelConfig.data_parallel_backend
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
+ enable_eplb: bool = ParallelConfig.enable_eplb
+ num_redundant_experts: int = ParallelConfig.num_redundant_experts
+ eplb_window_size: int = ParallelConfig.eplb_window_size
+ eplb_step_interval: int = ParallelConfig.eplb_step_interval
+ eplb_log_balancedness: bool = ParallelConfig.eplb_log_balancedness
max_parallel_loading_workers: Optional[
int] = ParallelConfig.max_parallel_loading_workers
block_size: Optional[BlockSize] = CacheConfig.block_size
@@ -666,6 +671,16 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parallel_group.add_argument(
"--enable-expert-parallel",
**parallel_kwargs["enable_expert_parallel"])
+ parallel_group.add_argument("--enable-eplb",
+ **parallel_kwargs["enable_eplb"])
+ parallel_group.add_argument("--num-redundant-experts",
+ **parallel_kwargs["num_redundant_experts"])
+ parallel_group.add_argument("--eplb-window-size",
+ **parallel_kwargs["eplb_window_size"])
+ parallel_group.add_argument("--eplb-step-interval",
+ **parallel_kwargs["eplb_step_interval"])
+ parallel_group.add_argument("--eplb-log-balancedness",
+ **parallel_kwargs["eplb_log_balancedness"])
parallel_group.add_argument(
"--max-parallel-loading-workers",
**parallel_kwargs["max_parallel_loading_workers"])
@@ -1135,6 +1150,11 @@ def create_engine_config(
data_parallel_rpc_port=data_parallel_rpc_port,
data_parallel_backend=data_parallel_backend,
enable_expert_parallel=self.enable_expert_parallel,
+ enable_eplb=self.enable_eplb,
+ num_redundant_experts=self.num_redundant_experts,
+ eplb_window_size=self.eplb_window_size,
+ eplb_step_interval=self.eplb_step_interval,
+ eplb_log_balancedness=self.eplb_log_balancedness,
max_parallel_loading_workers=self.max_parallel_loading_workers,
disable_custom_all_reduce=self.disable_custom_all_reduce,
ray_workers_use_nsight=self.ray_workers_use_nsight,
diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py
index 8fccf9bd2aa0..25fa1c3058be 100644
--- a/vllm/engine/llm_engine.py
+++ b/vllm/engine/llm_engine.py
@@ -687,6 +687,10 @@ def add_request(
>>> # continue the request processing
>>> ...
"""
+ if not isinstance(request_id, str):
+ raise TypeError(
+ f"request_id must be a string, got {type(request_id)}")
+
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py
index 7951c49f5da0..35ee52ab4601 100644
--- a/vllm/entrypoints/chat_utils.py
+++ b/vllm/entrypoints/chat_utils.py
@@ -293,6 +293,7 @@ def _try_extract_ast(chat_template: str) -> Optional[jinja2.nodes.Template]:
return None
+@lru_cache(maxsize=32)
def _detect_content_format(
chat_template: str,
*,
diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py
index 681633a2aff7..f3fd15486271 100644
--- a/vllm/entrypoints/openai/api_server.py
+++ b/vllm/entrypoints/openai/api_server.py
@@ -14,7 +14,7 @@
import tempfile
import uuid
from argparse import Namespace
-from collections.abc import AsyncIterator
+from collections.abc import AsyncIterator, Awaitable
from contextlib import asynccontextmanager
from functools import partial
from http import HTTPStatus
@@ -30,8 +30,9 @@
from prometheus_client import make_asgi_app
from prometheus_fastapi_instrumentator import Instrumentator
from starlette.concurrency import iterate_in_threadpool
-from starlette.datastructures import State
+from starlette.datastructures import URL, Headers, MutableHeaders, State
from starlette.routing import Mount
+from starlette.types import ASGIApp, Message, Receive, Scope, Send
from typing_extensions import assert_never
import vllm.envs as envs
@@ -1061,6 +1062,74 @@ def load_log_config(log_config_file: Optional[str]) -> Optional[dict]:
return None
+class AuthenticationMiddleware:
+ """
+ Pure ASGI middleware that authenticates each request by checking
+ if the Authorization header exists and equals "Bearer {api_key}".
+
+ Notes
+ -----
+ There are two cases in which authentication is skipped:
+ 1. The HTTP method is OPTIONS.
+ 2. The request path doesn't start with /v1 (e.g. /health).
+ """
+
+ def __init__(self, app: ASGIApp, api_token: str) -> None:
+ self.app = app
+ self.api_token = api_token
+
+ def __call__(self, scope: Scope, receive: Receive,
+ send: Send) -> Awaitable[None]:
+ if scope["type"] not in ("http",
+ "websocket") or scope["method"] == "OPTIONS":
+ # scope["type"] can be "lifespan" or "startup" for example,
+ # in which case we don't need to do anything
+ return self.app(scope, receive, send)
+ root_path = scope.get("root_path", "")
+ url_path = URL(scope=scope).path.removeprefix(root_path)
+ headers = Headers(scope=scope)
+ # Type narrow to satisfy mypy.
+ if url_path.startswith("/v1") and headers.get(
+ "Authorization") != f"Bearer {self.api_token}":
+ response = JSONResponse(content={"error": "Unauthorized"},
+ status_code=401)
+ return response(scope, receive, send)
+ return self.app(scope, receive, send)
+
+
+class XRequestIdMiddleware:
+ """
+ Middleware the set's the X-Request-Id header for each response
+ to a random uuid4 (hex) value if the header isn't already
+ present in the request, otherwise use the provided request id.
+ """
+
+ def __init__(self, app: ASGIApp) -> None:
+ self.app = app
+
+ def __call__(self, scope: Scope, receive: Receive,
+ send: Send) -> Awaitable[None]:
+ if scope["type"] not in ("http", "websocket"):
+ return self.app(scope, receive, send)
+
+ # Extract the request headers.
+ request_headers = Headers(scope=scope)
+
+ async def send_with_request_id(message: Message) -> None:
+ """
+ Custom send function to mutate the response headers
+ and append X-Request-Id to it.
+ """
+ if message["type"] == "http.response.start":
+ response_headers = MutableHeaders(raw=message["headers"])
+ request_id = request_headers.get("X-Request-Id",
+ uuid.uuid4().hex)
+ response_headers.append("X-Request-Id", request_id)
+ await send(message)
+
+ return self.app(scope, receive, send_with_request_id)
+
+
def build_app(args: Namespace) -> FastAPI:
if args.disable_fastapi_docs:
app = FastAPI(openapi_url=None,
@@ -1108,33 +1177,10 @@ async def validation_exception_handler(_: Request,
# Ensure --api-key option from CLI takes precedence over VLLM_API_KEY
if token := args.api_key or envs.VLLM_API_KEY:
-
- @app.middleware("http")
- async def authentication(request: Request, call_next):
- if request.method == "OPTIONS":
- return await call_next(request)
- url_path = request.url.path
- if app.root_path and url_path.startswith(app.root_path):
- url_path = url_path[len(app.root_path):]
- if not url_path.startswith("/v1"):
- return await call_next(request)
- if request.headers.get("Authorization") != "Bearer " + token:
- return JSONResponse(content={"error": "Unauthorized"},
- status_code=401)
- return await call_next(request)
+ app.add_middleware(AuthenticationMiddleware, api_token=token)
if args.enable_request_id_headers:
- logger.warning(
- "CAUTION: Enabling X-Request-Id headers in the API Server. "
- "This can harm performance at high QPS.")
-
- @app.middleware("http")
- async def add_request_id(request: Request, call_next):
- request_id = request.headers.get(
- "X-Request-Id") or uuid.uuid4().hex
- response = await call_next(request)
- response.headers["X-Request-Id"] = request_id
- return response
+ app.add_middleware(XRequestIdMiddleware)
if envs.VLLM_DEBUG_LOG_API_SERVER_RESPONSE:
logger.warning("CAUTION: Enabling log response in the API Server. "
diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py
index dd4bd53046a3..f9bec8451868 100644
--- a/vllm/entrypoints/openai/cli_args.py
+++ b/vllm/entrypoints/openai/cli_args.py
@@ -216,7 +216,7 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"--enable-request-id-headers",
action="store_true",
help="If specified, API server will add X-Request-Id header to "
- "responses. Caution: this hurts performance at high QPS.")
+ "responses.")
parser.add_argument(
"--enable-auto-tool-choice",
action="store_true",
diff --git a/vllm/entrypoints/openai/speech_to_text.py b/vllm/entrypoints/openai/speech_to_text.py
index b23cf6cab097..6c16e5324531 100644
--- a/vllm/entrypoints/openai/speech_to_text.py
+++ b/vllm/entrypoints/openai/speech_to_text.py
@@ -24,6 +24,7 @@
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
+from vllm.model_executor.model_loader.utils import get_model_architecture
from vllm.outputs import RequestOutput
from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils import PlaceholderModule
@@ -38,118 +39,10 @@
logger = init_logger(__name__)
-# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages
-# TODO these configs should live somewhere with the model so we can support
-# additional ones
-
-ISO639_1_SUPPORTED_LANGS = {
- "af": "Afrikaans",
- "ar": "Arabic",
- "hy": "Armenian",
- "az": "Azerbaijani",
- "be": "Belarusian",
- "bs": "Bosnian",
- "bg": "Bulgarian",
- "ca": "Catalan",
- "zh": "Chinese",
- "hr": "Croatian",
- "cs": "Czech",
- "da": "Danish",
- "nl": "Dutch",
- "en": "English",
- "et": "Estonian",
- "fi": "Finnish",
- "fr": "French",
- "gl": "Galician",
- "de": "German",
- "el": "Greek",
- "he": "Hebrew",
- "hi": "Hindi",
- "hu": "Hungarian",
- "is": "Icelandic",
- "id": "Indonesian",
- "it": "Italian",
- "ja": "Japanese",
- "kn": "Kannada",
- "kk": "Kazakh",
- "ko": "Korean",
- "lv": "Latvian",
- "lt": "Lithuanian",
- "mk": "Macedonian",
- "ms": "Malay",
- "mr": "Marathi",
- "mi": "Maori",
- "ne": "Nepali",
- "no": "Norwegian",
- "fa": "Persian",
- "pl": "Polish",
- "pt": "Portuguese",
- "ro": "Romanian",
- "ru": "Russian",
- "sr": "Serbian",
- "sk": "Slovak",
- "sl": "Slovenian",
- "es": "Spanish",
- "sw": "Swahili",
- "sv": "Swedish",
- "tl": "Tagalog",
- "ta": "Tamil",
- "th": "Thai",
- "tr": "Turkish",
- "uk": "Ukrainian",
- "ur": "Urdu",
- "vi": "Vietnamese",
- "cy": "Welsh"
-}
-ISO639_1_OTHER_LANGS = {
- "lo": "Lao",
- "jw": "Javanese",
- "tk": "Turkmen",
- "yi": "Yiddish",
- "so": "Somali",
- "bn": "Bengali",
- "nn": "Norwegian Nynorsk",
- "si": "Sinhala",
- "yo": "Yoruba",
- "sa": "Sanskrit",
- "mi": "Mฤori",
- "fo": "Faroese", # codespell:ignore
- "mt": "Maltese",
- "tg": "Tajik",
- "mg": "Malagasy",
- "haw": "Hawaiian",
- "km": "Khmer",
- "br": "Breton",
- "ps": "Pashto",
- "ln": "Lingala",
- "la": "Latin",
- "ml": "Malayalam",
- "sq": "Albanian",
- "su": "Sundanese",
- "eu": "Basque",
- "ka": "Georgian",
- "uz": "Uzbek",
- "sn": "Shona",
- "ht": "Haitian",
- "as": "Assamese",
- "mn": "Mongolian",
- "te": "Telugu",
- "pa": "Panjabi",
- "tt": "Tatar",
- "gu": "Gujarati",
- "oc": "Occitan",
- "ha": "Hausa",
- "ba": "Bashkir",
- "my": "Burmese",
- "sd": "Sindhi",
- "am": "Amharic",
- "lb": "Luxembourgish",
- "bo": "Tibetan"
-}
-
# As per https://platform.openai.com/docs/guides/speech-to-text#overview.
# TODO configurable
MAX_AUDIO_CLIP_FILESIZE_MB = 25
+MAX_AUDIO_CLIP_SECONDS = 30
OVERLAP_CHUNK_SECOND = 1
MIN_ENERGY_WINDOW_SIZE = 1600 # 1600 ~ 100ms for 16000 Hz audio
@@ -177,10 +70,13 @@ def __init__(
self.default_sampling_params = (
self.model_config.get_diff_sampling_param())
processor = cached_get_processor(model_config.model)
- self.max_audio_clip_s = processor.feature_extractor.chunk_length
+ self.max_audio_clip_s = processor.feature_extractor.chunk_length \
+ if hasattr(processor.feature_extractor, 'chunk_length') \
+ else MAX_AUDIO_CLIP_SECONDS
self.model_sr = processor.feature_extractor.sampling_rate
self.hop_length = processor.feature_extractor.hop_length
self.task_type = task_type
+ self.model_cls, _ = get_model_architecture(model_config)
if self.default_sampling_params:
logger.info(
@@ -196,21 +92,8 @@ async def _preprocess_speech_to_text(
# TODO language should be optional and can be guessed.
# For now we default to en. See
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
- lang_token = f"<|{request.language}|>" if request.language else "<|en|>"
- if request.language:
- if request.language in ISO639_1_SUPPORTED_LANGS:
- pass
- elif request.language in ISO639_1_OTHER_LANGS:
- logger.warning(
- "The selected language %s has limited accuracy with"
- " reported WER>=0.5. Results may be less accurate "
- "for this choice.", request.language)
- else:
- raise ValueError(
- f"Unsupported language: {request.language}."
- "Language should be one of:" +
- f" {list(ISO639_1_SUPPORTED_LANGS.values())}" +
- f"or {list(ISO639_1_OTHER_LANGS.values())}")
+ lang = request.language or "en"
+ self.model_cls.validate_language(lang) # type: ignore[attr-defined]
if len(audio_data) / 1024**2 > MAX_AUDIO_CLIP_FILESIZE_MB:
raise ValueError("Maximum file size exceeded.")
@@ -221,7 +104,9 @@ async def _preprocess_speech_to_text(
y, sr = librosa.load(bytes_, sr=self.model_sr)
duration = librosa.get_duration(y=y, sr=sr)
- chunks = [y] if duration < 30 else self._split_audio(y, int(sr))
+ chunks = [y
+ ] if duration < self.max_audio_clip_s else self._split_audio(
+ y, int(sr))
prompts = []
for chunk in chunks:
prompt = {
@@ -232,8 +117,9 @@ async def _preprocess_speech_to_text(
},
},
"decoder_prompt":
- (f"<|startoftranscript|>{lang_token}"
- f"<|{self.task_type}|><|notimestamps|>{request.prompt}")
+ self.model_cls.
+ get_decoder_prompt( # type: ignore[attr-defined]
+ lang, self.task_type, request.prompt)
}
prompts.append(cast(PromptType, prompt))
return prompts, duration
diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
index ab1cfd4b6eab..c0691f122904 100644
--- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
+++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
@@ -77,8 +77,8 @@ def __init__(self, tokenizer: AnyTokenizer):
self.bot_token_id = self.vocab.get(self.bot_token)
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
if _is_fn_name_regex_support(self.model_tokenizer):
- self.fn_name_regex = re.compile(r'([a-zA-Z0-9_-]+)(\{.*?\})',
- re.DOTALL)
+ self.fn_name_regex = re.compile(
+ r'([a-zA-Z0-9_-]+)(\{[\s\S]*?\})(?=\s*$|,|\s)', re.DOTALL)
else:
self.fn_name_regex = None
diff --git a/vllm/envs.py b/vllm/envs.py
index 745ca626cda1..a46341ccdb15 100644
--- a/vllm/envs.py
+++ b/vllm/envs.py
@@ -119,6 +119,7 @@
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
VLLM_V0_USE_OUTLINES_CACHE: bool = False
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
+ VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None
VLLM_USE_DEEP_GEMM: bool = False
VLLM_XGRAMMAR_CACHE_MB: int = 0
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
@@ -134,6 +135,9 @@
VLLM_KV_CACHE_LAYOUT: Optional[str] = None
VLLM_COMPUTE_NANS_IN_LOGITS: bool = False
VLLM_USE_NVFP4_CT_EMULATIONS: bool = False
+ VLLM_ROCM_QUICK_REDUCE_QUANTIZATION: str = "NONE"
+ VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True
+ VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None
def get_default_cache_root():
@@ -689,6 +693,31 @@ def get_vllm_port() -> Optional[int]:
lambda: (os.getenv("VLLM_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in
("true", "1")),
+ # Custom quick allreduce kernel for MI3* cards
+ # Choice of quantization level: FP, INT8, INT6, INT4 or NONE
+ # Recommended for large models to get allreduce
+ "VLLM_ROCM_QUICK_REDUCE_QUANTIZATION":
+ lambda: os.getenv("VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", "NONE").upper(),
+
+ # Custom quick allreduce kernel for MI3* cards
+ # Due to the lack of the bfloat16 asm instruction, bfloat16
+ # kernels are slower than fp16,
+ # If environment variable is set to 1, the input is converted to fp16
+ "VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16":
+ lambda:
+ (os.getenv("VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16", "True").lower() in
+ ("true", "1")),
+
+ # Custom quick allreduce kernel for MI3* cards.
+ # Controls the maximum allowed number of data bytes(MB) for custom quick
+ # allreduce communication.
+ # Default: 2048 MB.
+ # Data exceeding this size will use either custom allreduce or RCCL
+ # communication.
+ "VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB":
+ lambda: maybe_convert_int(
+ os.environ.get("VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB", None)),
+
# If set, when running in Quark emulation mode, do not dequantize the
# weights at load time. Instead, dequantize weights on-the-fly during
# kernel execution.
@@ -833,6 +862,8 @@ def get_vllm_port() -> Optional[int]:
"VLLM_TPU_BUCKET_PADDING_GAP":
lambda: int(os.environ["VLLM_TPU_BUCKET_PADDING_GAP"])
if "VLLM_TPU_BUCKET_PADDING_GAP" in os.environ else 0,
+ "VLLM_TPU_MOST_MODEL_LEN":
+ lambda: maybe_convert_int(os.environ.get("VLLM_TPU_MOST_MODEL_LEN", None)),
# Allow use of DeepGemm kernels for fused moe ops.
"VLLM_USE_DEEP_GEMM":
diff --git a/vllm/executor/ray_distributed_executor.py b/vllm/executor/ray_distributed_executor.py
index a3f05ec5ea3f..84e8ddd8e274 100644
--- a/vllm/executor/ray_distributed_executor.py
+++ b/vllm/executor/ray_distributed_executor.py
@@ -73,7 +73,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
def _init_executor(self) -> None:
self.forward_dag: Optional[ray.dag.CompiledDAG] = None
- if envs.VLLM_USE_V1:
+ if envs.VLLM_USE_V1 and not current_platform.is_xpu():
# V1 uses SPMD worker and compiled DAG
os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1"
os.environ["VLLM_USE_RAY_COMPILED_DAG"] = "1"
diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py
index 66e78833f52a..fc6e190e5480 100644
--- a/vllm/inputs/registry.py
+++ b/vllm/inputs/registry.py
@@ -5,7 +5,9 @@
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union
import torch
+from packaging.version import Version
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
+from transformers import __version__ as TRANSFORMERS_VERSION
from typing_extensions import TypeVar
from vllm.jsontree import JSONTree, json_map_leaves
@@ -128,9 +130,13 @@ def get_hf_processor(
/,
**kwargs: object,
) -> _P:
+ # Transformers 4.53.0 has issue with passing tokenizer to
+ # initialize processor. We disable it for this version.
+ # See: https://github.com/vllm-project/vllm/issues/20224
+ if Version(TRANSFORMERS_VERSION) != Version("4.53.0"):
+ kwargs["tokenizer"] = self.tokenizer
return super().get_hf_processor(
typ,
- tokenizer=self.tokenizer,
**kwargs,
)
diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py
index 1680b723d6a2..9c88721fb278 100644
--- a/vllm/model_executor/custom_op.py
+++ b/vllm/model_executor/custom_op.py
@@ -141,16 +141,16 @@ def enabled(cls) -> bool:
@staticmethod
def default_on() -> bool:
"""
- On by default if level < CompilationLevel.PIECEWISE
+ On by default if PyTorch Inductor is not used.
Specifying 'all' or 'none' in custom_op takes precedence.
"""
from vllm.config import CompilationLevel
compilation_config = get_current_vllm_config().compilation_config
- custom_ops = compilation_config.custom_ops
- count_none = custom_ops.count("none")
- count_all = custom_ops.count("all")
- return compilation_config.level < CompilationLevel.PIECEWISE and \
- not count_none > 0 or count_all > 0
+ default_on = (compilation_config.level < CompilationLevel.PIECEWISE
+ or not compilation_config.use_inductor)
+ count_none = compilation_config.custom_ops.count("none")
+ count_all = compilation_config.custom_ops.count("all")
+ return default_on and not count_none > 0 or count_all > 0
# Dictionary of all custom ops (classes, indexed by registered name).
# To check if an op with a name is enabled, call .enabled() on the class.
diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py
index 22eadc7290f3..f781494fe396 100644
--- a/vllm/model_executor/layers/activation.py
+++ b/vllm/model_executor/layers/activation.py
@@ -147,6 +147,57 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
+@CustomOp.register("gelu_and_mul_sparse")
+class GeluAndMulSparse(CustomOp):
+ """An activation function for GeluAndMulSparse.
+ This activation function is used in Gemma3n. It computes:
+ up_proj = self.up_proj(x)
+ gate_proj = self.gate_proj(x)
+ gate_proj = self._gaussian_topk(gate_proj) # sparsity
+ activations = self.act_fn(gate_proj) # gelu
+ down_proj = self.down_proj(activations * up_proj)
+ Shapes:
+ x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
+ return: (num_tokens, d) or (batch_size, seq_len, d)
+ """
+
+ def __init__(self, activation_sparsity: float, approximate: str = "none"):
+ super().__init__()
+ # Gelu.
+ self.approximate = approximate
+ if approximate not in ("none", "tanh"):
+ raise ValueError(f"Unknown approximate mode: {approximate}")
+
+ # Sparsity.
+ if activation_sparsity == 0.0:
+ raise ValueError(
+ "activation_sparsity is 0.0. Please use GeluAndMul.")
+ target_sparsity_tensor = torch.tensor(activation_sparsity,
+ dtype=torch.float32)
+ normal_dist = torch.distributions.normal.Normal(0, 1)
+ self.std_multiplier = normal_dist.icdf(target_sparsity_tensor)
+
+ def _gaussian_topk(self, x: torch.Tensor) -> torch.Tensor:
+ """Get % sparse percentile of the Gaussian distribution."""
+ # NOTE(rob): for TP>1, we could all-gather to get the means/std.
+ # But we do not do this because in expectation they are the same
+ # and in practice the eval scores are good without gathering.
+ mean = torch.mean(x, dim=-1, keepdim=True)
+ std = torch.std(x, dim=-1, keepdim=True, unbiased=False)
+ cutoff_x = mean + std * self.std_multiplier
+ return nn.functional.relu(x - cutoff_x)
+
+ def forward_native(self, x: torch.Tensor) -> torch.Tensor:
+ """PyTorch-native implementation equivalent to forward()."""
+ d = x.shape[-1] // 2
+ out = self._gaussian_topk(x[..., :d])
+ out = F.gelu(out, approximate=self.approximate)
+ return out * x[..., d:]
+
+ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
+ return self.forward_native(x)
+
+
@CustomOp.register("gelu_and_mul")
class GeluAndMul(CustomOp):
"""An activation function for GeGLU.
diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
index 70836879d17c..b54ac80535a4 100644
--- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
+++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: Apache-2.0
-import importlib.util
from typing import Optional
import torch
@@ -11,8 +10,6 @@
logger = init_logger(__name__)
-has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
-
@triton.jit
def _silu_mul_fp8_quant_deep_gemm(
diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
index 050d9520ca01..321fb0351ad9 100644
--- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
+++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
-import importlib.util
from typing import Optional
import torch
@@ -12,14 +11,13 @@
_moe_permute)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
-from vllm.model_executor.layers.fused_moe.utils import (
- _resize_cache, per_token_group_quant_fp8)
-from vllm.utils import round_up
+from vllm.model_executor.layers.fused_moe.utils import _resize_cache
+from vllm.model_executor.layers.quantization.utils.fp8_utils import (
+ per_token_group_quant_fp8)
+from vllm.utils import has_deep_gemm, round_up
logger = init_logger(__name__)
-has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
-
@functools.cache
def deep_gemm_block_shape() -> list[int]:
@@ -41,7 +39,7 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor,
gemm kernel. All of M, N, K and the quantization block_shape must be
aligned by `dg.get_m_alignment_for_contiguous_layout()`.
"""
- if not has_deep_gemm:
+ if not has_deep_gemm():
logger.debug("DeepGemm disabled: deep_gemm not available.")
return False
diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py
index c1bae033c2b4..e6f555d315d8 100644
--- a/vllm/model_executor/layers/fused_moe/layer.py
+++ b/vllm/model_executor/layers/fused_moe/layer.py
@@ -1,11 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-import importlib
from abc import abstractmethod
+from collections.abc import Iterable
from dataclasses import dataclass
from enum import Enum
-from typing import Callable, Optional, Union
+from typing import Callable, Literal, Optional, Union, overload
import torch
import torch.nn.functional as F
@@ -20,6 +20,7 @@
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
+from vllm.distributed.eplb.eplb_state import EplbState
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
@@ -30,10 +31,7 @@
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
-from vllm.utils import direct_register_custom_op
-
-has_pplx = importlib.util.find_spec("pplx_kernels") is not None
-has_deepep = importlib.util.find_spec("deep_ep") is not None
+from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx
if current_platform.is_cuda_alike():
from .fused_batched_moe import BatchedTritonExperts
@@ -41,9 +39,9 @@
from .modular_kernel import (FusedMoEModularKernel,
FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize)
- if has_pplx:
+ if has_pplx():
from .pplx_prepare_finalize import PplxPrepareAndFinalize
- if has_deepep:
+ if has_deep_ep():
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SIZE,
DeepEPLLPrepareAndFinalize)
@@ -54,6 +52,8 @@
if is_rocm_aiter_moe_enabled():
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
rocm_aiter_grouped_topk as grouped_topk)
+elif current_platform.is_cpu():
+ pass
else:
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
if current_platform.is_tpu():
@@ -433,6 +433,10 @@ def apply(
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
+ enable_eplb: bool = False,
+ expert_load_view: Optional[torch.Tensor] = None,
+ logical_to_physical_map: Optional[torch.Tensor] = None,
+ logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
@@ -572,7 +576,15 @@ def apply(
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
+ enable_eplb: bool = False,
+ expert_load_view: Optional[torch.Tensor] = None,
+ logical_to_physical_map: Optional[torch.Tensor] = None,
+ logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
+ if enable_eplb:
+ raise NotImplementedError(
+ "EPLB not supported for `UnquantizedFusedMoEMethod` yet.")
+
return self.forward(
x=x,
layer=layer,
@@ -819,6 +831,7 @@ class FusedMoE(torch.nn.Module):
reduce_results: Whether to all all_reduce on the output of the layer
renomalize: Whether to renormalize the logits in the fused_moe kernel
quant_config: Quantization configure.
+ enable_eplb: Whether to enable expert parallelism load balancer.
"""
def __init__(
@@ -843,6 +856,8 @@ def __init__(
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
+ enable_eplb: bool = False,
+ num_redundant_experts: int = 0,
):
super().__init__()
if params_dtype is None:
@@ -858,7 +873,7 @@ def __init__(
get_dp_group().world_size),
vllm_parallel_config=vllm_config.parallel_config))
- self.global_num_experts = num_experts
+ self.global_num_experts = num_experts + num_redundant_experts
# For smuggling this layer into the fused moe custom op
compilation_config = vllm_config.compilation_config
@@ -867,8 +882,20 @@ def __init__(
compilation_config.static_forward_context[prefix] = self
self.layer_name = prefix
+ self.enable_eplb = enable_eplb
+ self.expert_load_view: Optional[torch.Tensor] = None
+ self.logical_to_physical_map: Optional[torch.Tensor] = None
+ self.logical_replica_count: Optional[torch.Tensor] = None
+
# Determine expert maps
if self.use_ep:
+ if self.enable_eplb:
+ assert self.global_num_experts % self.ep_size == 0, \
+ "EPLB currently only supports even distribution of " \
+ "experts across ranks."
+ else:
+ assert num_redundant_experts == 0, \
+ "Redundant experts are only supported with EPLB."
self.local_num_experts, self.expert_map = determine_expert_map(
ep_size=self.ep_size,
ep_rank=self.ep_rank,
@@ -935,6 +962,20 @@ def __init__(
assert isinstance(quant_method, FusedMoEMethodBase)
self.quant_method = quant_method
+ if self.enable_eplb:
+ from vllm.model_executor.layers.quantization.fp8 import (
+ Fp8MoEMethod)
+ if not isinstance(quant_method, Fp8MoEMethod):
+ # TODO: Add support for additional quantization methods.
+ # The implementation for other quantization methods does not
+ # contain essential differences, but the current quant API
+ # design causes duplicated work when extending to new
+ # quantization methods, so I'm leaving it for now.
+ # If you plan to add support for more quantization methods,
+ # please refer to the implementation in `Fp8MoEMethod`.
+ raise NotImplementedError("EPLB is only supported for FP8 "
+ "quantization for now.")
+
moe_quant_params = {
"num_experts": self.local_num_experts,
"hidden_size": hidden_size,
@@ -963,8 +1004,9 @@ def __init__(
dtype=act_dtype,
device=torch.cuda.current_device())
+ # Note here we use `num_experts` which is logical expert count
self.batched_router_logits = torch.zeros(
- (envs.VLLM_MOE_DP_CHUNK_SIZE, self.global_num_experts),
+ (envs.VLLM_MOE_DP_CHUNK_SIZE, num_experts),
dtype=act_dtype,
device=torch.cuda.current_device())
@@ -1128,13 +1170,33 @@ def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int:
return expert_id
return self.expert_map[expert_id].item()
+ @overload
def weight_loader(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, weight_name: str,
- shard_id: str, expert_id: int) -> None:
+ shard_id: str, expert_id: int,
+ return_success: Literal[False]) -> None:
+ ...
+ @overload
+ def weight_loader(self, param: torch.nn.Parameter,
+ loaded_weight: torch.Tensor, weight_name: str,
+ shard_id: str, expert_id: int,
+ return_success: Literal[True]) -> bool:
+ ...
+
+ def weight_loader(self,
+ param: torch.nn.Parameter,
+ loaded_weight: torch.Tensor,
+ weight_name: str,
+ shard_id: str,
+ expert_id: int,
+ return_success: bool = False) -> Optional[bool]:
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
if expert_id == -1:
- return
+ # Failed to load this param since it's not local to this rank
+ return False if return_success else None
+ # Hereafter, `expert_id` is local physical id
+
quant_method_name = self.quant_method.__class__.__name__
# compressed-tensors checkpoints with packed weights are stored flipped
# TODO (mgoin): check self.quant_method.quant_config.quant_format
@@ -1161,7 +1223,7 @@ def weight_loader(self, param: torch.nn.Parameter,
if is_gguf_weight_type:
param.weight_type = loaded_weight.item()
param.data.copy_(loaded_weight)
- return
+ return True if return_success else None
# is_transposed: if the dim to shard the weight
# should be flipped. Required by GPTQ, compressed-tensors
@@ -1184,6 +1246,7 @@ def weight_loader(self, param: torch.nn.Parameter,
param.materialize(final_shape, dtype=loaded_weight.dtype)
expert_data = param.data if full_load else param.data[expert_id]
+
# Case input scale: input_scale loading is only supported for fp8
if "input_scale" in weight_name:
# this is needed for compressed-tensors only
@@ -1200,7 +1263,7 @@ def weight_loader(self, param: torch.nn.Parameter,
self._load_single_value(param=param,
loaded_weight=loaded_weight,
expert_id=expert_id)
- return
+ return True if return_success else None
# Case g_idx
if "g_idx" in weight_name:
@@ -1209,8 +1272,9 @@ def weight_loader(self, param: torch.nn.Parameter,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=self.tp_rank)
- return
+ return True if return_success else None
+ # TODO @dsikka: ModelOpt should follow the proper MoE loading pattern
if "ModelOpt" in quant_method_name:
if ('weight_scale_2' in weight_name
or 'input_scale' in weight_name):
@@ -1225,9 +1289,9 @@ def weight_loader(self, param: torch.nn.Parameter,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=self.tp_rank)
- return
+ return True if return_success else None
- # Case weight scales, zero_points and offset
+ # Case weight scales, zero_points and offset, weight/input global scales
if ("scale" in weight_name or "zero" in weight_name
or "offset" in weight_name):
# load the weight scales and zp based on the quantization scheme
@@ -1262,7 +1326,7 @@ def weight_loader(self, param: torch.nn.Parameter,
else:
raise ValueError(
f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}")
- return
+ return True if return_success else None
# Case weight_shape
if "weight_shape" in weight_name:
@@ -1270,7 +1334,7 @@ def weight_loader(self, param: torch.nn.Parameter,
self._load_single_value(param=param,
loaded_weight=loaded_weight,
expert_id=expert_id)
- return
+ return True if return_success else None
# Case model weights
if "weight" in weight_name:
@@ -1280,23 +1344,77 @@ def weight_loader(self, param: torch.nn.Parameter,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=self.tp_rank)
- return
+ return True if return_success else None
+
+ return False if return_success else None
+
+ def get_expert_weights(self) -> Iterable[torch.Tensor]:
+ weights = list(self.named_parameters())
+ assert all(weight.is_contiguous() for _, weight in weights)
+
+ # Filter out the non-expert weights.
+ # `e_score_correction_bias` is a bias for each logical expert,
+ # with shape (num_logical_experts,), not an expert weight.
+ NON_EXPERT_WEIGHTS = {
+ "e_score_correction_bias",
+ }
+
+ return [
+ weight.view(self.local_num_experts, -1) for name, weight in weights
+ if name not in NON_EXPERT_WEIGHTS
+ ]
+
+ def set_eplb_state(
+ self,
+ moe_layer_idx: int,
+ expert_load_view: torch.Tensor,
+ logical_to_physical_map: torch.Tensor,
+ logical_replica_count: torch.Tensor,
+ ) -> None:
+ """
+ Register the EPLB state in this layer.
+
+ This is used later in forward pass, where we get the expert mapping
+ and record the load metrics in `expert_load_view`.
+ """
+ self.expert_load_view = expert_load_view[moe_layer_idx]
+ self.logical_to_physical_map = logical_to_physical_map[moe_layer_idx]
+ self.logical_replica_count = logical_replica_count[moe_layer_idx]
@staticmethod
- def select_experts(hidden_states: torch.Tensor,
- router_logits: torch.Tensor,
- top_k: int,
- use_grouped_topk: bool,
- renormalize: bool,
- topk_group: Optional[int] = None,
- num_expert_group: Optional[int] = None,
- custom_routing_function: Optional[Callable] = None,
- scoring_func: str = "softmax",
- e_score_correction_bias: Optional[torch.Tensor] = None,
- indices_type: Optional[torch.dtype] = None):
+ def select_experts(
+ hidden_states: torch.Tensor,
+ router_logits: torch.Tensor,
+ top_k: int,
+ use_grouped_topk: bool,
+ renormalize: bool,
+ topk_group: Optional[int] = None,
+ num_expert_group: Optional[int] = None,
+ custom_routing_function: Optional[Callable] = None,
+ scoring_func: str = "softmax",
+ e_score_correction_bias: Optional[torch.Tensor] = None,
+ indices_type: Optional[torch.dtype] = None,
+ enable_eplb: bool = False,
+ expert_map: Optional[torch.Tensor] = None,
+ expert_load_view: Optional[torch.Tensor] = None,
+ logical_to_physical_map: Optional[torch.Tensor] = None,
+ logical_replica_count: Optional[torch.Tensor] = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Route the input hidden states to the top-k experts based on the
+ router logits.
+
+ Returns:
+ (topk_weights, topk_ids) (tuple[torch.Tensor, torch.Tensor]):
+ The weights and *global physical* expert ids of the top-k experts.
+
+ **Compatibility**: When EPLB is not enabled, the returned ids are
+ equivalent to global logical ids, so should be compatible with
+ plain MoE implementations without redundant experts.
+ """
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
- # DeekSeekv2 uses grouped_top_k
+ # DeepSeekv2 uses grouped_top_k
if use_grouped_topk:
assert topk_group is not None
assert num_expert_group is not None
@@ -1328,6 +1446,74 @@ def select_experts(hidden_states: torch.Tensor,
if indices_type is not None:
topk_ids = topk_ids.to(dtype=indices_type)
+ if enable_eplb:
+ assert expert_load_view is not None
+ assert logical_to_physical_map is not None
+ assert logical_replica_count is not None
+
+ # 1. Convert the logical expert ids to physical expert ids
+ # Directly select a random replica for each logical expert
+
+ # TODO: maybe optimize this by using specified kernels,
+ # or compute pseudo-random indices by modulo
+
+ # In case `indices_type` is not `torch.long` or `torch.int`,
+ # e.g. `torch.uint32` as required by dispatch/combine kernels
+ topk_ids_long = topk_ids.long()
+ replica_indices = (
+ torch.rand_like(topk_ids, dtype=torch.float) *
+ logical_replica_count[topk_ids_long]).long().unsqueeze(-1)
+ physical_ids = logical_to_physical_map[topk_ids_long].gather(
+ -1, replica_indices).squeeze(-1)
+
+ topk_ids = physical_ids
+
+ # 2. Record expert load metrics.
+
+ # TODO(bowen): When using `FusedMoEModularKernel`, this
+ # can be done in a more unified way, since
+ # `FusedMoEPrepareAndFinalize` will return the expert
+ # token count, in some cases directly from the kernel.
+ # However, now there are many code paths not using
+ # the modular kernel, e.g. calling `fused_experts`,
+ # so we decide to keep the logic here.
+ #
+ # If later refactor moved all the MoE kernel calls
+ # to the modular kernel, we can move this logic there
+ # to achieve better efficiency.
+
+ # `expert_load_view`: (num_logical_experts,)
+
+ # Mask out non-local experts
+ if expert_map is not None:
+ topk_ids_local = expert_map[topk_ids]
+ topk_ids_flatten = topk_ids_local.flatten()
+ else:
+ topk_ids_flatten = topk_ids.flatten()
+
+ # Should be equivalent to:
+ # ```
+ # topk_ids_masked = topk_ids_local[topk_ids_local >= 0]
+ # expert_load_view += topk_ids_masked.bincount(
+ # minlength=expert_load_view.shape[0])
+ # ```
+ # We use `scatter_add_` since `bincount` cannot be compiled
+
+ # Performance optimization:
+ # `masked_fill` is significantly faster than `masked_select`
+ invalid_mask = topk_ids_flatten < 0
+ # Replace invalid expert ids with 0 (just a dummy position)
+ # to avoid out-of-bounds errors in scatter_add_
+ index = topk_ids_flatten.masked_fill_(invalid_mask, 0)
+ # `src` is the valid mask, which is 1 for valid and 0 for invalid
+ src = ~invalid_mask
+
+ expert_load_view.scatter_add_(dim=0,
+ index=index.long(),
+ src=src.to(expert_load_view))
+
+ topk_ids = topk_ids.to(dtype=indices_type)
+
return topk_weights, topk_ids
def must_reduce_shared_expert_outputs(self) -> bool:
@@ -1408,6 +1594,10 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
+ enable_eplb=self.enable_eplb,
+ expert_load_view=self.expert_load_view,
+ logical_to_physical_map=self.logical_to_physical_map,
+ logical_replica_count=self.logical_replica_count,
)
if not skip_result_store:
@@ -1465,6 +1655,10 @@ def forward_impl(self, hidden_states: torch.Tensor,
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
+ enable_eplb=self.enable_eplb,
+ expert_load_view=self.expert_load_view,
+ logical_to_physical_map=self.logical_to_physical_map,
+ logical_replica_count=self.logical_replica_count,
)
if do_naive_dispatch_combine:
@@ -1479,16 +1673,30 @@ def forward_impl(self, hidden_states: torch.Tensor,
@classmethod
def make_expert_params_mapping(
- cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str,
+ cls,
+ ckpt_gate_proj_name: str,
+ ckpt_down_proj_name: str,
ckpt_up_proj_name: str,
- num_experts: int) -> list[tuple[str, str, int, str]]:
+ num_experts: int,
+ num_redundant_experts: int = 0) -> list[tuple[str, str, int, str]]:
+
+ num_physical_experts = num_experts + num_redundant_experts
+
+ # In the returned mapping:
+ # - `expert_id` is the physical expert id
+ # - `weight_name` contains the weight name of the logical expert
+ # So that we should map the expert id to logical in `weight_name`
+ physical_to_logical_map = \
+ EplbState.build_initial_global_physical_to_logical_map(
+ num_experts, num_redundant_experts)
return [
# (param_name, weight_name, expert_id, shard_id)
("experts.w13_" if weight_name
in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_",
- f"experts.{expert_id}.{weight_name}.", expert_id, shard_id)
- for expert_id in range(num_experts) for shard_id, weight_name in [
+ f"experts.{physical_to_logical_map[expert_id]}.{weight_name}.",
+ expert_id, shard_id) for expert_id in range(num_physical_experts)
+ for shard_id, weight_name in [
("w1", ckpt_gate_proj_name),
("w2", ckpt_down_proj_name),
("w3", ckpt_up_proj_name),
@@ -1533,7 +1741,8 @@ def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor,
direct_register_custom_op(
op_name="moe_forward",
op_func=moe_forward,
- mutates_args=[],
+ mutates_args=["hidden_states"],
fake_impl=moe_forward_fake,
dispatch_key=current_platform.dispatch_key,
+ tags=(torch.Tag.needs_fixed_stride_order, ),
)
diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py
index 692482c2ea69..4c91e697f8e9 100644
--- a/vllm/model_executor/layers/fused_moe/utils.py
+++ b/vllm/model_executor/layers/fused_moe/utils.py
@@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import socket
+from contextlib import closing
from math import prod
from typing import Optional
@@ -96,3 +98,10 @@ def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype)
else:
return m[idx, ...]
+
+
+def find_free_port():
+ with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
+ s.bind(('', 0))
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ return s.getsockname()[1]
diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py
index 56d803c6baf1..aff54bc495b2 100644
--- a/vllm/model_executor/layers/quantization/awq_marlin.py
+++ b/vllm/model_executor/layers/quantization/awq_marlin.py
@@ -482,7 +482,15 @@ def apply(
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
+ enable_eplb: bool = False,
+ expert_load_view: Optional[torch.Tensor] = None,
+ logical_to_physical_map: Optional[torch.Tensor] = None,
+ logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
+ if enable_eplb:
+ raise NotImplementedError(
+ "EPLB not supported for `AWQMoEMethod` yet.")
+
assert activation == "silu", "Only SiLU activation is supported."
if apply_router_weight_on_input:
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
index d21abb2741a2..4f87b2a44f0a 100644
--- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
@@ -33,6 +33,8 @@
find_matched_target, is_activation_quantization_format,
should_ignore_layer)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
+from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
+ cutlass_fp4_supported)
from vllm.platforms import current_platform
logger = init_logger(__name__)
@@ -375,7 +377,7 @@ def _get_scheme_from_parts(
if is_activation_quantization_format(self.quant_format):
if self._is_fp4a4_nvfp4(weight_quant, input_quant):
- if CompressedTensorsW4A4Fp4.cutlass_fp4_supported(
+ if cutlass_fp4_supported(
) or envs.VLLM_USE_NVFP4_CT_EMULATIONS:
return CompressedTensorsW4A4Fp4()
else:
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
index f14131c5f05b..fa4ce5668091 100644
--- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
-import importlib
from enum import Enum
from typing import Callable, Optional
@@ -22,20 +21,23 @@
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_moe_marlin_supports_layer, marlin_make_workspace_new,
marlin_moe_permute_scales)
+from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
+ prepare_moe_fp4_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
prepare_moe_fp8_layer_for_marlin)
+from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
+ cutlass_fp4_supported)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
-
-has_pplx = importlib.util.find_spec("pplx_kernels") is not None
+from vllm.utils import has_pplx
if current_platform.is_cuda_alike():
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize)
- if has_pplx:
+ if has_pplx():
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize)
@@ -48,12 +50,11 @@ class GPTQMarlinState(Enum):
__all__ = [
- "CompressedTensorsMoEMethod",
- "CompressedTensorsW8A8Fp8MoEMethod",
+ "CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod",
"CompressedTensorsW8A8Fp8MoECutlassMethod",
"CompressedTensorsW8A8Int8MoEMethod",
- "CompressedTensorsWNA16MarlinMoEMethod",
- "CompressedTensorsWNA16MoEMethod",
+ "CompressedTensorsWNA16MarlinMoEMethod", "CompressedTensorsWNA16MoEMethod",
+ "CompressedTensorsW4A4MoeMethod"
]
@@ -86,6 +87,8 @@ def get_moe_method(
else:
logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod")
return CompressedTensorsWNA16MarlinMoEMethod(quant_config)
+ elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant):
+ return CompressedTensorsW4A4MoeMethod()
elif quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant):
return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config)
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
@@ -97,6 +100,268 @@ def get_moe_method(
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")
+class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
+
+ def __init__(self):
+ self.use_marlin = not cutlass_fp4_supported()
+ self.group_size = 16
+
+ def create_weights(self, layer: torch.nn.Module, num_experts: int,
+ hidden_size: int, intermediate_size_per_partition: int,
+ params_dtype: torch.dtype, **extra_weight_attrs):
+
+ layer.num_experts = num_experts
+ layer.params_dtype = params_dtype
+
+ w13_weight = torch.nn.Parameter(
+ torch.empty(
+ num_experts,
+ 2 * intermediate_size_per_partition,
+ # 2 fp4 items are packed in the input dimension
+ hidden_size // 2,
+ requires_grad=False,
+ dtype=torch.uint8),
+ requires_grad=False)
+ layer.register_parameter("w13_weight_packed", w13_weight)
+ set_weight_attrs(w13_weight, extra_weight_attrs)
+
+ w2_weight = torch.nn.Parameter(
+ torch.empty(
+ num_experts,
+ hidden_size,
+ # 2 fp4 items are packed in the input dimension
+ intermediate_size_per_partition // 2,
+ dtype=torch.uint8),
+ requires_grad=False)
+ layer.register_parameter("w2_weight_packed", w2_weight)
+ set_weight_attrs(w2_weight, extra_weight_attrs)
+
+ # Weight Scales
+ w13_weight_scale = torch.nn.Parameter(
+ torch.empty(
+ num_experts,
+ 2 * intermediate_size_per_partition,
+ # 2 fp4 items are packed in the input dimension
+ hidden_size // self.group_size,
+ dtype=torch.float8_e4m3fn),
+ requires_grad=False)
+ layer.register_parameter("w13_weight_scale", w13_weight_scale)
+ extra_weight_attrs.update(
+ {"quant_method": FusedMoeWeightScaleSupported.GROUP.value})
+ set_weight_attrs(w13_weight_scale, extra_weight_attrs)
+
+ w2_weight_scale = torch.nn.Parameter(
+ torch.empty(
+ num_experts,
+ hidden_size,
+ # 2 fp4 items are packed in the input dimension
+ intermediate_size_per_partition // self.group_size,
+ dtype=torch.float8_e4m3fn),
+ requires_grad=False)
+ layer.register_parameter("w2_weight_scale", w2_weight_scale)
+ extra_weight_attrs.update(
+ {"quant_method": FusedMoeWeightScaleSupported.GROUP.value})
+ set_weight_attrs(w2_weight_scale, extra_weight_attrs)
+
+ # Weight Global Scales
+ w13_weight_scale_2 = torch.nn.Parameter(torch.empty(
+ num_experts, 2, dtype=torch.float32),
+ requires_grad=False)
+ layer.register_parameter("w13_weight_global_scale", w13_weight_scale_2)
+ extra_weight_attrs.update(
+ {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
+ set_weight_attrs(w13_weight_scale_2, extra_weight_attrs)
+
+ w2_weight_scale_2 = torch.nn.Parameter(torch.empty(
+ num_experts, dtype=torch.float32),
+ requires_grad=False)
+ layer.register_parameter("w2_weight_global_scale", w2_weight_scale_2)
+ extra_weight_attrs.update(
+ {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
+ set_weight_attrs(w2_weight_scale_2, extra_weight_attrs)
+
+ # Input Global Scales
+ w13_input_scale = torch.nn.Parameter(torch.empty(num_experts,
+ 2,
+ dtype=torch.float32),
+ requires_grad=False)
+ layer.register_parameter("w13_input_global_scale", w13_input_scale)
+ extra_weight_attrs.update(
+ {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
+ set_weight_attrs(w13_input_scale, extra_weight_attrs)
+
+ w2_input_scale = torch.nn.Parameter(torch.empty(num_experts,
+ dtype=torch.float32),
+ requires_grad=False)
+ layer.register_parameter("w2_input_global_scale", w2_input_scale)
+ extra_weight_attrs.update(
+ {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
+ set_weight_attrs(w2_input_scale, extra_weight_attrs)
+
+ def swizzle_blockscale(self, scale: torch.tensor):
+ assert (scale.dtype == torch.float8_e4m3fn)
+ # Pad and blockwise interleave weight_scale
+ scale_ndim = scale.ndim
+ if scale.ndim == 2:
+ scale = scale.unsqueeze(0)
+ assert scale.ndim == 3
+ B, M, K = scale.shape
+ round_up_multiple = lambda x, m: (x + m - 1) // m * m
+ M_padded = round_up_multiple(M, 128)
+ K_padded = round_up_multiple(K, 4)
+ padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
+ padded_scale[:B, :M, :K] = scale
+ batches, rows, cols = padded_scale.shape
+ assert rows % 128 == 0
+ assert cols % 4 == 0
+ padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
+ cols // 4, 4)
+ swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
+ swizzled_scale = swizzled_scale.contiguous().cuda()
+ return (swizzled_scale.reshape(M, K)
+ if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))
+
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
+
+ # From packed to weight
+ layer.w13_weight = torch.nn.Parameter(layer.w13_weight_packed.data,
+ requires_grad=False)
+
+ layer.w2_weight = torch.nn.Parameter(layer.w2_weight_packed.data,
+ requires_grad=False)
+
+ if not torch.allclose(layer.w13_weight_global_scale[:, 0],
+ layer.w13_weight_global_scale[:, 1]):
+ logger.warning_once(
+ "w1_weight_global_scale must match w3_weight_global_scale. "
+ "Accuracy may be affected.")
+
+ # Take inverse of global scale saved to disk
+ layer.w13_weight_scale_2 = torch.nn.Parameter(
+ 1 / layer.w13_weight_global_scale[:, 0], requires_grad=False)
+
+ layer.w2_weight_scale_2 = torch.nn.Parameter(
+ 1 / layer.w2_weight_global_scale.data, requires_grad=False)
+
+ if self.use_marlin:
+ prepare_moe_fp4_layer_for_marlin(layer)
+ return
+
+ # swizzle weight scales
+ layer.w13_blockscale_swizzled = torch.nn.Parameter(
+ self.swizzle_blockscale(layer.w13_weight_scale),
+ requires_grad=False)
+
+ layer.w2_blockscale_swizzled = torch.nn.Parameter(
+ self.swizzle_blockscale(layer.w2_weight_scale),
+ requires_grad=False)
+
+ # w13
+ w13_input_global_scale = layer.w13_input_global_scale.max(
+ dim=1).values.to(torch.float32)
+
+ layer.g1_alphas = torch.nn.Parameter(
+ ((1 / w13_input_global_scale) * layer.w13_weight_scale_2),
+ requires_grad=False)
+
+ layer.w13_input_scale_quant = torch.nn.Parameter(
+ (w13_input_global_scale), requires_grad=False)
+
+ # w2
+ layer.g2_alphas = torch.nn.Parameter(
+ ((1 / layer.w2_input_global_scale) * layer.w2_weight_scale_2).to(
+ torch.float32),
+ requires_grad=False)
+
+ layer.w2_input_scale_quant = torch.nn.Parameter(
+ (layer.w2_input_global_scale), requires_grad=False)
+
+ def apply(
+ self,
+ layer: torch.nn.Module,
+ x: torch.Tensor,
+ router_logits: torch.Tensor,
+ top_k: int,
+ renormalize: bool,
+ use_grouped_topk: bool = False,
+ topk_group: Optional[int] = None,
+ num_expert_group: Optional[int] = None,
+ global_num_experts: int = -1,
+ expert_map: Optional[torch.Tensor] = None,
+ custom_routing_function: Optional[Callable] = None,
+ scoring_func: str = "softmax",
+ e_score_correction_bias: Optional[torch.Tensor] = None,
+ apply_router_weight_on_input: bool = False,
+ activation: str = "silu",
+ enable_eplb: bool = False,
+ expert_load_view: Optional[torch.Tensor] = None,
+ logical_to_physical_map: Optional[torch.Tensor] = None,
+ logical_replica_count: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if enable_eplb:
+ raise NotImplementedError("EPLB not supported for "
+ "`CompressedTensorsW4A4MoeMethod` yet.")
+
+ topk_weights, topk_ids = FusedMoE.select_experts(
+ hidden_states=x,
+ router_logits=router_logits,
+ use_grouped_topk=use_grouped_topk,
+ top_k=top_k,
+ renormalize=renormalize,
+ topk_group=topk_group,
+ num_expert_group=num_expert_group,
+ custom_routing_function=custom_routing_function,
+ scoring_func=scoring_func,
+ e_score_correction_bias=e_score_correction_bias,
+ )
+
+ if self.use_marlin:
+ return torch.ops.vllm.fused_marlin_moe(
+ x,
+ layer.w13_weight,
+ layer.w2_weight,
+ layer.w13_weight_scale,
+ layer.w2_weight_scale,
+ router_logits,
+ topk_weights,
+ topk_ids,
+ global_scale1=layer.w13_weight_scale_2,
+ global_scale2=layer.w2_weight_scale_2,
+ quant_type_id=scalar_types.float4_e2m1f.id,
+ global_num_experts=global_num_experts,
+ expert_map=expert_map)
+
+ assert activation == "silu", "Only SiLU activation is supported."
+ assert not apply_router_weight_on_input, (
+ "Router weight on input is not "
+ "supported for CompressedTensorsW4A4MoeMethod.")
+ assert expert_map is None, ("Expert Parallelism / expert_map "
+ "is currently not supported for "
+ "CompressedTensorsW4A4MoeMethod.")
+
+ from vllm.model_executor.layers.fused_moe.cutlass_moe import (
+ cutlass_moe_fp4)
+
+ # Cutlass moe takes in activations in BF16/Half precision
+ # and fp4 quantized weights loaded from the checkpoint
+ return cutlass_moe_fp4(a=x,
+ w1_fp4=layer.w13_weight,
+ w1_blockscale=layer.w13_blockscale_swizzled,
+ w1_alphas=layer.g1_alphas,
+ w2_fp4=layer.w2_weight,
+ w2_blockscale=layer.w2_blockscale_swizzled,
+ w2_alphas=layer.g2_alphas,
+ topk_weights=topk_weights,
+ topk_ids=topk_ids,
+ m=x.shape[0],
+ n=layer.w2_weight.shape[2] * 2,
+ k=x.shape[1],
+ e=layer.w13_weight.shape[0],
+ a1_gscale=layer.w13_input_scale_quant,
+ a2_gscale=layer.w2_input_scale_quant,
+ device=x.device).to(x.dtype)
+
+
class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def __init__(
@@ -331,7 +596,15 @@ def apply(
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
+ enable_eplb: bool = False,
+ expert_load_view: Optional[torch.Tensor] = None,
+ logical_to_physical_map: Optional[torch.Tensor] = None,
+ logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
+ if enable_eplb:
+ raise NotImplementedError(
+ "EPLB not supported for "
+ "`CompressedTensorsW8A8Fp8MoEMethod` yet.")
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
@@ -569,7 +842,7 @@ def select_gemm_impl(self, prepare_finalize, moe):
use_batched_format=True,
)
- if has_pplx and isinstance(
+ if has_pplx() and isinstance(
prepare_finalize,
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
# no expert_map support in this case
@@ -593,7 +866,15 @@ def apply(
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
+ enable_eplb: bool = False,
+ expert_load_view: Optional[torch.Tensor] = None,
+ logical_to_physical_map: Optional[torch.Tensor] = None,
+ logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
+ if enable_eplb:
+ raise NotImplementedError(
+ "EPLB not supported for "
+ "`CompressedTensorsW8A8Fp8MoECutlassMethod` yet.")
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
@@ -722,7 +1003,16 @@ def apply(
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
+ enable_eplb: bool = False,
+ expert_load_view: Optional[torch.Tensor] = None,
+ logical_to_physical_map: Optional[torch.Tensor] = None,
+ logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
+ if enable_eplb:
+ raise NotImplementedError(
+ "EPLB not supported for "
+ "`CompressedTensorsW8A8Int8MoEMethod` yet.")
+
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = FusedMoE.select_experts(
@@ -1012,7 +1302,16 @@ def apply(
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
+ enable_eplb: bool = False,
+ expert_load_view: Optional[torch.Tensor] = None,
+ logical_to_physical_map: Optional[torch.Tensor] = None,
+ logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
+ if enable_eplb:
+ raise NotImplementedError(
+ "EPLB not supported for "
+ "`CompressedTensorsWNA16MarlinMoEMethod` yet.")
+
assert activation == "silu", (
f"{activation} not supported for Marlin MoE.")
assert not apply_router_weight_on_input, (
@@ -1228,7 +1527,15 @@ def apply(
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
+ enable_eplb: bool = False,
+ expert_load_view: Optional[torch.Tensor] = None,
+ logical_to_physical_map: Optional[torch.Tensor] = None,
+ logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
+ if enable_eplb:
+ raise NotImplementedError("EPLB not supported for "
+ "`CompressedTensorsWNA16MoEMethod` yet.")
+
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = FusedMoE.select_experts(
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py
index ec1d4a6c0efa..65cbc49d2640 100644
--- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py
@@ -5,8 +5,7 @@
from torch.nn.parameter import Parameter
import vllm.envs as envs
-from vllm._custom_ops import (cutlass_scaled_fp4_mm,
- cutlass_scaled_mm_supports_fp4, scaled_fp4_quant)
+from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
@@ -15,7 +14,6 @@
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
-from vllm.platforms import current_platform
logger = init_logger(__name__)
@@ -33,15 +31,6 @@ def get_min_capability(cls) -> int:
return 80
return 100
- @classmethod
- def cutlass_fp4_supported(cls) -> bool:
- if not current_platform.is_cuda():
- return False
- capability_tuple = current_platform.get_device_capability()
- capability = -1 if capability_tuple is None else capability_tuple.to_int( # noqa: E501
- )
- return cutlass_scaled_mm_supports_fp4(capability)
-
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: list[int],
input_size_per_partition: int,
diff --git a/vllm/model_executor/layers/quantization/deepgemm.py b/vllm/model_executor/layers/quantization/deepgemm.py
index 1d40f4915a1b..e4cf64740758 100644
--- a/vllm/model_executor/layers/quantization/deepgemm.py
+++ b/vllm/model_executor/layers/quantization/deepgemm.py
@@ -1,15 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
-import importlib.util
import logging
import torch
from vllm.platforms import current_platform
from vllm.triton_utils import triton
-from vllm.utils import direct_register_custom_op
+from vllm.utils import direct_register_custom_op, has_deep_gemm
-has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
-if has_deep_gemm:
+if has_deep_gemm():
import deep_gemm
logger = logging.getLogger(__name__)
diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py
index 01b0064f0805..47eca80609e0 100644
--- a/vllm/model_executor/layers/quantization/experts_int8.py
+++ b/vllm/model_executor/layers/quantization/experts_int8.py
@@ -117,7 +117,15 @@ def apply(
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
+ enable_eplb: bool = False,
+ expert_load_view: Optional[torch.Tensor] = None,
+ logical_to_physical_map: Optional[torch.Tensor] = None,
+ logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
+ if enable_eplb:
+ raise NotImplementedError(
+ "EPLB not supported for `ExpertsInt8MoEMethod` yet.")
+
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = FusedMoE.select_experts(
diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py
index 145a2e96417c..8e7e44f30eb8 100644
--- a/vllm/model_executor/layers/quantization/fp8.py
+++ b/vllm/model_executor/layers/quantization/fp8.py
@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
-import importlib.util
from typing import Any, Callable, Optional, Union
import torch
@@ -38,13 +37,12 @@
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
+from vllm.utils import has_deep_gemm
ACTIVATION_SCHEMES = ["static", "dynamic"]
logger = init_logger(__name__)
-has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
-
def _is_col_major(x: torch.Tensor) -> bool:
assert x.dim() == 3
@@ -455,7 +453,7 @@ def __init__(self, quant_config: Fp8Config):
# Check for DeepGemm support.
self.allow_deep_gemm = False
if envs.VLLM_USE_DEEP_GEMM:
- if not has_deep_gemm:
+ if not has_deep_gemm():
logger.warning_once("Failed to import DeepGemm kernels.")
elif not self.block_quant:
logger.warning_once("Model is not block quantized. Not using "
@@ -829,7 +827,16 @@ def apply(
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
+ enable_eplb: bool = False,
+ expert_load_view: Optional[torch.Tensor] = None,
+ logical_to_physical_map: Optional[torch.Tensor] = None,
+ logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
+ if enable_eplb:
+ assert expert_load_view is not None
+ assert logical_to_physical_map is not None
+ assert logical_replica_count is not None
+ assert isinstance(layer, FusedMoE)
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
@@ -843,6 +850,11 @@ def apply(
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
+ enable_eplb=enable_eplb,
+ expert_map=expert_map,
+ expert_load_view=expert_load_view,
+ logical_to_physical_map=logical_to_physical_map,
+ logical_replica_count=logical_replica_count,
)
if self.rocm_aiter_moe_enabled:
diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py
index 9c8f74545d37..86da04c39989 100644
--- a/vllm/model_executor/layers/quantization/gguf.py
+++ b/vllm/model_executor/layers/quantization/gguf.py
@@ -520,7 +520,15 @@ def apply(
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
+ enable_eplb: bool = False,
+ expert_load_view: Optional[torch.Tensor] = None,
+ logical_to_physical_map: Optional[torch.Tensor] = None,
+ logical_replica_count: Optional[torch.Tensor] = None,
):
+ if enable_eplb:
+ raise NotImplementedError(
+ "EPLB not supported for `GGUFMoEMethod` yet.")
+
assert activation == "silu", "Only SiLU activation is supported."
if apply_router_weight_on_input:
raise NotImplementedError(
diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py
index e9b8dc3266b4..48ab04c9ab37 100644
--- a/vllm/model_executor/layers/quantization/gptq_marlin.py
+++ b/vllm/model_executor/layers/quantization/gptq_marlin.py
@@ -635,7 +635,15 @@ def apply(
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
+ enable_eplb: bool = False,
+ expert_load_view: Optional[torch.Tensor] = None,
+ logical_to_physical_map: Optional[torch.Tensor] = None,
+ logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
+ if enable_eplb:
+ raise NotImplementedError(
+ "EPLB not supported for `GPTQMarlinMoEMethod` yet.")
+
assert activation == "silu", "Only SiLU activation is supported."
if apply_router_weight_on_input:
raise NotImplementedError(
diff --git a/vllm/model_executor/layers/quantization/ipex_quant.py b/vllm/model_executor/layers/quantization/ipex_quant.py
index 31ad96eccaf3..428e9b882bca 100644
--- a/vllm/model_executor/layers/quantization/ipex_quant.py
+++ b/vllm/model_executor/layers/quantization/ipex_quant.py
@@ -15,7 +15,7 @@
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from vllm.platforms import current_platform
-MIN_IPEX_VERSION = "2.7.0"
+MIN_IPEX_VERSION = "2.6.0"
class IPEXConfig(QuantizationConfig):
diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py
index 3f79b203aa17..e35db5b31dba 100644
--- a/vllm/model_executor/layers/quantization/modelopt.py
+++ b/vllm/model_executor/layers/quantization/modelopt.py
@@ -664,7 +664,15 @@ def apply(
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
+ enable_eplb: bool = False,
+ expert_load_view: Optional[torch.Tensor] = None,
+ logical_to_physical_map: Optional[torch.Tensor] = None,
+ logical_replica_count: Optional[torch.Tensor] = None,
):
+ if enable_eplb:
+ raise NotImplementedError(
+ "EPLB not supported for `ModelOptNvFp4FusedMoE` yet.")
+
if self.use_marlin:
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py
index 3aa23f068257..c5055a02fa3d 100644
--- a/vllm/model_executor/layers/quantization/moe_wna16.py
+++ b/vllm/model_executor/layers/quantization/moe_wna16.py
@@ -297,7 +297,15 @@ def apply(
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
+ enable_eplb: bool = False,
+ expert_load_view: Optional[torch.Tensor] = None,
+ logical_to_physical_map: Optional[torch.Tensor] = None,
+ logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
+ if enable_eplb:
+ raise NotImplementedError(
+ "EPLB not supported for `MoeWNA16Method` yet.")
+
from vllm.model_executor.layers.fused_moe import fused_experts
assert activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids = FusedMoE.select_experts(
diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py
index 693b3dfcfa59..740a4b48e454 100644
--- a/vllm/model_executor/layers/quantization/quark/quark.py
+++ b/vllm/model_executor/layers/quantization/quark/quark.py
@@ -325,11 +325,7 @@ def _get_scheme_from_config(self, config: dict[str, Any]) -> "QuarkScheme":
is_fp8_w8a8_supported = self._check_scheme_supported(
QuarkW8A8Fp8.get_min_capability(), error=False)
if is_fp8_w8a8_supported:
- weight_qscheme = cast(str, weight_config.get("qscheme"))
- input_static = (input_config is not None and
- not cast(bool, input_config.get("is_dynamic")))
- return QuarkW8A8Fp8(qscheme=weight_qscheme,
- is_static_input_scheme=input_static)
+ return QuarkW8A8Fp8(weight_config, input_config)
elif self._is_static_tensor_w8a8(weight_config, input_config):
weight_qscheme = cast(str, weight_config.get("qscheme"))
return QuarkW8A8Int8(qscheme=weight_qscheme,
diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py
index 4c2da4c8b04e..a040c430cbca 100644
--- a/vllm/model_executor/layers/quantization/quark/quark_moe.py
+++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py
@@ -205,7 +205,15 @@ def apply(
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
+ enable_eplb: bool = False,
+ expert_load_view: Optional[torch.Tensor] = None,
+ logical_to_physical_map: Optional[torch.Tensor] = None,
+ logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
+ if enable_eplb:
+ raise NotImplementedError(
+ "EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet.")
+
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = FusedMoE.select_experts(
diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py
index 47e0a492b23b..c7bc98184d0e 100644
--- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py
+++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-from typing import Callable, Optional
+from typing import Any, Callable, Optional, cast
import torch
from torch.nn import Parameter
@@ -19,10 +19,19 @@
class QuarkW8A8Fp8(QuarkScheme):
- def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool]):
- self.qscheme = qscheme
- self.is_static_input_scheme = is_static_input_scheme
- self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=False)
+ def __init__(self, weight_config: dict[str, Any],
+ input_config: Optional[dict[str, Any]]):
+ self.weight_qscheme = cast(str, weight_config.get("qscheme"))
+ self.is_static_input_scheme: bool = False
+ self.input_qscheme: Optional[str] = None
+ if input_config is not None:
+ self.is_static_input_scheme = not cast(
+ bool, input_config.get("is_dynamic"))
+ self.input_qscheme = cast(str, input_config.get("qscheme"))
+ self.use_per_token_if_dynamic = (not self.is_static_input_scheme \
+ and self.input_qscheme == "per_channel")
+ self.fp8_linear = Fp8LinearOp(
+ use_per_token_if_dynamic=self.use_per_token_if_dynamic)
self.out_dtype = torch.get_default_dtype()
@classmethod
@@ -34,7 +43,7 @@ def process_weights_after_loading(self, layer) -> None:
# If per tensor, when we have a fused module (e.g. QKV) with per
# tensor scales (thus N scales being passed to the kernel),
# requantize so we can always run per tensor
- if self.qscheme == "per_tensor":
+ if self.weight_qscheme == "per_tensor":
if current_platform.is_rocm():
input_scale = getattr(layer, 'input_scale', None)
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
@@ -58,7 +67,7 @@ def process_weights_after_loading(self, layer) -> None:
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
# If channelwise, scales are already lined up, so just transpose.
- elif self.qscheme == "per_channel":
+ elif self.weight_qscheme == "per_channel":
weight = layer.weight
if current_platform.is_fp8_fnuz():
@@ -73,13 +82,15 @@ def process_weights_after_loading(self, layer) -> None:
requires_grad=False)
else:
weight_scale = layer.weight_scale.data
-
+ if self.use_per_token_if_dynamic:
+ weight_scale = weight_scale.view(-1, 1)
layer.weight = Parameter(weight.t(), requires_grad=False)
# required by torch.compile to be torch.nn.Parameter
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
else:
- raise ValueError(f"Unknown quantization scheme {self.qscheme}")
+ raise ValueError(
+ f"Unknown quantization scheme {self.weight_qscheme}")
# INPUT SCALE
if self.is_static_input_scheme:
@@ -109,14 +120,14 @@ def create_weights(self, layer: torch.nn.Module,
# WEIGHT SCALE
# TODO: update create_xxx_parameter functions to return
# the newly added parameters
- if self.qscheme == "per_channel":
+ if self.weight_qscheme == "per_channel":
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes)),
dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader)
else:
- assert self.qscheme == "per_tensor"
+ assert self.weight_qscheme == "per_tensor"
weight_scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py
index 3a0fb83d627a..c38a445c571b 100644
--- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py
+++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py
@@ -3,7 +3,6 @@
# Adapted from https://github.com/sgl-project/sglang/pull/2575
import functools
-import importlib.util
import json
import os
from typing import Any, Callable, Optional, Union
@@ -19,10 +18,9 @@
CUTLASS_BLOCK_FP8_SUPPORTED)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
-from vllm.utils import cdiv, direct_register_custom_op
+from vllm.utils import cdiv, direct_register_custom_op, has_deep_gemm
logger = init_logger(__name__)
-has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool:
@@ -109,7 +107,7 @@ def should_use_deepgemm(output_dtype: torch.dtype, weight: torch.Tensor):
"""
return (current_platform.is_cuda()
- and current_platform.is_device_capability(90) and has_deep_gemm
+ and current_platform.is_device_capability(90) and has_deep_gemm()
and envs.VLLM_USE_DEEP_GEMM and output_dtype == torch.bfloat16
and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)
diff --git a/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py
index d5ce6d7ad757..fb3287d3b89e 100644
--- a/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py
+++ b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py
@@ -2,9 +2,14 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
+from vllm._custom_ops import cutlass_scaled_mm_supports_fp4
+from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
-__all__ = ["break_fp4_bytes", "dequantize_to_dtype", "ref_nvfp4_quant"]
+__all__ = [
+ "break_fp4_bytes", "dequantize_to_dtype", "ref_nvfp4_quant",
+ "cutlass_fp4_supported"
+]
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
@@ -12,6 +17,14 @@
dtype=torch.float32)
+def cutlass_fp4_supported() -> bool:
+ if not current_platform.is_cuda():
+ return False
+ capability_tuple = current_platform.get_device_capability()
+ capability = -1 if capability_tuple is None else capability_tuple.to_int()
+ return cutlass_scaled_mm_supports_fp4(capability)
+
+
def break_fp4_bytes(a, dtype):
assert a.dtype == torch.uint8
m, n = a.shape
diff --git a/vllm/model_executor/model_loader/bitsandbytes_loader.py b/vllm/model_executor/model_loader/bitsandbytes_loader.py
index 09857ef297f0..0c46d170e88d 100644
--- a/vllm/model_executor/model_loader/bitsandbytes_loader.py
+++ b/vllm/model_executor/model_loader/bitsandbytes_loader.py
@@ -20,8 +20,6 @@
get_tensor_model_parallel_world_size)
# yapf: enable
from vllm.logger import init_logger
-# yapf conflicts with isort for this block
-# yapf: disable
from vllm.model_executor.layers.linear import (LinearBase,
MergedColumnParallelLinear,
QKVParallelLinear,
@@ -39,6 +37,8 @@
set_weight_attrs)
from vllm.platforms import current_platform
+# yapf conflicts with isort for this block
+
logger = init_logger(__name__)
@@ -54,11 +54,17 @@ def __init__(self, load_config: LoadConfig):
self.unsharded_weights_modules: list[str] = []
# Save the module names that are sharded by column.
self.column_sharded_weights_modules: list[str] = []
+ # Modules whose weights might have fused on disk
+ # we need their output_sizes to make shard in flight correctly with TP
+ self.maybe_fused_weights_modules: dict[str, list[int]] = {}
# Store all module names (from transformers) that support
# BNB quantization.
self.target_modules: list[str] = []
# mapping weight names from transformers to vllm.
self.weight_mapper: Callable = lambda name: name
+ self.pre_quant: bool = False
+ self.load_8bit: bool = False
+ self.is_pool_model: bool = False
def _get_weight_files(
self,
@@ -134,13 +140,14 @@ def _prepare_weights(self, model_name_or_path: str,
return hf_weights_files, use_safetensors
def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
- def _maybe_pool_model(module_name:str):
+
+ def _maybe_pool_model(module_name: str):
# For pool model, we need to add the prefix `model.`
# for the weight name if possible.
if self.is_pool_model and self.target_modules[0]. \
startswith("model.") and not module_name.startswith(
"model."):
- return "model."+module_name
+ return "model." + module_name
return module_name
@@ -159,8 +166,7 @@ def _maybe_pool_model(module_name:str):
# mapping weight names from transformers to vllm while preserving
# original names.
mapped_name = self.weight_mapper(org_name)
- mapped_name=_maybe_pool_model(mapped_name)
-
+ mapped_name = _maybe_pool_model(mapped_name)
yield org_name, mapped_name, param
@@ -168,8 +174,6 @@ def _get_quantized_weights_iterator(
self,
model_name_or_path: str,
revision: Optional[str],
- pre_quant: bool,
- load_8bit: bool,
) -> tuple[Generator[tuple[str, torch.Tensor], None, None], dict[str,
Any]]:
"""Get an iterator to the model weights with bitsandbytes quantization,
@@ -192,8 +196,8 @@ def _get_quantized_weights_iterator(
quant_state_dict: dict[str, Any] = {}
- if pre_quant:
- if load_8bit:
+ if self.pre_quant:
+ if self.load_8bit:
return self._quantized_8bit_generator(
hf_weights_files, use_safetensors,
quant_state_dict), quant_state_dict
@@ -390,10 +394,13 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors,
yield org_weight_name, processed_weight
def _get_bnb_target_modules(self, model: nn.Module) -> None:
-
+ """
+ Identify and collect all modules that support BitsAndBytes
+ quantization.
+ """
for name, module in model.named_modules():
- if (isinstance(module, LinearBase) and
- hasattr(module.quant_method, "quant_config")):
+ if (isinstance(module, LinearBase)
+ and hasattr(module.quant_method, "quant_config")):
if modules_info := self.modules_mapping.get_sub_modules(name):
# Map vllm's names to transformers's names.
rep_name, sub_modules = modules_info
@@ -409,29 +416,11 @@ def _get_bnb_target_modules(self, model: nn.Module) -> None:
), "vllm currently does not support BNB quantization for"
f" {type(model).__name__}"
- def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
- if not hasattr(model, "load_weights"):
- raise AttributeError(
- "The required method 'load_weights' is not defined in class"
- f" {type(model).__name__}.")
-
- if not hasattr(model, "packed_modules_mapping"):
- raise AttributeError(
- f"Model {type(model).__name__} does not support BitsAndBytes "
- "quantization yet. No 'packed_modules_mapping' found.")
- self.is_pool_model=is_pooling_model(model)
-
- self.modules_mapping = ParamMapping(get_packed_modules_mapping(model))
-
- # For some models like Molmo, we need to use hf_to_vllm_mapper
- # to ensure correct loading of weights.
- if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
- self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name)
-
- # Modules whose weights might have fused on disk
- # we need their output_sizes to make shard in flight correctly with TP
- self.maybe_fused_weights_modules: dict[str, list[int]] = {}
- self._get_bnb_target_modules(model)
+ def _classify_module_sharding(self, model: nn.Module):
+ """
+ Categorize modules based on their weight sharding requirements
+ for tensor parallelism.
+ """
for name, module in model.named_modules():
# Some modules like `ReplicatedLinear` should not have their weights
# sharded. The reason for implementing it this way is to avoid new
@@ -449,19 +438,27 @@ def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
elif isinstance(module, (RowParallelLinear, )):
self.column_sharded_weights_modules.append(name)
- self.model_type = type(model).__name__
+ def _verify_model_compatibility(self, model: nn.Module,
+ model_config: ModelConfig) -> None:
+ """
+ Verify that the model is compatible with BitsAndBytes quantization.
+ """
+ if not hasattr(model, "load_weights"):
+ raise AttributeError(
+ "The required method 'load_weights' is not defined in class"
+ f" {type(model).__name__}.")
- logger.info("Loading weights with BitsAndBytes quantization. "
- "May take a while ...")
+ if not hasattr(model, "packed_modules_mapping"):
+ raise AttributeError(
+ f"Model {type(model).__name__} does not support BitsAndBytes "
+ "quantization yet. No 'packed_modules_mapping' found.")
quant_config = getattr(model_config.hf_config, "quantization_config",
None)
-
- pre_quant = False
if quant_config is not None:
quant_method = quant_config.get("quant_method")
if quant_method == "bitsandbytes":
- pre_quant = True
+ self.pre_quant = True
else:
raise ValueError(
f"BitsAndBytes loader does not support {quant_method} "
@@ -469,20 +466,43 @@ def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
# The quant_states in pre_quantized models cannot work with a split
# weight tensor. So TP does not work with pre_quantized bnb models.
- if pre_quant and get_tensor_model_parallel_world_size() > 1:
+ if self.pre_quant and get_tensor_model_parallel_world_size() > 1:
raise ValueError(
"Prequant BitsAndBytes models with tensor parallelism is not "
"supported. Please try with pipeline parallelism.")
+ if self.pre_quant:
+ self.load_8bit = quant_config.get("load_in_8bit", False)
+
+ def _initialize_loader_state(self, model: nn.Module,
+ model_config: ModelConfig) -> None:
+ """
+ Initialize the loader's internal state based on the model and
+ configuration.
+ """
+ self.is_pool_model = is_pooling_model(model)
+ self.modules_mapping = ParamMapping(get_packed_modules_mapping(model))
- load_8bit = False
- if pre_quant:
- load_8bit = quant_config.get("load_in_8bit", False)
+ # For some models like Molmo, we need to use hf_to_vllm_mapper
+ # to ensure correct loading of weights.
+ if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
+ self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name)
- qweight_iterator, quant_state_dict = (
- self._get_quantized_weights_iterator(model_config.model,
- model_config.revision,
- pre_quant, load_8bit))
+ self._get_bnb_target_modules(model)
+ self._classify_module_sharding(model)
+ def load_weights(self, model: nn.Module,
+ model_config: ModelConfig) -> None:
+
+ self._verify_model_compatibility(model, model_config)
+ self._initialize_loader_state(model, model_config)
+
+ logger.info("Loading weights with BitsAndBytes quantization. "
+ "May take a while ...")
+ qweight_iterator, quant_state_dict = (
+ self._get_quantized_weights_iterator(
+ model_config.model,
+ model_config.revision,
+ ))
weights_to_load = {name for name, _ in model.named_parameters()}
loaded_weights = model.load_weights(qweight_iterator)
# Some models may have weights loading tracker unimplemented.
@@ -562,10 +582,11 @@ def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
offsets = torch.tensor(offsets).cpu()
set_weight_attrs(param, {"bnb_shard_offsets": offsets})
- if load_8bit:
+ if self.load_8bit:
set_weight_attrs(
param, {"matmul_state": [None] * len(quant_states)})
torch.cuda.empty_cache()
+
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, model_config.revision)
diff --git a/vllm/model_executor/models/bert_with_rope.py b/vllm/model_executor/models/bert_with_rope.py
index 0f22393c79d9..0b7350f07d3f 100644
--- a/vllm/model_executor/models/bert_with_rope.py
+++ b/vllm/model_executor/models/bert_with_rope.py
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
-from copy import deepcopy
from typing import Optional
import torch
@@ -12,7 +11,6 @@
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
-from vllm.logger import init_logger
from vllm.model_executor.layers.activation import (get_act_and_mul_fn,
get_act_fn)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@@ -30,8 +28,6 @@
from vllm.model_executor.models.utils import WeightsMapper
from vllm.sequence import IntermediateTensors
-logger = init_logger(__name__)
-
class BertWithRopeEmbedding(nn.Module):
@@ -408,7 +404,7 @@ class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.vllm_config = vllm_config
- self.config = self.config_verify(vllm_config)
+ self.config = vllm_config.model_config.hf_config
self.embeddings = BertWithRopeEmbedding(self.config)
self.encoder = BertWithRopeEncoder(
vllm_config=vllm_config,
@@ -416,9 +412,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
rotary_kwargs=self.config.rotary_kwargs,
prefix=f"{prefix}.encoder")
- def config_verify(self, vllm_config):
- raise NotImplementedError
-
def forward(
self,
input_ids: Optional[torch.Tensor],
@@ -490,95 +483,6 @@ class NomicBertModel(BertWithRope):
"norm2": "mlp_ln",
})
- def config_verify(self, vllm_config):
- config = vllm_config.model_config.hf_config
-
- assert config.__class__.__name__ == "NomicBertConfig"
- assert config.activation_function in ["swiglu", "gelu"]
- config.position_embedding_type = getattr(config,
- "position_embedding_type",
- "rope")
-
- if config.activation_function == "swiglu":
- config.hidden_act = "silu"
- else:
- config.hidden_act = config.activation_function
-
- assert (config.mlp_fc1_bias == config.mlp_fc2_bias ==
- config.qkv_proj_bias)
- config.bias = config.qkv_proj_bias
-
- assert config.rotary_emb_scale_base is None
- assert not config.rotary_emb_interleaved
-
- config.layer_norm_eps = config.layer_norm_epsilon
- config.intermediate_size = config.n_inner
- config.hidden_size = config.n_embd
- config.num_hidden_layers = config.n_layer
-
- head_dim = config.hidden_size // config.num_attention_heads
- rotary_emb_dim = head_dim * config.rotary_emb_fraction
- max_trained_positions = getattr(config, "max_trained_positions", 2048)
- config.rotary_kwargs = {
- "head_size": head_dim,
- "rotary_dim": rotary_emb_dim,
- "max_position": max_trained_positions,
- "base": getattr(config, "rope_theta", config.rotary_emb_base),
- "rope_scaling": getattr(config, "rope_scaling", None)
- }
-
- # we ignore config.rotary_scaling_factor so that for datasets shorter
- # than max_trained_positions 2048, the results are consistent
- # with SentenceTransformer.
- # The context extension uses vllm style rope_theta and rope_scaling.
- # See #17785 #18755
- if (not vllm_config.model_config.hf_overrides
- and vllm_config.model_config.original_max_model_len is None):
- # Default
- # Reset max_model_len to max_trained_positions.
- # nomic-embed-text-v2-moe the length is set to 512
- # by sentence_bert_config.json.
- max_model_len_before = vllm_config.model_config.max_model_len
- max_model_len = min(vllm_config.model_config.max_model_len,
- max_trained_positions)
-
- vllm_config.recalculate_max_model_len(max_model_len)
- logger.warning(
- "Nomic context extension is disabled. "
- "Changing max_model_len from %s to %s. "
- "To enable context extension, see: "
- "https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.html",
- max_model_len_before, vllm_config.model_config.max_model_len)
- else:
- # We need to re-verify max_model_len to avoid lengths
- # greater than position_embedding.
- model_config = vllm_config.model_config
- hf_text_config = model_config.hf_text_config
-
- if isinstance(model_config.hf_overrides, dict):
- # hf_overrides_kw
- max_model_len = model_config.hf_overrides.get(
- "max_model_len", vllm_config.model_config.max_model_len)
- else:
- # hf_overrides_fn
- # This might be overridden by sentence_bert_config.json.
- max_model_len = vllm_config.model_config.max_model_len
-
- # reset hf_text_config for recalculate_max_model_len.
- if hasattr(hf_text_config, "max_model_len"):
- delattr(hf_text_config, "max_model_len")
- hf_text_config.max_position_embeddings = max_trained_positions
- hf_text_config.rope_scaling = config.rotary_kwargs["rope_scaling"]
-
- # The priority of sentence_bert_config.json is higher
- # than max_position_embeddings
- encoder_config = deepcopy(model_config.encoder_config)
- encoder_config.pop("max_seq_length", None)
- model_config.encoder_config = encoder_config
-
- vllm_config.recalculate_max_model_len(max_model_len)
- return config
-
class GteNewModel(BertWithRope):
# for https://huggingface.co/Alibaba-NLP/new-impl
@@ -600,24 +504,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
layer.mlp.gate_up_proj.bias = None
layer.mlp.gate_up_proj.skip_bias_add = True
- def config_verify(self, vllm_config):
- config = vllm_config.model_config.hf_config
-
- assert config.__class__.__name__ == "NewConfig"
- assert config.hidden_act == "gelu"
-
- config.hidden_act = "geglu"
-
- head_dim = config.hidden_size // config.num_attention_heads
- config.rotary_kwargs = {
- "head_size": head_dim,
- "rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
- "max_position": config.max_position_embeddings,
- "base": config.rope_theta,
- "rope_scaling": getattr(config, "rope_scaling", None)
- }
- return config
-
def split_up_gate_proj(self, weights: Iterable[tuple[str, torch.Tensor]]):
n = "mlp.up_gate_proj"
for name, weight in weights:
@@ -652,24 +538,6 @@ class SnowflakeGteNewModel(GteNewModel):
"attention.o_proj": "attn.out_proj",
})
- def config_verify(self, vllm_config):
- config = vllm_config.model_config.hf_config
-
- assert config.__class__.__name__ == "GteConfig"
- assert config.hidden_act == "gelu"
-
- config.hidden_act = "geglu"
-
- head_dim = config.hidden_size // config.num_attention_heads
- config.rotary_kwargs = {
- "head_size": head_dim,
- "rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
- "max_position": config.max_position_embeddings,
- "base": config.rope_theta,
- "rope_scaling": getattr(config, "rope_scaling", None)
- }
- return config
-
class JinaRobertaModel(BertWithRope):
# for https://huggingface.co/jinaai/jina-embeddings-v3
@@ -685,21 +553,6 @@ class JinaRobertaModel(BertWithRope):
"norm2": "mlp_ln",
})
- def config_verify(self, vllm_config):
- config = vllm_config.model_config.hf_config
-
- assert config.__class__.__name__ == "XLMRobertaFlashConfig"
-
- head_dim = config.hidden_size // config.num_attention_heads
- config.rotary_kwargs = {
- "head_size": head_dim,
- "rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
- "max_position": config.max_position_embeddings,
- "base": getattr(config, "rope_theta", config.rotary_emb_base),
- "rope_scaling": getattr(config, "rope_scaling", None)
- }
- return config
-
def forward(
self,
input_ids: torch.Tensor,
diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py
new file mode 100644
index 000000000000..7b5345704ad0
--- /dev/null
+++ b/vllm/model_executor/models/config.py
@@ -0,0 +1,200 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from copy import deepcopy
+from typing import TYPE_CHECKING
+
+from vllm.logger import init_logger
+
+if TYPE_CHECKING:
+ from vllm.config import VllmConfig
+
+logger = init_logger(__name__)
+
+
+class VerifyAndUpdateConfig:
+
+ @staticmethod
+ def verify_and_update_config(vllm_config: "VllmConfig") -> None:
+ raise NotImplementedError
+
+
+class GteNewModelConfig(VerifyAndUpdateConfig):
+
+ @staticmethod
+ def verify_and_update_config(vllm_config: "VllmConfig") -> None:
+ config = vllm_config.model_config.hf_config
+
+ assert config.__class__.__name__ == "NewConfig"
+ assert config.hidden_act == "gelu"
+
+ config.hidden_act = "geglu"
+
+ head_dim = config.hidden_size // config.num_attention_heads
+ config.rotary_kwargs = {
+ "head_size": head_dim,
+ "rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
+ "max_position": config.max_position_embeddings,
+ "base": config.rope_theta,
+ "rope_scaling": getattr(config, "rope_scaling", None)
+ }
+
+
+class JinaRobertaModelConfig(VerifyAndUpdateConfig):
+
+ @staticmethod
+ def verify_and_update_config(vllm_config: "VllmConfig") -> None:
+ config = vllm_config.model_config.hf_config
+
+ if config.position_embedding_type == "rotary":
+ assert config.__class__.__name__ == "XLMRobertaFlashConfig"
+
+ head_dim = config.hidden_size // config.num_attention_heads
+ config.rotary_kwargs = {
+ "head_size": head_dim,
+ "rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
+ "max_position": config.max_position_embeddings,
+ "base": getattr(config, "rope_theta", config.rotary_emb_base),
+ "rope_scaling": getattr(config, "rope_scaling", None)
+ }
+
+
+class NomicBertModelConfig(VerifyAndUpdateConfig):
+
+ @staticmethod
+ def verify_and_update_config(vllm_config: "VllmConfig") -> None:
+ config = vllm_config.model_config.hf_config
+
+ assert config.__class__.__name__ == "NomicBertConfig"
+ assert config.activation_function in ["swiglu", "gelu"]
+ config.position_embedding_type = getattr(config,
+ "position_embedding_type",
+ "rope")
+
+ if config.activation_function == "swiglu":
+ config.hidden_act = "silu"
+ else:
+ config.hidden_act = config.activation_function
+
+ assert (config.mlp_fc1_bias == config.mlp_fc2_bias ==
+ config.qkv_proj_bias)
+ config.bias = config.qkv_proj_bias
+
+ assert config.rotary_emb_scale_base is None
+ assert not config.rotary_emb_interleaved
+
+ config.layer_norm_eps = config.layer_norm_epsilon
+ config.intermediate_size = config.n_inner
+ config.hidden_size = config.n_embd
+ config.num_hidden_layers = config.n_layer
+
+ head_dim = config.hidden_size // config.num_attention_heads
+ rotary_emb_dim = head_dim * config.rotary_emb_fraction
+ max_trained_positions = getattr(config, "max_trained_positions", 2048)
+ config.rotary_kwargs = {
+ "head_size": head_dim,
+ "rotary_dim": rotary_emb_dim,
+ "max_position": max_trained_positions,
+ "base": getattr(config, "rope_theta", config.rotary_emb_base),
+ "rope_scaling": getattr(config, "rope_scaling", None)
+ }
+
+ # we ignore config.rotary_scaling_factor so that for datasets shorter
+ # than max_trained_positions 2048, the results are consistent
+ # with SentenceTransformer.
+ # The context extension uses vllm style rope_theta and rope_scaling.
+ # See #17785 #18755
+ if (not vllm_config.model_config.hf_overrides
+ and vllm_config.model_config.original_max_model_len is None):
+ # Default
+ # Reset max_model_len to max_trained_positions.
+ # nomic-embed-text-v2-moe the length is set to 512
+ # by sentence_bert_config.json.
+ max_model_len_before = vllm_config.model_config.max_model_len
+ max_model_len = min(vllm_config.model_config.max_model_len,
+ max_trained_positions)
+
+ vllm_config.recalculate_max_model_len(max_model_len)
+ logger.warning(
+ "Nomic context extension is disabled. "
+ "Changing max_model_len from %s to %s. "
+ "To enable context extension, see: "
+ "https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.html",
+ max_model_len_before, vllm_config.model_config.max_model_len)
+ else:
+ # We need to re-verify max_model_len to avoid lengths
+ # greater than position_embedding.
+ model_config = vllm_config.model_config
+ hf_text_config = model_config.hf_text_config
+
+ if isinstance(model_config.hf_overrides, dict):
+ # hf_overrides_kw
+ max_model_len = model_config.hf_overrides.get(
+ "max_model_len", vllm_config.model_config.max_model_len)
+ else:
+ # hf_overrides_fn
+ # This might be overridden by sentence_bert_config.json.
+ max_model_len = vllm_config.model_config.max_model_len
+
+ # reset hf_text_config for recalculate_max_model_len.
+ if hasattr(hf_text_config, "max_model_len"):
+ delattr(hf_text_config, "max_model_len")
+ hf_text_config.max_position_embeddings = max_trained_positions
+ hf_text_config.rope_scaling = config.rotary_kwargs["rope_scaling"]
+
+ # The priority of sentence_bert_config.json is higher
+ # than max_position_embeddings
+ encoder_config = deepcopy(model_config.encoder_config)
+ encoder_config.pop("max_seq_length", None)
+ model_config.encoder_config = encoder_config
+
+ vllm_config.recalculate_max_model_len(max_model_len)
+
+
+class Qwen3ForSequenceClassificationConfig(VerifyAndUpdateConfig):
+
+ @staticmethod
+ def verify_and_update_config(vllm_config: "VllmConfig") -> None:
+ config = vllm_config.model_config.hf_config
+
+ is_original_qwen3_reranker = getattr(config,
+ "is_original_qwen3_reranker",
+ False)
+
+ if not is_original_qwen3_reranker:
+ return
+
+ tokens = getattr(config, "classifier_from_token", None)
+ assert tokens is not None and len(tokens) == 2, \
+ ("Try loading the original Qwen3 Reranker?, see: "
+ "https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py")
+ config.num_labels = 1
+
+
+class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
+
+ @staticmethod
+ def verify_and_update_config(vllm_config: "VllmConfig") -> None:
+ config = vllm_config.model_config.hf_config
+
+ assert config.__class__.__name__ == "GteConfig"
+ assert config.hidden_act == "gelu"
+
+ config.hidden_act = "geglu"
+
+ head_dim = config.hidden_size // config.num_attention_heads
+ config.rotary_kwargs = {
+ "head_size": head_dim,
+ "rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
+ "max_position": config.max_position_embeddings,
+ "base": config.rope_theta,
+ "rope_scaling": getattr(config, "rope_scaling", None)
+ }
+
+
+MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
+ "GteModel": SnowflakeGteNewModelConfig,
+ "GteNewModel": GteNewModelConfig,
+ "NomicBertModel": NomicBertModelConfig,
+ "Qwen3ForSequenceClassification": Qwen3ForSequenceClassificationConfig,
+ "XLMRobertaModel": JinaRobertaModelConfig,
+}
diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py
index 0f996d04e6e8..2fa1294b79b9 100644
--- a/vllm/model_executor/models/deepseek_v2.py
+++ b/vllm/model_executor/models/deepseek_v2.py
@@ -23,7 +23,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only DeepseekV2/DeepseekV3 model."""
-from collections.abc import Iterable
+import typing
+from collections.abc import Callable, Iterable
from typing import Any, Optional, Union
import torch
@@ -32,8 +33,10 @@
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
-from vllm.config import CacheConfig, ModelConfig, VllmConfig
-from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
+from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
+ get_current_vllm_config)
+from vllm.distributed import (get_ep_group, get_pp_group,
+ get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
@@ -51,7 +54,7 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
-from .interfaces import SupportsPP
+from .interfaces import MixtureOfExperts, SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
@@ -99,11 +102,17 @@ def __init__(
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
+ enable_eplb: bool = False,
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.routed_scaling_factor = config.routed_scaling_factor
- self.n_shared_experts = config.n_shared_experts
+
+ self.ep_group = get_ep_group().device_group
+ self.ep_rank = self.ep_group.rank()
+ self.ep_size = self.ep_group.size()
+ self.n_routed_experts: int = config.n_routed_experts
+ self.n_shared_experts: int = config.n_shared_experts
if config.hidden_act != "silu":
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
@@ -120,6 +129,22 @@ def __init__(
else:
self.gate.e_score_correction_bias = None
+ # Load balancing settings.
+ vllm_config = get_current_vllm_config()
+ parallel_config = vllm_config.parallel_config
+ self.enable_eplb = enable_eplb
+
+ self.n_redundant_experts = parallel_config.num_redundant_experts
+ self.n_logical_experts = self.n_routed_experts
+ self.n_physical_experts = (self.n_logical_experts +
+ self.n_redundant_experts)
+ self.n_local_physical_experts = self.n_physical_experts // self.ep_size
+
+ self.physical_expert_start = (self.ep_rank *
+ self.n_local_physical_experts)
+ self.physical_expert_end = (self.physical_expert_start +
+ self.n_local_physical_experts)
+
self.experts = FusedMoE(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
@@ -133,7 +158,9 @@ def __init__(
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
- e_score_correction_bias=self.gate.e_score_correction_bias)
+ e_score_correction_bias=self.gate.e_score_correction_bias,
+ enable_eplb=self.enable_eplb,
+ num_redundant_experts=self.n_redundant_experts)
if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size *
@@ -503,6 +530,7 @@ def __init__(
model_config: ModelConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
+ enable_eplb: bool = False,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
@@ -543,6 +571,7 @@ def __init__(
config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
+ enable_eplb=enable_eplb,
)
else:
self.mlp = DeepseekV2MLP(
@@ -615,6 +644,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
+ enable_eplb = vllm_config.parallel_config.enable_eplb
self.config = config
self.vocab_size = config.vocab_size
@@ -636,6 +666,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
+ enable_eplb=enable_eplb,
),
prefix=f"{prefix}.layers")
@@ -681,7 +712,7 @@ def forward(
return hidden_states
-class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
+class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
@@ -700,6 +731,44 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
+ self.expert_weights = []
+
+ # Set MoE hyperparameters
+ self.num_moe_layers = (config.num_hidden_layers -
+ config.first_k_dense_replace)
+ self.num_expert_groups = config.n_group
+
+ self.moe_layers: list[FusedMoE] = []
+ for layer in self.model.layers:
+ assert isinstance(layer, DeepseekV2DecoderLayer)
+ if isinstance(layer.mlp, DeepseekV2MoE):
+ self.moe_layers.append(layer.mlp.experts)
+
+ # Pick last one layer since the first ones may be dense layers.
+ example_moe = typing.cast(
+ DeepseekV2MoE, self.model.layers[config.num_hidden_layers - 1].mlp)
+ self.num_logical_experts = example_moe.n_logical_experts
+ self.num_physical_experts = example_moe.n_physical_experts
+ self.num_local_physical_experts = example_moe.n_local_physical_experts
+ self.num_routed_experts = example_moe.n_routed_experts
+ self.num_shared_experts = example_moe.n_shared_experts
+ self.num_redundant_experts = example_moe.n_redundant_experts
+
+ def set_eplb_state(
+ self,
+ expert_load_view: torch.Tensor,
+ logical_to_physical_map: torch.Tensor,
+ logical_replica_count: torch.Tensor,
+ ) -> None:
+ for layer_idx, layer in enumerate(self.moe_layers):
+ # Register the expert weights.
+ self.expert_weights.append(layer.get_expert_weights())
+ layer.set_eplb_state(
+ moe_layer_idx=layer_idx,
+ expert_load_view=expert_load_view,
+ logical_to_physical_map=logical_to_physical_map,
+ logical_replica_count=logical_replica_count,
+ )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
@@ -752,7 +821,8 @@ def load_weights(self, weights: Iterable[tuple[str,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
- num_experts=self.config.n_routed_experts)
+ num_experts=self.config.n_routed_experts,
+ num_redundant_experts=self.num_redundant_experts)
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
@@ -789,24 +859,45 @@ def load_weights(self, weights: Iterable[tuple[str,
weight_loader(param, loaded_weight, shard_id)
break
else:
+ is_expert_weight = False
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
- name = name.replace(weight_name, param_name)
- if is_pp_missing_parameter(name, self):
+ # Anyway, this is an expert weight and should not be
+ # attempted to load as other weights later
+ is_expert_weight = True
+
+ # Do not modify `name` since the loop may continue here
+ # Instead, create a new variable
+ name_mapped = name.replace(weight_name, param_name)
+
+ if is_pp_missing_parameter(name_mapped, self):
continue
- param = params_dict[name]
- weight_loader = param.weight_loader
- weight_loader(param,
- loaded_weight,
- name,
- shard_id=shard_id,
- expert_id=expert_id)
- break
+ param = params_dict[name_mapped]
+ # We should ask the weight loader to return success or not
+ # here since otherwise we may skip experts with other
+ # available replicas.
+ weight_loader = typing.cast(Callable[..., bool],
+ param.weight_loader)
+ success = weight_loader(param,
+ loaded_weight,
+ name_mapped,
+ shard_id=shard_id,
+ expert_id=expert_id,
+ return_success=True)
+ if success:
+ name = name_mapped
+ break
else:
+ if is_expert_weight:
+ # We've checked that this is an expert weight
+ # However it's not mapped locally to this rank
+ # So we simply skip it
+ continue
+
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
@@ -824,6 +915,7 @@ def load_weights(self, weights: Iterable[tuple[str,
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
+
return loaded_params
diff --git a/vllm/model_executor/models/dots1.py b/vllm/model_executor/models/dots1.py
new file mode 100644
index 000000000000..01a27d02a304
--- /dev/null
+++ b/vllm/model_executor/models/dots1.py
@@ -0,0 +1,535 @@
+# SPDX-License-Identifier: Apache-2.0
+
+# Adapted from
+# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
+# Copyright 2025 The rednote-hilab team.
+# Copyright 2023 The vLLM team.
+# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Inference-only dots1 model."""
+from collections.abc import Iterable
+from typing import Any, Optional, Union
+
+import torch
+from torch import nn
+from transformers import PretrainedConfig
+
+from vllm.attention import Attention
+from vllm.compilation.decorators import support_torch_compile
+from vllm.config import CacheConfig, ModelConfig, VllmConfig
+from vllm.distributed import (get_pp_group,
+ get_tensor_model_parallel_world_size,
+ tensor_model_parallel_all_reduce)
+from vllm.model_executor.layers.activation import SiluAndMul
+from vllm.model_executor.layers.fused_moe import FusedMoE
+from vllm.model_executor.layers.layernorm import RMSNorm
+from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
+ QKVParallelLinear,
+ ReplicatedLinear,
+ RowParallelLinear)
+from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.layers.rotary_embedding import get_rope
+from vllm.model_executor.layers.vocab_parallel_embedding import (
+ ParallelLMHead, VocabParallelEmbedding)
+from vllm.model_executor.model_loader.weight_utils import (
+ default_weight_loader, maybe_remap_kv_scale_name)
+from vllm.model_executor.sampling_metadata import SamplingMetadata
+from vllm.sequence import IntermediateTensors
+
+from .interfaces import SupportsPP
+from .utils import (PPMissingLayer, is_pp_missing_parameter,
+ make_empty_intermediate_tensors_factory, make_layers,
+ maybe_prefix)
+
+
+class Dots1MLP(nn.Module):
+
+ def __init__(
+ self,
+ hidden_size: int,
+ intermediate_size: int,
+ hidden_act: str,
+ quant_config: Optional[QuantizationConfig] = None,
+ reduce_results: bool = True,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.gate_up_proj = MergedColumnParallelLinear(
+ hidden_size, [intermediate_size] * 2,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.gate_up_proj")
+ self.down_proj = RowParallelLinear(intermediate_size,
+ hidden_size,
+ bias=False,
+ quant_config=quant_config,
+ reduce_results=reduce_results,
+ prefix=f"{prefix}.down_proj")
+ if hidden_act != "silu":
+ raise ValueError(f"Unsupported activation: {hidden_act}. "
+ "Only silu is supported for now.")
+ self.act_fn = SiluAndMul()
+
+ def forward(self, x):
+ gate_up, _ = self.gate_up_proj(x)
+ x = self.act_fn(gate_up)
+ x, _ = self.down_proj(x)
+ return x
+
+
+class Dots1MoE(nn.Module):
+
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ):
+ super().__init__()
+ self.tp_size = get_tensor_model_parallel_world_size()
+ self.routed_scaling_factor = config.routed_scaling_factor
+ self.n_shared_experts = config.n_shared_experts
+
+ if config.hidden_act != "silu":
+ raise ValueError(f"Unsupported activation: {config.hidden_act}. "
+ "Only silu is supported for now.")
+
+ self.gate = ReplicatedLinear(config.hidden_size,
+ config.n_routed_experts,
+ bias=False,
+ quant_config=None,
+ prefix=f"{prefix}.gate")
+ if config.topk_method == "noaux_tc":
+ self.gate.e_score_correction_bias = (nn.Parameter(
+ torch.empty(config.n_routed_experts)))
+ else:
+ self.gate.e_score_correction_bias = None
+
+ self.experts = FusedMoE(
+ num_experts=config.n_routed_experts,
+ top_k=config.num_experts_per_tok,
+ hidden_size=config.hidden_size,
+ intermediate_size=config.moe_intermediate_size,
+ reduce_results=False,
+ renormalize=config.norm_topk_prob,
+ quant_config=quant_config,
+ use_grouped_topk=True,
+ num_expert_group=config.n_group,
+ topk_group=config.topk_group,
+ prefix=f"{prefix}.experts",
+ scoring_func=config.scoring_func,
+ e_score_correction_bias=self.gate.e_score_correction_bias)
+
+ if config.n_shared_experts is not None:
+ intermediate_size = (config.moe_intermediate_size *
+ config.n_shared_experts)
+ self.shared_experts = Dots1MLP(
+ hidden_size=config.hidden_size,
+ intermediate_size=intermediate_size,
+ hidden_act=config.hidden_act,
+ quant_config=quant_config,
+ reduce_results=False,
+ prefix=f"{prefix}.shared_experts",
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ num_tokens, hidden_dim = hidden_states.shape
+ hidden_states = hidden_states.view(-1, hidden_dim)
+ if self.n_shared_experts is not None:
+ shared_output = self.shared_experts(hidden_states)
+ router_logits, _ = self.gate(hidden_states)
+ final_hidden_states = self.experts(
+ hidden_states=hidden_states,
+ router_logits=router_logits) * self.routed_scaling_factor
+ if shared_output is not None:
+ final_hidden_states = final_hidden_states + shared_output
+ if self.tp_size > 1:
+ final_hidden_states = tensor_model_parallel_all_reduce(
+ final_hidden_states)
+ return final_hidden_states.view(num_tokens, hidden_dim)
+
+
+class Dots1Attention(nn.Module):
+
+ def __init__(
+ self,
+ hidden_size: int,
+ num_heads: int,
+ num_kv_heads: int,
+ config: PretrainedConfig,
+ rope_theta: float = 10000,
+ rope_scaling: Optional[dict[str, Any]] = None,
+ max_position_embeddings: int = 8192,
+ cache_config: Optional[CacheConfig] = None,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.hidden_size = hidden_size
+ tp_size = get_tensor_model_parallel_world_size()
+ self.total_num_heads = num_heads
+ assert self.total_num_heads % tp_size == 0
+ self.num_heads = self.total_num_heads // tp_size
+ self.total_num_kv_heads = num_kv_heads
+ if self.total_num_kv_heads >= tp_size:
+ # Number of KV heads is greater than TP size, so we partition
+ # the KV heads across multiple tensor parallel GPUs.
+ assert self.total_num_kv_heads % tp_size == 0
+ else:
+ # Number of KV heads is less than TP size, so we replicate
+ # the KV heads across multiple tensor parallel GPUs.
+ assert tp_size % self.total_num_kv_heads == 0
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
+ self.head_dim = getattr(config, "head_dim",
+ hidden_size // self.total_num_heads)
+ self.q_size = self.num_heads * self.head_dim
+ self.kv_size = self.num_kv_heads * self.head_dim
+ self.scaling = self.head_dim**-0.5
+ self.rope_theta = rope_theta
+ self.max_position_embeddings = max_position_embeddings
+ attention_bias = config.attention_bias
+
+ self.qkv_proj = QKVParallelLinear(
+ hidden_size,
+ self.head_dim,
+ self.total_num_heads,
+ self.total_num_kv_heads,
+ bias=attention_bias,
+ quant_config=quant_config,
+ )
+
+ self.o_proj = RowParallelLinear(
+ self.total_num_heads * self.head_dim,
+ hidden_size,
+ bias=False,
+ quant_config=quant_config,
+ )
+
+ self.rotary_emb = get_rope(
+ self.head_dim,
+ rotary_dim=self.head_dim,
+ max_position=max_position_embeddings,
+ base=rope_theta,
+ rope_scaling=rope_scaling,
+ )
+ self.attn = Attention(
+ self.num_heads,
+ self.head_dim,
+ self.scaling,
+ num_kv_heads=self.num_kv_heads,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.attn",
+ )
+ self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
+ self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
+
+ def forward(self, positions: torch.Tensor,
+ hidden_states: torch.Tensor) -> torch.Tensor:
+ qkv, _ = self.qkv_proj(hidden_states)
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+ q = self.q_norm(q.reshape(-1, self.num_heads,
+ self.head_dim)).reshape(q.shape)
+ k = self.k_norm(k.reshape(-1, self.num_kv_heads,
+ self.head_dim)).reshape(k.shape)
+ q, k = self.rotary_emb(positions, q, k)
+ attn_output = self.attn(q, k, v)
+ output, _ = self.o_proj(attn_output)
+ return output
+
+
+class Dots1DecoderLayer(nn.Module):
+
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ prefix: str,
+ model_config: ModelConfig,
+ cache_config: Optional[CacheConfig] = None,
+ quant_config: Optional[QuantizationConfig] = None,
+ ) -> None:
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ rope_theta = getattr(config, "rope_theta", 10000)
+ rope_scaling = getattr(config, "rope_scaling", None)
+ max_position_embeddings = getattr(config, "max_position_embeddings",
+ 8192)
+ layer_idx = int(prefix.split(sep='.')[-1])
+ self.layer_idx = layer_idx
+
+ self.self_attn = Dots1Attention(
+ hidden_size=self.hidden_size,
+ num_heads=config.num_attention_heads,
+ num_kv_heads=config.num_key_value_heads,
+ config=config,
+ rope_theta=rope_theta,
+ rope_scaling=rope_scaling,
+ max_position_embeddings=max_position_embeddings,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.self_attn",
+ )
+ if (config.n_routed_experts is not None
+ and layer_idx >= config.first_k_dense_replace
+ and layer_idx % config.moe_layer_freq == 0):
+ self.mlp = Dots1MoE(config=config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.mlp")
+ else:
+ self.mlp = Dots1MLP(
+ hidden_size=config.hidden_size,
+ intermediate_size=config.intermediate_size,
+ hidden_act=config.hidden_act,
+ quant_config=quant_config,
+ prefix=f"{prefix}.mlp",
+ )
+ self.input_layernorm = RMSNorm(config.hidden_size,
+ eps=config.rms_norm_eps)
+ self.post_attention_layernorm = RMSNorm(config.hidden_size,
+ eps=config.rms_norm_eps)
+ self.routed_scaling_factor = config.routed_scaling_factor
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ residual: Optional[torch.Tensor],
+ ) -> torch.Tensor:
+ if residual is None:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ else:
+ hidden_states, residual = self.input_layernorm(
+ hidden_states, residual)
+ hidden_states = self.self_attn(positions=positions,
+ hidden_states=hidden_states)
+ hidden_states, residual = self.post_attention_layernorm(
+ hidden_states, residual)
+ hidden_states = self.mlp(hidden_states)
+ return hidden_states, residual
+
+
+class Dots1Model(nn.Module):
+
+ fall_back_to_pt_during_load = False
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+
+ config = vllm_config.model_config.hf_config
+ model_config = vllm_config.model_config
+ cache_config = vllm_config.cache_config
+ quant_config = vllm_config.quant_config
+ self.config = config
+
+ self.vocab_size = config.vocab_size
+
+ if get_pp_group().is_first_rank:
+ self.embed_tokens = VocabParallelEmbedding(
+ config.vocab_size,
+ config.hidden_size,
+ quant_config=quant_config,
+ prefix=f"{prefix}.embed_tokens")
+ else:
+ self.embed_tokens = PPMissingLayer()
+
+ self.start_layer, self.end_layer, self.layers = make_layers(
+ config.num_hidden_layers,
+ lambda prefix: Dots1DecoderLayer(
+ config,
+ prefix,
+ model_config=model_config,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ ),
+ prefix=f"{prefix}.layers")
+
+ if get_pp_group().is_last_rank:
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ else:
+ self.norm = PPMissingLayer()
+ self.make_empty_intermediate_tensors = (
+ make_empty_intermediate_tensors_factory(
+ ["hidden_states", "residual"], config.hidden_size))
+
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.embed_tokens(input_ids)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: Optional[IntermediateTensors],
+ inputs_embeds: Optional[torch.Tensor] = None,
+ ) -> Union[torch.Tensor, IntermediateTensors]:
+ if get_pp_group().is_first_rank:
+ if inputs_embeds is not None:
+ hidden_states = inputs_embeds
+ else:
+ hidden_states = self.get_input_embeddings(input_ids)
+ residual = None
+ else:
+ assert intermediate_tensors is not None
+ hidden_states = intermediate_tensors["hidden_states"]
+ residual = intermediate_tensors["residual"]
+ for layer in self.layers[self.start_layer:self.end_layer]:
+ hidden_states, residual = layer(
+ positions,
+ hidden_states,
+ residual,
+ )
+ if not get_pp_group().is_last_rank:
+ return IntermediateTensors({
+ "hidden_states": hidden_states,
+ "residual": residual
+ })
+ hidden_states, _ = self.norm(hidden_states, residual)
+ return hidden_states
+
+
+@support_torch_compile
+class Dots1ForCausalLM(nn.Module, SupportsPP):
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+ config = vllm_config.model_config.hf_config
+ quant_config = vllm_config.quant_config
+ self.config = config
+ self.quant_config = quant_config
+ self.model = Dots1Model(vllm_config=vllm_config,
+ prefix=maybe_prefix(prefix, "model"))
+ if get_pp_group().is_last_rank:
+ self.lm_head = ParallelLMHead(config.vocab_size,
+ config.hidden_size,
+ quant_config=quant_config)
+ else:
+ self.lm_head = PPMissingLayer()
+ self.logits_processor = LogitsProcessor(config.vocab_size)
+ self.make_empty_intermediate_tensors = (
+ self.model.make_empty_intermediate_tensors)
+
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.model.get_input_embeddings(input_ids)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ ) -> Union[torch.Tensor, IntermediateTensors]:
+ hidden_states = self.model(
+ input_ids,
+ positions,
+ intermediate_tensors,
+ inputs_embeds,
+ )
+ return hidden_states
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> Optional[torch.Tensor]:
+ logits = self.logits_processor(self.lm_head, hidden_states,
+ sampling_metadata)
+ return logits
+
+ def make_empty_intermediate_tensors(
+ self, batch_size: int, dtype: torch.dtype,
+ device: torch.device) -> IntermediateTensors:
+ return IntermediateTensors({
+ "hidden_states":
+ torch.zeros((batch_size, self.config.hidden_size),
+ dtype=dtype,
+ device=device),
+ "residual":
+ torch.zeros((batch_size, self.config.hidden_size),
+ dtype=dtype,
+ device=device),
+ })
+
+ def load_weights(self, weights: Iterable[tuple[str,
+ torch.Tensor]]) -> set[str]:
+ stacked_params_mapping = [
+ ("qkv_proj", "q_proj", "q"),
+ ("qkv_proj", "k_proj", "k"),
+ ("qkv_proj", "v_proj", "v"),
+ ("gate_up_proj", "gate_proj", 0),
+ ("gate_up_proj", "up_proj", 1),
+ ]
+
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
+ ckpt_gate_proj_name="gate_proj",
+ ckpt_down_proj_name="down_proj",
+ ckpt_up_proj_name="up_proj",
+ num_experts=self.config.n_routed_experts)
+
+ params_dict = dict(self.named_parameters())
+ loaded_params: set[str] = set()
+ for name, loaded_weight in weights:
+ if "rotary_emb.inv_freq" in name:
+ continue
+ for (param_name, weight_name, shard_id) in stacked_params_mapping:
+ if weight_name not in name:
+ continue
+ if (("mlp.experts." in name) and name not in params_dict):
+ continue
+ name = name.replace(weight_name, param_name)
+ if name.endswith(".bias") and name not in params_dict:
+ continue
+ if is_pp_missing_parameter(name, self):
+ continue
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, shard_id)
+ break
+ else:
+ for mapping in expert_params_mapping:
+ param_name, weight_name, expert_id, shard_id = mapping
+ if weight_name not in name:
+ continue
+ name = name.replace(weight_name, param_name)
+
+ if is_pp_missing_parameter(name, self):
+ continue
+
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ weight_loader(param,
+ loaded_weight,
+ name,
+ shard_id=shard_id,
+ expert_id=expert_id)
+ break
+ else:
+ if name.endswith(".bias") and name not in params_dict:
+ continue
+ name = maybe_remap_kv_scale_name(name, params_dict)
+ if name is None:
+ continue
+ if is_pp_missing_parameter(name, self):
+ continue
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader",
+ default_weight_loader)
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+ return loaded_params
diff --git a/vllm/model_executor/models/gemma3n.py b/vllm/model_executor/models/gemma3n.py
new file mode 100644
index 000000000000..7d163320e0d6
--- /dev/null
+++ b/vllm/model_executor/models/gemma3n.py
@@ -0,0 +1,811 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+# Copyright 2025 The vLLM team.
+# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from collections.abc import Iterable
+from typing import Optional, Union
+
+import torch
+from torch import nn
+from transformers.models.gemma3n.configuration_gemma3n import Gemma3nTextConfig
+
+from vllm.attention import Attention
+from vllm.compilation.decorators import support_torch_compile
+from vllm.config import CacheConfig, VllmConfig
+from vllm.distributed import get_tensor_model_parallel_world_size
+from vllm.logger import init_logger
+from vllm.model_executor.layers.activation import (_ACTIVATION_REGISTRY,
+ GeluAndMul,
+ GeluAndMulSparse)
+from vllm.model_executor.layers.layernorm import RMSNorm
+from vllm.model_executor.layers.linear import (ColumnParallelLinear,
+ MergedColumnParallelLinear,
+ QKVParallelLinear,
+ ReplicatedLinear,
+ RowParallelLinear)
+from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.layers.rotary_embedding import get_rope
+from vllm.model_executor.layers.vocab_parallel_embedding import (
+ VocabParallelEmbedding)
+from vllm.model_executor.model_loader.weight_utils import (
+ default_weight_loader, maybe_remap_kv_scale_name)
+from vllm.model_executor.sampling_metadata import SamplingMetadata
+from vllm.sequence import IntermediateTensors
+
+from .utils import (AutoWeightsLoader, extract_layer_index,
+ is_pp_missing_parameter, make_layers, maybe_prefix)
+
+logger = init_logger(__name__)
+
+
+class Gemma3nAltUp(nn.Module):
+ """Alternating updates (Altup)
+ The AltUp module wraps transformer layers. The `predict` step modifies the
+ input to the transformer layer, and the `correct` step propagates the output
+ of the transformer layer to the sparsely updated dimensions.
+ See more in the research paper:
+ https://proceedings.neurips.cc/paper_files/paper/2023/file/f2059277ac6ce66e7e5543001afa8bb5-Paper-Conference.pdf
+ """
+
+ def __init__(
+ self,
+ hidden_size: int,
+ rms_norm_eps: float,
+ altup_num_inputs: int,
+ altup_coef_clip: float,
+ altup_active_idx: int,
+ prefix: str,
+ ):
+ super().__init__()
+
+ self.altup_num_inputs = altup_num_inputs
+ self.altup_active_idx = altup_active_idx
+ self.altup_coef_clip = altup_coef_clip
+
+ self.correction_coefs = ReplicatedLinear(
+ altup_num_inputs,
+ altup_num_inputs,
+ bias=False,
+ prefix=f"{prefix}.correction_coefs",
+ return_bias=False,
+ )
+ self.prediction_coefs = ReplicatedLinear(
+ altup_num_inputs,
+ altup_num_inputs**2,
+ bias=False,
+ prefix=f"{prefix}.prediction_coefs",
+ return_bias=False,
+ )
+ self.modality_router = ReplicatedLinear(
+ hidden_size,
+ altup_num_inputs,
+ bias=False,
+ prefix=f"{prefix}.modality_router",
+ return_bias=False,
+ )
+ self.router_norm = RMSNorm(
+ hidden_size=hidden_size,
+ eps=rms_norm_eps,
+ )
+ self.router_input_scale = torch.tensor(
+ hidden_size**-1.0, dtype=self.modality_router.weight.dtype)
+ self.correct_output_scale = nn.Parameter(
+ torch.zeros(hidden_size, dtype=torch.float32))
+
+ def _compute_router_modalities(self, x: torch.Tensor) -> torch.Tensor:
+ router_inputs = self.router_norm(x) * self.router_input_scale
+ routed = self.modality_router(router_inputs)
+ return torch.tanh(routed.float()).type_as(x)
+
+ def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor:
+ return (corrected.type_as(self.correct_output_scale) *
+ self.correct_output_scale).type_as(corrected)
+
+ def predict(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ # hidden: [altup_num_inputs, num_tokens, hidden_size]
+ # modalities: [num_tokens, num_altup_inputs]
+ # all_coefs: [num_tokens, num_altup_inputs ** 2]
+ modalities = self._compute_router_modalities(
+ hidden_states[self.altup_active_idx])
+ all_coefs = self.prediction_coefs(modalities)
+
+ # Reshape and transpose the 2D matrix for the matmul.
+ # all_coefs_T: [num_tokens, num_altup_inputs, num_altup_inputs]
+ all_coefs_T = all_coefs.reshape(
+ -1,
+ self.altup_num_inputs,
+ self.altup_num_inputs,
+ ).permute(0, 2, 1)
+
+ # hidden_states to [num_tokens, hidden_size, altup_num_inputs]
+ predictions = torch.matmul(hidden_states.permute(1, 2, 0), all_coefs_T)
+ # [altup_num_inputs, num_tokens, hidden_size]
+ predictions = predictions.permute(2, 0, 1)
+ predictions += hidden_states
+ return predictions.contiguous()
+
+ def correct(self, predictions: torch.Tensor,
+ activated: torch.Tensor) -> torch.Tensor:
+ # predictions: [altup_num_inputs, num_tokens, hidden_size]
+ # activated: [num_tokens, hidden_size]
+ # modalities: [num_tokens, altup_num_inputs]
+ modalities = self._compute_router_modalities(activated)
+ # innovation: [num_tokens, altup_num_inputs]
+ innovation = activated - predictions[self.altup_active_idx]
+ # innovation: [altup_num_inputs, num_tokens, hidden_size]
+ innovation = innovation.repeat(self.altup_num_inputs, 1, 1)
+
+ # Permute to [altup_num_inputs, num_tokens] as the last dim
+ # is a scalar applied to each altup input and expand on
+ # num_tokens dim for broadcastability over hidden_size.
+ # all_coefs: [num_tokens, altup_num_inputs]
+ all_coefs = self.correction_coefs(modalities) + 1.0
+ # all_coefs: [altup_num_inputs, num_tokens, 1]
+ all_coefs = all_coefs.T.unsqueeze(-1)
+
+ # Elementwise (broadcast over hidden_size).
+ corrected = torch.mul(innovation, all_coefs)
+ corrected += predictions
+
+ return corrected.contiguous()
+
+
+class Gemma3nLaurelBlock(nn.Module):
+ """Learned Augmented Residual Layer"""
+
+ def __init__(self, hidden_size: int, laurel_rank: int, rms_norm_eps: float,
+ prefix: str):
+ super().__init__()
+
+ self.linear_left = ColumnParallelLinear(
+ hidden_size,
+ laurel_rank,
+ bias=False,
+ prefix=f"{prefix}.linear_left",
+ return_bias=False,
+ )
+ self.linear_right = RowParallelLinear(laurel_rank,
+ hidden_size,
+ bias=False,
+ prefix=f"{prefix}.linear_right",
+ return_bias=False)
+ self.post_laurel_norm = RMSNorm(
+ hidden_size=hidden_size,
+ eps=rms_norm_eps,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ laurel_x = self.linear_left(x)
+ laurel_x = self.linear_right(laurel_x)
+ normed_laurel_x = self.post_laurel_norm(laurel_x)
+ return x + normed_laurel_x
+
+
+class Gemma3nMLP(nn.Module):
+
+ def __init__(
+ self,
+ hidden_size: int,
+ intermediate_size: int,
+ hidden_activation: str,
+ activation_sparsity: float = 0.0,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.gate_up_proj = MergedColumnParallelLinear(
+ hidden_size,
+ [intermediate_size] * 2,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.gate_up_proj",
+ )
+ self.down_proj = RowParallelLinear(
+ intermediate_size,
+ hidden_size,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.down_proj",
+ )
+ if hidden_activation != "gelu_pytorch_tanh":
+ raise ValueError(
+ "Gemma3 uses `gelu_pytorch_tanh` as the hidden activation "
+ "function. Please set `hidden_act` and `hidden_activation` to "
+ "`gelu_pytorch_tanh`.")
+
+ self.act_fn = GeluAndMulSparse(
+ activation_sparsity=activation_sparsity,
+ approximate="tanh") if activation_sparsity > 0.0 else GeluAndMul(
+ approximate="tanh")
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ gate_up, _ = self.gate_up_proj(x)
+ x = self.act_fn(gate_up)
+ x, _ = self.down_proj(x)
+ return x
+
+
+class Gemma3nAttention(nn.Module):
+
+ def __init__(self,
+ config: Gemma3nTextConfig,
+ hidden_size: int,
+ num_heads: int,
+ num_kv_heads: int,
+ head_dim: int,
+ max_position_embeddings: int,
+ cache_config: Optional[CacheConfig] = None,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "") -> None:
+ super().__init__()
+ self.config = config
+ self.hidden_size = hidden_size
+ tp_size = get_tensor_model_parallel_world_size()
+ self.total_num_heads = num_heads
+ assert self.total_num_heads % tp_size == 0
+ self.num_heads = self.total_num_heads // tp_size
+ self.total_num_kv_heads = num_kv_heads
+ if self.total_num_kv_heads >= tp_size:
+ # Number of KV heads is greater than TP size, so we partition
+ # the KV heads across multiple tensor parallel GPUs.
+ assert self.total_num_kv_heads % tp_size == 0
+ else:
+ # Number of KV heads is less than TP size, so we replicate
+ # the KV heads across multiple tensor parallel GPUs.
+ assert tp_size % self.total_num_kv_heads == 0
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
+ self.head_dim = head_dim
+ self.q_size = self.num_heads * self.head_dim
+ self.kv_size = self.num_kv_heads * self.head_dim
+
+ self.qkv_proj = QKVParallelLinear(
+ hidden_size,
+ self.head_dim,
+ self.total_num_heads,
+ self.total_num_kv_heads,
+ bias=config.attention_bias,
+ quant_config=quant_config,
+ prefix=f"{prefix}.qkv_proj",
+ )
+ self.o_proj = RowParallelLinear(
+ self.total_num_heads * self.head_dim,
+ hidden_size,
+ bias=config.attention_bias,
+ quant_config=quant_config,
+ prefix=f"{prefix}.o_proj",
+ )
+ self.q_norm = RMSNorm(hidden_size=self.head_dim,
+ eps=config.rms_norm_eps)
+ self.k_norm = RMSNorm(hidden_size=self.head_dim,
+ eps=config.rms_norm_eps)
+ self.v_norm = RMSNorm(hidden_size=self.head_dim,
+ eps=config.rms_norm_eps,
+ has_weight=False)
+
+ layer_idx = extract_layer_index(prefix)
+ if config.layer_types[layer_idx] == "sliding_attention":
+ self.sliding_window = config.sliding_window
+ rope_theta = config.rope_local_base_freq
+ rope_scaling = {"rope_type": "default"}
+ else:
+ self.sliding_window = None
+ rope_theta = config.rope_theta
+ rope_scaling = config.rope_scaling
+
+ first_kv_shared_layer_idx = (config.num_hidden_layers -
+ config.num_kv_shared_layers)
+ self.is_kv_shared = layer_idx >= first_kv_shared_layer_idx
+
+ if self.is_kv_shared:
+ # Last full attention layer is 1 before sharing
+ # Last sliding attention layer is 2 before sharing
+ offset = 2 if self.sliding_window is not None else 1
+ kv_shared_layer_index = first_kv_shared_layer_idx - offset
+ kv_sharing_target_layer_name = f"model.language_model.layers.{kv_shared_layer_index}.self_attn.attn" # noqa: E501
+ else:
+ kv_sharing_target_layer_name = None
+
+ self.rotary_emb = get_rope(
+ self.head_dim,
+ rotary_dim=self.head_dim,
+ max_position=max_position_embeddings,
+ base=rope_theta,
+ is_neox_style=True,
+ rope_scaling=rope_scaling,
+ )
+
+ self.attn = Attention(
+ num_heads=self.num_heads,
+ head_size=self.head_dim,
+ scale=1.0,
+ num_kv_heads=self.num_kv_heads,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ per_layer_sliding_window=self.sliding_window,
+ kv_sharing_target_layer_name=kv_sharing_target_layer_name,
+ prefix=f"{prefix}.attn")
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ **kwargs,
+ ) -> torch.Tensor:
+ qkv, _ = self.qkv_proj(hidden_states)
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+
+ q = q.unflatten(-1, (self.num_heads, self.head_dim))
+ q = self.q_norm(q)
+ q = q.flatten(-2, -1)
+ k = k.unflatten(-1, (self.num_kv_heads, self.head_dim))
+ k = self.k_norm(k)
+ k = k.flatten(-2, -1)
+ v = v.unflatten(-1, (self.num_kv_heads, self.head_dim))
+ v = self.v_norm(v)
+ v = v.flatten(-2, -1)
+
+ q, k = self.rotary_emb(positions, q, k)
+ attn_output = self.attn(q, k, v)
+
+ output, _ = self.o_proj(attn_output)
+ return output
+
+
+class Gemma3nDecoderLayer(nn.Module):
+
+ def __init__(
+ self,
+ config: Gemma3nTextConfig,
+ cache_config: Optional[CacheConfig] = None,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.altup_active_idx = config.altup_active_idx
+ assert config.altup_correct_scale
+
+ self.altup = Gemma3nAltUp(
+ hidden_size=config.hidden_size,
+ rms_norm_eps=config.rms_norm_eps,
+ altup_num_inputs=config.altup_num_inputs,
+ altup_coef_clip=config.altup_coef_clip,
+ altup_active_idx=config.altup_active_idx,
+ prefix=f"{prefix}.altup",
+ )
+ self.self_attn = Gemma3nAttention(
+ config=config,
+ hidden_size=config.hidden_size,
+ num_heads=config.num_attention_heads,
+ num_kv_heads=config.num_key_value_heads,
+ head_dim=config.head_dim,
+ max_position_embeddings=config.max_position_embeddings,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.self_attn",
+ )
+ self.mlp = Gemma3nMLP(
+ hidden_size=config.hidden_size,
+ # NOTE: Matformer https://github.com/huggingface/transformers/blob/a52478253bbe522a420e88ea3940d4d98a935300/src/transformers/models/gemma3n/modular_gemma3n.py#L258 # noqa: E501
+ intermediate_size=config.intermediate_size[extract_layer_index(
+ prefix)],
+ hidden_activation=config.hidden_activation,
+ quant_config=quant_config,
+ activation_sparsity=config.activation_sparsity_pattern[
+ extract_layer_index(prefix)],
+ prefix=f"{prefix}.mlp",
+ )
+ self.laurel = Gemma3nLaurelBlock(
+ hidden_size=config.hidden_size,
+ laurel_rank=config.laurel_rank,
+ rms_norm_eps=config.rms_norm_eps,
+ prefix=f"{prefix}.laurel",
+ )
+
+ # NOTE(rob): should be ColumnParallelLinear and RowParallelLinear
+ # But, we need to add per_layer_input_gate(x) to per_layer_input.
+ # per_layer_input cannot be sharded, so we replicate for now.
+ self.per_layer_input_gate = ReplicatedLinear(
+ config.hidden_size,
+ config.hidden_size_per_layer_input,
+ bias=False,
+ prefix=f"{prefix}.per_layer_input_gate",
+ return_bias=False,
+ )
+ self.per_layer_projection = ReplicatedLinear(
+ config.hidden_size_per_layer_input,
+ config.hidden_size,
+ bias=False,
+ prefix=f"{prefix}.per_layer_projection",
+ return_bias=False,
+ )
+
+ # LayerNorms.
+ self.input_layernorm = RMSNorm(
+ config.hidden_size,
+ eps=config.rms_norm_eps,
+ )
+ self.post_attention_layernorm = RMSNorm(
+ config.hidden_size,
+ eps=config.rms_norm_eps,
+ )
+ self.pre_feedforward_layernorm = RMSNorm(
+ config.hidden_size,
+ eps=config.rms_norm_eps,
+ )
+ self.post_feedforward_layernorm = RMSNorm(
+ config.hidden_size,
+ eps=config.rms_norm_eps,
+ )
+ self.post_per_layer_input_norm = RMSNorm(
+ config.hidden_size,
+ eps=config.rms_norm_eps,
+ )
+
+ self.act_fn = _ACTIVATION_REGISTRY[config.hidden_activation]
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ per_layer_input: torch.Tensor,
+ **kwargs,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+
+ # ActUp (predict).
+ predictions = self.altup.predict(hidden_states)
+ active_prediction = predictions[self.altup_active_idx]
+ active_prediction_normed = self.input_layernorm(active_prediction)
+ laurel_output = self.laurel(active_prediction_normed)
+
+ # Attention.
+ attn = self.self_attn(
+ positions=positions,
+ hidden_states=active_prediction_normed,
+ **kwargs,
+ )
+ attn = self.post_attention_layernorm(attn)
+ attn_gated = attn + active_prediction
+ attn_laurel = (attn_gated + laurel_output) / torch.sqrt(
+ torch.tensor(2.0))
+
+ # MLP.
+ attn_norm = self.pre_feedforward_layernorm(attn_laurel)
+ attn_ffw = self.mlp(attn_norm)
+ attn_ffw_norm = self.post_feedforward_layernorm(attn_ffw)
+ attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm
+
+ # ActUp (connect).
+ corrected_predictions = self.altup.correct(predictions,
+ attn_ffw_laurel_gated)
+ first_prediction = corrected_predictions[self.altup_active_idx]
+ first_prediction = self.altup.scale_corrected_output(first_prediction)
+
+ # per_layer_input_gate adapted from jax.numpy.einsum("btd,dp->btp", ...)
+ first_prediction = self.per_layer_input_gate(first_prediction)
+ first_prediction = self.act_fn(first_prediction)
+ first_prediction = torch.mul(first_prediction, per_layer_input)
+
+ # per_layer_projection adapted from jax.numpy.einsum("btp,pd->btd", ...)
+ first_prediction = self.per_layer_projection(first_prediction)
+ first_prediction = self.post_per_layer_input_norm(first_prediction)
+ corrected_predictions[1:] += first_prediction
+
+ return corrected_predictions
+
+
+@support_torch_compile
+class Gemma3nTextModel(nn.Module):
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+ config = vllm_config.model_config.hf_config.text_config
+ cache_config = vllm_config.cache_config
+ quant_config = vllm_config.quant_config
+ self.config = config
+ self.quant_config = quant_config
+
+ self.embed_tokens = VocabParallelEmbedding(
+ config.vocab_size,
+ config.hidden_size,
+ prefix=f"{prefix}.embed_tokens",
+ )
+ self.embed_scale = torch.tensor(
+ config.hidden_size**0.5,
+ dtype=self.embed_tokens.weight.dtype,
+ )
+ self.embed_tokens_per_layer = VocabParallelEmbedding(
+ config.vocab_size_per_layer_input,
+ config.num_hidden_layers * config.hidden_size_per_layer_input,
+ prefix=f"{prefix}.per_layer_embed_tokens",
+ )
+ self.embed_scale_per_layer = torch.tensor(
+ config.hidden_size_per_layer_input**0.5,
+ dtype=self.embed_tokens.weight.dtype,
+ )
+ self.per_layer_model_projection = ColumnParallelLinear(
+ config.hidden_size,
+ config.num_hidden_layers * config.hidden_size_per_layer_input,
+ bias=False,
+ gather_output=True,
+ return_bias=False,
+ prefix=f"{prefix}.per_layer_model_projection",
+ )
+ self.per_layer_projection_norm = RMSNorm(
+ hidden_size=config.hidden_size_per_layer_input,
+ eps=config.rms_norm_eps,
+ )
+ self.per_layer_input_scale = torch.rsqrt(torch.tensor(2.0)).to(
+ self.embed_tokens.weight.dtype)
+ self.per_layer_projection_scale = torch.tensor(
+ config.hidden_size**0.5,
+ dtype=self.embed_tokens.weight.dtype,
+ )
+ self.altup_projections = nn.ModuleList([
+ ColumnParallelLinear(
+ config.hidden_size,
+ config.hidden_size,
+ bias=False,
+ gather_output=True,
+ return_bias=False,
+ prefix=f"{prefix}.{idx-1}.altup_projections",
+ ) for idx in range(1, self.config.altup_num_inputs)
+ ])
+ self.altup_unembed_projections = nn.ModuleList([
+ ColumnParallelLinear(
+ config.hidden_size,
+ config.hidden_size,
+ bias=False,
+ gather_output=True,
+ return_bias=False,
+ prefix=f"{prefix}.{idx-1}.altup_unembed_projections",
+ ) for idx in range(1, self.config.altup_num_inputs)
+ ])
+
+ # Transformer blocks.
+ self.start_layer, self.end_layer, self.layers = make_layers(
+ config.num_hidden_layers,
+ lambda prefix: Gemma3nDecoderLayer(
+ config, cache_config, quant_config, prefix=prefix),
+ prefix=f"{prefix}.layers")
+ self.norm = RMSNorm(
+ config.hidden_size,
+ eps=config.rms_norm_eps,
+ )
+ self.eps = torch.tensor(torch.finfo().min)
+
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.embed_tokens(input_ids) * self.embed_scale
+
+ def get_per_layer_input_embeddings(
+ self, input_ids: torch.Tensor) -> torch.Tensor:
+ # Deal with the fact that vocab_size_per_layer_input < vocab_size
+ # which causes us to have some out of vocab tokens by setting
+ # those token ids to 0. This matches the HF implementation.
+ per_layer_inputs_mask = torch.logical_and(
+ input_ids >= 0, input_ids < self.config.vocab_size_per_layer_input)
+ per_layer_inputs_tokens = torch.where(per_layer_inputs_mask, input_ids,
+ torch.zeros_like(input_ids))
+ return self.embed_tokens_per_layer(
+ per_layer_inputs_tokens) * self.embed_scale_per_layer
+
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor],
+ positions: torch.Tensor,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> Union[torch.Tensor, IntermediateTensors]:
+ if inputs_embeds is not None:
+ hidden_states_0 = inputs_embeds
+ else:
+ hidden_states_0 = self.get_input_embeddings(input_ids)
+
+ # Per layer inputs.
+ if input_ids is None:
+ raise ValueError("Passing None for input ids is not supported.")
+ per_layer_inputs = self.get_per_layer_input_embeddings(input_ids)
+ per_layer_inputs = per_layer_inputs.reshape(
+ -1, self.config.num_hidden_layers,
+ self.config.hidden_size_per_layer_input)
+ per_layer_projection = self.per_layer_model_projection(hidden_states_0)
+ per_layer_projection = per_layer_projection.reshape(
+ *hidden_states_0.shape[:-1],
+ self.config.num_hidden_layers,
+ self.config.hidden_size_per_layer_input,
+ )
+ per_layer_projection = self.per_layer_projection_norm(
+ per_layer_projection)
+ per_layer_inputs = per_layer_projection + per_layer_inputs
+ per_layer_inputs *= self.per_layer_input_scale
+
+ # Altup embed.
+ hidden_states = [hidden_states_0] * self.config.altup_num_inputs
+ target_magnitude = torch.mean(hidden_states_0**2, dim=-1,
+ keepdim=True)**0.5
+ for i in range(1, self.config.altup_num_inputs):
+ hidden_states[i] = self.altup_projections[i - 1](hidden_states[i])
+ new_magnitude = torch.mean(hidden_states[i]**2,
+ dim=-1,
+ keepdim=True)**0.5
+ hidden_states[i] *= target_magnitude / torch.maximum(
+ new_magnitude, self.eps)
+ hidden_states = torch.stack(hidden_states, dim=0)
+
+ # Transformer blocks.
+ for layer_idx, layer in enumerate(self.layers):
+ # [altup_num_inputs, num_tokens, hidden_size]
+ hidden_states = layer(
+ positions=positions,
+ hidden_states=hidden_states,
+ per_layer_input=per_layer_inputs[:, layer_idx, :],
+ **kwargs,
+ )
+
+ # Altup unembed.
+ target_magnitude = torch.mean(hidden_states[0]**2,
+ dim=-1,
+ keepdim=True)**0.5
+ for i in range(1, self.config.altup_num_inputs):
+ hidden_states[i] = self.altup_unembed_projections[i - 1](
+ hidden_states[i])
+ new_magnitude = torch.mean(hidden_states[i]**2,
+ dim=-1,
+ keepdim=True)**0.5
+ hidden_states[i] *= target_magnitude / torch.maximum(
+ new_magnitude, self.eps)
+ # [altup_num_inputs,num_tokens,hidden_size] -> [num_tokens,hidden_size]
+ hidden_states = torch.mean(hidden_states, dim=0)
+
+ return self.norm(hidden_states)
+
+ def load_weights(self, weights: Iterable[tuple[str,
+ torch.Tensor]]) -> set[str]:
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ ("qkv_proj", "q_proj", "q"),
+ ("qkv_proj", "k_proj", "k"),
+ ("qkv_proj", "v_proj", "v"),
+ ("gate_up_proj", "gate_proj", 0),
+ ("gate_up_proj", "up_proj", 1),
+ ]
+ params_dict = dict(self.named_parameters())
+ loaded_params: set[str] = set()
+ for name, loaded_weight in weights:
+ if (self.quant_config is not None and
+ (scale_name := self.quant_config.get_cache_scale(name))):
+ # Loading kv cache scales for compressed-tensors quantization
+ param = params_dict[scale_name]
+ weight_loader = getattr(param, "weight_loader",
+ default_weight_loader)
+ loaded_weight = loaded_weight[0]
+ weight_loader(param, loaded_weight)
+ loaded_params.add(scale_name)
+ continue
+ for (param_name, shard_name, shard_id) in stacked_params_mapping:
+ if shard_name not in name:
+ continue
+ # Avoid spurious match with ".up_proj".
+ if "altup_projections" in name:
+ continue
+ name = name.replace(shard_name, param_name)
+ # Skip loading extra bias for GPTQ models.
+ if name.endswith(".bias") and name not in params_dict:
+ continue
+ if is_pp_missing_parameter(name, self):
+ continue
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, shard_id)
+ break
+ else:
+ # Skip loading extra bias for GPTQ models.
+ if name.endswith(".bias") and name not in params_dict:
+ continue
+ # Remapping the name of FP8 kv-scale.
+ name = maybe_remap_kv_scale_name(name, params_dict)
+ if name is None:
+ continue
+ if is_pp_missing_parameter(name, self):
+ continue
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader",
+ default_weight_loader)
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+
+ return loaded_params
+
+
+class Gemma3nModel(nn.Module):
+
+ def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+ self.language_model = Gemma3nTextModel(vllm_config=vllm_config,
+ prefix=maybe_prefix(
+ prefix, "language_model"))
+
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor],
+ positions: torch.Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ return self.language_model(input_ids=input_ids,
+ positions=positions,
+ inputs_embeds=inputs_embeds,
+ **kwargs)
+
+
+class Gemma3nForConditionalGeneration(nn.Module):
+ packed_modules_mapping = {
+ "qkv_proj": [
+ "q_proj",
+ "k_proj",
+ "v_proj",
+ ],
+ "gate_up_proj": [
+ "gate_proj",
+ "up_proj",
+ ],
+ }
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ config = vllm_config.model_config.hf_config
+ lora_config = vllm_config.lora_config
+ del lora_config # Unused.
+ super().__init__()
+ self.config = config
+ self.model = Gemma3nModel(vllm_config=vllm_config,
+ prefix=maybe_prefix(prefix, "model"))
+ self.logits_processor = LogitsProcessor(
+ config.text_config.vocab_size,
+ soft_cap=config.text_config.final_logit_softcapping)
+
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.model.language_model.get_input_embeddings(input_ids)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> Union[torch.Tensor, IntermediateTensors]:
+ hidden_states = self.model(input_ids, positions, intermediate_tensors,
+ inputs_embeds, **kwargs)
+ return hidden_states
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ sampling_metadata: Optional[SamplingMetadata],
+ ) -> Optional[torch.Tensor]:
+ logits = self.logits_processor(self.model.language_model.embed_tokens,
+ hidden_states, sampling_metadata)
+ return logits
+
+ def load_weights(self, weights: Iterable[tuple[str,
+ torch.Tensor]]) -> set[str]:
+ loader = AutoWeightsLoader(self,
+ skip_substrs=([
+ "embed_audio.", "embed_vision.",
+ "audio_tower.", "vision_tower."
+ ]))
+ return loader.load_weights(weights)
diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py
index 26b5b3ac1534..33e8626209d5 100644
--- a/vllm/model_executor/models/granitemoehybrid.py
+++ b/vllm/model_executor/models/granitemoehybrid.py
@@ -15,7 +15,8 @@
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.layernorm import RMSNorm
-from vllm.model_executor.layers.linear import ReplicatedLinear
+from vllm.model_executor.layers.linear import (QKVParallelLinear,
+ RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
@@ -36,8 +37,9 @@
from .granitemoeshared import GraniteMoeSharedMLP
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
SupportsQuant, SupportsV0Only)
-from .utils import (AutoWeightsLoader, make_empty_intermediate_tensors_factory,
- make_layers, maybe_prefix)
+from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
+ make_empty_intermediate_tensors_factory, make_layers,
+ maybe_prefix)
class GraniteMoeHybridMambaDecoderLayer(nn.Module):
@@ -220,35 +222,37 @@ def __init__(
self.hidden_size = config.hidden_size
self.attention_bias = config.attention_bias
self.attention_multiplier = config.attention_multiplier
- self.num_heads = config.num_attention_heads
- self.head_dim = self.hidden_size // self.num_heads
- self.num_key_value_heads = config.num_key_value_heads
-
- self.q_proj = ReplicatedLinear(self.hidden_size,
- self.num_heads * self.head_dim,
- bias=self.attention_bias,
- quant_config=quant_config,
- prefix=f"{prefix}.q_proj")
-
- self.k_proj = ReplicatedLinear(self.hidden_size,
- self.num_key_value_heads *
- self.head_dim,
- bias=self.attention_bias,
- quant_config=quant_config,
- prefix=f"{prefix}.k_proj")
-
- self.v_proj = ReplicatedLinear(self.hidden_size,
- self.num_key_value_heads *
- self.head_dim,
- bias=self.attention_bias,
- quant_config=quant_config,
- prefix=f"{prefix}.v_proj")
-
- self.o_proj = ReplicatedLinear(self.hidden_size,
- self.hidden_size,
- bias=self.attention_bias,
- quant_config=quant_config,
- prefix=f"{prefix}.o_proj")
+ self.total_num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.total_num_heads
+ self.total_num_kv_heads = config.num_key_value_heads
+
+ # TensorParallel logic
+ tp_size = get_tensor_model_parallel_world_size()
+ assert self.total_num_heads % tp_size == 0
+ self.num_heads = self.total_num_heads // tp_size
+ if self.total_num_kv_heads >= tp_size:
+ # Number of KV heads is greater than TP size, so we partition
+ # the KV heads across multiple tensor parallel GPUs.
+ assert self.total_num_kv_heads % tp_size == 0
+ else:
+ # Number of KV heads is less than TP size, so we replicate
+ # the KV heads across multiple tensor parallel GPUs.
+ assert tp_size % self.total_num_kv_heads == 0
+ self.num_key_value_heads = max(1, self.total_num_kv_heads // tp_size)
+
+ self.qkv_proj = QKVParallelLinear(self.hidden_size,
+ self.head_dim,
+ self.total_num_heads,
+ self.total_num_kv_heads,
+ bias=self.attention_bias,
+ quant_config=quant_config,
+ prefix=f"{prefix}.qkv_proj")
+
+ self.o_proj = RowParallelLinear(self.hidden_size,
+ self.hidden_size,
+ bias=self.attention_bias,
+ quant_config=quant_config,
+ prefix=f"{prefix}.o_proj")
if config.position_embedding_type == "rope":
self.rotary_emb = get_rope(
@@ -278,9 +282,12 @@ def forward(
hidden_states: torch.Tensor,
) -> torch.Tensor:
- query = self.q_proj(hidden_states)[0]
- key = self.k_proj(hidden_states)[0]
- value = self.v_proj(hidden_states)[0]
+ qkv, _ = self.qkv_proj(hidden_states)
+ query, key, value = qkv.split([
+ self.num_heads * self.head_dim, self.num_key_value_heads *
+ self.head_dim, self.num_key_value_heads * self.head_dim
+ ],
+ dim=-1)
if self.rotary_emb is not None:
query, key = self.rotary_emb(positions, query, key)
@@ -401,6 +408,12 @@ def forward(
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ (".qkv_proj", ".q_proj", "q"),
+ (".qkv_proj", ".k_proj", "k"),
+ (".qkv_proj", ".v_proj", "v"),
+ ]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
@@ -411,6 +424,15 @@ def _load(n, p):
weight_loader(param, p)
loaded_params.add(n)
+ def _load_shard(n, p, shard_id):
+ # Skip layers on other devices.
+ if not is_pp_missing_parameter(n, self):
+ param = params_dict[n]
+ weight_loader = getattr(param, "weight_loader",
+ default_weight_loader)
+ weight_loader(param, p, shard_id)
+ loaded_params.add(n)
+
def _load_expert(n, p, name, shard_id, expert_id):
param = params_dict[n]
weight_loader = getattr(param, "weight_loader",
@@ -465,7 +487,15 @@ def _load_expert(n, p, name, shard_id, expert_id):
".block_sparse_moe.gate.weight")
_load(gate_name, p)
else:
- _load(n, p)
+ loaded = False
+ for param_name, weight_name, shard_id in stacked_params_mapping:
+ if weight_name in n:
+ _load_shard(n.replace(weight_name, param_name),
+ p,
+ shard_id=shard_id)
+ loaded = True
+ if not loaded:
+ _load(n, p)
return loaded_params
@@ -473,7 +503,13 @@ def _load_expert(n, p, name, shard_id, expert_id):
class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
SupportsPP, IsHybrid, SupportsV0Only,
SupportsQuant):
- packed_modules_mapping = {}
+ packed_modules_mapping = {
+ "qkv_proj": [
+ "q_proj",
+ "k_proj",
+ "v_proj",
+ ],
+ }
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py
index f759f8f1f273..ad59fe79edcb 100644
--- a/vllm/model_executor/models/interfaces.py
+++ b/vllm/model_executor/models/interfaces.py
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from collections.abc import Iterable, MutableSequence
from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol,
Union, overload, runtime_checkable)
@@ -426,6 +427,73 @@ def is_hybrid(
return isinstance(model, IsHybrid)
+@runtime_checkable
+class MixtureOfExperts(Protocol):
+ """
+ Check if the model is a mixture of experts (MoE) model.
+ """
+
+ expert_weights: MutableSequence[Iterable[Tensor]]
+ """
+ Expert weights saved in this rank.
+
+ The first dimension is the layer, and the second dimension is different
+ parameters in the layer, e.g. up/down projection weights.
+ """
+
+ num_moe_layers: int
+ """Number of MoE layers in this model."""
+
+ num_expert_groups: int
+ """Number of expert groups in this model."""
+
+ num_logical_experts: int
+ """Number of logical experts in this model."""
+
+ num_physical_experts: int
+ """Number of physical experts in this model."""
+
+ num_local_physical_experts: int
+ """Number of local physical experts in this model."""
+
+ num_routed_experts: int
+ """Number of routed experts in this model."""
+
+ num_shared_experts: int
+ """Number of shared experts in this model."""
+
+ num_redundant_experts: int
+ """Number of redundant experts in this model."""
+
+ def set_eplb_state(
+ self,
+ expert_load_view: Tensor,
+ logical_to_physical_map: Tensor,
+ logical_replica_count: Tensor,
+ ) -> None:
+ """
+ Register the EPLB state in the MoE model.
+
+ Since these are views of the actual EPLB state, any changes made by
+ the EPLB algorithm are automatically reflected in the model's behavior
+ without requiring additional method calls to set new states.
+
+ You should also collect model's `expert_weights` here instead of in
+ the weight loader, since after initial weight loading, further
+ processing like quantization may be applied to the weights.
+
+ Args:
+ expert_load_view: A view of the expert load metrics tensor.
+ logical_to_physical_map: Mapping from logical to physical experts.
+ logical_replica_count: Count of replicas for each logical expert.
+ """
+ ...
+
+
+def is_mixture_of_experts(model: object) -> TypeIs[MixtureOfExperts]:
+ return isinstance(model, MixtureOfExperts)
+
+
@runtime_checkable
class HasNoOps(Protocol):
has_noops: ClassVar[Literal[True]] = True
@@ -531,6 +599,17 @@ class SupportsTranscription(Protocol):
supports_transcription: ClassVar[Literal[True]] = True
+ @classmethod
+ def get_decoder_prompt(cls, language: str, task_type: str,
+ prompt: str) -> str:
+ """Get the decoder prompt for the ASR model."""
+ ...
+
+ @classmethod
+ def validate_language(cls, language: str) -> bool:
+ """Check if the model supports a specific ISO639_1 language."""
+ ...
+
@overload
def supports_transcription(
diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py
index 216c1f1c7ff7..1224ba7abc75 100644
--- a/vllm/model_executor/models/qwen3.py
+++ b/vllm/model_executor/models/qwen3.py
@@ -400,22 +400,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
def load_weights_from_original_qwen3_reranker(
self, weights: Iterable[tuple[str, torch.Tensor]]):
- tokens = getattr(self.config, "classifier_from_token", None)
- assert tokens is not None and len(tokens) == 2, \
- ("Try loading the original Qwen3 Reranker?, see: "
- "https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py")
- self.config.num_labels = 1
model_config = self.vllm_config.model_config
-
+ tokens = getattr(self.config, "classifier_from_token", None)
device = self.score.weight.device
- self.score = RowParallelLinear(self.config.hidden_size,
- self.config.num_labels,
- quant_config=self.quant_config,
- input_is_parallel=False,
- bias=False,
- prefix=maybe_prefix(
- self.prefix, "score")).to(device)
if self.config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
@@ -443,5 +431,6 @@ def load_weights_from_original_qwen3_reranker(
self.score.weight.data.copy_(weight)
del self.lm_head
- loaded_weights.add("classifier.weight")
+ loaded_weights.add("score.weight")
loaded_weights.discard("lm_head.weight")
+ return loaded_weights
diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py
index 417d7b22088b..90a28192eccb 100644
--- a/vllm/model_executor/models/qwen3_moe.py
+++ b/vllm/model_executor/models/qwen3_moe.py
@@ -386,6 +386,11 @@ def load_weights(self, weights: Iterable[tuple[str,
("gate_up_proj", "up_proj", 1),
]
+ # Skip loading extra parameters for GPTQ/modelopt models.
+ ignore_suffixes = (".bias", "_bias", ".k_scale", "_k_scale",
+ ".v_scale", "_v_scale", ".weight_scale",
+ "_weight_scale", ".input_scale", "_input_scale")
+
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
@@ -410,10 +415,11 @@ def load_weights(self, weights: Iterable[tuple[str,
if "mlp.experts" in name:
continue
name = name.replace(weight_name, param_name)
- # Skip loading extra bias for GPTQ models.
- if ((name.endswith(".bias") or name.endswith("_bias"))
- and name not in params_dict):
+
+ # Skip loading extra parameters for GPTQ/modelopt models.
+ if name.endswith(ignore_suffixes) and name not in params_dict:
continue
+
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
@@ -433,9 +439,9 @@ def load_weights(self, weights: Iterable[tuple[str,
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
- # Skip loading extra bias for GPTQ models.
- if ((name.endswith(".bias") or name.endswith("_bias"))
- and name not in params_dict):
+ # Skip loading extra parameters for GPTQ/modelopt models.
+ if name.endswith(
+ ignore_suffixes) and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
@@ -446,9 +452,9 @@ def load_weights(self, weights: Iterable[tuple[str,
expert_id=expert_id)
break
else:
- # Skip loading extra bias for GPTQ models.
- if ((name.endswith(".bias") or name.endswith("_bias"))
- and name not in params_dict):
+ # Skip loading extra parameters for GPTQ/modelopt models.
+ if name.endswith(
+ ignore_suffixes) and name not in params_dict:
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index faeaf6ef68cc..d566146662b8 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -52,12 +52,15 @@
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
"DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
"DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
+ "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
"ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
"Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
"Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
+ #TODO(ywang96): Support multimodal gemma3n
+ "Gemma3nForConditionalGeneration": ("gemma3n", "Gemma3nForConditionalGeneration"), # noqa: E501
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
"Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py
index 8cf2a009d667..5a0094fa749f 100644
--- a/vllm/model_executor/models/whisper.py
+++ b/vllm/model_executor/models/whisper.py
@@ -41,6 +41,113 @@
logger = init_logger(__name__)
+# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages
+
+ISO639_1_SUPPORTED_LANGS = {
+ "af": "Afrikaans",
+ "ar": "Arabic",
+ "hy": "Armenian",
+ "az": "Azerbaijani",
+ "be": "Belarusian",
+ "bs": "Bosnian",
+ "bg": "Bulgarian",
+ "ca": "Catalan",
+ "zh": "Chinese",
+ "hr": "Croatian",
+ "cs": "Czech",
+ "da": "Danish",
+ "nl": "Dutch",
+ "en": "English",
+ "et": "Estonian",
+ "fi": "Finnish",
+ "fr": "French",
+ "gl": "Galician",
+ "de": "German",
+ "el": "Greek",
+ "he": "Hebrew",
+ "hi": "Hindi",
+ "hu": "Hungarian",
+ "is": "Icelandic",
+ "id": "Indonesian",
+ "it": "Italian",
+ "ja": "Japanese",
+ "kn": "Kannada",
+ "kk": "Kazakh",
+ "ko": "Korean",
+ "lv": "Latvian",
+ "lt": "Lithuanian",
+ "mk": "Macedonian",
+ "ms": "Malay",
+ "mr": "Marathi",
+ "mi": "Maori",
+ "ne": "Nepali",
+ "no": "Norwegian",
+ "fa": "Persian",
+ "pl": "Polish",
+ "pt": "Portuguese",
+ "ro": "Romanian",
+ "ru": "Russian",
+ "sr": "Serbian",
+ "sk": "Slovak",
+ "sl": "Slovenian",
+ "es": "Spanish",
+ "sw": "Swahili",
+ "sv": "Swedish",
+ "tl": "Tagalog",
+ "ta": "Tamil",
+ "th": "Thai",
+ "tr": "Turkish",
+ "uk": "Ukrainian",
+ "ur": "Urdu",
+ "vi": "Vietnamese",
+ "cy": "Welsh"
+}
+ISO639_1_OTHER_LANGS = {
+ "lo": "Lao",
+ "jw": "Javanese",
+ "tk": "Turkmen",
+ "yi": "Yiddish",
+ "so": "Somali",
+ "bn": "Bengali",
+ "nn": "Norwegian Nynorsk",
+ "si": "Sinhala",
+ "yo": "Yoruba",
+ "sa": "Sanskrit",
+ "mi": "Mฤori",
+ "fo": "Faroese", # codespell:ignore
+ "mt": "Maltese",
+ "tg": "Tajik",
+ "mg": "Malagasy",
+ "haw": "Hawaiian",
+ "km": "Khmer",
+ "br": "Breton",
+ "ps": "Pashto",
+ "ln": "Lingala",
+ "la": "Latin",
+ "ml": "Malayalam",
+ "sq": "Albanian",
+ "su": "Sundanese",
+ "eu": "Basque",
+ "ka": "Georgian",
+ "uz": "Uzbek",
+ "sn": "Shona",
+ "ht": "Haitian",
+ "as": "Assamese",
+ "mn": "Mongolian",
+ "te": "Telugu",
+ "pa": "Panjabi",
+ "tt": "Tatar",
+ "gu": "Gujarati",
+ "oc": "Occitan",
+ "ha": "Hausa",
+ "ba": "Bashkir",
+ "my": "Burmese",
+ "sd": "Sindhi",
+ "am": "Amharic",
+ "lb": "Luxembourgish",
+ "bo": "Tibetan"
+}
+
class WhisperAudioInputs(TypedDict):
input_features: NestedTensors
@@ -731,6 +838,28 @@ def load_weights(self, weights: Iterable[tuple[str,
weights = _create_fake_bias_for_k_proj(weights)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
+ @classmethod
+ def validate_language(cls, language: str) -> bool:
+ if language in ISO639_1_SUPPORTED_LANGS:
+ return True
+ elif language in ISO639_1_OTHER_LANGS:
+ logger.warning(
+ "The selected language %s has limited accuracy with"
+ " reported WER>=0.5. Results may be less accurate "
+ "for this choice.", language)
+ return True
+ else:
+ raise ValueError(f"Unsupported language: {language}."
+ "Language should be one of:" +
+ f" {list(ISO639_1_SUPPORTED_LANGS.values())}" +
+ f"or {list(ISO639_1_OTHER_LANGS.values())}")
+
+ @classmethod
+ def get_decoder_prompt(cls, language: str, task_type: str,
+ prompt: str) -> str:
+ return (f"<|startoftranscript|><|{language}|><|{task_type}|>"
+ f"<|notimestamps|>{prompt}")
+
def _create_fake_bias_for_k_proj(
weights: Iterable[tuple[str, torch.Tensor]]
diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py
index f962fafabf50..0f08bf986333 100644
--- a/vllm/platforms/interface.py
+++ b/vllm/platforms/interface.py
@@ -173,17 +173,12 @@ def is_sleep_mode_available(self) -> bool:
@classmethod
def device_id_to_physical_device_id(cls, device_id: int):
- if cls.device_control_env_var in os.environ:
+ # Treat empty device control env var as unset. This is a valid
+ # configuration in Ray setups where the engine is launched in
+ # a CPU-only placement group located on a GPU node.
+ if cls.device_control_env_var in os.environ and os.environ[
+ cls.device_control_env_var] != "":
device_ids = os.environ[cls.device_control_env_var].split(",")
- if device_ids == [""]:
- msg = (f"{cls.device_control_env_var} is set to empty string, "
- "which means current platform support is disabled. If "
- "you are using ray, please unset the environment "
- f"variable `{cls.device_control_env_var}` inside the "
- "worker/actor. Check "
- "https://github.com/vllm-project/vllm/issues/8402 for "
- "more information.")
- raise RuntimeError(msg)
physical_device_id = device_ids[device_id]
return int(physical_device_id)
else:
diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py
index 07e52017f5a5..0387e348965d 100644
--- a/vllm/platforms/tpu.py
+++ b/vllm/platforms/tpu.py
@@ -122,16 +122,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
PallasAttentionBackend)
cache_config.block_size = PallasAttentionBackend.get_page_size(
vllm_config) # type: ignore[assignment]
- min_page_size = PallasAttentionBackend.get_min_page_size(
- vllm_config)
- if min_page_size > cache_config.block_size:
- logger.warning(
- "Increase the page size from %s to %s to make sure there's"
- "no SMEM OOM",
- cache_config.block_size,
- min_page_size,
- )
- cache_config.block_size = min_page_size # type: ignore[assignment]
parallel_config = vllm_config.parallel_config
scheduler_config = vllm_config.scheduler_config
diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py
index 73f6f3d41767..f361f5e2616e 100644
--- a/vllm/platforms/xpu.py
+++ b/vllm/platforms/xpu.py
@@ -1,18 +1,21 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import os
from typing import TYPE_CHECKING, Optional
import torch
+import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
if TYPE_CHECKING:
- from vllm.config import VllmConfig
+ from vllm.config import ModelConfig, VllmConfig
else:
+ ModelConfig = None
VllmConfig = None
logger = init_logger(__name__)
@@ -35,8 +38,13 @@ def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
use_mla: bool) -> str:
if selected_backend != _Backend.IPEX:
logger.info("Cannot use %s backend on XPU.", selected_backend)
- logger.info("Using IPEX attention backend.")
- return "vllm.attention.backends.ipex_attn.IpexAttnBackend"
+ use_v1 = envs.VLLM_USE_V1
+ if use_v1:
+ logger.info("Using Flash Attention backend on V1 engine.")
+ return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
+ else:
+ logger.info("Using IPEX attention backend.")
+ return "vllm.attention.backends.ipex_attn.IpexAttnBackend"
@classmethod
def get_device_capability(
@@ -67,25 +75,27 @@ def inference_mode(cls):
@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
cache_config = vllm_config.cache_config
+ # in V1(or with ipex chunked prefill) block_size is 64
if cache_config and cache_config.block_size is None:
- cache_config.block_size = 16
-
- # check and update model config
- model_config = vllm_config.model_config
- if model_config.dtype == torch.bfloat16:
- bf16_supported = cls.device_support_bf16()
- if not bf16_supported:
+ if envs.VLLM_USE_V1:
+ cache_config.block_size = 64
+ else:
+ cache_config.block_size = 16
+
+ # Instances created using VllmConfig() typically have model_config as
+ # None by default. The modification involves adding a check to prevent
+ # potential null exceptions check and update model config.
+ if vllm_config.model_config is not None:
+ model_config = vllm_config.model_config
+ if model_config.dtype == torch.bfloat16:
+ bf16_supported = cls.device_support_bf16()
+ if not bf16_supported:
+ model_config.dtype = torch.float16
+ if not model_config.enforce_eager:
logger.warning(
- "bfloat16 is only supported on Intel Data Center GPU, "
- "Intel Arc GPU is not supported yet. Your device is %s,"
- " which is not supported. will fallback to float16",
- cls.get_device_name())
- model_config.dtype = torch.float16
- if not model_config.enforce_eager:
- logger.warning(
- "CUDA graph is not supported on XPU, fallback to the eager "
- "mode.")
- model_config.enforce_eager = True
+ "CUDA graph is not supported on XPU, fallback to the eager "
+ "mode.")
+ model_config.enforce_eager = True
if vllm_config.speculative_config is not None:
raise NotImplementedError(
@@ -96,21 +106,27 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
# check and update parallel config
parallel_config = vllm_config.parallel_config
- if parallel_config.worker_cls == "auto":
+ if envs.VLLM_USE_V1:
+ parallel_config.worker_cls =\
+ "vllm.v1.worker.xpu_worker.XPUWorker"
+ else:
parallel_config.worker_cls = "vllm.worker.xpu_worker.XPUWorker"
if parallel_config.distributed_executor_backend is None:
- parallel_config.distributed_executor_backend = "ray"
+ if parallel_config.world_size > 1:
+ parallel_config.distributed_executor_backend = "ray"
+ else:
+ parallel_config.distributed_executor_backend = "uni"
elif parallel_config.distributed_executor_backend == "mp":
# FIXME(kunshang):
# spawn needs calling `if __name__ == '__main__':``
# fork is not supported for xpu start new process.
- logger.error(
- "Both start methods (spawn and fork) have issue "
- "on XPU if you use mp backend, setting it to ray instead.")
- parallel_config.distributed_executor_backend = "ray"
-
- elif parallel_config.distributed_executor_backend != "ray":
+ if envs.VLLM_WORKER_MULTIPROC_METHOD != "spawn":
+ os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
+ logger.warning(
+ "Please use spawn as start method if you want to use mp.")
+ elif parallel_config.distributed_executor_backend != "ray" and \
+ parallel_config.distributed_executor_backend != "uni":
logger.warning(
"%s is not supported on XPU, fallback to ray distributed"
" executor backend.",
@@ -142,15 +158,35 @@ def get_current_memory_usage(cls,
@classmethod
def device_support_bf16(cls) -> bool:
device_name = cls.get_device_name().lower()
- if device_name.count("arc") > 0:
+ if cls.is_client_gpu_a770():
+ logger.warning("Intel Arc A770 have bfloat16 accuracy known issue,"
+ " fallback to float16")
return False
- elif device_name.count("data center gpu") > 0:
- return True
else:
- logger.warning("Unknown device name %s, always use float16",
- device_name)
- return False
+ logger.info(
+ "Device name %s supports bfloat16. Please file an issue "
+ "if you encounter any accuracy problems with bfloat16.",
+ device_name)
+ return True
+
+ @classmethod
+ def is_data_center_gpu(cls) -> bool:
+ device_name = cls.get_device_name().lower()
+ return device_name.count("data center gpu") > 0
+
+ @classmethod
+ def is_client_gpu_a770(cls) -> bool:
+ device_name = cls.get_device_name().lower()
+ return device_name.count("a770") > 0
@classmethod
def get_device_communicator_cls(cls) -> str:
return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator" # noqa
+
+ @classmethod
+ def supports_v1(cls, model_config: ModelConfig) -> bool:
+ return True
+
+ @classmethod
+ def device_count(cls) -> int:
+ return torch.xpu.device_count()
diff --git a/vllm/utils.py b/vllm/utils.py
index fdefda901c4d..7eb3c1e347cd 100644
--- a/vllm/utils.py
+++ b/vllm/utils.py
@@ -2929,3 +2929,31 @@ def is_torch_equal_or_newer(target: str) -> bool:
def _is_torch_equal_or_newer(torch_version: str, target: str) -> bool:
torch_version = version.parse(torch_version)
return torch_version >= version.parse(target)
+
+
+@cache
+def _has_module(module_name: str) -> bool:
+ """Return True if *module_name* can be found in the current environment.
+
+ The result is cached so that subsequent queries for the same module incur
+ no additional overhead.
+ """
+ return importlib.util.find_spec(module_name) is not None
+
+
+def has_pplx() -> bool:
+ """Whether the optional `pplx_kernels` package is available."""
+
+ return _has_module("pplx_kernels")
+
+
+def has_deep_ep() -> bool:
+ """Whether the optional `deep_ep` package is available."""
+
+ return _has_module("deep_ep")
+
+
+def has_deep_gemm() -> bool:
+ """Whether the optional `deep_gemm` package is available."""
+
+ return _has_module("deep_gemm")
\ No newline at end of file
diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py
index ef65d2ea36e4..527b31153410 100755
--- a/vllm/v1/attention/backends/flash_attn.py
+++ b/vllm/v1/attention/backends/flash_attn.py
@@ -14,10 +14,16 @@
from vllm.attention.layer import Attention
from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
- get_flash_attn_version)
+ get_flash_attn_version,
+ is_flash_attn_varlen_func_available)
+
+if is_flash_attn_varlen_func_available():
+ from vllm.attention.utils.fa_utils import (flash_attn_varlen_func,
+ get_scheduler_metadata,
+ reshape_and_cache_flash)
+
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
-from vllm.platforms import current_platform
from vllm.utils import cdiv
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout,
@@ -28,10 +34,6 @@
if TYPE_CHECKING:
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
-if current_platform.is_cuda():
- from vllm.vllm_flash_attn import (flash_attn_varlen_func,
- get_scheduler_metadata)
-
logger = init_logger(__name__)
@@ -443,7 +445,7 @@ def forward(
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
- torch.ops._C_cache_ops.reshape_and_cache_flash(
+ reshape_and_cache_flash(
key,
value,
key_cache,
diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py
index 1069578cfd29..49f0772c62d1 100644
--- a/vllm/v1/attention/backends/pallas.py
+++ b/vllm/v1/attention/backends/pallas.py
@@ -5,8 +5,12 @@
from typing import Any, Optional
import torch
-# Required to register custom ops.
+import torch_xla.core.xla_builder as xb
import torch_xla.experimental.custom_kernel # noqa: F401
+# Required to register custom ops.
+from torch.library import impl
+from torch_xla._internal.jax_workarounds import requires_jax
+from torch_xla.experimental.custom_kernel import XLA_LIB
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType)
@@ -48,13 +52,7 @@ def get_kv_cache_shape(
) -> tuple[int, ...]:
padded_head_size = cdiv(
head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
- num_blocks = num_blocks * head_size // padded_head_size
- if padded_head_size != head_size:
- logger.warning_once(
- "head size is padded to %d, and num_blocks is adjusted to %d"
- " accordingly", padded_head_size, num_blocks)
- head_size = padded_head_size
- return (num_blocks, block_size, num_kv_heads * 2, head_size)
+ return (num_blocks, block_size, num_kv_heads * 2, padded_head_size)
@staticmethod
def swap_blocks(
@@ -77,6 +75,11 @@ def get_min_page_size(vllm_config: VllmConfig) -> int:
min_page_size = 1 << (min_page_size - 1).bit_length()
return min_page_size
+ @staticmethod
+ def get_max_num_seqs(model_len: int, page_size: int) -> int:
+ num_page_per_req = cdiv(model_len, page_size)
+ return 1024 * 1024 // 2 // num_page_per_req // 4
+
# TPU has limited SREGs (scalar registers), if page_size is too small, we
# can spill SREGs easily which leads to bad performance. The strategy we
# apply here is trying to split max-model-len to 16 pages which make the
@@ -108,6 +111,7 @@ class PallasMetadata:
context_lens: torch.Tensor
query_start_loc: torch.Tensor
num_seqs: torch.Tensor
+ num_slices_per_kv_cache_update_block: int
class PallasAttentionBackendImpl(AttentionImpl):
@@ -213,7 +217,9 @@ def forward(
# Write input keys and values to the KV cache.
# Skip this if sharing KV cache with an earlier attention layer.
slot_mapping = attn_metadata.slot_mapping
- write_to_kv_cache(key, value, kv_cache, slot_mapping)
+ write_to_kv_cache(
+ key, value, kv_cache, slot_mapping,
+ attn_metadata.num_slices_per_kv_cache_update_block)
output = torch.ops.xla.ragged_paged_attention(
query,
@@ -245,6 +251,7 @@ def write_to_kv_cache(
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
+ num_slices_per_kv_cache_update_block: int,
) -> None:
""" Write the key and values to the KV cache.
@@ -252,9 +259,9 @@ def write_to_kv_cache(
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size]
-
+ num_slices_per_kv_cache_update_block: int
"""
- _, _, num_combined_kv_heads, head_size = kv_cache.shape
+ _, page_size, num_combined_kv_heads, head_size = kv_cache.shape
head_size = cdiv(head_size,
TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads,
@@ -263,4 +270,41 @@ def write_to_kv_cache(
torch.ops.xla.dynamo_set_buffer_donor_(kv_cache, True)
kv_cache = kv_cache.flatten(0, 1)
- kv_cache.index_copy_(0, slot_mapping, kv)
+ new_kv_cache = torch.ops.xla.kv_cache_update_op(
+ kv, slot_mapping, kv_cache, page_size,
+ num_slices_per_kv_cache_update_block)
+ # NOTE: the in-place copy will be optimized away by XLA compiler.
+ kv_cache.copy_(new_kv_cache)
+
+
+@requires_jax
+def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor,
+ kv_cache: torch.Tensor, page_size: int,
+ num_slices_per_block: int):
+ from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update
+ new_kv_cache = xb.call_jax(kv_cache_update, (kv, slot_mapping, kv_cache), {
+ "page_size": page_size,
+ "num_slices_per_block": num_slices_per_block
+ })
+ return new_kv_cache
+
+
+XLA_LIB.define(
+ "kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache, "
+ "int page_size, int num_slices_per_block) -> Tensor", )
+
+
+@impl(XLA_LIB, "kv_cache_update_op", "XLA")
+def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
+ kv_cache: torch.Tensor, page_size: int,
+ num_slices_per_block: int) -> torch.Tensor:
+ new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache,
+ page_size, num_slices_per_block)
+ return new_kv_cache
+
+
+@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd")
+def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
+ kv_cache: torch.Tensor, page_size: int,
+ num_slices_per_block: int) -> torch.Tensor:
+ return kv_cache
diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py
index e011e95efd41..dc8ff2261306 100644
--- a/vllm/v1/attention/backends/rocm_aiter_fa.py
+++ b/vllm/v1/attention/backends/rocm_aiter_fa.py
@@ -243,8 +243,8 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
self.runner.device, non_blocking=True)
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
self.runner.device, non_blocking=True)
- local_max_query_len = seqlens_q_local_np.max()
- local_max_seq_len = virt_k_seqlens_np.max()
+ local_max_query_len = int(seqlens_q_local_np.max())
+ local_max_seq_len = int(virt_k_seqlens_np.max())
local_scheduler_metadata = schedule(
batch_size=local_query_start_loc.shape[0] - 1,
cu_query_lens=local_query_start_loc,
@@ -253,6 +253,17 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
max_seq_len=local_max_seq_len,
causal=True)
+ local_cu_seq_lens = torch.zeros(virt_k_seqlens_np.shape[0] + 1,
+ dtype=torch.int32,
+ device=self.runner.device)
+ local_cu_seq_lens[1:] = torch.cumsum(
+ torch.from_numpy(virt_k_seqlens_np).to(
+ device=self.runner.device,
+ dtype=torch.int32,
+ non_blocking=True),
+ dim=0)
+
+
local_attn_metadata = \
AiterFlashAttentionMetadata.LocalAttentionMetadata(
local_query_start_loc=local_query_start_loc,
@@ -260,6 +271,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
local_block_table=virt_block_table_tensor,
local_max_query_len=local_max_query_len,
local_max_seq_len=local_max_seq_len,
+ local_cu_seq_lens=local_cu_seq_lens,
local_scheduler_metadata=local_scheduler_metadata,
)
@@ -368,6 +380,7 @@ class LocalAttentionMetadata:
local_block_table: torch.Tensor
local_max_query_len: int
local_max_seq_len: int
+ local_cu_seq_lens: torch.Tensor
local_scheduler_metadata: Optional[torch.Tensor]
local_attn_metadata: Optional[LocalAttentionMetadata] = None
@@ -387,6 +400,7 @@ def __init__(
blocksparse_params: Optional[dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
+ kv_sharing_target_layer_name: Optional[int] = None,
use_irope: bool = False,
) -> None:
if blocksparse_params is not None:
@@ -408,6 +422,7 @@ def __init__(
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0.
self.logits_soft_cap = logits_soft_cap
+ self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
@@ -478,22 +493,25 @@ def forward(
# performance to make sure it does not introduce any overhead.
num_actual_tokens = attn_metadata.num_actual_tokens
- # Reshape the input keys and values and store them in the cache.
- # NOTE(woosuk): Here, key and value are padded while slot_mapping is
- # not padded. However, we don't need to do key[:num_actual_tokens] and
- # value[:num_actual_tokens] because the reshape_and_cache_flash op uses
- # the slot_mapping's shape to determine the number of actual tokens.
key_cache, value_cache = kv_cache.unbind(0)
- torch.ops._C_cache_ops.reshape_and_cache_flash(
- key,
- value,
- key_cache,
- value_cache,
- attn_metadata.slot_mapping,
- self.kv_cache_dtype,
- layer._k_scale,
- layer._v_scale,
- )
+ if self.kv_sharing_target_layer_name is None:
+ # Reshape the input keys and values and store them in the cache.
+ # Skip this if sharing KV cache with an earlier attention layer.
+ # NOTE(woosuk): Here, key and value are padded while slot_mapping is
+ # not padded. However, we don't need to do key[:num_actual_tokens]
+ # and value[:num_actual_tokens] because the reshape_and_cache_flash
+ # op uses the slot_mapping's shape to determine the number of
+ # actual tokens.
+ torch.ops._C_cache_ops.reshape_and_cache_flash(
+ key,
+ value,
+ key_cache,
+ value_cache,
+ attn_metadata.slot_mapping,
+ self.kv_cache_dtype,
+ layer._k_scale,
+ layer._v_scale,
+ )
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(torch.float8_e4m3fnuz)
@@ -541,7 +559,8 @@ def forward(
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=block_table,
- cu_seqlens_k=cu_seq_lens,
+ cu_seqlens_k=(cu_seq_lens if not use_local_attn else
+ local_metadata.local_cu_seq_lens),
)
_, num_heads, head_size = query.shape
diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py
index 6f31031a1086..efc5b3012ec2 100644
--- a/vllm/v1/core/sched/output.py
+++ b/vllm/v1/core/sched/output.py
@@ -83,29 +83,27 @@ def anon_repr(self):
@dataclass
class CachedRequestData:
- req_id: str
+ req_ids: list[str]
# If resumed_from_preemption is False, new_block_ids will be appended to
# the request's block IDs. If True, new_block_ids will be used as the
# request's block IDs instead of appending to the existing block IDs.
- resumed_from_preemption: bool
- new_token_ids: list[int]
- new_block_ids: tuple[list[int], ...]
- num_computed_tokens: int
+ resumed_from_preemption: list[bool]
+ new_token_ids: list[list[int]]
+ new_block_ids: list[tuple[list[int], ...]]
+ num_computed_tokens: list[int]
+
+ @property
+ def num_reqs(self) -> int:
+ return len(self.req_ids)
@classmethod
- def from_request(
- cls,
- request: Request,
- resumed_from_preemption: bool,
- new_token_ids: list[int],
- new_block_ids: tuple[list[int], ...],
- ) -> CachedRequestData:
+ def make_empty(cls) -> CachedRequestData:
return cls(
- req_id=request.request_id,
- resumed_from_preemption=resumed_from_preemption,
- new_token_ids=new_token_ids,
- new_block_ids=new_block_ids,
- num_computed_tokens=request.num_computed_tokens,
+ req_ids=[],
+ resumed_from_preemption=[],
+ new_token_ids=[],
+ new_block_ids=[],
+ num_computed_tokens=[],
)
@@ -119,7 +117,7 @@ class SchedulerOutput:
# list of the requests that have been scheduled before.
# Since the request's data is already cached in the worker processes,
# we only send the diff to minimize the communication cost.
- scheduled_cached_reqs: list[CachedRequestData]
+ scheduled_cached_reqs: CachedRequestData
# req_id -> num_scheduled_tokens
# Number of tokens scheduled for each request.
diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py
index 00b0844a5660..20a40d74f311 100644
--- a/vllm/v1/core/sched/scheduler.py
+++ b/vllm/v1/core/sched/scheduler.py
@@ -3,8 +3,9 @@
from __future__ import annotations
+import itertools
import time
-from collections import defaultdict, deque
+from collections import defaultdict
from collections.abc import Iterable
from typing import Any, Optional, Union
@@ -117,12 +118,6 @@ def __init__(
# KV Connector: requests in process of async KV loading or recving
self.finished_recving_kv_req_ids: set[str] = set()
- # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
- # them at each scheduling step.
- # Request id -> deque of CachedRequestData
- self._cached_reqs_data: dict[
- str, deque[CachedRequestData]] = defaultdict(deque)
-
# Encoder-related.
# Calculate encoder cache size if applicable
# NOTE: For now we use the same budget for both compute and space.
@@ -547,27 +542,16 @@ def schedule(self) -> SchedulerOutput:
req_to_new_block_ids[req.request_id])
for req in scheduled_new_reqs
]
- resumed_reqs_data = [
- self._make_cached_request_data(
- req,
- num_scheduled_tokens[req.request_id],
- len(scheduled_spec_decode_tokens.get(req.request_id, ())),
- req_to_new_block_ids[req.request_id],
- resumed_from_preemption=True,
- ) for req in scheduled_resumed_reqs
- ]
- running_reqs_data = [
- self._make_cached_request_data(
- req,
- num_scheduled_tokens[req.request_id],
- len(scheduled_spec_decode_tokens.get(req.request_id, ())),
- req_to_new_block_ids[req.request_id],
- resumed_from_preemption=False,
- ) for req in scheduled_running_reqs
- ]
+ cached_reqs_data = self._make_cached_request_data(
+ scheduled_running_reqs,
+ scheduled_resumed_reqs,
+ num_scheduled_tokens,
+ scheduled_spec_decode_tokens,
+ req_to_new_block_ids,
+ )
scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data,
- scheduled_cached_reqs=resumed_reqs_data + running_reqs_data,
+ scheduled_cached_reqs=cached_reqs_data,
num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
@@ -613,34 +597,39 @@ def schedule(self) -> SchedulerOutput:
def _make_cached_request_data(
self,
- request: Request,
- num_scheduled_tokens: int,
- num_scheduled_spec_tokens: int,
- new_block_ids: tuple[list[int], ...],
- resumed_from_preemption: bool,
+ running_reqs: list[Request],
+ resumed_reqs: list[Request],
+ num_scheduled_tokens: dict[str, int],
+ spec_decode_tokens: dict[str, list[int]],
+ req_to_new_block_ids: dict[str, tuple[list[int], ...]],
) -> CachedRequestData:
- # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
- # them at each scheduling step.
- num_computed_tokens = request.num_computed_tokens
- num_regular_tokens = num_scheduled_tokens - num_scheduled_spec_tokens
- new_token_ids = request.all_token_ids[
- num_computed_tokens:num_computed_tokens + num_regular_tokens]
-
- req_data_queue = self._cached_reqs_data.get(request.request_id)
- if req_data_queue:
- req_data = req_data_queue.popleft()
- req_data.resumed_from_preemption = resumed_from_preemption
- req_data.new_token_ids = new_token_ids
- req_data.new_block_ids = new_block_ids
- req_data.num_computed_tokens = num_computed_tokens
- else:
- # No cached request data, or all cached request data has been
- # used by the scheduled requests.
- req_data = CachedRequestData.from_request(request,
- resumed_from_preemption,
- new_token_ids,
- new_block_ids)
- return req_data
+ req_ids: list[str] = []
+ new_token_ids: list[list[int]] = []
+ new_block_ids: list[tuple[list[int], ...]] = []
+ num_computed_tokens: list[int] = []
+
+ for req in itertools.chain(running_reqs, resumed_reqs):
+ req_id = req.request_id
+ req_ids.append(req_id)
+ num_tokens = (num_scheduled_tokens[req_id] -
+ len(spec_decode_tokens.get(req_id, ())))
+ token_ids = req.all_token_ids[req.num_computed_tokens:req.
+ num_computed_tokens + num_tokens]
+ new_token_ids.append(token_ids)
+ new_block_ids.append(req_to_new_block_ids[req_id])
+ num_computed_tokens.append(req.num_computed_tokens)
+ # Because resumed_reqs is usually empty, it is more efficient to do
+ # in-place appending so that we don't need to allocate a new list.
+ resumed_from_preemption = [False] * len(running_reqs)
+ resumed_from_preemption += [True] * len(resumed_reqs)
+
+ return CachedRequestData(
+ req_ids=req_ids,
+ resumed_from_preemption=resumed_from_preemption,
+ new_token_ids=new_token_ids,
+ new_block_ids=new_block_ids,
+ num_computed_tokens=num_computed_tokens,
+ )
def _try_schedule_encoder_inputs(
self,
@@ -870,19 +859,11 @@ def update_from_output(
if not stopped:
new_running.append(request)
+ self.running = new_running
# KV Connector: update state for finished KV Transfers.
self._update_from_kv_xfer_finished(model_runner_output)
- # Return the cached request data to the queue so they can be reused.
- for req_data in scheduler_output.scheduled_cached_reqs:
- # NOTE(rob): since we free stopped reqs above, adding stopped reqs
- # to _cached_reqs_data will cause a memory leak.
- if req_data.req_id not in self.finished_req_ids:
- self._cached_reqs_data[req_data.req_id].append(req_data)
-
- self.running = new_running
-
# Create EngineCoreOutputs for all clients that have requests with
# outputs in this step.
engine_core_outputs = {
@@ -965,13 +946,11 @@ def finish_requests(
self._free_request(request)
def _free_request(self, request: Request) -> Optional[dict[str, Any]]:
-
assert request.is_finished()
delay_free_blocks, kv_xfer_params = self._connector_finished(request)
self.encoder_cache_manager.free(request)
request_id = request.request_id
- self._cached_reqs_data.pop(request_id, None)
self.finished_req_ids.add(request_id)
if self.finished_req_ids_dict is not None:
self.finished_req_ids_dict[request.client_index].add(request_id)
@@ -983,7 +962,6 @@ def _free_request(self, request: Request) -> Optional[dict[str, Any]]:
def _free_blocks(self, request: Request):
assert request.is_finished()
- assert request.request_id not in self._cached_reqs_data
self.kv_cache_manager.free(request)
self.kv_cache_manager.free_block_hashes(request)
del self.requests[request.request_id]
diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py
index 25fab2713114..a2328c37ba0c 100644
--- a/vllm/v1/engine/llm_engine.py
+++ b/vllm/v1/engine/llm_engine.py
@@ -192,6 +192,11 @@ def add_request(
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
+ # Validate the request_id type.
+ if not isinstance(request_id, str):
+ raise TypeError(
+ f"request_id must be a string, got {type(request_id)}")
+
# Process raw inputs into the request.
prompt_str, request = self.processor.process_inputs(
request_id, prompt, params, arrival_time, lora_request,
diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py
index a0b170ba55ad..7e7703df2cf1 100644
--- a/vllm/v1/engine/processor.py
+++ b/vllm/v1/engine/processor.py
@@ -173,6 +173,12 @@ def _validate_structured_output(self, params: SamplingParams) -> None:
params.guided_decoding.backend = engine_level_backend
# Request content validation
+ if (isinstance(params.guided_decoding.choice, list)
+ and not params.guided_decoding.choice):
+ # It is invalid for choice to be an empty list
+ raise ValueError(f"Choice '{params.guided_decoding.choice}' "
+ "cannot be an empty list")
+
if engine_level_backend.startswith("xgrammar"):
# xgrammar with no fallback
validate_xgrammar_grammar(params)
diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py
index 40639fdf2433..29d39de212f8 100644
--- a/vllm/v1/worker/gpu_model_runner.py
+++ b/vllm/v1/worker/gpu_model_runner.py
@@ -21,6 +21,7 @@
from vllm.compilation.counter import compilation_counter
from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config)
+from vllm.distributed.eplb.eplb_state import EplbState
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
@@ -33,7 +34,8 @@
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
-from vllm.model_executor.models.interfaces import has_step_pooler
+from vllm.model_executor.models.interfaces import (has_step_pooler,
+ is_mixture_of_experts)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.multimodal.utils import group_mm_inputs_by_modality
@@ -43,7 +45,7 @@
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, LazyLoader, async_tensor_h2d, cdiv,
check_use_alibi, get_dtype_size,
- is_pin_memory_available)
+ is_pin_memory_available, round_up)
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
@@ -150,6 +152,13 @@ def __init__(
# Sampler
self.sampler = Sampler()
+ self.eplb_state: Optional[EplbState] = None
+ """
+ State of the expert parallelism load balancer.
+
+ Will be lazily initialized when the model is loaded.
+ """
+
# Lazy initializations
# self.model: nn.Module # Set after load_model
# Initialize in initialize_kv_cache
@@ -461,34 +470,36 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
req_ids_to_add.append(req_id)
# Update the states of the running/resumed requests.
- for req_data in scheduler_output.scheduled_cached_reqs:
- req_id = req_data.req_id
+ req_data = scheduler_output.scheduled_cached_reqs
+ for i, req_id in enumerate(req_data.req_ids):
req_state = self.requests[req_id]
+ num_computed_tokens = req_data.num_computed_tokens[i]
+ new_token_ids = req_data.new_token_ids[i]
+ new_block_ids = req_data.new_block_ids[i]
+ resumed_from_preemption = req_data.resumed_from_preemption[i]
# Update the cached states.
- num_computed_tokens = req_data.num_computed_tokens
req_state.num_computed_tokens = num_computed_tokens
# Add the sampled token(s) from the previous step (if any).
# This doesn't include "unverified" tokens like spec decode tokens.
- num_new_tokens = (num_computed_tokens +
- len(req_data.new_token_ids) -
+ num_new_tokens = (num_computed_tokens + len(new_token_ids) -
req_state.num_tokens)
if num_new_tokens == 1:
# Avoid slicing list in most common case.
- req_state.output_token_ids.append(req_data.new_token_ids[-1])
+ req_state.output_token_ids.append(new_token_ids[-1])
elif num_new_tokens > 0:
req_state.output_token_ids.extend(
- req_data.new_token_ids[-num_new_tokens:])
+ new_token_ids[-num_new_tokens:])
# Update the block IDs.
- if not req_data.resumed_from_preemption:
+ if not resumed_from_preemption:
# Append the new blocks to the existing block IDs.
- for block_ids, new_block_ids in zip(req_state.block_ids,
- req_data.new_block_ids):
- block_ids.extend(new_block_ids)
+ for block_ids, new_ids in zip(req_state.block_ids,
+ new_block_ids):
+ block_ids.extend(new_ids)
else:
# The request is resumed from preemption.
# Replace the existing block IDs with the new ones.
- req_state.block_ids = req_data.new_block_ids
+ req_state.block_ids = new_block_ids
req_index = self.input_batch.req_id_to_index.get(req_id)
if req_index is None:
@@ -501,14 +512,12 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
# Update the persistent batch.
self.input_batch.num_computed_tokens_cpu[req_index] = (
num_computed_tokens)
- self.input_batch.block_table.append_row(req_data.new_block_ids,
- req_index)
+ self.input_batch.block_table.append_row(new_block_ids, req_index)
# Add new_token_ids to token_ids_cpu.
start_token_index = num_computed_tokens
- end_token_index = num_computed_tokens + len(req_data.new_token_ids)
+ end_token_index = num_computed_tokens + len(new_token_ids)
self.input_batch.token_ids_cpu[
- req_index,
- start_token_index:end_token_index] = req_data.new_token_ids
+ req_index, start_token_index:end_token_index] = new_token_ids
self.input_batch.num_tokens_no_spec[req_index] = end_token_index
# Add spec_token_ids to token_ids_cpu.
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
@@ -1178,6 +1187,24 @@ def sync_and_slice_intermediate_tensors(
for k, v in self.intermediate_tensors.items()
})
+ def eplb_step(self,
+ is_dummy: bool = False,
+ is_profile: bool = False) -> None:
+ """
+ Step for the EPLB (Expert Parallelism Load Balancing) state.
+ """
+ if not self.parallel_config.enable_eplb:
+ return
+
+ assert self.eplb_state is not None
+ assert is_mixture_of_experts(self.model)
+ self.eplb_state.step(
+ self.model,
+ is_dummy,
+ is_profile,
+ log_stats=self.parallel_config.eplb_log_balancedness,
+ )
+
def get_dp_padding(self,
num_tokens: int) -> tuple[int, Optional[torch.Tensor]]:
dp_size = self.vllm_config.parallel_config.data_parallel_size
@@ -1281,7 +1308,6 @@ def execute_model(
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if self.compilation_config.pass_config. \
enable_sequence_parallelism and tp_size > 1:
- from vllm.utils import round_up
num_input_tokens = round_up(num_scheduled_tokens, tp_size)
else:
num_input_tokens = num_scheduled_tokens
@@ -1362,6 +1388,8 @@ def execute_model(
hidden_states, aux_hidden_states = model_output
else:
hidden_states = model_output
+ aux_hidden_states = None
+
# Broadcast PP output for external_launcher (torchrun)
# to make sure we are synced across pp ranks
# TODO: Support overlapping mirco-batches
@@ -1484,25 +1512,67 @@ def execute_model(
if not self.speculative_config:
# Speculative decoding is not enabled.
spec_token_ids = None
- elif self.speculative_config.method == "ngram":
+ else:
+ spec_token_ids = self.propose_draft_token_ids(
+ scheduler_output,
+ valid_sampled_token_ids,
+ sampling_metadata,
+ hidden_states,
+ sample_hidden_states,
+ aux_hidden_states,
+ spec_decode_metadata,
+ attn_metadata,
+ )
+
+ # Clear KVConnector state after all KVs are generated.
+ if has_kv_transfer_group():
+ get_kv_transfer_group().clear_connector_metadata()
+
+ self.eplb_step()
+
+ return ModelRunnerOutput(
+ req_ids=self.input_batch.req_ids,
+ req_id_to_index=self.input_batch.req_id_to_index,
+ sampled_token_ids=valid_sampled_token_ids,
+ spec_token_ids=spec_token_ids,
+ logprobs=logprobs_lists,
+ prompt_logprobs_dict=prompt_logprobs_dict,
+ pooler_output=[],
+ finished_sending=finished_sending,
+ finished_recving=finished_recving,
+ num_nans_in_logits=num_nans_in_logits,
+ )
+
+ def propose_draft_token_ids(
+ self,
+ scheduler_output: "SchedulerOutput",
+ sampled_token_ids: list[list[int]],
+ sampling_metadata: SamplingMetadata,
+ hidden_states: torch.Tensor,
+ sample_hidden_states: torch.Tensor,
+ aux_hidden_states: Optional[torch.Tensor],
+ spec_decode_metadata: Optional[SpecDecodeMetadata],
+ attn_metadata: dict[str, Any],
+ ) -> list[list[int]]:
+ num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
+ if self.speculative_config.method == "ngram":
assert isinstance(self.drafter, NgramProposer)
- spec_token_ids = self.generate_draft_token_ids(
- valid_sampled_token_ids, sampling_metadata)
+ spec_token_ids = self.propose_ngram_draft_token_ids(
+ sampled_token_ids)
elif self.speculative_config.method == "medusa":
assert isinstance(self.drafter, MedusaProposer)
- if max_gen_len == 1:
+ if sample_hidden_states.shape[0] == len(sampled_token_ids):
+ # The input to the target model does not include draft tokens.
hidden_states = sample_hidden_states
else:
indices = []
offset = 0
for num_draft, tokens in zip(
spec_decode_metadata.num_draft_tokens,
- valid_sampled_token_ids):
+ sampled_token_ids):
indices.append(offset + len(tokens) - 1)
offset += num_draft + 1
-
- indices = torch.tensor(indices,
- device=sample_hidden_states.device)
+ indices = torch.tensor(indices, device=self.device)
hidden_states = sample_hidden_states[indices]
spec_token_ids = self.drafter.propose(
@@ -1513,7 +1583,7 @@ def execute_model(
assert isinstance(self.drafter, EagleProposer)
# TODO(woosuk): Refactor the loop.
next_token_ids: list[int] = []
- for i, token_ids in enumerate(valid_sampled_token_ids):
+ for i, token_ids in enumerate(sampled_token_ids):
if token_ids:
# Common case.
next_token_id = token_ids[-1]
@@ -1543,7 +1613,8 @@ def execute_model(
if spec_decode_metadata is None:
# input_ids can be None for multimodal models.
target_token_ids = self.input_ids[:num_scheduled_tokens]
- target_positions = positions[:num_scheduled_tokens]
+ # TODO(woosuk): Support M-RoPE.
+ target_positions = self.positions[:num_scheduled_tokens]
if self.use_aux_hidden_state_outputs:
target_hidden_states = torch.cat(
[h[:num_scheduled_tokens] for h in aux_hidden_states],
@@ -1556,7 +1627,7 @@ def execute_model(
# TODO(woosuk): Refactor this.
num_draft_tokens = spec_decode_metadata.num_draft_tokens
num_rejected_tokens = [
- n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
+ n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
for i, n in enumerate(num_draft_tokens)
]
num_rejected_tokens_tensor = async_tensor_h2d(
@@ -1571,7 +1642,8 @@ def execute_model(
num_tokens,
)
target_token_ids = self.input_ids[token_indices]
- target_positions = positions[token_indices]
+ # TODO(woosuk): Support M-RoPE.
+ target_positions = self.positions[token_indices]
if self.use_aux_hidden_state_outputs:
target_hidden_states = torch.cat(
[h[token_indices] for h in aux_hidden_states], dim=-1)
@@ -1590,23 +1662,7 @@ def execute_model(
sampling_metadata=sampling_metadata,
)
spec_token_ids = draft_token_ids.tolist()
-
- # Clear KVConnector state after all KVs are generated.
- if has_kv_transfer_group():
- get_kv_transfer_group().clear_connector_metadata()
-
- return ModelRunnerOutput(
- req_ids=self.input_batch.req_ids,
- req_id_to_index=self.input_batch.req_id_to_index,
- sampled_token_ids=valid_sampled_token_ids,
- spec_token_ids=spec_token_ids,
- logprobs=logprobs_lists,
- prompt_logprobs_dict=prompt_logprobs_dict,
- pooler_output=[],
- finished_sending=finished_sending,
- finished_recving=finished_recving,
- num_nans_in_logits=num_nans_in_logits,
- )
+ return spec_token_ids
def kv_connector_no_forward(
self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
@@ -1654,10 +1710,9 @@ def get_finished_kv_transfers(
scheduler_output.finished_req_ids)
return None, None
- def generate_draft_token_ids(
+ def propose_ngram_draft_token_ids(
self,
sampled_token_ids: list[list[int]],
- sampling_metadata: SamplingMetadata,
) -> list[list[int]]:
# TODO(woosuk): Optimize.
draft_token_ids: list[list[int]] = []
@@ -1729,6 +1784,16 @@ def load_model(self) -> None:
time_after_load - time_before_load)
prepare_communication_buffer_for_model(self.model)
+ if is_mixture_of_experts(
+ self.model) and self.parallel_config.enable_eplb:
+ logger.info("EPLB is enabled for model %s.",
+ self.model_config.model)
+ self.eplb_state = EplbState.build(
+ self.model,
+ self.device,
+ self.parallel_config,
+ )
+
def save_tensorized_model(
self,
tensorizer_config: "TensorizerConfig",
@@ -1887,6 +1952,8 @@ def _dummy_run(
self,
num_tokens: int,
capture_attn_cudagraph: bool = False,
+ skip_eplb: bool = False,
+ is_profile: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
# Padding for DP
@@ -1983,6 +2050,16 @@ def _dummy_run(
assert isinstance(self.drafter, EagleProposer)
self.drafter.dummy_run(num_tokens)
+ # This is necessary to avoid blocking DP.
+ # For dummy runs, we typically skip EPLB since we don't have any real
+ # requests to process.
+ # However, in DP settings, there may be cases when some DP ranks do
+ # not have any requests to process, so they're executing dummy batches.
+ # In such cases, we still have to trigger EPLB to make sure
+ # ranks execute the rearrangement in synchronization.
+ if not skip_eplb:
+ self.eplb_step(is_dummy=True, is_profile=is_profile)
+
logit_indices = np.cumsum(num_scheduled_tokens) - 1
return hidden_states, hidden_states[logit_indices]
@@ -2175,8 +2252,9 @@ def profile_run(self) -> None:
# Cache the dummy encoder outputs.
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
+ # Add `is_profile` here to pre-allocate communication buffers
hidden_states, last_hidden_states \
- = self._dummy_run(self.max_num_tokens)
+ = self._dummy_run(self.max_num_tokens, is_profile=True)
if get_pp_group().is_last_rank:
if self.is_pooling_model:
output = self._dummy_pooler_run(hidden_states)
@@ -2210,10 +2288,15 @@ def capture_model(self) -> None:
for num_tokens in tqdm(reversed(self.cudagraph_batch_sizes),
desc="Capturing CUDA graphs",
total=len(self.cudagraph_batch_sizes)):
+ # We skip EPLB here since we don't want to record dummy metrics
for _ in range(
self.compilation_config.cudagraph_num_of_warmups):
- self._dummy_run(num_tokens, capture_attn_cudagraph=full_cg)
- self._dummy_run(num_tokens, capture_attn_cudagraph=full_cg)
+ self._dummy_run(num_tokens,
+ capture_attn_cudagraph=full_cg,
+ skip_eplb=True)
+ self._dummy_run(num_tokens,
+ capture_attn_cudagraph=full_cg,
+ skip_eplb=True)
end_time = time.perf_counter()
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py
index b0f80c701325..9e7e44d06861 100644
--- a/vllm/v1/worker/gpu_worker.py
+++ b/vllm/v1/worker/gpu_worker.py
@@ -259,9 +259,10 @@ def compile_or_warm_up_model(self) -> None:
x for x in warmup_sizes if x not in
self.vllm_config.compilation_config.cudagraph_capture_sizes
]
+ # We skip EPLB here since we don't want to record dummy metrics
for size in sorted(warmup_sizes, reverse=True):
logger.info("Compile and warming up model for size %d", size)
- self.model_runner._dummy_run(size)
+ self.model_runner._dummy_run(size, skip_eplb=True)
if not self.model_config.enforce_eager:
self.model_runner.capture_model()
@@ -274,8 +275,12 @@ def compile_or_warm_up_model(self) -> None:
max_num_reqs = min(self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens)
+ # We skip EPLB here since we don't want to record dummy metrics
hidden_states, last_hidden_states = \
- self.model_runner._dummy_run(num_tokens=max_num_reqs)
+ self.model_runner._dummy_run(
+ num_tokens=max_num_reqs,
+ skip_eplb=True,
+ )
if self.model_runner.is_pooling_model:
self.model_runner._dummy_pooler_run(hidden_states)
else:
diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py
index 774caa1a3d98..0cc218bdb646 100644
--- a/vllm/v1/worker/tpu_model_runner.py
+++ b/vllm/v1/worker/tpu_model_runner.py
@@ -37,8 +37,8 @@
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
KVCacheConfig, KVCacheSpec,
SlidingWindowSpec)
-from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
- ModelRunnerOutput)
+from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsLists,
+ LogprobsTensors, ModelRunnerOutput)
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
from vllm.v1.utils import bind_kv_cache
@@ -53,12 +53,11 @@
logger = init_logger(__name__)
-# Here we utilize the behavior that out-of-bound index is ignored.
-# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
-_PAD_SLOT_ID = 1_000_000_000
INVALID_TOKEN_ID = -1
# Smallest output size
MIN_NUM_SEQS = 8
+# Block size used for kv cache updating kernel
+NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK = 8
#########################################################
@@ -150,7 +149,11 @@ def __init__(
self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size
self.max_model_len = model_config.max_model_len
+ self.most_model_len = envs.VLLM_TPU_MOST_MODEL_LEN
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
+ self.num_blocks_per_most_len_req = cdiv(
+ self.most_model_len,
+ self.block_size) if self.most_model_len is not None else None
# InputBatch needs to work with sampling tensors greater than padding
# to avoid dynamic shapes. Also, avoid suboptimal alignment.
self.max_num_reqs = max(scheduler_config.max_num_seqs, MIN_NUM_SEQS)
@@ -220,12 +223,19 @@ def __init__(
dtype=torch.int32,
device="cpu")
self.positions_np = self.positions_cpu.numpy()
-
self.block_table_cpu = torch.zeros(
(self.max_num_reqs, self.max_num_blocks_per_req),
dtype=torch.int32,
device="cpu")
-
+ # adjust num_reqs to avoid SMEM OOM.
+ self.num_reqs_most_model_len = min(
+ PallasAttentionBackend.get_max_num_seqs(self.most_model_len,
+ self.block_size),
+ self.max_num_reqs) if self.most_model_len is not None else None
+ self.num_reqs_max_model_len = min(
+ PallasAttentionBackend.get_max_num_seqs(self.max_model_len,
+ self.block_size),
+ self.max_num_reqs)
self.query_start_loc_cpu = torch.zeros(self.max_num_tokens + 1,
dtype=torch.int32,
device="cpu",
@@ -408,21 +418,24 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
req_ids_to_add.append(req_id)
# Update the states of the running/resumed requests.
- for req_data in scheduler_output.scheduled_cached_reqs:
- req_id = req_data.req_id
+ req_data = scheduler_output.scheduled_cached_reqs
+ for i, req_id in enumerate(req_data.req_ids):
req_state = self.requests[req_id]
+ num_computed_tokens = req_data.num_computed_tokens[i]
+ new_block_ids = req_data.new_block_ids[i]
+ resumed_from_preemption = req_data.resumed_from_preemption[i]
# Update the cached states.
- req_state.num_computed_tokens = req_data.num_computed_tokens
- if not req_data.resumed_from_preemption:
+ req_state.num_computed_tokens = num_computed_tokens
+ if not resumed_from_preemption:
# Append the new blocks to the existing block IDs.
- for block_ids, new_block_ids in zip(req_state.block_ids,
- req_data.new_block_ids):
- block_ids.extend(new_block_ids)
+ for block_ids, new_ids in zip(req_state.block_ids,
+ new_block_ids):
+ block_ids.extend(new_ids)
else:
# The request is resumed from preemption.
# Replace the existing block IDs with the new ones.
- req_state.block_ids = req_data.new_block_ids
+ req_state.block_ids = new_block_ids
req_index = self.input_batch.req_id_to_index.get(req_id)
if req_index is None:
@@ -434,9 +447,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
# Update the persistent batch.
self.input_batch.num_computed_tokens_cpu[req_index] = (
- req_data.num_computed_tokens)
- self.input_batch.block_table.append_row(req_data.new_block_ids,
- req_index)
+ num_computed_tokens)
+ self.input_batch.block_table.append_row(new_block_ids, req_index)
# Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first.
@@ -515,25 +527,113 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
return kv_cache_spec
- def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
- total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
- assert total_num_scheduled_tokens > 0
+ def _get_slot_mapping_metadata(self, num_reqs,
+ num_scheduled_tokens_per_req):
+ """
+ Computes metadata for mapping slots to blocks in the key-value (KV)
+ cache for a batch of requests.
+
+ This function determines, for each request in the batch, how the
+ scheduled tokens are distributed across memory blocks, and generates
+ metadata needed to map slices of tokens to their corresponding positions
+ in the KV cache.
+
+ Args:
+ num_reqs (int): Number of requests in the current batch.
+ num_scheduled_tokens_per_req (int or np.ndarray): Number of tokens
+ to be scheduled for each request.
+
+ Returns:
+ np.ndarray: A 2D array of shape (total_block_len, 3), where each row
+ contains:
+ - kv_cache_start_index (int): The starting index in the KV cache
+ for the corresponding slice.
+ - new_kv_start_index (int): The starting index in the new KV
+ cache for the corresponding slice.
+ - slice_len (int): The length of the slice.
+ """
+ slices_start = self.input_batch.num_computed_tokens_cpu[:num_reqs]
+ slices_end = self.input_batch.num_computed_tokens_cpu[:num_reqs] + \
+ num_scheduled_tokens_per_req
+ local_block_start_idx = slices_start // self.block_size
+ local_block_end_idx = (slices_end - 1) // self.block_size
+ no_repeat_req_indices = self.arange_np[:num_reqs]
+ global_block_start_idx = (
+ no_repeat_req_indices * self.max_num_blocks_per_req +
+ local_block_start_idx)
+ block_lens = local_block_end_idx - local_block_start_idx + 1
+ global_block_start_idx = np.repeat(global_block_start_idx, block_lens)
+ slice_arange = np.concatenate([self.arange_np[:n] for n in block_lens])
+ global_block_indices = global_block_start_idx + slice_arange
+ block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor()
+ block_numbers = block_table_cpu.flatten()[global_block_indices].numpy()
+ total_block_len = np.sum(block_lens)
+ slot_mapping_slices = np.repeat(np.array([[0, self.block_size]],
+ dtype=np.int32),
+ total_block_len,
+ axis=0)
+ cu_block_lens = np.zeros(len(block_lens) + 1, dtype=np.int32)
+ np.cumsum(block_lens, out=cu_block_lens[1:])
+ for req_idx in range(num_reqs):
+ slot_mapping_slices[cu_block_lens[req_idx]][
+ 0] = slices_start[req_idx] % self.block_size
+ slot_mapping_slices[
+ cu_block_lens[req_idx + 1] -
+ 1][1] = (slices_end[req_idx] - 1) % self.block_size + 1
+ slice_lens = slot_mapping_slices[:, 1] - slot_mapping_slices[:, 0]
+ cu_slices_lens = np.zeros(len(slice_lens) + 1, dtype=np.int32)
+ np.cumsum(slice_lens, out=cu_slices_lens[1:])
+ kv_cache_start_indices = slot_mapping_slices[:, 0] + \
+ (block_numbers * self.block_size)
+ new_kv_start_indices = cu_slices_lens[:-1]
+ slot_mapping_metadata = np.stack(
+ [kv_cache_start_indices, new_kv_start_indices, slice_lens], axis=1)
+ return slot_mapping_metadata
+
+ def _prepare_inputs(self, scheduler_output: "SchedulerOutput",
+ start_index: int):
+ assert scheduler_output.total_num_scheduled_tokens > 0
num_reqs = self.input_batch.num_reqs
assert num_reqs > 0
+ assert start_index < num_reqs
# Get the number of scheduled tokens for each request.
+ use_max_model_len = self.most_model_len is None
num_scheduled_tokens_per_req = []
max_num_scheduled_tokens_all_reqs = 0
- for req_id in self.input_batch.req_ids[:num_reqs]:
+ end_index = start_index
+
+ # Use either most_model_len or max_model_len depending on request size.
+ for i in range(start_index, num_reqs):
+ req_id = self.input_batch.req_ids[i]
assert req_id is not None
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
+ if not use_max_model_len and num_tokens > self.most_model_len:
+ use_max_model_len = True
num_scheduled_tokens_per_req.append(num_tokens)
- max_num_scheduled_tokens_all_reqs = max(
- max_num_scheduled_tokens_all_reqs, num_tokens)
+ if use_max_model_len:
+ if len(num_scheduled_tokens_per_req) > self.num_reqs_max_model_len:
+ num_scheduled_tokens_per_req = \
+ num_scheduled_tokens_per_req[:self.num_reqs_max_model_len]
+ end_index = start_index + self.num_reqs_max_model_len
+ else:
+ end_index = num_reqs
+ else:
+ if len(num_scheduled_tokens_per_req
+ ) > self.num_reqs_most_model_len:
+ num_scheduled_tokens_per_req = \
+ num_scheduled_tokens_per_req[:self.num_reqs_most_model_len]
+ end_index = start_index + self.num_reqs_most_model_len
+ else:
+ end_index = num_reqs
+ max_num_scheduled_tokens_all_reqs = max(num_scheduled_tokens_per_req)
num_scheduled_tokens_per_req = np.array(num_scheduled_tokens_per_req,
dtype=np.int32)
+ total_num_scheduled_tokens = sum(num_scheduled_tokens_per_req)
assert max_num_scheduled_tokens_all_reqs > 0
+ num_reqs = len(num_scheduled_tokens_per_req)
+
# Get request indices.
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
# For each scheduled token, what are the corresponding req index.
@@ -567,26 +667,6 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
torch.from_numpy(token_indices),
out=self.input_ids_cpu[:total_num_scheduled_tokens])
- # Calculate the slot mapping.
- # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
- # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
- # where K is the max_num_blocks_per_req and the block size is 2.
- # NOTE(woosuk): We can't simply use `token_indices // block_size` here
- # because M (max_model_len) is not necessarily divisible by block_size.
- # req_indices: # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
- block_table_indices = (req_indices * self.max_num_blocks_per_req +
- positions_np // self.block_size)
- # NOTE(woosuk): We use torch.index_select instead of np.take here
- # because torch.index_select is much faster than np.take for large
- # tensors.
- block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor()
- block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
- block_offsets = positions_np % self.block_size
- np.add(block_numbers * self.block_size,
- block_offsets,
- out=self.input_batch.block_table[0].
- slot_mapping_np[:total_num_scheduled_tokens])
-
# Prepare the attention metadata.
self.query_start_loc_np[0] = 0
np.cumsum(num_scheduled_tokens_per_req,
@@ -609,19 +689,42 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
self.position_ids = self.positions_cpu[:
padded_total_num_scheduled_tokens].to(
self.device)
- self.input_batch.block_table[0].slot_mapping_cpu[
- total_num_scheduled_tokens:] = _PAD_SLOT_ID
- slot_mapping = (
- self.input_batch.block_table[0].
- slot_mapping_cpu[:padded_total_num_scheduled_tokens].to(
- self.device))
- block_tables = self.block_table_cpu[:self.max_num_reqs]
- block_tables[:num_reqs, :self.max_num_blocks_per_req] = (
- self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs])
+ if use_max_model_len:
+ block_tables = self.block_table_cpu[:self.num_reqs_max_model_len, :
+ self.max_num_blocks_per_req]
+ block_tables[:num_reqs, :self.max_num_blocks_per_req] = (
+ self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs])
+ query_start_loc = self.query_start_loc_cpu[:self.
+ num_reqs_max_model_len +
+ 1].to(self.device)
+ seq_lens = self.seq_lens_cpu[:self.num_reqs_max_model_len].to(
+ self.device)
+ else:
+ block_tables = self.block_table_cpu[:self.
+ num_reqs_most_model_len, :self.
+ num_blocks_per_most_len_req]
+ block_tables[:num_reqs, :self.num_blocks_per_most_len_req] = (
+ self.input_batch.block_table[0].get_cpu_tensor()
+ [:num_reqs, :self.num_blocks_per_most_len_req])
+ query_start_loc = self.query_start_loc_cpu[:self.
+ num_reqs_most_model_len +
+ 1].to(self.device)
+ seq_lens = self.seq_lens_cpu[:self.num_reqs_most_model_len].to(
+ self.device)
block_tables = block_tables.to(self.device)
- query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1].to(
- self.device)
- seq_lens = self.seq_lens_cpu[:self.max_num_reqs].to(self.device)
+
+ slot_mapping_metadata = self._get_slot_mapping_metadata(
+ num_reqs, num_scheduled_tokens_per_req)
+ padded_num_slices = _get_padded_num_kv_cache_update_slices(
+ padded_total_num_scheduled_tokens, self.max_num_reqs,
+ self.block_size)
+ slot_mapping_metadata = np.pad(
+ slot_mapping_metadata,
+ [[0, padded_num_slices - len(slot_mapping_metadata)], [0, 0]],
+ constant_values=0)
+ slot_mapping_metadata = np.transpose(slot_mapping_metadata)
+ slot_mapping_metadata = torch.tensor(slot_mapping_metadata,
+ device=self.device)
if self.lora_config is not None:
# We need to respect padding when activating LoRA adapters
@@ -635,13 +738,15 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
padded_num_scheduled_tokens_per_req)
attn_metadata = PallasMetadata(
- slot_mapping=slot_mapping,
+ slot_mapping=slot_mapping_metadata,
block_tables=block_tables,
context_lens=seq_lens,
query_start_loc=query_start_loc,
num_seqs=torch.tensor([num_reqs],
dtype=torch.int32,
device=self.device),
+ num_slices_per_kv_cache_update_block=
+ NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK,
)
# NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
# request in the batch. While we should not sample any token from this
@@ -672,7 +777,8 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
layer_name: attn_metadata
for layer_name in layer_names
}
- return per_layer_attn_metadata, logits_indices, padded_num_reqs
+ return per_layer_attn_metadata, logits_indices, padded_num_reqs,\
+ num_reqs, end_index
def _scatter_placeholders(
self,
@@ -847,52 +953,84 @@ def execute_model(
else:
mm_embeds = []
xm.mark_step()
- # Prepare inputs
- attn_metadata, logits_indices, padded_num_reqs = self._prepare_inputs(
- scheduler_output)
- input_ids, inputs_embeds = self._get_model_inputs(
- self.input_ids, mm_embeds)
- xm.mark_step()
- num_reqs = self.input_batch.num_reqs
- # Run the decoder
- with set_forward_context(
- attn_metadata,
- self.vllm_config,
- num_tokens=scheduler_output.total_num_scheduled_tokens):
- hidden_states = self.model(
- input_ids=input_ids,
- positions=self.position_ids,
- inputs_embeds=inputs_embeds,
- )
- hidden_states = self.select_hidden_states(hidden_states,
- logits_indices)
- logits = self.compute_logits(hidden_states)
- tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
- from_input_batch(self.input_batch, padded_num_reqs, self.device)
- if scheduler_output.grammar_bitmask is not None:
- require_struct_decoding, grammar_bitmask_padded, arange = \
- self.prepare_structured_decoding_input(logits, scheduler_output)
- logits = self.structured_decode(require_struct_decoding,
- grammar_bitmask_padded, logits,
- arange)
- selected_token_ids = self.sample_from_logits_func(
- logits, tpu_sampling_metadata)
- # NOTE (NickLucche) Use the original logits (before any penalties or
- # temperature scaling) for the top-k logprobs. We can't enforce it due
- # to recompilations outside torch.compiled code, so just make sure
- # `sample_from_logits` does not modify the logits in-place.
- logprobs = self.gather_logprobs(logits, selected_token_ids) \
- if tpu_sampling_metadata.logprobs else None
-
- # Remove padding on cpu and keep dynamic op outside of xla graph.
- selected_token_ids = selected_token_ids.cpu()[:num_reqs]
- logprobs_lists = logprobs.tolists() \
- if tpu_sampling_metadata.logprobs else None
+ # Prepare inputs, the requests might be splitted into multiple
+ # executions, combine the result of each execution.
+ start_index = 0
+ combined_selected_tokens: list[torch.Tensor] = []
+ combined_logprobs: list[LogprobsLists] = []
+ while start_index < self.input_batch.num_reqs:
+ attn_metadata, logits_indices, padded_num_reqs, num_reqs,\
+ end_index = self._prepare_inputs(scheduler_output, start_index)
+ input_ids, inputs_embeds = self._get_model_inputs(
+ self.input_ids, mm_embeds)
+ xm.mark_step()
+ # Run the decoder
+ with set_forward_context(
+ attn_metadata,
+ self.vllm_config,
+ num_tokens=scheduler_output.total_num_scheduled_tokens):
+ hidden_states = self.model(
+ input_ids=input_ids,
+ positions=self.position_ids,
+ inputs_embeds=inputs_embeds,
+ )
+ hidden_states = self.select_hidden_states(hidden_states,
+ logits_indices)
+ logits = self.compute_logits(hidden_states)
+ tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
+ from_input_batch(self.input_batch, padded_num_reqs, self.device)
+ if scheduler_output.grammar_bitmask is not None:
+ require_struct_decoding, grammar_bitmask_padded, arange = \
+ self.prepare_structured_decoding_input(logits,
+ scheduler_output)
+ logits = self.structured_decode(require_struct_decoding,
+ grammar_bitmask_padded, logits,
+ arange)
+ selected_token_ids = self.sample_from_logits_func(
+ logits, tpu_sampling_metadata)
+ # NOTE (NickLucche) Use the original logits (before any penalties or
+ # temperature scaling) for the top-k logprobs. We can't enforce it
+ # due to recompilations outside torch.compiled code, so just make
+ # sure `sample_from_logits` does not modify the logits in-place.
+ logprobs = self.gather_logprobs(logits, selected_token_ids) \
+ if tpu_sampling_metadata.logprobs else None
+
+ # Remove padding on cpu and keep dynamic op outside of xla graph.
+ selected_token_ids = selected_token_ids.cpu()[:num_reqs]
+
+ combined_selected_tokens.append(selected_token_ids)
+ if tpu_sampling_metadata.logprobs:
+ combined_logprobs.append(logprobs.tolists())
+
+ start_index = end_index
+
+ selected_token_ids = torch.cat(combined_selected_tokens, dim=0)
+ if tpu_sampling_metadata.logprobs:
+
+ def concat_lists(input_lists):
+ result = []
+ for input_list in input_lists:
+ result.extend(input_list)
+ return result
+
+ logprobs_lists = LogprobsLists(logprob_token_ids=concat_lists(
+ [lp.logprob_token_ids for lp in combined_logprobs]),
+ logprobs=concat_lists([
+ lp.logprobs
+ for lp in combined_logprobs
+ ]),
+ sampled_token_ranks=concat_lists([
+ lp.sampled_token_ranks
+ for lp in combined_logprobs
+ ]))
+ else:
+ logprobs_lists = None
# Update the cache state concurrently. Code above will not block until
# we use `selected_token_ids`. Add mark_step if post-processing changes
request_seq_lens: list[tuple[int, CachedRequestState, int]] = []
discard_sampled_tokens_req_indices = []
+ num_reqs = self.input_batch.num_reqs
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
assert req_id is not None
req_state = self.requests[req_id]
@@ -1020,7 +1158,8 @@ def load_model(self) -> None:
self.sampler = TPUSampler()
@torch.no_grad()
- def _dummy_run(self, num_tokens: int) -> None:
+ def _dummy_run(self, num_tokens: int, num_reqs: int,
+ num_blocks: int) -> None:
if self.is_multimodal_model:
input_ids = None
inputs_embeds = torch.zeros((num_tokens, self.hidden_size),
@@ -1030,20 +1169,21 @@ def _dummy_run(self, num_tokens: int) -> None:
input_ids = torch.zeros((num_tokens),
dtype=torch.int32).to(self.device)
inputs_embeds = None
- actual_num_reqs = min(num_tokens, self.max_num_reqs)
+ actual_num_reqs = min(num_tokens, num_reqs)
position_ids = torch.zeros(num_tokens,
dtype=torch.int32).to(self.device)
- slot_mapping = torch.zeros(num_tokens,
- dtype=torch.int64).to(self.device)
- block_tables = torch.zeros(
- (self.max_num_reqs, self.block_table_cpu.shape[1]),
- dtype=torch.int32).to(self.device)
- query_lens = [1] * self.max_num_reqs
+ padded_num_slices = _get_padded_num_kv_cache_update_slices(
+ num_tokens, self.max_num_reqs, self.block_size)
+ slot_mapping = torch.zeros((3, padded_num_slices),
+ dtype=torch.int32).to(self.device)
+ block_tables = torch.zeros((num_reqs, num_blocks),
+ dtype=torch.int32).to(self.device)
+ query_lens = [1] * num_reqs
query_start_loc = torch.cumsum(torch.tensor([0] + query_lens,
dtype=torch.int32),
dim=0,
dtype=torch.int32).to(self.device)
- context_lens = torch.ones((self.max_num_reqs, ),
+ context_lens = torch.ones((num_reqs, ),
dtype=torch.int32).to(self.device)
num_seqs = torch.tensor([actual_num_reqs],
dtype=torch.int32).to(self.device)
@@ -1053,6 +1193,8 @@ def _dummy_run(self, num_tokens: int) -> None:
context_lens=context_lens,
query_start_loc=query_start_loc,
num_seqs=num_seqs,
+ num_slices_per_kv_cache_update_block=
+ NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK,
)
if self.is_multimodal_model:
@@ -1061,6 +1203,9 @@ def _dummy_run(self, num_tokens: int) -> None:
torch._dynamo.mark_dynamic(input_ids, 0)
torch._dynamo.mark_dynamic(position_ids, 0)
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
+ torch._dynamo.mark_dynamic(attn_metadata.block_tables, (0, 1))
+ torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
+ torch._dynamo.mark_dynamic(attn_metadata.query_start_loc, 0)
layer_names = get_layers_from_vllm_config(self.vllm_config,
Attention).keys()
@@ -1152,7 +1297,11 @@ def _precompile_backbone(self) -> None:
start = time.perf_counter()
for num_tokens in self.num_tokens_paddings:
logger.info(" -- num_tokens: %d", num_tokens)
- self._dummy_run(num_tokens)
+ self._dummy_run(num_tokens, self.num_reqs_max_model_len,
+ self.max_num_blocks_per_req)
+ if self.most_model_len is not None:
+ self._dummy_run(num_tokens, self.num_reqs_most_model_len,
+ self.num_blocks_per_most_len_req)
xm.wait_device_ops()
end = time.perf_counter()
logger.info("Compilation finished in %.2f [secs].", end - start)
@@ -1341,7 +1490,11 @@ def profile_run(
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
# Trigger compilation for general shape.
- self._dummy_run(num_tokens)
+ self._dummy_run(num_tokens, self.num_reqs_max_model_len,
+ self.max_num_blocks_per_req)
+ if self.most_model_len is not None:
+ self._dummy_run(num_tokens, self.num_reqs_most_model_len,
+ self.num_blocks_per_most_len_req)
xm.mark_step()
xm.wait_device_ops()
@@ -1646,6 +1799,19 @@ def _get_padded_token_len(paddings: list[int], x: int) -> int:
return paddings[index]
+def _get_padded_num_kv_cache_update_slices(num_tokens: int, max_num_reqs: int,
+ page_size: int) -> int:
+ """Calculates the padded number of KV cache update slices to avoid
+ recompilation."""
+ padded_num_slices = 2 * max_num_reqs + num_tokens // page_size
+ padded_num_slices = min(padded_num_slices, num_tokens)
+ padded_num_slices = (
+ padded_num_slices + NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK - 1
+ ) // NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK * \
+ NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK
+ return padded_num_slices
+
+
def replace_set_lora(model):
def _tpu_set_lora(
diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py
index 87af8e476707..a64ce881fe31 100644
--- a/vllm/v1/worker/tpu_worker.py
+++ b/vllm/v1/worker/tpu_worker.py
@@ -18,7 +18,8 @@
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
-from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
+from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
+from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig,
KVCacheSpec)
@@ -221,7 +222,17 @@ def determine_available_memory(self) -> int:
usable_memory_size = int(total_memory_size *
self.cache_config.gpu_memory_utilization)
tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0)
-
+ head_size = self.model_config.get_head_size()
+ if head_size > 0:
+ padded_head_size = cdiv(
+ head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
+ if padded_head_size != head_size:
+ logger.warning_once("head size is padded to %d",
+ padded_head_size)
+ # We adjust the usable memory size for the KV cache to prevent OOM
+ # errors, even after padding the head_size.
+ tpu_kv_cache_bytes = (tpu_kv_cache_bytes * head_size //
+ padded_head_size)
return int(tpu_kv_cache_bytes)
def execute_model(
diff --git a/vllm/v1/worker/xpu_model_runner.py b/vllm/v1/worker/xpu_model_runner.py
new file mode 100644
index 000000000000..55d116dcd496
--- /dev/null
+++ b/vllm/v1/worker/xpu_model_runner.py
@@ -0,0 +1,32 @@
+# SPDX-License-Identifier: Apache-2.0
+from typing import TYPE_CHECKING
+
+import torch
+
+from vllm.config import VllmConfig
+from vllm.logger import init_logger
+from vllm.v1.worker.gpu_model_runner import GPUModelRunner
+
+if TYPE_CHECKING:
+ pass
+
+logger = init_logger(__name__)
+
+
+class XPUModelRunner(GPUModelRunner):
+ """A model runner for XPU devices."""
+
+ def __init__(
+ self,
+ vllm_config: VllmConfig,
+ device: torch.device,
+ ):
+ super().__init__(vllm_config, device)
+ # FIXME: To be verified.
+ self.cascade_attn_enabled = False
+
+ def _init_device_properties(self) -> None:
+ pass
+
+ def _sync_device(self) -> None:
+ torch.xpu.synchronize()
diff --git a/vllm/v1/worker/xpu_worker.py b/vllm/v1/worker/xpu_worker.py
new file mode 100644
index 000000000000..d9ea03986566
--- /dev/null
+++ b/vllm/v1/worker/xpu_worker.py
@@ -0,0 +1,164 @@
+# SPDX-License-Identifier: Apache-2.0
+import os
+
+import torch
+import torch.distributed
+
+import vllm.envs as envs
+from vllm.config import VllmConfig
+from vllm.logger import init_logger
+from vllm.model_executor import set_random_seed
+from vllm.platforms import current_platform
+from vllm.v1.worker.gpu_worker import (Worker,
+ init_worker_distributed_environment)
+from vllm.v1.worker.xpu_model_runner import XPUModelRunner
+
+logger = init_logger(__name__)
+
+
+class XPUWorker(Worker):
+ """A XPU worker class."""
+
+ def __init__(
+ self,
+ vllm_config: VllmConfig,
+ local_rank: int,
+ rank: int,
+ distributed_init_method: str,
+ is_driver_worker: bool = False,
+ ):
+ super().__init__(vllm_config, local_rank, rank,
+ distributed_init_method, is_driver_worker)
+ device_config = self.device_config
+ assert device_config.device_type == "xpu"
+ assert current_platform.is_xpu()
+
+ # Torch profiler. Enabled and configured through env vars:
+ # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
+ if envs.VLLM_TORCH_PROFILER_DIR:
+ torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
+ logger.info("Profiling enabled. Traces will be saved to: %s",
+ torch_profiler_trace_dir)
+ self.profiler = torch.profiler.profile(
+ activities=[
+ torch.profiler.ProfilerActivity.CPU,
+ torch.profiler.ProfilerActivity.XPU,
+ ],
+ with_stack=True,
+ on_trace_ready=torch.profiler.tensorboard_trace_handler(
+ torch_profiler_trace_dir, use_gzip=True))
+ else:
+ self.profiler = None
+
+ # we provide this function due to `torch.xpu.mem_get_info()` doesn't
+ # return correct free_gpu_memory on intel client GPU. We need to
+ # calculate/estiamte it.
+ def xpu_get_mem_info(self):
+ if current_platform.is_data_center_gpu():
+ return torch.xpu.mem_get_info()
+ else:
+ _, total_gpu_memory = torch.xpu.mem_get_info()
+ # FIXME: memory_allocated() doesn't count non-torch allocations,
+ # and we don't have any API to get it. so we mark it as 128MB.
+ used_memory = torch.xpu.memory_allocated()
+ non_torch_allocations = 128 * 1024 * 1024
+ free_gpu_memory = total_gpu_memory - (used_memory +
+ non_torch_allocations)
+ return free_gpu_memory, total_gpu_memory
+
+ @torch.inference_mode()
+ def determine_available_memory(self) -> int:
+ """Profiles the peak memory usage of the model to determine how many
+ KV blocks may be allocated without OOMs.
+ The engine will first conduct a profiling of the existing memory usage.
+ Then, it calculate the maximum possible number of GPU and CPU blocks
+ that can be allocated with the remaining free memory.
+ .. tip::
+ You may limit the usage of GPU memory
+ by adjusting the `gpu_memory_utilization` parameter.
+ """
+ # Profile the memory usage of the model and get the maximum number of
+ # cache blocks that can be allocated with the remaining free memory.
+ torch.xpu.empty_cache()
+ torch.xpu.reset_peak_memory_stats()
+
+ free_gpu_memory, total_gpu_memory = torch.xpu.mem_get_info()
+ current_allocated_bytes = torch.xpu.memory_allocated()
+ msg = ("Before memory profiling run, "
+ f"total GPU memory: {total_gpu_memory / 1024**2:.2f} MB, "
+ f"model load takes {current_allocated_bytes / 1024**2:.2f} MB, "
+ f"free gpu memory is {free_gpu_memory / 1024**2:.2f} MB.")
+ logger.info(msg)
+ # Execute a forward pass with dummy inputs to profile the memory usage
+ # of the model.
+ self.model_runner.profile_run()
+
+ free_gpu_memory, _ = self.xpu_get_mem_info()
+ # NOTE(woosuk): Here we assume that the other processes using the same
+ # GPU did not change their memory usage during the profiling.
+ assert self.init_gpu_memory > free_gpu_memory, (
+ "Error in memory profiling. "
+ f"Initial free memory {self.init_gpu_memory}, current free memory"
+ f" {free_gpu_memory}. This happens when the GPU memory was "
+ "not properly cleaned up before initializing the vLLM instance.")
+
+ # Get the peak memory allocation recorded by torch
+ peak_memory = torch.xpu.memory_stats()["allocated_bytes.all.peak"]
+
+ torch.xpu.empty_cache()
+ torch_allocated_bytes = torch.xpu.memory_stats(
+ )["allocated_bytes.all.current"]
+ total_allocated_bytes = self.xpu_get_mem_info(
+ )[1] - self.xpu_get_mem_info()[0]
+
+ non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
+ if non_torch_allocations > 0:
+ peak_memory += non_torch_allocations
+ available_kv_cache_memory = (
+ total_gpu_memory * self.cache_config.gpu_memory_utilization -
+ peak_memory)
+
+ msg = ("After memory profiling run, "
+ f"peak memory usage is {peak_memory / 1024**2:.2f} MB,"
+ f"torch mem is {torch_allocated_bytes / 1024**2:.2f} MB, "
+ f"non-torch mem is {non_torch_allocations / 1024**2:.2f} MB, "
+ f"free gpu memory is {free_gpu_memory / 1024**2:.2f} MB.")
+ logger.info(msg)
+
+ return int(available_kv_cache_memory)
+
+ def init_device(self):
+ if self.device_config.device.type == "xpu" and current_platform.is_xpu(
+ ):
+ self.device = torch.device(f"xpu:{self.local_rank}")
+ torch.xpu.set_device(self.device)
+ torch.xpu.empty_cache()
+ self.init_gpu_memory = torch.xpu.get_device_properties(
+ self.local_rank).total_memory
+ else:
+ raise RuntimeError(
+ f"Not support device type: {self.device_config.device}")
+
+ ENV_CCL_ZE_IPC_EXCHANGE = os.getenv("CCL_ZE_IPC_EXCHANGE", "drmfd")
+ ENV_CCL_ATL_TRANSPORT = os.getenv("CCL_ATL_TRANSPORT", "ofi")
+ ENV_LOCAL_WORLD_SIZE = os.getenv("LOCAL_WORLD_SIZE",
+ str(self.parallel_config.world_size))
+ os.environ["CCL_ZE_IPC_EXCHANGE"] = ENV_CCL_ZE_IPC_EXCHANGE
+ os.environ["CCL_ATL_TRANSPORT"] = ENV_CCL_ATL_TRANSPORT
+ os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE
+ os.environ["LOCAL_RANK"] = str(self.local_rank)
+ dist_backend = "ccl"
+
+ init_worker_distributed_environment(self.vllm_config, self.rank,
+ self.distributed_init_method,
+ self.local_rank, dist_backend)
+
+ # global all_reduce needed for overall oneccl warm up
+ torch.distributed.all_reduce(torch.zeros(1).xpu())
+
+ # Set random seed.
+ set_random_seed(self.model_config.seed)
+
+ # Construct the model runner
+ self.model_runner = XPUModelRunner( # type: ignore
+ self.vllm_config, self.device)