diff --git a/setup.py b/setup.py index cabaad01cf..0915f6ae1e 100644 --- a/setup.py +++ b/setup.py @@ -433,6 +433,7 @@ def get_extensions(): "to_sparse_semi_structured_cutlass_sm9x_f8.cu", ), os.path.join(extensions_cuda_dir, "activation24", "sparsify24.cu"), + os.path.join(extensions_cuda_dir, "activation24", "sparse_gemm.cu"), ] for dtypes in ["e4m3e4m3", "e4m3e5m2", "e5m2e4m3", "e5m2e5m2"]: cutlass_90a_sources.append( diff --git a/test/sparsity/test_activation24.py b/test/sparsity/test_activation24.py index 65b7cfd8d2..420bf4328a 100644 --- a/test/sparsity/test_activation24.py +++ b/test/sparsity/test_activation24.py @@ -8,6 +8,7 @@ PerRow, quantize_, ) +from torchao.quantization.quant_api import _float8_cutlass_quant torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = True @@ -141,3 +142,68 @@ def srelu_linear(x): custom_output = reference_linear_copy(input_tensor) torch.testing.assert_close(reference_output, custom_output, rtol=0.1, atol=0.01) + + +@unittest.skipIf(not is_sm_at_least_90(), "Need cuda arch greater than SM90") +def test_sparse24_fp8_sm90_cutlass_gemm_eye( + M=512, K=256, dtype=torch.float8_e4m3fn +) -> None: + torch.manual_seed(0) + + A_dense = create_semi_structured_tensor(M, K, dtype=torch.bfloat16).cuda() + A_aqt = _float8_cutlass_quant(A_dense, dtype) + A = A_aqt.tensor_impl.float8_data + + # NOTE: CUTLASS compression kernel expects the input to be *exactly* + # 2:4 sparse already (eg it does not select the largest values) + A_packed, A_mdata = to_sparse_semi_structured_cutlass_sm9x_f8(A) + assert torch.allclose( + A_packed.float().sum(), A.float().sum() + ) # Check all values are there + + # Check MM without scale + eye = torch.eye(A.shape[1], device=A.device, dtype=A.dtype).T + A_reconstructed = torch.ops.torchao.sparse24_fp8_sm90_cutlass_gemm( + A_packed, A_mdata, eye + ) + assert torch.allclose(A.float(), A_reconstructed.float()) + + # Check MM with scale + b_scale = torch.randn([1, A.shape[1]], device=eye.device, dtype=torch.float32) + a_scale = torch.randn([A.shape[0], 1], device=eye.device, dtype=torch.float32) + A_reconstructed = torch.ops.torchao._sparse24_fp8_sm90_cutlass_gemm( + A_packed, A_mdata, eye, a_scale=a_scale, b_scale=b_scale + ) + assert torch.allclose( + A.float() * b_scale * a_scale, A_reconstructed.float(), rtol=0.01 + ) + + +@unittest.skipIf(not is_sm_at_least_90(), "Need cuda arch greater than SM90") +def test_sparse24_fp8_sm90_cutlass_gemm_random_tensor( + M=512, N=1024, K=256, dtype=torch.float8_e4m3fn +) -> None: + def _to_fp8_rowwise(x: torch.Tensor, dtype): + max_v = torch.finfo(dtype).max + x_scale = (x.abs().max(1, keepdim=True)[0] / max_v).float() + x = (x / x_scale).to(dtype) + return x, x_scale + + torch.manual_seed(0) + A_dense = create_semi_structured_tensor(M, K, dtype=torch.bfloat16).cuda() + A, a_scale = _to_fp8_rowwise(A_dense, dtype) + + B_dense = torch.randn([N, K], device="cuda", dtype=torch.bfloat16) + B, b_scale = _to_fp8_rowwise(B_dense, dtype) + + B = B.T + b_scale = b_scale.T + + A_packed, A_mdata = to_sparse_semi_structured_cutlass_sm9x_f8(A) + out_sparse = torch.ops.torchao.sparse24_fp8_sm90_cutlass_gemm( + A_packed, A_mdata, B, a_scale=a_scale, b_scale=b_scale + ) + out_ref = torch._scaled_mm( + A, B, scale_a=a_scale, scale_b=b_scale, out_dtype=out_sparse.dtype + ) + assert torch.allclose(out_sparse, out_ref, rtol=0.01, atol=0.01) diff --git a/torchao/csrc/cuda/activation24/sparse_gemm.cu b/torchao/csrc/cuda/activation24/sparse_gemm.cu new file mode 100644 index 0000000000..776766794e --- /dev/null +++ b/torchao/csrc/cuda/activation24/sparse_gemm.cu @@ -0,0 +1,351 @@ +#include +#include +#include +#include +#include + +#if defined(TORCHAO_USE_CUTLASS) && !defined(_WIN32) && \ + defined(CUDA_VERSION) && (CUDA_VERSION >= 12020) +#define BUILD_SM90_24_FP8_CUTLASS_GEMM +#endif + +#if defined(BUILD_SM90_24_FP8_CUTLASS_GEMM) +#include +#include +#include +#include "cutlass/arch/wmma.h" +#include "cutlass/bfloat16.h" +#include "cutlass/cuda_host_adapter.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/transform/device/transform_universal_adapter.hpp" + +#include +#include + +using namespace at; + +namespace { +#define CUTLASS_STATUS_CHECK(status) \ + { \ + TORCH_CHECK( \ + status == cutlass::Status::kSuccess, \ + "Got CUTLASS error: ", \ + cutlass::cutlassGetStatusString(status)); \ + } + +template +struct identity { + CUTLASS_HOST_DEVICE + T operator()(T lhs) const { + return lhs; + } +}; + +template +struct SparseRowwiseKernel; + +template <> +struct SparseRowwiseKernel { + static constexpr auto kElementOutAt = at::ScalarType::BFloat16; + static constexpr auto kElementAAt = at::ScalarType::Float8_e4m3fn; + + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementOut = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using TileShape = cute::Shape; + + // Epilogue visitor tree + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + using AScale = + cutlass::epilogue::fusion::Sm90ColBroadcast<0, TileShape, float>; + using BScale = + cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, float>; + using Multiply = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, + float, + float, + cutlass::FloatRoundStyle::round_to_nearest>; + using Cast = cutlass::epilogue::fusion::Sm90Compute< + identity, + ElementOut, + float, + cutlass::FloatRoundStyle::round_to_nearest>; + using EpilogueEVT = cutlass::epilogue::fusion::Sm90EVT< + Cast, + cutlass::epilogue::fusion::Sm90EVT< + Multiply, + BScale, + cutlass::epilogue::fusion::Sm90EVT>>; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, + cutlass::arch::OpClassSparseTensorOp, + TileShape, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + float, + ElementOut, + cutlass::layout::RowMajor, + 1, + ElementOut, + cutlass::layout::RowMajor, + 1, + cutlass::epilogue::TmaWarpSpecializedCooperative, + EpilogueEVT>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, + cutlass::arch::OpClassSparseTensorOp, + ElementA, + cutlass::layout::RowMajor, + 32, + ElementB, + cutlass::layout::ColumnMajor, + 16, + ElementAccumulator, + cute::Shape, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum>:: + CollectiveOp; + + // Gemm operator + // cutlass3x_sm90_sptensorop_s64x256x64spgemm_e4m3_e4m3_f32_bf16_bf16_128x256x128_2x1x1_0_tnt_align32_warpspecialized_cooperative_fp8_fastaccum_epi_tma + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::PersistentScheduler>; + using ElementE = CollectiveMainloop::ElementE; +}; + +template <> +struct SparseRowwiseKernel { + static constexpr auto kElementOutAt = at::ScalarType::BFloat16; + static constexpr auto kElementAAt = at::ScalarType::BFloat16; + + using ElementA = cutlass::bfloat16_t; + using ElementB = cutlass::bfloat16_t; + using ElementOut = cutlass::bfloat16_t; + + using TileShape = cute::Shape; + + // Epilogue visitor tree + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + using AScale = + cutlass::epilogue::fusion::Sm90ColBroadcast<0, TileShape, float>; + using BScale = + cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, float>; + using Multiply = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, + float, + float, + cutlass::FloatRoundStyle::round_to_nearest>; + using Cast = cutlass::epilogue::fusion::Sm90Compute< + identity, + ElementOut, + float, + cutlass::FloatRoundStyle::round_to_nearest>; + using EpilogueEVT = cutlass::epilogue::fusion::Sm90EVT< + Cast, + cutlass::epilogue::fusion::Sm90EVT< + Multiply, + BScale, + cutlass::epilogue::fusion::Sm90EVT>>; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, + cutlass::arch::OpClassSparseTensorOp, + TileShape, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, + float, + ElementOut, + cutlass::layout::RowMajor, + 1, + ElementOut, + cutlass::layout::RowMajor, + 1, + cutlass::epilogue::TmaWarpSpecializedCooperative, + EpilogueEVT>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, + cutlass::arch::OpClassSparseTensorOp, + ElementA, + cutlass::layout::RowMajor, + 16, + ElementB, + cutlass::layout::ColumnMajor, + 16, + float, + cute::Shape, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecializedCooperative>::CollectiveOp; + + // Gemm operator + // cutlass3x_sm90_sptensorop_s64x128x32spgemm_bf16_bf16_f32_void_f32_128x128x64_2x1x1_0_ttn_align16_warpspecialized_cooperative_epi_tma + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::PersistentScheduler>; + using ElementE = CollectiveMainloop::ElementE; +}; + +template +Tensor _sparse24_fp8_sm90_cutlass_gemm( + const Tensor& tensor_a, + const Tensor& tensor_e, // metadata for `A` + const Tensor& tensor_b, + // *, + std::optional a_scale, + std::optional b_scale, + int64_t swizzle_size, + std::string swizzle_axis, + int64_t sm_count) { + std::optional device_guard; + if (!kIsMeta) { + device_guard.emplace(tensor_a.device()); + } + + using K = SparseRowwiseKernel; + + // For now, only CC 9.x devices are supported. + if (!kIsMeta) { + const auto dprops = at::cuda::getCurrentDeviceProperties(); + TORCH_CHECK( + dprops && dprops->major == 9, + "_sparse24_gemm_fp8_sm90: Supported only on GPUs with " + "compute capability 9.x"); + } + + // Validate layouts of input tensors. + TORCH_CHECK(tensor_a.device() == tensor_b.device()); + TORCH_CHECK(tensor_a.device() == tensor_e.device()); + TORCH_CHECK(tensor_a.dim() == 2); + TORCH_CHECK(tensor_b.dim() == 2); + TORCH_CHECK(tensor_a.scalar_type() == tensor_b.scalar_type()); + TORCH_CHECK(tensor_a.scalar_type() == K::kElementAAt); + TORCH_CHECK(tensor_b.stride(0) == 1, "B must be Row-Major"); + TORCH_CHECK(tensor_a.is_contiguous()); + TORCH_CHECK(tensor_b.t().is_contiguous()); + int64_t a_rows = tensor_a.size(0); + if (a_scale.has_value()) { + TORCH_CHECK(a_scale->is_contiguous()); + TORCH_CHECK(a_scale->scalar_type() == at::ScalarType::Float); + TORCH_CHECK(a_scale->device() == tensor_a.device()); + TORCH_CHECK(a_scale->dim() == 2); + TORCH_CHECK(a_scale->size(0) == a_rows); + TORCH_CHECK(a_scale->size(1) == 1); + } + if (b_scale.has_value()) { + TORCH_CHECK(b_scale->is_contiguous()); + TORCH_CHECK(b_scale->scalar_type() == at::ScalarType::Float); + TORCH_CHECK(b_scale->device() == tensor_b.device()); + TORCH_CHECK(b_scale->dim() == 2); + TORCH_CHECK(b_scale->size(0) == 1); + TORCH_CHECK(b_scale->size(1) == tensor_b.size(1)); + } + + typename K::GemmKernel::Arguments args; + args.mode = cutlass::gemm::GemmUniversalMode::kGemm; + args.problem_shape = cute::make_shape( + int(a_rows), int(tensor_b.size(1)), int(tensor_b.size(0)), 1); + Tensor out = tensor_a.new_empty( + {cute::get<0>(args.problem_shape), cute::get<1>(args.problem_shape)}, + at::TensorOptions().dtype(K::kElementOutAt)); + + args.mainloop.ptr_A = + reinterpret_cast(tensor_a.data_ptr()); + args.mainloop.ptr_B = static_cast(tensor_b.data_ptr()); + args.mainloop.ptr_E = + reinterpret_cast(tensor_e.data_ptr()); + args.epilogue.ptr_C = nullptr; + args.epilogue.ptr_D = static_cast(out.data_ptr()); + + float const* a_scale_ptr = + (float const*)(a_scale.has_value() ? a_scale->data_ptr() : nullptr); + float const* b_scale_ptr = + (float const*)(b_scale.has_value() ? b_scale->data_ptr() : nullptr); + float default_scale = 1.0f; // used if ptr is nullptr + auto& cast_op = args.epilogue.thread; + auto& mulB_op = cast_op.op_0; + mulB_op.op_0 = {b_scale_ptr, default_scale}; + auto& mulA_op = mulB_op.op_1; + mulA_op.op_0 = {a_scale_ptr, default_scale}; + + args.mainloop.layout_a = + K::CollectiveMainloop::SparseConfig::fill_layoutA(args.problem_shape); + args.mainloop.layout_e = + K::CollectiveMainloop::SparseConfig::fill_layoutE(args.problem_shape); + args.mainloop.dB = cute::make_int_tuple_from( + tensor_b.stride(1), 0); + args.epilogue.dC = cute::make_int_tuple_from( + out.stride(0), 0); + args.epilogue.dD = cute::make_int_tuple_from( + out.stride(0), 0); + + /* Query device SM count to pass onto the kernel as an argument, where needed + */ + args.hw_info.device_id = tensor_a.device().index(); + args.hw_info.sm_count = sm_count; + args.scheduler.max_swizzle_size = swizzle_size; + using Enum_t = decltype(args.scheduler.raster_order); + if (swizzle_axis == "n") { + args.scheduler.raster_order = Enum_t::AlongN; + } else { + TORCH_CHECK( + swizzle_axis == "m", + "Invalid value for swizzle_axis ('", + swizzle_axis, + "')"); + args.scheduler.raster_order = Enum_t::AlongM; + } + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + int64_t device_op_workspace_size = Gemm::get_workspace_size(args); + Tensor workspace = tensor_a.new_empty( + {device_op_workspace_size}, + at::TensorOptions().dtype(at::ScalarType::Byte)); + + Gemm gemm; + // Check the problem size is supported or not + CUTLASS_STATUS_CHECK(gemm.can_implement(args)); + + auto status = gemm.run( + args, (void*)workspace.data_ptr(), at::cuda::getCurrentCUDAStream()); + CUTLASS_STATUS_CHECK(status); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + return out; +} +} // namespace + +TORCH_LIBRARY_IMPL(torchao, CUDA, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchao::sparse24_fp8_sm90_cutlass_gemm"), + TORCH_FN(_sparse24_fp8_sm90_cutlass_gemm)); +} + +TORCH_LIBRARY_IMPL(torchao, Meta, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchao::sparse24_fp8_sm90_cutlass_gemm"), + TORCH_FN(_sparse24_fp8_sm90_cutlass_gemm)); +} +#endif diff --git a/torchao/ops.py b/torchao/ops.py index faebdbd5d1..b91bb8ae18 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -42,6 +42,9 @@ lib.define( "sparse24_sm90_sparsify(Tensor input, str metadata_fmt, str activation, str sp_selection_algo, *, ScalarType? dtype = None, Tensor? scale=None) -> (Tensor, Tensor)" ) +lib.define( + "sparse24_fp8_sm90_cutlass_gemm(Tensor a, Tensor a_mdata, Tensor b, *, Tensor? a_scale = None, Tensor? b_scale = None, int swizzle_size=8, str swizzle_axis='n', int sm_count=128) -> Tensor" +) lib.define( "swizzle_mm(Tensor mat1, Tensor mat2, bool mat1_is_swizzled, bool mat2_is_swizzled) -> Tensor" ) @@ -840,6 +843,28 @@ def sparse24_sm90_sparsify( ) +def sparse24_fp8_sm90_cutlass_gemm( + a: Tensor, + meta: Tensor, + b: Tensor, + a_scale: Optional[Tensor], + b_scale: Optional[Tensor], + swizzle_size: int, + swizzle_axis: str, + sm_count: int, +) -> Tensor: + return torch.ops.torchao.sparse24_fp8_sm90_cutlass_gemm( + a, + meta, + b, + a_scale=a_scale, + b_scale=b_scale, + swizzle_size=swizzle_size, + swizzle_axis=swizzle_axis, + sm_count=sm_count, + ) + + def swizzle_mm( mat1: Tensor, mat2: Tensor, mat1_is_swizzled: bool, mat2_is_swizzled: bool ) -> Tensor: