33
44import torch
55
6- from bitsandbytes .cextension import lib
6+ from bitsandbytes .cextension import HIP_ENVIRONMENT , lib
77from bitsandbytes .functional import (
88 CUBLAS_Context ,
99 coo_zeros ,
1414 get_ptr ,
1515 get_transform_buffer ,
1616 is_on_gpu ,
17+ nvidia_transform ,
1718 post_call ,
1819 pre_call ,
1920 prod ,
@@ -184,6 +185,11 @@ def transform(
184185 state : Optional [Tuple [torch .Size , str ]] = None ,
185186 ld = None ,
186187 ):
188+ if HIP_ENVIRONMENT :
189+ # transform kernel formats (col32/col_turing/col_ampere) are not applicable to ROCm
190+ # Use nvidia_transform instead
191+ return nvidia_transform (A , to_order , from_order , out , transpose , state , ld )
192+
187193 prev_device = pre_call (A .device )
188194 if state is None :
189195 state = (A .shape , from_order )
@@ -266,19 +272,33 @@ def igemmlt(
266272 return torch .empty (tuple (shapeA [:2 ] + [shapeB [0 ]]), device = A .device , dtype = torch .float16 )
267273
268274 if dimsA == 2 and out is None :
269- out , Sout = get_transform_buffer ((shapeA [0 ], shapeB [0 ]), dtype , A .device , "col32" , "row" )
275+ if HIP_ENVIRONMENT :
276+ # Use col format for HIP
277+ out , Sout = get_transform_buffer ((shapeA [0 ], shapeB [0 ]), dtype , A .device , "col" , "row" )
278+ else :
279+ out , Sout = get_transform_buffer ((shapeA [0 ], shapeB [0 ]), dtype , A .device , "col32" , "row" )
270280 elif dimsA == 3 and out is None :
271- out , Sout = get_transform_buffer ((shapeA [0 ], shapeA [1 ], shapeB [0 ]), dtype , A .device , "col32" , "row" )
281+ if HIP_ENVIRONMENT :
282+ # Use col format for HIP
283+ out , Sout = get_transform_buffer ((shapeA [0 ], shapeA [1 ], shapeB [0 ]), dtype , A .device , "col" , "row" )
284+ else :
285+ out , Sout = get_transform_buffer ((shapeA [0 ], shapeA [1 ], shapeB [0 ]), dtype , A .device , "col32" , "row" )
272286
273287 assert dimsB != 3 , "len(B.shape)==3 not supported"
274288 assert A .device .type == "cuda"
275289 assert B .device .type == "cuda"
276290 assert A .dtype == torch .int8
277291 assert B .dtype == torch .int8
278292 assert out .dtype == dtype
279- assert SA [1 ] == "col32"
280- assert SB [1 ] in ["col_turing" , "col_ampere" ]
281- assert Sout [1 ] == "col32"
293+ if HIP_ENVIRONMENT :
294+ # Use col format for HIP
295+ assert SA [1 ] == "col"
296+ assert SB [1 ] == "col"
297+ assert Sout [1 ] == "col"
298+ else :
299+ assert SA [1 ] == "col32"
300+ assert SB [1 ] in ["col_turing" , "col_ampere" ]
301+ assert Sout [1 ] == "col32"
282302 assert (
283303 shapeA [- 1 ] == shapeB [- 1 ]
284304 ), f"Matmullt only supports A @ B^T. Inner matrix dimensions do not match: A @ B = { shapeA } @ { shapeB } "
@@ -293,17 +313,23 @@ def igemmlt(
293313 ptrC = get_ptr (out )
294314
295315 k = shapeA [- 1 ]
296- lda = ct . c_int32 ( m * 32 )
297- if formatB == "col_turing" :
298- # turing: tiles with rows filled up to multiple of 8 rows by 32 columns
299- # n = rows
300- ldb = ct .c_int32 ((( rows + 7 ) // 8 ) * 8 * 32 )
316+ if HIP_ENVIRONMENT :
317+ # Set ld values for col format
318+ lda = ct . c_int32 ( m )
319+ ldb = ct . c_int32 ( shapeB [ 0 ])
320+ ldc = ct .c_int32 (m )
301321 else :
302- # ampere: tiles with rows filled up to multiple of 32 rows by 32 columns
303- # n = rows
304- ldb = ct .c_int32 (((rows + 31 ) // 32 ) * 32 * 32 )
322+ lda = ct .c_int32 (m * 32 )
323+ if formatB == "col_turing" :
324+ # turing: tiles with rows filled up to multiple of 8 rows by 32 columns
325+ # n = rows
326+ ldb = ct .c_int32 (((rows + 7 ) // 8 ) * 8 * 32 )
327+ else :
328+ # ampere: tiles with rows filled up to multiple of 32 rows by 32 columns
329+ # n = rows
330+ ldb = ct .c_int32 (((rows + 31 ) // 32 ) * 32 * 32 )
305331
306- ldc = ct .c_int32 (m * 32 )
332+ ldc = ct .c_int32 (m * 32 )
307333 m = ct .c_int32 (m )
308334 n = ct .c_int32 (n )
309335 k = ct .c_int32 (k )
@@ -312,7 +338,7 @@ def igemmlt(
312338 ptrRowScale = get_ptr (None )
313339 is_on_gpu ([A , B , out ])
314340
315- if formatB == "col_turing" :
341+ if formatB == "col_turing" or HIP_ENVIRONMENT :
316342 if dtype == torch .int32 :
317343 has_error = lib .cigemmlt_turing_32 (ptr , m , n , k , ptrA , ptrB , ptrC , ptrRowScale , lda , ldb , ldc )
318344 else :
@@ -324,7 +350,7 @@ def igemmlt(
324350 else :
325351 has_error = lib .cigemmlt_ampere_8 (ptr , m , n , k , ptrA , ptrB , ptrC , ptrRowScale , lda , ldb , ldc )
326352
327- if has_error == 100 : # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu`
353+ if has_error == 100 : # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu`, `ops.hip`
328354 raise NotImplementedError ("igemmlt not available (probably built with NO_CUBLASLT)" )
329355
330356 if has_error :
@@ -348,6 +374,9 @@ def mm_dequant(
348374 new_col_stats : Optional [torch .Tensor ] = None ,
349375 bias : Optional [torch .Tensor ] = None ,
350376 ):
377+ if HIP_ENVIRONMENT :
378+ # HIP kernel requires 'row' format
379+ A , quant_state = nvidia_transform (A , "row" , state = quant_state )
351380 assert A .dtype == torch .int32
352381 if bias is not None :
353382 assert bias .dtype == torch .float16
@@ -386,7 +415,11 @@ def mm_dequant(
386415 def extract_outliers (self , A : torch .Tensor , SA : Tuple [torch .Size , str ], idx : torch .Tensor ):
387416 shapeA = SA [0 ]
388417 formatA = SA [1 ]
389- assert formatA in ["col_turing" , "col_ampere" ]
418+ if not HIP_ENVIRONMENT :
419+ assert formatA in ["col_turing" , "col_ampere" ]
420+ else :
421+ # HIP uses col format
422+ assert formatA in ["col" ]
390423 assert A .device .type == "cuda"
391424
392425 out = torch .zeros ((shapeA [0 ], idx .numel ()), dtype = torch .int8 , device = A .device )
@@ -400,7 +433,7 @@ def extract_outliers(self, A: torch.Tensor, SA: Tuple[torch.Size, str], idx: tor
400433
401434 prev_device = pre_call (A .device )
402435
403- if formatA == "col_turing" :
436+ if formatA == "col_turing" or HIP_ENVIRONMENT :
404437 lib .cextractOutliers_turing (ptrA , ptrIdx , ptrOut , idx_size , rows , cols )
405438 elif formatA == "col_ampere" :
406439 lib .cextractOutliers_ampere (ptrA , ptrIdx , ptrOut , idx_size , rows , cols )
@@ -414,11 +447,15 @@ def quantize_4bit(
414447 A : torch .Tensor ,
415448 absmax : Optional [torch .Tensor ] = None ,
416449 out : Optional [torch .Tensor ] = None ,
417- blocksize = 64 ,
450+ blocksize : Optional [ int ] = None ,
418451 compress_statistics = False ,
419452 quant_type : Literal ["fp4" , "nf4" ] = "fp4" ,
420453 quant_storage = torch .uint8 ,
421454 ) -> Tuple [torch .Tensor , QuantState ]:
455+ if blocksize is None :
456+ # Some AMD GPUs have warpsize 64
457+ # Set default blocksize to 128 (~warpsize 64 in kernel) for HIP
458+ blocksize = 64 if not HIP_ENVIRONMENT else 128
422459 if A .device .type != "cuda" :
423460 raise NotImplementedError (f"Device type not supported for FP4 quantization: { A .device .type } " )
424461 if quant_type not in ["fp4" , "nf4" ]:
@@ -436,7 +473,12 @@ def quantize_4bit(
436473 mod = dtype2bytes [quant_storage ] * 2
437474 out = torch .zeros (((n + 1 ) // mod , 1 ), dtype = quant_storage , device = A .device )
438475
439- assert blocksize in [4096 , 2048 , 1024 , 512 , 256 , 128 , 64 ]
476+ # Some AMD GPUs have warpsize 64
477+ # Set min blocksize to 128 (~warpsize 64 in kernel) for HIP
478+ if not HIP_ENVIRONMENT :
479+ assert blocksize in [4096 , 2048 , 1024 , 512 , 256 , 128 , 64 ]
480+ else :
481+ assert blocksize in [4096 , 2048 , 1024 , 512 , 256 , 128 ]
440482
441483 prev_device = pre_call (A .device )
442484 is_on_gpu ([A , out , absmax ])
@@ -507,12 +549,19 @@ def dequantize_4bit(
507549 quant_state : Optional [QuantState ] = None ,
508550 absmax : Optional [torch .Tensor ] = None ,
509551 out : Optional [torch .Tensor ] = None ,
510- blocksize : int = 64 ,
552+ blocksize : Optional [ int ] = None ,
511553 quant_type : Literal ["fp4" , "nf4" ] = "fp4" ,
512554 ) -> torch .Tensor :
513- if blocksize not in [2048 , 4096 , 1024 , 512 , 256 , 128 , 64 ]:
555+ # Some AMD GPUs have warpsize 64
556+ # Set default blocksize to 128 (~warpsize 64 in kernel) for HIP
557+ if blocksize is None :
558+ blocksize = 64 if not HIP_ENVIRONMENT else 128
559+ supported_blocksizes = [2048 , 4096 , 1024 , 512 , 256 , 128 , 64 ]
560+ if HIP_ENVIRONMENT :
561+ supported_blocksizes = supported_blocksizes [:- 1 ]
562+ if blocksize not in supported_blocksizes :
514563 raise ValueError (
515- f"The blockwise of { blocksize } is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64] "
564+ f"The blockwise of { blocksize } is not supported. Supported values: { supported_blocksizes } "
516565 )
517566
518567 if quant_type not in ["fp4" , "nf4" ]:
0 commit comments