From 980b5e6c257c4147ad7508225162a4b1e67d75d2 Mon Sep 17 00:00:00 2001 From: kushanam Date: Mon, 24 Feb 2025 18:20:14 -0800 Subject: [PATCH 1/9] add cutlass support for blackwell fp8 gemm --- CMakeLists.txt | 16 +- .../scaled_mm_epilogues_c3x_blackwell.hpp | 384 ++++++++++++++++++ .../cutlass_w8a8/c3x/cutlass_gemm_caller.cuh | 7 +- .../cutlass_w8a8/c3x/scaled_mm.cuh | 64 +++ .../cutlass_w8a8/c3x/scaled_mm_kernels.hpp | 6 + .../cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu | 25 ++ .../c3x/scaled_mm_sm100_fp8_dispatch.cuh | 67 +++ .../cutlass_w8a8/scaled_mm_c3x.cu | 21 + .../cutlass_w8a8/scaled_mm_entry.cu | 12 +- 9 files changed, 591 insertions(+), 11 deletions(-) create mode 100644 csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x_blackwell.hpp create mode 100644 csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu create mode 100644 csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh diff --git a/CMakeLists.txt b/CMakeLists.txt index c5fc2f3c1aaf..0607156326b2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -31,7 +31,7 @@ set(ignoreMe "${VLLM_PYTHON_PATH}") set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12") # Supported NVIDIA architectures. -set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0") +set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0") # Supported AMD GPU architectures. set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101") @@ -297,7 +297,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # Only build Marlin kernels if we are building for at least some compatible archs. # Keep building Marlin for 9.0 as there are some group sizes and shapes that # are not supported by Machete yet. - cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}") if (MARLIN_ARCHS) set(MARLIN_SRCS "csrc/quantization/fp8/fp8_marlin.cu" @@ -335,14 +335,15 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require # CUDA 12.0 or later (and only work on Hopper, 9.0a for now). - cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a;10.0a;10.1a;12.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS) set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu" "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu" "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu" "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu" - "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu") + "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu" + "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_3X_ARCHS}") @@ -369,7 +370,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x) # kernels for the remaining archs that are not already built for 3x. cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS - "7.5;8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}") + "7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}") # subtract out the archs that are already built for 3x list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS}) if (SCALED_MM_2X_ARCHS) @@ -394,7 +395,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # 2:4 Sparse Kernels # The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor - # require CUDA 12.2 or later (and only work on Hopper, 9.0a for now). + # require CUDA 12.2 or later (and only work on Hopper and Blackwell). if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS) set(SRCS "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu") set_gencode_flags_for_srcs( @@ -514,6 +515,7 @@ define_gpu_extension_target( COMPILE_FLAGS ${VLLM_GPU_FLAGS} ARCHITECTURES ${VLLM_GPU_ARCHES} INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} USE_SABI 3 WITH_SOABI) @@ -537,7 +539,7 @@ set_gencode_flags_for_srcs( CUDA_ARCHS "${CUDA_ARCHS}") if(VLLM_GPU_LANG STREQUAL "CUDA") - cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}") if (MARLIN_MOE_ARCHS) set(MARLIN_MOE_SRC "csrc/moe/marlin_kernels/marlin_moe_kernel.h" diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x_blackwell.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x_blackwell.hpp new file mode 100644 index 000000000000..3df1ed090935 --- /dev/null +++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x_blackwell.hpp @@ -0,0 +1,384 @@ +#pragma once + +#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp" + +/* + This file defines custom epilogues for fusing channel scales, token scales, + bias, and activation zero-points onto a GEMM operation using the + CUTLASS 3.x API, for NVIDIA GPUs with sm90a (Hopper) or later. + + Epilogues must contain a public type named EVTCompute of type Sm90EVT, + as well as a static prepare_args function that constructs an + EVTCompute::Arguments struct. +*/ + +namespace vllm::c3x_blackwell { + +using namespace cute; + +template +struct identity { + CUTLASS_HOST_DEVICE + T operator()(T lhs) const { return lhs; } +}; + +template +struct TrivialEpilogue { + private: + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + using Compute = cutlass::epilogue::fusion::Sm90Compute< + cutlass::epilogue::thread::Identity, ElementD, ElementAcc, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + template + static ArgumentType prepare_args(Args... args) { + return {}; + } +}; + +/* + * This class provides the common load descriptors for the + * ScaledEpilogue[...] classes + */ +template +struct ScaledEpilogueBase { + protected: + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + + template + using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast< + 0 /*Stages*/, TileShape, T, + Stride, Int<0>, Int<0>>>; + + template + using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast< + 0 /*Stages*/, TileShape, T, + Stride, Int<1>, Int<0>>>; + + // Don't want to support nullptr by default + template + using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast< + 0 /*Stages*/, TileShape, T, T, + Stride, Int<0>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; + + // Don't want to support nullptr by default + template + using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast< + 0 /*Stages*/, TileShape, T, T, + Stride, Int<1>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; + + // This utility function constructs the arguments for the load descriptors + // from a tensor. It can handle both row and column, as well as row/column or + // scalar cases. + template + static auto args_from_tensor(torch::Tensor const& tensor) { + using Arguments = typename Descriptor::Arguments; + auto* data_ptr = static_cast(tensor.data_ptr()); + if constexpr (std::is_same_v> || + std::is_same_v>) { + return Arguments{data_ptr, tensor.numel() != 1}; + } else { + static_assert(!std::is_same_v> && + !std::is_same_v>); + return Arguments{data_ptr}; + } + } + + // This overload handles the case where there might not be a tensor, in which + // case a nullptr is passed and a constant (0) is used. + template + static auto args_from_tensor(std::optional const& tensor) { + using Arguments = typename Descriptor::Arguments; + auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr; + static_assert(std::is_same_v> || + std::is_same_v>); + return Arguments{data_ptr}; + } +}; + +/* + This epilogue function defines a quantized GEMM operation similar to + torch.scaled_mm_. + + A and B may be both either int8 or fp8_e4m3. A can be + quantized per-tensor or per-row. B can be quantized per-tensor or per-column. + Any combination of per-tensor and per-row or column is supported. + A and B must have symmetric quantization (zero point == 0). + + So the GEMM operation is D = (a_scales * A) (b_scales * B), where the + scales are applied elementwise with numpy-style broadcasting. + + ScaleA and ScaleB define the epilogue functions that apply the scales for + the A and B operands respectively. These scales may be either per-tensor or + per row or column. +*/ +template +struct ScaledEpilogue + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + + typename EVTCompute0::Arguments evt0_args{b_args, {}, {}}; + return ArgumentType{a_args, evt0_args, {}}; + } +}; + +/* + * This epilogue performs the same operation as ScaledEpilogue, but adds a bias. + * This bias can also be used in the per-tensor azp case, where the activation + * zero point (azp) is used to compute an azp correction term, + * which is folded into the bias. + * + * The bias tensor must be per-output channel. + * ScaleA and ScaleB can be per-tensor or per-token/per-channel. + */ +template +struct ScaledEpilogueBias + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowLoad; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + + using ArgumentType = typename EVTCompute::Arguments; + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + + typename EVTCompute0::Arguments evt0_args{b_args, {}, {}}; + return ArgumentType{a_args, evt0_args, bias_args, {}}; + } +}; + +/* + * This epilogue performs the same operation as ScaledEpilogueBias, but the + * bias is a column vector instead of a row vector. Useful e.g. if we are + * computing a GEMM via C^T += B^T A^T. This happens in the 2:4 sparse kernels. + */ +template +struct ScaledEpilogueColumnBias + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template ColLoad; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + + using ArgumentType = typename EVTCompute::Arguments; + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + + typename EVTCompute0::Arguments evt0_args{b_args, {}, {}}; + return ArgumentType{a_args, evt0_args, bias_args, {}}; + } +}; + +/* + * This epilogue directly supports per-tensor azp in int32 form. + * As opposed to the per-token epilogue below, this epilogue only has an azp_adj + * term, which should already be multiplied with the scalar azp. + * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B. + * + * This epilogue also supports bias, which remains per-channel. + */ +template +struct ScaledEpilogueBiasAzp + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowLoad; + + // This is the full AZP term, azp * J @ B, shape (1,n) + using AzpWithAdj = typename SUPER::template RowLoad; + + // Compute float(accum - azp_adj), both operands are int32_t + using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute< + cutlass::minus, float, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAzp = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeScaleB = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + std::optional const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + auto azp_adj_args = + SUPER::template args_from_tensor(azp_adj); + + typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args, {}}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args, {}}; + return ArgumentType{a_args, evt_scale_b_args, bias_args, {}}; + } +}; + +/* + * This epilogue supports per-token azp by computing and applying + * the correction term using a rank-1 update. If the term were materialized, + * it would require O(m*n) space, and this way it only requires O(m+n) space. + * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero + * point for each row of A. + * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B. + * + * This epilogue also supports bias, which remains per-channel. + */ +template +struct ScaledEpilogueBiasAzpToken + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowLoad; + + // Per-token azp term, shape (m,1) + using Azp = typename SUPER::template ColLoad; + + // This is the AZP adjustment term, J @ B, shape (1,n) + using AzpAdj = typename SUPER::template RowLoad; + + // Compute azp * azp_adj + using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, int32_t, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAzp = + cutlass::epilogue::fusion::Sm90EVT; + + // Compute float(accum - azp*azp_adj), all operands are int32_t + using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute< + cutlass::minus, float, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAcc = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeScaleB = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + torch::Tensor const& azp, + std::optional const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + auto azp_args = SUPER::template args_from_tensor(azp); + auto azp_adj_args = + SUPER::template args_from_tensor(azp_adj); + + typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args, {}}; + typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args, {}}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args, {}}; + return ArgumentType{a_args, evt_scale_b_args, bias_args, {}}; + } +}; + +}; // namespace vllm::c3x diff --git a/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh b/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh index 69a3f64cb0b0..c6aad5b5643a 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh @@ -16,6 +16,7 @@ #include "cutlass/gemm/kernel/gemm_universal.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/util/packed_stride.hpp" #include "core/math.hpp" #include "cutlass_extensions/common.hpp" @@ -73,11 +74,13 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, using StrideA = cute::Stride, int64_t>; using StrideB = cute::Stride, int64_t>; - using StrideC = typename Gemm::StrideC; + using StrideC = typename Gemm::GemmKernel::StrideC; StrideA a_stride{lda, cute::Int<1>{}, 0}; StrideB b_stride{ldb, cute::Int<1>{}, 0}; - StrideC c_stride{ldc, cute::Int<1>{}, cute::Int<0>{}}; + // StrideC c_stride{ldc, cute::Int<1>{}, cute::Int<0>{}}; + StrideC c_stride = + cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(ldc, 1, 0)); typename GemmKernel::ProblemShape prob_shape = get_problem_shape(a, b); diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh index d2f43e2b7a89..ffbab00e2a4f 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh @@ -88,4 +88,68 @@ struct cutlass_3x_gemm { struct GemmKernel : public KernelType {}; }; +template typename Epilogue_, + typename TileShape, typename ClusterShape, typename KernelSchedule, + typename EpilogueSchedule> +struct cutlass_3x_gemm_sm100 { + using ElementAB = ElementAB_; + using LayoutA = cutlass::layout::RowMajor; + static constexpr int AlignmentA = + 128 / cutlass::sizeof_bits::value; + + using LayoutB = cutlass::layout::ColumnMajor; + static constexpr int AlignmentB = + 128 / cutlass::sizeof_bits::value; + + using ElementD = ElementD_; + using LayoutD = cutlass::layout::ColumnMajor; + static constexpr int AlignmentD = + 128 / cutlass::sizeof_bits::value; + + using ElementC = void; + using LayoutC = LayoutD; + static constexpr int AlignmentC = AlignmentD; + + using ElementAcc = + typename std::conditional, int32_t, + float>::type; + using Epilogue = Epilogue_; + + // MMA type + using ElementAccumulator = float; + + // Epilogue types + using ElementBias = cutlass::half_t; + using ElementCompute = float; + using ElementAux = ElementD; + using LayoutAux = LayoutD; + using ElementAmax = float; + + using StrideD = Stride, Int<0>>; + using StrideC = StrideD; + + using EVTCompute = typename Epilogue::EVTCompute; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, TileShape, + ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, + ElementD, LayoutC, AlignmentD, EpilogueSchedule, + EVTCompute>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, ElementAB, + LayoutA, AlignmentA, ElementAB, LayoutB, AlignmentB, + ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, CollectiveMainloop, CollectiveEpilogue, void>; +}; + } // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp index 7ede9e067477..85272804774d 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp @@ -30,4 +30,10 @@ void cutlass_scaled_mm_blockwise_sm90_fp8(torch::Tensor& out, torch::Tensor const& a_scales, torch::Tensor const& b_scales); +void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + std::optional const& bias); + } // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu new file mode 100644 index 000000000000..0e65ad17ca70 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu @@ -0,0 +1,25 @@ +#include "scaled_mm_kernels.hpp" +#include "scaled_mm_sm100_fp8_dispatch.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x_blackwell.hpp" + +namespace vllm { + +void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + std::optional const& bias) { + TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); + if (bias) { + TORCH_CHECK(bias->dtype() == out.dtype(), + "currently bias dtype must match output dtype ", out.dtype()); + return cutlass_scaled_mm_sm100_fp8_epilogue< + c3x_blackwell::ScaledEpilogueBias>(out, a, b, a_scales, b_scales, + *bias); + } else { + return cutlass_scaled_mm_sm100_fp8_epilogue( + out, a, b, a_scales, b_scales); + } +} + +} // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh new file mode 100644 index 000000000000..2a2cb7393a7c --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh @@ -0,0 +1,67 @@ +#pragma once + +#include "scaled_mm.cuh" +#include "cutlass_gemm_caller.cuh" + +/** + * This file defines Gemm kernel configurations for SM100 (fp8) based on the + * Gemm shape. + */ + +namespace vllm { + +using c3x::cutlass_gemm_caller; + +template typename Epilogue> +struct sm100_fp8_config_default { + static_assert(std::is_same()); + using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; + using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; + using TileShape = Shape<_128, _128, _64>; + using ClusterShape = Shape<_2, _2, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm_sm100; +}; + +template typename Epilogue, + typename... EpilogueArgs> +inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + EpilogueArgs&&... args) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); + + using Cutlass3xGemmDefault = + typename sm100_fp8_config_default::Cutlass3xGemm; + return cutlass_gemm_caller( + out, a, b, std::forward(args)...); +} + +template