Skip to content

Commit a1c8f37

Browse files
authored
dynamic distpatch of fp8 kernels (#14245)
Signed-off-by: Jeff Daily <[email protected]>
1 parent 08a1a11 commit a1c8f37

File tree

25 files changed

+292
-159
lines changed

25 files changed

+292
-159
lines changed

benchmarks/kernels/benchmark_moe.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@
1818
from vllm.platforms import current_platform
1919
from vllm.utils import FlexibleArgumentParser
2020

21-
FP8_DTYPE = torch.float8_e4m3fnuz if current_platform.is_rocm(
22-
) else torch.float8_e4m3fn
21+
FP8_DTYPE = current_platform.fp8_dtype()
2322

2423

2524
class BenchmarkConfig(TypedDict):

csrc/dispatch_utils.h

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@
66

77
#include <torch/all.h>
88

9+
// Need a special dispatch case macro since we will nest the FP8 dispatch.
10+
// Instead of the usual 'scalar_t', this names the dispatched type 'fp8_t'.
11+
#define AT_DISPATCH_FP8_CASE(enum_type, ...) \
12+
AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, fp8_t, __VA_ARGS__)
13+
914
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
1015
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
1116
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
@@ -14,17 +19,32 @@
1419
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
1520
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
1621

17-
// TODO(luka/varun): use FP8_TYPE macro after refactoring
18-
#ifndef USE_ROCM
19-
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
20-
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
21-
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
22-
#else
22+
// ROCm devices might use either fn or fnuz, so set up dispatch table for both.
23+
// A host-based check at runtime will create a preferred FP8 type for ROCm
24+
// such that the correct kernel is dispatched.
25+
#ifdef USE_ROCM
26+
#define VLLM_DISPATCH_CASE_FP8_TYPES(...) \
27+
AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
28+
AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__)
29+
2330
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
31+
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
2432
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) \
2533
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
34+
#else
35+
#define VLLM_DISPATCH_CASE_FP8_TYPES(...) \
36+
AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__)
37+
38+
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
39+
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
40+
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
2641
#endif
2742

43+
// When using this dispatch macro, the type is 'fp8_t' not 'scalar_t'.
44+
// See AT_DISPATCH_FP8_CASE above.
45+
#define VLLM_DISPATCH_FP8_TYPES(TYPE, NAME, ...) \
46+
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__))
47+
2848
#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \
2949
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__))
3050

csrc/layernorm_quant_kernels.cu

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
namespace vllm {
2222

2323
// TODO(woosuk): Further optimize this kernel.
24-
template <typename scalar_t>
24+
template <typename scalar_t, typename fp8_type>
2525
__global__ void rms_norm_static_fp8_quant_kernel(
26-
FP8_TYPE* __restrict__ out, // [..., hidden_size]
26+
fp8_type* __restrict__ out, // [..., hidden_size]
2727
const scalar_t* __restrict__ input, // [..., hidden_size]
2828
const scalar_t* __restrict__ weight, // [hidden_size]
2929
const float* __restrict__ scale, // [1]
@@ -52,18 +52,18 @@ __global__ void rms_norm_static_fp8_quant_kernel(
5252
float x = (float)input[blockIdx.x * hidden_size + idx];
5353
float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx];
5454
out[blockIdx.x * hidden_size + idx] =
55-
scaled_fp8_conversion<true>(out_norm, scale_inv);
55+
scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
5656
}
5757
}
5858

5959
/* Function specialization in the case of FP16/BF16 tensors.
6060
Additional optimizations we can make in this case are
6161
packed and vectorized operations, which help with the
6262
memory latency bottleneck. */
63-
template <typename scalar_t, int width>
63+
template <typename scalar_t, int width, typename fp8_type>
6464
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
6565
fused_add_rms_norm_static_fp8_quant_kernel(
66-
FP8_TYPE* __restrict__ out, // [..., hidden_size]
66+
fp8_type* __restrict__ out, // [..., hidden_size]
6767
scalar_t* __restrict__ input, // [..., hidden_size]
6868
scalar_t* __restrict__ residual, // [..., hidden_size]
6969
const scalar_t* __restrict__ weight, // [hidden_size]
@@ -114,18 +114,18 @@ fused_add_rms_norm_static_fp8_quant_kernel(
114114
#pragma unroll
115115
for (int i = 0; i < width; ++i) {
116116
out[id * width + i] =
117-
scaled_fp8_conversion<true>(float(temp.data[i]), scale_inv);
117+
scaled_fp8_conversion<true, fp8_type>(float(temp.data[i]), scale_inv);
118118
}
119119
}
120120
}
121121

