Skip to content

Commit 4965ec4

Browse files
authored
[FEAT] [ROCm] Add AITER int8 scaled gemm kernel (#15433)
Signed-off-by: tjtanaa <[email protected]>
1 parent 73aa704 commit 4965ec4

File tree

4 files changed

+202
-5
lines changed

4 files changed

+202
-5
lines changed

tests/quantization/test_compressed_tensors.py

Lines changed: 72 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,23 @@
2020
sparse_cutlass_supported)
2121
from vllm.platforms import current_platform
2222

23+
# AITER only supports per-channel-per-channel INT8 gemm
24+
# and per-tensor-per-tensor INT8 GEMM.
25+
# It does not support mix precision MM and mix quantization scheme.
26+
ROCM_AITER_SUPPORTED_INT8_MODEL = [
27+
"neuralmagic/Llama-3.2-1B-quantized.w8a8",
28+
"nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2"
29+
]
30+
31+
# TritonScaledMMLinearKernel only supports symmetric quantization.
32+
ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL = [
33+
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change",
34+
"nm-testing/tinyllama-oneshot-w8-channel-a8-tensor",
35+
"neuralmagic/Llama-3.2-1B-quantized.w8a8",
36+
"nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2",
37+
"nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2",
38+
]
39+
2340

2441
@pytest.fixture(scope="function", autouse=True)
2542
def use_v0_only(monkeypatch):
@@ -57,6 +74,11 @@ def use_v0_only(monkeypatch):
5774
)
5875
def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
5976
model_path, strategy, quant_type, shape_0, is_symmetric = model_args
77+
78+
if current_platform.is_rocm(
79+
) and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL:
80+
pytest.skip(f"Skip model {model_path} as it is not support on ROCm.")
81+
6082
with vllm_runner(model_path, enforce_eager=True) as llm:
6183

6284
def check_model(model):
@@ -123,14 +145,30 @@ def zp_valid(zp: Optional[torch.Tensor]):
123145
)
124146
@pytest.mark.parametrize("max_tokens", [32])
125147
@pytest.mark.parametrize("num_logprobs", [10])
148+
@pytest.mark.parametrize(
149+
"use_aiter", [True, False] if current_platform.is_rocm() else [False])
126150
def test_compressed_tensors_w8a8_logprobs(
127151
hf_runner,
128152
vllm_runner,
129153
example_prompts,
130154
model_path,
131155
max_tokens,
132156
num_logprobs,
157+
use_aiter,
158+
monkeypatch,
133159
):
160+
161+
if current_platform.is_rocm(
162+
) and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL:
163+
pytest.skip(f"Skip model {model_path} as it is not support on ROCm.")
164+
165+
if use_aiter:
166+
if model_path not in ROCM_AITER_SUPPORTED_INT8_MODEL:
167+
pytest.skip(
168+
f"Skip model {model_path} as it is not support by aiter.")
169+
# this will enable VLLM_ROCM_USE_AITER_LINEAR
170+
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
171+
134172
dtype = "bfloat16"
135173

136174
# skip language translation prompt for the static per tensor asym model
@@ -154,6 +192,9 @@ def test_compressed_tensors_w8a8_logprobs(
154192
name_1="vllm",
155193
)
156194

195+
if current_platform.is_rocm():
196+
torch.cuda.synchronize()
197+
157198

158199
def test_compressed_tensors_no_enforce_eager(vllm_runner):
159200
model_path = "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
@@ -177,8 +218,27 @@ def test_compressed_tensors_no_enforce_eager(vllm_runner):
177218
),
178219
],
179220
)
180-
def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args):
221+
@pytest.mark.parametrize(
222+
"use_aiter", [True, False] if current_platform.is_rocm() else [False])
223+
def test_compressed_tensors_w8a8_dynamic_per_token(
224+
vllm_runner,
225+
model_args,
226+
use_aiter,
227+
monkeypatch,
228+
):
181229
model_path, strategy = model_args
230+
231+
if current_platform.is_rocm(
232+
) and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL:
233+
pytest.skip(f"Skip model {model_path} as it is not support on ROCm.")
234+
235+
if use_aiter:
236+
if model_path not in ROCM_AITER_SUPPORTED_INT8_MODEL:
237+
pytest.skip(
238+
f"Skip model {model_path} as it is not support by aiter.")
239+
# this will enable VLLM_ROCM_USE_AITER_LINEAR
240+
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
241+
182242
with vllm_runner(model_path, dtype=torch.float16) as llm:
183243

