Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit 04a3f7b

Browse files
mgoinrshaw@neuralmagic.com
authored andcommitted
[Kernel] Expand FP8 support to Ampere GPUs using FP8 Marlin (vllm-project#5975)
1 parent 55c44bc commit 04a3f7b

File tree

11 files changed

+1587
-44
lines changed

11 files changed

+1587
-44
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
171171
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
172172
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
173173
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
174+
"csrc/quantization/fp8/fp8_marlin.cu"
174175
"csrc/custom_all_reduce.cu"
175176
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
176177
"csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu"

csrc/ops.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,11 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
9393
int64_t size_k, int64_t size_n,
9494
int64_t num_bits);
9595

96+
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
97+
torch::Tensor& b_scales, torch::Tensor& workspace,
98+
int64_t num_bits, int64_t size_m, int64_t size_n,
99+
int64_t size_k);
100+
96101
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
97102

98103
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,

csrc/quantization/fp8/fp8_marlin.cu

Lines changed: 1308 additions & 0 deletions
Large diffs are not rendered by default.

csrc/torch_bindings.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
137137
ops.def("gptq_marlin_repack", &gptq_marlin_repack);
138138
ops.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack);
139139

140+
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
141+
ops.def("fp8_marlin_gemm", &fp8_marlin_gemm);
142+
ops.impl("fp8_marlin_gemm", torch::kCUDA, &fp8_marlin_gemm);
143+
140144
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
141145
// quantization.
142146
ops.def(

docs/source/quantization/fp8.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ FP8
44
==================
55

66
vLLM supports FP8 (8-bit floating point) weight and activation quantization using hardware acceleration on GPUs such as Nvidia H100 and AMD MI300x.
7-
Currently, only Hopper and Ada Lovelace GPUs are supported.
7+
Currently, only Hopper and Ada Lovelace GPUs are officially supported for W8A8.
8+
Ampere GPUs are supported for W8A16 (weight-only FP8) utilizing Marlin kernels.
89
Quantization of models with FP8 allows for a 2x reduction in model memory requirements and up to a 1.6x improvement in throughput with minimal impact on accuracy.
910

1011
Please visit the HF collection of `quantized FP8 checkpoints of popular LLMs ready to use with vLLM <https://huggingface.co/collections/neuralmagic/fp8-llms-for-vllm-666742ed2b78b7ac8df13127>`_.

docs/source/quantization/supported_hardware.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Implementation Volta Turing Ampere Ada Hopper AMD GPU Intel GPU x86
1111
AQLM ✅ ✅ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
1212
AWQ ❌ ✅ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
1313
DeepSpeedFP ✅ ✅ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
14-
FP8 ❌ ❌ ✅ ✅ ❌ ❌ ❌ ❌ ❌
14+
FP8 ❌ ❌ ✅ ✅ ❌ ❌ ❌ ❌ ❌
1515
Marlin ❌ ❌ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
1616
GPTQ ✅ ✅ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
1717
SqueezeLLM ✅ ✅ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌

tests/kernels/test_marlin_gemm.py

Lines changed: 84 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,16 @@
99
from vllm import _custom_ops as ops
1010
from vllm.model_executor.layers.quantization.gptq_marlin import (
1111
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
12-
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS)
12+
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS,
13+
marlin_permute_scales)
1314
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
1415
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
1516
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
1617
from vllm.model_executor.layers.quantization.utils.marlin_perms import (
1718
marlin_perm)
1819
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
1920
MarlinWorkspace, compute_max_diff, is_marlin_supported, marlin_24_quantize,
20-
marlin_quantize, marlin_weights)
21+
marlin_quantize, marlin_weights, pack_fp8_to_int32)
2122
from vllm.model_executor.layers.quantization.utils.quant_utils import (
2223
gptq_pack, quantize_weights, sort_weights)
2324

@@ -43,9 +44,11 @@
4344
(67, 13, 11),
4445
]
4546

47+
DTYPES = [torch.float16, torch.bfloat16]
4648

47-
def rand_data(shape):
48-
return torch.randn(shape, dtype=torch.half, device="cuda")
49+
50+
def rand_data(shape, dtype=torch.float16):
51+
return torch.randn(shape, dtype=dtype, device="cuda")
4952

5053

5154
@pytest.mark.skipif(not is_marlin_supported(),
@@ -222,3 +225,80 @@ def test_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, mnk_factors):
222225
print("max_diff = {}".format(max_diff))
223226

