diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 94c0944c838c..09ff4ee11097 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -172,6 +172,8 @@ steps: - tests/v1/engine/test_engine_core_client.py - tests/distributed/test_symm_mem_allreduce.py commands: + # https://github.com/NVIDIA/nccl/issues/1838 + - export NCCL_CUMEM_HOST_ENABLE=0 # test with torchrun tp=2 and external_dp=2 - torchrun --nproc-per-node=4 distributed/test_torchrun_example.py # test with torchrun tp=2 and pp=2 @@ -527,8 +529,7 @@ steps: # since torchao nightly is only compatible with torch nightly currently # https://github.com/pytorch/ao/issues/2919, we'll have to skip new torchao tests for now # we can only upgrade after this is resolved - # TODO(jerryzh168): resolve the above comment - - uv pip install --system torchao==0.13.0 + - pip install --pre torchao==0.15.0.dev20251014 --index-url https://download.pytorch.org/whl/nightly/cu128 - VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/ - label: LM Eval Small Models # 53min @@ -944,6 +945,8 @@ steps: - tests/v1/shutdown - tests/v1/worker/test_worker_memory_snapshot.py commands: + # https://github.com/NVIDIA/nccl/issues/1838 + - export NCCL_CUMEM_HOST_ENABLE=0 - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py - DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 121bdb750de5..2b7f70bea1ca 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -38,7 +38,7 @@ repos: rev: 0.9.1 hooks: - id: pip-compile - args: [requirements/test.in, -o, requirements/test.txt, --index-strategy, unsafe-best-match, --torch-backend, cu128, --python-platform, x86_64-manylinux_2_28] + args: [requirements/test.in, -o, requirements/test.txt, --index-strategy, unsafe-best-match, --extra-index-url, https://download.pytorch.org/whl/test/cu128, --python-platform, x86_64-manylinux_2_28] files: ^requirements/test\.(in|txt)$ - repo: local hooks: diff --git a/CMakeLists.txt b/CMakeLists.txt index 005590445361..e1059d949653 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -49,8 +49,8 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1 # requirements.txt files and should be kept consistent. The ROCm torch # versions are derived from docker/Dockerfile.rocm # -set(TORCH_SUPPORTED_VERSION_CUDA "2.8.0") -set(TORCH_SUPPORTED_VERSION_ROCM "2.8.0") +set(TORCH_SUPPORTED_VERSION_CUDA "2.9.0") +set(TORCH_SUPPORTED_VERSION_ROCM "2.9.0") # # Try to find python package with an executable that exactly matches diff --git a/docker/Dockerfile b/docker/Dockerfile index f9e07acb855c..b05beb648b27 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -55,7 +55,7 @@ ARG UV_INDEX_URL=${PIP_INDEX_URL} ARG UV_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL} # PyTorch provides its own indexes for standard and nightly builds -ARG PYTORCH_CUDA_INDEX_BASE_URL=https://download.pytorch.org/whl +ARG PYTORCH_CUDA_INDEX_BASE_URL=https://download.pytorch.org/whl/test ARG PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL=https://download.pytorch.org/whl/nightly # PIP supports multiple authentication schemes, including keyring @@ -356,6 +356,13 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist uv pip install --system dist/*.whl --verbose \ --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') +# TODO (huydhn): Remove this once xformers is released for 2.9.0 +RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH' + . /etc/environment + export TORCH_CUDA_ARCH_LIST='7.5 8.0+PTX 9.0a' + uv pip install --system --no-build-isolation "git+https://github.com/facebookresearch/xformers@v0.0.32.post2" +BASH + # Install FlashInfer pre-compiled kernel cache and binaries # https://docs.flashinfer.ai/installation.html RUN --mount=type=cache,target=/root/.cache/uv \ @@ -422,6 +429,7 @@ ARG PYTHON_VERSION ARG PIP_INDEX_URL UV_INDEX_URL ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL +ARG PYTORCH_CUDA_INDEX_BASE_URL # This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out # Reference: https://github.com/astral-sh/uv/pull/1694 @@ -434,7 +442,8 @@ ENV UV_LINK_MODE=copy RUN --mount=type=cache,target=/root/.cache/uv \ CUDA_MAJOR="${CUDA_VERSION%%.*}"; \ if [ "$CUDA_MAJOR" -ge 12 ]; then \ - uv pip install --system -r requirements/dev.txt; \ + uv pip install --system -r requirements/dev.txt \ + --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.'); \ fi # install development dependencies (for testing) diff --git a/docker/Dockerfile.cpu b/docker/Dockerfile.cpu index 2aed1872ee85..a4e72f019374 100644 --- a/docker/Dockerfile.cpu +++ b/docker/Dockerfile.cpu @@ -111,9 +111,13 @@ FROM base AS vllm-test-deps WORKDIR /workspace/vllm +# TODO: Update to 2.9.0 when there is a new build for intel_extension_for_pytorch for that version RUN --mount=type=bind,src=requirements/test.in,target=requirements/test.in \ cp requirements/test.in requirements/cpu-test.in && \ sed -i '/mamba_ssm/d' requirements/cpu-test.in && \ + sed -i 's/^torch==.*/torch==2.8.0/g' requirements/cpu-test.in && \ + sed -i 's/torchaudio.*/torchaudio/g' requirements/cpu-test.in && \ + sed -i 's/torchvision.*/torchvision/g' requirements/cpu-test.in && \ uv pip compile requirements/cpu-test.in -o requirements/cpu-test.txt --index-strategy unsafe-best-match --torch-backend cpu RUN --mount=type=cache,target=/root/.cache/uv \ diff --git a/docs/contributing/ci/update_pytorch_version.md b/docs/contributing/ci/update_pytorch_version.md index 3dae62dd5d94..ad7be3ac9cb4 100644 --- a/docs/contributing/ci/update_pytorch_version.md +++ b/docs/contributing/ci/update_pytorch_version.md @@ -87,7 +87,7 @@ is ineffective. While ongoing efforts like [#17419](gh-issue:17419) address the long build time at its source, the current workaround is to set `VLLM_CI_BRANCH` -to a custom branch provided by @khluu (`VLLM_CI_BRANCH=khluu/use_postmerge_q`) +to a custom branch provided by @khluu (`VLLM_CI_BRANCH=khluu/long_build`) when manually triggering a build on Buildkite. This branch accomplishes two things: 1. Increase the timeout limit to 10 hours so that the build doesn't time out. @@ -107,28 +107,24 @@ source to unblock the update process. ### FlashInfer -Here is how to build and install it from source with `torch2.7.0+cu128` in vLLM [Dockerfile](https://github.com/vllm-project/vllm/blob/27bebcd89792d5c4b08af7a65095759526f2f9e1/docker/Dockerfile#L259-L271): +After #25782, the pre-compiled FlashInfer wheel can be built using tools/flashinfer-build.sh +script. The new wheel can then be uploaded to [PyTorch test index](https://download.pytorch.org/whl/test/cu128/flashinfer_python-0.3.1-cp39-abi3-linux_x86_64.whl) and used during the update. + +During PyTorch 2.9 update, using the old FlashInfer wheel built for +2.8 led to a crash with the following error: ```bash -export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0 10.0+PTX' -export FLASHINFER_ENABLE_SM90=1 -uv pip install --system \ - --no-build-isolation "git+https://github.com/flashinfer-ai/flashinfer@v0.2.6.post1" +terminate called after throwing an instance of 'std::bad_array_new_length' ``` -One caveat is that building FlashInfer from source adds approximately 30 -minutes to the vLLM build time. Therefore, it's preferable to cache the wheel in a -public location for immediate installation, such as [this FlashInfer wheel link](https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl). For future releases, contact the PyTorch release -team if you want to get the package published there. - ### xFormers Similar to FlashInfer, here is how to build and install xFormers from source: ```bash -export TORCH_CUDA_ARCH_LIST='7.0 7.5 8.0 8.9 9.0 10.0+PTX' +export TORCH_CUDA_ARCH_LIST='7.5 8.0+PTX 9.0a' MAX_JOBS=16 uv pip install --system \ - --no-build-isolation "git+https://github.com/facebookresearch/xformers@v0.0.30" + --no-build-isolation "git+https://github.com/facebookresearch/xformers@v0.0.32.post2" ``` ## Update all the different vLLM platforms diff --git a/pyproject.toml b/pyproject.toml index 95dda76063bc..690ef49fec5d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ requires = [ "packaging>=24.2", "setuptools>=77.0.3,<80.0.0", "setuptools-scm>=8.0", - "torch == 2.8.0", + "torch == 2.9.0", "wheel", "jinja2", ] diff --git a/requirements/build.txt b/requirements/build.txt index 5f826a1afa14..ba09eaab70e8 100644 --- a/requirements/build.txt +++ b/requirements/build.txt @@ -4,7 +4,7 @@ ninja packaging>=24.2 setuptools>=77.0.3,<80.0.0 setuptools-scm>=8 -torch==2.8.0 +torch==2.9.0 wheel jinja2>=3.1.6 regex diff --git a/requirements/cuda.txt b/requirements/cuda.txt index 06956415d072..18ab9ce25ea2 100644 --- a/requirements/cuda.txt +++ b/requirements/cuda.txt @@ -5,11 +5,11 @@ numba == 0.61.2 # Required for N-gram speculative decoding # Dependencies for NVIDIA GPUs ray[cgraph]>=2.48.0 # Ray Compiled Graph, required for pipeline parallelism in V1. -torch==2.8.0 -torchaudio==2.8.0 +torch==2.9.0 +torchaudio==2.9.0 # These must be updated alongside torch -torchvision==0.23.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version +torchvision==0.24.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version # https://github.com/facebookresearch/xformers/releases/tag/v0.0.32.post1 -xformers==0.0.32.post1; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.8 +# xformers==0.0.32.post1; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.8 # FlashInfer should be updated together with the Dockerfile -flashinfer-python==0.4.0 \ No newline at end of file +flashinfer-python==0.4.0 diff --git a/requirements/rocm-build.txt b/requirements/rocm-build.txt index a86a8ab6df14..ca6043fc560e 100644 --- a/requirements/rocm-build.txt +++ b/requirements/rocm-build.txt @@ -1,10 +1,10 @@ # Common dependencies -r common.txt ---extra-index-url https://download.pytorch.org/whl/rocm6.3 -torch==2.8.0 -torchvision==0.23.0 -torchaudio==2.8.0 +--extra-index-url https://download.pytorch.org/whl/test/rocm6.3 +torch==2.9.0 +torchvision==0.24.0 +torchaudio==2.9.0 triton==3.3.0 cmake>=3.26.1,<4 diff --git a/requirements/test.in b/requirements/test.in index f0941d3c5918..55792a6aee62 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -24,9 +24,9 @@ soundfile # required for audio tests jiwer # required for audio tests tblib # for pickling test exceptions timm >=1.0.17 # required for internvl and gemma3n-mm test -torch==2.8.0 -torchaudio==2.8.0 -torchvision==0.23.0 +torch==2.9.0 +torchaudio==2.9.0 +torchvision==0.24.0 transformers_stream_generator # required for qwen-vl test matplotlib # required for qwen-vl test mistral_common[image,audio] >= 1.8.5 # required for voxtral test diff --git a/requirements/test.txt b/requirements/test.txt index 03fbdcc8d453..cfea1b48a70b 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,5 +1,5 @@ # This file was autogenerated by uv via the following command: -# uv pip compile requirements/test.in -o requirements/test.txt --index-strategy unsafe-best-match --torch-backend cu128 --python-platform x86_64-manylinux_2_28 +# uv pip compile requirements/test.in -o requirements/test.txt --index-strategy unsafe-best-match --python-platform x86_64-manylinux_2_28 absl-py==2.1.0 # via rouge-score accelerate==1.0.1 @@ -17,7 +17,6 @@ aiohttp==3.13.0 # aiohttp-cors # datasets # fsspec - # gpt-oss # lm-eval # ray aiohttp-cors==0.8.1 @@ -44,6 +43,10 @@ argcomplete==3.5.1 # via datamodel-code-generator arrow==1.3.0 # via isoduration +async-timeout==5.0.1 + # via + # aiohttp + # redis attrs==24.2.0 # via # aiohttp @@ -104,8 +107,6 @@ chardet==5.2.0 # via mbstrdecoder charset-normalizer==3.4.0 # via requests -chz==0.3.0 - # via gpt-oss click==8.1.7 # via # black @@ -176,9 +177,7 @@ distlib==0.3.9 dnspython==2.7.0 # via email-validator docker==7.1.0 - # via - # gpt-oss - # mlflow + # via mlflow docopt==0.6.2 # via num2words docstring-parser==0.17.0 @@ -203,10 +202,13 @@ eval-type-backport==0.2.2 # via mteb evaluate==0.4.3 # via lm-eval -fastapi==0.116.1 +exceptiongroup==1.3.0 # via - # gpt-oss - # mlflow-skinny + # anyio + # hypothesis + # pytest +fastapi==0.116.1 + # via mlflow-skinny fastparquet==2024.11.0 # via genai-perf fastrlock==0.8.2 @@ -281,8 +283,6 @@ google-resumable-media==2.7.2 # via google-cloud-storage googleapis-common-protos==1.70.0 # via google-api-core -gpt-oss==0.0.8 - # via -r requirements/test.in graphene==3.4.3 # via mlflow graphql-core==3.2.6 @@ -310,8 +310,6 @@ hf-xet==1.1.7 # via huggingface-hub hiredis==3.0.0 # via tensorizer -html2text==2025.4.15 - # via gpt-oss httpcore==1.0.6 # via httpx httpx==0.27.2 @@ -446,7 +444,6 @@ lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b772215 lxml==5.3.0 # via # blobfile - # gpt-oss # sacrebleu mako==1.3.10 # via alembic @@ -600,7 +597,7 @@ nvidia-cusparse-cu12==12.5.8.93 # torch nvidia-cusparselt-cu12==0.7.1 # via torch -nvidia-nccl-cu12==2.27.3 +nvidia-nccl-cu12==2.27.5 # via torch nvidia-nvjitlink-cu12==12.8.93 # via @@ -608,6 +605,8 @@ nvidia-nvjitlink-cu12==12.8.93 # nvidia-cusolver-cu12 # nvidia-cusparse-cu12 # torch +nvidia-nvshmem-cu12==3.3.20 + # via torch nvidia-nvtx-cu12==12.8.90 # via torch omegaconf==2.3.0 @@ -616,8 +615,6 @@ omegaconf==2.3.0 # lightning open-clip-torch==2.32.0 # via -r requirements/test.in -openai-harmony==0.0.4 - # via gpt-oss opencensus==0.11.4 # via ray opencensus-context==0.1.3 @@ -789,12 +786,10 @@ pydantic==2.12.0 # albumentations # datamodel-code-generator # fastapi - # gpt-oss # lightly # mistral-common # mlflow-skinny # mteb - # openai-harmony # pydantic-extra-types # ray pydantic-core==2.41.1 @@ -925,7 +920,6 @@ requests==2.32.3 # evaluate # google-api-core # google-cloud-storage - # gpt-oss # huggingface-hub # lightly # lm-eval @@ -1016,8 +1010,6 @@ setuptools==77.0.3 # via # lightning-utilities # pytablewriter - # torch - # triton shapely==2.1.1 # via # geopandas @@ -1069,8 +1061,6 @@ starlette-testclient==0.4.1 # via schemathesis statsmodels==0.14.4 # via genai-perf -structlog==25.4.0 - # via gpt-oss sympy==1.13.3 # via # einx @@ -1085,15 +1075,12 @@ tcolorpy==0.1.6 # via pytablewriter tenacity==9.1.2 # via - # gpt-oss # lm-eval # plotly tensorboardx==2.6.4 # via lightning tensorizer==2.10.1 # via -r requirements/test.in -termcolor==3.1.0 - # via gpt-oss terratorch @ git+https://github.com/IBM/terratorch.git@07184fcf91a1324f831ff521dd238d97fe350e3e # via -r requirements/test.in threadpoolctl==3.5.0 @@ -1104,7 +1091,6 @@ tifffile==2025.3.30 # terratorch tiktoken==0.12.0 # via - # gpt-oss # lm-eval # mistral-common timm==1.0.17 @@ -1118,11 +1104,18 @@ tokenizers==0.22.0 # via # -r requirements/test.in # transformers +toml==0.10.2 + # via datamodel-code-generator tomli==2.2.1 - # via schemathesis + # via + # alembic + # black + # coverage + # pytest + # schemathesis tomli-w==1.2.0 # via schemathesis -torch==2.8.0+cu128 +torch==2.9.0+cu128 # via # -r requirements/test.in # accelerate @@ -1151,12 +1144,12 @@ torch==2.8.0+cu128 # torchvision # vector-quantize-pytorch # vocos -torchaudio==2.8.0+cu128 +torchaudio==2.9.0+cu128 # via # -r requirements/test.in # encodec # vocos -torchgeo==0.7.0 +torchgeo==0.6.2 # via terratorch torchmetrics==1.7.4 # via @@ -1164,7 +1157,7 @@ torchmetrics==1.7.4 # pytorch-lightning # terratorch # torchgeo -torchvision==0.23.0+cu128 +torchvision==0.24.0+cu128 # via # -r requirements/test.in # lightly @@ -1205,7 +1198,7 @@ transformers==4.56.2 # transformers-stream-generator transformers-stream-generator==0.0.5 # via -r requirements/test.in -triton==3.4.0 +triton==3.5.0 # via torch tritonclient==2.51.0 # via @@ -1227,7 +1220,9 @@ typing-extensions==4.15.0 # aiosignal # albumentations # alembic - # chz + # anyio + # black + # exceptiongroup # fastapi # graphene # huggingface-hub @@ -1237,6 +1232,7 @@ typing-extensions==4.15.0 # mistral-common # mlflow-skinny # mteb + # multidict # opentelemetry-api # opentelemetry-sdk # opentelemetry-semantic-conventions @@ -1245,12 +1241,13 @@ typing-extensions==4.15.0 # pydantic-core # pydantic-extra-types # pytorch-lightning + # rich # sqlalchemy # torch - # torchgeo # typer # typeshed-client # typing-inspection + # uvicorn typing-inspection==0.4.2 # via pydantic tzdata==2024.2 @@ -1267,9 +1264,7 @@ urllib3==2.2.3 # responses # tritonclient uvicorn==0.35.0 - # via - # gpt-oss - # mlflow-skinny + # via mlflow-skinny vector-quantize-pytorch==1.21.2 # via -r requirements/test.in virtualenv==20.31.2 @@ -1288,7 +1283,7 @@ word2number==1.1 # via lm-eval wrapt==1.17.2 # via smart-open -xarray==2025.7.1 +xarray==2025.6.1 # via rioxarray xxhash==3.5.0 # via diff --git a/tests/compile/piecewise/test_multiple_graphs.py b/tests/compile/piecewise/test_multiple_graphs.py index d1f741479acf..246239b87d5f 100644 --- a/tests/compile/piecewise/test_multiple_graphs.py +++ b/tests/compile/piecewise/test_multiple_graphs.py @@ -20,6 +20,7 @@ set_current_vllm_config, ) from vllm.forward_context import BatchDescriptor, set_forward_context +from vllm.utils import is_torch_equal_or_newer # This import automatically registers `torch.ops.silly.attention` from .. import silly_attention # noqa: F401 @@ -193,9 +194,8 @@ def run_model( @pytest.mark.parametrize("use_inductor_graph_partition", [False, True]) def test_multi_graph_piecewise_compile(use_inductor_graph_partition: bool): - if use_inductor_graph_partition: - # FIXME(luka/boyuan): this currently fails - pytest.skip("Inductor graph partition not supported with multi-graph") + if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): + pytest.skip("inductor graph partition is only available in PyTorch 2.9+") outputs = [] diff --git a/tests/compile/test_decorator.py b/tests/compile/test_decorator.py index 4d60899a628a..05131c888b8e 100644 --- a/tests/compile/test_decorator.py +++ b/tests/compile/test_decorator.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest import torch from torch import nn @@ -14,6 +15,7 @@ set_current_vllm_config, ) from vllm.forward_context import BatchDescriptor, set_forward_context +from vllm.utils import is_torch_equal_or_newer # This import automatically registers `torch.ops.silly.attention` from . import silly_attention # noqa: F401 @@ -65,19 +67,40 @@ def run_model( return output.cpu() -def test_ignore_torch_compile_decorator(): - # vllmcompile +@pytest.mark.parametrize("use_inductor_graph_partition", [True, False]) +def test_ignore_torch_compile_decorator(use_inductor_graph_partition, monkeypatch): + # disable compile cache so that we can count the number of compilations + # appropriately + monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") + + if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): + pytest.skip("inductor graph partition is only available in PyTorch 2.9+") + + # Compile piecewise with VLLM_COMPILE vllm_config = VllmConfig( compilation_config=CompilationConfig( mode=CompilationMode.VLLM_COMPILE, use_cudagraph=True, splitting_ops=["silly::attention"], cudagraph_capture_sizes=[1, 2], - use_inductor_graph_partition=False, # TODO test both? + use_inductor_graph_partition=use_inductor_graph_partition, ) ) cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE + expected_num_graphs_seen = 1 + expected_num_cudagraph_captured = ( + 4 # num_cudagraph_sizes * num cudagraphs to capture + ) + if use_inductor_graph_partition: + expected_num_piecewise_graphs_seen = 1 + expected_num_piecewise_capturable_graphs_seen = 1 + expected_num_backend_compilations = 1 + else: + expected_num_piecewise_graphs_seen = 3 + expected_num_piecewise_capturable_graphs_seen = 2 + expected_num_backend_compilations = 2 + @support_torch_compile class A(nn.Module): def __init__( @@ -104,12 +127,11 @@ class C(B): ... # A has support_torch_compile with compilation_counter.expect( - num_graphs_seen=1, - num_piecewise_graphs_seen=3, - num_piecewise_capturable_graphs_seen=2, - num_backend_compilations=2, - num_cudagraph_captured=4, - # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + num_graphs_seen=expected_num_graphs_seen, + num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen, + num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen, + num_backend_compilations=expected_num_backend_compilations, + num_cudagraph_captured=expected_num_cudagraph_captured, ): run_model(vllm_config, mod_A, cudagraph_runtime_mode) @@ -131,12 +153,11 @@ class C(B): ... # C's support_torch_compile should override B's ignore_torch_compile with compilation_counter.expect( - num_graphs_seen=1, - num_piecewise_graphs_seen=3, - num_piecewise_capturable_graphs_seen=2, - num_backend_compilations=2, - num_cudagraph_captured=4, - # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + num_graphs_seen=expected_num_graphs_seen, + num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen, + num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen, + num_backend_compilations=expected_num_backend_compilations, + num_cudagraph_captured=expected_num_cudagraph_captured, ): run_model(vllm_config, mod_C, cudagraph_runtime_mode) @@ -179,7 +200,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -def test_conditional_compile_enable_if(): +@pytest.mark.parametrize("use_inductor_graph_partition", [True, False]) +def test_conditional_compile_enable_if(use_inductor_graph_partition, monkeypatch): + # disable compile cache so that we can count the number of compilations + # appropriately + monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") + + if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): + pytest.skip("inductor graph partition is only available in PyTorch 2.9+") + vllm_config = VllmConfig( cache_config=CacheConfig( kv_sharing_fast_prefill=True, @@ -189,7 +218,7 @@ def test_conditional_compile_enable_if(): use_cudagraph=True, splitting_ops=["silly::attention"], cudagraph_capture_sizes=[1, 2], - use_inductor_graph_partition=False, # TODO test both + use_inductor_graph_partition=use_inductor_graph_partition, ), ) cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE @@ -197,17 +226,26 @@ def test_conditional_compile_enable_if(): with set_current_vllm_config(vllm_config): mod_A = A(vllm_config=vllm_config, prefix="").eval().cuda() + if use_inductor_graph_partition: + expected_num_piecewise_graphs_seen = 2 + expected_num_piecewise_capturable_graphs_seen = 2 + expected_num_backend_compilations = 2 + else: + expected_num_piecewise_graphs_seen = 6 + expected_num_piecewise_capturable_graphs_seen = 4 + expected_num_backend_compilations = 4 + # A has support_torch_compile but enable_if fn returns False # enalbe_if will be True for B, so we expect mod1 and mod2 # to be compiled with compilation_counter.expect( num_graphs_seen=2, - num_piecewise_graphs_seen=6, + num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen, # 3 piecewise graphs per instance of B() - num_piecewise_capturable_graphs_seen=4, - num_backend_compilations=4, + num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen, + num_backend_compilations=expected_num_backend_compilations, num_cudagraph_captured=8, - # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + # num_cudagraph_sizes * num cudagraphable graphs to capture ): run_model(vllm_config, mod_A, cudagraph_runtime_mode) @@ -222,20 +260,30 @@ def test_conditional_compile_enable_if(): use_cudagraph=True, splitting_ops=["silly::attention"], cudagraph_capture_sizes=[1, 2], - use_inductor_graph_partition=False, # TODO test both? + use_inductor_graph_partition=use_inductor_graph_partition, ), ) with set_current_vllm_config(vllm_config): mod_A = A(vllm_config=vllm_config, prefix="").eval().cuda() + if use_inductor_graph_partition: + expected_num_piecewise_graphs_seen = 1 + expected_num_piecewise_capturable_graphs_seen = 1 + expected_num_backend_compilations = 1 + else: + # 3 attn ops and 4 non-attn ops + expected_num_piecewise_graphs_seen = 7 + expected_num_piecewise_capturable_graphs_seen = 4 + expected_num_backend_compilations = 4 + with compilation_counter.expect( num_graphs_seen=1, - num_piecewise_graphs_seen=7, + num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen, # 3 attn ops and 4 non-attn ops - num_piecewise_capturable_graphs_seen=4, - num_backend_compilations=4, + num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen, + num_backend_compilations=expected_num_backend_compilations, num_cudagraph_captured=8, - # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + # num_cudagraph_sizes * num cudagraphable graphs to capture ): run_model(vllm_config, mod_A, cudagraph_runtime_mode) diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 2d290771f9ad..1b696ce5137e 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -184,6 +184,9 @@ def test_inductor_graph_partition_attn_fusion(caplog_vllm): pytest.skip("inductor graph partition is only available in PyTorch 2.9+") model = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8" + if current_platform.get_device_capability()[0] < 10: + pytest.skip(f"{model} can only be loaded by B200 or above") + compilation_config = CompilationConfig( mode=CompilationMode.VLLM_COMPILE, use_inductor_graph_partition=True, diff --git a/tests/standalone_tests/python_only_compile.sh b/tests/standalone_tests/python_only_compile.sh index 7cc5ef659649..7232ee3a090a 100644 --- a/tests/standalone_tests/python_only_compile.sh +++ b/tests/standalone_tests/python_only_compile.sh @@ -18,7 +18,9 @@ apt autoremove -y echo 'import os; os.system("touch /tmp/changed.file")' >> vllm/__init__.py -VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL=1 VLLM_USE_PRECOMPILED=1 pip3 install -vvv -e . +# TESTING, TO BE REMOVED +VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL=1 VLLM_USE_PRECOMPILED=1 pip3 install -vvv -e . \ + --extra-index-url https://download.pytorch.org/whl/test/cu128 # Run the script python3 -c 'import vllm' diff --git a/tests/utils.py b/tests/utils.py index 5bfdf703390e..fe35fb784cb4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1021,8 +1021,8 @@ def create_new_process_for_each_test( assert method in ["spawn", "fork"], "Method must be either 'spawn' or 'fork'" - if method == "fork": - return fork_new_process_for_each_test + # if method == "fork": + # return fork_new_process_for_each_test return spawn_new_process_for_each_test diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index a34fb0bf920c..2b57f9ba79fc 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -366,7 +366,7 @@ class CompilationConfig: FULL_AND_PIECEWISE instead. """ - use_inductor_graph_partition: bool = False + use_inductor_graph_partition: bool = is_torch_equal_or_newer("2.9.0") """Use inductor graph partition to split the graph at cudagraph_unsafe ops. This partition happens at inductor codegen time after all passes and fusions are finished. It generates a single `call` function which wraps diff --git a/vllm/env_override.py b/vllm/env_override.py index 7f9054e73846..82d1071369ea 100644 --- a/vllm/env_override.py +++ b/vllm/env_override.py @@ -5,6 +5,7 @@ import torch from vllm.logger import init_logger +from vllm.utils import _is_torch_equal, is_torch_equal_or_newer logger = init_logger(__name__) @@ -21,3 +22,189 @@ os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" # see https://github.com/vllm-project/vllm/issues/10619 torch._inductor.config.compile_threads = 1 + + +# ======================================== +# torch 2.9 Inductor Scheduler monkeypatch +# ======================================== +# This change monkeypatches a function in Inductor to work around the following +# bug: https://github.com/vllm-project/vllm/issues/26678 +# +# The bug occurs when `use_inductor_graph_partition` is turned on and there +# exists operators inside of `splitting_ops` that have an in-place mutation. In +# vllm, this specifically occurs on the operator +# vllm.unified_attention_with_output. In this case, inductor does not populate +# the inductor IR's `origin_node` field, causing an assertion error when trying +# to access the node's `origin_node` field. +# +# So, we will monkeypatch torch._inductor.scheduler.Scheduler.should_partition +# so that it does not access the inductor IR node's `origin_node` field and just +# returns True if a node is registered as having a custom partition function. +# This is ok for now since vllm's implementation of the custom partition +# functions just return True. +# ======================================== + + +def should_partition_patched(self, node, should_log: bool = False) -> bool: + # This is a patched version of + # torch._inductor.scheduler.Scheduler.should_partition that modifies + # the following piece of code so that we always return True: + # https://github.com/pytorch/pytorch/blob/ecb53078faf86ca1b33277df33b82985675bb011/torch/_inductor/scheduler.py#L4712-L4724 + """Return True if we should partition the inductor graph on this node""" + + import torch._inductor.ir as ir + from torch._inductor.scheduler import ( + BaseSchedulerNode, + FusedSchedulerNode, + _custom_should_partition_fns, + ) + from torch._inductor.utils import ( + _unstable_customized_partition_wrapper, + is_cudagraph_unsafe_op, + maybe_log_cudagraph_partition, + ) + + # Allow users to manually specify if a node should be partitioned + # Can only do this for FallbackKernels + ir_node = node.node + if isinstance(ir_node, ir.FallbackKernel): + operator = ir_node.op_overload + if operator is not None and operator in _custom_should_partition_fns: + return True + + # When not using cudagraphs, keep all kernels in the `call` function + # instead of graph partition functions, since graph partition only brings + # benefit to cudagraph + if ( + not torch._inductor.config.triton.cudagraphs + and _unstable_customized_partition_wrapper.wrapper is None + ): + return True + + # avoid duplicating logs when should_partition is called multiple times + # on the same node + def noop_log(msg: str, node: BaseSchedulerNode | None) -> None: + return + + log_partition_reason = maybe_log_cudagraph_partition if should_log else noop_log + + if isinstance(node, FusedSchedulerNode): + return any(self.should_partition(snode) for snode in node.snodes) + + assert node.node is not None + + if not node.is_gpu(): + log_partition_reason("non gpu ops", node=node) + + return True + + if isinstance(node.node, ir.DeviceCopy): + log_partition_reason("DeviceCopy ops", node=node) + return True + + if isinstance(node.node, ir.Conditional): + log_partition_reason("Conditional ops", node=node) + return True + + if getattr(node.node, "unbacked_bindings", None): + log_partition_reason("unbacked binding ops", node=node) + return True + + if is_cudagraph_unsafe_op(node.node): + log_partition_reason("CUDAGraph-unsafe custom ops", node=node) + return True + + return False + + +def _update_scheduler_patched(self) -> None: + # Copied from torch._inductor.graph.GrahLowering._update_scheduler. Patches + # this method so that we can patch Scheduler.should_partition with the + # function above + """ + (Re)initializes the scheduler member. When initializing the scheduler, no CUBIN + files should be generated (to avoid biasing any benchmarks and pessimizing + fusion decisions). + """ + import torch._inductor.config as config + from torch._inductor.scheduler import Scheduler + + Scheduler.should_partition = should_partition_patched + + with config.patch("triton.store_cubin", False): + self.scheduler = Scheduler(self.operations) + + +# =================================================== +# torch 2.9 Inductor PythonWrapperCodegen monkeypatch +# =================================================== +# This change monkeypatches memory_plan_reuse in pytorch 2.9.0 to work around +# a test failure for test_multi_graph_piecewise_compile_outputs_equal. +# For more context, see https://github.com/pytorch/pytorch/pull/165514. + + +def memory_plan_reuse_patched(self): + import torch._inductor.ir as ir + from torch._inductor.codegen.wrapper import ( + EnterSubgraphLine, + ExitSubgraphLine, + MemoryPlanningLine, + MemoryPlanningState, + SubgraphPythonWrapperCodegen, + ) + from torch._inductor.virtualized import V + + def get_output_names(graph_outputs) -> list[str]: + import itertools + + names = [] + shape_counter = itertools.count(0) + none_counter = itertools.count(0) + for node in graph_outputs: + if isinstance(node, ir.NoneAsConstantBuffer): + names.append(f"{V.graph.name}_none{next(none_counter)}") + elif isinstance(node, ir.ShapeAsConstantBuffer): + names.append(f"{V.graph.name}_shape{next(shape_counter)}") + else: + names.append(node.get_name()) + return names + + if ( + isinstance(V.graph.wrapper_code, SubgraphPythonWrapperCodegen) + and V.graph.wrapper_code.partition_signatures is not None + ): + out_names = get_output_names( + V.graph.wrapper_code.partition_signatures.output_nodes + ) + else: + out_names = V.graph.get_output_names() + + while ( + self.lines + and isinstance(self.lines[-1], MemoryPlanningLine) + and self.lines[-1].node.name not in out_names # type: ignore[attr-defined] + ): + # these lines will be pointless + self.lines.pop() + + # codegen allocations in two passes + planning_states = [MemoryPlanningState()] + past_planning_states = [] + for i in range(len(self.lines)): + line = self.lines[i] + if isinstance(line, MemoryPlanningLine): + self.lines[i] = line.plan(planning_states[-1]) + elif isinstance(line, EnterSubgraphLine): + planning_states.append(MemoryPlanningState()) + elif isinstance(line, ExitSubgraphLine): + past_planning_states.append(planning_states.pop()) + past_planning_states.append(planning_states.pop()) + assert len(planning_states) == 0 + + +if _is_torch_equal("2.9.0"): + from torch._inductor.graph import GraphLowering + GraphLowering._update_scheduler = _update_scheduler_patched + + from torch._inductor.codegen.wrapper import PythonWrapperCodegen + PythonWrapperCodegen.memory_plan_reuse = memory_plan_reuse_patched diff --git a/vllm/envs.py b/vllm/envs.py index b5c7f325f670..1d7441b0ab2d 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -63,7 +63,7 @@ VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False VLLM_USE_RAY_WRAPPED_PP_COMM: bool = True VLLM_XLA_USE_SPMD: bool = False - VLLM_WORKER_MULTIPROC_METHOD: Literal["fork", "spawn"] = "fork" + VLLM_WORKER_MULTIPROC_METHOD: Literal["fork", "spawn"] = "spawn" VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets") VLLM_ASSETS_CACHE_MODEL_CLEAN: bool = False VLLM_IMAGE_FETCH_TIMEOUT: int = 5 @@ -133,7 +133,7 @@ VLLM_DP_RANK: int = 0 VLLM_DP_RANK_LOCAL: int = -1 VLLM_DP_SIZE: int = 1 - VLLM_USE_STANDALONE_COMPILE: bool = False + VLLM_USE_STANDALONE_COMPILE: bool = True VLLM_DP_MASTER_IP: str = "" VLLM_DP_MASTER_PORT: int = 0 VLLM_MOE_DP_CHUNK_SIZE: int = 256 @@ -498,9 +498,9 @@ def get_vllm_port() -> int | None: # In torch <= 2.7 we ignore this flag; in torch >= 2.8 this is # disabled by default. "VLLM_USE_STANDALONE_COMPILE": lambda: os.environ.get( - "VLLM_USE_STANDALONE_COMPILE", "0" - ) - == "1", + "VLLM_USE_STANDALONE_COMPILE", "1" + ).lower() + in ["true", "1"], # Debug pattern matching inside custom passes. # Should be set to the fx.Node name (e.g. 'getitem_34' or 'scaled_mm_3'). "VLLM_PATTERN_MATCH_DEBUG": lambda: os.environ.get( @@ -675,7 +675,7 @@ def get_vllm_port() -> int | None: # Use dedicated multiprocess context for workers. # Both spawn and fork work "VLLM_WORKER_MULTIPROC_METHOD": env_with_choices( - "VLLM_WORKER_MULTIPROC_METHOD", "fork", ["spawn", "fork"] + "VLLM_WORKER_MULTIPROC_METHOD", "spawn", ["spawn", "fork"] ), # Path to the cache for storing downloaded assets "VLLM_ASSETS_CACHE": lambda: os.path.expanduser( diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 9f66e47dcb96..6d5e64dcaa7f 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1126,7 +1126,7 @@ def fused_topk_bias( # This is used by the Deepseek-V2 and Deepseek-V3 model -@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) +# @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) def grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index bb5d3a688094..f2776490e795 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -3265,6 +3265,33 @@ def _is_torch_equal_or_newer(torch_version: str, target: str) -> bool: return torch_version >= version.parse(target) +def _is_torch_equal(target: str) -> bool: + assert target.count(".") == 2 + torch_version = str(torch.__version__) + torch_version = version.parse(torch_version) + # torch version is like "2.6.0.dev20240101" or "2.6.0.dev20240101+cpu" + # or "2.6.0+cu128" but never "2.6.0.1" + return ( + torch_version >= version.parse(target) + and version.parse(target + ".1") > torch_version + ) + + +def is_torch_equal(target: str) -> bool: + """Check if the installed torch version is == the target version. + + Args: + target: a version string, like "2.6.0". + + Returns: + Whether the condition meets. + """ + try: + return _is_torch_equal(target) + except Exception: + return Version(importlib.metadata.version("torch")) == Version(target) + + @cache def _has_module(module_name: str) -> bool: """Return True if *module_name* can be found in the current environment.