Skip to content

Commit 1186dfc

Browse files
charlifuAlvant
authored andcommitted
[Feature][Hardware][Amd] Add fp8 Linear Layer for Rocm (vllm-project#7210)
Signed-off-by: Alvant <[email protected]>
1 parent c919da9 commit 1186dfc

File tree

7 files changed

+164
-49
lines changed

7 files changed

+164
-49
lines changed

csrc/quantization/fp8/common.cu

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,18 @@
99

1010
#include "../../reduction_utils.cuh"
1111

12+
#ifndef USE_ROCM
13+
using FP8_TYPE = c10::Float8_e4m3fn;
14+
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
15+
std::numeric_limits<FP8_TYPE>::max();
16+
#else
17+
#include "amd/hip_float8.h"
18+
using FP8_TYPE = c10::Float8_e4m3fnuz;
19+
// Using the default max value from pytorch (240.0) will cause accuracy
20+
// issue when running dynamic quantization. Here use 224.0f for rocm.
21+
constexpr auto FP8_E4M3_MAX = 224.0f;
22+
#endif
23+
1224
namespace vllm {
1325

1426
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
@@ -21,11 +33,9 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
2133
return old;
2234
}
2335

24-
#define FP8_E4M3_MAX std::numeric_limits<c10::Float8_e4m3fn>::max()
25-
2636
template <bool is_scale_inverted>
27-
__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(
28-
float const val, float const scale) {
37+
__device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
38+
float const scale) {
2939
float x = 0.0f;
3040
if constexpr (is_scale_inverted) {
3141
x = val * scale;
@@ -34,7 +44,13 @@ __device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(
3444
}
3545

3646
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
47+
#ifndef USE_ROCM
3748
return static_cast<c10::Float8_e4m3fn>(r);
49+
#else
50+
// Use hardware cvt instruction for fp8 on rocm
51+
return c10::Float8_e4m3fnuz(hip_fp8(r).data,
52+
c10::Float8_e4m3fnuz::from_bits());
53+
#endif
3854
}
3955

4056
// Compute the absolute maximum m of the input tensor and store
@@ -74,8 +90,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
7490
// Finally, since cache[0] contains the maximum for this thread block,
7591
// atomically write the max to the target location
7692
if (threadIdx.x == 0) {
77-
atomicMaxFloat(scale,
78-
cache[0] / std::numeric_limits<c10::Float8_e4m3fn>::max());
93+
atomicMaxFloat(scale, cache[0] / FP8_E4M3_MAX);
7994
}
8095
}
8196

@@ -88,10 +103,10 @@ struct __align__(8) vec4_t {
88103
};
89104

90105
typedef struct __align__(4) {
91-
c10::Float8_e4m3fn x;
92-
c10::Float8_e4m3fn y;
93-
c10::Float8_e4m3fn z;
94-
c10::Float8_e4m3fn w;
106+
FP8_TYPE x;
107+
FP8_TYPE y;
108+
FP8_TYPE z;
109+
FP8_TYPE w;
95110
}
96111
float8x4_t;
97112

@@ -124,7 +139,7 @@ __device__ float thread_max_vec(scalar_t const* __restrict__ input,
124139
}
125140

126141
template <typename scalar_t, bool is_scale_inverted>
127-
__device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
142+
__device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out,
128143
scalar_t const* __restrict__ input,
129144
float const scale,
130145
int64_t const num_elems,
@@ -160,7 +175,7 @@ __device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
160175
}
161176