184244
def check_model(model):
@@ -207,6 +267,8 @@ def check_model(model):
207267
("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4),
208268
],
209269
)
270+
@pytest.mark.skipif(not current_platform.is_cuda(),
271+
reason="The tests are skipped on non-CUDA platform.")
210272
def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
211273
model, strategy, group, pack_factor = wNa16_args
212274
with vllm_runner(model) as llm:
@@ -231,6 +293,8 @@ def check_model(model):
231293
assert output
232294

233295

296+
@pytest.mark.skipif(not current_platform.is_cuda(),
297+
reason="This test is skipped on non-CUDA platform.")
234298
def test_compressed_tensors_w4a16_marlin24(vllm_runner):
235299
model_path = "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t"
236300
with vllm_runner(model_path) as llm:
@@ -271,7 +335,7 @@ def check_model(model):
271335

272336
if isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8):
273337
assert len(qkv_proj.input_scale.shape) == 0
274-
assert qkv_proj.weight.dtype is torch.float8_e4m3fn
338+
assert qkv_proj.weight.dtype is current_platform.fp8_dtype()
275339
assert qkv_proj.weight_scale.dtype is torch.float32
276340
assert len(qkv_proj.weight_scale.shape) == 0
277341

@@ -281,6 +345,8 @@ def check_model(model):
281345
assert output
282346

283347

