Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 54 additions & 10 deletions Dockerfile.rocm
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# default base image
ARG BASE_IMAGE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging"
ARG BASE_IMAGE="rocm/pytorch:rocm6.2_ubuntu20.04_py3.9_pytorch_release_2.3.0"

ARG COMMON_WORKDIR=/app

# The following ARGs should be "0" or "1". If "1", the respective component will be built and installed on top of the base image
ARG BUILD_HIPBLASLT="1"
ARG BUILD_HIPBLASLT="0"
ARG BUILD_RCCL="1"
ARG BUILD_FA="1"
ARG BUILD_TRITON="1"
ARG BUILD_PYTORCH="1"
# This ARG should also be "0" or "1". If "1", the vLLM development directory is obtained via git clone.
# If "0", it is copied in from the local working directory.
ARG REMOTE_VLLM="0"
Expand Down Expand Up @@ -39,11 +40,12 @@ WORKDIR ${COMMON_WORKDIR}
# -----------------------
# hipBLASLt build stages
FROM base AS build_hipblaslt
ARG HIPBLASLT_BRANCH="6f65c6e"
RUN git clone https://github.com/ROCm/hipBLASLt \
ARG HIPBLASLT_BRANCH="e6da924"
RUN apt-get purge -y hipblaslt \
&& git clone https://github.com/ROCm/hipBLASLt.git \
&& cd hipBLASLt \
&& git checkout ${HIPBLASLT_BRANCH} \
&& SCCACHE_IDLE_TIMEOUT=1800 ./install.sh --architecture ${PYTORCH_ROCM_ARCH} \
&& SCCACHE_IDLE_TIMEOUT=1800 ./install.sh --architecture ${PYTORCH_ROCM_ARCH} --legacy_hipblas_direct \
&& cd build/release \
&& make package
FROM scratch AS export_hipblaslt_1
Expand All @@ -55,7 +57,7 @@ FROM export_hipblaslt_${BUILD_HIPBLASLT} AS export_hipblaslt
# -----------------------
# RCCL build stages
FROM base AS build_rccl
ARG RCCL_BRANCH="73221b4"
ARG RCCL_BRANCH="rocm-6.2.0"
RUN git clone https://github.com/ROCm/rccl \
&& cd rccl \
&& git checkout ${RCCL_BRANCH} \
Expand All @@ -69,7 +71,7 @@ FROM export_rccl_${BUILD_RCCL} AS export_rccl
# -----------------------
# flash attn build stages
FROM base AS build_flash_attn
ARG FA_BRANCH="ae7928c"
ARG FA_BRANCH="3cea2fb"
Comment on lines -72 to +74

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe FA depends on Torch. So you'll have to install the Torch wheels before building the FA wheels in this stage (same as what's done in the vLLM stage).

ARG FA_REPO="https://github.com/ROCm/flash-attention.git"
RUN git clone ${FA_REPO} \
&& cd flash-attention \
Expand All @@ -85,9 +87,9 @@ FROM export_flash_attn_${BUILD_FA} AS export_flash_attn
# -----------------------
# Triton build stages
FROM base AS build_triton
ARG TRITON_BRANCH="6ddb79b"
ARG TRITON_REPO="https://github.com/OpenAI/triton.git"
RUN git clone ${TRITON_REPO} \
ARG TRITON_BRANCH="e192dba"
ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
RUN python3 -m pip install ninja cmake wheel pybind11 && git clone ${TRITON_REPO} \

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: only pybind11 is necessary: the rest are inside the base container.

