2121
2222
2323class TestQuantPasses (unittest .TestCase ):
24- def replace_q_dq_patterns_with_quantized_linear_ops_pass (self ):
25- # setattr(torch.ops.pt2e_quant, "dequantize_affine", None)
26- layers = [
27- torch .nn .Linear (256 , 128 , bias = True ),
28- torch .nn .Linear (128 , 64 , bias = False ),
29- torch .nn .Linear (64 , 32 , bias = True ),
30- ]
24+ def test_replace_q_dq_patterns_with_quantized_linear_ops_pass (self ):
25+ layers = []
26+ layer_to_weight_dtype = {}
27+ layer_to_has_weight_zeros = {}
28+ for weight_dtype in [getattr (torch , f"int{ i } " ) for i in range (1 , 9 )]:
29+ for has_weight_zeros in [True , False ]:
30+ for has_bias in [True , False ]:
31+ idx = len (layers )
32+ layer_to_weight_dtype [idx ] = weight_dtype
33+ layer_to_has_weight_zeros [idx ] = has_weight_zeros
34+ layers .append (torch .nn .Linear (64 , 64 , bias = has_bias ))
35+ activations = torch .randn (2 , 1 , 64 , dtype = torch .float32 )
36+
3137 model = torch .nn .Sequential (* layers )
32- activations = torch .randn (2 , 1 , 256 , dtype = torch .float32 )
33- quantize_ (
34- model ,
35- Int8DynamicActivationIntxWeightConfig (
36- weight_dtype = torch .int4 ,
37- granularity = PerGroup (64 ),
38- has_weight_zeros = True ,
39- layout = QDQLayout (),
40- ),
41- lambda m , fqn : fqn == "0" ,
42- )
43- quantize_ (
44- model ,
45- Int8DynamicActivationIntxWeightConfig (
46- weight_dtype = torch .int3 ,
47- granularity = PerRow (),
48- has_weight_zeros = False ,
49- layout = QDQLayout (),
50- ),
51- lambda m , fqn : fqn == "1" ,
52- )
53- quantize_ (
54- model ,
55- Int8DynamicActivationIntxWeightConfig (
56- weight_dtype = torch .int5 ,
57- granularity = PerGroup (32 ),
58- has_weight_zeros = False ,
59- layout = QDQLayout (),
60- ),
61- lambda m , fqn : fqn == "2" ,
62- )
38+ for idx in range (len (layers )):
39+ quantize_ (
40+ model ,
41+ Int8DynamicActivationIntxWeightConfig (
42+ weight_dtype = layer_to_weight_dtype [idx ],
43+ # Test out different granularities
44+ granularity = PerGroup (32 ) if idx % 2 == 0 else PerRow (),
45+ has_weight_zeros = layer_to_has_weight_zeros [idx ],
46+ layout = QDQLayout (),
47+ ),
48+ lambda m , fqn : fqn == str (idx ),
49+ )
6350
6451 eager_results = model (activations )
6552 exported = torch .export .export (model , (activations ,), strict = True )
@@ -70,9 +57,9 @@ def replace_q_dq_patterns_with_quantized_linear_ops_pass(self):
7057 exported .graph_module .code
7158 )
7259
73- # We should find 3 torchao linear ops
60+ # We should find len(layers) torchao linear ops
7461 FileCheck ().check_count (
75- "torch.ops.torchao._linear_8bit_act_" , count = 3 , exactly = True
62+ "torch.ops.torchao._linear_8bit_act_" , count = len ( layers ) , exactly = True
7663 ).run (exported .graph_module .code )
7764
7865 # We should not find Q/DQ ops
0 commit comments