Skip to content

Commit b72b766

Browse files
pnunna93MISHANMAURYAamcamdPrasanth Nunnasstamenk
authored
Fix for warpSize deprecation in ROCm 7.0 (bitsandbytes-foundation#1762)
* Port ROCm changes from multi-backend-refactor branch * Update ops.py * Update functional.py * Update ops.py * Update ops.py * Update ops.py * Update ops.py * Update functional.py * Update ops.py * Update ops.py * Update ops.py * Update ops.py * Update functional.py * Update functional.py * Update functional.py * Update functional.py * Update ops.py * Update ops.py * Update ops.py * Update ops.py * Update ops.py * Update ops.py * Update ops.py * Update ops.py * Update ops.py * Update functional.py * Update functional.py * Update functional.py * Update test_ops.py * Update test_functional.py * Update test_ops.py * Update test_functional.py * Update test_functional.py * Update functional.py * Update functional.py * Update ops.py * Update ops.py * Update test_functional.py * Update test_functional.py * Update cextension.py * Update cuda_specs.py * Update cuda_specs.py * Update test_functional.py * Update test_linear4bit.py * Update test_cuda_setup_evaluator.py * Update test_functional.py * Update modules.py * Update modules.py * Update ops.py * Update test_linear4bit.py * Update ops.py * Update ops.py * Update test_linear4bit.py * Update test_linear4bit.py * Update python-package.yml * Update python-package.yml * Update python-package.yml * Update python-package.yml * Create build-rocm.sh * Update cuda_specs.py * Fix trailing whitespace * Remove conflicts.diff * update for hipblasVersionMajor >=3 * Update test_functional.py * Update test_linear4bit.py * Update test_ops.py * Update main.py * Update test_functional.py * Update test_linear4bit.py * Update test_ops.py * Update test_linear4bit.py * Lint * Lint * Update helpers.py * Update test_functional.py * Update test_linear4bit.py * Update test_ops.py * Lint * Update pythonInterface.cpp * lint fix * lint * Update pythonInterface.cpp * revert permissions change * Fix indentation * Update kernels_hip.cuh * Update kernels.hip * Update ops.hip * Update ops_hip.cuh * Update kernels_hip.cuh * Update kernels.hip * Update kernels.hip * Update ops.hip * Update ops_hip.cuh * Update ops.hip * Update CMakeLists.txt * Update functional.py * Update cextension.py * Update cextension.py * warpSize is being made non constexpr in ROCm 7.0 * Merge pull request ROCm#90 from ROCm/IFU-rocm_enabled-09-23-2025 Ifu rocm enabled 09 23 2025 * Fix typo * unskip test_4bit_quant --------- Co-authored-by: MISHANMAURYA <[email protected]> Co-authored-by: MISHANMAUYRA <[email protected]> Co-authored-by: amcamd <[email protected]> Co-authored-by: Prasanth Nunna <[email protected]> Co-authored-by: sstamenk <[email protected]>
1 parent bdb8b2b commit b72b766

File tree

6 files changed

+66
-106
lines changed

6 files changed

+66
-106
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ elseif(${COMPUTE_BACKEND} STREQUAL "xpu")
7070
message(FATAL_ERROR "XPU is not supported on macOS" )
7171
endif()
7272
set(BUILD_CUDA OFF)
73+
set(BUILD_HIP OFF)
7374
set(BUILD_MPS OFF)
7475
set(BUILD_XPU ON)
7576
else()

csrc/kernels.hip

Lines changed: 52 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -19,37 +19,42 @@
1919
#define NUM 4
2020
#define NUM_BLOCK 4096
2121

22-
__device__ static float nf4_data[16] = {-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0};
22+
__device__ static float fp4_dequantization_lut[8] = {
23+
0.0f, // 0b000
24+
0.005208333333f, // 0b001
25+
0.66666667f, // 0b010
26+
1.0f, // 0b011
27+
0.33333333f, // 0b100
28+
0.5f, // 0b101
29+
0.16666667f, // 0b110
30+
0.25f // 0b111
31+
};
32+
33+
__device__ static float nf4_dequantization_lut[16] = {
34+
-1.0f, // 0b0000
35+
-0.6961928009986877f, // 0b0001
36+
-0.5250730514526367f, // 0b0010
37+
-0.39491748809814453f, // 0b0011
38+
-0.28444138169288635f, // 0b0100
39+
-0.18477343022823334f, // 0b0101
40+
-0.09105003625154495f, // 0b0110
41+
0.0f, // 0b0111
42+
0.07958029955625534f, // 0b1000
43+
0.16093020141124725f, // 0b1001
44+
0.24611230194568634f, // 0b1010
45+
0.33791524171829224f, // 0b1011
46+
0.44070982933044434f, // 0b1100
47+
0.5626170039176941f, // 0b1101
48+
0.7229568362236023f, // 0b1110
49+
1.0f // 0b1111
50+
};
2351

2452
// source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda
2553
// Luckily we have atomicmax and atomicmin in ROCm
2654

27-
28-
__device__ float dDequantizeFP4Tree(unsigned char val, float absmax)
29-
{
30-
float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f;
31-
if((val & 0b0100) == 4) // 0
32-
if((val & 0b0010) == 2) //01
33-
if((val & 0b0001) == 1) // 111
34-
return 0.25000000f*absmax*sign; // 1111
35-
else
36-
return 0.16666667f*absmax*sign; // 1110
37-
else
38-
if((val & 0b0001) == 1) // 110
39-
return 0.50000000f*absmax*sign; // 1101
40-
else
41-
return 0.33333333f*absmax*sign; // 1100
42-
else
43-
if((val & 0b0010) == 2) //10
44-
if((val & 0b0001) == 1) // 101
45-
return 1.00000000f*absmax*sign; // 1011
46-
else
47-
return 0.66666667f*absmax*sign; // 1010
48-
else
49-
if((val & 0b0001) == 1) // 100
50-
return 5.208333333e-03f*absmax*sign; // 1001
51-
else
52-
return 0.00000000f*absmax*sign; // 1000
55+
__device__ __forceinline__ float dDequantizeFP4Tree(unsigned char val) {
56+
float sign = 1.0f - 2 * ((val & 0b1000) >> 3);
57+
return fp4_dequantization_lut[val & 0b111] * sign;
5358
}
5459

5560
__device__ unsigned char dQuantizeFP4(float x)
@@ -101,61 +106,7 @@ __device__ unsigned char dQuantizeFP4(float x)
101106
return 0b0000+sign;
102107
}
103108

104-
105-
__device__ __forceinline__ float dDequantizeNF4(unsigned char val)
106-
{
107-
108-
// the values for this tree was generated by test_normal_map_tree
109-
// in the file tests/test_functional.py
110-
if((val & 0b1000) == 8)
111-
if((val & 0b0100) == 4) // 1
112-
if((val & 0b0010) == 2) // 11
113-
if((val & 0b0001) == 1) // 111
114-
return 1.0f;
115-
else
116-
return 0.7229568362236023f;
117-
else
118-
if((val & 0b0001) == 1) // 110
119-
return 0.5626170039176941f;
120-
else
121-
return 0.44070982933044434f;
122-
else
123-
if((val & 0b0010) == 2) //10
124-
if((val & 0b0001) == 1) // 101
125-
return 0.33791524171829224f;
126-
else
127-
return 0.24611230194568634f;
128-
else
129-
if((val & 0b0001) == 1) // 100
130-
return 0.16093020141124725f;
131-
else
132-
return 0.07958029955625534f;
133-
134-
else
135-
if((val & 0b0100) == 4) // 0
136-
if((val & 0b0010) == 2) //01
137-
if((val & 0b0001) == 1) // 011
138-
return 0.0f;
139-
else
140-
return -0.09105003625154495f;
141-
else
142-
if((val & 0b0001) == 1) // 010
143-
return -0.18477343022823334f;
144-
else
145-
return -0.28444138169288635f;
146-
else
147-
if((val & 0b0010) == 2) //00
148-
if((val & 0b0001) == 1) // 001
149-
return -0.39491748809814453f;
150-
else
151-
return -0.5250730514526367f;
152-
else
153-
if((val & 0b0001) == 1) // 000
154-
return -0.6961928009986877f;
155-
else
156-
return -1.0f;
157-
158-
}
109+
__device__ __forceinline__ float dDequantizeNF4(unsigned char val) { return nf4_dequantization_lut[val & 0x0F]; }
159110

160111
__device__ unsigned char dQuantizeNF4(float x)
161112
{
@@ -456,7 +407,6 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
456407
LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0);
457408
}
458409

