Skip to content

Commit bdcee0f

Browse files
authored
fix triton kernel on the correct device (#1691)
Signed-off-by: jiqing-feng <[email protected]>
1 parent 6d0a5cd commit bdcee0f

File tree

1 file changed

+39
-28
lines changed
  • bitsandbytes/backends/triton

1 file changed

+39
-28
lines changed

bitsandbytes/backends/triton/ops.py

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
# from bitsandbytes.functional import get_4bit_type
1010
# _FP4_QUANT_TABLE = get_4bit_type("fp4", device="xpu")
1111
# _NF4_QUANT_TABLE = get_4bit_type("nf4", device="xpu")
12+
device_type = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
13+
torch_accelerator_module = getattr(torch, device_type, torch.cuda)
1214

1315

1416
def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
@@ -21,7 +23,9 @@ def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> t
2123
absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype)
2224
out = torch.empty_like(A.flatten(), dtype=torch.uint8)
2325

24-
triton_kernels.quantize_blockwise_triton(A, blocksize, code, blocks, absmax, out)
26+
with torch_accelerator_module.device(A.device):
27+
triton_kernels.quantize_blockwise_triton(A, blocksize, code, blocks, absmax, out)
28+
2529
out = out.reshape(A.shape)
2630

2731
return out, absmax.float()
@@ -35,13 +39,14 @@ def dequantize_blockwise(
3539
# torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on xpu, got {dtype}")
3640

3741
out = torch.empty_like(A, dtype=dtype, device=A.device)
38-
triton_kernels.dequant_int8_blockwise(
39-
A,
40-
code,
41-
absmax,
42-
out,
43-
blocksize,
44-
)
42+
with torch_accelerator_module.device(A.device):
43+
triton_kernels.dequant_int8_blockwise(
44+
A,
45+
code,
46+
absmax,
47+
out,
48+
blocksize,
49+
)
4550

4651
return out
4752

@@ -55,13 +60,14 @@ def dequantize_blockwise_inplace(
5560
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
5661
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
5762

58-
triton_kernels.dequant_int8_blockwise(
59-
A,
60-
code,
61-
absmax,
62-
out,
63-
blocksize,
64-
)
63+
with torch_accelerator_module.device(A.device):
64+
triton_kernels.dequant_int8_blockwise(
65+
A,
66+
code,
67+
absmax,
68+
out,
69+
blocksize,
70+
)
6571

6672

6773
def quantize_4bit(
@@ -84,9 +90,10 @@ def quantize_4bit(
8490
absmax = torch.empty((blocks * 2,), device=A.device, dtype=A.dtype)
8591
out = torch.empty((n // 2, 1), device=A.device, dtype=torch.uint8)
8692

87-
triton_kernels.quantize_4bit_blockwise_triton(
88-
A, blocksize, quant_type, blocks, absmax, num_elements=n, quantized_out=out
89-
)
93+
with torch_accelerator_module.device(A.device):
94+
triton_kernels.quantize_4bit_blockwise_triton(
95+
A, blocksize, quant_type, blocks, absmax, num_elements=n, quantized_out=out
96+
)
9097
packed = out
9198

9299
if quant_storage != torch.uint8:
@@ -119,7 +126,9 @@ def dequantize_4bit(
119126

120127
out = torch.empty(shape, dtype=dtype, device=A.device)
121128

122-
triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
129+
with torch_accelerator_module.device(A.device):
130+
triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
131+
123132
return out
124133

125134

@@ -134,7 +143,8 @@ def dequantize_4bit_inplace(
134143
) -> None:
135144
torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}")
136145
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
137-
triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
146+
with torch_accelerator_module.device(A.device):
147+
triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
138148

139149

140150
def gemv_4bit(
@@ -150,14 +160,15 @@ def gemv_4bit(
150160

151161
B_dq_triton = torch.empty(shapeB, dtype=A.dtype, device=A.device)
152162

153-
triton_kernels._dequantize_4bit_impl_passing_code(
154-
B,
155-
absmax,
156-
blocksize,
157-
code,
158-
dtype=A.dtype,
159-
out=B_dq_triton,
160-
)
163+
with torch_accelerator_module.device(A.device):
164+
triton_kernels._dequantize_4bit_impl_passing_code(
165+
B,
166+
absmax,
167+
blocksize,
168+
code,
169+
dtype=A.dtype,
170+
out=B_dq_triton,
171+
)
161172

162173
return torch.nn.functional.linear(
163174
A,

0 commit comments

Comments
 (0)