diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index 0e1eab66f0b9..5fe5dd04bd89 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -1,7 +1,8 @@ #include "common.cuh" #include "dispatch_utils.h" - +#include "../vectorization_utils.cuh" #include +#include #ifndef USE_ROCM #include @@ -12,74 +13,127 @@ namespace vllm { template -__global__ void scaled_fp8_quant_kernel(fp8_type* __restrict__ out, - const scalar_t* __restrict__ input, - const float* __restrict__ scale, - int64_t num_elems) { - int tid = blockDim.x * blockIdx.x + threadIdx.x; - - // Invert the scale so that we can use multiplications to avoid expensive - // division. - const float inverted_scale = 1.0f / (*scale); - scaled_fp8_conversion_vec( - out, input, inverted_scale, num_elems, tid, blockDim.x * gridDim.x); +__global__ void scaled_fp8_quant_kernel_strided( + fp8_type* __restrict__ out, const scalar_t* __restrict__ input, + const float* __restrict__ scale, int hidden_size, int64_t in_row_stride, + int64_t out_row_stride) { + const int64_t token_idx = blockIdx.x; // one token per block + const int tid = threadIdx.x; + + const scalar_t* token_in = input + token_idx * in_row_stride; + fp8_type* token_out = out + token_idx * out_row_stride; + + const float inv_scale = 1.0f / (*scale); + + vectorize_with_alignment<16>( + token_in, token_out, hidden_size, tid, blockDim.x, + [=] __device__(fp8_type & dst, const scalar_t& src) { + dst = scaled_fp8_conversion(static_cast(src), + inv_scale); + }); } template -__global__ void dynamic_per_token_scaled_fp8_quant_kernel( - fp8_type* __restrict__ out, float* __restrict__ scale, - scalar_t const* __restrict__ input, float const* __restrict__ scale_ub, - const int hidden_size) { - int const tid = threadIdx.x; - int const token_idx = blockIdx.x; +__global__ void segmented_max_reduction_strided( + float* __restrict__ scale, const scalar_t* __restrict__ input, + int hidden_size, int64_t in_row_stride, int64_t num_tokens) { + __shared__ float cache[256]; + const int tid = threadIdx.x; + int64_t token_idx = blockIdx.x; + + // one block per token. Guard in case gridDim.x > num_tokens. + if (token_idx >= num_tokens) { + return; + } - // Use int64 to avoid overflowing an int32 when calculating this offset - int64_t offset = static_cast(token_idx) * hidden_size; - scalar_t const* __restrict__ token_input = &input[offset]; - fp8_type* __restrict__ token_output = &out[offset]; - - // For vectorization, token_input and token_output pointers need to be - // aligned at 32-byte and 16-byte addresses respectively. - bool const can_vectorize = hidden_size % 16 == 0; - - float absmax_val = 0.0f; - if (can_vectorize) { - absmax_val = thread_max_vec(token_input, hidden_size, tid, blockDim.x); - } else { - for (int i = tid; i < hidden_size; i += blockDim.x) { - float const x = static_cast(token_input[i]); - absmax_val = fmaxf(absmax_val, fabsf(x)); + const scalar_t* row_ptr = input + token_idx * in_row_stride; + + // each thread scans elements of the row in a strided fashion. + float thread_max = 0.0f; + for (int e = tid; e < hidden_size; e += blockDim.x) { + float v = fabsf(static_cast(row_ptr[e])); + thread_max = fmaxf(thread_max, v); + } + + cache[tid] = thread_max; + __syncthreads(); + + // parallel reduction to find row max. + for (int offset = blockDim.x / 2; offset > 0; offset >>= 1) { + if (tid < offset) { + cache[tid] = fmaxf(cache[tid], cache[tid + offset]); } + __syncthreads(); } + // thread 0 updates global scale (per-tensor) atomically. + if (tid == 0) { + atomicMaxFloat(scale, cache[0] / quant_type_max_v); + } +} + +template +__global__ void scaled_fp8_quant_kernel_strided_dynamic( + fp8_type* __restrict__ out, const scalar_t* __restrict__ input, + const float* __restrict__ scale, int hidden_size, int64_t in_row_stride, + int64_t out_row_stride) { + const int64_t token_idx = blockIdx.x; + const int tid = threadIdx.x; + + const scalar_t* token_in = input + token_idx * in_row_stride; + fp8_type* token_out = out + token_idx * out_row_stride; + + const float reciprocal_scale = 1.0f / (*scale); + vectorize_with_alignment<16>( + token_in, token_out, hidden_size, tid, blockDim.x, + [=] __device__(fp8_type & dst, const scalar_t& src) { + dst = scaled_fp8_conversion(static_cast(src), + reciprocal_scale); + }); +} + +template +__global__ void dynamic_per_token_scaled_fp8_quant_kernel_strided( + fp8_type* __restrict__ out, float* __restrict__ scale, + const scalar_t* __restrict__ input, const float* __restrict__ scale_ub, + int hidden_size, int64_t in_row_stride, int64_t out_row_stride) { + const int64_t token_idx = blockIdx.x; + const int tid = threadIdx.x; + + // Use int64 to avoid overflowing an int32 when calculating this offset + int64_t in_offset = static_cast(token_idx) * in_row_stride; + int64_t out_offset = static_cast(token_idx) * out_row_stride; + const scalar_t* token_in = input + in_offset; + fp8_type* token_out = out + out_offset; + + // 1) per-token absmax + float absmax_val = 0.f; + vectorize_read_with_alignment<16>( + token_in, hidden_size, tid, blockDim.x, [&] __device__(scalar_t v) { + absmax_val = fmaxf(absmax_val, fabsf(static_cast(v))); + }); + using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage reduceStorage; - float const block_absmax_val_maybe = - BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x); + __shared__ typename BlockReduce::TempStorage tmp; + const float block_max = + BlockReduce(tmp).Reduce(absmax_val, cub::Max{}, blockDim.x); + __shared__ float token_scale; if (tid == 0) { - if (scale_ub) { - token_scale = fminf(block_absmax_val_maybe, *scale_ub); - } else { - token_scale = block_absmax_val_maybe; - } - // token scale computation + token_scale = scale_ub ? fminf(block_max, *scale_ub) : block_max; token_scale = fmaxf(token_scale / quant_type_max_v, min_scaling_factor::val()); scale[token_idx] = token_scale; } __syncthreads(); - // Note that we don't use inverted scales so we can match FBGemm impl. - if (can_vectorize) { - scaled_fp8_conversion_vec( - token_output, token_input, token_scale, hidden_size, tid, blockDim.x); - } else { - for (int i = tid; i < hidden_size; i += blockDim.x) { - token_output[i] = scaled_fp8_conversion( - static_cast(token_input[i]), token_scale); - } - } + // 2) quantize + vectorize_with_alignment<16>( + token_in, token_out, hidden_size, tid, blockDim.x, + [=] __device__(fp8_type & dst, const scalar_t& src) { + dst = scaled_fp8_conversion(static_cast(src), + token_scale); + }); } } // namespace vllm @@ -88,23 +142,31 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d] torch::Tensor const& input, // [..., d] torch::Tensor const& scale) // [1] { - TORCH_CHECK(input.is_contiguous()); - TORCH_CHECK(out.is_contiguous()); - int const block_size = 256; - int const num_tokens = input.numel() / input.size(-1); - int const num_elems = input.numel(); - dim3 const grid(num_tokens); - dim3 const block(block_size); + TORCH_CHECK(input.stride(-1) == 1, + "last dimension of input must be contiguous"); + TORCH_CHECK(out.stride(-1) == 1, + "last dimension of output must be contiguous"); + + const int hidden_size = input.size(-1); + const int num_tokens = input.numel() / hidden_size; + const int block_size = 256; + dim3 grid(num_tokens); + dim3 block(block_size); + + const int64_t in_row_stride = input.stride(-2); + const int64_t out_row_stride = out.stride(-2); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] { VLLM_DISPATCH_FP8_TYPES( out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] { - vllm::scaled_fp8_quant_kernel + vllm::scaled_fp8_quant_kernel_strided <<>>( out.data_ptr(), input.data_ptr(), - scale.data_ptr(), num_elems); + scale.data_ptr(), hidden_size, in_row_stride, + out_row_stride); }); }); } @@ -113,27 +175,42 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] torch::Tensor const& input, // [..., d] torch::Tensor& scale) // [1] { - TORCH_CHECK(input.is_contiguous()); - TORCH_CHECK(out.is_contiguous()); - int const block_size = 256; - int const num_tokens = input.numel() / input.size(-1); - int const num_elems = input.numel(); - dim3 const grid(num_tokens); - dim3 const block(block_size); + TORCH_CHECK(input.stride(-1) == 1, + "last dimension of input must be contiguous"); + TORCH_CHECK(out.stride(-1) == 1, + "last dimension of output must be contiguous"); + + const int hidden_size = input.size(-1); + const int num_tokens = input.numel() / hidden_size; + const int block_size = 256; + dim3 grid(num_tokens); + dim3 block(block_size); + + const int64_t in_row_stride = input.stride(-2); + const int64_t out_row_stride = out.stride(-2); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // scale tensor should be initialised to <=0 before reduction + AT_CUDA_CHECK( + cudaMemsetAsync(scale.data_ptr(), 0, sizeof(float), stream)); + VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] { VLLM_DISPATCH_FP8_TYPES( out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] { - vllm::segmented_max_reduction - <<>>(scale.data_ptr(), - input.data_ptr(), - num_elems); - vllm::scaled_fp8_quant_kernel + vllm::segmented_max_reduction_strided + <<>>( + scale.data_ptr(), input.data_ptr(), + hidden_size, in_row_stride, + static_cast(num_tokens)); + + vllm::scaled_fp8_quant_kernel_strided_dynamic <<>>( out.data_ptr(), input.data_ptr(), - scale.data_ptr(), num_elems); + scale.data_ptr(), hidden_size, in_row_stride, + out_row_stride); }); }); } @@ -142,14 +219,19 @@ void dynamic_per_token_scaled_fp8_quant( torch::Tensor& out, // [..., d] torch::Tensor const& input, // [..., d] torch::Tensor& scales, std::optional const& scale_ub) { - TORCH_CHECK(input.is_contiguous()); - TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK(input.stride(-1) == 1, + "last dimension of input must be contiguous"); + TORCH_CHECK(out.stride(-1) == 1, + "last dimension of output must be contiguous"); - int const hidden_size = input.size(-1); - int const num_tokens = input.numel() / hidden_size; - int const block_size = 256; - dim3 const grid(num_tokens); - dim3 const block(std::min(hidden_size, block_size)); + const int hidden_size = input.size(-1); + const int num_tokens = input.numel() / hidden_size; + const int block_size = 256; + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, block_size)); + + const int64_t in_row_stride = input.stride(-2); + const int64_t out_row_stride = out.stride(-2); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -159,13 +241,12 @@ void dynamic_per_token_scaled_fp8_quant( VLLM_DISPATCH_FP8_TYPES( out.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel_fp8_type", [&] { - vllm::dynamic_per_token_scaled_fp8_quant_kernel - <<>>( - out.data_ptr(), scales.data_ptr(), - input.data_ptr(), - scale_ub.has_value() ? scale_ub->data_ptr() - : nullptr, - hidden_size); + vllm::dynamic_per_token_scaled_fp8_quant_kernel_strided< + scalar_t, fp8_t><<>>( + out.data_ptr(), scales.data_ptr(), + input.data_ptr(), + scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, + hidden_size, in_row_stride, out_row_stride); }); }); } diff --git a/csrc/quantization/fp8/common.cuh b/csrc/quantization/fp8/common.cuh index d36f94a8f10d..1aad6330c44b 100644 --- a/csrc/quantization/fp8/common.cuh +++ b/csrc/quantization/fp8/common.cuh @@ -55,111 +55,4 @@ __device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val, #endif } -// Compute the absolute maximum m of the input tensor and store -// m / float8_e4m3::max() in *scale. Each thread block performs a -// reduction tree and the memory in scale is atomically updated. -// So to get the right answer, *scale needs to be initialized to -// a value <= 0.0 and we need to wait for all thread blocks to -// finish before consuming *scale. -template -__global__ void segmented_max_reduction(float* __restrict__ scale, - const scalar_t* __restrict__ input, - int64_t num_elems) { - __shared__ float cache[256]; - int64_t i = blockDim.x * blockIdx.x + threadIdx.x; - - // First store maximum for all values processes by - // the current thread in cache[threadIdx.x] - scalar_t tmp = 0.0; - while (i < num_elems) { - float x = static_cast(input[i]); - tmp = fmaxf(tmp, fabsf(x)); - i += blockDim.x * gridDim.x; - } - cache[threadIdx.x] = tmp; - - __syncthreads(); - - // Now perform parallel reduction within the thread block - int ib = blockDim.x / 2; - while (ib != 0) { - if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) { - cache[threadIdx.x] = cache[threadIdx.x + ib]; - } - __syncthreads(); - ib /= 2; - } - // Finally, since cache[0] contains the maximum for this thread block, - // atomically write the max to the target location - if (threadIdx.x == 0) { - atomicMaxFloat(scale, cache[0] / quant_type_max_v); - } -} - -template -__device__ float thread_max_vec(scalar_t const* __restrict__ input, - int64_t const num_elems, int const tid, - int const step) { - constexpr size_t VEC_SIZE = 16; - using scalarxN_t = vec_n_t; - // Vectorized input/output to better utilize memory bandwidth. - auto const* vectorized_in = reinterpret_cast(input); - - // num_elems / VEC_SIZE (which is 16) - int64_t const num_vec_elems = num_elems >> 4; - float absmax_val = 0.0f; - -#pragma unroll - for (int64_t i = tid; i < num_vec_elems; i += step) { - scalarxN_t in_vec = vectorized_in[i]; -#pragma unroll - for (int j = 0; j < VEC_SIZE; ++j) { - absmax_val = fmaxf(absmax_val, fabsf(in_vec.val[j])); - } - } - - // Handle the remaining elements if num_elems is not divisible by VEC_SIZE - for (int64_t i = num_vec_elems * VEC_SIZE + tid; i < num_elems; i += step) { - absmax_val = fmaxf(absmax_val, fabsf(input[i])); - } - - return absmax_val; -} - -template -__device__ void scaled_fp8_conversion_vec(fp8_type* __restrict__ out, - scalar_t const* __restrict__ input, - float const scale, - int64_t const num_elems, - int const tid, int const step) { - constexpr size_t VEC_SIZE = 16; - using scalarxN_t = vec_n_t; - using float8xN_t = q8_n_t; - // Vectorized input/output to better utilize memory bandwidth. - auto const* vectorized_in = reinterpret_cast(input); - auto* vectorized_out = reinterpret_cast(out); - - // num_elems / VEC_SIZE (which is 16) - int64_t const num_vec_elems = num_elems >> 4; - -#pragma unroll - for (int64_t i = tid; i < num_vec_elems; i += step) { - scalarxN_t in_vec = vectorized_in[i]; - float8xN_t out_vec; - -#pragma unroll - for (int j = 0; j < VEC_SIZE; ++j) { - out_vec.val[j] = scaled_fp8_conversion( - static_cast(in_vec.val[j]), scale); - } - vectorized_out[i] = out_vec; - } - - // Handle the remaining elements if num_elems is not divisible by VEC_SIZE - for (int64_t i = num_vec_elems * VEC_SIZE + tid; i < num_elems; i += step) { - out[i] = scaled_fp8_conversion( - static_cast(input[i]), scale); - } -} - } // namespace vllm diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py index e5ab7b3dd3cf..0b37c83c92c2 100644 --- a/tests/quantization/test_fp8.py +++ b/tests/quantization/test_fp8.py @@ -194,3 +194,36 @@ def per_tensor_dequantize(tensor, inv_scale, dtype): ref_y, per_tensor_dequantize(torch.narrow(y, 0, 0, x.shape[0]), inv_scale, dtype)) + + # non-contiguous input with padding + m, n, padded_stride = 975, 512, 576 + padded_tensor = (torch.randn(size=(m, padded_stride), device="cuda") * + 13).to(dtype) + x_nc = padded_tensor[:, :n] # shape (m, n) with stride (padded_stride, 1) + + assert not x_nc.is_contiguous() + assert x_nc.stride(0) == padded_stride + + # dynamic quantization + ref_y_nc, inv_scale_nc = ops.scaled_fp8_quant(x_nc, None) + ref_y_nc = per_tensor_dequantize(ref_y_nc, inv_scale_nc, dtype) + + # reference dynamic quantization + y_nc = quantize_ref(x_nc, inv_scale_nc) + torch.testing.assert_close( + ref_y_nc, per_tensor_dequantize(y_nc, inv_scale_nc, dtype)) + + # static quantization + y_nc, _ = ops.scaled_fp8_quant(x_nc, inv_scale_nc) + torch.testing.assert_close( + ref_y_nc, per_tensor_dequantize(y_nc, inv_scale_nc, dtype)) + + # padding after non-contiguous input quantization + y_nc_pad, _ = ops.scaled_fp8_quant(x_nc, + inv_scale_nc, + num_token_padding=m + 10) + assert y_nc_pad.shape[0] == m + 10 + torch.testing.assert_close( + ref_y_nc, + per_tensor_dequantize(torch.narrow(y_nc_pad, 0, 0, x_nc.shape[0]), + inv_scale_nc, dtype)) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 35345b1be01c..e6f69e2344ef 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1279,14 +1279,13 @@ def scaled_fp8_quant( device=input.device, dtype=torch.float32) torch.ops._C.dynamic_per_token_scaled_fp8_quant( - output, input.contiguous(), scale, scale_ub) + output, input, scale, scale_ub) else: - scale = torch.zeros(1, device=input.device, dtype=torch.float32) - torch.ops._C.dynamic_scaled_fp8_quant(output, input.contiguous(), - scale) + scale = torch.empty(1, device=input.device, dtype=torch.float32) + torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) else: assert scale.numel() == 1, f"{scale.shape}" - torch.ops._C.static_scaled_fp8_quant(output, input.contiguous(), scale) + torch.ops._C.static_scaled_fp8_quant(output, input, scale) return output, scale