Skip to content

Commit 1813b05

Browse files
xiaolil1jiqing-fengshangerxin
authored
Add SYCL Kernels for XPU backend (#1679)
* Add SYCL Kernels for XPU backend * fix transpose Signed-off-by: jiqing-feng <[email protected]> * fix log and format Signed-off-by: jiqing-feng <[email protected]> * revert cpu changes Signed-off-by: jiqing-feng <[email protected]> * clean ipex_xpu Signed-off-by: jiqing-feng <[email protected]> * clean ipex import Signed-off-by: jiqing-feng <[email protected]> * fix ipex cpu import Signed-off-by: jiqing-feng <[email protected]> * fix typo Signed-off-by: jiqing-feng <[email protected]> * fix comments Signed-off-by: jiqing-feng <[email protected]> * refine gemv_4bit kernel * enable FP4 for dequant_4bit and gemv_4bit * refine FP4 dequantization performance * remove check for better performance Signed-off-by: jiqing-feng <[email protected]> * fix doc Signed-off-by: jiqing-feng <[email protected]> * clean code * fix tests Signed-off-by: jiqing-feng <[email protected]> * rm comments Signed-off-by: jiqing-feng <[email protected]> * fix memory issue * fix ut failure * adjust threshold Signed-off-by: jiqing-feng <[email protected]> * fix xpu check Signed-off-by: jiqing-feng <[email protected]> * change test_functional check Signed-off-by: jiqing-feng <[email protected]> * fix test_module Signed-off-by: jiqing-feng <[email protected]> * fix device check Signed-off-by: jiqing-feng <[email protected]> * fix tests Signed-off-by: jiqing-feng <[email protected]> * Enable Windows build and refine code * fix xpu log Signed-off-by: jiqing-feng <[email protected]> * remove ipex entirely Signed-off-by: jiqing-feng <[email protected]> * fix cpu int8 CB Signed-off-by: jiqing-feng <[email protected]> * fix lint Signed-off-by: jiqing-feng <[email protected]> * fix logs (#12) * fix logs Signed-off-by: jiqing-feng <[email protected]> * fix format Signed-off-by: jiqing-feng <[email protected]> --------- Signed-off-by: jiqing-feng <[email protected]> * Fix sycl lint error and tests (#13) * fix sycl nd Signed-off-by: jiqing-feng <[email protected]> * fix tests Signed-off-by: jiqing-feng <[email protected]> --------- Signed-off-by: jiqing-feng <[email protected]> * skip typo check for xpu kernel codes (#14) * skip test for xpu ops Signed-off-by: jiqing-feng <[email protected]> * fix lint Signed-off-by: jiqing-feng <[email protected]> * skip typo for xpu Signed-off-by: jiqing-feng <[email protected]> * skip Signed-off-by: jiqing-feng <[email protected]> * skip Signed-off-by: jiqing-feng <[email protected]> --------- Signed-off-by: jiqing-feng <[email protected]> * register triton kernel for quantization (#15) Signed-off-by: jiqing-feng <[email protected]> * Fix version comparison issue (#18) # Description The version comparison expression miss reference the .release property from the version object. This lead to compare between the tuple and the string # Error message ``` The 8-bit optimizer is not available on your device, only available on CUDA for now. 🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. Traceback (most recent call last): File "/home/erxin/jenkins/workspace/Unsloth_Benchmark/unsloth_validation/run.py", line 1, in <module> import unsloth File "/home/erxin/jenkins/workspace/Unsloth_Benchmark/v/lib/python3.10/site-packages/unsloth/__init__.py", line 235, in <module> from .models import * File "/home/erxin/jenkins/workspace/Unsloth_Benchmark/v/lib/python3.10/site-packages/unsloth/models/__init__.py", line 15, in <module> from .llama import FastLlamaModel File "/home/erxin/jenkins/workspace/Unsloth_Benchmark/v/lib/python3.10/site-packages/unsloth/models/llama.py", line 23, in <module> from ._utils import * File "/home/erxin/jenkins/workspace/Unsloth_Benchmark/v/lib/python3.10/site-packages/unsloth/models/_utils.py", line 89, in <module> from unsloth_zoo.patching_utils import ( File "/home/erxin/jenkins/workspace/Unsloth_Benchmark/v/lib/python3.10/site-packages/unsloth_zoo/patching_utils.py", line 629, in <module> import transformers.integrations.bitsandbytes File "/home/erxin/jenkins/workspace/Unsloth_Benchmark/v/lib/python3.10/site-packages/transformers/integrations/bitsandbytes.py", line 20, in <module> import bitsandbytes as bnb File "/home/erxin/jenkins/workspace/Unsloth_Benchmark/bitsandbytes/bitsandbytes/__init__.py", line 39, in <module> from .backends.xpu import ops as xpu_ops File "/home/erxin/jenkins/workspace/Unsloth_Benchmark/bitsandbytes/bitsandbytes/backends/xpu/ops.py", line 17, in <module> if version.parse(torch.__version__).release >= version.parse("2.9"): TypeError: '>=' not supported between instances of 'tuple' and 'Version' ``` --------- Signed-off-by: jiqing-feng <[email protected]> Co-authored-by: jiqing-feng <[email protected]> Co-authored-by: Er-Xin (Edwin) Shang <[email protected]>
1 parent 275671b commit 1813b05

File tree

23 files changed

+1010
-376
lines changed

23 files changed

+1010
-376
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ jobs:
162162
- name: Run tests
163163
run: pytest --durations=100
164164

165-
test-cpu-ipex:
165+
test-cpu-intel:
166166
if: github.repository == 'bitsandbytes-foundation/bitsandbytes'
167167
needs: build-cpu
168168
runs-on: banb-aws-general-8-plus-use1-public-80
@@ -186,7 +186,6 @@ jobs:
186186
- name: Install dependencies
187187
run: |
188188
pip install torch==2.7.1 --index-url https://download.pytorch.org/whl/cpu
189-
pip install intel_extension_for_pytorch==2.7.0 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/cpu/us/
190189
pip install -e ".[test]"
191190
pip install pytest-cov
192191
@@ -196,9 +195,6 @@ jobs:
196195
- name: Show environment information
197196
run: python -m torch.utils.collect_env
198197

199-
- name: IPEX smoke test
200-
run: python -c "import torch; import intel_extension_for_pytorch as ipex; print(torch.__version__); print(ipex.__version__);"
201-
202198
- name: Run tests
203199
run: pytest --durations=100
204200

@@ -286,15 +282,6 @@ jobs:
286282
fail-fast: false
287283
matrix:
288284
torch_version: ["2.7.1"] #["2.6.0", "2.7.1"]
289-
ipex: [false]
290-
# ipex: [true, false]
291-
# include:
292-
# - torch_version: "2.6.0"
293-
# ipex: true
294-
# ipex_version: "2.6.10+xpu"
295-
# - torch_version: "2.7.1"
296-
# ipex: true
297-
# ipex_version: "2.7.10+xpu"
298285
runs-on:
299286
group: bandb-itac-bmsprpvc1550-8-1gpu
300287
env:
@@ -330,10 +317,6 @@ jobs:
330317
- name: Install PyTorch
331318
run: pip install torch==${{ matrix.torch_version }} --index-url https://download.pytorch.org/whl/xpu
332319

333-
- name: Install IPEX
334-
if: matrix.ipex == true
335-
run: pip install intel_extension_for_pytorch==${{ matrix.ipex_version }} --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
336-
337320
- name: Install dependencies
338321
run: |
339322
pip install -e ".[test]"

CMakeLists.txt

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,12 @@ set(CUDA_FILES csrc/ops.cu csrc/kernels.cu)
2828
set(HIP_FILES csrc/ops.hip csrc/kernels.hip)
2929
set(MPS_FILES csrc/mps_ops.mm)
3030
set(METAL_FILES csrc/mps_kernels.metal)
31+
set(XPU_FILES csrc/xpu_ops.cpp csrc/xpu_kernels.cpp)
3132
# C++ sources are always included
3233
list(APPEND SRC_FILES ${CPP_FILES})
3334

34-
set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, hip, mps)")
35-
set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda hip mps)
35+
set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, hip, mps, xpu)")
36+
set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda hip mps xpu)
3637
option(PTXAS_VERBOSE "Pass through -v flag to PTX Assembler" OFF)
3738

