Skip to content

Commit 4e70e4a

Browse files
authored
[CUTLASS] Add FP8 gemm kernels (#17408)
This PR introduces the sm90a FP8 kernels from CUTLASS. These kernels are helpful in the cases of small `M`, where cuBLAS has unoptimized performance.
1 parent 5648a8e commit 4e70e4a

File tree

5 files changed

+349
-15
lines changed

5 files changed

+349
-15
lines changed

cmake/modules/contrib/CUTLASS.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ if(USE_CUDA AND USE_CUTLASS)
5858
if (CMAKE_CUDA_ARCHITECTURES MATCHES "90a")
5959
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp16_group_gemm.cu)
6060
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_group_gemm.cu)
61+
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_gemm.cu)
6162
endif()
6263
if(TVM_CUTLASS_RUNTIME_SRCS)
6364
add_library(tvm_cutlass_objs OBJECT ${TVM_CUTLASS_RUNTIME_SRCS})

src/runtime/contrib/cublas/cublas.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,11 +194,13 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream,
194194
&bias->data, sizeof(float*)));
195195
}
196196

197-
if (scaleA != nullptr && scaleB != nullptr) {
197+
if (scaleA != nullptr) {
198198
auto scaleA_data = static_cast<char*>(scaleA->data) + scaleA->byte_offset;
199-
auto scaleB_data = static_cast<char*>(scaleB->data) + scaleB->byte_offset;
200199
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
201200
&scaleA_data, sizeof(float*)));
201+
}
202+
if (scaleB != nullptr) {
203+
auto scaleB_data = static_cast<char*>(scaleB->data) + scaleB->byte_offset;
202204
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
203205
&scaleB_data, sizeof(float*)));
204206
}
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
#include <cuda_fp16.h>
21+
#include <float.h>
22+
#include <tvm/runtime/ndarray.h>
23+
#include <tvm/runtime/packed_func.h>
24+
#include <tvm/runtime/registry.h>
25+
26+
#include "../cublas/cublas_utils.h"
27+
#include "gemm_runner.cuh"
28+
29+
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
30+
31+
struct KernelTraitsM64 {
32+
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
33+
using TileShape = Shape<_64, _64, _128>;
34+
using ClusterShape = Shape<_1, _8, _1>;
35+
};
36+
37+
namespace tvm {
38+
namespace runtime {
39+
40+
template <typename ElementA, typename ElementB, typename ElementC>
41+
void tvm_cutlass_fp8_gemm(NDArray x, NDArray weight, NDArray workspace, NDArray alpha,
42+
NDArray out) {
43+
// Workspace is used for storing device-side gemm arguments and cutlass internal workspace.
44+
// Recommened size is 4MB.
45+
auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
46+
ICHECK(func != nullptr);
47+
CHECK_GE(x->ndim, 2);
48+
CHECK_EQ(weight->ndim, 2);
49+
CHECK_EQ(workspace->ndim, 1);
50+
CHECK_GE(out->ndim, 2);
51+
CHECK_EQ(alpha->dtype.code, kDLFloat);
52+
CHECK_EQ(alpha->dtype.bits, 32);
53+
CHECK_EQ(alpha->ndim, 1);
54+
CHECK_EQ(alpha->shape[0], 1);
55+
int64_t m = 1;
56+
for (int i = 0; i < x->ndim - 1; ++i) {
57+
m *= x->shape[i];
58+
}
59+
int64_t n = weight->shape[0];
60+
CHECK_EQ(x->shape[x->ndim - 1], weight->shape[1]) << "Only col-major weight is supported now.";
61+
int64_t k = x->shape[x->ndim - 1];
62+
const float* beta = nullptr;
63+
cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator void*());
64+
if (m <= 64) {
65+
cutlass_gemm<KernelTraitsM64>(
66+
static_cast<ElementA*>(x->data), static_cast<ElementB*>(weight->data),
67+
static_cast<uint8_t*>(workspace->data), workspace->shape[0], m, n, k,
68+
static_cast<float*>(alpha->data), beta, static_cast<ElementC*>(out->data), stream);
69+
} else {
70+
tvm::contrib::CuBlasLtThreadEntry* cublas_entry =
71+
tvm::contrib::CuBlasLtThreadEntry::ThreadLocal();
72+
tvm::contrib::CallCublasLt(cublas_entry->handle, stream, cublas_entry->matmul_pref_desc,
73+
x.operator->(), weight.operator->(), nullptr, alpha.operator->(),
74+
nullptr, out.operator->(), /*transa=*/false, /*transb=*/true,
75+
cublas_entry->workspace_ptr, cublas_entry->workspace_size,
76+
CUBLASLT_EPILOGUE_DEFAULT, std::nullopt);
77+
}
78+
}
79+
80+
TVM_REGISTER_GLOBAL("cutlass.gemm_e5m2_e5m2_fp16")
81+
.set_body_typed(
82+
tvm_cutlass_fp8_gemm<cutlass::float_e5m2_t, cutlass::float_e5m2_t, cutlass::half_t>);
83+
84+
TVM_REGISTER_GLOBAL("cutlass.gemm_e5m2_e4m3_fp16")
85+
.set_body_typed(
86+
tvm_cutlass_fp8_gemm<cutlass::float_e5m2_t, cutlass::float_e4m3_t, cutlass::half_t>);
87+
88+
TVM_REGISTER_GLOBAL("cutlass.gemm_e4m3_e4m3_fp16")
89+
.set_body_typed(
90+
tvm_cutlass_fp8_gemm<cutlass::float_e4m3_t, cutlass::float_e4m3_t, cutlass::half_t>);
91+
92+
} // namespace runtime
93+
} // namespace tvm
94+
95+
#endif // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
#include <fstream>
21+
#include <iostream>
22+
#include <sstream>
23+
#include <variant>
24+
#include <vector>
25+
26+
#include "../../cuda/cuda_common.h"
27+
28+
// clang-format off
29+
#include "cutlass/cutlass.h"
30+
31+
#include "cute/tensor.hpp"
32+
#include "cutlass/tensor_ref.h"
33+
#include "cutlass/epilogue/collective/default_epilogue.hpp"
34+
#include "cutlass/epilogue/thread/linear_combination.h"
35+
#include "cutlass/gemm/dispatch_policy.hpp"
36+
#include "cutlass/gemm/gemm.h"
37+
#include "cutlass/gemm/collective/collective_builder.hpp"
38+
#include "cutlass/epilogue/collective/collective_builder.hpp"
39+
#include "cutlass/gemm/device/gemm_universal_adapter.h"
40+
#include "cutlass/gemm/kernel/gemm_universal.hpp"
41+
// clang-format on
42+
43+
#define CUTLASS_CHECK(status) \
44+
{ \
45+
cutlass::Status error = status; \
46+
CHECK(error == cutlass::Status::kSuccess) \
47+
<< "Got cutlass error: " << cutlassGetStatusString(error); \
48+
}
49+
50+
using namespace cute;
51+
using ProblemShape = Shape<int, int, int>; // <M, N, K>
52+
53+
template <typename KernelTraits, typename ElementA, typename ElementB, typename ElementC,
54+
typename LayoutA = cutlass::layout::RowMajor,
55+
typename LayoutB = cutlass::layout::ColumnMajor,
56+
typename LayoutC = cutlass::layout::RowMajor>
57+
struct CutlassGemmRunner {
58+
static constexpr int AlignmentA =
59+
128 / cutlass::sizeof_bits<ElementA>::value; // Alignment of A matrix in units of elements
60+
// (up to 16 bytes)
61+
62+
static constexpr int AlignmentB =
63+
128 / cutlass::sizeof_bits<ElementB>::value; // Alignment of B matrix in units of elements
64+
// (up to 16 bytes)
65+
66+
static constexpr int AlignmentC =
67+
128 / cutlass::sizeof_bits<ElementC>::value; // Alignment of C matrix in units of elements
68+
// (up to 16 bytes)
69+
70+
// Core kernel configurations
71+
using ElementAccumulator = float; // Element type for internal accumulation
72+
using ScaleType = std::variant<ElementAccumulator, const ElementAccumulator*>;
73+
using ArchTag =
74+
cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
75+
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
76+
using TileShape = typename KernelTraits::TileShape;
77+
using ClusterShape = typename KernelTraits::ClusterShape;
78+
using StageCountType =
79+
cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
80+
using KernelSchedule = typename KernelTraits::KernelSchedule; // Kernel to launch
81+
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; // Epilogue to launch
82+
83+
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
84+
ArchTag, OperatorClass, TileShape, ClusterShape,
85+
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator,
86+
ElementC, LayoutC, AlignmentC, ElementC, LayoutC, AlignmentC, EpilogueSchedule>::CollectiveOp;
87+
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
88+
ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, LayoutB, AlignmentB,
89+
ElementAccumulator, TileShape, ClusterShape,
90+
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
91+
sizeof(typename CollectiveEpilogue::SharedStorage))>,
92+
KernelSchedule>::CollectiveOp;
93+
94+
using GemmKernel =
95+
cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop, CollectiveEpilogue>;
96+
97+
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
98+
99+
using StrideA = typename Gemm::GemmKernel::StrideA;
100+
using StrideB = typename Gemm::GemmKernel::StrideB;
101+
using StrideC = typename Gemm::GemmKernel::StrideC;
102+
using StrideD = typename Gemm::GemmKernel::StrideD;
103+
104+
void run_gemm(const ElementA* ptr_A, const ElementB* ptr_B, const ElementC* ptr_C,
105+
ElementC* ptr_D, ProblemShape* problem_size, StrideA* stride_A, StrideB* stride_B,
106+
StrideC* stride_C, StrideD* stride_D, uint8_t* workspace, int64_t workspace_size,
107+
ScaleType alpha, ScaleType beta, cudaStream_t stream) {
108+
cutlass::KernelHardwareInfo hw_info;
109+
hw_info.device_id = 0;
110+
hw_info.sm_count =
111+
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
112+
typename Gemm::Arguments arguments{cutlass::gemm::GemmUniversalMode::kGemm,
113+
*problem_size,
114+
{ptr_A, *stride_A, ptr_B, *stride_B},
115+
{{}, ptr_C, *stride_C, ptr_D, *stride_D},
116+
// {epilogue_params, ptr_C, *stride_C, ptr_D, *stride_D},
117+
hw_info};
118+
119+
ICHECK(alpha.index() == beta.index()) << "alpha and beta must have the same type";
120+
if (std::holds_alternative<ElementAccumulator>(alpha)) {
121+
arguments.epilogue.thread.alpha = std::get<ElementAccumulator>(alpha);
122+
arguments.epilogue.thread.beta = std::get<ElementAccumulator>(beta);
123+
} else if (std::holds_alternative<const ElementAccumulator*>(alpha)) {
124+
arguments.epilogue.thread.alpha_ptr = std::get<const ElementAccumulator*>(alpha);
125+
arguments.epilogue.thread.beta_ptr = std::get<const ElementAccumulator*>(beta);
126+
} else {
127+
LOG(FATAL) << "Unsupported alpha and beta type";
128+
throw;
129+
}
130+
131+
Gemm gemm_op;
132+
CUTLASS_CHECK(gemm_op.can_implement(arguments));
133+
CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments));
134+
CUTLASS_CHECK(gemm_op.initialize(arguments, workspace, stream));
135+
CUTLASS_CHECK(gemm_op.run(stream));
136+
}
137+
};
138+
139+
template <typename KernelTraits, typename ElementA, typename ElementB, typename ElementC>
140+
void cutlass_gemm(ElementA* x, ElementB* weight, uint8_t* workspace, int64_t workspace_size,
141+
int64_t m, int64_t n, int64_t k, std::variant<float, const float*> alpha,
142+
std::variant<float, const float*> beta, ElementC* out, cudaStream_t stream) {
143+
using Runner = CutlassGemmRunner<KernelTraits, ElementA, ElementB, ElementC>;
144+
using StrideA = typename Runner::StrideA;
145+
using StrideB = typename Runner::StrideB;
146+
using StrideC = typename Runner::StrideC;
147+
148+
Runner runner;
149+
StrideA stride_A = cute::make_stride(k, Int<1>{}, int64_t{0});
150+
StrideB stride_B = cute::make_stride(k, Int<1>{}, int64_t{0});
151+
StrideC stride_D = cute::make_stride(n, Int<1>{}, int64_t{0});
152+
ProblemShape problem_size{static_cast<int>(m), static_cast<int>(n), static_cast<int>(k)};
153+
runner.run_gemm(x, weight, out, out, &problem_size, &stride_A, &stride_B, &stride_D, &stride_D,
154+
workspace, workspace_size, alpha, beta, stream);
155+
}

0 commit comments

Comments
 (0)