1919#define NUM 4
2020#define NUM_BLOCK 4096
2121
22- __device__ static float nf4_data[16 ] = {-1.0 , -0.6961928009986877 , -0.5250730514526367 , -0.39491748809814453 , -0.28444138169288635 , -0.18477343022823334 , -0.09105003625154495 , 0.0 , 0.07958029955625534 , 0.16093020141124725 , 0.24611230194568634 , 0.33791524171829224 , 0.44070982933044434 , 0.5626170039176941 , 0.7229568362236023 , 1.0 };
22+ __device__ static float fp4_dequantization_lut[8 ] = {
23+ 0 .0f , // 0b000
24+ 0 .005208333333f , // 0b001
25+ 0 .66666667f , // 0b010
26+ 1 .0f , // 0b011
27+ 0 .33333333f , // 0b100
28+ 0 .5f , // 0b101
29+ 0 .16666667f , // 0b110
30+ 0 .25f // 0b111
31+ };
32+
33+ __device__ static float nf4_dequantization_lut[16 ] = {
34+ -1 .0f , // 0b0000
35+ -0 .6961928009986877f , // 0b0001
36+ -0 .5250730514526367f , // 0b0010
37+ -0 .39491748809814453f , // 0b0011
38+ -0 .28444138169288635f , // 0b0100
39+ -0 .18477343022823334f , // 0b0101
40+ -0 .09105003625154495f , // 0b0110
41+ 0 .0f , // 0b0111
42+ 0 .07958029955625534f , // 0b1000
43+ 0 .16093020141124725f , // 0b1001
44+ 0 .24611230194568634f , // 0b1010
45+ 0 .33791524171829224f , // 0b1011
46+ 0 .44070982933044434f , // 0b1100
47+ 0 .5626170039176941f , // 0b1101
48+ 0 .7229568362236023f , // 0b1110
49+ 1 .0f // 0b1111
50+ };
2351
2452// source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda
2553// Luckily we have atomicmax and atomicmin in ROCm
2654
27-
28- __device__ float dDequantizeFP4Tree (unsigned char val, float absmax)
29- {
30- float sign = (val & 0b1000 ) == 8 ? -1 .0f : 1 .0f ;
31- if ((val & 0b0100 ) == 4 ) // 0
32- if ((val & 0b0010 ) == 2 ) // 01
33- if ((val & 0b0001 ) == 1 ) // 111
34- return 0 .25000000f *absmax*sign; // 1111
35- else
36- return 0 .16666667f *absmax*sign; // 1110
37- else
38- if ((val & 0b0001 ) == 1 ) // 110
39- return 0 .50000000f *absmax*sign; // 1101
40- else
41- return 0 .33333333f *absmax*sign; // 1100
42- else
43- if ((val & 0b0010 ) == 2 ) // 10
44- if ((val & 0b0001 ) == 1 ) // 101
45- return 1 .00000000f *absmax*sign; // 1011
46- else
47- return 0 .66666667f *absmax*sign; // 1010
48- else
49- if ((val & 0b0001 ) == 1 ) // 100
50- return 5 .208333333e-03f *absmax*sign; // 1001
51- else
52- return 0 .00000000f *absmax*sign; // 1000
55+ __device__ __forceinline__ float dDequantizeFP4Tree (unsigned char val) {
56+ float sign = 1 .0f - 2 * ((val & 0b1000 ) >> 3 );
57+ return fp4_dequantization_lut[val & 0b111 ] * sign;
5358}
5459
5560__device__ unsigned char dQuantizeFP4 (float x)
@@ -101,61 +106,7 @@ __device__ unsigned char dQuantizeFP4(float x)
101106 return 0b0000 +sign;
102107}
103108
104-
105- __device__ __forceinline__ float dDequantizeNF4 (unsigned char val)
106- {
107-
108- // the values for this tree was generated by test_normal_map_tree
109- // in the file tests/test_functional.py
110- if ((val & 0b1000 ) == 8 )
111- if ((val & 0b0100 ) == 4 ) // 1
112- if ((val & 0b0010 ) == 2 ) // 11
113- if ((val & 0b0001 ) == 1 ) // 111
114- return 1 .0f ;
115- else
116- return 0 .7229568362236023f ;
117- else
118- if ((val & 0b0001 ) == 1 ) // 110
119- return 0 .5626170039176941f ;
120- else
121- return 0 .44070982933044434f ;
122- else
123- if ((val & 0b0010 ) == 2 ) // 10
124- if ((val & 0b0001 ) == 1 ) // 101
125- return 0 .33791524171829224f ;
126- else
127- return 0 .24611230194568634f ;
128- else
129- if ((val & 0b0001 ) == 1 ) // 100
130- return 0 .16093020141124725f ;
131- else
132- return 0 .07958029955625534f ;
133-
134- else
135- if ((val & 0b0100 ) == 4 ) // 0
136- if ((val & 0b0010 ) == 2 ) // 01
137- if ((val & 0b0001 ) == 1 ) // 011
138- return 0 .0f ;
139- else
140- return -0 .09105003625154495f ;
141- else
142- if ((val & 0b0001 ) == 1 ) // 010
143- return -0 .18477343022823334f ;
144- else
145- return -0 .28444138169288635f ;
146- else
147- if ((val & 0b0010 ) == 2 ) // 00
148- if ((val & 0b0001 ) == 1 ) // 001
149- return -0 .39491748809814453f ;
150- else
151- return -0 .5250730514526367f ;
152- else
153- if ((val & 0b0001 ) == 1 ) // 000
154- return -0 .6961928009986877f ;
155- else
156- return -1 .0f ;
157-
158- }
109+ __device__ __forceinline__ float dDequantizeNF4 (unsigned char val) { return nf4_dequantization_lut[val & 0x0F ]; }
159110
160111__device__ unsigned char dQuantizeNF4 (float x)
161112{
@@ -456,7 +407,6 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
456407 LoadFloat (loadf).Load (&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0 );
457408 }
458409
459- unsigned char packed_4bit = 0 ;
460410 switch (DATA_TYPE)
461411 {
462412 case General8bit:
@@ -473,18 +423,16 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
473423 #pragma unroll NUM_PER_TH
474424 for (int j = 0 ; j < NUM_PER_TH/2 ; j++)
475425 {
476- packed_4bit |= dQuantizeFP4 (((float )vals[2 *j])*local_abs_max) << 4 ;
477- packed_4bit |= dQuantizeFP4 (((float )vals[2 *j+1 ])*local_abs_max);
478- qvals[j] = packed_4bit;
426+ qvals[j] = dQuantizeFP4 (((float )vals[2 * j]) * local_abs_max) << 4 ;
427+ qvals[j] |= dQuantizeFP4 (((float )vals[2 * j + 1 ]) * local_abs_max);
479428 }
480429 break ;
481430 case NF4:
482431 #pragma unroll NUM_PER_TH
483432 for (int j = 0 ; j < NUM_PER_TH/2 ; j++)
484433 {
485- packed_4bit |= dQuantizeNF4 (((float )vals[2 *j])*local_abs_max) << 4 ;
486- packed_4bit |= dQuantizeNF4 (((float )vals[2 *j+1 ])*local_abs_max);
487- qvals[j] = packed_4bit;
434+ qvals[j] = dQuantizeNF4 (((float )vals[2 * j]) * local_abs_max) << 4 ;
435+ qvals[j] |= dQuantizeNF4 (((float )vals[2 * j + 1 ]) * local_abs_max);
488436 }
489437 break ;
490438 }
@@ -546,8 +494,8 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs
546494 #pragma unroll NUM_PER_TH
547495 for (int j = 0 ; j < NUM_PER_TH; j++)
548496 {
549- vals[j* 2 ] = dDequantizeFP4Tree (qvals[j] >> 4 , local_abs_max) ;
550- vals[j* 2 + 1 ] = dDequantizeFP4Tree (qvals[j] & 0x0F , local_abs_max) ;
497+ vals[j * 2 ] = dDequantizeFP4Tree (qvals[j] >> 4 ) * local_abs_max;
498+ vals[j * 2 + 1 ] = dDequantizeFP4Tree (qvals[j] & 0x0F ) * local_abs_max;
551499 }
552500 break ;
553501 case NF4:
@@ -2109,7 +2057,11 @@ __global__ void kdequant_mm_int32_fp16(
21092057#define DENORM 1 .0f /127 .0f
21102058#define MAX_SPARSE_COUNT 32
21112059#define SMEM_SIZE 8 *256
2112- #define WARP_SIZE warpSize
2060+ #if defined(__GFX9__)
2061+ #define WARP_SIZE 64
2062+ #else
2063+ #define WARP_SIZE 32
2064+ #endif
21132065template <typename T, int SPMM_ITEMS, int BITS>
21142066__global__ void kspmm_coo_very_sparse_naive (int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB)
21152067{
@@ -2503,7 +2455,7 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
25032455
25042456 #pragma unroll 16
25052457 for (int i = 0 ; i < 16 ; i++)
2506- quant_map[i] = nf4_data [i];
2458+ quant_map[i] = nf4_dequantization_lut [i];
25072459 // __shared__ T quant_map[16*160];
25082460
25092461 T local_A[2 ];
@@ -2708,13 +2660,13 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
27082660 // load step-by-step in chunks of [warp_size,warps]: 1xwarp_size * [warp_size,warps] -> [1,warps]
27092661 // 4 warps -> 4 loads per iter
27102662 // 1xwarp_size * warp_sizex4 -> 1x4 outputs per thread block
2711- typedef hipcub::WarpReduce<float , warpSize > WarpReduce;
2712- __shared__ typename WarpReduce::TempStorage temp_storage[THREADS/warpSize ];
2663+ typedef hipcub::WarpReduce<float , WARP_SIZE > WarpReduce;
2664+ __shared__ typename WarpReduce::TempStorage temp_storage[THREADS/WARP_SIZE ];
27132665
2714- const int warp_idx = threadIdx.x / warpSize ;
2715- const int warp_lane = threadIdx.x % warpSize ;
2716- const int row_B = (THREADS/warpSize )*blockIdx.x + warp_idx;
2717- const int offset_B = ldb* row_B;
2666+ const int warp_idx = threadIdx.x / WARP_SIZE ;
2667+ const int warp_lane = threadIdx.x % WARP_SIZE ;
2668+ const int row_B = (THREADS/WARP_SIZE )*blockIdx.x + warp_idx;
2669+ const int offset_B = ldb * row_B;
27182670 const int num_values_8bit = num_values_4bit/2 ;
27192671 float local_C = 0 .0f ;
27202672
@@ -2732,7 +2684,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
27322684
27332685 // A: [1, K]
27342686 // B: [M, K]
2735- for (int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += warpSize *num_values_4bit)
2687+ for (int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += WARP_SIZE *num_values_4bit)
27362688 {
27372689 const int inner_idx_halved = inner_idx/2 ;
27382690
0 commit comments