From 7dd7b88db2782b3b943bd0184c7ad230c0d97f9e Mon Sep 17 00:00:00 2001 From: sstamenk Date: Fri, 17 Oct 2025 22:13:19 +0000 Subject: [PATCH] Reuse BNB_WARP_SIZE macro --- csrc/common_hip.cuh | 8 +++++-- csrc/kernels.hip | 51 ++++++++++++++++++++++----------------------- csrc/ops.hip | 2 +- 3 files changed, 32 insertions(+), 29 deletions(-) diff --git a/csrc/common_hip.cuh b/csrc/common_hip.cuh index 1d9d9afe0..3ea545e9a 100644 --- a/csrc/common_hip.cuh +++ b/csrc/common_hip.cuh @@ -1,7 +1,11 @@ #pragma once -#define BNB_WARP_SIZE warpSize +#ifdef __GFX9__ + #define BNB_WARP_SIZE 64 +#else + #define BNB_WARP_SIZE 32 +#endif // These are set based on current BNB support for CDNA 2 & RDNA 3. Update as needed for future archs -#define BNB_MAX_THREADS_PER_SM 2048 +#define BNB_MAX_THREADS_PER_CU 2048 #define BNB_BF16_AVAILABLE true diff --git a/csrc/kernels.hip b/csrc/kernels.hip index 6956ebac4..9b94b334d 100644 --- a/csrc/kernels.hip +++ b/csrc/kernels.hip @@ -1933,7 +1933,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char // rowStats [rows] // out [rows, cols] template -__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) +__launch_bounds__(1024, BNB_MAX_THREADS_PER_CU / 1024) __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols) { // For sm50/sm52 and CUDA < 12.2 we need to do the reduction in fp32. @@ -1997,7 +1997,7 @@ __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStat } template -__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) +__launch_bounds__(1024, BNB_MAX_THREADS_PER_CU / 1024) __global__ void kgetRowStats(T * __restrict__ A, float *rowStats, float threshold, int rows, int cols) { using BlockReduceT = hipcub::BlockReduce; @@ -2109,7 +2109,6 @@ __global__ void kdequant_mm_int32_fp16( #define DENORM 1.0f/127.0f #define MAX_SPARSE_COUNT 32 #define SMEM_SIZE 8*256 -#define WARP_SIZE warpSize template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB) { @@ -2130,9 +2129,9 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o const int offset = local_max_idx == 0 ? 0 : offset_rowidx[local_max_idx-1]; const int local_row_idx = rowidx[offset]; - const int warp_id = threadIdx.x / WARP_SIZE; - const int warp_idx = threadIdx.x % WARP_SIZE; - const int warp_offset = (warp_id*WARP_SIZE)*SPMM_ITEMS; + const int warp_id = threadIdx.x / BNB_WARP_SIZE; + const int warp_idx = threadIdx.x % BNB_WARP_SIZE; + const int warp_offset = (warp_id*BNB_WARP_SIZE)*SPMM_ITEMS; const int num_items = BITS == 8 ? 8 : 8; int idx_col_B = warp_offset; int local_idx_col_B_offset = 0; @@ -2152,7 +2151,7 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o } // each thread processes SPMM_ITEMS=32 per iteration. We have 256 threads. 32*256=x192 - // we expect each warp to be SPMM_ITEMS*WARP_SIZE apart + // we expect each warp to be SPMM_ITEMS*BNB_WARP_SIZE apart // we have a total of 128 bytes for the bank with a bank size of 4 bytes // added 3 bytes = 6 values between warps should reduce bank conflicts __shared__ half smem_dequant_stats[SMEM_SIZE]; @@ -2705,16 +2704,16 @@ template __global__ void kgemm_4bit_inferenc { // per threadblock: - // load step-by-step in chunks of [warp_size,warps]: 1xwarp_size * [warp_size,warps] -> [1,warps] + // load step-by-step in chunks of [BNB_WARP_SIZE,warps]: 1xBNB_WARP_SIZE * [BNB_WARP_SIZE,warps] -> [1,warps] // 4 warps -> 4 loads per iter - // 1xwarp_size * warp_sizex4 -> 1x4 outputs per thread block - typedef hipcub::WarpReduce WarpReduce; - __shared__ typename WarpReduce::TempStorage temp_storage[THREADS/warpSize]; - - const int warp_idx = threadIdx.x / warpSize; - const int warp_lane = threadIdx.x % warpSize; - const int row_B = (THREADS/warpSize)*blockIdx.x + warp_idx; - const int offset_B = ldb*row_B; + // 1xBNB_WARP_SIZE * BNB_WARP_SIZEx4 -> 1x4 outputs per thread block + typedef hipcub::WarpReduce WarpReduce; + __shared__ typename WarpReduce::TempStorage temp_storage[THREADS/BNB_WARP_SIZE]; + + const int warp_idx = threadIdx.x / BNB_WARP_SIZE; + const int warp_lane = threadIdx.x % BNB_WARP_SIZE; + const int row_B = (THREADS/BNB_WARP_SIZE)*blockIdx.x + warp_idx; + const int offset_B = ldb * row_B; const int num_values_8bit = num_values_4bit/2; float local_C = 0.0f; @@ -2732,7 +2731,7 @@ template __global__ void kgemm_4bit_inferenc // A: [1, K] // B: [M, K] - for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += warpSize*num_values_4bit) + for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += BNB_WARP_SIZE*num_values_4bit) { const int inner_idx_halved = inner_idx/2; @@ -3044,7 +3043,7 @@ MAKE_kQuantizeBlockwise(half, 1024, 4, 0, General8bit) MAKE_kQuantizeBlockwise(half, 512, 2, 0, General8bit) MAKE_kQuantizeBlockwise(half, 256, 2, 0, General8bit) MAKE_kQuantizeBlockwise(half, 128, 2, 0, General8bit) -#if WARP_SIZE == 32 +#if BNB_WARP_SIZE == 32 MAKE_kQuantizeBlockwise(half, 64, 2, 0, General8bit) #endif @@ -3054,7 +3053,7 @@ MAKE_kQuantizeBlockwise(half, 1024, 4, 0, FP4) MAKE_kQuantizeBlockwise(half, 512, 2, 0, FP4) MAKE_kQuantizeBlockwise(half, 256, 2, 0, FP4) MAKE_kQuantizeBlockwise(half, 128, 2, 0, FP4) -#if WARP_SIZE == 32 +#if BNB_WARP_SIZE == 32 MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4) #endif @@ -3064,7 +3063,7 @@ MAKE_kQuantizeBlockwise(half, 1024, 4, 0, NF4) MAKE_kQuantizeBlockwise(half, 512, 2, 0, NF4) MAKE_kQuantizeBlockwise(half, 256, 2, 0, NF4) MAKE_kQuantizeBlockwise(half, 128, 2, 0, NF4) -#if WARP_SIZE == 32 +#if BNB_WARP_SIZE == 32 MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4) #endif @@ -3075,7 +3074,7 @@ MAKE_kQuantizeBlockwise(float, 1024, 4, 0, General8bit) MAKE_kQuantizeBlockwise(float, 512, 2, 0, General8bit) MAKE_kQuantizeBlockwise(float, 256, 2, 0, General8bit) MAKE_kQuantizeBlockwise(float, 128, 2, 0, General8bit) -#if WARP_SIZE == 32 +#if BNB_WARP_SIZE == 32 MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit) #endif @@ -3085,7 +3084,7 @@ MAKE_kQuantizeBlockwise(float, 1024, 4, 0, FP4) MAKE_kQuantizeBlockwise(float, 512, 2, 0, FP4) MAKE_kQuantizeBlockwise(float, 256, 2, 0, FP4) MAKE_kQuantizeBlockwise(float, 128, 2, 0, FP4) -#if WARP_SIZE == 32 +#if BNB_WARP_SIZE == 32 MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4) #endif @@ -3095,7 +3094,7 @@ MAKE_kQuantizeBlockwise(float, 1024, 4, 0, NF4) MAKE_kQuantizeBlockwise(float, 512, 2, 0, NF4) MAKE_kQuantizeBlockwise(float, 256, 2, 0, NF4) MAKE_kQuantizeBlockwise(float, 128, 2, 0, NF4) -#if WARP_SIZE == 32 +#if BNB_WARP_SIZE == 32 MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4) #endif @@ -3106,7 +3105,7 @@ MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, General8bit) MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, General8bit) MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, General8bit) MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, General8bit) -#if WARP_SIZE == 32 +#if BNB_WARP_SIZE == 32 MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, General8bit) #endif @@ -3116,7 +3115,7 @@ MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, FP4) MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, FP4) MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, FP4) MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, FP4) -#if WARP_SIZE == 32 +#if BNB_WARP_SIZE == 32 MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, FP4) #endif @@ -3126,7 +3125,7 @@ MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, NF4) MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, NF4) MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, NF4) MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, NF4) -#if WARP_SIZE == 32 +#if BNB_WARP_SIZE == 32 MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, NF4) #endif diff --git a/csrc/ops.hip b/csrc/ops.hip index 4f7ce9abb..a7ab32fdc 100644 --- a/csrc/ops.hip +++ b/csrc/ops.hip @@ -693,7 +693,7 @@ template void gemm_4bit_inference_naive(int m, int n, int //warpsize - 32 int num_blocks = (m+3)/4; //warpsize - 64 - if (warpSize == 64) { + if (BNB_WARP_SIZE == 64) { num_blocks = (m+1)/2; }