459-
unsigned char packed_4bit = 0;
460410
switch(DATA_TYPE)
461411
{
462412
case General8bit:
@@ -473,18 +423,16 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
473423
#pragma unroll NUM_PER_TH
474424
for(int j = 0; j < NUM_PER_TH/2; j++)
475425
{
476-
packed_4bit |= dQuantizeFP4(((float)vals[2*j])*local_abs_max) << 4;
477-
packed_4bit |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max);
478-
qvals[j] = packed_4bit;
426+
qvals[j] = dQuantizeFP4(((float)vals[2 * j]) * local_abs_max) << 4;
427+
qvals[j] |= dQuantizeFP4(((float)vals[2 * j + 1]) * local_abs_max);
479428
}
480429
break;
481430
case NF4:
482431
#pragma unroll NUM_PER_TH
483432
for(int j = 0; j < NUM_PER_TH/2; j++)
484433
{
485-
packed_4bit |= dQuantizeNF4(((float)vals[2*j])*local_abs_max) << 4;
486-
packed_4bit |= dQuantizeNF4(((float)vals[2*j+1])*local_abs_max);
487-
qvals[j] = packed_4bit;
434+
qvals[j] = dQuantizeNF4(((float)vals[2 * j]) * local_abs_max) << 4;
435+
qvals[j] |= dQuantizeNF4(((float)vals[2 * j + 1]) * local_abs_max);
488436
}
489437
break;
490438
}
@@ -546,8 +494,8 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs
546494
#pragma unroll NUM_PER_TH
547495
for(int j = 0; j < NUM_PER_TH; j++)
548496
{
549-
vals[j*2] = dDequantizeFP4Tree(qvals[j] >> 4, local_abs_max);
550-
vals[j*2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F, local_abs_max);
497+
vals[j * 2] = dDequantizeFP4Tree(qvals[j] >> 4) * local_abs_max;
498+
vals[j * 2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F) * local_abs_max;
551499
}
552500
break;
553501
case NF4:
@@ -2109,7 +2057,11 @@ __global__ void kdequant_mm_int32_fp16(
21092057
#define DENORM 1.0f/127.0f
21102058
#define MAX_SPARSE_COUNT 32
21112059
#define SMEM_SIZE 8*256
2112-
#define WARP_SIZE warpSize
2060+
#if defined(__GFX9__)
2061+
#define WARP_SIZE 64
2062+
#else
2063+
#define WARP_SIZE 32
2064+
#endif
21132065
template <typename T, int SPMM_ITEMS, int BITS>
21142066
__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB)
21152067
{
@@ -2503,7 +2455,7 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
25032455

25042456
#pragma unroll 16
25052457
for(int i = 0; i < 16; i++)
2506-
quant_map[i] = nf4_data[i];
2458+
quant_map[i] = nf4_dequantization_lut[i];
25072459
//__shared__ T quant_map[16*160];
25082460

25092461
T local_A[2];
@@ -2708,13 +2660,13 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
27082660
// load step-by-step in chunks of [warp_size,warps]: 1xwarp_size * [warp_size,warps] -> [1,warps]
27092661
// 4 warps -> 4 loads per iter
27102662
// 1xwarp_size * warp_sizex4 -> 1x4 outputs per thread block
2711-
typedef hipcub::WarpReduce<float, warpSize> WarpReduce;
2712-
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS/warpSize];
2663+
typedef hipcub::WarpReduce<float, WARP_SIZE> WarpReduce;
2664+
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS/WARP_SIZE];
27132665

