Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
214c6e3
add amd fp8 linear layer support
charlifu Jul 31, 2024
b6e2ebc
Merge branch 'main' into amd_fp8
charlifu Jul 31, 2024
9733ec5
fix scaled_mm
charlifu Aug 1, 2024
80d47c6
Merge branch 'main' into amd_fp8
charlifu Aug 1, 2024
8e6a502
fix header file
charlifu Aug 1, 2024
dd0eca8
add comments
charlifu Aug 6, 2024
3113196
Merge branch 'main' into amd_fp8
charlifu Aug 6, 2024
74ab95a
add comments
charlifu Aug 6, 2024
7163b84
fix linter
charlifu Aug 6, 2024
fdefb15
fix linter
charlifu Aug 6, 2024
5f26ac1
change padding size back to 17
charlifu Aug 6, 2024
23624f8
fix linter
charlifu Aug 6, 2024
bcc5287
fix Rocmplatform not found
charlifu Aug 6, 2024
f25e8d6
fix RocmPlatform not found
charlifu Aug 6, 2024
23802f3
address comments
charlifu Aug 6, 2024
937733e
fix linter
charlifu Aug 6, 2024
746607f
use is_hip
charlifu Aug 6, 2024
85642ef
remove import torch.version
charlifu Aug 6, 2024
4e4517b
fix linter
charlifu Aug 6, 2024
92b0640
address comments
charlifu Aug 7, 2024
bcfa571
add unit test
charlifu Aug 8, 2024
4d3c4bb
* fix linter
charlifu Aug 8, 2024
be12b71
Merge branch 'main' into amd_fp8
charlifu Aug 8, 2024
0e39e37
Disable marlin for rocm
charlifu Aug 8, 2024
ff19105
fix linter
charlifu Aug 8, 2024
374c958
fix linter
charlifu Aug 8, 2024
a00654a
modify comments
charlifu Aug 8, 2024
acbd97b
Use constexpr and using key word
charlifu Aug 8, 2024
cadbe32
fix FP8_E4M3_MAX not fount on nvidia
charlifu Aug 8, 2024
3932a3f
fix
charlifu Aug 8, 2024
fccc711
add auto
charlifu Aug 8, 2024
c3b022f
.
charlifu Aug 8, 2024
3ba8c06
fix clang
charlifu Aug 8, 2024
476f4e0
remove C10_HOST_DEVICE for rocm
charlifu Aug 9, 2024
b00aa19
Merge branch 'main' into amd_fp8
charlifu Aug 9, 2024
586fece
.
charlifu Aug 9, 2024
ad2ac93
Merge branch 'main' into amd_fp8
charlifu Aug 11, 2024
21b15c8
Add weight normalization for hip in fp8moe
charlifu Aug 13, 2024
b9580bc
fix linter
charlifu Aug 13, 2024
e4f7af2
Merge branch 'main' into amd_fp8
charlifu Aug 14, 2024
62cbd4d
fix naming in fp8moe
charlifu Aug 14, 2024
d9b1e27
fix linter
charlifu Aug 14, 2024
f8a8c8b
fix comment
charlifu Aug 14, 2024
19346d6
add support for dynamic quant of fp8moe
charlifu Aug 14, 2024
2f8fe27
* using normalize_e4m3fn_to_e4m3fnuz for the name of convert function
charlifu Aug 14, 2024
c56f31e
fix linter
charlifu Aug 14, 2024
492725f
Merge branch 'main' into amd_fp8
charlifu Aug 14, 2024
747d538
Merge branch 'main' into amd_fp8
charlifu Aug 15, 2024
843d8b6
Merge branch 'main' into amd_fp8
charlifu Aug 15, 2024
fb84ca6
Merge branch 'main' into amd_fp8
charlifu Aug 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 32 additions & 17 deletions csrc/quantization/fp8/common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,18 @@

#include "../../reduction_utils.cuh"

#ifndef USE_ROCM
using FP8_TYPE = c10::Float8_e4m3fn;
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
std::numeric_limits<FP8_TYPE>::max();
#else
#include "amd/hip_float8.h"
using FP8_TYPE = c10::Float8_e4m3fnuz;
// Using the default max value from pytorch (240.0) will cause accuracy
// issue when running dynamic quantization. Here use 224.0f for rocm.
constexpr auto FP8_E4M3_MAX = 224.0f;
#endif

namespace vllm {

__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
Expand All @@ -21,11 +33,9 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
return old;
}