&& cd triton \
&& git checkout ${TRITON_BRANCH} \
&& cd python \
Expand All @@ -105,6 +107,36 @@ RUN cd /opt/rocm/share/amd_smi \
FROM scratch AS export_amdsmi
COPY --from=build_amdsmi /opt/rocm/share/amd_smi/dist/*.whl /

FROM base as build_pytorch
# A commit to fix the output scaling factor issue in _scaled_mm
# Not yet in 2.5.0-rc1
ARG PYTORCH_BRANCH="cedc116"
ARG PYTORCH_VISION_BRANCH="v0.19.1"
ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git"
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
#RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \
#if ls /install/*.deb; then \
# apt-get purge -y hipblaslt \
# && dpkg -i /install/*.deb \
# && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
# && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \
#fi
RUN git clone ${PYTORCH_REPO} pytorch \
&& cd pytorch && git checkout ${PYTORCH_BRANCH} && git submodule update --init --recursive \
&& python tools/amd_build/build_amd.py \
&& CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist \
&& pip install dist/*.whl \
&& cd .. \
&& git clone ${PYTORCH_VISION_REPO} vision \
&& cd vision && git checkout ${PYTORCH_VISION_BRANCH} \
&& python3 setup.py bdist_wheel --dist-dir=dist
FROM scratch as export_pytorch_1
ARG COMMON_WORKDIR
COPY --from=build_pytorch ${COMMON_WORKDIR}/pytorch/dist/*.whl /
COPY --from=build_pytorch ${COMMON_WORKDIR}/vision/dist/*.whl /
FROM scratch as export_pytorch_0
from export_pytorch_${BUILD_PYTORCH} as export_pytorch

# -----------------------
# vLLM (and gradlib) fetch stages
FROM base AS fetch_vllm_0
Expand All @@ -129,6 +161,11 @@ if ls /install/*.deb; then \
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \
fi
# Install pytorch

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove the hipBLASLt install before this because our FP8 no longer needs it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gradlib and tuned gemm need it

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hate to say it, but they work for me without installing hipBLASLt here. And it has always worked.

The old FP8 implementation seems to have used some hipBLASLt API that is unstable, whereas gradlib does not.

RUN --mount=type=bind,from=export_pytorch,src=/,target=/install \
if ls /install/*.whl; then \
pip install /install/*.whl; \
fi
# Build vLLM
RUN cd vllm \
&& python3 setup.py clean --all && python3 setup.py bdist_wheel --dist-dir=dist
Expand Down Expand Up @@ -197,6 +234,13 @@ RUN --mount=type=bind,from=export_amdsmi,src=/,target=/install \
pip uninstall -y amdsmi \
&& pip install /install/*.whl;

RUN --mount=type=bind,from=export_pytorch,src=/,target=/install \
if ls /install/*.whl; then \
# Preemptively uninstall to prevent pip same-version no-installs
pip uninstall -y torch torchvision \
&& pip install /install/*.whl; \
fi

RUN python3 -m pip install --upgrade numba scipy huggingface-hub[cli]

# Install vLLM (and gradlib)
Expand Down
3 changes: 3 additions & 0 deletions gradlib/gradlib/GemmTuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
atol = 1

CACHE_INVALIDATE_BUFFERS = int(os.getenv("CACHE_INVALIDATE_BUFFERS", "37"))
ONE = torch.ones(1, dtype=torch.float32, device='cuda')
Copy link

@mawong-amd mawong-amd Sep 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add this as a class member of the Gemm class? If I'm not mistaken this will run the moment anything from this file is imported (even if it's not used) which can prematurely initialize the CUDA context.

Specifically: let's initialize the class member to None. Then in __init__, if this class member is not set, initialize it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

Copy link

@mawong-amd mawong-amd Sep 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here about how the PR author should not resolve unresolved conversations.



class Gemm:
Expand Down Expand Up @@ -68,6 +69,8 @@ def check_gemm_ref(self, libtype, solidx):
if self.indtype == torch.float8_e4m3fnuz:
ref, _ = torch._scaled_mm(self.inp,
self.weights.t(),
scale_a=ONE,
scale_b=ONE,
out_dtype=self.outdtype)
else:
ref = F.linear(self.inp, self.weights)
Expand Down
14 changes: 10 additions & 4 deletions gradlib/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,14 @@
extra_compile_args={
'cxx': [
'-O3',
'-DLEGACY_HIPBLAS_DIRECT=ON',
],
'nvcc': [
'-O3', '-U__CUDA_NO_HALF_OPERATORS__',
'-O3',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
"-ftemplate-depth=1024"
"-ftemplate-depth=1024",
'-DLEGACY_HIPBLAS_DIRECT=ON',
] + extra_args
}))
ext_modules.append(
Expand All @@ -142,11 +145,14 @@
extra_compile_args={
'cxx': [
'-O3',
'-DLEGACY_HIPBLAS_DIRECT=ON',
],
'nvcc': [
'-O3', '-U__CUDA_NO_HALF_OPERATORS__',
'-O3',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
"-ftemplate-depth=1024"
"-ftemplate-depth=1024",
'-DLEGACY_HIPBLAS_DIRECT=ON',
] + extra_args
}))

Expand Down
25 changes: 12 additions & 13 deletions vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# providing scaling factor for result. This value is created
# as global value to avoid multiple tensor allocations, and
# can be removed once pytorch fixes the bug.
TORCH_SCALED_MM_SCALE_RESULT = torch.ones(1).cuda() if is_hip() else None
TORCH_DEVICE_IDENTITY = torch.ones(1).cuda() if is_hip() else None

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue in initializing a global CUDA tensor as in gradlib GemmTuner.py, with a similar workaround.

I'm concerned that this is a premature optimization. Even if this was done as a multiple allocation: this would not be an issue in CUDA graph mode, while the overhead for eager mode remains to be determined (for such a small tensor, PyTorch's allocator should be able to supply a cached allocation).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have any numbers to back that up?
If you want to make a change to an existing feature, please create a separate PR with a justification

Copy link

@mawong-amd mawong-amd Sep 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see this is a mistake by a PR in upstream.

I'll agree that this one shouldn't be changed in this PR. However, I would suggest not compounding the error in gradlib.

As for numbers: if this was initialized as None as a global and only initialized as a Tensor once when first used, which is what I suggested, you would not have any performance concerns on top of what's done currently.

Also, conversations should be marked resolved by the conversation starter when possible, not when the PR author wishes to close discussions they feel are inconvenient.



def cutlass_fp8_supported() -> bool:
Expand Down Expand Up @@ -132,20 +132,17 @@ def apply_fp8_linear(
per_tensor_weights = (weight_scale.numel() == 1)
per_tensor_activations = (x_scale.numel() == 1)

global TORCH_DEVICE_IDENTITY
if TORCH_DEVICE_IDENTITY.device != weight.device:
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
Copy link

@mawong-amd mawong-amd Sep 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a really bad code smell that is further evidence for why this should not be initialized as a global, but rather should be initialized when it is first used, where the correct device is already set.

if per_tensor_weights and per_tensor_activations:
# Fused GEMM_DQ
global TORCH_SCALED_MM_SCALE_RESULT
if TORCH_SCALED_MM_SCALE_RESULT.device != weight.device:
TORCH_SCALED_MM_SCALE_RESULT = TORCH_SCALED_MM_SCALE_RESULT.to(
weight.device)
output = torch._scaled_mm(
qinput,
weight,
out_dtype=out_dtype,
scale_a=x_scale,
scale_b=weight_scale,
scale_result=TORCH_SCALED_MM_SCALE_RESULT,
bias=bias)
output = torch._scaled_mm(qinput,
weight,
out_dtype=out_dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
Expand Down Expand Up @@ -173,6 +170,8 @@ def apply_fp8_linear(
# Output in fp32 to allow subsequent ops to happen in-place
output, _ = torch._scaled_mm(qinput,
weight,
scale_a=TORCH_DEVICE_IDENTITY,
scale_b=TORCH_DEVICE_IDENTITY,
out_dtype=torch.float32)
# Unpad (undo num_token_padding)
output = torch.narrow(output, 0, 0, input.shape[0])
Expand Down