From ad66935e724f0ffd52ed9a49b50a57ba3fff647b Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 31 Jan 2024 10:12:46 +0000 Subject: [PATCH 01/33] Add CUTLASS as a submodule --- .gitmodules | 3 +++ third_party/cutlass | 1 + 2 files changed, 4 insertions(+) create mode 100644 .gitmodules create mode 160000 third_party/cutlass diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 000000000000..281cb2d85d91 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "third_party/cutlass"] + path = third_party/cutlass + url = https://github.com/NVIDIA/cutlass.git diff --git a/third_party/cutlass b/third_party/cutlass new file mode 160000 index 000000000000..39c6a83f231d --- /dev/null +++ b/third_party/cutlass @@ -0,0 +1 @@ +Subproject commit 39c6a83f231d6db2bc6b9c251e7add77d68cbfb4 From 396e537d26f0986be0426da1fb983ad14df62beb Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 31 Jan 2024 10:19:46 +0000 Subject: [PATCH 02/33] Port CUTLASS extensions --- csrc/cutlass_extensions/arch/mma.h | 120 +++ csrc/cutlass_extensions/compute_occupancy.h | 61 ++ .../epilogue/thread/fused_activations.h | 105 +++ .../epilogue_per_row_per_col_scale.h | 352 +++++++++ .../threadblock/epilogue_tensor_op_int32.h | 282 +++++++ csrc/cutlass_extensions/epilogue_helpers.h | 139 ++++ .../gemm/device/gemm_universal_base_compat.h | 438 +++++++++++ .../gemm/kernel/default_fpA_intB_traits.h | 138 ++++ .../gemm/kernel/default_int8_traits.h | 57 ++ .../gemm/kernel/fpA_intB_gemm.h | 574 +++++++++++++++ .../gemm/kernel/gemm_moe_problem_visitor.h | 73 ++ .../gemm/kernel/gemm_with_epilogue_visitor.h | 545 ++++++++++++++ .../gemm/kernel/mixed_gemm_B_layout.h | 114 +++ .../gemm/kernel/moe_cutlass_kernel.h | 526 +++++++++++++ .../gemm/kernel/moe_problem_visitor.h | 344 +++++++++ .../gemm/threadblock/default_dq_mma.h | 125 ++++ .../threadblock/default_dq_mma_multistage.h | 297 ++++++++ .../threadblock/default_dq_mma_pipelined.h | 249 +++++++ .../gemm/threadblock/default_mma.h | 290 ++++++++ .../gemm/threadblock/default_mma_bf16.h | 353 +++++++++ .../gemm/threadblock/dq_mma_base.h | 257 +++++++ .../gemm/threadblock/dq_mma_multistage.h | 110 +++ .../dq_mma_multistage_finegrained.h | 691 ++++++++++++++++++ .../threadblock/dq_mma_multistage_percol.h | 636 ++++++++++++++++ .../gemm/threadblock/dq_mma_pipelined.h | 397 ++++++++++ .../gemm/warp/default_mma_tensor_op.h | 107 +++ .../warp/mma_tensorop_compute_B_with_f16.h | 301 ++++++++ .../gemm/warp/mma_tensorop_dequantizer.h | 646 ++++++++++++++++ csrc/cutlass_extensions/gemm_configs.h | 72 ++ .../interleaved_numeric_conversion.h | 447 +++++++++++ .../tile_interleaved_layout.h | 66 ++ .../fine_grained_scale_zero_iterator.h | 248 +++++++ .../cutlass_extensions/weight_only_quant_op.h | 58 ++ 33 files changed, 9218 insertions(+) create mode 100644 csrc/cutlass_extensions/arch/mma.h create mode 100644 csrc/cutlass_extensions/compute_occupancy.h create mode 100644 csrc/cutlass_extensions/epilogue/thread/fused_activations.h create mode 100644 csrc/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h create mode 100644 csrc/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h create mode 100644 csrc/cutlass_extensions/epilogue_helpers.h create mode 100644 csrc/cutlass_extensions/gemm/device/gemm_universal_base_compat.h create mode 100644 csrc/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h create mode 100644 csrc/cutlass_extensions/gemm/kernel/default_int8_traits.h create mode 100644 csrc/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h create mode 100644 csrc/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h create mode 100644 csrc/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h create mode 100644 csrc/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h create mode 100644 csrc/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h create mode 100644 csrc/cutlass_extensions/gemm/kernel/moe_problem_visitor.h create mode 100644 csrc/cutlass_extensions/gemm/threadblock/default_dq_mma.h create mode 100644 csrc/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h create mode 100644 csrc/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h create mode 100644 csrc/cutlass_extensions/gemm/threadblock/default_mma.h create mode 100644 csrc/cutlass_extensions/gemm/threadblock/default_mma_bf16.h create mode 100644 csrc/cutlass_extensions/gemm/threadblock/dq_mma_base.h create mode 100644 csrc/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h create mode 100644 csrc/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h create mode 100644 csrc/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h create mode 100644 csrc/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h create mode 100644 csrc/cutlass_extensions/gemm/warp/default_mma_tensor_op.h create mode 100644 csrc/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h create mode 100644 csrc/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h create mode 100644 csrc/cutlass_extensions/gemm_configs.h create mode 100644 csrc/cutlass_extensions/interleaved_numeric_conversion.h create mode 100644 csrc/cutlass_extensions/tile_interleaved_layout.h create mode 100644 csrc/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h create mode 100644 csrc/cutlass_extensions/weight_only_quant_op.h diff --git a/csrc/cutlass_extensions/arch/mma.h b/csrc/cutlass_extensions/arch/mma.h new file mode 100644 index 000000000000..2362da4f7f2d --- /dev/null +++ b/csrc/cutlass_extensions/arch/mma.h @@ -0,0 +1,120 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates exposing architecture support for multiply-add operations +*/ + +#pragma once +#include "cutlass_extensions/weight_only_quant_op.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace arch +{ + +// Tag which triggers MMA which will trigger +struct OpMultiplyAddDequantizeInterleavedBToA; + +/* + Below we have extra tags to signal what kind of dequantization we want to do + (per col, scale only fine grained, finegrained with zero). This still lets us + the existing template infrastructure (incl. that in CUTLASS). However, we + split out the template below into OpMultiplyAddDequantizeInterleavedBToA along + with the quantization op before instantiating the GEMM pieces. + + Note that this is somewhat of a hack, but it SIGNIFICANTLY reduces the amount of + code we need to duplicate. + */ +struct OpMultiplyAddDequantizeInterleavedBToA_percol_scale; +struct OpMultiplyAddDequantizeInterleavedBToA_fine_scale; +struct OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias; + +// The default just forwards the original operator +template +struct TagOperator +{ + using TaggedOperator = MmaOp; +}; + +// Specializations below attach more information to the operator +template <> +struct TagOperator +{ + using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_percol_scale; +}; + +template <> +struct TagOperator +{ + using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scale; +}; + +template <> +struct TagOperator +{ + using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias; +}; + +// Here we instantiate some structs to "detag" the tagged operator. It splits it back to the original +// operator + the extra information. If no extra info was tagged, the dequant op per column scaling +// as a default. +template +struct DetagOperator +{ + using Operator = TaggedMmaOp; + static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY; +}; + +template <> +struct DetagOperator +{ + using Operator = OpMultiplyAddDequantizeInterleavedBToA; + static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY; +}; + +template <> +struct DetagOperator +{ + using Operator = OpMultiplyAddDequantizeInterleavedBToA; + static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY; +}; + +template <> +struct DetagOperator +{ + using Operator = OpMultiplyAddDequantizeInterleavedBToA; + static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS; +}; + +} // namespace arch +} // namespace cutlass diff --git a/csrc/cutlass_extensions/compute_occupancy.h b/csrc/cutlass_extensions/compute_occupancy.h new file mode 100644 index 000000000000..23821e1d1008 --- /dev/null +++ b/csrc/cutlass_extensions/compute_occupancy.h @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +#include "cutlass/device_kernel.h" +#include "tensorrt_llm/common/cudaUtils.h" + +namespace tensorrt_llm +{ +namespace cutlass_extensions +{ + +template +inline int compute_occupancy_for_kernel() +{ + + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + if (smem_size > (48 << 10)) + { + cudaFuncAttributes attr; + int device = 0; + int max_smem_per_block = 0; + tensorrt_llm::common::check_cuda_error(cudaGetDevice(&device)); + tensorrt_llm::common::check_cuda_error( + cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device)); + tensorrt_llm::common::check_cuda_error(cudaFuncGetAttributes(&attr, cutlass::Kernel)); + if (smem_size + attr.sharedSizeBytes >= static_cast(max_smem_per_block)) + { + // This should mean that + // cudaFuncSetAttribute(cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) + // wouldn't work. In that case, we return an occupancy of 0. This will cause the heuristic to ignore this + // configuration. + return 0; + } + } + + int max_active_blocks = -1; + tensorrt_llm::common::check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, cutlass::Kernel, GemmKernel::kThreadCount, smem_size)); + + return max_active_blocks; +} + +} // namespace cutlass_extensions +} // namespace tensorrt_llm diff --git a/csrc/cutlass_extensions/epilogue/thread/fused_activations.h b/csrc/cutlass_extensions/epilogue/thread/fused_activations.h new file mode 100644 index 000000000000..2ed13dde1920 --- /dev/null +++ b/csrc/cutlass_extensions/epilogue/thread/fused_activations.h @@ -0,0 +1,105 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing linear combination with a maximum operation used by epilogues. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/epilogue/thread/linear_combination_generic.h" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/functional.h" +#include "cutlass/half.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace epilogue +{ +namespace thread +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +__forceinline__ __device__ float copysignf_pos(float a, float b) +{ + float r; + r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000)); + return r; +} + +__forceinline__ __device__ float tanh_opt(float x) +{ +#if (__CUDACC_VER_MAJOR__ < 11) || (__CUDA_ARCH__ < 750) + const float exp_val = -1.f * fabs(2 * x); + return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x); +#else + return fast_tanh(x); +#endif +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +template <> +struct GELU_taylor +{ + static const bool kIsHeavy = true; + + CUTLASS_DEVICE + float operator()(float const& z) const + { + + float k0 = float(0.7978845608028654); + float k1 = float(0.044715); + + return float(cutlass::constants::half() * z + * (cutlass::constants::one() + tanh_opt(k0 * z * (cutlass::constants::one() + k1 * z * z)))); + } + + using Params = LinearCombinationGenericParams; + + CUTLASS_DEVICE + float operator()(float const& scalar, Params const& params_) const + { + return this->operator()(scalar); + } +}; + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h b/csrc/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h new file mode 100644 index 000000000000..1781fc3ac94c --- /dev/null +++ b/csrc/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h @@ -0,0 +1,352 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue visitor for threadblock scoped INT8 GEMMs that uses one scaling factor per row, and one per column. + + original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h + +*/ + +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass/arch/memory.h" +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/numeric_conversion.h" +#include "tensorrt_llm/common/quantization.h" + +namespace tk = tensorrt_llm::common; + +namespace cutlass +{ +namespace epilogue +{ +namespace threadblock +{ + +template +class EpilogueVisitorPerRowPerCol +{ +public: + using ThreadblockShape = ThreadblockShape_; + static int const kThreadCount = ThreadCount; + + using ScaleTileIterator = ScaleTileIterator_; + using OutputTileIterator = OutputTileIterator_; + using ElementwiseFunctor = ElementwiseFunctor_; + + static int const kIterations = OutputTileIterator::kIterations; + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + + using ElementOutput = typename OutputTileIterator::Element; + using LayoutOutput = cutlass::layout::RowMajor; + using ElementAccumulator = ElementAccumulator_; + + using AlphaScaleElementType = typename ScaleTileIterator::Element; + + using ElementCompute = ElementCompute_; + using AccumulatorFragment = Array; + using ComputeFragment = Array; + using OutputVector = Array; + + static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth; + static bool const kHasMultiStepsInRow = (OutputTileIterator::ThreadMap::Iterations::kColumn > 1); + + /// Argument structure + struct Arguments + { + + typename ElementwiseFunctor::Params elementwise; + int64_t batch_stride_alpha; + int64_t batch_stride_C; + int64_t batch_stride_D; + + // + // Methods + // + Arguments() + : batch_stride_alpha(0) + , batch_stride_C(0) + , batch_stride_D(0) + { + } + + Arguments(typename ElementwiseFunctor::Params elementwise_) + : elementwise(elementwise_) + , batch_stride_alpha(0) + , batch_stride_C(0) + , batch_stride_D(0) + { + } + + Arguments(typename ElementwiseFunctor::Params elementwise_, int64_t batch_stride_alpha_, + int64_t batch_stride_C_, int64_t batch_stride_D_) + : elementwise(elementwise_) + , batch_stride_alpha(batch_stride_alpha_) + , batch_stride_C(batch_stride_C_) + , batch_stride_D(batch_stride_D_) + { + } + }; + + struct Params + { + + typename ElementwiseFunctor::Params elementwise; + int64_t batch_stride_alpha; + int64_t batch_stride_C; + int64_t batch_stride_D; + + // + // Methods + // + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + Params(Arguments const& args) + : elementwise(args.elementwise) + , batch_stride_alpha(args.batch_stride_alpha) + , batch_stride_C(args.batch_stride_C) + , batch_stride_D(args.batch_stride_D) + { + } + }; + + /// Shared storage + struct SharedStorage + { + }; + +private: + Params const& params_; + SharedStorage& shared_storage_; + MatrixCoord extent_; + MatrixCoord extent_real_; + ElementwiseFunctor elementwise_; + + const bool per_token_quant_; + const bool per_channel_quant_; + + AlphaScaleElementType* ptr_alpha_row_; + AlphaScaleElementType* ptr_alpha_col_; + ScaleTileIterator iterator_alpha_col_; + OutputTileIterator iterator_C_; + OutputTileIterator iterator_D_; + + AlphaScaleElementType element_alpha_row_ = 1.0f; + AlphaScaleElementType element_alpha_col_ = 1.0f; + typename ScaleTileIterator::Fragment fragment_alpha_col_; + typename OutputTileIterator::Fragment fragment_C_; + typename OutputTileIterator::Fragment fragment_D_; + + ElementAccumulator beta_; + + int column_offset_; + + MatrixCoord thread_offset_; + +public: + CUTLASS_DEVICE + EpilogueVisitorPerRowPerCol(Params const& params, SharedStorage& shared_storage, + cutlass::MatrixCoord const& problem_size, int thread_idx, int warp_idx, int lane_idx, + typename ScaleTileIterator::Params params_alpha_col, typename OutputTileIterator::Params params_C, + typename OutputTileIterator::Params params_D, tk::QuantMode quant_option, AlphaScaleElementType* ptr_alpha_row, + AlphaScaleElementType* ptr_alpha_col, typename OutputTileIterator::Element* ptr_C, + typename OutputTileIterator::Element* ptr_D, + cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0), int column_offset = 0, + cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0)) + : params_(params) + , shared_storage_(shared_storage) + , extent_(problem_size) + , elementwise_(params.elementwise) + , per_token_quant_(quant_option.hasPerTokenScaling()) + , per_channel_quant_(quant_option.hasPerChannelScaling()) + , ptr_alpha_row_(ptr_alpha_row) + , ptr_alpha_col_(ptr_alpha_col) + , iterator_alpha_col_(params_alpha_col, ptr_alpha_col, problem_size, thread_idx, threadblock_offset) + , iterator_C_(params_C, ptr_C, problem_size, thread_idx, threadblock_offset) + , iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset) + , extent_real_(problem_size_real) + { + beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta); + + if (beta_ == ElementAccumulator()) + { + iterator_C_.clear_mask(); + } + + if (!per_channel_quant_ && (ptr_alpha_col_ != nullptr)) + { + element_alpha_col_ = *ptr_alpha_col_; + } + + if (!per_token_quant_ && (ptr_alpha_row_ != nullptr)) + { + element_alpha_row_ = *ptr_alpha_row_; + } + } + + /// Helper to indicate split-K behavior + CUTLASS_DEVICE + void set_k_partition(int split_k_index, ///< Index of this threadblock within split-K partitioned scheme + int split_k_slices) + { ///< Total number of split-K slices + } + + /// Called to set the batch index + CUTLASS_DEVICE + void set_batch_index(int batch_idx) + { + iterator_alpha_col_.add_pointer_offset(batch_idx * params_.batch_stride_alpha); + iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C); + iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D); + } + + /// Called at the start of the epilogue just before iterating over accumulator slices + CUTLASS_DEVICE + void begin_epilogue() + { + if (per_channel_quant_) + { + iterator_alpha_col_.load(fragment_alpha_col_); + } + } + + /// Called at the start of one step before starting accumulator exchange + CUTLASS_DEVICE + void begin_step(int step_idx) + { + fragment_D_.clear(); + fragment_C_.clear(); + + if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) + { + iterator_C_.load(fragment_C_); + ++iterator_C_; + } + } + + /// Called at the start of a row + CUTLASS_DEVICE + void begin_row(int row_idx) + { + // load alpha_row in begin_step only when per token(row) scaling is used + if (per_token_quant_) + { + int thread_offset_row + = iterator_D_.thread_start_row() + OutputTileIterator::ThreadMap::iteration_offset(row_idx).row(); + + arch::global_load( + element_alpha_row_, ptr_alpha_row_ + thread_offset_row, thread_offset_row < extent_.row()); + } + } + + /// Called after accumulators have been exchanged for each accumulator vector + CUTLASS_DEVICE + void visit(int iter_idx, int row_idx, int column_idx, int frag_idx, AccumulatorFragment const& accum) + { + + NumericArrayConverter source_converter; + + ComputeFragment result = source_converter(accum); + if (per_channel_quant_) + { + ComputeFragment alpha_col = reinterpret_cast(&fragment_alpha_col_)[column_idx]; + result = per_token_channel_scale_accumulator_(result, alpha_col, element_alpha_row_); + } + else + { + result = per_token_scale_accumulator_(result, element_alpha_col_, element_alpha_row_); + } + + // Convert to the output + NumericArrayConverter output_converter; + OutputVector& output = reinterpret_cast(&fragment_D_)[frag_idx]; + output = output_converter(result); + } + + /// Called at the end of a row + CUTLASS_DEVICE + void end_row(int row_idx) {} + + /// Called after all accumulator elements have been visited + CUTLASS_DEVICE + void end_step(int step_idx) + { + + iterator_D_.store(fragment_D_); + ++iterator_D_; + } + + /// Called after all steps have been completed + CUTLASS_DEVICE + void end_epilogue() {} + +private: + CUTLASS_DEVICE + ComputeFragment per_token_channel_scale_accumulator_( + ComputeFragment const& accum, ComputeFragment const& scale_col, AlphaScaleElementType const& scale_row) + { + + ComputeFragment result; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ComputeFragment::kElements; ++i) + { + result[i] = accum[i] * (scale_col[i] * scale_row); + } + + return result; + } + + CUTLASS_DEVICE + ComputeFragment per_token_scale_accumulator_( + ComputeFragment const& accum, AlphaScaleElementType const& scale_col, AlphaScaleElementType const& scale_row) + { + + ComputeFragment result; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ComputeFragment::kElements; ++i) + { + result[i] = accum[i] * (scale_col * scale_row); + } + + return result; + } +}; + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass diff --git a/csrc/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h b/csrc/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h new file mode 100644 index 000000000000..6f26d7901703 --- /dev/null +++ b/csrc/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h @@ -0,0 +1,282 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + + original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h + +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/platform/platform.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_clamp.h" +#include "cutlass/epilogue/thread/linear_combination_gelu.h" +#include "cutlass/epilogue/thread/linear_combination_hardswish.h" +#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" +#include "cutlass/epilogue/thread/linear_combination_relu.h" +#include "cutlass/epilogue/thread/linear_combination_relu0.h" +#include "cutlass/epilogue/thread/linear_combination_sigmoid.h" + +#include "cutlass/epilogue/thread/conversion_op.h" +#include "cutlass/epilogue/thread/reduction_op.h" + +#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h" + +#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h" +#include "cutlass/epilogue/threadblock/shared_load_iterator.h" +#include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h" +#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h" +#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h" +#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h" +#include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h" + +#include "cutlass/epilogue/threadblock/epilogue.h" +#include "cutlass/epilogue/threadblock/interleaved_epilogue.h" + +#include "cutlass/layout/permute.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace epilogue +{ +namespace threadblock +{ + +//////////////////////////////////////////////////////////////////////////////// + +namespace detail +{ + +/// Partial specialization for bfloat16_t <= int32_t x 8 epilogues avoids shared memory bank conflicts. +template +struct DefaultIteratorsTensorOp +{ + using WarpTileIterator + = cutlass::epilogue::warp::TileIteratorTensorOpMixed; + + using SharedLoadIterator + = cutlass::epilogue::threadblock::SharedLoadIteratorMixed; + + static int const kFragmentsPerIteration = 2; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load output tile from shared memory in epilogue. +/// +/// Satisfies: ReadableTileIterator +/// +template +class SharedLoadIteratorMixed +{ +public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = int32_t; + + using Layout = layout::RowMajor; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + + static int const kAlignment = ThreadMap::kElementsPerAccess * sizeof_bits::value / 8; + + static int const kThreads = ThreadMap::kThreads; + + /// Fragment object + using Fragment = Array; + + /// Memory access size + using AccessType = AlignedArray; + + /// Vector type used for SMEM loads + using LoadType = AlignedArray::value, ThreadMap::kElementsPerAccess), + const_min(16, kAlignment)>; + + static int const kLoadsPerAccess = AccessType::kElements / LoadType::kElements; + +private: + // + // Data members + // + + /// Byte-level pointer + LoadType const* pointers_[kLoadsPerAccess]; + + /// Stride along adjacent rows in units of LoadType + int stride_; + +public: + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + SharedLoadIteratorMixed(TensorRef ref, int thread_idx) + : stride_((ref.stride(0) / LoadType::kElements)) + { + + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx); + + // Initialize pointers + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) + { + pointers_[i] = reinterpret_cast(ref.data()); + + int col_idx = (thread_offset.column() / kElementsPerAccess) * kLoadsPerAccess; + int bank_offset = (col_idx * int(sizeof(LoadType)) / 128) % kLoadsPerAccess; + + col_idx += (bank_offset + i) % kLoadsPerAccess; + + pointers_[i] += thread_offset.row() * stride_ + col_idx; + } + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) + { + pointers_[i] += pointer_offset / LoadType::kElements; + } + } + + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const& offset) + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) + { + pointers_[i] + += offset.row() * Shape::kRow * stride_ + offset.column() * Shape::kColumn / LoadType::kElements; + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) const + { + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) + { + + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) + { + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) + { + + int row_ptr_offset = row * ThreadMap::Delta::kRow * stride_ + + group * ThreadMap::Delta::kGroup * stride_ + cluster * ThreadMap::Delta::kCluster * stride_ + + pointer_offset / LoadType::kElements; + + int frag_row_idx + = (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + LoadType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) + { + + int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn + column; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kLoadsPerAccess; ++v) + { + + int vector_idx + = (column * ThreadMap::Delta::kColumn / kElementsPerAccess * kLoadsPerAccess); + + LoadType const* memory_pointer = pointers_[v] + row_ptr_offset; + + frag_ptr[frag_idx * kLoadsPerAccess + v] = memory_pointer[vector_idx]; + } + } + } + } + } + } + + /// Loads a fragment + CUTLASS_DEVICE + void load(Fragment& frag) const + { + + load_with_pointer_offset(frag, 0); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/cutlass_extensions/epilogue_helpers.h b/csrc/cutlass_extensions/epilogue_helpers.h new file mode 100644 index 000000000000..54ba2465f76e --- /dev/null +++ b/csrc/cutlass_extensions/epilogue_helpers.h @@ -0,0 +1,139 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * @file epilogue_helpers.h + * + * This file includes types for the epilogues. The empty structs exist so we can signal to template + * code the type of epilogue we want to run, and let the underlying code specify the details such as + * element types, accumulator type and elements per vector access. + * + */ + +#pragma once + +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_generic.h" +#include "cutlass/epilogue/thread/linear_combination_relu.h" +#include "cutlass/epilogue/thread/linear_combination_silu.h" +#include "cutlass_extensions/epilogue/thread/fused_activations.h" + +namespace tensorrt_llm +{ +namespace cutlass_extensions +{ + +struct EpilogueOpBiasSilu +{ +}; + +struct EpilogueOpBiasReLU +{ +}; + +struct EpilogueOpBiasFtGelu +{ +}; + +struct EpilogueOpDefaultSilu +{ +}; + +struct EpilogueOpDefaultReLU +{ +}; + +struct EpilogueOpDefaultFtGelu +{ +}; + +struct EpilogueOpBias +{ +}; + +struct EpilogueOpDefault +{ +}; + +template +struct Epilogue +{ +}; + +constexpr auto BiasScaleMode = cutlass::epilogue::thread::ScaleType::NoBetaScaling; + +template +struct Epilogue +{ + using Op = cutlass::epilogue::thread::LinearCombinationSilu; +}; + +template +struct Epilogue +{ + using Op = cutlass::epilogue::thread::LinearCombinationRelu; +}; + +template +struct Epilogue +{ + using Op = cutlass::epilogue::thread::LinearCombinationGeneric; +}; + +template +struct Epilogue +{ + using Op = cutlass::epilogue::thread::LinearCombination; +}; + +constexpr auto DefaultScaleMode = cutlass::epilogue::thread::ScaleType::Default; + +template +struct Epilogue +{ + using Op = cutlass::epilogue::thread::LinearCombinationSilu; +}; + +template +struct Epilogue +{ + using Op = cutlass::epilogue::thread::LinearCombinationRelu; +}; + +template +struct Epilogue +{ + using Op = cutlass::epilogue::thread::LinearCombinationGeneric; +}; + +template +struct Epilogue +{ + using Op = cutlass::epilogue::thread::LinearCombination; +}; + +} // namespace cutlass_extensions +} // namespace tensorrt_llm diff --git a/csrc/cutlass_extensions/gemm/device/gemm_universal_base_compat.h b/csrc/cutlass_extensions/gemm/device/gemm_universal_base_compat.h new file mode 100644 index 000000000000..2edd5a228b47 --- /dev/null +++ b/csrc/cutlass_extensions/gemm/device/gemm_universal_base_compat.h @@ -0,0 +1,438 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and + batched array variants. +*/ + +#pragma once + +// #include + +#include "cutlass/arch/arch.h" +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_universal.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" + +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/kernel/default_gemm_universal.h" + +#include "cutlass/trace.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace device +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/* + This is the device layer from CUTLASS 2.10 (SHA - cc85b64cf676c45f98a17e3a47c0aafcf817f088) + It is replicated here since we needed to duplicate kernel level APIs for mixed dtype GEMMs + and SmoothQuant. The newer device layer is not compatible with these older kernel level APIs. + + Note: While CUTLASS 3.x supports stream-k, none of the kernels in the extensions folder support + that feature at the moment. + */ + +template +class GemmUniversalBaseCompat +{ +public: + using GemmKernel = GemmKernel_; + using ThreadblockShape = typename GemmKernel::Mma::Shape; + + using ElementA = typename GemmKernel::ElementA; + using LayoutA = typename GemmKernel::LayoutA; + using TensorRefA = TensorRef; + static ComplexTransform const kTransformA = GemmKernel::kTransformA; + + using ElementB = typename GemmKernel::ElementB; + using LayoutB = typename GemmKernel::LayoutB; + using TensorRefB = TensorRef; + static ComplexTransform const kTransformB = GemmKernel::kTransformB; + + using ElementC = typename GemmKernel::ElementC; + using LayoutC = typename GemmKernel::LayoutC; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + + using ElementAccumulator = typename GemmKernel::Mma::Policy::Operator::ElementC; + + using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp; + using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle; + using Operator = typename GemmKernel::Operator; + + /// Argument structure + using Arguments = typename GemmKernel::Arguments; + +protected: + /// Kernel parameters object + typename GemmKernel::Params params_; + +protected: + /// Private helper to obtain the grid dimensions with fix-up for split-K + static void get_grid_shape_(gemm::GemmCoord& grid_tiled_shape, int& gemm_k_size, Arguments const& args) + { + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count); + + gemm_k_size = args.problem_size.k(); + + if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) + { + + int const kAlignK + = const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); + + gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK); + + if (gemm_k_size) + { + grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); + } + } + } + +public: + /// Constructs the GEMM. + GemmUniversalBaseCompat() {} + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const& args) + { + + // Determine grid shape + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + + ThreadblockSwizzle threadblock_swizzle; + dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape); + + uint32_t const kGridYZMax = ((1 << (sizeof(uint16_t) * 8)) - 1); + + if (!(grid.y <= kGridYZMax && grid.z <= kGridYZMax)) + { + + return Status::kErrorInvalidProblem; + } + + return GemmKernel::can_implement(args); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const& args) + { + + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_workspace_size()"); + + size_t workspace_bytes = 0; + + // Determine grid shape + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + + if (args.mode == GemmUniversalMode::kGemmSplitKParallel) + { + + // Split-K parallel always requires a temporary workspace + workspace_bytes = sizeof(ElementC) * size_t(args.batch_stride_D) * size_t(grid_tiled_shape.k()); + } + else if (args.mode == GemmUniversalMode::kGemm && grid_tiled_shape.k() > 1) + { + + // Serial split-K only requires a temporary workspace if the number of partitions along the + // GEMM K dimension is greater than one. + workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n()); + } + + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + + workspace_bytes += GemmKernel::get_extra_workspace_size(args, grid_tiled_shape); + + return workspace_bytes; + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const& args) + { + + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_grid_shape()"); + + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape); + + CUTLASS_TRACE_HOST(" grid_tiled_shape: " << grid_tiled_shape << "\n" + << " result = {" << result << "}"); + + return result; + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) + { + + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::maximum_active_blocks()"); + + int max_active_blocks = -1; + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); + + if (smem_size <= (48 << 10)) + { + + cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, Kernel, GemmKernel::kThreadCount, smem_size); + + if (result == cudaSuccess) + { + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + } + else + { + + // Query assuming zero shared memory then compute occupancy limit based on SMEM + cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, Kernel, GemmKernel::kThreadCount, 0); + + if (result != cudaSuccess) + { + + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result)); + + return -1; + } + + if (smem_capacity < 0) + { + int device_idx = 0; + result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) + { + return -1; + } + + cudaDeviceProp properties; + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) + { + return -1; + } + + smem_capacity = static_cast(properties.sharedMemPerMultiprocessor); + } + + int occupancy = std::min(max_active_blocks, smem_capacity / smem_size); + + CUTLASS_TRACE_HOST(" occupancy: " << occupancy); + + return occupancy; + } + + CUTLASS_TRACE_HOST(" returning internal error"); + + return -1; + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) + { + + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + size_t workspace_bytes = get_workspace_size(args); + + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + + if (workspace_bytes) + { + + if (!workspace) + { + CUTLASS_TRACE_HOST(" error: device workspace must not be null"); + + return Status::kErrorWorkspaceNull; + } + + if (args.mode == GemmUniversalMode::kGemm) + { + CUTLASS_TRACE_HOST(" clearing device workspace"); + cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream); + + if (result != cudaSuccess) + { + CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result)); + + return Status::kErrorInternal; + } + } + } + + // Get CUDA grid shape + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + + // Initialize the Params structure + params_ = typename GemmKernel::Params(args, grid_tiled_shape, gemm_k_size, static_cast(workspace)); + + // Specify shared memory capacity for kernel. + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + if (smem_size >= (48 << 10)) + { + cudaError_t result + = cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + if (result != cudaSuccess) + { + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const& args, void* workspace = nullptr) + { + + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat()::update() - workspace: " << workspace); + + size_t workspace_bytes = get_workspace_size(args); + + if (workspace_bytes && !workspace) + { + return Status::kErrorWorkspaceNull; + } + + params_.update(args, workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) + { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::run()"); + + // + // Configure grid and block dimensions + // + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(GemmKernel::kThreadCount, 1, 1); + + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + // + // Launch kernel + // + + CUTLASS_TRACE_HOST(" grid: (" << grid << "), block: (" << block << "), SMEM: " << smem_size << " bytes"); + + // Launch + cutlass::Kernel<<>>(params_); + + // + // Query for errors + // + cudaError_t result = cudaGetLastError(); + + if (result != cudaSuccess) + { + CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) + { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) + { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) + { + status = run(stream); + } + + return status; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h b/csrc/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h new file mode 100644 index 000000000000..1886a253fb00 --- /dev/null +++ b/csrc/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h @@ -0,0 +1,138 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/bfloat16.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" + +#include "cutlass_extensions/arch/mma.h" +#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" + +namespace cutlass +{ +namespace gemm +{ +namespace kernel +{ + +template +struct MixedGemmArchTraits +{ +}; + +template +struct MixedGemmArchTraits +{ + static constexpr int Stages = 2; + using OperatorClass = cutlass::arch::OpClassSimt; + using AccType = float; + using LayoutB = cutlass::layout::ColumnMajor; + + static constexpr int ElementsPerAccessA = 1; + static constexpr int ElementsPerAccessB = 1; + static constexpr int ElementsPerAccessC = 1; + static constexpr int ThreadblockK = 8; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// ========================= Volta Traits =========================== +// Volta will always dequantize after the global memory load. +// This will instantiate any HMMA tensorcore kernels for Volta. +// Note that volta does not have native bfloat support so weights and activations will be casted to fp16 +// and compute will happen in fp16 then will be converted for bf16 output. +template +struct MixedGemmArchTraits::value + || cutlass::platform::is_same::value>::type> +{ +private: + using LayoutDetails = LayoutDetailsB; + +public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + + using Operator = typename LayoutDetails::Operator; +}; + +// ======================= Turing Traits ============================== +// Note that turing does not have native bfloat support so weights and activations will be casted to fp16 +// and compute will happen in fp16 then will be converted for bf16 output. +template +struct MixedGemmArchTraits::value + || cutlass::platform::is_same::value>::type> +{ +private: + using LayoutDetails = LayoutDetailsB; + +public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + using Operator = typename LayoutDetails::Operator; +}; + +// ======================= Ampere Traits ============================== +template +struct MixedGemmArchTraits::value + || cutlass::platform::is_same::value>::type> +{ +private: + using LayoutDetails = LayoutDetailsB; + +public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + using Operator = typename LayoutDetails::Operator; +}; + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/csrc/cutlass_extensions/gemm/kernel/default_int8_traits.h b/csrc/cutlass_extensions/gemm/kernel/default_int8_traits.h new file mode 100644 index 000000000000..58b98a015368 --- /dev/null +++ b/csrc/cutlass_extensions/gemm/kernel/default_int8_traits.h @@ -0,0 +1,57 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" + +namespace cutlass +{ +namespace gemm +{ +namespace kernel +{ + +template +struct Int8GemmArchTraits +{ + using OperatorClass = cutlass::arch::OpClassSimt; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; +}; + +// ======================= Turing Traits ============================== +template <> +struct Int8GemmArchTraits +{ + using OperatorClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; +}; + +// ======================= Ampere Traits ============================== +template <> +struct Int8GemmArchTraits +{ + using OperatorClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; +}; + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/csrc/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h b/csrc/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h new file mode 100644 index 000000000000..36ae924eebd2 --- /dev/null +++ b/csrc/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h @@ -0,0 +1,574 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/arch/arch.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" + +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace kernel +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail +{ +template +inline constexpr bool dependent_false_v = false; +} + +template +struct GemmFpAIntB +{ + + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static bool const kSplitKSerial = SplitKSerial; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Element; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Mma::LayoutC; + using ElementScale = ElementC; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformA; + + // Type definitions about the mainloop. + using Operator = typename Mma::Operator; + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK; + + /// Parameters structure + struct Arguments + { + GemmUniversalMode mode = GemmUniversalMode::kGemm; + + cutlass::gemm::GemmCoord problem_size; + int group_size; + typename Mma::IteratorA::TensorRef ref_A; + typename Mma::IteratorB::TensorRef ref_B; + typename Mma::IteratorScale::TensorRef ref_scale; + typename Mma::IteratorScale::TensorRef ref_zero; + typename Epilogue::OutputTileIterator::TensorRef ref_C; + typename Epilogue::OutputTileIterator::TensorRef ref_D; + + // Control serial split-k + int batch_count; + + typename EpilogueOutputOp::Params output_op; + + // For gather+scatter operations + int const* gather_A_indices; + int const* gather_B_indices; + int const* scatter_D_indices; + + // Included so we can use Gemm Universal + int batch_stride_D = 0; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Arguments() {} + + CUTLASS_HOST_DEVICE + Arguments(cutlass::gemm::GemmCoord const& problem_size, const int group_size, + typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::TensorRef ref_B, + typename Mma::IteratorScale::TensorRef ref_scale, typename Mma::IteratorScale::TensorRef ref_zero, + typename Epilogue::OutputTileIterator::TensorRef ref_C, + typename Epilogue::OutputTileIterator::TensorRef ref_D, int serial_split_k_factor, + typename EpilogueOutputOp::Params output_op = typename EpilogueOutputOp::Params(), + int const* gather_A_indices = nullptr, int const* gather_B_indices = nullptr, + int const* scatter_D_indices = nullptr) + : problem_size(problem_size) + , group_size(group_size) + , ref_A(ref_A) + , ref_B(ref_B) + , ref_scale(ref_scale) + , ref_zero(ref_zero) + , ref_C(ref_C) + , ref_D(ref_D) + , batch_count(serial_split_k_factor) + , output_op(output_op) + , gather_A_indices(gather_A_indices) + , gather_B_indices(gather_B_indices) + , scatter_D_indices(scatter_D_indices) + { + } + }; + + /// Parameters structure + struct Params + { + cutlass::gemm::GemmCoord problem_size; + int group_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorA::TensorRef ref_A; + typename Mma::IteratorB::Params params_B; + typename Mma::IteratorB::TensorRef ref_B; + typename Mma::IteratorScale::Params params_scale; + typename Mma::IteratorScale::TensorRef ref_scale; + typename Mma::IteratorScale::TensorRef ref_zero; + typename Epilogue::OutputTileIterator::Params params_C; + typename Epilogue::OutputTileIterator::TensorRef ref_C; + typename Epilogue::OutputTileIterator::Params params_D; + typename Epilogue::OutputTileIterator::TensorRef ref_D; + typename EpilogueOutputOp::Params output_op; + int* semaphore; + int gemm_k_size; + // For gather+scatter operations + int const* gather_A_indices; + int const* gather_B_indices; + int const* scatter_D_indices; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() + : swizzle_log_tile(0) + , semaphore(0) + , gemm_k_size(0) + { + } + + CUTLASS_HOST_DEVICE + Params(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape, const int gemm_k_size, + void* workspace = nullptr) + : problem_size(args.problem_size) + , group_size(args.group_size) + , grid_tiled_shape(grid_tiled_shape) + , swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)) + , params_A(args.ref_A.layout()) + , ref_A(args.ref_A) + , params_B(args.ref_B.layout()) + , ref_B(args.ref_B) + , params_scale(args.ref_scale.layout()) + , ref_scale(args.ref_scale) + , ref_zero(args.ref_zero) + , params_C(args.ref_C.layout()) + , ref_C(args.ref_C) + , params_D(args.ref_D.layout()) + , ref_D(args.ref_D) + , output_op(args.output_op) + , semaphore(static_cast(workspace)) + , gemm_k_size(gemm_k_size) + , gather_A_indices(args.gather_A_indices) + , gather_B_indices(args.gather_B_indices) + , scatter_D_indices(args.scatter_D_indices) + { + } + }; + + /// Shared memory storage structure + union SharedStorage + { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + GemmFpAIntB() {} + + /// Determines whether kernel satisfies alignment + CUTLASS_HOST_DEVICE + static Status can_implement(Arguments const& args) + { + + static int const kAlignmentA + = (platform::is_same>::value) ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB + = (platform::is_same>::value) ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorB::AccessType::kElements; + + static int const kAlignmentScale = Mma::IteratorScale::AccessType::kElements; + + static int const kAlignmentC = (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Epilogue::OutputTileIterator::kElementsPerAccess; + + if (!TensorRef_aligned(args.ref_A, kAlignmentA)) + { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_B, kAlignmentB)) + { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_scale, kAlignmentScale)) + { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_zero, kAlignmentScale)) + { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_C, kAlignmentC)) + { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_D, kAlignmentC)) + { + return Status::kErrorMisalignedOperand; + } + + if (!args.ref_scale.good()) + { + return Status::kErrorNotSupported; + } + + if constexpr (hasZero(Mma::QuantOp)) + { + if (!args.ref_zero.good()) + { + return Status::kErrorNotSupported; + } + } + else + { + if (args.ref_zero.good()) + { + return Status::kErrorNotSupported; + } + } + + if constexpr (isFinegrained(Mma::QuantOp)) + { + if (args.group_size != 64 && args.group_size != 128) + { + return Status::kErrorNotSupported; + } + } + + return Status::kSuccess; + } + + static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) + { + + return 0; + } + + // The dummy template parameter is not used and exists so that we can compile this code using + // a standard earlier than C++17. Prior to C++17, fully specialized templates HAD to exists in + // a namespace + template + struct KernelRunner + { + CUTLASS_DEVICE + static void run_kernel(Params const& params, SharedStorage& shared_storage) + { + CUTLASS_NOT_IMPLEMENTED(); + } + }; + + // Initializes the fine grained scale+bias iterator. Needed since the fine grained iterator + // has a different constructor signature than a regular cutlass iterator + template = true> + CUTLASS_DEVICE static IteratorScale initialize_scale(typename IteratorScale::Params const& params, + typename IteratorScale::Pointer pointer_scale, typename IteratorScale::Pointer pointer_zero, + typename IteratorScale::TensorCoord extent, int thread_id, + typename IteratorScale::TensorCoord const& threadblock_offset, int group_size) + { + + return IteratorScale(params, pointer_scale, pointer_zero, extent, thread_id, threadblock_offset, group_size); + } + + template = true> + CUTLASS_DEVICE static IteratorScale initialize_scale(typename IteratorScale::Params const& params, + typename IteratorScale::Pointer pointer_scale, typename IteratorScale::Pointer pointer_zero, + typename IteratorScale::TensorCoord extent, int thread_id, + typename IteratorScale::TensorCoord const& threadblock_offset, int group_size) + { + + return IteratorScale(params, pointer_scale, extent, thread_id, threadblock_offset); + } + + template + struct KernelRunner + { + CUTLASS_DEVICE + static void run_kernel(Params const& params, SharedStorage& shared_storage) + { + using LayoutB = typename Mma::IteratorB::Layout; + static_assert(platform::is_same::value && kInterleave == 1 + || platform::is_same::value && kInterleave >= 1, + "B must be row major/col major OR col major interleaved."); + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset + = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() + || params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) + { + + return; + } + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.k() * params.gemm_k_size, + }; + + cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size * kInterleave, + threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave}; + + typename MatrixCoord::Index fg_row_offset = threadblock_tile_offset.k() * params.gemm_k_size / 64; + typename MatrixCoord::Index scale_row_offset = isFinegrained(Mma::QuantOp) ? fg_row_offset : 0; + cutlass::MatrixCoord tb_offset_scale{scale_row_offset, threadblock_tile_offset.n() * Mma::Shape::kN}; + + // Problem size is a function of threadblock index in the K dimension + int problem_size_k = min(params.problem_size.k(), (threadblock_tile_offset.k() + 1) * params.gemm_k_size); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(params.params_A, params.ref_A.data(), + {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A, params.gather_A_indices); + + typename Mma::IteratorB iterator_B(params.params_B, params.ref_B.data(), + {problem_size_k * kInterleave, params.problem_size.n() / kInterleave}, thread_idx, tb_offset_B, + params.gather_B_indices); + + typename MatrixCoord::Index scale_row_extent = isFinegrained(Mma::QuantOp) ? problem_size_k / 64 : 1; + typename Mma::IteratorScale iterator_scale = initialize_scale( + params.params_scale, params.ref_scale.data(), params.ref_zero.data(), + {scale_row_extent, params.problem_size.n()}, thread_idx, tb_offset_scale, params.group_size); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + if (!kSplitKSerial || gemm_k_iterations > 0) + { + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators); + } + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + // If performing a reduction via split-K, fetch the initial synchronization + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) + { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C(params.params_C, params.ref_C.data(), + params.problem_size.mn(), thread_idx, threadblock_offset, params.scatter_D_indices); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D(params.params_D, params.ref_D.data(), + params.problem_size.mn(), thread_idx, threadblock_offset, params.scatter_D_indices); + + Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) + { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_offset.k()) + { + iterator_C = iterator_D; + } + + semaphore.wait(threadblock_tile_offset.k()); + } + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, iterator_D, accumulators, iterator_C); + + // + // Release the semaphore + // + + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) + { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) + { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else + { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + semaphore.release(lock); + } + } + }; + + /* + To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond + to the ArchTag of the cutlass kernel operator. + */ + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) + { +#if defined(__CUDA_ARCH__) +#if (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 750) + static constexpr bool compile_needed = platform::is_same::value; + KernelRunner::run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800) + static constexpr bool compile_needed = platform::is_same::value; + KernelRunner::run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ <= 900) + static constexpr bool compile_needed = platform::is_same::value; + KernelRunner::run_kernel(params, shared_storage); +#else + static_assert( + false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels."); +#endif +#else + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/csrc/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h b/csrc/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h new file mode 100644 index 000000000000..80a4d8560859 --- /dev/null +++ b/csrc/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h @@ -0,0 +1,73 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Scheduler for grouped GEMM +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" +#include "cutlass/matrix_coord.h" + +#include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h" +#include "cutlass_extensions/gemm/kernel/moe_problem_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace kernel +{ + +/// Visitor class to abstract away the algorithm for iterating over tiles +template +struct GemmMoeProblemVisitor + : public MoeProblemVisitor, ThreadblockShape, + GroupScheduleMode_, PrefetchTileCount, ThreadCount> +{ + + static bool const kTransposed = Transposed; + + using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper; + using Base + = MoeProblemVisitor; + using Params = typename Base::Params; + using SharedStorage = typename Base::SharedStorage; + + // + // Methods + // + CUTLASS_DEVICE + GemmMoeProblemVisitor(Params const& params_, SharedStorage& shared_storage_, int32_t block_idx) + : Base(params_, shared_storage_, block_idx) + { + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h b/csrc/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h new file mode 100644 index 000000000000..54602754279f --- /dev/null +++ b/csrc/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h @@ -0,0 +1,545 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief GEMM kernel to support the epilogue visitor model + for customized softmax partial reduction epilogue fusion. + + This source file will likely be moved to `include/cutlass/gemm/kernel/` in the future once + its usage has been stabilized. For now, it is included in this example to demonstrate + some basic output fusion options. + + original file: 3rdparty/cutlass/examples/35_gemm_softmax/gemm_with_epilogue_visitor.h +*/ + +#pragma once + +#include "cutlass/complex.h" +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" +#include "cutlass/trace.h" + +#include "cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h" + +namespace tk = tensorrt_llm::common; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace kernel +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct GemmWithEpilogueVisitor +{ +public: + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueVisitor = typename Epilogue::Visitor; + using ThreadblockSwizzle = ThreadblockSwizzle_; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using TensorRefA = TensorRef; + + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using TensorRefB = TensorRef; + + using ElementCompute = typename EpilogueVisitor::ElementCompute; + using LayoutAlphaCol = cutlass::layout::RowMajor; + using LayoutAlphaRow = cutlass::layout::ColumnMajor; + using TensorRefAlphaCol = TensorRef; + using TensorRefAlphaRow = TensorRef; + + using ElementC = typename EpilogueVisitor::ElementOutput; + using LayoutC = typename Epilogue::Layout; + using TensorRefC = TensorRef; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformB; + using Operator = typename Mma::Operator; + + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + using EpilogueOutputOp = + typename Epilogue::Visitor::ElementwiseFunctor; // Define type so GemmUniversalBase doesn't complain + + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Split-K preserves splits that are 128b aligned + static int const kSplitKAlignment + = const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value); + + // + // Structures + // + + /// Argument structure + struct Arguments + { + + // + // Data members + // + + GemmUniversalMode mode; + GemmCoord problem_size; + int batch_count; + + TensorRefA ref_A; + TensorRefB ref_B; + tk::QuantMode quant_option; + TensorRefAlphaCol ref_alpha_col; + TensorRefAlphaRow ref_alpha_row; + TensorRefC ref_C; + TensorRefC ref_D; + + int64_t batch_stride_A; + int64_t batch_stride_B; + int64_t batch_stride_D; + + typename EpilogueVisitor::Arguments epilogue_visitor; + + // + // Methods + // + + Arguments() + : mode(GemmUniversalMode::kGemm) + , batch_count(1) + { + } + + /// constructs an arguments structure + Arguments(GemmUniversalMode mode_, GemmCoord problem_size_, int batch_count_, TensorRefA ref_A_, + TensorRefB ref_B_, tk::QuantMode quant_option_, TensorRefAlphaCol ref_alpha_col_, + TensorRefAlphaRow ref_alpha_row_, TensorRefC ref_C_, TensorRefC ref_D_, int64_t batch_stride_A_, + int64_t batch_stride_B_, typename EpilogueVisitor::Arguments epilogue_visitor_) + : mode(mode_) + , problem_size(problem_size_) + , batch_count(batch_count_) + , ref_A(ref_A_) + , ref_B(ref_B_) + , quant_option(quant_option_) + , ref_alpha_col(ref_alpha_col_) + , ref_alpha_row(ref_alpha_row_) + , ref_C(ref_C_) + , ref_D(ref_D_) + , batch_stride_A(batch_stride_A_) + , batch_stride_B(batch_stride_B_) + , batch_stride_D(0) + , epilogue_visitor(epilogue_visitor_) + { + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params + { + + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorB::Params params_B; + typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_col; + typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_row; + typename EpilogueVisitor::OutputTileIterator::Params params_C; + typename EpilogueVisitor::OutputTileIterator::Params params_D; + + GemmUniversalMode mode; + int batch_count; + int gemm_k_size; + + void* ptr_A; + void* ptr_B; + tk::QuantMode quant_option; + typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_col; + typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_row; + ElementC* ptr_C; + ElementC* ptr_D; + + int64_t batch_stride_A; + int64_t batch_stride_B; + + typename EpilogueVisitor::Params epilogue_visitor; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() + : swizzle_log_tile(0) + , params_A(0) + , params_B(0) + , params_alpha_col(0) + , params_C(0) + , params_D(0) + , batch_count(0) + , gemm_k_size(0) + , mode(cutlass::gemm::GemmUniversalMode::kGemm) + , ptr_A(nullptr) + , ptr_B(nullptr) + , ptr_alpha_col(nullptr) + , ptr_alpha_row(nullptr) + , ptr_C(nullptr) + , ptr_D(nullptr) + , batch_stride_A(0) + , batch_stride_B(0) + { + } + + Params( + Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape_, int gemm_k_size_, int* workspace_) + : problem_size(args.problem_size) + , swizzle_log_tile(0) + , params_A(args.ref_A.layout()) + , params_B(args.ref_B.layout()) + , params_alpha_col(args.ref_alpha_col.layout()) + , params_alpha_row(args.ref_alpha_col.layout()) + , params_C(args.ref_C.layout()) + , params_D(args.ref_D.layout()) + , mode(args.mode) + , batch_count(args.batch_count) + , gemm_k_size(args.problem_size.k()) + , ptr_A(args.ref_A.data()) + , ptr_B(args.ref_B.data()) + , quant_option(args.quant_option) + , ptr_alpha_col(args.ref_alpha_col.data()) + , ptr_alpha_row(args.ref_alpha_row.data()) + , ptr_C(args.ref_C.data()) + , ptr_D(args.ref_D.data()) + , batch_stride_A(args.batch_stride_A) + , batch_stride_B(args.batch_stride_B) + , epilogue_visitor(args.epilogue_visitor) + { + + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape(args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count); + + if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) + { + + int const kAlignK + = const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); + + gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK); + + if (gemm_k_size) + { + grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); + } + } + + swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); + } + }; + + /// Shared memory storage structure + union SharedStorage + { + + typename Mma::SharedStorage main_loop; + + struct + { + typename Epilogue::SharedStorage epilogue; + typename EpilogueVisitor::SharedStorage visitor; + } epilogue; + }; + +public: + // + // Methods + // + + CUTLASS_DEVICE + GemmWithEpilogueVisitor() {} + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) + { + + CUTLASS_TRACE_HOST("GemmWithEpilogueVisitor::can_implement()"); + + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = EpilogueVisitor::OutputTileIterator::kElementsPerAccess; + + bool isAMisaligned = false; + bool isBMisaligned = false; + bool isCMisaligned = false; + + if (platform::is_same::value) + { + isAMisaligned = problem_size.k() % kAlignmentA; + } + else if (platform::is_same::value) + { + isAMisaligned = problem_size.m() % kAlignmentA; + } + else if (platform::is_same>::value + || platform::is_same>::value) + { + isAMisaligned = problem_size.k() % kAlignmentA; + } + + if (platform::is_same::value) + { + isBMisaligned = problem_size.n() % kAlignmentB; + } + else if (platform::is_same::value) + { + isBMisaligned = problem_size.k() % kAlignmentB; + } + else if (platform::is_same>::value + || platform::is_same>::value) + { + isBMisaligned = problem_size.k() % kAlignmentB; + } + + if (platform::is_same::value) + { + isCMisaligned = problem_size.n() % kAlignmentC; + } + else if (platform::is_same::value) + { + isCMisaligned = problem_size.m() % kAlignmentC; + } + else if (platform::is_same>::value + || platform::is_same>::value) + { + isCMisaligned = problem_size.n() % kAlignmentC; + } + + if (isAMisaligned) + { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); + return Status::kErrorMisalignedOperand; + } + + if (isBMisaligned) + { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); + return Status::kErrorMisalignedOperand; + } + + if (isCMisaligned) + { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); + return Status::kErrorMisalignedOperand; + } + + CUTLASS_TRACE_HOST(" returning kSuccess"); + + return Status::kSuccess; + } + + static Status can_implement(Arguments const& args) + { + return can_implement(args.problem_size); + } + + static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) + { + + return 0; + } + +#define SPLIT_K_ENABLED 1 + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) + { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() + || params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) + { + + return; + } + + int offset_k = 0; + int problem_size_k = params.problem_size.k(); + + ElementA* ptr_A = static_cast(params.ptr_A); + ElementB* ptr_B = static_cast(params.ptr_B); + +#if SPLIT_K_ENABLED + // + // Fetch pointers based on mode. + // + if (params.mode == GemmUniversalMode::kGemm || params.mode == GemmUniversalMode::kGemmSplitKParallel) + { + + if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) + { + + problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + } + + offset_k = threadblock_tile_offset.k() * params.gemm_k_size; + } + else if (params.mode == GemmUniversalMode::kBatched) + { + ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; + ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; + } + else if (params.mode == GemmUniversalMode::kArray) + { + ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; + ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; + } +#endif + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + offset_k, + }; + + cutlass::MatrixCoord tb_offset_B{offset_k, threadblock_tile_offset.n() * Mma::Shape::kN}; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A); + + typename Mma::IteratorB iterator_B( + params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, thread_idx, tb_offset_B); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + // + // Construct the epilogue visitor + // + + EpilogueVisitor epilogue_visitor(params.epilogue_visitor, shared_storage.epilogue.visitor, + params.problem_size.mn(), thread_idx, warp_idx, lane_idx, params.params_alpha_col, params.params_C, + params.params_D, params.quant_option, params.ptr_alpha_row, params.ptr_alpha_col, params.ptr_C, + params.ptr_D, threadblock_offset, blockIdx.y * params.problem_size.m()); + + if (params.mode == GemmUniversalMode::kGemm) + { + // Indicate which position in a serial reduction the output operator is currently updating + epilogue_visitor.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } + else if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray) + { + epilogue_visitor.set_batch_index(threadblock_tile_offset.k()); + } + + // Construct the epilogue + Epilogue epilogue(shared_storage.epilogue.epilogue, thread_idx, warp_idx, lane_idx); + + // Execute the epilogue operator to update the destination tensor. + epilogue(epilogue_visitor, accumulators); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h b/csrc/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h new file mode 100644 index 000000000000..b8176eb52167 --- /dev/null +++ b/csrc/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h @@ -0,0 +1,114 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/* + This file exists so that we use the same weight layout for MoE grouped gemm and regular gemm when the weight is + quantized. The preprocessing code reads this template to know how to organize the quantized weight matrices + to be consumed by CUTLASS. + + Note that for int4, ThreadBlockK MUST be 64. + + */ + +#pragma once + +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/platform/platform.h" + +#include "cutlass_extensions/arch/mma.h" +#include "cutlass_extensions/tile_interleaved_layout.h" + +namespace cutlass +{ +namespace gemm +{ +namespace kernel +{ + +template +struct LayoutDetailsB +{ +}; + +// Volta specialiations. Volta will dequantize before STS, so we need a different operator +template +struct LayoutDetailsB +{ + static constexpr int ThreadblockK = 64; + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = 8; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// Specializations for Turing+ when B is FP16. These are currently only used for MoE networks. +// TODO - Switch this to column major for weights since gemms should be more performant. +template +struct LayoutDetailsB= 75>::type> +{ + static constexpr int ThreadblockK = 64; + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +template +struct LayoutDetailsB= 75>::type> +{ + static constexpr int ThreadblockK = 64; + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// Specializations for Turing+ when B is quantized. These can use the operator OpMultiplyAddDequantizeInterleavedBToA, +// which signals that we want to dequantize after loading from smem. +template +struct LayoutDetailsB= 75>::type> +{ + static constexpr int ThreadblockK = 64; + +private: + static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; + +public: + using Layout = layout::ColumnMajorTileInterleave; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; +}; + +template +struct LayoutDetailsB= 75>::type> +{ + static constexpr int ThreadblockK = 64; + +private: + static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; + +public: + using Layout = layout::ColumnMajorTileInterleave; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; +}; + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/csrc/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h b/csrc/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h new file mode 100644 index 000000000000..4c5c8cc64f43 --- /dev/null +++ b/csrc/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h @@ -0,0 +1,526 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/complex.h" +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" + +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/trace.h" + +#include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h" +#include "cutlass_extensions/tile_interleaved_layout.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace kernel +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// +// This section exists to that we can use the same kernel code for regular gemm and dequantizing gemms. +// It will dispatch to the dequantizing gemm if the Mma type has an Iterator for scales in global. +template +using void_t = void; + +template +struct use_dq_gemm : platform::false_type +{ +}; + +template +struct use_dq_gemm> : platform::true_type +{ +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MoeFCGemm +{ +public: + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; + static bool const kTransposed = false; + + // Optional transpose + using MapArguments = kernel::detail::MapArguments; + + // Public-facing type definitions related to operand element type, layout, and complex conjugate + // operation. Must interact with the 'kTransposed' notion. + static_assert(!kTransposed, "Transpose problem not supported"); + using ElementA = typename MapArguments::ElementA; + using LayoutA = typename MapArguments::LayoutA; + using ElementB = typename MapArguments::ElementB; + using LayoutB = typename MapArguments::LayoutB; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename MapArguments::LayoutC; + using ElementScale = ElementC; + + static ComplexTransform const kTransformA = MapArguments::kTransformA; + static ComplexTransform const kTransformB = MapArguments::kTransformB; + + // Type definitions about the mainloop. + using Operator = typename Mma::Operator; + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = MapArguments::kAlignmentA; + static int const kAlignmentB = MapArguments::kAlignmentB; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + using ProblemVisitor + = GemmMoeProblemVisitor; + + // + // Structures + // + + /// Argument structure + struct Arguments + { + + // + // Data members + // + + int problem_count; + int threadblock_count; + int group_size; + + typename EpilogueOutputOp::Params output_op; + + ElementA* ptr_A; + ElementB* ptr_B; + ElementScale* weight_scales; + ElementC* ptr_C; + ElementC* ptr_D; + + int64_t* total_rows_before_expert; + int64_t gemm_n; + int64_t gemm_k; + + // Only used by device-level operator + GemmCoord* host_problem_sizes; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() + : problem_count(0) + , threadblock_count(0) + , ptr_A(nullptr) + , ptr_B(nullptr) + , weight_scales(nullptr) + , ptr_C(nullptr) + , ptr_D(nullptr) + , total_rows_before_expert(nullptr) + , gemm_n(0) + , gemm_k(0) + , host_problem_sizes(nullptr) + { + } + + /// Ctor + CUTLASS_HOST_DEVICE + Arguments(int problem_count, int threadblock_count, int group_size, typename EpilogueOutputOp::Params output_op, + const ElementA* ptr_A, const ElementB* ptr_B, const ElementScale* weight_scales, const ElementC* ptr_C, + ElementC* ptr_D, int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, + GemmCoord* host_problem_sizes = nullptr) + : problem_count(problem_count) + , threadblock_count(threadblock_count) + , group_size(group_size) + , output_op(output_op) + , ptr_A(const_cast(ptr_A)) + , ptr_B(const_cast(ptr_B)) + , weight_scales(const_cast(weight_scales)) + , ptr_C(const_cast(ptr_C)) + , ptr_D(ptr_D) + , total_rows_before_expert(total_rows_before_expert) + , gemm_n(gemm_n) + , gemm_k(gemm_k) + , host_problem_sizes(nullptr) + { + if (platform::is_same::value || platform::is_same::value) + { + assert(weight_scales); + } + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params + { + + typename ProblemVisitor::Params problem_visitor; + int threadblock_count; + int group_size; + + typename EpilogueOutputOp::Params output_op; + + ElementA* ptr_A; + ElementB* ptr_B; + ElementScale* weight_scales; + ElementC* ptr_C; + ElementC* ptr_D; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() + : ptr_A(nullptr) + , ptr_B(nullptr) + , weight_scales(nullptr) + , ptr_C(nullptr) + , ptr_D(nullptr) + { + } + + CUTLASS_HOST_DEVICE + Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0) + : problem_visitor( + args.total_rows_before_expert, args.gemm_n, args.gemm_k, args.problem_count, workspace, tile_count) + , threadblock_count(args.threadblock_count) + , group_size(args.group_size) + , output_op(args.output_op) + , ptr_A(args.ptr_A) + , ptr_B(args.ptr_B) + , weight_scales(args.weight_scales) + , ptr_C(args.ptr_C) + , ptr_D(args.ptr_D) + { + } + + CUTLASS_HOST_DEVICE + void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0) + { + + problem_visitor = typename ProblemVisitor::Params( + args.total_rows_before_expert, args.gemm_n, args.gemm_k, args.problem_count, workspace, tile_count); + threadblock_count = args.threadblock_count; + output_op = args.output_op; + ptr_A = args.ptr_A; + ptr_B = args.ptr_B; + weight_scales = args.weight_scales; + ptr_C = args.ptr_C; + ptr_D = args.ptr_D; + } + }; + + /// Shared memory storage structure + union SharedStorage + { + typename ProblemVisitor::SharedStorage problem_visitor; + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + +public: + // + // Methods + // + + CUTLASS_DEVICE + MoeFCGemm() {} + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) + { + return Status::kSuccess; + } + + static Status can_implement(Arguments const& args) + { + if (platform::is_same::value || platform::is_same::value) + { + if (args.weight_scales == nullptr) + { + CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - weight scales are required for uint8_t and uint4b_t"); + return Status::kInvalid; + } + } + else if (args.weight_scales != nullptr) + { + CUTLASS_TRACE_HOST( + "MoeFCGemm::can_implement() - weight scales are ignored for all types except uint8_t and uint4b_t"); + return Status::kInvalid; + } + else if (args.group_size != args.gemm_k) + { + CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - scale shape should be (1, gemm_n)"); + return Status::kInvalid; + } + // Handle the case the input is too short + else if (args.gemm_n < Mma::IteratorB::AccessType::kElements) + { + CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - gemm_n is smaller than the input alignment"); + return Status::kInvalid; + } + return Status::kSuccess; + } + + static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) + { + + return 0; + } + + // The dummy template parameter is not used and exists so that we can compile this code using + // a standard earlier than C++17. Prior to C++17, fully specialized templates HAD to exists in + // a namespace + template + struct KernelRunner + { + CUTLASS_DEVICE + static void run_kernel(Params const& params, SharedStorage& shared_storage) + { + CUTLASS_NOT_IMPLEMENTED(); + } + }; + + template + struct KernelRunner + { + CUTLASS_DEVICE + static void run_kernel(Params const& params, SharedStorage& shared_storage) + { + // + // These types shadow the type-level definitions and support the ability to implement + // a 'transposed' GEMM that computes the transposed problems. + // + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK; + static_assert(platform::is_same::value && kInterleave == 1 + || platform::is_same::value && kInterleave >= 1, + "B must be row major/col major OR col major interleaved."); + + // + // Problem visitor. + // + ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x); + + const int64_t gemm_k = params.problem_visitor.gemm_k; + const int64_t gemm_n = params.problem_visitor.gemm_n; + int64_t bytes_per_expert_matrix = (gemm_k * gemm_n / 8) * cutlass::sizeof_bits::value; + + // Outer 'persistent' loop to iterate over tiles + int loop = 0; + while (problem_visitor.next_tile()) + { + loop++; + + GemmCoord problem_size = problem_visitor.problem_size(); + int32_t problem_idx = problem_visitor.problem_index(); + int32_t cta_idx = int32_t(problem_visitor.threadblock_idx()); + + GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); + + cutlass::gemm::GemmCoord threadblock_offset( + int(cta_idx / grid_shape.n()) * Mma::Shape::kM, int(cta_idx % grid_shape.n()) * Mma::Shape::kN, 0); + + // Load element pointers. Exchange pointers and strides if working on the transpose + const int64_t rows_to_jump + = problem_idx == 0 ? 0 : params.problem_visitor.last_row_for_problem[problem_idx - 1]; + ElementA* ptr_A = reinterpret_cast(params.ptr_A) + rows_to_jump * gemm_k; + typename LayoutA::LongIndex ldm_A = gemm_k; + + char* byte_ptr_B = ((char*) params.ptr_B) + problem_idx * bytes_per_expert_matrix; + ElementB* ptr_B = reinterpret_cast(byte_ptr_B); + typename LayoutB::LongIndex ldm_B + = platform::is_same::value ? gemm_n : gemm_k * kInterleave; + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_offset.m(), + 0, + }; + + cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave}; + + cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()}; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + LayoutA(ldm_A), ptr_A, {problem_size.m(), problem_size.k()}, thread_idx, tb_offset_A); + + typename Mma::IteratorB iterator_B(LayoutB(ldm_B), ptr_B, + {problem_size.k() * kInterleave, problem_size.n() / kInterleave}, thread_idx, tb_offset_B); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + // + // Matrix multiply phase + // + + // Construct thread-scoped matrix multiply + auto CreateMMA = [&]() + { + if constexpr (use_dq_gemm::value) + return Mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx); + else + return Mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + }; + Mma mma = CreateMMA(); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Wait for all threads to finish their epilogue phases from the previous tile. + __syncthreads(); + + // Compute threadblock-scoped matrix multiply-add + ElementScale* weight_scale_ptr = params.weight_scales + problem_idx * problem_size.n(); + + if constexpr (use_dq_gemm::value) + { + const MatrixCoord scale_extent = {1, problem_size.n()}; + typename Mma::IteratorScale iterator_scale(Mma::IteratorScale::Layout(scale_extent.column()), + weight_scale_ptr, scale_extent, thread_idx, tb_offset_scale); + + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators); + } + else + { + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); + } + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + ElementC* ptr_C = reinterpret_cast(params.ptr_C) + problem_idx * gemm_n; + ElementC* ptr_D = reinterpret_cast(params.ptr_D) + rows_to_jump * gemm_n; + + LayoutC layout_C(0); + LayoutC layout_D(gemm_n); + + typename Epilogue::OutputTileIterator::Params params_C(layout_C); + typename Epilogue::OutputTileIterator::Params params_D(layout_D); + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C( + params_C, ptr_C, problem_size.mn(), thread_idx, threadblock_offset.mn()); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params_D, ptr_D, problem_size.mn(), thread_idx, threadblock_offset.mn()); + + Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, iterator_D, accumulators, iterator_C); + + // Next tile + problem_visitor.advance(gridDim.x); + } + } + }; + + /* + To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond + to the ArchTag of the cutlass kernel operator. + */ + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) + { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 750) + static constexpr bool compile_needed = platform::is_same::value; + KernelRunner::run_kernel(params, shared_storage); +#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800) + static constexpr bool compile_needed = platform::is_same::value; + KernelRunner::run_kernel(params, shared_storage); +#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900) + static constexpr bool compile_needed = platform::is_same::value; + KernelRunner::run_kernel(params, shared_storage); +#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + // TODO Update the arch to Sm90 once CUTLASS hopper specialisations are available + static constexpr bool compile_needed = platform::is_same::value; + KernelRunner::run_kernel(params, shared_storage); +#else + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/cutlass_extensions/gemm/kernel/moe_problem_visitor.h b/csrc/cutlass_extensions/gemm/kernel/moe_problem_visitor.h new file mode 100644 index 000000000000..cd9270d1414d --- /dev/null +++ b/csrc/cutlass_extensions/gemm/kernel/moe_problem_visitor.h @@ -0,0 +1,344 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Base scheduler for grouped problems, using MoE +*/ + +#pragma once + +#include "cutlass/gemm/kernel/grouped_problem_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace kernel +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Visitor class to abstract away the algorithm for iterating over tiles +template +struct BaseMoeProblemVisitor +{ + using ThreadblockShape = ThreadblockShape_; + + struct ProblemInfo + { + static int32_t const kNoPrefetchEntry = -1; + int32_t problem_idx; + int32_t problem_start; + + CUTLASS_DEVICE + ProblemInfo() + : problem_idx(kNoPrefetchEntry) + , problem_start(kNoPrefetchEntry) + { + } + + CUTLASS_DEVICE + ProblemInfo(int32_t problem_idx_, int32_t problem_start_) + : problem_idx(problem_idx_) + , problem_start(problem_start_) + { + } + }; + + struct Params + { + int64_t const* last_row_for_problem; + int64_t gemm_n; + int64_t gemm_k; + int32_t problem_count; + void const* workspace; + int32_t tile_count; + + // + // Methods + // + + /// Ctor + CUTLASS_HOST_DEVICE + Params() + : last_row_for_problem(nullptr) + , gemm_n(0) + , gemm_k(0) + , problem_count(0) + , workspace(nullptr) + , tile_count(0) + { + } + + /// Ctor + CUTLASS_HOST_DEVICE + Params(int64_t const* last_row_for_problem, int64_t gemm_n, int64_t gemm_k, int32_t problem_count, + void const* workspace = nullptr, int32_t tile_count = 0) + : last_row_for_problem(last_row_for_problem) + , gemm_n(gemm_n) + , gemm_k(gemm_k) + , problem_count(problem_count) + , workspace(workspace) + , tile_count(tile_count) + { + } + }; + + Params const& params; + int32_t tile_idx; + int32_t problem_tile_start; + int32_t problem_idx; + + // + // Methods + // + CUTLASS_DEVICE + BaseMoeProblemVisitor(Params const& params_, int32_t block_idx) + : params(params_) + , tile_idx(block_idx) + , problem_tile_start(0) + , problem_idx(0) + { + } + + /// Get the grid shape + CUTLASS_HOST_DEVICE + static cutlass::gemm::GemmCoord grid_shape(const cutlass::gemm::GemmCoord& problem) + { + + return cutlass::gemm::GemmCoord(((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM), + ((problem.n() - 1 + ThreadblockShape::kN) / ThreadblockShape::kN), 1); + } + + /// Gets the global tile index + CUTLASS_HOST_DEVICE + int32_t tile_index() const + { + return tile_idx; + } + + /// Gets the index of the problem + CUTLASS_HOST_DEVICE + int32_t problem_index() const + { + return problem_idx; + } + + CUTLASS_HOST_DEVICE + int32_t threadblock_idx() const + { + return tile_idx - problem_tile_start; + } + + CUTLASS_DEVICE + void advance(int32_t grid_size) + { + tile_idx += grid_size; + } + + CUTLASS_HOST_DEVICE + static void possibly_transpose_problem(cutlass::gemm::GemmCoord& problem) + { + ProblemSizeHelper::possibly_transpose_problem(problem); + } + + /// Returns the problem size for the current problem + CUTLASS_HOST_DEVICE + cutlass::gemm::GemmCoord problem_size() const + { + return problem_size(problem_idx); + } + + CUTLASS_HOST_DEVICE + cutlass::gemm::GemmCoord problem_size(int idx) const + { + const int64_t prev_problem_row = idx == 0 ? 0 : params.last_row_for_problem[idx - 1]; + const int64_t current_problem_row = params.last_row_for_problem[idx]; + const int64_t gemm_m = current_problem_row - prev_problem_row; + GemmCoord problem(GemmCoord::Index(gemm_m), GemmCoord::Index(params.gemm_n), GemmCoord::Index(params.gemm_k)); + ProblemSizeHelper::possibly_transpose_problem(problem); + return problem; + } + + CUTLASS_HOST_DEVICE + static int32_t tile_count(const cutlass::gemm::GemmCoord& grid) + { + return ProblemSizeHelper::tile_count(grid); + } + + static int32_t group_tile_count(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, int32_t problem_count) + { + int32_t total_tiles = 0; + for (int32_t i = 0; i < problem_count; ++i) + { + auto problem = host_problem_sizes_ptr[i]; + possibly_transpose_problem(problem); + auto grid = grid_shape(problem); + total_tiles += tile_count(grid); + } + + return total_tiles; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MoeProblemVisitor; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// ProblemVisitor that performs all scheduling on device +// +template +struct MoeProblemVisitor : public BaseMoeProblemVisitor +{ + using Base = BaseMoeProblemVisitor; + using Params = typename Base::Params; + static int const kThreadCount = ThreadCount; + static bool const kRequiresPrecomputation = false; + static int const kThreadsPerWarp = 32; + + struct SharedStorage + { + }; + + // Final tile of the problem loaded by this thread. Each thread will hold + // a separate value. + int32_t problem_ending_tile; + + SharedStorage& shared_storage; + + // + // Methods + // + CUTLASS_DEVICE + MoeProblemVisitor(Params const& params_, SharedStorage& shared_storage_, int32_t block_idx) + : Base(params_, block_idx) + , problem_ending_tile(0) + , shared_storage(shared_storage_) + { + this->problem_idx = -1 * kThreadsPerWarp; + this->problem_tile_start = 0; + } + + CUTLASS_DEVICE + bool next_tile() + { + // Check whether the tile to compute is within the range of the current problem. + int32_t problem_tile_end = __shfl_sync(0xffffffff, problem_ending_tile, this->problem_idx % kThreadsPerWarp); + if (this->tile_idx < problem_tile_end) + { + return true; + } + + // Check whether the tile to compute is within the current group of problems fetched by the warp. + // The last tile for this group is the final tile of the problem held by the final thread in the warp. + int32_t group_tile_end = __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1); + + // Keep the starting problem for this group in `problem_idx`. This is done to reduce + // register pressure. The starting problem for this group is simply the first problem + // in the group most recently fetched by the warp. + int32_t& group_problem_start = this->problem_idx; + group_problem_start = (this->problem_idx / kThreadsPerWarp) * kThreadsPerWarp; + + // Keep the starting tile for this group in `problem_tile_start`. This is done to reduce + // register pressure. + int32_t& group_tile_start = this->problem_tile_start; + + // Each thread in the warp processes a separate problem to advance until + // reaching a problem whose starting tile is less less than tile_idx. + while (group_tile_end <= this->tile_idx) + { + group_problem_start += kThreadsPerWarp; + if (group_problem_start > this->params.problem_count) + { + return false; + } + + // Since `group_tile_start` is a reference to `this->problem_tile_start`, this + // also sets `this->problem_tile_start`. The fact that `this->problem_tile_start` + // is also set here is used later in `next_tile`. + group_tile_start = group_tile_end; + + int lane_idx = threadIdx.x % kThreadsPerWarp; + int32_t lane_problem = group_problem_start + lane_idx; + + // Compute the number of tiles in the problem assigned to each thread. + problem_ending_tile = 0; + if (lane_problem < this->params.problem_count) + { + cutlass::gemm::GemmCoord problem = this->problem_size(lane_problem); + cutlass::gemm::GemmCoord grid = this->grid_shape(problem); + problem_ending_tile = this->tile_count(grid); + } + + // Compute a warp-wide inclusive prefix sum to compute the ending tile index of + // each thread's problem. + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < kThreadsPerWarp; i <<= 1) + { + int32_t val = __shfl_up_sync(0xffffffff, problem_ending_tile, i); + if (lane_idx >= i) + { + problem_ending_tile += val; + } + } + + // The total tile count for this group is now in the final position of the prefix sum + int32_t tiles_in_group = __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1); + + problem_ending_tile += group_tile_start; + group_tile_end += tiles_in_group; + } + + // The next problem to process is the first one that does not have ending tile position + // that is greater than or equal to tile index. + int32_t problem_idx_in_group = __popc(__ballot_sync(0xffffffff, problem_ending_tile <= this->tile_idx)); + + this->problem_idx = group_problem_start + problem_idx_in_group; + + // The starting tile for this problem is the ending tile of the previous problem. In cases + // where `problem_idx_in_group` is the first problem in the group, we do not need to reset + // `problem_tile_start`, because it is set to the previous group's ending tile in the while + // loop above. + if (problem_idx_in_group > 0) + { + this->problem_tile_start = __shfl_sync(0xffffffff, problem_ending_tile, problem_idx_in_group - 1); + } + + return true; + } + + static size_t get_workspace_size( + const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, int32_t problem_count, int32_t block_count) + { + return 0; + } + + static void host_precompute(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, int32_t problem_count, + int32_t block_count, void* host_workspace_ptr) + { + } +}; + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/csrc/cutlass_extensions/gemm/threadblock/default_dq_mma.h b/csrc/cutlass_extensions/gemm/threadblock/default_dq_mma.h new file mode 100644 index 000000000000..a10ed85a8b0b --- /dev/null +++ b/csrc/cutlass_extensions/gemm/threadblock/default_dq_mma.h @@ -0,0 +1,125 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass_extensions/arch/mma.h" +#include "cutlass_extensions/interleaved_numeric_conversion.h" + +namespace cutlass +{ +namespace gemm +{ +namespace threadblock +{ +//////////////////////////////////////////////////////////////////////////////// + +// We need to distinguish here, since we want volta support. It is too much effort +// to write shared memory iterators that are probably needed for volta to function +// properly. As a result, we allow converters both after the LDG (for volta) and after +// the LDS for Turing+. +template < + /// Iterator for B matrix in global memory + typename IteratorB, + /// Warp level Mma + typename MmaOperator, + /// Math operation perform by warp level operator + typename MathOperator> +struct SetConverters +{ +}; + +// Dequantize after LDG, so set transforms accordingly +template < + /// Iterator for B matrix in global memory + typename IteratorB, + /// Mma Policy + typename MmaOperator> +struct SetConverters +{ + using TransformAfterLDG + = FastInterleavedAndBiasedNumericArrayConverter; + + using TransformAfterLDS = NumericArrayConverter; +}; + +// Dequantize after LDS, so set transforms accordingly + +template < + /// Iterator for B matrix in global memory + typename IteratorB, + /// Mma Policy + typename MmaOperator> +struct SetConverters +{ + using TransformAfterLDG = NumericArrayConverter; + + using TransformAfterLDS + = FastInterleavedAndBiasedNumericArrayConverter; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale_, + /// Layout for the scale operand + typename LayoutScale_, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// + typename Enable = void> +struct DqMma; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/csrc/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h b/csrc/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h new file mode 100644 index 000000000000..bd4c16ee0194 --- /dev/null +++ b/csrc/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h @@ -0,0 +1,297 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass_extensions/arch/mma.h" + +#include "cutlass_extensions/gemm/threadblock/dq_mma_multistage.h" +#include "cutlass_extensions/gemm/warp/default_mma_tensor_op.h" +#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" +#include "cutlass_extensions/tile_interleaved_layout.h" + +#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h" +#include "cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h" + +namespace cutlass +{ +namespace gemm +{ +namespace threadblock +{ + +//////////////////////////////////////////////////////////////////////////////// + +template +struct DefaultScaleIterators; + +// Fine grained iterators +template +struct DefaultScaleIterators> +{ + using IteratorScale + = cutlass::transform::threadblock::FineGrainedScaleZeroIterator, Element, + Layout, 0, Alignment>; + + using SmemIteratorScale = IteratorScale; +}; + +// Per column iterators +template +struct DefaultScaleIterators> +{ + // ThreadMap for scale iterator + static_assert((MmaShape::kN % Alignment) == 0, ""); + +private: + using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap, + MmaShape::kN / Alignment, Alignment>; + +public: + // Define iterators over tiles from the scale operand + using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator, + Element, Layout, 0, IteratorScaleThreadMap, Alignment>; + + using SmemIteratorScale = IteratorScale; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Type for elementA + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale, + /// Layout for the scale operand + typename LayoutScale, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Stages in GEMM + int kStages, + /// + typename Operator_, + /// + SharedMemoryClearOption SharedMemoryClear> +struct DqMma= 80 && !layout::IsColumnMajorTileInterleave::value)>::type> +{ + + static_assert(platform::is_same::value || platform::is_same::value, + "Element A must be fp16 or bf16"); + + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + static_assert(platform::is_same::value, + "Mma multistage must dequantize after ldsm"); + + static_assert(platform::is_same::value || platform::is_same::value, + "Element B must be uint8 or uint4"); + + static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the MmaCore components + // Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, ElementA, LayoutA, 1, ThreadMapA, + AccessTypeA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, ElementB, LayoutB, 0, ThreadMapB, + AccessTypeB>; + + using ScaleIterators = DefaultScaleIterators; + + // Define iterators over tiles from the scale operand + using IteratorScale = typename ScaleIterators::IteratorScale; + + using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; + + using Converter = FastInterleavedAndBiasedNumericArrayConverter; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage; +}; + +template < + /// Type for element A + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale, + /// Layout for the scale operand + typename LayoutScale, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Stages in GEMM + int kStages, + /// + typename Operator_, + /// + SharedMemoryClearOption SharedMemoryClear> +struct DqMma= 80 && layout::IsColumnMajorTileInterleave::value)>::type> +{ + + static_assert(platform::is_same::value || platform::is_same::value, + "Element A must be fp16 or bf16"); + + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + static_assert(platform::is_same::value, + "Mma multistage must dequantize after ldsm"); + + static_assert(platform::is_same::value || platform::is_same::value, + "Element B must be uint8 or uint4"); + + static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the MmaCore components + // Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, ElementA, LayoutA, 1, ThreadMapA, + AccessTypeA>; + +private: + static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved; + static constexpr int RowsPerTile = LayoutB::kRowsPerTile; + static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), ""); + static_assert(RowsPerTile == MmaCore::Shape::kK, ""); + + using OriginalThreadMap = typename MmaCore::IteratorThreadMapB; + using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement; + static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), ""); + + using GmemIteratorShape + = MatrixShape; + using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, OriginalThreadMap::kThreads, + layout::PitchLinearShape, + MmaCore::kAccessSizeInBits / sizeof_bits::value>; + +public: + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator; + + using ScaleIterators = DefaultScaleIterators; + + // Define iterators over tiles from the scale operand + using IteratorScale = typename ScaleIterators::IteratorScale; + + using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; + + using Converter = FastInterleavedAndBiasedNumericArrayConverter; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/csrc/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h b/csrc/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h new file mode 100644 index 000000000000..f94e1950e589 --- /dev/null +++ b/csrc/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h @@ -0,0 +1,249 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass_extensions/arch/mma.h" + +#include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h" +#include "cutlass_extensions/gemm/warp/default_mma_tensor_op.h" +#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" +#include "cutlass_extensions/tile_interleaved_layout.h" + +#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h" + +namespace cutlass +{ +namespace gemm +{ +namespace threadblock +{ + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Type for element A + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale, + /// Layout for the scale operand + typename LayoutScale, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator_> +struct DqMma::value)>::type> +{ + + static_assert(platform::is_same::value || platform::is_same::value, + "Element A must be fp16 or bf16"); + + static_assert(platform::is_same::value || platform::is_same::value, + "Element B must be uint8 or uint4"); + + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + static_assert(OperatorInfo::QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, ""); + + static constexpr bool DqAfterLDG = platform::is_same::value; + static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80; + using MmaCoreElementA = typename platform::conditional::type; + using MmaCoreElementB = typename platform::conditional::type; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, ElementA, LayoutA, 1, + typename MmaCore::IteratorThreadMapA, kAlignmentA>; + + // Define iterators over tiles from the B operand + using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, ElementB, LayoutB, 0, + typename MmaCore::IteratorThreadMapB, kAlignmentB>; + + // ThreadMap for scale iterator + static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, ""); + using IteratorScaleThreadMap + = transform::PitchLinearStripminedThreadMap, + MmaCore::Shape::kN / kAlignmentScale, kAlignmentScale>; + + // Define iterators over tiles from the scale operand + using IteratorScale + = cutlass::transform::threadblock::PredicatedTileIterator, + ElementScale, LayoutScale, 0, IteratorScaleThreadMap, kAlignmentScale>; + + using SmemScaleType = typename platform::conditional::type; + using SmemIteratorScale + = cutlass::transform::threadblock::PredicatedTileIterator, + SmemScaleType, LayoutScale, 0, IteratorScaleThreadMap, kAlignmentScale>; + + using Converters = SetConverters; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined; +}; + +// Specialization to handle column major interleave B +template < + /// Type for element A + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale, + /// Layout for the scale operand + typename LayoutScale, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator_> +struct DqMma::value)>::type> +{ + + static_assert(platform::is_same::value || platform::is_same::value, + "Element A must be fp16 or bf16"); + + static_assert(platform::is_same::value || platform::is_same::value, + "Element B must be uint8 or uint4"); + + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + static_assert(OperatorInfo::QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, ""); + + static constexpr bool DqAfterLDG = platform::is_same::value; + static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80; + using MmaCoreElementA = typename platform::conditional::type; + using MmaCoreElementB = typename platform::conditional::type; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, ElementA, LayoutA, 1, + typename MmaCore::IteratorThreadMapA, kAlignmentA>; + +private: + static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved; + static constexpr int RowsPerTile = LayoutB::kRowsPerTile; + static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), ""); + static_assert(RowsPerTile == MmaCore::Shape::kK, ""); + + using OriginalThreadMap = typename MmaCore::IteratorThreadMapB; + using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement; + static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), ""); + + using GmemIteratorShape + = MatrixShape; + using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, OriginalThreadMap::kThreads, + layout::PitchLinearShape, + MmaCore::kAccessSizeInBits / sizeof_bits::value>; + +public: + // Define iterators over tiles from the B operand + using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator; + + // ThreadMap for scale iterator + static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, ""); + using IteratorScaleThreadMap + = transform::PitchLinearStripminedThreadMap, + MmaCore::Shape::kN / kAlignmentScale, kAlignmentScale>; + + // Define iterators over tiles from the scale operand + using IteratorScale + = cutlass::transform::threadblock::PredicatedTileIterator, + ElementScale, LayoutScale, 0, IteratorScaleThreadMap, kAlignmentScale>; + + using SmemScaleType = typename platform::conditional::type; + using SmemIteratorScale + = cutlass::transform::threadblock::PredicatedTileIterator, + SmemScaleType, LayoutScale, 0, IteratorScaleThreadMap, kAlignmentScale>; + + using Converters = SetConverters; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/csrc/cutlass_extensions/gemm/threadblock/default_mma.h b/csrc/cutlass_extensions/gemm/threadblock/default_mma.h new file mode 100644 index 000000000000..8f5cb8a71b9c --- /dev/null +++ b/csrc/cutlass_extensions/gemm/threadblock/default_mma.h @@ -0,0 +1,290 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h" +#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h" +#include "cutlass_extensions/gemm/threadblock/default_mma_bf16.h" + +namespace cutlass +{ +namespace gemm +{ +namespace threadblock +{ + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +// fp16 x fp16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on +// large tile when not enough shared mem is present to do 3+ stage +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB> +struct DefaultMma +{ + + // Define the MmaCore components + // 3 is used on purpose here to trigger components for mma multistage + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, half_t, LayoutA, 1, ThreadMapA, AccessTypeA, + GatherA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, half_t, LayoutB, 0, ThreadMapB, AccessTypeB, + GatherB>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/csrc/cutlass_extensions/gemm/threadblock/default_mma_bf16.h b/csrc/cutlass_extensions/gemm/threadblock/default_mma_bf16.h new file mode 100644 index 000000000000..0a952900cd79 --- /dev/null +++ b/csrc/cutlass_extensions/gemm/threadblock/default_mma_bf16.h @@ -0,0 +1,353 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h" +#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h" + +namespace cutlass +{ +namespace gemm +{ +namespace threadblock +{ + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & bf16 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB> +struct DefaultMma +{ + +private: + // Conversions only needed pre-ampere. This will trigger mma pipeline, so we convert before STS. + static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80; + using MmaElementA = typename platform::conditional::type; + using MmaElementB = typename platform::conditional::type; + +public: + // Define the MmaCore components + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; + + using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, bfloat16_t, LayoutA, 1, + typename MmaCore::IteratorThreadMapA, kAlignmentA, GatherA>; + + // Define iterators over tiles from the B operand + using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, bfloat16_t, LayoutB, 0, + typename MmaCore::IteratorThreadMapB, kAlignmentB, GatherB>; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined; +}; + +// bf16 x bf16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on +// large tile when not enough shared mem is present to do 3+ stage +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB> +struct DefaultMma +{ + + // Define the MmaCore components + // 3 is used on purpose here to trigger components for mma multistage + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, bfloat16_t, LayoutA, 1, ThreadMapA, + AccessTypeA, GatherA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, bfloat16_t, LayoutB, 0, ThreadMapB, + AccessTypeB, GatherB>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int8 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int4 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/csrc/cutlass_extensions/gemm/threadblock/dq_mma_base.h b/csrc/cutlass_extensions/gemm/threadblock/dq_mma_base.h new file mode 100644 index 000000000000..dff66be7593f --- /dev/null +++ b/csrc/cutlass_extensions/gemm/threadblock/dq_mma_base.h @@ -0,0 +1,257 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/mma_base.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass_extensions/weight_only_quant_op.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace threadblock +{ + +//////////////////////////////////////////////////////////////////////////////// +// SFINAE trick so I can keep the same loop code for Volta and dispatch to the +// correct warp level mma. On volta, all data is stored to shared memory as FP16. +template +CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC& D, + typename WarpMma::FragmentA const& A, typename WarpMma::FragmentB const& B, typename WarpMma::FragmentC const& C, + const int warp_tileB_k_offset) +{ + warp_mma(D, A, B, C); +} + +template +CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC& D, + typename WarpMma::TransformedFragmentA const& A, typename WarpMma::TransformedFragmentB const& B, + typename WarpMma::FragmentC const& C, const int warp_tileB_k_offset) +{ + warp_mma(D, A, B, C, warp_tileB_k_offset); +} + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// The type of the scales + typename ElementScale_, + /// Number of stages, + int Stages, + /// The dequantizing op to be performed. + WeightOnlyQuantOp DequantOp, + /// Used for partial specialization, + typename Enable = bool> +class DqMmaBase +{ +public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Policy describing tuning details + using Policy = Policy_; + + ///< Type of the scale to be loaded + using ElementScale = ElementScale_; + + static_assert(DequantOp != WeightOnlyQuantOp::UNDEFINED, ""); + + // Finegrained scales get streamed in via cp.async + static constexpr int ScalebiasStages = isFinegrained(DequantOp) ? Stages : 1; + // We always have scales. + static constexpr int ScaleElementsPerStage = Shape::kN; + // We sometimes have a bias + static constexpr int BiasElementsPerStage = hasZero(DequantOp) ? Shape::kN : 0; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = GemmShape; + + /// Number of warp-level GEMM operations + static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK); + + static constexpr int kNumKIterationsPerWarpBLoad + = Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK; + + static_assert(!(kWarpGemmIterations % kNumKIterationsPerWarpBLoad), ""); + static constexpr int kWarpGemmIterationsForB = kWarpGemmIterations / kNumKIterationsPerWarpBLoad; + + /// Number of stages + static int const kStages = Stages; + + /// Tensor reference to the A operand + using TensorRefA = TensorRef; + + /// Tensor reference to the B operand + using TensorRefB = TensorRef; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage + { + public: + // + // Type definitions + // + + /// Shape of the A matrix operand in shared memory + using ShapeA + = MatrixShape; + + /// Shape of the B matrix operand in shared memory + using ShapeB + = MatrixShape; + + /// Shape of the shared memory buffer for the scales for the B matrix. + using ShapeScale = MatrixShape; + /// Shape of the shared memory buffer for the biases of the B matrix. + using ShapeZero = MatrixShape; + + public: + // + // Data members + // + + /// Buffer for A operand + AlignedBuffer operand_A; + + /// Buffer for B operand + AlignedBuffer operand_B; + + /// Buffer to hold scales for threadblock + AlignedBuffer operand_scale; + + /// Buffer to hold scales for threadblock + AlignedBuffer operand_zero; + + public: + // + // Methods + // + + /// Returns a layout object for the A matrix + CUTLASS_DEVICE + static typename Operator::LayoutA LayoutA() + { + return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); + } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() + { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + /// Returns a TensorRef to the A operand + CUTLASS_HOST_DEVICE + TensorRefA operand_A_ref() + { + return TensorRefA{operand_A.data(), LayoutA()}; + } + + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() + { + return TensorRefB{operand_B.data(), LayoutB()}; + } + }; + +protected: + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + +public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorage& shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx) + , warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) + { + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h b/csrc/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h new file mode 100644 index 000000000000..3c4036dd8cc5 --- /dev/null +++ b/csrc/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h @@ -0,0 +1,110 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "cutlass_extensions/interleaved_numeric_conversion.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace threadblock +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type for the scales + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Converter for B matrix applited immediately after the LDS + typename TransformBAfterLDS_, + /// The quantization operator being used + WeightOnlyQuantOp QuantOp_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// Used for partial specialization + typename Enable = void> +class DqMmaMultistage; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +#include "cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h" +#include "cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h" diff --git a/csrc/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h b/csrc/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h new file mode 100644 index 000000000000..76564f14ba2d --- /dev/null +++ b/csrc/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h @@ -0,0 +1,691 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "cutlass_extensions/interleaved_numeric_conversion.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace threadblock +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type for the scales + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Converter for B matrix applited immediately after the LDS + typename TransformBAfterLDS_, + /// The quantization operator being used + WeightOnlyQuantOp QuantOp_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear> +class DqMmaMultistage> + : public DqMmaBase +{ +public: + ///< Base class + using Base = DqMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using IteratorScale = IteratorScale_; + using ElementScale = typename IteratorScale::Element; + using LayoutScale = typename IteratorScale::Layout; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorScale = SmemIteratorScale_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + using TransformBAfterLDS = TransformBAfterLDS_; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + using Dequantizer = warp::MmaTensorOpDequantizer; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + static_assert(Base::SharedStorage::ShapeScale::kRow == Stages, ""); + static_assert(Base::SharedStorage::ShapeScale::kColumn == Shape::kN, ""); + + /// Internal structure exposed for introspection. + struct Detail + { + + static_assert(Base::kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA + = (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB + = (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + }; + +private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + Dequantizer warp_dequantizer_; + + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool RequiresTileInterleave + = layout::IsColumnMajorTileInterleave::value; + static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); + +private: + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to write threadblock-scoped tile of scale and zero operand to shared memory + SmemIteratorScale smem_iterator_scale_; + +public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& shared_storage, + /// The group size for quantization + int group_size, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx) + , warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, + {shared_storage.operand_zero.data(), LayoutScale(Shape::kN)}, + (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx) + , smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx) + , smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) + , smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), + shared_storage.operand_zero.data(), {Base::kStages, Shape::kN}, thread_idx, group_size) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_scales_and_advance(IteratorScale& iterator_scale, int stage = -1, int k_iter = -1) + { + static_assert(IteratorScale::Shape::kRow == 1, "Scale stride must be 1."); + + typename IteratorScale::AccessType* gmem_scale_ptr = iterator_scale.get_scale(); + typename IteratorScale::AccessType* gmem_zero_ptr = iterator_scale.get_zero(); + + typename IteratorScale::AccessType* smem_scale_ptr + = reinterpret_cast(this->smem_iterator_scale_.get_scale()); + typename IteratorScale::AccessType* smem_zero_ptr + = reinterpret_cast(this->smem_iterator_scale_.get_zero()); + + int const kSrcBytes = sizeof_bits::value * IteratorScale::kAlignment / 8; + + cutlass::arch::cp_async(smem_scale_ptr, gmem_scale_ptr, iterator_scale.valid()); + + if (gmem_zero_ptr != nullptr) + { + cutlass::arch::cp_async(smem_zero_ptr, gmem_zero_ptr, iterator_scale.valid()); + } + + if (iterator_scale.group_size_ == 64) + { + iterator_scale.add_tile_offset({1, 0}); + } + else if (iterator_scale.group_size_ == 128) + { + if (iterator_scale.row_groupsize64_ & 0x1) + { + iterator_scale.add_tile_offset({1, 0}); + } + } + + iterator_scale.row_groupsize64_++; + + this->smem_iterator_scale_.add_tile_offset({1, 0}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance(IteratorA& iterator_A, IteratorB& iterator_B, IteratorScale& iterator_scale, + int group_start_A = 0, int group_start_B = 0) + { + iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) + { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) + { + typename IteratorA::AccessType* dst_ptr + = reinterpret_cast(this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value + * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) + { + auto gmem_ptr = iterator_A.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) + { + cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + else + { + cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) + { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) + { + typename IteratorB::AccessType* dst_ptr + = reinterpret_cast(this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value + * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) + { + auto gmem_ptr = iterator_B.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) + { + cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + else + { + cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC& accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< iterator over scale operand in global memory + IteratorScale iterator_scale, + ///< initial value of accumulator + FragmentC const& src_accum) + { + + // + // Prologue + // + + TransformBAfterLDS lds_converter; + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) + { + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_scale.clear_mask(gemm_k_iterations == 0); + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) + { + typename IteratorA::AccessType* dst_ptr + = reinterpret_cast(this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) + { + int const kSrcBytes = sizeof_bits::value + * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; + + int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) + { + typename IteratorB::AccessType* dst_ptr + = reinterpret_cast(this->smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) + { + int const kSrcBytes = sizeof_bits::value + * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + } + + ++this->smem_iterator_B_; + } + + copy_scales_and_advance(iterator_scale, stage, gemm_k_iterations); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + // + // Clear the remaining tiles of SMEM. This is a functional requirement for some kernels + // so that all accumulator elements outside the GEMM footprint are zero. + // + + if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) + { + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); + + typename IteratorA::AccessType zero_A; + zero_A.clear(); + + last_smem_iterator_A.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) + { + + typename IteratorA::AccessType* dst_ptr + = reinterpret_cast(last_smem_iterator_A.get()); + + *dst_ptr = zero_A; + + ++last_smem_iterator_A; + } + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); + typename IteratorB::AccessType zero_B; + + zero_B.clear(); + last_smem_iterator_B.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) + { + + typename IteratorB::AccessType* dst_ptr + = reinterpret_cast(last_smem_iterator_B.get()); + + *dst_ptr = zero_B; + + ++last_smem_iterator_B; + } + } + + // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed) + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + typename Dequantizer::FragmentScale warp_frag_scales; + typename Dequantizer::FragmentZero warp_frag_zeros; + + Operator warp_mma; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + warp_dequantizer_.load(warp_frag_scales, warp_frag_zeros); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + warp_dequantizer_.add_pointer_offset(Shape::kN); + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_scale.clear_mask(gemm_k_iterations == 0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) + { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) + { + + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + const int warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + const int warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) + { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + ++this->warp_tile_iterator_B_; + } + + typename TransformBAfterLDS::result_type converted_frag_B + = lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); + warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales, warp_frag_zeros); + + run_warp_mma( + warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset); + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations - 1) + { + int group_start_iteration_A, group_start_iteration_B; + + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance( + iterator_A, iterator_B, iterator_scale, group_start_iteration_A, group_start_iteration_B); + + // This is the first group of a given stage, so we issue the loads for the B scales immediately. + if (group_start_iteration_B == 0) + { + copy_scales_and_advance(iterator_scale); + } + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) + { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance( + iterator_A, iterator_B, iterator_scale, group_start_iteration_A, group_start_iteration_B); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - + // #committed) + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) + { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } + else + { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) + { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); + warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN); + smem_read_stage_idx = 0; + } + else + { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_scale.clear_mask(gemm_k_iterations == 0); + } + } + + // Load the scale needed for the next tile iteration. + warp_dequantizer_.load(warp_frag_scales, warp_frag_zeros); + // Update internal pointer to set of scales in shared memory. + warp_dequantizer_.add_pointer_offset(Shape::kN); + } + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) + { + // commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h b/csrc/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h new file mode 100644 index 000000000000..5ec515c28712 --- /dev/null +++ b/csrc/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h @@ -0,0 +1,636 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "cutlass_extensions/interleaved_numeric_conversion.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace threadblock +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type for the scales + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Converter for B matrix applited immediately after the LDS + typename TransformBAfterLDS_, + /// The quantization operator being used + WeightOnlyQuantOp QuantOp_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear> +class DqMmaMultistage> + : public DqMmaBase +{ +public: + ///< Base class + using Base = DqMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using IteratorScale = IteratorScale_; + using ElementScale = typename IteratorScale::Element; + using LayoutScale = typename IteratorScale::Layout; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorScale = SmemIteratorScale_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + using TransformBAfterLDS = TransformBAfterLDS_; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + + // + // Dependent types + // + + /// Fragment of operand Scale loaded from global memory; + using FragmentScale = typename IteratorScale::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + using Dequantizer = warp::MmaTensorOpDequantizer; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail + { + + static_assert(Base::kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA + = (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB + = (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + }; + +private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + Dequantizer warp_dequantizer_; + + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool RequiresTileInterleave + = layout::IsColumnMajorTileInterleave::value; + static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); + +private: + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to write threadblock-scoped tile of scale operand to shared memory + SmemIteratorScale smem_iterator_scale_; + +public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& shared_storage, + ///< Group size for quantization. Not used by this main loop since it assumes per-column + const int group_size, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx) + , warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, + (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx) + , smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx) + , smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) + , smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance( + IteratorA& iterator_A, IteratorB& iterator_B, int group_start_A = 0, int group_start_B = 0) + { + iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) + { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) + { + typename IteratorA::AccessType* dst_ptr + = reinterpret_cast(this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value + * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) + { + auto gmem_ptr = iterator_A.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) + { + cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + else + { + cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) + { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) + { + typename IteratorB::AccessType* dst_ptr + = reinterpret_cast(this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value + * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) + { + auto gmem_ptr = iterator_B.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) + { + cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + else + { + cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC& accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< iterator over scale operand in global memory + IteratorScale iterator_scale, + ///< initial value of accumulator + FragmentC const& src_accum) + { + + // + // Prologue + // + + TransformBAfterLDS lds_converter; + + // NOTE - switch to ldg.sts + // Issue this first, so cp.async.commit_group will commit this load as well. + // Note: we do not commit here and this load will commit in the same group as + // the first load of A. + FragmentScale tb_frag_scales; + tb_frag_scales.clear(); + iterator_scale.load(tb_frag_scales); + this->smem_iterator_scale_.store(tb_frag_scales); + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) + { + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) + { + typename IteratorA::AccessType* dst_ptr + = reinterpret_cast(this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) + { + int const kSrcBytes = sizeof_bits::value + * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; + + int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) + { + typename IteratorB::AccessType* dst_ptr + = reinterpret_cast(this->smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) + { + int const kSrcBytes = sizeof_bits::value + * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + } + + ++this->smem_iterator_B_; + } + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + // + // Clear the remaining tiles of SMEM. This is a functional requirement for some kernels + // so that all accumulator elements outside the GEMM footprint are zero. + // + + if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) + { + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); + + typename IteratorA::AccessType zero_A; + zero_A.clear(); + + last_smem_iterator_A.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) + { + + typename IteratorA::AccessType* dst_ptr + = reinterpret_cast(last_smem_iterator_A.get()); + + *dst_ptr = zero_A; + + ++last_smem_iterator_A; + } + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); + typename IteratorB::AccessType zero_B; + + zero_B.clear(); + last_smem_iterator_B.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) + { + + typename IteratorB::AccessType* dst_ptr + = reinterpret_cast(last_smem_iterator_B.get()); + + *dst_ptr = zero_B; + + ++last_smem_iterator_B; + } + } + + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + typename Dequantizer::FragmentScale warp_frag_scales; + + Operator warp_mma; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + warp_dequantizer_.load(warp_frag_scales); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) + { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) + { + + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + const int warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + const int warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) + { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + ++this->warp_tile_iterator_B_; + } + + typename TransformBAfterLDS::result_type converted_frag_B + = lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); + warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales); + + run_warp_mma( + warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset); + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations - 1) + { + int group_start_iteration_A, group_start_iteration_B; + + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) + { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages have committed. + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) + { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } + else + { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) + { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); + smem_read_stage_idx = 0; + } + else + { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + } + } + } + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) + { + // commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h b/csrc/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h new file mode 100644 index 000000000000..e8f5a92c3f02 --- /dev/null +++ b/csrc/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h @@ -0,0 +1,397 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "cutlass_extensions/interleaved_numeric_conversion.h" + +#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" +#include "cutlass_extensions/gemm_configs.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace threadblock +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type for the scales + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Converter for B matrix applied immediately after the LDG (before STS) + typename TransformBAfterLDG_, + /// Converter for B matrix applited immediately after the LDS + typename TransformBAfterLDS_, + /// The quantization operator being used + WeightOnlyQuantOp QuantOp_, + /// Used for partial specialization + typename Enable = bool> +class DqMmaPipelined : public DqMmaBase +{ +public: + ///< Base class + using Base = DqMmaBase; + + using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory + using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + using Policy = Policy_; ///< Policy describing tuning details + + using IteratorScale = IteratorScale_; + using ElementScale = typename IteratorScale::Element; + using LayoutScale = typename IteratorScale::Layout; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorScale = SmemIteratorScale_; + + using TransformBAfterLDG = TransformBAfterLDG_; + using TransformBAfterLDS = TransformBAfterLDS_; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + + // + // Dependent types + // + + /// Fragment of operand A loaded from global memory + using FragmentA = typename IteratorA::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB = typename IteratorB::Fragment; + + /// Fragment of operand Scale loaded from global memory; + using FragmentScale = typename IteratorScale::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; + + using Dequantizer = warp::MmaTensorOpDequantizer; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + // staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline) + static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2"); + +private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + Dequantizer warp_dequantizer_; + + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool RequiresTileInterleave + = layout::IsColumnMajorTileInterleave::value; + static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); + +protected: + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to write threadblock-scoped tile of scale operand to shared memory + SmemIteratorScale smem_iterator_scale_; + +public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaPipelined(typename Base::SharedStorage& + shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM + const int group_size, ///< Will not be used, just to adapt to finegrained modifications and make the compilation + ///< successful. Because DqMmaPipelined is only enabled for sm<80, so even if this + ///< argument is not added, it does not affect compilation for sm>=80. + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx ///< ID of each thread within a warp + ) + : Base(shared_storage, thread_idx, warp_idx, lane_idx) + , warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, + (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx) + , smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx) + , smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) + , smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx) + { + + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop + FragmentC& accum, ///< destination accumulator tile + IteratorA iterator_A, ///< iterator over A operand in global memory + IteratorB iterator_B, ///< iterator over B operand in global memory + IteratorScale iterator_scale, ///< iterator over scale operand in global memory + FragmentC const& src_accum) + { ///< source accumulator tile + + // + // Prologue + // + TransformBAfterLDG ldg_converter; + TransformBAfterLDS lds_converter; + + using TransformA + = NumericArrayConverter; + + using TransformScale = NumericArrayConverter; + + // These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want + // to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS. + TransformA transformA; + TransformScale transformScale; + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + FragmentA tb_frag_A; + FragmentB tb_frag_B; + FragmentScale tb_frag_scales; + + using WarpFragmentScale = typename Dequantizer::FragmentScale; + WarpFragmentScale warp_frag_scales; + + tb_frag_A.clear(); + tb_frag_B.clear(); + tb_frag_scales.clear(); + + // The last kblock is loaded in the prolog + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + iterator_scale.load(tb_frag_scales); + + ++iterator_A; + ++iterator_B; + + this->smem_iterator_A_.store(transformA(tb_frag_A)); + this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); + this->smem_iterator_scale_.store(transformScale(tb_frag_scales)); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + __syncthreads(); + + warp_dequantizer_.load(warp_frag_scales); + + // Pair of fragments used to overlap shared memory loads and math instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + Operator warp_mma; + + int smem_write_stage_idx = 1; + + // Avoid reading out of bounds + iterator_A.clear_mask(gemm_k_iterations <= 1); + iterator_B.clear_mask(gemm_k_iterations <= 1); + + // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing + // shared memory loads (which have the tighest latency requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) + { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) + { + + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group + // as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations - 1) + { + + // Write fragments to shared memory + this->smem_iterator_A_.store(transformA(tb_frag_A)); + + this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); + + __syncthreads(); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory + if (smem_write_stage_idx == 1) + { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + } + else + { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); + } + + smem_write_stage_idx ^= 1; + } + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + const int warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + const int warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + // We are just about to finish computing on a fragment of B, so initiate the load for the next fragment. + if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) + { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + ++this->warp_tile_iterator_B_; + } + + if (warp_mma_k == 0) + { + + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + // Avoid reading out of bounds if this was the last loop iteration + iterator_A.clear_mask(gemm_k_iterations <= 2); + iterator_B.clear_mask(gemm_k_iterations <= 2); + } + + typename TransformBAfterLDS::result_type converted_frag_B + = lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); + warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales); + run_warp_mma( + warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset); + } + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/cutlass_extensions/gemm/warp/default_mma_tensor_op.h b/csrc/cutlass_extensions/gemm/warp/default_mma_tensor_op.h new file mode 100644 index 000000000000..c8160c59d2d8 --- /dev/null +++ b/csrc/cutlass_extensions/gemm/warp/default_mma_tensor_op.h @@ -0,0 +1,107 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Default warp-level GEMM operators selected by data type, size, and layouts of operands. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/warp/default_mma_tensor_op.h" +#include "cutlass/gemm/warp/mma_tensor_op.h" + +#include "cutlass_extensions/arch/mma.h" +#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" + +namespace cutlass +{ +namespace gemm +{ +namespace warp +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for m-by-n-by-kgroup +template < + /// Shape of one matrix production operation (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A elements, + typename ElementA, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Data type of B elements + typename ElementB, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Element type of C matrix + typename ElementC, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC, + /// Number of partitions along K dimension + int PartitionsK, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor> +struct DefaultMmaTensorOp +{ + +private: + // Shape for computing the FP16s + using ComputeInstructionShape = InstructionShape_; + + // Chosen so we get K=16 for int8 and K=32 for int4. + static constexpr int LoadInstructionK = 8 * sizeof_bits::value / sizeof_bits::value; + + // Shape for loading the narrow data type from shared memory + using LoadInstructionShape = GemmShape; + +public: + using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< + cutlass::arch::Mma, + cutlass::MatrixShape<1, 1>>; + + // Define the warp-level tensor op + using Type = cutlass::gemm::warp::MmaTensorOpComputeBWithF16; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h b/csrc/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h new file mode 100644 index 000000000000..8e94516945eb --- /dev/null +++ b/csrc/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h @@ -0,0 +1,301 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing warp-level matrix multiply-accumulate operations targeting + Tensor Cores. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/platform/platform.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/arch/mma_sm75.h" +#include "cutlass/arch/mma_sm80.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma.h" + +#include "cutlass/gemm/warp/mma_tensor_op_policy.h" + +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace warp +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Data type of A elements + typename ElementA_, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA_, + /// Data type of B elements + typename ElementB_, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB_, + /// Element type of C matrix + typename ElementC_, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC_, + /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) + typename Policy_, + /// Instruction shape to override shared memory iterators with + typename SharedMemoryInstructionShape_, + /// Number of partitions along K dimension + int PartitionsK_ = 1, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Used for partial specialization + typename Enable = bool> +class MmaTensorOpComputeBWithF16 +{ +public: + /// Shape of warp-level matrix operation (concept: GemmShape) + using Shape = Shape_; + + /// Data type of multiplicand A + using ElementA = ElementA_; + + /// Layout of multiplicand A + using LayoutA = LayoutA_; + + /// Data type of multiplicand B + using ElementB = ElementB_; + + /// Layout of multiplicand B + using LayoutB = LayoutB_; + + /// Data type of accumulator matrix C + using ElementC = ElementC_; + + /// Layout of accumulator matrix C + using LayoutC = LayoutC_; + + /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) + using Policy = Policy_; + + /// Underlying matrix multiply operator (concept: arch::Mma) + using ArchMmaOperator = typename Policy::Operator; + + /// Indicates math operator + using MathOperator = typename ArchMmaOperator::Operator; + + /// Architecture tag from underlying instruction + using ArchTag = typename ArchMmaOperator::ArchTag; + static_assert((platform::is_same::value + && platform::is_same::value) + || (platform::is_same::value + && platform::is_same::value + && ArchTag::kMinComputeCapability >= 80), + "MmaTensorOpCvtBToA only supports underlying HMMA"); + + static_assert(platform::is_same::value + || (platform::is_same::value && ArchTag::kMinComputeCapability >= 80), + "MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+"); + + /// Indicates class of matrix operator + using OperatorClass = arch::OpClassTensorOp; + + /// Shape of underlying instruction + using InstructionShape = typename ArchMmaOperator::Shape; + + /// Instruction shape to override shared memory iterators with + using SharedMemoryInstructionShape = SharedMemoryInstructionShape_; + + static_assert( + SharedMemoryInstructionShape::kM == InstructionShape::kM, "M dimension of compute instruction must match load"); + static_assert( + SharedMemoryInstructionShape::kN == InstructionShape::kN, "N dimension of compute instruction must match load"); + + static constexpr int kExpansionFactor = SharedMemoryInstructionShape::kK / InstructionShape::kK; + + static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), ""); + + /// Complex transform on A operand + static ComplexTransform const kTransformA = ComplexTransform::kNone; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + /// Number of threads participating in warp-level matrix product + static int const kThreadCount = 32; + + /// Number of partitions along K dimension + static int const kPartitionsK = PartitionsK_; + +public: + /// Iterates over the A operand in memory + using IteratorA + = MmaTensorOpMultiplicandTileIterator, Operand::kA, ElementA, LayoutA, + MatrixShape, Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; + + /// Storage for A tile + using FragmentA = typename IteratorA::Fragment; + + /// Storage for transformed A tile + using TransformedFragmentA = Array; + + /// Iterates over the B operand in memory + using IteratorB = MmaTensorOpMultiplicandTileIterator, Operand::kB, ElementB, + LayoutB, MatrixShape, Policy::OpDelta::kRow, + kThreadCount, kPartitionsK>; + + /// Storage for B tile + using FragmentB = typename IteratorB::Fragment; + + /// Storage for transformed B tile + using TransformedFragmentB = Array; + + /// Iterates over the C operand in memory + using IteratorC = MmaTensorOpAccumulatorTileIterator, ElementC, LayoutC, + typename ArchMmaOperator::Shape, typename Policy::OpDelta>; + + /// Storage for C tile + using FragmentC = typename IteratorC::Fragment; + + /// Number of mma operations performed + using MmaIterations = MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, + (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN>; + +public: + /// Underlying matrix multiply operator (concept: arch::Mma) + ArchMmaOperator mma; + +public: + // + // Methods + // + + /// Ctor + CUTLASS_DEVICE + MmaTensorOpComputeBWithF16() {} + + /// Performs a warp-level matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()(FragmentC& D, TransformedFragmentA const& A, TransformedFragmentB const& B, FragmentC const& C, + const int warp_tileB_k_offset) const + { + + using MmaOperandA = typename ArchMmaOperator::FragmentA; + using MmaOperandB = typename ArchMmaOperator::FragmentB; + using MmaOperandC = typename ArchMmaOperator::FragmentC; + + static_assert( + TransformedFragmentB::kElements == MmaOperandB::kElements * kExpansionFactor * MmaIterations::kColumn, + "Each thread should have a pack of mma registers for each column iteration AND for the expanded K dim of " + "B"); + + D = C; + + MmaOperandA const* ptr_A = reinterpret_cast(&A); + MmaOperandB const* ptr_B = reinterpret_cast(&B); + MmaOperandC* ptr_D = reinterpret_cast(&D); + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) + // Serpentine visitation order maximizing reuse of Rb + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) + { + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) + { + + int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); + + int n_offsetB = warp_tileB_k_offset + kExpansionFactor * n; + if (AccumulatorsInRowMajor) + { // matrix B is reordered + mma(ptr_D[n + m_serpentine * MmaIterations::kColumn], ptr_A[m_serpentine], ptr_B[n_offsetB], + ptr_D[n + m_serpentine * MmaIterations::kColumn]); + } + else + { + mma(ptr_D[m_serpentine + n * MmaIterations::kRow], ptr_A[m_serpentine], ptr_B[n_offsetB], + ptr_D[m_serpentine + n * MmaIterations::kRow]); + } + } + } +#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + // Serpentine visitation order maximizing reuse of Ra + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) + { + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) + { + + int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); + + int n_serpentine_offsetB = warp_tileB_k_offset + kExpansionFactor * n_serpentine; + if (AccumulatorsInRowMajor) + { // matrix B is reordered + mma(ptr_D[n_serpentine + m * MmaIterations::kColumn], ptr_A[m], ptr_B[n_serpentine_offsetB], + ptr_D[n_serpentine + m * MmaIterations::kColumn]); + } + else + { + mma(ptr_D[m + n_serpentine * MmaIterations::kRow], ptr_A[m], ptr_B[n_serpentine_offsetB], + ptr_D[m + n_serpentine * MmaIterations::kRow]); + } + } + } +#else + assert(0); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h b/csrc/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h new file mode 100644 index 000000000000..bdac36fd95d9 --- /dev/null +++ b/csrc/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h @@ -0,0 +1,646 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/array.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" + +#include "cutlass/functional.h" +#include "cutlass/platform/platform.h" + +#include "cutlass_extensions/weight_only_quant_op.h" +#include "tensorrt_llm/common/cudaBf16Wrapper.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace warp +{ + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Matrix multiply operator + typename MmaOperator_, + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Operand identity + Operand Operand, + /// Data type of Scale elements + typename Element_, + /// Layout of operand + typename Layout_, + /// Number of threads participating in one matrix operation + int Threads, + /// + WeightOnlyQuantOp QuantOp_, + /// + typename Enable = void> +class MmaTensorOpDequantizer; + +//////////////////////////////////////////////////////////////////////////////// +// Bfloat specialization for Ampere +template < + /// Underlying matrix multiply operator (concept: MmaTensorOp) + typename MmaOperator_, + /// Shape of the warp level matrix multiply (concept: GemmShape) + typename Shape_, + /// + WeightOnlyQuantOp QuantOp_> +class MmaTensorOpDequantizer= 80 + && platform::is_same::value>::type> +{ + +public: + /// Mma Operator + using MmaOperator = MmaOperator_; + + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; + + // This is the ratio of the load instruction vs the compute instruction. + static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; + + /// Type of the scales + using ElementScale = bfloat16_t; + + /// Fragment to hold B data before Mma + using FragmentDequantizedOperand = Array; + + // Fragment to hold scale data to apply to B before mma + // We need 1 fp16 per matrix iteration in the N dimension + static constexpr int kColsPerMmaPerThread = 1; + using FragmentScale = Array; + using FragmentZero = Array; + + /// Warp mma shape + using Shape = Shape_; + + /// Layout of the scales in shared memory + using Layout = layout::RowMajor; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, const int warp_idx_n, const int lane_idx) + { + const int warp_offset = warp_idx_n * Shape::kN; + const int quad = lane_idx / 4; + const int thread_offset = warp_offset + quad; + pointer_scale_ = smem_scales.data() + thread_offset; + if constexpr (hasZero(QuantOp)) + { + pointer_zero_ = smem_zeros.data() + thread_offset; + } + } + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx) + : MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx) + { + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag) + { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + } + } + + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag) + { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) + using _MmaOperandB = typename ArchMmaOperator::FragmentB; + using ExpandedMmaOperandB = Array; + static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn + == FragmentDequantizedOperand::kElements, + ""); + + const __nv_bfloat16* scale_ptr = reinterpret_cast(&scale_frag); + ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + { + static_assert(ExpandedMmaOperandB::kElements % 2 == 0, ""); + + __nv_bfloat162 scalex2 = __bfloat162bfloat162(scale_ptr[mma_n_iter]); + __nv_bfloat162* operand_bf16x2_ptr = reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]); + + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) + { + operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], scalex2); + } + } +#else + // Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should + // happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid + // numerous conversion instructions in GEMM main loop. + arch::device_breakpoint(); +#endif + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag, FragmentScale& zero_frag) + { + if constexpr (hasZero(QuantOp)) + { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + zero_frag[mma_n_iter] = pointer_zero_[mma_n_iter * InstructionShape::kN]; + } + } + else + { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + } + } + } + + CUTLASS_DEVICE + void dequantize( + FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag, const FragmentScale& zero_frag) + { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) + using _MmaOperandB = typename ArchMmaOperator::FragmentB; + using ExpandedMmaOperandB = Array; + static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn + == FragmentDequantizedOperand::kElements, + ""); + + const __nv_bfloat16* scale_ptr = reinterpret_cast(&scale_frag); + const __nv_bfloat16* zero_ptr = reinterpret_cast(&zero_frag); + + ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + { + static_assert(ExpandedMmaOperandB::kElements % 2 == 0, ""); + + __nv_bfloat162 scalex2 = __bfloat162bfloat162(scale_ptr[mma_n_iter]); + __nv_bfloat162 zerox2 = __bfloat162bfloat162(zero_ptr[mma_n_iter]); + __nv_bfloat162* operand_bf16x2_ptr = reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]); + + if constexpr (hasZero(QuantOp)) + { + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) + { + operand_bf16x2_ptr[ii] = __hfma2(operand_bf16x2_ptr[ii], scalex2, zerox2); + } + } + else + { + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) + { + operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], scalex2); + } + } + } +#else + // Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should + // happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid + // numerous conversion instructions in GEMM main loop. + arch::device_breakpoint(); +#endif + } + + // Adds a pointer offset in units of elements. + CUTLASS_DEVICE + void add_pointer_offset(int64_t const& offset) + { + static_assert(sizeof(ElementScale) > 1, ""); + pointer_scale_ += offset; + pointer_zero_ += offset; + } + +private: + ElementScale const* pointer_scale_; + ElementScale const* pointer_zero_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +// Specialization for Turing & Ampere +template < + /// Underlying matrix multiply operator (concept: MmaTensorOp) + typename MmaOperator_, + /// Shape of the warp level matrix multiply (concept: GemmShape) + typename Shape_, + /// + WeightOnlyQuantOp QuantOp_> +class MmaTensorOpDequantizer= 75 + && platform::is_same::value>::type> +{ + +public: + /// Mma Operator + using MmaOperator = MmaOperator_; + + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; + + // This is the ratio of the load instruction vs the compute instruction. + static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; + + /// Type of the scales + using ElementScale = half_t; + + /// Fragment to hold B data before Mma + using FragmentDequantizedOperand = Array; + + // Fragment to hold scale data to apply to B before mma + // We need 1 fp16 per matrix iteration in the N dimension + static constexpr int kColsPerMmaPerThread = 1; + using FragmentScale = Array; + using FragmentZero = Array; + + /// Warp mma shape + using Shape = Shape_; + + /// Layout of the scales in shared memory + using Layout = layout::RowMajor; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, const int warp_idx_n, const int lane_idx) + { + const int warp_offset = warp_idx_n * Shape::kN; + const int quad = lane_idx / 4; + const int thread_offset = warp_offset + quad; + pointer_scale_ = smem_scales.data() + thread_offset; + if constexpr (hasZero(QuantOp)) + { + pointer_zero_ = smem_zeros.data() + thread_offset; + } + } + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx) + : MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx) + { + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag) + { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + } + } + + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag) + { + using _MmaOperandB = typename ArchMmaOperator::FragmentB; + using ExpandedMmaOperandB = Array; + static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn + == FragmentDequantizedOperand::kElements, + ""); + + multiplies mul_op; + + ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + { + operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]); + } + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag, FragmentScale& zero_frag) + { + if constexpr (hasZero(QuantOp)) + { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + zero_frag[mma_n_iter] = pointer_zero_[mma_n_iter * InstructionShape::kN]; + } + } + else + { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + } + } + } + + CUTLASS_DEVICE + void dequantize( + FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag, const FragmentScale& zero_frag) + { + using _MmaOperandB = typename ArchMmaOperator::FragmentB; + using ExpandedMmaOperandB = Array; + static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn + == FragmentDequantizedOperand::kElements, + ""); + + multiplies mul_op; + ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); + + if constexpr (hasZero(QuantOp)) + { + plus plus_op; + + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + { + operand_frag_ptr[mma_n_iter] + = plus_op(mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]), zero_frag[mma_n_iter]); + } + } + else + { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + { + operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]); + } + } + } + + // Adds a pointer offset in units of elements. + CUTLASS_DEVICE + void add_pointer_offset(int64_t const& offset) + { + static_assert(sizeof(ElementScale) > 1, ""); + pointer_scale_ += offset; + pointer_zero_ += offset; + } + +private: + ElementScale const* pointer_scale_; + ElementScale const* pointer_zero_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +// Specialization for Volta A x RowMajor B tensorOp, for 32x32x4 interleaved gemm +template < + /// Underlying matrix multiply operator (concept: MmaTensorOp) + typename MmaOperator_, + /// Shape of the warp level matrix multiply (concept: GemmShape) + typename Shape_, + /// + WeightOnlyQuantOp QuantOp_> +class MmaTensorOpDequantizer::value + && platform::is_same::value>::type> +{ + +public: + static_assert(platform::is_same>::value, ""); + + /// Mma Operator + using MmaOperator = MmaOperator_; + + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; + + /// Type of the scales + using ElementScale = half_t; + + /// Fragment to hold B data before Mma + using FragmentDequantizedOperand = Array; + + /// Warp mma shape + using Shape = Shape_; + + // Fragment to hold scale data to apply to B before mma + // Each 32x32x4 matmul uses 8 elements from B. + static constexpr int ColsPerMmaTile = 32; + static constexpr int TileNIterations = Shape::kN / ColsPerMmaTile; + using FragmentScale = Array; + using AccessType = Array; + + /// Layout of the scales in shared memory + using Layout = layout::RowMajor; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + static_assert(QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, ""); + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx) + { + const int warp_offset = warp_idx_n * Shape::kN; + const int base_col = lane_idx & 0xF8; + const int thread_offset = warp_offset + base_col; + pointer_ = smem_scales.data() + thread_offset; + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag) + { + AccessType* scale_frag_ptr = reinterpret_cast(&scale_frag); + + CUTLASS_PRAGMA_UNROLL + for (int tile_iter = 0; tile_iter < TileNIterations; ++tile_iter) + { + // We jump by 32 here since volta does <32x32x4> super mmas inside a warp. + scale_frag_ptr[tile_iter] = *reinterpret_cast(pointer_ + ColsPerMmaTile * tile_iter); + } + } + + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag) + { + static_assert(FragmentScale::kElements == FragmentDequantizedOperand::kElements, ""); + + multiplies mul_op; + operand_frag = mul_op(operand_frag, scale_frag); + } + +private: + ElementScale const* pointer_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +// Specialization for Volta A x ColumnMajor B tensorOp, for 32x32x4 interleaved gemm +template < + /// Underlying matrix multiply operator (concept: MmaTensorOp) + typename MmaOperator_, + /// Shape of the warp level matrix multiply (concept: GemmShape) + typename Shape_, + /// + WeightOnlyQuantOp QuantOp_> +class MmaTensorOpDequantizer::value + && platform::is_same::value>::type> +{ + +public: + static_assert(platform::is_same>::value, ""); + + /// Mma Operator + using MmaOperator = MmaOperator_; + + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; + + /// Type of the scales + using ElementScale = half_t; + + /// Fragment to hold B data before Mma + using FragmentDequantizedOperand = Array; + + /// Warp mma shape + using Shape = Shape_; + + // Fragment to hold scale data to apply to B before mma + // Each 32x32x4 matmul uses 8 elements from B. + static constexpr int ColsPerMmaTile = 32; + static constexpr int TileNIterations = Shape::kN / ColsPerMmaTile; + using FragmentScale = Array; + + /// Layout of the scales in shared memory + using Layout = layout::RowMajor; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + static_assert(QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, ""); + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx) + { + const int warp_offset = warp_idx_n * Shape::kN; + const int base_col = lane_idx & 0xF8 + lane_idx % 4; + const int thread_offset = warp_offset + base_col; + pointer_ = smem_scales.data() + thread_offset; + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag) + { + CUTLASS_PRAGMA_UNROLL + for (int tile_iter = 0; tile_iter < TileNIterations; ++tile_iter) + { + // We jump by 32 here since volta does <32x32x4> super mmas inside a warp. + // For col major B, each thread will jump 4 cols to get its next value inside + // of the super mma. + CUTLASS_PRAGMA_UNROLL + for (int mma_iter = 0; mma_iter < 2; ++mma_iter) + { + scale_frag[tile_iter * 2 + mma_iter] = pointer_[ColsPerMmaTile * tile_iter + 4 * mma_iter]; + } + } + } + + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag) + { + using MmaOperandB = typename ArchMmaOperator::FragmentB; + static constexpr int total_n_mmas = 2 * TileNIterations; + static_assert(MmaOperandB::kElements * total_n_mmas == FragmentDequantizedOperand::kElements, ""); + + multiplies mul_op; + + MmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < total_n_mmas; ++mma_n_iter) + { + operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]); + } + } + +private: + ElementScale const* pointer_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/cutlass_extensions/gemm_configs.h b/csrc/cutlass_extensions/gemm_configs.h new file mode 100644 index 000000000000..11180c4260d8 --- /dev/null +++ b/csrc/cutlass_extensions/gemm_configs.h @@ -0,0 +1,72 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace tensorrt_llm +{ +namespace cutlass_extensions +{ +// Note: The shapes are in the format MxNxK. The K shape of the runtime config MUST match the K shape +// in the kernel layout details when doing weight only quantization. +enum class CutlassTileConfig +{ + // Signals that we should run heuristics do choose a config + Undefined, + + // Signals that we should run heuristics do choose a config + ChooseWithHeuristic, + + // SiMT config + CtaShape128x128x8_WarpShape64x64x8, + + // TensorCore configs CTA_N = 128, CTA_K = 64 + // Warp configs for M=32 + CtaShape32x128x64_WarpShape32x32x64, + + // Warp configs for M=64 + CtaShape64x128x64_WarpShape32x64x64, + CtaShape64x64x128_WarpShape32x64x64, + CtaShape64x128x64_WarpShape64x32x64, + + // Warp configs for M=128 + CtaShape128x64x64_WarpShape64x32x64, + CtaShape128x128x64_WarpShape64x32x64, + CtaShape128x128x64_WarpShape64x64x64, + CtaShape128x128x64_WarpShape128x32x64, + CtaShape128x256x64_WarpShape64x64x64, + + // Warp configs for M=256 + CtaShape256x128x64_WarpShape64x64x64 +}; + +enum class SplitKStyle +{ + NO_SPLIT_K, + SPLIT_K_SERIAL, + // SPLIT_K_PARALLEL // Not supported yet +}; + +struct CutlassGemmConfig +{ + CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic; + SplitKStyle split_k_style = SplitKStyle::NO_SPLIT_K; + int split_k_factor = -1; + int stages = -1; +}; + +} // namespace cutlass_extensions +} // namespace tensorrt_llm diff --git a/csrc/cutlass_extensions/interleaved_numeric_conversion.h b/csrc/cutlass_extensions/interleaved_numeric_conversion.h new file mode 100644 index 000000000000..44ba79680e69 --- /dev/null +++ b/csrc/cutlass_extensions/interleaved_numeric_conversion.h @@ -0,0 +1,447 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Boost-like numeric conversion operator for int8 and CUTLASS int4b_t interleaved in a register +*/ + +#pragma once + +#include "cutlass/arch/arch.h" +#include "cutlass/array.h" +#include "cutlass/half.h" +#include "cutlass/numeric_types.h" + +namespace cutlass +{ + +// This converter is meant to be used with data interleaved in a 32-bit register where the even elements are in the low +// bits and the odd elemeents are in the high bits of the register. In addition, it assumes elements were originally +// signed and had a bias of 2**(b-1) added (where b is the number of bits in the type) to make all numbers unsigned. +// This converter will uninterleave the data and subtract the bias while converting to the result type. +template +struct FastInterleavedAndBiasedNumericArrayConverter +{ +}; + +template <> +struct FastInterleavedAndBiasedNumericArrayConverter +{ + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + result_type result; + + uint32_t* h = reinterpret_cast(&result); + uint32_t const i8s = reinterpret_cast(source); + + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[0]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01)); + asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[1]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23)); + + // Lastly, we subtract 1152 from our constructed number using fp16 math to get our signed integer as fp16. + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(I8s_TO_F16s_MAGIC_NUM)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(I8s_TO_F16s_MAGIC_NUM)); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter +{ + static constexpr int VEC_WIDTH = 4; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) + { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template <> +struct FastInterleavedAndBiasedNumericArrayConverter +{ + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + result_type result; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + + uint32_t* bf16_result_ptr = reinterpret_cast(&result); + uint32_t const i8s = reinterpret_cast(source); + + static constexpr uint32_t fp32_base = 0x4B000000; + float fp32_intermediates[4]; + + // Construct FP32s, bfloat does not have enough mantissa for IADD trick + uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); + fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653); + + // Subtract out fp32_base + 128 to make the unsigned integer signed. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < 4; ++ii) + { + fp32_intermediates[ii] -= 8388736.f; + } + + // Truncate the fp32 representation and pack up as bfloat16s. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < 2; ++ii) + { + bf16_result_ptr[ii] + = __byte_perm(fp32_intermediates_casted[2 * ii + 0], fp32_intermediates_casted[2 * ii + 1], 0x7632); + } +#else + // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use + // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters. + result.clear(); // Suppress compiler warning + arch::device_breakpoint(); +#endif + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter +{ + static constexpr int VEC_WIDTH = 4; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) + { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template <> +struct FastInterleavedAndBiasedNumericArrayConverter +{ + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + result_type result; + + uint32_t* h = reinterpret_cast(&result); + uint32_t const i4s = reinterpret_cast(source); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t BOTTOM_MASK = 0x000f000f; + static constexpr uint32_t TOP_MASK = 0x00f000f0; + static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; + + // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing + // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions. + // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and + // elt_67 to fp16 without having to shift them to the bottom bits before hand. + + // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue + // immediately before required. + const uint32_t top_i4s = i4s >> 8; + // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[1]) + : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[2]) + : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[3]) + : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + + // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the + // half2 ctor. In this case, I chose performance reliability over code readability. + + // This is the half2 {1032, 1032} represented as an integer. + static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; + // This is the half2 {1 / 16, 1 / 16} represented as an integer. + static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; + // This is the half2 {-72, -72} represented as an integer. + static constexpr uint32_t NEG_72 = 0xd480d480; + + // Finally, we construct the output numbers. + // Convert elt_01 + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_23 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_72)); + // Convert elt_45 + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_67 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_72)); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter +{ + static constexpr int VEC_WIDTH = 8; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 8."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) + { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template <> +struct FastInterleavedAndBiasedNumericArrayConverter +{ + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + result_type result; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + + uint32_t* h = reinterpret_cast(&result); + uint32_t const source_i4s = reinterpret_cast(source); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; + + // We don't have enough mantissa to remove as much shift overhead as FP16, so we must loop. + // No shift needed for first item. + uint32_t i4s = source_i4s; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); + CUTLASS_PRAGMA_UNROLL + for (int ii = 1; ii < result_type::kElements / 2; ++ii) + { + i4s >>= sizeof_bits::value; + // (i4s & 0x000f000f) | 0x43004300 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[ii]) + : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); + } + + // This is the BF16 {-136, -136} represented as an integer. + static constexpr uint32_t BF16_BIAS = 0xC308C308; + static constexpr uint32_t BF16_ONE = 0x3F803F80; + + // Finally, we construct the output numbers. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < result_type::kElements / 2; ++ii) + { + // Since this section is for Ampere+, we use bf16 fma to do the bias subtraction + asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[ii]) : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS)); + } +#else + // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use + // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters. + arch::device_breakpoint(); + result.clear(); // Suppress compiler warning. +#endif + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter +{ + static constexpr int VEC_WIDTH = 8; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 8."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) + { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/cutlass_extensions/tile_interleaved_layout.h b/csrc/cutlass_extensions/tile_interleaved_layout.h new file mode 100644 index 000000000000..5a0cd2957082 --- /dev/null +++ b/csrc/cutlass_extensions/tile_interleaved_layout.h @@ -0,0 +1,66 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines new layouts needed for MoE +*/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/pitch_linear_coord.h" + +namespace cutlass +{ +namespace layout +{ + +template +struct ColumnMajorTileInterleave +{ + static constexpr int kRowsPerTile = RowsPerTile; + static constexpr int kColumnsInterleaved = ColumnsInterleaved; +}; + +template +struct IsColumnMajorTileInterleave +{ + static constexpr bool value = false; +}; + +template +struct IsColumnMajorTileInterleave> +{ + static constexpr bool value = true; +}; + +} // namespace layout +} // namespace cutlass diff --git a/csrc/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h b/csrc/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h new file mode 100644 index 000000000000..f8e46f1d2ab7 --- /dev/null +++ b/csrc/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h @@ -0,0 +1,248 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates for visiting scales to be used when dequantizing the weights for weight-only GEMM + quantization. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h" + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace transform +{ +namespace threadblock +{ + +//////////////////////////////////////////////////////////////////////////////// + +template +class FineGrainedScaleZeroIterator; + +template +class FineGrainedScaleZeroIterator +{ +public: + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = 0; + static int const kAlignment = Alignment_; + + static int const kAccessesPerVector = 1; + + /// Row index of scales corresponding to the groupsize of 64 + int row_groupsize64_; + int group_size_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using AccessType = AlignedArray; + + // For compatibility with existing iterator interface + struct Params + { + LongIndex stride_ = 0; + + /// amount (in byte) to increment pointer from first access of current tile + /// to first access of next tile + LongIndex inc_advance_ = 0; + + // Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : stride_(layout.stride(0)) + { + inc_advance_ = Shape::kRow * stride_ * sizeof_bits::value / 8; + } + }; + +private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + +private: + // + // Data members + // + + /// Parameters object with precomputed internal state + Params const params_; + + /// Internal pointer to first access of tile + BytePointer pointer_scale_; + BytePointer pointer_zero_; + + bool is_valid_ = false; + +public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_DEVICE + FineGrainedScaleZeroIterator( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of scale tensor + Pointer pointer_scale, + ///< Pointer to start of zero tensor + Pointer pointer_zero, + ///< Extent of the scale and bias + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + ///< Group size + int group_size) + : params_(params) + , pointer_scale_(reinterpret_cast(const_cast(pointer_scale))) + , pointer_zero_(reinterpret_cast(const_cast(pointer_zero))) + { + row_groupsize64_ = threadblock_offset.row(); + group_size_ = group_size; + + const LongIndex tb_row_byte_offset + = threadblock_offset.row() / (group_size / 64) * params_.stride_ * sizeof_bits::value / 8; + const LongIndex tb_col_byte_offset = threadblock_offset.column() * sizeof_bits::value / 8; + pointer_scale_ += (tb_row_byte_offset + tb_col_byte_offset); + + if (pointer_zero_ != nullptr) + { + pointer_zero_ += (tb_row_byte_offset + tb_col_byte_offset); + } + + static constexpr int THREADS_PER_ROW = Shape::kColumn / kAlignment; + + const int thread_row = thread_id / THREADS_PER_ROW; + const int thread_col = thread_id % THREADS_PER_ROW; + + const LongIndex thread_row_byte_offset = thread_row * params_.stride_ * sizeof_bits::value / 8; + const LongIndex thread_col_byte_offset = thread_col * kAlignment * sizeof_bits::value / 8; + pointer_scale_ += (thread_row_byte_offset + thread_col_byte_offset); + if (pointer_zero_ != nullptr) + { + pointer_zero_ += (thread_row_byte_offset + thread_col_byte_offset); + } + + // For the rows, we must check that we are within the extent AND the tile to avoid extra reads on + // a given iteration. The same threads will be responsible for issues reads since the number of scales + // read in a given iteration is a constant. Therefore, we should never have to update is_valid_ + // outside of the constructor. + const int global_row = threadblock_offset.row() + thread_row; + const int global_col = threadblock_offset.column() + thread_col * kAlignment; + + const bool row_in_bounds = global_row < extent.row() && thread_row < Shape::kRow; + const bool col_in_bounds = global_col < extent.column(); + + is_valid_ = row_in_bounds && col_in_bounds; + } + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE FineGrainedScaleZeroIterator(Params const& params, ///< Precomputed parameters object + Pointer pointer_scale, ///< Pointer to start of scale tensor + Pointer pointer_zero, ///< Pointer to start of zero tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + int group_size) + : FineGrainedScaleZeroIterator( + params, pointer_scale, pointer_zero, extent, thread_id, make_Coord(0, 0), group_size) + { + } + + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) + { + const LongIndex row_byte_offset = tile_offset.row() * params_.inc_advance_; + const LongIndex col_byte_offset = tile_offset.column() * Shape::kColumn * sizeof_bits::value / 8; + pointer_scale_ += row_byte_offset + col_byte_offset; + if (pointer_zero_ != nullptr) + { + pointer_zero_ += row_byte_offset + col_byte_offset; + } + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE void clear_mask(bool enable = true) + { + is_valid_ &= (!enable); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() const + { + return is_valid_; + } + + /// Returns a scale pointer + CUTLASS_HOST_DEVICE + AccessType* get_scale() const + { + return reinterpret_cast(pointer_scale_); + } + + /// Returns a zero pointer + CUTLASS_HOST_DEVICE + AccessType* get_zero() const + { + return reinterpret_cast(pointer_zero_); + } +}; + +} // namespace threadblock +} // namespace transform +} // namespace cutlass diff --git a/csrc/cutlass_extensions/weight_only_quant_op.h b/csrc/cutlass_extensions/weight_only_quant_op.h new file mode 100644 index 000000000000..64774428e9f9 --- /dev/null +++ b/csrc/cutlass_extensions/weight_only_quant_op.h @@ -0,0 +1,58 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. +*/ + +#pragma once + +namespace cutlass +{ + +enum class WeightOnlyQuantOp +{ + UNDEFINED, + PER_COLUMN_SCALE_ONLY, + FINEGRAINED_SCALE_ONLY, + FINEGRAINED_SCALE_AND_ZEROS +}; + +constexpr bool isFinegrained(WeightOnlyQuantOp op) +{ + return op == WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS || op == WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY; +} + +constexpr bool hasZero(WeightOnlyQuantOp op) +{ + return op == WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS; +} + +} // namespace cutlass From 0cd943679c69aaa282e4f105093f97f6764edb0d Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 31 Jan 2024 10:20:47 +0000 Subject: [PATCH 03/33] Port MoE kernels --- csrc/moe/moe_kernels.cu | 1112 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 1112 insertions(+) create mode 100644 csrc/moe/moe_kernels.cu diff --git a/csrc/moe/moe_kernels.cu b/csrc/moe/moe_kernels.cu new file mode 100644 index 000000000000..c42fd5370202 --- /dev/null +++ b/csrc/moe/moe_kernels.cu @@ -0,0 +1,1112 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/common/workspace.h" +#include +#include +#include +#include + +// Ignore CUTLASS warnings about type punning +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" + +#include "cutlass/array.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass_extensions/epilogue/thread/fused_activations.h" + +#pragma GCC diagnostic pop + +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.h" + +#ifndef CUDART_VERSION +#error CUDART_VERSION Undefined! +#elif (CUDART_VERSION >= 11050) +#include +#include +#include +#else +#include "3rdparty/cub/cub.cuh" +#include "3rdparty/cub/device/device_radix_sort.cuh" +#include "3rdparty/cub/util_type.cuh" +#endif + +using namespace tensorrt_llm::kernels; +using namespace tensorrt_llm::common; + +namespace tensorrt_llm::kernels +{ + +static constexpr int WARP_SIZE = 32; + +// ====================== Softmax things =============================== +// We have our own implementation of softmax here so we can support transposing the output +// in the softmax kernel when we extend this module to support expert-choice routing. +template +__launch_bounds__(TPB) __global__ + void moeSoftmax(const float* input, const bool* finished, float* output, const int num_cols) +{ + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmpStorage; + + __shared__ float normalizing_factor; + __shared__ float float_max; + + const int thread_row_offset = blockIdx.x * num_cols; + + cub::Sum sum; + float threadData(-FLT_MAX); + + // Don't touch finished rows. + if ((finished != nullptr) && finished[blockIdx.x]) + { + return; + } + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) + { + const int idx = thread_row_offset + ii; + threadData = max(input[idx], threadData); + } + + const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); + if (threadIdx.x == 0) + { + float_max = maxElem; + } + __syncthreads(); + + threadData = 0; + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) + { + const int idx = thread_row_offset + ii; + threadData += exp((static_cast(input[idx]) - float_max)); + } + + const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum); + + if (threadIdx.x == 0) + { + normalizing_factor = 1.f / Z; + } + __syncthreads(); + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) + { + const int idx = thread_row_offset + ii; + const float val = exp((static_cast(input[idx]) - float_max)) * normalizing_factor; + output[idx] = val; + } +} + +template +__launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax, const bool* finished, float* output, + int* indices, int* source_rows, const int num_experts, const int k, const int start_expert, const int end_expert) +{ + + using cub_kvp = cub::KeyValuePair; + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmpStorage; + + cub_kvp thread_kvp; + cub::ArgMax arg_max; + + const int num_rows = gridDim.x; + const int block_row = blockIdx.x; + + const bool row_is_active = finished ? !finished[block_row] : true; + const int thread_read_offset = blockIdx.x * num_experts; + for (int k_idx = 0; k_idx < k; ++k_idx) + { + thread_kvp.key = 0; + thread_kvp.value = -1.f; // This is OK because inputs are probabilities + + cub_kvp inp_kvp; + for (int expert = threadIdx.x; expert < num_experts; expert += TPB) + { + const int idx = thread_read_offset + expert; + inp_kvp.key = expert; + inp_kvp.value = inputs_after_softmax[idx]; + + for (int prior_k = 0; prior_k < k_idx; ++prior_k) + { + const int prior_winning_expert = indices[k * block_row + prior_k]; + + if (prior_winning_expert == expert) + { + inp_kvp = thread_kvp; + } + } + + thread_kvp = arg_max(inp_kvp, thread_kvp); + } + + const cub_kvp result_kvp = BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max); + if (threadIdx.x == 0) + { + // Ignore experts the node isn't responsible for with expert parallelism + const int expert = result_kvp.key; + const bool node_uses_expert = expert >= start_expert && expert < end_expert; + const bool should_process_row = row_is_active && node_uses_expert; + + const int idx = k * block_row + k_idx; + output[idx] = result_kvp.value; + indices[idx] = should_process_row ? (expert - start_expert) : num_experts; + assert(indices[idx] >= 0); + source_rows[idx] = k_idx * num_rows + block_row; + } + __syncthreads(); + } +} + +// ====================== TopK softmax things =============================== + +/* + A Top-K gating softmax written to exploit when the number of experts in the MoE layers + are a small power of 2. This allows us to cleanly share the rows among the threads in + a single warp and eliminate communication between warps (so no need to use shared mem). + + It fuses the softmax, max and argmax into a single kernel. + + Limitations: + 1) This implementation is intended for when the number of experts is a small power of 2. + 2) This implementation assumes k is small, but will work for any k. +*/ + +template +__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ + void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, int* indices, + int* source_rows, const int k, const int start_expert, const int end_expert) +{ + // We begin by enforcing compile time assertions and setting up compile time constants. + static_assert(VPT == (VPT & -VPT), "VPT must be power of 2"); + static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2"); + static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2"); + static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16"); + + // Number of bytes each thread pulls in per load + static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float); + static constexpr int ELTS_PER_ROW = NUM_EXPERTS; + static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT; + static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG; + + // Restrictions based on previous section. + static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg"); + static_assert(WARP_SIZE % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp"); + static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), "THREADS_PER_ROW must be power of 2"); + static_assert(THREADS_PER_ROW <= WARP_SIZE, "THREADS_PER_ROW can be at most warp size"); + + // We have NUM_EXPERTS elements per row. We specialize for small #experts + static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT; + static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW; + static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP; + + // Restrictions for previous section. + static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0, "The elts per row must cleanly divide the total elt per warp"); + + // ===================== From this point, we finally start computing run-time variables. ======================== + + // Compute CTA and warp rows. We pack multiple rows into a single warp, and a block contains WARPS_PER_CTA warps. + // This, each block processes a chunk of rows. We start by computing the start row for each block. + const int cta_base_row = blockIdx.x * ROWS_PER_CTA; + + // Now, using the base row per thread block, we compute the base row per warp. + const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP; + + // The threads in a warp are split into sub-groups that will work on a row. + // We compute row offset for each thread sub-group + const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW; + const int thread_row = warp_base_row + thread_row_in_warp; + + // Threads with indices out of bounds should early exit here. + if (thread_row >= num_rows) + { + return; + } + const bool row_is_active = finished ? !finished[thread_row] : true; + + // We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the + // row it will read. + const float* thread_row_ptr = input + thread_row * ELTS_PER_ROW; + + // Now, we compute the group each thread belong to in order to determine the first column to start loads. + const int thread_group_idx = threadIdx.x % THREADS_PER_ROW; + const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG; + const float* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread; + + // Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory, + // this can support all powers of 2 up to 16. + using AccessType = cutlass::AlignedArray; + + // Finally, we pull in the data from global mem + cutlass::Array row_chunk; + AccessType* row_chunk_vec_ptr = reinterpret_cast(&row_chunk); + const AccessType* vec_thread_read_ptr = reinterpret_cast(thread_read_ptr); +#pragma unroll + for (int ii = 0; ii < LDG_PER_THREAD; ++ii) + { + row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW]; + } + + // First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just + // convert to float afterwards for the exp + sum reduction. + float thread_max = row_chunk[0]; +#pragma unroll + for (int ii = 1; ii < VPT; ++ii) + { + thread_max = max(thread_max, row_chunk[ii]); + } + +// Now, we find the max within the thread group and distribute among the threads. We use a butterfly reduce. +#pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) + { + thread_max = max(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW)); + } + + // From this point, thread max in all the threads have the max within the row. + // Now, we subtract the max from each element in the thread and take the exp. We also compute the thread local sum. + float row_sum = 0; +#pragma unroll + for (int ii = 0; ii < VPT; ++ii) + { + row_chunk[ii] = expf(row_chunk[ii] - thread_max); + row_sum += row_chunk[ii]; + } + +// Now, we perform the sum reduce within each thread group. Similar to the max reduce, we use a bufferfly pattern. +#pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) + { + row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, THREADS_PER_ROW); + } + + // From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables + // respectively. Finally, we can scale the rows for the softmax. Technically, for top-k gating we don't need to + // compute the entire softmax row. We can likely look at the maxes and only compute for the top-k values in the row. + // However, this kernel will likely not be a bottle neck and it seems better to closer match torch and find the + // argmax after computing the softmax. + const float reciprocal_row_sum = 1.f / row_sum; + +#pragma unroll + for (int ii = 0; ii < VPT; ++ii) + { + row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum; + } + + // Now, softmax_res contains the softmax of the row chunk. Now, I want to find the topk elements in each row, along + // with the max index. + int start_col = first_elt_read_by_thread; + static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW; + + for (int k_idx = 0; k_idx < k; ++k_idx) + { + // First, each thread does the local argmax + float max_val = row_chunk[0]; + int expert = start_col; +#pragma unroll + for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD; ++ldg, col += COLS_PER_GROUP_LDG) + { +#pragma unroll + for (int ii = 0; ii < ELTS_PER_LDG; ++ii) + { + float val = row_chunk[ldg * ELTS_PER_LDG + ii]; + + // No check on the experts here since columns with the smallest index are processed first and only + // updated if > (not >=) + if (val > max_val) + { + max_val = val; + expert = col + ii; + } + } + } + +// Now, we perform the argmax reduce. We use the butterfly pattern so threads reach consensus about the max. +// This will be useful for K > 1 so that the threads can agree on "who" had the max value. That thread can +// then blank out their max with -inf and the warp can run more iterations... +#pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) + { + float other_max = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, THREADS_PER_ROW); + int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, THREADS_PER_ROW); + + // We want lower indices to "win" in every thread so we break ties this way + if (other_max > max_val || (other_max == max_val && other_expert < expert)) + { + max_val = other_max; + expert = other_expert; + } + } + + // Write the max for this k iteration to global memory. + if (thread_group_idx == 0) + { + // Add a guard to ignore experts not included by this node + const bool node_uses_expert = expert >= start_expert && expert < end_expert; + const bool should_process_row = row_is_active && node_uses_expert; + + // The lead thread from each sub-group will write out the final results to global memory. (This will be a + // single) thread per row of the input/output matrices. + const int idx = k * thread_row + k_idx; + output[idx] = max_val; + indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS; + source_rows[idx] = k_idx * num_rows + thread_row; + } + + // Finally, we clear the value in the thread with the current max if there is another iteration to run. + if (k_idx + 1 < k) + { + const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG; + const int thread_to_clear_in_group = (expert / ELTS_PER_LDG) % THREADS_PER_ROW; + + // Only the thread in the group which produced the max will reset the "winning" value to -inf. + if (thread_group_idx == thread_to_clear_in_group) + { + const int offset_for_expert = expert % ELTS_PER_LDG; + // Safe to set to any negative value since row_chunk values must be between 0 and 1. + row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = -10000.f; + } + } + } +} + +namespace detail +{ +// Constructs some constants needed to partition the work across threads at compile time. +template +struct TopkConstants +{ + static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float); + static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, ""); + static constexpr int VECs_PER_THREAD = std::max(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE)); + static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG; + static constexpr int THREADS_PER_ROW = EXPERTS / VPT; + static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW; +}; +} // namespace detail + +template +void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, int* indices, + int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream) +{ + static constexpr std::size_t MAX_BYTES_PER_LDG = 16; + + static constexpr int BYTES_PER_LDG = std::min(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS); + using Constants = detail::TopkConstants; + static constexpr int VPT = Constants::VPT; + static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP; + const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP; + const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB; + + dim3 block_dim(WARP_SIZE, WARPS_PER_TB); + topkGatingSoftmax<<>>( + input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert); +} + +void topkGatingSoftmaxKernelLauncher(const float* input, const bool* finished, float* output, + float* softmax_temp_output, int* indices, int* source_row, const int num_rows, const int num_experts, const int k, + const int start_expert, const int end_expert, cudaStream_t stream) +{ + static constexpr int WARPS_PER_TB = 4; + + switch (num_experts) + { + case 1: + { + topkGatingSoftmaxLauncherHelper<1, WARPS_PER_TB>( + input, finished, output, indices, source_row, num_rows, k, start_expert, end_expert, stream); + break; + } + case 2: + { + topkGatingSoftmaxLauncherHelper<2, WARPS_PER_TB>( + input, finished, output, indices, source_row, num_rows, k, start_expert, end_expert, stream); + break; + } + case 4: + { + topkGatingSoftmaxLauncherHelper<4, WARPS_PER_TB>( + input, finished, output, indices, source_row, num_rows, k, start_expert, end_expert, stream); + break; + } + case 8: + { + topkGatingSoftmaxLauncherHelper<8, WARPS_PER_TB>( + input, finished, output, indices, source_row, num_rows, k, start_expert, end_expert, stream); + break; + } + case 16: + { + topkGatingSoftmaxLauncherHelper<16, WARPS_PER_TB>( + input, finished, output, indices, source_row, num_rows, k, start_expert, end_expert, stream); + break; + } + case 32: + { + topkGatingSoftmaxLauncherHelper<32, WARPS_PER_TB>( + input, finished, output, indices, source_row, num_rows, k, start_expert, end_expert, stream); + break; + } + case 64: + { + topkGatingSoftmaxLauncherHelper<64, WARPS_PER_TB>( + input, finished, output, indices, source_row, num_rows, k, start_expert, end_expert, stream); + break; + } + case 128: + { + topkGatingSoftmaxLauncherHelper<128, WARPS_PER_TB>( + input, finished, output, indices, source_row, num_rows, k, start_expert, end_expert, stream); + break; + } + case 256: + { + topkGatingSoftmaxLauncherHelper<256, WARPS_PER_TB>( + input, finished, output, indices, source_row, num_rows, k, start_expert, end_expert, stream); + break; + } + default: + { + static constexpr int TPB = 256; + TLLM_CHECK(softmax_temp_output != nullptr); + moeSoftmax<<>>(input, finished, softmax_temp_output, num_experts); + moeTopK<<>>( + softmax_temp_output, finished, output, indices, source_row, num_experts, k, start_expert, end_expert); + } + } +} + +// ========================== CUB Sorting things ==================================== +CubKeyValueSorter::CubKeyValueSorter() + : num_experts_(0) + , num_bits_(sizeof(int) * 8) +{ +} + +CubKeyValueSorter::CubKeyValueSorter(const int num_experts) + : num_experts_(num_experts) + , num_bits_((int) log2(num_experts) + 1) +{ +} + +void CubKeyValueSorter::updateNumExperts(const int num_experts) +{ + num_experts_ = num_experts; + num_bits_ = (int) log2(num_experts) + 1; +} + +size_t CubKeyValueSorter::getWorkspaceSize(const size_t num_key_value_pairs, const int num_experts) +{ + size_t num_bits = (int) log2(num_experts) + 1; + size_t required_storage = 0; + int* null_int = nullptr; + cub::DeviceRadixSort::SortPairs( + NULL, required_storage, null_int, null_int, null_int, null_int, num_key_value_pairs, 0, num_bits); + return required_storage; +} + +void CubKeyValueSorter::run(void* workspace, const size_t workspace_size, const int* keys_in, int* keys_out, + const int* values_in, int* values_out, const size_t num_key_value_pairs, cudaStream_t stream) +{ + size_t expected_ws_size = getWorkspaceSize(num_key_value_pairs, num_experts_); + size_t actual_ws_size = workspace_size; + + TLLM_CHECK_WITH_INFO(expected_ws_size <= workspace_size, + "[CubKeyValueSorter::run] The allocated workspace is too small to run this problem."); + cub::DeviceRadixSort::SortPairs( + workspace, actual_ws_size, keys_in, keys_out, values_in, values_out, num_key_value_pairs, 0, num_bits_, stream); +} + +// ============================== Infer GEMM sizes ================================= +// TODO Could linear search be better for small # experts +__device__ inline int findTotalEltsLeqTarget(const int* sorted_indices, const int arr_length, const int target) +{ + int64_t low = 0, high = arr_length - 1, target_location = -1; + while (low <= high) + { + int64_t mid = (low + high) / 2; + + if (sorted_indices[mid] > target) + { + high = mid - 1; + } + else + { + low = mid + 1; + target_location = mid; + } + } + return target_location + 1; +} + +// Sets up the gemm assuming the inputs, experts and outputs are stored in row major order. +// Assumes we want to perform output = matmul(inputs, experts) + bias +// +// "total_rows_before_expert" contains the index one past the last occurrence of the corresponding expert. +// e.g. Index 0 is the start offset of expert 1, the final entry is the total number of active rows +__global__ void computeTotalRowsBeforeExpertKernel(const int* sorted_experts, const int sorted_experts_len, + const int64_t num_experts, int64_t* total_rows_before_expert) +{ + // First, compute the global tid. We only need 1 thread per expert. + const int expert = blockIdx.x * blockDim.x + threadIdx.x; + if (expert >= num_experts) + { + return; + } + + // This should construct the last index where each expert occurs. + total_rows_before_expert[expert] = findTotalEltsLeqTarget(sorted_experts, sorted_experts_len, expert); +} + +// ========================== Permutation things ======================================= + +// Duplicated and permutes rows for MoE. In addition, reverse the permutation map to help with finalizing routing. + +// "expanded_x_row" simply means that the number of values is num_rows x k. It is "expanded" since we will have to +// duplicate some rows in the input matrix to match the dimensions. Duplicates will always get routed to separate +// experts in the end. + +// Note that the expanded_dest_row_to_expanded_source_row map referred to here has indices in the range (0, +// k*rows_in_input - 1). However, it is set up so that index 0, rows_in_input, 2*rows_in_input ... (k-1)*rows_in_input +// all map to row 0 in the original matrix. Thus, to know where to read in the source matrix, we simply take the modulus +// of the expanded index. + +template +__global__ void expandInputRowsKernel(const T* unpermuted_input, T* permuted_output, + const int* expanded_dest_row_to_expanded_source_row, int* expanded_source_row_to_expanded_dest_row, + const int num_rows, const int64_t* num_dest_rows, const int cols) +{ + + // Reverse permutation map. + // I do this so that later, we can use the source -> dest map to do the k-way reduction and unpermuting. I need the + // reverse map for that reduction to allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1 + // thread block will be responsible for all k summations. + const int expanded_dest_row = blockIdx.x; + const int expanded_source_row = expanded_dest_row_to_expanded_source_row[expanded_dest_row]; + if (threadIdx.x == 0) + { + expanded_source_row_to_expanded_dest_row[expanded_source_row] = expanded_dest_row; + } + + if (!CHECK_SKIPPED || blockIdx.x < *num_dest_rows) + { + // Duplicate and permute rows + const int source_row = expanded_source_row % num_rows; + + const T* source_row_ptr = unpermuted_input + source_row * cols; + T* dest_row_ptr = permuted_output + expanded_dest_row * cols; + + for (int tid = threadIdx.x; tid < cols; tid += blockDim.x) + { + dest_row_ptr[tid] = source_row_ptr[tid]; + } + } +} + +template +void expandInputRowsKernelLauncher(const T* unpermuted_input, T* permuted_output, + const int* expanded_dest_row_to_expanded_source_row, int* expanded_source_row_to_expanded_dest_row, + const int num_rows, const int64_t* num_valid_tokens_ptr, const int cols, const int k, cudaStream_t stream) +{ + const int blocks = num_rows * k; + const int threads = std::min(cols, 1024); + auto func = (num_valid_tokens_ptr != nullptr) ? expandInputRowsKernel : expandInputRowsKernel; + func<<>>(unpermuted_input, permuted_output, expanded_dest_row_to_expanded_source_row, + expanded_source_row_to_expanded_dest_row, num_rows, num_valid_tokens_ptr, cols); +} + +enum class ScaleMode : int +{ + NO_SCALE = 0, + DEFAULT = 1, + RENORM_SCALE = 2, +}; + +// Final kernel to unpermute and scale +// This kernel unpermutes the original data, does the k-way reduction and performs the final skip connection. +template +__global__ void finalizeMoeRoutingKernel(const T* expanded_permuted_rows, T* reduced_unpermuted_output, const T* skip_1, + const T* skip_2, const T* bias, const float* scales, const int* expanded_source_row_to_expanded_dest_row, + const int* expert_for_source_row, const int cols, const int k, const int64_t* num_valid_ptr) +{ + const int original_row = blockIdx.x; + const int num_rows = gridDim.x; + const auto offset = original_row * cols; + T* reduced_row_ptr = reduced_unpermuted_output + offset; + const T* skip_1_row_ptr{}; + const T* skip_2_row_ptr{}; + + if (RESIDUAL_NUM >= 1) + { + skip_1_row_ptr = skip_1 + offset; + } + + if (RESIDUAL_NUM == 2) + { + skip_2_row_ptr = skip_2 + offset; + } + const int64_t num_valid = *num_valid_ptr; + for (int tid = threadIdx.x; tid < cols; tid += blockDim.x) + { + T thread_output{0.f}; + float row_rescale{0.f}; + for (int k_idx = 0; k_idx < k; ++k_idx) + { + const int expanded_original_row = original_row + k_idx * num_rows; + const int expanded_permuted_row = expanded_source_row_to_expanded_dest_row[expanded_original_row]; + + const int64_t k_offset = original_row * k + k_idx; + const float row_scale = (SCALE_MODE == ScaleMode::NO_SCALE) ? 1.f : scales[k_offset]; + if constexpr (SCALE_MODE == ScaleMode::RENORM_SCALE) + { + row_rescale = row_rescale + row_scale; + } + + // Check after row sum has accumulated + if (CHECK_SKIPPED && expanded_permuted_row >= num_valid) + { + continue; + } + + const T* expanded_permuted_rows_row_ptr = expanded_permuted_rows + expanded_permuted_row * cols; + + const int expert_idx = expert_for_source_row[k_offset]; + + const T* bias_ptr = bias + expert_idx * cols; + const T bias_value = HAS_BIAS ? bias_ptr[tid] : T(0.f); + + thread_output = static_cast(thread_output) + + row_scale * static_cast(expanded_permuted_rows_row_ptr[tid] + bias_value); + } + + if (SCALE_MODE == ScaleMode::RENORM_SCALE && (!CHECK_SKIPPED || thread_output)) + { + assert(row_rescale != 0.f); + thread_output = static_cast(thread_output) / row_rescale; + } + + if (RESIDUAL_NUM == 1) + { + thread_output = thread_output + skip_1_row_ptr[tid]; + } + else if (RESIDUAL_NUM == 2) + { + thread_output = thread_output + skip_1_row_ptr[tid] + skip_2_row_ptr[tid]; + } + reduced_row_ptr[tid] = thread_output; + } +} + +template +void finalizeMoeRoutingKernelLauncherSelectBias(const T* expanded_permuted_rows, T* reduced_unpermuted_output, + const T* skip_1, const T* skip_2, const T* bias, const float* scales, + const int* expanded_source_row_to_expanded_dest_row, const int* expert_for_source_row, const int num_rows, + const int cols, const int k, const int64_t* num_valid_ptr, MOEParallelismConfig parallelism_config, + MOEExpertScaleNormalizationMode normalization_mode, cudaStream_t stream) +{ + const int blocks = num_rows; + const int threads = std::min(cols, 1024); + + // Only add bias on rank 0 for tensor parallelism + const bool is_rank_0 = parallelism_config.tp_rank == 0; + const bool has_bias = bias != nullptr && is_rank_0; + + const bool check_finished = num_valid_ptr != nullptr; + + ScaleMode renorm_scales = ScaleMode::DEFAULT; + if (normalization_mode == MOEExpertScaleNormalizationMode::RENORMALIZE) + { + renorm_scales = k == 1 ? ScaleMode::NO_SCALE : ScaleMode::RENORM_SCALE; + } + + using FuncPtr = decltype(&finalizeMoeRoutingKernel); + FuncPtr func_map[2][3][2] + = {{ + {&finalizeMoeRoutingKernel, + &finalizeMoeRoutingKernel}, + {&finalizeMoeRoutingKernel, + &finalizeMoeRoutingKernel}, + {&finalizeMoeRoutingKernel, + &finalizeMoeRoutingKernel}, + }, + { + {&finalizeMoeRoutingKernel, + &finalizeMoeRoutingKernel}, + {&finalizeMoeRoutingKernel, + &finalizeMoeRoutingKernel}, + {&finalizeMoeRoutingKernel, + &finalizeMoeRoutingKernel}, + }}; + auto* const func = func_map[check_finished][int(renorm_scales)][has_bias]; + func<<>>(expanded_permuted_rows, reduced_unpermuted_output, skip_1, skip_2, bias, + scales, expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k, num_valid_ptr); +} + +template +void finalizeMoeRoutingKernelLauncher(const T* expanded_permuted_rows, T* reduced_unpermuted_output, const T* skip_1, + const T* skip_2, const T* bias, const float* scales, const int* expanded_source_row_to_expanded_dest_row, + const int* expert_for_source_row, const int num_rows, const int cols, const int k, const int64_t* num_valid_ptr, + MOEParallelismConfig parallelism_config, MOEExpertScaleNormalizationMode normalization_mode, cudaStream_t stream) +{ + // If we are not rank 0 we should not add any residuals because the allreduce would sum multiple copies + const bool is_rank_0 = parallelism_config.tp_rank == 0; + if (skip_1 == nullptr || !is_rank_0) + { + assert(skip_2 == nullptr); + finalizeMoeRoutingKernelLauncherSelectBias(expanded_permuted_rows, reduced_unpermuted_output, skip_1, + skip_2, bias, scales, expanded_source_row_to_expanded_dest_row, expert_for_source_row, num_rows, cols, k, + num_valid_ptr, parallelism_config, normalization_mode, stream); + } + else if (skip_2 == nullptr) + { + finalizeMoeRoutingKernelLauncherSelectBias(expanded_permuted_rows, reduced_unpermuted_output, skip_1, + skip_2, bias, scales, expanded_source_row_to_expanded_dest_row, expert_for_source_row, num_rows, cols, k, + num_valid_ptr, parallelism_config, normalization_mode, stream); + } + else + { + finalizeMoeRoutingKernelLauncherSelectBias(expanded_permuted_rows, reduced_unpermuted_output, skip_1, + skip_2, bias, scales, expanded_source_row_to_expanded_dest_row, expert_for_source_row, num_rows, cols, k, + num_valid_ptr, parallelism_config, normalization_mode, stream); + } +} + +// ============================== Gated Activation ================================= + +template +__global__ void doGatedActivationKernel( + T* output, const T* gemm_result, const int64_t* num_valid_tokens_ptr, size_t inter_size) +{ + const int tid = threadIdx.x; + const int token = blockIdx.x; + if (num_valid_tokens_ptr && token >= *num_valid_tokens_ptr) + { + return; + } + + ActFn fn{}; + output = output + token * inter_size; + gemm_result = gemm_result + token * inter_size * 2; + for (int i = tid; i < inter_size; i += blockDim.x) + { + T fc1_value = gemm_result[i]; + // BF16 isn't supported, use FP32 for activation function + float gate_value = gemm_result[i + inter_size]; + T gate_act = fn(gate_value); + output[i] = fc1_value * gate_act; + } +} + +template +void doGatedActivation(T* output, const T* gemm_result, const int64_t* num_valid_tokens_ptr, int inter_size, + int num_tokens, ActivationType activation_type, cudaStream_t stream) +{ + const int blocks = num_tokens; + const int threads = std::min(inter_size, 1024); + + // TODO Instead of T use a vectored type if performance would benefit + // TODO For some reason Volta fails on GELU_taylor here with Warp Illegal Instruction. + auto* fn = activation_type == ActivationType::Swiglu + ? &doGatedActivationKernel> + : &doGatedActivationKernel>; + fn<<>>(output, gemm_result, num_valid_tokens_ptr, inter_size); +} + +template +std::vector CutlassMoeFCRunner::getWorkspaceBufferSizes(const int num_rows, + const int hidden_size, const int inter_size, const int num_experts, const int num_experts_per_node, const int k, + ActivationType activation_type) const +{ + const size_t num_moe_inputs = k * num_rows; + const size_t buf_size = num_moe_inputs * hidden_size; + const size_t interbuf_elems = num_moe_inputs * inter_size; + const size_t glu_inter_elems = isGatedActivation(activation_type) ? (interbuf_elems * 2) : 0; + int num_softmax_outs = 0; + + const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); + if (!is_pow_2 || num_experts > 256) + { + num_softmax_outs = num_rows * num_experts; + } + + size_t source_rows_size = num_moe_inputs * sizeof(int); + size_t permuted_rows_size = num_moe_inputs * sizeof(int); + size_t permuted_experts_size = num_moe_inputs * sizeof(int); + size_t permuted_data_size = buf_size * sizeof(T); + size_t total_rows_before_expert_size = num_experts_per_node * sizeof(int64_t); + size_t softmax_out_size = num_softmax_outs * sizeof(float); + size_t glu_inter_size = glu_inter_elems * sizeof(T); + size_t fc1_result_size = interbuf_elems * sizeof(T); + size_t sorter_size = CubKeyValueSorter::getWorkspaceSize(num_rows, num_experts); + + std::vector workspace{ + source_rows_size, + permuted_rows_size, + permuted_experts_size, + permuted_data_size, + total_rows_before_expert_size, + softmax_out_size, + glu_inter_size, + // These pointers reuse the same memory + std::max(fc1_result_size, sorter_size), + }; + return workspace; +} + +template +size_t CutlassMoeFCRunner::getWorkspaceSize(const int num_rows, const int hidden_size, + const int inter_size, const int num_experts, const int k, ActivationType activation_type, + MOEParallelismConfig parallelism_config) const +{ + const int ep_size = parallelism_config.ep_size; + TLLM_CHECK_WITH_INFO(num_experts % ep_size == 0, "Number of experts must be a multiple of tp size"); + auto workspace = getWorkspaceBufferSizes( + num_rows, hidden_size, inter_size, num_experts, num_experts / ep_size, k, activation_type); + return tensorrt_llm::common::calculateTotalWorkspaceSize(workspace.data(), workspace.size()); +} + +template +void CutlassMoeFCRunner::configureWsPtrs(char* ws_ptr, const int num_rows, const int hidden_size, + const int inter_size, const int num_experts, const int num_experts_per_node, const int k, + ActivationType activation_type) +{ + auto workspace = getWorkspaceBufferSizes( + num_rows, hidden_size, inter_size, num_experts, num_experts_per_node, k, activation_type); + + std::vector ws_sliced{(int8_t*) ws_ptr}; + for (auto size : workspace) + { + ws_sliced.push_back(nextWorkspacePtr(ws_sliced.back(), size)); + } + + source_rows_ = (int*) ws_sliced[0]; + permuted_rows_ = (int*) ws_sliced[1]; + permuted_experts_ = (int*) ws_sliced[2]; + permuted_data_ = (T*) ws_sliced[3]; + + total_rows_before_expert_ = (int64_t*) ws_sliced[4]; + + softmax_out_ = nullptr; + const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); + if (!is_pow_2 || num_experts > 256) + { + softmax_out_ = (float*) ws_sliced[5]; + } + + glu_inter_result_ = (T*) ws_sliced[6]; + + // These pointers are aliased. Since the sort ws can be overwritten after it is finished + sorter_ws_ = (char*) ws_sliced[7]; + fc1_result_ = (T*) ws_sliced[7]; +} + +template +void CutlassMoeFCRunner::runMoe(const void* input_activations_void, const float* gating_output, + const void* fc1_expert_weights_void, const void* fc1_scales_void, const void* fc1_expert_biases_void, + ActivationType fc1_activation_type, const void* fc2_expert_weights_void, const void* fc2_scales_void, + const void* fc2_expert_biases_void, const int num_rows, const int hidden_size, const int inter_size, + const int num_experts, const int k, char* workspace_ptr, void* final_output_void, void* fc2_result_void, + const bool* finished, const int active_rows, void* expert_scales_void, + int* expanded_source_row_to_expanded_dest_row, int* expert_for_source_row, MOEParallelismConfig parallelism_config, + MOEExpertScaleNormalizationMode normalization_mode, cudaStream_t stream) +{ + static constexpr bool scales_required + = std::is_same::value || std::is_same::value; + + auto* input_activations = static_cast(input_activations_void); + auto* fc1_expert_weights = static_cast(fc1_expert_weights_void); + auto* fc1_scales = static_cast(fc1_scales_void); + auto* fc1_expert_biases = static_cast(fc1_expert_biases_void); + auto* fc2_expert_weights = static_cast(fc2_expert_weights_void); + auto* fc2_scales = static_cast(fc2_scales_void); + auto* fc2_expert_biases = static_cast(fc2_expert_biases_void); + auto* final_output = static_cast(final_output_void); + auto* fc2_result = static_cast(fc2_result_void); + auto* expert_scales = static_cast(expert_scales_void); + + TLLM_CHECK(input_activations); + TLLM_CHECK(gating_output); + TLLM_CHECK(fc1_expert_weights); + TLLM_CHECK(fc2_expert_weights); + TLLM_CHECK(workspace_ptr); + TLLM_CHECK(fc2_result); + TLLM_CHECK(expert_scales); + TLLM_CHECK(expanded_source_row_to_expanded_dest_row); + TLLM_CHECK(expert_for_source_row); + TLLM_CHECK(num_experts % parallelism_config.ep_size == 0); + + if (scales_required) + { + TLLM_CHECK_WITH_INFO(fc1_scales != nullptr, "Scales expected but scale for first matmul is a null pointer"); + TLLM_CHECK_WITH_INFO(fc2_scales != nullptr, "Scales expected but scale for second matmul is a null pointer"); + } + else + { + TLLM_CHECK_WITH_INFO(fc1_scales == nullptr, "Scales are ignored for fp32/fp16/bf16 but received scale for FC1"); + TLLM_CHECK_WITH_INFO(fc2_scales == nullptr, "Scales are ignored for fp32/fp16/bf16 but received scale for FC2"); + } + + const int num_experts_per_node = num_experts / parallelism_config.ep_size; + const int start_expert = num_experts_per_node * parallelism_config.ep_rank; + const int end_expert = start_expert + num_experts_per_node; + + configureWsPtrs( + workspace_ptr, num_rows, hidden_size, inter_size, num_experts, num_experts_per_node, k, fc1_activation_type); + topkGatingSoftmaxKernelLauncher(gating_output, finished, expert_scales, softmax_out_, expert_for_source_row, + source_rows_, num_rows, num_experts, k, start_expert, end_expert, stream); + + sync_check_cuda_error(); + + sorter_.updateNumExperts(num_experts); + const int sorter_ws_size_bytes = pad_to_multiple_of_16(sorter_.getWorkspaceSize(k * num_rows, num_experts)); + sorter_.run((void*) sorter_ws_, sorter_ws_size_bytes, expert_for_source_row, permuted_experts_, source_rows_, + permuted_rows_, k * num_rows, stream); + + sync_check_cuda_error(); + + // Upper bound on number of expanded rows + const int expanded_active_expert_rows = k * active_rows; + computeTotalRowsBeforeExpert( + permuted_experts_, expanded_active_expert_rows, num_experts_per_node, total_rows_before_expert_, stream); + + sync_check_cuda_error(); + + const bool needs_num_valid = finished || parallelism_config.ep_size > 1; + const int64_t* num_valid_tokens_ptr + = needs_num_valid ? total_rows_before_expert_ + num_experts_per_node - 1 : nullptr; + expandInputRowsKernelLauncher(input_activations, permuted_data_, permuted_rows_, + expanded_source_row_to_expanded_dest_row, num_rows, num_valid_tokens_ptr, hidden_size, k, stream); + + sync_check_cuda_error(); + + if (!isGatedActivation(fc1_activation_type)) + { + moe_gemm_runner_.moeGemmBiasAct(permuted_data_, fc1_expert_weights, fc1_scales, fc1_expert_biases, fc1_result_, + total_rows_before_expert_, expanded_active_expert_rows, inter_size, hidden_size, num_experts_per_node, + fc1_activation_type, stream); + } + else + { + const size_t fc1_out_size = inter_size * 2; + // Run the GEMM with activation function overridden with `Identity`, we do the activation separately + moe_gemm_runner_.moeGemmBiasAct(permuted_data_, fc1_expert_weights, fc1_scales, fc1_expert_biases, + glu_inter_result_, total_rows_before_expert_, expanded_active_expert_rows, fc1_out_size, hidden_size, + num_experts_per_node, ActivationType::Identity, stream); + + sync_check_cuda_error(); + + doGatedActivation(fc1_result_, glu_inter_result_, num_valid_tokens_ptr, inter_size, num_rows * k, + fc1_activation_type, stream); + } + + sync_check_cuda_error(); + + moe_gemm_runner_.moeGemm(fc1_result_, fc2_expert_weights, fc2_scales, fc2_result, total_rows_before_expert_, + expanded_active_expert_rows, hidden_size, inter_size, num_experts_per_node, stream); + + sync_check_cuda_error(); + + finalizeMoeRoutingKernelLauncher(fc2_result, final_output, + // TODO pass 'skip' connections (residuals) + nullptr, nullptr, fc2_expert_biases, expert_scales, expanded_source_row_to_expanded_dest_row, + expert_for_source_row, num_rows, hidden_size, k, num_valid_tokens_ptr, parallelism_config, normalization_mode, + stream); + + sync_check_cuda_error(); +} + +template +void CutlassMoeFCRunner::computeTotalRowsBeforeExpert(const int* sorted_indices, + const int total_indices, const int num_experts, int64_t* total_rows_before_expert, cudaStream_t stream) +{ + const int threads = std::min(1024, num_experts); + const int blocks = (num_experts + threads - 1) / threads; + + computeTotalRowsBeforeExpertKernel<<>>( + sorted_indices, total_indices, num_experts, total_rows_before_expert); +} + +// ==================== Helper for getting load balanced routing for profiling ================================== + +template +__global__ void initRoutingKernelDiagonal(void* data_void, int num_experts, int num_tokens, int k, int stride) +{ + assert(k == 1 || (stride % num_experts) != 0); + int token = blockIdx.x * blockDim.x + threadIdx.x; + if (token >= num_tokens) + { + return; + } + T* data = (T*) data_void + token * num_experts; + int start = token % num_experts; + for (int i = 0; i < k; i++) + { + data[start] = T{1.f}; + start += stride; + if (start >= num_experts) // Wrap + start -= num_experts; + } +} + +void makeLoadBalancedRoutingConfiguration( + void* data_void, int num_experts, int num_tokens, int k, nvinfer1::DataType type, cudaStream_t stream) +{ + size_t item_size = sizeof(float); + auto* func = &initRoutingKernelDiagonal; + if (type == nvinfer1::DataType::kHALF) + { + func = &initRoutingKernelDiagonal; + item_size = sizeof(half); + } +#ifdef ENABLE_BF16 + else if (type == nvinfer1::DataType::kBF16) + { + func = &initRoutingKernelDiagonal<__nv_bfloat16>; + item_size = sizeof(__nv_bfloat16); + } +#endif + + check_cuda_error(cudaMemsetAsync(data_void, 0x0, num_experts * num_tokens * item_size, stream)); + + int stride = tensorrt_llm::common::ceilDiv(num_experts, k); + + int blockDim = 256; + int gridDim = tensorrt_llm::common::ceilDiv(num_tokens, blockDim); + func<<>>(data_void, num_experts, num_tokens, k, stride); + + sync_check_cuda_error(); +} + +// ==================== Variable batched GEMM specializations ================================== +template class CutlassMoeFCRunner; + +#ifdef ENABLE_BF16 +template class CutlassMoeFCRunner<__nv_bfloat16, __nv_bfloat16>; +template class CutlassMoeFCRunner<__nv_bfloat16, uint8_t>; +template class CutlassMoeFCRunner<__nv_bfloat16, cutlass::uint4b_t>; +#endif + +template class CutlassMoeFCRunner; +template class CutlassMoeFCRunner; +template class CutlassMoeFCRunner; + +} // namespace tensorrt_llm::kernels From cb4524c643d5b533adb9a9b00ee8ca7a6936c672 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 31 Jan 2024 10:29:48 +0000 Subject: [PATCH 04/33] Move moe_kernels --- csrc/{moe => }/moe_kernels.cu | 77 ++++------------------------------- 1 file changed, 8 insertions(+), 69 deletions(-) rename csrc/{moe => }/moe_kernels.cu (95%) diff --git a/csrc/moe/moe_kernels.cu b/csrc/moe_kernels.cu similarity index 95% rename from csrc/moe/moe_kernels.cu rename to csrc/moe_kernels.cu index c42fd5370202..0fde9d429ed3 100644 --- a/csrc/moe/moe_kernels.cu +++ b/csrc/moe_kernels.cu @@ -21,10 +21,6 @@ #include #include -// Ignore CUTLASS warnings about type punning -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wstrict-aliasing" - #include "cutlass/array.h" #include "cutlass/epilogue/thread/activation.h" #include "cutlass/numeric_conversion.h" @@ -36,20 +32,14 @@ #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.h" -#ifndef CUDART_VERSION -#error CUDART_VERSION Undefined! -#elif (CUDART_VERSION >= 11050) #include #include #include -#else -#include "3rdparty/cub/cub.cuh" -#include "3rdparty/cub/device/device_radix_sort.cuh" -#include "3rdparty/cub/util_type.cuh" -#endif -using namespace tensorrt_llm::kernels; -using namespace tensorrt_llm::common; +// FIXME(woosuk) +#ifndef ENABLE_BF16 +#define ENABLE_BF16 +#endif namespace tensorrt_llm::kernels { @@ -1045,68 +1035,17 @@ void CutlassMoeFCRunner::computeTotalRowsBeforeExpert(con sorted_indices, total_indices, num_experts, total_rows_before_expert); } -// ==================== Helper for getting load balanced routing for profiling ================================== - -template -__global__ void initRoutingKernelDiagonal(void* data_void, int num_experts, int num_tokens, int k, int stride) -{ - assert(k == 1 || (stride % num_experts) != 0); - int token = blockIdx.x * blockDim.x + threadIdx.x; - if (token >= num_tokens) - { - return; - } - T* data = (T*) data_void + token * num_experts; - int start = token % num_experts; - for (int i = 0; i < k; i++) - { - data[start] = T{1.f}; - start += stride; - if (start >= num_experts) // Wrap - start -= num_experts; - } -} - -void makeLoadBalancedRoutingConfiguration( - void* data_void, int num_experts, int num_tokens, int k, nvinfer1::DataType type, cudaStream_t stream) -{ - size_t item_size = sizeof(float); - auto* func = &initRoutingKernelDiagonal; - if (type == nvinfer1::DataType::kHALF) - { - func = &initRoutingKernelDiagonal; - item_size = sizeof(half); - } -#ifdef ENABLE_BF16 - else if (type == nvinfer1::DataType::kBF16) - { - func = &initRoutingKernelDiagonal<__nv_bfloat16>; - item_size = sizeof(__nv_bfloat16); - } -#endif - - check_cuda_error(cudaMemsetAsync(data_void, 0x0, num_experts * num_tokens * item_size, stream)); - - int stride = tensorrt_llm::common::ceilDiv(num_experts, k); - - int blockDim = 256; - int gridDim = tensorrt_llm::common::ceilDiv(num_tokens, blockDim); - func<<>>(data_void, num_experts, num_tokens, k, stride); - - sync_check_cuda_error(); -} - // ==================== Variable batched GEMM specializations ================================== template class CutlassMoeFCRunner; #ifdef ENABLE_BF16 template class CutlassMoeFCRunner<__nv_bfloat16, __nv_bfloat16>; -template class CutlassMoeFCRunner<__nv_bfloat16, uint8_t>; -template class CutlassMoeFCRunner<__nv_bfloat16, cutlass::uint4b_t>; +// template class CutlassMoeFCRunner<__nv_bfloat16, uint8_t>; +// template class CutlassMoeFCRunner<__nv_bfloat16, cutlass::uint4b_t>; #endif template class CutlassMoeFCRunner; -template class CutlassMoeFCRunner; -template class CutlassMoeFCRunner; +// template class CutlassMoeFCRunner; +// template class CutlassMoeFCRunner; } // namespace tensorrt_llm::kernels From c19120761825e1d66d9eacf3301caf870079c094 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 31 Jan 2024 10:30:25 +0000 Subject: [PATCH 05/33] Port MoE GEMM --- csrc/moe_gemm/moe_gemm_kernels.h | 81 ++++ csrc/moe_gemm/moe_gemm_kernels_bf16_bf16.cu | 24 + csrc/moe_gemm/moe_gemm_kernels_bf16_uint4.cu | 24 + csrc/moe_gemm/moe_gemm_kernels_bf16_uint8.cu | 24 + csrc/moe_gemm/moe_gemm_kernels_fp16_fp16.cu | 22 + csrc/moe_gemm/moe_gemm_kernels_fp16_uint4.cu | 22 + csrc/moe_gemm/moe_gemm_kernels_fp16_uint8.cu | 22 + csrc/moe_gemm/moe_gemm_kernels_fp32_fp32.cu | 22 + csrc/moe_gemm/moe_gemm_kernels_template.h | 440 +++++++++++++++++++ 9 files changed, 681 insertions(+) create mode 100644 csrc/moe_gemm/moe_gemm_kernels.h create mode 100644 csrc/moe_gemm/moe_gemm_kernels_bf16_bf16.cu create mode 100644 csrc/moe_gemm/moe_gemm_kernels_bf16_uint4.cu create mode 100644 csrc/moe_gemm/moe_gemm_kernels_bf16_uint8.cu create mode 100644 csrc/moe_gemm/moe_gemm_kernels_fp16_fp16.cu create mode 100644 csrc/moe_gemm/moe_gemm_kernels_fp16_uint4.cu create mode 100644 csrc/moe_gemm/moe_gemm_kernels_fp16_uint8.cu create mode 100644 csrc/moe_gemm/moe_gemm_kernels_fp32_fp32.cu create mode 100644 csrc/moe_gemm/moe_gemm_kernels_template.h diff --git a/csrc/moe_gemm/moe_gemm_kernels.h b/csrc/moe_gemm/moe_gemm_kernels.h new file mode 100644 index 000000000000..3f4def7d7152 --- /dev/null +++ b/csrc/moe_gemm/moe_gemm_kernels.h @@ -0,0 +1,81 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h" +#include +#include + +namespace tensorrt_llm +{ + +// Note update moe.py to match +enum class ActivationType +{ + Gelu = 0, + Relu, + Silu, + Swiglu, + Geglu, + Identity, + InvalidType +}; + +constexpr bool isGatedActivation(ActivationType activation_type) +{ + return activation_type == ActivationType::Swiglu || activation_type == ActivationType::Geglu; +} + +template +class MoeGemmRunner +{ +public: + MoeGemmRunner(); + + void setBestConfig(std::optional best_config) + { + best_config_ = std::move(best_config); + } + + void moeGemmBiasAct(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + ActivationType activation_type, cudaStream_t stream); + + void moeGemm(const T* A, const WeightType* B, const T* weight_scales, T* C, int64_t* total_rows_before_expert, + int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, cudaStream_t stream); + + std::vector getConfigs(); + +private: + template + void dispatchToArch(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, cudaStream_t stream, int* occupancy = nullptr); + + template + void runGemm(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cudaStream_t stream); + +private: + int sm_; + int multi_processor_count_; + std::optional best_config_{}; +}; + +} // namespace tensorrt_llm diff --git a/csrc/moe_gemm/moe_gemm_kernels_bf16_bf16.cu b/csrc/moe_gemm/moe_gemm_kernels_bf16_bf16.cu new file mode 100644 index 000000000000..42699295b7e8 --- /dev/null +++ b/csrc/moe_gemm/moe_gemm_kernels_bf16_bf16.cu @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" + +namespace tensorrt_llm +{ +#ifdef ENABLE_BF16 +template class MoeGemmRunner<__nv_bfloat16, __nv_bfloat16>; +#endif +} // namespace tensorrt_llm diff --git a/csrc/moe_gemm/moe_gemm_kernels_bf16_uint4.cu b/csrc/moe_gemm/moe_gemm_kernels_bf16_uint4.cu new file mode 100644 index 000000000000..b5d129ca91c0 --- /dev/null +++ b/csrc/moe_gemm/moe_gemm_kernels_bf16_uint4.cu @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" + +namespace tensorrt_llm +{ +#ifdef ENABLE_BF16 +template class MoeGemmRunner<__nv_bfloat16, cutlass::uint4b_t>; +#endif +} // namespace tensorrt_llm diff --git a/csrc/moe_gemm/moe_gemm_kernels_bf16_uint8.cu b/csrc/moe_gemm/moe_gemm_kernels_bf16_uint8.cu new file mode 100644 index 000000000000..174d5a7b907e --- /dev/null +++ b/csrc/moe_gemm/moe_gemm_kernels_bf16_uint8.cu @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" + +namespace tensorrt_llm +{ +#ifdef ENABLE_BF16 +template class MoeGemmRunner<__nv_bfloat16, uint8_t>; +#endif +} // namespace tensorrt_llm diff --git a/csrc/moe_gemm/moe_gemm_kernels_fp16_fp16.cu b/csrc/moe_gemm/moe_gemm_kernels_fp16_fp16.cu new file mode 100644 index 000000000000..f57d91f9d810 --- /dev/null +++ b/csrc/moe_gemm/moe_gemm_kernels_fp16_fp16.cu @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" + +namespace tensorrt_llm +{ +template class MoeGemmRunner; +} diff --git a/csrc/moe_gemm/moe_gemm_kernels_fp16_uint4.cu b/csrc/moe_gemm/moe_gemm_kernels_fp16_uint4.cu new file mode 100644 index 000000000000..3f4b0bb718fd --- /dev/null +++ b/csrc/moe_gemm/moe_gemm_kernels_fp16_uint4.cu @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" + +namespace tensorrt_llm +{ +template class MoeGemmRunner; +} diff --git a/csrc/moe_gemm/moe_gemm_kernels_fp16_uint8.cu b/csrc/moe_gemm/moe_gemm_kernels_fp16_uint8.cu new file mode 100644 index 000000000000..a8d2d5e6c8eb --- /dev/null +++ b/csrc/moe_gemm/moe_gemm_kernels_fp16_uint8.cu @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" + +namespace tensorrt_llm +{ +template class MoeGemmRunner; +} diff --git a/csrc/moe_gemm/moe_gemm_kernels_fp32_fp32.cu b/csrc/moe_gemm/moe_gemm_kernels_fp32_fp32.cu new file mode 100644 index 000000000000..6b57aae1d844 --- /dev/null +++ b/csrc/moe_gemm/moe_gemm_kernels_fp32_fp32.cu @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" + +namespace tensorrt_llm +{ +template class MoeGemmRunner; +} diff --git a/csrc/moe_gemm/moe_gemm_kernels_template.h b/csrc/moe_gemm/moe_gemm_kernels_template.h new file mode 100644 index 000000000000..19a480dc8986 --- /dev/null +++ b/csrc/moe_gemm/moe_gemm_kernels_template.h @@ -0,0 +1,440 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Ignore CUTLASS warnings about type punning +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" + +#include "cutlass/array.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/gemm/device/gemm_grouped.h" +#include "cutlass/gemm/kernel/default_gemm_grouped.h" + +#include "cutlass_extensions/compute_occupancy.h" +#include "cutlass_extensions/epilogue_helpers.h" +#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h" +#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h" +#include "cutlass_extensions/gemm/threadblock/default_mma.h" + +#pragma GCC diagnostic pop + +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h" +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h" +#include +#include +#include +#include + +namespace tensorrt_llm +{ + +// ============================= Variable batched Gemm things =========================== +template +void genericMoeGemmKernelLauncher(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, const int multi_processor_count, cudaStream_t stream, + int* kernel_occupancy = nullptr) +{ +#ifdef ENABLE_BF16 + static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value + || cutlass::platform::is_same::value, + "Specialized for bfloat16, half, float"); +#else + static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value, + "Specialized for half, float"); +#endif + + static_assert(cutlass::platform::is_same::value + || cutlass::platform::is_same::value + || cutlass::platform::is_same::value, + ""); + + // The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary. + using ElementType_ = + typename cutlass::platform::conditional::value, cutlass::half_t, T>::type; +#ifdef ENABLE_BF16 + using ElementType = + typename cutlass::platform::conditional::value, + cutlass::bfloat16_t, ElementType_>::type; +#else + using ElementType = ElementType_; +#endif + + using CutlassWeightType_ = + typename cutlass::platform::conditional::value, cutlass::half_t, + WeightType>::type; +#ifdef ENABLE_BF16 + using CutlassWeightType = + typename cutlass::platform::conditional::value, + cutlass::bfloat16_t, CutlassWeightType_>::type; +#else + using CutlassWeightType = CutlassWeightType_; +#endif + + // We need separate config for each architecture since we will target different tensorcore instructions. For float, + // we do not target TCs. + using MixedGemmArchTraits = cutlass::gemm::kernel::MixedGemmArchTraits; + using ElementAccumulator = typename MixedGemmArchTraits::AccType; + + using EpilogueOp = typename tensorrt_llm::cutlass_extensions::Epilogue::Op; + + // Finally, set up the kernel. + using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemmGrouped::GemmKernel; + + using GemmKernel = cutlass::gemm::kernel::MoeFCGemm; + + using GemmGrouped = cutlass::gemm::device::GemmGrouped; + + if (kernel_occupancy != nullptr) + { + *kernel_occupancy = tensorrt_llm::cutlass_extensions::compute_occupancy_for_kernel(); + return; + } + int occupancy = std::min(2, GemmGrouped::maximum_active_blocks()); + TLLM_CHECK_WITH_INFO(occupancy != 0, "GPU lacks the shared memory resources to run GroupedGEMM kernel"); + const int threadblock_count = multi_processor_count * occupancy; + + typename EpilogueOp::Params epilogue_op( + ElementAccumulator(1.f), biases ? ElementAccumulator(1.f) : ElementAccumulator(0.f)); + + const int group_size = gemm_k; + typename GemmGrouped::Arguments args(num_experts, threadblock_count, group_size, epilogue_op, + reinterpret_cast(A), reinterpret_cast(B), + reinterpret_cast(weight_scales), reinterpret_cast(biases), + reinterpret_cast(C), total_rows_before_expert, gemm_n, gemm_k); + + GemmGrouped gemm; + + auto can_implement = gemm.can_implement(args); + TLLM_CHECK_WITH_INFO(can_implement == cutlass::Status::kSuccess, + "MoE FC kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement))); + + auto init_status = gemm.initialize(args); + TLLM_CHECK_WITH_INFO(init_status == cutlass::Status::kSuccess, + "Failed to initialize cutlass variable batched gemm. Error: " + + std::string(cutlassGetStatusString(init_status))); + + auto run_status = gemm.run(stream); + TLLM_CHECK_WITH_INFO(run_status == cutlass::Status::kSuccess, + "Failed to run cutlass variable batched gemm. Error: " + std::string(cutlassGetStatusString(run_status))); +} + +template +struct dispatch_stages +{ + static void dispatch(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, + int* occupancy = nullptr) + { + TLLM_THROW("Cutlass fpA_intB gemm. Not instantiated for arch %d with stages set to %d", + arch::kMinComputeCapability, Stages); + } +}; + +template +struct dispatch_stages +{ + static void dispatch(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, + int* occupancy = nullptr) + { + genericMoeGemmKernelLauncher(A, B, + weight_scales, biases, C, total_rows_before_expert, num_rows, gemm_n, gemm_k, num_experts, gemm_config, + multi_processor_count, stream, occupancy); + } +}; + +template +struct dispatch_stages 2)>::type> +{ + static void dispatch(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, + int* occupancy = nullptr) + { + genericMoeGemmKernelLauncher(A, B, weight_scales, biases, C, total_rows_before_expert, num_rows, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count, stream, occupancy); + } +}; + +template +void dispatchGemmConfig(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, + int* occupancy = nullptr) +{ + switch (gemm_config.stages) + { + case 2: + using DispatcherStages2 = dispatch_stages; + DispatcherStages2::dispatch(A, B, weight_scales, biases, C, total_rows_before_expert, num_rows, gemm_n, gemm_k, + num_experts, gemm_config, multi_processor_count, stream, occupancy); + break; + case 3: + using DispatcherStages3 = dispatch_stages; + DispatcherStages3::dispatch(A, B, weight_scales, biases, C, total_rows_before_expert, num_rows, gemm_n, gemm_k, + num_experts, gemm_config, multi_processor_count, stream, occupancy); + break; + case 4: + using DispatcherStages4 = dispatch_stages; + DispatcherStages4::dispatch(A, B, weight_scales, biases, C, total_rows_before_expert, num_rows, gemm_n, gemm_k, + num_experts, gemm_config, multi_processor_count, stream, occupancy); + break; + default: TLLM_THROW("dispatchGemmConfig does not support stages %d", gemm_config.stages); break; + } +} + +// This overload will handle tensorop gemms. It is disabled via SFINAE for fp32. +// This overload is only enabled when T == WeightType. +template ::value && std::is_same::value>::type* = nullptr> +void dispatchMoeGemmToCutlass(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, int sm_version, int multi_processor_count, cudaStream_t stream, + int* occupancy = nullptr) +{ + switch (gemm_config.tile_config) + { + case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, + gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<32, 64, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, + gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, + gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break; + case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic: + TLLM_THROW("GEMM config should have already been set by heuristic."); + break; + default: TLLM_THROW("Config is invalid for same type tensorop GEMM."); break; + } +} + +// Tensorop GEMM overload +// Overload for quantize MoE GEMMs. We disable some warp configs here since they will not be used and we can improve +// compile time +template ::value && !std::is_same::value>::type* = nullptr> +void dispatchMoeGemmToCutlass(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, int sm_version, int multi_processor_count, cudaStream_t stream, + int* occupancy = nullptr) +{ + switch (gemm_config.tile_config) + { + case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, + gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, + gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<128, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, + gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break; + case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic: + TLLM_THROW("GEMM config should have already been set by heuristic."); + break; + default: TLLM_THROW("Config is invalid for mixed type tensorop GEMM."); break; + } +} + +// This overload will handle simt gemms. It is disabled via SFINAE for tensorop. +template ::value>::type* = nullptr> +void dispatchMoeGemmToCutlass(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, int sm_version, int multi_processor_count, cudaStream_t stream, + int* occupancy = nullptr) +{ + switch (gemm_config.tile_config) + { + case cutlass_extensions::CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8: + dispatchGemmConfig, + cutlass::gemm::GemmShape<64, 64, 8>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, + gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break; + case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic: + TLLM_THROW("GEMM config should have already been set by heuristic."); + break; + default: TLLM_THROW("Unsupported config for float MoE gemm."); break; + } +} + +template +std::vector MoeGemmRunner::getConfigs() +{ + static constexpr bool is_weight_only = !std::is_same::value; + static constexpr bool only_simt_configs = std::is_same::value; + std::vector candidate_configs + = kernels::cutlass_kernels::get_candidate_configs(sm_, is_weight_only, only_simt_configs); + return candidate_configs; +} + +template +MoeGemmRunner::MoeGemmRunner() +{ + int device{-1}; + tensorrt_llm::common::check_cuda_error(cudaGetDevice(&device)); + sm_ = tensorrt_llm::common::getSMVersion(); + tensorrt_llm::common::check_cuda_error( + cudaDeviceGetAttribute(&multi_processor_count_, cudaDevAttrMultiProcessorCount, device)); +} + +template +template +void MoeGemmRunner::dispatchToArch(const T* A, const WeightType* B, const T* weight_scales, + const T* biases, T* C, int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, + int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, cudaStream_t stream, int* occupancy) +{ + if (sm_ >= 70 && sm_ < 75) + { + dispatchMoeGemmToCutlass(A, B, weight_scales, biases, C, + total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, sm_, multi_processor_count_, + stream, occupancy); + } + else if (sm_ >= 75 && sm_ < 80) + { + dispatchMoeGemmToCutlass(A, B, weight_scales, biases, C, + total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, sm_, multi_processor_count_, + stream, occupancy); + } + else if (sm_ >= 80 && sm_ < 90) + { + dispatchMoeGemmToCutlass(A, B, weight_scales, biases, C, + total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, sm_, multi_processor_count_, + stream, occupancy); + } + else if (sm_ >= 90) + { + // TODO Update the arch to Sm90 once CUTLASS hopper specialisations are available + dispatchMoeGemmToCutlass(A, B, weight_scales, biases, C, + total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, sm_, multi_processor_count_, + stream, occupancy); + } + else + { + TLLM_THROW("Arch unsupported for MoE GEMM"); + } +} + +template +template +void MoeGemmRunner::runGemm(const T* A, const WeightType* B, const T* weight_scales, + const T* biases, T* C, int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, + int num_experts, cudaStream_t stream) +{ + auto chosen_conf = this->best_config_; + if (!chosen_conf) + { + auto candidate_configs = getConfigs(); + std::vector occupancies(candidate_configs.size()); + + for (size_t ii = 0; ii < candidate_configs.size(); ++ii) + { + dispatchToArch(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, + gemm_k, num_experts, candidate_configs[ii], stream, &occupancies[ii]); + } + + static constexpr int workspace_bytes = 0; // No workspace for MoE GEMMs. + static constexpr int split_k_limit = 1; // MoE GEMM does not support split-k. + + static constexpr bool is_weight_only = !std::is_same::value; + chosen_conf = kernels::cutlass_kernels::estimate_best_config_from_occupancies(candidate_configs, occupancies, + total_rows, gemm_n, gemm_k, num_experts, split_k_limit, workspace_bytes, multi_processor_count_, + is_weight_only); + } + assert(chosen_conf); + dispatchToArch(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, + num_experts, *chosen_conf, stream); +} + +template +void MoeGemmRunner::moeGemmBiasAct(const T* A, const WeightType* B, const T* weight_scales, + const T* biases, T* C, int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, + int num_experts, ActivationType activation_type, cudaStream_t stream) +{ + switch (activation_type) + { + case ActivationType::Relu: + runGemm( + A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, stream); + break; + case ActivationType::Gelu: + runGemm( + A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, stream); + break; + case ActivationType::Silu: + runGemm( + A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, stream); + break; + case ActivationType::Identity: + runGemm( + A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, stream); + break; + case ActivationType::InvalidType: TLLM_THROW("Activation type for fpA_intB must be valid."); break; + default: TLLM_THROW("Invalid activation type."); break; + } +} + +template +void MoeGemmRunner::moeGemm(const T* A, const WeightType* B, const T* weight_scales, T* C, + int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cudaStream_t stream) +{ + runGemm( + A, B, weight_scales, nullptr, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, stream); +} + +} // namespace tensorrt_llm From cfa4554708befedfe6cdba0af37310e6244e1895 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 31 Jan 2024 10:32:29 +0000 Subject: [PATCH 06/33] Port CUTLASS kernels --- csrc/cutlass_kernels/cutlass_heuristic.cpp | 286 +++++++ csrc/cutlass_kernels/cutlass_heuristic.h | 40 + .../cutlass_kernels/cutlass_preprocessors.cpp | 761 ++++++++++++++++++ csrc/cutlass_kernels/cutlass_preprocessors.h | 64 ++ .../bf16_int4_gemm_fg_scalebias.cu | 31 + .../bf16_int4_gemm_fg_scaleonly.cu | 31 + .../fpA_intB_gemm/bf16_int4_gemm_per_col.cu | 31 + .../bf16_int8_gemm_fg_scalebias.cu | 31 + .../bf16_int8_gemm_fg_scaleonly.cu | 30 + .../fpA_intB_gemm/bf16_int8_gemm_per_col.cu | 30 + .../fp16_int4_gemm_fg_scalebias.cu | 29 + .../fp16_int4_gemm_fg_scaleonly.cu | 28 + .../fpA_intB_gemm/fp16_int4_gemm_per_col.cu | 28 + .../fp16_int8_gemm_fg_scalebias.cu | 28 + .../fp16_int8_gemm_fg_scaleonly.cu | 28 + .../fpA_intB_gemm/fp16_int8_gemm_per_col.cu | 28 + .../fpA_intB_gemm/fpA_intB_gemm.h | 120 +++ .../fpA_intB_gemm/fpA_intB_gemm_template.h | 487 +++++++++++ csrc/cutlass_kernels/int8_gemm/int8_gemm.h | 93 +++ .../int8_gemm/int8_gemm_bf16.cu | 32 + .../int8_gemm/int8_gemm_fp16.cu | 30 + .../int8_gemm/int8_gemm_fp32.cu | 30 + .../int8_gemm/int8_gemm_int32.cu | 30 + .../int8_gemm/int8_gemm_template.h | 388 +++++++++ .../moe_gemm/moe_gemm_kernels.h | 81 ++ .../moe_gemm/moe_gemm_kernels_bf16_bf16.cu | 24 + .../moe_gemm/moe_gemm_kernels_bf16_uint4.cu | 24 + .../moe_gemm/moe_gemm_kernels_bf16_uint8.cu | 24 + .../moe_gemm/moe_gemm_kernels_fp16_fp16.cu | 22 + .../moe_gemm/moe_gemm_kernels_fp16_uint4.cu | 22 + .../moe_gemm/moe_gemm_kernels_fp16_uint8.cu | 22 + .../moe_gemm/moe_gemm_kernels_fp32_fp32.cu | 22 + .../moe_gemm/moe_gemm_kernels_template.h | 440 ++++++++++ 33 files changed, 3395 insertions(+) create mode 100644 csrc/cutlass_kernels/cutlass_heuristic.cpp create mode 100644 csrc/cutlass_kernels/cutlass_heuristic.h create mode 100644 csrc/cutlass_kernels/cutlass_preprocessors.cpp create mode 100644 csrc/cutlass_kernels/cutlass_preprocessors.h create mode 100644 csrc/cutlass_kernels/fpA_intB_gemm/bf16_int4_gemm_fg_scalebias.cu create mode 100644 csrc/cutlass_kernels/fpA_intB_gemm/bf16_int4_gemm_fg_scaleonly.cu create mode 100644 csrc/cutlass_kernels/fpA_intB_gemm/bf16_int4_gemm_per_col.cu create mode 100644 csrc/cutlass_kernels/fpA_intB_gemm/bf16_int8_gemm_fg_scalebias.cu create mode 100644 csrc/cutlass_kernels/fpA_intB_gemm/bf16_int8_gemm_fg_scaleonly.cu create mode 100644 csrc/cutlass_kernels/fpA_intB_gemm/bf16_int8_gemm_per_col.cu create mode 100644 csrc/cutlass_kernels/fpA_intB_gemm/fp16_int4_gemm_fg_scalebias.cu create mode 100644 csrc/cutlass_kernels/fpA_intB_gemm/fp16_int4_gemm_fg_scaleonly.cu create mode 100644 csrc/cutlass_kernels/fpA_intB_gemm/fp16_int4_gemm_per_col.cu create mode 100644 csrc/cutlass_kernels/fpA_intB_gemm/fp16_int8_gemm_fg_scalebias.cu create mode 100644 csrc/cutlass_kernels/fpA_intB_gemm/fp16_int8_gemm_fg_scaleonly.cu create mode 100644 csrc/cutlass_kernels/fpA_intB_gemm/fp16_int8_gemm_per_col.cu create mode 100644 csrc/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h create mode 100644 csrc/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h create mode 100644 csrc/cutlass_kernels/int8_gemm/int8_gemm.h create mode 100644 csrc/cutlass_kernels/int8_gemm/int8_gemm_bf16.cu create mode 100644 csrc/cutlass_kernels/int8_gemm/int8_gemm_fp16.cu create mode 100644 csrc/cutlass_kernels/int8_gemm/int8_gemm_fp32.cu create mode 100644 csrc/cutlass_kernels/int8_gemm/int8_gemm_int32.cu create mode 100644 csrc/cutlass_kernels/int8_gemm/int8_gemm_template.h create mode 100644 csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels.h create mode 100644 csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_bf16.cu create mode 100644 csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint4.cu create mode 100644 csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint8.cu create mode 100644 csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp16.cu create mode 100644 csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint4.cu create mode 100644 csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint8.cu create mode 100644 csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp32_fp32.cu create mode 100644 csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h diff --git a/csrc/cutlass_kernels/cutlass_heuristic.cpp b/csrc/cutlass_kernels/cutlass_heuristic.cpp new file mode 100644 index 000000000000..db77569e374d --- /dev/null +++ b/csrc/cutlass_kernels/cutlass_heuristic.cpp @@ -0,0 +1,286 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h" +#include "tensorrt_llm/common/cudaBf16Wrapper.h" + +#ifndef _WIN32 +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // #ifndef _WIN32 + +#include "cutlass/gemm/gemm.h" +#include "cutlass/numeric_types.h" + +#ifndef _WIN32 +#pragma GCC diagnostic pop +#endif // #ifndef _WIN32 + +#include +#include + +using namespace tensorrt_llm::cutlass_extensions; + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ + +struct TileShape +{ + int m; + int n; +}; + +TileShape get_cta_shape_for_config(CutlassTileConfig tile_config) +{ + switch (tile_config) + { + case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: return TileShape{32, 128}; + case CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64: return TileShape{64, 64}; + case CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64: + case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: return TileShape{64, 128}; + case CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64: return TileShape{128, 64}; + case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8: + case CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64: + case CutlassTileConfig::CtaShape128x128x64_WarpShape64x64x64: + case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: return TileShape{128, 128}; + case CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64: return TileShape{128, 256}; + case CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64: return TileShape{256, 128}; + default: throw std::runtime_error("[TensorRT-LLm Error][get_grid_shape_for_config] Invalid config"); + } +} + +bool is_valid_split_k_factor(const int64_t m, const int64_t n, const int64_t k, const TileShape tile_shape, + const int split_k_factor, const size_t workspace_bytes, const bool is_weight_only) +{ + + // All tile sizes have a k_tile of 64. + static constexpr int k_tile = 64; + + // For weight-only quant, we need k and k_elements_per_split to be a multiple of cta_k + if (is_weight_only) + { + if ((k % k_tile) != 0) + { + return false; + } + + if ((k % split_k_factor) != 0) + { + return false; + } + + const int k_elements_per_split = k / split_k_factor; + if ((k_elements_per_split % k_tile) != 0) + { + return false; + } + } + + // Check that the workspace has sufficient space for this split-k factor + const int ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m; + const int ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n; + const int required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim; + + if (required_ws_bytes > workspace_bytes) + { + return false; + } + + return true; +} + +std::vector get_candidate_tiles( + const int sm, const bool is_weight_only, const bool simt_configs_only, const bool int8_configs_only) +{ + enum class CutlassGemmType : char + { + Default, + WeightOnly, + Simt, + Int8 + }; + + CutlassGemmType gemm_type = CutlassGemmType::Default; + if (simt_configs_only) + { + gemm_type = CutlassGemmType::Simt; + } + else if (is_weight_only) + { + gemm_type = CutlassGemmType::WeightOnly; + } + else if (int8_configs_only) + { + gemm_type = CutlassGemmType::Int8; + } + + std::vector base_configs{ + CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64}; + if (sm >= 75) + { + base_configs.push_back(CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64); + } + + switch (gemm_type) + { + case CutlassGemmType::Simt: return {CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8}; + case CutlassGemmType::WeightOnly: + if (sm >= 75) + { + return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64}; + } + else + { + return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64}; + } + case CutlassGemmType::Int8: + return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64, + CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64, + CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64}; + default: return base_configs; + } +} + +std::vector get_candidate_configs(int sm, const bool is_weight_only, const bool simt_configs_only, + const bool int8_configs_only, const int max_split_k) +{ + std::vector tiles + = get_candidate_tiles(sm, is_weight_only, simt_configs_only, int8_configs_only); + + std::vector candidate_configs; + const int min_stages = int8_configs_only ? 3 : 2; + const int max_stages = int8_configs_only ? 6 : (sm >= 80 ? 4 : 2); + for (const auto& tile_config : tiles) + { + for (int stages = min_stages; stages <= max_stages; ++stages) + { + CutlassGemmConfig config{tile_config, SplitKStyle::NO_SPLIT_K, 1, stages}; + candidate_configs.push_back(config); + if (sm >= 75) + { + for (int split_k_factor = 2; split_k_factor <= max_split_k; ++split_k_factor) + { + auto config = CutlassGemmConfig{tile_config, SplitKStyle::SPLIT_K_SERIAL, split_k_factor, stages}; + candidate_configs.push_back(config); + } + } + } + } + + return candidate_configs; +} + +CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector& candidate_configs, + const std::vector& occupancies, const int64_t m, const int64_t n, const int64_t k, const int64_t num_experts, + const int split_k_limit, const size_t workspace_bytes, const int multi_processor_count, const int is_weight_only) +{ + + if (occupancies.size() != candidate_configs.size()) + { + throw std::runtime_error( + "[TensorRT-LLm Error][estimate_best_config_from_occupancies] occpancies and " + "candidate configs vectors must have equal length."); + } + + CutlassGemmConfig best_config; + // Score will be [0, 1]. The objective is to minimize this score. + // It represents the fraction of SM resources unused in the last wave. + float config_score = 1.0f; + int config_waves = INT_MAX; + int current_m_tile = 0; + + const int max_split_k = n >= multi_processor_count * 256 ? 1 : split_k_limit; + for (int ii = 0; ii < candidate_configs.size(); ++ii) + { + CutlassGemmConfig candidate_config = candidate_configs[ii]; + TileShape tile_shape = get_cta_shape_for_config(candidate_config.tile_config); + int occupancy = occupancies[ii]; + + if (occupancy == 0) + { + continue; + } + + // Keep small tile sizes when possible. + if (best_config.tile_config != CutlassTileConfig::ChooseWithHeuristic && m < current_m_tile + && current_m_tile < tile_shape.m) + { + continue; + } + + const int ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m; + const int ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n; + + for (int split_k_factor = 1; split_k_factor <= max_split_k; ++split_k_factor) + { + if (is_valid_split_k_factor(m, n, k, tile_shape, split_k_factor, workspace_bytes, is_weight_only)) + { + const int ctas_per_wave = occupancy * multi_processor_count; + const int ctas_for_problem = ctas_in_m_dim * ctas_in_n_dim * split_k_factor; + + const int num_waves_total = (ctas_for_problem + ctas_per_wave - 1) / ctas_per_wave; + const float num_waves_fractional = ctas_for_problem / float(ctas_per_wave); + const float current_score = float(num_waves_total) - num_waves_fractional; + + const float score_slack = 0.1f; + if (current_score < config_score + || ((config_waves > num_waves_total) && (current_score < config_score + score_slack))) + { + config_score = current_score; + config_waves = num_waves_total; + SplitKStyle split_style + = split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K; + best_config = CutlassGemmConfig{ + candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages}; + current_m_tile = tile_shape.m; + } + else if (current_score == config_score + && (best_config.stages < candidate_config.stages || split_k_factor < best_config.split_k_factor + || current_m_tile < tile_shape.m)) + { + // Prefer deeper pipeline or smaller split-k + SplitKStyle split_style + = split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K; + best_config = CutlassGemmConfig{ + candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages}; + current_m_tile = tile_shape.m; + config_waves = num_waves_total; + } + } + } + } + + if (best_config.tile_config == CutlassTileConfig::ChooseWithHeuristic) + { + throw std::runtime_error("[TensorRT-LLm Error] Heurisitc failed to find a valid config."); + } + + return best_config; +} + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/cutlass_heuristic.h b/csrc/cutlass_kernels/cutlass_heuristic.h new file mode 100644 index 000000000000..071998406be2 --- /dev/null +++ b/csrc/cutlass_kernels/cutlass_heuristic.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cutlass_extensions/gemm_configs.h" +#include "tensorrt_llm/common/cudaUtils.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ + +std::vector get_candidate_configs(int sm, + const bool is_weight_only, const bool simt_configs_only, const bool int8_configs_only = false, + const int max_split_k = 1); + +tensorrt_llm::cutlass_extensions::CutlassGemmConfig estimate_best_config_from_occupancies( + const std::vector& candidate_configs, + const std::vector& occupancies, const int64_t m, const int64_t n, const int64_t k, const int64_t num_experts, + const int split_k_limit, const size_t workspace_bytes, const int multi_processor_count, const int is_weight_only); + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/cutlass_preprocessors.cpp b/csrc/cutlass_kernels/cutlass_preprocessors.cpp new file mode 100644 index 000000000000..24d9af815915 --- /dev/null +++ b/csrc/cutlass_kernels/cutlass_preprocessors.cpp @@ -0,0 +1,761 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.h" +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/cudaBf16Wrapper.h" +#include "tensorrt_llm/common/stringUtils.h" + +#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" + +using namespace tensorrt_llm::common; + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ + +int get_bits_in_quant_type(QuantType quant_type) +{ + switch (quant_type) + { + case QuantType::INT8_WEIGHT_ONLY: return 8; + case QuantType::PACKED_INT4_WEIGHT_ONLY: return 4; + default: TLLM_CHECK_WITH_INFO(false, "Invalid quant_type"); return -1; + } +} + +struct LayoutDetails +{ + enum class Layout + { + UNKNOWN, + ROW_MAJOR, + COLUMN_MAJOR + }; + + Layout layoutB = Layout::UNKNOWN; + int rows_per_column_tile = 1; + int columns_interleaved = 1; + + bool uses_imma_ldsm = false; +}; + +template +struct getLayoutDetails +{ +}; + +template <> +struct getLayoutDetails +{ + LayoutDetails operator()() + { + LayoutDetails layout_details; + layout_details.layoutB = LayoutDetails::Layout::ROW_MAJOR; + return layout_details; + } +}; + +template <> +struct getLayoutDetails +{ + LayoutDetails operator()() + { + LayoutDetails layout_details; + layout_details.layoutB = LayoutDetails::Layout::COLUMN_MAJOR; + return layout_details; + } +}; + +template +struct getLayoutDetails> +{ + LayoutDetails operator()() + { + LayoutDetails layout_details; + layout_details.layoutB = LayoutDetails::Layout::COLUMN_MAJOR; + layout_details.rows_per_column_tile = RowsPerTile; + layout_details.columns_interleaved = ColumnsInterleaved; + return layout_details; + } +}; + +template +LayoutDetails getLayoutDetailsForArchAndQuantType() +{ + + using CompileTraits = cutlass::gemm::kernel::LayoutDetailsB; + using LayoutB = typename CompileTraits::Layout; + using MmaOperator = typename CompileTraits::Operator; + LayoutDetails details = getLayoutDetails()(); + details.uses_imma_ldsm = std::is_same::value; + return details; +} + +template +LayoutDetails getLayoutDetailsForArch(QuantType quant_type) +{ + LayoutDetails details; + if (quant_type == QuantType::INT8_WEIGHT_ONLY) + { + details = getLayoutDetailsForArchAndQuantType(); + } + else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) + { + details = getLayoutDetailsForArchAndQuantType(); + } + else + { + TLLM_CHECK_WITH_INFO(false, "Unsupported quantization type"); + } + return details; +} + +LayoutDetails getLayoutDetailsForTransform(QuantType quant_type) +{ + const int arch = getSMVersion(); + if (arch >= 70 && arch < 75) + { + return getLayoutDetailsForArch(quant_type); + } + else if (arch >= 75 && arch < 80) + { + return getLayoutDetailsForArch(quant_type); + } + else if (arch >= 80 && arch <= 90) + { + return getLayoutDetailsForArch(quant_type); + } + else + { + TLLM_CHECK_WITH_INFO(false, "Unsupported Arch"); + return LayoutDetails(); + } +} + +// Permutes the rows of B for Turing and Ampere. Throws an error for other architectures. +// The data is permuted such that: +// For int8, each group of 16 rows is permuted using the map below: +// 0 1 8 9 2 3 10 11 4 5 12 13 6 7 14 15 +// For int4, each group of 32 rows is permuted using the map below: +// 0 1 8 9 16 17 24 25 2 3 10 11 18 19 26 27 4 5 12 13 20 21 28 29 6 7 14 15 22 23 30 31 +void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, const int8_t* quantized_tensor, + const std::vector& shape, QuantType quant_type, const int64_t arch_version) +{ + + // We only want to run this step for weight only quant. + TLLM_CHECK(quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY || quant_type == QuantType::INT8_WEIGHT_ONLY); + + TLLM_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); + const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; + const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; + const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; + + const int BITS_PER_ELT = get_bits_in_quant_type(quant_type); + const int K = 16 / BITS_PER_ELT; + const int ELTS_PER_BYTE = 8 / BITS_PER_ELT; + const int ELTS_PER_REG = 32 / BITS_PER_ELT; + + const uint32_t* input_byte_ptr = reinterpret_cast(quantized_tensor); + uint32_t* output_byte_ptr = reinterpret_cast(permuted_quantized_tensor); + + int MMA_SHAPE_N = 8; + int B_ROWS_PER_MMA = 8 * K; + const int elts_in_int32 = 32 / BITS_PER_ELT; + + const int num_vec_cols = num_cols / elts_in_int32; + + TLLM_CHECK_WITH_INFO( + arch_version >= 75, "Unsupported Arch. Pre-volta not supported. Column interleave not needed on Volta."); + + TLLM_CHECK_WITH_INFO(num_rows % B_ROWS_PER_MMA == 0, + fmtstr("Invalid shape for quantized tensor. Number of rows of quantized matrix must be a multiple of %d", + B_ROWS_PER_MMA)); + TLLM_CHECK_WITH_INFO(num_cols % MMA_SHAPE_N == 0, + fmtstr("Invalid shape for quantized tensor. On turing/Ampere, the number of cols must be a multiple of %d.", + MMA_SHAPE_N)); + + // The code is written as below so it works for both int8 and packed int4. + for (int expert = 0; expert < num_experts; ++expert) + { + const int64_t matrix_offset = expert * int64_t(num_rows) * int64_t(num_vec_cols); + for (int base_row = 0; base_row < num_rows; base_row += B_ROWS_PER_MMA) + { + for (int tile_row = 0; tile_row < B_ROWS_PER_MMA; ++tile_row) + { + + for (int write_col = 0; write_col < num_vec_cols; ++write_col) + { + const int write_row = base_row + tile_row; + const int tile_read_row + = 8 * (((tile_row % ELTS_PER_REG) / 2)) + tile_row % 2 + 2 * (tile_row / ELTS_PER_REG); + const int read_row = base_row + tile_read_row; + const int read_col = write_col; + + const int64_t read_offset = matrix_offset + int64_t(read_row) * num_vec_cols + read_col; + const int64_t write_offset = matrix_offset + int64_t(write_row) * num_vec_cols + write_col; + + output_byte_ptr[write_offset] = input_byte_ptr[read_offset]; + } + } + } + } +} + +// We need to use this transpose to correctly handle packed int4 and int8 data +// The reason this code is relatively complex is that the "trivial" loops took a substantial +// amount of time to transpose leading to long preprocessing times. This seemed to be a big +// issue for relatively large models. +template +void subbyte_transpose_impl( + int8_t* transposed_quantized_tensor, const int8_t* quantized_tensor, const std::vector& shape) +{ + const int bits_per_elt = get_bits_in_quant_type(quant_type); + + TLLM_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); + const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; + const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; + const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; + + const size_t col_bytes = num_cols * bits_per_elt / 8; + const size_t col_bytes_trans = num_rows * bits_per_elt / 8; + const size_t num_bytes = size_t(num_experts) * num_rows * col_bytes; + + const uint8_t* input_byte_ptr = reinterpret_cast(quantized_tensor); + uint8_t* output_byte_ptr = reinterpret_cast(transposed_quantized_tensor); + + static_assert(quant_type == QuantType::INT8_WEIGHT_ONLY || quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY, ""); + static constexpr int ELTS_PER_BYTE = quant_type == QuantType::INT8_WEIGHT_ONLY ? 1 : 2; + + static constexpr int M_TILE_L1 = 64; + static constexpr int N_TILE_L1 = M_TILE_L1 / ELTS_PER_BYTE; + uint8_t cache_buf[M_TILE_L1][N_TILE_L1]; + + static constexpr int VECTOR_WIDTH = std::min(32, N_TILE_L1); + + // We assume the dims are a multiple of vector width. Our kernels only handle dims which are multiples + // of 64 for weight-only quantization. As a result, this seemed like a reasonable tradeoff because it + // allows GCC to emit vector instructions. + TLLM_CHECK_WITH_INFO(!(col_bytes_trans % VECTOR_WIDTH) && !(col_bytes % VECTOR_WIDTH), + fmtstr("Number of bytes for rows and cols must be a multiple of %d. However, num_rows_bytes = %ld and " + "num_col_bytes = %ld.", + VECTOR_WIDTH, col_bytes_trans, col_bytes)); + + const int num_m_tiles = (num_rows + M_TILE_L1 - 1) / M_TILE_L1; + const int num_n_tiles = (col_bytes + N_TILE_L1 - 1) / N_TILE_L1; + + for (size_t expert = 0; expert < num_experts; ++expert) + { + const size_t matrix_offset = expert * num_rows * col_bytes; + for (size_t row_tile_start = 0; row_tile_start < num_rows; row_tile_start += M_TILE_L1) + { + for (size_t col_tile_start_byte = 0; col_tile_start_byte < col_bytes; col_tile_start_byte += N_TILE_L1) + { + + const int row_limit = std::min(row_tile_start + M_TILE_L1, num_rows); + const int col_limit = std::min(col_tile_start_byte + N_TILE_L1, col_bytes); + + for (int ii = 0; ii < M_TILE_L1; ++ii) + { + const int row = row_tile_start + ii; + + for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) + { + const int col = col_tile_start_byte + jj; + + const size_t logical_src_offset = matrix_offset + row * col_bytes + col; + + if (row < row_limit && col < col_limit) + { + for (int v = 0; v < VECTOR_WIDTH; ++v) + { + cache_buf[ii][jj + v] = input_byte_ptr[logical_src_offset + v]; + } + } + } + } + + if (quant_type == QuantType::INT8_WEIGHT_ONLY) + { + for (int ii = 0; ii < M_TILE_L1; ++ii) + { + for (int jj = ii + 1; jj < N_TILE_L1; ++jj) + { + std::swap(cache_buf[ii][jj], cache_buf[jj][ii]); + } + } + } + else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) + { + + for (int ii = 0; ii < M_TILE_L1; ++ii) + { + // Using M_TILE_L1 here is deliberate since we assume that the cache tile + // is square in the number of elements (not necessarily the number of bytes). + for (int jj = ii + 1; jj < M_TILE_L1; ++jj) + { + const int ii_byte = ii / ELTS_PER_BYTE; + const int ii_bit_offset = ii % ELTS_PER_BYTE; + + const int jj_byte = jj / ELTS_PER_BYTE; + const int jj_bit_offset = jj % ELTS_PER_BYTE; + + uint8_t src_elt = 0xF & (cache_buf[ii][jj_byte] >> (4 * jj_bit_offset)); + uint8_t tgt_elt = 0xF & (cache_buf[jj][ii_byte] >> (4 * ii_bit_offset)); + + cache_buf[ii][jj_byte] &= (0xF0 >> (4 * jj_bit_offset)); + cache_buf[jj][ii_byte] &= (0xF0 >> (4 * ii_bit_offset)); + + cache_buf[ii][jj_byte] |= (tgt_elt << (4 * jj_bit_offset)); + cache_buf[jj][ii_byte] |= (src_elt << (4 * ii_bit_offset)); + } + } + } + else + { + TLLM_CHECK_WITH_INFO(false, "Unsupported quantization type."); + } + + const size_t row_tile_start_trans = col_tile_start_byte * ELTS_PER_BYTE; + const size_t col_tile_start_byte_trans = row_tile_start / ELTS_PER_BYTE; + + const int row_limit_trans = std::min(row_tile_start_trans + M_TILE_L1, num_cols); + const int col_limit_trans = std::min(col_tile_start_byte_trans + N_TILE_L1, col_bytes_trans); + + for (int ii = 0; ii < M_TILE_L1; ++ii) + { + const int row = row_tile_start_trans + ii; + for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) + { + const int col = col_tile_start_byte_trans + jj; + + const size_t logical_tgt_offset = matrix_offset + row * col_bytes_trans + col; + + if (row < row_limit_trans && col < col_limit_trans) + { + for (int v = 0; v < VECTOR_WIDTH; ++v) + { + output_byte_ptr[logical_tgt_offset + v] = cache_buf[ii][jj + v]; + } + } + } + } + } + } + } +} + +void subbyte_transpose(int8_t* transposed_quantized_tensor, const int8_t* quantized_tensor, + const std::vector& shape, QuantType quant_type) +{ + + if (quant_type == QuantType::INT8_WEIGHT_ONLY) + { + subbyte_transpose_impl(transposed_quantized_tensor, quantized_tensor, shape); + } + else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) + { + subbyte_transpose_impl( + transposed_quantized_tensor, quantized_tensor, shape); + } + else + { + TLLM_CHECK_WITH_INFO(false, "Invalid quant_tye"); + } +} + +void add_bias_and_interleave_int8s_inplace(int8_t* int8_tensor, const size_t num_elts) +{ + for (int ii = 0; ii < num_elts; ++ii) + { + int8_tensor[ii] = int8_t(int(int8_tensor[ii]) + 128); + } + + // Step 2 will transform the layout of a 32-bit register in CUDA in order to match the int4 layout. This has no + // performance benefit and is purely so that int4 and int8 have the same layout. + // Pictorially, this does the following: + // bit 32 0 + // [elt_3 elt_2 elt_1 elt_0] (each elt occupies 8 bits) + // + // And it will rearrange the output 32 bit register to be the following: + // bit 32 0 + // [elt_3 elt_1 elt_2 elt_0] (each elt occupies 8 bits) + + TLLM_CHECK_WITH_INFO(num_elts % 4 == 0, "Dimensions of int8 tensor must be a multiple of 4 for register relayout"); + for (size_t base = 0; base < num_elts; base += 4) + { + std::swap(int8_tensor[base + 1], int8_tensor[base + 2]); + } +} + +void add_bias_and_interleave_int4s_inplace(int8_t* packed_int4_tensor, const size_t num_elts) +{ + const int num_bytes = num_elts / 2; + + // Step 1 will be to transform all the int4s to unsigned in order to make the dequantize take as little + // instructions as possible in the CUDA code. + for (size_t ii = 0; ii < num_bytes; ++ii) + { + int8_t transformed_packed_int4s = 0; + int8_t transformed_first_elt + = (int8_t(packed_int4_tensor[ii] << 4) >> 4) + 8; // The double shift here is to ensure sign extension + int8_t transformed_second_elt = (packed_int4_tensor[ii] >> 4) + 8; + + TLLM_CHECK_WITH_INFO( + transformed_first_elt >= 0 && transformed_first_elt <= 15, "Illegal result for int4 transform (first elt)"); + TLLM_CHECK_WITH_INFO(transformed_second_elt >= 0 && transformed_second_elt <= 15, + "Illegal result for int4 transform (second elt)"); + + // We don't need to mask in these ops since everything should be in the range 0-15 + transformed_packed_int4s |= transformed_first_elt; + transformed_packed_int4s |= (transformed_second_elt << 4); + packed_int4_tensor[ii] = transformed_packed_int4s; + } + + // Step 2 will transform the layout of a 32-bit register in CUDA in order to minimize the number of shift & logical + // instructions That are needed to extract the int4s in the GEMM main loop. Pictorially, the loop below will do the + // following: Take as input a 32 bit register with layout: bit 32 0 + // [elt_7 elt_6 elt_5 elt_4 elt_3 elt_2 elt_1 elt_0] (each elt occupies 4 bits) + // + // And it will rearrange the output 32 bit register to be the following: + // bit 32 0 + // [elt_7 elt_5 elt_3 elt_1 elt_6 elt_4 elt_2 elt_0] (each elt occupies 4 bits) + + TLLM_CHECK_WITH_INFO(num_bytes % 4 == 0, "Dimensions of int4 tensor must be a multiple of 8 for register relayout"); + const size_t num_registers = num_bytes / 4; + + uint32_t* register_ptr = reinterpret_cast(packed_int4_tensor); + for (size_t ii = 0; ii < num_registers; ++ii) + { + const uint32_t current_register = register_ptr[ii]; + uint32_t transformed_register = 0; + + for (int dest_idx = 0; dest_idx < 8; ++dest_idx) + { + const int src_idx = dest_idx < 4 ? 2 * dest_idx : 2 * (dest_idx - 4) + 1; + const int src_shift = 4 * src_idx; + const int dest_shift = 4 * dest_idx; + + const uint32_t src_bits = (current_register >> src_shift) & 0xF; + transformed_register |= (src_bits << dest_shift); + } + register_ptr[ii] = transformed_register; + } +} + +void add_bias_and_interleave_quantized_tensor_inplace(int8_t* tensor, const size_t num_elts, QuantType quant_type) +{ + if (quant_type == QuantType::INT8_WEIGHT_ONLY) + { + add_bias_and_interleave_int8s_inplace(tensor, num_elts); + } + else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) + { + add_bias_and_interleave_int4s_inplace(tensor, num_elts); + } + else + { + TLLM_CHECK_WITH_INFO(false, "Invalid quantization type for interleaving."); + } +} + +void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor, const int8_t* quantized_tensor, + const std::vector& shape, QuantType quant_type, LayoutDetails details) +{ + + // We only want to run this step for weight only quant. + TLLM_CHECK(quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY || quant_type == QuantType::INT8_WEIGHT_ONLY); + + TLLM_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); + const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; + const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; + const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; + + const int BITS_PER_ELT = get_bits_in_quant_type(quant_type); + const int elts_in_int32 = 32 / BITS_PER_ELT; + + const int rows_per_tile = details.rows_per_column_tile; + + TLLM_CHECK_WITH_INFO(!(num_rows % elts_in_int32), + fmtstr("The number of rows must be a multiple of %d but the number of rows is %ld.", elts_in_int32, num_rows)); + + const uint32_t* input_byte_ptr = reinterpret_cast(quantized_tensor); + uint32_t* output_byte_ptr = reinterpret_cast(interleaved_quantized_tensor); + + TLLM_CHECK_WITH_INFO(!(num_rows % rows_per_tile), + fmtstr("The number of rows must be a multiple of %d but the number of rows is %ld.", rows_per_tile, num_rows)); + + const int num_vec_rows = num_rows / elts_in_int32; + const int vec_rows_per_tile = rows_per_tile / elts_in_int32; + const int interleave = details.columns_interleaved; + + for (int expert = 0; expert < num_experts; ++expert) + { + const int64_t matrix_offset = expert * int64_t(num_vec_rows) * int64_t(num_cols); + for (int read_col = 0; read_col < num_cols; ++read_col) + { + const int64_t write_col = read_col / interleave; + for (int base_vec_row = 0; base_vec_row < num_vec_rows; base_vec_row += vec_rows_per_tile) + { + for (int vec_read_row = base_vec_row; + vec_read_row < std::min(num_vec_rows, base_vec_row + vec_rows_per_tile); ++vec_read_row) + { + const int64_t vec_write_row = interleave * base_vec_row + + vec_rows_per_tile * (read_col % interleave) + vec_read_row % vec_rows_per_tile; + + const int64_t read_offset = matrix_offset + int64_t(read_col) * num_vec_rows + vec_read_row; + const int64_t write_offset + = matrix_offset + int64_t(write_col) * num_vec_rows * interleave + vec_write_row; + output_byte_ptr[write_offset] = input_byte_ptr[read_offset]; + } + } + } + } +} + +void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, const int8_t* row_major_quantized_weight, + const std::vector& shape, QuantType quant_type) +{ + LayoutDetails details = getLayoutDetailsForTransform(quant_type); + + TLLM_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); + + size_t num_elts = 1; + for (const auto& dim : shape) + { + num_elts *= dim; + } + + const size_t num_bytes = num_elts * get_bits_in_quant_type(quant_type) / 8; + + std::vector src_buf(num_bytes); + std::vector dst_buf(num_bytes); + std::copy(row_major_quantized_weight, row_major_quantized_weight + num_bytes, src_buf.begin()); + + // Works on row major data, so issue this permutation first. + if (details.uses_imma_ldsm) + { + const int arch = getSMVersion(); + permute_B_rows_for_mixed_gemm(dst_buf.data(), src_buf.data(), shape, quant_type, arch); + src_buf.swap(dst_buf); + } + + if (details.layoutB == LayoutDetails::Layout::COLUMN_MAJOR) + { + subbyte_transpose(dst_buf.data(), src_buf.data(), shape, quant_type); + src_buf.swap(dst_buf); + } + + if (details.columns_interleaved > 1) + { + interleave_column_major_tensor(dst_buf.data(), src_buf.data(), shape, quant_type, details); + src_buf.swap(dst_buf); + } + + add_bias_and_interleave_quantized_tensor_inplace(src_buf.data(), num_elts, quant_type); + std::copy(src_buf.begin(), src_buf.end(), preprocessed_quantized_weight); +} + +/* + Arguments: + input_weight_ptr - the weight tensor to be quantized. Must be 2-D or 3-D and of type FP16. + + quant_type - the type of the output quantization weight. + + This function does symmetric quantization on 2-D or 3-D tensors. It uses the full int range and assumes the + zero-point is zero and will automatically construct the scales. + + It always quantizes the last axis of the tensor. For 3-D tensors, it operates in "batched" mode where the tensor is + viewed as a stack of matrices and a scale is produced for each column of every matrix. + +Outputs + processed_quantized_weight - quantized AND processed weight for GEMM. This MUST be used with the CUTLASS GEMM + unprocessed_quantized_weight - quantized but unprocessed weights. Useful for reference checking. + scale_ptr - scales for the quantized weight. + + Note that the returned quantized_weights will be preprocessed in a way to accelerate the mixed type GEMM. The data + layout may not make sense if printed. + + Shapes: + quant_type == int8: + If weight is a [m,n] matrix, quantized_weights will have shape [m,n] and scales of shape [n] + If weight is a [b,m,n] tensor, unprocessed_quantized_weight will have shape [b,m,n] and scales of shape [b,n] + quant_type == int4: + If weight is a [m,n] matrix, quantized_weights will have shape [m, ceil(n/2)] and scales of shape [n] + If weight is a [b,m,n] tensor, unprocessed_quantized_weight will have shape [b,m, ceil(n/2)] and scales of shape + [b,n] + + The quantized_weight will be of type torch.int8 and have two int4 values packed in a single byte. This is the + reason for halving the shape. At the time of writing this code, there was not an elegant way to handle this kind + of batched quantization using torch's quantized tensors (to the best of the author's knowledge). Scale tensors + must have a dimension of 1, which breaks the semantics we need for batched weights. + */ + +template +void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_quantized_weight, + ComputeType* scale_ptr, const WeightType* input_weight_ptr, const std::vector& shape, QuantType quant_type) +{ + + TLLM_CHECK_WITH_INFO(processed_quantized_weight, "Processed quantized tensor is NULL"); + TLLM_CHECK_WITH_INFO(scale_ptr, "Scale output pointer is NULL"); + TLLM_CHECK_WITH_INFO(input_weight_ptr, "Input weight pointer is NULL"); + + TLLM_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); + const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; + const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; + const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; + + const int bits_in_type = get_bits_in_quant_type(quant_type); + const int bytes_per_out_col = num_cols * bits_in_type / 8; + + std::vector weight_buf; + if (unprocessed_quantized_weight == nullptr) + { + weight_buf.resize(num_experts * num_rows * num_cols); + unprocessed_quantized_weight = weight_buf.data(); + } + + const int input_mat_size = num_rows * num_cols; + const int quantized_mat_size = num_rows * bytes_per_out_col; + const float quant_range_scale = 1.f / float(1 << (bits_in_type - 1)); + + std::vector per_col_max(num_cols); + + for (int expert = 0; expert < num_experts; ++expert) + { + const WeightType* current_weight = input_weight_ptr + expert * input_mat_size; + int8_t* current_quantized_weight = unprocessed_quantized_weight + expert * quantized_mat_size; + + // First we find the per column max for this expert weight. + for (int jj = 0; jj < num_cols; ++jj) + { + per_col_max[jj] = 0.f; + } + + for (int ii = 0; ii < num_rows; ++ii) + { + const WeightType* current_weight_row = current_weight + ii * num_cols; + for (int jj = 0; jj < num_cols; ++jj) + { + per_col_max[jj] = std::max(per_col_max[jj], std::abs(float(current_weight_row[jj]))); + } + } + + // Then, we construct the scales + ComputeType* current_scales = scale_ptr + expert * num_cols; + for (int jj = 0; jj < num_cols; ++jj) + { + per_col_max[jj] *= quant_range_scale; + current_scales[jj] = ComputeType(per_col_max[jj]); + } + + // Finally, construct the weights. + for (int ii = 0; ii < num_rows; ++ii) + { + int8_t* current_quantized_weight_row = current_quantized_weight + ii * bytes_per_out_col; + const WeightType* current_weight_row = current_weight + ii * num_cols; + for (int jj = 0; jj < bytes_per_out_col; ++jj) + { + + if (quant_type == QuantType::INT8_WEIGHT_ONLY) + { + const float col_scale = per_col_max[jj]; + const float weight_elt = float(current_weight_row[jj]); + const float scaled_weight = round(weight_elt / col_scale); + const int8_t clipped_weight = int8_t(std::max(-128.f, std::min(127.f, scaled_weight))); + current_quantized_weight_row[jj] = clipped_weight; + } + else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) + { + + // We will pack two int4 elements per iteration of the inner loop. + int8_t packed_int4s = 0; + for (int packed_idx = 0; packed_idx < 2; ++packed_idx) + { + const int input_idx = 2 * jj + packed_idx; + if (input_idx < num_cols) + { + const float col_scale = per_col_max[input_idx]; + const float weight_elt = float(current_weight_row[input_idx]); + const float scaled_weight = round(weight_elt / col_scale); + int int_weight = int(scaled_weight); + const int8_t clipped_weight = std::max(-8, std::min(7, int_weight)); + + // Kill the sign extension bits (hence 0x0F mask) then shift to upper bits + // if packing the second int4 and or the bits into the final result. + packed_int4s |= ((clipped_weight & 0x0F) << (4 * packed_idx)); + } + } + current_quantized_weight_row[jj] = packed_int4s; + } + else + { + TLLM_CHECK_WITH_INFO(false, "Unsupported quantization type"); + } + } + } + } + + preprocess_weights_for_mixed_gemm(processed_quantized_weight, unprocessed_quantized_weight, shape, quant_type); +} + +template void symmetric_quantize( + int8_t*, int8_t*, half*, const float*, const std::vector&, QuantType); + +template void symmetric_quantize( + int8_t*, int8_t*, half*, const half*, const std::vector&, QuantType); + +#ifdef ENABLE_BF16 +template void symmetric_quantize<__nv_bfloat16, __nv_bfloat16>( + int8_t*, int8_t*, __nv_bfloat16*, const __nv_bfloat16*, const std::vector&, QuantType); + +template void symmetric_quantize<__nv_bfloat16, float>( + int8_t*, int8_t*, __nv_bfloat16*, const float*, const std::vector&, QuantType); +#endif + +template +void symmetric_quantize(int8_t* processed_quantized_weight, ComputeType* scale_ptr, const WeightType* input_weight_ptr, + const std::vector& shape, QuantType quant_type) +{ + symmetric_quantize(processed_quantized_weight, nullptr, scale_ptr, input_weight_ptr, shape, quant_type); +} + +template void symmetric_quantize(int8_t*, float*, const float*, const std::vector&, QuantType); + +template void symmetric_quantize(int8_t*, half*, const float*, const std::vector&, QuantType); + +template void symmetric_quantize(int8_t*, half*, const half*, const std::vector&, QuantType); + +#ifdef ENABLE_BF16 +template void symmetric_quantize<__nv_bfloat16, __nv_bfloat16>( + int8_t*, __nv_bfloat16*, const __nv_bfloat16*, const std::vector&, QuantType); + +template void symmetric_quantize<__nv_bfloat16, half>( + int8_t*, __nv_bfloat16*, const half*, const std::vector&, QuantType); + +template void symmetric_quantize( + int8_t*, half*, const __nv_bfloat16*, const std::vector&, QuantType); + +template void symmetric_quantize<__nv_bfloat16, float>( + int8_t*, __nv_bfloat16*, const float*, const std::vector&, QuantType); +#endif + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/cutlass_preprocessors.h b/csrc/cutlass_kernels/cutlass_preprocessors.h new file mode 100644 index 000000000000..f93790e54108 --- /dev/null +++ b/csrc/cutlass_kernels/cutlass_preprocessors.h @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include "tensorrt_llm/common/cudaUtils.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ + +enum class QuantType +{ + INT8_WEIGHT_ONLY, + PACKED_INT4_WEIGHT_ONLY +}; +int get_bits_in_quant_type(QuantType quant_type); + +// Shapes here can be 2 or 3D. 2-D shapes are [num_rows, num_cols] +// 3-D shapes are [num_experts, num_rows, num_cols] +void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, const int8_t* quantized_tensor, + const std::vector& shape, QuantType quant_type, const int64_t arch_version); + +void subbyte_transpose(int8_t* transposed_quantized_tensor, const int8_t* quantized_tensor, + const std::vector& shape, QuantType quant_type); + +void add_bias_and_interleave_quantized_tensor_inplace(int8_t* tensor, const size_t num_elts, QuantType quant_type); + +void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, const int8_t* row_major_quantized_weight, + const std::vector& shape, QuantType quant_type); + +template +void symmetric_quantize(int8_t* processed_quantized_weight, ComputeType* scale_ptr, const WeightType* input_weight_ptr, + const std::vector& shape, QuantType quant_type); + +// This is exposed so that we can write tests that use the processed weights for CUTLASS but the unprocessed weight +// to implement a simple reference implementation. +template +void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_quantized_weight, + ComputeType* scale_ptr, const WeightType* input_weight_ptr, const std::vector& shape, QuantType quant_type); + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/fpA_intB_gemm/bf16_int4_gemm_fg_scalebias.cu b/csrc/cutlass_kernels/fpA_intB_gemm/bf16_int4_gemm_fg_scalebias.cu new file mode 100644 index 000000000000..e4783fdefd16 --- /dev/null +++ b/csrc/cutlass_kernels/fpA_intB_gemm/bf16_int4_gemm_fg_scalebias.cu @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ +#ifdef ENABLE_BF16 +template class CutlassFpAIntBGemmRunner<__nv_bfloat16, cutlass::uint4b_t, + cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS>; +#endif +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/fpA_intB_gemm/bf16_int4_gemm_fg_scaleonly.cu b/csrc/cutlass_kernels/fpA_intB_gemm/bf16_int4_gemm_fg_scaleonly.cu new file mode 100644 index 000000000000..8934a2c0df4e --- /dev/null +++ b/csrc/cutlass_kernels/fpA_intB_gemm/bf16_int4_gemm_fg_scaleonly.cu @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ +#ifdef ENABLE_BF16 +template class CutlassFpAIntBGemmRunner<__nv_bfloat16, cutlass::uint4b_t, + cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>; +#endif +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/fpA_intB_gemm/bf16_int4_gemm_per_col.cu b/csrc/cutlass_kernels/fpA_intB_gemm/bf16_int4_gemm_per_col.cu new file mode 100644 index 000000000000..b3fa996a87c9 --- /dev/null +++ b/csrc/cutlass_kernels/fpA_intB_gemm/bf16_int4_gemm_per_col.cu @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ +#ifdef ENABLE_BF16 +template class CutlassFpAIntBGemmRunner<__nv_bfloat16, cutlass::uint4b_t, + cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY>; +#endif +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/fpA_intB_gemm/bf16_int8_gemm_fg_scalebias.cu b/csrc/cutlass_kernels/fpA_intB_gemm/bf16_int8_gemm_fg_scalebias.cu new file mode 100644 index 000000000000..064e4dbde97b --- /dev/null +++ b/csrc/cutlass_kernels/fpA_intB_gemm/bf16_int8_gemm_fg_scalebias.cu @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ +#ifdef ENABLE_BF16 +template class CutlassFpAIntBGemmRunner<__nv_bfloat16, uint8_t, + cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS>; +#endif +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/fpA_intB_gemm/bf16_int8_gemm_fg_scaleonly.cu b/csrc/cutlass_kernels/fpA_intB_gemm/bf16_int8_gemm_fg_scaleonly.cu new file mode 100644 index 000000000000..0dbdfabe0a69 --- /dev/null +++ b/csrc/cutlass_kernels/fpA_intB_gemm/bf16_int8_gemm_fg_scaleonly.cu @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ +#ifdef ENABLE_BF16 +template class CutlassFpAIntBGemmRunner<__nv_bfloat16, uint8_t, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>; +#endif +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/fpA_intB_gemm/bf16_int8_gemm_per_col.cu b/csrc/cutlass_kernels/fpA_intB_gemm/bf16_int8_gemm_per_col.cu new file mode 100644 index 000000000000..6701d0637ec4 --- /dev/null +++ b/csrc/cutlass_kernels/fpA_intB_gemm/bf16_int8_gemm_per_col.cu @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ +#ifdef ENABLE_BF16 +template class CutlassFpAIntBGemmRunner<__nv_bfloat16, uint8_t, cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY>; +#endif +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/fpA_intB_gemm/fp16_int4_gemm_fg_scalebias.cu b/csrc/cutlass_kernels/fpA_intB_gemm/fp16_int4_gemm_fg_scalebias.cu new file mode 100644 index 000000000000..45e0f4c0f8d1 --- /dev/null +++ b/csrc/cutlass_kernels/fpA_intB_gemm/fp16_int4_gemm_fg_scalebias.cu @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ +template class CutlassFpAIntBGemmRunner; +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/fpA_intB_gemm/fp16_int4_gemm_fg_scaleonly.cu b/csrc/cutlass_kernels/fpA_intB_gemm/fp16_int4_gemm_fg_scaleonly.cu new file mode 100644 index 000000000000..113c6c61741d --- /dev/null +++ b/csrc/cutlass_kernels/fpA_intB_gemm/fp16_int4_gemm_fg_scaleonly.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ +template class CutlassFpAIntBGemmRunner; +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/fpA_intB_gemm/fp16_int4_gemm_per_col.cu b/csrc/cutlass_kernels/fpA_intB_gemm/fp16_int4_gemm_per_col.cu new file mode 100644 index 000000000000..6e69985edc54 --- /dev/null +++ b/csrc/cutlass_kernels/fpA_intB_gemm/fp16_int4_gemm_per_col.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ +template class CutlassFpAIntBGemmRunner; +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/fpA_intB_gemm/fp16_int8_gemm_fg_scalebias.cu b/csrc/cutlass_kernels/fpA_intB_gemm/fp16_int8_gemm_fg_scalebias.cu new file mode 100644 index 000000000000..51e33974f76d --- /dev/null +++ b/csrc/cutlass_kernels/fpA_intB_gemm/fp16_int8_gemm_fg_scalebias.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ +template class CutlassFpAIntBGemmRunner; +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/fpA_intB_gemm/fp16_int8_gemm_fg_scaleonly.cu b/csrc/cutlass_kernels/fpA_intB_gemm/fp16_int8_gemm_fg_scaleonly.cu new file mode 100644 index 000000000000..148cfb519e19 --- /dev/null +++ b/csrc/cutlass_kernels/fpA_intB_gemm/fp16_int8_gemm_fg_scaleonly.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ +template class CutlassFpAIntBGemmRunner; +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/fpA_intB_gemm/fp16_int8_gemm_per_col.cu b/csrc/cutlass_kernels/fpA_intB_gemm/fp16_int8_gemm_per_col.cu new file mode 100644 index 000000000000..35d199f58f14 --- /dev/null +++ b/csrc/cutlass_kernels/fpA_intB_gemm/fp16_int8_gemm_per_col.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ +template class CutlassFpAIntBGemmRunner; +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h b/csrc/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h new file mode 100644 index 000000000000..c805f7a4e000 --- /dev/null +++ b/csrc/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h @@ -0,0 +1,120 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cutlass_extensions/gemm_configs.h" +#include "cutlass_extensions/weight_only_quant_op.h" +#include + +namespace tkc = tensorrt_llm::cutlass_extensions; + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ + +// TRT Activation Type does not have Gelu or Silu +enum class ActivationType +{ + Gelu, + Relu, + Silu, + Identity, + InvalidType +}; + +/* + This runner only supports: + T in {half, __nv_bfloat} WeightType in {int8_t, cutlass::uint4b_t} + + Activations, biases, scales and outputs are all assumed to be row-major. + + However, it is assumed that B is in a special format governed by cutlass_extensions/gemm/kernel/mixed_gemm_B_layout. + In this case, B must be preprocessed using the cutlass weight only quant preprocessors. The weight preprocessor + will instantiate the layout and preprocess based on the instantiation, so layout changes should only require + modifications to mix_gemm_B_layout.h. +*/ + +class CutlassFpAIntBGemmRunnerInterface +{ +public: + CutlassFpAIntBGemmRunnerInterface() {} + + virtual ~CutlassFpAIntBGemmRunnerInterface() {} + + virtual void gemm(const void* A, const void* B, const void* weight_scales, void* C, int m, int n, int k, + tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) + = 0; + + virtual void gemm(const void* A, const void* B, const void* weight_scales, const void* weight_zero_points, + const void* biases, void* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemmConfig, + char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) + = 0; + + // Returns desired workspace size in bytes. + virtual size_t getWorkspaceSize(const int m, const int n, const int k) = 0; + + virtual std::vector getConfigs() const = 0; + +protected: + static constexpr int SPLIT_K_LIMIT = 7; + static constexpr int MIN_M_TILE = 32; + static constexpr int MIN_N_TILE = 64; +}; + +template +class CutlassFpAIntBGemmRunner : public virtual CutlassFpAIntBGemmRunnerInterface +{ +public: + CutlassFpAIntBGemmRunner(); + ~CutlassFpAIntBGemmRunner(); + + void gemm(const void* A, const void* B, const void* weight_scales, void* C, int m, int n, int k, + tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, + cudaStream_t stream) override; + + void gemm(const void* A, const void* B, const void* weight_scales, const void* weight_zero_points, + const void* biases, void* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemmConfig, + char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) override; + + // Disabled since the fused GEMM, activation kernels will not be used in v1. + + // void gemm_bias_act(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, int m, int n, + // int k, ActivationType activation_type, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t + // stream); + + // Returns desired workspace size in bytes. + size_t getWorkspaceSize(const int m, const int n, const int k) override; + + std::vector getConfigs() const override; + +private: + template + void dispatch_to_arch(const T* A, const WeightType* B, const T* weight_scales, const T* weight_zero_points, + const T* biases, T* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemm_config, + char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream, int* occupancy = nullptr); + +private: + int sm_; + int multi_processor_count_; +}; + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h b/csrc/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h new file mode 100644 index 000000000000..b816b111c96e --- /dev/null +++ b/csrc/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h @@ -0,0 +1,487 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef _WIN32 +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // #ifndef _WIN32 + +#include "cutlass/gemm/kernel/default_gemm.h" +#include "cutlass_extensions/compute_occupancy.h" +#include "cutlass_extensions/gemm/device/gemm_universal_base_compat.h" + +#include "cutlass_extensions/epilogue_helpers.h" +#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h" +#include "cutlass_extensions/gemm/kernel/fpA_intB_gemm.h" +#include "cutlass_extensions/gemm/threadblock/default_mma.h" +#include "cutlass_extensions/gemm_configs.h" + +#ifndef _WIN32 +#pragma GCC diagnostic pop +#endif // #ifndef _WIN32 + +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/common/logger.h" +#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h" +#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h" + +namespace tk = tensorrt_llm::common; +namespace tkc = tensorrt_llm::cutlass_extensions; + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ + +template +void generic_mixed_gemm_kernelLauncher(const T* A, const WeightType* B, const T* weight_scales, + const T* weight_zero_points, const T* biases, T* C, int m, int n, int k, const int group_size, + tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream, + int* occupancy = nullptr) +{ + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); + +#ifdef ENABLE_BF16 + static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value + || cutlass::platform::is_same::value, + "Specialized for bfloat16, half, float"); +#else + static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value, + "Specialized for half, float"); +#endif + + static_assert(cutlass::platform::is_same::value + || cutlass::platform::is_same::value + || cutlass::platform::is_same::value, + ""); + + // The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary. + using ElementType_ = + typename cutlass::platform::conditional::value, cutlass::half_t, T>::type; +#ifdef ENABLE_BF16 + using ElementType = + typename cutlass::platform::conditional::value, + cutlass::bfloat16_t, ElementType_>::type; +#else + using ElementType = ElementType_; +#endif + + using CutlassWeightType_ = + typename cutlass::platform::conditional::value, cutlass::half_t, + WeightType>::type; +#ifdef ENABLE_BF16 + using CutlassWeightType = + typename cutlass::platform::conditional::value, + cutlass::bfloat16_t, CutlassWeightType_>::type; +#else + using CutlassWeightType = CutlassWeightType_; +#endif + + // We need separate config for each architecture since we will target different tensorcore instructions. For float, + // we do not target TCs. + using MixedGemmArchTraits = cutlass::gemm::kernel::MixedGemmArchTraits; + using ElementAccumulator = typename MixedGemmArchTraits::AccType; + + using EpilogueOp = typename tkc::Epilogue::Op; + + using Operator = typename MixedGemmArchTraits::Operator; + using TaggedOperator = typename cutlass::arch::TagOperator::TaggedOperator; + + using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm, Stages, true, + TaggedOperator>::GemmKernel; + + using GemmKernel = cutlass::gemm::kernel::GemmFpAIntB; + + if (occupancy != nullptr) + { + *occupancy = tensorrt_llm::cutlass_extensions::compute_occupancy_for_kernel(); + return; + } + + using Gemm = cutlass::gemm::device::GemmUniversalBaseCompat; + + const int ldb = cutlass::platform::is_same::value + ? n + : k * GemmKernel::kInterleave; + + if (weight_scales == nullptr) + { + throw std::runtime_error("Weight scales must always be set to a non-null value."); + } + + if constexpr (cutlass::isFinegrained(QuantOp)) + { + if (group_size != 64 && group_size != 128) + { + throw std::runtime_error("Only group size 64 and 128 supported for fine grained kernels."); + } + + if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY) + { + if (weight_zero_points != nullptr) + { + throw std::runtime_error("Weight zero pointer must be a nullptr for scale only fine grained"); + } + } + else if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS) + { + if (weight_zero_points == nullptr) + { + throw std::runtime_error("Weight zero pointer must be valid for scale and bias fine grained"); + } + } + } + else + { + if (group_size != k) + { + throw std::runtime_error("Invalid group size for per column scaling kernels."); + } + + if (weight_zero_points != nullptr) + { + throw std::runtime_error("Weight zero-points must be null when running per column scaling"); + } + } + + const int ld_scale_zero = cutlass::isFinegrained(QuantOp) ? n : 0; + ElementAccumulator output_op_beta = (biases == nullptr) ? ElementAccumulator(0.f) : ElementAccumulator(1.f); + typename Gemm::Arguments args({m, n, k}, group_size, {reinterpret_cast(const_cast(A)), k}, + {reinterpret_cast(const_cast(B)), ldb}, + {reinterpret_cast(const_cast(weight_scales)), ld_scale_zero}, + {reinterpret_cast(const_cast(weight_zero_points)), ld_scale_zero}, + {reinterpret_cast(const_cast(biases)), 0}, {reinterpret_cast(C), n}, + gemm_config.split_k_factor, {ElementAccumulator(1.f), output_op_beta}); + + // This assertion is enabled because because for the column interleaved layout, K MUST be a multiple of + // threadblockK. The reason for this is that the default pitchlinear iterators are used to handle walking over the + // interleaved matrix. The way masking in handled in these do not map to the interleaved layout. We need to write + // our own predicated iterator in order to relax this limitation. + if (GemmKernel::kInterleave > 1 + && ((k % MixedGemmArchTraits::ThreadblockK) + || ((k / gemm_config.split_k_factor) % MixedGemmArchTraits::ThreadblockK))) + { + throw std::runtime_error("Temp assertion: k must be multiple of threadblockK"); + } + + Gemm gemm; + if (gemm.get_workspace_size(args) > workspace_bytes) + { + TLLM_LOG_WARNING( + "Requested split-k but workspace size insufficient. Falling back to non-split-k implementation."); + // If requested split-k factor will require more workspace bytes, revert to standard gemm. + args.batch_count = 1; + } + + auto can_implement = gemm.can_implement(args); + if (can_implement != cutlass::Status::kSuccess) + { + std::string err_msg = "fpA_intB cutlass kernel will fail for params. Error: " + + std::string(cutlassGetStatusString(can_implement)); + throw std::runtime_error("[TensorRT-LLm Error][fpA_intB Runner] " + err_msg); + } + + auto init_status = gemm.initialize(args, workspace, stream); + if (init_status != cutlass::Status::kSuccess) + { + std::string err_msg + = "Failed to initialize cutlass fpA_intB gemm. Error: " + std::string(cutlassGetStatusString(init_status)); + throw std::runtime_error("[TensorRT-LLm Error][fpA_intB Runner] " + err_msg); + } + + auto run_status = gemm.run(stream); + if (run_status != cutlass::Status::kSuccess) + { + std::string err_msg + = "Failed to run cutlass fpA_intB gemm. Error: " + std::string(cutlassGetStatusString(run_status)); + throw std::runtime_error("[TensorRT-LLm Error][fpA_intB Runner] " + err_msg); + } +} + +// This filters out invalid template combinations that we DON'T want instantiated in CUTLASS. For example, +// instantiating SM=75, Stages=3 is invalid so we would need to filter that out. Fine grained +// quanitzation is only supported on Ampere+ GPUs. +template +void filter_and_run_mixed_gemm(const T* A, const WeightType* B, const T* weight_scales, const T* weight_zero_points, + const T* biases, T* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemm_config, + char* workspace, size_t workspace_bytes, cudaStream_t stream, int* occupancy = nullptr) +{ + + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); + if constexpr (cutlass::isFinegrained(QuantOp) && arch::kMinComputeCapability < 80) + { + // Finegrained only supported on Ampere + std::string err_msg = "Cutlass fpA_intB gemm not implemented for arch " + + std::to_string(arch::kMinComputeCapability) + " with finegraind weight-only quantization."; + throw std::runtime_error("[TensorRT-LLm Error][filter_and_run_mixed_gemm] " + err_msg); + } + else if constexpr (Stages > 2 && arch::kMinComputeCapability < 80) + { + // Multistage only supported on Ampere + std::string err_msg = "Cutlass fpA_intB gemm not supported for arch " + + std::to_string(arch::kMinComputeCapability) + " with stages set to " + std::to_string(Stages); + throw std::runtime_error("[TensorRT-LLm Error][filter_and_run_mixed_gemm] " + err_msg); + } + else + { + generic_mixed_gemm_kernelLauncher(A, B, weight_scales, weight_zero_points, biases, C, m, n, k, group_size, gemm_config, workspace, + workspace_bytes, stream, occupancy); + } +} + +template +void dispatch_gemm_config(const T* A, const WeightType* B, const T* weight_scales, const T* weight_zero_points, + const T* biases, T* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemm_config, + char* workspace, size_t workspace_bytes, cudaStream_t stream, int* occupancy = nullptr) +{ + + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); + switch (gemm_config.stages) + { + case 2: + filter_and_run_mixed_gemm(A, B, + weight_scales, weight_zero_points, biases, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, + stream, occupancy); + break; + case 3: + filter_and_run_mixed_gemm(A, B, + weight_scales, weight_zero_points, biases, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, + stream, occupancy); + break; + case 4: + filter_and_run_mixed_gemm(A, B, + weight_scales, weight_zero_points, biases, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, + stream, occupancy); + break; + default: + std::string err_msg = "dispatch_gemm_config does not support stages " + std::to_string(gemm_config.stages); + throw std::runtime_error("[TensorRT-LLm Error][dispatch_gemm_config] " + err_msg); + break; + } +} + +template +void dispatch_gemm_to_cutlass(const T* A, const WeightType* B, const T* weight_scales, const T* weight_zero_points, + const T* biases, T* C, int m, int n, int k, const int group_size, char* workspace, size_t workspace_bytes, + tkc::CutlassGemmConfig gemm_config, cudaStream_t stream, int* occupancy = nullptr) +{ + + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); + + // Note that SIMT configs are omitted here since they are not supported for fpA_intB. + // We also only instantiate configs here where threadblockShapeM == warpShapeM since those usually perform the best + // for mixed type gemms. + switch (gemm_config.tile_config) + { + case tkc::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, weight_zero_points, biases, C, m, n, k, + group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, weight_zero_points, biases, C, m, n, k, + group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: + if (arch::kMinComputeCapability < 75) + { + TLLM_CHECK_WITH_INFO(false, "Invalid config on Volta"); + } + else + { + dispatch_gemm_config, + cutlass::gemm::GemmShape<128, 32, 64>>(A, B, weight_scales, weight_zero_points, biases, C, m, n, k, + group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + } + break; + case tkc::CutlassTileConfig::Undefined: + throw std::runtime_error("[TensorRT-LLm Error][fpA_intB][dispatch_gemm_to_cutlass] gemm config undefined."); + break; + case tkc::CutlassTileConfig::ChooseWithHeuristic: + throw std::runtime_error( + "[TensorRT-LLm Error][fpA_intB][dispatch_gemm_to_cutlass] gemm config should have already been set by " + "heuristic."); + break; + default: + throw std::runtime_error( + "[TensorRT-LLm Error][fpA_intB][dispatch_gemm_to_cutlass] Config is invalid for mixed type GEMM."); + break; + } +} + +template +CutlassFpAIntBGemmRunner::CutlassFpAIntBGemmRunner() +{ + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); + int device{-1}; + tk::check_cuda_error(cudaGetDevice(&device)); + sm_ = tk::getSMVersion(); + tk::check_cuda_error(cudaDeviceGetAttribute(&multi_processor_count_, cudaDevAttrMultiProcessorCount, device)); +} + +template +CutlassFpAIntBGemmRunner::~CutlassFpAIntBGemmRunner() +{ + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); +} + +template +template +void CutlassFpAIntBGemmRunner::dispatch_to_arch(const T* A, const WeightType* B, + const T* weight_scales, const T* weight_zero_points, const T* biases, T* C, int m, int n, int k, + const int group_size, tkc::CutlassGemmConfig gemm_config, char* workspace_ptr, const size_t workspace_bytes, + cudaStream_t stream, int* occupancy) +{ + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); + if (sm_ >= 70 && sm_ < 75) + { + dispatch_gemm_to_cutlass(A, B, weight_scales, + weight_zero_points, biases, C, m, n, k, group_size, workspace_ptr, workspace_bytes, gemm_config, stream, + occupancy); + } + else if (sm_ >= 75 && sm_ < 80) + { + dispatch_gemm_to_cutlass(A, B, weight_scales, + weight_zero_points, biases, C, m, n, k, group_size, workspace_ptr, workspace_bytes, gemm_config, stream, + occupancy); + } + else if (sm_ >= 80 && sm_ <= 90) + { + dispatch_gemm_to_cutlass(A, B, weight_scales, + weight_zero_points, biases, C, m, n, k, group_size, workspace_ptr, workspace_bytes, gemm_config, stream, + occupancy); + } + else + { + throw std::runtime_error( + "[TensorRT-LLm Error][CutlassFpAIntBGemmRunner][GEMM Dispatch] Arch unsupported for CUTLASS mixed type " + "GEMM"); + } +} + +// Disabled since the fused GEMM, activation kernels will not be used in v1. + +// template +// void CutlassFpAIntBGemmRunner::gemm_bias_act(const T* A, const WeightType* B, const T* +// weight_scales, +// const T* biases, T* C, int m, int n, int k, ActivationType activation_type, char* workspace_ptr, +// const size_t workspace_bytes, cudaStream_t stream) +// { +// TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); + +// switch (activation_type) +// { +// case ActivationType::Relu: +// run_gemm( +// A, B, weight_scales, biases, C, m, n, k, workspace_ptr, workspace_bytes, stream); +// break; +// case ActivationType::Gelu: +// run_gemm( +// A, B, weight_scales, biases, C, m, n, k, workspace_ptr, workspace_bytes, stream); +// break; +// case ActivationType::Silu: +// run_gemm( +// A, B, weight_scales, biases, C, m, n, k, workspace_ptr, workspace_bytes, stream); +// break; +// case ActivationType::Identity: +// run_gemm(A, B, weight_scales, biases, C, m, n, k, workspace_ptr, workspace_bytes, +// stream); break; +// case ActivationType::InvalidType: TLLM_CHECK_WITH_INFO(false, "Activation type for fpA_intB must be +// valid."); break; default: +// { +// TLLM_CHECK_WITH_INFO(false, "Invalid activation type."); +// } +// } +// } + +template +void CutlassFpAIntBGemmRunner::gemm(const void* A, const void* B, const void* weight_scales, + const void* weight_zero_points, const void* biases, void* C, int m, int n, int k, const int group_size, + tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) +{ + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); + if constexpr ((QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS) + || (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY)) + { + dispatch_to_arch((const T*) A, (const WeightType*) B, (const T*) weight_scales, + (const T*) weight_zero_points, (const T*) biases, (T*) C, m, n, k, group_size, gemmConfig, workspace_ptr, + workspace_bytes, stream, nullptr); + } + else + { + throw std::runtime_error( + "Overload with scale, zero and group size only supported for fine grained bias template."); + } +} + +template +void CutlassFpAIntBGemmRunner::gemm(const void* A, const void* B, const void* weight_scales, + void* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, + cudaStream_t stream) +{ + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); + + if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY) + { + dispatch_to_arch((const T*) A, (const WeightType*) B, (const T*) weight_scales, nullptr, + nullptr, (T*) C, m, n, k, k, gemmConfig, workspace_ptr, workspace_bytes, stream, nullptr); + } + else + { + throw std::runtime_error("Overload with scale only (and no group size) only supported for per column scaling."); + } +} + +template +std::vector CutlassFpAIntBGemmRunner::getConfigs() const +{ + static constexpr bool is_weight_only = !std::is_same::value; + std::vector candidateConfigs + = get_candidate_configs(sm_, is_weight_only, false, false, SPLIT_K_LIMIT); + return candidateConfigs; +} + +template +size_t CutlassFpAIntBGemmRunner::getWorkspaceSize(const int m, const int n, const int k) +{ + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); + // These are the min tile sizes for each config, which would launch the maximum number of blocks + const int max_grid_m = cutlass::ceil_div(m, MIN_M_TILE); + const int max_grid_n = cutlass::ceil_div(n, MIN_N_TILE); + // We need 4 bytes per block in the worst case. We launch split_k_limit in z dim. + return static_cast(max_grid_m * max_grid_n * SPLIT_K_LIMIT * 4); +} + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/int8_gemm/int8_gemm.h b/csrc/cutlass_kernels/int8_gemm/int8_gemm.h new file mode 100644 index 000000000000..f06ba4b4d85a --- /dev/null +++ b/csrc/cutlass_kernels/int8_gemm/int8_gemm.h @@ -0,0 +1,93 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cutlass_extensions/gemm_configs.h" +#include "tensorrt_llm/common/quantization.h" +#include + +namespace tk = tensorrt_llm::common; +namespace tkc = tensorrt_llm::cutlass_extensions; + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ + +/* + This runner supports: + int8_t inputs (A and B) + float alpha scalings (either per-col, or per-col x per-row) + T output (D) where T = {float, half, __nv_bfloat16} // TODO + + Activations, biases, scales and outputs are all assumed to be row-major. + Weights are assumed to be column-major. +*/ + +class CutlassInt8GemmRunnerInterface +{ +public: + CutlassInt8GemmRunnerInterface() {} + + virtual ~CutlassInt8GemmRunnerInterface() {} + + virtual void gemm(const int8_t* A, const int8_t* B, tk::QuantMode quantOption, const float* alphaCol, + const float* alphaRow, void* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspacePtr, + const size_t workspaceBytes, cudaStream_t stream) + = 0; + + // Returns desired workspace size in bytes. + virtual size_t getWorkspaceSize(const int m, const int n, const int k) = 0; + + virtual std::vector getConfigs() const = 0; + +protected: + static constexpr int SPLIT_K_LIMIT = 7; + static constexpr int MIN_M_TILE = 32; + static constexpr int MIN_N_TILE = 64; +}; + +template +class CutlassInt8GemmRunner : public virtual CutlassInt8GemmRunnerInterface +{ +public: + CutlassInt8GemmRunner(); + ~CutlassInt8GemmRunner(); + + void gemm(const int8_t* A, const int8_t* B, tk::QuantMode quantOption, const float* alphaCol, const float* alphaRow, + void* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspacePtr, + const size_t workspaceBytes, cudaStream_t stream) override; + + // Returns desired workspace size in bytes. + size_t getWorkspaceSize(const int m, const int n, const int k) override; + + std::vector getConfigs() const override; + +private: + void dispatchToArch(const int8_t* A, const int8_t* B, tk::QuantMode quantOption, const float* alphaCol, + const float* alphaRow, T* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspacePtr, + const size_t workspaceBytes, cudaStream_t stream, int* occupancy = nullptr); + + int mSm; + int mMultiProcessorCount; +}; + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/int8_gemm/int8_gemm_bf16.cu b/csrc/cutlass_kernels/int8_gemm/int8_gemm_bf16.cu new file mode 100644 index 000000000000..a3633bc0992a --- /dev/null +++ b/csrc/cutlass_kernels/int8_gemm/int8_gemm_bf16.cu @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ + +#ifdef ENABLE_BF16 +template class CutlassInt8GemmRunner<__nv_bfloat16>; +#endif + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/int8_gemm/int8_gemm_fp16.cu b/csrc/cutlass_kernels/int8_gemm/int8_gemm_fp16.cu new file mode 100644 index 000000000000..7189956d5d03 --- /dev/null +++ b/csrc/cutlass_kernels/int8_gemm/int8_gemm_fp16.cu @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ + +template class CutlassInt8GemmRunner; + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/int8_gemm/int8_gemm_fp32.cu b/csrc/cutlass_kernels/int8_gemm/int8_gemm_fp32.cu new file mode 100644 index 000000000000..861a2d4ff0f3 --- /dev/null +++ b/csrc/cutlass_kernels/int8_gemm/int8_gemm_fp32.cu @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ + +template class CutlassInt8GemmRunner; // for compilation only + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/int8_gemm/int8_gemm_int32.cu b/csrc/cutlass_kernels/int8_gemm/int8_gemm_int32.cu new file mode 100644 index 000000000000..6814b00e0286 --- /dev/null +++ b/csrc/cutlass_kernels/int8_gemm/int8_gemm_int32.cu @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ + +template class CutlassInt8GemmRunner; + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/int8_gemm/int8_gemm_template.h b/csrc/cutlass_kernels/int8_gemm/int8_gemm_template.h new file mode 100644 index 000000000000..55bdc98df251 --- /dev/null +++ b/csrc/cutlass_kernels/int8_gemm/int8_gemm_template.h @@ -0,0 +1,388 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef _WIN32 +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // #ifndef _WIN32 + +// clang-format off +#include +#include +#include +#include +#include +// clang-format on + +#include "cutlass_extensions/compute_occupancy.h" +#include "cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h" +#include "cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h" +#include "cutlass_extensions/epilogue_helpers.h" +#include "cutlass_extensions/gemm_configs.h" + +#include "cutlass_extensions/gemm/kernel/default_int8_traits.h" +#include "cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h" + +#ifndef _WIN32 +#pragma GCC diagnostic pop +#endif // #ifndef _WIN32 + +#include "tensorrt_llm/common/allocator.h" +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h" +#include "tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm.h" +#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttentionUtils.h" + +#include +#include + +namespace tk = tensorrt_llm::common; +namespace tkc = tensorrt_llm::cutlass_extensions; + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ + +template +void genericInt8GemmKernelLauncher(const int8_t* A, const int8_t* B, tk::QuantMode quantOption, const float* alphaCol, + const float* alphaRow, T* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspace, + size_t workspaceBytes, cudaStream_t stream, int* occupancy = nullptr) +{ + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); + + using ElementInput = int8_t; + + using ElementOutput_ = + typename cutlass::platform::conditional::value, cutlass::half_t, T>::type; +#ifdef ENABLE_BF16 + using ElementOutput = + typename cutlass::platform::conditional::value, + cutlass::bfloat16_t, ElementOutput_>::type; +#else + using ElementOutput = ElementOutput_; +#endif + + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; + + using OperatorClass = typename cutlass::gemm::kernel::Int8GemmArchTraits::OperatorClass; + using InstructionShape = typename cutlass::gemm::kernel::Int8GemmArchTraits::InstructionShape; + + using DefaultGemmConf = typename cutlass::gemm::device::DefaultGemmConfiguration; + using GemmOp = typename DefaultGemmConf::Operator; + using EpilogueOp = typename DefaultGemmConf::EpilogueOutputOp; + + // only TN is supported (s8 * s8 + s32) + using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm::GemmKernel; + + using AlphaColTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< + cutlass::epilogue::threadblock::OutputTileOptimalThreadMap< + typename GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::Shape, + typename GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::Count, + GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::kThreads, + GemmKernel_::Epilogue::OutputTileIterator::kElementsPerAccess, cutlass::sizeof_bits::value>, + ElementCompute>; + + // Epilogue visitor + using EpilogueVisitor = typename cutlass::epilogue::threadblock::EpilogueVisitorPerRowPerCol; + + /// Epilogue + using Epilogue = typename cutlass::epilogue::threadblock::EpilogueWithVisitorFromExistingEpilogue::Epilogue; + + // GEMM + using GemmKernel + = cutlass::gemm::kernel::GemmWithEpilogueVisitor; + + if (occupancy != nullptr) + { + *occupancy = tensorrt_llm::cutlass_extensions::compute_occupancy_for_kernel(); + return; + } + + using Gemm = cutlass::gemm::device::GemmUniversalBaseCompat; + + typename EpilogueOp::Params linearScalingParams; // TODO: right now it's unused (scaling is done in + // visitor, no activation needed) + typename Gemm::Arguments args{cutlass::gemm::GemmUniversalMode::kBatched, {m, n, k}, 1, + {reinterpret_cast(const_cast(A)), k}, + {reinterpret_cast(const_cast(B)), k}, quantOption, + {reinterpret_cast(const_cast(alphaCol)), 0}, + {reinterpret_cast(const_cast(alphaRow)), 0}, {nullptr, 0}, + {reinterpret_cast(C), n}, 0, 0, + typename EpilogueVisitor::Arguments(linearScalingParams, 0, 0, 0)}; + + Gemm gemm; + // TODO: handle that + if (gemm.get_workspace_size(args) > workspaceBytes) + { + TLLM_LOG_WARNING( + "Requested split-k but workspace size insufficient. Falling back to non-split-k implementation."); + // If requested split-k factor will require more workspace bytes, revert to standard gemm. + args.batch_count = 1; + } + + auto can_implement = gemm.can_implement(args); + if (can_implement != cutlass::Status::kSuccess) + { + std::string errMsg = "int8gemm cutlass kernel will fail for params. Error: " + + std::string(cutlassGetStatusString(can_implement)); + throw std::runtime_error("[TensorRT-LLM Error][int8gemm Runner] " + errMsg); + } + + auto initStatus = gemm.initialize(args, workspace, stream); + if (initStatus != cutlass::Status::kSuccess) + { + std::string errMsg + = "Failed to initialize cutlass int8 gemm. Error: " + std::string(cutlassGetStatusString(initStatus)); + throw std::runtime_error("[TensorRT-LLM Error][int8gemm Runner] " + errMsg); + } + + auto runStatus = gemm.run(stream); + if (runStatus != cutlass::Status::kSuccess) + { + std::string errMsg + = "Failed to run cutlass int8 gemm. Error: " + std::string(cutlassGetStatusString(runStatus)); + throw std::runtime_error("[TensorRT-LLM Error][int8gemm Runner] " + errMsg); + } +} + +template +struct dispatchStages +{ + static void dispatch(const int8_t* A, const int8_t* B, tk::QuantMode quantOption, const float* alphaCol, + const float* alphaRow, T* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspace, + size_t workspaceBytes, cudaStream_t stream, int* occupancy = nullptr) + { + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); + std::string errMsg = "Cutlass int8 gemm. Not instantiates for arch " + + std::to_string(arch::kMinComputeCapability) + " with stages set to " + std::to_string(Stages); + throw std::runtime_error("[TensorRT-LLM Error][dispatchStages::dispatch] " + errMsg); + } +}; + +template +struct dispatchStages +{ + static void dispatch(const int8_t* A, const int8_t* B, tk::QuantMode quantOption, const float* alphaCol, + const float* alphaRow, T* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspace, + size_t workspaceBytes, cudaStream_t stream, int* occupancy = nullptr) + { + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); + genericInt8GemmKernelLauncher(A, B, quantOption, alphaCol, alphaRow, C, + m, n, k, gemmConfig, workspace, workspaceBytes, stream, occupancy); + } +}; + +template +struct dispatchStages 2)>::type> +{ + static void dispatch(const int8_t* A, const int8_t* B, tk::QuantMode quantOption, const float* alphaCol, + const float* alphaRow, T* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspace, + size_t workspaceBytes, cudaStream_t stream, int* occupancy = nullptr) + { + + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); + genericInt8GemmKernelLauncher(A, B, quantOption, + alphaCol, alphaRow, C, m, n, k, gemmConfig, workspace, workspaceBytes, stream, occupancy); + } +}; + +template +void dispatchGemmConfig(const int8_t* A, const int8_t* B, tk::QuantMode quantOption, const float* alphaCol, + const float* alphaRow, T* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspace, + size_t workspaceBytes, cudaStream_t stream, int* occupancy = nullptr) +{ + + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); + switch (gemmConfig.stages) + { + case 2: + using DispatcherStages2 = dispatchStages; + DispatcherStages2::dispatch(A, B, quantOption, alphaCol, alphaRow, C, m, n, k, gemmConfig, workspace, + workspaceBytes, stream, occupancy); + break; + case 3: + using DispatcherStages3 = dispatchStages; + DispatcherStages3::dispatch(A, B, quantOption, alphaCol, alphaRow, C, m, n, k, gemmConfig, workspace, + workspaceBytes, stream, occupancy); + break; + case 4: + using DispatcherStages4 = dispatchStages; + DispatcherStages4::dispatch(A, B, quantOption, alphaCol, alphaRow, C, m, n, k, gemmConfig, workspace, + workspaceBytes, stream, occupancy); + break; + case 5: + using DispatcherStages5 = dispatchStages; + DispatcherStages5::dispatch(A, B, quantOption, alphaCol, alphaRow, C, m, n, k, gemmConfig, workspace, + workspaceBytes, stream, occupancy); + break; + case 6: + using DispatcherStages6 = dispatchStages; + DispatcherStages6::dispatch(A, B, quantOption, alphaCol, alphaRow, C, m, n, k, gemmConfig, workspace, + workspaceBytes, stream, occupancy); + break; + default: + std::string errMsg = "dispatchGemmConfig does not support stages " + std::to_string(gemmConfig.stages); + throw std::runtime_error("[TensorRT-LLM Error][dispatch_gemm_config] " + errMsg); + break; + } +} + +template +void dispatchGemmToCutlass(const int8_t* A, const int8_t* B, tk::QuantMode quantOption, const float* alphaCol, + const float* alphaRow, T* C, int m, int n, int k, char* workspace, size_t workspaceBytes, + tkc::CutlassGemmConfig gemmConfig, cudaStream_t stream, int* occupancy = nullptr) +{ + + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); + + switch (gemmConfig.tile_config) + { + case tkc::CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64: + dispatchGemmConfig, cutlass::gemm::GemmShape<64, 32, 64>>(A, B, + quantOption, alphaCol, alphaRow, C, m, n, k, gemmConfig, workspace, workspaceBytes, stream, occupancy); + break; + case tkc::CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64: + dispatchGemmConfig, cutlass::gemm::GemmShape<64, 64, 64>>(A, B, + quantOption, alphaCol, alphaRow, C, m, n, k, gemmConfig, workspace, workspaceBytes, stream, occupancy); + break; + case tkc::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: + dispatchGemmConfig, cutlass::gemm::GemmShape<32, 32, 64>>(A, B, + quantOption, alphaCol, alphaRow, C, m, n, k, gemmConfig, workspace, workspaceBytes, stream, occupancy); + break; + case tkc::CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: + dispatchGemmConfig, cutlass::gemm::GemmShape<64, 32, 64>>(A, B, + quantOption, alphaCol, alphaRow, C, m, n, k, gemmConfig, workspace, workspaceBytes, stream, occupancy); + break; + case tkc::CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64: + dispatchGemmConfig, cutlass::gemm::GemmShape<32, 64, 64>>(A, B, + quantOption, alphaCol, alphaRow, C, m, n, k, gemmConfig, workspace, workspaceBytes, stream, occupancy); + break; + case tkc::CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64: + dispatchGemmConfig, cutlass::gemm::GemmShape<64, 64, 64>>(A, B, + quantOption, alphaCol, alphaRow, C, m, n, k, gemmConfig, workspace, workspaceBytes, stream, occupancy); + break; + case tkc::CutlassTileConfig::Undefined: + throw std::runtime_error("[TensorRT-LLM Error][int8][dispatch_gemm_to_cutlass] gemm config undefined."); + break; + case tkc::CutlassTileConfig::ChooseWithHeuristic: + throw std::runtime_error( + "[TensorRT-LLM Error][int8][dispatch_gemm_to_cutlass] gemm config should have already been set by " + "heuristic."); + break; + default: + throw std::runtime_error( + "[TensorRT-LLM Error][int8][dispatch_gemm_to_cutlass] Config is invalid for int8 GEMM."); + break; + } +} + +template +CutlassInt8GemmRunner::CutlassInt8GemmRunner() +{ + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); + int device{-1}; + tk::check_cuda_error(cudaGetDevice(&device)); + mSm = tk::getSMVersion(); + tk::check_cuda_error(cudaDeviceGetAttribute(&mMultiProcessorCount, cudaDevAttrMultiProcessorCount, device)); +} + +template +CutlassInt8GemmRunner::~CutlassInt8GemmRunner() +{ + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); +} + +template +void CutlassInt8GemmRunner::dispatchToArch(const int8_t* A, const int8_t* B, tk::QuantMode quantOption, + const float* alphaCol, const float* alphaRow, T* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, + char* workspacePtr, const size_t workspaceBytes, cudaStream_t stream, int* occupancy) +{ + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); + if (mSm >= 70 && mSm < 72) + { + dispatchGemmToCutlass(A, B, quantOption, alphaCol, alphaRow, C, m, n, k, workspacePtr, + workspaceBytes, gemmConfig, stream, occupancy); + } + else if (mSm >= 72 && mSm < 75) + { + dispatchGemmToCutlass(A, B, quantOption, alphaCol, alphaRow, C, m, n, k, workspacePtr, + workspaceBytes, gemmConfig, stream, occupancy); + } + else if (mSm >= 75 && mSm < 80) + { + dispatchGemmToCutlass(A, B, quantOption, alphaCol, alphaRow, C, m, n, k, workspacePtr, + workspaceBytes, gemmConfig, stream, occupancy); + } + else if (mSm >= 80 && mSm <= 90) + { + dispatchGemmToCutlass(A, B, quantOption, alphaCol, alphaRow, C, m, n, k, workspacePtr, + workspaceBytes, gemmConfig, stream, occupancy); + } + else + { + throw std::runtime_error( + "[TensorRT-LLM Error][CutlassInt8GemmRunner][GEMM Dispatch] Arch unsupported for CUTLASS int8 GEMM"); + } +} + +template +void CutlassInt8GemmRunner::gemm(const int8_t* A, const int8_t* B, tk::QuantMode quantOption, const float* alphaCol, + const float* alphaRow, void* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspacePtr, + const size_t workspaceBytes, cudaStream_t stream) +{ + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); + dispatchToArch(A, B, quantOption, alphaCol, alphaRow, reinterpret_cast(C), m, n, k, gemmConfig, workspacePtr, + workspaceBytes, stream); +} + +template +std::vector CutlassInt8GemmRunner::getConfigs() const +{ + static constexpr bool isWeightOnly = false; + std::vector candidateConfigs + = get_candidate_configs(mSm, isWeightOnly, mSm <= 70, /* SIMT configs */ + true, SPLIT_K_LIMIT); /* INT8 configs */ + return candidateConfigs; +} + +template +size_t CutlassInt8GemmRunner::getWorkspaceSize(const int m, const int n, const int k) +{ + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); + // These are the min tile sizes for each config, which would launch the maximum number of blocks + const int maxGridM = cutlass::ceil_div(m, MIN_M_TILE); + const int maxGridN = cutlass::ceil_div(m, MIN_N_TILE); + // We need 4 bytes per block in the worst case. We launch SPLIT_K_LIMIT in z dim. + return static_cast(maxGridM * maxGridN * SPLIT_K_LIMIT * 4); +} + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels.h b/csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels.h new file mode 100644 index 000000000000..3f4def7d7152 --- /dev/null +++ b/csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels.h @@ -0,0 +1,81 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h" +#include +#include + +namespace tensorrt_llm +{ + +// Note update moe.py to match +enum class ActivationType +{ + Gelu = 0, + Relu, + Silu, + Swiglu, + Geglu, + Identity, + InvalidType +}; + +constexpr bool isGatedActivation(ActivationType activation_type) +{ + return activation_type == ActivationType::Swiglu || activation_type == ActivationType::Geglu; +} + +template +class MoeGemmRunner +{ +public: + MoeGemmRunner(); + + void setBestConfig(std::optional best_config) + { + best_config_ = std::move(best_config); + } + + void moeGemmBiasAct(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + ActivationType activation_type, cudaStream_t stream); + + void moeGemm(const T* A, const WeightType* B, const T* weight_scales, T* C, int64_t* total_rows_before_expert, + int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, cudaStream_t stream); + + std::vector getConfigs(); + +private: + template + void dispatchToArch(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, cudaStream_t stream, int* occupancy = nullptr); + + template + void runGemm(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cudaStream_t stream); + +private: + int sm_; + int multi_processor_count_; + std::optional best_config_{}; +}; + +} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_bf16.cu b/csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_bf16.cu new file mode 100644 index 000000000000..42699295b7e8 --- /dev/null +++ b/csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_bf16.cu @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" + +namespace tensorrt_llm +{ +#ifdef ENABLE_BF16 +template class MoeGemmRunner<__nv_bfloat16, __nv_bfloat16>; +#endif +} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint4.cu b/csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint4.cu new file mode 100644 index 000000000000..b5d129ca91c0 --- /dev/null +++ b/csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint4.cu @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" + +namespace tensorrt_llm +{ +#ifdef ENABLE_BF16 +template class MoeGemmRunner<__nv_bfloat16, cutlass::uint4b_t>; +#endif +} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint8.cu b/csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint8.cu new file mode 100644 index 000000000000..174d5a7b907e --- /dev/null +++ b/csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint8.cu @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" + +namespace tensorrt_llm +{ +#ifdef ENABLE_BF16 +template class MoeGemmRunner<__nv_bfloat16, uint8_t>; +#endif +} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp16.cu b/csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp16.cu new file mode 100644 index 000000000000..f57d91f9d810 --- /dev/null +++ b/csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp16.cu @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" + +namespace tensorrt_llm +{ +template class MoeGemmRunner; +} diff --git a/csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint4.cu b/csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint4.cu new file mode 100644 index 000000000000..3f4b0bb718fd --- /dev/null +++ b/csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint4.cu @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" + +namespace tensorrt_llm +{ +template class MoeGemmRunner; +} diff --git a/csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint8.cu b/csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint8.cu new file mode 100644 index 000000000000..a8d2d5e6c8eb --- /dev/null +++ b/csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint8.cu @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" + +namespace tensorrt_llm +{ +template class MoeGemmRunner; +} diff --git a/csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp32_fp32.cu b/csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp32_fp32.cu new file mode 100644 index 000000000000..6b57aae1d844 --- /dev/null +++ b/csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp32_fp32.cu @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" + +namespace tensorrt_llm +{ +template class MoeGemmRunner; +} diff --git a/csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h b/csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h new file mode 100644 index 000000000000..19a480dc8986 --- /dev/null +++ b/csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h @@ -0,0 +1,440 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Ignore CUTLASS warnings about type punning +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" + +#include "cutlass/array.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/gemm/device/gemm_grouped.h" +#include "cutlass/gemm/kernel/default_gemm_grouped.h" + +#include "cutlass_extensions/compute_occupancy.h" +#include "cutlass_extensions/epilogue_helpers.h" +#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h" +#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h" +#include "cutlass_extensions/gemm/threadblock/default_mma.h" + +#pragma GCC diagnostic pop + +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h" +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h" +#include +#include +#include +#include + +namespace tensorrt_llm +{ + +// ============================= Variable batched Gemm things =========================== +template +void genericMoeGemmKernelLauncher(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, const int multi_processor_count, cudaStream_t stream, + int* kernel_occupancy = nullptr) +{ +#ifdef ENABLE_BF16 + static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value + || cutlass::platform::is_same::value, + "Specialized for bfloat16, half, float"); +#else + static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value, + "Specialized for half, float"); +#endif + + static_assert(cutlass::platform::is_same::value + || cutlass::platform::is_same::value + || cutlass::platform::is_same::value, + ""); + + // The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary. + using ElementType_ = + typename cutlass::platform::conditional::value, cutlass::half_t, T>::type; +#ifdef ENABLE_BF16 + using ElementType = + typename cutlass::platform::conditional::value, + cutlass::bfloat16_t, ElementType_>::type; +#else + using ElementType = ElementType_; +#endif + + using CutlassWeightType_ = + typename cutlass::platform::conditional::value, cutlass::half_t, + WeightType>::type; +#ifdef ENABLE_BF16 + using CutlassWeightType = + typename cutlass::platform::conditional::value, + cutlass::bfloat16_t, CutlassWeightType_>::type; +#else + using CutlassWeightType = CutlassWeightType_; +#endif + + // We need separate config for each architecture since we will target different tensorcore instructions. For float, + // we do not target TCs. + using MixedGemmArchTraits = cutlass::gemm::kernel::MixedGemmArchTraits; + using ElementAccumulator = typename MixedGemmArchTraits::AccType; + + using EpilogueOp = typename tensorrt_llm::cutlass_extensions::Epilogue::Op; + + // Finally, set up the kernel. + using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemmGrouped::GemmKernel; + + using GemmKernel = cutlass::gemm::kernel::MoeFCGemm; + + using GemmGrouped = cutlass::gemm::device::GemmGrouped; + + if (kernel_occupancy != nullptr) + { + *kernel_occupancy = tensorrt_llm::cutlass_extensions::compute_occupancy_for_kernel(); + return; + } + int occupancy = std::min(2, GemmGrouped::maximum_active_blocks()); + TLLM_CHECK_WITH_INFO(occupancy != 0, "GPU lacks the shared memory resources to run GroupedGEMM kernel"); + const int threadblock_count = multi_processor_count * occupancy; + + typename EpilogueOp::Params epilogue_op( + ElementAccumulator(1.f), biases ? ElementAccumulator(1.f) : ElementAccumulator(0.f)); + + const int group_size = gemm_k; + typename GemmGrouped::Arguments args(num_experts, threadblock_count, group_size, epilogue_op, + reinterpret_cast(A), reinterpret_cast(B), + reinterpret_cast(weight_scales), reinterpret_cast(biases), + reinterpret_cast(C), total_rows_before_expert, gemm_n, gemm_k); + + GemmGrouped gemm; + + auto can_implement = gemm.can_implement(args); + TLLM_CHECK_WITH_INFO(can_implement == cutlass::Status::kSuccess, + "MoE FC kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement))); + + auto init_status = gemm.initialize(args); + TLLM_CHECK_WITH_INFO(init_status == cutlass::Status::kSuccess, + "Failed to initialize cutlass variable batched gemm. Error: " + + std::string(cutlassGetStatusString(init_status))); + + auto run_status = gemm.run(stream); + TLLM_CHECK_WITH_INFO(run_status == cutlass::Status::kSuccess, + "Failed to run cutlass variable batched gemm. Error: " + std::string(cutlassGetStatusString(run_status))); +} + +template +struct dispatch_stages +{ + static void dispatch(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, + int* occupancy = nullptr) + { + TLLM_THROW("Cutlass fpA_intB gemm. Not instantiated for arch %d with stages set to %d", + arch::kMinComputeCapability, Stages); + } +}; + +template +struct dispatch_stages +{ + static void dispatch(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, + int* occupancy = nullptr) + { + genericMoeGemmKernelLauncher(A, B, + weight_scales, biases, C, total_rows_before_expert, num_rows, gemm_n, gemm_k, num_experts, gemm_config, + multi_processor_count, stream, occupancy); + } +}; + +template +struct dispatch_stages 2)>::type> +{ + static void dispatch(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, + int* occupancy = nullptr) + { + genericMoeGemmKernelLauncher(A, B, weight_scales, biases, C, total_rows_before_expert, num_rows, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count, stream, occupancy); + } +}; + +template +void dispatchGemmConfig(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, + int* occupancy = nullptr) +{ + switch (gemm_config.stages) + { + case 2: + using DispatcherStages2 = dispatch_stages; + DispatcherStages2::dispatch(A, B, weight_scales, biases, C, total_rows_before_expert, num_rows, gemm_n, gemm_k, + num_experts, gemm_config, multi_processor_count, stream, occupancy); + break; + case 3: + using DispatcherStages3 = dispatch_stages; + DispatcherStages3::dispatch(A, B, weight_scales, biases, C, total_rows_before_expert, num_rows, gemm_n, gemm_k, + num_experts, gemm_config, multi_processor_count, stream, occupancy); + break; + case 4: + using DispatcherStages4 = dispatch_stages; + DispatcherStages4::dispatch(A, B, weight_scales, biases, C, total_rows_before_expert, num_rows, gemm_n, gemm_k, + num_experts, gemm_config, multi_processor_count, stream, occupancy); + break; + default: TLLM_THROW("dispatchGemmConfig does not support stages %d", gemm_config.stages); break; + } +} + +// This overload will handle tensorop gemms. It is disabled via SFINAE for fp32. +// This overload is only enabled when T == WeightType. +template ::value && std::is_same::value>::type* = nullptr> +void dispatchMoeGemmToCutlass(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, int sm_version, int multi_processor_count, cudaStream_t stream, + int* occupancy = nullptr) +{ + switch (gemm_config.tile_config) + { + case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, + gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<32, 64, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, + gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, + gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break; + case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic: + TLLM_THROW("GEMM config should have already been set by heuristic."); + break; + default: TLLM_THROW("Config is invalid for same type tensorop GEMM."); break; + } +} + +// Tensorop GEMM overload +// Overload for quantize MoE GEMMs. We disable some warp configs here since they will not be used and we can improve +// compile time +template ::value && !std::is_same::value>::type* = nullptr> +void dispatchMoeGemmToCutlass(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, int sm_version, int multi_processor_count, cudaStream_t stream, + int* occupancy = nullptr) +{ + switch (gemm_config.tile_config) + { + case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, + gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, + gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<128, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, + gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break; + case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic: + TLLM_THROW("GEMM config should have already been set by heuristic."); + break; + default: TLLM_THROW("Config is invalid for mixed type tensorop GEMM."); break; + } +} + +// This overload will handle simt gemms. It is disabled via SFINAE for tensorop. +template ::value>::type* = nullptr> +void dispatchMoeGemmToCutlass(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, int sm_version, int multi_processor_count, cudaStream_t stream, + int* occupancy = nullptr) +{ + switch (gemm_config.tile_config) + { + case cutlass_extensions::CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8: + dispatchGemmConfig, + cutlass::gemm::GemmShape<64, 64, 8>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, + gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break; + case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic: + TLLM_THROW("GEMM config should have already been set by heuristic."); + break; + default: TLLM_THROW("Unsupported config for float MoE gemm."); break; + } +} + +template +std::vector MoeGemmRunner::getConfigs() +{ + static constexpr bool is_weight_only = !std::is_same::value; + static constexpr bool only_simt_configs = std::is_same::value; + std::vector candidate_configs + = kernels::cutlass_kernels::get_candidate_configs(sm_, is_weight_only, only_simt_configs); + return candidate_configs; +} + +template +MoeGemmRunner::MoeGemmRunner() +{ + int device{-1}; + tensorrt_llm::common::check_cuda_error(cudaGetDevice(&device)); + sm_ = tensorrt_llm::common::getSMVersion(); + tensorrt_llm::common::check_cuda_error( + cudaDeviceGetAttribute(&multi_processor_count_, cudaDevAttrMultiProcessorCount, device)); +} + +template +template +void MoeGemmRunner::dispatchToArch(const T* A, const WeightType* B, const T* weight_scales, + const T* biases, T* C, int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, + int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, cudaStream_t stream, int* occupancy) +{ + if (sm_ >= 70 && sm_ < 75) + { + dispatchMoeGemmToCutlass(A, B, weight_scales, biases, C, + total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, sm_, multi_processor_count_, + stream, occupancy); + } + else if (sm_ >= 75 && sm_ < 80) + { + dispatchMoeGemmToCutlass(A, B, weight_scales, biases, C, + total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, sm_, multi_processor_count_, + stream, occupancy); + } + else if (sm_ >= 80 && sm_ < 90) + { + dispatchMoeGemmToCutlass(A, B, weight_scales, biases, C, + total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, sm_, multi_processor_count_, + stream, occupancy); + } + else if (sm_ >= 90) + { + // TODO Update the arch to Sm90 once CUTLASS hopper specialisations are available + dispatchMoeGemmToCutlass(A, B, weight_scales, biases, C, + total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, sm_, multi_processor_count_, + stream, occupancy); + } + else + { + TLLM_THROW("Arch unsupported for MoE GEMM"); + } +} + +template +template +void MoeGemmRunner::runGemm(const T* A, const WeightType* B, const T* weight_scales, + const T* biases, T* C, int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, + int num_experts, cudaStream_t stream) +{ + auto chosen_conf = this->best_config_; + if (!chosen_conf) + { + auto candidate_configs = getConfigs(); + std::vector occupancies(candidate_configs.size()); + + for (size_t ii = 0; ii < candidate_configs.size(); ++ii) + { + dispatchToArch(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, + gemm_k, num_experts, candidate_configs[ii], stream, &occupancies[ii]); + } + + static constexpr int workspace_bytes = 0; // No workspace for MoE GEMMs. + static constexpr int split_k_limit = 1; // MoE GEMM does not support split-k. + + static constexpr bool is_weight_only = !std::is_same::value; + chosen_conf = kernels::cutlass_kernels::estimate_best_config_from_occupancies(candidate_configs, occupancies, + total_rows, gemm_n, gemm_k, num_experts, split_k_limit, workspace_bytes, multi_processor_count_, + is_weight_only); + } + assert(chosen_conf); + dispatchToArch(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, + num_experts, *chosen_conf, stream); +} + +template +void MoeGemmRunner::moeGemmBiasAct(const T* A, const WeightType* B, const T* weight_scales, + const T* biases, T* C, int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, + int num_experts, ActivationType activation_type, cudaStream_t stream) +{ + switch (activation_type) + { + case ActivationType::Relu: + runGemm( + A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, stream); + break; + case ActivationType::Gelu: + runGemm( + A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, stream); + break; + case ActivationType::Silu: + runGemm( + A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, stream); + break; + case ActivationType::Identity: + runGemm( + A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, stream); + break; + case ActivationType::InvalidType: TLLM_THROW("Activation type for fpA_intB must be valid."); break; + default: TLLM_THROW("Invalid activation type."); break; + } +} + +template +void MoeGemmRunner::moeGemm(const T* A, const WeightType* B, const T* weight_scales, T* C, + int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cudaStream_t stream) +{ + runGemm( + A, B, weight_scales, nullptr, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, stream); +} + +} // namespace tensorrt_llm From 90ccdfa7d3ee8c336b771446c9055dbd83c1c64e Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 31 Jan 2024 10:32:42 +0000 Subject: [PATCH 07/33] Remove MoE gemm --- csrc/moe_gemm/moe_gemm_kernels.h | 81 ---- csrc/moe_gemm/moe_gemm_kernels_bf16_bf16.cu | 24 - csrc/moe_gemm/moe_gemm_kernels_bf16_uint4.cu | 24 - csrc/moe_gemm/moe_gemm_kernels_bf16_uint8.cu | 24 - csrc/moe_gemm/moe_gemm_kernels_fp16_fp16.cu | 22 - csrc/moe_gemm/moe_gemm_kernels_fp16_uint4.cu | 22 - csrc/moe_gemm/moe_gemm_kernels_fp16_uint8.cu | 22 - csrc/moe_gemm/moe_gemm_kernels_fp32_fp32.cu | 22 - csrc/moe_gemm/moe_gemm_kernels_template.h | 440 ------------------- 9 files changed, 681 deletions(-) delete mode 100644 csrc/moe_gemm/moe_gemm_kernels.h delete mode 100644 csrc/moe_gemm/moe_gemm_kernels_bf16_bf16.cu delete mode 100644 csrc/moe_gemm/moe_gemm_kernels_bf16_uint4.cu delete mode 100644 csrc/moe_gemm/moe_gemm_kernels_bf16_uint8.cu delete mode 100644 csrc/moe_gemm/moe_gemm_kernels_fp16_fp16.cu delete mode 100644 csrc/moe_gemm/moe_gemm_kernels_fp16_uint4.cu delete mode 100644 csrc/moe_gemm/moe_gemm_kernels_fp16_uint8.cu delete mode 100644 csrc/moe_gemm/moe_gemm_kernels_fp32_fp32.cu delete mode 100644 csrc/moe_gemm/moe_gemm_kernels_template.h diff --git a/csrc/moe_gemm/moe_gemm_kernels.h b/csrc/moe_gemm/moe_gemm_kernels.h deleted file mode 100644 index 3f4def7d7152..000000000000 --- a/csrc/moe_gemm/moe_gemm_kernels.h +++ /dev/null @@ -1,81 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include "tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h" -#include -#include - -namespace tensorrt_llm -{ - -// Note update moe.py to match -enum class ActivationType -{ - Gelu = 0, - Relu, - Silu, - Swiglu, - Geglu, - Identity, - InvalidType -}; - -constexpr bool isGatedActivation(ActivationType activation_type) -{ - return activation_type == ActivationType::Swiglu || activation_type == ActivationType::Geglu; -} - -template -class MoeGemmRunner -{ -public: - MoeGemmRunner(); - - void setBestConfig(std::optional best_config) - { - best_config_ = std::move(best_config); - } - - void moeGemmBiasAct(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, - int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - ActivationType activation_type, cudaStream_t stream); - - void moeGemm(const T* A, const WeightType* B, const T* weight_scales, T* C, int64_t* total_rows_before_expert, - int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, cudaStream_t stream); - - std::vector getConfigs(); - -private: - template - void dispatchToArch(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, - int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - cutlass_extensions::CutlassGemmConfig gemm_config, cudaStream_t stream, int* occupancy = nullptr); - - template - void runGemm(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, - int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - cudaStream_t stream); - -private: - int sm_; - int multi_processor_count_; - std::optional best_config_{}; -}; - -} // namespace tensorrt_llm diff --git a/csrc/moe_gemm/moe_gemm_kernels_bf16_bf16.cu b/csrc/moe_gemm/moe_gemm_kernels_bf16_bf16.cu deleted file mode 100644 index 42699295b7e8..000000000000 --- a/csrc/moe_gemm/moe_gemm_kernels_bf16_bf16.cu +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" - -namespace tensorrt_llm -{ -#ifdef ENABLE_BF16 -template class MoeGemmRunner<__nv_bfloat16, __nv_bfloat16>; -#endif -} // namespace tensorrt_llm diff --git a/csrc/moe_gemm/moe_gemm_kernels_bf16_uint4.cu b/csrc/moe_gemm/moe_gemm_kernels_bf16_uint4.cu deleted file mode 100644 index b5d129ca91c0..000000000000 --- a/csrc/moe_gemm/moe_gemm_kernels_bf16_uint4.cu +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" - -namespace tensorrt_llm -{ -#ifdef ENABLE_BF16 -template class MoeGemmRunner<__nv_bfloat16, cutlass::uint4b_t>; -#endif -} // namespace tensorrt_llm diff --git a/csrc/moe_gemm/moe_gemm_kernels_bf16_uint8.cu b/csrc/moe_gemm/moe_gemm_kernels_bf16_uint8.cu deleted file mode 100644 index 174d5a7b907e..000000000000 --- a/csrc/moe_gemm/moe_gemm_kernels_bf16_uint8.cu +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" - -namespace tensorrt_llm -{ -#ifdef ENABLE_BF16 -template class MoeGemmRunner<__nv_bfloat16, uint8_t>; -#endif -} // namespace tensorrt_llm diff --git a/csrc/moe_gemm/moe_gemm_kernels_fp16_fp16.cu b/csrc/moe_gemm/moe_gemm_kernels_fp16_fp16.cu deleted file mode 100644 index f57d91f9d810..000000000000 --- a/csrc/moe_gemm/moe_gemm_kernels_fp16_fp16.cu +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" - -namespace tensorrt_llm -{ -template class MoeGemmRunner; -} diff --git a/csrc/moe_gemm/moe_gemm_kernels_fp16_uint4.cu b/csrc/moe_gemm/moe_gemm_kernels_fp16_uint4.cu deleted file mode 100644 index 3f4b0bb718fd..000000000000 --- a/csrc/moe_gemm/moe_gemm_kernels_fp16_uint4.cu +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" - -namespace tensorrt_llm -{ -template class MoeGemmRunner; -} diff --git a/csrc/moe_gemm/moe_gemm_kernels_fp16_uint8.cu b/csrc/moe_gemm/moe_gemm_kernels_fp16_uint8.cu deleted file mode 100644 index a8d2d5e6c8eb..000000000000 --- a/csrc/moe_gemm/moe_gemm_kernels_fp16_uint8.cu +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" - -namespace tensorrt_llm -{ -template class MoeGemmRunner; -} diff --git a/csrc/moe_gemm/moe_gemm_kernels_fp32_fp32.cu b/csrc/moe_gemm/moe_gemm_kernels_fp32_fp32.cu deleted file mode 100644 index 6b57aae1d844..000000000000 --- a/csrc/moe_gemm/moe_gemm_kernels_fp32_fp32.cu +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" - -namespace tensorrt_llm -{ -template class MoeGemmRunner; -} diff --git a/csrc/moe_gemm/moe_gemm_kernels_template.h b/csrc/moe_gemm/moe_gemm_kernels_template.h deleted file mode 100644 index 19a480dc8986..000000000000 --- a/csrc/moe_gemm/moe_gemm_kernels_template.h +++ /dev/null @@ -1,440 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// Ignore CUTLASS warnings about type punning -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wstrict-aliasing" - -#include "cutlass/array.h" -#include "cutlass/numeric_conversion.h" - -#include "cutlass/gemm/device/gemm_grouped.h" -#include "cutlass/gemm/kernel/default_gemm_grouped.h" - -#include "cutlass_extensions/compute_occupancy.h" -#include "cutlass_extensions/epilogue_helpers.h" -#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h" -#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h" -#include "cutlass_extensions/gemm/threadblock/default_mma.h" - -#pragma GCC diagnostic pop - -#include "tensorrt_llm/common/assert.h" -#include "tensorrt_llm/common/cudaUtils.h" -#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h" -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h" -#include -#include -#include -#include - -namespace tensorrt_llm -{ - -// ============================= Variable batched Gemm things =========================== -template -void genericMoeGemmKernelLauncher(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, - int64_t* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - cutlass_extensions::CutlassGemmConfig gemm_config, const int multi_processor_count, cudaStream_t stream, - int* kernel_occupancy = nullptr) -{ -#ifdef ENABLE_BF16 - static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value - || cutlass::platform::is_same::value, - "Specialized for bfloat16, half, float"); -#else - static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value, - "Specialized for half, float"); -#endif - - static_assert(cutlass::platform::is_same::value - || cutlass::platform::is_same::value - || cutlass::platform::is_same::value, - ""); - - // The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary. - using ElementType_ = - typename cutlass::platform::conditional::value, cutlass::half_t, T>::type; -#ifdef ENABLE_BF16 - using ElementType = - typename cutlass::platform::conditional::value, - cutlass::bfloat16_t, ElementType_>::type; -#else - using ElementType = ElementType_; -#endif - - using CutlassWeightType_ = - typename cutlass::platform::conditional::value, cutlass::half_t, - WeightType>::type; -#ifdef ENABLE_BF16 - using CutlassWeightType = - typename cutlass::platform::conditional::value, - cutlass::bfloat16_t, CutlassWeightType_>::type; -#else - using CutlassWeightType = CutlassWeightType_; -#endif - - // We need separate config for each architecture since we will target different tensorcore instructions. For float, - // we do not target TCs. - using MixedGemmArchTraits = cutlass::gemm::kernel::MixedGemmArchTraits; - using ElementAccumulator = typename MixedGemmArchTraits::AccType; - - using EpilogueOp = typename tensorrt_llm::cutlass_extensions::Epilogue::Op; - - // Finally, set up the kernel. - using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemmGrouped::GemmKernel; - - using GemmKernel = cutlass::gemm::kernel::MoeFCGemm; - - using GemmGrouped = cutlass::gemm::device::GemmGrouped; - - if (kernel_occupancy != nullptr) - { - *kernel_occupancy = tensorrt_llm::cutlass_extensions::compute_occupancy_for_kernel(); - return; - } - int occupancy = std::min(2, GemmGrouped::maximum_active_blocks()); - TLLM_CHECK_WITH_INFO(occupancy != 0, "GPU lacks the shared memory resources to run GroupedGEMM kernel"); - const int threadblock_count = multi_processor_count * occupancy; - - typename EpilogueOp::Params epilogue_op( - ElementAccumulator(1.f), biases ? ElementAccumulator(1.f) : ElementAccumulator(0.f)); - - const int group_size = gemm_k; - typename GemmGrouped::Arguments args(num_experts, threadblock_count, group_size, epilogue_op, - reinterpret_cast(A), reinterpret_cast(B), - reinterpret_cast(weight_scales), reinterpret_cast(biases), - reinterpret_cast(C), total_rows_before_expert, gemm_n, gemm_k); - - GemmGrouped gemm; - - auto can_implement = gemm.can_implement(args); - TLLM_CHECK_WITH_INFO(can_implement == cutlass::Status::kSuccess, - "MoE FC kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement))); - - auto init_status = gemm.initialize(args); - TLLM_CHECK_WITH_INFO(init_status == cutlass::Status::kSuccess, - "Failed to initialize cutlass variable batched gemm. Error: " - + std::string(cutlassGetStatusString(init_status))); - - auto run_status = gemm.run(stream); - TLLM_CHECK_WITH_INFO(run_status == cutlass::Status::kSuccess, - "Failed to run cutlass variable batched gemm. Error: " + std::string(cutlassGetStatusString(run_status))); -} - -template -struct dispatch_stages -{ - static void dispatch(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, - int64_t* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, - int* occupancy = nullptr) - { - TLLM_THROW("Cutlass fpA_intB gemm. Not instantiated for arch %d with stages set to %d", - arch::kMinComputeCapability, Stages); - } -}; - -template -struct dispatch_stages -{ - static void dispatch(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, - int64_t* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, - int* occupancy = nullptr) - { - genericMoeGemmKernelLauncher(A, B, - weight_scales, biases, C, total_rows_before_expert, num_rows, gemm_n, gemm_k, num_experts, gemm_config, - multi_processor_count, stream, occupancy); - } -}; - -template -struct dispatch_stages 2)>::type> -{ - static void dispatch(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, - int64_t* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, - int* occupancy = nullptr) - { - genericMoeGemmKernelLauncher(A, B, weight_scales, biases, C, total_rows_before_expert, num_rows, gemm_n, gemm_k, num_experts, - gemm_config, multi_processor_count, stream, occupancy); - } -}; - -template -void dispatchGemmConfig(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, - int64_t* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, - int* occupancy = nullptr) -{ - switch (gemm_config.stages) - { - case 2: - using DispatcherStages2 = dispatch_stages; - DispatcherStages2::dispatch(A, B, weight_scales, biases, C, total_rows_before_expert, num_rows, gemm_n, gemm_k, - num_experts, gemm_config, multi_processor_count, stream, occupancy); - break; - case 3: - using DispatcherStages3 = dispatch_stages; - DispatcherStages3::dispatch(A, B, weight_scales, biases, C, total_rows_before_expert, num_rows, gemm_n, gemm_k, - num_experts, gemm_config, multi_processor_count, stream, occupancy); - break; - case 4: - using DispatcherStages4 = dispatch_stages; - DispatcherStages4::dispatch(A, B, weight_scales, biases, C, total_rows_before_expert, num_rows, gemm_n, gemm_k, - num_experts, gemm_config, multi_processor_count, stream, occupancy); - break; - default: TLLM_THROW("dispatchGemmConfig does not support stages %d", gemm_config.stages); break; - } -} - -// This overload will handle tensorop gemms. It is disabled via SFINAE for fp32. -// This overload is only enabled when T == WeightType. -template ::value && std::is_same::value>::type* = nullptr> -void dispatchMoeGemmToCutlass(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, - int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - cutlass_extensions::CutlassGemmConfig gemm_config, int sm_version, int multi_processor_count, cudaStream_t stream, - int* occupancy = nullptr) -{ - switch (gemm_config.tile_config) - { - case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: - dispatchGemmConfig, - cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, - gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy); - break; - case cutlass_extensions::CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64: - dispatchGemmConfig, - cutlass::gemm::GemmShape<32, 64, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, - gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy); - break; - case cutlass_extensions::CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64: - dispatchGemmConfig, - cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, - gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy); - break; - case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break; - case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic: - TLLM_THROW("GEMM config should have already been set by heuristic."); - break; - default: TLLM_THROW("Config is invalid for same type tensorop GEMM."); break; - } -} - -// Tensorop GEMM overload -// Overload for quantize MoE GEMMs. We disable some warp configs here since they will not be used and we can improve -// compile time -template ::value && !std::is_same::value>::type* = nullptr> -void dispatchMoeGemmToCutlass(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, - int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - cutlass_extensions::CutlassGemmConfig gemm_config, int sm_version, int multi_processor_count, cudaStream_t stream, - int* occupancy = nullptr) -{ - switch (gemm_config.tile_config) - { - case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: - dispatchGemmConfig, - cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, - gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy); - break; - case cutlass_extensions::CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: - dispatchGemmConfig, - cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, - gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy); - break; - case cutlass_extensions::CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: - dispatchGemmConfig, - cutlass::gemm::GemmShape<128, 32, 64>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, - gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy); - break; - case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break; - case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic: - TLLM_THROW("GEMM config should have already been set by heuristic."); - break; - default: TLLM_THROW("Config is invalid for mixed type tensorop GEMM."); break; - } -} - -// This overload will handle simt gemms. It is disabled via SFINAE for tensorop. -template ::value>::type* = nullptr> -void dispatchMoeGemmToCutlass(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, - int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - cutlass_extensions::CutlassGemmConfig gemm_config, int sm_version, int multi_processor_count, cudaStream_t stream, - int* occupancy = nullptr) -{ - switch (gemm_config.tile_config) - { - case cutlass_extensions::CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8: - dispatchGemmConfig, - cutlass::gemm::GemmShape<64, 64, 8>>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, - gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy); - break; - case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break; - case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic: - TLLM_THROW("GEMM config should have already been set by heuristic."); - break; - default: TLLM_THROW("Unsupported config for float MoE gemm."); break; - } -} - -template -std::vector MoeGemmRunner::getConfigs() -{ - static constexpr bool is_weight_only = !std::is_same::value; - static constexpr bool only_simt_configs = std::is_same::value; - std::vector candidate_configs - = kernels::cutlass_kernels::get_candidate_configs(sm_, is_weight_only, only_simt_configs); - return candidate_configs; -} - -template -MoeGemmRunner::MoeGemmRunner() -{ - int device{-1}; - tensorrt_llm::common::check_cuda_error(cudaGetDevice(&device)); - sm_ = tensorrt_llm::common::getSMVersion(); - tensorrt_llm::common::check_cuda_error( - cudaDeviceGetAttribute(&multi_processor_count_, cudaDevAttrMultiProcessorCount, device)); -} - -template -template -void MoeGemmRunner::dispatchToArch(const T* A, const WeightType* B, const T* weight_scales, - const T* biases, T* C, int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, - int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, cudaStream_t stream, int* occupancy) -{ - if (sm_ >= 70 && sm_ < 75) - { - dispatchMoeGemmToCutlass(A, B, weight_scales, biases, C, - total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, sm_, multi_processor_count_, - stream, occupancy); - } - else if (sm_ >= 75 && sm_ < 80) - { - dispatchMoeGemmToCutlass(A, B, weight_scales, biases, C, - total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, sm_, multi_processor_count_, - stream, occupancy); - } - else if (sm_ >= 80 && sm_ < 90) - { - dispatchMoeGemmToCutlass(A, B, weight_scales, biases, C, - total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, sm_, multi_processor_count_, - stream, occupancy); - } - else if (sm_ >= 90) - { - // TODO Update the arch to Sm90 once CUTLASS hopper specialisations are available - dispatchMoeGemmToCutlass(A, B, weight_scales, biases, C, - total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, sm_, multi_processor_count_, - stream, occupancy); - } - else - { - TLLM_THROW("Arch unsupported for MoE GEMM"); - } -} - -template -template -void MoeGemmRunner::runGemm(const T* A, const WeightType* B, const T* weight_scales, - const T* biases, T* C, int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, - int num_experts, cudaStream_t stream) -{ - auto chosen_conf = this->best_config_; - if (!chosen_conf) - { - auto candidate_configs = getConfigs(); - std::vector occupancies(candidate_configs.size()); - - for (size_t ii = 0; ii < candidate_configs.size(); ++ii) - { - dispatchToArch(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, - gemm_k, num_experts, candidate_configs[ii], stream, &occupancies[ii]); - } - - static constexpr int workspace_bytes = 0; // No workspace for MoE GEMMs. - static constexpr int split_k_limit = 1; // MoE GEMM does not support split-k. - - static constexpr bool is_weight_only = !std::is_same::value; - chosen_conf = kernels::cutlass_kernels::estimate_best_config_from_occupancies(candidate_configs, occupancies, - total_rows, gemm_n, gemm_k, num_experts, split_k_limit, workspace_bytes, multi_processor_count_, - is_weight_only); - } - assert(chosen_conf); - dispatchToArch(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, - num_experts, *chosen_conf, stream); -} - -template -void MoeGemmRunner::moeGemmBiasAct(const T* A, const WeightType* B, const T* weight_scales, - const T* biases, T* C, int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, - int num_experts, ActivationType activation_type, cudaStream_t stream) -{ - switch (activation_type) - { - case ActivationType::Relu: - runGemm( - A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, stream); - break; - case ActivationType::Gelu: - runGemm( - A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, stream); - break; - case ActivationType::Silu: - runGemm( - A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, stream); - break; - case ActivationType::Identity: - runGemm( - A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, stream); - break; - case ActivationType::InvalidType: TLLM_THROW("Activation type for fpA_intB must be valid."); break; - default: TLLM_THROW("Invalid activation type."); break; - } -} - -template -void MoeGemmRunner::moeGemm(const T* A, const WeightType* B, const T* weight_scales, T* C, - int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - cudaStream_t stream) -{ - runGemm( - A, B, weight_scales, nullptr, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, stream); -} - -} // namespace tensorrt_llm From 77a5c8d63ddf348bac193934c405816046e802cd Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 2 Feb 2024 02:21:18 +0000 Subject: [PATCH 08/33] Remove unused CUTLASS kernels --- .../bf16_int4_gemm_fg_scalebias.cu | 31 -- .../bf16_int4_gemm_fg_scaleonly.cu | 31 -- .../fpA_intB_gemm/bf16_int4_gemm_per_col.cu | 31 -- .../bf16_int8_gemm_fg_scalebias.cu | 31 -- .../bf16_int8_gemm_fg_scaleonly.cu | 30 -- .../fpA_intB_gemm/bf16_int8_gemm_per_col.cu | 30 -- .../fp16_int4_gemm_fg_scalebias.cu | 29 -- .../fp16_int4_gemm_fg_scaleonly.cu | 28 - .../fpA_intB_gemm/fp16_int4_gemm_per_col.cu | 28 - .../fp16_int8_gemm_fg_scalebias.cu | 28 - .../fp16_int8_gemm_fg_scaleonly.cu | 28 - .../fpA_intB_gemm/fp16_int8_gemm_per_col.cu | 28 - .../fpA_intB_gemm/fpA_intB_gemm.h | 120 ----- .../fpA_intB_gemm/fpA_intB_gemm_template.h | 487 ------------------ csrc/cutlass_kernels/int8_gemm/int8_gemm.h | 93 ---- .../int8_gemm/int8_gemm_bf16.cu | 32 -- .../int8_gemm/int8_gemm_fp16.cu | 30 -- .../int8_gemm/int8_gemm_fp32.cu | 30 -- .../int8_gemm/int8_gemm_int32.cu | 30 -- .../int8_gemm/int8_gemm_template.h | 388 -------------- 20 files changed, 1563 deletions(-) delete mode 100644 csrc/cutlass_kernels/fpA_intB_gemm/bf16_int4_gemm_fg_scalebias.cu delete mode 100644 csrc/cutlass_kernels/fpA_intB_gemm/bf16_int4_gemm_fg_scaleonly.cu delete mode 100644 csrc/cutlass_kernels/fpA_intB_gemm/bf16_int4_gemm_per_col.cu delete mode 100644 csrc/cutlass_kernels/fpA_intB_gemm/bf16_int8_gemm_fg_scalebias.cu delete mode 100644 csrc/cutlass_kernels/fpA_intB_gemm/bf16_int8_gemm_fg_scaleonly.cu delete mode 100644 csrc/cutlass_kernels/fpA_intB_gemm/bf16_int8_gemm_per_col.cu delete mode 100644 csrc/cutlass_kernels/fpA_intB_gemm/fp16_int4_gemm_fg_scalebias.cu delete mode 100644 csrc/cutlass_kernels/fpA_intB_gemm/fp16_int4_gemm_fg_scaleonly.cu delete mode 100644 csrc/cutlass_kernels/fpA_intB_gemm/fp16_int4_gemm_per_col.cu delete mode 100644 csrc/cutlass_kernels/fpA_intB_gemm/fp16_int8_gemm_fg_scalebias.cu delete mode 100644 csrc/cutlass_kernels/fpA_intB_gemm/fp16_int8_gemm_fg_scaleonly.cu delete mode 100644 csrc/cutlass_kernels/fpA_intB_gemm/fp16_int8_gemm_per_col.cu delete mode 100644 csrc/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h delete mode 100644 csrc/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h delete mode 100644 csrc/cutlass_kernels/int8_gemm/int8_gemm.h delete mode 100644 csrc/cutlass_kernels/int8_gemm/int8_gemm_bf16.cu delete mode 100644 csrc/cutlass_kernels/int8_gemm/int8_gemm_fp16.cu delete mode 100644 csrc/cutlass_kernels/int8_gemm/int8_gemm_fp32.cu delete mode 100644 csrc/cutlass_kernels/int8_gemm/int8_gemm_int32.cu delete mode 100644 csrc/cutlass_kernels/int8_gemm/int8_gemm_template.h diff --git a/csrc/cutlass_kernels/fpA_intB_gemm/bf16_int4_gemm_fg_scalebias.cu b/csrc/cutlass_kernels/fpA_intB_gemm/bf16_int4_gemm_fg_scalebias.cu deleted file mode 100644 index e4783fdefd16..000000000000 --- a/csrc/cutlass_kernels/fpA_intB_gemm/bf16_int4_gemm_fg_scalebias.cu +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" - -namespace tensorrt_llm -{ -namespace kernels -{ -namespace cutlass_kernels -{ -#ifdef ENABLE_BF16 -template class CutlassFpAIntBGemmRunner<__nv_bfloat16, cutlass::uint4b_t, - cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS>; -#endif -} // namespace cutlass_kernels -} // namespace kernels -} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/fpA_intB_gemm/bf16_int4_gemm_fg_scaleonly.cu b/csrc/cutlass_kernels/fpA_intB_gemm/bf16_int4_gemm_fg_scaleonly.cu deleted file mode 100644 index 8934a2c0df4e..000000000000 --- a/csrc/cutlass_kernels/fpA_intB_gemm/bf16_int4_gemm_fg_scaleonly.cu +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" - -namespace tensorrt_llm -{ -namespace kernels -{ -namespace cutlass_kernels -{ -#ifdef ENABLE_BF16 -template class CutlassFpAIntBGemmRunner<__nv_bfloat16, cutlass::uint4b_t, - cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>; -#endif -} // namespace cutlass_kernels -} // namespace kernels -} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/fpA_intB_gemm/bf16_int4_gemm_per_col.cu b/csrc/cutlass_kernels/fpA_intB_gemm/bf16_int4_gemm_per_col.cu deleted file mode 100644 index b3fa996a87c9..000000000000 --- a/csrc/cutlass_kernels/fpA_intB_gemm/bf16_int4_gemm_per_col.cu +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" - -namespace tensorrt_llm -{ -namespace kernels -{ -namespace cutlass_kernels -{ -#ifdef ENABLE_BF16 -template class CutlassFpAIntBGemmRunner<__nv_bfloat16, cutlass::uint4b_t, - cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY>; -#endif -} // namespace cutlass_kernels -} // namespace kernels -} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/fpA_intB_gemm/bf16_int8_gemm_fg_scalebias.cu b/csrc/cutlass_kernels/fpA_intB_gemm/bf16_int8_gemm_fg_scalebias.cu deleted file mode 100644 index 064e4dbde97b..000000000000 --- a/csrc/cutlass_kernels/fpA_intB_gemm/bf16_int8_gemm_fg_scalebias.cu +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" - -namespace tensorrt_llm -{ -namespace kernels -{ -namespace cutlass_kernels -{ -#ifdef ENABLE_BF16 -template class CutlassFpAIntBGemmRunner<__nv_bfloat16, uint8_t, - cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS>; -#endif -} // namespace cutlass_kernels -} // namespace kernels -} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/fpA_intB_gemm/bf16_int8_gemm_fg_scaleonly.cu b/csrc/cutlass_kernels/fpA_intB_gemm/bf16_int8_gemm_fg_scaleonly.cu deleted file mode 100644 index 0dbdfabe0a69..000000000000 --- a/csrc/cutlass_kernels/fpA_intB_gemm/bf16_int8_gemm_fg_scaleonly.cu +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" - -namespace tensorrt_llm -{ -namespace kernels -{ -namespace cutlass_kernels -{ -#ifdef ENABLE_BF16 -template class CutlassFpAIntBGemmRunner<__nv_bfloat16, uint8_t, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>; -#endif -} // namespace cutlass_kernels -} // namespace kernels -} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/fpA_intB_gemm/bf16_int8_gemm_per_col.cu b/csrc/cutlass_kernels/fpA_intB_gemm/bf16_int8_gemm_per_col.cu deleted file mode 100644 index 6701d0637ec4..000000000000 --- a/csrc/cutlass_kernels/fpA_intB_gemm/bf16_int8_gemm_per_col.cu +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" - -namespace tensorrt_llm -{ -namespace kernels -{ -namespace cutlass_kernels -{ -#ifdef ENABLE_BF16 -template class CutlassFpAIntBGemmRunner<__nv_bfloat16, uint8_t, cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY>; -#endif -} // namespace cutlass_kernels -} // namespace kernels -} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/fpA_intB_gemm/fp16_int4_gemm_fg_scalebias.cu b/csrc/cutlass_kernels/fpA_intB_gemm/fp16_int4_gemm_fg_scalebias.cu deleted file mode 100644 index 45e0f4c0f8d1..000000000000 --- a/csrc/cutlass_kernels/fpA_intB_gemm/fp16_int4_gemm_fg_scalebias.cu +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" - -namespace tensorrt_llm -{ -namespace kernels -{ -namespace cutlass_kernels -{ -template class CutlassFpAIntBGemmRunner; -} // namespace cutlass_kernels -} // namespace kernels -} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/fpA_intB_gemm/fp16_int4_gemm_fg_scaleonly.cu b/csrc/cutlass_kernels/fpA_intB_gemm/fp16_int4_gemm_fg_scaleonly.cu deleted file mode 100644 index 113c6c61741d..000000000000 --- a/csrc/cutlass_kernels/fpA_intB_gemm/fp16_int4_gemm_fg_scaleonly.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" - -namespace tensorrt_llm -{ -namespace kernels -{ -namespace cutlass_kernels -{ -template class CutlassFpAIntBGemmRunner; -} // namespace cutlass_kernels -} // namespace kernels -} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/fpA_intB_gemm/fp16_int4_gemm_per_col.cu b/csrc/cutlass_kernels/fpA_intB_gemm/fp16_int4_gemm_per_col.cu deleted file mode 100644 index 6e69985edc54..000000000000 --- a/csrc/cutlass_kernels/fpA_intB_gemm/fp16_int4_gemm_per_col.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" - -namespace tensorrt_llm -{ -namespace kernels -{ -namespace cutlass_kernels -{ -template class CutlassFpAIntBGemmRunner; -} // namespace cutlass_kernels -} // namespace kernels -} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/fpA_intB_gemm/fp16_int8_gemm_fg_scalebias.cu b/csrc/cutlass_kernels/fpA_intB_gemm/fp16_int8_gemm_fg_scalebias.cu deleted file mode 100644 index 51e33974f76d..000000000000 --- a/csrc/cutlass_kernels/fpA_intB_gemm/fp16_int8_gemm_fg_scalebias.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" - -namespace tensorrt_llm -{ -namespace kernels -{ -namespace cutlass_kernels -{ -template class CutlassFpAIntBGemmRunner; -} // namespace cutlass_kernels -} // namespace kernels -} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/fpA_intB_gemm/fp16_int8_gemm_fg_scaleonly.cu b/csrc/cutlass_kernels/fpA_intB_gemm/fp16_int8_gemm_fg_scaleonly.cu deleted file mode 100644 index 148cfb519e19..000000000000 --- a/csrc/cutlass_kernels/fpA_intB_gemm/fp16_int8_gemm_fg_scaleonly.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" - -namespace tensorrt_llm -{ -namespace kernels -{ -namespace cutlass_kernels -{ -template class CutlassFpAIntBGemmRunner; -} // namespace cutlass_kernels -} // namespace kernels -} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/fpA_intB_gemm/fp16_int8_gemm_per_col.cu b/csrc/cutlass_kernels/fpA_intB_gemm/fp16_int8_gemm_per_col.cu deleted file mode 100644 index 35d199f58f14..000000000000 --- a/csrc/cutlass_kernels/fpA_intB_gemm/fp16_int8_gemm_per_col.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" - -namespace tensorrt_llm -{ -namespace kernels -{ -namespace cutlass_kernels -{ -template class CutlassFpAIntBGemmRunner; -} // namespace cutlass_kernels -} // namespace kernels -} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h b/csrc/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h deleted file mode 100644 index c805f7a4e000..000000000000 --- a/csrc/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h +++ /dev/null @@ -1,120 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "cutlass_extensions/gemm_configs.h" -#include "cutlass_extensions/weight_only_quant_op.h" -#include - -namespace tkc = tensorrt_llm::cutlass_extensions; - -namespace tensorrt_llm -{ -namespace kernels -{ -namespace cutlass_kernels -{ - -// TRT Activation Type does not have Gelu or Silu -enum class ActivationType -{ - Gelu, - Relu, - Silu, - Identity, - InvalidType -}; - -/* - This runner only supports: - T in {half, __nv_bfloat} WeightType in {int8_t, cutlass::uint4b_t} - - Activations, biases, scales and outputs are all assumed to be row-major. - - However, it is assumed that B is in a special format governed by cutlass_extensions/gemm/kernel/mixed_gemm_B_layout. - In this case, B must be preprocessed using the cutlass weight only quant preprocessors. The weight preprocessor - will instantiate the layout and preprocess based on the instantiation, so layout changes should only require - modifications to mix_gemm_B_layout.h. -*/ - -class CutlassFpAIntBGemmRunnerInterface -{ -public: - CutlassFpAIntBGemmRunnerInterface() {} - - virtual ~CutlassFpAIntBGemmRunnerInterface() {} - - virtual void gemm(const void* A, const void* B, const void* weight_scales, void* C, int m, int n, int k, - tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) - = 0; - - virtual void gemm(const void* A, const void* B, const void* weight_scales, const void* weight_zero_points, - const void* biases, void* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemmConfig, - char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) - = 0; - - // Returns desired workspace size in bytes. - virtual size_t getWorkspaceSize(const int m, const int n, const int k) = 0; - - virtual std::vector getConfigs() const = 0; - -protected: - static constexpr int SPLIT_K_LIMIT = 7; - static constexpr int MIN_M_TILE = 32; - static constexpr int MIN_N_TILE = 64; -}; - -template -class CutlassFpAIntBGemmRunner : public virtual CutlassFpAIntBGemmRunnerInterface -{ -public: - CutlassFpAIntBGemmRunner(); - ~CutlassFpAIntBGemmRunner(); - - void gemm(const void* A, const void* B, const void* weight_scales, void* C, int m, int n, int k, - tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, - cudaStream_t stream) override; - - void gemm(const void* A, const void* B, const void* weight_scales, const void* weight_zero_points, - const void* biases, void* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemmConfig, - char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) override; - - // Disabled since the fused GEMM, activation kernels will not be used in v1. - - // void gemm_bias_act(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, int m, int n, - // int k, ActivationType activation_type, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t - // stream); - - // Returns desired workspace size in bytes. - size_t getWorkspaceSize(const int m, const int n, const int k) override; - - std::vector getConfigs() const override; - -private: - template - void dispatch_to_arch(const T* A, const WeightType* B, const T* weight_scales, const T* weight_zero_points, - const T* biases, T* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemm_config, - char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream, int* occupancy = nullptr); - -private: - int sm_; - int multi_processor_count_; -}; - -} // namespace cutlass_kernels -} // namespace kernels -} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h b/csrc/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h deleted file mode 100644 index b816b111c96e..000000000000 --- a/csrc/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h +++ /dev/null @@ -1,487 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef _WIN32 -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wstrict-aliasing" -#endif // #ifndef _WIN32 - -#include "cutlass/gemm/kernel/default_gemm.h" -#include "cutlass_extensions/compute_occupancy.h" -#include "cutlass_extensions/gemm/device/gemm_universal_base_compat.h" - -#include "cutlass_extensions/epilogue_helpers.h" -#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h" -#include "cutlass_extensions/gemm/kernel/fpA_intB_gemm.h" -#include "cutlass_extensions/gemm/threadblock/default_mma.h" -#include "cutlass_extensions/gemm_configs.h" - -#ifndef _WIN32 -#pragma GCC diagnostic pop -#endif // #ifndef _WIN32 - -#include "tensorrt_llm/common/assert.h" -#include "tensorrt_llm/common/cudaUtils.h" -#include "tensorrt_llm/common/logger.h" -#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h" -#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h" - -namespace tk = tensorrt_llm::common; -namespace tkc = tensorrt_llm::cutlass_extensions; - -namespace tensorrt_llm -{ -namespace kernels -{ -namespace cutlass_kernels -{ - -template -void generic_mixed_gemm_kernelLauncher(const T* A, const WeightType* B, const T* weight_scales, - const T* weight_zero_points, const T* biases, T* C, int m, int n, int k, const int group_size, - tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream, - int* occupancy = nullptr) -{ - TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); - -#ifdef ENABLE_BF16 - static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value - || cutlass::platform::is_same::value, - "Specialized for bfloat16, half, float"); -#else - static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value, - "Specialized for half, float"); -#endif - - static_assert(cutlass::platform::is_same::value - || cutlass::platform::is_same::value - || cutlass::platform::is_same::value, - ""); - - // The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary. - using ElementType_ = - typename cutlass::platform::conditional::value, cutlass::half_t, T>::type; -#ifdef ENABLE_BF16 - using ElementType = - typename cutlass::platform::conditional::value, - cutlass::bfloat16_t, ElementType_>::type; -#else - using ElementType = ElementType_; -#endif - - using CutlassWeightType_ = - typename cutlass::platform::conditional::value, cutlass::half_t, - WeightType>::type; -#ifdef ENABLE_BF16 - using CutlassWeightType = - typename cutlass::platform::conditional::value, - cutlass::bfloat16_t, CutlassWeightType_>::type; -#else - using CutlassWeightType = CutlassWeightType_; -#endif - - // We need separate config for each architecture since we will target different tensorcore instructions. For float, - // we do not target TCs. - using MixedGemmArchTraits = cutlass::gemm::kernel::MixedGemmArchTraits; - using ElementAccumulator = typename MixedGemmArchTraits::AccType; - - using EpilogueOp = typename tkc::Epilogue::Op; - - using Operator = typename MixedGemmArchTraits::Operator; - using TaggedOperator = typename cutlass::arch::TagOperator::TaggedOperator; - - using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm, Stages, true, - TaggedOperator>::GemmKernel; - - using GemmKernel = cutlass::gemm::kernel::GemmFpAIntB; - - if (occupancy != nullptr) - { - *occupancy = tensorrt_llm::cutlass_extensions::compute_occupancy_for_kernel(); - return; - } - - using Gemm = cutlass::gemm::device::GemmUniversalBaseCompat; - - const int ldb = cutlass::platform::is_same::value - ? n - : k * GemmKernel::kInterleave; - - if (weight_scales == nullptr) - { - throw std::runtime_error("Weight scales must always be set to a non-null value."); - } - - if constexpr (cutlass::isFinegrained(QuantOp)) - { - if (group_size != 64 && group_size != 128) - { - throw std::runtime_error("Only group size 64 and 128 supported for fine grained kernels."); - } - - if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY) - { - if (weight_zero_points != nullptr) - { - throw std::runtime_error("Weight zero pointer must be a nullptr for scale only fine grained"); - } - } - else if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS) - { - if (weight_zero_points == nullptr) - { - throw std::runtime_error("Weight zero pointer must be valid for scale and bias fine grained"); - } - } - } - else - { - if (group_size != k) - { - throw std::runtime_error("Invalid group size for per column scaling kernels."); - } - - if (weight_zero_points != nullptr) - { - throw std::runtime_error("Weight zero-points must be null when running per column scaling"); - } - } - - const int ld_scale_zero = cutlass::isFinegrained(QuantOp) ? n : 0; - ElementAccumulator output_op_beta = (biases == nullptr) ? ElementAccumulator(0.f) : ElementAccumulator(1.f); - typename Gemm::Arguments args({m, n, k}, group_size, {reinterpret_cast(const_cast(A)), k}, - {reinterpret_cast(const_cast(B)), ldb}, - {reinterpret_cast(const_cast(weight_scales)), ld_scale_zero}, - {reinterpret_cast(const_cast(weight_zero_points)), ld_scale_zero}, - {reinterpret_cast(const_cast(biases)), 0}, {reinterpret_cast(C), n}, - gemm_config.split_k_factor, {ElementAccumulator(1.f), output_op_beta}); - - // This assertion is enabled because because for the column interleaved layout, K MUST be a multiple of - // threadblockK. The reason for this is that the default pitchlinear iterators are used to handle walking over the - // interleaved matrix. The way masking in handled in these do not map to the interleaved layout. We need to write - // our own predicated iterator in order to relax this limitation. - if (GemmKernel::kInterleave > 1 - && ((k % MixedGemmArchTraits::ThreadblockK) - || ((k / gemm_config.split_k_factor) % MixedGemmArchTraits::ThreadblockK))) - { - throw std::runtime_error("Temp assertion: k must be multiple of threadblockK"); - } - - Gemm gemm; - if (gemm.get_workspace_size(args) > workspace_bytes) - { - TLLM_LOG_WARNING( - "Requested split-k but workspace size insufficient. Falling back to non-split-k implementation."); - // If requested split-k factor will require more workspace bytes, revert to standard gemm. - args.batch_count = 1; - } - - auto can_implement = gemm.can_implement(args); - if (can_implement != cutlass::Status::kSuccess) - { - std::string err_msg = "fpA_intB cutlass kernel will fail for params. Error: " - + std::string(cutlassGetStatusString(can_implement)); - throw std::runtime_error("[TensorRT-LLm Error][fpA_intB Runner] " + err_msg); - } - - auto init_status = gemm.initialize(args, workspace, stream); - if (init_status != cutlass::Status::kSuccess) - { - std::string err_msg - = "Failed to initialize cutlass fpA_intB gemm. Error: " + std::string(cutlassGetStatusString(init_status)); - throw std::runtime_error("[TensorRT-LLm Error][fpA_intB Runner] " + err_msg); - } - - auto run_status = gemm.run(stream); - if (run_status != cutlass::Status::kSuccess) - { - std::string err_msg - = "Failed to run cutlass fpA_intB gemm. Error: " + std::string(cutlassGetStatusString(run_status)); - throw std::runtime_error("[TensorRT-LLm Error][fpA_intB Runner] " + err_msg); - } -} - -// This filters out invalid template combinations that we DON'T want instantiated in CUTLASS. For example, -// instantiating SM=75, Stages=3 is invalid so we would need to filter that out. Fine grained -// quanitzation is only supported on Ampere+ GPUs. -template -void filter_and_run_mixed_gemm(const T* A, const WeightType* B, const T* weight_scales, const T* weight_zero_points, - const T* biases, T* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemm_config, - char* workspace, size_t workspace_bytes, cudaStream_t stream, int* occupancy = nullptr) -{ - - TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); - if constexpr (cutlass::isFinegrained(QuantOp) && arch::kMinComputeCapability < 80) - { - // Finegrained only supported on Ampere - std::string err_msg = "Cutlass fpA_intB gemm not implemented for arch " - + std::to_string(arch::kMinComputeCapability) + " with finegraind weight-only quantization."; - throw std::runtime_error("[TensorRT-LLm Error][filter_and_run_mixed_gemm] " + err_msg); - } - else if constexpr (Stages > 2 && arch::kMinComputeCapability < 80) - { - // Multistage only supported on Ampere - std::string err_msg = "Cutlass fpA_intB gemm not supported for arch " - + std::to_string(arch::kMinComputeCapability) + " with stages set to " + std::to_string(Stages); - throw std::runtime_error("[TensorRT-LLm Error][filter_and_run_mixed_gemm] " + err_msg); - } - else - { - generic_mixed_gemm_kernelLauncher(A, B, weight_scales, weight_zero_points, biases, C, m, n, k, group_size, gemm_config, workspace, - workspace_bytes, stream, occupancy); - } -} - -template -void dispatch_gemm_config(const T* A, const WeightType* B, const T* weight_scales, const T* weight_zero_points, - const T* biases, T* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemm_config, - char* workspace, size_t workspace_bytes, cudaStream_t stream, int* occupancy = nullptr) -{ - - TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); - switch (gemm_config.stages) - { - case 2: - filter_and_run_mixed_gemm(A, B, - weight_scales, weight_zero_points, biases, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, - stream, occupancy); - break; - case 3: - filter_and_run_mixed_gemm(A, B, - weight_scales, weight_zero_points, biases, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, - stream, occupancy); - break; - case 4: - filter_and_run_mixed_gemm(A, B, - weight_scales, weight_zero_points, biases, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, - stream, occupancy); - break; - default: - std::string err_msg = "dispatch_gemm_config does not support stages " + std::to_string(gemm_config.stages); - throw std::runtime_error("[TensorRT-LLm Error][dispatch_gemm_config] " + err_msg); - break; - } -} - -template -void dispatch_gemm_to_cutlass(const T* A, const WeightType* B, const T* weight_scales, const T* weight_zero_points, - const T* biases, T* C, int m, int n, int k, const int group_size, char* workspace, size_t workspace_bytes, - tkc::CutlassGemmConfig gemm_config, cudaStream_t stream, int* occupancy = nullptr) -{ - - TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); - - // Note that SIMT configs are omitted here since they are not supported for fpA_intB. - // We also only instantiate configs here where threadblockShapeM == warpShapeM since those usually perform the best - // for mixed type gemms. - switch (gemm_config.tile_config) - { - case tkc::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: - dispatch_gemm_config, - cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, weight_zero_points, biases, C, m, n, k, - group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); - break; - case tkc::CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: - dispatch_gemm_config, - cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, weight_zero_points, biases, C, m, n, k, - group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); - break; - case tkc::CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: - if (arch::kMinComputeCapability < 75) - { - TLLM_CHECK_WITH_INFO(false, "Invalid config on Volta"); - } - else - { - dispatch_gemm_config, - cutlass::gemm::GemmShape<128, 32, 64>>(A, B, weight_scales, weight_zero_points, biases, C, m, n, k, - group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); - } - break; - case tkc::CutlassTileConfig::Undefined: - throw std::runtime_error("[TensorRT-LLm Error][fpA_intB][dispatch_gemm_to_cutlass] gemm config undefined."); - break; - case tkc::CutlassTileConfig::ChooseWithHeuristic: - throw std::runtime_error( - "[TensorRT-LLm Error][fpA_intB][dispatch_gemm_to_cutlass] gemm config should have already been set by " - "heuristic."); - break; - default: - throw std::runtime_error( - "[TensorRT-LLm Error][fpA_intB][dispatch_gemm_to_cutlass] Config is invalid for mixed type GEMM."); - break; - } -} - -template -CutlassFpAIntBGemmRunner::CutlassFpAIntBGemmRunner() -{ - TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); - int device{-1}; - tk::check_cuda_error(cudaGetDevice(&device)); - sm_ = tk::getSMVersion(); - tk::check_cuda_error(cudaDeviceGetAttribute(&multi_processor_count_, cudaDevAttrMultiProcessorCount, device)); -} - -template -CutlassFpAIntBGemmRunner::~CutlassFpAIntBGemmRunner() -{ - TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); -} - -template -template -void CutlassFpAIntBGemmRunner::dispatch_to_arch(const T* A, const WeightType* B, - const T* weight_scales, const T* weight_zero_points, const T* biases, T* C, int m, int n, int k, - const int group_size, tkc::CutlassGemmConfig gemm_config, char* workspace_ptr, const size_t workspace_bytes, - cudaStream_t stream, int* occupancy) -{ - TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); - if (sm_ >= 70 && sm_ < 75) - { - dispatch_gemm_to_cutlass(A, B, weight_scales, - weight_zero_points, biases, C, m, n, k, group_size, workspace_ptr, workspace_bytes, gemm_config, stream, - occupancy); - } - else if (sm_ >= 75 && sm_ < 80) - { - dispatch_gemm_to_cutlass(A, B, weight_scales, - weight_zero_points, biases, C, m, n, k, group_size, workspace_ptr, workspace_bytes, gemm_config, stream, - occupancy); - } - else if (sm_ >= 80 && sm_ <= 90) - { - dispatch_gemm_to_cutlass(A, B, weight_scales, - weight_zero_points, biases, C, m, n, k, group_size, workspace_ptr, workspace_bytes, gemm_config, stream, - occupancy); - } - else - { - throw std::runtime_error( - "[TensorRT-LLm Error][CutlassFpAIntBGemmRunner][GEMM Dispatch] Arch unsupported for CUTLASS mixed type " - "GEMM"); - } -} - -// Disabled since the fused GEMM, activation kernels will not be used in v1. - -// template -// void CutlassFpAIntBGemmRunner::gemm_bias_act(const T* A, const WeightType* B, const T* -// weight_scales, -// const T* biases, T* C, int m, int n, int k, ActivationType activation_type, char* workspace_ptr, -// const size_t workspace_bytes, cudaStream_t stream) -// { -// TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); - -// switch (activation_type) -// { -// case ActivationType::Relu: -// run_gemm( -// A, B, weight_scales, biases, C, m, n, k, workspace_ptr, workspace_bytes, stream); -// break; -// case ActivationType::Gelu: -// run_gemm( -// A, B, weight_scales, biases, C, m, n, k, workspace_ptr, workspace_bytes, stream); -// break; -// case ActivationType::Silu: -// run_gemm( -// A, B, weight_scales, biases, C, m, n, k, workspace_ptr, workspace_bytes, stream); -// break; -// case ActivationType::Identity: -// run_gemm(A, B, weight_scales, biases, C, m, n, k, workspace_ptr, workspace_bytes, -// stream); break; -// case ActivationType::InvalidType: TLLM_CHECK_WITH_INFO(false, "Activation type for fpA_intB must be -// valid."); break; default: -// { -// TLLM_CHECK_WITH_INFO(false, "Invalid activation type."); -// } -// } -// } - -template -void CutlassFpAIntBGemmRunner::gemm(const void* A, const void* B, const void* weight_scales, - const void* weight_zero_points, const void* biases, void* C, int m, int n, int k, const int group_size, - tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) -{ - TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); - if constexpr ((QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS) - || (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY)) - { - dispatch_to_arch((const T*) A, (const WeightType*) B, (const T*) weight_scales, - (const T*) weight_zero_points, (const T*) biases, (T*) C, m, n, k, group_size, gemmConfig, workspace_ptr, - workspace_bytes, stream, nullptr); - } - else - { - throw std::runtime_error( - "Overload with scale, zero and group size only supported for fine grained bias template."); - } -} - -template -void CutlassFpAIntBGemmRunner::gemm(const void* A, const void* B, const void* weight_scales, - void* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, - cudaStream_t stream) -{ - TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); - - if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY) - { - dispatch_to_arch((const T*) A, (const WeightType*) B, (const T*) weight_scales, nullptr, - nullptr, (T*) C, m, n, k, k, gemmConfig, workspace_ptr, workspace_bytes, stream, nullptr); - } - else - { - throw std::runtime_error("Overload with scale only (and no group size) only supported for per column scaling."); - } -} - -template -std::vector CutlassFpAIntBGemmRunner::getConfigs() const -{ - static constexpr bool is_weight_only = !std::is_same::value; - std::vector candidateConfigs - = get_candidate_configs(sm_, is_weight_only, false, false, SPLIT_K_LIMIT); - return candidateConfigs; -} - -template -size_t CutlassFpAIntBGemmRunner::getWorkspaceSize(const int m, const int n, const int k) -{ - TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); - // These are the min tile sizes for each config, which would launch the maximum number of blocks - const int max_grid_m = cutlass::ceil_div(m, MIN_M_TILE); - const int max_grid_n = cutlass::ceil_div(n, MIN_N_TILE); - // We need 4 bytes per block in the worst case. We launch split_k_limit in z dim. - return static_cast(max_grid_m * max_grid_n * SPLIT_K_LIMIT * 4); -} - -} // namespace cutlass_kernels -} // namespace kernels -} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/int8_gemm/int8_gemm.h b/csrc/cutlass_kernels/int8_gemm/int8_gemm.h deleted file mode 100644 index f06ba4b4d85a..000000000000 --- a/csrc/cutlass_kernels/int8_gemm/int8_gemm.h +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "cutlass_extensions/gemm_configs.h" -#include "tensorrt_llm/common/quantization.h" -#include - -namespace tk = tensorrt_llm::common; -namespace tkc = tensorrt_llm::cutlass_extensions; - -namespace tensorrt_llm -{ -namespace kernels -{ -namespace cutlass_kernels -{ - -/* - This runner supports: - int8_t inputs (A and B) - float alpha scalings (either per-col, or per-col x per-row) - T output (D) where T = {float, half, __nv_bfloat16} // TODO - - Activations, biases, scales and outputs are all assumed to be row-major. - Weights are assumed to be column-major. -*/ - -class CutlassInt8GemmRunnerInterface -{ -public: - CutlassInt8GemmRunnerInterface() {} - - virtual ~CutlassInt8GemmRunnerInterface() {} - - virtual void gemm(const int8_t* A, const int8_t* B, tk::QuantMode quantOption, const float* alphaCol, - const float* alphaRow, void* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspacePtr, - const size_t workspaceBytes, cudaStream_t stream) - = 0; - - // Returns desired workspace size in bytes. - virtual size_t getWorkspaceSize(const int m, const int n, const int k) = 0; - - virtual std::vector getConfigs() const = 0; - -protected: - static constexpr int SPLIT_K_LIMIT = 7; - static constexpr int MIN_M_TILE = 32; - static constexpr int MIN_N_TILE = 64; -}; - -template -class CutlassInt8GemmRunner : public virtual CutlassInt8GemmRunnerInterface -{ -public: - CutlassInt8GemmRunner(); - ~CutlassInt8GemmRunner(); - - void gemm(const int8_t* A, const int8_t* B, tk::QuantMode quantOption, const float* alphaCol, const float* alphaRow, - void* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspacePtr, - const size_t workspaceBytes, cudaStream_t stream) override; - - // Returns desired workspace size in bytes. - size_t getWorkspaceSize(const int m, const int n, const int k) override; - - std::vector getConfigs() const override; - -private: - void dispatchToArch(const int8_t* A, const int8_t* B, tk::QuantMode quantOption, const float* alphaCol, - const float* alphaRow, T* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspacePtr, - const size_t workspaceBytes, cudaStream_t stream, int* occupancy = nullptr); - - int mSm; - int mMultiProcessorCount; -}; - -} // namespace cutlass_kernels -} // namespace kernels -} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/int8_gemm/int8_gemm_bf16.cu b/csrc/cutlass_kernels/int8_gemm/int8_gemm_bf16.cu deleted file mode 100644 index a3633bc0992a..000000000000 --- a/csrc/cutlass_kernels/int8_gemm/int8_gemm_bf16.cu +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h" - -namespace tensorrt_llm -{ -namespace kernels -{ -namespace cutlass_kernels -{ - -#ifdef ENABLE_BF16 -template class CutlassInt8GemmRunner<__nv_bfloat16>; -#endif - -} // namespace cutlass_kernels -} // namespace kernels -} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/int8_gemm/int8_gemm_fp16.cu b/csrc/cutlass_kernels/int8_gemm/int8_gemm_fp16.cu deleted file mode 100644 index 7189956d5d03..000000000000 --- a/csrc/cutlass_kernels/int8_gemm/int8_gemm_fp16.cu +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h" - -namespace tensorrt_llm -{ -namespace kernels -{ -namespace cutlass_kernels -{ - -template class CutlassInt8GemmRunner; - -} // namespace cutlass_kernels -} // namespace kernels -} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/int8_gemm/int8_gemm_fp32.cu b/csrc/cutlass_kernels/int8_gemm/int8_gemm_fp32.cu deleted file mode 100644 index 861a2d4ff0f3..000000000000 --- a/csrc/cutlass_kernels/int8_gemm/int8_gemm_fp32.cu +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h" - -namespace tensorrt_llm -{ -namespace kernels -{ -namespace cutlass_kernels -{ - -template class CutlassInt8GemmRunner; // for compilation only - -} // namespace cutlass_kernels -} // namespace kernels -} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/int8_gemm/int8_gemm_int32.cu b/csrc/cutlass_kernels/int8_gemm/int8_gemm_int32.cu deleted file mode 100644 index 6814b00e0286..000000000000 --- a/csrc/cutlass_kernels/int8_gemm/int8_gemm_int32.cu +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h" - -namespace tensorrt_llm -{ -namespace kernels -{ -namespace cutlass_kernels -{ - -template class CutlassInt8GemmRunner; - -} // namespace cutlass_kernels -} // namespace kernels -} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/int8_gemm/int8_gemm_template.h b/csrc/cutlass_kernels/int8_gemm/int8_gemm_template.h deleted file mode 100644 index 55bdc98df251..000000000000 --- a/csrc/cutlass_kernels/int8_gemm/int8_gemm_template.h +++ /dev/null @@ -1,388 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef _WIN32 -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wstrict-aliasing" -#endif // #ifndef _WIN32 - -// clang-format off -#include -#include -#include -#include -#include -// clang-format on - -#include "cutlass_extensions/compute_occupancy.h" -#include "cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h" -#include "cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h" -#include "cutlass_extensions/epilogue_helpers.h" -#include "cutlass_extensions/gemm_configs.h" - -#include "cutlass_extensions/gemm/kernel/default_int8_traits.h" -#include "cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h" - -#ifndef _WIN32 -#pragma GCC diagnostic pop -#endif // #ifndef _WIN32 - -#include "tensorrt_llm/common/allocator.h" -#include "tensorrt_llm/common/cudaUtils.h" -#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h" -#include "tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm.h" -#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttentionUtils.h" - -#include -#include - -namespace tk = tensorrt_llm::common; -namespace tkc = tensorrt_llm::cutlass_extensions; - -namespace tensorrt_llm -{ -namespace kernels -{ -namespace cutlass_kernels -{ - -template -void genericInt8GemmKernelLauncher(const int8_t* A, const int8_t* B, tk::QuantMode quantOption, const float* alphaCol, - const float* alphaRow, T* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspace, - size_t workspaceBytes, cudaStream_t stream, int* occupancy = nullptr) -{ - TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); - - using ElementInput = int8_t; - - using ElementOutput_ = - typename cutlass::platform::conditional::value, cutlass::half_t, T>::type; -#ifdef ENABLE_BF16 - using ElementOutput = - typename cutlass::platform::conditional::value, - cutlass::bfloat16_t, ElementOutput_>::type; -#else - using ElementOutput = ElementOutput_; -#endif - - using ElementAccumulator = int32_t; - using ElementCompute = float; - - using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; - - using OperatorClass = typename cutlass::gemm::kernel::Int8GemmArchTraits::OperatorClass; - using InstructionShape = typename cutlass::gemm::kernel::Int8GemmArchTraits::InstructionShape; - - using DefaultGemmConf = typename cutlass::gemm::device::DefaultGemmConfiguration; - using GemmOp = typename DefaultGemmConf::Operator; - using EpilogueOp = typename DefaultGemmConf::EpilogueOutputOp; - - // only TN is supported (s8 * s8 + s32) - using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm::GemmKernel; - - using AlphaColTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< - cutlass::epilogue::threadblock::OutputTileOptimalThreadMap< - typename GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::Shape, - typename GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::Count, - GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::kThreads, - GemmKernel_::Epilogue::OutputTileIterator::kElementsPerAccess, cutlass::sizeof_bits::value>, - ElementCompute>; - - // Epilogue visitor - using EpilogueVisitor = typename cutlass::epilogue::threadblock::EpilogueVisitorPerRowPerCol; - - /// Epilogue - using Epilogue = typename cutlass::epilogue::threadblock::EpilogueWithVisitorFromExistingEpilogue::Epilogue; - - // GEMM - using GemmKernel - = cutlass::gemm::kernel::GemmWithEpilogueVisitor; - - if (occupancy != nullptr) - { - *occupancy = tensorrt_llm::cutlass_extensions::compute_occupancy_for_kernel(); - return; - } - - using Gemm = cutlass::gemm::device::GemmUniversalBaseCompat; - - typename EpilogueOp::Params linearScalingParams; // TODO: right now it's unused (scaling is done in - // visitor, no activation needed) - typename Gemm::Arguments args{cutlass::gemm::GemmUniversalMode::kBatched, {m, n, k}, 1, - {reinterpret_cast(const_cast(A)), k}, - {reinterpret_cast(const_cast(B)), k}, quantOption, - {reinterpret_cast(const_cast(alphaCol)), 0}, - {reinterpret_cast(const_cast(alphaRow)), 0}, {nullptr, 0}, - {reinterpret_cast(C), n}, 0, 0, - typename EpilogueVisitor::Arguments(linearScalingParams, 0, 0, 0)}; - - Gemm gemm; - // TODO: handle that - if (gemm.get_workspace_size(args) > workspaceBytes) - { - TLLM_LOG_WARNING( - "Requested split-k but workspace size insufficient. Falling back to non-split-k implementation."); - // If requested split-k factor will require more workspace bytes, revert to standard gemm. - args.batch_count = 1; - } - - auto can_implement = gemm.can_implement(args); - if (can_implement != cutlass::Status::kSuccess) - { - std::string errMsg = "int8gemm cutlass kernel will fail for params. Error: " - + std::string(cutlassGetStatusString(can_implement)); - throw std::runtime_error("[TensorRT-LLM Error][int8gemm Runner] " + errMsg); - } - - auto initStatus = gemm.initialize(args, workspace, stream); - if (initStatus != cutlass::Status::kSuccess) - { - std::string errMsg - = "Failed to initialize cutlass int8 gemm. Error: " + std::string(cutlassGetStatusString(initStatus)); - throw std::runtime_error("[TensorRT-LLM Error][int8gemm Runner] " + errMsg); - } - - auto runStatus = gemm.run(stream); - if (runStatus != cutlass::Status::kSuccess) - { - std::string errMsg - = "Failed to run cutlass int8 gemm. Error: " + std::string(cutlassGetStatusString(runStatus)); - throw std::runtime_error("[TensorRT-LLM Error][int8gemm Runner] " + errMsg); - } -} - -template -struct dispatchStages -{ - static void dispatch(const int8_t* A, const int8_t* B, tk::QuantMode quantOption, const float* alphaCol, - const float* alphaRow, T* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspace, - size_t workspaceBytes, cudaStream_t stream, int* occupancy = nullptr) - { - TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); - std::string errMsg = "Cutlass int8 gemm. Not instantiates for arch " - + std::to_string(arch::kMinComputeCapability) + " with stages set to " + std::to_string(Stages); - throw std::runtime_error("[TensorRT-LLM Error][dispatchStages::dispatch] " + errMsg); - } -}; - -template -struct dispatchStages -{ - static void dispatch(const int8_t* A, const int8_t* B, tk::QuantMode quantOption, const float* alphaCol, - const float* alphaRow, T* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspace, - size_t workspaceBytes, cudaStream_t stream, int* occupancy = nullptr) - { - TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); - genericInt8GemmKernelLauncher(A, B, quantOption, alphaCol, alphaRow, C, - m, n, k, gemmConfig, workspace, workspaceBytes, stream, occupancy); - } -}; - -template -struct dispatchStages 2)>::type> -{ - static void dispatch(const int8_t* A, const int8_t* B, tk::QuantMode quantOption, const float* alphaCol, - const float* alphaRow, T* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspace, - size_t workspaceBytes, cudaStream_t stream, int* occupancy = nullptr) - { - - TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); - genericInt8GemmKernelLauncher(A, B, quantOption, - alphaCol, alphaRow, C, m, n, k, gemmConfig, workspace, workspaceBytes, stream, occupancy); - } -}; - -template -void dispatchGemmConfig(const int8_t* A, const int8_t* B, tk::QuantMode quantOption, const float* alphaCol, - const float* alphaRow, T* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspace, - size_t workspaceBytes, cudaStream_t stream, int* occupancy = nullptr) -{ - - TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); - switch (gemmConfig.stages) - { - case 2: - using DispatcherStages2 = dispatchStages; - DispatcherStages2::dispatch(A, B, quantOption, alphaCol, alphaRow, C, m, n, k, gemmConfig, workspace, - workspaceBytes, stream, occupancy); - break; - case 3: - using DispatcherStages3 = dispatchStages; - DispatcherStages3::dispatch(A, B, quantOption, alphaCol, alphaRow, C, m, n, k, gemmConfig, workspace, - workspaceBytes, stream, occupancy); - break; - case 4: - using DispatcherStages4 = dispatchStages; - DispatcherStages4::dispatch(A, B, quantOption, alphaCol, alphaRow, C, m, n, k, gemmConfig, workspace, - workspaceBytes, stream, occupancy); - break; - case 5: - using DispatcherStages5 = dispatchStages; - DispatcherStages5::dispatch(A, B, quantOption, alphaCol, alphaRow, C, m, n, k, gemmConfig, workspace, - workspaceBytes, stream, occupancy); - break; - case 6: - using DispatcherStages6 = dispatchStages; - DispatcherStages6::dispatch(A, B, quantOption, alphaCol, alphaRow, C, m, n, k, gemmConfig, workspace, - workspaceBytes, stream, occupancy); - break; - default: - std::string errMsg = "dispatchGemmConfig does not support stages " + std::to_string(gemmConfig.stages); - throw std::runtime_error("[TensorRT-LLM Error][dispatch_gemm_config] " + errMsg); - break; - } -} - -template -void dispatchGemmToCutlass(const int8_t* A, const int8_t* B, tk::QuantMode quantOption, const float* alphaCol, - const float* alphaRow, T* C, int m, int n, int k, char* workspace, size_t workspaceBytes, - tkc::CutlassGemmConfig gemmConfig, cudaStream_t stream, int* occupancy = nullptr) -{ - - TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); - - switch (gemmConfig.tile_config) - { - case tkc::CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64: - dispatchGemmConfig, cutlass::gemm::GemmShape<64, 32, 64>>(A, B, - quantOption, alphaCol, alphaRow, C, m, n, k, gemmConfig, workspace, workspaceBytes, stream, occupancy); - break; - case tkc::CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64: - dispatchGemmConfig, cutlass::gemm::GemmShape<64, 64, 64>>(A, B, - quantOption, alphaCol, alphaRow, C, m, n, k, gemmConfig, workspace, workspaceBytes, stream, occupancy); - break; - case tkc::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: - dispatchGemmConfig, cutlass::gemm::GemmShape<32, 32, 64>>(A, B, - quantOption, alphaCol, alphaRow, C, m, n, k, gemmConfig, workspace, workspaceBytes, stream, occupancy); - break; - case tkc::CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: - dispatchGemmConfig, cutlass::gemm::GemmShape<64, 32, 64>>(A, B, - quantOption, alphaCol, alphaRow, C, m, n, k, gemmConfig, workspace, workspaceBytes, stream, occupancy); - break; - case tkc::CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64: - dispatchGemmConfig, cutlass::gemm::GemmShape<32, 64, 64>>(A, B, - quantOption, alphaCol, alphaRow, C, m, n, k, gemmConfig, workspace, workspaceBytes, stream, occupancy); - break; - case tkc::CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64: - dispatchGemmConfig, cutlass::gemm::GemmShape<64, 64, 64>>(A, B, - quantOption, alphaCol, alphaRow, C, m, n, k, gemmConfig, workspace, workspaceBytes, stream, occupancy); - break; - case tkc::CutlassTileConfig::Undefined: - throw std::runtime_error("[TensorRT-LLM Error][int8][dispatch_gemm_to_cutlass] gemm config undefined."); - break; - case tkc::CutlassTileConfig::ChooseWithHeuristic: - throw std::runtime_error( - "[TensorRT-LLM Error][int8][dispatch_gemm_to_cutlass] gemm config should have already been set by " - "heuristic."); - break; - default: - throw std::runtime_error( - "[TensorRT-LLM Error][int8][dispatch_gemm_to_cutlass] Config is invalid for int8 GEMM."); - break; - } -} - -template -CutlassInt8GemmRunner::CutlassInt8GemmRunner() -{ - TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); - int device{-1}; - tk::check_cuda_error(cudaGetDevice(&device)); - mSm = tk::getSMVersion(); - tk::check_cuda_error(cudaDeviceGetAttribute(&mMultiProcessorCount, cudaDevAttrMultiProcessorCount, device)); -} - -template -CutlassInt8GemmRunner::~CutlassInt8GemmRunner() -{ - TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); -} - -template -void CutlassInt8GemmRunner::dispatchToArch(const int8_t* A, const int8_t* B, tk::QuantMode quantOption, - const float* alphaCol, const float* alphaRow, T* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, - char* workspacePtr, const size_t workspaceBytes, cudaStream_t stream, int* occupancy) -{ - TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); - if (mSm >= 70 && mSm < 72) - { - dispatchGemmToCutlass(A, B, quantOption, alphaCol, alphaRow, C, m, n, k, workspacePtr, - workspaceBytes, gemmConfig, stream, occupancy); - } - else if (mSm >= 72 && mSm < 75) - { - dispatchGemmToCutlass(A, B, quantOption, alphaCol, alphaRow, C, m, n, k, workspacePtr, - workspaceBytes, gemmConfig, stream, occupancy); - } - else if (mSm >= 75 && mSm < 80) - { - dispatchGemmToCutlass(A, B, quantOption, alphaCol, alphaRow, C, m, n, k, workspacePtr, - workspaceBytes, gemmConfig, stream, occupancy); - } - else if (mSm >= 80 && mSm <= 90) - { - dispatchGemmToCutlass(A, B, quantOption, alphaCol, alphaRow, C, m, n, k, workspacePtr, - workspaceBytes, gemmConfig, stream, occupancy); - } - else - { - throw std::runtime_error( - "[TensorRT-LLM Error][CutlassInt8GemmRunner][GEMM Dispatch] Arch unsupported for CUTLASS int8 GEMM"); - } -} - -template -void CutlassInt8GemmRunner::gemm(const int8_t* A, const int8_t* B, tk::QuantMode quantOption, const float* alphaCol, - const float* alphaRow, void* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspacePtr, - const size_t workspaceBytes, cudaStream_t stream) -{ - TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); - dispatchToArch(A, B, quantOption, alphaCol, alphaRow, reinterpret_cast(C), m, n, k, gemmConfig, workspacePtr, - workspaceBytes, stream); -} - -template -std::vector CutlassInt8GemmRunner::getConfigs() const -{ - static constexpr bool isWeightOnly = false; - std::vector candidateConfigs - = get_candidate_configs(mSm, isWeightOnly, mSm <= 70, /* SIMT configs */ - true, SPLIT_K_LIMIT); /* INT8 configs */ - return candidateConfigs; -} - -template -size_t CutlassInt8GemmRunner::getWorkspaceSize(const int m, const int n, const int k) -{ - TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); - // These are the min tile sizes for each config, which would launch the maximum number of blocks - const int maxGridM = cutlass::ceil_div(m, MIN_M_TILE); - const int maxGridN = cutlass::ceil_div(m, MIN_N_TILE); - // We need 4 bytes per block in the worst case. We launch SPLIT_K_LIMIT in z dim. - return static_cast(maxGridM * maxGridN * SPLIT_K_LIMIT * 4); -} - -} // namespace cutlass_kernels -} // namespace kernels -} // namespace tensorrt_llm From f1583defce858174b93b855ef8865cb3844c4c07 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 2 Feb 2024 06:11:32 +0000 Subject: [PATCH 09/33] Minor --- csrc/pybind.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 8a8235691ab8..b36d25969716 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -48,8 +48,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { &rotary_embedding, "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); -#ifndef USE_ROCM // Quantization ops +#ifndef USE_ROCM ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ"); #endif From de7a74969ed8c49605c924297c2f5c1154b6b56b Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 2 Feb 2024 06:13:26 +0000 Subject: [PATCH 10/33] Add topk_softmax kernels --- MANIFEST.in | 1 + csrc/moe/moe_ops.cc | 7 + csrc/moe/moe_ops.h | 9 + csrc/moe/topk_softmax_kernels.cu | 494 +++++++++++++++++++++++++++++++ setup.py | 12 + 5 files changed, 523 insertions(+) create mode 100644 csrc/moe/moe_ops.cc create mode 100644 csrc/moe/moe_ops.h create mode 100644 csrc/moe/topk_softmax_kernels.cu diff --git a/MANIFEST.in b/MANIFEST.in index 0c897cf147f1..5e218f8a30a2 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -2,3 +2,4 @@ include LICENSE include requirements.txt recursive-include csrc * +recursive-include third_party * diff --git a/csrc/moe/moe_ops.cc b/csrc/moe/moe_ops.cc new file mode 100644 index 000000000000..3acc67e57761 --- /dev/null +++ b/csrc/moe/moe_ops.cc @@ -0,0 +1,7 @@ +#include "moe_ops.h" + +#include + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("topk_softmax", &topk_softmax, "Top-k softmax for MoE"); +} diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h new file mode 100644 index 000000000000..a01be3e426d7 --- /dev/null +++ b/csrc/moe/moe_ops.h @@ -0,0 +1,9 @@ +#pragma once + +#include + +void topk_softmax( + torch::Tensor& topk_weights, + torch::Tensor& topk_indices, + torch::Tensor& token_expert_indices, + torch::Tensor& gating_output); diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu new file mode 100644 index 000000000000..ee4aee28408e --- /dev/null +++ b/csrc/moe/topk_softmax_kernels.cu @@ -0,0 +1,494 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include +#include +#include +#include + +#include "cutlass/array.h" + +#include +#include + +// FIXME(woosuk) +#ifndef ENABLE_BF16 +#define ENABLE_BF16 +#endif + +namespace vllm { + +static constexpr int WARP_SIZE = 32; + +// ====================== Softmax things =============================== +// We have our own implementation of softmax here so we can support transposing the output +// in the softmax kernel when we extend this module to support expert-choice routing. +template +__launch_bounds__(TPB) __global__ + void moeSoftmax(const float* input, const bool* finished, float* output, const int num_cols) +{ + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmpStorage; + + __shared__ float normalizing_factor; + __shared__ float float_max; + + const int thread_row_offset = blockIdx.x * num_cols; + + cub::Sum sum; + float threadData(-FLT_MAX); + + // Don't touch finished rows. + if ((finished != nullptr) && finished[blockIdx.x]) + { + return; + } + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) + { + const int idx = thread_row_offset + ii; + threadData = max(input[idx], threadData); + } + + const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); + if (threadIdx.x == 0) + { + float_max = maxElem; + } + __syncthreads(); + + threadData = 0; + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) + { + const int idx = thread_row_offset + ii; + threadData += exp((static_cast(input[idx]) - float_max)); + } + + const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum); + + if (threadIdx.x == 0) + { + normalizing_factor = 1.f / Z; + } + __syncthreads(); + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) + { + const int idx = thread_row_offset + ii; + const float val = exp((static_cast(input[idx]) - float_max)) * normalizing_factor; + output[idx] = val; + } +} + +template +__launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax, const bool* finished, float* output, + int* indices, int* source_rows, const int num_experts, const int k, const int start_expert, const int end_expert) +{ + + using cub_kvp = cub::KeyValuePair; + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmpStorage; + + cub_kvp thread_kvp; + cub::ArgMax arg_max; + + const int num_rows = gridDim.x; + const int block_row = blockIdx.x; + + const bool row_is_active = finished ? !finished[block_row] : true; + const int thread_read_offset = blockIdx.x * num_experts; + for (int k_idx = 0; k_idx < k; ++k_idx) + { + thread_kvp.key = 0; + thread_kvp.value = -1.f; // This is OK because inputs are probabilities + + cub_kvp inp_kvp; + for (int expert = threadIdx.x; expert < num_experts; expert += TPB) + { + const int idx = thread_read_offset + expert; + inp_kvp.key = expert; + inp_kvp.value = inputs_after_softmax[idx]; + + for (int prior_k = 0; prior_k < k_idx; ++prior_k) + { + const int prior_winning_expert = indices[k * block_row + prior_k]; + + if (prior_winning_expert == expert) + { + inp_kvp = thread_kvp; + } + } + + thread_kvp = arg_max(inp_kvp, thread_kvp); + } + + const cub_kvp result_kvp = BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max); + if (threadIdx.x == 0) + { + // Ignore experts the node isn't responsible for with expert parallelism + const int expert = result_kvp.key; + const bool node_uses_expert = expert >= start_expert && expert < end_expert; + const bool should_process_row = row_is_active && node_uses_expert; + + const int idx = k * block_row + k_idx; + output[idx] = result_kvp.value; + indices[idx] = should_process_row ? (expert - start_expert) : num_experts; + assert(indices[idx] >= 0); + source_rows[idx] = k_idx * num_rows + block_row; + } + __syncthreads(); + } +} + +// ====================== TopK softmax things =============================== + +/* + A Top-K gating softmax written to exploit when the number of experts in the MoE layers + are a small power of 2. This allows us to cleanly share the rows among the threads in + a single warp and eliminate communication between warps (so no need to use shared mem). + + It fuses the softmax, max and argmax into a single kernel. + + Limitations: + 1) This implementation is intended for when the number of experts is a small power of 2. + 2) This implementation assumes k is small, but will work for any k. +*/ + +template +__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ + void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, int* indices, + int* source_rows, const int k, const int start_expert, const int end_expert) +{ + // We begin by enforcing compile time assertions and setting up compile time constants. + static_assert(VPT == (VPT & -VPT), "VPT must be power of 2"); + static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2"); + static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2"); + static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16"); + + // Number of bytes each thread pulls in per load + static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float); + static constexpr int ELTS_PER_ROW = NUM_EXPERTS; + static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT; + static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG; + + // Restrictions based on previous section. + static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg"); + static_assert(WARP_SIZE % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp"); + static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), "THREADS_PER_ROW must be power of 2"); + static_assert(THREADS_PER_ROW <= WARP_SIZE, "THREADS_PER_ROW can be at most warp size"); + + // We have NUM_EXPERTS elements per row. We specialize for small #experts + static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT; + static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW; + static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP; + + // Restrictions for previous section. + static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0, "The elts per row must cleanly divide the total elt per warp"); + + // ===================== From this point, we finally start computing run-time variables. ======================== + + // Compute CTA and warp rows. We pack multiple rows into a single warp, and a block contains WARPS_PER_CTA warps. + // This, each block processes a chunk of rows. We start by computing the start row for each block. + const int cta_base_row = blockIdx.x * ROWS_PER_CTA; + + // Now, using the base row per thread block, we compute the base row per warp. + const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP; + + // The threads in a warp are split into sub-groups that will work on a row. + // We compute row offset for each thread sub-group + const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW; + const int thread_row = warp_base_row + thread_row_in_warp; + + // Threads with indices out of bounds should early exit here. + if (thread_row >= num_rows) + { + return; + } + const bool row_is_active = finished ? !finished[thread_row] : true; + + // We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the + // row it will read. + const float* thread_row_ptr = input + thread_row * ELTS_PER_ROW; + + // Now, we compute the group each thread belong to in order to determine the first column to start loads. + const int thread_group_idx = threadIdx.x % THREADS_PER_ROW; + const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG; + const float* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread; + + // Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory, + // this can support all powers of 2 up to 16. + using AccessType = cutlass::AlignedArray; + + // Finally, we pull in the data from global mem + cutlass::Array row_chunk; + AccessType* row_chunk_vec_ptr = reinterpret_cast(&row_chunk); + const AccessType* vec_thread_read_ptr = reinterpret_cast(thread_read_ptr); +#pragma unroll + for (int ii = 0; ii < LDG_PER_THREAD; ++ii) + { + row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW]; + } + + // First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just + // convert to float afterwards for the exp + sum reduction. + float thread_max = row_chunk[0]; +#pragma unroll + for (int ii = 1; ii < VPT; ++ii) + { + thread_max = max(thread_max, row_chunk[ii]); + } + +// Now, we find the max within the thread group and distribute among the threads. We use a butterfly reduce. +#pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) + { + thread_max = max(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW)); + } + + // From this point, thread max in all the threads have the max within the row. + // Now, we subtract the max from each element in the thread and take the exp. We also compute the thread local sum. + float row_sum = 0; +#pragma unroll + for (int ii = 0; ii < VPT; ++ii) + { + row_chunk[ii] = expf(row_chunk[ii] - thread_max); + row_sum += row_chunk[ii]; + } + +// Now, we perform the sum reduce within each thread group. Similar to the max reduce, we use a bufferfly pattern. +#pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) + { + row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, THREADS_PER_ROW); + } + + // From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables + // respectively. Finally, we can scale the rows for the softmax. Technically, for top-k gating we don't need to + // compute the entire softmax row. We can likely look at the maxes and only compute for the top-k values in the row. + // However, this kernel will likely not be a bottle neck and it seems better to closer match torch and find the + // argmax after computing the softmax. + const float reciprocal_row_sum = 1.f / row_sum; + +#pragma unroll + for (int ii = 0; ii < VPT; ++ii) + { + row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum; + } + + // Now, softmax_res contains the softmax of the row chunk. Now, I want to find the topk elements in each row, along + // with the max index. + int start_col = first_elt_read_by_thread; + static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW; + + for (int k_idx = 0; k_idx < k; ++k_idx) + { + // First, each thread does the local argmax + float max_val = row_chunk[0]; + int expert = start_col; +#pragma unroll + for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD; ++ldg, col += COLS_PER_GROUP_LDG) + { +#pragma unroll + for (int ii = 0; ii < ELTS_PER_LDG; ++ii) + { + float val = row_chunk[ldg * ELTS_PER_LDG + ii]; + + // No check on the experts here since columns with the smallest index are processed first and only + // updated if > (not >=) + if (val > max_val) + { + max_val = val; + expert = col + ii; + } + } + } + +// Now, we perform the argmax reduce. We use the butterfly pattern so threads reach consensus about the max. +// This will be useful for K > 1 so that the threads can agree on "who" had the max value. That thread can +// then blank out their max with -inf and the warp can run more iterations... +#pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) + { + float other_max = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, THREADS_PER_ROW); + int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, THREADS_PER_ROW); + + // We want lower indices to "win" in every thread so we break ties this way + if (other_max > max_val || (other_max == max_val && other_expert < expert)) + { + max_val = other_max; + expert = other_expert; + } + } + + // Write the max for this k iteration to global memory. + if (thread_group_idx == 0) + { + // Add a guard to ignore experts not included by this node + const bool node_uses_expert = expert >= start_expert && expert < end_expert; + const bool should_process_row = row_is_active && node_uses_expert; + + // The lead thread from each sub-group will write out the final results to global memory. (This will be a + // single) thread per row of the input/output matrices. + const int idx = k * thread_row + k_idx; + output[idx] = max_val; + indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS; + source_rows[idx] = k_idx * num_rows + thread_row; + } + + // Finally, we clear the value in the thread with the current max if there is another iteration to run. + if (k_idx + 1 < k) + { + const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG; + const int thread_to_clear_in_group = (expert / ELTS_PER_LDG) % THREADS_PER_ROW; + + // Only the thread in the group which produced the max will reset the "winning" value to -inf. + if (thread_group_idx == thread_to_clear_in_group) + { + const int offset_for_expert = expert % ELTS_PER_LDG; + // Safe to set to any negative value since row_chunk values must be between 0 and 1. + row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = -10000.f; + } + } + } +} + +namespace detail +{ +// Constructs some constants needed to partition the work across threads at compile time. +template +struct TopkConstants +{ + static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float); + static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, ""); + static constexpr int VECs_PER_THREAD = std::max(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE)); + static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG; + static constexpr int THREADS_PER_ROW = EXPERTS / VPT; + static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW; +}; +} // namespace detail + +template +void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, int* indices, + int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream) +{ + static constexpr std::size_t MAX_BYTES_PER_LDG = 16; + + static constexpr int BYTES_PER_LDG = std::min(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS); + using Constants = detail::TopkConstants; + static constexpr int VPT = Constants::VPT; + static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP; + const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP; + const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB; + + dim3 block_dim(WARP_SIZE, WARPS_PER_TB); + topkGatingSoftmax<<>>( + input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert); +} + +#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \ + topkGatingSoftmaxLauncherHelper( \ + gating_output, nullptr, topk_weights, topk_indicies, \ + token_expert_indices, num_tokens, topk, 0, num_experts, \ + stream); + +void topkGatingSoftmaxKernelLauncher( + const float* gating_output, + float* topk_weights, + int* topk_indicies, + int* token_expert_indices, + float* softmax_workspace, + const int num_tokens, + const int num_experts, + const int topk, + cudaStream_t stream) { + static constexpr int WARPS_PER_TB = 4; + switch (num_experts) { + case 1: + LAUNCH_SOFTMAX(1, WARPS_PER_TB); + break; + case 2: + LAUNCH_SOFTMAX(2, WARPS_PER_TB); + break; + case 4: + LAUNCH_SOFTMAX(4, WARPS_PER_TB); + break; + case 8: + LAUNCH_SOFTMAX(8, WARPS_PER_TB); + break; + case 16: + LAUNCH_SOFTMAX(16, WARPS_PER_TB); + break; + case 32: + LAUNCH_SOFTMAX(32, WARPS_PER_TB); + break; + case 64: + LAUNCH_SOFTMAX(64, WARPS_PER_TB); + break; + case 128: + LAUNCH_SOFTMAX(128, WARPS_PER_TB); + break; + case 256: + LAUNCH_SOFTMAX(256, WARPS_PER_TB); + break; + default: { + TORCH_CHECK(softmax_workspace != nullptr, + "softmax_workspace must be provided for num_experts that are not a power of 2."); + static constexpr int TPB = 256; + moeSoftmax<<>>( + gating_output, nullptr, softmax_workspace, num_experts); + moeTopK<<>>( + softmax_workspace, nullptr, topk_weights, topk_indicies, token_expert_indices, + num_experts, topk, 0, num_experts); + } + } +} + +} // namespace vllm + +void topk_softmax( + torch::Tensor& topk_weights, // [num_tokens, topk] + torch::Tensor& topk_indices, // [num_tokens, topk] + torch::Tensor& token_expert_indices, // [num_tokens, topk] + torch::Tensor& gating_output) // [num_tokens, num_experts] +{ + const int num_experts = gating_output.size(-1); + const int num_tokens = gating_output.numel() / num_experts; + const int topk = topk_weights.size(-1); + + const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); + const bool needs_workspace = !is_pow_2 || num_experts > 256; + const int64_t workspace_size = needs_workspace ? num_tokens * num_experts : 0; + + const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output)); + torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options()); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + vllm::topkGatingSoftmaxKernelLauncher( + gating_output.data_ptr(), + topk_weights.data_ptr(), + topk_indices.data_ptr(), + token_expert_indices.data_ptr(), + softmax_workspace.data_ptr(), + num_tokens, + num_experts, + topk, + stream); +} diff --git a/setup.py b/setup.py index 3e2127855a75..d00abd1c4b23 100644 --- a/setup.py +++ b/setup.py @@ -317,6 +317,18 @@ def get_torch_arch_list() -> Set[str]: vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu") vllm_extension_sources.append("csrc/custom_all_reduce.cu") + abs_root_dir = os.path.abspath(ROOT_DIR) + ext_modules.append( + CUDAExtension( + name="vllm._moe_C", + sources=["csrc/moe/moe_ops.cc"] + glob("csrc/moe/*.cu"), + include_dirs=[os.path.join(abs_root_dir, "third_party/cutlass/include/")], + extra_compile_args={ + "cxx": CXX_FLAGS, + "nvcc": NVCC_FLAGS_PUNICA, # FIXME + }, + )) + if not _is_neuron(): vllm_extension = CUDAExtension( name="vllm._C", From e5c62e8d193ee76b513d26c919cad28eeb27b8b5 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 2 Feb 2024 06:31:40 +0000 Subject: [PATCH 11/33] Remove unnecessary headers --- csrc/moe/topk_softmax_kernels.cu | 6 ------ 1 file changed, 6 deletions(-) diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu index ee4aee28408e..114a5dddc9ab 100644 --- a/csrc/moe/topk_softmax_kernels.cu +++ b/csrc/moe/topk_softmax_kernels.cu @@ -25,15 +25,9 @@ #include #include "cutlass/array.h" - #include #include -// FIXME(woosuk) -#ifndef ENABLE_BF16 -#define ENABLE_BF16 -#endif - namespace vllm { static constexpr int WARP_SIZE = 32; From e127d9b77726d06f9a505d92ac45d5f786135378 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 2 Feb 2024 06:49:50 +0000 Subject: [PATCH 12/33] Add MoE namespace --- csrc/moe/topk_softmax_kernels.cu | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu index 114a5dddc9ab..84a48aa32bed 100644 --- a/csrc/moe/topk_softmax_kernels.cu +++ b/csrc/moe/topk_softmax_kernels.cu @@ -29,6 +29,7 @@ #include namespace vllm { +namespace moe { static constexpr int WARP_SIZE = 32; @@ -456,6 +457,7 @@ void topkGatingSoftmaxKernelLauncher( } } +} // namespace moe } // namespace vllm void topk_softmax( @@ -475,7 +477,7 @@ void topk_softmax( const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output)); torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options()); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - vllm::topkGatingSoftmaxKernelLauncher( + vllm::moe::topkGatingSoftmaxKernelLauncher( gating_output.data_ptr(), topk_weights.data_ptr(), topk_indices.data_ptr(), From c3096a02a2f9ebd48d7e2a35058e1ec82cd0c0d8 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 2 Feb 2024 07:34:42 +0000 Subject: [PATCH 13/33] Minor --- csrc/moe/topk_softmax_kernels.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu index 84a48aa32bed..804aec4c0fb7 100644 --- a/csrc/moe/topk_softmax_kernels.cu +++ b/csrc/moe/topk_softmax_kernels.cu @@ -475,8 +475,8 @@ void topk_softmax( const int64_t workspace_size = needs_workspace ? num_tokens * num_experts : 0; const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output)); - torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options()); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options()); vllm::moe::topkGatingSoftmaxKernelLauncher( gating_output.data_ptr(), topk_weights.data_ptr(), From 9a561cc798b178e1dbf4f0cf9565dcc76126434f Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 2 Feb 2024 08:35:27 +0000 Subject: [PATCH 14/33] Add permute_kernels --- csrc/moe/moe_ops.cc | 3 +- csrc/moe/moe_ops.h | 8 ++ csrc/moe/permute_kernels.cu | 243 ++++++++++++++++++++++++++++++++++++ 3 files changed, 253 insertions(+), 1 deletion(-) create mode 100644 csrc/moe/permute_kernels.cu diff --git a/csrc/moe/moe_ops.cc b/csrc/moe/moe_ops.cc index 3acc67e57761..edc9e0fe6dd4 100644 --- a/csrc/moe/moe_ops.cc +++ b/csrc/moe/moe_ops.cc @@ -3,5 +3,6 @@ #include PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("topk_softmax", &topk_softmax, "Top-k softmax for MoE"); + m.def("topk_softmax", &topk_softmax, "Apply topk softmax to the gating outputs."); + m.def("expand_and_permute", &expand_and_permute, "Expand and permute the input tokens."); } diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index a01be3e426d7..e18c6f66d0d2 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -7,3 +7,11 @@ void topk_softmax( torch::Tensor& topk_indices, torch::Tensor& token_expert_indices, torch::Tensor& gating_output); + +void expand_and_permute( + torch::Tensor& permuted_tokens, + torch::Tensor& cum_num_tokens_per_expert, + torch::Tensor& reverse_permutation_map, + torch::Tensor& input_tokens, + torch::Tensor& topk_indices, + torch::Tensor& token_expert_indices); diff --git a/csrc/moe/permute_kernels.cu b/csrc/moe/permute_kernels.cu new file mode 100644 index 000000000000..3ee66c2b230c --- /dev/null +++ b/csrc/moe/permute_kernels.cu @@ -0,0 +1,243 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include + +#include "../dispatch_utils.h" + +#include +#include +#include +#include + +#include +#include +#include + +namespace vllm { +namespace moe { + +// ========================== CUB Sorting things ==================================== +size_t get_workspace_size_for_radix_sort( + const size_t num_key_value_pairs, + const int num_buckets) +{ + size_t num_bits = (int) log2(num_buckets) + 1; + size_t required_storage = 0; + int* null_int = nullptr; + cub::DeviceRadixSort::SortPairs( + NULL, required_storage, null_int, null_int, null_int, null_int, + num_key_value_pairs, 0, num_bits); + return required_storage; +} + +void radix_sort( + const int* keys_in, + int* keys_out, + const int* values_in, + int* values_out, + void* workspace, + size_t workspace_size, + const int num_buckets, + const size_t num_key_value_pairs, + cudaStream_t stream) +{ + size_t num_bits = (int) log2(num_buckets) + 1; + cub::DeviceRadixSort::SortPairs( + workspace, workspace_size, keys_in, keys_out, values_in, values_out, + num_key_value_pairs, 0, num_bits, stream); +} + +// ============================== Infer GEMM sizes ================================= +// TODO Could linear search be better for small # experts +__device__ inline int findTotalEltsLeqTarget(const int* sorted_indices, const int arr_length, const int target) +{ + int64_t low = 0, high = arr_length - 1, target_location = -1; + while (low <= high) + { + int64_t mid = (low + high) / 2; + + if (sorted_indices[mid] > target) + { + high = mid - 1; + } + else + { + low = mid + 1; + target_location = mid; + } + } + return target_location + 1; +} + +// Sets up the gemm assuming the inputs, experts and outputs are stored in row major order. +// Assumes we want to perform output = matmul(inputs, experts) + bias +// +// "total_rows_before_expert" contains the index one past the last occurrence of the corresponding expert. +// e.g. Index 0 is the start offset of expert 1, the final entry is the total number of active rows +__global__ void computeTotalRowsBeforeExpertKernel(const int* sorted_experts, const int sorted_experts_len, + const int64_t num_experts, int64_t* total_rows_before_expert) +{ + // First, compute the global tid. We only need 1 thread per expert. + const int expert = blockIdx.x * blockDim.x + threadIdx.x; + if (expert >= num_experts) + { + return; + } + + // This should construct the last index where each expert occurs. + total_rows_before_expert[expert] = findTotalEltsLeqTarget(sorted_experts, sorted_experts_len, expert); +} + +void computeTotalRowsBeforeExpert(const int* sorted_indices, const int total_indices, const int num_experts, + int64_t* total_rows_before_expert, cudaStream_t stream) +{ + const int threads = std::min(1024, num_experts); + const int blocks = (num_experts + threads - 1) / threads; + + computeTotalRowsBeforeExpertKernel<<>>( + sorted_indices, total_indices, num_experts, total_rows_before_expert); +} + + +// ========================== Permutation things ======================================= + +// Duplicated and permutes rows for MoE. In addition, reverse the permutation map to help with finalizing routing. + +// "expanded_x_row" simply means that the number of values is num_rows x k. It is "expanded" since we will have to +// duplicate some rows in the input matrix to match the dimensions. Duplicates will always get routed to separate +// experts in the end. + +// Note that the expanded_dest_row_to_expanded_source_row map referred to here has indices in the range (0, +// k*rows_in_input - 1). However, it is set up so that index 0, rows_in_input, 2*rows_in_input ... (k-1)*rows_in_input +// all map to row 0 in the original matrix. Thus, to know where to read in the source matrix, we simply take the modulus +// of the expanded index. + +template +__global__ void expandInputRowsKernel(const T* unpermuted_input, T* permuted_output, + const int* expanded_dest_row_to_expanded_source_row, int* expanded_source_row_to_expanded_dest_row, + const int num_rows, const int64_t* num_dest_rows, const int cols) +{ + + // Reverse permutation map. + // I do this so that later, we can use the source -> dest map to do the k-way reduction and unpermuting. I need the + // reverse map for that reduction to allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1 + // thread block will be responsible for all k summations. + const int expanded_dest_row = blockIdx.x; + const int expanded_source_row = expanded_dest_row_to_expanded_source_row[expanded_dest_row]; + if (threadIdx.x == 0) + { + expanded_source_row_to_expanded_dest_row[expanded_source_row] = expanded_dest_row; + } + + if (!CHECK_SKIPPED || blockIdx.x < *num_dest_rows) + { + // Duplicate and permute rows + const int source_row = expanded_source_row % num_rows; + + const T* source_row_ptr = unpermuted_input + source_row * cols; + T* dest_row_ptr = permuted_output + expanded_dest_row * cols; + + for (int tid = threadIdx.x; tid < cols; tid += blockDim.x) + { + dest_row_ptr[tid] = source_row_ptr[tid]; + } + } +} + +template +void expandInputRowsKernelLauncher( + T* output, + int* reverse_permutation_map, + const T* input_tokens, + const int* sorted_token_expert_indices, + const int num_tokens, + const int hidden_size, + const int topk, + cudaStream_t stream) +{ + const int64_t blocks = num_tokens * topk; + const int threads = std::min(hidden_size, 1024); + expandInputRowsKernel<<>>( + input_tokens, output, sorted_token_expert_indices, reverse_permutation_map, + num_tokens, nullptr, hidden_size); +} + +} // namespace moe +} // namespace vllm + +void expand_and_permute( + torch::Tensor& permuted_tokens, // [num_tokens * topk, hidden_size] + torch::Tensor& cum_num_tokens_per_expert, // [num_experts] + torch::Tensor& reverse_permutation_map, // [num_tokens * topk] + torch::Tensor& input_tokens, // [num_tokens, hidden_size] + torch::Tensor& topk_indices, // [num_tokens, topk] + torch::Tensor& token_expert_indices) // [num_tokens, topk] +{ + const int num_experts = cum_num_tokens_per_expert.size(0); + const int topk = topk_indices.size(-1); + const int num_tokens = topk_indices.numel() / topk; + const int hidden_size = input_tokens.size(-1); + + const size_t num_expanded_tokens = num_tokens * topk; + int64_t workspace_size_bytes = (int64_t) vllm::moe::get_workspace_size_for_radix_sort( + num_expanded_tokens, num_experts); + workspace_size_bytes = (workspace_size_bytes + 15) / 16 * 16; + + const at::cuda::OptionalCUDAGuard device_guard(device_of(input_tokens)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + torch::Tensor cub_workspace = torch::empty( + {workspace_size_bytes / input_tokens.element_size()}, input_tokens.options()); + torch::Tensor sorted_topk_indices = torch::empty_like(topk_indices); + torch::Tensor sorted_token_expert_indices = torch::empty_like(token_expert_indices); + + // Sort the token_expert_indices using topk_indices as the key + vllm::moe::radix_sort( + topk_indices.data_ptr(), + sorted_topk_indices.data_ptr(), + token_expert_indices.data_ptr(), + sorted_token_expert_indices.data_ptr(), + cub_workspace.data_ptr(), + workspace_size_bytes, + num_experts, + num_expanded_tokens, + stream); + + // Compute the cumulative number of tokens per expert + vllm::moe::computeTotalRowsBeforeExpert( + sorted_topk_indices.data_ptr(), + num_expanded_tokens, + num_experts, + cum_num_tokens_per_expert.data_ptr(), + stream); + + // Expand and permute the input tokens + VLLM_DISPATCH_FLOATING_TYPES( + input_tokens.scalar_type(), "expandInputRowsKernelLauncher", + [&] { + vllm::moe::expandInputRowsKernelLauncher( + permuted_tokens.data_ptr(), + reverse_permutation_map.data_ptr(), + input_tokens.data_ptr(), + sorted_token_expert_indices.data_ptr(), + num_tokens, + hidden_size, + topk, + stream); + }); +} From ba07256b51355b9fbc8472d76ac8d1ae5fb34311 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 2 Feb 2024 09:12:25 +0000 Subject: [PATCH 15/33] Remove unused --- .../cutlass_kernels/cutlass_preprocessors.cpp | 761 ------------------ csrc/cutlass_kernels/cutlass_preprocessors.h | 64 -- 2 files changed, 825 deletions(-) delete mode 100644 csrc/cutlass_kernels/cutlass_preprocessors.cpp delete mode 100644 csrc/cutlass_kernels/cutlass_preprocessors.h diff --git a/csrc/cutlass_kernels/cutlass_preprocessors.cpp b/csrc/cutlass_kernels/cutlass_preprocessors.cpp deleted file mode 100644 index 24d9af815915..000000000000 --- a/csrc/cutlass_kernels/cutlass_preprocessors.cpp +++ /dev/null @@ -1,761 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.h" -#include "tensorrt_llm/common/assert.h" -#include "tensorrt_llm/common/cudaBf16Wrapper.h" -#include "tensorrt_llm/common/stringUtils.h" - -#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" - -using namespace tensorrt_llm::common; - -namespace tensorrt_llm -{ -namespace kernels -{ -namespace cutlass_kernels -{ - -int get_bits_in_quant_type(QuantType quant_type) -{ - switch (quant_type) - { - case QuantType::INT8_WEIGHT_ONLY: return 8; - case QuantType::PACKED_INT4_WEIGHT_ONLY: return 4; - default: TLLM_CHECK_WITH_INFO(false, "Invalid quant_type"); return -1; - } -} - -struct LayoutDetails -{ - enum class Layout - { - UNKNOWN, - ROW_MAJOR, - COLUMN_MAJOR - }; - - Layout layoutB = Layout::UNKNOWN; - int rows_per_column_tile = 1; - int columns_interleaved = 1; - - bool uses_imma_ldsm = false; -}; - -template -struct getLayoutDetails -{ -}; - -template <> -struct getLayoutDetails -{ - LayoutDetails operator()() - { - LayoutDetails layout_details; - layout_details.layoutB = LayoutDetails::Layout::ROW_MAJOR; - return layout_details; - } -}; - -template <> -struct getLayoutDetails -{ - LayoutDetails operator()() - { - LayoutDetails layout_details; - layout_details.layoutB = LayoutDetails::Layout::COLUMN_MAJOR; - return layout_details; - } -}; - -template -struct getLayoutDetails> -{ - LayoutDetails operator()() - { - LayoutDetails layout_details; - layout_details.layoutB = LayoutDetails::Layout::COLUMN_MAJOR; - layout_details.rows_per_column_tile = RowsPerTile; - layout_details.columns_interleaved = ColumnsInterleaved; - return layout_details; - } -}; - -template -LayoutDetails getLayoutDetailsForArchAndQuantType() -{ - - using CompileTraits = cutlass::gemm::kernel::LayoutDetailsB; - using LayoutB = typename CompileTraits::Layout; - using MmaOperator = typename CompileTraits::Operator; - LayoutDetails details = getLayoutDetails()(); - details.uses_imma_ldsm = std::is_same::value; - return details; -} - -template -LayoutDetails getLayoutDetailsForArch(QuantType quant_type) -{ - LayoutDetails details; - if (quant_type == QuantType::INT8_WEIGHT_ONLY) - { - details = getLayoutDetailsForArchAndQuantType(); - } - else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) - { - details = getLayoutDetailsForArchAndQuantType(); - } - else - { - TLLM_CHECK_WITH_INFO(false, "Unsupported quantization type"); - } - return details; -} - -LayoutDetails getLayoutDetailsForTransform(QuantType quant_type) -{ - const int arch = getSMVersion(); - if (arch >= 70 && arch < 75) - { - return getLayoutDetailsForArch(quant_type); - } - else if (arch >= 75 && arch < 80) - { - return getLayoutDetailsForArch(quant_type); - } - else if (arch >= 80 && arch <= 90) - { - return getLayoutDetailsForArch(quant_type); - } - else - { - TLLM_CHECK_WITH_INFO(false, "Unsupported Arch"); - return LayoutDetails(); - } -} - -// Permutes the rows of B for Turing and Ampere. Throws an error for other architectures. -// The data is permuted such that: -// For int8, each group of 16 rows is permuted using the map below: -// 0 1 8 9 2 3 10 11 4 5 12 13 6 7 14 15 -// For int4, each group of 32 rows is permuted using the map below: -// 0 1 8 9 16 17 24 25 2 3 10 11 18 19 26 27 4 5 12 13 20 21 28 29 6 7 14 15 22 23 30 31 -void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, const int8_t* quantized_tensor, - const std::vector& shape, QuantType quant_type, const int64_t arch_version) -{ - - // We only want to run this step for weight only quant. - TLLM_CHECK(quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY || quant_type == QuantType::INT8_WEIGHT_ONLY); - - TLLM_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); - const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; - const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; - const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; - - const int BITS_PER_ELT = get_bits_in_quant_type(quant_type); - const int K = 16 / BITS_PER_ELT; - const int ELTS_PER_BYTE = 8 / BITS_PER_ELT; - const int ELTS_PER_REG = 32 / BITS_PER_ELT; - - const uint32_t* input_byte_ptr = reinterpret_cast(quantized_tensor); - uint32_t* output_byte_ptr = reinterpret_cast(permuted_quantized_tensor); - - int MMA_SHAPE_N = 8; - int B_ROWS_PER_MMA = 8 * K; - const int elts_in_int32 = 32 / BITS_PER_ELT; - - const int num_vec_cols = num_cols / elts_in_int32; - - TLLM_CHECK_WITH_INFO( - arch_version >= 75, "Unsupported Arch. Pre-volta not supported. Column interleave not needed on Volta."); - - TLLM_CHECK_WITH_INFO(num_rows % B_ROWS_PER_MMA == 0, - fmtstr("Invalid shape for quantized tensor. Number of rows of quantized matrix must be a multiple of %d", - B_ROWS_PER_MMA)); - TLLM_CHECK_WITH_INFO(num_cols % MMA_SHAPE_N == 0, - fmtstr("Invalid shape for quantized tensor. On turing/Ampere, the number of cols must be a multiple of %d.", - MMA_SHAPE_N)); - - // The code is written as below so it works for both int8 and packed int4. - for (int expert = 0; expert < num_experts; ++expert) - { - const int64_t matrix_offset = expert * int64_t(num_rows) * int64_t(num_vec_cols); - for (int base_row = 0; base_row < num_rows; base_row += B_ROWS_PER_MMA) - { - for (int tile_row = 0; tile_row < B_ROWS_PER_MMA; ++tile_row) - { - - for (int write_col = 0; write_col < num_vec_cols; ++write_col) - { - const int write_row = base_row + tile_row; - const int tile_read_row - = 8 * (((tile_row % ELTS_PER_REG) / 2)) + tile_row % 2 + 2 * (tile_row / ELTS_PER_REG); - const int read_row = base_row + tile_read_row; - const int read_col = write_col; - - const int64_t read_offset = matrix_offset + int64_t(read_row) * num_vec_cols + read_col; - const int64_t write_offset = matrix_offset + int64_t(write_row) * num_vec_cols + write_col; - - output_byte_ptr[write_offset] = input_byte_ptr[read_offset]; - } - } - } - } -} - -// We need to use this transpose to correctly handle packed int4 and int8 data -// The reason this code is relatively complex is that the "trivial" loops took a substantial -// amount of time to transpose leading to long preprocessing times. This seemed to be a big -// issue for relatively large models. -template -void subbyte_transpose_impl( - int8_t* transposed_quantized_tensor, const int8_t* quantized_tensor, const std::vector& shape) -{ - const int bits_per_elt = get_bits_in_quant_type(quant_type); - - TLLM_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); - const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; - const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; - const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; - - const size_t col_bytes = num_cols * bits_per_elt / 8; - const size_t col_bytes_trans = num_rows * bits_per_elt / 8; - const size_t num_bytes = size_t(num_experts) * num_rows * col_bytes; - - const uint8_t* input_byte_ptr = reinterpret_cast(quantized_tensor); - uint8_t* output_byte_ptr = reinterpret_cast(transposed_quantized_tensor); - - static_assert(quant_type == QuantType::INT8_WEIGHT_ONLY || quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY, ""); - static constexpr int ELTS_PER_BYTE = quant_type == QuantType::INT8_WEIGHT_ONLY ? 1 : 2; - - static constexpr int M_TILE_L1 = 64; - static constexpr int N_TILE_L1 = M_TILE_L1 / ELTS_PER_BYTE; - uint8_t cache_buf[M_TILE_L1][N_TILE_L1]; - - static constexpr int VECTOR_WIDTH = std::min(32, N_TILE_L1); - - // We assume the dims are a multiple of vector width. Our kernels only handle dims which are multiples - // of 64 for weight-only quantization. As a result, this seemed like a reasonable tradeoff because it - // allows GCC to emit vector instructions. - TLLM_CHECK_WITH_INFO(!(col_bytes_trans % VECTOR_WIDTH) && !(col_bytes % VECTOR_WIDTH), - fmtstr("Number of bytes for rows and cols must be a multiple of %d. However, num_rows_bytes = %ld and " - "num_col_bytes = %ld.", - VECTOR_WIDTH, col_bytes_trans, col_bytes)); - - const int num_m_tiles = (num_rows + M_TILE_L1 - 1) / M_TILE_L1; - const int num_n_tiles = (col_bytes + N_TILE_L1 - 1) / N_TILE_L1; - - for (size_t expert = 0; expert < num_experts; ++expert) - { - const size_t matrix_offset = expert * num_rows * col_bytes; - for (size_t row_tile_start = 0; row_tile_start < num_rows; row_tile_start += M_TILE_L1) - { - for (size_t col_tile_start_byte = 0; col_tile_start_byte < col_bytes; col_tile_start_byte += N_TILE_L1) - { - - const int row_limit = std::min(row_tile_start + M_TILE_L1, num_rows); - const int col_limit = std::min(col_tile_start_byte + N_TILE_L1, col_bytes); - - for (int ii = 0; ii < M_TILE_L1; ++ii) - { - const int row = row_tile_start + ii; - - for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) - { - const int col = col_tile_start_byte + jj; - - const size_t logical_src_offset = matrix_offset + row * col_bytes + col; - - if (row < row_limit && col < col_limit) - { - for (int v = 0; v < VECTOR_WIDTH; ++v) - { - cache_buf[ii][jj + v] = input_byte_ptr[logical_src_offset + v]; - } - } - } - } - - if (quant_type == QuantType::INT8_WEIGHT_ONLY) - { - for (int ii = 0; ii < M_TILE_L1; ++ii) - { - for (int jj = ii + 1; jj < N_TILE_L1; ++jj) - { - std::swap(cache_buf[ii][jj], cache_buf[jj][ii]); - } - } - } - else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) - { - - for (int ii = 0; ii < M_TILE_L1; ++ii) - { - // Using M_TILE_L1 here is deliberate since we assume that the cache tile - // is square in the number of elements (not necessarily the number of bytes). - for (int jj = ii + 1; jj < M_TILE_L1; ++jj) - { - const int ii_byte = ii / ELTS_PER_BYTE; - const int ii_bit_offset = ii % ELTS_PER_BYTE; - - const int jj_byte = jj / ELTS_PER_BYTE; - const int jj_bit_offset = jj % ELTS_PER_BYTE; - - uint8_t src_elt = 0xF & (cache_buf[ii][jj_byte] >> (4 * jj_bit_offset)); - uint8_t tgt_elt = 0xF & (cache_buf[jj][ii_byte] >> (4 * ii_bit_offset)); - - cache_buf[ii][jj_byte] &= (0xF0 >> (4 * jj_bit_offset)); - cache_buf[jj][ii_byte] &= (0xF0 >> (4 * ii_bit_offset)); - - cache_buf[ii][jj_byte] |= (tgt_elt << (4 * jj_bit_offset)); - cache_buf[jj][ii_byte] |= (src_elt << (4 * ii_bit_offset)); - } - } - } - else - { - TLLM_CHECK_WITH_INFO(false, "Unsupported quantization type."); - } - - const size_t row_tile_start_trans = col_tile_start_byte * ELTS_PER_BYTE; - const size_t col_tile_start_byte_trans = row_tile_start / ELTS_PER_BYTE; - - const int row_limit_trans = std::min(row_tile_start_trans + M_TILE_L1, num_cols); - const int col_limit_trans = std::min(col_tile_start_byte_trans + N_TILE_L1, col_bytes_trans); - - for (int ii = 0; ii < M_TILE_L1; ++ii) - { - const int row = row_tile_start_trans + ii; - for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) - { - const int col = col_tile_start_byte_trans + jj; - - const size_t logical_tgt_offset = matrix_offset + row * col_bytes_trans + col; - - if (row < row_limit_trans && col < col_limit_trans) - { - for (int v = 0; v < VECTOR_WIDTH; ++v) - { - output_byte_ptr[logical_tgt_offset + v] = cache_buf[ii][jj + v]; - } - } - } - } - } - } - } -} - -void subbyte_transpose(int8_t* transposed_quantized_tensor, const int8_t* quantized_tensor, - const std::vector& shape, QuantType quant_type) -{ - - if (quant_type == QuantType::INT8_WEIGHT_ONLY) - { - subbyte_transpose_impl(transposed_quantized_tensor, quantized_tensor, shape); - } - else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) - { - subbyte_transpose_impl( - transposed_quantized_tensor, quantized_tensor, shape); - } - else - { - TLLM_CHECK_WITH_INFO(false, "Invalid quant_tye"); - } -} - -void add_bias_and_interleave_int8s_inplace(int8_t* int8_tensor, const size_t num_elts) -{ - for (int ii = 0; ii < num_elts; ++ii) - { - int8_tensor[ii] = int8_t(int(int8_tensor[ii]) + 128); - } - - // Step 2 will transform the layout of a 32-bit register in CUDA in order to match the int4 layout. This has no - // performance benefit and is purely so that int4 and int8 have the same layout. - // Pictorially, this does the following: - // bit 32 0 - // [elt_3 elt_2 elt_1 elt_0] (each elt occupies 8 bits) - // - // And it will rearrange the output 32 bit register to be the following: - // bit 32 0 - // [elt_3 elt_1 elt_2 elt_0] (each elt occupies 8 bits) - - TLLM_CHECK_WITH_INFO(num_elts % 4 == 0, "Dimensions of int8 tensor must be a multiple of 4 for register relayout"); - for (size_t base = 0; base < num_elts; base += 4) - { - std::swap(int8_tensor[base + 1], int8_tensor[base + 2]); - } -} - -void add_bias_and_interleave_int4s_inplace(int8_t* packed_int4_tensor, const size_t num_elts) -{ - const int num_bytes = num_elts / 2; - - // Step 1 will be to transform all the int4s to unsigned in order to make the dequantize take as little - // instructions as possible in the CUDA code. - for (size_t ii = 0; ii < num_bytes; ++ii) - { - int8_t transformed_packed_int4s = 0; - int8_t transformed_first_elt - = (int8_t(packed_int4_tensor[ii] << 4) >> 4) + 8; // The double shift here is to ensure sign extension - int8_t transformed_second_elt = (packed_int4_tensor[ii] >> 4) + 8; - - TLLM_CHECK_WITH_INFO( - transformed_first_elt >= 0 && transformed_first_elt <= 15, "Illegal result for int4 transform (first elt)"); - TLLM_CHECK_WITH_INFO(transformed_second_elt >= 0 && transformed_second_elt <= 15, - "Illegal result for int4 transform (second elt)"); - - // We don't need to mask in these ops since everything should be in the range 0-15 - transformed_packed_int4s |= transformed_first_elt; - transformed_packed_int4s |= (transformed_second_elt << 4); - packed_int4_tensor[ii] = transformed_packed_int4s; - } - - // Step 2 will transform the layout of a 32-bit register in CUDA in order to minimize the number of shift & logical - // instructions That are needed to extract the int4s in the GEMM main loop. Pictorially, the loop below will do the - // following: Take as input a 32 bit register with layout: bit 32 0 - // [elt_7 elt_6 elt_5 elt_4 elt_3 elt_2 elt_1 elt_0] (each elt occupies 4 bits) - // - // And it will rearrange the output 32 bit register to be the following: - // bit 32 0 - // [elt_7 elt_5 elt_3 elt_1 elt_6 elt_4 elt_2 elt_0] (each elt occupies 4 bits) - - TLLM_CHECK_WITH_INFO(num_bytes % 4 == 0, "Dimensions of int4 tensor must be a multiple of 8 for register relayout"); - const size_t num_registers = num_bytes / 4; - - uint32_t* register_ptr = reinterpret_cast(packed_int4_tensor); - for (size_t ii = 0; ii < num_registers; ++ii) - { - const uint32_t current_register = register_ptr[ii]; - uint32_t transformed_register = 0; - - for (int dest_idx = 0; dest_idx < 8; ++dest_idx) - { - const int src_idx = dest_idx < 4 ? 2 * dest_idx : 2 * (dest_idx - 4) + 1; - const int src_shift = 4 * src_idx; - const int dest_shift = 4 * dest_idx; - - const uint32_t src_bits = (current_register >> src_shift) & 0xF; - transformed_register |= (src_bits << dest_shift); - } - register_ptr[ii] = transformed_register; - } -} - -void add_bias_and_interleave_quantized_tensor_inplace(int8_t* tensor, const size_t num_elts, QuantType quant_type) -{ - if (quant_type == QuantType::INT8_WEIGHT_ONLY) - { - add_bias_and_interleave_int8s_inplace(tensor, num_elts); - } - else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) - { - add_bias_and_interleave_int4s_inplace(tensor, num_elts); - } - else - { - TLLM_CHECK_WITH_INFO(false, "Invalid quantization type for interleaving."); - } -} - -void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor, const int8_t* quantized_tensor, - const std::vector& shape, QuantType quant_type, LayoutDetails details) -{ - - // We only want to run this step for weight only quant. - TLLM_CHECK(quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY || quant_type == QuantType::INT8_WEIGHT_ONLY); - - TLLM_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); - const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; - const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; - const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; - - const int BITS_PER_ELT = get_bits_in_quant_type(quant_type); - const int elts_in_int32 = 32 / BITS_PER_ELT; - - const int rows_per_tile = details.rows_per_column_tile; - - TLLM_CHECK_WITH_INFO(!(num_rows % elts_in_int32), - fmtstr("The number of rows must be a multiple of %d but the number of rows is %ld.", elts_in_int32, num_rows)); - - const uint32_t* input_byte_ptr = reinterpret_cast(quantized_tensor); - uint32_t* output_byte_ptr = reinterpret_cast(interleaved_quantized_tensor); - - TLLM_CHECK_WITH_INFO(!(num_rows % rows_per_tile), - fmtstr("The number of rows must be a multiple of %d but the number of rows is %ld.", rows_per_tile, num_rows)); - - const int num_vec_rows = num_rows / elts_in_int32; - const int vec_rows_per_tile = rows_per_tile / elts_in_int32; - const int interleave = details.columns_interleaved; - - for (int expert = 0; expert < num_experts; ++expert) - { - const int64_t matrix_offset = expert * int64_t(num_vec_rows) * int64_t(num_cols); - for (int read_col = 0; read_col < num_cols; ++read_col) - { - const int64_t write_col = read_col / interleave; - for (int base_vec_row = 0; base_vec_row < num_vec_rows; base_vec_row += vec_rows_per_tile) - { - for (int vec_read_row = base_vec_row; - vec_read_row < std::min(num_vec_rows, base_vec_row + vec_rows_per_tile); ++vec_read_row) - { - const int64_t vec_write_row = interleave * base_vec_row - + vec_rows_per_tile * (read_col % interleave) + vec_read_row % vec_rows_per_tile; - - const int64_t read_offset = matrix_offset + int64_t(read_col) * num_vec_rows + vec_read_row; - const int64_t write_offset - = matrix_offset + int64_t(write_col) * num_vec_rows * interleave + vec_write_row; - output_byte_ptr[write_offset] = input_byte_ptr[read_offset]; - } - } - } - } -} - -void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, const int8_t* row_major_quantized_weight, - const std::vector& shape, QuantType quant_type) -{ - LayoutDetails details = getLayoutDetailsForTransform(quant_type); - - TLLM_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); - - size_t num_elts = 1; - for (const auto& dim : shape) - { - num_elts *= dim; - } - - const size_t num_bytes = num_elts * get_bits_in_quant_type(quant_type) / 8; - - std::vector src_buf(num_bytes); - std::vector dst_buf(num_bytes); - std::copy(row_major_quantized_weight, row_major_quantized_weight + num_bytes, src_buf.begin()); - - // Works on row major data, so issue this permutation first. - if (details.uses_imma_ldsm) - { - const int arch = getSMVersion(); - permute_B_rows_for_mixed_gemm(dst_buf.data(), src_buf.data(), shape, quant_type, arch); - src_buf.swap(dst_buf); - } - - if (details.layoutB == LayoutDetails::Layout::COLUMN_MAJOR) - { - subbyte_transpose(dst_buf.data(), src_buf.data(), shape, quant_type); - src_buf.swap(dst_buf); - } - - if (details.columns_interleaved > 1) - { - interleave_column_major_tensor(dst_buf.data(), src_buf.data(), shape, quant_type, details); - src_buf.swap(dst_buf); - } - - add_bias_and_interleave_quantized_tensor_inplace(src_buf.data(), num_elts, quant_type); - std::copy(src_buf.begin(), src_buf.end(), preprocessed_quantized_weight); -} - -/* - Arguments: - input_weight_ptr - the weight tensor to be quantized. Must be 2-D or 3-D and of type FP16. - - quant_type - the type of the output quantization weight. - - This function does symmetric quantization on 2-D or 3-D tensors. It uses the full int range and assumes the - zero-point is zero and will automatically construct the scales. - - It always quantizes the last axis of the tensor. For 3-D tensors, it operates in "batched" mode where the tensor is - viewed as a stack of matrices and a scale is produced for each column of every matrix. - -Outputs - processed_quantized_weight - quantized AND processed weight for GEMM. This MUST be used with the CUTLASS GEMM - unprocessed_quantized_weight - quantized but unprocessed weights. Useful for reference checking. - scale_ptr - scales for the quantized weight. - - Note that the returned quantized_weights will be preprocessed in a way to accelerate the mixed type GEMM. The data - layout may not make sense if printed. - - Shapes: - quant_type == int8: - If weight is a [m,n] matrix, quantized_weights will have shape [m,n] and scales of shape [n] - If weight is a [b,m,n] tensor, unprocessed_quantized_weight will have shape [b,m,n] and scales of shape [b,n] - quant_type == int4: - If weight is a [m,n] matrix, quantized_weights will have shape [m, ceil(n/2)] and scales of shape [n] - If weight is a [b,m,n] tensor, unprocessed_quantized_weight will have shape [b,m, ceil(n/2)] and scales of shape - [b,n] - - The quantized_weight will be of type torch.int8 and have two int4 values packed in a single byte. This is the - reason for halving the shape. At the time of writing this code, there was not an elegant way to handle this kind - of batched quantization using torch's quantized tensors (to the best of the author's knowledge). Scale tensors - must have a dimension of 1, which breaks the semantics we need for batched weights. - */ - -template -void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_quantized_weight, - ComputeType* scale_ptr, const WeightType* input_weight_ptr, const std::vector& shape, QuantType quant_type) -{ - - TLLM_CHECK_WITH_INFO(processed_quantized_weight, "Processed quantized tensor is NULL"); - TLLM_CHECK_WITH_INFO(scale_ptr, "Scale output pointer is NULL"); - TLLM_CHECK_WITH_INFO(input_weight_ptr, "Input weight pointer is NULL"); - - TLLM_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); - const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; - const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; - const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; - - const int bits_in_type = get_bits_in_quant_type(quant_type); - const int bytes_per_out_col = num_cols * bits_in_type / 8; - - std::vector weight_buf; - if (unprocessed_quantized_weight == nullptr) - { - weight_buf.resize(num_experts * num_rows * num_cols); - unprocessed_quantized_weight = weight_buf.data(); - } - - const int input_mat_size = num_rows * num_cols; - const int quantized_mat_size = num_rows * bytes_per_out_col; - const float quant_range_scale = 1.f / float(1 << (bits_in_type - 1)); - - std::vector per_col_max(num_cols); - - for (int expert = 0; expert < num_experts; ++expert) - { - const WeightType* current_weight = input_weight_ptr + expert * input_mat_size; - int8_t* current_quantized_weight = unprocessed_quantized_weight + expert * quantized_mat_size; - - // First we find the per column max for this expert weight. - for (int jj = 0; jj < num_cols; ++jj) - { - per_col_max[jj] = 0.f; - } - - for (int ii = 0; ii < num_rows; ++ii) - { - const WeightType* current_weight_row = current_weight + ii * num_cols; - for (int jj = 0; jj < num_cols; ++jj) - { - per_col_max[jj] = std::max(per_col_max[jj], std::abs(float(current_weight_row[jj]))); - } - } - - // Then, we construct the scales - ComputeType* current_scales = scale_ptr + expert * num_cols; - for (int jj = 0; jj < num_cols; ++jj) - { - per_col_max[jj] *= quant_range_scale; - current_scales[jj] = ComputeType(per_col_max[jj]); - } - - // Finally, construct the weights. - for (int ii = 0; ii < num_rows; ++ii) - { - int8_t* current_quantized_weight_row = current_quantized_weight + ii * bytes_per_out_col; - const WeightType* current_weight_row = current_weight + ii * num_cols; - for (int jj = 0; jj < bytes_per_out_col; ++jj) - { - - if (quant_type == QuantType::INT8_WEIGHT_ONLY) - { - const float col_scale = per_col_max[jj]; - const float weight_elt = float(current_weight_row[jj]); - const float scaled_weight = round(weight_elt / col_scale); - const int8_t clipped_weight = int8_t(std::max(-128.f, std::min(127.f, scaled_weight))); - current_quantized_weight_row[jj] = clipped_weight; - } - else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) - { - - // We will pack two int4 elements per iteration of the inner loop. - int8_t packed_int4s = 0; - for (int packed_idx = 0; packed_idx < 2; ++packed_idx) - { - const int input_idx = 2 * jj + packed_idx; - if (input_idx < num_cols) - { - const float col_scale = per_col_max[input_idx]; - const float weight_elt = float(current_weight_row[input_idx]); - const float scaled_weight = round(weight_elt / col_scale); - int int_weight = int(scaled_weight); - const int8_t clipped_weight = std::max(-8, std::min(7, int_weight)); - - // Kill the sign extension bits (hence 0x0F mask) then shift to upper bits - // if packing the second int4 and or the bits into the final result. - packed_int4s |= ((clipped_weight & 0x0F) << (4 * packed_idx)); - } - } - current_quantized_weight_row[jj] = packed_int4s; - } - else - { - TLLM_CHECK_WITH_INFO(false, "Unsupported quantization type"); - } - } - } - } - - preprocess_weights_for_mixed_gemm(processed_quantized_weight, unprocessed_quantized_weight, shape, quant_type); -} - -template void symmetric_quantize( - int8_t*, int8_t*, half*, const float*, const std::vector&, QuantType); - -template void symmetric_quantize( - int8_t*, int8_t*, half*, const half*, const std::vector&, QuantType); - -#ifdef ENABLE_BF16 -template void symmetric_quantize<__nv_bfloat16, __nv_bfloat16>( - int8_t*, int8_t*, __nv_bfloat16*, const __nv_bfloat16*, const std::vector&, QuantType); - -template void symmetric_quantize<__nv_bfloat16, float>( - int8_t*, int8_t*, __nv_bfloat16*, const float*, const std::vector&, QuantType); -#endif - -template -void symmetric_quantize(int8_t* processed_quantized_weight, ComputeType* scale_ptr, const WeightType* input_weight_ptr, - const std::vector& shape, QuantType quant_type) -{ - symmetric_quantize(processed_quantized_weight, nullptr, scale_ptr, input_weight_ptr, shape, quant_type); -} - -template void symmetric_quantize(int8_t*, float*, const float*, const std::vector&, QuantType); - -template void symmetric_quantize(int8_t*, half*, const float*, const std::vector&, QuantType); - -template void symmetric_quantize(int8_t*, half*, const half*, const std::vector&, QuantType); - -#ifdef ENABLE_BF16 -template void symmetric_quantize<__nv_bfloat16, __nv_bfloat16>( - int8_t*, __nv_bfloat16*, const __nv_bfloat16*, const std::vector&, QuantType); - -template void symmetric_quantize<__nv_bfloat16, half>( - int8_t*, __nv_bfloat16*, const half*, const std::vector&, QuantType); - -template void symmetric_quantize( - int8_t*, half*, const __nv_bfloat16*, const std::vector&, QuantType); - -template void symmetric_quantize<__nv_bfloat16, float>( - int8_t*, __nv_bfloat16*, const float*, const std::vector&, QuantType); -#endif - -} // namespace cutlass_kernels -} // namespace kernels -} // namespace tensorrt_llm diff --git a/csrc/cutlass_kernels/cutlass_preprocessors.h b/csrc/cutlass_kernels/cutlass_preprocessors.h deleted file mode 100644 index f93790e54108..000000000000 --- a/csrc/cutlass_kernels/cutlass_preprocessors.h +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include - -#include "tensorrt_llm/common/cudaUtils.h" - -namespace tensorrt_llm -{ -namespace kernels -{ -namespace cutlass_kernels -{ - -enum class QuantType -{ - INT8_WEIGHT_ONLY, - PACKED_INT4_WEIGHT_ONLY -}; -int get_bits_in_quant_type(QuantType quant_type); - -// Shapes here can be 2 or 3D. 2-D shapes are [num_rows, num_cols] -// 3-D shapes are [num_experts, num_rows, num_cols] -void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, const int8_t* quantized_tensor, - const std::vector& shape, QuantType quant_type, const int64_t arch_version); - -void subbyte_transpose(int8_t* transposed_quantized_tensor, const int8_t* quantized_tensor, - const std::vector& shape, QuantType quant_type); - -void add_bias_and_interleave_quantized_tensor_inplace(int8_t* tensor, const size_t num_elts, QuantType quant_type); - -void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, const int8_t* row_major_quantized_weight, - const std::vector& shape, QuantType quant_type); - -template -void symmetric_quantize(int8_t* processed_quantized_weight, ComputeType* scale_ptr, const WeightType* input_weight_ptr, - const std::vector& shape, QuantType quant_type); - -// This is exposed so that we can write tests that use the processed weights for CUTLASS but the unprocessed weight -// to implement a simple reference implementation. -template -void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_quantized_weight, - ComputeType* scale_ptr, const WeightType* input_weight_ptr, const std::vector& shape, QuantType quant_type); - -} // namespace cutlass_kernels -} // namespace kernels -} // namespace tensorrt_llm From def2ccdcfc7fec32f252db2b1e59a95e4a61e6c8 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 2 Feb 2024 09:17:11 +0000 Subject: [PATCH 16/33] Move --- csrc/{cutlass_kernels => cutlass_utils}/cutlass_heuristic.cpp | 0 csrc/{cutlass_kernels => cutlass_utils}/cutlass_heuristic.h | 0 .../moe_gemm/moe_gemm_kernels.h | 2 +- .../moe_gemm/moe_gemm_kernels_bf16_bf16.cu | 0 .../moe_gemm/moe_gemm_kernels_bf16_uint4.cu | 0 .../moe_gemm/moe_gemm_kernels_bf16_uint8.cu | 0 .../moe_gemm/moe_gemm_kernels_fp16_fp16.cu | 0 .../moe_gemm/moe_gemm_kernels_fp16_uint4.cu | 0 .../moe_gemm/moe_gemm_kernels_fp16_uint8.cu | 0 .../moe_gemm/moe_gemm_kernels_fp32_fp32.cu | 0 .../moe_gemm => moe}/moe_gemm_kernels_template.h | 4 ++-- 11 files changed, 3 insertions(+), 3 deletions(-) rename csrc/{cutlass_kernels => cutlass_utils}/cutlass_heuristic.cpp (100%) rename csrc/{cutlass_kernels => cutlass_utils}/cutlass_heuristic.h (100%) rename csrc/{cutlass_kernels => cutlass_utils}/moe_gemm/moe_gemm_kernels.h (97%) rename csrc/{cutlass_kernels => cutlass_utils}/moe_gemm/moe_gemm_kernels_bf16_bf16.cu (100%) rename csrc/{cutlass_kernels => cutlass_utils}/moe_gemm/moe_gemm_kernels_bf16_uint4.cu (100%) rename csrc/{cutlass_kernels => cutlass_utils}/moe_gemm/moe_gemm_kernels_bf16_uint8.cu (100%) rename csrc/{cutlass_kernels => cutlass_utils}/moe_gemm/moe_gemm_kernels_fp16_fp16.cu (100%) rename csrc/{cutlass_kernels => cutlass_utils}/moe_gemm/moe_gemm_kernels_fp16_uint4.cu (100%) rename csrc/{cutlass_kernels => cutlass_utils}/moe_gemm/moe_gemm_kernels_fp16_uint8.cu (100%) rename csrc/{cutlass_kernels => cutlass_utils}/moe_gemm/moe_gemm_kernels_fp32_fp32.cu (100%) rename csrc/{cutlass_kernels/moe_gemm => moe}/moe_gemm_kernels_template.h (99%) diff --git a/csrc/cutlass_kernels/cutlass_heuristic.cpp b/csrc/cutlass_utils/cutlass_heuristic.cpp similarity index 100% rename from csrc/cutlass_kernels/cutlass_heuristic.cpp rename to csrc/cutlass_utils/cutlass_heuristic.cpp diff --git a/csrc/cutlass_kernels/cutlass_heuristic.h b/csrc/cutlass_utils/cutlass_heuristic.h similarity index 100% rename from csrc/cutlass_kernels/cutlass_heuristic.h rename to csrc/cutlass_utils/cutlass_heuristic.h diff --git a/csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels.h b/csrc/cutlass_utils/moe_gemm/moe_gemm_kernels.h similarity index 97% rename from csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels.h rename to csrc/cutlass_utils/moe_gemm/moe_gemm_kernels.h index 3f4def7d7152..b5aee286e77b 100644 --- a/csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels.h +++ b/csrc/cutlass_utils/moe_gemm/moe_gemm_kernels.h @@ -16,7 +16,7 @@ */ #pragma once -#include "tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h" +#include "../cutlass_extensions/gemm_configs.h" #include #include diff --git a/csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_bf16.cu b/csrc/cutlass_utils/moe_gemm/moe_gemm_kernels_bf16_bf16.cu similarity index 100% rename from csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_bf16.cu rename to csrc/cutlass_utils/moe_gemm/moe_gemm_kernels_bf16_bf16.cu diff --git a/csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint4.cu b/csrc/cutlass_utils/moe_gemm/moe_gemm_kernels_bf16_uint4.cu similarity index 100% rename from csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint4.cu rename to csrc/cutlass_utils/moe_gemm/moe_gemm_kernels_bf16_uint4.cu diff --git a/csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint8.cu b/csrc/cutlass_utils/moe_gemm/moe_gemm_kernels_bf16_uint8.cu similarity index 100% rename from csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint8.cu rename to csrc/cutlass_utils/moe_gemm/moe_gemm_kernels_bf16_uint8.cu diff --git a/csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp16.cu b/csrc/cutlass_utils/moe_gemm/moe_gemm_kernels_fp16_fp16.cu similarity index 100% rename from csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp16.cu rename to csrc/cutlass_utils/moe_gemm/moe_gemm_kernels_fp16_fp16.cu diff --git a/csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint4.cu b/csrc/cutlass_utils/moe_gemm/moe_gemm_kernels_fp16_uint4.cu similarity index 100% rename from csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint4.cu rename to csrc/cutlass_utils/moe_gemm/moe_gemm_kernels_fp16_uint4.cu diff --git a/csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint8.cu b/csrc/cutlass_utils/moe_gemm/moe_gemm_kernels_fp16_uint8.cu similarity index 100% rename from csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint8.cu rename to csrc/cutlass_utils/moe_gemm/moe_gemm_kernels_fp16_uint8.cu diff --git a/csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp32_fp32.cu b/csrc/cutlass_utils/moe_gemm/moe_gemm_kernels_fp32_fp32.cu similarity index 100% rename from csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp32_fp32.cu rename to csrc/cutlass_utils/moe_gemm/moe_gemm_kernels_fp32_fp32.cu diff --git a/csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h b/csrc/moe/moe_gemm_kernels_template.h similarity index 99% rename from csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h rename to csrc/moe/moe_gemm_kernels_template.h index 19a480dc8986..275c625e8734 100644 --- a/csrc/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h +++ b/csrc/moe/moe_gemm_kernels_template.h @@ -34,8 +34,8 @@ #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/cudaUtils.h" -#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h" -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h" +#include "../cutlass_kernels/cutlass_heuristic.h" +#include "../cutlass_kernels/moe_gemm/moe_gemm_kernels.h" #include #include #include From 72256cceca0860a072bea221cba248bc446ee89e Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 2 Feb 2024 09:22:18 +0000 Subject: [PATCH 17/33] Move --- .../moe_gemm/moe_gemm_kernels_bf16_uint4.cu | 24 ------------------- .../moe_gemm/moe_gemm_kernels_bf16_uint8.cu | 24 ------------------- .../moe_gemm/moe_gemm_kernels_fp16_uint4.cu | 22 ----------------- .../moe_gemm/moe_gemm_kernels_fp16_uint8.cu | 22 ----------------- .../moe_gemm => moe}/moe_gemm_kernels.h | 0 .../moe_gemm_kernels_bf16_bf16.cu | 2 +- .../moe_gemm_kernels_fp16_fp16.cu | 2 +- .../moe_gemm_kernels_fp32_fp32.cu | 2 +- 8 files changed, 3 insertions(+), 95 deletions(-) delete mode 100644 csrc/cutlass_utils/moe_gemm/moe_gemm_kernels_bf16_uint4.cu delete mode 100644 csrc/cutlass_utils/moe_gemm/moe_gemm_kernels_bf16_uint8.cu delete mode 100644 csrc/cutlass_utils/moe_gemm/moe_gemm_kernels_fp16_uint4.cu delete mode 100644 csrc/cutlass_utils/moe_gemm/moe_gemm_kernels_fp16_uint8.cu rename csrc/{cutlass_utils/moe_gemm => moe}/moe_gemm_kernels.h (100%) rename csrc/{cutlass_utils/moe_gemm => moe}/moe_gemm_kernels_bf16_bf16.cu (90%) rename csrc/{cutlass_utils/moe_gemm => moe}/moe_gemm_kernels_fp16_fp16.cu (89%) rename csrc/{cutlass_utils/moe_gemm => moe}/moe_gemm_kernels_fp32_fp32.cu (89%) diff --git a/csrc/cutlass_utils/moe_gemm/moe_gemm_kernels_bf16_uint4.cu b/csrc/cutlass_utils/moe_gemm/moe_gemm_kernels_bf16_uint4.cu deleted file mode 100644 index b5d129ca91c0..000000000000 --- a/csrc/cutlass_utils/moe_gemm/moe_gemm_kernels_bf16_uint4.cu +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" - -namespace tensorrt_llm -{ -#ifdef ENABLE_BF16 -template class MoeGemmRunner<__nv_bfloat16, cutlass::uint4b_t>; -#endif -} // namespace tensorrt_llm diff --git a/csrc/cutlass_utils/moe_gemm/moe_gemm_kernels_bf16_uint8.cu b/csrc/cutlass_utils/moe_gemm/moe_gemm_kernels_bf16_uint8.cu deleted file mode 100644 index 174d5a7b907e..000000000000 --- a/csrc/cutlass_utils/moe_gemm/moe_gemm_kernels_bf16_uint8.cu +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" - -namespace tensorrt_llm -{ -#ifdef ENABLE_BF16 -template class MoeGemmRunner<__nv_bfloat16, uint8_t>; -#endif -} // namespace tensorrt_llm diff --git a/csrc/cutlass_utils/moe_gemm/moe_gemm_kernels_fp16_uint4.cu b/csrc/cutlass_utils/moe_gemm/moe_gemm_kernels_fp16_uint4.cu deleted file mode 100644 index 3f4b0bb718fd..000000000000 --- a/csrc/cutlass_utils/moe_gemm/moe_gemm_kernels_fp16_uint4.cu +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" - -namespace tensorrt_llm -{ -template class MoeGemmRunner; -} diff --git a/csrc/cutlass_utils/moe_gemm/moe_gemm_kernels_fp16_uint8.cu b/csrc/cutlass_utils/moe_gemm/moe_gemm_kernels_fp16_uint8.cu deleted file mode 100644 index a8d2d5e6c8eb..000000000000 --- a/csrc/cutlass_utils/moe_gemm/moe_gemm_kernels_fp16_uint8.cu +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" - -namespace tensorrt_llm -{ -template class MoeGemmRunner; -} diff --git a/csrc/cutlass_utils/moe_gemm/moe_gemm_kernels.h b/csrc/moe/moe_gemm_kernels.h similarity index 100% rename from csrc/cutlass_utils/moe_gemm/moe_gemm_kernels.h rename to csrc/moe/moe_gemm_kernels.h diff --git a/csrc/cutlass_utils/moe_gemm/moe_gemm_kernels_bf16_bf16.cu b/csrc/moe/moe_gemm_kernels_bf16_bf16.cu similarity index 90% rename from csrc/cutlass_utils/moe_gemm/moe_gemm_kernels_bf16_bf16.cu rename to csrc/moe/moe_gemm_kernels_bf16_bf16.cu index 42699295b7e8..fee2550fbca3 100644 --- a/csrc/cutlass_utils/moe_gemm/moe_gemm_kernels_bf16_bf16.cu +++ b/csrc/moe/moe_gemm_kernels_bf16_bf16.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" +#include "moe_gemm_kernels_template.h" namespace tensorrt_llm { diff --git a/csrc/cutlass_utils/moe_gemm/moe_gemm_kernels_fp16_fp16.cu b/csrc/moe/moe_gemm_kernels_fp16_fp16.cu similarity index 89% rename from csrc/cutlass_utils/moe_gemm/moe_gemm_kernels_fp16_fp16.cu rename to csrc/moe/moe_gemm_kernels_fp16_fp16.cu index f57d91f9d810..ea958cd6cc23 100644 --- a/csrc/cutlass_utils/moe_gemm/moe_gemm_kernels_fp16_fp16.cu +++ b/csrc/moe/moe_gemm_kernels_fp16_fp16.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" +#include "moe_gemm_kernels_template.h" namespace tensorrt_llm { diff --git a/csrc/cutlass_utils/moe_gemm/moe_gemm_kernels_fp32_fp32.cu b/csrc/moe/moe_gemm_kernels_fp32_fp32.cu similarity index 89% rename from csrc/cutlass_utils/moe_gemm/moe_gemm_kernels_fp32_fp32.cu rename to csrc/moe/moe_gemm_kernels_fp32_fp32.cu index 6b57aae1d844..6b27ab8e9c1a 100644 --- a/csrc/cutlass_utils/moe_gemm/moe_gemm_kernels_fp32_fp32.cu +++ b/csrc/moe/moe_gemm_kernels_fp32_fp32.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" +#include "moe_gemm_kernels_template.h" namespace tensorrt_llm { From e86fd06fd7c88ae9653b88d0d564b30eacdadd59 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 4 Feb 2024 03:58:27 +0000 Subject: [PATCH 18/33] Remove --- csrc/moe_kernels.cu | 1051 ------------------------------------------- 1 file changed, 1051 deletions(-) delete mode 100644 csrc/moe_kernels.cu diff --git a/csrc/moe_kernels.cu b/csrc/moe_kernels.cu deleted file mode 100644 index 0fde9d429ed3..000000000000 --- a/csrc/moe_kernels.cu +++ /dev/null @@ -1,1051 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/common/workspace.h" -#include -#include -#include -#include - -#include "cutlass/array.h" -#include "cutlass/epilogue/thread/activation.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/numeric_types.h" -#include "cutlass_extensions/epilogue/thread/fused_activations.h" - -#pragma GCC diagnostic pop - -#include "tensorrt_llm/common/cudaUtils.h" -#include "tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.h" - -#include -#include -#include - -// FIXME(woosuk) -#ifndef ENABLE_BF16 -#define ENABLE_BF16 -#endif - -namespace tensorrt_llm::kernels -{ - -static constexpr int WARP_SIZE = 32; - -// ====================== Softmax things =============================== -// We have our own implementation of softmax here so we can support transposing the output -// in the softmax kernel when we extend this module to support expert-choice routing. -template -__launch_bounds__(TPB) __global__ - void moeSoftmax(const float* input, const bool* finished, float* output, const int num_cols) -{ - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage tmpStorage; - - __shared__ float normalizing_factor; - __shared__ float float_max; - - const int thread_row_offset = blockIdx.x * num_cols; - - cub::Sum sum; - float threadData(-FLT_MAX); - - // Don't touch finished rows. - if ((finished != nullptr) && finished[blockIdx.x]) - { - return; - } - - for (int ii = threadIdx.x; ii < num_cols; ii += TPB) - { - const int idx = thread_row_offset + ii; - threadData = max(input[idx], threadData); - } - - const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); - if (threadIdx.x == 0) - { - float_max = maxElem; - } - __syncthreads(); - - threadData = 0; - - for (int ii = threadIdx.x; ii < num_cols; ii += TPB) - { - const int idx = thread_row_offset + ii; - threadData += exp((static_cast(input[idx]) - float_max)); - } - - const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum); - - if (threadIdx.x == 0) - { - normalizing_factor = 1.f / Z; - } - __syncthreads(); - - for (int ii = threadIdx.x; ii < num_cols; ii += TPB) - { - const int idx = thread_row_offset + ii; - const float val = exp((static_cast(input[idx]) - float_max)) * normalizing_factor; - output[idx] = val; - } -} - -template -__launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax, const bool* finished, float* output, - int* indices, int* source_rows, const int num_experts, const int k, const int start_expert, const int end_expert) -{ - - using cub_kvp = cub::KeyValuePair; - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage tmpStorage; - - cub_kvp thread_kvp; - cub::ArgMax arg_max; - - const int num_rows = gridDim.x; - const int block_row = blockIdx.x; - - const bool row_is_active = finished ? !finished[block_row] : true; - const int thread_read_offset = blockIdx.x * num_experts; - for (int k_idx = 0; k_idx < k; ++k_idx) - { - thread_kvp.key = 0; - thread_kvp.value = -1.f; // This is OK because inputs are probabilities - - cub_kvp inp_kvp; - for (int expert = threadIdx.x; expert < num_experts; expert += TPB) - { - const int idx = thread_read_offset + expert; - inp_kvp.key = expert; - inp_kvp.value = inputs_after_softmax[idx]; - - for (int prior_k = 0; prior_k < k_idx; ++prior_k) - { - const int prior_winning_expert = indices[k * block_row + prior_k]; - - if (prior_winning_expert == expert) - { - inp_kvp = thread_kvp; - } - } - - thread_kvp = arg_max(inp_kvp, thread_kvp); - } - - const cub_kvp result_kvp = BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max); - if (threadIdx.x == 0) - { - // Ignore experts the node isn't responsible for with expert parallelism - const int expert = result_kvp.key; - const bool node_uses_expert = expert >= start_expert && expert < end_expert; - const bool should_process_row = row_is_active && node_uses_expert; - - const int idx = k * block_row + k_idx; - output[idx] = result_kvp.value; - indices[idx] = should_process_row ? (expert - start_expert) : num_experts; - assert(indices[idx] >= 0); - source_rows[idx] = k_idx * num_rows + block_row; - } - __syncthreads(); - } -} - -// ====================== TopK softmax things =============================== - -/* - A Top-K gating softmax written to exploit when the number of experts in the MoE layers - are a small power of 2. This allows us to cleanly share the rows among the threads in - a single warp and eliminate communication between warps (so no need to use shared mem). - - It fuses the softmax, max and argmax into a single kernel. - - Limitations: - 1) This implementation is intended for when the number of experts is a small power of 2. - 2) This implementation assumes k is small, but will work for any k. -*/ - -template -__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ - void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, int* indices, - int* source_rows, const int k, const int start_expert, const int end_expert) -{ - // We begin by enforcing compile time assertions and setting up compile time constants. - static_assert(VPT == (VPT & -VPT), "VPT must be power of 2"); - static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2"); - static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2"); - static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16"); - - // Number of bytes each thread pulls in per load - static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float); - static constexpr int ELTS_PER_ROW = NUM_EXPERTS; - static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT; - static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG; - - // Restrictions based on previous section. - static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg"); - static_assert(WARP_SIZE % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp"); - static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), "THREADS_PER_ROW must be power of 2"); - static_assert(THREADS_PER_ROW <= WARP_SIZE, "THREADS_PER_ROW can be at most warp size"); - - // We have NUM_EXPERTS elements per row. We specialize for small #experts - static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT; - static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW; - static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP; - - // Restrictions for previous section. - static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0, "The elts per row must cleanly divide the total elt per warp"); - - // ===================== From this point, we finally start computing run-time variables. ======================== - - // Compute CTA and warp rows. We pack multiple rows into a single warp, and a block contains WARPS_PER_CTA warps. - // This, each block processes a chunk of rows. We start by computing the start row for each block. - const int cta_base_row = blockIdx.x * ROWS_PER_CTA; - - // Now, using the base row per thread block, we compute the base row per warp. - const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP; - - // The threads in a warp are split into sub-groups that will work on a row. - // We compute row offset for each thread sub-group - const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW; - const int thread_row = warp_base_row + thread_row_in_warp; - - // Threads with indices out of bounds should early exit here. - if (thread_row >= num_rows) - { - return; - } - const bool row_is_active = finished ? !finished[thread_row] : true; - - // We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the - // row it will read. - const float* thread_row_ptr = input + thread_row * ELTS_PER_ROW; - - // Now, we compute the group each thread belong to in order to determine the first column to start loads. - const int thread_group_idx = threadIdx.x % THREADS_PER_ROW; - const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG; - const float* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread; - - // Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory, - // this can support all powers of 2 up to 16. - using AccessType = cutlass::AlignedArray; - - // Finally, we pull in the data from global mem - cutlass::Array row_chunk; - AccessType* row_chunk_vec_ptr = reinterpret_cast(&row_chunk); - const AccessType* vec_thread_read_ptr = reinterpret_cast(thread_read_ptr); -#pragma unroll - for (int ii = 0; ii < LDG_PER_THREAD; ++ii) - { - row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW]; - } - - // First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just - // convert to float afterwards for the exp + sum reduction. - float thread_max = row_chunk[0]; -#pragma unroll - for (int ii = 1; ii < VPT; ++ii) - { - thread_max = max(thread_max, row_chunk[ii]); - } - -// Now, we find the max within the thread group and distribute among the threads. We use a butterfly reduce. -#pragma unroll - for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) - { - thread_max = max(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW)); - } - - // From this point, thread max in all the threads have the max within the row. - // Now, we subtract the max from each element in the thread and take the exp. We also compute the thread local sum. - float row_sum = 0; -#pragma unroll - for (int ii = 0; ii < VPT; ++ii) - { - row_chunk[ii] = expf(row_chunk[ii] - thread_max); - row_sum += row_chunk[ii]; - } - -// Now, we perform the sum reduce within each thread group. Similar to the max reduce, we use a bufferfly pattern. -#pragma unroll - for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) - { - row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, THREADS_PER_ROW); - } - - // From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables - // respectively. Finally, we can scale the rows for the softmax. Technically, for top-k gating we don't need to - // compute the entire softmax row. We can likely look at the maxes and only compute for the top-k values in the row. - // However, this kernel will likely not be a bottle neck and it seems better to closer match torch and find the - // argmax after computing the softmax. - const float reciprocal_row_sum = 1.f / row_sum; - -#pragma unroll - for (int ii = 0; ii < VPT; ++ii) - { - row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum; - } - - // Now, softmax_res contains the softmax of the row chunk. Now, I want to find the topk elements in each row, along - // with the max index. - int start_col = first_elt_read_by_thread; - static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW; - - for (int k_idx = 0; k_idx < k; ++k_idx) - { - // First, each thread does the local argmax - float max_val = row_chunk[0]; - int expert = start_col; -#pragma unroll - for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD; ++ldg, col += COLS_PER_GROUP_LDG) - { -#pragma unroll - for (int ii = 0; ii < ELTS_PER_LDG; ++ii) - { - float val = row_chunk[ldg * ELTS_PER_LDG + ii]; - - // No check on the experts here since columns with the smallest index are processed first and only - // updated if > (not >=) - if (val > max_val) - { - max_val = val; - expert = col + ii; - } - } - } - -// Now, we perform the argmax reduce. We use the butterfly pattern so threads reach consensus about the max. -// This will be useful for K > 1 so that the threads can agree on "who" had the max value. That thread can -// then blank out their max with -inf and the warp can run more iterations... -#pragma unroll - for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) - { - float other_max = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, THREADS_PER_ROW); - int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, THREADS_PER_ROW); - - // We want lower indices to "win" in every thread so we break ties this way - if (other_max > max_val || (other_max == max_val && other_expert < expert)) - { - max_val = other_max; - expert = other_expert; - } - } - - // Write the max for this k iteration to global memory. - if (thread_group_idx == 0) - { - // Add a guard to ignore experts not included by this node - const bool node_uses_expert = expert >= start_expert && expert < end_expert; - const bool should_process_row = row_is_active && node_uses_expert; - - // The lead thread from each sub-group will write out the final results to global memory. (This will be a - // single) thread per row of the input/output matrices. - const int idx = k * thread_row + k_idx; - output[idx] = max_val; - indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS; - source_rows[idx] = k_idx * num_rows + thread_row; - } - - // Finally, we clear the value in the thread with the current max if there is another iteration to run. - if (k_idx + 1 < k) - { - const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG; - const int thread_to_clear_in_group = (expert / ELTS_PER_LDG) % THREADS_PER_ROW; - - // Only the thread in the group which produced the max will reset the "winning" value to -inf. - if (thread_group_idx == thread_to_clear_in_group) - { - const int offset_for_expert = expert % ELTS_PER_LDG; - // Safe to set to any negative value since row_chunk values must be between 0 and 1. - row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = -10000.f; - } - } - } -} - -namespace detail -{ -// Constructs some constants needed to partition the work across threads at compile time. -template -struct TopkConstants -{ - static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float); - static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, ""); - static constexpr int VECs_PER_THREAD = std::max(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE)); - static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG; - static constexpr int THREADS_PER_ROW = EXPERTS / VPT; - static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW; -}; -} // namespace detail - -template -void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, int* indices, - int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream) -{ - static constexpr std::size_t MAX_BYTES_PER_LDG = 16; - - static constexpr int BYTES_PER_LDG = std::min(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS); - using Constants = detail::TopkConstants; - static constexpr int VPT = Constants::VPT; - static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP; - const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP; - const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB; - - dim3 block_dim(WARP_SIZE, WARPS_PER_TB); - topkGatingSoftmax<<>>( - input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert); -} - -void topkGatingSoftmaxKernelLauncher(const float* input, const bool* finished, float* output, - float* softmax_temp_output, int* indices, int* source_row, const int num_rows, const int num_experts, const int k, - const int start_expert, const int end_expert, cudaStream_t stream) -{ - static constexpr int WARPS_PER_TB = 4; - - switch (num_experts) - { - case 1: - { - topkGatingSoftmaxLauncherHelper<1, WARPS_PER_TB>( - input, finished, output, indices, source_row, num_rows, k, start_expert, end_expert, stream); - break; - } - case 2: - { - topkGatingSoftmaxLauncherHelper<2, WARPS_PER_TB>( - input, finished, output, indices, source_row, num_rows, k, start_expert, end_expert, stream); - break; - } - case 4: - { - topkGatingSoftmaxLauncherHelper<4, WARPS_PER_TB>( - input, finished, output, indices, source_row, num_rows, k, start_expert, end_expert, stream); - break; - } - case 8: - { - topkGatingSoftmaxLauncherHelper<8, WARPS_PER_TB>( - input, finished, output, indices, source_row, num_rows, k, start_expert, end_expert, stream); - break; - } - case 16: - { - topkGatingSoftmaxLauncherHelper<16, WARPS_PER_TB>( - input, finished, output, indices, source_row, num_rows, k, start_expert, end_expert, stream); - break; - } - case 32: - { - topkGatingSoftmaxLauncherHelper<32, WARPS_PER_TB>( - input, finished, output, indices, source_row, num_rows, k, start_expert, end_expert, stream); - break; - } - case 64: - { - topkGatingSoftmaxLauncherHelper<64, WARPS_PER_TB>( - input, finished, output, indices, source_row, num_rows, k, start_expert, end_expert, stream); - break; - } - case 128: - { - topkGatingSoftmaxLauncherHelper<128, WARPS_PER_TB>( - input, finished, output, indices, source_row, num_rows, k, start_expert, end_expert, stream); - break; - } - case 256: - { - topkGatingSoftmaxLauncherHelper<256, WARPS_PER_TB>( - input, finished, output, indices, source_row, num_rows, k, start_expert, end_expert, stream); - break; - } - default: - { - static constexpr int TPB = 256; - TLLM_CHECK(softmax_temp_output != nullptr); - moeSoftmax<<>>(input, finished, softmax_temp_output, num_experts); - moeTopK<<>>( - softmax_temp_output, finished, output, indices, source_row, num_experts, k, start_expert, end_expert); - } - } -} - -// ========================== CUB Sorting things ==================================== -CubKeyValueSorter::CubKeyValueSorter() - : num_experts_(0) - , num_bits_(sizeof(int) * 8) -{ -} - -CubKeyValueSorter::CubKeyValueSorter(const int num_experts) - : num_experts_(num_experts) - , num_bits_((int) log2(num_experts) + 1) -{ -} - -void CubKeyValueSorter::updateNumExperts(const int num_experts) -{ - num_experts_ = num_experts; - num_bits_ = (int) log2(num_experts) + 1; -} - -size_t CubKeyValueSorter::getWorkspaceSize(const size_t num_key_value_pairs, const int num_experts) -{ - size_t num_bits = (int) log2(num_experts) + 1; - size_t required_storage = 0; - int* null_int = nullptr; - cub::DeviceRadixSort::SortPairs( - NULL, required_storage, null_int, null_int, null_int, null_int, num_key_value_pairs, 0, num_bits); - return required_storage; -} - -void CubKeyValueSorter::run(void* workspace, const size_t workspace_size, const int* keys_in, int* keys_out, - const int* values_in, int* values_out, const size_t num_key_value_pairs, cudaStream_t stream) -{ - size_t expected_ws_size = getWorkspaceSize(num_key_value_pairs, num_experts_); - size_t actual_ws_size = workspace_size; - - TLLM_CHECK_WITH_INFO(expected_ws_size <= workspace_size, - "[CubKeyValueSorter::run] The allocated workspace is too small to run this problem."); - cub::DeviceRadixSort::SortPairs( - workspace, actual_ws_size, keys_in, keys_out, values_in, values_out, num_key_value_pairs, 0, num_bits_, stream); -} - -// ============================== Infer GEMM sizes ================================= -// TODO Could linear search be better for small # experts -__device__ inline int findTotalEltsLeqTarget(const int* sorted_indices, const int arr_length, const int target) -{ - int64_t low = 0, high = arr_length - 1, target_location = -1; - while (low <= high) - { - int64_t mid = (low + high) / 2; - - if (sorted_indices[mid] > target) - { - high = mid - 1; - } - else - { - low = mid + 1; - target_location = mid; - } - } - return target_location + 1; -} - -// Sets up the gemm assuming the inputs, experts and outputs are stored in row major order. -// Assumes we want to perform output = matmul(inputs, experts) + bias -// -// "total_rows_before_expert" contains the index one past the last occurrence of the corresponding expert. -// e.g. Index 0 is the start offset of expert 1, the final entry is the total number of active rows -__global__ void computeTotalRowsBeforeExpertKernel(const int* sorted_experts, const int sorted_experts_len, - const int64_t num_experts, int64_t* total_rows_before_expert) -{ - // First, compute the global tid. We only need 1 thread per expert. - const int expert = blockIdx.x * blockDim.x + threadIdx.x; - if (expert >= num_experts) - { - return; - } - - // This should construct the last index where each expert occurs. - total_rows_before_expert[expert] = findTotalEltsLeqTarget(sorted_experts, sorted_experts_len, expert); -} - -// ========================== Permutation things ======================================= - -// Duplicated and permutes rows for MoE. In addition, reverse the permutation map to help with finalizing routing. - -// "expanded_x_row" simply means that the number of values is num_rows x k. It is "expanded" since we will have to -// duplicate some rows in the input matrix to match the dimensions. Duplicates will always get routed to separate -// experts in the end. - -// Note that the expanded_dest_row_to_expanded_source_row map referred to here has indices in the range (0, -// k*rows_in_input - 1). However, it is set up so that index 0, rows_in_input, 2*rows_in_input ... (k-1)*rows_in_input -// all map to row 0 in the original matrix. Thus, to know where to read in the source matrix, we simply take the modulus -// of the expanded index. - -template -__global__ void expandInputRowsKernel(const T* unpermuted_input, T* permuted_output, - const int* expanded_dest_row_to_expanded_source_row, int* expanded_source_row_to_expanded_dest_row, - const int num_rows, const int64_t* num_dest_rows, const int cols) -{ - - // Reverse permutation map. - // I do this so that later, we can use the source -> dest map to do the k-way reduction and unpermuting. I need the - // reverse map for that reduction to allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1 - // thread block will be responsible for all k summations. - const int expanded_dest_row = blockIdx.x; - const int expanded_source_row = expanded_dest_row_to_expanded_source_row[expanded_dest_row]; - if (threadIdx.x == 0) - { - expanded_source_row_to_expanded_dest_row[expanded_source_row] = expanded_dest_row; - } - - if (!CHECK_SKIPPED || blockIdx.x < *num_dest_rows) - { - // Duplicate and permute rows - const int source_row = expanded_source_row % num_rows; - - const T* source_row_ptr = unpermuted_input + source_row * cols; - T* dest_row_ptr = permuted_output + expanded_dest_row * cols; - - for (int tid = threadIdx.x; tid < cols; tid += blockDim.x) - { - dest_row_ptr[tid] = source_row_ptr[tid]; - } - } -} - -template -void expandInputRowsKernelLauncher(const T* unpermuted_input, T* permuted_output, - const int* expanded_dest_row_to_expanded_source_row, int* expanded_source_row_to_expanded_dest_row, - const int num_rows, const int64_t* num_valid_tokens_ptr, const int cols, const int k, cudaStream_t stream) -{ - const int blocks = num_rows * k; - const int threads = std::min(cols, 1024); - auto func = (num_valid_tokens_ptr != nullptr) ? expandInputRowsKernel : expandInputRowsKernel; - func<<>>(unpermuted_input, permuted_output, expanded_dest_row_to_expanded_source_row, - expanded_source_row_to_expanded_dest_row, num_rows, num_valid_tokens_ptr, cols); -} - -enum class ScaleMode : int -{ - NO_SCALE = 0, - DEFAULT = 1, - RENORM_SCALE = 2, -}; - -// Final kernel to unpermute and scale -// This kernel unpermutes the original data, does the k-way reduction and performs the final skip connection. -template -__global__ void finalizeMoeRoutingKernel(const T* expanded_permuted_rows, T* reduced_unpermuted_output, const T* skip_1, - const T* skip_2, const T* bias, const float* scales, const int* expanded_source_row_to_expanded_dest_row, - const int* expert_for_source_row, const int cols, const int k, const int64_t* num_valid_ptr) -{ - const int original_row = blockIdx.x; - const int num_rows = gridDim.x; - const auto offset = original_row * cols; - T* reduced_row_ptr = reduced_unpermuted_output + offset; - const T* skip_1_row_ptr{}; - const T* skip_2_row_ptr{}; - - if (RESIDUAL_NUM >= 1) - { - skip_1_row_ptr = skip_1 + offset; - } - - if (RESIDUAL_NUM == 2) - { - skip_2_row_ptr = skip_2 + offset; - } - const int64_t num_valid = *num_valid_ptr; - for (int tid = threadIdx.x; tid < cols; tid += blockDim.x) - { - T thread_output{0.f}; - float row_rescale{0.f}; - for (int k_idx = 0; k_idx < k; ++k_idx) - { - const int expanded_original_row = original_row + k_idx * num_rows; - const int expanded_permuted_row = expanded_source_row_to_expanded_dest_row[expanded_original_row]; - - const int64_t k_offset = original_row * k + k_idx; - const float row_scale = (SCALE_MODE == ScaleMode::NO_SCALE) ? 1.f : scales[k_offset]; - if constexpr (SCALE_MODE == ScaleMode::RENORM_SCALE) - { - row_rescale = row_rescale + row_scale; - } - - // Check after row sum has accumulated - if (CHECK_SKIPPED && expanded_permuted_row >= num_valid) - { - continue; - } - - const T* expanded_permuted_rows_row_ptr = expanded_permuted_rows + expanded_permuted_row * cols; - - const int expert_idx = expert_for_source_row[k_offset]; - - const T* bias_ptr = bias + expert_idx * cols; - const T bias_value = HAS_BIAS ? bias_ptr[tid] : T(0.f); - - thread_output = static_cast(thread_output) - + row_scale * static_cast(expanded_permuted_rows_row_ptr[tid] + bias_value); - } - - if (SCALE_MODE == ScaleMode::RENORM_SCALE && (!CHECK_SKIPPED || thread_output)) - { - assert(row_rescale != 0.f); - thread_output = static_cast(thread_output) / row_rescale; - } - - if (RESIDUAL_NUM == 1) - { - thread_output = thread_output + skip_1_row_ptr[tid]; - } - else if (RESIDUAL_NUM == 2) - { - thread_output = thread_output + skip_1_row_ptr[tid] + skip_2_row_ptr[tid]; - } - reduced_row_ptr[tid] = thread_output; - } -} - -template -void finalizeMoeRoutingKernelLauncherSelectBias(const T* expanded_permuted_rows, T* reduced_unpermuted_output, - const T* skip_1, const T* skip_2, const T* bias, const float* scales, - const int* expanded_source_row_to_expanded_dest_row, const int* expert_for_source_row, const int num_rows, - const int cols, const int k, const int64_t* num_valid_ptr, MOEParallelismConfig parallelism_config, - MOEExpertScaleNormalizationMode normalization_mode, cudaStream_t stream) -{ - const int blocks = num_rows; - const int threads = std::min(cols, 1024); - - // Only add bias on rank 0 for tensor parallelism - const bool is_rank_0 = parallelism_config.tp_rank == 0; - const bool has_bias = bias != nullptr && is_rank_0; - - const bool check_finished = num_valid_ptr != nullptr; - - ScaleMode renorm_scales = ScaleMode::DEFAULT; - if (normalization_mode == MOEExpertScaleNormalizationMode::RENORMALIZE) - { - renorm_scales = k == 1 ? ScaleMode::NO_SCALE : ScaleMode::RENORM_SCALE; - } - - using FuncPtr = decltype(&finalizeMoeRoutingKernel); - FuncPtr func_map[2][3][2] - = {{ - {&finalizeMoeRoutingKernel, - &finalizeMoeRoutingKernel}, - {&finalizeMoeRoutingKernel, - &finalizeMoeRoutingKernel}, - {&finalizeMoeRoutingKernel, - &finalizeMoeRoutingKernel}, - }, - { - {&finalizeMoeRoutingKernel, - &finalizeMoeRoutingKernel}, - {&finalizeMoeRoutingKernel, - &finalizeMoeRoutingKernel}, - {&finalizeMoeRoutingKernel, - &finalizeMoeRoutingKernel}, - }}; - auto* const func = func_map[check_finished][int(renorm_scales)][has_bias]; - func<<>>(expanded_permuted_rows, reduced_unpermuted_output, skip_1, skip_2, bias, - scales, expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k, num_valid_ptr); -} - -template -void finalizeMoeRoutingKernelLauncher(const T* expanded_permuted_rows, T* reduced_unpermuted_output, const T* skip_1, - const T* skip_2, const T* bias, const float* scales, const int* expanded_source_row_to_expanded_dest_row, - const int* expert_for_source_row, const int num_rows, const int cols, const int k, const int64_t* num_valid_ptr, - MOEParallelismConfig parallelism_config, MOEExpertScaleNormalizationMode normalization_mode, cudaStream_t stream) -{ - // If we are not rank 0 we should not add any residuals because the allreduce would sum multiple copies - const bool is_rank_0 = parallelism_config.tp_rank == 0; - if (skip_1 == nullptr || !is_rank_0) - { - assert(skip_2 == nullptr); - finalizeMoeRoutingKernelLauncherSelectBias(expanded_permuted_rows, reduced_unpermuted_output, skip_1, - skip_2, bias, scales, expanded_source_row_to_expanded_dest_row, expert_for_source_row, num_rows, cols, k, - num_valid_ptr, parallelism_config, normalization_mode, stream); - } - else if (skip_2 == nullptr) - { - finalizeMoeRoutingKernelLauncherSelectBias(expanded_permuted_rows, reduced_unpermuted_output, skip_1, - skip_2, bias, scales, expanded_source_row_to_expanded_dest_row, expert_for_source_row, num_rows, cols, k, - num_valid_ptr, parallelism_config, normalization_mode, stream); - } - else - { - finalizeMoeRoutingKernelLauncherSelectBias(expanded_permuted_rows, reduced_unpermuted_output, skip_1, - skip_2, bias, scales, expanded_source_row_to_expanded_dest_row, expert_for_source_row, num_rows, cols, k, - num_valid_ptr, parallelism_config, normalization_mode, stream); - } -} - -// ============================== Gated Activation ================================= - -template -__global__ void doGatedActivationKernel( - T* output, const T* gemm_result, const int64_t* num_valid_tokens_ptr, size_t inter_size) -{ - const int tid = threadIdx.x; - const int token = blockIdx.x; - if (num_valid_tokens_ptr && token >= *num_valid_tokens_ptr) - { - return; - } - - ActFn fn{}; - output = output + token * inter_size; - gemm_result = gemm_result + token * inter_size * 2; - for (int i = tid; i < inter_size; i += blockDim.x) - { - T fc1_value = gemm_result[i]; - // BF16 isn't supported, use FP32 for activation function - float gate_value = gemm_result[i + inter_size]; - T gate_act = fn(gate_value); - output[i] = fc1_value * gate_act; - } -} - -template -void doGatedActivation(T* output, const T* gemm_result, const int64_t* num_valid_tokens_ptr, int inter_size, - int num_tokens, ActivationType activation_type, cudaStream_t stream) -{ - const int blocks = num_tokens; - const int threads = std::min(inter_size, 1024); - - // TODO Instead of T use a vectored type if performance would benefit - // TODO For some reason Volta fails on GELU_taylor here with Warp Illegal Instruction. - auto* fn = activation_type == ActivationType::Swiglu - ? &doGatedActivationKernel> - : &doGatedActivationKernel>; - fn<<>>(output, gemm_result, num_valid_tokens_ptr, inter_size); -} - -template -std::vector CutlassMoeFCRunner::getWorkspaceBufferSizes(const int num_rows, - const int hidden_size, const int inter_size, const int num_experts, const int num_experts_per_node, const int k, - ActivationType activation_type) const -{ - const size_t num_moe_inputs = k * num_rows; - const size_t buf_size = num_moe_inputs * hidden_size; - const size_t interbuf_elems = num_moe_inputs * inter_size; - const size_t glu_inter_elems = isGatedActivation(activation_type) ? (interbuf_elems * 2) : 0; - int num_softmax_outs = 0; - - const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); - if (!is_pow_2 || num_experts > 256) - { - num_softmax_outs = num_rows * num_experts; - } - - size_t source_rows_size = num_moe_inputs * sizeof(int); - size_t permuted_rows_size = num_moe_inputs * sizeof(int); - size_t permuted_experts_size = num_moe_inputs * sizeof(int); - size_t permuted_data_size = buf_size * sizeof(T); - size_t total_rows_before_expert_size = num_experts_per_node * sizeof(int64_t); - size_t softmax_out_size = num_softmax_outs * sizeof(float); - size_t glu_inter_size = glu_inter_elems * sizeof(T); - size_t fc1_result_size = interbuf_elems * sizeof(T); - size_t sorter_size = CubKeyValueSorter::getWorkspaceSize(num_rows, num_experts); - - std::vector workspace{ - source_rows_size, - permuted_rows_size, - permuted_experts_size, - permuted_data_size, - total_rows_before_expert_size, - softmax_out_size, - glu_inter_size, - // These pointers reuse the same memory - std::max(fc1_result_size, sorter_size), - }; - return workspace; -} - -template -size_t CutlassMoeFCRunner::getWorkspaceSize(const int num_rows, const int hidden_size, - const int inter_size, const int num_experts, const int k, ActivationType activation_type, - MOEParallelismConfig parallelism_config) const -{ - const int ep_size = parallelism_config.ep_size; - TLLM_CHECK_WITH_INFO(num_experts % ep_size == 0, "Number of experts must be a multiple of tp size"); - auto workspace = getWorkspaceBufferSizes( - num_rows, hidden_size, inter_size, num_experts, num_experts / ep_size, k, activation_type); - return tensorrt_llm::common::calculateTotalWorkspaceSize(workspace.data(), workspace.size()); -} - -template -void CutlassMoeFCRunner::configureWsPtrs(char* ws_ptr, const int num_rows, const int hidden_size, - const int inter_size, const int num_experts, const int num_experts_per_node, const int k, - ActivationType activation_type) -{ - auto workspace = getWorkspaceBufferSizes( - num_rows, hidden_size, inter_size, num_experts, num_experts_per_node, k, activation_type); - - std::vector ws_sliced{(int8_t*) ws_ptr}; - for (auto size : workspace) - { - ws_sliced.push_back(nextWorkspacePtr(ws_sliced.back(), size)); - } - - source_rows_ = (int*) ws_sliced[0]; - permuted_rows_ = (int*) ws_sliced[1]; - permuted_experts_ = (int*) ws_sliced[2]; - permuted_data_ = (T*) ws_sliced[3]; - - total_rows_before_expert_ = (int64_t*) ws_sliced[4]; - - softmax_out_ = nullptr; - const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); - if (!is_pow_2 || num_experts > 256) - { - softmax_out_ = (float*) ws_sliced[5]; - } - - glu_inter_result_ = (T*) ws_sliced[6]; - - // These pointers are aliased. Since the sort ws can be overwritten after it is finished - sorter_ws_ = (char*) ws_sliced[7]; - fc1_result_ = (T*) ws_sliced[7]; -} - -template -void CutlassMoeFCRunner::runMoe(const void* input_activations_void, const float* gating_output, - const void* fc1_expert_weights_void, const void* fc1_scales_void, const void* fc1_expert_biases_void, - ActivationType fc1_activation_type, const void* fc2_expert_weights_void, const void* fc2_scales_void, - const void* fc2_expert_biases_void, const int num_rows, const int hidden_size, const int inter_size, - const int num_experts, const int k, char* workspace_ptr, void* final_output_void, void* fc2_result_void, - const bool* finished, const int active_rows, void* expert_scales_void, - int* expanded_source_row_to_expanded_dest_row, int* expert_for_source_row, MOEParallelismConfig parallelism_config, - MOEExpertScaleNormalizationMode normalization_mode, cudaStream_t stream) -{ - static constexpr bool scales_required - = std::is_same::value || std::is_same::value; - - auto* input_activations = static_cast(input_activations_void); - auto* fc1_expert_weights = static_cast(fc1_expert_weights_void); - auto* fc1_scales = static_cast(fc1_scales_void); - auto* fc1_expert_biases = static_cast(fc1_expert_biases_void); - auto* fc2_expert_weights = static_cast(fc2_expert_weights_void); - auto* fc2_scales = static_cast(fc2_scales_void); - auto* fc2_expert_biases = static_cast(fc2_expert_biases_void); - auto* final_output = static_cast(final_output_void); - auto* fc2_result = static_cast(fc2_result_void); - auto* expert_scales = static_cast(expert_scales_void); - - TLLM_CHECK(input_activations); - TLLM_CHECK(gating_output); - TLLM_CHECK(fc1_expert_weights); - TLLM_CHECK(fc2_expert_weights); - TLLM_CHECK(workspace_ptr); - TLLM_CHECK(fc2_result); - TLLM_CHECK(expert_scales); - TLLM_CHECK(expanded_source_row_to_expanded_dest_row); - TLLM_CHECK(expert_for_source_row); - TLLM_CHECK(num_experts % parallelism_config.ep_size == 0); - - if (scales_required) - { - TLLM_CHECK_WITH_INFO(fc1_scales != nullptr, "Scales expected but scale for first matmul is a null pointer"); - TLLM_CHECK_WITH_INFO(fc2_scales != nullptr, "Scales expected but scale for second matmul is a null pointer"); - } - else - { - TLLM_CHECK_WITH_INFO(fc1_scales == nullptr, "Scales are ignored for fp32/fp16/bf16 but received scale for FC1"); - TLLM_CHECK_WITH_INFO(fc2_scales == nullptr, "Scales are ignored for fp32/fp16/bf16 but received scale for FC2"); - } - - const int num_experts_per_node = num_experts / parallelism_config.ep_size; - const int start_expert = num_experts_per_node * parallelism_config.ep_rank; - const int end_expert = start_expert + num_experts_per_node; - - configureWsPtrs( - workspace_ptr, num_rows, hidden_size, inter_size, num_experts, num_experts_per_node, k, fc1_activation_type); - topkGatingSoftmaxKernelLauncher(gating_output, finished, expert_scales, softmax_out_, expert_for_source_row, - source_rows_, num_rows, num_experts, k, start_expert, end_expert, stream); - - sync_check_cuda_error(); - - sorter_.updateNumExperts(num_experts); - const int sorter_ws_size_bytes = pad_to_multiple_of_16(sorter_.getWorkspaceSize(k * num_rows, num_experts)); - sorter_.run((void*) sorter_ws_, sorter_ws_size_bytes, expert_for_source_row, permuted_experts_, source_rows_, - permuted_rows_, k * num_rows, stream); - - sync_check_cuda_error(); - - // Upper bound on number of expanded rows - const int expanded_active_expert_rows = k * active_rows; - computeTotalRowsBeforeExpert( - permuted_experts_, expanded_active_expert_rows, num_experts_per_node, total_rows_before_expert_, stream); - - sync_check_cuda_error(); - - const bool needs_num_valid = finished || parallelism_config.ep_size > 1; - const int64_t* num_valid_tokens_ptr - = needs_num_valid ? total_rows_before_expert_ + num_experts_per_node - 1 : nullptr; - expandInputRowsKernelLauncher(input_activations, permuted_data_, permuted_rows_, - expanded_source_row_to_expanded_dest_row, num_rows, num_valid_tokens_ptr, hidden_size, k, stream); - - sync_check_cuda_error(); - - if (!isGatedActivation(fc1_activation_type)) - { - moe_gemm_runner_.moeGemmBiasAct(permuted_data_, fc1_expert_weights, fc1_scales, fc1_expert_biases, fc1_result_, - total_rows_before_expert_, expanded_active_expert_rows, inter_size, hidden_size, num_experts_per_node, - fc1_activation_type, stream); - } - else - { - const size_t fc1_out_size = inter_size * 2; - // Run the GEMM with activation function overridden with `Identity`, we do the activation separately - moe_gemm_runner_.moeGemmBiasAct(permuted_data_, fc1_expert_weights, fc1_scales, fc1_expert_biases, - glu_inter_result_, total_rows_before_expert_, expanded_active_expert_rows, fc1_out_size, hidden_size, - num_experts_per_node, ActivationType::Identity, stream); - - sync_check_cuda_error(); - - doGatedActivation(fc1_result_, glu_inter_result_, num_valid_tokens_ptr, inter_size, num_rows * k, - fc1_activation_type, stream); - } - - sync_check_cuda_error(); - - moe_gemm_runner_.moeGemm(fc1_result_, fc2_expert_weights, fc2_scales, fc2_result, total_rows_before_expert_, - expanded_active_expert_rows, hidden_size, inter_size, num_experts_per_node, stream); - - sync_check_cuda_error(); - - finalizeMoeRoutingKernelLauncher(fc2_result, final_output, - // TODO pass 'skip' connections (residuals) - nullptr, nullptr, fc2_expert_biases, expert_scales, expanded_source_row_to_expanded_dest_row, - expert_for_source_row, num_rows, hidden_size, k, num_valid_tokens_ptr, parallelism_config, normalization_mode, - stream); - - sync_check_cuda_error(); -} - -template -void CutlassMoeFCRunner::computeTotalRowsBeforeExpert(const int* sorted_indices, - const int total_indices, const int num_experts, int64_t* total_rows_before_expert, cudaStream_t stream) -{ - const int threads = std::min(1024, num_experts); - const int blocks = (num_experts + threads - 1) / threads; - - computeTotalRowsBeforeExpertKernel<<>>( - sorted_indices, total_indices, num_experts, total_rows_before_expert); -} - -// ==================== Variable batched GEMM specializations ================================== -template class CutlassMoeFCRunner; - -#ifdef ENABLE_BF16 -template class CutlassMoeFCRunner<__nv_bfloat16, __nv_bfloat16>; -// template class CutlassMoeFCRunner<__nv_bfloat16, uint8_t>; -// template class CutlassMoeFCRunner<__nv_bfloat16, cutlass::uint4b_t>; -#endif - -template class CutlassMoeFCRunner; -// template class CutlassMoeFCRunner; -// template class CutlassMoeFCRunner; - -} // namespace tensorrt_llm::kernels From 612f9614e1d95b0cee768699c732080156dce22a Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 4 Feb 2024 03:59:33 +0000 Subject: [PATCH 19/33] Add MoE MLP --- csrc/moe/moe_mlp_kernels.cu | 171 ++++++++++++++++++++++++++++++++++++ csrc/moe/moe_ops.cc | 1 + csrc/moe/moe_ops.h | 9 ++ 3 files changed, 181 insertions(+) create mode 100644 csrc/moe/moe_mlp_kernels.cu diff --git a/csrc/moe/moe_mlp_kernels.cu b/csrc/moe/moe_mlp_kernels.cu new file mode 100644 index 000000000000..f71511d78843 --- /dev/null +++ b/csrc/moe/moe_mlp_kernels.cu @@ -0,0 +1,171 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include + +#include "../dispatch_utils.h" +#include "moe_gemm_kernels.h" + +#include +#include +#include + +#include "cutlass/array.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass_extensions/epilogue/thread/fused_activations.h" + +namespace tensorrt_llm { + +// ============================== Gated Activation ================================= +template +__global__ void doGatedActivationKernel( + T* output, const T* gemm_result, const int64_t* num_valid_tokens_ptr, size_t inter_size) +{ + const int tid = threadIdx.x; + const int token = blockIdx.x; + if (num_valid_tokens_ptr && token >= *num_valid_tokens_ptr) + { + return; + } + + ActFn fn{}; + output = output + token * inter_size; + gemm_result = gemm_result + token * inter_size * 2; + for (int i = tid; i < inter_size; i += blockDim.x) + { + T fc1_value = gemm_result[i]; + // BF16 isn't supported, use FP32 for activation function + float gate_value = gemm_result[i + inter_size]; + T gate_act = fn(gate_value); + output[i] = fc1_value * gate_act; + } +} + +template +void doGatedActivation(T* output, const T* gemm_result, const int64_t* num_valid_tokens_ptr, int inter_size, + int num_tokens, ActivationType activation_type, cudaStream_t stream) +{ + const int blocks = num_tokens; + const int threads = std::min(inter_size, 1024); + + // TODO Instead of T use a vectored type if performance would benefit + // TODO For some reason Volta fails on GELU_taylor here with Warp Illegal Instruction. + auto* fn = activation_type == ActivationType::Swiglu + ? &doGatedActivationKernel> + : &doGatedActivationKernel>; + fn<<>>(output, gemm_result, num_valid_tokens_ptr, inter_size); +} + +template +void run_moe_mlp( + T* moe_output, + T* fc1_output, + T* glu_output, + const T* input_tokens, + int64_t* cum_num_tokens_per_expert, + const T* fc1_expert_weights, + const T* fc1_expert_biases, + ActivationType fc1_activation_type, + const T* fc2_expert_weights, + const int64_t num_expanded_tokens, + const int hidden_size, + const int inter_size, + const int num_experts, + cudaStream_t stream) +{ + // FIXME(woosuk): The MoE GEMM runner is created for each call. This is inefficient. + tensorrt_llm::MoeGemmRunner moe_gemm_runner; + // Compute FC1 + if (!tensorrt_llm::isGatedActivation(fc1_activation_type)) { + moe_gemm_runner.moeGemmBiasAct( + input_tokens, fc1_expert_weights, nullptr, fc1_expert_biases, fc1_output, + cum_num_tokens_per_expert, num_expanded_tokens, inter_size, hidden_size, num_experts, + fc1_activation_type, stream); + } else { + const size_t fc1_out_size = inter_size * 2; + // Run the GEMM with activation function overridden with `Identity`, we do the activation separately + moe_gemm_runner.moeGemmBiasAct( + input_tokens, fc1_expert_weights, nullptr, fc1_expert_biases, glu_output, + cum_num_tokens_per_expert, num_expanded_tokens, fc1_out_size, hidden_size, num_experts, + ActivationType::Identity, stream); + doGatedActivation( + fc1_output, glu_output, nullptr, inter_size, num_expanded_tokens, + fc1_activation_type, stream); + } + // Compute FC2 + moe_gemm_runner.moeGemm( + fc1_output, fc2_expert_weights, nullptr, moe_output, cum_num_tokens_per_expert, + num_expanded_tokens, hidden_size, inter_size, num_experts, stream); +} + +} // namespace tensorrt_llm + +// FIXME(woosuk) +#define LAUNCH_MOE_MLP(scalar_t, nv_t) \ + tensorrt_llm::run_moe_mlp( \ + (nv_t *) moe_output.data_ptr(), \ + (nv_t *) fc1_output.data_ptr(), \ + (nv_t *) glu_output.data_ptr(), \ + (nv_t *) input_tokens.data_ptr(), \ + cum_num_tokens_per_expert.data_ptr(), \ + (nv_t *) fc1_expert_weights.data_ptr(), \ + (nv_t *) (fc1_expert_biases.has_value() ? fc1_expert_biases.value().data_ptr() : nullptr), \ + fc1_activation_type_enum, \ + (nv_t *) fc2_expert_weights.data_ptr(), \ + num_expanded_tokens, \ + hidden_size, \ + inter_size, \ + num_experts, \ + stream); + +void moe_mlp( + torch::Tensor& moe_output, // [num_tokens * topk, hidden_size] + torch::Tensor& input_tokens, // [num_tokens * topk, hidden_size] + torch::Tensor& cum_num_tokens_per_expert, // [num_experts] + torch::Tensor& fc1_expert_weights, // [num_experts, inter_size or 2 * inter_size, hidden_size] + const c10::optional& fc1_expert_biases, // [num_experts, inter_size] + int fc1_activation_type, + torch::Tensor& fc2_expert_weights) // [num_experts, hidden_size, inter_size] +{ + const int64_t num_expanded_tokens = input_tokens.numel() / input_tokens.size(-1); + const int num_experts = fc2_expert_weights.size(0); + const int hidden_size = fc2_expert_weights.size(1); + const int inter_size = fc2_expert_weights.size(2); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(input_tokens)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + tensorrt_llm::ActivationType fc1_activation_type_enum = static_cast(fc1_activation_type); + torch::Tensor fc1_output = torch::empty({num_expanded_tokens, inter_size}, input_tokens.options()); + const bool is_glu = tensorrt_llm::isGatedActivation(fc1_activation_type_enum); + const int64_t glu_output_size = is_glu ? num_expanded_tokens * inter_size * 2 : 0; + torch::Tensor glu_output = torch::empty({glu_output_size}, input_tokens.options()); + + auto dtype = input_tokens.dtype(); + if (dtype == at::ScalarType::Float) { + LAUNCH_MOE_MLP(float, float); + } else if (dtype == at::ScalarType::Half) { + LAUNCH_MOE_MLP(at::Half, half); + // } else if (dtype == at::ScalarType::BFloat16) { + // LAUNCH_MOE_MLP(__nv_bfloat16); + } else { + TORCH_CHECK(false, "Unsupported data type: ", dtype); + } +} diff --git a/csrc/moe/moe_ops.cc b/csrc/moe/moe_ops.cc index edc9e0fe6dd4..f928a6cf991f 100644 --- a/csrc/moe/moe_ops.cc +++ b/csrc/moe/moe_ops.cc @@ -5,4 +5,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("topk_softmax", &topk_softmax, "Apply topk softmax to the gating outputs."); m.def("expand_and_permute", &expand_and_permute, "Expand and permute the input tokens."); + m.def("moe_mlp", &moe_mlp, "Apply the MoE MLP."); } diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index e18c6f66d0d2..e3a45601057e 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -15,3 +15,12 @@ void expand_and_permute( torch::Tensor& input_tokens, torch::Tensor& topk_indices, torch::Tensor& token_expert_indices); + +void moe_mlp( + torch::Tensor& moe_output, + torch::Tensor& input_tokens, + torch::Tensor& cum_num_tokens_per_expert, + torch::Tensor& fc1_expert_weights, + const c10::optional& fc1_expert_biases, + int fc1_activation_type, + torch::Tensor& fc2_expert_weights); From 0bf8fb9fcf4d9e75a12b16f07659df2a1da42943 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 4 Feb 2024 04:00:41 +0000 Subject: [PATCH 20/33] Add cudaUtils --- csrc/cutlass_extensions/compute_occupancy.h | 2 +- csrc/cutlass_extensions/cudaUtils.h | 89 +++++++++++++++++++++ csrc/cutlass_utils/cutlass_heuristic.h | 11 ++- 3 files changed, 100 insertions(+), 2 deletions(-) create mode 100644 csrc/cutlass_extensions/cudaUtils.h diff --git a/csrc/cutlass_extensions/compute_occupancy.h b/csrc/cutlass_extensions/compute_occupancy.h index 23821e1d1008..97bf693e7092 100644 --- a/csrc/cutlass_extensions/compute_occupancy.h +++ b/csrc/cutlass_extensions/compute_occupancy.h @@ -18,7 +18,7 @@ #include #include "cutlass/device_kernel.h" -#include "tensorrt_llm/common/cudaUtils.h" +#include "cudaUtils.h" namespace tensorrt_llm { diff --git a/csrc/cutlass_extensions/cudaUtils.h b/csrc/cutlass_extensions/cudaUtils.h new file mode 100644 index 000000000000..1bd90f42ba65 --- /dev/null +++ b/csrc/cutlass_extensions/cudaUtils.h @@ -0,0 +1,89 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tensorrt_llm::common +{ + +/* **************************** debug tools ********************************* */ +static const char* _cudaGetErrorEnum(cudaError_t error) +{ + return cudaGetErrorString(error); +} + +static const char* _cudaGetErrorEnum(cublasStatus_t error) +{ + switch (error) + { + case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS"; + + case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED"; + + case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED"; + + case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE"; + + case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH"; + + case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR"; + + case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED"; + + case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR"; + + case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED"; + + case CUBLAS_STATUS_LICENSE_ERROR: return "CUBLAS_STATUS_LICENSE_ERROR"; + } + return ""; +} + +// FIXME(woosuk) +template +void check(T result, char const* const func, const char* const file, int const line) +{ + if (result) + { + throw std::runtime_error("ERROR!"); + } +} + +#define check_cuda_error(val) check((val), #val, __FILE__, __LINE__) +#define check_cuda_error_2(val, file, line) check((val), #val, file, line) + +inline int getSMVersion() +{ + int device{-1}; + check_cuda_error(cudaGetDevice(&device)); + int sm_major = 0; + int sm_minor = 0; + check_cuda_error(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device)); + check_cuda_error(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device)); + return sm_major * 10 + sm_minor; +} + +} // namespace tensorrt_llm::common diff --git a/csrc/cutlass_utils/cutlass_heuristic.h b/csrc/cutlass_utils/cutlass_heuristic.h index 071998406be2..95f50c637232 100644 --- a/csrc/cutlass_utils/cutlass_heuristic.h +++ b/csrc/cutlass_utils/cutlass_heuristic.h @@ -17,7 +17,16 @@ #pragma once #include "cutlass_extensions/gemm_configs.h" -#include "tensorrt_llm/common/cudaUtils.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include namespace tensorrt_llm { From c09179da42fc5c894868583a2246c28d2282806d Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 4 Feb 2024 04:06:30 +0000 Subject: [PATCH 21/33] Fix headers --- .../gemm/warp/mma_tensorop_dequantizer.h | 3 ++- csrc/cutlass_utils/cutlass_heuristic.cpp | 13 ++---------- csrc/moe/moe_gemm_kernels.h | 3 ++- csrc/moe/moe_gemm_kernels_template.h | 21 +++++++++++-------- setup.py | 7 +++++-- 5 files changed, 23 insertions(+), 24 deletions(-) diff --git a/csrc/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h b/csrc/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h index bdac36fd95d9..3b3fcd0f2e00 100644 --- a/csrc/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h +++ b/csrc/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h @@ -53,7 +53,8 @@ #include "cutlass/platform/platform.h" #include "cutlass_extensions/weight_only_quant_op.h" -#include "tensorrt_llm/common/cudaBf16Wrapper.h" +// FIXME(woosuk) +#include //////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/cutlass_utils/cutlass_heuristic.cpp b/csrc/cutlass_utils/cutlass_heuristic.cpp index db77569e374d..6ad1cddc6a90 100644 --- a/csrc/cutlass_utils/cutlass_heuristic.cpp +++ b/csrc/cutlass_utils/cutlass_heuristic.cpp @@ -14,21 +14,12 @@ * limitations under the License. */ -#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h" -#include "tensorrt_llm/common/cudaBf16Wrapper.h" - -#ifndef _WIN32 -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wstrict-aliasing" -#endif // #ifndef _WIN32 +#include "cutlass_heuristic.h" +#include #include "cutlass/gemm/gemm.h" #include "cutlass/numeric_types.h" -#ifndef _WIN32 -#pragma GCC diagnostic pop -#endif // #ifndef _WIN32 - #include #include diff --git a/csrc/moe/moe_gemm_kernels.h b/csrc/moe/moe_gemm_kernels.h index b5aee286e77b..2cc69f57b8f1 100644 --- a/csrc/moe/moe_gemm_kernels.h +++ b/csrc/moe/moe_gemm_kernels.h @@ -16,7 +16,8 @@ */ #pragma once -#include "../cutlass_extensions/gemm_configs.h" + +#include "cutlass_extensions/gemm_configs.h" #include #include diff --git a/csrc/moe/moe_gemm_kernels_template.h b/csrc/moe/moe_gemm_kernels_template.h index 275c625e8734..3a0beaf2924f 100644 --- a/csrc/moe/moe_gemm_kernels_template.h +++ b/csrc/moe/moe_gemm_kernels_template.h @@ -13,10 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - -// Ignore CUTLASS warnings about type punning -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#pragma once #include "cutlass/array.h" #include "cutlass/numeric_conversion.h" @@ -30,17 +27,23 @@ #include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h" #include "cutlass_extensions/gemm/threadblock/default_mma.h" -#pragma GCC diagnostic pop +#include "../cutlass_utils/cutlass_heuristic.h" +#include "moe_gemm_kernels.h" -#include "tensorrt_llm/common/assert.h" -#include "tensorrt_llm/common/cudaUtils.h" -#include "../cutlass_kernels/cutlass_heuristic.h" -#include "../cutlass_kernels/moe_gemm/moe_gemm_kernels.h" #include #include #include #include +// FIXME(woosuk) +#define TLLM_THROW(...) \ + do \ + { \ + throw std::runtime_error("ERROR!"); \ + } while (0) + +#define TLLM_CHECK_WITH_INFO(...) ;;\ + namespace tensorrt_llm { diff --git a/setup.py b/setup.py index d00abd1c4b23..2605a633c59e 100644 --- a/setup.py +++ b/setup.py @@ -321,8 +321,11 @@ def get_torch_arch_list() -> Set[str]: ext_modules.append( CUDAExtension( name="vllm._moe_C", - sources=["csrc/moe/moe_ops.cc"] + glob("csrc/moe/*.cu"), - include_dirs=[os.path.join(abs_root_dir, "third_party/cutlass/include/")], + sources=["csrc/cutlass_utils/cutlass_heuristic.cpp"] + glob("csrc/moe/*.cu") + glob("csrc/moe/*.cc"), + include_dirs=[ + os.path.join(abs_root_dir, "third_party/cutlass/include/"), + os.path.join(abs_root_dir, "csrc/"), + ], extra_compile_args={ "cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS_PUNICA, # FIXME From 2ab65df77668416b0232afb66e24136b3ce5bed4 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 4 Feb 2024 06:07:04 +0000 Subject: [PATCH 22/33] Enable BF16 --- csrc/moe/moe_gemm_kernels_bf16_bf16.cu | 2 -- csrc/moe/moe_gemm_kernels_template.h | 1 + csrc/moe/moe_mlp_kernels.cu | 4 ++-- setup.py | 2 +- 4 files changed, 4 insertions(+), 5 deletions(-) diff --git a/csrc/moe/moe_gemm_kernels_bf16_bf16.cu b/csrc/moe/moe_gemm_kernels_bf16_bf16.cu index fee2550fbca3..c0ca12814bbb 100644 --- a/csrc/moe/moe_gemm_kernels_bf16_bf16.cu +++ b/csrc/moe/moe_gemm_kernels_bf16_bf16.cu @@ -18,7 +18,5 @@ namespace tensorrt_llm { -#ifdef ENABLE_BF16 template class MoeGemmRunner<__nv_bfloat16, __nv_bfloat16>; -#endif } // namespace tensorrt_llm diff --git a/csrc/moe/moe_gemm_kernels_template.h b/csrc/moe/moe_gemm_kernels_template.h index 3a0beaf2924f..3a94e1d0ba5d 100644 --- a/csrc/moe/moe_gemm_kernels_template.h +++ b/csrc/moe/moe_gemm_kernels_template.h @@ -32,6 +32,7 @@ #include #include +#include #include #include diff --git a/csrc/moe/moe_mlp_kernels.cu b/csrc/moe/moe_mlp_kernels.cu index f71511d78843..78467781afea 100644 --- a/csrc/moe/moe_mlp_kernels.cu +++ b/csrc/moe/moe_mlp_kernels.cu @@ -163,8 +163,8 @@ void moe_mlp( LAUNCH_MOE_MLP(float, float); } else if (dtype == at::ScalarType::Half) { LAUNCH_MOE_MLP(at::Half, half); - // } else if (dtype == at::ScalarType::BFloat16) { - // LAUNCH_MOE_MLP(__nv_bfloat16); + } else if (dtype == at::ScalarType::BFloat16) { + LAUNCH_MOE_MLP(at::BFloat16, __nv_bfloat16); } else { TORCH_CHECK(false, "Unsupported data type: ", dtype); } diff --git a/setup.py b/setup.py index 2605a633c59e..3500e4fee4d5 100644 --- a/setup.py +++ b/setup.py @@ -328,7 +328,7 @@ def get_torch_arch_list() -> Set[str]: ], extra_compile_args={ "cxx": CXX_FLAGS, - "nvcc": NVCC_FLAGS_PUNICA, # FIXME + "nvcc": NVCC_FLAGS_PUNICA + ["-DENABLE_BF16"], # FIXME }, )) From c74fc79a73537787e9b1402935dfd76ea7cfb565 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 4 Feb 2024 06:07:17 +0000 Subject: [PATCH 23/33] Err msg --- csrc/cutlass_extensions/cudaUtils.h | 30 ++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/csrc/cutlass_extensions/cudaUtils.h b/csrc/cutlass_extensions/cudaUtils.h index 1bd90f42ba65..db6fd6e5af5e 100644 --- a/csrc/cutlass_extensions/cudaUtils.h +++ b/csrc/cutlass_extensions/cudaUtils.h @@ -20,12 +20,17 @@ #include #include #include + #include #include #include #include #include +#include +#include +#include + namespace tensorrt_llm::common { @@ -62,13 +67,36 @@ static const char* _cudaGetErrorEnum(cublasStatus_t error) return ""; } +static std::string vformat(char const* fmt, va_list args) +{ + va_list args0; + va_copy(args0, args); + auto const size = std::vsnprintf(nullptr, 0, fmt, args0); + if (size <= 0) + return ""; + + std::string stringBuf(size, char{}); + auto const size2 = std::vsnprintf(&stringBuf[0], size + 1, fmt, args); + return stringBuf; +} + +static std::string fmtstr(char const* format, ...) +{ + va_list args; + va_start(args, format); + std::string result = vformat(format, args); + va_end(args); + return result; +}; + // FIXME(woosuk) template void check(T result, char const* const func, const char* const file, int const line) { if (result) { - throw std::runtime_error("ERROR!"); + throw std::runtime_error( + fmtstr("[ERROR] CUDA runtime error in %s: %s %s:%d\n", func, _cudaGetErrorEnum(result), file, line)); } } From 6320de43de66df54087b8db5cfab3af2f72a541a Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 4 Feb 2024 07:35:47 +0000 Subject: [PATCH 24/33] Add unpermute_and_reduce --- csrc/moe/moe_ops.cc | 1 + csrc/moe/moe_ops.h | 7 ++ csrc/moe/unpermute_kernels.cu | 197 ++++++++++++++++++++++++++++++++++ 3 files changed, 205 insertions(+) create mode 100644 csrc/moe/unpermute_kernels.cu diff --git a/csrc/moe/moe_ops.cc b/csrc/moe/moe_ops.cc index f928a6cf991f..bf1da01895f6 100644 --- a/csrc/moe/moe_ops.cc +++ b/csrc/moe/moe_ops.cc @@ -6,4 +6,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("topk_softmax", &topk_softmax, "Apply topk softmax to the gating outputs."); m.def("expand_and_permute", &expand_and_permute, "Expand and permute the input tokens."); m.def("moe_mlp", &moe_mlp, "Apply the MoE MLP."); + m.def("unpermute_and_reduce", &unpermute_and_reduce, "Unpermute and reduce the MoE outputs."); } diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index e3a45601057e..94b2c7e8542e 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -24,3 +24,10 @@ void moe_mlp( const c10::optional& fc1_expert_biases, int fc1_activation_type, torch::Tensor& fc2_expert_weights); + +void unpermute_and_reduce( + torch::Tensor& output_tokens, + torch::Tensor& experts_output, + torch::Tensor& topk_weights, + torch::Tensor& topk_indices, + torch::Tensor& reverse_permutation_map); diff --git a/csrc/moe/unpermute_kernels.cu b/csrc/moe/unpermute_kernels.cu new file mode 100644 index 000000000000..1f4044ef9511 --- /dev/null +++ b/csrc/moe/unpermute_kernels.cu @@ -0,0 +1,197 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include + +#include "../dispatch_utils.h" + +namespace vllm { +namespace moe { + +enum class MOEExpertScaleNormalizationMode : int +{ + NONE = 0, //!< Run the softmax on all scales and select the topk + RENORMALIZE, //!< Renormalize the selected scales so they sum to one. This is equivalent to only running softmax on + //!< the topk selected experts +}; + +enum class ScaleMode : int +{ + NO_SCALE = 0, + DEFAULT = 1, + RENORM_SCALE = 2, +}; + +// Final kernel to unpermute and scale +// This kernel unpermutes the original data, does the k-way reduction and performs the final skip connection. +template +__global__ void finalizeMoeRoutingKernel(const T* expanded_permuted_rows, T* reduced_unpermuted_output, const T* skip_1, + const T* skip_2, const T* bias, const float* scales, const int* expanded_source_row_to_expanded_dest_row, + const int* expert_for_source_row, const int cols, const int k, const int64_t* num_valid_ptr) +{ + const int original_row = blockIdx.x; + const int num_rows = gridDim.x; + const auto offset = original_row * cols; + T* reduced_row_ptr = reduced_unpermuted_output + offset; + const T* skip_1_row_ptr{}; + const T* skip_2_row_ptr{}; + + if (RESIDUAL_NUM >= 1) + { + skip_1_row_ptr = skip_1 + offset; + } + + if (RESIDUAL_NUM == 2) + { + skip_2_row_ptr = skip_2 + offset; + } + const int64_t num_valid = *num_valid_ptr; + for (int tid = threadIdx.x; tid < cols; tid += blockDim.x) + { + T thread_output{0.f}; + float row_rescale{0.f}; + for (int k_idx = 0; k_idx < k; ++k_idx) + { + const int expanded_original_row = original_row + k_idx * num_rows; + const int expanded_permuted_row = expanded_source_row_to_expanded_dest_row[expanded_original_row]; + + const int64_t k_offset = original_row * k + k_idx; + const float row_scale = (SCALE_MODE == ScaleMode::NO_SCALE) ? 1.f : scales[k_offset]; + if constexpr (SCALE_MODE == ScaleMode::RENORM_SCALE) + { + row_rescale = row_rescale + row_scale; + } + + // Check after row sum has accumulated + if (CHECK_SKIPPED && expanded_permuted_row >= num_valid) + { + continue; + } + + const T* expanded_permuted_rows_row_ptr = expanded_permuted_rows + expanded_permuted_row * cols; + + const int expert_idx = expert_for_source_row[k_offset]; + + const T* bias_ptr = bias + expert_idx * cols; + const T bias_value = HAS_BIAS ? bias_ptr[tid] : T(0.f); + + thread_output = static_cast(thread_output) + + row_scale * static_cast(expanded_permuted_rows_row_ptr[tid] + bias_value); + } + + if (SCALE_MODE == ScaleMode::RENORM_SCALE && (!CHECK_SKIPPED || thread_output)) + { + assert(row_rescale != 0.f); + thread_output = static_cast(thread_output) / row_rescale; + } + + if (RESIDUAL_NUM == 1) + { + thread_output = thread_output + skip_1_row_ptr[tid]; + } + else if (RESIDUAL_NUM == 2) + { + thread_output = thread_output + skip_1_row_ptr[tid] + skip_2_row_ptr[tid]; + } + reduced_row_ptr[tid] = thread_output; + } +} + +template +void finalizeMoeRoutingKernelLauncherSelectBias(const T* expanded_permuted_rows, T* reduced_unpermuted_output, + const T* skip_1, const T* skip_2, const T* bias, const float* scales, + const int* expanded_source_row_to_expanded_dest_row, const int* expert_for_source_row, const int num_rows, + const int cols, const int k, const int64_t* num_valid_ptr, const bool has_bias, + MOEExpertScaleNormalizationMode normalization_mode, cudaStream_t stream) +{ + const int blocks = num_rows; + const int threads = std::min(cols, 1024); + + const bool check_finished = num_valid_ptr != nullptr; + + ScaleMode renorm_scales = ScaleMode::DEFAULT; + if (normalization_mode == MOEExpertScaleNormalizationMode::RENORMALIZE) + { + renorm_scales = k == 1 ? ScaleMode::NO_SCALE : ScaleMode::RENORM_SCALE; + } + + using FuncPtr = decltype(&finalizeMoeRoutingKernel); + FuncPtr func_map[2][3][2] + = {{ + {&finalizeMoeRoutingKernel, + &finalizeMoeRoutingKernel}, + {&finalizeMoeRoutingKernel, + &finalizeMoeRoutingKernel}, + {&finalizeMoeRoutingKernel, + &finalizeMoeRoutingKernel}, + }, + { + {&finalizeMoeRoutingKernel, + &finalizeMoeRoutingKernel}, + {&finalizeMoeRoutingKernel, + &finalizeMoeRoutingKernel}, + {&finalizeMoeRoutingKernel, + &finalizeMoeRoutingKernel}, + }}; + auto* const func = func_map[check_finished][int(renorm_scales)][has_bias]; + func<<>>(expanded_permuted_rows, reduced_unpermuted_output, skip_1, skip_2, bias, + scales, expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k, num_valid_ptr); +} + +template +void finalizeMoeRoutingKernelLauncher(const T* expanded_permuted_rows, T* reduced_unpermuted_output, + const float* topk_weights, const int* expanded_source_row_to_expanded_dest_row, const int* expert_for_source_row, + const int num_rows, const int cols, const int k, cudaStream_t stream) +{ + finalizeMoeRoutingKernelLauncherSelectBias( + expanded_permuted_rows, reduced_unpermuted_output, nullptr, nullptr, nullptr, + topk_weights, expanded_source_row_to_expanded_dest_row, expert_for_source_row, + num_rows, cols, k, nullptr, false, MOEExpertScaleNormalizationMode::RENORMALIZE, stream); +} + +} // namespace moe +} // namespace vllm + +void unpermute_and_reduce( + torch::Tensor& output_tokens, // [num_tokens, hidden_size] + torch::Tensor& experts_output, // [num_tokens * topk, hidden_size] + torch::Tensor& topk_weights, // [num_tokens, topk] + torch::Tensor& topk_indices, // [num_tokens, topk] + torch::Tensor& reverse_permutation_map) // [num_tokens * topk] +{ + const int hidden_size = output_tokens.size(-1); + const int num_tokens = output_tokens.numel() / hidden_size; + const int topk = topk_weights.size(-1); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(output_tokens)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + experts_output.scalar_type(), "finalizeMoeRoutingKernelLauncher", + [&] { + vllm::moe::finalizeMoeRoutingKernelLauncher( + experts_output.data_ptr(), + output_tokens.data_ptr(), + topk_weights.data_ptr(), + reverse_permutation_map.data_ptr(), + topk_indices.data_ptr(), + num_tokens, + hidden_size, + topk, + stream); + }); +} From 9b57e39a56f46817abc80c7ed92780940c5cb9ee Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 5 Feb 2024 07:40:45 +0000 Subject: [PATCH 25/33] Add renormalize --- csrc/moe/moe_ops.h | 3 ++- csrc/moe/unpermute_kernels.cu | 10 +++++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index 94b2c7e8542e..52a346817ae9 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -30,4 +30,5 @@ void unpermute_and_reduce( torch::Tensor& experts_output, torch::Tensor& topk_weights, torch::Tensor& topk_indices, - torch::Tensor& reverse_permutation_map); + torch::Tensor& reverse_permutation_map, + bool renormalize); diff --git a/csrc/moe/unpermute_kernels.cu b/csrc/moe/unpermute_kernels.cu index 1f4044ef9511..f88f3abe198f 100644 --- a/csrc/moe/unpermute_kernels.cu +++ b/csrc/moe/unpermute_kernels.cu @@ -156,12 +156,14 @@ void finalizeMoeRoutingKernelLauncherSelectBias(const T* expanded_permuted_rows, template void finalizeMoeRoutingKernelLauncher(const T* expanded_permuted_rows, T* reduced_unpermuted_output, const float* topk_weights, const int* expanded_source_row_to_expanded_dest_row, const int* expert_for_source_row, - const int num_rows, const int cols, const int k, cudaStream_t stream) + const int num_rows, const int cols, const int k, bool renormalize, cudaStream_t stream) { + const MOEExpertScaleNormalizationMode normalization_mode = renormalize ? MOEExpertScaleNormalizationMode::RENORMALIZE + : MOEExpertScaleNormalizationMode::NONE; finalizeMoeRoutingKernelLauncherSelectBias( expanded_permuted_rows, reduced_unpermuted_output, nullptr, nullptr, nullptr, topk_weights, expanded_source_row_to_expanded_dest_row, expert_for_source_row, - num_rows, cols, k, nullptr, false, MOEExpertScaleNormalizationMode::RENORMALIZE, stream); + num_rows, cols, k, nullptr, false, normalization_mode, stream); } } // namespace moe @@ -172,7 +174,8 @@ void unpermute_and_reduce( torch::Tensor& experts_output, // [num_tokens * topk, hidden_size] torch::Tensor& topk_weights, // [num_tokens, topk] torch::Tensor& topk_indices, // [num_tokens, topk] - torch::Tensor& reverse_permutation_map) // [num_tokens * topk] + torch::Tensor& reverse_permutation_map, // [num_tokens * topk] + bool renormalize) { const int hidden_size = output_tokens.size(-1); const int num_tokens = output_tokens.numel() / hidden_size; @@ -192,6 +195,7 @@ void unpermute_and_reduce( num_tokens, hidden_size, topk, + renormalize, stream); }); } From 55fae45a2162aa469f08bb58e84a0dafd6467f98 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 5 Feb 2024 07:41:03 +0000 Subject: [PATCH 26/33] Add FusedMoE --- vllm/model_executor/layers/fused_moe.py | 62 +++++++++++++++++++++++++ vllm/model_executor/models/deepseek.py | 40 ++++++++++------ 2 files changed, 87 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe.py b/vllm/model_executor/layers/fused_moe.py index eed2e83bed7f..c8256695e7cc 100644 --- a/vllm/model_executor/layers/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe.py @@ -287,3 +287,65 @@ def fused_moe(hidden_states: torch.Tensor, out=hidden_states) return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) + + +import vllm._moe_C as moe_kernels + +def fused_moe_( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, +) -> torch.Tensor: + num_tokens = gating_output.shape[:-1].numel() + num_experts = gating_output.shape[-1] + hidden_size = hidden_states.shape[-1] + dtype = hidden_states.dtype + device = hidden_states.device + # print(hidden_states.shape, w1.shape, w2.shape, gating_output.shape) + + topk_weights = torch.empty(num_tokens, topk, dtype=torch.float32, device=device) + topk_indices = torch.empty(num_tokens, topk, dtype=torch.int32, device=device) + token_expert_indicies = torch.empty_like(topk_indices) + moe_kernels.topk_softmax( + topk_weights, + topk_indices, + token_expert_indicies, + gating_output.float(), + ) + + permuted_tokens = torch.empty(num_tokens * topk, hidden_size, dtype=dtype, device=device) + cum_num_tokens_per_expert = torch.empty(num_experts, dtype=torch.long, device=device) + reverse_permutation_map = torch.empty(num_tokens * topk, dtype=torch.int32, device=device) + moe_kernels.expand_and_permute( + permuted_tokens, + cum_num_tokens_per_expert, + reverse_permutation_map, + hidden_states, + topk_indices, + token_expert_indicies, + ) + + mlp_output = torch.empty_like(permuted_tokens) + moe_kernels.moe_mlp( + mlp_output, + permuted_tokens, + cum_num_tokens_per_expert, + w1, + None, + 3, + w2, + ) + + output_tokens = torch.empty_like(hidden_states) + moe_kernels.unpermute_and_reduce( + output_tokens, + mlp_output, + topk_weights, + topk_indices, + reverse_permutation_map, + renormalize, + ) + return output_tokens diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index fc727b8e661b..ca2cb8b6ceb3 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -31,7 +31,7 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention import PagedAttention -from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe import fused_moe, fused_moe_ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, @@ -156,20 +156,30 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # router_logits: (batch * sequence_length, n_experts) router_logits, _ = self.gate(hidden_states) - routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, - self.top_k, - dim=-1) - - if self.config.norm_topk_prob: - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - - final_hidden_states = fused_moe(hidden_states, - self.w1, - self.w2, - routing_weights, - selected_experts, - inplace=True) + if False: + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, + self.top_k, + dim=-1) + + if self.config.norm_topk_prob: + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + + final_hidden_states = fused_moe(hidden_states, + self.w1, + self.w2, + routing_weights, + selected_experts, + inplace=True) + else: + final_hidden_states = fused_moe_( + hidden_states, + self.w1, + self.w2, + router_logits, + self.top_k, + renormalize=self.config.norm_topk_prob, + ) if self.config.n_shared_experts is not None: final_hidden_states = final_hidden_states + shared_output From fb9c52426d096dc31fb01eea94dcde300cc0e52e Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 7 Feb 2024 17:15:57 +0000 Subject: [PATCH 27/33] Minor fix --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index f2dd2d0d7f13..daf54bb73fef 100644 --- a/setup.py +++ b/setup.py @@ -343,7 +343,7 @@ def get_torch_arch_list() -> Set[str]: ext_modules.append( CUDAExtension( name="vllm._moe_C", - sources=["csrc/cutlass_utils/cutlass_heuristic.cpp", "csrc/moe/*.cpp"] + glob("csrc/moe/*.cu"), + sources=["csrc/cutlass_utils/cutlass_heuristic.cpp", "csrc/moe/moe_ops.cpp"] + glob("csrc/moe/*.cu"), include_dirs=[ os.path.join(abs_root_dir, "third_party/cutlass/include/"), os.path.join(abs_root_dir, "csrc/"), From bd52cbb1025d7d90f1487f7f2a741319cca81366 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Feb 2024 14:07:40 -0800 Subject: [PATCH 28/33] Use autotuned config --- vllm/model_executor/layers/fused_moe.py | 31 +++++++++++++++---------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe.py b/vllm/model_executor/layers/fused_moe.py index ddb36e6c0277..35add5969c0d 100644 --- a/vllm/model_executor/layers/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe.py @@ -218,6 +218,7 @@ def fused_moe( topk: int, renormalize: bool, inplace: bool = False, + fused_moe_config = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism. @@ -248,6 +249,9 @@ def fused_moe( M, _ = hidden_states.shape E, N, _ = w1.shape + if M <= 64: + return fused_moe_(hidden_states, w1, w2, gating_output, topk, renormalize, inplace) + if is_hip(): # The MoE kernels are not yet supported on ROCm. routing_weights = torch.softmax(gating_output, @@ -279,21 +283,24 @@ def fused_moe( if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - config = { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8 - } - - if topk_ids.numel() <= w1.shape[0]: + if not fused_moe_config: config = { - 'BLOCK_SIZE_M': 16, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 64, - 'GROUP_SIZE_M': 1 + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 } + if topk_ids.numel() <= w1.shape[0]: + config = { + 'BLOCK_SIZE_M': 16, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 64, + 'GROUP_SIZE_M': 1 + } + else: + config = fused_moe_config[min(fused_moe_config.keys(), key=lambda x: abs(x - M))] + intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N), device=hidden_states.device, dtype=hidden_states.dtype) From 855d98a98bb0970256e286605a7a37a7bd595328 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Feb 2024 14:36:53 -0800 Subject: [PATCH 29/33] fix --- vllm/model_executor/layers/fused_moe.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe.py b/vllm/model_executor/layers/fused_moe.py index 35add5969c0d..4c63b836d02e 100644 --- a/vllm/model_executor/layers/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe.py @@ -250,7 +250,8 @@ def fused_moe( E, N, _ = w1.shape if M <= 64: - return fused_moe_(hidden_states, w1, w2, gating_output, topk, renormalize, inplace) + assert inplace + return fused_moe_(hidden_states, w1, w2, gating_output, topk, renormalize) if is_hip(): # The MoE kernels are not yet supported on ROCm. From 437edcfaa9efc0b94d04b32af8dc888a3d438f66 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Feb 2024 16:00:45 -0800 Subject: [PATCH 30/33] add config --- vllm/model_executor/models/mixtral.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 0100624a44d7..1b66e533f08e 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -21,6 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Mixtral model.""" +import os from typing import List, Optional, Tuple import torch @@ -78,6 +79,11 @@ def __init__( self.hidden_size = hidden_size self.intermediate_size = intermediate_size // self.tp_size + self.fused_moe_config = None + if "VLLM_MIXTRAL_FUSE_MOE_CONFIG" in os.environ: + data = json.load(os.environ["VLLM_MIXTRAL_FUSE_MOE_CONFIG"]) + self.fused_moe_config = {int(key): val for key, val in data.items()} + if params_dtype is None: params_dtype = torch.get_default_dtype() self.params_dtype = params_dtype @@ -133,7 +139,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: router_logits, self.top_k, renormalize=True, - inplace=True) + inplace=True, + self.fused_moe_config) if self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce( From e13bc81b72a5b996decc348b1d9eaf20225373e8 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Feb 2024 16:02:54 -0800 Subject: [PATCH 31/33] fix --- vllm/model_executor/models/mixtral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 1b66e533f08e..0d11f192cbb4 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -140,7 +140,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: self.top_k, renormalize=True, inplace=True, - self.fused_moe_config) + fused_moe_config=self.fused_moe_config) if self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce( From 4b94875a8b04c4b196dd5d52e989a87148a3b206 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Feb 2024 16:04:08 -0800 Subject: [PATCH 32/33] update --- vllm/model_executor/models/mixtral.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 0d11f192cbb4..89a4dc20a5dc 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -81,6 +81,7 @@ def __init__( self.fused_moe_config = None if "VLLM_MIXTRAL_FUSE_MOE_CONFIG" in os.environ: + import json data = json.load(os.environ["VLLM_MIXTRAL_FUSE_MOE_CONFIG"]) self.fused_moe_config = {int(key): val for key, val in data.items()} From 5bd4256540f64891adffeabb3f93427985e70e0f Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Feb 2024 16:06:37 -0800 Subject: [PATCH 33/33] update --- vllm/model_executor/models/mixtral.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 89a4dc20a5dc..e122d35f7875 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -82,7 +82,8 @@ def __init__( self.fused_moe_config = None if "VLLM_MIXTRAL_FUSE_MOE_CONFIG" in os.environ: import json - data = json.load(os.environ["VLLM_MIXTRAL_FUSE_MOE_CONFIG"]) + with open(os.environ["VLLM_MIXTRAL_FUSE_MOE_CONFIG"]) as f: + data = json.load(f) self.fused_moe_config = {int(key): val for key, val in data.items()} if params_dtype is None: