Skip to content

Commit 701c5aa

Browse files
Merge pull request #1206 from Xia-Weiwen/multi-backend-refactor-cpu-4bit
Support 4bit on CPU backend
2 parents 8561f09 + 2c489f8 commit 701c5aa

File tree

5 files changed

+361
-8
lines changed

5 files changed

+361
-8
lines changed

bitsandbytes/autograd/_functions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,8 @@ def matmul_4bit(
572572
bias=None,
573573
):
574574
assert quant_state is not None
575-
if A.numel() == A.shape[-1] and A.requires_grad == False:
575+
if (A.numel() == A.shape[-1] or A.device.type == "cpu") and A.requires_grad == False:
576+
# CPU backend does not require A to be a vector
576577
if A.shape[-1] % quant_state.blocksize != 0:
577578
warn(
578579
f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}",

bitsandbytes/backends/cpu.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,12 @@
66

77
from .base import Backend
88
from .cpu_xpu_common import (
9+
dequantize_4bit_impl,
910
double_quant_impl,
11+
gemm_4bit_impl,
1012
igemmlt_impl,
1113
mm_dequant_impl,
14+
quantize_4bit_impl,
1215
)
1316

1417
Tensor = torch.Tensor
@@ -132,7 +135,9 @@ def quantize_4bit(
132135
quant_type: Literal["fp4", "nf4"] = "fp4",
133136
quant_storage=torch.uint8,
134137
) -> Tuple[torch.Tensor, QuantState]:
135-
raise NotImplementedError("Not yet implemented for CPU backend")
138+
assert_on_cpu([A, absmax, out])
139+
assert quant_storage == torch.uint8, "CPU backend only supports uint8 quant_storage"
140+
return quantize_4bit_impl(A, absmax, out, blocksize, compress_statistics, quant_type)
136141

