diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 69e36f2804c4..a43cdf6b07ec 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -172,6 +172,8 @@ steps: - tests/v1/engine/test_engine_core_client.py - tests/distributed/test_symm_mem_allreduce.py commands: + # https://github.com/NVIDIA/nccl/issues/1838 + - export NCCL_CUMEM_HOST_ENABLE=0 # test with torchrun tp=2 and external_dp=2 - torchrun --nproc-per-node=4 distributed/test_torchrun_example.py # test with torchrun tp=2 and pp=2 @@ -349,7 +351,8 @@ steps: - python3 offline_inference/basic/embed.py - python3 offline_inference/basic/score.py - python3 offline_inference/spec_decode.py --test --method eagle --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 2048 - - python3 offline_inference/spec_decode.py --test --method eagle3 --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 2048 + # https://github.com/vllm-project/vllm/pull/26682 uses slightly more memory in PyTorch 2.9+ causing this test to OOM in 1xL4 GPU + - python3 offline_inference/spec_decode.py --test --method eagle3 --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 1536 - label: Platform Tests (CUDA) # 4min timeout_in_minutes: 15 @@ -529,7 +532,7 @@ steps: # https://github.com/pytorch/ao/issues/2919, we'll have to skip new torchao tests for now # we can only upgrade after this is resolved # TODO(jerryzh168): resolve the above comment - - uv pip install --system torchao==0.13.0 + - uv pip install --system torchao==0.13.0 --index-url https://download.pytorch.org/whl/cu129 - VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/ --ignore quantization/test_blackwell_moe.py - label: LM Eval Small Models # 53min @@ -970,6 +973,8 @@ steps: - tests/v1/shutdown - tests/v1/worker/test_worker_memory_snapshot.py commands: + # https://github.com/NVIDIA/nccl/issues/1838 + - export NCCL_CUMEM_HOST_ENABLE=0 - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py - DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 121bdb750de5..638c288afe9e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -38,7 +38,7 @@ repos: rev: 0.9.1 hooks: - id: pip-compile - args: [requirements/test.in, -o, requirements/test.txt, --index-strategy, unsafe-best-match, --torch-backend, cu128, --python-platform, x86_64-manylinux_2_28] + args: [requirements/test.in, -o, requirements/test.txt, --index-strategy, unsafe-best-match, --extra-index-url, https://download.pytorch.org/whl/cu129, --python-platform, x86_64-manylinux_2_28] files: ^requirements/test\.(in|txt)$ - repo: local hooks: diff --git a/CMakeLists.txt b/CMakeLists.txt index 005590445361..e1059d949653 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -49,8 +49,8 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1 # requirements.txt files and should be kept consistent. The ROCm torch # versions are derived from docker/Dockerfile.rocm # -set(TORCH_SUPPORTED_VERSION_CUDA "2.8.0") -set(TORCH_SUPPORTED_VERSION_ROCM "2.8.0") +set(TORCH_SUPPORTED_VERSION_CUDA "2.9.0") +set(TORCH_SUPPORTED_VERSION_ROCM "2.9.0") # # Try to find python package with an executable that exactly matches diff --git a/docker/Dockerfile b/docker/Dockerfile index 8f482b393c91..f43a5b749458 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -5,7 +5,7 @@ # docs/contributing/dockerfile/dockerfile.md and # docs/assets/contributing/dockerfile-stages-dependency.png -ARG CUDA_VERSION=12.8.1 +ARG CUDA_VERSION=12.9.1 ARG PYTHON_VERSION=3.12 # By parameterizing the base images, we allow third-party to use their own @@ -273,7 +273,7 @@ WORKDIR /vllm-workspace ENV DEBIAN_FRONTEND=noninteractive ARG TARGETPLATFORM -ARG GDRCOPY_CUDA_VERSION=12.8 +ARG GDRCOPY_CUDA_VERSION=12.9 # Keep in line with FINAL_BASE_IMAGE ARG GDRCOPY_OS_VERSION=Ubuntu22_04 @@ -356,6 +356,13 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist uv pip install --system dist/*.whl --verbose \ --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') +# TODO (huydhn): Remove this once xformers is released for 2.9.0 +RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH' + . /etc/environment + export TORCH_CUDA_ARCH_LIST='7.5 8.0+PTX 9.0a' + uv pip install --system --no-build-isolation "git+https://github.com/facebookresearch/xformers@v0.0.32.post2" +BASH + # Install FlashInfer pre-compiled kernel cache and binaries # https://docs.flashinfer.ai/installation.html RUN --mount=type=cache,target=/root/.cache/uv \ @@ -422,6 +429,7 @@ ARG PYTHON_VERSION ARG PIP_INDEX_URL UV_INDEX_URL ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL +ARG PYTORCH_CUDA_INDEX_BASE_URL # This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out # Reference: https://github.com/astral-sh/uv/pull/1694 @@ -434,7 +442,8 @@ ENV UV_LINK_MODE=copy RUN --mount=type=cache,target=/root/.cache/uv \ CUDA_MAJOR="${CUDA_VERSION%%.*}"; \ if [ "$CUDA_MAJOR" -ge 12 ]; then \ - uv pip install --system -r requirements/dev.txt; \ + uv pip install --system -r requirements/dev.txt \ + --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.'); \ fi # install development dependencies (for testing) diff --git a/docker/Dockerfile.cpu b/docker/Dockerfile.cpu index 2aed1872ee85..a4e72f019374 100644 --- a/docker/Dockerfile.cpu +++ b/docker/Dockerfile.cpu @@ -111,9 +111,13 @@ FROM base AS vllm-test-deps WORKDIR /workspace/vllm +# TODO: Update to 2.9.0 when there is a new build for intel_extension_for_pytorch for that version RUN --mount=type=bind,src=requirements/test.in,target=requirements/test.in \ cp requirements/test.in requirements/cpu-test.in && \ sed -i '/mamba_ssm/d' requirements/cpu-test.in && \ + sed -i 's/^torch==.*/torch==2.8.0/g' requirements/cpu-test.in && \ + sed -i 's/torchaudio.*/torchaudio/g' requirements/cpu-test.in && \ + sed -i 's/torchvision.*/torchvision/g' requirements/cpu-test.in && \ uv pip compile requirements/cpu-test.in -o requirements/cpu-test.txt --index-strategy unsafe-best-match --torch-backend cpu RUN --mount=type=cache,target=/root/.cache/uv \ diff --git a/docs/assets/contributing/dockerfile-stages-dependency.png b/docs/assets/contributing/dockerfile-stages-dependency.png index 0838bfa37fe6..f8c104ba1425 100644 Binary files a/docs/assets/contributing/dockerfile-stages-dependency.png and b/docs/assets/contributing/dockerfile-stages-dependency.png differ diff --git a/docs/contributing/ci/update_pytorch_version.md b/docs/contributing/ci/update_pytorch_version.md index 5f6edc2b139c..f983c25f26ee 100644 --- a/docs/contributing/ci/update_pytorch_version.md +++ b/docs/contributing/ci/update_pytorch_version.md @@ -87,7 +87,7 @@ is ineffective. While ongoing efforts like address the long build time at its source, the current workaround is to set `VLLM_CI_BRANCH` -to a custom branch provided by @khluu (`VLLM_CI_BRANCH=khluu/use_postmerge_q`) +to a custom branch provided by @khluu (`VLLM_CI_BRANCH=khluu/long_build`) when manually triggering a build on Buildkite. This branch accomplishes two things: 1. Increase the timeout limit to 10 hours so that the build doesn't time out. @@ -100,35 +100,17 @@ to warm it up so that future builds are faster. ## Update dependencies -Several vLLM dependencies, such as FlashInfer, also depend on PyTorch and need +Several vLLM dependencies like xFormers depend on PyTorch and need to be updated accordingly. Rather than waiting for all of them to publish new releases (which would take too much time), they can be built from source to unblock the update process. -### FlashInfer - -Here is how to build and install it from source with `torch2.7.0+cu128` in vLLM [Dockerfile](https://github.com/vllm-project/vllm/blob/27bebcd89792d5c4b08af7a65095759526f2f9e1/docker/Dockerfile#L259-L271): - -```bash -export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0 10.0+PTX' -export FLASHINFER_ENABLE_SM90=1 -uv pip install --system \ - --no-build-isolation "git+https://github.com/flashinfer-ai/flashinfer@v0.2.6.post1" -``` - -One caveat is that building FlashInfer from source adds approximately 30 -minutes to the vLLM build time. Therefore, it's preferable to cache the wheel in a -public location for immediate installation, such as [this FlashInfer wheel link](https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl). For future releases, contact the PyTorch release -team if you want to get the package published there. - ### xFormers -Similar to FlashInfer, here is how to build and install xFormers from source: - ```bash -export TORCH_CUDA_ARCH_LIST='7.0 7.5 8.0 8.9 9.0 10.0+PTX' +export TORCH_CUDA_ARCH_LIST='7.5 8.0+PTX 9.0a' MAX_JOBS=16 uv pip install --system \ - --no-build-isolation "git+https://github.com/facebookresearch/xformers@v0.0.30" + --no-build-isolation "git+https://github.com/facebookresearch/xformers@v0.0.32.post2" ``` ## Update all the different vLLM platforms diff --git a/pyproject.toml b/pyproject.toml index eb9bdb593baa..29ee7f75f070 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ requires = [ "packaging>=24.2", "setuptools>=77.0.3,<80.0.0", "setuptools-scm>=8.0", - "torch == 2.8.0", + "torch == 2.9.0", "wheel", "jinja2", ] diff --git a/requirements/build.txt b/requirements/build.txt index 5f826a1afa14..ba09eaab70e8 100644 --- a/requirements/build.txt +++ b/requirements/build.txt @@ -4,7 +4,7 @@ ninja packaging>=24.2 setuptools>=77.0.3,<80.0.0 setuptools-scm>=8 -torch==2.8.0 +torch==2.9.0 wheel jinja2>=3.1.6 regex diff --git a/requirements/cuda.txt b/requirements/cuda.txt index 411c8de5378b..dd45eb832a96 100644 --- a/requirements/cuda.txt +++ b/requirements/cuda.txt @@ -5,11 +5,11 @@ numba == 0.61.2 # Required for N-gram speculative decoding # Dependencies for NVIDIA GPUs ray[cgraph]>=2.48.0 # Ray Compiled Graph, required for pipeline parallelism in V1. -torch==2.8.0 -torchaudio==2.8.0 +torch==2.9.0 +torchaudio==2.9.0 # These must be updated alongside torch -torchvision==0.23.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version +torchvision==0.24.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version # https://github.com/facebookresearch/xformers/releases/tag/v0.0.32.post1 -xformers==0.0.32.post1; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.8 +# xformers==0.0.32.post1; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.8 # FlashInfer should be updated together with the Dockerfile -flashinfer-python==0.4.1 \ No newline at end of file +flashinfer-python==0.4.1 diff --git a/requirements/rocm-build.txt b/requirements/rocm-build.txt index a86a8ab6df14..51f58e57a785 100644 --- a/requirements/rocm-build.txt +++ b/requirements/rocm-build.txt @@ -1,12 +1,12 @@ # Common dependencies -r common.txt ---extra-index-url https://download.pytorch.org/whl/rocm6.3 -torch==2.8.0 -torchvision==0.23.0 -torchaudio==2.8.0 +--extra-index-url https://download.pytorch.org/whl/rocm6.4 +torch==2.9.0 +torchvision==0.24.0 +torchaudio==2.9.0 -triton==3.3.0 +triton==3.5.0 cmake>=3.26.1,<4 packaging>=24.2 setuptools>=77.0.3,<80.0.0 diff --git a/requirements/test.in b/requirements/test.in index f0941d3c5918..a79ec839dbec 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -24,9 +24,9 @@ soundfile # required for audio tests jiwer # required for audio tests tblib # for pickling test exceptions timm >=1.0.17 # required for internvl and gemma3n-mm test -torch==2.8.0 -torchaudio==2.8.0 -torchvision==0.23.0 +torch==2.9.0 +torchaudio==2.9.0 +torchvision==0.24.0 transformers_stream_generator # required for qwen-vl test matplotlib # required for qwen-vl test mistral_common[image,audio] >= 1.8.5 # required for voxtral test @@ -55,4 +55,4 @@ fastsafetensors>=0.1.10 pydantic>=2.12 # 2.11 leads to error on python 3.13 decord==0.6.0 terratorch @ git+https://github.com/IBM/terratorch.git@1.1.rc3 # required for PrithviMAE test -gpt-oss >= 0.0.7; python_version > '3.11' \ No newline at end of file +gpt-oss >= 0.0.7; python_version > '3.11' diff --git a/requirements/test.txt b/requirements/test.txt index 03fbdcc8d453..5838113c9560 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,5 +1,5 @@ # This file was autogenerated by uv via the following command: -# uv pip compile requirements/test.in -o requirements/test.txt --index-strategy unsafe-best-match --torch-backend cu128 --python-platform x86_64-manylinux_2_28 +# uv pip compile requirements/test.in -o requirements/test.txt --index-strategy unsafe-best-match --python-platform x86_64-manylinux_2_28 absl-py==2.1.0 # via rouge-score accelerate==1.0.1 @@ -573,42 +573,44 @@ numpy==1.26.4 # tritonclient # vocos # xarray -nvidia-cublas-cu12==12.8.4.1 +nvidia-cublas-cu12==12.9.1.4 # via # nvidia-cudnn-cu12 # nvidia-cusolver-cu12 # torch -nvidia-cuda-cupti-cu12==12.8.90 +nvidia-cuda-cupti-cu12==12.9.79 # via torch -nvidia-cuda-nvrtc-cu12==12.8.93 +nvidia-cuda-nvrtc-cu12==12.9.86 # via torch -nvidia-cuda-runtime-cu12==12.8.90 +nvidia-cuda-runtime-cu12==12.9.79 # via torch nvidia-cudnn-cu12==9.10.2.21 # via torch -nvidia-cufft-cu12==11.3.3.83 +nvidia-cufft-cu12==11.4.1.4 # via torch -nvidia-cufile-cu12==1.13.1.3 +nvidia-cufile-cu12==1.14.1.1 # via torch -nvidia-curand-cu12==10.3.9.90 +nvidia-curand-cu12==10.3.10.19 # via torch -nvidia-cusolver-cu12==11.7.3.90 +nvidia-cusolver-cu12==11.7.5.82 # via torch -nvidia-cusparse-cu12==12.5.8.93 +nvidia-cusparse-cu12==12.5.10.65 # via # nvidia-cusolver-cu12 # torch nvidia-cusparselt-cu12==0.7.1 # via torch -nvidia-nccl-cu12==2.27.3 +nvidia-nccl-cu12==2.27.5 # via torch -nvidia-nvjitlink-cu12==12.8.93 +nvidia-nvjitlink-cu12==12.9.86 # via # nvidia-cufft-cu12 # nvidia-cusolver-cu12 # nvidia-cusparse-cu12 # torch -nvidia-nvtx-cu12==12.8.90 +nvidia-nvshmem-cu12==3.3.20 + # via torch +nvidia-nvtx-cu12==12.9.79 # via torch omegaconf==2.3.0 # via @@ -1017,7 +1019,6 @@ setuptools==77.0.3 # lightning-utilities # pytablewriter # torch - # triton shapely==2.1.1 # via # geopandas @@ -1122,7 +1123,7 @@ tomli==2.2.1 # via schemathesis tomli-w==1.2.0 # via schemathesis -torch==2.8.0+cu128 +torch==2.9.0+cu129 # via # -r requirements/test.in # accelerate @@ -1151,7 +1152,7 @@ torch==2.8.0+cu128 # torchvision # vector-quantize-pytorch # vocos -torchaudio==2.8.0+cu128 +torchaudio==2.9.0 # via # -r requirements/test.in # encodec @@ -1164,7 +1165,7 @@ torchmetrics==1.7.4 # pytorch-lightning # terratorch # torchgeo -torchvision==0.23.0+cu128 +torchvision==0.24.0 # via # -r requirements/test.in # lightly @@ -1205,7 +1206,7 @@ transformers==4.56.2 # transformers-stream-generator transformers-stream-generator==0.0.5 # via -r requirements/test.in -triton==3.4.0 +triton==3.5.0 # via torch tritonclient==2.51.0 # via diff --git a/tests/compile/test_fusions_e2e.py b/tests/compile/test_fusions_e2e.py index 7399abaec542..929637fc1cf5 100644 --- a/tests/compile/test_fusions_e2e.py +++ b/tests/compile/test_fusions_e2e.py @@ -21,12 +21,18 @@ from ..utils import flat_product, multi_gpu_test +class Matches(NamedTuple): + attention_fusion: int + allreduce_fusion: int = 0 + sequence_parallel: int = 0 + async_tp: int = 0 + + class ModelBackendTestCase(NamedTuple): model_name: str model_kwargs: dict[str, Any] backend: _Backend - attention_fusions: int - allreduce_fusions: int | None = None + matches: Matches MODELS_FP8: list[ModelBackendTestCase] = [] @@ -40,15 +46,23 @@ class ModelBackendTestCase(NamedTuple): model_name="RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8", model_kwargs=dict(max_model_len=1024), backend=_Backend.TRITON_ATTN, - attention_fusions=32, - allreduce_fusions=65, + matches=Matches( + attention_fusion=32, + allreduce_fusion=65, + sequence_parallel=65, + async_tp=128, + ), ), ModelBackendTestCase( model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"), backend=_Backend.FLASHINFER, - attention_fusions=48, - allreduce_fusions=96, + matches=Matches( + attention_fusion=48, + allreduce_fusion=96, + sequence_parallel=96, + async_tp=190, + ), ), ] @@ -57,8 +71,12 @@ class ModelBackendTestCase(NamedTuple): model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"), backend=_Backend.FLASHINFER, - attention_fusions=48, - allreduce_fusions=96, + matches=Matches( + attention_fusion=48, + allreduce_fusion=96, + sequence_parallel=96, + async_tp=190, + ), ), ] @@ -68,8 +86,12 @@ class ModelBackendTestCase(NamedTuple): model_name="meta-llama/Llama-3.1-8B-Instruct", model_kwargs=dict(max_model_len=1024), backend=_Backend.TRITON_ATTN, - attention_fusions=0, - allreduce_fusions=65, + matches=Matches( + attention_fusion=0, + allreduce_fusion=65, + sequence_parallel=65, + async_tp=128, + ), ), ] @@ -79,19 +101,19 @@ class ModelBackendTestCase(NamedTuple): model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", model_kwargs=dict(max_model_len=1024), backend=_Backend.TRITON_ATTN, - attention_fusions=32, + matches=Matches(attention_fusion=32), ), ModelBackendTestCase( model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", model_kwargs=dict(max_model_len=1024), backend=_Backend.ROCM_ATTN, - attention_fusions=32, + matches=Matches(attention_fusion=32), ), ModelBackendTestCase( model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", model_kwargs=dict(max_model_len=1024), backend=_Backend.ROCM_AITER_UNIFIED_ATTN, - attention_fusions=32, + matches=Matches(attention_fusion=32), ), ] @@ -100,8 +122,7 @@ class ModelBackendTestCase(NamedTuple): @pytest.mark.parametrize( - "model_name, model_kwargs, backend, " - "attention_fusions, allreduce_fusions, custom_ops", + "model_name, model_kwargs, backend, matches, custom_ops", # Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8 list(flat_product(MODELS_FP8, CUSTOM_OPS_FP8)) # quant_fp4 only has the custom impl @@ -112,8 +133,7 @@ def test_attn_quant( model_name: str, model_kwargs: dict[str, Any], backend: _Backend, - attention_fusions: int, - allreduce_fusions: int, + matches: Matches, custom_ops: str, inductor_graph_partition: bool, caplog_mp_spawn, @@ -125,6 +145,11 @@ def test_attn_quant( pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer") if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): pytest.skip("Inductor graph partition requires torch>=2.9") + if inductor_graph_partition and "fp4" in model_name.lower(): + pytest.skip( + "Known bug for fp4 fusion & inductor partition: " + "https://github.com/vllm-project/vllm/issues/26988" + ) custom_ops_list = custom_ops.split(",") if custom_ops else [] @@ -160,12 +185,12 @@ def test_attn_quant( with caplog_mp_spawn(logging.DEBUG) as log_holder: run_model(compilation_config, model_name, **model_kwargs) - matches = re.findall( + log_matches = re.findall( r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes", log_holder.text, ) - assert len(matches) == 1, log_holder.text - assert int(matches[0]) == attention_fusions + assert len(log_matches) == 1, log_holder.text + assert int(log_matches[0]) == matches.attention_fusion # TODO(luka) test both in nightly @@ -179,8 +204,7 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]: @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize( - "model_name, model_kwargs, backend, " - "attention_fusions, allreduce_fusions, custom_ops", + "model_name, model_kwargs, backend, matches, custom_ops", # Toggle RMSNorm and QuantFP8 for FP8 models list( flat_product( @@ -201,8 +225,7 @@ def test_tp2_attn_quant_allreduce_rmsnorm( model_name: str, model_kwargs: dict, backend: _Backend, - attention_fusions: int, - allreduce_fusions: int, + matches: Matches, custom_ops: str, inductor_graph_partition: bool, caplog_mp_spawn, @@ -210,6 +233,11 @@ def test_tp2_attn_quant_allreduce_rmsnorm( ): if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): pytest.skip("Inductor graph partition requires torch>=2.9") + if inductor_graph_partition and "fp4" in model_name.lower(): + pytest.skip( + "Known bug for fp4 fusion & inductor partition: " + "https://github.com/vllm-project/vllm/issues/26988" + ) custom_ops_list = custom_ops.split(",") if custom_ops else [] @@ -250,23 +278,134 @@ def test_tp2_attn_quant_allreduce_rmsnorm( run_model( compilation_config, model_name, tensor_parallel_size=2, **model_kwargs ) - matches = re.findall( + log_matches = re.findall( r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes", log_holder.text, ) - assert len(matches) == 2, log_holder.text + assert len(log_matches) == 2, log_holder.text + + assert int(log_matches[0]) == matches.attention_fusion + assert int(log_matches[1]) == matches.attention_fusion + + log_matches = re.findall( + r"collective_fusion.py:\d+] Replaced (\d+) patterns", + log_holder.text, + ) + assert len(log_matches) == 2, log_holder.text + + assert int(log_matches[0]) == matches.allreduce_fusion + assert int(log_matches[1]) == matches.allreduce_fusion + + +# TODO luka resolve +CUSTOM_OPS_RMS_NORM = ["+rms_norm"] + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize( + "model_name, model_kwargs, backend, matches, custom_ops", + # Toggle RMSNorm and QuantFP8 for FP8 models + list( + flat_product( + MODELS_FP8, custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM) + ) + ) + # Toggle RMSNorm for FP4 models and unquant models + + list(flat_product(MODELS_FP4 + MODELS, CUSTOM_OPS_RMS_NORM)), +) +@pytest.mark.parametrize("inductor_graph_partition", [True, False]) +@pytest.mark.skipif( + not current_platform.is_cuda(), + reason="sequence parallel only tested on CUDA", +) +def test_tp2_attn_quant_async_tp( + model_name: str, + model_kwargs: dict, + backend: _Backend, + matches: Matches, + custom_ops: str, + inductor_graph_partition: bool, + caplog_mp_spawn, + monkeypatch, +): + if backend == _Backend.FLASHINFER and ( + not current_platform.is_device_capability((10, 0)) or not has_flashinfer() + ): + pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer") + if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): + pytest.skip("Inductor graph partition requires torch>=2.9") + if inductor_graph_partition and "fp4" in model_name.lower(): + pytest.skip( + "Known bug for fp4 fusion & inductor partition: " + "https://github.com/vllm-project/vllm/issues/26988" + ) + + custom_ops_list = custom_ops.split(",") if custom_ops else [] + + if inductor_graph_partition: + mode = CUDAGraphMode.FULL_AND_PIECEWISE + splitting_ops: list[str] | None = None + else: + mode = CUDAGraphMode.FULL_DECODE_ONLY + splitting_ops = [] + + # Disable, compile cache to make sure custom passes run. + # Otherwise, we can't verify fusion happened through the logs. + monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") + + # To capture subprocess logs, we need to know whether spawn or fork is used. + # Force spawn as it is more general. + monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name) + + compilation_config = CompilationConfig( + # Testing properties + use_inductor_graph_partition=inductor_graph_partition, + cudagraph_mode=mode, + custom_ops=custom_ops_list, + splitting_ops=splitting_ops, + # Common + level=CompilationMode.VLLM_COMPILE, + pass_config=PassConfig( + enable_attn_fusion=True, + enable_noop=True, + enable_sequence_parallelism=True, + enable_async_tp=True, + ), + # Inductor caches custom passes by default as well via uuid + inductor_compile_config={"force_disable_caches": True}, + ) + + with caplog_mp_spawn(logging.DEBUG) as log_holder: + run_model( + compilation_config, model_name, tensor_parallel_size=2, **model_kwargs + ) + log_matches = re.findall( + r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes", + log_holder.text, + ) + assert len(log_matches) == 2, log_holder.text + + assert int(log_matches[0]) == matches.attention_fusion + assert int(log_matches[1]) == matches.attention_fusion + + log_matches = re.findall( + r"sequence_parallelism.py:\d+] Replaced (\d+) patterns", + log_holder.text, + ) + assert len(log_matches) == 2, log_holder.text - assert int(matches[0]) == attention_fusions - assert int(matches[1]) == attention_fusions + assert int(log_matches[0]) == matches.sequence_parallel + assert int(log_matches[1]) == matches.sequence_parallel - matches = re.findall( + log_matches = re.findall( r"collective_fusion.py:\d+] Replaced (\d+) patterns", log_holder.text, ) - assert len(matches) == 2, log_holder.text + assert len(log_matches) == 2, log_holder.text - assert int(matches[0]) == allreduce_fusions - assert int(matches[1]) == allreduce_fusions + assert int(log_matches[0]) == matches.async_tp + assert int(log_matches[1]) == matches.async_tp def run_model(compile_config: int | CompilationConfig, model: str, **model_kwargs): diff --git a/tests/compile/test_sequence_parallelism.py b/tests/compile/test_sequence_parallelism.py index 31b6ddf3c698..81fb375af271 100644 --- a/tests/compile/test_sequence_parallelism.py +++ b/tests/compile/test_sequence_parallelism.py @@ -99,7 +99,6 @@ def __init__(self, hidden_size=16, intermediate_size=32): super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size - self.vllm_config = get_current_vllm_config() self.gate_proj = torch.nn.Parameter( torch.empty((intermediate_size, hidden_size)), requires_grad=False ) @@ -152,16 +151,12 @@ def forward(self, hidden_states, residual): def ops_in_model_before(self): ops_to_remove = [torch.ops.vllm.all_reduce.default] # Always removed by SP # The following are only removed if fusion happens - if ( - self.vllm_config - and self.vllm_config.compilation_config.pass_config.enable_fusion - ): - ops_to_remove.extend( - [ - torch.ops._C.fused_add_rms_norm.default, - torch.ops._C.static_scaled_fp8_quant.default, - ] - ) + config = get_current_vllm_config() + if config.compilation_config.pass_config.enable_fusion: + ops_to_remove.append(torch.ops._C.fused_add_rms_norm.default) + # Only check for static_scaled_fp8_quant if custom quant_fp8 is enabled + if "+quant_fp8" in config.compilation_config.custom_ops: + ops_to_remove.append(torch.ops._C.static_scaled_fp8_quant.default) return ops_to_remove def ops_in_model_after(self): @@ -169,24 +164,23 @@ def ops_in_model_after(self): torch.ops.vllm.reduce_scatter.default, torch.ops.vllm.all_gather.default, ] - # The following is only added if fusion happens + # The following is only added if fusion happens and custom quant_fp8 is enabled + config = get_current_vllm_config() if ( - self.vllm_config - and self.vllm_config.compilation_config.pass_config.enable_fusion + config.compilation_config.pass_config.enable_fusion + and "+quant_fp8" in config.compilation_config.custom_ops ): ops_to_add.append(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default) return ops_to_add def ops_in_model(self): - if ( - self.vllm_config - and self.vllm_config.compilation_config.pass_config.enable_fusion - ): - # If fusion happens, the fused op is the one + config = get_current_vllm_config() + if config.compilation_config.pass_config.enable_fusion: + # If fusion happens with custom quant_fp8, the fused op is the one # we check for (de)functionalization return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default] else: - # If no fusion, the original ops are checked + # If no fusion or using native quant, the original ops are checked return [ torch.ops._C.fused_add_rms_norm.default, # TODO functionalization pass does not handle this yet @@ -195,20 +189,30 @@ def ops_in_model(self): @multi_gpu_test(num_gpus=2) -@pytest.mark.parametrize("test_model_cls", [TestModel, TestQuantModel]) +@pytest.mark.parametrize( + "test_model_cls, custom_ops", + [ + (TestModel, ""), + (TestQuantModel, "+quant_fp8"), + (TestQuantModel, "-quant_fp8"), + ], +) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("seq_len", [16]) @pytest.mark.parametrize("hidden_size", [16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("enable_fusion", [True, False]) +@pytest.mark.parametrize("dynamic", [False, True]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") def test_sequence_parallelism_pass( test_model_cls: type[torch.nn.Module], + custom_ops: str, batch_size: int, seq_len: int, hidden_size: int, dtype: torch.dtype, enable_fusion: bool, + dynamic: bool, ): num_processes = 2 @@ -220,11 +224,13 @@ def run_torch_spawn(fn, nprocs): args=( num_processes, test_model_cls, + custom_ops, batch_size, seq_len, hidden_size, dtype, enable_fusion, + dynamic, ), nprocs=nprocs, ) @@ -236,11 +242,13 @@ def sequence_parallelism_pass_on_test_model( local_rank: int, world_size: int, test_model_cls: type[torch.nn.Module], + custom_ops: str, batch_size: int, seq_len: int, hidden_size: int, dtype: torch.dtype, enable_fusion: bool, + dynamic: bool, ): current_platform.seed_everything(0) @@ -264,12 +272,14 @@ def sequence_parallelism_pass_on_test_model( initialize_model_parallel(tensor_model_parallel_size=world_size) # configure vllm config for SequenceParallelismPass + custom_ops_list = custom_ops.split(",") if custom_ops else [] compilation_config = CompilationConfig( + custom_ops=custom_ops_list, pass_config=PassConfig( enable_sequence_parallelism=True, enable_fusion=enable_fusion, enable_noop=True, - ) + ), ) # NoOp needed for fusion device_config = DeviceConfig(device=torch.device("cuda")) @@ -317,6 +327,9 @@ def sequence_parallelism_pass_on_test_model( hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) + if dynamic: + torch._dynamo.mark_dynamic(hidden_states, 0) + torch._dynamo.mark_dynamic(residual, 0) compiled_model_no_func = torch.compile(model, backend=backend_no_func) compiled_model_no_func(hidden_states, residual) diff --git a/tests/lora/test_chatglm3_tp.py b/tests/lora/test_chatglm3_tp.py index f4f151180dec..8b21a6a6d2dd 100644 --- a/tests/lora/test_chatglm3_tp.py +++ b/tests/lora/test_chatglm3_tp.py @@ -103,7 +103,7 @@ def test_chatglm3_lora_tp4_fully_sharded_loras(chatglm3_lora_files): tensor_parallel_size=4, trust_remote_code=True, fully_sharded_loras=True, - gpu_memory_utilization=0.85, + gpu_memory_utilization=0.8, ) output1 = do_sample(llm, chatglm3_lora_files, lora_id=1) for i in range(len(EXPECTED_LORA_OUTPUT)): diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py index 31624a8fdcc0..caca08c394b5 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/sequence_parallelism.py @@ -10,109 +10,30 @@ from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + kFp8StaticTensorSym, +) from vllm.platforms import current_platform from .inductor_pass import enable_fake_mode +from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass logger = init_logger(__name__) -class _RMSNormAndQuantOpHelper: - """Base helper for RMSNorm and RMSNorm + Quantization functionalization.""" +class _SequenceParallelPatternHelper: + """Helper for sequence parallelism patterns.""" def __init__( self, epsilon: float, dtype: torch.dtype, device: str, - quant_op: torch._ops.OpOverload | None = None, - **kwargs, ): self.epsilon = epsilon self.dtype = dtype self.device = device - self.quant_op = quant_op - - def _functional_rmsnorm(self, result_buffer, input_tensor, weight_tensor): - return torch.ops.higher_order.auto_functionalized( - torch.ops._C.rms_norm.default, - result=result_buffer, - input=input_tensor, - weight=weight_tensor, - epsilon=self.epsilon, - ) - - def _functional_fused_add_rmsnorm( - self, input_tensor, residual_tensor, weight_tensor - ): - return torch.ops.higher_order.auto_functionalized( - torch.ops._C.fused_add_rms_norm.default, - input=input_tensor, - residual=residual_tensor, - weight=weight_tensor, - epsilon=self.epsilon, - ) - - def _functional_rmsnorm_then_quant( - self, - rmsnorm_result_buffer, - quant_result_buffer, - input_tensor, - weight_tensor, - scale_tensor, - ): - if self.quant_op is None: - raise RuntimeError( - "_RMSNormAndQuantOpHelper was not initialized with a quant_op." - ) - rmsnorm_out_tuple = self._functional_rmsnorm( - rmsnorm_result_buffer, input_tensor, weight_tensor - ) - quant_out_tuple = torch.ops.higher_order.auto_functionalized( - self.quant_op, - result=quant_result_buffer, - input=rmsnorm_out_tuple[1], - scale=scale_tensor, - ) - return quant_out_tuple - - def _functional_fused_add_rmsnorm_then_quant( - self, - quant_result_buffer, - input_tensor, - residual_tensor, - weight_tensor, - scale_tensor, - ): - if self.quant_op is None: - raise RuntimeError( - "_RMSNormAndQuantOpHelper was not initialized with a quant_op." - ) - fused_add_rmsnorm_out_tuple = self._functional_fused_add_rmsnorm( - input_tensor, residual_tensor, weight_tensor - ) - quant_out_tuple = torch.ops.higher_order.auto_functionalized( - self.quant_op, - result=quant_result_buffer, - input=fused_add_rmsnorm_out_tuple[1], - scale=scale_tensor, - ) - return quant_out_tuple, fused_add_rmsnorm_out_tuple[2] - - -class _SequenceParallelPatternHelper(_RMSNormAndQuantOpHelper): - """Helper for sequence parallelism patterns.""" - - def __init__( - self, - epsilon: float, - dtype: torch.dtype, - device: str, - quant_op: torch._ops.OpOverload | None = None, - **kwargs, - ): - super().__init__(epsilon, dtype, device, quant_op=quant_op, **kwargs) self.tp_group = get_tp_group() self.tp_size = get_tensor_model_parallel_world_size() @@ -131,36 +52,34 @@ def _all_gather(self, x: torch.Tensor) -> torch.Tensor: class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper): + def __init__(self, epsilon: float, dtype: torch.dtype, device: str): + super().__init__(epsilon, dtype, device) + self.rmsnorm_matcher = MatcherRMSNorm(epsilon) + def get_inputs(self): input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) - permute = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype) - return [input, permute, arg3_1] + return [input, arg3_1] def register(self, pm_pass: PatternMatcherPass): def pattern( input: torch.Tensor, - permute: torch.Tensor, arg3_1: torch.Tensor, ): all_reduce = self._all_reduce(input) - rmsnorm = self._functional_rmsnorm(permute, all_reduce, arg3_1) + rmsnorm = self.rmsnorm_matcher(all_reduce, arg3_1) - return rmsnorm[1], all_reduce + return rmsnorm, all_reduce def replacement( input: torch.Tensor, - permute: torch.Tensor, arg3_1: torch.Tensor, ): reduce_scatter = self._reduce_scatter(input) - rmsnorm_result = torch.empty_like(reduce_scatter) - rmsnorm = self._functional_rmsnorm(rmsnorm_result, reduce_scatter, arg3_1) - - all_gather = self._all_gather(rmsnorm[1]) - + rmsnorm = self.rmsnorm_matcher(reduce_scatter, arg3_1) + all_gather = self._all_gather(rmsnorm) return all_gather, reduce_scatter pm.register_replacement( @@ -169,6 +88,10 @@ def replacement( class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper): + def __init__(self, epsilon: float, dtype: torch.dtype, device: str): + super().__init__(epsilon, dtype, device) + self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon) + def get_inputs(self): mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) @@ -188,10 +111,8 @@ def pattern( rms_norm_weights: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: all_reduce = self._all_reduce(mm_1) - rmsnorm = self._functional_fused_add_rmsnorm( - all_reduce, residual, rms_norm_weights - ) - return rmsnorm[1], rmsnorm[2] + rmsnorm = self.rmsnorm_matcher(all_reduce, rms_norm_weights, residual) + return rmsnorm[0], rmsnorm[1] def replacement( residual: torch.Tensor, @@ -199,11 +120,9 @@ def replacement( rms_norm_weights: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: reduce_scatter = self._reduce_scatter(mm_1) - rmsnorm = self._functional_fused_add_rmsnorm( - reduce_scatter, residual, rms_norm_weights - ) - all_gather = self._all_gather(rmsnorm[1]) - return all_gather, rmsnorm[2] + rmsnorm = self.rmsnorm_matcher(reduce_scatter, rms_norm_weights, residual) + all_gather = self._all_gather(rmsnorm[0]) + return all_gather, rmsnorm[1] pm.register_replacement( pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass @@ -211,9 +130,12 @@ def replacement( class LastAllReduceRMSNormPattern(_SequenceParallelPatternHelper): + def __init__(self, epsilon: float, dtype: torch.dtype, device: str): + super().__init__(epsilon, dtype, device) + self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon) + def get_inputs(self): mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) - residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype) @@ -228,23 +150,19 @@ def pattern( residual: torch.Tensor, mm_1: torch.Tensor, rms_norm_weights: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: all_reduce = self._all_reduce(mm_1) - rmsnorm = self._functional_fused_add_rmsnorm( - all_reduce, residual, rms_norm_weights - ) - return rmsnorm[1] + rmsnorm = self.rmsnorm_matcher(all_reduce, rms_norm_weights, residual) + return rmsnorm[0] def replacement( residual: torch.Tensor, mm_1: torch.Tensor, rms_norm_weights: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: reduce_scatter = self._reduce_scatter(mm_1) - rmsnorm = self._functional_fused_add_rmsnorm( - reduce_scatter, residual, rms_norm_weights - ) - normalized = self._all_gather(rmsnorm[1]) + rmsnorm = self.rmsnorm_matcher(reduce_scatter, rms_norm_weights, residual) + normalized = self._all_gather(rmsnorm[0]) return normalized pm.register_replacement( @@ -257,52 +175,41 @@ def replacement( class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): def __init__( - self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload + self, + epsilon: float, + dtype: torch.dtype, + device: str, ): - super().__init__(epsilon, dtype, device, quant_op=op) + super().__init__(epsilon, dtype, device) + self.rmsnorm_matcher = MatcherRMSNorm(epsilon) + self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym) def get_inputs(self): input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype) - rmsnorm_result = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) - quant_result = torch.empty([1, 8, 4], device=self.device, dtype=FP8_DTYPE) weight = torch.empty([4], device=self.device, dtype=self.dtype) scale = torch.tensor(1.0, device=self.device, dtype=torch.float32) - return [input, rmsnorm_result, quant_result, weight, scale] + return [input, weight, scale] def register(self, pm_pass: PatternMatcherPass): def pattern( input: torch.Tensor, - rmsnorm_result: torch.Tensor, - quant_result: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, ): all_reduce = self._all_reduce(input) - static_fp8 = self._functional_rmsnorm_then_quant( - rmsnorm_result, quant_result, all_reduce, weight, scale - ) - return static_fp8[1], all_reduce + rms = self.rmsnorm_matcher(all_reduce, weight) + quant, _ = self.quant_matcher(rms, scale) + return quant, all_reduce def replacement( input: torch.Tensor, - rmsnorm_result: torch.Tensor, - quant_result: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, ): reduce_scatter = self._reduce_scatter(input) - - rmsnorm_result = torch.empty_like( - reduce_scatter, dtype=rmsnorm_result.dtype - ) - quant_result = torch.empty_like( - rmsnorm_result, # Output of RMSNorm - dtype=quant_result.dtype, - ) - static_fp8 = self._functional_rmsnorm_then_quant( - rmsnorm_result, quant_result, reduce_scatter, weight, scale - ) - all_gather = self._all_gather(static_fp8[1]) + rms = self.rmsnorm_matcher(reduce_scatter, weight) + quant, _ = self.quant_matcher(rms, scale) + all_gather = self._all_gather(quant) return all_gather, reduce_scatter @@ -312,59 +219,46 @@ def replacement( class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): - def __init__( - self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload - ): - super().__init__(epsilon, dtype, device, quant_op=op) + def __init__(self, epsilon: float, dtype: torch.dtype, device: str): + super().__init__(epsilon, dtype, device) + self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon) + self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym) def get_inputs(self): mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) - residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype) - result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE) scale = torch.empty([1, 1], device=self.device, dtype=torch.float32) - return [ - result, - residual, - mm_1, - rms_norm_weights, - scale, - ] + return [residual, mm_1, rms_norm_weights, scale] def register(self, pm_pass: PatternMatcherPass): def pattern( - result: torch.Tensor, residual: torch.Tensor, mm_1: torch.Tensor, rms_norm_weights: torch.Tensor, scale: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: all_reduce = self._all_reduce(mm_1) - static_fp8, rmsnorm_residual_out = ( - self._functional_fused_add_rmsnorm_then_quant( # noqa: E501 - result, all_reduce, residual, rms_norm_weights, scale - ) + rms, residual_out = self.rmsnorm_matcher( + all_reduce, rms_norm_weights, residual ) - return static_fp8[1], rmsnorm_residual_out + quant, _ = self.quant_matcher(rms, scale) + return quant, residual_out def replacement( - result: torch.Tensor, residual: torch.Tensor, mm_1: torch.Tensor, rms_norm_weights: torch.Tensor, scale: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: reduce_scatter = self._reduce_scatter(mm_1) - quant_result_buf = torch.empty_like(reduce_scatter, dtype=result.dtype) - static_fp8, rmsnorm_residual_out = ( - self._functional_fused_add_rmsnorm_then_quant( # noqa: E501 - quant_result_buf, reduce_scatter, residual, rms_norm_weights, scale - ) + rms, residual_out = self.rmsnorm_matcher( + reduce_scatter, rms_norm_weights, residual ) - all_gather = self._all_gather(static_fp8[1]) - return all_gather, rmsnorm_residual_out + quant, _ = self.quant_matcher(rms, scale) + all_gather = self._all_gather(quant) + return all_gather, residual_out pm.register_replacement( pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass @@ -372,54 +266,41 @@ def replacement( class LastAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): - def __init__( - self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload - ): - super().__init__(epsilon, dtype, device, quant_op=op) + def __init__(self, epsilon: float, dtype: torch.dtype, device: str): + super().__init__(epsilon, dtype, device) + self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon) + self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym) def get_inputs(self): mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) - residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype) - result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE) scale = torch.empty([1, 1], device=self.device, dtype=torch.float32) - return [ - result, - residual, - mm_1, - rms_norm_weights, - scale, - ] + return [residual, mm_1, rms_norm_weights, scale] def register(self, pm_pass: PatternMatcherPass): def pattern( - result: torch.Tensor, residual: torch.Tensor, mm_1: torch.Tensor, rms_norm_weights: torch.Tensor, scale: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: all_reduce = self._all_reduce(mm_1) - static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant( - result, all_reduce, residual, rms_norm_weights, scale - ) - return static_fp8[1] + rms, _ = self.rmsnorm_matcher(all_reduce, rms_norm_weights, residual) + quant, _ = self.quant_matcher(rms, scale) + return quant def replacement( - result: torch.Tensor, residual: torch.Tensor, mm_1: torch.Tensor, rms_norm_weights: torch.Tensor, scale: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: reduce_scatter = self._reduce_scatter(mm_1) - quant_result_buf = torch.empty_like(reduce_scatter, dtype=result.dtype) - static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant( - quant_result_buf, reduce_scatter, residual, rms_norm_weights, scale - ) - normalized = self._all_gather(static_fp8[1]) + rms, _ = self.rmsnorm_matcher(reduce_scatter, rms_norm_weights, residual) + quant, _ = self.quant_matcher(rms, scale) + normalized = self._all_gather(quant) return normalized pm.register_replacement( @@ -457,15 +338,14 @@ def __init__(self, config: VllmConfig): for epsilon in [1e-5, 1e-6]: # RMSNorm + Static FP8 quantization patterns - fp8_quant_op = torch.ops._C.static_scaled_fp8_quant.default FirstAllReduceRMSNormStaticFP8Pattern( - epsilon, self.model_dtype, self.device, fp8_quant_op + epsilon, self.model_dtype, self.device ).register(self.patterns) MiddleAllReduceRMSNormStaticFP8Pattern( - epsilon, self.model_dtype, self.device, fp8_quant_op + epsilon, self.model_dtype, self.device ).register(self.patterns) LastAllReduceRMSNormStaticFP8Pattern( - epsilon, self.model_dtype, self.device, fp8_quant_op + epsilon, self.model_dtype, self.device ).register(self.patterns) # Normal RMSNorm patterns diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index ff43e4e826df..931619a42287 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -367,7 +367,7 @@ class CompilationConfig: FULL_AND_PIECEWISE instead. """ - use_inductor_graph_partition: bool = False + use_inductor_graph_partition: bool = is_torch_equal_or_newer("2.9.0") """Use inductor graph partition to split the graph at cudagraph_unsafe ops. This partition happens at inductor codegen time after all passes and fusions are finished. It generates a single `call` function which wraps diff --git a/vllm/envs.py b/vllm/envs.py index 7dcfabe3e044..754e0ef1ec2e 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -134,7 +134,7 @@ VLLM_DP_RANK: int = 0 VLLM_DP_RANK_LOCAL: int = -1 VLLM_DP_SIZE: int = 1 - VLLM_USE_STANDALONE_COMPILE: bool = False + VLLM_USE_STANDALONE_COMPILE: bool = True VLLM_DP_MASTER_IP: str = "" VLLM_DP_MASTER_PORT: int = 0 VLLM_MOE_DP_CHUNK_SIZE: int = 256 @@ -500,9 +500,9 @@ def get_vllm_port() -> int | None: # In torch <= 2.7 we ignore this flag; in torch >= 2.8 this is # disabled by default. "VLLM_USE_STANDALONE_COMPILE": lambda: os.environ.get( - "VLLM_USE_STANDALONE_COMPILE", "0" - ) - == "1", + "VLLM_USE_STANDALONE_COMPILE", "1" + ).lower() + in ["true", "1"], # Debug pattern matching inside custom passes. # Should be set to the fx.Node name (e.g. 'getitem_34' or 'scaled_mm_3'). "VLLM_PATTERN_MATCH_DEBUG": lambda: os.environ.get(