99# from bitsandbytes.functional import get_4bit_type
1010# _FP4_QUANT_TABLE = get_4bit_type("fp4", device="xpu")
1111# _NF4_QUANT_TABLE = get_4bit_type("nf4", device="xpu")
12+ device_type = torch .accelerator .current_accelerator ().type if hasattr (torch , "accelerator" ) else "cuda"
13+ torch_accelerator_module = getattr (torch , device_type , torch .cuda )
1214
1315
1416def quantize_blockwise (A : torch .Tensor , code : torch .Tensor , blocksize : int ) -> tuple [torch .Tensor , torch .Tensor ]:
@@ -21,7 +23,9 @@ def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> t
2123 absmax = torch .empty ((blocks ,), device = A .device , dtype = A .dtype )
2224 out = torch .empty_like (A .flatten (), dtype = torch .uint8 )
2325
24- triton_kernels .quantize_blockwise_triton (A , blocksize , code , blocks , absmax , out )
26+ with torch_accelerator_module .device (A .device ):
27+ triton_kernels .quantize_blockwise_triton (A , blocksize , code , blocks , absmax , out )
28+
2529 out = out .reshape (A .shape )
2630
2731 return out , absmax .float ()
@@ -35,13 +39,14 @@ def dequantize_blockwise(
3539 # torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on xpu, got {dtype}")
3640
3741 out = torch .empty_like (A , dtype = dtype , device = A .device )
38- triton_kernels .dequant_int8_blockwise (
39- A ,
40- code ,
41- absmax ,
42- out ,
43- blocksize ,
44- )
42+ with torch_accelerator_module .device (A .device ):
43+ triton_kernels .dequant_int8_blockwise (
44+ A ,
45+ code ,
46+ absmax ,
47+ out ,
48+ blocksize ,
49+ )
4550
4651 return out
4752
@@ -55,13 +60,14 @@ def dequantize_blockwise_inplace(
5560 torch ._check (out .device == A .device , lambda : f"Expected out.device == { A .device } , got { out .device } " )
5661 torch ._check (out .dtype == dtype , lambda : f"Expected out.dtype == { dtype } , got { out .dtype } " )
5762
58- triton_kernels .dequant_int8_blockwise (
59- A ,
60- code ,
61- absmax ,
62- out ,
63- blocksize ,
64- )
63+ with torch_accelerator_module .device (A .device ):
64+ triton_kernels .dequant_int8_blockwise (
65+ A ,
66+ code ,
67+ absmax ,
68+ out ,
69+ blocksize ,
70+ )
6571
6672
6773def quantize_4bit (
@@ -84,9 +90,10 @@ def quantize_4bit(
8490 absmax = torch .empty ((blocks * 2 ,), device = A .device , dtype = A .dtype )
8591 out = torch .empty ((n // 2 , 1 ), device = A .device , dtype = torch .uint8 )
8692
87- triton_kernels .quantize_4bit_blockwise_triton (
88- A , blocksize , quant_type , blocks , absmax , num_elements = n , quantized_out = out
89- )
93+ with torch_accelerator_module .device (A .device ):
94+ triton_kernels .quantize_4bit_blockwise_triton (
95+ A , blocksize , quant_type , blocks , absmax , num_elements = n , quantized_out = out
96+ )
9097 packed = out
9198
9299 if quant_storage != torch .uint8 :
@@ -119,7 +126,9 @@ def dequantize_4bit(
119126
120127 out = torch .empty (shape , dtype = dtype , device = A .device )
121128
122- triton_kernels ._dequantize_4bit_impl (A , absmax , blocksize , quant_type , dtype , out = out )
129+ with torch_accelerator_module .device (A .device ):
130+ triton_kernels ._dequantize_4bit_impl (A , absmax , blocksize , quant_type , dtype , out = out )
131+
123132 return out
124133
125134
@@ -134,7 +143,8 @@ def dequantize_4bit_inplace(
134143) -> None :
135144 torch ._check (out .shape == shape , lambda : f"Expected out.shape == { shape } , got { out .shape } " )
136145 torch ._check (out .dtype == dtype , lambda : f"Expected out.dtype == { dtype } , got { out .dtype } " )
137- triton_kernels ._dequantize_4bit_impl (A , absmax , blocksize , quant_type , dtype , out = out )
146+ with torch_accelerator_module .device (A .device ):
147+ triton_kernels ._dequantize_4bit_impl (A , absmax , blocksize , quant_type , dtype , out = out )
138148
139149
140150def gemv_4bit (
@@ -150,14 +160,15 @@ def gemv_4bit(
150160
151161 B_dq_triton = torch .empty (shapeB , dtype = A .dtype , device = A .device )
152162
153- triton_kernels ._dequantize_4bit_impl_passing_code (
154- B ,
155- absmax ,
156- blocksize ,
157- code ,
158- dtype = A .dtype ,
159- out = B_dq_triton ,
160- )
163+ with torch_accelerator_module .device (A .device ):
164+ triton_kernels ._dequantize_4bit_impl_passing_code (
165+ B ,
166+ absmax ,
167+ blocksize ,
168+ code ,
169+ dtype = A .dtype ,
170+ out = B_dq_triton ,
171+ )
161172
162173 return torch .nn .functional .linear (
163174 A ,
0 commit comments