|
| 1 | +// modified from https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/kernels/linear.cu |
| 2 | + |
| 3 | +#include "linear.h" |
| 4 | +#include <cutlass/core_io.h> |
| 5 | +#include <cutlass/cutlass.h> |
| 6 | +#include <cutlass/half.h> |
| 7 | + |
| 8 | +#include <cutlass/gemm/device/gemm.h> |
| 9 | +#include <cutlass/numeric_types.h> |
| 10 | +#include <cutlass/util/host_tensor.h> |
| 11 | +#include <cutlass/epilogue/thread/linear_combination_silu.h> |
| 12 | +#include <cstdint> |
| 13 | +#include <cuda.h> |
| 14 | +#include <cuda_runtime.h> |
| 15 | +#include <cuda_fp16.h> |
| 16 | +#include <iostream> |
| 17 | +#include <torch/torch.h> |
| 18 | +torch::Tensor linear_silu_a8_w8_bfp32_ofp32(torch::Tensor input, // INT8 |
| 19 | + torch::Tensor weight, // INT8 |
| 20 | + torch::Tensor bias, // FP32 |
| 21 | + float alpha, // FP32 |
| 22 | + float beta // FP32 |
| 23 | +) { |
| 24 | + auto M = input.size(0); |
| 25 | + auto N = weight.size(0); |
| 26 | + auto K = input.size(1); |
| 27 | + |
| 28 | + using ElementOutput = float; |
| 29 | + using ElementAccumulator = int32_t; |
| 30 | + using ElementComputeEpilogue = float; |
| 31 | + using ElementInputA = int8_t; // <- data type of elements in input matrix A |
| 32 | + using ElementInputB = int8_t; // <- data type of elements in input matrix B |
| 33 | + |
| 34 | + // The code section below describes matrix layout of input and output |
| 35 | + // matrices. Column Major for Matrix A, Row Major for Matrix B and Row Major |
| 36 | + // for Matrix C |
| 37 | + using LayoutInputA = cutlass::layout::RowMajor; |
| 38 | + using LayoutInputB = cutlass::layout::ColumnMajor; |
| 39 | + using LayoutOutput = cutlass::layout::RowMajor; |
| 40 | + |
| 41 | +#if CUDA_ARCH >= 800 |
| 42 | + using EpilogueOp = cutlass::epilogue::thread::LinearCombinationSilu< |
| 43 | + ElementOutput, // <- data type of output matrix |
| 44 | + 128 / cutlass::sizeof_bits< |
| 45 | + ElementOutput>::value, // <- this is the number of elements per |
| 46 | + // vectorized memory access. For half |
| 47 | + // precision, it's 8 elements. This |
| 48 | + // becomes the vector width of math |
| 49 | + // instructions in epilogue too |
| 50 | + ElementAccumulator, // <- data type of accumulator |
| 51 | + ElementComputeEpilogue // <- data type for alpha in linear combination |
| 52 | + // function |
| 53 | + >; |
| 54 | + using Gemm = cutlass::gemm::device::Gemm< |
| 55 | + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, |
| 56 | + ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, |
| 57 | + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, |
| 58 | + cutlass::gemm::GemmShape<256, 128, 64>, |
| 59 | + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, |
| 60 | + EpilogueOp, |
| 61 | + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; |
| 62 | +#elif CUDA_ARCH >= 750 |
| 63 | + using EpilogueOp = cutlass::epilogue::thread::LinearCombinationSilu< |
| 64 | + ElementOutput, // <- data type of output matrix |
| 65 | + 128 / cutlass::sizeof_bits< |
| 66 | + ElementOutput>::value, // <- this is the number of elements per |
| 67 | + // vectorized memory access. For half |
| 68 | + // precision, it's 8 elements. This |
| 69 | + // becomes the vector width of math |
| 70 | + // instructions in epilogue too |
| 71 | + ElementAccumulator, // <- data type of accumulator |
| 72 | + ElementComputeEpilogue // <- data type for alpha in linear combination |
| 73 | + // function |
| 74 | + >; |
| 75 | + |
| 76 | + using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration< |
| 77 | + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, |
| 78 | + ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>; |
| 79 | + using Gemm = cutlass::gemm::device::Gemm< |
| 80 | + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, |
| 81 | + ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, |
| 82 | + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, |
| 83 | + DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape, |
| 84 | + DefaultGemmCfg::InstructionShape, |
| 85 | + EpilogueOp>; |
| 86 | +#elif CUDA_ARCH >= 700 |
| 87 | + #define USE_TORCH_SILU |
| 88 | + using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration< |
| 89 | + cutlass::arch::OpClassSimt, cutlass::arch::Sm70, |
| 90 | + ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>; |
| 91 | + using Gemm = cutlass::gemm::device::Gemm< |
| 92 | + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, |
| 93 | + ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, |
| 94 | + cutlass::arch::OpClassSimt, cutlass::arch::Sm70, |
| 95 | + DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape, |
| 96 | + DefaultGemmCfg::InstructionShape, |
| 97 | + cutlass::epilogue::thread::LinearCombination< |
| 98 | + ElementOutput, 1, ElementAccumulator, ElementComputeEpilogue>>; |
| 99 | +#else |
| 100 | + #error "Unsupported cuda arch" |
| 101 | +#endif |
| 102 | + |
| 103 | + auto input_size = cutlass::MatrixCoord(M, K); |
| 104 | + auto weight_size = cutlass::MatrixCoord(K, N); |
| 105 | + auto output_size = cutlass::MatrixCoord(M, N); |
| 106 | + |
| 107 | + auto device = input.device(); |
| 108 | + // use the broadcasted bias as the output |
| 109 | + auto out = bias.to(device).view({1, -1}).repeat({M, 1}); |
| 110 | + |
| 111 | + // constexpr int kSparse = Gemm::kSparse; |
| 112 | + // How many elements of A are covered per ElementE |
| 113 | + // constexpr int kElementsPerElementE = Gemm::kElementsPerElementE; |
| 114 | + // The size of individual meta data |
| 115 | + // constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits; |
| 116 | + cutlass::gemm::GemmCoord problem_size(M, N, K); |
| 117 | + |
| 118 | + cutlass::TensorRef<ElementInputA, LayoutInputA> input_ref( |
| 119 | + input.data_ptr<ElementInputA>(), LayoutInputA::packed(input_size)); |
| 120 | + cutlass::TensorRef<ElementInputB, LayoutInputB> weight_ref( |
| 121 | + weight.data_ptr<ElementInputB>(), LayoutInputB::packed(weight_size)); |
| 122 | + cutlass::TensorRef<ElementOutput, LayoutOutput> out_ref( |
| 123 | + out.data_ptr<ElementOutput>(), LayoutOutput::packed(output_size)); |
| 124 | + |
| 125 | + typename Gemm::Arguments arguments{ |
| 126 | + problem_size, // <- problem size of matrix multiplication |
| 127 | + input_ref, // <- reference to matrix A on device |
| 128 | + weight_ref, // <- reference to matrix B on device |
| 129 | + out_ref, // <- reference to matrix C on device |
| 130 | + out_ref, // <- reference to matrix D on device |
| 131 | + {alpha, beta}, 1}; |
| 132 | + Gemm gemm_op; |
| 133 | + |
| 134 | + // Using the arguments, query for extra workspace required for matrix |
| 135 | + // multiplication computation |
| 136 | + size_t workspace_size = Gemm::get_workspace_size(arguments); |
| 137 | + |
| 138 | + // Allocate workspace memory |
| 139 | + cutlass::device_memory::allocation<uint8_t> workspace(workspace_size); |
| 140 | + |
| 141 | + // Check the problem size is supported or not |
| 142 | + cutlass::Status status = gemm_op.can_implement(arguments); |
| 143 | + if (status != cutlass::Status::kSuccess) { |
| 144 | + throw std::runtime_error("cutlass cannot implement"); |
| 145 | + } |
| 146 | + |
| 147 | + // Initialize CUTLASS kernel with arguments and workspace pointer |
| 148 | + status = gemm_op.initialize(arguments, workspace.get()); |
| 149 | + if (status != cutlass::Status::kSuccess) { |
| 150 | + throw std::runtime_error("cutlass cannot initialize"); |
| 151 | + } |
| 152 | + |
| 153 | + status = gemm_op(); |
| 154 | + if (status != cutlass::Status::kSuccess) { |
| 155 | + throw std::runtime_error("cutlass cannot run"); |
| 156 | + } |
| 157 | +#ifdef USE_TORCH_SILU |
| 158 | +#undef USE_TORCH_SILU |
| 159 | + out = torch::silu(out); |
| 160 | +#endif |
| 161 | + return out; |
| 162 | +} |
0 commit comments