diff --git a/CMakeLists.txt b/CMakeLists.txt index 5baa39b6f9e5..b7bfdc6c857b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -319,7 +319,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # Only build AllSpark kernels if we are building for at least some compatible archs. cuda_archs_loose_intersection(ALLSPARK_ARCHS "8.0;8.6;8.7;8.9" "${CUDA_ARCHS}") - if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND ALLSPARK_ARCHS) + if (ALLSPARK_ARCHS) set(ALLSPARK_SRCS "csrc/quantization/gptq_allspark/allspark_repack.cu" "csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu") @@ -330,7 +330,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") message(STATUS "Building AllSpark kernels for archs: ${ALLSPARK_ARCHS}") else() message(STATUS "Not building AllSpark kernels as no compatible archs found" - " in CUDA target architectures, or CUDA not >= 12.0") + " in CUDA target architectures") endif() diff --git a/csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu b/csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu index c4ed98ca64f8..b520f8c32b95 100644 --- a/csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu +++ b/csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu @@ -437,9 +437,10 @@ struct ComputeTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK { for (int n_idx = 0; n_idx < WARP_NITER; ++n_idx) { #pragma unroll for (int k_idx = 0; k_idx < 2; ++k_idx) { - FType low16 = static_cast(C_frag[m_idx][n_idx][k_idx * 2]); + FType low16 = + ScalarType::float2num(C_frag[m_idx][n_idx][k_idx * 2]); FType high16 = - static_cast(C_frag[m_idx][n_idx][k_idx * 2 + 1]); + ScalarType::float2num(C_frag[m_idx][n_idx][k_idx * 2 + 1]); uint32_t tmp = (reinterpret_cast(low16) & 0xffff) | (reinterpret_cast(high16) << 16); int sts_offset = @@ -793,7 +794,7 @@ __global__ void restore_N32_K16_dequantize_rhs_w8a16_perc_kernel( FT scale_reg[4]; *(reinterpret_cast(scale_reg)) = *(reinterpret_cast(scales + params_nidx)); - FT zero_reg[4] = {0}; + FT zero_reg[4]; if (zeros != nullptr) { *(reinterpret_cast(zero_reg)) = *(reinterpret_cast(zeros + params_nidx)); @@ -809,8 +810,10 @@ __global__ void restore_N32_K16_dequantize_rhs_w8a16_perc_kernel( reinterpret_cast::T2*>(&(fval_reg[ni * 4]))); #pragma unroll for (int ki = 0; ki < 4; ++ki) { - fval_reg[ni * 4 + ki] = - (fval_reg[ni * 4 + ki] - zero_reg[ni]) * scale_reg[ni]; + if (zeros != nullptr) { + fval_reg[ni * 4 + ki] = __hsub(fval_reg[ni * 4 + ki], zero_reg[ni]); + } + fval_reg[ni * 4 + ki] = __hmul(fval_reg[ni * 4 + ki], scale_reg[ni]); int sts_offset = sts_base_offset + ((ki / 2) * 8 + (ki % 2)) * 32 + ((ni + lane_id % 4) % 4) * 8; smem[sts_offset] = fval_reg[ni * 4 + ki]; diff --git a/csrc/quantization/gptq_allspark/allspark_utils.cuh b/csrc/quantization/gptq_allspark/allspark_utils.cuh index 7aded9a17280..80456c25590d 100644 --- a/csrc/quantization/gptq_allspark/allspark_utils.cuh +++ b/csrc/quantization/gptq_allspark/allspark_utils.cuh @@ -7,6 +7,8 @@ #include #include #include +#include "../gptq_marlin/marlin_dtypes.cuh" +using marlin::ScalarType; namespace allspark { @@ -66,14 +68,14 @@ __global__ void f16_gemm_splitk_reduce_kernel(const FType* C_split, FType* C, return; } - FType sum(0); + float sum = 0.f; int n_mat = N_MATRIX > 0 ? N_MATRIX : (int)n_matrix; for (int i = 0; i < n_mat; ++i) { - sum += C_split[idx + i * matrix_size]; + sum += ScalarType::num2float(C_split[idx + i * matrix_size]); } - C[idx] = sum; + C[idx] = ScalarType::float2num(sum); } template