Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions csrc/common_hip.cuh
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
#pragma once

#define BNB_WARP_SIZE warpSize
#ifdef __GFX9__
#define BNB_WARP_SIZE 64
#else
#define BNB_WARP_SIZE 32
#endif

// These are set based on current BNB support for CDNA 2 & RDNA 3. Update as needed for future archs
#define BNB_MAX_THREADS_PER_SM 2048
#define BNB_MAX_THREADS_PER_CU 2048
#define BNB_BF16_AVAILABLE true
51 changes: 25 additions & 26 deletions csrc/kernels.hip
Original file line number Diff line number Diff line change
Expand Up @@ -1933,7 +1933,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
// rowStats [rows]
// out [rows, cols]
template<typename T, int THREADS, int SPARSE_DECOMP>
__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024)
__launch_bounds__(1024, BNB_MAX_THREADS_PER_CU / 1024)
__global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols) {

// For sm50/sm52 and CUDA < 12.2 we need to do the reduction in fp32.
Expand Down Expand Up @@ -1997,7 +1997,7 @@ __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStat
}

template<typename T, int THREADS, int SPARSE_DECOMP>
__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024)
__launch_bounds__(1024, BNB_MAX_THREADS_PER_CU / 1024)
__global__ void kgetRowStats(T * __restrict__ A, float *rowStats, float threshold, int rows, int cols) {
using BlockReduceT = hipcub::BlockReduce<float, THREADS>;

Expand Down Expand Up @@ -2109,7 +2109,6 @@ __global__ void kdequant_mm_int32_fp16(
#define DENORM 1.0f/127.0f
#define MAX_SPARSE_COUNT 32
#define SMEM_SIZE 8*256
#define WARP_SIZE warpSize
template <typename T, int SPMM_ITEMS, int BITS>
__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB)
{
Expand All @@ -2130,9 +2129,9 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
const int offset = local_max_idx == 0 ? 0 : offset_rowidx[local_max_idx-1];
const int local_row_idx = rowidx[offset];

const int warp_id = threadIdx.x / WARP_SIZE;
const int warp_idx = threadIdx.x % WARP_SIZE;
const int warp_offset = (warp_id*WARP_SIZE)*SPMM_ITEMS;
const int warp_id = threadIdx.x / BNB_WARP_SIZE;
const int warp_idx = threadIdx.x % BNB_WARP_SIZE;
const int warp_offset = (warp_id*BNB_WARP_SIZE)*SPMM_ITEMS;
const int num_items = BITS == 8 ? 8 : 8;
int idx_col_B = warp_offset;
int local_idx_col_B_offset = 0;
Expand All @@ -2152,7 +2151,7 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
}

// each thread processes SPMM_ITEMS=32 per iteration. We have 256 threads. 32*256=x192
// we expect each warp to be SPMM_ITEMS*WARP_SIZE apart
// we expect each warp to be SPMM_ITEMS*BNB_WARP_SIZE apart
// we have a total of 128 bytes for the bank with a bank size of 4 bytes
// added 3 bytes = 6 values between warps should reduce bank conflicts
__shared__ half smem_dequant_stats[SMEM_SIZE];
Expand Down Expand Up @@ -2705,16 +2704,16 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
{

// per threadblock:
// load step-by-step in chunks of [warp_size,warps]: 1xwarp_size * [warp_size,warps] -> [1,warps]
// load step-by-step in chunks of [BNB_WARP_SIZE,warps]: 1xBNB_WARP_SIZE * [BNB_WARP_SIZE,warps] -> [1,warps]
// 4 warps -> 4 loads per iter
// 1xwarp_size * warp_sizex4 -> 1x4 outputs per thread block
typedef hipcub::WarpReduce<float, warpSize> WarpReduce;
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS/warpSize];

const int warp_idx = threadIdx.x / warpSize;
const int warp_lane = threadIdx.x % warpSize;
const int row_B = (THREADS/warpSize)*blockIdx.x + warp_idx;
const int offset_B = ldb*row_B;
// 1xBNB_WARP_SIZE * BNB_WARP_SIZEx4 -> 1x4 outputs per thread block
typedef hipcub::WarpReduce<float, BNB_WARP_SIZE> WarpReduce;
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS/BNB_WARP_SIZE];

