44# This source code is licensed under the license found in the
55# LICENSE file in the root directory of this source tree.
66
7- from enum import Enum
7+ from enum import Enum , auto
88from typing import List , Optional , Tuple , Dict
99import torch
1010
1111from torchao .kernel .intmm import int_scaled_matmul
1212from torchao .kernel .intmm import safe_int_mm
13- from torchao .utils import TORCH_VERSION_AFTER_2_3
13+ from torchao .utils import (
14+ TORCH_VERSION_AFTER_2_3 ,
15+ TORCH_VERSION_AFTER_2_5 ,
16+ )
1417
1518
1619__all__ = [
@@ -34,17 +37,17 @@ class MappingType(Enum):
3437 based on this mapping
3538 e.g. scale = (10.2 - (-3.5)) / (7 - (-8))
3639 """
37- SYMMETRIC = 0
38- ASYMMETRIC = 1
40+ SYMMETRIC = auto ()
41+ ASYMMETRIC = auto ()
3942
4043class ZeroPointDomain (Enum ):
4144 """Enum that indicate whether zero_point is in integer domain or floating point domain
4245
4346 integer domain: quantized_val = (float_val / scale) (integer) + zero_point (integer)
4447 float domain: quantized_val = (float_val - (zero_point (float) - scale * mid_point)) / scale
4548 """
46- INT = 0
47- FLOAT = 1
49+ INT = auto ()
50+ FLOAT = auto ()
4851
4952"""
5053Map from dtype to the bound value of integers
@@ -69,6 +72,20 @@ class ZeroPointDomain(Enum):
6972 })
7073
7174
75+ def register_custom_op (name : str ):
76+ from torch ._inductor .decomposition import register_decomposition
77+
78+ def decorator (fn ):
79+ if TORCH_VERSION_AFTER_2_5 :
80+ opdef = torch .library .custom_op (name , mutates_args = ())(fn )
81+ opdef .register_fake (fn )
82+ register_decomposition ([opdef ._opoverload ])(fn )
83+ return opdef
84+ else :
85+ return fn
86+
87+ return decorator
88+
7289# TODO: decide on if we want to allow custom quant_min/quant_max here
7390def _get_and_check_qmin_qmax (dtype , quant_min , quant_max ):
7491 """Get quant_min and quant_max args based on dtype and also
@@ -140,7 +157,7 @@ def quantize_affine(
140157 quant_min : Optional [int ] = None ,
141158 quant_max : Optional [int ] = None ,
142159 zero_point_domain : ZeroPointDomain = ZeroPointDomain .INT ,
143- ):
160+ ) -> torch . Tensor :
144161 """
145162 Args:
146163 input (torch.Tensor): original float32, float16 or bfloat16 Tensor
@@ -174,6 +191,31 @@ def quantize_affine(
174191 Output:
175192 quantized tensor with requested dtype
176193 """
194+ return _quantize_affine (
195+ input ,
196+ block_size ,
197+ scale ,
198+ zero_point ,
199+ output_dtype ,
200+ quant_min ,
201+ quant_max ,
202+ zero_point_domain .name ,
203+ )
204+
205+
206+ @register_custom_op ("quant::quantize_affine" )
207+ def _quantize_affine (
208+ input : torch .Tensor ,
209+ block_size : List [int ],
210+ scale : torch .Tensor ,
211+ zero_point : Optional [torch .Tensor ],
212+ output_dtype : torch .dtype ,
213+ quant_min : Optional [int ] = None ,
214+ quant_max : Optional [int ] = None ,
215+ zero_point_domain : str = "INT" ,
216+ ) -> torch .Tensor :
217+ """op definition that has compatible signatures with custom op library
218+ """
177219 # TODO: validations
178220 # TODO: validate scale/zero_point dimensions are compatible with block_size
179221 assert input .dtype in [torch .float32 , torch .float16 , torch .bfloat16 ], f"Unsupported input dtype: { input .dtype } "
@@ -188,12 +230,12 @@ def quantize_affine(
188230 if zero_point is not None :
189231 zero_point = zero_point .view (shape_after_reduction )
190232
191- if zero_point_domain == ZeroPointDomain .INT :
233+ if zero_point_domain == ZeroPointDomain .INT . name :
192234 quant = torch .clamp (
193235 torch .round (input * (1.0 / scale )) + zero_point , quant_min , quant_max
194236 ).to (output_dtype )
195237 else :
196- assert zero_point_domain == ZeroPointDomain .FLOAT
238+ assert zero_point_domain == ZeroPointDomain .FLOAT . name
197239 mid_point = (quant_max + quant_min + 1 ) / 2
198240 min_val = zero_point - scale * mid_point
199241 quant = (
@@ -216,7 +258,7 @@ def dequantize_affine(
216258 zero_point_domain : ZeroPointDomain = ZeroPointDomain .INT ,
217259 * ,
218260 output_dtype : torch .dtype = torch .float32 ,
219- ):
261+ ) -> torch . Tensor :
220262 """
221263 Args:
222264 input (torch.Tensor): quantized tensor, should match the dtype `dtype` argument
@@ -238,6 +280,34 @@ def dequantize_affine(
238280 Output:
239281 dequantized Tensor, with requested dtype or fp32
240282 """
283+ return _dequantize_affine (
284+ input ,
285+ block_size ,
286+ scale ,
287+ zero_point ,
288+ input_dtype ,
289+ quant_min ,
290+ quant_max ,
291+ zero_point_domain .name ,
292+ output_dtype = output_dtype ,
293+ )
294+
295+
296+ @register_custom_op ("quant::dequantize_affine" )
297+ def _dequantize_affine (
298+ input : torch .Tensor ,
299+ block_size : List [int ],
300+ scale : torch .Tensor ,
301+ zero_point : Optional [torch .Tensor ],
302+ input_dtype : torch .dtype ,
303+ quant_min : Optional [int ] = None ,
304+ quant_max : Optional [int ] = None ,
305+ zero_point_domain : str = "INT" ,
306+ * ,
307+ output_dtype : torch .dtype = torch .float32 ,
308+ ) -> torch .Tensor :
309+ """op definition that has compatible signatures with custom op library
310+ """
241311
242312 # TODO: validations
243313 # TODO: validate scale/zero_point dimensions are compatible with block_size
@@ -255,16 +325,16 @@ def dequantize_affine(
255325 if zero_point is not None :
256326 zero_point = zero_point .view (shape_after_reduction )
257327
258- if zero_point_domain == ZeroPointDomain .INT :
328+ if zero_point_domain == ZeroPointDomain .INT . name :
259329 # Force a copy to avoid input modification due
260330 # to upcoming in-place operations.
261331 dequant = input .to (torch .int32 , copy = True )
262332 if zero_point is not None :
263- dequant -= zero_point .to (torch .int32 )
333+ dequant = dequant - zero_point .to (torch .int32 )
264334 dequant = dequant .to (output_dtype )
265- dequant *= scale
335+ dequant = dequant * scale
266336 else :
267- assert zero_point_domain == ZeroPointDomain .FLOAT , f"Unexpected zero point domain: { zero_point_domain } "
337+ assert zero_point_domain == ZeroPointDomain .FLOAT . name , f"Unexpected zero point domain: { zero_point_domain } "
268338 mid_point = (quant_max + quant_min + 1 ) / 2
269339 # This should allocate new memory and avoid input modification
270340 dequant = input - mid_point
@@ -320,8 +390,38 @@ def choose_qparams_affine(
320390 Output:
321391 Tuple of scales and zero_points Tensor with requested dtype
322392 """
393+ return _choose_qparams_affine (
394+ input ,
395+ mapping_type .name ,
396+ block_size ,
397+ target_dtype ,
398+ quant_min ,
399+ quant_max ,
400+ eps ,
401+ scale_dtype ,
402+ zero_point_dtype ,
403+ preserve_zero ,
404+ zero_point_domain .name
405+ )
406+
407+ @register_custom_op ("quant::choose_qparams_affine" )
408+ def _choose_qparams_affine (
409+ input : torch .Tensor ,
410+ mapping_type : str ,
411+ block_size : List [int ],
412+ target_dtype : torch .dtype ,
413+ quant_min : Optional [int ] = None ,
414+ quant_max : Optional [int ] = None ,
415+ eps : Optional [float ] = None ,
416+ scale_dtype : Optional [torch .dtype ] = None ,
417+ zero_point_dtype : Optional [torch .dtype ] = None ,
418+ preserve_zero : bool = True ,
419+ zero_point_domain : str = "INT" ,
420+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
421+ """op definition that has compatible signatures with custom op library
422+ """
323423 quant_min , quant_max = _get_and_check_qmin_qmax (target_dtype , quant_min , quant_max )
324- assert mapping_type in [MappingType .SYMMETRIC , MappingType .ASYMMETRIC ], f"Unsupported mapping type: { mapping_type } "
424+ assert mapping_type in [MappingType .SYMMETRIC . name , MappingType .ASYMMETRIC . name ], f"Unsupported mapping type: { mapping_type } "
325425
326426 if scale_dtype is None :
327427 scale_dtype = input .dtype
@@ -342,21 +442,22 @@ def choose_qparams_affine(
342442 min_val_neg = min_val
343443 max_val_pos = max_val
344444
345- if mapping_type == MappingType .SYMMETRIC :
445+ if mapping_type == MappingType .SYMMETRIC . name :
346446 max_val_pos = torch .max (- min_val_neg , max_val_pos )
347447 scale = max_val_pos / (float (quant_max - quant_min ) / 2 )
348448 if not preserve_zero :
349449 raise ValueError ("preserve_zero == False is not supported for symmetric quantization" )
350- if zero_point_domain != ZeroPointDomain .INT :
450+ if zero_point_domain != ZeroPointDomain .INT . name :
351451 raise ValueError ("zero_point_domain != ZeroPointDomain.INT is not supported for symmetric quantization" )
352452 zero_point = torch .full_like (scale , int ((quant_max + quant_min + 1 ) / 2 ))
353453 else :
454+ assert mapping_type == MappingType .ASYMMETRIC .name
354455 scale = (max_val_pos - min_val_neg ) / float (quant_max - quant_min )
355456 if preserve_zero :
356457 zero_point = quant_min - torch .round (min_val_neg / scale )
357458 zero_point = torch .clamp (zero_point , quant_min , quant_max )
358459 else :
359- assert zero_point_domain == ZeroPointDomain .FLOAT , "if not preserve_zero, zero_point must be in FLOAT domain"
460+ assert zero_point_domain == ZeroPointDomain .FLOAT . name , "if not preserve_zero, zero_point must be in FLOAT domain"
360461 mid_point = (quant_max + quant_min + 1 ) / 2
361462 zero_point = min_val_neg + scale * mid_point
362463
0 commit comments