99
1010#include " ../../reduction_utils.cuh"
1111
12+ #ifndef USE_ROCM
13+ using FP8_TYPE = c10::Float8_e4m3fn;
14+ C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
15+ std::numeric_limits<FP8_TYPE>::max();
16+ #else
17+ #include " amd/hip_float8.h"
18+ using FP8_TYPE = c10::Float8_e4m3fnuz;
19+ // Using the default max value from pytorch (240.0) will cause accuracy
20+ // issue when running dynamic quantization. Here use 224.0f for rocm.
21+ constexpr auto FP8_E4M3_MAX = 224 .0f ;
22+ #endif
23+
1224namespace vllm {
1325
1426__device__ __forceinline__ float atomicMaxFloat (float * addr, float value) {
@@ -21,11 +33,9 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
2133 return old;
2234}
2335
24- #define FP8_E4M3_MAX std::numeric_limits<c10::Float8_e4m3fn>::max()
25-
2636template <bool is_scale_inverted>
27- __device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion (
28- float const val, float const scale) {
37+ __device__ __forceinline__ FP8_TYPE scaled_fp8_conversion (float const val,
38+ float const scale) {
2939 float x = 0 .0f ;
3040 if constexpr (is_scale_inverted) {
3141 x = val * scale;
@@ -34,7 +44,13 @@ __device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(
3444 }
3545
3646 float r = fmax (-FP8_E4M3_MAX, fmin (x, FP8_E4M3_MAX));
47+ #ifndef USE_ROCM
3748 return static_cast <c10::Float8_e4m3fn>(r);
49+ #else
50+ // Use hardware cvt instruction for fp8 on rocm
51+ return c10::Float8_e4m3fnuz (hip_fp8 (r).data ,
52+ c10::Float8_e4m3fnuz::from_bits ());
53+ #endif
3854}
3955
4056// Compute the absolute maximum m of the input tensor and store
@@ -74,8 +90,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
7490 // Finally, since cache[0] contains the maximum for this thread block,
7591 // atomically write the max to the target location
7692 if (threadIdx .x == 0 ) {
77- atomicMaxFloat (scale,
78- cache[0 ] / std::numeric_limits<c10::Float8_e4m3fn>::max ());
93+ atomicMaxFloat (scale, cache[0 ] / FP8_E4M3_MAX);
7994 }
8095}
8196
@@ -88,10 +103,10 @@ struct __align__(8) vec4_t {
88103};
89104
90105typedef struct __align__ (4 ) {
91- c10::Float8_e4m3fn x;
92- c10::Float8_e4m3fn y;
93- c10::Float8_e4m3fn z;
94- c10::Float8_e4m3fn w;
106+ FP8_TYPE x;
107+ FP8_TYPE y;
108+ FP8_TYPE z;
109+ FP8_TYPE w;
95110}
96111float8x4_t ;
97112
@@ -124,7 +139,7 @@ __device__ float thread_max_vec(scalar_t const* __restrict__ input,
124139}
125140
126141template <typename scalar_t , bool is_scale_inverted>
127- __device__ void scaled_fp8_conversion_vec (c10::Float8_e4m3fn * __restrict__ out,
142+ __device__ void scaled_fp8_conversion_vec (FP8_TYPE * __restrict__ out,
128143 scalar_t const * __restrict__ input,
129144 float const scale,
130145 int64_t const num_elems,
@@ -160,7 +175,7 @@ __device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
160175}
161176
162177template <typename scalar_t >
163- __global__ void scaled_fp8_quant_kernel (c10::Float8_e4m3fn * __restrict__ out,
178+ __global__ void scaled_fp8_quant_kernel (FP8_TYPE * __restrict__ out,
164179 const scalar_t * __restrict__ input,
165180 const float * __restrict__ scale,
166181 int64_t num_elems) {
@@ -175,7 +190,7 @@ __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
175190
176191template <typename scalar_t >
177192__global__ void dynamic_per_token_scaled_fp8_quant_kernel (
178- c10::Float8_e4m3fn * __restrict__ out, float * __restrict__ scale,
193+ FP8_TYPE * __restrict__ out, float * __restrict__ scale,
179194 scalar_t const * __restrict__ input, float const * __restrict__ scale_ub,
180195 const int hidden_size) {
181196 float const min_scaling_factor = 1 .0f / (FP8_E4M3_MAX * 512 .f );
@@ -184,7 +199,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
184199 int const token_idx = blockIdx .x ;
185200
186201 scalar_t const * __restrict__ token_input = &input[token_idx * hidden_size];
187- c10::Float8_e4m3fn * __restrict__ token_output = &out[token_idx * hidden_size];
202+ FP8_TYPE * __restrict__ token_output = &out[token_idx * hidden_size];
188203
189204 // For vectorization, token_input and token_output pointers need to be
190205 // aligned at 8-byte and 4-byte addresses respectively.
@@ -241,7 +256,7 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
241256 VLLM_DISPATCH_FLOATING_TYPES (
242257 input.scalar_type (), " scaled_fp8_quant_kernel" , [&] {
243258 vllm::scaled_fp8_quant_kernel<scalar_t ><<<grid, block, 0 , stream>>> (
244- out.data_ptr <c10::Float8_e4m3fn >(), input.data_ptr <scalar_t >(),
259+ out.data_ptr <FP8_TYPE >(), input.data_ptr <scalar_t >(),
245260 scale.data_ptr <float >(), num_elems);
246261 });
247262}
@@ -261,7 +276,7 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
261276 vllm::segmented_max_reduction<scalar_t ><<<grid, block, 0 , stream>>> (
262277 scale.data_ptr <float >(), input.data_ptr <scalar_t >(), num_elems);
263278 vllm::scaled_fp8_quant_kernel<scalar_t ><<<grid, block, 0 , stream>>> (
264- out.data_ptr <c10::Float8_e4m3fn >(), input.data_ptr <scalar_t >(),
279+ out.data_ptr <FP8_TYPE >(), input.data_ptr <scalar_t >(),
265280 scale.data_ptr <float >(), num_elems);
266281 });
267282}
@@ -284,7 +299,7 @@ void dynamic_per_token_scaled_fp8_quant(
284299 input.scalar_type (), " dynamic_per_token_scaled_fp8_quant_kernel" , [&] {
285300 vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t >
286301 <<<grid, block, 0 , stream>>> (
287- out.data_ptr <c10::Float8_e4m3fn >(), scales.data_ptr <float >(),
302+ out.data_ptr <FP8_TYPE >(), scales.data_ptr <float >(),
288303 input.data_ptr <scalar_t >(),
289304 scale_ub.has_value () ? scale_ub->data_ptr <float >() : nullptr ,
290305 hidden_size);
0 commit comments