162177
template <typename scalar_t>
163-
__global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
178+
__global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out,
164179
const scalar_t* __restrict__ input,
165180
const float* __restrict__ scale,
166181
int64_t num_elems) {
@@ -175,7 +190,7 @@ __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
175190

176191
template <typename scalar_t>
177192
__global__ void dynamic_per_token_scaled_fp8_quant_kernel(
178-
c10::Float8_e4m3fn* __restrict__ out, float* __restrict__ scale,
193+
FP8_TYPE* __restrict__ out, float* __restrict__ scale,
179194
scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
180195
const int hidden_size) {
181196
float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
@@ -184,7 +199,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
184199
int const token_idx = blockIdx.x;
185200

186201
scalar_t const* __restrict__ token_input = &input[token_idx * hidden_size];
187-
c10::Float8_e4m3fn* __restrict__ token_output = &out[token_idx * hidden_size];
202+
FP8_TYPE* __restrict__ token_output = &out[token_idx * hidden_size];
188203

189204
// For vectorization, token_input and token_output pointers need to be
190205
// aligned at 8-byte and 4-byte addresses respectively.
@@ -241,7 +256,7 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
241256
VLLM_DISPATCH_FLOATING_TYPES(
242257
input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
243258
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
244-
out.data_ptr<c10::Float8_e4m3fn>(), input.data_ptr<scalar_t>(),
259+
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
245260
scale.data_ptr<float>(), num_elems);
246261
});
247262
}
@@ -261,7 +276,7 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
261276
vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>(
262277
scale.data_ptr<float>(), input.data_ptr<scalar_t>(), num_elems);
263278
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
264-
out.data_ptr<c10::Float8_e4m3fn>(), input.data_ptr<scalar_t>(),
279+
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
265280
scale.data_ptr<float>(), num_elems);
266281
});
267282
}
@@ -284,7 +299,7 @@ void dynamic_per_token_scaled_fp8_quant(
284299
input.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel", [&] {
285300
vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t>
286301
<<<grid, block, 0, stream>>>(
287-
out.data_ptr<c10::Float8_e4m3fn>(), scales.data_ptr<float>(),
302+
out.data_ptr<FP8_TYPE>(), scales.data_ptr<float>(),
288303
input.data_ptr<scalar_t>(),
289304
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
290305
hidden_size);

tests/kernels/quant_utils.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@
22

33
import torch
44

5+
from vllm.utils import is_hip
6+
7+
# Using the default value (240.0) from pytorch will cause accuracy
8+
# issue on dynamic quantization models. Here use 224.0 for rocm.
9+
ROCM_FP8_MAX = 224.0
10+
FP8_DTYPE = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
11+
512

613
def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
714
return torch.as_tensor(x, dtype=torch.float32, device='cuda')
@@ -11,13 +18,15 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
1118
scale_ub: Optional[torch.tensor] = None) \
1219
-> Tuple[torch.tensor, torch.tensor]:
1320

14-
assert quant_dtype in [torch.int8, torch.float8_e4m3fn]
21+
assert quant_dtype in [torch.int8, FP8_DTYPE]
1522
if scale_ub is not None:
16-
assert quant_dtype == torch.float8_e4m3fn
23+
assert quant_dtype == FP8_DTYPE
1724

1825
qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \
1926
else torch.finfo(quant_dtype)
20-
qtype_max = as_float32_tensor(qtype_traits.max)
27+
qtype_traits_max = ROCM_FP8_MAX if is_hip() else qtype_traits.max
28+
qtype_traits_min = -ROCM_FP8_MAX if is_hip() else qtype_traits.min
29+
qtype_max = as_float32_tensor(qtype_traits_max)
2130
s_1 = as_float32_tensor(1.0)
2231
s_512 = as_float32_tensor(512.0)
2332

@@ -37,15 +46,15 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
3746
iscales = as_float32_tensor(s_1 / scales)
3847
torch_out = as_float32_tensor(x) * iscales
3948
torch_out = torch_out.round()
40-
torch_out = torch_out.clamp(qtype_traits.min,
41-
qtype_traits.max).to(quant_dtype)
49+
torch_out = torch_out.clamp(qtype_traits_min,
50+
qtype_traits_max).to(quant_dtype)
4251
else:
43-
assert quant_dtype == torch.float8_e4m3fn
52+
assert quant_dtype == FP8_DTYPE
4453
min_scaling_factor = s_1 / (qtype_max * s_512)
4554
scales = scales.clamp(min=min_scaling_factor)
4655
torch_out = as_float32_tensor(x) / scales
47-
torch_out = torch_out.clamp(qtype_traits.min,
48-
qtype_traits.max).to(quant_dtype)
56+
torch_out = torch_out.clamp(qtype_traits_min,
57+
qtype_traits_max).to(quant_dtype)
4958

5059
return torch_out, scales
5160

@@ -56,8 +65,10 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
5665
def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
5766
-> Tuple[torch.tensor, torch.tensor]:
5867