122122
/* Generic fused_add_rms_norm_kernel
123123
The width field is not used here but necessary for other specializations.
124124
*/
125-
template <typename scalar_t, int width>
125+
template <typename scalar_t, int width, typename fp8_type>
126126
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
127127
fused_add_rms_norm_static_fp8_quant_kernel(
128-
FP8_TYPE* __restrict__ out, // [..., hidden_size]
128+
fp8_type* __restrict__ out, // [..., hidden_size]
129129
scalar_t* __restrict__ input, // [..., hidden_size]
130130
scalar_t* __restrict__ residual, // [..., hidden_size]
131131
const scalar_t* __restrict__ weight, // [hidden_size]
@@ -158,7 +158,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
158158
float x = (float)residual[blockIdx.x * hidden_size + idx];
159159
float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx];
160160
out[blockIdx.x * hidden_size + idx] =
161-
scaled_fp8_conversion<true>(out_norm, scale_inv);
161+
scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
162162
}
163163
}
164164

@@ -176,25 +176,33 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
176176
dim3 block(std::min(hidden_size, 1024));
177177
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
178178
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
179-
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
180-
vllm::rms_norm_static_fp8_quant_kernel<scalar_t>
181-
<<<grid, block, 0, stream>>>(
182-
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
183-
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(), epsilon,
184-
num_tokens, hidden_size);
185-
});
179+
VLLM_DISPATCH_FLOATING_TYPES(
180+
input.scalar_type(), "rms_norm_kernel_scalar_type", [&] {
181+
VLLM_DISPATCH_FP8_TYPES(
182+
out.scalar_type(), "rms_norm_kernel_fp8_type", [&] {
183+
vllm::rms_norm_static_fp8_quant_kernel<scalar_t, fp8_t>
184+
<<<grid, block, 0, stream>>>(
185+
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
186+
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(),
187+
epsilon, num_tokens, hidden_size);
188+
});
189+
});
186190
}
187191

188-
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
189-
VLLM_DISPATCH_FLOATING_TYPES( \
190-
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
191-
vllm::fused_add_rms_norm_static_fp8_quant_kernel<scalar_t, width> \
192-
<<<grid, block, 0, stream>>>( \
193-
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(), \
194-
residual.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(), \
195-
scale.data_ptr<float>(), epsilon, num_tokens, hidden_size); \
192+
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
193+
VLLM_DISPATCH_FLOATING_TYPES( \
194+
input.scalar_type(), "fused_add_rms_norm_kernel_scalar_type", [&] { \
195+
VLLM_DISPATCH_FP8_TYPES( \
196+
out.scalar_type(), "fused_add_rms_norm_kernel_fp8_type", [&] { \
197+
vllm::fused_add_rms_norm_static_fp8_quant_kernel<scalar_t, \
198+
width, fp8_t> \
199+
<<<grid, block, 0, stream>>>( \
200+
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(), \
201+
residual.data_ptr<scalar_t>(), \
202+
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(), \
203+
epsilon, num_tokens, hidden_size); \
204+
}); \
196205
});
197-
198206
void fused_add_rms_norm_static_fp8_quant(
199207
torch::Tensor& out, // [..., hidden_size],
200208
torch::Tensor& input, // [..., hidden_size]

csrc/quantization/fp8/amd/quant_utils.cuh

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,28 @@ namespace vllm {
1313
namespace fp8 {
1414
#ifdef ENABLE_FP8
1515

16+
// Use hardware cvt instruction for fp8 on rocm
17+
template <typename fp8_type>
18+
__device__ __forceinline__ fp8_type cvt_c10(float const r) {
19+
return {};
20+
}
21+
22+
template <>
23+
__device__ __forceinline__ c10::Float8_e4m3fn cvt_c10(float const r) {
24+
return c10::Float8_e4m3fn(
25+
__hip_cvt_float_to_fp8(r, __hip_fp8_e4m3::__default_saturation,
26+
__hip_fp8_e4m3::__default_interpret),
27+
c10::Float8_e4m3fn::from_bits());
28+
}
29+
30+
template <>
31+
__device__ __forceinline__ c10::Float8_e4m3fnuz cvt_c10(float const r) {
32+
return c10::Float8_e4m3fnuz(
33+
__hip_cvt_float_to_fp8(r, __hip_fp8_e4m3_fnuz::__default_saturation,
34+
__hip_fp8_e4m3_fnuz::__default_interpret),
35+
c10::Float8_e4m3fnuz::from_bits());
36+
}
37+
1638
template <typename Tout, typename Tin>
1739
__inline__ __device__ Tout vec_conversion(const Tin& x) {
1840
return x;

csrc/quantization/fp8/common.cu

Lines changed: 43 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111

1212
namespace vllm {
1313

14-
template <typename scalar_t>
15-
__global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out,
14+
template <typename scalar_t, typename fp8_type>
15+
__global__ void scaled_fp8_quant_kernel(fp8_type* __restrict__ out,
1616
const scalar_t* __restrict__ input,
1717
const float* __restrict__ scale,
1818
int64_t num_elems) {
@@ -25,20 +25,21 @@ __global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out,
2525
out, input, inverted_scale, num_elems, tid, blockDim.x * gridDim.x);
2626
}
2727

28-
template <typename scalar_t>
28+
template <typename scalar_t, typename fp8_type>
2929
__global__ void dynamic_per_token_scaled_fp8_quant_kernel(
30-
FP8_TYPE* __restrict__ out, float* __restrict__ scale,
30+
fp8_type* __restrict__ out, float* __restrict__ scale,
3131
scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
3232
const int hidden_size) {
33-
float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
33+
float const min_scaling_factor =
34+
1.0f / (fp8_e4m3_adjusted_max_v<fp8_type> * 512.f);
3435

3536
int const tid = threadIdx.x;
3637
int const token_idx = blockIdx.x;
3738

3839
// Use int64 to avoid overflowing an int32 when calculating this offset
3940
int64_t offset = static_cast<int64_t>(token_idx) * hidden_size;
4041
scalar_t const* __restrict__ token_input = &input[offset];
41-
FP8_TYPE* __restrict__ token_output = &out[offset];
42+
fp8_type* __restrict__ token_output = &out[offset];
4243

4344
// For vectorization, token_input and token_output pointers need to be
4445
// aligned at 8-byte and 4-byte addresses respectively.
@@ -66,7 +67,8 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
6667
token_scale = block_absmax_val_maybe;
6768
}
6869
// token scale computation
69-
token_scale = max(token_scale / FP8_E4M3_MAX, min_scaling_factor);
70+
token_scale = max(token_scale / fp8_e4m3_adjusted_max_v<fp8_type>,
71+
min_scaling_factor);
7072
scale[token_idx] = token_scale;
7173
}
7274
__syncthreads();
@@ -77,7 +79,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
7779
token_output, token_input, token_scale, hidden_size, tid, blockDim.x);
7880
} else {
7981
for (int i = tid; i < hidden_size; i += blockDim.x) {
80-
token_output[i] = scaled_fp8_conversion<false>(
82+
token_output[i] = scaled_fp8_conversion<false, fp8_type>(
8183
static_cast<float>(token_input[i]), token_scale);
8284
}
8385
}
@@ -96,10 +98,14 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
9698
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
9799
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
98100
VLLM_DISPATCH_FLOATING_TYPES(
99-
input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
100-
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
101-
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
102-
scale.data_ptr<float>(), num_elems);
101+
input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] {
102+
VLLM_DISPATCH_FP8_TYPES(
103+
out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] {
104+
vllm::scaled_fp8_quant_kernel<scalar_t, fp8_t>
105+
<<<grid, block, 0, stream>>>(
106+
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
107+
scale.data_ptr<float>(), num_elems);
108+
});
103109
});
104110
}
105111