224227
assert max_diff < 0.04
228+
229+
230+
@pytest.mark.skipif(not is_marlin_supported(),
231+
reason="Marlin is not supported on this GPU type.")
232+
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
233+
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
234+
@pytest.mark.parametrize("num_bits", [8])
235+
@pytest.mark.parametrize("group_size", [-1])
236+
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
237+
@pytest.mark.parametrize("dtype", DTYPES)
238+
def test_fp8_marlin_gemm(
239+
k_chunk,
240+
n_chunk,
241+
num_bits,
242+
group_size,
243+
mnk_factors,
244+
dtype,
245+
):
246+
m_factor, n_factor, k_factor = mnk_factors
247+
248+
size_m = m_factor
249+
size_k = k_chunk * k_factor
250+
size_n = n_chunk * n_factor
251+
252+
print(f"MNK = {size_m} {size_n} {size_k}")
253+
print(f"groupsize = {group_size}")
254+
255+
a_input = rand_data((size_m, size_k), dtype=dtype)
256+
b_weight = rand_data((size_k, size_n), dtype=dtype)
257+
258+
# WEIGHTS
259+
fp8_weight, weight_scale = ops.scaled_fp8_quant(b_weight, scale=None)
260+
# Repack weights to gptq format (packed int32 elements)
261+
packed_gptq_qweight = pack_fp8_to_int32(fp8_weight)
262+
# Repack weights to marlin format
263+
marlin_qweight = ops.gptq_marlin_repack(
264+
b_q_weight=packed_gptq_qweight,
265+
perm=torch.empty(0, dtype=torch.int, device="cuda"),
266+
size_k=size_k,
267+
size_n=size_n,
268+
num_bits=8,
269+
)
270+
271+
# WEIGHT SCALES
272+
# Currently Marlin doesn't support per-tensor scales, so we
273+
# expand it to channelwise
274+
scales = weight_scale.repeat(1, size_n).to(a_input.dtype).to("cuda")
275+
# Permute scales
276+
marlin_scales = marlin_permute_scales(
277+
s=scales,
278+
size_k=size_k,
279+
size_n=size_n,
280+
group_size=-1,
281+
num_bits=8,
282+
)
283+
284+
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
285+
GPTQ_MARLIN_MAX_PARALLEL)
286+
287+
output = ops.fp8_marlin_gemm(
288+
a=a_input,
289+
b_q_weight=marlin_qweight,
290+
b_scales=marlin_scales,
291+
workspace=workspace.scratch,
292+
num_bits=num_bits,
293+
size_m=a_input.shape[0],
294+
size_n=b_weight.shape[1],
295+
size_k=a_input.shape[1],
296+
)
297+
output_ref = torch.matmul(a_input, b_weight)
298+
299+
torch.cuda.synchronize()
300+
301+
max_diff = compute_max_diff(output, output_ref)
302+
print("max_diff = {}".format(max_diff))
303+
304+
assert max_diff < 0.04

tests/quantization/test_fp8.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from tests.nm_utils.utils_skip import should_skip_test_group
99
from tests.quantization.utils import is_quant_method_supported
10-
from vllm._custom_ops import scaled_fp8_quant
10+
from vllm import _custom_ops as ops
1111
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
1212

1313
if should_skip_test_group(group_name="TEST_QUANTIZATION"):
@@ -40,7 +40,16 @@ def test_load_fp16_model(vllm_runner) -> None:
4040
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
4141
fc1 = model.model.decoder.layers[0].fc1
4242
assert isinstance(fc1.quant_method, Fp8LinearMethod)
43-
assert fc1.weight.dtype == torch.float8_e4m3fn
43+
44+
capability = torch.cuda.get_device_capability()
45+
capability = capability[0] * 10 + capability[1]
46+
if capability >= 89:
47+
# For GPUs with hardware support, we keep weights in fp8
48+
assert fc1.weight.dtype == torch.float8_e4m3fn
49+
else:
50+
# For GPUs without hardware support, we pack the fp8 weights
51+
# for weight-only quantization using Marlin kernels
52+
assert fc1.weight.dtype == torch.int32
4453

4554

4655
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
@@ -68,19 +77,19 @@ def per_tensor_dequantize(tensor, inv_scale, dtype):
6877
x = (torch.randn(size=(11, 11), device="cuda") * 13).to(dtype)
6978

7079
# Dynamic quantization
71-
ref_y, inv_scale = scaled_fp8_quant(x, None)
80+
ref_y, inv_scale = ops.scaled_fp8_quant(x, None)
7281
ref_y = per_tensor_dequantize(ref_y, inv_scale, dtype)
7382

7483
# Reference dynamic quantizaton
7584
y = quantize_ref(x, inv_scale)
7685
assert torch.allclose(ref_y, per_tensor_dequantize(y, inv_scale, dtype))
7786

7887
# Static quantization
79-
y, _ = scaled_fp8_quant(x, inv_scale)
88+
y, _ = ops.scaled_fp8_quant(x, inv_scale)
8089
assert torch.allclose(ref_y, per_tensor_dequantize(y, inv_scale, dtype))
8190

8291
# Padding
83-
y, _ = scaled_fp8_quant(x, inv_scale, batch_dim_padding=17)
92+
y, _ = ops.scaled_fp8_quant(x, inv_scale, batch_dim_padding=17)
8493
assert y.shape[0] == 17
8594
assert torch.allclose(
8695
ref_y,

vllm/_custom_ops.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,15 @@ def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
271271
size_k, is_k_full)
272272

273273

274+
# fp8 marlin
275+
def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
276+
b_scales: torch.Tensor, workspace: torch.Tensor,
277+
num_bits: int, size_m: int, size_n: int,
278+
size_k: int) -> torch.Tensor:
279+
return torch.ops._C.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
280+
num_bits, size_m, size_n, size_k)
281+
282+
274283
# fp8
275284
def scaled_fp8_quant(
276285
input: torch.Tensor,

0 commit comments

Comments
 (0)