Skip to content

Commit 50ac4fc

Browse files
committed
Support cutlass nvfp4 gemm
Signed-off-by: kaixih <[email protected]>
1 parent 550d97e commit 50ac4fc

File tree

7 files changed

+582
-0
lines changed

7 files changed

+582
-0
lines changed

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
267267
"csrc/permute_cols.cu"
268268
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
269269
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
270+
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
270271
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
271272
"csrc/cutlass_extensions/common.cpp")
272273

@@ -383,6 +384,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
383384
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND FP4_ARCHS)
384385
set(SRCS
385386
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
387+
"csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu"
386388
)
387389
set_gencode_flags_for_srcs(
388390
SRCS "${SRCS}"

csrc/ops.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,11 @@ torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type,
152152
int64_t row);
153153

154154
#ifndef USE_ROCM
155+
void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
156+
torch::Tensor const& B, torch::Tensor const& A_sf,
157+
torch::Tensor const& B_sf,
158+
torch::Tensor const& alpha);
159+
155160
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
156161
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability);
157162

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/*
2+
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include <torch/all.h>
18+
19+
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
20+
void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A,
21+
torch::Tensor const& B,
22+
torch::Tensor const& A_sf,
23+
torch::Tensor const& B_sf,
24+
torch::Tensor const& alpha);
25+
#endif
26+
27+
void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
28+
torch::Tensor const& B, torch::Tensor const& A_sf,
29+
torch::Tensor const& B_sf,
30+
torch::Tensor const& alpha) {
31+
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
32+
return cutlass_scaled_fp4_mm_sm100a(D, A, B, A_sf, B_sf, alpha);
33+
#endif
34+
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 mm kernel.");
35+
}

0 commit comments

Comments
 (0)