diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index 30cad3e34..abc1fd223 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -8,7 +8,7 @@ from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr from ..._ops import register_kernel -from ...cextension import HIP_ENVIRONMENT, lib +from ...cextension import ROCM_WARP_SIZE_64, lib @register_kernel("bitsandbytes::int8_linear_matmul", "cuda") @@ -211,7 +211,7 @@ def _get_col_absmax( def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: torch._check_is_size(blocksize) - if HIP_ENVIRONMENT: + if ROCM_WARP_SIZE_64: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) else: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) @@ -269,7 +269,7 @@ def _( def _dequantize_blockwise_impl( A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor ) -> None: - if HIP_ENVIRONMENT: + if ROCM_WARP_SIZE_64: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) else: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) @@ -303,7 +303,7 @@ def _dequantize_blockwise_impl( def _( A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype ) -> tuple[torch.Tensor, torch.Tensor]: - if HIP_ENVIRONMENT: + if ROCM_WARP_SIZE_64: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) else: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) @@ -385,7 +385,7 @@ def _dequantize_4bit_impl( dtype: torch.dtype, out: torch.Tensor, ) -> None: - if HIP_ENVIRONMENT: + if ROCM_WARP_SIZE_64: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) else: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 2eb584a66..188576225 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -9,7 +9,13 @@ import torch from bitsandbytes.consts import DYNAMIC_LIBRARY_SUFFIX, PACKAGE_DIR -from bitsandbytes.cuda_specs import CUDASpecs, get_cuda_specs, get_cuda_version_tuple, get_rocm_gpu_arch +from bitsandbytes.cuda_specs import ( + CUDASpecs, + get_cuda_specs, + get_cuda_version_tuple, + get_rocm_gpu_arch, + get_rocm_warpsize, +) logger = logging.getLogger(__name__) @@ -298,6 +304,7 @@ def get_native_library() -> BNBNativeLibrary: ROCM_GPU_ARCH = get_rocm_gpu_arch() +ROCM_WARP_SIZE_64 = True if get_rocm_warpsize() == 64 else False HIP_ENVIRONMENT = False BNB_BACKEND = "CPU" diff --git a/bitsandbytes/cuda_specs.py b/bitsandbytes/cuda_specs.py index 32563a159..71e7568a9 100644 --- a/bitsandbytes/cuda_specs.py +++ b/bitsandbytes/cuda_specs.py @@ -100,3 +100,29 @@ def get_rocm_gpu_arch() -> str: """, ) return "unknown" + + +def get_rocm_warpsize() -> int: + """Get ROCm warp size.""" + logger = logging.getLogger(__name__) + try: + if torch.version.hip: + result = subprocess.run(["rocminfo"], capture_output=True, text=True) + match = re.search(r"Wavefront Size:\s+([0-9]{2})\(0x[0-9]{2}\)", result.stdout) + if match: + return int(match.group(1)) + else: + # default to 64 to be safe + return 64 + else: + # nvidia cards always use 32 warp size + return 32 + except Exception as e: + logger.error(f"Could not detect ROCm warp size: {e}. Defaulting to 64. (some 4-bit functions may not work!)") + if torch.cuda.is_available(): + logger.warning( + """ +ROCm warp size detection failed despite ROCm being available. + """, + ) + return 64 diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 7cca33dcf..f603c01df 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -15,7 +15,7 @@ from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict -from .cextension import HIP_ENVIRONMENT, lib +from .cextension import ROCM_WARP_SIZE_64, lib name2qmap = {} @@ -806,7 +806,7 @@ def quantize_fp4( quant_storage=torch.uint8, ): if blocksize is None: - blocksize = 64 if not HIP_ENVIRONMENT else 128 + blocksize = 64 if not ROCM_WARP_SIZE_64 else 128 return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "fp4", quant_storage) @@ -819,7 +819,7 @@ def quantize_nf4( quant_storage=torch.uint8, ): if blocksize is None: - blocksize = 64 if not HIP_ENVIRONMENT else 128 + blocksize = 64 if not ROCM_WARP_SIZE_64 else 128 return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "nf4", quant_storage) @@ -857,7 +857,7 @@ def quantize_4bit( """ if blocksize is None: - blocksize = 64 if not HIP_ENVIRONMENT else 128 + blocksize = 64 if not ROCM_WARP_SIZE_64 else 128 input_shape = A.shape @@ -912,7 +912,7 @@ def dequantize_fp4( blocksize: Optional[int] = None, ) -> torch.Tensor: if blocksize is None: - blocksize = 64 if not HIP_ENVIRONMENT else 128 + blocksize = 64 if not ROCM_WARP_SIZE_64 else 128 return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4") @@ -924,7 +924,7 @@ def dequantize_nf4( blocksize: Optional[int] = None, ) -> torch.Tensor: if blocksize is None: - blocksize = 64 if not HIP_ENVIRONMENT else 128 + blocksize = 64 if not ROCM_WARP_SIZE_64 else 128 return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4") @@ -964,7 +964,7 @@ def dequantize_4bit( """ if blocksize is None: - blocksize = 64 if not HIP_ENVIRONMENT else 128 + blocksize = 64 if not ROCM_WARP_SIZE_64 else 128 if quant_state is None: assert absmax is not None and out is not None diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index c36fb68a6..96cd356c0 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -11,7 +11,7 @@ import torch.nn.functional as F import bitsandbytes as bnb -from bitsandbytes.cextension import HIP_ENVIRONMENT +from bitsandbytes.cextension import ROCM_WARP_SIZE_64 from bitsandbytes.functional import QuantState from bitsandbytes.optim import GlobalOptimManager from bitsandbytes.utils import INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, OutlierTracer @@ -221,7 +221,7 @@ def __new__( data = torch.empty(0) if blocksize is None: - blocksize = 64 if not HIP_ENVIRONMENT else 128 + blocksize = 64 if not ROCM_WARP_SIZE_64 else 128 self = torch.Tensor._make_subclass(cls, data, requires_grad) self.blocksize = blocksize diff --git a/csrc/common_hip.cuh b/csrc/common_hip.cuh index 1d9d9afe0..3ea545e9a 100644 --- a/csrc/common_hip.cuh +++ b/csrc/common_hip.cuh @@ -1,7 +1,11 @@ #pragma once -#define BNB_WARP_SIZE warpSize +#ifdef __GFX9__ + #define BNB_WARP_SIZE 64 +#else + #define BNB_WARP_SIZE 32 +#endif // These are set based on current BNB support for CDNA 2 & RDNA 3. Update as needed for future archs -#define BNB_MAX_THREADS_PER_SM 2048 +#define BNB_MAX_THREADS_PER_CU 2048 #define BNB_BF16_AVAILABLE true diff --git a/csrc/kernels.hip b/csrc/kernels.hip index bef6cffa6..d7a5d21e4 100644 --- a/csrc/kernels.hip +++ b/csrc/kernels.hip @@ -1881,7 +1881,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char // rowStats [rows] // out [rows, cols] template -__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) +__launch_bounds__(1024, BNB_MAX_THREADS_PER_CU / 1024) __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols) { // For sm50/sm52 and CUDA < 12.2 we need to do the reduction in fp32. @@ -1945,7 +1945,7 @@ __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStat } template -__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) +__launch_bounds__(1024, BNB_MAX_THREADS_PER_CU / 1024) __global__ void kgetRowStats(T * __restrict__ A, float *rowStats, float threshold, int rows, int cols) { using BlockReduceT = hipcub::BlockReduce; @@ -2057,11 +2057,6 @@ __global__ void kdequant_mm_int32_fp16( #define DENORM 1.0f/127.0f #define MAX_SPARSE_COUNT 32 #define SMEM_SIZE 8*256 -#if defined(__GFX9__) - #define WARP_SIZE 64 -#else - #define WARP_SIZE 32 -#endif template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB) { @@ -2082,9 +2077,9 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o const int offset = local_max_idx == 0 ? 0 : offset_rowidx[local_max_idx-1]; const int local_row_idx = rowidx[offset]; - const int warp_id = threadIdx.x / WARP_SIZE; - const int warp_idx = threadIdx.x % WARP_SIZE; - const int warp_offset = (warp_id*WARP_SIZE)*SPMM_ITEMS; + const int warp_id = threadIdx.x / BNB_WARP_SIZE; + const int warp_idx = threadIdx.x % BNB_WARP_SIZE; + const int warp_offset = (warp_id*BNB_WARP_SIZE)*SPMM_ITEMS; const int num_items = BITS == 8 ? 8 : 8; int idx_col_B = warp_offset; int local_idx_col_B_offset = 0; @@ -2104,7 +2099,7 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o } // each thread processes SPMM_ITEMS=32 per iteration. We have 256 threads. 32*256=x192 - // we expect each warp to be SPMM_ITEMS*WARP_SIZE apart + // we expect each warp to be SPMM_ITEMS*BNB_WARP_SIZE apart // we have a total of 128 bytes for the bank with a bank size of 4 bytes // added 3 bytes = 6 values between warps should reduce bank conflicts __shared__ half smem_dequant_stats[SMEM_SIZE]; @@ -2657,15 +2652,15 @@ template __global__ void kgemm_4bit_inferenc { // per threadblock: - // load step-by-step in chunks of [warp_size,warps]: 1xwarp_size * [warp_size,warps] -> [1,warps] + // load step-by-step in chunks of [BNB_WARP_SIZE,warps]: 1xBNB_WARP_SIZE * [BNB_WARP_SIZE,warps] -> [1,warps] // 4 warps -> 4 loads per iter - // 1xwarp_size * warp_sizex4 -> 1x4 outputs per thread block - typedef hipcub::WarpReduce WarpReduce; - __shared__ typename WarpReduce::TempStorage temp_storage[THREADS/WARP_SIZE]; + // 1xBNB_WARP_SIZE * BNB_WARP_SIZEx4 -> 1x4 outputs per thread block + typedef hipcub::WarpReduce WarpReduce; + __shared__ typename WarpReduce::TempStorage temp_storage[THREADS/BNB_WARP_SIZE]; - const int warp_idx = threadIdx.x / WARP_SIZE; - const int warp_lane = threadIdx.x % WARP_SIZE; - const int row_B = (THREADS/WARP_SIZE)*blockIdx.x + warp_idx; + const int warp_idx = threadIdx.x / BNB_WARP_SIZE; + const int warp_lane = threadIdx.x % BNB_WARP_SIZE; + const int row_B = (THREADS/BNB_WARP_SIZE)*blockIdx.x + warp_idx; const int offset_B = ldb * row_B; const int num_values_8bit = num_values_4bit/2; float local_C = 0.0f; @@ -2684,7 +2679,7 @@ template __global__ void kgemm_4bit_inferenc // A: [1, K] // B: [M, K] - for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += WARP_SIZE*num_values_4bit) + for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += BNB_WARP_SIZE*num_values_4bit) { const int inner_idx_halved = inner_idx/2; @@ -2996,7 +2991,9 @@ MAKE_kQuantizeBlockwise(half, 1024, 4, 0, General8bit) MAKE_kQuantizeBlockwise(half, 512, 2, 0, General8bit) MAKE_kQuantizeBlockwise(half, 256, 2, 0, General8bit) MAKE_kQuantizeBlockwise(half, 128, 2, 0, General8bit) -//MAKE_kQuantizeBlockwise(half, 64, 2, 0, General8bit) +#if BNB_WARP_SIZE == 32 + MAKE_kQuantizeBlockwise(half, 64, 2, 0, General8bit) +#endif MAKE_kQuantizeBlockwise(half, 4096, 4, 0, FP4) MAKE_kQuantizeBlockwise(half, 2048, 4, 0, FP4) @@ -3004,7 +3001,9 @@ MAKE_kQuantizeBlockwise(half, 1024, 4, 0, FP4) MAKE_kQuantizeBlockwise(half, 512, 2, 0, FP4) MAKE_kQuantizeBlockwise(half, 256, 2, 0, FP4) MAKE_kQuantizeBlockwise(half, 128, 2, 0, FP4) -//MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4) +#if BNB_WARP_SIZE == 32 + MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4) +#endif MAKE_kQuantizeBlockwise(half, 4096, 4, 0, NF4) MAKE_kQuantizeBlockwise(half, 2048, 4, 0, NF4) @@ -3012,7 +3011,9 @@ MAKE_kQuantizeBlockwise(half, 1024, 4, 0, NF4) MAKE_kQuantizeBlockwise(half, 512, 2, 0, NF4) MAKE_kQuantizeBlockwise(half, 256, 2, 0, NF4) MAKE_kQuantizeBlockwise(half, 128, 2, 0, NF4) -//MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4) +#if BNB_WARP_SIZE == 32 + MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4) +#endif MAKE_kQuantizeBlockwise(float, 4096, 4, 0, General8bit) MAKE_kQuantizeBlockwise(float, 4096, 4, 1, General8bit) @@ -3021,7 +3022,9 @@ MAKE_kQuantizeBlockwise(float, 1024, 4, 0, General8bit) MAKE_kQuantizeBlockwise(float, 512, 2, 0, General8bit) MAKE_kQuantizeBlockwise(float, 256, 2, 0, General8bit) MAKE_kQuantizeBlockwise(float, 128, 2, 0, General8bit) -//MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit) +#if BNB_WARP_SIZE == 32 + MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit) +#endif MAKE_kQuantizeBlockwise(float, 4096, 4, 0, FP4) MAKE_kQuantizeBlockwise(float, 2048, 4, 0, FP4) @@ -3029,7 +3032,9 @@ MAKE_kQuantizeBlockwise(float, 1024, 4, 0, FP4) MAKE_kQuantizeBlockwise(float, 512, 2, 0, FP4) MAKE_kQuantizeBlockwise(float, 256, 2, 0, FP4) MAKE_kQuantizeBlockwise(float, 128, 2, 0, FP4) -//MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4) +#if BNB_WARP_SIZE == 32 + MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4) +#endif MAKE_kQuantizeBlockwise(float, 4096, 4, 0, NF4) MAKE_kQuantizeBlockwise(float, 2048, 4, 0, NF4) @@ -3037,7 +3042,9 @@ MAKE_kQuantizeBlockwise(float, 1024, 4, 0, NF4) MAKE_kQuantizeBlockwise(float, 512, 2, 0, NF4) MAKE_kQuantizeBlockwise(float, 256, 2, 0, NF4) MAKE_kQuantizeBlockwise(float, 128, 2, 0, NF4) -//MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4) +#if BNB_WARP_SIZE == 32 + MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4) +#endif MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 0, General8bit) MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 1, General8bit) @@ -3046,7 +3053,9 @@ MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, General8bit) MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, General8bit) MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, General8bit) MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, General8bit) -//MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, General8bit) +#if BNB_WARP_SIZE == 32 + MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, General8bit) +#endif MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 0, FP4) MAKE_kQuantizeBlockwise(hip_bfloat16, 2048, 4, 0, FP4) @@ -3054,7 +3063,9 @@ MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, FP4) MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, FP4) MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, FP4) MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, FP4) -//MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, FP4) +#if BNB_WARP_SIZE == 32 + MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, FP4) +#endif MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 0, NF4) MAKE_kQuantizeBlockwise(hip_bfloat16, 2048, 4, 0, NF4) @@ -3062,7 +3073,9 @@ MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, NF4) MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, NF4) MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, NF4) MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, NF4) -//MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, NF4) +#if BNB_WARP_SIZE == 32 + MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, NF4) +#endif template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); diff --git a/csrc/ops.hip b/csrc/ops.hip index b26d138e1..a7ab32fdc 100644 --- a/csrc/ops.hip +++ b/csrc/ops.hip @@ -5,6 +5,7 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. +#include #include #include #include @@ -20,12 +21,6 @@ #define ERR_NOT_IMPLEMENTED 100 -#if defined(__GFX9__) - #define WARP_SIZE 64 -#else - #define WARP_SIZE 32 -#endif - using namespace BinSearch; using std::cout; using std::endl; @@ -63,8 +58,8 @@ template void quantizeBlockwise(floa hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(128), 0, 0, code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 128) hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(64), 0, 0, code, A, absmax, out, rand, rand_offset, n); - //else if(blocksize == 64) - // hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(32), 0, 0, code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 64 && BNB_WARP_SIZE == 32) + hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(32), 0, 0, code, A, absmax, out, rand, rand_offset, n); CUDA_CHECK_RETURN(hipPeekAtLastError()); @@ -698,7 +693,7 @@ template void gemm_4bit_inference_naive(int m, int n, int //warpsize - 32 int num_blocks = (m+3)/4; //warpsize - 64 - if (WARP_SIZE == 64) { + if (BNB_WARP_SIZE == 64) { num_blocks = (m+1)/2; } diff --git a/tests/test_functional.py b/tests/test_functional.py index 072e3b4f5..762298852 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -10,7 +10,7 @@ import bitsandbytes as bnb from bitsandbytes import functional as F -from bitsandbytes.cextension import HIP_ENVIRONMENT +from bitsandbytes.cextension import HIP_ENVIRONMENT, ROCM_GPU_ARCH, ROCM_WARP_SIZE_64 from tests.helpers import ( BOOLEAN_TUPLES, TRUE_FALSE, @@ -96,7 +96,7 @@ class Test8BitBlockwiseQuantizeFunctional: @pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested")) @pytest.mark.parametrize( "blocksize", - [4096, 2048, 1024, 512, 256, 128, 64] if not HIP_ENVIRONMENT else [4096, 2048, 1024, 512, 256, 128], + [4096, 2048, 1024, 512, 256, 128, 64] if not ROCM_WARP_SIZE_64 else [4096, 2048, 1024, 512, 256, 128], ) @pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed")) def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, signed): @@ -1109,7 +1109,7 @@ class TestQuantize4BitFunctional: @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize( "blocksize", - [64, 128, 256, 512, 1024, 2048, 4096] if not HIP_ENVIRONMENT else [128, 256, 512, 1024, 2048, 4096], + [64, 128, 256, 512, 1024, 2048, 4096] if not ROCM_WARP_SIZE_64 else [128, 256, 512, 1024, 2048, 4096], ) def test_4bit_quant(self, device, dtype, quant_type, blocksize): if device == "hpu" and not is_supported_on_hpu(quant_type, dtype): @@ -1180,7 +1180,7 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize): @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128], ids=id_formatter("blocksize")) + @pytest.mark.parametrize("blocksize", [64, 128] if not ROCM_WARP_SIZE_64 else [128], ids=id_formatter("blocksize")) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=describe_dtype) def test_4bit_compressed_stats(self, device, quant_type, blocksize, dtype): if device == "hpu" and not is_supported_on_hpu(quant_type, dtype): @@ -1247,7 +1247,7 @@ def test_bench_4bit_dequant(self, quant_type): # print((time.time()-t0)/iters*1e6) @pytest.mark.skipif( - HIP_ENVIRONMENT, reason="gemv 4bit tests are partially enabled on MI300, others being fixed for warpsize 64" + ROCM_WARP_SIZE_64, reason="gemv 4bit tests are partially enabled on MI300, others being fixed for warpsize 64" ) @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("double_quant", TRUE_FALSE, ids=lambda double_quant: f"DQ_{double_quant}") diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index 1c5e77a32..398fb83d3 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -8,7 +8,7 @@ import torch import bitsandbytes as bnb -from bitsandbytes.cextension import HIP_ENVIRONMENT +from bitsandbytes.cextension import ROCM_WARP_SIZE_64 from tests.helpers import ( TRUE_FALSE, describe_dtype, @@ -192,7 +192,7 @@ def test_linear_serialization( @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) -@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) +@pytest.mark.parametrize("blocksize", [64, 128] if not ROCM_WARP_SIZE_64 else [128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) def test_copy_param(device, quant_type, blocksize, compress_statistics): if device == "hpu" and not is_supported_on_hpu(quant_type): @@ -249,7 +249,7 @@ def test_params4bit_torch_chunk_split(device, quant_type): @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) -@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) +@pytest.mark.parametrize("blocksize", [64, 128] if not ROCM_WARP_SIZE_64 else [128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) def test_deepcopy_param(device, quant_type, blocksize, compress_statistics): if device == "hpu" and not is_supported_on_hpu(quant_type): @@ -278,7 +278,7 @@ def test_deepcopy_param(device, quant_type, blocksize, compress_statistics): @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) -@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) +@pytest.mark.parametrize("blocksize", [64, 128] if not ROCM_WARP_SIZE_64 else [128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) def test_params4bit_real_serialization(device, quant_type, blocksize, compress_statistics): if device == "hpu" and not is_supported_on_hpu(quant_type): diff --git a/tests/test_ops.py b/tests/test_ops.py index 02472630e..da589005e 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -4,7 +4,7 @@ import torch import bitsandbytes -from bitsandbytes.cextension import HIP_ENVIRONMENT +from bitsandbytes.cextension import ROCM_WARP_SIZE_64 from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter, is_supported_on_hpu # torch.library.opcheck is only available in torch 2.4 and later. @@ -102,7 +102,7 @@ def test_int8_scaled_mm(self, device, dtype, has_bias): class TestInt8BlockwiseQuantOps: @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) - @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) + @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512]) def test_quantize_blockwise(self, device, dtype, blocksize): if device == "cpu": if dtype != torch.float32: @@ -126,7 +126,7 @@ def test_quantize_blockwise(self, device, dtype, blocksize): @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) - @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) + @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512]) def test_dequantize_blockwise(self, device, dtype, blocksize): if device == "cpu" and dtype != torch.float32: pytest.skip("CPU implementation is only available for float32") @@ -152,7 +152,7 @@ class Test4bitBlockwiseQuantOps: @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) + @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512]) def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype): pytest.skip("This configuration is not supported on HPU.") @@ -176,7 +176,7 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) + @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512]) def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype): pytest.skip("This configuration is not supported on HPU.") @@ -210,8 +210,8 @@ def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksi @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) - @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") + @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512]) + @pytest.mark.skipif(ROCM_WARP_SIZE_64, reason="this test is not supported on ROCm yet") def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype): pytest.skip("This configuration is not supported on HPU.")