3839
if(APPLE)
@@ -64,10 +65,18 @@ elseif(${COMPUTE_BACKEND} STREQUAL "mps")
6465
set(BUILD_CUDA OFF)
6566
set(BUILD_HIP OFF)
6667
set(BUILD_MPS ON)
68+
elseif(${COMPUTE_BACKEND} STREQUAL "xpu")
69+
if(APPLE)
70+
message(FATAL_ERROR "XPU is not supported on macOS" )
71+
endif()
72+
set(BUILD_CUDA OFF)
73+
set(BUILD_MPS OFF)
74+
set(BUILD_XPU ON)
6775
else()
6876
set(BUILD_CUDA OFF)
6977
set(BUILD_HIP OFF)
7078
set(BUILD_MPS OFF)
79+
set(BUILD_XPU OFF)
7180
endif()
7281

7382

@@ -217,6 +226,15 @@ elseif(BUILD_MPS)
217226
COMMENT "Compiling Metal kernels"
218227
VERBATIM)
219228
add_custom_target(metallib DEPENDS "bitsandbytes/bitsandbytes.metallib")
229+
elseif(BUILD_XPU)
230+
list(APPEND SRC_FILES ${XPU_FILES})
231+
string(APPEND BNB_OUTPUT_NAME "_xpu")
232+
add_compile_definitions(BUILD_XPU)
233+
set(CMAKE_C_COMPILER icx)
234+
set(CMAKE_CXX_COMPILER icpx)
235+
if(WIN32)
236+
set(CMAKE_CXX_COMPILER icx)
237+
endif()
220238
else()
221239
string(APPEND BNB_OUTPUT_NAME "_cpu")
222240
set(GPU_SOURCES)
@@ -285,6 +303,15 @@ if(BUILD_MPS)
285303
add_dependencies(bitsandbytes metallib)
286304
target_link_libraries(bitsandbytes objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph")
287305
endif()
306+
if(BUILD_XPU)
307+
set(SYCL_LINK_FLAGS "-fsycl;--offload-compress;-fsycl-targets=spir64_gen,spir64;-Xs;-device pvc,xe-lpg,ats-m150 -options ' -cl-intel-enable-auto-large-GRF-mode -cl-poison-unsupported-fp64-kernels -cl-intel-greater-than-4GB-buffer-required'")
308+
set(SYCL_COMPILE_FLAGS "-fsycl;-fhonor-nans;-fhonor-infinities;-fno-associative-math;-fno-approx-func;-fno-sycl-instrument-device-code;--offload-compress;-fsycl-targets=spir64_gen,spir64;")
309+
310+
set_property(TARGET bitsandbytes PROPERTY CXX_STANDARD 20)
311+
target_compile_options(bitsandbytes PRIVATE ${SYCL_COMPILE_FLAGS})
312+
target_link_options(bitsandbytes PRIVATE ${SYCL_LINK_FLAGS})
313+
314+
endif()
288315

289316
if(WIN32)
290317
set_target_properties(bitsandbytes PROPERTIES PREFIX "lib")

_typos.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
11
[files]
2+
# Skip these files in typo checks
3+
extend-exclude = [
4+
"csrc/xpu_ops.h",
5+
"csrc/xpu_ops.cpp",
6+
"csrc/xpu_kernels.h",
7+
"csrc/xpu_kernels.cpp"
8+
]
29

310
[default]
411
extend-ignore-re = [

bitsandbytes/_ops.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44

55
import torch
66

7-
from .cextension import ipex_cpu, ipex_xpu
8-
97
_IS_TORCH_GTE_24 = False
108

119
if hasattr(torch.library, "register_fake"):
@@ -331,25 +329,6 @@ def _(
331329
torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}")
332330

333331

334-
if ipex_cpu or ipex_xpu:
335-
# Register the dequantize_nf4_ipex implementation
336-
torch.library.define(
337-
"bitsandbytes::dequantize_nf4_ipex",
338-
"(Tensor A, Tensor absmax, int blocksize, int[] shape, ScalarType dtype) -> Tensor",
339-
)
340-
341-
@register_fake("bitsandbytes::dequantize_nf4_ipex")
342-
def _(
343-
A: torch.Tensor,
344-
absmax: torch.Tensor,
345-
blocksize: int,
346-
shape: Sequence[int],
347-
dtype: torch.dtype,
348-
) -> torch.Tensor:
349-
torch._check_is_size(blocksize)
350-
return torch.empty(shape, dtype=dtype, device=A.device)
351-
352-
353332
torch.library.define(
354333
"bitsandbytes::optimizer_update_32bit",
355334
"(str optimizer_name, Tensor(a0!) g, Tensor(a1!) p, Tensor(a2!) state1, Tensor(a3!)? state2, Tensor(a4!)? unorm_vec, float max_unorm, float param_norm, float beta1, float beta2, float beta3, float alpha, float eps, float weight_decay, int step, float lr, float gnorm_scale, bool skip_zeros=False) -> ()",

bitsandbytes/autograd/_functions.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from typing_extensions import deprecated
99

1010
import bitsandbytes.functional as F
11-
from bitsandbytes.functional import ipex_cpu
1211

1312
# The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov:
1413
# https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py
@@ -320,8 +319,6 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
320319

321320
CB = state.CB.data.to(A.dtype).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
322321
output = torch.nn.functional.linear(A, CB, bias)
323-
# to pass the test: tests/test_modules.py::test_linear8bitlt_no_fp16_weights[2.0-xpu]
324-
state.idx = False
325322
ctx.state = state
326323
ctx.dtype_A = A.dtype
327324
ctx.grad_shape = A.shape
@@ -426,7 +423,7 @@ def matmul(
426423
state.threshold = threshold
427424
# MatMul8bitLt is slower because no fast kernel for quant/dequant 8bit in CPU/XPU
428425
if state.is_training:
429-
if (A.device.type == "cpu" and ipex_cpu) or (A.device.type == "xpu"):
426+
if A.device.type in ("cpu", "xpu"):
430427
return MatMul8bitFp.apply(A, B, out, bias, state)
431428
return MatMul8bitLt.apply(A, B, out, bias, state)
432429

@@ -440,17 +437,6 @@ def matmul_4bit(
440437
):
441438
assert quant_state is not None
442439

443-
if A.device.type in ("cpu", "xpu") and A.requires_grad == False:
444-
if getattr(quant_state, "ipex", False):
445-
# IPEX CPU will change weight to 4D so don't need transpose
446-
B = B.t() if B.dim() == 2 else B
447-
out = F.gemv_4bit(A, B, out, state=quant_state)
448-
if bias is not None:
449-
out += bias
450-
return out
451-
else:
452-
return MatMul4Bit.apply(A, B, out, bias, quant_state)
453-
454440
if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu":
455441
if A.shape[-1] % quant_state.blocksize != 0:
456442
warn(

bitsandbytes/backends/cpu/ops.py

Lines changed: 76 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1-
from collections.abc import Sequence
21
import ctypes as ct
2+
import logging
33

44
import torch
55

66
from bitsandbytes.functional import get_ptr
77

88
from ..._ops import register_kernel
9-
from ...cextension import lib
10-
from ..utils import ipex_cpu
9+
from ...cextension import ErrorHandlerMockBNBNativeLibrary, lib
10+
11+
logger = logging.getLogger(__name__)
1112

1213
# torch._int_mm for s8@s8->s32 is supported on CPU from torch 2.4+.
1314
# However, we can overflow if we use this without AVX512_VNNI support.
@@ -24,97 +25,77 @@ def _(A: torch.Tensor, B: torch.Tensor):
2425
).reshape(*A.shape[:-1], B.shape[0])
2526

2627

27-
@register_kernel("bitsandbytes::quantize_blockwise", "cpu")
28-
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
29-
torch._check_is_size(blocksize)
30-
31-
n = A.numel()
32-
33-
# Only FP32 has c++ kernrl
34-
if A.dtype == torch.float32:
35-
blocks = -(n // -blocksize)
36-
37-
absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
38-
out = torch.empty_like(A, dtype=torch.uint8)
39-
40-
lib.cquantize_blockwise_cpu_fp32(
41-
get_ptr(code),
42-
get_ptr(A),
43-
get_ptr(absmax),
44-
get_ptr(out),
45-
ct.c_longlong(blocksize),
46-
ct.c_longlong(n),
47-
)
48-
else:
49-
rem = n % blocksize
50-
has_rem = rem > 0
51-
blocks = n // blocksize + has_rem
52-
absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)
53-
A_reshaped = A.reshape(n)
54-
A_com = A_reshaped[: n - rem]
55-
A_com_reshaped = A_com.reshape(n // blocksize, blocksize)
56-
absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0]
57-
scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1)
58-
scaled_A = scaled_A.reshape(-1)
59-
if has_rem:
60-
absmax[-1] = torch.abs(A_reshaped[n - rem :]).max()
61-
scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1)
62-
scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0)
63-
64-
diff = torch.abs(scaled_A.unsqueeze(-1) - code.to(scaled_A.device))
65-
out = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device).reshape(A.shape)
66-
67-
return out, absmax
68-
69-
70-
@register_kernel("bitsandbytes::dequantize_blockwise", "cpu")
71-
def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor:
72-
torch._check_is_size(blocksize)
73-
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
74-
75-
# Only FP32 has c++ kernrl
76-
if dtype == torch.float32:
77-
out = torch.empty_like(A, dtype=dtype)
78-
79-
lib.cdequantize_blockwise_cpu_fp32(
80-
get_ptr(code),
81-
get_ptr(A),
82-
get_ptr(absmax),
83-
get_ptr(out),
84-
ct.c_longlong(blocksize),
85-
ct.c_longlong(A.numel()),
86-
)
87-
else:
88-
out = code[A.reshape(-1).int()]
89-
blocks = out.shape[-1] // blocksize
90-
res = out.shape[-1] % blocksize
91-
if res != 0:
92-
out = torch.nn.functional.pad(out, (0, blocksize - res), mode="constant", value=0)
93-
out = (out.view(-1, blocksize) * absmax.view(-1, 1)).to(dtype).reshape(-1)
94-
out = out[: blocks * blocksize + res]
95-
out = out.reshape(A.shape)
96-
97-
return out
98-
99-
100-
if ipex_cpu:
101-
from bitsandbytes.utils import _reverse_4bit_compress_format
102-
103-
@register_kernel("bitsandbytes::dequantize_nf4_ipex", "cpu")
28+
if not isinstance(lib, ErrorHandlerMockBNBNativeLibrary):
29+
30+
@register_kernel("bitsandbytes::quantize_blockwise", "cpu")
31+
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
32+
torch._check_is_size(blocksize)
33+
34+
n = A.numel()
35+
36+
# Only FP32 has c++ kernrl
37+
if A.dtype == torch.float32:
38+
blocks = -(n // -blocksize)
39+
40+
absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
41+
out = torch.empty_like(A, dtype=torch.uint8)
42+
43+
lib.cquantize_blockwise_cpu_fp32(
44+
get_ptr(code),
45+
get_ptr(A),
46+
get_ptr(absmax),
47+
get_ptr(out),
48+
ct.c_longlong(blocksize),
49+
ct.c_longlong(n),
50+
)
51+
else:
52+
rem = n % blocksize
53+
has_rem = rem > 0
54+
blocks = n // blocksize + has_rem
55+
absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)
56+
A_reshaped = A.reshape(n)
57+
A_com = A_reshaped[: n - rem]
58+
A_com_reshaped = A_com.reshape(n // blocksize, blocksize)
59+
absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0]
60+
scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1)
61+
scaled_A = scaled_A.reshape(-1)
62+
if has_rem:
63+
absmax[-1] = torch.abs(A_reshaped[n - rem :]).max()
64+
scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1)
65+
scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0)
66+
67+
diff = torch.abs(scaled_A.unsqueeze(-1) - code.to(scaled_A.device))
68+
out = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device).reshape(A.shape)
69+
70+
return out, absmax
71+
72+
@register_kernel("bitsandbytes::dequantize_blockwise", "cpu")
10473
def _(
105-
A: torch.Tensor,
106-
absmax: torch.Tensor,
107-
blocksize: int,
108-
shape: Sequence[int],
109-
dtype: torch.dtype,
74+
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype
11075
) -> torch.Tensor:
111-
ipex_weight = torch.ops.ipex_prepack.woq_linear_unpack_weight(A, "nf4", shape, 2)
112-
A = _reverse_4bit_compress_format(ipex_weight.reshape(-1)).reshape(1, -1)
113-
return torch.ops.bitsandbytes.dequantize_4bit.default(
114-
A,
115-
absmax,
116-
blocksize,
117-
"nf4",
118-
shape,
119-
dtype,
120-
)
76+
torch._check_is_size(blocksize)
77+
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
78+
79+
# Only FP32 has c++ kernrl
80+
if dtype == torch.float32:
81+
out = torch.empty_like(A, dtype=dtype)
82+
83+
lib.cdequantize_blockwise_cpu_fp32(
84+
get_ptr(code),
85+
get_ptr(A),
86+
get_ptr(absmax),
87+
get_ptr(out),
88+
ct.c_longlong(blocksize),
89+
ct.c_longlong(A.numel()),
90+
)
91+
else:
92+
out = code[A.reshape(-1).int()]
93+
blocks = out.shape[-1] // blocksize
94+
res = out.shape[-1] % blocksize
95+
if res != 0:
96+
out = torch.nn.functional.pad(out, (0, blocksize - res), mode="constant", value=0)
97+
out = (out.view(-1, blocksize) * absmax.view(-1, 1)).to(dtype).reshape(-1)
98+
out = out[: blocks * blocksize + res]
99+
out = out.reshape(A.shape)
100+
101+
return out

0 commit comments

Comments
 (0)