Skip to content
3 changes: 1 addition & 2 deletions benchmarks/kernels/benchmark_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser

FP8_DTYPE = torch.float8_e4m3fnuz if current_platform.is_rocm(
) else torch.float8_e4m3fn
FP8_DTYPE = current_platform.fp8_dtype()


class BenchmarkConfig(TypedDict):
Expand Down
32 changes: 26 additions & 6 deletions csrc/dispatch_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@

#include <torch/all.h>

// Need a special dispatch case macro since we will nest the FP8 dispatch.
// Instead of the usual 'scalar_t', this names the dispatched type 'fp8_t'.
#define AT_DISPATCH_FP8_CASE(enum_type, ...) \
AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, fp8_t, __VA_ARGS__)

#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
Expand All @@ -14,17 +19,32 @@
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))

// TODO(luka/varun): use FP8_TYPE macro after refactoring
#ifndef USE_ROCM
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
#else
// ROCm devices might use either fn or fnuz, so set up dispatch table for both.
// A host-based check at runtime will create a preferred FP8 type for ROCm
// such that the correct kernel is dispatched.
#ifdef USE_ROCM
#define VLLM_DISPATCH_CASE_FP8_TYPES(...) \
AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__)

#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
#else
#define VLLM_DISPATCH_CASE_FP8_TYPES(...) \
AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__)

#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
#endif

// When using this dispatch macro, the type is 'fp8_t' not 'scalar_t'.
// See AT_DISPATCH_FP8_CASE above.
#define VLLM_DISPATCH_FP8_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__))

#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__))

Expand Down
58 changes: 33 additions & 25 deletions csrc/layernorm_quant_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
namespace vllm {

// TODO(woosuk): Further optimize this kernel.
template <typename scalar_t>
template <typename scalar_t, typename fp8_type>
__global__ void rms_norm_static_fp8_quant_kernel(
FP8_TYPE* __restrict__ out, // [..., hidden_size]
fp8_type* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float* __restrict__ scale, // [1]
Expand Down Expand Up @@ -52,18 +52,18 @@ __global__ void rms_norm_static_fp8_quant_kernel(
float x = (float)input[blockIdx.x * hidden_size + idx];
float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx];
out[blockIdx.x * hidden_size + idx] =
scaled_fp8_conversion<true>(out_norm, scale_inv);
scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
}
}

/* Function specialization in the case of FP16/BF16 tensors.
Additional optimizations we can make in this case are
packed and vectorized operations, which help with the
memory latency bottleneck. */
template <typename scalar_t, int width>
template <typename scalar_t, int width, typename fp8_type>
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
fused_add_rms_norm_static_fp8_quant_kernel(
FP8_TYPE* __restrict__ out, // [..., hidden_size]
fp8_type* __restrict__ out, // [..., hidden_size]
scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
Expand Down Expand Up @@ -114,18 +114,18 @@ fused_add_rms_norm_static_fp8_quant_kernel(
#pragma unroll
for (int i = 0; i < width; ++i) {
out[id * width + i] =
scaled_fp8_conversion<true>(float(temp.data[i]), scale_inv);
scaled_fp8_conversion<true, fp8_type>(float(temp.data[i]), scale_inv);
}
}
}

/* Generic fused_add_rms_norm_kernel
The width field is not used here but necessary for other specializations.
*/
template <typename scalar_t, int width>
template <typename scalar_t, int width, typename fp8_type>
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
fused_add_rms_norm_static_fp8_quant_kernel(
FP8_TYPE* __restrict__ out, // [..., hidden_size]
fp8_type* __restrict__ out, // [..., hidden_size]
scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
Expand Down Expand Up @@ -158,7 +158,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
float x = (float)residual[blockIdx.x * hidden_size + idx];
float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx];
out[blockIdx.x * hidden_size + idx] =
scaled_fp8_conversion<true>(out_norm, scale_inv);
scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
}
}

Expand All @@ -176,25 +176,33 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
vllm::rms_norm_static_fp8_quant_kernel<scalar_t>
<<<grid, block, 0, stream>>>(
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(), epsilon,
num_tokens, hidden_size);
});
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "rms_norm_kernel_scalar_type", [&] {
VLLM_DISPATCH_FP8_TYPES(
out.scalar_type(), "rms_norm_kernel_fp8_type", [&] {
vllm::rms_norm_static_fp8_quant_kernel<scalar_t, fp8_t>
<<<grid, block, 0, stream>>>(
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(),
epsilon, num_tokens, hidden_size);
});
});
}

#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
vllm::fused_add_rms_norm_static_fp8_quant_kernel<scalar_t, width> \
<<<grid, block, 0, stream>>>( \
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(), \
residual.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(), \
scale.data_ptr<float>(), epsilon, num_tokens, hidden_size); \
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "fused_add_rms_norm_kernel_scalar_type", [&] { \
VLLM_DISPATCH_FP8_TYPES( \
out.scalar_type(), "fused_add_rms_norm_kernel_fp8_type", [&] { \
vllm::fused_add_rms_norm_static_fp8_quant_kernel<scalar_t, \
width, fp8_t> \
<<<grid, block, 0, stream>>>( \
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(), \
residual.data_ptr<scalar_t>(), \
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(), \
epsilon, num_tokens, hidden_size); \
}); \
});

