From 2a9ecdccf05c2a24288c93430067b96b33f763c8 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Wed, 21 May 2025 16:46:57 -0700 Subject: [PATCH 1/7] [sparse] Add fp8 sparse gemm with rowwise scaling for activation sparsity Summary: We have this gemm already in torchao, but for weight sparsity. For activation sparsity, we need the weights to be stored in column-major format to allow for us to use the selective weight loading kernel for decode. Test Plan: Reviewers: Subscribers: Tasks: Tags: --- setup.py | 1 + test/sparsity/test_activation24.py | 68 +++ torchao/csrc/cuda/activation24/sparse_gemm.cu | 438 ++++++++++++++++++ torchao/ops.py | 17 + 4 files changed, 524 insertions(+) create mode 100644 torchao/csrc/cuda/activation24/sparse_gemm.cu diff --git a/setup.py b/setup.py index cabaad01cf..790f241d52 100644 --- a/setup.py +++ b/setup.py @@ -433,6 +433,7 @@ def get_extensions(): "to_sparse_semi_structured_cutlass_sm9x_f8.cu", ), os.path.join(extensions_cuda_dir, "activation24", "sparsify24.cu"), + os.path.join(extensions_cuda_dir, "activation24", "sparse_gemm.cu") ] for dtypes in ["e4m3e4m3", "e4m3e5m2", "e5m2e4m3", "e5m2e5m2"]: cutlass_90a_sources.append( diff --git a/test/sparsity/test_activation24.py b/test/sparsity/test_activation24.py index 65b7cfd8d2..37374291e4 100644 --- a/test/sparsity/test_activation24.py +++ b/test/sparsity/test_activation24.py @@ -8,6 +8,7 @@ PerRow, quantize_, ) +from torchao.quantization.quant_api import _float8_cutlass_quant torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = True @@ -141,3 +142,70 @@ def srelu_linear(x): custom_output = reference_linear_copy(input_tensor) torch.testing.assert_close(reference_output, custom_output, rtol=0.1, atol=0.01) + + +@unittest.skipIf(not is_sm_at_least_90(), "Need cuda arch greater than SM90") +def test_sparse24_fp8_sm90_cutlass_gemm_eye( + M=512, K=256, dtype=torch.float8_e4m3fn +) -> None: + torch.manual_seed(0) + + A_dense = create_semi_structured_tensor(M, K, dtype=torch.bfloat16).cuda() + A_aqt = _float8_cutlass_quant(A_dense, dtype) + A = A_aqt.tensor_impl.float8_data + + # NOTE: CUTLASS compression kernel expects the input to be *exactly* + # 2:4 sparse already (eg it does not select the largest values) + A_packed, A_mdata = to_sparse_semi_structured_cutlass_sm9x_f8(A) + assert torch.allclose( + A_packed.float().sum(), A.float().sum() + ) # Check all values are there + + # Check MM without scale + eye = torch.eye(A.shape[1], device=A.device, dtype=A.dtype).T + A_reconstructed = torch.ops.torchao._sparse24_fp8_sm90_cutlass_gemm( + A_packed, A_mdata, eye + ) + assert torch.allclose(A.float(), A_reconstructed.float()) + + # Check MM with scale + b_scale = torch.randn([1, A.shape[1]], device=eye.device, dtype=torch.float32) + a_scale = torch.randn([A.shape[0], 1], device=eye.device, dtype=torch.float32) + A_reconstructed = torch.ops.torchao._sparse24_fp8_sm90_cutlass_gemm( + A_packed, A_mdata, eye, a_scale=a_scale, b_scale=b_scale + ) + assert torch.allclose( + A.float() * b_scale * a_scale, A_reconstructed.float(), rtol=0.01 + ) + + +@unittest.skipIf(not is_sm_at_least_90(), "Need cuda arch greater than SM90") +def test_sparse24_fp8_sm90_cutlass_gemm_random_tensor( + M=512, N=1024, K=256, dtype=torch.float8_e4m3fn +) -> None: + + def _to_fp8_rowwise(x: torch.Tensor, dtype): + max_v = torch.finfo(dtype).max + x_scale = (x.abs().max(1, keepdim=True)[0] / max_v).float() + x = (x / x_scale).to(dtype) + return x, x_scale + + torch.manual_seed(0) + A_dense = create_semi_structured_tensor(M, K, dtype=torch.bfloat16).cuda() + A, a_scale = _to_fp8_rowwise(A_dense, dtype) + + B_dense = torch.randn([N, K], device="cuda", dtype=torch.bfloat16) + B, b_scale = _to_fp8_rowwise(B_dense, dtype) + + B = B.T + b_scale = b_scale.T + + A_packed, A_mdata = to_sparse_semi_structured_cutlass_sm9x_f8(A) + out_sparse = torch.ops.torchao._sparse24_fp8_sm90_cutlass_gemm( + A_packed, A_mdata, B, a_scale=a_scale, b_scale=b_scale + ) + out_ref = torch._scaled_mm( + A, B, scale_a=a_scale, scale_b=b_scale, out_dtype=out_sparse.dtype + ) + assert torch.allclose(out_sparse, out_ref, rtol=0.01, atol=0.01) + diff --git a/torchao/csrc/cuda/activation24/sparse_gemm.cu b/torchao/csrc/cuda/activation24/sparse_gemm.cu new file mode 100644 index 0000000000..ae907a68ec --- /dev/null +++ b/torchao/csrc/cuda/activation24/sparse_gemm.cu @@ -0,0 +1,438 @@ +#include +#include +#include +#include +#include + +#include +#include +#include +#include "cutlass/arch/wmma.h" +#include "cutlass/bfloat16.h" +#include "cutlass/cuda_host_adapter.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/transform/device/transform_universal_adapter.hpp" +#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" + +#include +#include + +using namespace at; + +namespace { +#define CUTLASS_STATUS_CHECK(status) \ + { \ + TORCH_CHECK( \ + status == cutlass::Status::kSuccess, \ + "Got CUTLASS error: ", \ + cutlass::cutlassGetStatusString(status)); \ + } + +template +struct identity { + CUTLASS_HOST_DEVICE + T operator()(T lhs) const { + return lhs; + } +}; + +template +struct SparseRowwiseKernel; + +template <> +struct SparseRowwiseKernel { + static constexpr auto kElementOutAt = at::ScalarType::BFloat16; + static constexpr auto kElementAAt = at::ScalarType::Float8_e4m3fn; + + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementOut = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using TileShape = cute::Shape; + + // Epilogue visitor tree + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + using AScale = + cutlass::epilogue::fusion::Sm90ColBroadcast<0, TileShape, float>; + using BScale = + cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, float>; + using Multiply = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, + float, + float, + cutlass::FloatRoundStyle::round_to_nearest>; + using Cast = cutlass::epilogue::fusion::Sm90Compute< + identity, + ElementOut, + float, + cutlass::FloatRoundStyle::round_to_nearest>; + using EpilogueEVT = cutlass::epilogue::fusion::Sm90EVT< + Cast, + cutlass::epilogue::fusion::Sm90EVT< + Multiply, + BScale, + cutlass::epilogue::fusion::Sm90EVT>>; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, + cutlass::arch::OpClassSparseTensorOp, + TileShape, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + float, + ElementOut, + cutlass::layout::RowMajor, + 1, + ElementOut, + cutlass::layout::RowMajor, + 1, + cutlass::epilogue::TmaWarpSpecializedCooperative, + EpilogueEVT>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, + cutlass::arch::OpClassSparseTensorOp, + ElementA, + cutlass::layout::RowMajor, + 32, + ElementB, + cutlass::layout::ColumnMajor, + 16, + ElementAccumulator, + cute::Shape, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum>:: + CollectiveOp; + + // Gemm operator + // cutlass3x_sm90_sptensorop_s64x256x64spgemm_e4m3_e4m3_f32_bf16_bf16_128x256x128_2x1x1_0_tnt_align32_warpspecialized_cooperative_fp8_fastaccum_epi_tma + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::PersistentScheduler>; + using ElementE = CollectiveMainloop::ElementE; +}; + +template <> +struct SparseRowwiseKernel { + static constexpr auto kElementOutAt = at::ScalarType::BFloat16; + static constexpr auto kElementAAt = at::ScalarType::BFloat16; + + using ElementA = cutlass::bfloat16_t; + using ElementB = cutlass::bfloat16_t; + using ElementOut = cutlass::bfloat16_t; + + using TileShape = cute::Shape; + + // Epilogue visitor tree + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + using AScale = + cutlass::epilogue::fusion::Sm90ColBroadcast<0, TileShape, float>; + using BScale = + cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, float>; + using Multiply = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, + float, + float, + cutlass::FloatRoundStyle::round_to_nearest>; + using Cast = cutlass::epilogue::fusion::Sm90Compute< + identity, + ElementOut, + float, + cutlass::FloatRoundStyle::round_to_nearest>; + using EpilogueEVT = cutlass::epilogue::fusion::Sm90EVT< + Cast, + cutlass::epilogue::fusion::Sm90EVT< + Multiply, + BScale, + cutlass::epilogue::fusion::Sm90EVT>>; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, + cutlass::arch::OpClassSparseTensorOp, + TileShape, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, + float, + ElementOut, + cutlass::layout::RowMajor, + 1, + ElementOut, + cutlass::layout::RowMajor, + 1, + cutlass::epilogue::TmaWarpSpecializedCooperative, + EpilogueEVT>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, + cutlass::arch::OpClassSparseTensorOp, + ElementA, + cutlass::layout::RowMajor, + 16, + ElementB, + cutlass::layout::ColumnMajor, + 16, + float, + cute::Shape, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecializedCooperative>::CollectiveOp; + + // Gemm operator + // cutlass3x_sm90_sptensorop_s64x128x32spgemm_bf16_bf16_f32_void_f32_128x128x64_2x1x1_0_ttn_align16_warpspecialized_cooperative_epi_tma + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::PersistentScheduler>; + using ElementE = CollectiveMainloop::ElementE; +}; + +template +Tensor _sparse24_fp8_sm90_cutlass_gemm( + const Tensor& tensor_a, + const Tensor& tensor_e, // metadata for `A` + const Tensor& tensor_b, + // *, + std::optional a_scale, + std::optional b_scale, + int64_t swizzle_size, + std::string swizzle_axis, + int64_t sm_count) { + std::optional device_guard; + if (!kIsMeta) { + device_guard.emplace(tensor_a.device()); + } + + using K = SparseRowwiseKernel; + + // For now, only CC 9.x devices are supported. + if (!kIsMeta) { + const auto dprops = at::cuda::getCurrentDeviceProperties(); + TORCH_CHECK( + dprops && dprops->major == 9, + "_sparse24_gemm_fp8_sm90: Supported only on GPUs with " + "compute capability 9.x"); + } + + // Validate layouts of input tensors. + TORCH_CHECK(tensor_a.device() == tensor_b.device()); + TORCH_CHECK(tensor_a.device() == tensor_e.device()); + TORCH_CHECK(tensor_a.dim() == 2); + TORCH_CHECK(tensor_b.dim() == 2); + TORCH_CHECK(tensor_a.scalar_type() == tensor_b.scalar_type()); + TORCH_CHECK(tensor_a.scalar_type() == K::kElementAAt); + TORCH_CHECK(tensor_b.stride(0) == 1, "B must be Row-Major"); + TORCH_CHECK(tensor_a.is_contiguous()); + TORCH_CHECK(tensor_b.t().is_contiguous()); + int64_t a_rows = tensor_a.size(0); + if (a_scale.has_value()) { + TORCH_CHECK(a_scale->is_contiguous()); + TORCH_CHECK(a_scale->scalar_type() == at::ScalarType::Float); + TORCH_CHECK(a_scale->device() == tensor_a.device()); + TORCH_CHECK(a_scale->dim() == 2); + TORCH_CHECK(a_scale->size(0) == a_rows); + TORCH_CHECK(a_scale->size(1) == 1); + } + if (b_scale.has_value()) { + TORCH_CHECK(b_scale->is_contiguous()); + TORCH_CHECK(b_scale->scalar_type() == at::ScalarType::Float); + TORCH_CHECK(b_scale->device() == tensor_b.device()); + TORCH_CHECK(b_scale->dim() == 2); + TORCH_CHECK(b_scale->size(0) == 1); + TORCH_CHECK(b_scale->size(1) == tensor_b.size(1)); + } + + typename K::GemmKernel::Arguments args; + args.mode = cutlass::gemm::GemmUniversalMode::kGemm; + args.problem_shape = cute::make_shape( + int(a_rows), int(tensor_b.size(1)), int(tensor_b.size(0)), 1); + Tensor out = tensor_a.new_empty( + {cute::get<0>(args.problem_shape), cute::get<1>(args.problem_shape)}, + at::TensorOptions().dtype(K::kElementOutAt)); + + args.mainloop.ptr_A = + reinterpret_cast(tensor_a.data_ptr()); + args.mainloop.ptr_B = static_cast(tensor_b.data_ptr()); + args.mainloop.ptr_E = + reinterpret_cast(tensor_e.data_ptr()); + args.epilogue.ptr_C = nullptr; + args.epilogue.ptr_D = static_cast(out.data_ptr()); + + float const* a_scale_ptr = + (float const*)(a_scale.has_value() ? a_scale->data_ptr() : nullptr); + float const* b_scale_ptr = + (float const*)(b_scale.has_value() ? b_scale->data_ptr() : nullptr); + float default_scale = 1.0f; // used if ptr is nullptr + auto& cast_op = args.epilogue.thread; + auto& mulB_op = cast_op.op_0; + mulB_op.op_0 = {b_scale_ptr, default_scale}; + auto& mulA_op = mulB_op.op_1; + mulA_op.op_0 = {a_scale_ptr, default_scale}; + + args.mainloop.layout_a = + K::CollectiveMainloop::SparseConfig::fill_layoutA(args.problem_shape); + args.mainloop.layout_e = + K::CollectiveMainloop::SparseConfig::fill_layoutE(args.problem_shape); + args.mainloop.dB = cute::make_int_tuple_from( + tensor_b.stride(1), 0); + args.epilogue.dC = cute::make_int_tuple_from( + out.stride(0), 0); + args.epilogue.dD = cute::make_int_tuple_from( + out.stride(0), 0); + + /* Query device SM count to pass onto the kernel as an argument, where needed + */ + args.hw_info.device_id = tensor_a.device().index(); + args.hw_info.sm_count = sm_count; + args.scheduler.max_swizzle_size = swizzle_size; + using Enum_t = decltype(args.scheduler.raster_order); + if (swizzle_axis == "n") { + args.scheduler.raster_order = Enum_t::AlongN; + } else { + TORCH_CHECK( + swizzle_axis == "m", + "Invalid value for swizzle_axis ('", + swizzle_axis, + "')"); + args.scheduler.raster_order = Enum_t::AlongM; + } + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + int64_t device_op_workspace_size = Gemm::get_workspace_size(args); + Tensor workspace = tensor_a.new_empty( + {device_op_workspace_size}, + at::TensorOptions().dtype(at::ScalarType::Byte)); + + Gemm gemm; + // Check the problem size is supported or not + CUTLASS_STATUS_CHECK(gemm.can_implement(args)); + + auto status = gemm.run( + args, (void*)workspace.data_ptr(), at::cuda::getCurrentCUDAStream()); + CUTLASS_STATUS_CHECK(status); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + return out; +} + +template +std::tuple _sparse24_sm90_cutlass_compress_t(Tensor a) { + std::optional device_guard; + if (!kIsMeta) { + device_guard.emplace(a.device()); + } + + using K = SparseRowwiseKernel; + TORCH_CHECK(a.scalar_type() == K::kElementAAt); + TORCH_CHECK(a.is_contiguous()); + + // Offline compressor kernel + using LayoutA = cutlass::layout::RowMajor; + using ProblemShape = cute::Shape; + using SparseConfig = typename K::CollectiveMainloop::SparseConfig; + using CompressorUtility = + cutlass::transform::kernel::StructuredSparseCompressorUtility< + ProblemShape, + typename K::ElementA, + LayoutA, + SparseConfig>; + + using CompressorKernel = + cutlass::transform::kernel::StructuredSparseCompressor< + ProblemShape, + typename K::ElementA, + LayoutA, + SparseConfig, + cutlass::arch::Sm90>; + + using Compressor = + cutlass::transform::device::TransformUniversalAdapter; + + auto problem_shape = + cute::make_shape(int(a.size(0)), 8192, int(a.size(1)), 1); + auto [M, N, k, L] = problem_shape; + auto stride_A = cutlass::make_cute_packed_stride( + cutlass::gemm::TagToStrideA_t{}, cute::make_shape(M, k, L)); + CompressorUtility compressor_utility(problem_shape, stride_A); + + int ME = compressor_utility.get_metadata_m_physical(); + int KE = compressor_utility.get_metadata_k_physical(); + int KC = compressor_utility.get_tensorA_k_physical(); + + auto a_compressed = a.new_empty({M, KC * L}); + auto e = a.new_empty({ME * KE * L}, at::TensorOptions().dtype(at::kByte)); + + if (kIsMeta) { + return std::make_tuple(a_compressed, e); + } + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = a.device().index(); + hw_info.sm_count = 128; + typename Compressor::Arguments arguments{ + problem_shape, + {(typename K::ElementA const*)a.data_ptr(), + stride_A, + (typename K::ElementA*)a_compressed.data_ptr(), + (typename K::ElementE*)e.data_ptr()}, + {hw_info}}; + + Compressor compressor_op; + int64_t workspace_size = Compressor::get_workspace_size(arguments); + Tensor workspace = a.new_empty( + {workspace_size}, at::TensorOptions().dtype(at::ScalarType::Byte)); + + CUTLASS_STATUS_CHECK(compressor_op.can_implement(arguments)); + CUTLASS_STATUS_CHECK( + compressor_op.initialize(arguments, workspace.data_ptr())); + CUTLASS_STATUS_CHECK(compressor_op.run()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return std::make_tuple(a_compressed, e); +} + +template +std::tuple _sparse24_sm90_cutlass_compress(Tensor a) { + if (a.scalar_type() == at::ScalarType::Float8_e4m3fn) { + return _sparse24_sm90_cutlass_compress_t(a); + } + if (a.scalar_type() == at::ScalarType::BFloat16) { + return _sparse24_sm90_cutlass_compress_t(a); + } + TORCH_CHECK(false, "Unsupported dtype for operand"); +} +} // namespace + +TORCH_LIBRARY_IMPL(torchao, CUDA, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchao::_sparse24_fp8_sm90_cutlass_gemm"), + TORCH_FN(_sparse24_fp8_sm90_cutlass_gemm)); + m.impl( + TORCH_SELECTIVE_NAME("torchao::_sparse24_sm90_cutlass_compress"), + TORCH_FN(_sparse24_sm90_cutlass_compress)); +} + +TORCH_LIBRARY_IMPL(torchao, Meta, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchao::_sparse24_fp8_sm90_cutlass_gemm"), + TORCH_FN(_sparse24_fp8_sm90_cutlass_gemm)); + m.impl( + TORCH_SELECTIVE_NAME("torchao::_sparse24_sm90_cutlass_compress"), + TORCH_FN(_sparse24_sm90_cutlass_compress)); +} diff --git a/torchao/ops.py b/torchao/ops.py index faebdbd5d1..60ce3f84ea 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -42,6 +42,12 @@ lib.define( "sparse24_sm90_sparsify(Tensor input, str metadata_fmt, str activation, str sp_selection_algo, *, ScalarType? dtype = None, Tensor? scale=None) -> (Tensor, Tensor)" ) +lib.define( + "_sparse24_fp8_sm90_cutlass_gemm(Tensor a, Tensor a_mdata, Tensor b, *, Tensor? a_scale = None, Tensor? b_scale = None, int swizzle_size=8, str swizzle_axis='n', int sm_count=128) -> Tensor" +) +lib.define( + "_sparse24_sm90_cutlass_compress(Tensor a) -> (Tensor, Tensor)" +) lib.define( "swizzle_mm(Tensor mat1, Tensor mat2, bool mat1_is_swizzled, bool mat2_is_swizzled) -> Tensor" ) @@ -839,6 +845,17 @@ def sparse24_sm90_sparsify( input_tensor, metadata_format, activation, algorithm, dtype=dtype, scale=scale ) +def _sparse24_fp8_sm90_cutlass_gemm( + a: Tensor, + meta: Tensor, + b: Tensor, + a_scale: Optional[Tensor], + b_scale: Optional[Tensor], + swizzle_size: int, + swizzle_axis: str, + sm_count: int, +) -> Tensor: + return torch.ops.torchao._sparse24_fp8_sm90_cutlass_gemm(a, meta, b, a_scale=a_scale, b_scale=b_scale, swizzle_size=swizzle_size, swizzle_axis=swizzle_axis, sm_count=sm_count) def swizzle_mm( mat1: Tensor, mat2: Tensor, mat1_is_swizzled: bool, mat2_is_swizzled: bool From 6d85b5c3b876c0eba9432612c93ba2234f141231 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 22 May 2025 05:48:44 -0700 Subject: [PATCH 2/7] remove cutlass compression --- test/sparsity/test_activation24.py | 4 +- torchao/csrc/cuda/activation24/sparse_gemm.cu | 98 +------------------ torchao/ops.py | 9 +- 3 files changed, 7 insertions(+), 104 deletions(-) diff --git a/test/sparsity/test_activation24.py b/test/sparsity/test_activation24.py index 37374291e4..2ad18a2684 100644 --- a/test/sparsity/test_activation24.py +++ b/test/sparsity/test_activation24.py @@ -163,7 +163,7 @@ def test_sparse24_fp8_sm90_cutlass_gemm_eye( # Check MM without scale eye = torch.eye(A.shape[1], device=A.device, dtype=A.dtype).T - A_reconstructed = torch.ops.torchao._sparse24_fp8_sm90_cutlass_gemm( + A_reconstructed = torch.ops.torchao.sparse24_fp8_sm90_cutlass_gemm( A_packed, A_mdata, eye ) assert torch.allclose(A.float(), A_reconstructed.float()) @@ -201,7 +201,7 @@ def _to_fp8_rowwise(x: torch.Tensor, dtype): b_scale = b_scale.T A_packed, A_mdata = to_sparse_semi_structured_cutlass_sm9x_f8(A) - out_sparse = torch.ops.torchao._sparse24_fp8_sm90_cutlass_gemm( + out_sparse = torch.ops.torchao.sparse24_fp8_sm90_cutlass_gemm( A_packed, A_mdata, B, a_scale=a_scale, b_scale=b_scale ) out_ref = torch._scaled_mm( diff --git a/torchao/csrc/cuda/activation24/sparse_gemm.cu b/torchao/csrc/cuda/activation24/sparse_gemm.cu index ae907a68ec..89a78cc279 100644 --- a/torchao/csrc/cuda/activation24/sparse_gemm.cu +++ b/torchao/csrc/cuda/activation24/sparse_gemm.cu @@ -17,7 +17,6 @@ #include "cutlass/gemm/kernel/gemm_universal.hpp" #include "cutlass/numeric_types.h" #include "cutlass/transform/device/transform_universal_adapter.hpp" -#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" #include #include @@ -330,109 +329,16 @@ Tensor _sparse24_fp8_sm90_cutlass_gemm( C10_CUDA_KERNEL_LAUNCH_CHECK(); return out; } - -template -std::tuple _sparse24_sm90_cutlass_compress_t(Tensor a) { - std::optional device_guard; - if (!kIsMeta) { - device_guard.emplace(a.device()); - } - - using K = SparseRowwiseKernel; - TORCH_CHECK(a.scalar_type() == K::kElementAAt); - TORCH_CHECK(a.is_contiguous()); - - // Offline compressor kernel - using LayoutA = cutlass::layout::RowMajor; - using ProblemShape = cute::Shape; - using SparseConfig = typename K::CollectiveMainloop::SparseConfig; - using CompressorUtility = - cutlass::transform::kernel::StructuredSparseCompressorUtility< - ProblemShape, - typename K::ElementA, - LayoutA, - SparseConfig>; - - using CompressorKernel = - cutlass::transform::kernel::StructuredSparseCompressor< - ProblemShape, - typename K::ElementA, - LayoutA, - SparseConfig, - cutlass::arch::Sm90>; - - using Compressor = - cutlass::transform::device::TransformUniversalAdapter; - - auto problem_shape = - cute::make_shape(int(a.size(0)), 8192, int(a.size(1)), 1); - auto [M, N, k, L] = problem_shape; - auto stride_A = cutlass::make_cute_packed_stride( - cutlass::gemm::TagToStrideA_t{}, cute::make_shape(M, k, L)); - CompressorUtility compressor_utility(problem_shape, stride_A); - - int ME = compressor_utility.get_metadata_m_physical(); - int KE = compressor_utility.get_metadata_k_physical(); - int KC = compressor_utility.get_tensorA_k_physical(); - - auto a_compressed = a.new_empty({M, KC * L}); - auto e = a.new_empty({ME * KE * L}, at::TensorOptions().dtype(at::kByte)); - - if (kIsMeta) { - return std::make_tuple(a_compressed, e); - } - - cutlass::KernelHardwareInfo hw_info; - hw_info.device_id = a.device().index(); - hw_info.sm_count = 128; - typename Compressor::Arguments arguments{ - problem_shape, - {(typename K::ElementA const*)a.data_ptr(), - stride_A, - (typename K::ElementA*)a_compressed.data_ptr(), - (typename K::ElementE*)e.data_ptr()}, - {hw_info}}; - - Compressor compressor_op; - int64_t workspace_size = Compressor::get_workspace_size(arguments); - Tensor workspace = a.new_empty( - {workspace_size}, at::TensorOptions().dtype(at::ScalarType::Byte)); - - CUTLASS_STATUS_CHECK(compressor_op.can_implement(arguments)); - CUTLASS_STATUS_CHECK( - compressor_op.initialize(arguments, workspace.data_ptr())); - CUTLASS_STATUS_CHECK(compressor_op.run()); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - - return std::make_tuple(a_compressed, e); -} - -template -std::tuple _sparse24_sm90_cutlass_compress(Tensor a) { - if (a.scalar_type() == at::ScalarType::Float8_e4m3fn) { - return _sparse24_sm90_cutlass_compress_t(a); - } - if (a.scalar_type() == at::ScalarType::BFloat16) { - return _sparse24_sm90_cutlass_compress_t(a); - } - TORCH_CHECK(false, "Unsupported dtype for operand"); -} } // namespace TORCH_LIBRARY_IMPL(torchao, CUDA, m) { m.impl( - TORCH_SELECTIVE_NAME("torchao::_sparse24_fp8_sm90_cutlass_gemm"), + TORCH_SELECTIVE_NAME("torchao::sparse24_fp8_sm90_cutlass_gemm"), TORCH_FN(_sparse24_fp8_sm90_cutlass_gemm)); - m.impl( - TORCH_SELECTIVE_NAME("torchao::_sparse24_sm90_cutlass_compress"), - TORCH_FN(_sparse24_sm90_cutlass_compress)); } TORCH_LIBRARY_IMPL(torchao, Meta, m) { m.impl( - TORCH_SELECTIVE_NAME("torchao::_sparse24_fp8_sm90_cutlass_gemm"), + TORCH_SELECTIVE_NAME("torchao::sparse24_fp8_sm90_cutlass_gemm"), TORCH_FN(_sparse24_fp8_sm90_cutlass_gemm)); - m.impl( - TORCH_SELECTIVE_NAME("torchao::_sparse24_sm90_cutlass_compress"), - TORCH_FN(_sparse24_sm90_cutlass_compress)); } diff --git a/torchao/ops.py b/torchao/ops.py index 60ce3f84ea..4b3f6ad525 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -43,10 +43,7 @@ "sparse24_sm90_sparsify(Tensor input, str metadata_fmt, str activation, str sp_selection_algo, *, ScalarType? dtype = None, Tensor? scale=None) -> (Tensor, Tensor)" ) lib.define( - "_sparse24_fp8_sm90_cutlass_gemm(Tensor a, Tensor a_mdata, Tensor b, *, Tensor? a_scale = None, Tensor? b_scale = None, int swizzle_size=8, str swizzle_axis='n', int sm_count=128) -> Tensor" -) -lib.define( - "_sparse24_sm90_cutlass_compress(Tensor a) -> (Tensor, Tensor)" + "sparse24_fp8_sm90_cutlass_gemm(Tensor a, Tensor a_mdata, Tensor b, *, Tensor? a_scale = None, Tensor? b_scale = None, int swizzle_size=8, str swizzle_axis='n', int sm_count=128) -> Tensor" ) lib.define( "swizzle_mm(Tensor mat1, Tensor mat2, bool mat1_is_swizzled, bool mat2_is_swizzled) -> Tensor" @@ -845,7 +842,7 @@ def sparse24_sm90_sparsify( input_tensor, metadata_format, activation, algorithm, dtype=dtype, scale=scale ) -def _sparse24_fp8_sm90_cutlass_gemm( +def sparse24_fp8_sm90_cutlass_gemm( a: Tensor, meta: Tensor, b: Tensor, @@ -855,7 +852,7 @@ def _sparse24_fp8_sm90_cutlass_gemm( swizzle_axis: str, sm_count: int, ) -> Tensor: - return torch.ops.torchao._sparse24_fp8_sm90_cutlass_gemm(a, meta, b, a_scale=a_scale, b_scale=b_scale, swizzle_size=swizzle_size, swizzle_axis=swizzle_axis, sm_count=sm_count) + return torch.ops.torchao.sparse24_fp8_sm90_cutlass_gemm(a, meta, b, a_scale=a_scale, b_scale=b_scale, swizzle_size=swizzle_size, swizzle_axis=swizzle_axis, sm_count=sm_count) def swizzle_mm( mat1: Tensor, mat2: Tensor, mat1_is_swizzled: bool, mat2_is_swizzled: bool From d7909086a3a7f1da873e231bfe3472c35d3f343f Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 22 May 2025 05:50:56 -0700 Subject: [PATCH 3/7] ruff fix --- test/sparsity/test_activation24.py | 4 +--- torchao/ops.py | 13 ++++++++++++- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/test/sparsity/test_activation24.py b/test/sparsity/test_activation24.py index 2ad18a2684..420bf4328a 100644 --- a/test/sparsity/test_activation24.py +++ b/test/sparsity/test_activation24.py @@ -149,7 +149,7 @@ def test_sparse24_fp8_sm90_cutlass_gemm_eye( M=512, K=256, dtype=torch.float8_e4m3fn ) -> None: torch.manual_seed(0) - + A_dense = create_semi_structured_tensor(M, K, dtype=torch.bfloat16).cuda() A_aqt = _float8_cutlass_quant(A_dense, dtype) A = A_aqt.tensor_impl.float8_data @@ -183,7 +183,6 @@ def test_sparse24_fp8_sm90_cutlass_gemm_eye( def test_sparse24_fp8_sm90_cutlass_gemm_random_tensor( M=512, N=1024, K=256, dtype=torch.float8_e4m3fn ) -> None: - def _to_fp8_rowwise(x: torch.Tensor, dtype): max_v = torch.finfo(dtype).max x_scale = (x.abs().max(1, keepdim=True)[0] / max_v).float() @@ -208,4 +207,3 @@ def _to_fp8_rowwise(x: torch.Tensor, dtype): A, B, scale_a=a_scale, scale_b=b_scale, out_dtype=out_sparse.dtype ) assert torch.allclose(out_sparse, out_ref, rtol=0.01, atol=0.01) - diff --git a/torchao/ops.py b/torchao/ops.py index 4b3f6ad525..b91bb8ae18 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -842,6 +842,7 @@ def sparse24_sm90_sparsify( input_tensor, metadata_format, activation, algorithm, dtype=dtype, scale=scale ) + def sparse24_fp8_sm90_cutlass_gemm( a: Tensor, meta: Tensor, @@ -852,7 +853,17 @@ def sparse24_fp8_sm90_cutlass_gemm( swizzle_axis: str, sm_count: int, ) -> Tensor: - return torch.ops.torchao.sparse24_fp8_sm90_cutlass_gemm(a, meta, b, a_scale=a_scale, b_scale=b_scale, swizzle_size=swizzle_size, swizzle_axis=swizzle_axis, sm_count=sm_count) + return torch.ops.torchao.sparse24_fp8_sm90_cutlass_gemm( + a, + meta, + b, + a_scale=a_scale, + b_scale=b_scale, + swizzle_size=swizzle_size, + swizzle_axis=swizzle_axis, + sm_count=sm_count, + ) + def swizzle_mm( mat1: Tensor, mat2: Tensor, mat1_is_swizzled: bool, mat2_is_swizzled: bool From af3e70df7dbae956d6193bbaf820e5d273bfeb29 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 22 May 2025 05:51:57 -0700 Subject: [PATCH 4/7] one more ruff fix --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 790f241d52..0915f6ae1e 100644 --- a/setup.py +++ b/setup.py @@ -433,7 +433,7 @@ def get_extensions(): "to_sparse_semi_structured_cutlass_sm9x_f8.cu", ), os.path.join(extensions_cuda_dir, "activation24", "sparsify24.cu"), - os.path.join(extensions_cuda_dir, "activation24", "sparse_gemm.cu") + os.path.join(extensions_cuda_dir, "activation24", "sparse_gemm.cu"), ] for dtypes in ["e4m3e4m3", "e4m3e5m2", "e5m2e4m3", "e5m2e5m2"]: cutlass_90a_sources.append( From 8b1e8ffe13b7029e58f0f1f9fced66009e0efba0 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 22 May 2025 07:19:50 -0700 Subject: [PATCH 5/7] don't build for CUDA 11.8 --- setup.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 0915f6ae1e..e181c8cf16 100644 --- a/setup.py +++ b/setup.py @@ -433,7 +433,9 @@ def get_extensions(): "to_sparse_semi_structured_cutlass_sm9x_f8.cu", ), os.path.join(extensions_cuda_dir, "activation24", "sparsify24.cu"), - os.path.join(extensions_cuda_dir, "activation24", "sparse_gemm.cu"), + # 11.8 not supported + if torch.version.cuda >= "12.0": + os.path.join(extensions_cuda_dir, "activation24", "sparse_gemm.cu"), ] for dtypes in ["e4m3e4m3", "e4m3e5m2", "e5m2e4m3", "e5m2e5m2"]: cutlass_90a_sources.append( From 618cd85af2c33786ee67d40c7ddb961cd73c1beb Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 22 May 2025 07:22:05 -0700 Subject: [PATCH 6/7] fix formatting --- setup.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index e181c8cf16..dee34c35c0 100644 --- a/setup.py +++ b/setup.py @@ -433,10 +433,12 @@ def get_extensions(): "to_sparse_semi_structured_cutlass_sm9x_f8.cu", ), os.path.join(extensions_cuda_dir, "activation24", "sparsify24.cu"), - # 11.8 not supported - if torch.version.cuda >= "12.0": - os.path.join(extensions_cuda_dir, "activation24", "sparse_gemm.cu"), ] + # 11.8 not supported + if torch.version.cuda >= "12.0": + cutlass_90a_sources.append( + os.path.join(extensions_cuda_dir, "activation24", "sparse_gemm.cu"), + ) for dtypes in ["e4m3e4m3", "e4m3e5m2", "e5m2e4m3", "e5m2e5m2"]: cutlass_90a_sources.append( os.path.join( From e17ebfdeb5a967b7585c6f07cddaecf1bf1eb160 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 22 May 2025 07:41:27 -0700 Subject: [PATCH 7/7] ifdef to avoid issues --- setup.py | 6 +----- torchao/csrc/cuda/activation24/sparse_gemm.cu | 7 +++++++ 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/setup.py b/setup.py index dee34c35c0..0915f6ae1e 100644 --- a/setup.py +++ b/setup.py @@ -433,12 +433,8 @@ def get_extensions(): "to_sparse_semi_structured_cutlass_sm9x_f8.cu", ), os.path.join(extensions_cuda_dir, "activation24", "sparsify24.cu"), + os.path.join(extensions_cuda_dir, "activation24", "sparse_gemm.cu"), ] - # 11.8 not supported - if torch.version.cuda >= "12.0": - cutlass_90a_sources.append( - os.path.join(extensions_cuda_dir, "activation24", "sparse_gemm.cu"), - ) for dtypes in ["e4m3e4m3", "e4m3e5m2", "e5m2e4m3", "e5m2e5m2"]: cutlass_90a_sources.append( os.path.join( diff --git a/torchao/csrc/cuda/activation24/sparse_gemm.cu b/torchao/csrc/cuda/activation24/sparse_gemm.cu index 89a78cc279..776766794e 100644 --- a/torchao/csrc/cuda/activation24/sparse_gemm.cu +++ b/torchao/csrc/cuda/activation24/sparse_gemm.cu @@ -4,6 +4,12 @@ #include #include +#if defined(TORCHAO_USE_CUTLASS) && !defined(_WIN32) && \ + defined(CUDA_VERSION) && (CUDA_VERSION >= 12020) +#define BUILD_SM90_24_FP8_CUTLASS_GEMM +#endif + +#if defined(BUILD_SM90_24_FP8_CUTLASS_GEMM) #include #include #include @@ -342,3 +348,4 @@ TORCH_LIBRARY_IMPL(torchao, Meta, m) { TORCH_SELECTIVE_NAME("torchao::sparse24_fp8_sm90_cutlass_gemm"), TORCH_FN(_sparse24_fp8_sm90_cutlass_gemm)); } +#endif