59-
fp8_traits = torch.finfo(torch.float8_e4m3fn)
60-
fp8_max = as_float32_tensor(fp8_traits.max)
68+
fp8_traits = torch.finfo(FP8_DTYPE)
69+
fp8_traits_max = ROCM_FP8_MAX if is_hip() else fp8_traits.max
70+
fp8_traits_min = -ROCM_FP8_MAX if is_hip() else fp8_traits.min
71+
fp8_max = as_float32_tensor(fp8_traits_max)
6172
one = as_float32_tensor(1.0)
6273

6374
# For fp8, in order to match the cuda kernel output, we have to do exactly
@@ -68,5 +79,5 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
6879
ref_scale = x_max / fp8_max
6980
ref_iscale = one / ref_scale
7081
ref_out = (as_float32_tensor(x) * ref_iscale).clamp(
71-
fp8_traits.min, fp8_traits.max).to(dtype=torch.float8_e4m3fn)
82+
fp8_traits_min, fp8_traits_max).to(FP8_DTYPE)
7283
return ref_out, ref_scale.view((1, ))

tests/kernels/test_fp8_quant.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
import torch
33

44
import vllm._custom_ops as ops
5-
from tests.kernels.quant_utils import (ref_dynamic_per_tensor_fp8_quant,
5+
from tests.kernels.quant_utils import (FP8_DTYPE,
6+
ref_dynamic_per_tensor_fp8_quant,
67
ref_dynamic_per_token_quant)
78

89
DTYPES = [torch.half, torch.bfloat16, torch.float]
@@ -31,8 +32,7 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
3132

3233
scale_ub = torch.mean(x).to(dtype=torch.float32, device='cuda') \
3334
if scale_ub else None
34-
ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn,
35-
scale_ub)
35+
ref_out, ref_scales = ref_dynamic_per_token_quant(x, FP8_DTYPE, scale_ub)
3636
ops_out, ops_scales = ops.scaled_fp8_quant(x,
3737
scale_ub=scale_ub,
3838
use_per_token_if_dynamic=True)

vllm/_custom_ops.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,9 +369,12 @@ def scaled_fp8_quant(
369369
# This code assumes batch_dim and num_tokens are flattened
370370
assert (input.ndim == 2)
371371
shape: Union[Tuple[int, int], torch.Size] = input.shape
372+
# For rocm, the output fp8 dtype is torch.float_e3m3fnuz
373+
out_dtype: torch.dtype = torch.float8_e4m3fnuz if vllm.utils.is_hip() \
374+
else torch.float8_e4m3fn
372375
if num_token_padding:
373376
shape = (max(num_token_padding, input.shape[0]), shape[1])
374-
output = torch.empty(shape, device=input.device, dtype=torch.float8_e4m3fn)
377+
output = torch.empty(shape, device=input.device, dtype=out_dtype)
375378

376379
if scale is None:
377380
if use_per_token_if_dynamic:

vllm/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def _parse_quant_hf_config(self):
240240

241241
def _verify_quantization(self) -> None:
242242
supported_quantization = [*QUANTIZATION_METHODS]
243-
rocm_supported_quantization = ["gptq", "squeezellm"]
243+
rocm_supported_quantization = ["gptq", "squeezellm", "fp8"]
244244
optimized_quantization_methods = [
245245
"fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin",
246246
"fbgemm_fp8", "compressed_tensors", "compressed-tensors"

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,11 @@
2020
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
2121
all_close_1d, apply_fp8_linear, convert_to_channelwise,
2222
create_per_tensor_scale_param, cutlass_fp8_supported,
23-
per_tensor_dequantize, requantize_with_max_scale)
23+
normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
24+
requantize_with_max_scale)
2425
from vllm.model_executor.utils import set_weight_attrs
2526
from vllm.platforms import current_platform
26-
from vllm.utils import print_warning_once
27+
from vllm.utils import is_hip, print_warning_once
2728

2829
ACTIVATION_SCHEMES = ["static", "dynamic"]
2930

@@ -120,6 +121,9 @@ def __init__(self, quant_config: Fp8Config):
120121
capability = current_platform.get_device_capability()
121122
capability = capability[0] * 10 + capability[1]
122123
self.use_marlin = capability < 89 or envs.VLLM_TEST_FORCE_FP8_MARLIN
124+
# Disable marlin for rocm
125+
if is_hip():
126+
self.use_marlin = False
123127