void fused_add_rms_norm_static_fp8_quant(
torch::Tensor& out, // [..., hidden_size],
torch::Tensor& input, // [..., hidden_size]
Expand Down
22 changes: 22 additions & 0 deletions csrc/quantization/fp8/amd/quant_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,28 @@ namespace vllm {
namespace fp8 {
#ifdef ENABLE_FP8

// Use hardware cvt instruction for fp8 on rocm
template <typename fp8_type>
__device__ __forceinline__ fp8_type cvt_c10(float const r) {
return {};
}

template <>
__device__ __forceinline__ c10::Float8_e4m3fn cvt_c10(float const r) {
return c10::Float8_e4m3fn(
__hip_cvt_float_to_fp8(r, __hip_fp8_e4m3::__default_saturation,
__hip_fp8_e4m3::__default_interpret),
c10::Float8_e4m3fn::from_bits());
}

template <>
__device__ __forceinline__ c10::Float8_e4m3fnuz cvt_c10(float const r) {
return c10::Float8_e4m3fnuz(
__hip_cvt_float_to_fp8(r, __hip_fp8_e4m3_fnuz::__default_saturation,
__hip_fp8_e4m3_fnuz::__default_interpret),
c10::Float8_e4m3fnuz::from_bits());
}

template <typename Tout, typename Tin>
__inline__ __device__ Tout vec_conversion(const Tin& x) {
return x;
Expand Down
68 changes: 43 additions & 25 deletions csrc/quantization/fp8/common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

namespace vllm {

template <typename scalar_t>
__global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out,
template <typename scalar_t, typename fp8_type>
__global__ void scaled_fp8_quant_kernel(fp8_type* __restrict__ out,
const scalar_t* __restrict__ input,
const float* __restrict__ scale,
int64_t num_elems) {
Expand All @@ -25,20 +25,21 @@ __global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out,
out, input, inverted_scale, num_elems, tid, blockDim.x * gridDim.x);
}

template <typename scalar_t>
template <typename scalar_t, typename fp8_type>
__global__ void dynamic_per_token_scaled_fp8_quant_kernel(
FP8_TYPE* __restrict__ out, float* __restrict__ scale,
fp8_type* __restrict__ out, float* __restrict__ scale,
scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
const int hidden_size) {
float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
float const min_scaling_factor =
1.0f / (fp8_e4m3_adjusted_max_v<fp8_type> * 512.f);

int const tid = threadIdx.x;
int const token_idx = blockIdx.x;

// Use int64 to avoid overflowing an int32 when calculating this offset
int64_t offset = static_cast<int64_t>(token_idx) * hidden_size;
scalar_t const* __restrict__ token_input = &input[offset];
FP8_TYPE* __restrict__ token_output = &out[offset];
fp8_type* __restrict__ token_output = &out[offset];

// For vectorization, token_input and token_output pointers need to be
// aligned at 8-byte and 4-byte addresses respectively.
Expand Down Expand Up @@ -66,7 +67,8 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
token_scale = block_absmax_val_maybe;
}
// token scale computation
token_scale = max(token_scale / FP8_E4M3_MAX, min_scaling_factor);
token_scale = max(token_scale / fp8_e4m3_adjusted_max_v<fp8_type>,
min_scaling_factor);
scale[token_idx] = token_scale;
}
__syncthreads();
Expand All @@ -77,7 +79,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
token_output, token_input, token_scale, hidden_size, tid, blockDim.x);
} else {
for (int i = tid; i < hidden_size; i += blockDim.x) {
token_output[i] = scaled_fp8_conversion<false>(
token_output[i] = scaled_fp8_conversion<false, fp8_type>(
static_cast<float>(token_input[i]), token_scale);
}
}
Expand All @@ -96,10 +98,14 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
scale.data_ptr<float>(), num_elems);
input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] {
VLLM_DISPATCH_FP8_TYPES(
out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] {
vllm::scaled_fp8_quant_kernel<scalar_t, fp8_t>
<<<grid, block, 0, stream>>>(
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
scale.data_ptr<float>(), num_elems);
});
});
}

Expand All @@ -114,12 +120,18 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>(
scale.data_ptr<float>(), input.data_ptr<scalar_t>(), num_elems);
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
scale.data_ptr<float>(), num_elems);
input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] {
VLLM_DISPATCH_FP8_TYPES(
out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] {
vllm::segmented_max_reduction<scalar_t, fp8_t>
<<<grid, block, 0, stream>>>(scale.data_ptr<float>(),
input.data_ptr<scalar_t>(),
num_elems);
vllm::scaled_fp8_quant_kernel<scalar_t, fp8_t>
<<<grid, block, 0, stream>>>(
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
scale.data_ptr<float>(), num_elems);
});
});
}

Expand All @@ -138,12 +150,18 @@ void dynamic_per_token_scaled_fp8_quant(
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel", [&] {
vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t>
<<<grid, block, 0, stream>>>(
out.data_ptr<FP8_TYPE>(), scales.data_ptr<float>(),
input.data_ptr<scalar_t>(),
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
hidden_size);
input.scalar_type(),
"dynamic_per_token_scaled_fp8_quant_kernel_scalar_type", [&] {
VLLM_DISPATCH_FP8_TYPES(
out.scalar_type(),
"dynamic_per_token_scaled_fp8_quant_kernel_fp8_type", [&] {
vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t, fp8_t>
<<<grid, block, 0, stream>>>(
out.data_ptr<fp8_t>(), scales.data_ptr<float>(),
input.data_ptr<scalar_t>(),
scale_ub.has_value() ? scale_ub->data_ptr<float>()
: nullptr,
hidden_size);
});
});
}
Loading