2714-
const int warp_idx = threadIdx.x / warpSize;
2715-
const int warp_lane = threadIdx.x % warpSize;
2716-
const int row_B = (THREADS/warpSize)*blockIdx.x + warp_idx;
2717-
const int offset_B = ldb*row_B;
2666+
const int warp_idx = threadIdx.x / WARP_SIZE;
2667+
const int warp_lane = threadIdx.x % WARP_SIZE;
2668+
const int row_B = (THREADS/WARP_SIZE)*blockIdx.x + warp_idx;
2669+
const int offset_B = ldb * row_B;
27182670
const int num_values_8bit = num_values_4bit/2;
27192671
float local_C = 0.0f;
27202672

@@ -2732,7 +2684,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
27322684

27332685
// A: [1, K]
27342686
// B: [M, K]
2735-
for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += warpSize*num_values_4bit)
2687+
for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += WARP_SIZE*num_values_4bit)
27362688
{
27372689
const int inner_idx_halved = inner_idx/2;
27382690

csrc/ops.hip

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@
2020

2121
#define ERR_NOT_IMPLEMENTED 100
2222

23+
#if defined(__GFX9__)
24+
#define WARP_SIZE 64
25+
#else
26+
#define WARP_SIZE 32
27+
#endif
28+
2329
using namespace BinSearch;
2430
using std::cout;
2531
using std::endl;
@@ -692,7 +698,7 @@ template <typename T, int BITS> void gemm_4bit_inference_naive(int m, int n, int
692698
//warpsize - 32
693699
int num_blocks = (m+3)/4;
694700
//warpsize - 64
695-
if (warpSize == 64) {
701+
if (WARP_SIZE == 64) {
696702
num_blocks = (m+1)/2;
697703
}
698704

tests/test_functional.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import bitsandbytes as bnb
1212
from bitsandbytes import functional as F
13-
from bitsandbytes.cextension import HIP_ENVIRONMENT, ROCM_GPU_ARCH
13+
from bitsandbytes.cextension import HIP_ENVIRONMENT
1414
from tests.helpers import (
1515
BOOLEAN_TUPLES,
1616
TRUE_FALSE,
@@ -463,6 +463,7 @@ def test_dim3_igemm(self, seq_dim, hidden_dim, batch_dim):
463463
@pytest.mark.parametrize("hidden_dim", [32, 1024 * 4], ids=id_formatter("hidden_dim"))
464464
@pytest.mark.parametrize("batch_dim", [2, 16], ids=id_formatter("batch_dim"))
465465
@pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose"))
466+
@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
466467
def test_minmax_igemm(self, seq_dim, hidden_dim, batch_dim, transpose):
467468
def min_max(x):
468469
maxA = torch.amax(x, dim=2, keepdim=True)
@@ -1408,10 +1409,7 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double
14081409
@pytest.mark.parametrize("device", get_available_devices())
14091410
@pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"])
14101411
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
1411-
@pytest.mark.skipif(
1412-
HIP_ENVIRONMENT and ROCM_GPU_ARCH == "gfx90a",
1413-
reason="this test is not supported on ROCm with gfx90a architecture yet",
1414-
)
1412+
@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
14151413
def test_gemv_eye_4bit(self, device, storage_type, dtype):
14161414
if device == "hpu" and not is_supported_on_hpu(storage_type, dtype):
14171415
pytest.skip("This configuration is not supported on HPU.")

tests/test_linear8bitlt.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch
1010

1111
import bitsandbytes as bnb
12+
from bitsandbytes.cextension import HIP_ENVIRONMENT
1213
from bitsandbytes.nn.modules import Linear8bitLt
1314
from tests.helpers import (
1415
TRUE_FALSE,
@@ -233,6 +234,7 @@ def test_linear8bit_serialization(linear8bit):
233234
@pytest.mark.parametrize("fullgraph", TRUE_FALSE, ids=id_formatter("fullgraph"))
234235
@pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode"))
235236
@pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4")
237+
@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
236238
def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode):
237239
if device == "cuda" and platform.system() == "Windows":
238240
pytest.skip("Triton is not officially supported on Windows")

tests/test_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksi
211211
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
212212
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
213213
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512])
214+
@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
214215
def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
215216
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype):
216217
pytest.skip("This configuration is not supported on HPU.")

0 commit comments

Comments
 (0)