1111
1212namespace vllm {
1313
14- template <typename scalar_t >
15- __global__ void scaled_fp8_quant_kernel (FP8_TYPE * __restrict__ out,
14+ template <typename scalar_t , typename fp8_type >
15+ __global__ void scaled_fp8_quant_kernel (fp8_type * __restrict__ out,
1616 const scalar_t * __restrict__ input,
1717 const float * __restrict__ scale,
1818 int64_t num_elems) {
@@ -25,20 +25,21 @@ __global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out,
2525 out, input, inverted_scale, num_elems, tid, blockDim .x * gridDim .x );
2626}
2727
28- template <typename scalar_t >
28+ template <typename scalar_t , typename fp8_type >
2929__global__ void dynamic_per_token_scaled_fp8_quant_kernel (
30- FP8_TYPE * __restrict__ out, float * __restrict__ scale,
30+ fp8_type * __restrict__ out, float * __restrict__ scale,
3131 scalar_t const * __restrict__ input, float const * __restrict__ scale_ub,
3232 const int hidden_size) {
33- float const min_scaling_factor = 1 .0f / (FP8_E4M3_MAX * 512 .f );
33+ float const min_scaling_factor =
34+ 1 .0f / (fp8_e4m3_adjusted_max_v<fp8_type> * 512 .f );
3435
3536 int const tid = threadIdx .x ;
3637 int const token_idx = blockIdx .x ;
3738
3839 // Use int64 to avoid overflowing an int32 when calculating this offset
3940 int64_t offset = static_cast <int64_t >(token_idx) * hidden_size;
4041 scalar_t const * __restrict__ token_input = &input[offset];
41- FP8_TYPE * __restrict__ token_output = &out[offset];
42+ fp8_type * __restrict__ token_output = &out[offset];
4243
4344 // For vectorization, token_input and token_output pointers need to be
4445 // aligned at 8-byte and 4-byte addresses respectively.
@@ -66,7 +67,8 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
6667 token_scale = block_absmax_val_maybe;
6768 }
6869 // token scale computation
69- token_scale = max (token_scale / FP8_E4M3_MAX, min_scaling_factor);
70+ token_scale = max (token_scale / fp8_e4m3_adjusted_max_v<fp8_type>,
71+ min_scaling_factor);
7072 scale[token_idx] = token_scale;
7173 }
7274 __syncthreads ();
@@ -77,7 +79,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
7779 token_output, token_input, token_scale, hidden_size, tid, blockDim .x );
7880 } else {
7981 for (int i = tid; i < hidden_size; i += blockDim .x ) {
80- token_output[i] = scaled_fp8_conversion<false >(
82+ token_output[i] = scaled_fp8_conversion<false , fp8_type >(
8183 static_cast <float >(token_input[i]), token_scale);
8284 }
8385 }
@@ -96,10 +98,14 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
9698 const at::cuda::OptionalCUDAGuard device_guard (device_of (input));
9799 const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
98100 VLLM_DISPATCH_FLOATING_TYPES (
99- input.scalar_type (), " scaled_fp8_quant_kernel" , [&] {
100- vllm::scaled_fp8_quant_kernel<scalar_t ><<<grid, block, 0 , stream>>> (
101- out.data_ptr <FP8_TYPE>(), input.data_ptr <scalar_t >(),
102- scale.data_ptr <float >(), num_elems);
101+ input.scalar_type (), " scaled_fp8_quant_kernel_scalar_type" , [&] {
102+ VLLM_DISPATCH_FP8_TYPES (
103+ out.scalar_type (), " scaled_fp8_quant_kernel_fp8_type" , [&] {
104+ vllm::scaled_fp8_quant_kernel<scalar_t , fp8_t >
105+ <<<grid, block, 0 , stream>>> (
106+ out.data_ptr <fp8_t >(), input.data_ptr <scalar_t >(),
107+ scale.data_ptr <float >(), num_elems);
108+ });
103109 });
104110}
105111
@@ -114,12 +120,18 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
114120 const at::cuda::OptionalCUDAGuard device_guard (device_of (input));
115121 const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
116122 VLLM_DISPATCH_FLOATING_TYPES (
117- input.scalar_type (), " scaled_fp8_quant_kernel" , [&] {
118- vllm::segmented_max_reduction<scalar_t ><<<grid, block, 0 , stream>>> (
119- scale.data_ptr <float >(), input.data_ptr <scalar_t >(), num_elems);
120- vllm::scaled_fp8_quant_kernel<scalar_t ><<<grid, block, 0 , stream>>> (
121- out.data_ptr <FP8_TYPE>(), input.data_ptr <scalar_t >(),
122- scale.data_ptr <float >(), num_elems);
123+ input.scalar_type (), " scaled_fp8_quant_kernel_scalar_type" , [&] {
124+ VLLM_DISPATCH_FP8_TYPES (
125+ out.scalar_type (), " scaled_fp8_quant_kernel_fp8_type" , [&] {
126+ vllm::segmented_max_reduction<scalar_t , fp8_t >
127+ <<<grid, block, 0 , stream>>> (scale.data_ptr <float >(),
128+ input.data_ptr <scalar_t >(),
129+ num_elems);
130+ vllm::scaled_fp8_quant_kernel<scalar_t , fp8_t >
131+ <<<grid, block, 0 , stream>>> (
132+ out.data_ptr <fp8_t >(), input.data_ptr <scalar_t >(),
133+ scale.data_ptr <float >(), num_elems);
134+ });
123135 });
124136}
125137
@@ -138,12 +150,18 @@ void dynamic_per_token_scaled_fp8_quant(
138150 const at::cuda::OptionalCUDAGuard device_guard (device_of (input));
139151 const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
140152 VLLM_DISPATCH_FLOATING_TYPES (
141- input.scalar_type (), " dynamic_per_token_scaled_fp8_quant_kernel" , [&] {
142- vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t >
143- <<<grid, block, 0 , stream>>> (
144- out.data_ptr <FP8_TYPE>(), scales.data_ptr <float >(),
145- input.data_ptr <scalar_t >(),
146- scale_ub.has_value () ? scale_ub->data_ptr <float >() : nullptr ,
147- hidden_size);
153+ input.scalar_type (),
154+ " dynamic_per_token_scaled_fp8_quant_kernel_scalar_type" , [&] {
155+ VLLM_DISPATCH_FP8_TYPES (
156+ out.scalar_type (),
157+ " dynamic_per_token_scaled_fp8_quant_kernel_fp8_type" , [&] {
158+ vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t , fp8_t >
159+ <<<grid, block, 0 , stream>>> (
160+ out.data_ptr <fp8_t >(), scales.data_ptr <float >(),
161+ input.data_ptr <scalar_t >(),
162+ scale_ub.has_value () ? scale_ub->data_ptr <float >()
163+ : nullptr ,
164+ hidden_size);
165+ });
148166 });
149167}
0 commit comments