1- # default base image
2- ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
3-
4- FROM $BASE_IMAGE
5-
6- ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
7-
8- RUN echo "Base image is $BASE_IMAGE"
9-
10- ARG ROCm_5_7_BASE="rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" \
11- ROCm_6_0_BASE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
12-
1+ # Default ROCm 6.1 base image
2+ ARG BASE_IMAGE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging"
3+
4+ # Tested and supported base rocm/pytorch images
5+ ARG ROCm_5_7_BASE="rocm/pytorch:rocm5.7_ubuntu20.04_py3.9_pytorch_2.0.1" \
6+ ROCm_6_0_BASE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" \
7+ ROCM_6_1_BASE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging"
8+
9+ # Default ROCm ARCHes to build vLLM for.
10+ ARG PYTORCH_ROCM_ARCH="gfx908;gfx90a;gfx942;gfx1100"
11+
12+ # Whether to build CK-based flash-attention
13+ # If 0, will not build flash attention
14+ # This is useful for gfx target where flash-attention is not supported
15+ # (i.e. those that do not appear in `FA_GFX_ARCHS`)
16+ # Triton FA is used by default on ROCm now so this is unnecessary.
17+ ARG BUILD_FA="1"
1318ARG FA_GFX_ARCHS="gfx90a;gfx942"
14- RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS"
15-
1619ARG FA_BRANCH="ae7928c"
17- RUN echo "FA_BRANCH is $FA_BRANCH"
1820
19- # whether to build flash-attention
20- # if 0, will not build flash attention
21- # this is useful for gfx target where flash-attention is not supported
22- # In that case, we need to use the python reference attention implementation in vllm
23- ARG BUILD_FA="1"
24-
25- # whether to build triton on rocm
21+ # Whether to build triton on rocm
2622ARG BUILD_TRITON="1"
23+ ARG TRITON_BRANCH="0ef1848"
2724
28- # Install some basic utilities
29- RUN apt-get update && apt-get install python3 python3-pip -y
25+ ### Base image build stage
26+ FROM $BASE_IMAGE AS base
27+
28+ # Import arg(s) defined before this build stage
29+ ARG PYTORCH_ROCM_ARCH
3030
3131# Install some basic utilities
32+ RUN apt-get update && apt-get install python3 python3-pip -y
3233RUN apt-get update && apt-get install -y \
3334 curl \
3435 ca-certificates \
@@ -39,79 +40,159 @@ RUN apt-get update && apt-get install -y \
3940 build-essential \
4041 wget \
4142 unzip \
42- nvidia-cuda-toolkit \
4343 tmux \
4444 ccache \
4545 && rm -rf /var/lib/apt/lists/*
4646
47- ### Mount Point ###
48- # When launching the container, mount the code directory to /app
47+ # When launching the container, mount the code directory to /vllm-workspace
4948ARG APP_MOUNT=/vllm-workspace
50- VOLUME [ ${APP_MOUNT} ]
5149WORKDIR ${APP_MOUNT}
5250
53- RUN python3 -m pip install --upgrade pip
54- RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas
51+ RUN pip install --upgrade pip
52+ # Remove sccache so it doesn't interfere with ccache
53+ # TODO: implement sccache support across components
54+ RUN apt-get purge -y sccache; pip uninstall -y sccache; rm -f "$(which sccache)"
55+ # Install torch == 2.4.0 on ROCm
56+ RUN case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
57+ *"rocm-5.7"*) \
58+ pip uninstall -y torch \
59+ && pip install --no-cache-dir --pre torch==2.4.0.dev20240612 \
60+ --index-url https://download.pytorch.org/whl/nightly/rocm5.7;; \
61+ *"rocm-6.0"*) \
62+ pip uninstall -y torch \
63+ && pip install --no-cache-dir --pre torch==2.4.0.dev20240612 \
64+ --index-url https://download.pytorch.org/whl/nightly/rocm6.0;; \
65+ *"rocm-6.1"*) \
66+ pip uninstall -y torch \
67+ && pip install --no-cache-dir --pre torch==2.4.0.dev20240612 \
68+ --index-url https://download.pytorch.org/whl/nightly/rocm6.1;; \
69+ *) ;; esac
5570
5671ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer
5772ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin:
5873ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib:
5974ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/:
6075
61- # Install ROCm flash-attention
62- RUN if [ "$BUILD_FA" = "1" ]; then \
63- mkdir libs \
76+ ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}
77+ ENV CCACHE_DIR=/root/.cache/ccache
78+
79+
80+ ### AMD-SMI build stage
81+ FROM base AS build_amdsmi
82+ # Build amdsmi wheel always
83+ RUN cd /opt/rocm/share/amd_smi \
84+ && pip wheel . --wheel-dir=/install
85+
86+
87+ ### Flash-Attention wheel build stage
88+ FROM base AS build_fa
89+ ARG BUILD_FA
90+ ARG FA_GFX_ARCHS
91+ ARG FA_BRANCH
92+ # Build ROCm flash-attention wheel if `BUILD_FA = 1`
93+ RUN --mount=type=cache,target=${CCACHE_DIR} \
94+ if [ "$BUILD_FA" = "1" ]; then \
95+ mkdir -p libs \
6496 && cd libs \
6597 && git clone https://github.com/ROCm/flash-attention.git \
6698 && cd flash-attention \
67- && git checkout ${FA_BRANCH} \
99+ && git checkout " ${FA_BRANCH}" \
68100 && git submodule update --init \
69- && export GPU_ARCHS=${FA_GFX_ARCHS} \
70- && if [ "$BASE_IMAGE" = "$ROCm_5_7_BASE" ]; then \
71- patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch; fi \
72- && python3 setup.py install \
73- && cd ..; \
101+ && case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
102+ *"rocm-5.7"*) \
103+ export VLLM_TORCH_PATH="$(python3 -c 'import torch; print(torch.__path__[0])')" \
104+ && patch "${VLLM_TORCH_PATH}"/utils/hipify/hipify_python.py hipify_patch.patch;; \
105+ *) ;; esac \
106+ && GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \
107+ # Create an empty directory otherwise as later build stages expect one
108+ else mkdir -p /install; \
74109 fi
75110
76- # Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
77- # Manually removed it so that later steps of numpy upgrade can continue
78- RUN if [ "$BASE_IMAGE" = "$ROCm_6_0_BASE" ]; then \
79- rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi
80111
81- # build triton
82- RUN if [ "$BUILD_TRITON" = "1" ]; then \
112+ ### Triton wheel build stage
113+ FROM base AS build_triton
114+ ARG BUILD_TRITON
115+ ARG TRITON_BRANCH
116+ # Build triton wheel if `BUILD_TRITON = 1`
117+ RUN --mount=type=cache,target=${CCACHE_DIR} \
118+ if [ "$BUILD_TRITON" = "1" ]; then \
83119 mkdir -p libs \
84120 && cd libs \
85- && pip uninstall -y triton \
86- && git clone https://github.com/ROCm/triton.git \
87- && cd triton/python \
88- && pip3 install . \
89- && cd ../..; \
121+ && git clone https://github.com/OpenAI/triton.git \
122+ && cd triton \
123+ && git checkout "${TRITON_BRANCH}" \
124+ && cd python \
125+ && python3 setup.py bdist_wheel --dist-dir=/install; \
126+ # Create an empty directory otherwise as later build stages expect one
127+ else mkdir -p /install; \
90128 fi
91129
92- WORKDIR /vllm-workspace
130+
131+ ### Final vLLM build stage
132+ FROM base AS final
133+ # Import the vLLM development directory from the build context
93134COPY . .
94135
95- #RUN python3 -m pip install pynvml # to be removed eventually
96- RUN python3 -m pip install --upgrade pip numba
136+ # Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
137+ # Manually remove it so that later steps of numpy upgrade can continue
138+ RUN case "$(which python3)" in \
139+ *"/opt/conda/envs/py_3.9"*) \
140+ rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/;; \
141+ *) ;; esac
142+
143+ # Package upgrades for useful functionality or to avoid dependency issues
144+ RUN --mount=type=cache,target=/root/.cache/pip \
145+ pip install --upgrade numba scipy huggingface-hub[cli]
97146
98- # make sure punica kernels are built (for LoRA)
147+ # Make sure punica kernels are built (for LoRA)
99148ENV VLLM_INSTALL_PUNICA_KERNELS=1
100149# Workaround for ray >= 2.10.0
101150ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1
151+ # Silences the HF Tokenizers warning
152+ ENV TOKENIZERS_PARALLELISM=false
102153
103- ENV VLLM_NCCL_SO_PATH=/opt/rocm/lib/librccl.so
104-
105- ENV CCACHE_DIR=/root/.cache/ccache
106- RUN --mount=type=cache,target=/root/.cache/ccache \
154+ RUN --mount=type=cache,target=${CCACHE_DIR} \
107155 --mount=type=cache,target=/root/.cache/pip \
108156 pip install -U -r requirements-rocm.txt \
109- && if [ "$BASE_IMAGE" = "$ROCm_6_0_BASE" ]; then \
110- patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch; fi \
111- && python3 setup.py install \
112- && export VLLM_PYTHON_VERSION=$(python -c "import sys; print(str(sys.version_info.major) + str(sys.version_info.minor))") \
113- && cp build/lib.linux-x86_64-cpython-${VLLM_PYTHON_VERSION}/vllm/*.so vllm/ \
114- && cd ..
157+ && case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
158+ *"rocm-6.0"*) \
159+ patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h rocm_patch/rocm_bf16.patch;; \
160+ *"rocm-6.1"*) \
161+ # Bring in upgrades to HIP graph earlier than ROCm 6.2 for vLLM
162+ wget -N https://github.com/ROCm/vllm/raw/fa78403/rocm_patch/libamdhip64.so.6 -P rocm_patch \
163+ && cp rocm_patch/libamdhip64.so.6 /opt/rocm/lib/libamdhip64.so.6 \
164+ # Prevent interference if torch bundles its own HIP runtime
165+ && rm -f "$(python3 -c 'import torch; print(torch.__path__[0])')"/lib/libamdhip64.so* || true;; \
166+ *) ;; esac \
167+ && python3 setup.py clean --all \
168+ && python3 setup.py develop
169+
170+ # Copy amdsmi wheel into final image
171+ RUN --mount=type=bind,from=build_amdsmi,src=/install,target=/install \
172+ mkdir -p libs \
173+ && cp /install/*.whl libs \
174+ # Preemptively uninstall to avoid same-version no-installs
175+ && pip uninstall -y amdsmi;
115176
177+ # Copy triton wheel(s) into final image if they were built
178+ RUN --mount=type=bind,from=build_triton,src=/install,target=/install \
179+ mkdir -p libs \
180+ && if ls /install/*.whl; then \
181+ cp /install/*.whl libs \
182+ # Preemptively uninstall to avoid same-version no-installs
183+ && pip uninstall -y triton; fi
184+
185+ # Copy flash-attn wheel(s) into final image if they were built
186+ RUN --mount=type=bind,from=build_fa,src=/install,target=/install \
187+ mkdir -p libs \
188+ && if ls /install/*.whl; then \
189+ cp /install/*.whl libs \
190+ # Preemptively uninstall to avoid same-version no-installs
191+ && pip uninstall -y flash-attn; fi
192+
193+ # Install wheels that were built to the final image
194+ RUN --mount=type=cache,target=/root/.cache/pip \
195+ if ls libs/*.whl; then \
196+ pip install libs/*.whl; fi
116197
117198CMD ["/bin/bash"]
0 commit comments