Skip to content

Commit 7524c09

Browse files
authored
Merge pull request #87 from sstamenk/rocm_enabled_warpsize_fix
warpSize is being made non constexpr in ROCm 7.0
2 parents 48a551f + 7d4854e commit 7524c09

File tree

2 files changed

+19
-9
lines changed

2 files changed

+19
-9
lines changed

csrc/kernels.hip

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2109,7 +2109,11 @@ __global__ void kdequant_mm_int32_fp16(
21092109
#define DENORM 1.0f/127.0f
21102110
#define MAX_SPARSE_COUNT 32
21112111
#define SMEM_SIZE 8*256
2112-
#define WARP_SIZE warpSize
2112+
#if defined(__GFX9__)
2113+
#define WARP_SIZE 64
2114+
#else
2115+
#define WARP_SIZE 32
2116+
#endif
21132117
template <typename T, int SPMM_ITEMS, int BITS>
21142118
__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB)
21152119
{
@@ -2708,13 +2712,13 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
27082712
// load step-by-step in chunks of [warp_size,warps]: 1xwarp_size * [warp_size,warps] -> [1,warps]
27092713
// 4 warps -> 4 loads per iter
27102714
// 1xwarp_size * warp_sizex4 -> 1x4 outputs per thread block
2711-
typedef hipcub::WarpReduce<float, warpSize> WarpReduce;
2712-
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS/warpSize];
2715+
typedef hipcub::WarpReduce<float, WARP_SIZE> WarpReduce;
2716+
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS/WARP_SIZE];
27132717

2714-
const int warp_idx = threadIdx.x / warpSize;
2715-
const int warp_lane = threadIdx.x % warpSize;
2716-
const int row_B = (THREADS/warpSize)*blockIdx.x + warp_idx;
2717-
const int offset_B = ldb*row_B;
2718+
const int warp_idx = threadIdx.x / WARP_SIZE;
2719+
const int warp_lane = threadIdx.x % WARP_SIZE;
2720+
const int row_B = (THREADS/WARP_SIZE)*blockIdx.x + warp_idx;
2721+
const int offset_B = ldb * row_B;
27182722
const int num_values_8bit = num_values_4bit/2;
27192723
float local_C = 0.0f;
27202724

@@ -2732,7 +2736,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
27322736

27332737
// A: [1, K]
27342738
// B: [M, K]
2735-
for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += warpSize*num_values_4bit)
2739+
for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += WARP_SIZE*num_values_4bit)
27362740
{
27372741
const int inner_idx_halved = inner_idx/2;
27382742

csrc/ops.hip

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@
2020

2121
#define ERR_NOT_IMPLEMENTED 100
2222

23+
#if defined(__GFX9__)
24+
#define WARP_SIZE 64
25+
#else
26+
#define WARP_SIZE 32
27+
#endif
28+
2329
using namespace BinSearch;
2430
using std::cout;
2531
using std::endl;
@@ -692,7 +698,7 @@ template <typename T, int BITS> void gemm_4bit_inference_naive(int m, int n, int
692698
//warpsize - 32
693699
int num_blocks = (m+3)/4;
694700
//warpsize - 64
695-
if (warpSize == 64) {
701+
if (WARP_SIZE == 64) {
696702
num_blocks = (m+1)/2;
697703
}
698704

0 commit comments

Comments
 (0)