Skip to content

Commit e3d0a1d

Browse files
authored
[Quantizaton] [AMD] Add support for running DeepSeek int8 w8a8 MoE on ROCm (#17558)
Signed-off-by: Randall Smith <[email protected]>
1 parent d47b605 commit e3d0a1d

File tree

2 files changed

+29
-3
lines changed

2 files changed

+29
-3
lines changed

vllm/_custom_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -559,15 +559,15 @@ def cutlass_scaled_mm(a: torch.Tensor,
559559
scale_a.shape * [1, 128] == a.shape
560560
scale_b.shape * [128, 128] == b.shape
561561
"""
562-
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
563562
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
564563
assert bias is None or bias.shape[0] == b.shape[
565564
1] and bias.dtype == out_dtype
566565

567566
m = a.shape[0]
568567
n = b.shape[1]
569568

570-
if current_platform.is_rocm():
569+
cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
570+
if current_platform.is_rocm() or not cutlass_compatible_b:
571571
triton_scaled_mm_module = importlib.import_module(
572572
"vllm.model_executor.layers.quantization.compressed_tensors."
573573
"triton_scaled_mm")

vllm/model_executor/layers/quantization/utils/int8_utils.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,32 @@ def block_dequant(
8585
return x_dq_block
8686

8787

88+
if current_platform.is_rocm():
89+
from triton.language import core
90+
91+
# NOTE: This can be removed when hip.libdevice.round() is available.
92+
@core.extern
93+
def round_f32(arg0, _builder=None):
94+
return core.extern_elementwise("",
95+
"", [arg0], {
96+
(core.dtype("fp32"), ):
97+
("llvm.round", core.dtype("fp32")),
98+
(core.dtype("fp64"), ):
99+
("llvm.round", core.dtype("fp64")),
100+
},
101+
is_pure=True,
102+
_builder=_builder)
103+
104+
@triton.jit
105+
def round_int8(x):
106+
return round_f32(x).to(tl.int8)
107+
else:
108+
109+
@triton.jit
110+
def round_int8(x):
111+
return tl.extra.cuda.libdevice.round(x).to(tl.int8)
112+
113+
88114
@triton.jit
89115
def _per_token_quant_int8(
90116
x_ptr,
@@ -106,7 +132,7 @@ def _per_token_quant_int8(
106132
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
107133
scale_x = absmax / 127
108134
x_q = x * (127 / absmax)
109-
x_q = tl.extra.cuda.libdevice.round(x_q).to(tl.int8)
135+
x_q = round_int8(x_q)
110136

111137
tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
112138
tl.store(scale_ptr + row_id, scale_x)

0 commit comments

Comments
 (0)