diff --git a/cmake/modules/contrib/CUTLASS.cmake b/cmake/modules/contrib/CUTLASS.cmake index fa4a608f6161..11224a8d1f90 100644 --- a/cmake/modules/contrib/CUTLASS.cmake +++ b/cmake/modules/contrib/CUTLASS.cmake @@ -58,6 +58,7 @@ if(USE_CUDA AND USE_CUTLASS) if (CMAKE_CUDA_ARCHITECTURES MATCHES "90a") list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp16_group_gemm.cu) list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_group_gemm.cu) + list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_gemm.cu) endif() if(TVM_CUTLASS_RUNTIME_SRCS) add_library(tvm_cutlass_objs OBJECT ${TVM_CUTLASS_RUNTIME_SRCS}) diff --git a/src/runtime/contrib/cublas/cublas.cc b/src/runtime/contrib/cublas/cublas.cc index 8925080abfbc..c9a01fc24e06 100644 --- a/src/runtime/contrib/cublas/cublas.cc +++ b/src/runtime/contrib/cublas/cublas.cc @@ -194,11 +194,13 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream, &bias->data, sizeof(float*))); } - if (scaleA != nullptr && scaleB != nullptr) { + if (scaleA != nullptr) { auto scaleA_data = static_cast(scaleA->data) + scaleA->byte_offset; - auto scaleB_data = static_cast(scaleB->data) + scaleB->byte_offset; CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &scaleA_data, sizeof(float*))); + } + if (scaleB != nullptr) { + auto scaleB_data = static_cast(scaleB->data) + scaleB->byte_offset; CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &scaleB_data, sizeof(float*))); } diff --git a/src/runtime/contrib/cutlass/fp8_gemm.cu b/src/runtime/contrib/cutlass/fp8_gemm.cu new file mode 100644 index 000000000000..67e502a163cc --- /dev/null +++ b/src/runtime/contrib/cutlass/fp8_gemm.cu @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +#include "../cublas/cublas_utils.h" +#include "gemm_runner.cuh" + +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +struct KernelTraitsM64 { + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using TileShape = Shape<_64, _64, _128>; + using ClusterShape = Shape<_1, _8, _1>; +}; + +namespace tvm { +namespace runtime { + +template +void tvm_cutlass_fp8_gemm(NDArray x, NDArray weight, NDArray workspace, NDArray alpha, + NDArray out) { + // Workspace is used for storing device-side gemm arguments and cutlass internal workspace. + // Recommened size is 4MB. + auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream"); + ICHECK(func != nullptr); + CHECK_GE(x->ndim, 2); + CHECK_EQ(weight->ndim, 2); + CHECK_EQ(workspace->ndim, 1); + CHECK_GE(out->ndim, 2); + CHECK_EQ(alpha->dtype.code, kDLFloat); + CHECK_EQ(alpha->dtype.bits, 32); + CHECK_EQ(alpha->ndim, 1); + CHECK_EQ(alpha->shape[0], 1); + int64_t m = 1; + for (int i = 0; i < x->ndim - 1; ++i) { + m *= x->shape[i]; + } + int64_t n = weight->shape[0]; + CHECK_EQ(x->shape[x->ndim - 1], weight->shape[1]) << "Only col-major weight is supported now."; + int64_t k = x->shape[x->ndim - 1]; + const float* beta = nullptr; + cudaStream_t stream = static_cast((*func)().operator void*()); + if (m <= 64) { + cutlass_gemm( + static_cast(x->data), static_cast(weight->data), + static_cast(workspace->data), workspace->shape[0], m, n, k, + static_cast(alpha->data), beta, static_cast(out->data), stream); + } else { + tvm::contrib::CuBlasLtThreadEntry* cublas_entry = + tvm::contrib::CuBlasLtThreadEntry::ThreadLocal(); + tvm::contrib::CallCublasLt(cublas_entry->handle, stream, cublas_entry->matmul_pref_desc, + x.operator->(), weight.operator->(), nullptr, alpha.operator->(), + nullptr, out.operator->(), /*transa=*/false, /*transb=*/true, + cublas_entry->workspace_ptr, cublas_entry->workspace_size, + CUBLASLT_EPILOGUE_DEFAULT, std::nullopt); + } +} + +TVM_REGISTER_GLOBAL("cutlass.gemm_e5m2_e5m2_fp16") + .set_body_typed( + tvm_cutlass_fp8_gemm); + +TVM_REGISTER_GLOBAL("cutlass.gemm_e5m2_e4m3_fp16") + .set_body_typed( + tvm_cutlass_fp8_gemm); + +TVM_REGISTER_GLOBAL("cutlass.gemm_e4m3_e4m3_fp16") + .set_body_typed( + tvm_cutlass_fp8_gemm); + +} // namespace runtime +} // namespace tvm + +#endif // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED diff --git a/src/runtime/contrib/cutlass/gemm_runner.cuh b/src/runtime/contrib/cutlass/gemm_runner.cuh new file mode 100644 index 000000000000..c664f6cf6f0b --- /dev/null +++ b/src/runtime/contrib/cutlass/gemm_runner.cuh @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +#include "../../cuda/cuda_common.h" + +// clang-format off +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +// clang-format on + +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + CHECK(error == cutlass::Status::kSuccess) \ + << "Got cutlass error: " << cutlassGetStatusString(error); \ + } + +using namespace cute; +using ProblemShape = Shape; // + +template +struct CutlassGemmRunner { + static constexpr int AlignmentA = + 128 / cutlass::sizeof_bits::value; // Alignment of A matrix in units of elements + // (up to 16 bytes) + + static constexpr int AlignmentB = + 128 / cutlass::sizeof_bits::value; // Alignment of B matrix in units of elements + // (up to 16 bytes) + + static constexpr int AlignmentC = + 128 / cutlass::sizeof_bits::value; // Alignment of C matrix in units of elements + // (up to 16 bytes) + + // Core kernel configurations + using ElementAccumulator = float; // Element type for internal accumulation + using ScaleType = std::variant; + using ArchTag = + cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag + using TileShape = typename KernelTraits::TileShape; + using ClusterShape = typename KernelTraits::ClusterShape; + using StageCountType = + cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size + using KernelSchedule = typename KernelTraits::KernelSchedule; // Kernel to launch + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; // Epilogue to launch + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, ElementC, LayoutC, AlignmentC, EpilogueSchedule>::CollectiveOp; + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, LayoutB, AlignmentB, + ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernel = + cutlass::gemm::kernel::GemmUniversal; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + void run_gemm(const ElementA* ptr_A, const ElementB* ptr_B, const ElementC* ptr_C, + ElementC* ptr_D, ProblemShape* problem_size, StrideA* stride_A, StrideB* stride_B, + StrideC* stride_C, StrideD* stride_D, uint8_t* workspace, int64_t workspace_size, + ScaleType alpha, ScaleType beta, cudaStream_t stream) { + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + typename Gemm::Arguments arguments{cutlass::gemm::GemmUniversalMode::kGemm, + *problem_size, + {ptr_A, *stride_A, ptr_B, *stride_B}, + {{}, ptr_C, *stride_C, ptr_D, *stride_D}, + // {epilogue_params, ptr_C, *stride_C, ptr_D, *stride_D}, + hw_info}; + + ICHECK(alpha.index() == beta.index()) << "alpha and beta must have the same type"; + if (std::holds_alternative(alpha)) { + arguments.epilogue.thread.alpha = std::get(alpha); + arguments.epilogue.thread.beta = std::get(beta); + } else if (std::holds_alternative(alpha)) { + arguments.epilogue.thread.alpha_ptr = std::get(alpha); + arguments.epilogue.thread.beta_ptr = std::get(beta); + } else { + LOG(FATAL) << "Unsupported alpha and beta type"; + throw; + } + + Gemm gemm_op; + CUTLASS_CHECK(gemm_op.can_implement(arguments)); + CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments)); + CUTLASS_CHECK(gemm_op.initialize(arguments, workspace, stream)); + CUTLASS_CHECK(gemm_op.run(stream)); + } +}; + +template +void cutlass_gemm(ElementA* x, ElementB* weight, uint8_t* workspace, int64_t workspace_size, + int64_t m, int64_t n, int64_t k, std::variant alpha, + std::variant beta, ElementC* out, cudaStream_t stream) { + using Runner = CutlassGemmRunner; + using StrideA = typename Runner::StrideA; + using StrideB = typename Runner::StrideB; + using StrideC = typename Runner::StrideC; + + Runner runner; + StrideA stride_A = cute::make_stride(k, Int<1>{}, int64_t{0}); + StrideB stride_B = cute::make_stride(k, Int<1>{}, int64_t{0}); + StrideC stride_D = cute::make_stride(n, Int<1>{}, int64_t{0}); + ProblemShape problem_size{static_cast(m), static_cast(n), static_cast(k)}; + runner.run_gemm(x, weight, out, out, &problem_size, &stride_A, &stride_B, &stride_D, &stride_D, + workspace, workspace_size, alpha, beta, stream); +} diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py index 154a68e1169c..bc80323b753e 100644 --- a/tests/python/contrib/test_cutlass.py +++ b/tests/python/contrib/test_cutlass.py @@ -15,26 +15,27 @@ # specific language governing permissions and limitations # under the License. import logging -import tempfile import math +import tempfile + import ml_dtypes +import numpy as np + import tvm -from tvm import relay +import tvm.testing +from tvm import auto_scheduler, relay from tvm.contrib.cudnn import conv_output_shape -import numpy as np -from tvm.relay import op as _op -from tvm.runtime.vm import VirtualMachine -from tvm.relay.op.contrib.cutlass import partition_for_cutlass -from tvm import auto_scheduler -from tvm.relay.transform import FirstOrderGradient, ToMixedPrecision, InferType from tvm.contrib.cutlass import ( - has_cutlass, - num_cutlass_partitions, finalize_modules, finalize_modules_vm, + has_cutlass, + num_cutlass_partitions, ) from tvm.contrib.pickle_memoize import memoize -import tvm.testing +from tvm.relay import op as _op +from tvm.relay.op.contrib.cutlass import partition_for_cutlass +from tvm.relay.transform import FirstOrderGradient, InferType, ToMixedPrecision +from tvm.runtime.vm import VirtualMachine logging.basicConfig(level=logging.INFO) @@ -1189,13 +1190,13 @@ def test_group_gemm_sm90(): atol=1, ) verify_group_gemm( - "cutlass.group_gemm_e4m3_e5m2_fp16", + "cutlass.group_gemm_e5m2_e4m3_fp16", 8, 16, 16, 4, - "e4m3_float8", "e5m2_float8", + "e4m3_float8", "float16", True, rtol=1e-1, @@ -1203,5 +1204,85 @@ def test_group_gemm_sm90(): ) +def verify_gemm(func_name, M, N, K, x_dtype, weight_dtype, out_dtype, scale_value, rtol, atol): + gemm_func = tvm.get_global_func(func_name, allow_missing=True) + if gemm_func is None: + print(f"Skipped as {func_name} is not available") + return + + @memoize("tvm.contrib.cutlass.test_fp8_gemm_sm90") + def get_ref_data(): + a_np = get_random_ndarray((M, K), "float16") + b_np = get_random_ndarray((N, K), "float16") + c_np = a_np @ b_np.T * scale_value + return a_np, b_np, c_np + + def to_numpy_dtype(dtype): + mapping = {"e5m2_float8": ml_dtypes.float8_e5m2, "e4m3_float8": ml_dtypes.float8_e4m3fn} + return mapping.get(dtype, dtype) + + a_np, b_np, c_np = get_ref_data() + dev = tvm.cuda(0) + a_nd = tvm.nd.array(a_np.astype(to_numpy_dtype(x_dtype)), device=dev) + b_nd = tvm.nd.array(b_np.astype(to_numpy_dtype(weight_dtype)), device=dev) + c_nd = tvm.nd.empty(c_np.shape, dtype=out_dtype, device=dev) + workspace = tvm.nd.empty((4096 * 1024,), dtype="uint8", device=dev) + scale = tvm.nd.array(np.array([scale_value], dtype="float32"), device=dev) + gemm_func(a_nd, b_nd, workspace, scale, c_nd) + tvm.testing.assert_allclose(c_nd.asnumpy(), c_np, rtol=rtol, atol=atol) + + +@tvm.testing.requires_cutlass +def test_fp8_gemm_sm90(): + verify_gemm( + "cutlass.gemm_e5m2_e5m2_fp16", + 8, + 16, + 16, + "e5m2_float8", + "e5m2_float8", + "float16", + 1.5, + rtol=1e-1, + atol=1, + ) + verify_gemm( + "cutlass.gemm_e4m3_e4m3_fp16", + 8, + 16, + 16, + "e4m3_float8", + "e4m3_float8", + "float16", + 1.5, + rtol=1e-1, + atol=1, + ) + verify_gemm( + "cutlass.gemm_e4m3_e4m3_fp16", + 32, + 16, + 16, + "e4m3_float8", + "e4m3_float8", + "float16", + 1.5, + rtol=1e-1, + atol=1, + ) + verify_gemm( + "cutlass.gemm_e5m2_e4m3_fp16", + 8, + 16, + 16, + "e5m2_float8", + "e4m3_float8", + "float16", + 1.5, + rtol=1e-1, + atol=1, + ) + + if __name__ == "__main__": tvm.testing.main()