1- from collections .abc import Sequence
21import ctypes as ct
2+ import logging
33
44import torch
55
66from bitsandbytes .functional import get_ptr
77
88from ..._ops import register_kernel
9- from ...cextension import lib
10- from ..utils import ipex_cpu
9+ from ...cextension import ErrorHandlerMockBNBNativeLibrary , lib
10+
11+ logger = logging .getLogger (__name__ )
1112
1213# torch._int_mm for s8@s8->s32 is supported on CPU from torch 2.4+.
1314# However, we can overflow if we use this without AVX512_VNNI support.
@@ -24,97 +25,77 @@ def _(A: torch.Tensor, B: torch.Tensor):
2425 ).reshape (* A .shape [:- 1 ], B .shape [0 ])
2526
2627
27- @register_kernel ("bitsandbytes::quantize_blockwise" , "cpu" )
28- def _ (A : torch .Tensor , code : torch .Tensor , blocksize : int ) -> tuple [torch .Tensor , torch .Tensor ]:
29- torch ._check_is_size (blocksize )
30-
31- n = A .numel ()
32-
33- # Only FP32 has c++ kernrl
34- if A .dtype == torch .float32 :
35- blocks = - (n // - blocksize )
36-
37- absmax = torch .empty ((blocks ,), device = A .device , dtype = torch .float32 )
38- out = torch .empty_like (A , dtype = torch .uint8 )
39-
40- lib .cquantize_blockwise_cpu_fp32 (
41- get_ptr (code ),
42- get_ptr (A ),
43- get_ptr (absmax ),
44- get_ptr (out ),
45- ct .c_longlong (blocksize ),
46- ct .c_longlong (n ),
47- )
48- else :
49- rem = n % blocksize
50- has_rem = rem > 0
51- blocks = n // blocksize + has_rem
52- absmax = torch .zeros ((blocks ,), device = A .device , dtype = torch .float32 )
53- A_reshaped = A .reshape (n )
54- A_com = A_reshaped [: n - rem ]
55- A_com_reshaped = A_com .reshape (n // blocksize , blocksize )
56- absmax [: blocks - has_rem ] = torch .abs (A_com_reshaped ).max (dim = - 1 )[0 ]
57- scaled_A = torch .clamp (A_com_reshaped * (1 / absmax [: blocks - has_rem ].view (- 1 , 1 )), - 1 , 1 )
58- scaled_A = scaled_A .reshape (- 1 )
59- if has_rem :
60- absmax [- 1 ] = torch .abs (A_reshaped [n - rem :]).max ()
61- scaled_A_rem = torch .clamp (A_reshaped [n - rem :] * (1 / absmax [- 1 ]), - 1 , 1 )
62- scaled_A = torch .cat ([scaled_A , scaled_A_rem ], dim = 0 )
63-
64- diff = torch .abs (scaled_A .unsqueeze (- 1 ) - code .to (scaled_A .device ))
65- out = torch .argmin (diff , dim = - 1 ).to (torch .uint8 ).to (scaled_A .device ).reshape (A .shape )
66-
67- return out , absmax
68-
69-
70- @register_kernel ("bitsandbytes::dequantize_blockwise" , "cpu" )
71- def _ (A : torch .Tensor , absmax : torch .Tensor , code : torch .Tensor , blocksize : int , dtype : torch .dtype ) -> torch .Tensor :
72- torch ._check_is_size (blocksize )
73- torch ._check (A .dtype == torch .uint8 , lambda : f"A must be uint8, got { A .dtype } " )
74-
75- # Only FP32 has c++ kernrl
76- if dtype == torch .float32 :
77- out = torch .empty_like (A , dtype = dtype )
78-
79- lib .cdequantize_blockwise_cpu_fp32 (
80- get_ptr (code ),
81- get_ptr (A ),
82- get_ptr (absmax ),
83- get_ptr (out ),
84- ct .c_longlong (blocksize ),
85- ct .c_longlong (A .numel ()),
86- )
87- else :
88- out = code [A .reshape (- 1 ).int ()]
89- blocks = out .shape [- 1 ] // blocksize
90- res = out .shape [- 1 ] % blocksize
91- if res != 0 :
92- out = torch .nn .functional .pad (out , (0 , blocksize - res ), mode = "constant" , value = 0 )
93- out = (out .view (- 1 , blocksize ) * absmax .view (- 1 , 1 )).to (dtype ).reshape (- 1 )
94- out = out [: blocks * blocksize + res ]
95- out = out .reshape (A .shape )
96-
97- return out
98-
99-
100- if ipex_cpu :
101- from bitsandbytes .utils import _reverse_4bit_compress_format
102-
103- @register_kernel ("bitsandbytes::dequantize_nf4_ipex" , "cpu" )
28+ if not isinstance (lib , ErrorHandlerMockBNBNativeLibrary ):
29+
30+ @register_kernel ("bitsandbytes::quantize_blockwise" , "cpu" )
31+ def _ (A : torch .Tensor , code : torch .Tensor , blocksize : int ) -> tuple [torch .Tensor , torch .Tensor ]:
32+ torch ._check_is_size (blocksize )
33+
34+ n = A .numel ()
35+
36+ # Only FP32 has c++ kernrl
37+ if A .dtype == torch .float32 :
38+ blocks = - (n // - blocksize )
39+
40+ absmax = torch .empty ((blocks ,), device = A .device , dtype = torch .float32 )
41+ out = torch .empty_like (A , dtype = torch .uint8 )
42+
43+ lib .cquantize_blockwise_cpu_fp32 (
44+ get_ptr (code ),
45+ get_ptr (A ),
46+ get_ptr (absmax ),
47+ get_ptr (out ),
48+ ct .c_longlong (blocksize ),
49+ ct .c_longlong (n ),
50+ )
51+ else :
52+ rem = n % blocksize
53+ has_rem = rem > 0
54+ blocks = n // blocksize + has_rem
55+ absmax = torch .zeros ((blocks ,), device = A .device , dtype = torch .float32 )
56+ A_reshaped = A .reshape (n )
57+ A_com = A_reshaped [: n - rem ]
58+ A_com_reshaped = A_com .reshape (n // blocksize , blocksize )
59+ absmax [: blocks - has_rem ] = torch .abs (A_com_reshaped ).max (dim = - 1 )[0 ]
60+ scaled_A = torch .clamp (A_com_reshaped * (1 / absmax [: blocks - has_rem ].view (- 1 , 1 )), - 1 , 1 )
61+ scaled_A = scaled_A .reshape (- 1 )
62+ if has_rem :
63+ absmax [- 1 ] = torch .abs (A_reshaped [n - rem :]).max ()
64+ scaled_A_rem = torch .clamp (A_reshaped [n - rem :] * (1 / absmax [- 1 ]), - 1 , 1 )
65+ scaled_A = torch .cat ([scaled_A , scaled_A_rem ], dim = 0 )
66+
67+ diff = torch .abs (scaled_A .unsqueeze (- 1 ) - code .to (scaled_A .device ))
68+ out = torch .argmin (diff , dim = - 1 ).to (torch .uint8 ).to (scaled_A .device ).reshape (A .shape )
69+
70+ return out , absmax
71+
72+ @register_kernel ("bitsandbytes::dequantize_blockwise" , "cpu" )
10473 def _ (
105- A : torch .Tensor ,
106- absmax : torch .Tensor ,
107- blocksize : int ,
108- shape : Sequence [int ],
109- dtype : torch .dtype ,
74+ A : torch .Tensor , absmax : torch .Tensor , code : torch .Tensor , blocksize : int , dtype : torch .dtype
11075 ) -> torch .Tensor :
111- ipex_weight = torch .ops .ipex_prepack .woq_linear_unpack_weight (A , "nf4" , shape , 2 )
112- A = _reverse_4bit_compress_format (ipex_weight .reshape (- 1 )).reshape (1 , - 1 )
113- return torch .ops .bitsandbytes .dequantize_4bit .default (
114- A ,
115- absmax ,
116- blocksize ,
117- "nf4" ,
118- shape ,
119- dtype ,
120- )
76+ torch ._check_is_size (blocksize )
77+ torch ._check (A .dtype == torch .uint8 , lambda : f"A must be uint8, got { A .dtype } " )
78+
79+ # Only FP32 has c++ kernrl
80+ if dtype == torch .float32 :
81+ out = torch .empty_like (A , dtype = dtype )
82+
83+ lib .cdequantize_blockwise_cpu_fp32 (
84+ get_ptr (code ),
85+ get_ptr (A ),
86+ get_ptr (absmax ),
87+ get_ptr (out ),
88+ ct .c_longlong (blocksize ),
89+ ct .c_longlong (A .numel ()),
90+ )
91+ else :
92+ out = code [A .reshape (- 1 ).int ()]
93+ blocks = out .shape [- 1 ] // blocksize
94+ res = out .shape [- 1 ] % blocksize
95+ if res != 0 :
96+ out = torch .nn .functional .pad (out , (0 , blocksize - res ), mode = "constant" , value = 0 )
97+ out = (out .view (- 1 , blocksize ) * absmax .view (- 1 , 1 )).to (dtype ).reshape (- 1 )
98+ out = out [: blocks * blocksize + res ]
99+ out = out .reshape (A .shape )
100+
101+ return out
0 commit comments