137142
def dequantize_4bit(
138143
self,
@@ -143,7 +148,8 @@ def dequantize_4bit(
143148
blocksize: int = 64,
144149
quant_type: Literal["fp4", "nf4"] = "fp4",
145150
) -> torch.Tensor:
146-
raise NotImplementedError("Not yet implemented for CPU backend")
151+
assert_on_cpu([A, absmax, out])
152+
return dequantize_4bit_impl(A, quant_state, absmax, out, blocksize, quant_type)
147153

148154
def gemv_4bit(
149155
self,
@@ -154,7 +160,11 @@ def gemv_4bit(
154160
transposed_B=False,
155161
state: QuantState = None,
156162
) -> torch.Tensor:
157-
raise NotImplementedError("Not yet implemented for CPU backend")
163+
assert_on_cpu([A, B, out])
164+
if state is None:
165+
raise ValueError("state cannot be None. gemv_4bit() requires the state from quantize_4bit()")
166+
167+
return gemm_4bit_impl(A, B, out, transposed_A, transposed_B, state)
158168

159169
def dequantize_blockwise(
160170
self,

bitsandbytes/backends/cpu_xpu_common.py

Lines changed: 293 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
1+
from typing import Optional
12
import warnings
23

34
import torch
45

6+
from bitsandbytes.functional import (
7+
QuantState,
8+
get_4bit_type,
9+
)
10+
511
try:
612
# to support Intel CPU/GPU (XPU) backend
713
import intel_extension_for_pytorch as ipex
@@ -228,3 +234,290 @@ def mm_dequant_impl(
228234
out = out + bias.to(compute_dtype)
229235
out = out.to(output_dtype)
230236
return out
237+
238+
239+
NF4_QUANT_TABLE = [
240+
-1.0 - 1e-2, # 0b0000
241+
-0.8480964004993439, # 0b0001
242+
-0.6106329262256622, # 0b0010
243+
-0.4599952697753906, # 0b0011
244+
-0.33967943489551544, # 0b0100
245+
-0.23460740596055984, # 0b0101
246+
-0.13791173323988914, # 0b0110
247+
-0.045525018125772476, # 0b0111
248+
0.03979014977812767, # 0b1000
249+
0.1202552504837513, # 0b1001
250+
0.2035212516784668, # 0b1010
251+
0.2920137718319893, # 0b1011
252+
0.3893125355243683, # 0b1100
253+
0.5016634166240692, # 0b1101
254+
0.6427869200706482, # 0b1110
255+
0.8614784181118011, # 0b1111
256+
]
257+
258+
259+
FP4_QUANT_TABLE = {
260+
0 - 1e-2: 0, # 0b0000
261+
0.00260417: 1, # 0b0001
262+
0.0859375: 6, # 0b0110
263+
0.20833333: 7, # 0b0111
264+
0.29166667: 4, # 0b0100
265+
0.4166667: 5, # 0b0101
266+
0.583333: 2, # 0b0010
267+
0.8333333: 3, # 0b0011
268+
}
269+
270+
271+
# It's faster not to use torch.compile
272+
def quantize_4bit_impl(
273+
A: Tensor,
274+
absmax: Tensor = None,
275+
out: Tensor = None,
276+
blocksize=64,
277+
compress_statistics=False,
278+
quant_type="nf4",
279+
) -> Tensor:
280+
"""
281+
Quantize tensor A in blocks of 4-bit values.
282+
283+
Quantizes tensor A by dividing it into blocks which are independently quantized to FP4.
284+
285+
Parameters
286+
----------
287+
A : torch.Tensor
288+
The input tensor.
289+
absmax : torch.Tensor
290+
The absmax values.
291+
out : torch.Tensor
292+
The output tensor (8-bit).
293+
blocksize : int
294+
The blocksize used in quantization.
295+
quant_type : str
296+
The 4-bit quantization data type {fp4, nf4}, only nf4 is supported now
297+
298+
Returns
299+
-------
300+
torch.Tensor:
301+
The 8-bit tensor with packed 4-bit values.
302+
tuple(torch.Tensor, torch.Size, torch.dtype, int):
303+
The quantization state to undo the quantization.
304+
"""
305+
if quant_type not in ["nf4", "fp4"]:
306+
raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented for CPU/XPU.")
307+
if quant_type == "fp4":
308+
warnings.warn("fp4 quantization is currently slow on CPU/XPU. Please Use nf4 instead for better performance.")
309+
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
310+
n = A.numel()
311+
input_shape = A.shape
312+
blocks = n // blocksize
313+
blocks += 1 if n % blocksize > 0 else 0
314+
315+
if absmax is None:
316+
absmax = torch.zeros((blocks,), device=A.device, dtype=A.dtype)
317+
318+
if out is None:
319+
out = torch.zeros(((n + 1) // 2), dtype=torch.uint8, device=A.device)
320+
321+
rem = n % blocksize
322+
has_rem = rem > 0
323+
324+
# Scale tensor to [-1, 1]
325+
A_reshaped = A.reshape(n)
326+
A_com = A_reshaped[: n - rem]
327+
A_com_reshaped = A_com.reshape(n // blocksize, blocksize)
328+
absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0]
329+
scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1)
330+
scaled_A = scaled_A.reshape(-1)
331+
if has_rem:
332+
absmax[-1] = torch.abs(A_reshaped[n - rem :]).max()
333+
scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1)
334+
scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0)
335+
# map [-1, 1] to nf4/fp4
336+
out_uint8 = torch.empty(scaled_A.shape, dtype=torch.uint8)
337+
if quant_type == "nf4":
338+
for i in range(len(NF4_QUANT_TABLE)):
339+
out_uint8[scaled_A > NF4_QUANT_TABLE[i]] = i
340+
elif quant_type == "fp4":
341+
sign = scaled_A < 0
342+
abs_scaled_A = torch.abs(scaled_A)
343+
for key, val in FP4_QUANT_TABLE.items():
344+
out_uint8[abs_scaled_A > key] = val
345+
out_uint8 += sign.to(torch.uint8) * 8
346+
if out_uint8.size(-1) % 2:
347+
out_uint8 = torch.nn.functional.pad(out_uint8, (0, 1), value=0)
348+
out[:] = out_uint8[1::2].bitwise_left_shift(4).bitwise_or_(out_uint8[::2])
349+
350+
code = get_4bit_type(quant_type, device=A.device)
351+
352+
if compress_statistics:
353+
raise NotImplementedError("bnb_4bit_use_double_quant is not supported yet for CPU/XPU")
354+
else:
355+
state = QuantState(
356+
absmax=absmax,
357+
shape=input_shape,
358+
dtype=A.dtype,
359+
blocksize=blocksize,
360+
code=code,
361+
quant_type=quant_type,
362+
)
363+
364+
if ipex_cpu and _ipex_cpu_version_prereq(2, 3) and input_shape[1] % blocksize == 0 and quant_type == "nf4":
365+
# lowp_mode: lowest precision for computation
366+
lowp_mode = ipex_cpu.quantization.WoqLowpMode.BF16
367+
state.op_context = torch.ops.ipex_prepack.weight_only_qlinear_prepack(
368+
out.reshape([input_shape[0], input_shape[1] // 2]),
369+
ipex_cpu.quantization.WoqWeightDtype.NF4,
370+
input_shape, # weight shape
371+
absmax.view(input_shape[0], input_shape[1] // blocksize), # scales
372+
None, # zero_points
373+
None, # bias
374+
None, # g_idx
375+
None, # batch_size
376+
blocksize,
377+
int(lowp_mode),
378+
-1, # act_quant_mode. -1 means don't quant activation
379+
)
380+
state.absmax = torch.Tensor()
381+
return torch.Tensor(), state
382+
383+
return out, state
384+
385+
386+
@_maybe_torch_compile
387+
def dequantize_4bit_impl(
388+
A: Tensor,
389+
quant_state=None,
390+
absmax: Tensor = None,
391+
out: Tensor = None,
392+
blocksize: int = 64,
393+
quant_type="nf4",
394+
) -> Tensor:
395+
"""
396+
Dequantizes FP4 blockwise quantized values.
397+
398+
Dequantizes the tensor A with maximum absolute values absmax in blocks of size blocksize.
399+
400+
Parameters
401+
----------
402+
A : torch.Tensor
403+
The input 8-bit tensor (packed 4-bit values).
404+
quant_state : QuantState
405+
object with quantisation stats, incl. absmax values, original tensor shape and original dtype.
406+
absmax : torch.Tensor
407+
The absmax values.
408+
out : torch.Tensor
409+
Dequantized output tensor.
410+
blocksize : int
411+
The blocksize used in quantization.
412+
quant_type : str
413+
The 4-bit quantization data type {fp4, nf4}, only nf4 is supported now
414+
415+
416+
Returns
417+
-------
418+
torch.Tensor:
419+
Dequantized tensor.
420+
"""
421+
422+
if quant_state is None:
423+
assert absmax is not None and out is not None
424+
425+
quant_state = QuantState(
426+
absmax=absmax,
427+
shape=out.shape,
428+
dtype=out.dtype,
429+
blocksize=blocksize,
430+
quant_type=quant_type,
431+
)
432+
433+
else:
434+
absmax = quant_state.absmax
435+
436+
if quant_type not in ["nf4", "fp4"]:
437+
raise NotImplementedError(
438+
f"4-bit quantization data type {quant_state.quant_type} is not implemented for CPU/XPU."
439+
)
440+
441+
if quant_state.nested:
442+
raise NotImplementedError("bnb_4bit_use_double_quant is not supported yet for CPU/XPU")
443+
444+
if ipex_cpu and _ipex_cpu_version_prereq(2, 3) and hasattr(quant_state, "op_context"):
445+
assert quant_state.op_context is not None
446+
A = quant_state.op_context.to_public(quant_state.op_context.get_weight())
447+
A = A.reshape(-1)
448+
absmax = quant_state.op_context.get_scales().reshape(-1)
449+
450+
if out is None:
451+
out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device)
452+
453+
n = out.numel()
454+
# Map nf4 to [-1, 1]
455+
out_uint8 = torch.empty(A.size(0) * 2, dtype=torch.uint8, device=A.device)
456+
out_uint8[::2] = A.bitwise_and(0xF)
457+
out_uint8[1::2] = A.bitwise_right_shift(4)
458+
out_dq = torch.empty(out_uint8.shape).to(quant_state.dtype)
459+
for i in range(len(quant_state.code)):
460+
out_dq[out_uint8 == i] = quant_state.code[i]
461+
462+
# Apply scales
463+
if out_dq.numel() != n:
464+
assert out_dq.numel() == n + 1
465+
out_dq = torch.narrow(out_dq, 0, 0, n)
466+
blocks = n // blocksize
467+
blocks += 1 if n % blocksize > 0 else 0
468+
rem = n % blocksize
469+
has_rem = rem > 0
470+
out_reshaped = out.reshape(-1)
471+
out_reshaped[: n - rem] = (out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1)).reshape(
472+
-1
473+
)
474+
if has_rem:
475+
out_reshaped[n - rem :] = out_dq[n - rem :] * absmax[-1]
476+
477+
# take transpose here because weight is transposed (again) for computation
478+
return out.t()
479+
480+
481+
# Do not need torch.compile here as we are calling torch/ipex kernel
482+
def gemm_4bit_impl(
483+
A: torch.Tensor,
484+
B: torch.Tensor,
485+
out: Optional[torch.Tensor] = None,
486+
transposed_A=False,
487+
transposed_B=False,
488+
state: QuantState = None,
489+
) -> torch.Tensor:
490+
"""
491+
Matrix-matrix multiplication with 4-bit quantization.
492+
493+
Parameters
494+
----------
495+
A : torch.Tensor
496+
The first input tensor. Usually the activation tensor.
497+
B : torch.Tensor
498+
The second input tensor. Usually the weight tensor.
499+
out : torch.Tensor
500+
The output tensor.
501+
transposed_A : bool
502+
Whether A is transposed
503+
transposed_B : bool
504+
Whether B is transposed
505+
state : QuantState
506+
Contains quantization info, such as blocksize and dtype
507+
508+
Returns
509+
-------
510+
torch.Tensor:
511+
GEMM output tensor.
512+
"""
513+
if ipex_cpu and _ipex_cpu_version_prereq(2, 3) and hasattr(state, "op_context"):
514+
assert state.op_context is not None
515+
output = torch.ops.torch_ipex.ipex_woq_linear(A, state.op_context.get_data_handle())
516+
else:
517+
dqB = dequantize_4bit_impl(B, state, blocksize=state.blocksize)
518+
output = torch.matmul(A, dqB)
519+
if out is not None:
520+
out.copy_(output)
521+
else:
522+
out = output
523+
return out

bitsandbytes/nn/modules.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ def from_prequantized(
285285
return self
286286

287287
def _quantize(self, device):
288-
w = self.data.contiguous().cuda(device)
288+
w = self.data.contiguous().to(device)
289289
w_4bit, quant_state = bnb.functional.quantize_4bit(
290290
w,
291291
blocksize=self.blocksize,
@@ -303,6 +303,9 @@ def _quantize(self, device):
303303
def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False):
304304
return self.to(device="cuda" if device is None else device, non_blocking=non_blocking)
305305

306+
def cpu(self, non_blocking: bool = False):
307+
return self.to(device="cpu", non_blocking=non_blocking)
308+
306309
@overload
307310
def to(
308311
self: T,
@@ -320,7 +323,7 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ...
320323
def to(self, *args, **kwargs):
321324
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
322325

323-
if device is not None and device.type == "cuda" and not self.bnb_quantized:
326+
if device is not None and device.type in ["cuda", "cpu"] and not self.bnb_quantized:
324327
return self._quantize(device)
325328
else:
326329
if self.quant_state is not None:

0 commit comments

Comments
 (0)