Skip to content

Commit 164ce38

Browse files
authored
6.2 dockerfile (#176)
* Trying to modernize the dockerfile, pinning rccl; triton; pytorch; hipblaslt to the latest required versions * Dockerfile fixes. Using the scaling factors in scaled_mm where they are required by torch 2.5 or acceptable by others * Building torchvision too when building torch * gradlib as a not-cmake project doesn't inherit `target_compile_definitions(hipblaslt PUBLIC LEGACY_HIPBLAS_DIRECT )` * Using a specific torch commit with scaled_mm fix until it is in mainline. Fixed scaled_mm in gradlib for no reason at all * No point in pinning hipblaslt to rocm6.2 release, if we want to build it, we'll want 0.10 * Removed torch requirement
1 parent b53c35d commit 164ce38

File tree

4 files changed

+79
-27
lines changed

4 files changed

+79
-27
lines changed

Dockerfile.rocm

Lines changed: 54 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
# default base image
2-
ARG BASE_IMAGE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging"
2+
ARG BASE_IMAGE="rocm/pytorch:rocm6.2_ubuntu20.04_py3.9_pytorch_release_2.3.0"
33

44
ARG COMMON_WORKDIR=/app
55

66
# The following ARGs should be "0" or "1". If "1", the respective component will be built and installed on top of the base image
7-
ARG BUILD_HIPBLASLT="1"
7+
ARG BUILD_HIPBLASLT="0"
88
ARG BUILD_RCCL="1"
99
ARG BUILD_FA="1"
1010
ARG BUILD_TRITON="1"
11+
ARG BUILD_PYTORCH="1"
1112
# This ARG should also be "0" or "1". If "1", the vLLM development directory is obtained via git clone.
1213
# If "0", it is copied in from the local working directory.
1314
ARG REMOTE_VLLM="0"
@@ -39,11 +40,12 @@ WORKDIR ${COMMON_WORKDIR}
3940
# -----------------------
4041
# hipBLASLt build stages
4142
FROM base AS build_hipblaslt
42-
ARG HIPBLASLT_BRANCH="6f65c6e"
43-
RUN git clone https://github.com/ROCm/hipBLASLt \
43+
ARG HIPBLASLT_BRANCH="e6da924"
44+
RUN apt-get purge -y hipblaslt \
45+
&& git clone https://github.com/ROCm/hipBLASLt.git \
4446
&& cd hipBLASLt \
4547
&& git checkout ${HIPBLASLT_BRANCH} \
46-
&& SCCACHE_IDLE_TIMEOUT=1800 ./install.sh --architecture ${PYTORCH_ROCM_ARCH} \
48+
&& SCCACHE_IDLE_TIMEOUT=1800 ./install.sh --architecture ${PYTORCH_ROCM_ARCH} --legacy_hipblas_direct \
4749
&& cd build/release \
4850
&& make package
4951
FROM scratch AS export_hipblaslt_1
@@ -55,7 +57,7 @@ FROM export_hipblaslt_${BUILD_HIPBLASLT} AS export_hipblaslt
5557
# -----------------------
5658
# RCCL build stages
5759
FROM base AS build_rccl
58-
ARG RCCL_BRANCH="73221b4"
60+
ARG RCCL_BRANCH="rocm-6.2.0"
5961
RUN git clone https://github.com/ROCm/rccl \
6062
&& cd rccl \
6163
&& git checkout ${RCCL_BRANCH} \
@@ -69,7 +71,7 @@ FROM export_rccl_${BUILD_RCCL} AS export_rccl
6971
# -----------------------
7072
# flash attn build stages
7173
FROM base AS build_flash_attn
72-
ARG FA_BRANCH="ae7928c"
74+
ARG FA_BRANCH="3cea2fb"
7375
ARG FA_REPO="https://github.com/ROCm/flash-attention.git"
7476
RUN git clone ${FA_REPO} \
7577
&& cd flash-attention \
@@ -85,9 +87,9 @@ FROM export_flash_attn_${BUILD_FA} AS export_flash_attn
8587
# -----------------------
8688
# Triton build stages
8789
FROM base AS build_triton
88-
ARG TRITON_BRANCH="6ddb79b"
89-
ARG TRITON_REPO="https://github.com/OpenAI/triton.git"
90-
RUN git clone ${TRITON_REPO} \
90+
ARG TRITON_BRANCH="e192dba"
91+
ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
92+
RUN python3 -m pip install ninja cmake wheel pybind11 && git clone ${TRITON_REPO} \
9193
&& cd triton \
9294
&& git checkout ${TRITON_BRANCH} \
9395
&& cd python \
@@ -105,6 +107,36 @@ RUN cd /opt/rocm/share/amd_smi \
105107
FROM scratch AS export_amdsmi
106108
COPY --from=build_amdsmi /opt/rocm/share/amd_smi/dist/*.whl /
107109

110+
FROM base as build_pytorch
111+
# A commit to fix the output scaling factor issue in _scaled_mm
112+
# Not yet in 2.5.0-rc1
113+
ARG PYTORCH_BRANCH="cedc116"
114+
ARG PYTORCH_VISION_BRANCH="v0.19.1"
115+
ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git"
116+
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
117+
#RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \
118+
#if ls /install/*.deb; then \
119+
# apt-get purge -y hipblaslt \
120+
# && dpkg -i /install/*.deb \
121+
# && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
122+
# && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \
123+
#fi
124+
RUN git clone ${PYTORCH_REPO} pytorch \
125+
&& cd pytorch && git checkout ${PYTORCH_BRANCH} && git submodule update --init --recursive \
126+
&& python tools/amd_build/build_amd.py \
127+
&& CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist \
128+
&& pip install dist/*.whl \
129+
&& cd .. \
130+
&& git clone ${PYTORCH_VISION_REPO} vision \
131+
&& cd vision && git checkout ${PYTORCH_VISION_BRANCH} \
132+
&& python3 setup.py bdist_wheel --dist-dir=dist
133+
FROM scratch as export_pytorch_1
134+
ARG COMMON_WORKDIR
135+
COPY --from=build_pytorch ${COMMON_WORKDIR}/pytorch/dist/*.whl /
136+
COPY --from=build_pytorch ${COMMON_WORKDIR}/vision/dist/*.whl /
137+
FROM scratch as export_pytorch_0
138+
from export_pytorch_${BUILD_PYTORCH} as export_pytorch
139+
108140
# -----------------------
109141
# vLLM (and gradlib) fetch stages
110142
FROM base AS fetch_vllm_0
@@ -129,6 +161,11 @@ if ls /install/*.deb; then \
129161
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
130162
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \
131163
fi
164+
# Install pytorch
165+
RUN --mount=type=bind,from=export_pytorch,src=/,target=/install \
166+
if ls /install/*.whl; then \
167+
pip install /install/*.whl; \
168+
fi
132169
# Build vLLM
133170
RUN cd vllm \
134171
&& python3 setup.py clean --all && python3 setup.py bdist_wheel --dist-dir=dist
@@ -197,6 +234,13 @@ RUN --mount=type=bind,from=export_amdsmi,src=/,target=/install \
197234
pip uninstall -y amdsmi \
198235
&& pip install /install/*.whl;
199236

237+
RUN --mount=type=bind,from=export_pytorch,src=/,target=/install \
238+
if ls /install/*.whl; then \
239+
# Preemptively uninstall to prevent pip same-version no-installs
240+
pip uninstall -y torch torchvision \
241+
&& pip install /install/*.whl; \
242+
fi
243+
200244
RUN python3 -m pip install --upgrade numba scipy huggingface-hub[cli]
201245

202246
# Install vLLM (and gradlib)

gradlib/gradlib/GemmTuner.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
atol = 1
1616

1717
CACHE_INVALIDATE_BUFFERS = int(os.getenv("CACHE_INVALIDATE_BUFFERS", "37"))
18+
ONE = torch.ones(1, dtype=torch.float32, device='cuda')
1819

1920

2021
class Gemm:
@@ -68,6 +69,8 @@ def check_gemm_ref(self, libtype, solidx):
6869
if self.indtype == torch.float8_e4m3fnuz:
6970
ref, _ = torch._scaled_mm(self.inp,
7071
self.weights.t(),
72+
scale_a=ONE,
73+
scale_b=ONE,
7174
out_dtype=self.outdtype)
7275
else:
7376
ref = F.linear(self.inp, self.weights)

gradlib/setup.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,14 @@
125125
extra_compile_args={
126126
'cxx': [
127127
'-O3',
128+
'-DLEGACY_HIPBLAS_DIRECT=ON',
128129
],
129130
'nvcc': [
130-
'-O3', '-U__CUDA_NO_HALF_OPERATORS__',
131+
'-O3',
132+
'-U__CUDA_NO_HALF_OPERATORS__',
131133
'-U__CUDA_NO_HALF_CONVERSIONS__',
132-
"-ftemplate-depth=1024"
134+
"-ftemplate-depth=1024",
135+
'-DLEGACY_HIPBLAS_DIRECT=ON',
133136
] + extra_args
134137
}))
135138
ext_modules.append(
@@ -142,11 +145,14 @@
142145
extra_compile_args={
143146
'cxx': [
144147
'-O3',
148+
'-DLEGACY_HIPBLAS_DIRECT=ON',
145149
],
146150
'nvcc': [
147-
'-O3', '-U__CUDA_NO_HALF_OPERATORS__',
151+
'-O3',
152+
'-U__CUDA_NO_HALF_OPERATORS__',
148153
'-U__CUDA_NO_HALF_CONVERSIONS__',
149-
"-ftemplate-depth=1024"
154+
"-ftemplate-depth=1024",
155+
'-DLEGACY_HIPBLAS_DIRECT=ON',
150156
] + extra_args
151157
}))
152158

vllm/model_executor/layers/quantization/utils/w8a8_utils.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# providing scaling factor for result. This value is created
1111
# as global value to avoid multiple tensor allocations, and
1212
# can be removed once pytorch fixes the bug.
13-
TORCH_SCALED_MM_SCALE_RESULT = torch.ones(1).cuda() if is_hip() else None
13+
TORCH_DEVICE_IDENTITY = torch.ones(1).cuda() if is_hip() else None
1414

1515

1616
def cutlass_fp8_supported() -> bool:
@@ -132,20 +132,17 @@ def apply_fp8_linear(
132132
per_tensor_weights = (weight_scale.numel() == 1)
133133
per_tensor_activations = (x_scale.numel() == 1)
134134

135+
global TORCH_DEVICE_IDENTITY
136+
if TORCH_DEVICE_IDENTITY.device != weight.device:
137+
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
135138
if per_tensor_weights and per_tensor_activations:
136139
# Fused GEMM_DQ
137-
global TORCH_SCALED_MM_SCALE_RESULT
138-
if TORCH_SCALED_MM_SCALE_RESULT.device != weight.device:
139-
TORCH_SCALED_MM_SCALE_RESULT = TORCH_SCALED_MM_SCALE_RESULT.to(
140-
weight.device)
141-
output = torch._scaled_mm(
142-
qinput,
143-
weight,
144-
out_dtype=out_dtype,
145-
scale_a=x_scale,
146-
scale_b=weight_scale,
147-
scale_result=TORCH_SCALED_MM_SCALE_RESULT,
148-
bias=bias)
140+
output = torch._scaled_mm(qinput,
141+
weight,
142+
out_dtype=out_dtype,
143+
scale_a=x_scale,
144+
scale_b=weight_scale,
145+
bias=bias)
149146
# A fix for discrepancy in scaled_mm which returns tuple
150147
# for torch < 2.5 and a single value in torch >= 2.5
151148
if type(output) is tuple and len(output) == 2:
@@ -173,6 +170,8 @@ def apply_fp8_linear(
173170
# Output in fp32 to allow subsequent ops to happen in-place
174171
output, _ = torch._scaled_mm(qinput,
175172
weight,
173+
scale_a=TORCH_DEVICE_IDENTITY,
174+
scale_b=TORCH_DEVICE_IDENTITY,
176175
out_dtype=torch.float32)
177176
# Unpad (undo num_token_padding)
178177
output = torch.narrow(output, 0, 0, input.shape[0])

0 commit comments

Comments
 (0)