Skip to content

Commit eb3b816

Browse files
Merge pull request #1207 from ROCm/device_abstraction
BitsandBytes Enablement on ROCm
2 parents 701c5aa + 410f499 commit eb3b816

23 files changed

+6028
-54
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,4 @@ repos:
2121
rev: v1.18.2
2222
hooks:
2323
- id: typos
24+
exclude: ^.*\.hip$

CMakeLists.txt

Lines changed: 79 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# For GCC: `cmake -B build . && cmake --build build`
44
# For MSVC: `cmake -B build . && cmake --build build --config Release`
55
# You can also use the following options and variables
6-
# - COMPUTE_BACKEND: Set to `cpu`, `cuda`, or `mps` to select the backend
6+
# - COMPUTE_BACKEND: Set to `cpu`, `cuda`, `hip` or `mps` to select the backend
77
# - NO_CUBLASLT: Default OFF, will skip building/linking CUBLASLT support
88
# - CUDA_VERSION: The expected CUDA version, for sanity checking. The actual version
99
# is whatever CMake finds on your path.
@@ -26,13 +26,14 @@ endif()
2626
# Define included source files
2727
set(CPP_FILES csrc/common.cpp csrc/cpu_ops.cpp csrc/pythonInterface.cpp)
2828
set(CUDA_FILES csrc/ops.cu csrc/kernels.cu)
29+
set(HIP_FILES csrc/ops.hip csrc/kernels.hip)
2930
set(MPS_FILES csrc/mps_ops.mm)
3031
set(METAL_FILES csrc/mps_kernels.metal)
3132
# C++ sources are always included
3233
list(APPEND SRC_FILES ${CPP_FILES})
3334

34-
set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, mps)")
35-
set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda mps)
35+
set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, hip, mps)")
36+
set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda hip mps)
3637
option(PTXAS_VERBOSE "Pass through -v flag to PTX Assembler" OFF)
3738

3839
if(APPLE)
@@ -49,16 +50,28 @@ if(${COMPUTE_BACKEND} STREQUAL "cuda")
4950
endif()
5051
option(NO_CUBLASLT "Disable CUBLAS" OFF)
5152
set(BUILD_CUDA ON)
53+
set(BUILD_HIP OFF)
54+
set(BUILD_MPS OFF)
55+
message(STATUS "NO_CUBLASLT := ${NO_CUBLASLT}")
56+
elseif(${COMPUTE_BACKEND} STREQUAL "hip")
57+
if(APPLE)
58+
message(FATAL_ERROR "HIP is not supported on macOS" )
59+
endif()
60+
option(NO_CUBLASLT "Disable HIPBLASLT" OFF)
61+
set(BUILD_CUDA OFF)
62+
set(BUILD_HIP ON)
5263
set(BUILD_MPS OFF)
5364
message(STATUS "NO_CUBLASLT := ${NO_CUBLASLT}")
5465
elseif(${COMPUTE_BACKEND} STREQUAL "mps")
5566
if(NOT APPLE)
5667
message(FATAL_ERROR "MPS is only supported on macOS" )
5768
endif()
5869
set(BUILD_CUDA OFF)
70+
set(BUILD_HIP OFF)
5971
set(BUILD_MPS ON)
6072
else()
6173
set(BUILD_CUDA OFF)
74+
set(BUILD_HIP OFF)
6275
set(BUILD_MPS OFF)
6376
endif()
6477

