From 0f6ae437b491b980fb162c02e34b739157691a14 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Tue, 29 Jul 2025 14:33:40 -0700 Subject: [PATCH 01/11] init version Signed-off-by: yewentao256 --- csrc/quantization/fp8/common.cu | 45 ++++++------- csrc/quantization/fp8/common.cuh | 66 ------------------- .../fused_kernels/layernorm_utils.cuh | 2 - 3 files changed, 21 insertions(+), 92 deletions(-) diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index 0e1eab66f0b9..ad04148da7d4 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -1,6 +1,6 @@ #include "common.cuh" #include "dispatch_utils.h" - +#include "../vectorization_utils.cuh" #include #ifndef USE_ROCM @@ -21,8 +21,12 @@ __global__ void scaled_fp8_quant_kernel(fp8_type* __restrict__ out, // Invert the scale so that we can use multiplications to avoid expensive // division. const float inverted_scale = 1.0f / (*scale); - scaled_fp8_conversion_vec( - out, input, inverted_scale, num_elems, tid, blockDim.x * gridDim.x); + vectorize_with_alignment<16>( + input, out, num_elems, tid, blockDim.x * gridDim.x, + [=] __device__(fp8_type & dst, const scalar_t& src) { + dst = scaled_fp8_conversion(static_cast(src), + inverted_scale); + }); } template @@ -38,19 +42,14 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel( scalar_t const* __restrict__ token_input = &input[offset]; fp8_type* __restrict__ token_output = &out[offset]; - // For vectorization, token_input and token_output pointers need to be - // aligned at 32-byte and 16-byte addresses respectively. - bool const can_vectorize = hidden_size % 16 == 0; - + // 1) compute per-token absmax float absmax_val = 0.0f; - if (can_vectorize) { - absmax_val = thread_max_vec(token_input, hidden_size, tid, blockDim.x); - } else { - for (int i = tid; i < hidden_size; i += blockDim.x) { - float const x = static_cast(token_input[i]); - absmax_val = fmaxf(absmax_val, fabsf(x)); - } - } + vectorize_read_with_alignment<16>(token_input, hidden_size, tid, blockDim.x, + [&] __device__(const scalar_t& src) { + const float v = + fabsf(static_cast(src)); + absmax_val = fmaxf(absmax_val, v); + }); using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage reduceStorage; @@ -70,16 +69,14 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel( } __syncthreads(); + // 2) quantize // Note that we don't use inverted scales so we can match FBGemm impl. - if (can_vectorize) { - scaled_fp8_conversion_vec( - token_output, token_input, token_scale, hidden_size, tid, blockDim.x); - } else { - for (int i = tid; i < hidden_size; i += blockDim.x) { - token_output[i] = scaled_fp8_conversion( - static_cast(token_input[i]), token_scale); - } - } + vectorize_with_alignment<16>( + token_input, token_output, hidden_size, tid, blockDim.x, + [=] __device__(fp8_type & dst, const scalar_t& src) { + dst = scaled_fp8_conversion(static_cast(src), + token_scale); + }); } } // namespace vllm diff --git a/csrc/quantization/fp8/common.cuh b/csrc/quantization/fp8/common.cuh index d36f94a8f10d..973d0bfc1648 100644 --- a/csrc/quantization/fp8/common.cuh +++ b/csrc/quantization/fp8/common.cuh @@ -96,70 +96,4 @@ __global__ void segmented_max_reduction(float* __restrict__ scale, } } -template -__device__ float thread_max_vec(scalar_t const* __restrict__ input, - int64_t const num_elems, int const tid, - int const step) { - constexpr size_t VEC_SIZE = 16; - using scalarxN_t = vec_n_t; - // Vectorized input/output to better utilize memory bandwidth. - auto const* vectorized_in = reinterpret_cast(input); - - // num_elems / VEC_SIZE (which is 16) - int64_t const num_vec_elems = num_elems >> 4; - float absmax_val = 0.0f; - -#pragma unroll - for (int64_t i = tid; i < num_vec_elems; i += step) { - scalarxN_t in_vec = vectorized_in[i]; -#pragma unroll - for (int j = 0; j < VEC_SIZE; ++j) { - absmax_val = fmaxf(absmax_val, fabsf(in_vec.val[j])); - } - } - - // Handle the remaining elements if num_elems is not divisible by VEC_SIZE - for (int64_t i = num_vec_elems * VEC_SIZE + tid; i < num_elems; i += step) { - absmax_val = fmaxf(absmax_val, fabsf(input[i])); - } - - return absmax_val; -} - -template -__device__ void scaled_fp8_conversion_vec(fp8_type* __restrict__ out, - scalar_t const* __restrict__ input, - float const scale, - int64_t const num_elems, - int const tid, int const step) { - constexpr size_t VEC_SIZE = 16; - using scalarxN_t = vec_n_t; - using float8xN_t = q8_n_t; - // Vectorized input/output to better utilize memory bandwidth. - auto const* vectorized_in = reinterpret_cast(input); - auto* vectorized_out = reinterpret_cast(out); - - // num_elems / VEC_SIZE (which is 16) - int64_t const num_vec_elems = num_elems >> 4; - -#pragma unroll - for (int64_t i = tid; i < num_vec_elems; i += step) { - scalarxN_t in_vec = vectorized_in[i]; - float8xN_t out_vec; - -#pragma unroll - for (int j = 0; j < VEC_SIZE; ++j) { - out_vec.val[j] = scaled_fp8_conversion( - static_cast(in_vec.val[j]), scale); - } - vectorized_out[i] = out_vec; - } - - // Handle the remaining elements if num_elems is not divisible by VEC_SIZE - for (int64_t i = num_vec_elems * VEC_SIZE + tid; i < num_elems; i += step) { - out[i] = scaled_fp8_conversion( - static_cast(input[i]), scale); - } -} - } // namespace vllm diff --git a/csrc/quantization/fused_kernels/layernorm_utils.cuh b/csrc/quantization/fused_kernels/layernorm_utils.cuh index 3f188872d80d..3fb80afe806f 100644 --- a/csrc/quantization/fused_kernels/layernorm_utils.cuh +++ b/csrc/quantization/fused_kernels/layernorm_utils.cuh @@ -287,8 +287,6 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output, const int VEC_SIZE = 4; int32_t const num_vec_elems = hidden_size >> 2; -// TODO(luka/varun) extract into type-agnostic vectorized quant function to -// replace scaled_fp8_conversion_vec #pragma unroll 4 for (auto i = threadIdx.x; i < num_vec_elems; i += blockDim.x) { vec4_t const in = vec_input[i]; From 2e3b778dbac1c0fde8cbc80be5dedb6930261a98 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Wed, 30 Jul 2025 12:39:55 -0400 Subject: [PATCH 02/11] non-contiguous support Signed-off-by: yewentao256 --- csrc/quantization/fp8/common.cu | 261 +++++++++++++++++++++++++++----- vllm/_custom_ops.py | 7 +- 2 files changed, 224 insertions(+), 44 deletions(-) diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index ad04148da7d4..7be060b4fd7a 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -79,29 +79,164 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel( }); } +template +__global__ void scaled_fp8_quant_kernel_strided( + fp8_type* __restrict__ out, const scalar_t* __restrict__ input, + const float* __restrict__ scale, int hidden_size, int64_t in_row_stride, + int64_t out_row_stride) { + const int token_idx = blockIdx.x; // one token per block + const int tid = threadIdx.x; + + const scalar_t* token_in = input + token_idx * in_row_stride; + fp8_type* token_out = out + token_idx * out_row_stride; + + const float inv_scale = 1.0f / (*scale); + + vectorize_with_alignment<16>( + token_in, token_out, hidden_size, tid, blockDim.x, + [=] __device__(fp8_type & dst, const scalar_t& src) { + dst = scaled_fp8_conversion(static_cast(src), + inv_scale); + }); +} + +template +__global__ void segmented_max_reduction_strided( + float* __restrict__ scale, const scalar_t* __restrict__ input, + int hidden_size, int64_t in_row_stride, int64_t num_tokens) { + __shared__ float cache[256]; + int tid = threadIdx.x; + + // Each thread processes multiple rows in a round-robin fashion. + float local_max = 0.0f; + for (int64_t token = blockIdx.x * blockDim.x + tid; token < num_tokens; + token += blockDim.x * gridDim.x) { + const scalar_t* row_ptr = input + token * in_row_stride; + // Traverse the row +#pragma unroll 4 + for (int e = 0; e < hidden_size; ++e) { + float v = fabsf(static_cast(row_ptr[e])); + local_max = fmaxf(local_max, v); + } + } + + cache[tid] = local_max; + __syncthreads(); + + // Reduction inside block + for (int offset = blockDim.x / 2; offset > 0; offset >>= 1) { + if (tid < offset) { + cache[tid] = fmaxf(cache[tid], cache[tid + offset]); + } + __syncthreads(); + } + + if (tid == 0) { + atomicMaxFloat(scale, cache[0] / quant_type_max_v); + } +} + +template +__global__ void scaled_fp8_quant_kernel_strided_dynamic( + fp8_type* __restrict__ out, const scalar_t* __restrict__ input, + const float* __restrict__ scale, int hidden_size, int64_t in_row_stride, + int64_t out_row_stride) { + const int token_idx = blockIdx.x; + const int tid = threadIdx.x; + + const scalar_t* token_in = input + token_idx * in_row_stride; + fp8_type* token_out = out + token_idx * out_row_stride; + + const float reciprocal_scale = 1.0f / (*scale); + vectorize_with_alignment<16>( + token_in, token_out, hidden_size, tid, blockDim.x, + [=] __device__(fp8_type & dst, const scalar_t& src) { + dst = scaled_fp8_conversion(static_cast(src), + reciprocal_scale); + }); +} + +template +__global__ void dynamic_per_token_scaled_fp8_quant_kernel_strided( + fp8_type* __restrict__ out, float* __restrict__ scale, + const scalar_t* __restrict__ input, const float* __restrict__ scale_ub, + int hidden_size, int64_t in_row_stride, int64_t out_row_stride) { + const int token_idx = blockIdx.x; + const int tid = threadIdx.x; + + const scalar_t* token_in = input + token_idx * in_row_stride; + fp8_type* token_out = out + token_idx * out_row_stride; + + // 1) per-token absmax + float absmax_val = 0.f; + vectorize_read_with_alignment<16>( + token_in, hidden_size, tid, blockDim.x, [&] __device__(scalar_t v) { + absmax_val = fmaxf(absmax_val, fabsf(static_cast(v))); + }); + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmp; + float block_max = BlockReduce(tmp).Reduce(absmax_val, cub::Max{}, blockDim.x); + + __shared__ float token_scale; + if (tid == 0) { + token_scale = scale_ub ? fminf(block_max, *scale_ub) : block_max; + token_scale = fmaxf(token_scale / quant_type_max_v, + min_scaling_factor::val()); + scale[token_idx] = token_scale; + } + __syncthreads(); + + // 2) quantize + vectorize_with_alignment<16>( + token_in, token_out, hidden_size, tid, blockDim.x, + [=] __device__(fp8_type & dst, const scalar_t& src) { + dst = scaled_fp8_conversion(static_cast(src), + token_scale); + }); +} + } // namespace vllm void static_scaled_fp8_quant(torch::Tensor& out, // [..., d] torch::Tensor const& input, // [..., d] torch::Tensor const& scale) // [1] { - TORCH_CHECK(input.is_contiguous()); - TORCH_CHECK(out.is_contiguous()); - int const block_size = 256; - int const num_tokens = input.numel() / input.size(-1); - int const num_elems = input.numel(); - dim3 const grid(num_tokens); - dim3 const block(block_size); + TORCH_CHECK(input.stride(-1) == 1, + "last dimension of input must be contiguous"); + TORCH_CHECK(out.stride(-1) == 1, + "last dimension of output must be contiguous"); + + const int hidden_size = input.size(-1); + const int num_tokens = input.numel() / hidden_size; + const int block_size = 256; + dim3 grid(num_tokens); + dim3 block(block_size); + + const int64_t in_row_stride = input.stride(-2); + const int64_t out_row_stride = out.stride(-2); + const bool is_contig_rows = + (in_row_stride == hidden_size) && (out_row_stride == hidden_size); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] { VLLM_DISPATCH_FP8_TYPES( out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] { - vllm::scaled_fp8_quant_kernel - <<>>( - out.data_ptr(), input.data_ptr(), - scale.data_ptr(), num_elems); + if (is_contig_rows) { + const int num_elems = input.numel(); + vllm::scaled_fp8_quant_kernel + <<>>( + out.data_ptr(), input.data_ptr(), + scale.data_ptr(), num_elems); + } else { + vllm::scaled_fp8_quant_kernel_strided + <<>>( + out.data_ptr(), input.data_ptr(), + scale.data_ptr(), hidden_size, in_row_stride, + out_row_stride); + } }); }); } @@ -110,27 +245,56 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] torch::Tensor const& input, // [..., d] torch::Tensor& scale) // [1] { - TORCH_CHECK(input.is_contiguous()); - TORCH_CHECK(out.is_contiguous()); - int const block_size = 256; - int const num_tokens = input.numel() / input.size(-1); - int const num_elems = input.numel(); - dim3 const grid(num_tokens); - dim3 const block(block_size); + TORCH_CHECK(input.stride(-1) == 1, + "last dimension of input must be contiguous"); + TORCH_CHECK(out.stride(-1) == 1, + "last dimension of output must be contiguous"); + + const int hidden_size = input.size(-1); + const int num_tokens = input.numel() / hidden_size; + const int block_size = 256; + dim3 grid(num_tokens); + dim3 block(block_size); + + const int64_t in_row_stride = input.stride(-2); + const int64_t out_row_stride = out.stride(-2); + const bool is_contig_rows = + (in_row_stride == hidden_size) && (out_row_stride == hidden_size); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // scale tensor should be initialised to <=0 before reduction + if (!is_contig_rows) { + scale.fill_(0.0f); + } + VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] { VLLM_DISPATCH_FP8_TYPES( out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] { - vllm::segmented_max_reduction - <<>>(scale.data_ptr(), - input.data_ptr(), - num_elems); - vllm::scaled_fp8_quant_kernel - <<>>( - out.data_ptr(), input.data_ptr(), - scale.data_ptr(), num_elems); + if (is_contig_rows) { + const int num_elems = input.numel(); + vllm::segmented_max_reduction + <<>>(scale.data_ptr(), + input.data_ptr(), + num_elems); + vllm::scaled_fp8_quant_kernel + <<>>( + out.data_ptr(), input.data_ptr(), + scale.data_ptr(), num_elems); + } else { + vllm::segmented_max_reduction_strided + <<>>( + scale.data_ptr(), input.data_ptr(), + hidden_size, in_row_stride, (int64_t)num_tokens); + + vllm::scaled_fp8_quant_kernel_strided_dynamic + <<>>( + out.data_ptr(), input.data_ptr(), + scale.data_ptr(), hidden_size, in_row_stride, + out_row_stride); + } }); }); } @@ -139,14 +303,21 @@ void dynamic_per_token_scaled_fp8_quant( torch::Tensor& out, // [..., d] torch::Tensor const& input, // [..., d] torch::Tensor& scales, std::optional const& scale_ub) { - TORCH_CHECK(input.is_contiguous()); - TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK(input.stride(-1) == 1, + "last dimension of input must be contiguous"); + TORCH_CHECK(out.stride(-1) == 1, + "last dimension of output must be contiguous"); + + const int hidden_size = input.size(-1); + const int num_tokens = input.numel() / hidden_size; + const int block_size = 256; + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, block_size)); - int const hidden_size = input.size(-1); - int const num_tokens = input.numel() / hidden_size; - int const block_size = 256; - dim3 const grid(num_tokens); - dim3 const block(std::min(hidden_size, block_size)); + const int64_t in_row_stride = input.stride(-2); + const int64_t out_row_stride = out.stride(-2); + const bool is_contig_rows = + (in_row_stride == hidden_size) && (out_row_stride == hidden_size); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -156,13 +327,23 @@ void dynamic_per_token_scaled_fp8_quant( VLLM_DISPATCH_FP8_TYPES( out.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel_fp8_type", [&] { - vllm::dynamic_per_token_scaled_fp8_quant_kernel - <<>>( - out.data_ptr(), scales.data_ptr(), - input.data_ptr(), - scale_ub.has_value() ? scale_ub->data_ptr() - : nullptr, - hidden_size); + if (is_contig_rows) { + vllm::dynamic_per_token_scaled_fp8_quant_kernel + <<>>( + out.data_ptr(), scales.data_ptr(), + input.data_ptr(), + scale_ub.has_value() ? scale_ub->data_ptr() + : nullptr, + hidden_size); + } else { + vllm::dynamic_per_token_scaled_fp8_quant_kernel_strided< + scalar_t, fp8_t><<>>( + out.data_ptr(), scales.data_ptr(), + input.data_ptr(), + scale_ub.has_value() ? scale_ub->data_ptr() + : nullptr, + hidden_size, in_row_stride, out_row_stride); + } }); }); } diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 35345b1be01c..91a99821e83f 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1279,14 +1279,13 @@ def scaled_fp8_quant( device=input.device, dtype=torch.float32) torch.ops._C.dynamic_per_token_scaled_fp8_quant( - output, input.contiguous(), scale, scale_ub) + output, input, scale, scale_ub) else: scale = torch.zeros(1, device=input.device, dtype=torch.float32) - torch.ops._C.dynamic_scaled_fp8_quant(output, input.contiguous(), - scale) + torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) else: assert scale.numel() == 1, f"{scale.shape}" - torch.ops._C.static_scaled_fp8_quant(output, input.contiguous(), scale) + torch.ops._C.static_scaled_fp8_quant(output, input, scale) return output, scale From 9f5d4ab209ceb9329e54b9e28d661a1ab50262e9 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Wed, 30 Jul 2025 13:50:29 -0400 Subject: [PATCH 03/11] use strided only Signed-off-by: yewentao256 --- csrc/quantization/fp8/common.cu | 151 +++++-------------------------- csrc/quantization/fp8/common.cuh | 41 --------- 2 files changed, 22 insertions(+), 170 deletions(-) diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index 7be060b4fd7a..e9833955ca98 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -11,74 +11,6 @@ namespace vllm { -template -__global__ void scaled_fp8_quant_kernel(fp8_type* __restrict__ out, - const scalar_t* __restrict__ input, - const float* __restrict__ scale, - int64_t num_elems) { - int tid = blockDim.x * blockIdx.x + threadIdx.x; - - // Invert the scale so that we can use multiplications to avoid expensive - // division. - const float inverted_scale = 1.0f / (*scale); - vectorize_with_alignment<16>( - input, out, num_elems, tid, blockDim.x * gridDim.x, - [=] __device__(fp8_type & dst, const scalar_t& src) { - dst = scaled_fp8_conversion(static_cast(src), - inverted_scale); - }); -} - -template -__global__ void dynamic_per_token_scaled_fp8_quant_kernel( - fp8_type* __restrict__ out, float* __restrict__ scale, - scalar_t const* __restrict__ input, float const* __restrict__ scale_ub, - const int hidden_size) { - int const tid = threadIdx.x; - int const token_idx = blockIdx.x; - - // Use int64 to avoid overflowing an int32 when calculating this offset - int64_t offset = static_cast(token_idx) * hidden_size; - scalar_t const* __restrict__ token_input = &input[offset]; - fp8_type* __restrict__ token_output = &out[offset]; - - // 1) compute per-token absmax - float absmax_val = 0.0f; - vectorize_read_with_alignment<16>(token_input, hidden_size, tid, blockDim.x, - [&] __device__(const scalar_t& src) { - const float v = - fabsf(static_cast(src)); - absmax_val = fmaxf(absmax_val, v); - }); - - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage reduceStorage; - float const block_absmax_val_maybe = - BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x); - __shared__ float token_scale; - if (tid == 0) { - if (scale_ub) { - token_scale = fminf(block_absmax_val_maybe, *scale_ub); - } else { - token_scale = block_absmax_val_maybe; - } - // token scale computation - token_scale = fmaxf(token_scale / quant_type_max_v, - min_scaling_factor::val()); - scale[token_idx] = token_scale; - } - __syncthreads(); - - // 2) quantize - // Note that we don't use inverted scales so we can match FBGemm impl. - vectorize_with_alignment<16>( - token_input, token_output, hidden_size, tid, blockDim.x, - [=] __device__(fp8_type & dst, const scalar_t& src) { - dst = scaled_fp8_conversion(static_cast(src), - token_scale); - }); -} - template __global__ void scaled_fp8_quant_kernel_strided( fp8_type* __restrict__ out, const scalar_t* __restrict__ input, @@ -215,8 +147,6 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d] const int64_t in_row_stride = input.stride(-2); const int64_t out_row_stride = out.stride(-2); - const bool is_contig_rows = - (in_row_stride == hidden_size) && (out_row_stride == hidden_size); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -224,19 +154,11 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d] input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] { VLLM_DISPATCH_FP8_TYPES( out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] { - if (is_contig_rows) { - const int num_elems = input.numel(); - vllm::scaled_fp8_quant_kernel - <<>>( - out.data_ptr(), input.data_ptr(), - scale.data_ptr(), num_elems); - } else { - vllm::scaled_fp8_quant_kernel_strided - <<>>( - out.data_ptr(), input.data_ptr(), - scale.data_ptr(), hidden_size, in_row_stride, - out_row_stride); - } + vllm::scaled_fp8_quant_kernel_strided + <<>>( + out.data_ptr(), input.data_ptr(), + scale.data_ptr(), hidden_size, in_row_stride, + out_row_stride); }); }); } @@ -258,43 +180,27 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] const int64_t in_row_stride = input.stride(-2); const int64_t out_row_stride = out.stride(-2); - const bool is_contig_rows = - (in_row_stride == hidden_size) && (out_row_stride == hidden_size); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // scale tensor should be initialised to <=0 before reduction - if (!is_contig_rows) { - scale.fill_(0.0f); - } + scale.fill_(0.0f); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] { VLLM_DISPATCH_FP8_TYPES( out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] { - if (is_contig_rows) { - const int num_elems = input.numel(); - vllm::segmented_max_reduction - <<>>(scale.data_ptr(), - input.data_ptr(), - num_elems); - vllm::scaled_fp8_quant_kernel - <<>>( - out.data_ptr(), input.data_ptr(), - scale.data_ptr(), num_elems); - } else { - vllm::segmented_max_reduction_strided - <<>>( - scale.data_ptr(), input.data_ptr(), - hidden_size, in_row_stride, (int64_t)num_tokens); - - vllm::scaled_fp8_quant_kernel_strided_dynamic - <<>>( - out.data_ptr(), input.data_ptr(), - scale.data_ptr(), hidden_size, in_row_stride, - out_row_stride); - } + vllm::segmented_max_reduction_strided + <<>>( + scale.data_ptr(), input.data_ptr(), + hidden_size, in_row_stride, (int64_t)num_tokens); + + vllm::scaled_fp8_quant_kernel_strided_dynamic + <<>>( + out.data_ptr(), input.data_ptr(), + scale.data_ptr(), hidden_size, in_row_stride, + out_row_stride); }); }); } @@ -316,8 +222,6 @@ void dynamic_per_token_scaled_fp8_quant( const int64_t in_row_stride = input.stride(-2); const int64_t out_row_stride = out.stride(-2); - const bool is_contig_rows = - (in_row_stride == hidden_size) && (out_row_stride == hidden_size); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -327,23 +231,12 @@ void dynamic_per_token_scaled_fp8_quant( VLLM_DISPATCH_FP8_TYPES( out.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel_fp8_type", [&] { - if (is_contig_rows) { - vllm::dynamic_per_token_scaled_fp8_quant_kernel - <<>>( - out.data_ptr(), scales.data_ptr(), - input.data_ptr(), - scale_ub.has_value() ? scale_ub->data_ptr() - : nullptr, - hidden_size); - } else { - vllm::dynamic_per_token_scaled_fp8_quant_kernel_strided< - scalar_t, fp8_t><<>>( - out.data_ptr(), scales.data_ptr(), - input.data_ptr(), - scale_ub.has_value() ? scale_ub->data_ptr() - : nullptr, - hidden_size, in_row_stride, out_row_stride); - } + vllm::dynamic_per_token_scaled_fp8_quant_kernel_strided< + scalar_t, fp8_t><<>>( + out.data_ptr(), scales.data_ptr(), + input.data_ptr(), + scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, + hidden_size, in_row_stride, out_row_stride); }); }); } diff --git a/csrc/quantization/fp8/common.cuh b/csrc/quantization/fp8/common.cuh index 973d0bfc1648..1aad6330c44b 100644 --- a/csrc/quantization/fp8/common.cuh +++ b/csrc/quantization/fp8/common.cuh @@ -55,45 +55,4 @@ __device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val, #endif } -// Compute the absolute maximum m of the input tensor and store -// m / float8_e4m3::max() in *scale. Each thread block performs a -// reduction tree and the memory in scale is atomically updated. -// So to get the right answer, *scale needs to be initialized to -// a value <= 0.0 and we need to wait for all thread blocks to -// finish before consuming *scale. -template -__global__ void segmented_max_reduction(float* __restrict__ scale, - const scalar_t* __restrict__ input, - int64_t num_elems) { - __shared__ float cache[256]; - int64_t i = blockDim.x * blockIdx.x + threadIdx.x; - - // First store maximum for all values processes by - // the current thread in cache[threadIdx.x] - scalar_t tmp = 0.0; - while (i < num_elems) { - float x = static_cast(input[i]); - tmp = fmaxf(tmp, fabsf(x)); - i += blockDim.x * gridDim.x; - } - cache[threadIdx.x] = tmp; - - __syncthreads(); - - // Now perform parallel reduction within the thread block - int ib = blockDim.x / 2; - while (ib != 0) { - if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) { - cache[threadIdx.x] = cache[threadIdx.x + ib]; - } - __syncthreads(); - ib /= 2; - } - // Finally, since cache[0] contains the maximum for this thread block, - // atomically write the max to the target location - if (threadIdx.x == 0) { - atomicMaxFloat(scale, cache[0] / quant_type_max_v); - } -} - } // namespace vllm From 51e96885e46fbd5da6bd7073d95678b063a62e8b Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Wed, 30 Jul 2025 13:55:22 -0400 Subject: [PATCH 04/11] empty instead of zeros Signed-off-by: yewentao256 --- vllm/_custom_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 91a99821e83f..e6f69e2344ef 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1281,7 +1281,7 @@ def scaled_fp8_quant( torch.ops._C.dynamic_per_token_scaled_fp8_quant( output, input, scale, scale_ub) else: - scale = torch.zeros(1, device=input.device, dtype=torch.float32) + scale = torch.empty(1, device=input.device, dtype=torch.float32) torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) else: assert scale.numel() == 1, f"{scale.shape}" From 040706fffd57c6e34010470d6f4ede6c089d06d0 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Wed, 30 Jul 2025 14:19:10 -0400 Subject: [PATCH 05/11] fix segmented_max_reduction_strided issue Signed-off-by: yewentao256 --- csrc/quantization/fp8/common.cu | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index e9833955ca98..c9f69a946763 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -38,24 +38,26 @@ __global__ void segmented_max_reduction_strided( int hidden_size, int64_t in_row_stride, int64_t num_tokens) { __shared__ float cache[256]; int tid = threadIdx.x; + int64_t token_idx = blockIdx.x; - // Each thread processes multiple rows in a round-robin fashion. - float local_max = 0.0f; - for (int64_t token = blockIdx.x * blockDim.x + tid; token < num_tokens; - token += blockDim.x * gridDim.x) { - const scalar_t* row_ptr = input + token * in_row_stride; - // Traverse the row -#pragma unroll 4 - for (int e = 0; e < hidden_size; ++e) { - float v = fabsf(static_cast(row_ptr[e])); - local_max = fmaxf(local_max, v); - } + // one block per token. Guard in case gridDim.x > num_tokens. + if (token_idx >= num_tokens) { + return; + } + + const scalar_t* row_ptr = input + token_idx * in_row_stride; + + // each thread scans elements of the row in a strided fashion. + float thread_max = 0.0f; + for (int e = tid; e < hidden_size; e += blockDim.x) { + float v = fabsf(static_cast(row_ptr[e])); + thread_max = fmaxf(thread_max, v); } - cache[tid] = local_max; + cache[tid] = thread_max; __syncthreads(); - // Reduction inside block + // prallel reduction to find row max. for (int offset = blockDim.x / 2; offset > 0; offset >>= 1) { if (tid < offset) { cache[tid] = fmaxf(cache[tid], cache[tid + offset]); @@ -63,6 +65,7 @@ __global__ void segmented_max_reduction_strided( __syncthreads(); } + // thread 0 updates global scale (per-tensor) atomically. if (tid == 0) { atomicMaxFloat(scale, cache[0] / quant_type_max_v); } From 219b9bb663ed7b25c933e253ebfd9583ff2aab0e Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Wed, 30 Jul 2025 14:37:37 -0400 Subject: [PATCH 06/11] fix int64 overflow Signed-off-by: yewentao256 --- csrc/quantization/fp8/common.cu | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index c9f69a946763..8f076171fbf1 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -99,8 +99,11 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel_strided( const int token_idx = blockIdx.x; const int tid = threadIdx.x; - const scalar_t* token_in = input + token_idx * in_row_stride; - fp8_type* token_out = out + token_idx * out_row_stride; + // Use int64 to avoid overflowing an int32 when calculating this offset + int64_t in_offset = static_cast(token_idx) * in_row_stride; + int64_t out_offset = static_cast(token_idx) * out_row_stride; + const scalar_t* token_in = input + in_offset; + fp8_type* token_out = out + out_offset; // 1) per-token absmax float absmax_val = 0.f; From 698e6341b4a1721717481910e659953d799861cb Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Wed, 30 Jul 2025 16:40:43 -0400 Subject: [PATCH 07/11] update through comments Signed-off-by: yewentao256 --- csrc/quantization/fp8/common.cu | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index 8f076171fbf1..df82d958efc6 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -37,7 +37,7 @@ __global__ void segmented_max_reduction_strided( float* __restrict__ scale, const scalar_t* __restrict__ input, int hidden_size, int64_t in_row_stride, int64_t num_tokens) { __shared__ float cache[256]; - int tid = threadIdx.x; + const int tid = threadIdx.x; int64_t token_idx = blockIdx.x; // one block per token. Guard in case gridDim.x > num_tokens. @@ -57,7 +57,7 @@ __global__ void segmented_max_reduction_strided( cache[tid] = thread_max; __syncthreads(); - // prallel reduction to find row max. + // parallel reduction to find row max. for (int offset = blockDim.x / 2; offset > 0; offset >>= 1) { if (tid < offset) { cache[tid] = fmaxf(cache[tid], cache[tid + offset]); @@ -114,7 +114,8 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel_strided( using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage tmp; - float block_max = BlockReduce(tmp).Reduce(absmax_val, cub::Max{}, blockDim.x); + const float block_max = + BlockReduce(tmp).Reduce(absmax_val, cub::Max{}, blockDim.x); __shared__ float token_scale; if (tid == 0) { @@ -200,7 +201,8 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] vllm::segmented_max_reduction_strided <<>>( scale.data_ptr(), input.data_ptr(), - hidden_size, in_row_stride, (int64_t)num_tokens); + hidden_size, in_row_stride, + static_cast(num_tokens)); vllm::scaled_fp8_quant_kernel_strided_dynamic <<>>( From 53e0f2c719b92271739d077859fbe889d2614296 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Wed, 30 Jul 2025 16:44:19 -0400 Subject: [PATCH 08/11] use int64_t Signed-off-by: yewentao256 --- csrc/quantization/fp8/common.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index df82d958efc6..d6b3419f9a77 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -16,7 +16,7 @@ __global__ void scaled_fp8_quant_kernel_strided( fp8_type* __restrict__ out, const scalar_t* __restrict__ input, const float* __restrict__ scale, int hidden_size, int64_t in_row_stride, int64_t out_row_stride) { - const int token_idx = blockIdx.x; // one token per block + const int64_t token_idx = blockIdx.x; // one token per block const int tid = threadIdx.x; const scalar_t* token_in = input + token_idx * in_row_stride; @@ -76,7 +76,7 @@ __global__ void scaled_fp8_quant_kernel_strided_dynamic( fp8_type* __restrict__ out, const scalar_t* __restrict__ input, const float* __restrict__ scale, int hidden_size, int64_t in_row_stride, int64_t out_row_stride) { - const int token_idx = blockIdx.x; + const int64_t token_idx = blockIdx.x; const int tid = threadIdx.x; const scalar_t* token_in = input + token_idx * in_row_stride; @@ -96,7 +96,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel_strided( fp8_type* __restrict__ out, float* __restrict__ scale, const scalar_t* __restrict__ input, const float* __restrict__ scale_ub, int hidden_size, int64_t in_row_stride, int64_t out_row_stride) { - const int token_idx = blockIdx.x; + const int64_t token_idx = blockIdx.x; const int tid = threadIdx.x; // Use int64 to avoid overflowing an int32 when calculating this offset From 13be9a26d6887ee9b4616ed527c4928e874f2f10 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Wed, 30 Jul 2025 16:49:44 -0400 Subject: [PATCH 09/11] update 0 Signed-off-by: yewentao256 --- csrc/quantization/fp8/common.cu | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index d6b3419f9a77..5fe5dd04bd89 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -2,6 +2,7 @@ #include "dispatch_utils.h" #include "../vectorization_utils.cuh" #include +#include #ifndef USE_ROCM #include @@ -192,7 +193,8 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // scale tensor should be initialised to <=0 before reduction - scale.fill_(0.0f); + AT_CUDA_CHECK( + cudaMemsetAsync(scale.data_ptr(), 0, sizeof(float), stream)); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] { From c56ac2bf48ee85525c471a3756034e1b41bcc1a4 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Wed, 30 Jul 2025 17:29:39 -0400 Subject: [PATCH 10/11] add back comments Signed-off-by: yewentao256 --- csrc/quantization/fused_kernels/layernorm_utils.cuh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/csrc/quantization/fused_kernels/layernorm_utils.cuh b/csrc/quantization/fused_kernels/layernorm_utils.cuh index 3fb80afe806f..3f188872d80d 100644 --- a/csrc/quantization/fused_kernels/layernorm_utils.cuh +++ b/csrc/quantization/fused_kernels/layernorm_utils.cuh @@ -287,6 +287,8 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output, const int VEC_SIZE = 4; int32_t const num_vec_elems = hidden_size >> 2; +// TODO(luka/varun) extract into type-agnostic vectorized quant function to +// replace scaled_fp8_conversion_vec #pragma unroll 4 for (auto i = threadIdx.x; i < num_vec_elems; i += blockDim.x) { vec4_t const in = vec_input[i]; From 7dd15865792811f01f8c43d39c9846b8152c4de3 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Fri, 1 Aug 2025 17:10:00 -0400 Subject: [PATCH 11/11] add unit test Signed-off-by: yewentao256 --- tests/quantization/test_fp8.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py index e5ab7b3dd3cf..0b37c83c92c2 100644 --- a/tests/quantization/test_fp8.py +++ b/tests/quantization/test_fp8.py @@ -194,3 +194,36 @@ def per_tensor_dequantize(tensor, inv_scale, dtype): ref_y, per_tensor_dequantize(torch.narrow(y, 0, 0, x.shape[0]), inv_scale, dtype)) + + # non-contiguous input with padding + m, n, padded_stride = 975, 512, 576 + padded_tensor = (torch.randn(size=(m, padded_stride), device="cuda") * + 13).to(dtype) + x_nc = padded_tensor[:, :n] # shape (m, n) with stride (padded_stride, 1) + + assert not x_nc.is_contiguous() + assert x_nc.stride(0) == padded_stride + + # dynamic quantization + ref_y_nc, inv_scale_nc = ops.scaled_fp8_quant(x_nc, None) + ref_y_nc = per_tensor_dequantize(ref_y_nc, inv_scale_nc, dtype) + + # reference dynamic quantization + y_nc = quantize_ref(x_nc, inv_scale_nc) + torch.testing.assert_close( + ref_y_nc, per_tensor_dequantize(y_nc, inv_scale_nc, dtype)) + + # static quantization + y_nc, _ = ops.scaled_fp8_quant(x_nc, inv_scale_nc) + torch.testing.assert_close( + ref_y_nc, per_tensor_dequantize(y_nc, inv_scale_nc, dtype)) + + # padding after non-contiguous input quantization + y_nc_pad, _ = ops.scaled_fp8_quant(x_nc, + inv_scale_nc, + num_token_padding=m + 10) + assert y_nc_pad.shape[0] == m + 10 + torch.testing.assert_close( + ref_y_nc, + per_tensor_dequantize(torch.narrow(y_nc_pad, 0, 0, x_nc.shape[0]), + inv_scale_nc, dtype))