Skip to content

Commit ca5e688

Browse files
committed
up
1 parent 4afcea8 commit ca5e688

File tree

2 files changed

+28
-40
lines changed

2 files changed

+28
-40
lines changed

.github/workflows/torchao_experimental_test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ jobs:
4646
conda activate venv
4747
pytest torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py
4848
python torchao/experimental/tests/test_embedding_xbit_quantizer.py
49+
python torchao/experimental/tests/test_quant_passes.py
4950
- name: Run kernels/cpu/aarch64/tests
5051
run: |
5152
conda activate venv

torchao/experimental/tests/test_quant_passes.py

Lines changed: 27 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -21,45 +21,32 @@
2121

2222

2323
class TestQuantPasses(unittest.TestCase):
24-
def replace_q_dq_patterns_with_quantized_linear_ops_pass(self):
25-
# setattr(torch.ops.pt2e_quant, "dequantize_affine", None)
26-
layers = [
27-
torch.nn.Linear(256, 128, bias=True),
28-
torch.nn.Linear(128, 64, bias=False),
29-
torch.nn.Linear(64, 32, bias=True),
30-
]
24+
def test_replace_q_dq_patterns_with_quantized_linear_ops_pass(self):
25+
layers = []
26+
layer_to_weight_dtype = {}
27+
layer_to_has_weight_zeros = {}
28+
for weight_dtype in [getattr(torch, f"int{i}") for i in range(1, 9)]:
29+
for has_weight_zeros in [True, False]:
30+
for has_bias in [True, False]:
31+
idx = len(layers)
32+
layer_to_weight_dtype[idx] = weight_dtype
33+
layer_to_has_weight_zeros[idx] = has_weight_zeros
34+
layers.append(torch.nn.Linear(64, 64, bias=has_bias))
35+
activations = torch.randn(2, 1, 64, dtype=torch.float32)
36+
3137
model = torch.nn.Sequential(*layers)
32-
activations = torch.randn(2, 1, 256, dtype=torch.float32)
33-
quantize_(
34-
model,
35-
Int8DynamicActivationIntxWeightConfig(
36-
weight_dtype=torch.int4,
37-
granularity=PerGroup(64),
38-
has_weight_zeros=True,
39-
layout=QDQLayout(),
40-
),
41-
lambda m, fqn: fqn == "0",
42-
)
43-
quantize_(
44-
model,
45-
Int8DynamicActivationIntxWeightConfig(
46-
weight_dtype=torch.int3,
47-
granularity=PerRow(),
48-
has_weight_zeros=False,
49-
layout=QDQLayout(),
50-
),
51-
lambda m, fqn: fqn == "1",
52-
)
53-
quantize_(
54-
model,
55-
Int8DynamicActivationIntxWeightConfig(
56-
weight_dtype=torch.int5,
57-
granularity=PerGroup(32),
58-
has_weight_zeros=False,
59-
layout=QDQLayout(),
60-
),
61-
lambda m, fqn: fqn == "2",
62-
)
38+
for idx in range(len(layers)):
39+
quantize_(
40+
model,
41+
Int8DynamicActivationIntxWeightConfig(
42+
weight_dtype=layer_to_weight_dtype[idx],
43+
# Test out different granularities
44+
granularity=PerGroup(32) if idx % 2 == 0 else PerRow(),
45+
has_weight_zeros=layer_to_has_weight_zeros[idx],
46+
layout=QDQLayout(),
47+
),
48+
lambda m, fqn: fqn == str(idx),
49+
)
6350

6451
eager_results = model(activations)
6552
exported = torch.export.export(model, (activations,), strict=True)
@@ -70,9 +57,9 @@ def replace_q_dq_patterns_with_quantized_linear_ops_pass(self):
7057
exported.graph_module.code
7158
)
7259

73-
# We should find 3 torchao linear ops
60+
# We should find len(layers) torchao linear ops
7461
FileCheck().check_count(
75-
"torch.ops.torchao._linear_8bit_act_", count=3, exactly=True
62+
"torch.ops.torchao._linear_8bit_act_", count=len(layers), exactly=True
7663
).run(exported.graph_module.code)
7764

7865
# We should not find Q/DQ ops

0 commit comments

Comments
 (0)