348+
@pytest.mark.skipif(not current_platform.is_cuda(),
349+
reason="This test is skipped on non-CUDA platform.")
284350
def test_compressed_tensors_kv_cache(vllm_runner):
285351
model_path = "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme"
286352
with vllm_runner(model_path, kv_cache_dtype="fp8") as llm:
@@ -309,7 +375,8 @@ def _test_2of4_quant_models(qkv_proj,
309375

310376

311377
@pytest.mark.skipif(
312-
not current_platform.has_device_capability(90),
378+
not current_platform.is_cuda()
379+
or not current_platform.has_device_capability(90),
313380
reason="Sparse FP8 is not yet supported on this GPU type.",
314381
)
315382
@pytest.mark.parametrize(
@@ -356,7 +423,8 @@ def check_model(model):
356423

357424

358425
@pytest.mark.skipif(
359-
not current_platform.has_device_capability(90),
426+
not current_platform.is_cuda()
427+
or not current_platform.has_device_capability(90),
360428
reason="Sparse FP8 is not yet supported on this GPU type.",
361429
)
362430
@pytest.mark.parametrize(

vllm/envs.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
VLLM_DISABLED_KERNELS: list[str] = []
7676
VLLM_USE_V1: bool = True
7777
VLLM_ROCM_USE_AITER: bool = False
78+
VLLM_ROCM_USE_AITER_LINEAR: bool = True
7879
VLLM_ROCM_USE_AITER_MOE: bool = True
7980
VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE: bool = False
8081
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
@@ -524,6 +525,13 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
524525
lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in
525526
("true", "1")),
526527

528+
# use aiter linear op if aiter ops are enabled
529+
# The following list of related ops
530+
# - scaled_mm (per-tensor / rowwise)
531+
"VLLM_ROCM_USE_AITER_LINEAR":
532+
lambda: (os.getenv("VLLM_ROCM_USE_AITER_LINEAR", "True").lower() in
533+
("true", "1")),
534+
527535
# Whether to use aiter moe ops.
528536
# By default is enabled.
529537
"VLLM_ROCM_USE_AITER_MOE":

vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import os
44
from typing import Dict, List, Optional, Type
55

6+
from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import (
7+
AiterScaledMMLinearKernel)
68
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import (
79
CutlassScaledMMLinearKernel)
810
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
@@ -17,7 +19,7 @@
1719
_POSSIBLE_KERNELS: Dict[PlatformEnum, List[Type[ScaledMMLinearKernel]]] = {
1820
PlatformEnum.CPU: [CutlassScaledMMLinearKernel],
1921
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel],
20-
PlatformEnum.ROCM: [TritonScaledMMLinearKernel],
22+
PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel],
2123
PlatformEnum.TPU: [XLAScaledMMLinearKernel],
2224
}
2325

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from typing import Optional, Tuple
4+
5+
import torch
6+
7+
import vllm.envs as envs
8+
from vllm import _custom_ops as ops
9+
from vllm.platforms import current_platform
10+
11+
from .cutlass import CutlassScaledMMLinearKernel
12+
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
13+
14+
15+
class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
16+
17+
@classmethod
18+
def get_min_capability(cls) -> int:
19+
return 90
20+
21+
@classmethod
22+
def can_implement(
23+
cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]:
24+
if not current_platform.is_rocm():
25+
return (
26+
False,
27+
"AiterScaledMMLinearKernel requires `aiter` which is not " +
28+
"currently supported on non-ROCm platform.")
29+
30+
try:
31+
import aiter # noqa: F401 # deliberately attempt to import aiter
32+
except Exception:
33+
return (
34+
False,
35+
"AiterScaledMMLinearKernel requires `aiter` which is not " +
36+
"installed on ROCm.")
37+
# Check if rocm_aiter_gemm_w8a8_scaled_mm is enabled
38+
if not (
39+
envs.VLLM_ROCM_USE_AITER_LINEAR \
40+
and envs.VLLM_ROCM_USE_AITER
41+
):
42+
return (False, "AiterScaledMMLinearKernel is disabled. " +
43+
"Enable by setting `VLLM_ROCM_USE_AITER=1` " +
44+
"and `VLLM_ROCM_USE_AITER_LINEAR=1`. " +
45+
"`VLLM_ROCM_USE_AITER_LINEAR` default is True.")
46+
47+
if not c.input_symmetric:
48+
return (False,
49+
"AiterScaledMMLinearKernel only supports symmetric " +
50+
"quantization.")
51+
return True, None
52+
53+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
54+
super().process_weights_after_loading(layer)
55+
56+
def apply_weights(self,
57+
layer: torch.nn.Module,
58+
x: torch.Tensor,
59+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
60+
"""
61+
`AiterScaledMMLinearKernel` implements a fused version of
62+
`output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
63+
where scale_a * a and scale_b * b are implemented using numpy-style
64+
broadcasting.
65+
Currently only support per-tensor-per-tensor GEMM
66+
and per-token-per-channel GEMM through AITER
67+
w8a8 scaled gemm. `AiterScaledMMLinearKernel` also does not support
68+
ATIER block scaled GEMM and mix-precision GEMM.
69+
"""
70+
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)
71+
72+
# ops.scaled_int8_quant supports both dynamic and static quant:
73+
# * dynamic, i_s is None and x_s computed from x.
74+
# * static, i_s is scalar and x_s is i_s.
75+
symmetric = azp_adj is None
76+
assert symmetric, ("AiterScaledMMLinearKernel only supports"
77+
" symmetric quantization.")
78+
x_q, x_s, x_zp = ops.scaled_int8_quant(x,
79+
i_s,
80+
i_zp,
81+
symmetric=symmetric)
82+
83+
assert x_zp is None, ("AiterScaledMMLinearKernel only supports"
84+
" symmetric quantization.")
85+
out_dtype = x.dtype
86+
87+
assert (w_q.shape[0] % 16 == 0 and w_q.shape[1] % 16 == 0)
88+
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
89+
assert bias is None or bias.shape[0] == w_q.shape[
90+
1] and bias.dtype == out_dtype
91+
92+
m = x_q.shape[0] # a
93+
n = w_q.shape[1] # b
94+
95+
per_tensor_scale_a = (x_s.numel() == 1)
96+
per_tensor_scale_b = (w_s.numel() == 1)
97+
per_token_scale_a = (x_s.numel() == m)
98+
per_channel_scale_b = (w_s.numel() == n)
99+
100+
# @TODO:
101+
# Maybe broadcast the per-tensor-scale into per-channel-scale
102+
# if one of the scale is a per-channel-scale.
103+
# For now, it only supports:
104+
# - per-tensor-per-tensor a8w8 scaled GEMM, and
105+
# - per-token-per-channel a8w8 scaled GEMM
106+
assert ((per_tensor_scale_a and per_tensor_scale_b)
107+
or (per_token_scale_a and per_channel_scale_b)), (
108+
"Currently only support per-tensor-per-tensor GEMM " +
109+
" and per-token-per-channel GEMM through AITER"
110+
" w8a8 scaled gemm. `AiterScaledMMLinearKernel` " +
111+
"does not support AITER block scaled GEMM.")
112+
113+
from aiter import gemm_a8w8_CK
114+
115+
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
116+
# a to be [M, K]
117+
# b to be [N, K]
118+
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
119+
return gemm_a8w8_CK(x_q, w_q.t(), x_s, w_s, bias).to(out_dtype)

0 commit comments

Comments
 (0)