|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one |
| 3 | + * or more contributor license agreements. See the NOTICE file |
| 4 | + * distributed with this work for additional information |
| 5 | + * regarding copyright ownership. The ASF licenses this file |
| 6 | + * to you under the Apache License, Version 2.0 (the |
| 7 | + * "License"); you may not use this file except in compliance |
| 8 | + * with the License. You may obtain a copy of the License at |
| 9 | + * |
| 10 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | + * |
| 12 | + * Unless required by applicable law or agreed to in writing, |
| 13 | + * software distributed under the License is distributed on an |
| 14 | + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 15 | + * KIND, either express or implied. See the License for the |
| 16 | + * specific language governing permissions and limitations |
| 17 | + * under the License. |
| 18 | + */ |
| 19 | + |
| 20 | +#include <fstream> |
| 21 | +#include <iostream> |
| 22 | +#include <sstream> |
| 23 | +#include <variant> |
| 24 | +#include <vector> |
| 25 | + |
| 26 | +#include "../../cuda/cuda_common.h" |
| 27 | + |
| 28 | +// clang-format off |
| 29 | +#include "cutlass/cutlass.h" |
| 30 | + |
| 31 | +#include "cute/tensor.hpp" |
| 32 | +#include "cutlass/tensor_ref.h" |
| 33 | +#include "cutlass/epilogue/collective/default_epilogue.hpp" |
| 34 | +#include "cutlass/epilogue/thread/linear_combination.h" |
| 35 | +#include "cutlass/gemm/dispatch_policy.hpp" |
| 36 | +#include "cutlass/gemm/gemm.h" |
| 37 | +#include "cutlass/gemm/collective/collective_builder.hpp" |
| 38 | +#include "cutlass/epilogue/collective/collective_builder.hpp" |
| 39 | +#include "cutlass/gemm/device/gemm_universal_adapter.h" |
| 40 | +#include "cutlass/gemm/kernel/gemm_universal.hpp" |
| 41 | +// clang-format on |
| 42 | + |
| 43 | +#define CUTLASS_CHECK(status) \ |
| 44 | + { \ |
| 45 | + cutlass::Status error = status; \ |
| 46 | + CHECK(error == cutlass::Status::kSuccess) \ |
| 47 | + << "Got cutlass error: " << cutlassGetStatusString(error); \ |
| 48 | + } |
| 49 | + |
| 50 | +using namespace cute; |
| 51 | +using ProblemShape = Shape<int, int, int>; // <M, N, K> |
| 52 | + |
| 53 | +template <typename KernelTraits, typename ElementA, typename ElementB, typename ElementC, |
| 54 | + typename LayoutA = cutlass::layout::RowMajor, |
| 55 | + typename LayoutB = cutlass::layout::ColumnMajor, |
| 56 | + typename LayoutC = cutlass::layout::RowMajor> |
| 57 | +struct CutlassGemmRunner { |
| 58 | + static constexpr int AlignmentA = |
| 59 | + 128 / cutlass::sizeof_bits<ElementA>::value; // Alignment of A matrix in units of elements |
| 60 | + // (up to 16 bytes) |
| 61 | + |
| 62 | + static constexpr int AlignmentB = |
| 63 | + 128 / cutlass::sizeof_bits<ElementB>::value; // Alignment of B matrix in units of elements |
| 64 | + // (up to 16 bytes) |
| 65 | + |
| 66 | + static constexpr int AlignmentC = |
| 67 | + 128 / cutlass::sizeof_bits<ElementC>::value; // Alignment of C matrix in units of elements |
| 68 | + // (up to 16 bytes) |
| 69 | + |
| 70 | + // Core kernel configurations |
| 71 | + using ElementAccumulator = float; // Element type for internal accumulation |
| 72 | + using ScaleType = std::variant<ElementAccumulator, const ElementAccumulator*>; |
| 73 | + using ArchTag = |
| 74 | + cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature |
| 75 | + using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag |
| 76 | + using TileShape = typename KernelTraits::TileShape; |
| 77 | + using ClusterShape = typename KernelTraits::ClusterShape; |
| 78 | + using StageCountType = |
| 79 | + cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size |
| 80 | + using KernelSchedule = typename KernelTraits::KernelSchedule; // Kernel to launch |
| 81 | + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; // Epilogue to launch |
| 82 | + |
| 83 | + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< |
| 84 | + ArchTag, OperatorClass, TileShape, ClusterShape, |
| 85 | + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator, |
| 86 | + ElementC, LayoutC, AlignmentC, ElementC, LayoutC, AlignmentC, EpilogueSchedule>::CollectiveOp; |
| 87 | + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< |
| 88 | + ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, LayoutB, AlignmentB, |
| 89 | + ElementAccumulator, TileShape, ClusterShape, |
| 90 | + cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>( |
| 91 | + sizeof(typename CollectiveEpilogue::SharedStorage))>, |
| 92 | + KernelSchedule>::CollectiveOp; |
| 93 | + |
| 94 | + using GemmKernel = |
| 95 | + cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop, CollectiveEpilogue>; |
| 96 | + |
| 97 | + using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>; |
| 98 | + |
| 99 | + using StrideA = typename Gemm::GemmKernel::StrideA; |
| 100 | + using StrideB = typename Gemm::GemmKernel::StrideB; |
| 101 | + using StrideC = typename Gemm::GemmKernel::StrideC; |
| 102 | + using StrideD = typename Gemm::GemmKernel::StrideD; |
| 103 | + |
| 104 | + void run_gemm(const ElementA* ptr_A, const ElementB* ptr_B, const ElementC* ptr_C, |
| 105 | + ElementC* ptr_D, ProblemShape* problem_size, StrideA* stride_A, StrideB* stride_B, |
| 106 | + StrideC* stride_C, StrideD* stride_D, uint8_t* workspace, int64_t workspace_size, |
| 107 | + ScaleType alpha, ScaleType beta, cudaStream_t stream) { |
| 108 | + cutlass::KernelHardwareInfo hw_info; |
| 109 | + hw_info.device_id = 0; |
| 110 | + hw_info.sm_count = |
| 111 | + cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); |
| 112 | + typename Gemm::Arguments arguments{cutlass::gemm::GemmUniversalMode::kGemm, |
| 113 | + *problem_size, |
| 114 | + {ptr_A, *stride_A, ptr_B, *stride_B}, |
| 115 | + {{}, ptr_C, *stride_C, ptr_D, *stride_D}, |
| 116 | + // {epilogue_params, ptr_C, *stride_C, ptr_D, *stride_D}, |
| 117 | + hw_info}; |
| 118 | + |
| 119 | + ICHECK(alpha.index() == beta.index()) << "alpha and beta must have the same type"; |
| 120 | + if (std::holds_alternative<ElementAccumulator>(alpha)) { |
| 121 | + arguments.epilogue.thread.alpha = std::get<ElementAccumulator>(alpha); |
| 122 | + arguments.epilogue.thread.beta = std::get<ElementAccumulator>(beta); |
| 123 | + } else if (std::holds_alternative<const ElementAccumulator*>(alpha)) { |
| 124 | + arguments.epilogue.thread.alpha_ptr = std::get<const ElementAccumulator*>(alpha); |
| 125 | + arguments.epilogue.thread.beta_ptr = std::get<const ElementAccumulator*>(beta); |
| 126 | + } else { |
| 127 | + LOG(FATAL) << "Unsupported alpha and beta type"; |
| 128 | + throw; |
| 129 | + } |
| 130 | + |
| 131 | + Gemm gemm_op; |
| 132 | + CUTLASS_CHECK(gemm_op.can_implement(arguments)); |
| 133 | + CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments)); |
| 134 | + CUTLASS_CHECK(gemm_op.initialize(arguments, workspace, stream)); |
| 135 | + CUTLASS_CHECK(gemm_op.run(stream)); |
| 136 | + } |
| 137 | +}; |
| 138 | + |
| 139 | +template <typename KernelTraits, typename ElementA, typename ElementB, typename ElementC> |
| 140 | +void cutlass_gemm(ElementA* x, ElementB* weight, uint8_t* workspace, int64_t workspace_size, |
| 141 | + int64_t m, int64_t n, int64_t k, std::variant<float, const float*> alpha, |
| 142 | + std::variant<float, const float*> beta, ElementC* out, cudaStream_t stream) { |
| 143 | + using Runner = CutlassGemmRunner<KernelTraits, ElementA, ElementB, ElementC>; |
| 144 | + using StrideA = typename Runner::StrideA; |
| 145 | + using StrideB = typename Runner::StrideB; |
| 146 | + using StrideC = typename Runner::StrideC; |
| 147 | + |
| 148 | + Runner runner; |
| 149 | + StrideA stride_A = cute::make_stride(k, Int<1>{}, int64_t{0}); |
| 150 | + StrideB stride_B = cute::make_stride(k, Int<1>{}, int64_t{0}); |
| 151 | + StrideC stride_D = cute::make_stride(n, Int<1>{}, int64_t{0}); |
| 152 | + ProblemShape problem_size{static_cast<int>(m), static_cast<int>(n), static_cast<int>(k)}; |
| 153 | + runner.run_gemm(x, weight, out, out, &problem_size, &stride_A, &stride_B, &stride_D, &stride_D, |
| 154 | + workspace, workspace_size, alpha, beta, stream); |
| 155 | +} |
0 commit comments