11# default base image
2- ARG BASE_IMAGE="rocm/pytorch:rocm6.1. 2_ubuntu20.04_py3.9_pytorch_staging "
2+ ARG BASE_IMAGE="rocm/pytorch:rocm6.2_ubuntu20.04_py3.9_pytorch_release_2.3.0 "
33
44ARG COMMON_WORKDIR=/app
55
66# The following ARGs should be "0" or "1". If "1", the respective component will be built and installed on top of the base image
7- ARG BUILD_HIPBLASLT="1 "
7+ ARG BUILD_HIPBLASLT="0 "
88ARG BUILD_RCCL="1"
99ARG BUILD_FA="1"
1010ARG BUILD_TRITON="1"
11+ ARG BUILD_PYTORCH="1"
1112# This ARG should also be "0" or "1". If "1", the vLLM development directory is obtained via git clone.
1213# If "0", it is copied in from the local working directory.
1314ARG REMOTE_VLLM="0"
@@ -39,11 +40,12 @@ WORKDIR ${COMMON_WORKDIR}
3940# -----------------------
4041# hipBLASLt build stages
4142FROM base AS build_hipblaslt
42- ARG HIPBLASLT_BRANCH="6f65c6e"
43- RUN git clone https://github.com/ROCm/hipBLASLt \
43+ ARG HIPBLASLT_BRANCH="e6da924"
44+ RUN apt-get purge -y hipblaslt \
45+ && git clone https://github.com/ROCm/hipBLASLt.git \
4446 && cd hipBLASLt \
4547 && git checkout ${HIPBLASLT_BRANCH} \
46- && SCCACHE_IDLE_TIMEOUT=1800 ./install.sh --architecture ${PYTORCH_ROCM_ARCH} \
48+ && SCCACHE_IDLE_TIMEOUT=1800 ./install.sh --architecture ${PYTORCH_ROCM_ARCH} --legacy_hipblas_direct \
4749 && cd build/release \
4850 && make package
4951FROM scratch AS export_hipblaslt_1
@@ -55,7 +57,7 @@ FROM export_hipblaslt_${BUILD_HIPBLASLT} AS export_hipblaslt
5557# -----------------------
5658# RCCL build stages
5759FROM base AS build_rccl
58- ARG RCCL_BRANCH="73221b4 "
60+ ARG RCCL_BRANCH="rocm-6.2.0 "
5961RUN git clone https://github.com/ROCm/rccl \
6062 && cd rccl \
6163 && git checkout ${RCCL_BRANCH} \
@@ -69,7 +71,7 @@ FROM export_rccl_${BUILD_RCCL} AS export_rccl
6971# -----------------------
7072# flash attn build stages
7173FROM base AS build_flash_attn
72- ARG FA_BRANCH="ae7928c "
74+ ARG FA_BRANCH="3cea2fb "
7375ARG FA_REPO="https://github.com/ROCm/flash-attention.git"
7476RUN git clone ${FA_REPO} \
7577 && cd flash-attention \
@@ -85,9 +87,9 @@ FROM export_flash_attn_${BUILD_FA} AS export_flash_attn
8587# -----------------------
8688# Triton build stages
8789FROM base AS build_triton
88- ARG TRITON_BRANCH="6ddb79b "
89- ARG TRITON_REPO="https://github.com/OpenAI /triton.git"
90- RUN git clone ${TRITON_REPO} \
90+ ARG TRITON_BRANCH="e192dba "
91+ ARG TRITON_REPO="https://github.com/triton-lang /triton.git"
92+ RUN python3 -m pip install ninja cmake wheel pybind11 && git clone ${TRITON_REPO} \
9193 && cd triton \
9294 && git checkout ${TRITON_BRANCH} \
9395 && cd python \
@@ -105,6 +107,36 @@ RUN cd /opt/rocm/share/amd_smi \
105107FROM scratch AS export_amdsmi
106108COPY --from=build_amdsmi /opt/rocm/share/amd_smi/dist/*.whl /
107109
110+ FROM base as build_pytorch
111+ # A commit to fix the output scaling factor issue in _scaled_mm
112+ # Not yet in 2.5.0-rc1
113+ ARG PYTORCH_BRANCH="cedc116"
114+ ARG PYTORCH_VISION_BRANCH="v0.19.1"
115+ ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git"
116+ ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
117+ #RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \
118+ #if ls /install/*.deb; then \
119+ # apt-get purge -y hipblaslt \
120+ # && dpkg -i /install/*.deb \
121+ # && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
122+ # && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \
123+ #fi
124+ RUN git clone ${PYTORCH_REPO} pytorch \
125+ && cd pytorch && git checkout ${PYTORCH_BRANCH} && git submodule update --init --recursive \
126+ && python tools/amd_build/build_amd.py \
127+ && CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist \
128+ && pip install dist/*.whl \
129+ && cd .. \
130+ && git clone ${PYTORCH_VISION_REPO} vision \
131+ && cd vision && git checkout ${PYTORCH_VISION_BRANCH} \
132+ && python3 setup.py bdist_wheel --dist-dir=dist
133+ FROM scratch as export_pytorch_1
134+ ARG COMMON_WORKDIR
135+ COPY --from=build_pytorch ${COMMON_WORKDIR}/pytorch/dist/*.whl /
136+ COPY --from=build_pytorch ${COMMON_WORKDIR}/vision/dist/*.whl /
137+ FROM scratch as export_pytorch_0
138+ from export_pytorch_${BUILD_PYTORCH} as export_pytorch
139+
108140# -----------------------
109141# vLLM (and gradlib) fetch stages
110142FROM base AS fetch_vllm_0
@@ -129,6 +161,11 @@ if ls /install/*.deb; then \
129161 && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
130162 && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \
131163fi
164+ # Install pytorch
165+ RUN --mount=type=bind,from=export_pytorch,src=/,target=/install \
166+ if ls /install/*.whl; then \
167+ pip install /install/*.whl; \
168+ fi
132169# Build vLLM
133170RUN cd vllm \
134171 && python3 setup.py clean --all && python3 setup.py bdist_wheel --dist-dir=dist
@@ -197,6 +234,13 @@ RUN --mount=type=bind,from=export_amdsmi,src=/,target=/install \
197234 pip uninstall -y amdsmi \
198235 && pip install /install/*.whl;
199236
237+ RUN --mount=type=bind,from=export_pytorch,src=/,target=/install \
238+ if ls /install/*.whl; then \
239+ # Preemptively uninstall to prevent pip same-version no-installs
240+ pip uninstall -y torch torchvision \
241+ && pip install /install/*.whl; \
242+ fi
243+
200244RUN python3 -m pip install --upgrade numba scipy huggingface-hub[cli]
201245
202246# Install vLLM (and gradlib)
0 commit comments