44# This source code is licensed under the license found in the
55# LICENSE file in the root directory of this source tree.
66
7- from enum import Enum
7+ from enum import Enum , auto
88from typing import List , Optional , Tuple , Dict
99import torch
1010
1111from torchao .kernel .intmm import int_scaled_matmul
1212from torchao .kernel .intmm import safe_int_mm
13- from torchao .utils import TORCH_VERSION_AFTER_2_3
13+ from torchao .utils import (
14+ TORCH_VERSION_AFTER_2_3 ,
15+ TORCH_VERSION_AFTER_2_5 ,
16+ )
1417
1518
1619__all__ = [
@@ -34,17 +37,17 @@ class MappingType(Enum):
3437 based on this mapping
3538 e.g. scale = (10.2 - (-3.5)) / (7 - (-8))
3639 """
37- SYMMETRIC = 0
38- ASYMMETRIC = 1
40+ SYMMETRIC = auto ()
41+ ASYMMETRIC = auto ()
3942
4043class ZeroPointDomain (Enum ):
4144 """Enum that indicate whether zero_point is in integer domain or floating point domain
4245
4346 integer domain: quantized_val = (float_val / scale) (integer) + zero_point (integer)
4447 float domain: quantized_val = (float_val - (zero_point (float) - scale * mid_point)) / scale
4548 """
46- INT = 0
47- FLOAT = 1
49+ INT = auto ()
50+ FLOAT = auto ()
4851
4952"""
5053Map from dtype to the bound value of integers
@@ -69,6 +72,54 @@ class ZeroPointDomain(Enum):
6972 })
7073
7174
75+ quant_lib = torch .library .Library ("quant" , "FRAGMENT" )
76+
77+ def register_custom_op (lib ):
78+ """This decorator is used to preserve some high level operators for torch.export.export
79+ while still allow them to be decomposed for inductor path
80+
81+ requirement: make sure `fn.__name__[1:]` is the operator name you want to register
82+
83+ NOTE: This should be applied at the top, after all other decorators have been applied
84+ NOTE: We haven't tested the case when `fn` accepts tensor subclass instance as input,
85+ e.g. uint4 tensor subclass instance, and we'll probably need to figure out what would make
86+ sense for downstream system (like executorch) to accept as well
87+
88+ Example:
89+ lib = torch.library.Library("my_namespace', "FRAGMENT")
90+ @register_custom_op(lib)
91+ def _the_op_that_needs_to_be_preserved(...)
92+ ...
93+
94+ # after this, `_the_op_that_needs_to_be_preserved` will be preserved as
95+ # torch.ops.my_namespace.the_op_that_needs_to_be_preserved operator after
96+ # torch.export.export / torch._export.capture_pre_autograd_graph
97+
98+ """
99+ from torch ._inductor .decomposition import register_decomposition
100+
101+ def decorator (fn ):
102+ if TORCH_VERSION_AFTER_2_5 :
103+ from torch ._library .infer_schema import infer_schema
104+
105+ # expecting fn.__name__ starts with `_` and we want to take the rest
106+ # to be the name of the custom op
107+ assert fn .__name__ [0 ] == "_" , f"Expecting function name starts with `_`, got { fn .__name__ } "
108+ op_name = fn .__name__ [1 :]
109+ schema = op_name + infer_schema (fn )
110+ lib .define (schema )
111+ lib .impl (op_name , fn , "CompositeImplicitAutograd" )
112+
113+ lib_namespace = lib .ns
114+ op = getattr (getattr (torch .ops , lib_namespace ), op_name )
115+ register_decomposition ([op ])(fn )
116+ return op
117+ else :
118+ return fn
119+
120+ return decorator
121+
122+
72123# TODO: decide on if we want to allow custom quant_min/quant_max here
73124def _get_and_check_qmin_qmax (dtype , quant_min , quant_max ):
74125 """Get quant_min and quant_max args based on dtype and also
@@ -140,7 +191,7 @@ def quantize_affine(
140191 quant_min : Optional [int ] = None ,
141192 quant_max : Optional [int ] = None ,
142193 zero_point_domain : ZeroPointDomain = ZeroPointDomain .INT ,
143- ):
194+ ) -> torch . Tensor :
144195 """
145196 Args:
146197 input (torch.Tensor): original float32, float16 or bfloat16 Tensor
@@ -174,6 +225,31 @@ def quantize_affine(
174225 Output:
175226 quantized tensor with requested dtype
176227 """
228+ return _quantize_affine (
229+ input ,
230+ block_size ,
231+ scale ,
232+ zero_point ,
233+ output_dtype ,
234+ quant_min ,
235+ quant_max ,
236+ zero_point_domain .name ,
237+ )
238+
239+
240+ @register_custom_op (quant_lib )
241+ def _quantize_affine (
242+ input : torch .Tensor ,
243+ block_size : List [int ],
244+ scale : torch .Tensor ,
245+ zero_point : Optional [torch .Tensor ],
246+ output_dtype : torch .dtype ,
247+ quant_min : Optional [int ] = None ,
248+ quant_max : Optional [int ] = None ,
249+ zero_point_domain : str = "INT" ,
250+ ) -> torch .Tensor :
251+ """op definition that has compatible signatures with custom op library
252+ """
177253 # TODO: validations
178254 # TODO: validate scale/zero_point dimensions are compatible with block_size
179255 assert input .dtype in [torch .float32 , torch .float16 , torch .bfloat16 ], f"Unsupported input dtype: { input .dtype } "
@@ -188,12 +264,12 @@ def quantize_affine(
188264 if zero_point is not None :
189265 zero_point = zero_point .view (shape_after_reduction )
190266
191- if zero_point_domain == ZeroPointDomain .INT :
267+ if zero_point_domain == ZeroPointDomain .INT . name :
192268 quant = torch .clamp (
193269 torch .round (input * (1.0 / scale )) + zero_point , quant_min , quant_max
194270 ).to (output_dtype )
195271 else :
196- assert zero_point_domain == ZeroPointDomain .FLOAT
272+ assert zero_point_domain == ZeroPointDomain .FLOAT . name
197273 mid_point = (quant_max + quant_min + 1 ) / 2
198274 min_val = zero_point - scale * mid_point
199275 quant = (
@@ -216,7 +292,7 @@ def dequantize_affine(
216292 zero_point_domain : ZeroPointDomain = ZeroPointDomain .INT ,
217293 * ,
218294 output_dtype : torch .dtype = torch .float32 ,
219- ):
295+ ) -> torch . Tensor :
220296 """
221297 Args:
222298 input (torch.Tensor): quantized tensor, should match the dtype `dtype` argument
@@ -238,6 +314,34 @@ def dequantize_affine(
238314 Output:
239315 dequantized Tensor, with requested dtype or fp32
240316 """
317+ return _dequantize_affine (
318+ input ,
319+ block_size ,
320+ scale ,
321+ zero_point ,
322+ input_dtype ,
323+ quant_min ,
324+ quant_max ,
325+ zero_point_domain .name ,
326+ output_dtype = output_dtype ,
327+ )
328+
329+
330+ # @register_custom_op(quant_lib, 'dequantize_affine(Tensor input, int[] block_size, Tensor scale, Tensor zero_point, ScalarType input_dtype, int? quant_min=None, int? quant_max=None, str zero_point_domain="INT", ScalarType output_dtype=float) -> Tensor')
331+ @register_custom_op (quant_lib )
332+ def _dequantize_affine (
333+ input : torch .Tensor ,
334+ block_size : List [int ],
335+ scale : torch .Tensor ,
336+ zero_point : Optional [torch .Tensor ],
337+ input_dtype : torch .dtype ,
338+ quant_min : Optional [int ] = None ,
339+ quant_max : Optional [int ] = None ,
340+ zero_point_domain : str = "INT" ,
341+ output_dtype : torch .dtype = torch .float32 ,
342+ ) -> torch .Tensor :
343+ """op definition that has compatible signatures with custom op library
344+ """
241345
242346 # TODO: validations
243347 # TODO: validate scale/zero_point dimensions are compatible with block_size
@@ -255,16 +359,16 @@ def dequantize_affine(
255359 if zero_point is not None :
256360 zero_point = zero_point .view (shape_after_reduction )
257361
258- if zero_point_domain == ZeroPointDomain .INT :
362+ if zero_point_domain == ZeroPointDomain .INT . name :
259363 # Force a copy to avoid input modification due
260364 # to upcoming in-place operations.
261365 dequant = input .to (torch .int32 , copy = True )
262366 if zero_point is not None :
263- dequant -= zero_point .to (torch .int32 )
367+ dequant = dequant - zero_point .to (torch .int32 )
264368 dequant = dequant .to (output_dtype )
265- dequant *= scale
369+ dequant = dequant * scale
266370 else :
267- assert zero_point_domain == ZeroPointDomain .FLOAT , f"Unexpected zero point domain: { zero_point_domain } "
371+ assert zero_point_domain == ZeroPointDomain .FLOAT . name , f"Unexpected zero point domain: { zero_point_domain } "
268372 mid_point = (quant_max + quant_min + 1 ) / 2
269373 # This should allocate new memory and avoid input modification
270374 dequant = input - mid_point
@@ -320,8 +424,39 @@ def choose_qparams_affine(
320424 Output:
321425 Tuple of scales and zero_points Tensor with requested dtype
322426 """
427+ return _choose_qparams_affine (
428+ input ,
429+ mapping_type .name ,
430+ block_size ,
431+ target_dtype ,
432+ quant_min ,
433+ quant_max ,
434+ eps ,
435+ scale_dtype ,
436+ zero_point_dtype ,
437+ preserve_zero ,
438+ zero_point_domain .name
439+ )
440+
441+ # @register_custom_op(quant_lib, 'choose_qparams_affine(Tensor input, str mapping_type, int[] block_size, ScalarType target_dtype, int? quant_min=None, int? quant_max=None, float? eps=None, ScalarType? scale_dtype=None, ScalarType? zero_point_dtype=None, bool preserve_zero=True, str zero_point_domain="INT") -> (Tensor, Tensor)')
442+ @register_custom_op (quant_lib )
443+ def _choose_qparams_affine (
444+ input : torch .Tensor ,
445+ mapping_type : str ,
446+ block_size : List [int ],
447+ target_dtype : torch .dtype ,
448+ quant_min : Optional [int ] = None ,
449+ quant_max : Optional [int ] = None ,
450+ eps : Optional [float ] = None ,
451+ scale_dtype : Optional [torch .dtype ] = None ,
452+ zero_point_dtype : Optional [torch .dtype ] = None ,
453+ preserve_zero : bool = True ,
454+ zero_point_domain : str = "INT" ,
455+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
456+ """op definition that has compatible signatures with custom op library
457+ """
323458 quant_min , quant_max = _get_and_check_qmin_qmax (target_dtype , quant_min , quant_max )
324- assert mapping_type in [MappingType .SYMMETRIC , MappingType .ASYMMETRIC ], f"Unsupported mapping type: { mapping_type } "
459+ assert mapping_type in [MappingType .SYMMETRIC . name , MappingType .ASYMMETRIC . name ], f"Unsupported mapping type: { mapping_type } "
325460
326461 if scale_dtype is None :
327462 scale_dtype = input .dtype
@@ -342,21 +477,22 @@ def choose_qparams_affine(
342477 min_val_neg = min_val
343478 max_val_pos = max_val
344479
345- if mapping_type == MappingType .SYMMETRIC :
480+ if mapping_type == MappingType .SYMMETRIC . name :
346481 max_val_pos = torch .max (- min_val_neg , max_val_pos )
347482 scale = max_val_pos / (float (quant_max - quant_min ) / 2 )
348483 if not preserve_zero :
349484 raise ValueError ("preserve_zero == False is not supported for symmetric quantization" )
350- if zero_point_domain != ZeroPointDomain .INT :
485+ if zero_point_domain != ZeroPointDomain .INT . name :
351486 raise ValueError ("zero_point_domain != ZeroPointDomain.INT is not supported for symmetric quantization" )
352487 zero_point = torch .full_like (scale , int ((quant_max + quant_min + 1 ) / 2 ))
353488 else :
489+ assert mapping_type == MappingType .ASYMMETRIC .name
354490 scale = (max_val_pos - min_val_neg ) / float (quant_max - quant_min )
355491 if preserve_zero :
356492 zero_point = quant_min - torch .round (min_val_neg / scale )
357493 zero_point = torch .clamp (zero_point , quant_min , quant_max )
358494 else :
359- assert zero_point_domain == ZeroPointDomain .FLOAT , "if not preserve_zero, zero_point must be in FLOAT domain"
495+ assert zero_point_domain == ZeroPointDomain .FLOAT . name , "if not preserve_zero, zero_point must be in FLOAT domain"
360496 mid_point = (quant_max + quant_min + 1 ) / 2
361497 zero_point = min_val_neg + scale * mid_point
362498
0 commit comments