const int warp_idx = threadIdx.x / BNB_WARP_SIZE;
const int warp_lane = threadIdx.x % BNB_WARP_SIZE;
const int row_B = (THREADS/BNB_WARP_SIZE)*blockIdx.x + warp_idx;
const int offset_B = ldb * row_B;
const int num_values_8bit = num_values_4bit/2;
float local_C = 0.0f;

Expand All @@ -2732,7 +2731,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc

// A: [1, K]
// B: [M, K]
for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += warpSize*num_values_4bit)
for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += BNB_WARP_SIZE*num_values_4bit)
{
const int inner_idx_halved = inner_idx/2;

Expand Down Expand Up @@ -3044,7 +3043,7 @@ MAKE_kQuantizeBlockwise(half, 1024, 4, 0, General8bit)
MAKE_kQuantizeBlockwise(half, 512, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(half, 256, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(half, 128, 2, 0, General8bit)
#if WARP_SIZE == 32
#if BNB_WARP_SIZE == 32
MAKE_kQuantizeBlockwise(half, 64, 2, 0, General8bit)
#endif

Expand All @@ -3054,7 +3053,7 @@ MAKE_kQuantizeBlockwise(half, 1024, 4, 0, FP4)
MAKE_kQuantizeBlockwise(half, 512, 2, 0, FP4)
MAKE_kQuantizeBlockwise(half, 256, 2, 0, FP4)
MAKE_kQuantizeBlockwise(half, 128, 2, 0, FP4)
#if WARP_SIZE == 32
#if BNB_WARP_SIZE == 32
MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4)
#endif

Expand All @@ -3064,7 +3063,7 @@ MAKE_kQuantizeBlockwise(half, 1024, 4, 0, NF4)
MAKE_kQuantizeBlockwise(half, 512, 2, 0, NF4)
MAKE_kQuantizeBlockwise(half, 256, 2, 0, NF4)
MAKE_kQuantizeBlockwise(half, 128, 2, 0, NF4)
#if WARP_SIZE == 32
#if BNB_WARP_SIZE == 32
MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4)
#endif

Expand All @@ -3075,7 +3074,7 @@ MAKE_kQuantizeBlockwise(float, 1024, 4, 0, General8bit)
MAKE_kQuantizeBlockwise(float, 512, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(float, 256, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(float, 128, 2, 0, General8bit)
#if WARP_SIZE == 32
#if BNB_WARP_SIZE == 32
MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit)
#endif

Expand All @@ -3085,7 +3084,7 @@ MAKE_kQuantizeBlockwise(float, 1024, 4, 0, FP4)
MAKE_kQuantizeBlockwise(float, 512, 2, 0, FP4)
MAKE_kQuantizeBlockwise(float, 256, 2, 0, FP4)
MAKE_kQuantizeBlockwise(float, 128, 2, 0, FP4)
#if WARP_SIZE == 32
#if BNB_WARP_SIZE == 32
MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4)
#endif

Expand All @@ -3095,7 +3094,7 @@ MAKE_kQuantizeBlockwise(float, 1024, 4, 0, NF4)
MAKE_kQuantizeBlockwise(float, 512, 2, 0, NF4)
MAKE_kQuantizeBlockwise(float, 256, 2, 0, NF4)
MAKE_kQuantizeBlockwise(float, 128, 2, 0, NF4)
#if WARP_SIZE == 32
#if BNB_WARP_SIZE == 32
MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4)
#endif

Expand All @@ -3106,7 +3105,7 @@ MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, General8bit)
MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, General8bit)
#if WARP_SIZE == 32
#if BNB_WARP_SIZE == 32
MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, General8bit)
#endif

Expand All @@ -3116,7 +3115,7 @@ MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, FP4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, FP4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, FP4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, FP4)
#if WARP_SIZE == 32
#if BNB_WARP_SIZE == 32
MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, FP4)
#endif

Expand All @@ -3126,7 +3125,7 @@ MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, NF4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, NF4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, NF4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, NF4)
#if WARP_SIZE == 32
#if BNB_WARP_SIZE == 32
MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, NF4)
#endif

Expand Down
2 changes: 1 addition & 1 deletion csrc/ops.hip
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,7 @@ template <typename T, int BITS> void gemm_4bit_inference_naive(int m, int n, int
//warpsize - 32
int num_blocks = (m+3)/4;
//warpsize - 64
if (warpSize == 64) {
if (BNB_WARP_SIZE == 64) {
num_blocks = (m+1)/2;
}

Expand Down