124128
def create_weights(
125129
self,
@@ -168,6 +172,8 @@ def create_weights(
168172
scale = create_per_tensor_scale_param(output_partition_sizes,
169173
**extra_weight_attrs)
170174
layer.register_parameter("input_scale", scale)
175+
else:
176+
layer.register_parameter("input_scale", None)
171177

172178
def process_weights_after_loading(self, layer: Module) -> None:
173179
# If checkpoint not serialized fp8, quantize the weights.
@@ -202,9 +208,23 @@ def process_weights_after_loading(self, layer: Module) -> None:
202208
# requantize the logical shards as a single weight.
203209
else:
204210
# Dequant -> Quant with max scale so we can run per tensor.
211+
weight = layer.weight
212+
weight_scale = layer.weight_scale
213+
214+
# If rocm, use float8_e4m3fnuz.
215+
if is_hip():
216+
weight, weight_scale, input_scale = \
217+
normalize_e4m3fn_to_e4m3fnuz(
218+
weight=weight,
219+
weight_scale=weight_scale,
220+
input_scale=layer.input_scale)
221+
if input_scale is not None:
222+
layer.input_scale = Parameter(input_scale,
223+
requires_grad=False)
224+
205225
weight_scale, weight = requantize_with_max_scale(
206-
weight=layer.weight,
207-
weight_scale=layer.weight_scale,
226+
weight=weight,
227+
weight_scale=weight_scale,
208228
logical_widths=layer.logical_widths,
209229
)
210230

@@ -214,8 +234,6 @@ def process_weights_after_loading(self, layer: Module) -> None:
214234
if self.quant_config.activation_scheme == "static":
215235
layer.input_scale = Parameter(layer.input_scale.max(),
216236
requires_grad=False)
217-
else:
218-
layer.input_scale = None
219237

220238
if self.use_marlin:
221239
prepare_fp8_layer_for_marlin(layer)
@@ -346,10 +364,12 @@ def process_weights_after_loading(self, layer: Module) -> None:
346364

347365
# If checkpoint is fp16, quantize in place.
348366
if not self.quant_config.is_checkpoint_fp8_serialized:
367+
# If rocm, use float8_e4m3fnuz as dtype
368+
fp8_dtype = torch.float8_e4m3fnuz \
369+
if is_hip() else torch.float8_e4m3fn
349370
w13_weight = torch.empty_like(layer.w13_weight.data,
350-
dtype=torch.float8_e4m3fn)
351-
w2_weight = torch.empty_like(layer.w2_weight.data,
352-
dtype=torch.float8_e4m3fn)
371+
dtype=fp8_dtype)
372+
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
353373

354374
# Re-initialize w13_scale because we directly quantize
355375
# merged w13 weights and generate a single scaling factor.
@@ -393,6 +413,32 @@ def process_weights_after_loading(self, layer: Module) -> None:
393413
layer.w13_input_scale.max(), requires_grad=False)
394414
layer.w2_input_scale = torch.nn.Parameter(
395415
layer.w2_input_scale.max(), requires_grad=False)
416+
# If rocm, normalize the weights and scales to e4m3fnuz
417+
if is_hip():
418+
# Normalize the weights and scales
419+
w13_weight, w13_weight_scale, w13_input_scale = \
420+
normalize_e4m3fn_to_e4m3fnuz(
421+
layer.w13_weight, layer.w13_weight_scale,
422+
layer.w13_input_scale)
423+
w2_weight, w2_weight_scale, w2_input_scale = \
424+
normalize_e4m3fn_to_e4m3fnuz(
425+
layer.w2_weight, layer.w2_weight_scale,
426+
layer.w2_input_scale)
427+
# Reset the parameter
428+
layer.w13_weight = torch.nn.Parameter(w13_weight,
429+
requires_grad=False)
430+
layer.w13_weight_scale = torch.nn.Parameter(
431+
w13_weight_scale, requires_grad=False)
432+
if w13_input_scale is not None:
433+
layer.w13_input_scale = torch.nn.Parameter(
434+
w13_input_scale, requires_grad=False)
435+
layer.w2_weight = torch.nn.Parameter(w2_weight,
436+
requires_grad=False)
437+
layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
438+
requires_grad=False)
439+
if w2_input_scale is not None:
440+
layer.w2_input_scale = torch.nn.Parameter(
441+
w2_input_scale, requires_grad=False)
396442

397443
# Fp8 moe kernel needs single weight scale for w13 per expert.
398444
# We take the max then dequant and requant each expert.

0 commit comments

Comments
 (0)