#define FP8_E4M3_MAX std::numeric_limits<c10::Float8_e4m3fn>::max()

template <bool is_scale_inverted>
__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(
float const val, float const scale) {
__device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
float const scale) {
float x = 0.0f;
if constexpr (is_scale_inverted) {
x = val * scale;
Expand All @@ -34,7 +44,13 @@ __device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(
}

float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
#ifndef USE_ROCM
return static_cast<c10::Float8_e4m3fn>(r);
#else
// Use hardware cvt instruction for fp8 on rocm
return c10::Float8_e4m3fnuz(hip_fp8(r).data,
c10::Float8_e4m3fnuz::from_bits());
#endif
}

// Compute the absolute maximum m of the input tensor and store
Expand Down Expand Up @@ -74,8 +90,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
// Finally, since cache[0] contains the maximum for this thread block,
// atomically write the max to the target location
if (threadIdx.x == 0) {
atomicMaxFloat(scale,
cache[0] / std::numeric_limits<c10::Float8_e4m3fn>::max());
atomicMaxFloat(scale, cache[0] / FP8_E4M3_MAX);
}
}

Expand All @@ -88,10 +103,10 @@ struct __align__(8) vec4_t {
};

typedef struct __align__(4) {
c10::Float8_e4m3fn x;
c10::Float8_e4m3fn y;
c10::Float8_e4m3fn z;
c10::Float8_e4m3fn w;
FP8_TYPE x;
FP8_TYPE y;
FP8_TYPE z;
FP8_TYPE w;
}
float8x4_t;

Expand Down Expand Up @@ -124,7 +139,7 @@ __device__ float thread_max_vec(scalar_t const* __restrict__ input,
}

template <typename scalar_t, bool is_scale_inverted>
__device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
__device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out,
scalar_t const* __restrict__ input,
float const scale,
int64_t const num_elems,
Expand Down Expand Up @@ -160,7 +175,7 @@ __device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
}

template <typename scalar_t>
__global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
__global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out,
const scalar_t* __restrict__ input,
const float* __restrict__ scale,
int64_t num_elems) {
Expand All @@ -175,7 +190,7 @@ __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,

template <typename scalar_t>
__global__ void dynamic_per_token_scaled_fp8_quant_kernel(
c10::Float8_e4m3fn* __restrict__ out, float* __restrict__ scale,
FP8_TYPE* __restrict__ out, float* __restrict__ scale,
scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
const int hidden_size) {
float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
Expand All @@ -184,7 +199,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
int const token_idx = blockIdx.x;

scalar_t const* __restrict__ token_input = &input[token_idx * hidden_size];
c10::Float8_e4m3fn* __restrict__ token_output = &out[token_idx * hidden_size];
FP8_TYPE* __restrict__ token_output = &out[token_idx * hidden_size];

// For vectorization, token_input and token_output pointers need to be
// aligned at 8-byte and 4-byte addresses respectively.
Expand Down Expand Up @@ -241,7 +256,7 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<c10::Float8_e4m3fn>(), input.data_ptr<scalar_t>(),
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
scale.data_ptr<float>(), num_elems);
});
}
Expand All @@ -261,7 +276,7 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>(
scale.data_ptr<float>(), input.data_ptr<scalar_t>(), num_elems);
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<c10::Float8_e4m3fn>(), input.data_ptr<scalar_t>(),
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
scale.data_ptr<float>(), num_elems);
});
}
Expand All @@ -284,7 +299,7 @@ void dynamic_per_token_scaled_fp8_quant(
input.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel", [&] {
vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t>
<<<grid, block, 0, stream>>>(
out.data_ptr<c10::Float8_e4m3fn>(), scales.data_ptr<float>(),
out.data_ptr<FP8_TYPE>(), scales.data_ptr<float>(),
input.data_ptr<scalar_t>(),
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
hidden_size);
Expand Down
33 changes: 22 additions & 11 deletions tests/kernels/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@

import torch

from vllm.utils import is_hip

# Using the default value (240.0) from pytorch will cause accuracy
# issue on dynamic quantization models. Here use 224.0 for rocm.
ROCM_FP8_MAX = 224.0
FP8_DTYPE = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn


def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
return torch.as_tensor(x, dtype=torch.float32, device='cuda')
Expand All @@ -11,13 +18,15 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
scale_ub: Optional[torch.tensor] = None) \
-> Tuple[torch.tensor, torch.tensor]:

