diff --git a/CMakeLists.txt b/CMakeLists.txt index 062aa0b..ee74f27 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3,7 +3,7 @@ cmake_minimum_required(VERSION 3.26 FATAL_ERROR) project(driss_torch LANGUAGES CXX CUDA) # Set the C++ standard for all targets -set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_STANDARD 20) # This might be unsafe since pytorch use std17 set(CMAKE_CXX_STANDARD_REQUIRED ON) # Enable better clangd support diff --git a/benchmarks/benchmark_saturated_casting.py b/benchmarks/benchmark_saturated_casting.py index d66dd45..9fe6526 100644 --- a/benchmarks/benchmark_saturated_casting.py +++ b/benchmarks/benchmark_saturated_casting.py @@ -6,6 +6,7 @@ import torch from driss_torch import saturated_cast +from float8_experimental.float8_utils import amax_to_scale from jsonargparse import CLI from tabulate import tabulate @@ -22,16 +23,17 @@ def eager_scaled_quant( a: torch.Tensor, - scale: torch.Tensor, + amax: torch.Tensor, fp8_dtype: torch.dtype, ): """Quantize tensor to fp8 using a delayed scaled and calculate abs_max Args: a: Input tensor to quantize - scale: Scale to apply to input tensor, calculated from previous abs_max + amax of the input tensor fp8_dtype: FP8 datatype to quantize to """ + scale = amax_to_scale(amax, fp8_dtype, a.dtype) out = a * scale out = torch.where(out > torch.finfo(fp8_dtype).max, torch.finfo(fp8_dtype).max, out) out = torch.where(out < -1 * torch.finfo(fp8_dtype).max, -1 * torch.finfo(fp8_dtype).max, out) @@ -97,21 +99,17 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult: config.num_rows, config.num_cols, dtype=config.high_precision_dtype, device=device ) cuda_hp_tensor = high_precision_tensor.clone() - cuda_scale = torch.ones(1, dtype=torch.bfloat16, device=device) + cuda_amax = torch.abs(high_precision_tensor).max().to(torch.float32) - eager_abs_max = torch.abs(high_precision_tensor).max().to(torch.float32) - - scale = torch.finfo(config.low_precision_dtype).max / eager_abs_max - scale = scale.to(torch.float32) - scale = torch.ones(1, dtype=torch.float32, device=device) + eager_amax = torch.abs(high_precision_tensor).max().to(torch.float32) # Correctness check: - cuda_out = saturated_cast(cuda_hp_tensor, config.low_precision_dtype, cuda_scale) + cuda_out, cuda_scale = saturated_cast(cuda_hp_tensor, eager_amax, config.low_precision_dtype) cuda_out_hp = cuda_out.to(config.high_precision_dtype) - eager_out = eager_scaled_quant(high_precision_tensor, scale, config.low_precision_dtype).to( - config.high_precision_dtype - ) + eager_out = eager_scaled_quant( + high_precision_tensor, eager_amax, config.low_precision_dtype + ).to(config.high_precision_dtype) eager_out_hp = eager_out.to(config.high_precision_dtype) torch.testing.assert_close(cuda_out_hp, eager_out_hp, rtol=1e-3, atol=1e-3) @@ -119,20 +117,20 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult: cuda_time = benchmark_torch_function_in_microseconds( saturated_cast, cuda_hp_tensor, + eager_amax, config.low_precision_dtype, - cuda_scale, ) pytorch_time = benchmark_torch_function_in_microseconds( eager_scaled_quant, high_precision_tensor, - scale, + eager_amax, config.low_precision_dtype, ) compiled_pytorch_fn = torch.compile(eager_scaled_quant, fullgraph=True) compiled_pytorch_time = benchmark_torch_function_in_microseconds( compiled_pytorch_fn, high_precision_tensor, - scale, + eager_amax, config.low_precision_dtype, ) return ExperimentResult( diff --git a/driss_torch/__init__.py b/driss_torch/__init__.py index f7f28b9..52b9853 100644 --- a/driss_torch/__init__.py +++ b/driss_torch/__init__.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Optional +from typing import Tuple import torch @@ -17,31 +17,22 @@ def list_ops(): return ops.__dir__() -def add_one(x: torch.Tensor) -> torch.Tensor: - """Add one to a tensor. - This is a dummy test op to demonstrate how to add custom ops to PyTorch. - Args: - x: The input tensor. - Returns: - The output tensor. - """ - return ops.add_one(x) - - def saturated_cast( x: torch.Tensor, - scale: torch.Tensor, + amax: torch.Tensor, out_dtype: torch.dtype, transpose: bool = False, -) -> torch.Tensor: +) -> Tuple[torch.Tensor, torch.Tensor]: """This op takes in a tensor and returns the fp8 saturated casted version of it. Args; x: The input tensor. out_dtype: The output data type, must be a float8 dtype. - scale: An optional on device tensor, this is expected to be a singleton tensor whose value will be multiplied - x before casting + scale: An on device tensor, this is expected to be a singleton tensor whose value is + the max(abs(x) before casting, we will use this to calculate the scale + using the formula `scale = amax / max(max_abs(x), 1e-12)` transpose: If true will transpose the input tensor during casting Returns: - The output tensor. + The output tensor. And the on device scale tensor. """ - return ops.saturated_cast(x, scale, out_dtype, transpose) + assert not transpose, "Transpose is not supported yet" + return ops.saturated_cast(x, amax, out_dtype, transpose) diff --git a/driss_torch/abstract_impls.py b/driss_torch/abstract_impls.py index f96abf1..0393f5e 100644 --- a/driss_torch/abstract_impls.py +++ b/driss_torch/abstract_impls.py @@ -1,14 +1,14 @@ import torch from torch.library import impl_abstract -print(__name__) - @impl_abstract("DrissTorch::saturated_cast") def saturated_cast_meta( x: torch.Tensor, - scale: torch.Tensor, + amax: torch.Tensor, out_dtype: torch.dtype, transpose: bool = False, ): - return torch.empty_like(x, dtype=out_dtype) + return torch.empty_like(x, dtype=out_dtype), torch.empty( + (), device=x.device, dtype=torch.float32 + ) diff --git a/src/add.cu b/src/add.cu deleted file mode 100644 index 8091ad1..0000000 --- a/src/add.cu +++ /dev/null @@ -1,40 +0,0 @@ -#include "add.h" -#include "utils.h" -#include -#include -#include -#include -#include - -namespace driss_torch { - -using namespace at; - -namespace { -template -__global__ void add_one_kernel(const T *const input, T *const output, - const int64_t N) { - // Grid-strided loop - for (int i = blockDim.x * blockIdx.x + threadIdx.x; i < N; - i += blockDim.x * gridDim.x) { - output[i] = input[i] + 1; - } -} -} // namespace - -Tensor add_one(const Tensor &input) { - auto output = torch::zeros_like(input); - - AT_DISPATCH_ALL_TYPES(input.scalar_type(), "add_one_cuda", [&]() { - const auto block_size = 128; - const auto num_blocks = - std::min(65535L, ceil_div(input.numel(), block_size)); - add_one_kernel<<>>( - input.data_ptr(), output.data_ptr(), input.numel()); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); - - return output; -} - -} // namespace driss_torch diff --git a/src/include/add.h b/src/include/add.h deleted file mode 100644 index 6e17bbf..0000000 --- a/src/include/add.h +++ /dev/null @@ -1,7 +0,0 @@ - -#pragma once -#include - -namespace driss_torch { -at::Tensor add_one(const at::Tensor &input); -} // namespace driss_torch diff --git a/src/include/saturated_cast.h b/src/include/saturated_cast.h index 4d4d780..35cd2dd 100644 --- a/src/include/saturated_cast.h +++ b/src/include/saturated_cast.h @@ -4,9 +4,6 @@ namespace driss_torch { -at::Tensor saturated_cast(const at::Tensor &input, const at::Tensor &attn_mask, +std::tuple saturated_cast(const at::Tensor &input, const at::Tensor &amax, at::ScalarType dtype, bool transpose); -at::Tensor saturated_cast_meta(const at::Tensor &input, - const at::Tensor &attn_mask, - at::ScalarType dtype, bool transpose); } // namespace driss_torch diff --git a/src/include/utils.h b/src/include/utils.h index d13456d..fd13d41 100644 --- a/src/include/utils.h +++ b/src/include/utils.h @@ -1,11 +1,18 @@ #pragma once +#include #include #include - namespace driss_torch { +template +__device__ void thread_zero_print(const char *fmt, Args &&...args) { + if (threadIdx.x == 0 && blockIdx.x == 0) { + printf(fmt, std::forward(args)...); + } +} + // error checking macro #define cudaCheckErrors(msg) \ do { \ diff --git a/src/register_ops.cpp b/src/register_ops.cpp index b377bf9..b155006 100644 --- a/src/register_ops.cpp +++ b/src/register_ops.cpp @@ -2,14 +2,11 @@ #include // Custom up headers -#include "add.h" #include "saturated_cast.h" TORCH_LIBRARY(DrissTorch, m) { m.impl_abstract_pystub("driss_torch.abstract_impls"); - m.def("add_one(Tensor input) -> Tensor"); - m.impl("add_one", c10::DispatchKey::CUDA, TORCH_FN(driss_torch::add_one)); // Saturated cast func from bf16 to fp8 types - m.def("saturated_cast(Tensor input, Tensor scale, ScalarType dtype, bool transpose) -> Tensor"); + m.def("saturated_cast(Tensor input, Tensor amax, ScalarType dtype, bool transpose) -> (Tensor, Tensor)"); m.impl("saturated_cast", c10::DispatchKey::CUDA, TORCH_FN(driss_torch::saturated_cast)); } diff --git a/src/saturated_cast.cu b/src/saturated_cast.cu index e20621f..beef296 100644 --- a/src/saturated_cast.cu +++ b/src/saturated_cast.cu @@ -9,47 +9,66 @@ #include #include +#include #include namespace driss_torch { using namespace at; namespace { - +__forceinline__ __device__ void set_scale(float *scale, float scaler) { + if (threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && + blockIdx.y == 0) { + *scale = scaler; + } +} #define DISPATCH_KERNEL_SINGLE(T) \ saturated_cast_kernel_single<<>>( \ static_cast(input.data_ptr()), \ static_cast<__nv_fp8_storage_t *>(output.data_ptr()), n_rows, n_cols, \ - out_dtype, static_cast(scale.data_ptr())) + out_dtype, static_cast(amax.data_ptr()), \ + static_cast(scale.data_ptr())) #define DISPATCH_KERNEL_DOUBLE_COALESCED(T) \ saturated_cast_kernel_double_coalesced<<>>( \ static_cast(input.data_ptr()), \ static_cast<__nv_fp8x2_storage_t *>(output.data_ptr()), n_rows, n_cols, \ - out_dtype, static_cast(scale.data_ptr())) + out_dtype, static_cast(amax.data_ptr()), \ + static_cast(scale.data_ptr())) #define DISPATCH_KERNEL_DOUBLE_COALESCED_FLAT(T) \ saturated_cast_kernel_double_coalesced_flat \ <<>>( \ static_cast(input.data_ptr()), \ static_cast<__nv_fp8x2_storage_t *>(output.data_ptr()), \ - packed_numel, out_dtype, static_cast(scale.data_ptr())) + packed_numel, out_dtype, static_cast(amax.data_ptr()), \ + static_cast(scale.data_ptr())) + +float __forceinline__ __device__ +get_dtype_max(__nv_fp8_interpretation_t out_dtype) { + return out_dtype == __nv_fp8_interpretation_t::__NV_E4M3 ? 448.0f : 57344.0f; +} template -__global__ void saturated_cast_kernel_single( - HPType *input, __nv_fp8_storage_t *output, int n_rows, int n_cols, - __nv_fp8_interpretation_t out_dtype, float *scaler) { +__global__ void +saturated_cast_kernel_single(HPType const *input, __nv_fp8_storage_t *output, + const int n_rows, const int n_cols, + __nv_fp8_interpretation_t out_dtype, + const float *amax, float *scale) { int row = blockIdx.y * blockDim.y + threadIdx.y; int col = blockIdx.x * blockDim.x + threadIdx.x; // Assume row major + const float dtype_max = get_dtype_max(out_dtype); const int global_index = row * n_cols + col; if (row < n_rows && col < n_cols) { + const float scaler = dtype_max / std::max((*amax), 1e-12f); + set_scale(scale, scaler); if constexpr (std::is_same_v) { - const HPType scaled_input = __hmul(input[global_index], (*scaler)); + const HPType scaled_input = __hmul(input[global_index], scaler); output[global_index] = __nv_cvt_bfloat16raw_to_fp8( scaled_input, __nv_saturation_t::__NV_SATFINITE, out_dtype); } else { - const HPType scaled_input = input[global_index] * (*scaler); + const HPType scaled_input = input[global_index] * scaler; output[global_index] = __nv_cvt_float_to_fp8( scaled_input, __nv_saturation_t::__NV_SATFINITE, out_dtype); } @@ -60,10 +79,13 @@ template __global__ void saturated_cast_kernel_double_coalesced_flat( PackedHPType const *__restrict input, __nv_fp8x2_storage_t *__restrict output, const int numels, - __nv_fp8_interpretation_t out_dtype, float const *scaler) { + __nv_fp8_interpretation_t out_dtype, float const *amax, float *scale) { const int idx = (blockIdx.x * blockDim.x + threadIdx.x) * coarse_factor; const int stride = 1; - const PackedHPType scale_2 = {(*scaler), (*scaler)}; + const float dtype_max = get_dtype_max(out_dtype); + const float scaler = dtype_max / std::max((*amax), 1e-12f); + set_scale(scale, scaler); + const PackedHPType scale_2 = {scaler, scaler}; PackedHPType scaled_inputs[coarse_factor]; #pragma unroll @@ -81,8 +103,8 @@ __global__ void saturated_cast_kernel_double_coalesced_flat( scaled_inputs[i] = __hmul2(scaled_inputs[i], scale_2); } else { // I can't find the right fmul2 fo this?? - scaled_inputs[i] = {scaled_inputs[i].x * (*scaler), - scaled_inputs[i].y * (*scaler)}; + scaled_inputs[i] = {scaled_inputs[i].x * scaler, + scaled_inputs[i].y * scaler}; } } } @@ -107,12 +129,15 @@ template __global__ void saturated_cast_kernel_double_coalesced( PackedHPType const *__restrict input, __nv_fp8x2_storage_t *__restrict output, int n_rows, int n_cols, - __nv_fp8_interpretation_t out_dtype, float const *scaler) { + __nv_fp8_interpretation_t out_dtype, float const *amax, float *scale) { int row = blockIdx.y * blockDim.y + threadIdx.y; int col = (blockIdx.x * blockDim.x + threadIdx.x) * coarse_factor; const int row_stride = n_cols; const int col_stride = 1; - const PackedHPType scale_2 = {(*scaler), (*scaler)}; + const float dtype_max = get_dtype_max(out_dtype); + const float scaler = dtype_max / std::max((*amax), 1e-12f); + set_scale(scale, scaler); + const PackedHPType scale_2 = {scaler, scaler}; PackedHPType scaled_inputs[coarse_factor]; #pragma unroll @@ -130,8 +155,8 @@ __global__ void saturated_cast_kernel_double_coalesced( scaled_inputs[i] = __hmul2(scaled_inputs[i], scale_2); } else { // I can't find the right fmul2 fo this?? - scaled_inputs[i] = {scaled_inputs[i].x * (*scaler), - scaled_inputs[i].y * (*scaler)}; + scaled_inputs[i] = {scaled_inputs[i].x * scaler, + scaled_inputs[i].y * scaler}; } } } @@ -167,7 +192,7 @@ enum KernelChoice { single, coalesced, coalesced_flat }; void dispatch_best_kernel(const Tensor &input, const Tensor &output, __nv_fp8_interpretation_t out_dtype, - const Tensor &scale, bool transpose) { + const Tensor &amax, Tensor &scale, bool transpose) { const int n_rows = input.size(0); const int n_cols = input.size(1); const int block_size_x = 32; @@ -224,8 +249,9 @@ void dispatch_best_kernel(const Tensor &input, const Tensor &output, } } // namespace -Tensor saturated_cast_meta(const Tensor &input, const Tensor &scale, - ScalarType dtype, bool transpose) { +std::tuple saturated_cast(const Tensor &input, + const Tensor &amax, ScalarType dtype, + bool transpose) { TORCH_CHECK(dtype == at::kFloat8_e4m3fn || dtype == at::kFloat8_e5m2, "Output tensor must be of type Float8_e4m3fn or Float8_e5m2") @@ -233,38 +259,28 @@ Tensor saturated_cast_meta(const Tensor &input, const Tensor &scale, input.scalar_type() == at::kFloat, "Input tensor must be of type BFloat16 or Float, but got ", input.dtype()); - TORCH_CHECK(scale.scalar_type() == at::kFloat, - "Scale tensor must be of type Float, but got ", scale.dtype()) - - auto output = torch::empty_like(input, input.options().dtype(dtype)); - return output; -} + TORCH_CHECK(input.dim() == 2, "Input tensor must be 2D, but got ", + input.dim()); -Tensor saturated_cast(const Tensor &input, const Tensor &scale, - ScalarType dtype, bool transpose) { - TORCH_CHECK(dtype == at::kFloat8_e4m3fn || dtype == at::kFloat8_e5m2, - "Output tensor must be of type Float8_e4m3fn or Float8_e5m2") - - TORCH_CHECK(input.scalar_type() == at::kBFloat16 || - input.scalar_type() == at::kFloat, - "Input tensor must be of type BFloat16 or Float, but got ", - input.dtype()); - TORCH_CHECK(scale.scalar_type() == at::kFloat, - "Scale tensor must be of type Float, but got ", scale.dtype()) - TORCH_CHECK(input.dim() == 2, "Input tensor must be 2D, but got ", input.dim()); - TORCH_CHECK(scale.numel() == 1, "Scale tensor must be a scalar, but got ", - scale.numel()); + TORCH_CHECK(amax.scalar_type() == at::kFloat, + "Amax tensor must be of type Float, but got ", amax.dtype()) + TORCH_CHECK(amax.numel() == 1, "Amax tensor must be a scalar, but got ", + amax.numel()); // Input must either be transposed or contiguous auto strides = input.strides(); bool is_contiguous = input.is_contiguous(); bool is_transposed = strides[0] == 1 && strides[1] == input.size(0); - bool check_allowed_strides = (is_contiguous || is_transposed) && input.storage_offset() == 0 ; + bool check_allowed_strides = + (is_contiguous || is_transposed) && input.storage_offset() == 0; auto contig_input = check_allowed_strides ? input : input.contiguous(); - auto output = torch::empty_like(contig_input, contig_input.options().dtype(dtype)); - dispatch_best_kernel(contig_input, output, dtype_map(dtype), scale, transpose); - return output; + auto output = + torch::empty_like(contig_input, contig_input.options().dtype(dtype)); + auto scale = torch::empty({}, amax.options().dtype(at::kFloat)); + dispatch_best_kernel(contig_input, output, dtype_map(dtype), amax, scale, + transpose); + return {output, scale}; } } // namespace driss_torch diff --git a/test/test_add.py b/test/test_add.py deleted file mode 100644 index 3dc1f96..0000000 --- a/test/test_add.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from driss_torch import add_one - - -def test_add_one(): - shape = (3, 3, 3) - a = torch.randint(0, 10, shape, dtype=torch.float).cuda() - - torch.testing.assert_close(add_one(a), a + 1) diff --git a/test/test_sat_cast.py b/test/test_sat_cast.py index 559348e..7f63b35 100644 --- a/test/test_sat_cast.py +++ b/test/test_sat_cast.py @@ -1,7 +1,7 @@ import pytest import torch from driss_torch import saturated_cast -from float8_experimental.float8_utils import tensor_to_scale +from float8_experimental.float8_utils import tensor_to_amax, tensor_to_scale def eager_scaled_quant( @@ -26,16 +26,29 @@ def eager_scaled_quant( @pytest.mark.parametrize("num_cols", [7, 17, 127, 512, 3212, 4097]) @pytest.mark.parametrize("fp8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) @pytest.mark.parametrize("in_dtype", [torch.bfloat16, torch.float32]) -def test_cast(num_rows: int, num_cols: int, in_dtype: torch.dtype, fp8_dtype: torch.dtype): +def test_cast_eager(num_rows: int, num_cols: int, in_dtype: torch.dtype, fp8_dtype: torch.dtype): + torch.manual_seed(0) a = torch.rand(num_rows, num_cols, dtype=in_dtype, device="cuda") + amax = tensor_to_amax(a).to(torch.float32) scale = tensor_to_scale(a, fp8_dtype) cast_pytorch = eager_scaled_quant(a, scale, fp8_dtype) - cast_custom = saturated_cast(a, scale, fp8_dtype) + cast_custom, scale_custom = saturated_cast(a, amax, fp8_dtype) custom_fp32 = cast_custom.to(torch.float32) pytorch_fp32 = cast_pytorch.to(torch.float32) - torch.testing.assert_close(custom_fp32, pytorch_fp32) + torch.testing.assert_close(custom_fp32, pytorch_fp32, atol=1e-5, rtol=0.20) + torch.testing.assert_close(scale, scale_custom) + # I worked through examples and I am pretty convinced that the fused kernel is more accurate than + # eager pytorch + # The fused kernel says that scaler: 57344.066406 is an example of a scale when the amax is 0.9999988675117493 + # while eager pytorch says 57344.0703125 + # The actual value is 57344.0 / 0.9999988675117493 = 57344.064941479795 + # The difference for the fused kernel is 0.00146 and for pytorch it is: 0.00537 + # The unfortunate thing is that since we then take this scale and multiply it by the input tensor, and + # convert to fp8e4 or fp8e5 there will be values that get braodcasted near the end of the of range where small + # epsilon in scale can cause a large difference in the fp8 tensor since the dynamic range is so small at the + # end of the range. @pytest.mark.parametrize("num_rows", [3, 64, 512, 4096]) @@ -45,15 +58,18 @@ def test_cast(num_rows: int, num_cols: int, in_dtype: torch.dtype, fp8_dtype: to def test_cast_compile(num_rows: int, num_cols: int, in_dtype: torch.dtype, fp8_dtype: torch.dtype): torch._dynamo.reset() a = torch.rand(num_rows, num_cols, dtype=in_dtype, device="cuda") + amax = tensor_to_amax(a).to(torch.float32) scale = tensor_to_scale(a, fp8_dtype) cast_custom_compile_func = torch.compile(saturated_cast, fullgraph=True) - cast_custom = saturated_cast(a, scale, fp8_dtype) - cast_custom_compile = cast_custom_compile_func(a, scale, fp8_dtype) + cast_custom, scale_custom = saturated_cast(a, amax, fp8_dtype) + cast_custom_compile, scale_custom_compile = cast_custom_compile_func(a, amax, fp8_dtype) custom_fp32 = cast_custom.to(torch.float32) custom_compile_fp32 = cast_custom_compile.to(torch.float32) torch.testing.assert_close(custom_fp32, custom_compile_fp32) + torch.testing.assert_close(scale, scale_custom) + torch.testing.assert_close(scale, scale_custom_compile) @pytest.mark.xfail(reason="This test is failing, we need to investigate", strict=True) @@ -63,7 +79,7 @@ def test_cast_edge_bug(): cast_pytorch = eager_scaled_quant(a, scale, torch.float8_e5m2) cast_custom = saturated_cast(a, scale, torch.float8_e5m2) - custom_fp32 = cast_custom.to(torch.float32) + custom_fp32, scale_custom = cast_custom.to(torch.float32) pytorch_fp32 = cast_pytorch.to(torch.float32) MAX_P_output = a.to(torch.float64) * scale.to(torch.float64) print("Custom diff is ", torch.max(torch.abs(MAX_P_output - custom_fp32)))