|
| 1 | +from typing import Optional |
1 | 2 | import warnings |
2 | 3 |
|
3 | 4 | import torch |
4 | 5 |
|
| 6 | +from bitsandbytes.functional import ( |
| 7 | + QuantState, |
| 8 | + get_4bit_type, |
| 9 | +) |
| 10 | + |
5 | 11 | try: |
6 | 12 | # to support Intel CPU/GPU (XPU) backend |
7 | 13 | import intel_extension_for_pytorch as ipex |
@@ -228,3 +234,290 @@ def mm_dequant_impl( |
228 | 234 | out = out + bias.to(compute_dtype) |
229 | 235 | out = out.to(output_dtype) |
230 | 236 | return out |
| 237 | + |
| 238 | + |
| 239 | +NF4_QUANT_TABLE = [ |
| 240 | + -1.0 - 1e-2, # 0b0000 |
| 241 | + -0.8480964004993439, # 0b0001 |
| 242 | + -0.6106329262256622, # 0b0010 |
| 243 | + -0.4599952697753906, # 0b0011 |
| 244 | + -0.33967943489551544, # 0b0100 |
| 245 | + -0.23460740596055984, # 0b0101 |
| 246 | + -0.13791173323988914, # 0b0110 |
| 247 | + -0.045525018125772476, # 0b0111 |
| 248 | + 0.03979014977812767, # 0b1000 |
| 249 | + 0.1202552504837513, # 0b1001 |
| 250 | + 0.2035212516784668, # 0b1010 |
| 251 | + 0.2920137718319893, # 0b1011 |
| 252 | + 0.3893125355243683, # 0b1100 |
| 253 | + 0.5016634166240692, # 0b1101 |
| 254 | + 0.6427869200706482, # 0b1110 |
| 255 | + 0.8614784181118011, # 0b1111 |
| 256 | +] |
| 257 | + |
| 258 | + |
| 259 | +FP4_QUANT_TABLE = { |
| 260 | + 0 - 1e-2: 0, # 0b0000 |
| 261 | + 0.00260417: 1, # 0b0001 |
| 262 | + 0.0859375: 6, # 0b0110 |
| 263 | + 0.20833333: 7, # 0b0111 |
| 264 | + 0.29166667: 4, # 0b0100 |
| 265 | + 0.4166667: 5, # 0b0101 |
| 266 | + 0.583333: 2, # 0b0010 |
| 267 | + 0.8333333: 3, # 0b0011 |
| 268 | +} |
| 269 | + |
| 270 | + |
| 271 | +# It's faster not to use torch.compile |
| 272 | +def quantize_4bit_impl( |
| 273 | + A: Tensor, |
| 274 | + absmax: Tensor = None, |
| 275 | + out: Tensor = None, |
| 276 | + blocksize=64, |
| 277 | + compress_statistics=False, |
| 278 | + quant_type="nf4", |
| 279 | +) -> Tensor: |
| 280 | + """ |
| 281 | + Quantize tensor A in blocks of 4-bit values. |
| 282 | +
|
| 283 | + Quantizes tensor A by dividing it into blocks which are independently quantized to FP4. |
| 284 | +
|
| 285 | + Parameters |
| 286 | + ---------- |
| 287 | + A : torch.Tensor |
| 288 | + The input tensor. |
| 289 | + absmax : torch.Tensor |
| 290 | + The absmax values. |
| 291 | + out : torch.Tensor |
| 292 | + The output tensor (8-bit). |
| 293 | + blocksize : int |
| 294 | + The blocksize used in quantization. |
| 295 | + quant_type : str |
| 296 | + The 4-bit quantization data type {fp4, nf4}, only nf4 is supported now |
| 297 | +
|
| 298 | + Returns |
| 299 | + ------- |
| 300 | + torch.Tensor: |
| 301 | + The 8-bit tensor with packed 4-bit values. |
| 302 | + tuple(torch.Tensor, torch.Size, torch.dtype, int): |
| 303 | + The quantization state to undo the quantization. |
| 304 | + """ |
| 305 | + if quant_type not in ["nf4", "fp4"]: |
| 306 | + raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented for CPU/XPU.") |
| 307 | + if quant_type == "fp4": |
| 308 | + warnings.warn("fp4 quantization is currently slow on CPU/XPU. Please Use nf4 instead for better performance.") |
| 309 | + assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] |
| 310 | + n = A.numel() |
| 311 | + input_shape = A.shape |
| 312 | + blocks = n // blocksize |
| 313 | + blocks += 1 if n % blocksize > 0 else 0 |
| 314 | + |
| 315 | + if absmax is None: |
| 316 | + absmax = torch.zeros((blocks,), device=A.device, dtype=A.dtype) |
| 317 | + |
| 318 | + if out is None: |
| 319 | + out = torch.zeros(((n + 1) // 2), dtype=torch.uint8, device=A.device) |
| 320 | + |
| 321 | + rem = n % blocksize |
| 322 | + has_rem = rem > 0 |
| 323 | + |
| 324 | + # Scale tensor to [-1, 1] |
| 325 | + A_reshaped = A.reshape(n) |
| 326 | + A_com = A_reshaped[: n - rem] |
| 327 | + A_com_reshaped = A_com.reshape(n // blocksize, blocksize) |
| 328 | + absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0] |
| 329 | + scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1) |
| 330 | + scaled_A = scaled_A.reshape(-1) |
| 331 | + if has_rem: |
| 332 | + absmax[-1] = torch.abs(A_reshaped[n - rem :]).max() |
| 333 | + scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1) |
| 334 | + scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0) |
| 335 | + # map [-1, 1] to nf4/fp4 |
| 336 | + out_uint8 = torch.empty(scaled_A.shape, dtype=torch.uint8) |
| 337 | + if quant_type == "nf4": |
| 338 | + for i in range(len(NF4_QUANT_TABLE)): |
| 339 | + out_uint8[scaled_A > NF4_QUANT_TABLE[i]] = i |
| 340 | + elif quant_type == "fp4": |
| 341 | + sign = scaled_A < 0 |
| 342 | + abs_scaled_A = torch.abs(scaled_A) |
| 343 | + for key, val in FP4_QUANT_TABLE.items(): |
| 344 | + out_uint8[abs_scaled_A > key] = val |
| 345 | + out_uint8 += sign.to(torch.uint8) * 8 |
| 346 | + if out_uint8.size(-1) % 2: |
| 347 | + out_uint8 = torch.nn.functional.pad(out_uint8, (0, 1), value=0) |
| 348 | + out[:] = out_uint8[1::2].bitwise_left_shift(4).bitwise_or_(out_uint8[::2]) |
| 349 | + |
| 350 | + code = get_4bit_type(quant_type, device=A.device) |
| 351 | + |
| 352 | + if compress_statistics: |
| 353 | + raise NotImplementedError("bnb_4bit_use_double_quant is not supported yet for CPU/XPU") |
| 354 | + else: |
| 355 | + state = QuantState( |
| 356 | + absmax=absmax, |
| 357 | + shape=input_shape, |
| 358 | + dtype=A.dtype, |
| 359 | + blocksize=blocksize, |
| 360 | + code=code, |
| 361 | + quant_type=quant_type, |
| 362 | + ) |
| 363 | + |
| 364 | + if ipex_cpu and _ipex_cpu_version_prereq(2, 3) and input_shape[1] % blocksize == 0 and quant_type == "nf4": |
| 365 | + # lowp_mode: lowest precision for computation |
| 366 | + lowp_mode = ipex_cpu.quantization.WoqLowpMode.BF16 |
| 367 | + state.op_context = torch.ops.ipex_prepack.weight_only_qlinear_prepack( |
| 368 | + out.reshape([input_shape[0], input_shape[1] // 2]), |
| 369 | + ipex_cpu.quantization.WoqWeightDtype.NF4, |
| 370 | + input_shape, # weight shape |
| 371 | + absmax.view(input_shape[0], input_shape[1] // blocksize), # scales |
| 372 | + None, # zero_points |
| 373 | + None, # bias |
| 374 | + None, # g_idx |
| 375 | + None, # batch_size |
| 376 | + blocksize, |
| 377 | + int(lowp_mode), |
| 378 | + -1, # act_quant_mode. -1 means don't quant activation |
| 379 | + ) |
| 380 | + state.absmax = torch.Tensor() |
| 381 | + return torch.Tensor(), state |
| 382 | + |
| 383 | + return out, state |
| 384 | + |
| 385 | + |
| 386 | +@_maybe_torch_compile |
| 387 | +def dequantize_4bit_impl( |
| 388 | + A: Tensor, |
| 389 | + quant_state=None, |
| 390 | + absmax: Tensor = None, |
| 391 | + out: Tensor = None, |
| 392 | + blocksize: int = 64, |
| 393 | + quant_type="nf4", |
| 394 | +) -> Tensor: |
| 395 | + """ |
| 396 | + Dequantizes FP4 blockwise quantized values. |
| 397 | +
|
| 398 | + Dequantizes the tensor A with maximum absolute values absmax in blocks of size blocksize. |
| 399 | +
|
| 400 | + Parameters |
| 401 | + ---------- |
| 402 | + A : torch.Tensor |
| 403 | + The input 8-bit tensor (packed 4-bit values). |
| 404 | + quant_state : QuantState |
| 405 | + object with quantisation stats, incl. absmax values, original tensor shape and original dtype. |
| 406 | + absmax : torch.Tensor |
| 407 | + The absmax values. |
| 408 | + out : torch.Tensor |
| 409 | + Dequantized output tensor. |
| 410 | + blocksize : int |
| 411 | + The blocksize used in quantization. |
| 412 | + quant_type : str |
| 413 | + The 4-bit quantization data type {fp4, nf4}, only nf4 is supported now |
| 414 | +
|
| 415 | +
|
| 416 | + Returns |
| 417 | + ------- |
| 418 | + torch.Tensor: |
| 419 | + Dequantized tensor. |
| 420 | + """ |
| 421 | + |
| 422 | + if quant_state is None: |
| 423 | + assert absmax is not None and out is not None |
| 424 | + |
| 425 | + quant_state = QuantState( |
| 426 | + absmax=absmax, |
| 427 | + shape=out.shape, |
| 428 | + dtype=out.dtype, |
| 429 | + blocksize=blocksize, |
| 430 | + quant_type=quant_type, |
| 431 | + ) |
| 432 | + |
| 433 | + else: |
| 434 | + absmax = quant_state.absmax |
| 435 | + |
| 436 | + if quant_type not in ["nf4", "fp4"]: |
| 437 | + raise NotImplementedError( |
| 438 | + f"4-bit quantization data type {quant_state.quant_type} is not implemented for CPU/XPU." |
| 439 | + ) |
| 440 | + |
| 441 | + if quant_state.nested: |
| 442 | + raise NotImplementedError("bnb_4bit_use_double_quant is not supported yet for CPU/XPU") |
| 443 | + |
| 444 | + if ipex_cpu and _ipex_cpu_version_prereq(2, 3) and hasattr(quant_state, "op_context"): |
| 445 | + assert quant_state.op_context is not None |
| 446 | + A = quant_state.op_context.to_public(quant_state.op_context.get_weight()) |
| 447 | + A = A.reshape(-1) |
| 448 | + absmax = quant_state.op_context.get_scales().reshape(-1) |
| 449 | + |
| 450 | + if out is None: |
| 451 | + out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device) |
| 452 | + |
| 453 | + n = out.numel() |
| 454 | + # Map nf4 to [-1, 1] |
| 455 | + out_uint8 = torch.empty(A.size(0) * 2, dtype=torch.uint8, device=A.device) |
| 456 | + out_uint8[::2] = A.bitwise_and(0xF) |
| 457 | + out_uint8[1::2] = A.bitwise_right_shift(4) |
| 458 | + out_dq = torch.empty(out_uint8.shape).to(quant_state.dtype) |
| 459 | + for i in range(len(quant_state.code)): |
| 460 | + out_dq[out_uint8 == i] = quant_state.code[i] |
| 461 | + |
| 462 | + # Apply scales |
| 463 | + if out_dq.numel() != n: |
| 464 | + assert out_dq.numel() == n + 1 |
| 465 | + out_dq = torch.narrow(out_dq, 0, 0, n) |
| 466 | + blocks = n // blocksize |
| 467 | + blocks += 1 if n % blocksize > 0 else 0 |
| 468 | + rem = n % blocksize |
| 469 | + has_rem = rem > 0 |
| 470 | + out_reshaped = out.reshape(-1) |
| 471 | + out_reshaped[: n - rem] = (out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1)).reshape( |
| 472 | + -1 |
| 473 | + ) |
| 474 | + if has_rem: |
| 475 | + out_reshaped[n - rem :] = out_dq[n - rem :] * absmax[-1] |
| 476 | + |
| 477 | + # take transpose here because weight is transposed (again) for computation |
| 478 | + return out.t() |
| 479 | + |
| 480 | + |
| 481 | +# Do not need torch.compile here as we are calling torch/ipex kernel |
| 482 | +def gemm_4bit_impl( |
| 483 | + A: torch.Tensor, |
| 484 | + B: torch.Tensor, |
| 485 | + out: Optional[torch.Tensor] = None, |
| 486 | + transposed_A=False, |
| 487 | + transposed_B=False, |
| 488 | + state: QuantState = None, |
| 489 | +) -> torch.Tensor: |
| 490 | + """ |
| 491 | + Matrix-matrix multiplication with 4-bit quantization. |
| 492 | +
|
| 493 | + Parameters |
| 494 | + ---------- |
| 495 | + A : torch.Tensor |
| 496 | + The first input tensor. Usually the activation tensor. |
| 497 | + B : torch.Tensor |
| 498 | + The second input tensor. Usually the weight tensor. |
| 499 | + out : torch.Tensor |
| 500 | + The output tensor. |
| 501 | + transposed_A : bool |
| 502 | + Whether A is transposed |
| 503 | + transposed_B : bool |
| 504 | + Whether B is transposed |
| 505 | + state : QuantState |
| 506 | + Contains quantization info, such as blocksize and dtype |
| 507 | +
|
| 508 | + Returns |
| 509 | + ------- |
| 510 | + torch.Tensor: |
| 511 | + GEMM output tensor. |
| 512 | + """ |
| 513 | + if ipex_cpu and _ipex_cpu_version_prereq(2, 3) and hasattr(state, "op_context"): |
| 514 | + assert state.op_context is not None |
| 515 | + output = torch.ops.torch_ipex.ipex_woq_linear(A, state.op_context.get_data_handle()) |
| 516 | + else: |
| 517 | + dqB = dequantize_4bit_impl(B, state, blocksize=state.blocksize) |
| 518 | + output = torch.matmul(A, dqB) |
| 519 | + if out is not None: |
| 520 | + out.copy_(output) |
| 521 | + else: |
| 522 | + out = output |
| 523 | + return out |
0 commit comments