44# This source code is licensed under the license found in the
55# LICENSE file in the root directory of this source tree.
66
7- from typing import Any , List , Optional , Tuple
7+ from typing import Any , Optional
88
99import torch
1010import torch .nn .functional as F
11- from torch .ao .quantization .fx ._decomposed import quantized_decomposed_lib
12- from torch .library import impl
1311
1412from torchao .quantization .GPTQ import (
1513 _check_linear_int4_k ,
2018 Int8DynActInt4WeightLinear ,
2119 WeightOnlyInt4Linear ,
2220)
23- from torchao .quantization .quant_primitives import (
24- fake_quantize_affine_cachemask ,
25- ZeroPointDomain ,
26- )
21+ from torchao .quantization .quant_primitives import ZeroPointDomain
2722from torchao .quantization .unified import TwoStepQuantizer
28- from torchao .quantization .utils import (
29- _get_per_token_block_size ,
30- get_group_qparams_symmetric ,
23+ from torchao .quantization .utils import get_group_qparams_symmetric
24+ from .utils import (
25+ _choose_qparams_per_token_asymmetric ,
26+ _fake_quantize_per_channel_group ,
27+ _fake_quantize_per_token ,
3128)
3229
3330
@@ -163,7 +160,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
163160 x , self .scales_precision , self .zero_points_precision ,
164161 )
165162 (act_qmin , act_qmax ) = self ._get_qmin_qmax (8 )
166- x_fq = fake_quantize_per_token (
163+ x_fq = _fake_quantize_per_token (
167164 x , act_scales , act_zp , act_qmin , act_qmax ,
168165 )
169166 else :
@@ -177,7 +174,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
177174 # TODO: pass zp dtype to `get_group_qparams_symmetric` instead
178175 weight_zp = weight_zp .to (self .zero_points_precision )
179176 (weight_qmin , weight_qmax ) = self ._get_qmin_qmax (4 )
180- w_fq = fake_quantize_per_channel_group (
177+ w_fq = _fake_quantize_per_channel_group (
181178 self .weight ,
182179 weight_scales ,
183180 weight_zp ,
@@ -349,7 +346,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
349346 scales , zero_points = get_groupwise_affine_qparams (
350347 self .weight , n_bit , self .groupsize , self .scales_precision ,
351348 )
352- w_fq = fake_quantize_per_channel_group (
349+ w_fq = _fake_quantize_per_channel_group (
353350 self .weight ,
354351 scales ,
355352 zero_points ,
@@ -373,135 +370,3 @@ def disable_4w_fake_quant(mod: torch.nn.Module):
373370 """
374371 if isinstance (mod , Int4WeightOnlyQATLinear ):
375372 mod .disable_fake_quant ()
376-
377-
378- # ========================
379- # | QUANT PRIMITIVES |
380- # ========================
381-
382- class _GenericFakeQuantize (torch .autograd .Function ):
383- """
384- Implementation of generic fake quantize with backward STE.
385-
386- With the appropriate input tensor shape, this can be used to express
387- grouped per channel fake quantize or per token fake quantize.
388- """
389-
390- @staticmethod
391- def forward (
392- ctx : torch .autograd .function .FunctionCtx ,
393- input : torch .Tensor ,
394- scales : torch .Tensor ,
395- zero_points : torch .Tensor ,
396- quant_min : int ,
397- quant_max : int ,
398- block_size : List [int ],
399- zero_point_domain : ZeroPointDomain = ZeroPointDomain .INT ,
400- ) -> torch .Tensor :
401- # Note: for bf16 inputs, casting them to fp32 has the unexpected
402- # side effect of reducing memory footprint significantly, presumably
403- # because bf16 * fp32 kernels are not as memory efficient
404- assert input .dtype == torch .float32
405- assert scales .dtype == torch .float32
406- assert zero_points .dtype == torch .int32
407-
408- (fq , mask ) = fake_quantize_affine_cachemask (
409- input ,
410- block_size ,
411- scales ,
412- zero_points ,
413- torch .int32 ,
414- quant_min ,
415- quant_max ,
416- zero_point_domain ,
417- )
418-
419- ctx .save_for_backward (mask )
420- return fq
421-
422- @staticmethod
423- def backward (ctx , gy ):
424- (mask ,) = ctx .saved_tensors
425- return gy * mask , None , None , None , None , None , None
426-
427- def fake_quantize_per_channel_group (
428- input : torch .Tensor ,
429- scales : torch .Tensor ,
430- zero_points : torch .Tensor ,
431- quant_min : int ,
432- quant_max : int ,
433- group_size : int ,
434- zero_point_domain : ZeroPointDomain = ZeroPointDomain .INT ,
435- ) -> torch .Tensor :
436- assert group_size > 1
437- assert input .shape [- 1 ] % group_size == 0
438- assert input .dim () == 2
439- block_size = (1 , group_size )
440- return _GenericFakeQuantize .apply (
441- input , scales , zero_points , quant_min , quant_max , block_size , zero_point_domain ,
442- )
443-
444- def fake_quantize_per_token (
445- input : torch .Tensor ,
446- scales : torch .Tensor ,
447- zero_points : torch .Tensor ,
448- quant_min : int ,
449- quant_max : int ,
450- ) -> torch .Tensor :
451- from torch .ao .quantization .fx ._decomposed import _per_token_quant_qparam_dim_check
452-
453- _per_token_quant_qparam_dim_check (input , scales , zero_points )
454- block_size = _get_per_token_block_size (input )
455- fq_input = input .to (torch .float32 )
456- fq = _GenericFakeQuantize .apply (
457- fq_input , scales , zero_points , quant_min , quant_max , block_size ,
458- )
459- return fq .reshape_as (input ).to (input .dtype )
460-
461- # TODO: This is copied from torch/ao/quantization/fx/_decomposed.py.
462- # The version in pytorch does not have backward support yet so we add
463- # it here for now until https://github.com/pytorch/pytorch/pull/123452
464- # is landed.
465- def _choose_qparams_per_token_asymmetric (
466- input : torch .Tensor ,
467- scales_precision : torch .dtype = torch .float32 ,
468- zero_points_precision : torch .dtype = torch .float32 ,
469- ) -> Tuple [torch .Tensor , torch .Tensor ]:
470- """Choose quantization parameters for per token quantization. This means for a N dimension Tensor
471- (M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize
472- every N elements with the same quantization parameter. The dimension for scales/zero_points
473- will be (M1 * M2 ... * Mn)
474-
475- Args:
476- input (torch.Tensor): original float32/float16 Tensor
477- scales_precision (torch.dtype): precision of returned scales
478- zero_points_precision (torch.dtype): precision of returned zero points
479-
480- Returns:
481- scales and zero_points, both float32 Tensors
482- """
483- # Based on https://github.com/google/XNNPACK/blob/df156f0cf3db5a4576cc711123eeb54915f82ffc/src/xnnpack/quantization.h#L18
484- qmin , qmax = - 128 , 127
485- min_val = torch .amin (input , dim = - 1 , keepdim = True )
486- max_val = torch .amax (input , dim = - 1 , keepdim = True )
487- min_val_neg = torch .min (min_val , torch .zeros_like (min_val ))
488- max_val_pos = torch .max (max_val , torch .zeros_like (max_val ))
489- eps = torch .finfo (torch .float32 ).eps # use xnnpack eps?
490-
491- # scale
492- scale = (max_val_pos - min_val_neg ) / float (qmax - qmin )
493- scale = scale .clamp (min = eps )
494-
495- # zero point
496- descaled_min = min_val_neg / scale
497- descaled_max = max_val_pos / scale
498- zero_point_from_min_error = qmin + descaled_min
499- zero_point_from_max_error = qmax + descaled_max
500- zero_point = torch .where (
501- zero_point_from_min_error + zero_point_from_max_error > 0 ,
502- qmin - descaled_min ,
503- qmax - descaled_max ,
504- )
505- zero_point = torch .clamp (zero_point , qmin , qmax ).round ()
506-
507- return scale .to (scales_precision ), zero_point .to (zero_points_precision )
0 commit comments