1818 FP8_TYPES ,
1919 MappingType ,
2020 ZeroPointDomain ,
21+ _choose_qparams_affine_dont_preserve_zero ,
22+ _choose_qparams_affine_float8 ,
23+ _choose_qparams_affine_floatx ,
24+ _choose_qparams_affine_tinygemm ,
25+ _choose_qparams_and_quantize_affine_hqq ,
26+ _dequantize_affine_float8 ,
27+ _dequantize_affine_floatx ,
28+ _dequantize_affine_no_zero_point ,
29+ _dequantize_affine_tinygemm ,
30+ _quantize_affine_float8 ,
31+ _quantize_affine_floatx ,
32+ _quantize_affine_no_zero_point ,
33+ _quantize_affine_tinygemm ,
2134 choose_qparams_affine ,
22- choose_qparams_affine_dont_preserve_zero ,
23- choose_qparams_affine_float8 ,
24- choose_qparams_affine_floatx ,
25- choose_qparams_affine_tinygemm ,
26- choose_qparams_and_quantize_affine_hqq ,
2735 dequantize_affine ,
28- dequantize_affine_float8 ,
29- dequantize_affine_floatx ,
30- dequantize_affine_no_zero_point ,
31- dequantize_affine_tinygemm ,
3236 quantize_affine ,
33- quantize_affine_float8 ,
34- quantize_affine_floatx ,
35- quantize_affine_no_zero_point ,
36- quantize_affine_tinygemm ,
3737)
3838from torchao .utils import (
3939 TORCH_VERSION_AT_LEAST_2_5 ,
@@ -142,7 +142,7 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
142142
143143 if isinstance (self ._layout , FloatxTensorCoreLayout ):
144144 int_data , scale = self .tensor_impl .get_plain ()
145- return dequantize_affine_floatx (
145+ return _dequantize_affine_floatx (
146146 int_data ,
147147 scale ,
148148 self ._layout .ebits ,
@@ -151,11 +151,11 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
151151 )
152152 elif isinstance (self ._layout , Float8Layout ):
153153 data , scale , _ = self .tensor_impl .get_plain ()
154- return dequantize_affine_float8 (data , scale , output_dtype )
154+ return _dequantize_affine_float8 (data , scale , output_dtype )
155155 else :
156156 data , scale , zero_point = self .tensor_impl .get_plain ()
157157 if self .zero_point_domain == ZeroPointDomain .FLOAT :
158- dq = dequantize_affine_tinygemm (
158+ dq = _dequantize_affine_tinygemm (
159159 data ,
160160 self .block_size ,
161161 scale ,
@@ -166,7 +166,7 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
166166 output_dtype = output_dtype ,
167167 )
168168 elif self .zero_point_domain == ZeroPointDomain .NONE :
169- dq = dequantize_affine_no_zero_point (
169+ dq = _dequantize_affine_no_zero_point (
170170 data ,
171171 self .block_size ,
172172 scale ,
@@ -270,7 +270,7 @@ def from_hp_to_intx(
270270 from torchao .dtypes import Int4CPULayout
271271 from torchao .dtypes .uintx import TensorCoreTiledLayout
272272
273- data , scale , zero_point , _ = choose_qparams_and_quantize_affine_hqq (
273+ data , scale , zero_point , _ = _choose_qparams_and_quantize_affine_hqq (
274274 input_float ,
275275 nbits = nbits ,
276276 group_size = group_size ,
@@ -291,7 +291,7 @@ def from_hp_to_intx(
291291 data = data .to (target_dtype )
292292 else :
293293 if zero_point_domain == ZeroPointDomain .FLOAT and not preserve_zero :
294- scale , zero_point = choose_qparams_affine_tinygemm (
294+ scale , zero_point = _choose_qparams_affine_tinygemm (
295295 input_float ,
296296 mapping_type ,
297297 block_size ,
@@ -303,7 +303,7 @@ def from_hp_to_intx(
303303 zero_point_dtype ,
304304 )
305305 elif zero_point_domain == ZeroPointDomain .INT and not preserve_zero :
306- scale , zero_point = choose_qparams_affine_dont_preserve_zero (
306+ scale , zero_point = _choose_qparams_affine_dont_preserve_zero (
307307 input_float ,
308308 mapping_type ,
309309 block_size ,
@@ -329,7 +329,7 @@ def from_hp_to_intx(
329329 # choose_qparams_affine is a custom op that does support returning optional Tensors. We thus set the zero_point to None if its domain is None
330330 if zero_point_domain == ZeroPointDomain .NONE :
331331 zero_point = None
332- data = quantize_affine_no_zero_point (
332+ data = _quantize_affine_no_zero_point (
333333 input_float ,
334334 block_size ,
335335 scale ,
@@ -339,7 +339,7 @@ def from_hp_to_intx(
339339 quant_max ,
340340 )
341341 elif zero_point_domain == ZeroPointDomain .FLOAT :
342- data = quantize_affine_tinygemm (
342+ data = _quantize_affine_tinygemm (
343343 input_float ,
344344 block_size ,
345345 scale ,
@@ -400,7 +400,7 @@ def from_hp_to_intx_static(
400400
401401 if zero_point_domain == ZeroPointDomain .NONE :
402402 zero_point = None
403- int_data = quantize_affine_no_zero_point (
403+ int_data = _quantize_affine_no_zero_point (
404404 input_float ,
405405 block_size ,
406406 scale ,
@@ -410,7 +410,7 @@ def from_hp_to_intx_static(
410410 quant_max ,
411411 )
412412 elif zero_point_domain == ZeroPointDomain .FLOAT :
413- int_data = quantize_affine_tinygemm (
413+ int_data = _quantize_affine_tinygemm (
414414 input_float ,
415415 block_size ,
416416 scale ,
@@ -462,10 +462,10 @@ def from_hp_to_floatx(
462462 if target_dtype in FP8_TYPES :
463463 original_shape = input_float .shape
464464 input_float = _layout .pre_process (input_float )
465- scale = choose_qparams_affine_float8 (
465+ scale = _choose_qparams_affine_float8 (
466466 input_float , float8_dtype = target_dtype , block_size = block_size
467467 )
468- data = quantize_affine_float8 (input_float , scale , target_dtype )
468+ data = _quantize_affine_float8 (input_float , scale , target_dtype )
469469 data , scale , zero_point = _layout .post_process (
470470 data , scale , None , block_size
471471 )
@@ -499,7 +499,7 @@ def from_hp_to_floatx_static(
499499 input_float , scale , ZeroPointDomain .NONE , block_size
500500 )
501501
502- data = quantize_affine_float8 (
502+ data = _quantize_affine_float8 (
503503 input_float ,
504504 scale ,
505505 target_dtype ,
@@ -545,8 +545,8 @@ def from_hp_to_fpx(
545545
546546 ebits , mbits = _layout .ebits , _layout .mbits
547547 # Note: these ops are hardcoded to have per axis quantization (axis=1) right now
548- scale = choose_qparams_affine_floatx (input_float , ebits , mbits )
549- floatx_unpacked = quantize_affine_floatx (input_float , scale , ebits , mbits )
548+ scale = _choose_qparams_affine_floatx (input_float , ebits , mbits )
549+ floatx_unpacked = _quantize_affine_floatx (input_float , scale , ebits , mbits )
550550 floatx_packed , scale , _ = _layout .post_process (
551551 floatx_unpacked , scale , None , block_size
552552 )
0 commit comments