-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
[NVIDIA] Support nvfp4 cutlass gemm #13571
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
50ac4fc
87b200c
722bfca
4295589
d0449a2
a736e2b
4407d7c
1c79d1e
78fb480
d64b898
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
/* | ||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include <torch/all.h> | ||
|
||
#if defined ENABLE_NVFP4 && ENABLE_NVFP4 | ||
void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A, | ||
torch::Tensor const& B, | ||
torch::Tensor const& A_sf, | ||
torch::Tensor const& B_sf, | ||
torch::Tensor const& alpha); | ||
#endif | ||
|
||
void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A, | ||
torch::Tensor const& B, torch::Tensor const& A_sf, | ||
torch::Tensor const& B_sf, | ||
torch::Tensor const& alpha) { | ||
#if defined ENABLE_NVFP4 && ENABLE_NVFP4 | ||
return cutlass_scaled_fp4_mm_sm100a(D, A, B, A_sf, B_sf, alpha); | ||
#endif | ||
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 mm kernel, vLLM should " | ||
"be compiled using CUDA 12.8 and target " | ||
"compute capability 100 or above."); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,280 @@ | ||
/* | ||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include <torch/all.h> | ||
|
||
#include <ATen/cuda/CUDAContext.h> | ||
#include <c10/cuda/CUDAGuard.h> | ||
|
||
#include "cutlass_extensions/common.hpp" | ||
|
||
#include "cutlass/cutlass.h" | ||
|
||
#include "cutlass/gemm/collective/collective_builder.hpp" | ||
#include "cutlass/epilogue/collective/collective_builder.hpp" | ||
#include "cutlass/gemm/device/gemm_universal_adapter.h" | ||
#include "cutlass/gemm/kernel/gemm_universal.hpp" | ||
|
||
#include "cutlass/util/packed_stride.hpp" | ||
|
||
using namespace cute; | ||
|
||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) | ||
// Kernel Perf config | ||
template <typename T> | ||
struct KernelTraits; | ||
|
||
template <> | ||
struct KernelTraits<float> { | ||
using MmaTileShape = Shape<_128, _128, _256>; | ||
using ClusterShape = Shape<_1, _1, _1>; | ||
using PerSmTileShape_MNK = Shape<_128, _128, _256>; | ||
}; | ||
|
||
template <> | ||
struct KernelTraits<cutlass::half_t> { | ||
using MmaTileShape = Shape<_256, _256, _256>; | ||
using ClusterShape = Shape<_4, _4, _1>; | ||
using PerSmTileShape_MNK = Shape<_128, _256, _256>; | ||
}; | ||
|
||
template <> | ||
struct KernelTraits<cutlass::bfloat16_t> { | ||
using MmaTileShape = Shape<_256, _256, _256>; | ||
using ClusterShape = Shape<_4, _4, _1>; | ||
using PerSmTileShape_MNK = Shape<_128, _256, _256>; | ||
}; | ||
|
||
template <typename T> | ||
struct Fp4GemmSm100 { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. future work: I think we should try to unify this in the future with |
||
// A matrix configuration | ||
using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>; | ||
using LayoutATag = cutlass::layout::RowMajor; | ||
static constexpr int AlignmentA = 32; | ||
|
||
// B matrix configuration | ||
using ElementB = cutlass::nv_float4_t<cutlass::float_e2m1_t>; | ||
using LayoutBTag = cutlass::layout::ColumnMajor; | ||
static constexpr int AlignmentB = 32; | ||
|
||
// C/D matrix configuration | ||
using ElementD = T; | ||
using ElementC = T; | ||
using LayoutCTag = cutlass::layout::RowMajor; | ||
using LayoutDTag = cutlass::layout::RowMajor; | ||
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; | ||
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; | ||
// Kernel functional config | ||
using ElementAccumulator = float; | ||
using ArchTag = cutlass::arch::Sm100; | ||
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; | ||
|
||
// Kernel Perf config | ||
using MmaTileShape = typename KernelTraits<T>::MmaTileShape; | ||
using ClusterShape = typename KernelTraits<T>::ClusterShape; | ||
using PerSmTileShape_MNK = typename KernelTraits<T>::PerSmTileShape_MNK; | ||
|
||
using CollectiveEpilogue = | ||
typename cutlass::epilogue::collective::CollectiveBuilder< | ||
ArchTag, OperatorClass, PerSmTileShape_MNK, ClusterShape, | ||
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, | ||
ElementAccumulator, ElementC, LayoutCTag, AlignmentC, ElementD, | ||
LayoutDTag, AlignmentD, | ||
cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp; | ||
|
||
using CollectiveMainloop = | ||
typename cutlass::gemm::collective::CollectiveBuilder< | ||
ArchTag, OperatorClass, ElementA, LayoutATag, AlignmentA, ElementB, | ||
LayoutBTag, AlignmentB, ElementAccumulator, MmaTileShape, | ||
ClusterShape, | ||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>( | ||
sizeof(typename CollectiveEpilogue::SharedStorage))>, | ||
cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp; | ||
|
||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal< | ||
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>; | ||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>; | ||
using StrideA = typename Gemm::GemmKernel::StrideA; | ||
using LayoutA = decltype(cute::make_layout(make_shape(0, 0, 0), StrideA{})); | ||
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; | ||
using StrideB = typename Gemm::GemmKernel::StrideB; | ||
using LayoutB = decltype(cute::make_layout(make_shape(0, 0, 0), StrideB{})); | ||
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; | ||
using StrideC = typename Gemm::GemmKernel::StrideC; | ||
using LayoutC = decltype(cute::make_layout(make_shape(0, 0, 0), StrideC{})); | ||
using StrideD = typename Gemm::GemmKernel::StrideD; | ||
using LayoutD = decltype(cute::make_layout(make_shape(0, 0, 0), StrideD{})); | ||
}; | ||
|
||
template <typename T> | ||
typename T::Gemm::Arguments args_from_options( | ||
at::Tensor& D, at::Tensor const& A, at::Tensor const& B, | ||
at::Tensor const& A_sf, at::Tensor const& B_sf, at::Tensor const& alpha, | ||
int64_t M, int64_t N, int64_t K) { | ||
using ElementA = typename T::Gemm::ElementA; | ||
using ElementB = typename T::Gemm::ElementB; | ||
using ElementSFA = cutlass::float_ue4m3_t; | ||
using ElementSFB = cutlass::float_ue4m3_t; | ||
using ElementD = typename T::Gemm::ElementD; | ||
using ElementCompute = float; | ||
using StrideA = typename T::StrideA; | ||
using StrideB = typename T::StrideB; | ||
using StrideD = typename T::StrideD; | ||
using Sm100BlkScaledConfig = | ||
typename T::Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig; | ||
|
||
int m = static_cast<int>(M); | ||
int n = static_cast<int>(N); | ||
int k = static_cast<int>(K); | ||
auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {m, k, 1}); | ||
auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {n, k, 1}); | ||
auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {m, n, 1}); | ||
|
||
auto layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA( | ||
cute::make_shape(m, n, k, 1)); | ||
auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB( | ||
cute::make_shape(m, n, k, 1)); | ||
|
||
typename T::Gemm::Arguments arguments{ | ||
cutlass::gemm::GemmUniversalMode::kGemm, | ||
{m, n, k, 1}, | ||
{// Mainloop arguments | ||
static_cast<ElementA const*>(A.data_ptr()), stride_A, | ||
static_cast<ElementB const*>(B.data_ptr()), stride_B, | ||
static_cast<ElementSFA const*>(A_sf.data_ptr()), layout_SFA, | ||
static_cast<ElementSFB const*>(B_sf.data_ptr()), layout_SFB}, | ||
{ // Epilogue arguments | ||
{}, // epilogue.thread | ||
static_cast<ElementD const*>(D.data_ptr()), | ||
stride_D, | ||
static_cast<ElementD*>(D.data_ptr()), | ||
stride_D}}; | ||
auto& fusion_args = arguments.epilogue.thread; | ||
fusion_args.alpha_ptr = static_cast<ElementCompute const*>(alpha.data_ptr()); | ||
return arguments; | ||
} | ||
|
||
template <typename T> | ||
void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. future work: we should see if can unify this with |
||
at::Tensor const& A_sf, at::Tensor const& B_sf, | ||
at::Tensor const& alpha, int64_t m, int64_t n, int64_t k, | ||
cudaStream_t stream) { | ||
typename Fp4GemmSm100<T>::Gemm gemm; | ||
|
||
auto arguments = | ||
args_from_options<Fp4GemmSm100<T>>(D, A, B, A_sf, B_sf, alpha, m, n, k); | ||
|
||
size_t workspace_size = Fp4GemmSm100<T>::Gemm::get_workspace_size(arguments); | ||
auto const workspace_options = | ||
torch::TensorOptions().dtype(torch::kUInt8).device(A.device()); | ||
auto workspace = torch::empty(workspace_size, workspace_options); | ||
|
||
CUTLASS_CHECK(gemm.can_implement(arguments)); | ||
|
||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.data_ptr(), stream)); | ||
|
||
CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream)); | ||
} | ||
#else | ||
template <typename T> | ||
void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B, | ||
at::Tensor const& A_sf, at::Tensor const& B_sf, | ||
at::Tensor const& alpha, int64_t m, int64_t n, int64_t k, | ||
cudaStream_t stream) { | ||
TORCH_CHECK(false, "Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to " | ||
"a CUTLASS 3.8 source directory to enable support."); | ||
} | ||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) | ||
|
||
#define CHECK_TYPE(x, st, m) \ | ||
TORCH_CHECK(x.scalar_type() == st, "Inconsistency of Tensor type:", m) | ||
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor") | ||
#define CHECK_CONTIGUOUS(x, m) \ | ||
TORCH_CHECK(x.is_contiguous(), m, "must be contiguous") | ||
#define CHECK_INPUT(x, st, m) \ | ||
CHECK_TH_CUDA(x, m); \ | ||
CHECK_CONTIGUOUS(x, m); \ | ||
CHECK_TYPE(x, st, m) | ||
|
||
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte; | ||
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn; | ||
|
||
void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A, | ||
torch::Tensor const& B, | ||
torch::Tensor const& A_sf, | ||
torch::Tensor const& B_sf, | ||
torch::Tensor const& alpha) { | ||
CHECK_INPUT(A, FLOAT4_E2M1X2, "a"); | ||
CHECK_INPUT(B, FLOAT4_E2M1X2, "b"); | ||
|
||
CHECK_INPUT(A_sf, SF_DTYPE, "scale_a"); | ||
CHECK_INPUT(B_sf, SF_DTYPE, "scale_b"); | ||
|
||
CHECK_INPUT(alpha, at::ScalarType::Float, "alpha"); | ||
|
||
TORCH_CHECK(A.dim() == 2, "a must be a matrix"); | ||
TORCH_CHECK(B.dim() == 2, "b must be a matrix"); | ||
TORCH_CHECK(A.sizes()[1] == B.sizes()[1], | ||
"a and b shapes cannot be multiplied (", A.sizes()[0], "x", | ||
A.sizes()[1], " and ", B.sizes()[0], "x", B.sizes()[1], ")"); | ||
|
||
LucasWilkinson marked this conversation as resolved.
Show resolved
Hide resolved
|
||
auto const m = A.sizes()[0]; | ||
auto const n = B.sizes()[0]; | ||
auto const k = A.sizes()[1] * 2; | ||
|
||
constexpr int alignment = 32; | ||
TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ", alignment, | ||
", but got a shape: (", A.sizes()[0], "x", A.sizes()[1], | ||
"), k: ", k, "."); | ||
TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ", alignment, | ||
", but got b shape: (", B.sizes()[0], "x", B.sizes()[1], ")."); | ||
|
||
auto round_up = [](int x, int y) { return (x + y - 1) / y * y; }; | ||
int rounded_m = round_up(m, 128); | ||
int rounded_n = round_up(n, 128); | ||
// Since k is divisible by 32 (alignment), k / 16 is guaranteed to be an | ||
// integer. | ||
int rounded_k = round_up(k / 16, 4); | ||
|
||
TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix"); | ||
TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix"); | ||
TORCH_CHECK(A_sf.sizes()[1] == B_sf.sizes()[1], | ||
"scale_a and scale_b shapes cannot be multiplied (", | ||
A_sf.sizes()[0], "x", A_sf.sizes()[1], " and ", B_sf.sizes()[0], | ||
"x", B_sf.sizes()[1], ")"); | ||
TORCH_CHECK(A_sf.sizes()[0] == rounded_m && A_sf.sizes()[1] == rounded_k, | ||
"scale_a must be padded and swizzled to a shape (", rounded_m, | ||
"x", rounded_k, "), but got a shape (", A_sf.sizes()[0], "x", | ||
A_sf.sizes()[1], ")"); | ||
TORCH_CHECK(B_sf.sizes()[0] == rounded_n && B_sf.sizes()[1] == rounded_k, | ||
"scale_b must be padded and swizzled to a shape (", rounded_n, | ||
"x", rounded_k, "), but got a shape (", B_sf.sizes()[0], "x", | ||
B_sf.sizes()[1], ")"); | ||
|
||
auto out_dtype = D.dtype(); | ||
at::cuda::CUDAGuard device_guard{(char)A.get_device()}; | ||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device()); | ||
|
||
if (out_dtype == at::ScalarType::Half) { | ||
runGemm<cutlass::half_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); | ||
} else if (out_dtype == at::ScalarType::BFloat16) { | ||
runGemm<cutlass::bfloat16_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); | ||
} else if (out_dtype == at::ScalarType::Float) { | ||
runGemm<float>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); | ||
} else { | ||
TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm"); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just noticed this PR didn't actually update CUTLASS to 3.8 (see line 250 below)