@@ -158,6 +171,34 @@ if(BUILD_CUDA)
158171
string(APPEND BNB_OUTPUT_NAME "_nocublaslt")
159172
endif()
160173
add_compile_definitions(BUILD_CUDA)
174+
elseif(BUILD_HIP)
175+
enable_language(HIP)
176+
message(STATUS "HIP Compiler: ${CMAKE_HIP_COMPILER}")
177+
if(DEFINED BNB_ROCM_ARCH)
178+
set(CMAKE_HIP_ARCHITECTURES ${BNB_ROCM_ARCH})
179+
else()
180+
if (NOT AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
181+
set(CMAKE_HIP_ARCHITECTURES "gfx908;gfx90a;gfx940;gfx941;gfx942")
182+
elseif (AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
183+
set(CMAKE_HIP_ARCHITECTURES ${AMDGPU_TARGETS})
184+
endif()
185+
endif()
186+
message(STATUS "HIP Targets: ${CMAKE_HIP_ARCHITECTURES}")
187+
188+
list(APPEND SRC_FILES ${HIP_FILES})
189+
190+
string(APPEND BNB_OUTPUT_NAME "_hip")
191+
192+
# get hip version
193+
execute_process(COMMAND hipconfig --version OUTPUT_VARIABLE HIP_CONFIG_VERSION)
194+
string(REGEX MATCH "[0-9]+\\.[0-9]+" HIP_VERSION "${HIP_CONFIG_VERSION}")
195+
196+
if(NO_CUBLASLT OR HIP_VERSION VERSION_LESS "6.1")
197+
string(APPEND BNB_OUTPUT_NAME "_nohipblaslt")
198+
endif()
199+
add_compile_definitions(__HIP_PLATFORM_AMD__)
200+
add_compile_definitions(__HIP_PLATFORM_HCC__)
201+
add_compile_definitions(BUILD_HIP)
161202
elseif(BUILD_MPS)
162203
if(NOT APPLE)
163204
message(FATAL_ERROR "MPS is only supported on macOS" )
@@ -213,6 +254,41 @@ if(BUILD_CUDA)
213254
CUDA_SEPARABLE_COMPILATION ON
214255
)
215256
endif()
257+
if(BUILD_HIP)
258+
if(NOT DEFINED ENV{ROCM_PATH})
259+
set(ROCM_PATH /opt/rocm)
260+
else()
261+
set(ROCM_PATH $ENV{ROCM_PATH})
262+
endif()
263+
list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH})
264+
macro(find_package_and_print_version PACKAGE_NAME)
265+
find_package("${PACKAGE_NAME}" ${ARGN})
266+
message("${PACKAGE_NAME} VERSION: ${${PACKAGE_NAME}_VERSION}")
267+
endmacro()
268+
find_package_and_print_version(hipblas REQUIRED)
269+
find_package_and_print_version(hiprand REQUIRED)
270+
find_package_and_print_version(hipsparse REQUIRED)
271+
272+
## hacky way of excluding hip::amdhip64 (with it linked many tests unexpectedly fail e.g. adam8bit because of inaccuracies)
273+
set_target_properties(hip::host PROPERTIES INTERFACE_LINK_LIBRARIES "")
274+
set_target_properties(hip-lang::host PROPERTIES INTERFACE_LINK_LIBRARIES "")
275+
set(CMAKE_HIP_IMPLICIT_LINK_LIBRARIES "")
276+
277+
target_include_directories(bitsandbytes PRIVATE ${CMAKE_SOURCE_DIR} ${CMAKE_SOURCE_DIR}/include ${ROCM_PATH}/include /include)
278+
target_link_directories(bitsandbytes PRIVATE ${ROCM_PATH}/lib /lib)
279+
target_link_libraries(bitsandbytes PUBLIC roc::hipblas hip::hiprand roc::hipsparse)
280+
281+
target_compile_definitions(bitsandbytes PUBLIC BNB_USE_HIP)
282+
set_source_files_properties(${HIP_FILES} PROPERTIES LANGUAGE HIP)
283+
set_target_properties(bitsandbytes PROPERTIES LINKER_LANGUAGE CXX)
284+
285+
if(NO_CUBLASLT OR HIP_VERSION VERSION_LESS "6.1")
286+
target_compile_definitions(bitsandbytes PUBLIC NO_HIPBLASLT)
287+
else()
288+
find_package(hipblaslt)
289+
target_link_libraries(bitsandbytes PUBLIC roc::hipblaslt)
290+
endif()
291+
endif()
216292
if(BUILD_MPS)
217293
add_dependencies(bitsandbytes metallib)
218294
target_link_libraries(bitsandbytes objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph")

bitsandbytes/autograd/_functions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import torch
99

10+
from bitsandbytes.cextension import BNB_HIP_VERSION
1011
import bitsandbytes.functional as F
1112

1213

@@ -222,6 +223,8 @@ def supports_igemmlt(device: torch.device) -> bool:
222223
"""check if this device supports the optimized int8 kernel"""
223224
if device == torch.device("cpu"):
224225
return True
226+
if torch.version.hip:
227+
return False if BNB_HIP_VERSION < 601 else True
225228
if torch.cuda.get_device_capability(device=device) < (7, 5):
226229
return False
227230
device_name = torch.cuda.get_device_name(device=device)

bitsandbytes/backends/cuda.py

Lines changed: 73 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import torch
55

6-
from bitsandbytes.cextension import lib
6+
from bitsandbytes.cextension import HIP_ENVIRONMENT, lib
77
from bitsandbytes.functional import (
88
CUBLAS_Context,
99
coo_zeros,
@@ -14,6 +14,7 @@
1414
get_ptr,
1515
get_transform_buffer,
1616
is_on_gpu,
17+
nvidia_transform,
1718
post_call,
1819
pre_call,
1920
prod,
@@ -184,6 +185,11 @@ def transform(
184185
state: Optional[Tuple[torch.Size, str]] = None,
185186
ld=None,
186187
):
188+
if HIP_ENVIRONMENT:
189+
# transform kernel formats (col32/col_turing/col_ampere) are not applicable to ROCm
190+
# Use nvidia_transform instead
191+
return nvidia_transform(A, to_order, from_order, out, transpose, state, ld)
192+
187193
prev_device = pre_call(A.device)
188194
if state is None:
189195
state = (A.shape, from_order)
@@ -266,19 +272,33 @@ def igemmlt(
266272
return torch.empty(tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16)
267273

268274
if dimsA == 2 and out is None:
269-
out, Sout = get_transform_buffer((shapeA[0], shapeB[0]), dtype, A.device, "col32", "row")
275+
if HIP_ENVIRONMENT:
276+
# Use col format for HIP
277+
out, Sout = get_transform_buffer((shapeA[0], shapeB[0]), dtype, A.device, "col", "row")
278+
else:
279+
out, Sout = get_transform_buffer((shapeA[0], shapeB[0]), dtype, A.device, "col32", "row")
270280
elif dimsA == 3 and out is None:
271-
out, Sout = get_transform_buffer((shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col32", "row")
281+
if HIP_ENVIRONMENT:
282+
# Use col format for HIP
283+
out, Sout = get_transform_buffer((shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col", "row")
284+
else:
285+
out, Sout = get_transform_buffer((shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col32", "row")
272286

273287
assert dimsB != 3, "len(B.shape)==3 not supported"
274288
assert A.device.type == "cuda"
275289
assert B.device.type == "cuda"
276290
assert A.dtype == torch.int8
277291
assert B.dtype == torch.int8
278292
assert out.dtype == dtype
279-
assert SA[1] == "col32"
280-
assert SB[1] in ["col_turing", "col_ampere"]
281-
assert Sout[1] == "col32"
293+
if HIP_ENVIRONMENT:
294+
# Use col format for HIP
295+
assert SA[1] == "col"
296+
assert SB[1] == "col"
297+
assert Sout[1] == "col"
298+
else:
299+
assert SA[1] == "col32"
300+
assert SB[1] in ["col_turing", "col_ampere"]
301+
assert Sout[1] == "col32"
282302
assert (
283303
shapeA[-1] == shapeB[-1]
284304
), f"Matmullt only supports A @ B^T. Inner matrix dimensions do not match: A @ B = {shapeA} @ {shapeB}"
@@ -293,17 +313,23 @@ def igemmlt(
293313
ptrC = get_ptr(out)
294314

295315
k = shapeA[-1]
296-
lda = ct.c_int32(m * 32)
297-
if formatB == "col_turing":
298-
# turing: tiles with rows filled up to multiple of 8 rows by 32 columns
299-
# n = rows
300-
ldb = ct.c_int32(((rows + 7) // 8) * 8 * 32)
316+
if HIP_ENVIRONMENT:
317+
# Set ld values for col format
318+
lda = ct.c_int32(m)
319+
ldb = ct.c_int32(shapeB[0])
320+
ldc = ct.c_int32(m)
301321
else:
302-
# ampere: tiles with rows filled up to multiple of 32 rows by 32 columns
303-
# n = rows
304-
ldb = ct.c_int32(((rows + 31) // 32) * 32 * 32)
322+
lda = ct.c_int32(m * 32)
323+
if formatB == "col_turing":
324+
# turing: tiles with rows filled up to multiple of 8 rows by 32 columns
325+
# n = rows
326+
ldb = ct.c_int32(((rows + 7) // 8) * 8 * 32)
327+
else:
328+
# ampere: tiles with rows filled up to multiple of 32 rows by 32 columns
329+
# n = rows
330+
ldb = ct.c_int32(((rows + 31) // 32) * 32 * 32)
305331

306-
ldc = ct.c_int32(m * 32)
332+
ldc = ct.c_int32(m * 32)
307333
m = ct.c_int32(m)
308334
n = ct.c_int32(n)
309335
k = ct.c_int32(k)
@@ -312,7 +338,7 @@ def igemmlt(
312338
ptrRowScale = get_ptr(None)
313339
is_on_gpu([A, B, out])
314340

315-
if formatB == "col_turing":
341+
if formatB == "col_turing" or HIP_ENVIRONMENT:
316342
if dtype == torch.int32:
317343
has_error = lib.cigemmlt_turing_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
318344
else:
@@ -324,7 +350,7 @@ def igemmlt(
324350
else:
325351
has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
326352

327-
if has_error == 100: # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu`
353+
if has_error == 100: # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu`, `ops.hip`
328354
raise NotImplementedError("igemmlt not available (probably built with NO_CUBLASLT)")
329355

330356
if has_error:
@@ -348,6 +374,9 @@ def mm_dequant(
348374
new_col_stats: Optional[torch.Tensor] = None,
349375
bias: Optional[torch.Tensor] = None,
350376
):
377+
if HIP_ENVIRONMENT:
378+
# HIP kernel requires 'row' format
379+
A, quant_state = nvidia_transform(A, "row", state=quant_state)
351380
assert A.dtype == torch.int32
352381
if bias is not None:
353382
assert bias.dtype == torch.float16
@@ -386,7 +415,11 @@ def mm_dequant(
386415
def extract_outliers(self, A: torch.Tensor, SA: Tuple[torch.Size, str], idx: torch.Tensor):
387416
shapeA = SA[0]
388417
formatA = SA[1]
389-
assert formatA in ["col_turing", "col_ampere"]
418+
if not HIP_ENVIRONMENT:
419+
assert formatA in ["col_turing", "col_ampere"]
420+
else:
421+
# HIP uses col format
422+
assert formatA in ["col"]
390423
assert A.device.type == "cuda"
391424

392425
out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device)
@@ -400,7 +433,7 @@ def extract_outliers(self, A: torch.Tensor, SA: Tuple[torch.Size, str], idx: tor
400433

401434
prev_device = pre_call(A.device)
402435

403-
if formatA == "col_turing":
436+
if formatA == "col_turing" or HIP_ENVIRONMENT:
404437
lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
405438
elif formatA == "col_ampere":
406439
lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
@@ -414,11 +447,15 @@ def quantize_4bit(
414447
A: torch.Tensor,
415448
absmax: Optional[torch.Tensor] = None,
416449
out: Optional[torch.Tensor] = None,
417-
blocksize=64,
450+
blocksize: Optional[int] = None,
418451
compress_statistics=False,
419452
quant_type: Literal["fp4", "nf4"] = "fp4",
420453
quant_storage=torch.uint8,
421454
) -> Tuple[torch.Tensor, QuantState]:
455+
if blocksize is None:
456+
# Some AMD GPUs have warpsize 64
457+
# Set default blocksize to 128 (~warpsize 64 in kernel) for HIP
458+
blocksize = 64 if not HIP_ENVIRONMENT else 128
422459
if A.device.type != "cuda":
423460
raise NotImplementedError(f"Device type not supported for FP4 quantization: {A.device.type}")
424461
if quant_type not in ["fp4", "nf4"]:
@@ -436,7 +473,12 @@ def quantize_4bit(
436473
mod = dtype2bytes[quant_storage] * 2
437474
out = torch.zeros(((n + 1) // mod, 1), dtype=quant_storage, device=A.device)
438475

439-
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
476+
# Some AMD GPUs have warpsize 64
477+
# Set min blocksize to 128 (~warpsize 64 in kernel) for HIP
478+
if not HIP_ENVIRONMENT:
479+
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
480+
else:
481+
assert blocksize in [4096, 2048, 1024, 512, 256, 128]
440482

441483
prev_device = pre_call(A.device)
442484
is_on_gpu([A, out, absmax])
@@ -507,12 +549,19 @@ def dequantize_4bit(
507549
quant_state: Optional[QuantState] = None,
508550
absmax: Optional[torch.Tensor] = None,
509551
out: Optional[torch.Tensor] = None,
510-
blocksize: int = 64,
552+
blocksize: Optional[int] = None,
511553
quant_type: Literal["fp4", "nf4"] = "fp4",
512554
) -> torch.Tensor:
513-
if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]:
555+
# Some AMD GPUs have warpsize 64
556+
# Set default blocksize to 128 (~warpsize 64 in kernel) for HIP
557+
if blocksize is None:
558+
blocksize = 64 if not HIP_ENVIRONMENT else 128
559+
supported_blocksizes = [2048, 4096, 1024, 512, 256, 128, 64]
560+
if HIP_ENVIRONMENT:
561+
supported_blocksizes = supported_blocksizes[:-1]
562+
if blocksize not in supported_blocksizes:
514563
raise ValueError(
515-
f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]"
564+
f"The blockwise of {blocksize} is not supported. Supported values: {supported_blocksizes}"
516565
)
517566

518567
if quant_type not in ["fp4", "nf4"]:

0 commit comments

Comments
 (0)