assert quant_dtype in [torch.int8, torch.float8_e4m3fn]
assert quant_dtype in [torch.int8, FP8_DTYPE]
if scale_ub is not None:
assert quant_dtype == torch.float8_e4m3fn
assert quant_dtype == FP8_DTYPE

qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \
else torch.finfo(quant_dtype)
qtype_max = as_float32_tensor(qtype_traits.max)
qtype_traits_max = ROCM_FP8_MAX if is_hip() else qtype_traits.max
qtype_traits_min = -ROCM_FP8_MAX if is_hip() else qtype_traits.min
qtype_max = as_float32_tensor(qtype_traits_max)
s_1 = as_float32_tensor(1.0)
s_512 = as_float32_tensor(512.0)

Expand All @@ -37,15 +46,15 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
iscales = as_float32_tensor(s_1 / scales)
torch_out = as_float32_tensor(x) * iscales
torch_out = torch_out.round()
torch_out = torch_out.clamp(qtype_traits.min,
qtype_traits.max).to(quant_dtype)
torch_out = torch_out.clamp(qtype_traits_min,
qtype_traits_max).to(quant_dtype)
else:
assert quant_dtype == torch.float8_e4m3fn
assert quant_dtype == FP8_DTYPE
min_scaling_factor = s_1 / (qtype_max * s_512)
scales = scales.clamp(min=min_scaling_factor)
torch_out = as_float32_tensor(x) / scales
torch_out = torch_out.clamp(qtype_traits.min,
qtype_traits.max).to(quant_dtype)
torch_out = torch_out.clamp(qtype_traits_min,
qtype_traits_max).to(quant_dtype)

return torch_out, scales

Expand All @@ -56,8 +65,10 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
-> Tuple[torch.tensor, torch.tensor]:

fp8_traits = torch.finfo(torch.float8_e4m3fn)
fp8_max = as_float32_tensor(fp8_traits.max)
fp8_traits = torch.finfo(FP8_DTYPE)
fp8_traits_max = ROCM_FP8_MAX if is_hip() else fp8_traits.max
fp8_traits_min = -ROCM_FP8_MAX if is_hip() else fp8_traits.min
fp8_max = as_float32_tensor(fp8_traits_max)
one = as_float32_tensor(1.0)

# For fp8, in order to match the cuda kernel output, we have to do exactly
Expand All @@ -68,5 +79,5 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
ref_scale = x_max / fp8_max
ref_iscale = one / ref_scale
ref_out = (as_float32_tensor(x) * ref_iscale).clamp(
fp8_traits.min, fp8_traits.max).to(dtype=torch.float8_e4m3fn)
fp8_traits_min, fp8_traits_max).to(FP8_DTYPE)
return ref_out, ref_scale.view((1, ))
6 changes: 3 additions & 3 deletions tests/kernels/test_fp8_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import torch

import vllm._custom_ops as ops
from tests.kernels.quant_utils import (ref_dynamic_per_tensor_fp8_quant,
from tests.kernels.quant_utils import (FP8_DTYPE,
ref_dynamic_per_tensor_fp8_quant,
ref_dynamic_per_token_quant)

DTYPES = [torch.half, torch.bfloat16, torch.float]
Expand Down Expand Up @@ -31,8 +32,7 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,

scale_ub = torch.mean(x).to(dtype=torch.float32, device='cuda') \
if scale_ub else None
ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn,
scale_ub)
ref_out, ref_scales = ref_dynamic_per_token_quant(x, FP8_DTYPE, scale_ub)
ops_out, ops_scales = ops.scaled_fp8_quant(x,
scale_ub=scale_ub,
use_per_token_if_dynamic=True)
Expand Down
5 changes: 4 additions & 1 deletion vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,9 +369,12 @@ def scaled_fp8_quant(
# This code assumes batch_dim and num_tokens are flattened
assert (input.ndim == 2)
shape: Union[Tuple[int, int], torch.Size] = input.shape
# For rocm, the output fp8 dtype is torch.float_e3m3fnuz
out_dtype: torch.dtype = torch.float8_e4m3fnuz if vllm.utils.is_hip() \
else torch.float8_e4m3fn
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
output = torch.empty(shape, device=input.device, dtype=torch.float8_e4m3fn)
output = torch.empty(shape, device=input.device, dtype=out_dtype)

if scale is None:
if use_per_token_if_dynamic:
Expand Down
2 changes: 1 addition & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def _parse_quant_hf_config(self):

def _verify_quantization(self) -> None:
supported_quantization = [*QUANTIZATION_METHODS]
rocm_supported_quantization = ["gptq", "squeezellm"]
rocm_supported_quantization = ["gptq", "squeezellm", "fp8"]
optimized_quantization_methods = [
"fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin",
"fbgemm_fp8", "compressed_tensors", "compressed-tensors"
Expand Down
64 changes: 55 additions & 9 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, apply_fp8_linear, convert_to_channelwise,
create_per_tensor_scale_param, cutlass_fp8_supported,
per_tensor_dequantize, requantize_with_max_scale)
normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
requantize_with_max_scale)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import print_warning_once
from vllm.utils import is_hip, print_warning_once