@@ -114,12 +120,18 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
114120
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
115121
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
116122
VLLM_DISPATCH_FLOATING_TYPES(
117-
input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
118-
vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>(
119-
scale.data_ptr<float>(), input.data_ptr<scalar_t>(), num_elems);
120-
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
121-
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
122-
scale.data_ptr<float>(), num_elems);
123+
input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] {
124+
VLLM_DISPATCH_FP8_TYPES(
125+
out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] {
126+
vllm::segmented_max_reduction<scalar_t, fp8_t>
127+
<<<grid, block, 0, stream>>>(scale.data_ptr<float>(),
128+
input.data_ptr<scalar_t>(),
129+
num_elems);
130+
vllm::scaled_fp8_quant_kernel<scalar_t, fp8_t>
131+
<<<grid, block, 0, stream>>>(
132+
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
133+
scale.data_ptr<float>(), num_elems);
134+
});
123135
});
124136
}
125137

@@ -138,12 +150,18 @@ void dynamic_per_token_scaled_fp8_quant(
138150
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
139151
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
140152
VLLM_DISPATCH_FLOATING_TYPES(
141-
input.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel", [&] {
142-
vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t>
143-
<<<grid, block, 0, stream>>>(
144-
out.data_ptr<FP8_TYPE>(), scales.data_ptr<float>(),
145-
input.data_ptr<scalar_t>(),
146-
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
147-
hidden_size);
153+
input.scalar_type(),
154+
"dynamic_per_token_scaled_fp8_quant_kernel_scalar_type", [&] {
155+
VLLM_DISPATCH_FP8_TYPES(
156+
out.scalar_type(),
157+
"dynamic_per_token_scaled_fp8_quant_kernel_fp8_type", [&] {
158+
vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t, fp8_t>
159+
<<<grid, block, 0, stream>>>(
160+
out.data_ptr<fp8_t>(), scales.data_ptr<float>(),
161+
input.data_ptr<scalar_t>(),
162+
scale_ub.has_value() ? scale_ub->data_ptr<float>()
163+
: nullptr,
164+
hidden_size);
165+
});
148166
});
149167
}

0 commit comments

Comments
 (0)