2121)
2222from torchao .quantization .quant_primitives import (
2323 MappingType ,
24+ ZeroPointDomain ,
2425)
2526
2627
@@ -74,7 +75,7 @@ def test_block_size_calc_success(self):
7475 eps = torch .finfo (torch .float32 ).eps ,
7576 scale_dtype = torch .float ,
7677 zero_point_dtype = torch .int ,
77- zero_point_domain = None ,
78+ zero_point_domain = ZeroPointDomain . NONE ,
7879 )
7980 example_inputs = [
8081 torch .randn (10 , 2048 ),
@@ -93,7 +94,7 @@ def test_block_size_calc_success(self):
9394 eps = torch .finfo (torch .float32 ).eps ,
9495 scale_dtype = torch .float ,
9596 zero_point_dtype = torch .int ,
96- zero_point_domain = None ,
97+ zero_point_domain = ZeroPointDomain . NONE ,
9798 )
9899 for example_input in example_inputs :
99100 obs (example_input )
@@ -108,7 +109,7 @@ def test_block_size_row_errors(self):
108109 eps = torch .finfo (torch .float32 ).eps ,
109110 scale_dtype = torch .float ,
110111 zero_point_dtype = torch .int ,
111- zero_point_domain = None ,
112+ zero_point_domain = ZeroPointDomain . NONE ,
112113 )
113114 example_inputs = [
114115 torch .randn (10 , 2048 ),
@@ -127,7 +128,7 @@ def test_block_size_row_errors(self):
127128 eps = torch .finfo (torch .float32 ).eps ,
128129 scale_dtype = torch .float ,
129130 zero_point_dtype = torch .int ,
130- zero_point_domain = None ,
131+ zero_point_domain = ZeroPointDomain . NONE ,
131132 )
132133 example_inputs = [
133134 torch .randn (10 , 2048 ),
@@ -155,7 +156,7 @@ def test_linear_observer_tensor(self, observe_weight: bool):
155156 eps = torch .finfo (torch .float32 ).eps ,
156157 scale_dtype = torch .float ,
157158 zero_point_dtype = torch .int ,
158- zero_point_domain = None ,
159+ zero_point_domain = ZeroPointDomain . NONE ,
159160 )
160161 if observe_weight :
161162 weight_observer = AffineQuantizedMinMaxObserver (
@@ -165,7 +166,7 @@ def test_linear_observer_tensor(self, observe_weight: bool):
165166 eps = torch .finfo (torch .float32 ).eps ,
166167 scale_dtype = torch .float ,
167168 zero_point_dtype = torch .int ,
168- zero_point_domain = None ,
169+ zero_point_domain = ZeroPointDomain . NONE ,
169170 )
170171 else :
171172 weight_observer = None
@@ -199,7 +200,7 @@ def test_linear_observer_tensor(self, observe_weight: bool):
199200 input_scale .item (),
200201 max_val / max_fp8 ,
201202 )
202- self .assertIsNotNone (input_zero_point )
203+ self .assertIsNone (input_zero_point )
203204
204205 if observe_weight :
205206 weight_observer = linear .weight .weight_observer
@@ -210,7 +211,7 @@ def test_linear_observer_tensor(self, observe_weight: bool):
210211 atol = 5e-5 ,
211212 rtol = 0.0 ,
212213 )
213- self .assertIsNotNone (weight_zero_point )
214+ self .assertIsNone (weight_zero_point )
214215 else :
215216 self .assertIsNone (linear .weight .weight_observer )
216217
0 commit comments