ACTIVATION_SCHEMES = ["static", "dynamic"]

Expand Down Expand Up @@ -120,6 +121,9 @@ def __init__(self, quant_config: Fp8Config):
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
self.use_marlin = capability < 89 or envs.VLLM_TEST_FORCE_FP8_MARLIN
# Disable marlin for rocm
if is_hip():
self.use_marlin = False

def create_weights(
self,
Expand Down Expand Up @@ -168,6 +172,8 @@ def create_weights(
scale = create_per_tensor_scale_param(output_partition_sizes,
**extra_weight_attrs)
layer.register_parameter("input_scale", scale)
else:
layer.register_parameter("input_scale", None)

def process_weights_after_loading(self, layer: Module) -> None:
# If checkpoint not serialized fp8, quantize the weights.
Expand Down Expand Up @@ -202,9 +208,23 @@ def process_weights_after_loading(self, layer: Module) -> None:
# requantize the logical shards as a single weight.
else:
# Dequant -> Quant with max scale so we can run per tensor.
weight = layer.weight
weight_scale = layer.weight_scale

# If rocm, use float8_e4m3fnuz.
if is_hip():
weight, weight_scale, input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
weight_scale=weight_scale,
input_scale=layer.input_scale)
if input_scale is not None:
layer.input_scale = Parameter(input_scale,
requires_grad=False)

weight_scale, weight = requantize_with_max_scale(
weight=layer.weight,
weight_scale=layer.weight_scale,
weight=weight,
weight_scale=weight_scale,
logical_widths=layer.logical_widths,
)

Expand All @@ -214,8 +234,6 @@ def process_weights_after_loading(self, layer: Module) -> None:
if self.quant_config.activation_scheme == "static":
layer.input_scale = Parameter(layer.input_scale.max(),
requires_grad=False)
else:
layer.input_scale = None

if self.use_marlin:
prepare_fp8_layer_for_marlin(layer)
Expand Down Expand Up @@ -346,10 +364,12 @@ def process_weights_after_loading(self, layer: Module) -> None:

# If checkpoint is fp16, quantize in place.
if not self.quant_config.is_checkpoint_fp8_serialized:
# If rocm, use float8_e4m3fnuz as dtype
fp8_dtype = torch.float8_e4m3fnuz \
if is_hip() else torch.float8_e4m3fn
w13_weight = torch.empty_like(layer.w13_weight.data,
dtype=torch.float8_e4m3fn)
w2_weight = torch.empty_like(layer.w2_weight.data,
dtype=torch.float8_e4m3fn)
dtype=fp8_dtype)
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)

# Re-initialize w13_scale because we directly quantize
# merged w13 weights and generate a single scaling factor.
Expand Down Expand Up @@ -393,6 +413,32 @@ def process_weights_after_loading(self, layer: Module) -> None:
layer.w13_input_scale.max(), requires_grad=False)
layer.w2_input_scale = torch.nn.Parameter(
layer.w2_input_scale.max(), requires_grad=False)
# If rocm, normalize the weights and scales to e4m3fnuz
if is_hip():
# Normalize the weights and scales
w13_weight, w13_weight_scale, w13_input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
layer.w13_weight, layer.w13_weight_scale,
layer.w13_input_scale)
w2_weight, w2_weight_scale, w2_input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
layer.w2_weight, layer.w2_weight_scale,
layer.w2_input_scale)
# Reset the parameter
layer.w13_weight = torch.nn.Parameter(w13_weight,
requires_grad=False)
layer.w13_weight_scale = torch.nn.Parameter(
w13_weight_scale, requires_grad=False)
if w13_input_scale is not None:
layer.w13_input_scale = torch.nn.Parameter(
w13_input_scale, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight,
requires_grad=False)
layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
requires_grad=False)
if w2_input_scale is not None:
layer.w2_input_scale = torch.nn.Parameter(
w2_input_scale, requires_grad=False)

# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max then dequant and requant each expert.
Expand Down
Loading