diff --git a/bitsandbytes/backends/xpu/ops.py b/bitsandbytes/backends/xpu/ops.py index 1c1422c35..6e877cff8 100644 --- a/bitsandbytes/backends/xpu/ops.py +++ b/bitsandbytes/backends/xpu/ops.py @@ -139,6 +139,13 @@ def _gemv_4bit_impl( if not isinstance(lib, ErrorHandlerMockBNBNativeLibrary): logger.info("Register sycl bitsandbytes kernels for XPU") + # TODO: Remove the triton register when quantization sycl kernel is ready. + if triton_available: + from ..triton import ops as triton_ops + + register_kernel("bitsandbytes::quantize_blockwise", "xpu")(triton_ops.quantize_blockwise) + register_kernel("bitsandbytes::quantize_4bit", "xpu")(triton_ops.quantize_4bit) + @register_kernel("bitsandbytes::dequantize_4bit", "xpu") def _( A: torch.Tensor,