Skip to content

Commit ec870fb

Browse files
authored
[FEAT] [ROCm]: Add AITER RMS Norm (Layer Norm) Feature (#14959)
Signed-off-by: tjtanaa <[email protected]>
1 parent df14302 commit ec870fb

File tree

5 files changed

+173
-29
lines changed

5 files changed

+173
-29
lines changed

Dockerfile.rocm_base

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
1212
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
1313
ARG FA_BRANCH="b7d29fb"
1414
ARG FA_REPO="https://github.com/ROCm/flash-attention.git"
15+
ARG AITER_BRANCH="21d47a9"
16+
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
1517

1618
FROM ${BASE_IMAGE} AS base
1719

@@ -129,8 +131,18 @@ RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \
129131
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
130132
pip install /install/*.whl
131133

134+
ARG AITER_REPO
135+
ARG AITER_BRANCH
136+
RUN git clone --recursive ${AITER_REPO}
137+
RUN cd aiter \
138+
&& git checkout ${AITER_BRANCH} \
139+
&& git submodule update --init --recursive \
140+
&& pip install -r requirements.txt \
141+
&& PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop && pip show aiter
142+
132143
ARG BASE_IMAGE
133144
ARG HIPBLASLT_BRANCH
145+
ARG HIPBLAS_COMMON_BRANCH
134146
ARG LEGACY_HIPBLASLT_OPTION
135147
ARG RCCL_BRANCH
136148
ARG RCCL_REPO
@@ -155,4 +167,6 @@ RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \
155167
&& echo "PYTORCH_REPO: ${PYTORCH_REPO}" >> /app/versions.txt \
156168
&& echo "PYTORCH_VISION_REPO: ${PYTORCH_VISION_REPO}" >> /app/versions.txt \
157169
&& echo "FA_BRANCH: ${FA_BRANCH}" >> /app/versions.txt \
158-
&& echo "FA_REPO: ${FA_REPO}" >> /app/versions.txt
170+
&& echo "FA_REPO: ${FA_REPO}" >> /app/versions.txt \
171+
&& echo "AITER_BRANCH: ${AITER_BRANCH}" >> /app/versions.txt \
172+
&& echo "AITER_REPO: ${AITER_REPO}" >> /app/versions.txt

tests/model_executor/test_enabled_custom_ops.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77
from vllm.model_executor.layers.activation import (GeluAndMul,
88
ReLUSquaredActivation,
99
SiluAndMul)
10-
from vllm.model_executor.layers.layernorm import RMSNorm
10+
from vllm.model_executor.layers.layernorm import (
11+
RMSNorm, dispatch_cuda_rmsnorm_func, fused_add_rms_norm, rms_norm,
12+
rocm_aiter_fused_add_rms_norm, rocm_aiter_rms_norm)
13+
from vllm.platforms import current_platform
1114

1215

1316
# Registered subclass for test
@@ -87,3 +90,27 @@ def test_enabled_ops_invalid(env: str):
8790
custom_ops=env.split(",")))
8891
with set_current_vllm_config(vllm_config):
8992
RMSNorm(1024).enabled()
93+
94+
95+
@pytest.mark.parametrize("add_residual", [True, False])
96+
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
97+
@pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"])
98+
@pytest.mark.skipif(not current_platform.is_rocm(),
99+
reason="AITER is a feature exclusive for ROCm")
100+
def test_rms_norm_dispatch(add_residual: bool, use_rocm_aiter: str,
101+
use_rocm_aiter_norm: str, monkeypatch):
102+
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
103+
monkeypatch.setenv("VLLM_ROCM_USE_AITER_RMSNORM", use_rocm_aiter_norm)
104+
rms_norm_func = dispatch_cuda_rmsnorm_func(add_residual)
105+
106+
if not add_residual:
107+
if current_platform.is_rocm() and int(use_rocm_aiter) and int(
108+
use_rocm_aiter_norm):
109+
assert rms_norm_func == rocm_aiter_rms_norm
110+
else:
111+
assert rms_norm_func == rms_norm
112+
elif current_platform.is_rocm() and int(use_rocm_aiter) and int(
113+
use_rocm_aiter_norm):
114+
assert rms_norm_func == rocm_aiter_fused_add_rms_norm
115+
else:
116+
assert rms_norm_func == fused_add_rms_norm

tests/models/decoder_only/language/test_models.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33
44
Run `pytest tests/models/test_models.py`.
55
"""
6+
67
import pytest
8+
import torch
9+
10+
from vllm.platforms import current_platform
711

812
from ...utils import check_logprobs_close
913

@@ -13,7 +17,21 @@
1317
# https://github.com/vllm-project/vllm/issues/14524
1418
REQUIRES_V0 = ["microsoft/phi-2", "stabilityai/stablelm-3b-4e1t"]
1519

20+
# This list contains the model that are using AITER kernel.
21+
# Skip model that are not using AITER tests.
22+
# When more AITER kernels are added, this list will not be
23+
# needed as all the models will be calling AITER kernels
24+
# in parts of the operators
25+
AITER_MODEL_LIST = [
26+
"meta-llama/Llama-3.2-1B-Instruct",
27+
"openbmb/MiniCPM3-4B",
28+
"Qwen/Qwen-7B",
29+
"Qwen/Qwen2.5-0.5B-Instruct",
30+
"ehristoforu/Falcon3-MoE-2x7B-Insruct",
31+
]
32+
1633

34+
# @maybe_test_rocm_aiter
1735
@pytest.mark.parametrize(
1836
"model",
1937
[
@@ -69,19 +87,24 @@
6987
@pytest.mark.parametrize("dtype", ["half"])
7088
@pytest.mark.parametrize("max_tokens", [32])
7189
@pytest.mark.parametrize("num_logprobs", [5])
72-
def test_models(
73-
hf_runner,
74-
vllm_runner,
75-
example_prompts,
76-
model: str,
77-
dtype: str,
78-
max_tokens: int,
79-
num_logprobs: int,
80-
monkeypatch,
81-
) -> None:
90+
@pytest.mark.parametrize(
91+
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
92+
def test_models(hf_runner, vllm_runner, example_prompts, model: str,
93+
dtype: str, max_tokens: int, num_logprobs: int,
94+
use_rocm_aiter: bool, monkeypatch) -> None:
95+
8296
if model in REQUIRES_V0:
8397
monkeypatch.setenv("VLLM_USE_V1", "0")
8498

99+
if use_rocm_aiter and (model in AITER_MODEL_LIST):
100+
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
101+
elif use_rocm_aiter and model not in AITER_MODEL_LIST:
102+
# Skip model that are not using AITER tests.
103+
# When more AITER kernels are added, this list will not be
104+
# needed as all the models will be calling AITER kernels
105+
# in parts of the operators
106+
pytest.skip(f"Skipping '{model}' model test with AITER kernel.")
107+
85108
with hf_runner(model, dtype=dtype) as hf_model:
86109
if model.startswith("THUDM/chatglm3"):
87110
hf_model.model.get_output_embeddings = lambda: \
@@ -100,3 +123,10 @@ def test_models(
100123
name_0="hf",
101124
name_1="vllm",
102125
)
126+
if use_rocm_aiter:
127+
# this is to ensure that vllm engine
128+
# has deallocated the memory before running the next
129+
# unit tests. On ROCm, when using AITER
130+
# the memory might not be deallocated completely
131+
# before running the next test case
132+
torch.cuda.synchronize()

vllm/envs.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@
7575
VLLM_SKIP_P2P_CHECK: bool = False
7676
VLLM_DISABLED_KERNELS: list[str] = []
7777
VLLM_USE_V1: bool = True
78+
VLLM_ROCM_USE_AITER: bool = False
79+
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
7880
VLLM_ROCM_FP8_PADDING: bool = True
7981
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
8082
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
@@ -528,6 +530,17 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
528530
"VLLM_USE_V1":
529531
lambda: bool(int(os.getenv("VLLM_USE_V1", "1"))),
530532

533+
# Disable aiter ops unless specifically enabled.
534+
# Acts as a parent switch to enable the rest of the other operations.
535+
"VLLM_ROCM_USE_AITER":
536+
lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in
537+
("true", "1")),
538+
539+
# use aiter rms norm op if aiter ops are enabled.
540+
"VLLM_ROCM_USE_AITER_RMSNORM":
541+
lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in
542+
("true", "1")),
543+
531544
# Pad the fp8 weights to 256 bytes for ROCm
532545
"VLLM_ROCM_FP8_PADDING":
533546
lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))),

vllm/model_executor/layers/layernorm.py

Lines changed: 77 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,77 @@
55
import torch
66
import torch.nn as nn
77

8+
import vllm.envs as envs
89
from vllm.model_executor.custom_op import CustomOp
10+
from vllm.platforms import current_platform
11+
12+
13+
def is_rocm_aiter_rmsnorm_enabled() -> bool:
14+
return current_platform.is_rocm() \
15+
and envs.VLLM_ROCM_USE_AITER_RMSNORM \
16+
and envs.VLLM_ROCM_USE_AITER
17+
18+
19+
def rms_norm(x: torch.Tensor, weight: torch.Tensor,
20+
variance_epsilon: float) -> torch.Tensor:
21+
from vllm import _custom_ops as ops
22+
out = torch.empty_like(x)
23+
ops.rms_norm(
24+
out,
25+
x,
26+
weight,
27+
variance_epsilon,
28+
)
29+
return out
30+
31+
32+
def fused_add_rms_norm(
33+
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
34+
variance_epsilon: float) -> Tuple[torch.Tensor, torch.Tensor]:
35+
from vllm import _custom_ops as ops
36+
ops.fused_add_rms_norm(
37+
x,
38+
residual,
39+
weight,
40+
variance_epsilon,
41+
)
42+
return x, residual
43+
44+
45+
def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor,
46+
variance_epsilon: float) -> torch.Tensor:
47+
48+
import aiter as rocm_aiter
49+
return rocm_aiter.rms_norm(x, weight, variance_epsilon)
50+
51+
52+
def rocm_aiter_fused_add_rms_norm(
53+
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
54+
variance_epsilon: float) -> Tuple[torch.Tensor, torch.Tensor]:
55+
56+
import aiter as rocm_aiter
57+
58+
# Assuming the correct signature for rmsnorm2d_fwd_with_add
59+
rocm_aiter.rmsnorm2d_fwd_with_add(
60+
x, # output
61+
x, # input
62+
residual, # residual input
63+
residual, # residual output
64+
weight,
65+
variance_epsilon,
66+
)
67+
return x, residual
68+
69+
70+
def dispatch_cuda_rmsnorm_func(add_residual: bool):
71+
if add_residual:
72+
if is_rocm_aiter_rmsnorm_enabled():
73+
return rocm_aiter_fused_add_rms_norm
74+
return fused_add_rms_norm
75+
76+
if is_rocm_aiter_rmsnorm_enabled():
77+
return rocm_aiter_rms_norm
78+
return rms_norm
979

1080

1181
@CustomOp.register("rms_norm")
@@ -81,24 +151,14 @@ def forward_cuda(
81151
if self.variance_size_override is not None:
82152
return self.forward_native(x, residual)
83153

84-
from vllm import _custom_ops as ops
154+
add_residual = residual is not None
155+
norm_func = dispatch_cuda_rmsnorm_func(add_residual)
85156

86-
if residual is not None:
87-
ops.fused_add_rms_norm(
88-
x,
89-
residual,
90-
self.weight.data,
91-
self.variance_epsilon,
92-
)
93-
return x, residual
94-
out = torch.empty_like(x)
95-
ops.rms_norm(
96-
out,
97-
x,
98-
self.weight.data,
99-
self.variance_epsilon,
100-
)
101-
return out
157+
if add_residual:
158+
return norm_func(x, residual, self.weight.data,
159+
self.variance_epsilon)
160+
else:
161+
return norm_func(x, self.weight.data, self.variance_epsilon)
102162

103163
def forward_hpu(
104164
self,

0 commit comments

Comments
 (0)