|  | 
|  | 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. | 
|  | 2 | +# All rights reserved. | 
|  | 3 | +# | 
|  | 4 | +# This source code is licensed under the license found in the | 
|  | 5 | +# LICENSE file in the root directory of this source tree. | 
|  | 6 | + | 
|  | 7 | +from typing import Optional | 
|  | 8 | +import copy | 
|  | 9 | +import itertools | 
|  | 10 | +import os | 
|  | 11 | +import sys | 
|  | 12 | + | 
|  | 13 | +import torch | 
|  | 14 | +import torchao_mps_ops | 
|  | 15 | +import unittest | 
|  | 16 | + | 
|  | 17 | +from parameterized import parameterized | 
|  | 18 | +from torchao.experimental.quant_api import UIntxWeightOnlyLinearQuantizer | 
|  | 19 | +from torchao.experimental.quant_api import _quantize | 
|  | 20 | + | 
|  | 21 | + | 
|  | 22 | +class TestUIntxWeightOnlyLinearQuantizer(unittest.TestCase): | 
|  | 23 | +    BITWIDTHS = range(1, 8) | 
|  | 24 | +    GROUPSIZES = [32, 64, 128, 256] | 
|  | 25 | + | 
|  | 26 | +    # Currently, the quantization code in quant_api.py only supports K values | 
|  | 27 | +    # multiple of group_size. | 
|  | 28 | +    # TODO(mcandales): Generalize the code in quant_api.py and add tests to | 
|  | 29 | +    # cover values of K not multiple of group_size. | 
|  | 30 | +    def _model_setup(self): | 
|  | 31 | +        group_size = 32 | 
|  | 32 | +        k0 = 96 | 
|  | 33 | +        k1 = 224 | 
|  | 34 | +        k2 = 160 | 
|  | 35 | +        n = 47 | 
|  | 36 | +        layers = [ | 
|  | 37 | +            torch.nn.Linear(k0, k1, bias=False), | 
|  | 38 | +            torch.nn.Linear(k1, k2, bias=False), | 
|  | 39 | +            torch.nn.Linear(k2, n, bias=False), | 
|  | 40 | +        ] | 
|  | 41 | +        model = torch.nn.Sequential(*layers) | 
|  | 42 | +        return model, group_size, k0, n | 
|  | 43 | + | 
|  | 44 | +    def _quantize_model(self, model, precision, nbit, group_size): | 
|  | 45 | +        quantizer = UIntxWeightOnlyLinearQuantizer( | 
|  | 46 | +            device="mps", | 
|  | 47 | +            precision=precision, | 
|  | 48 | +            bitwidth=nbit, | 
|  | 49 | +            groupsize=group_size, | 
|  | 50 | +        ) | 
|  | 51 | +        quantized_model = copy.deepcopy(model) | 
|  | 52 | +        quantized_model = quantizer.quantize(quantized_model) | 
|  | 53 | +        return quantized_model | 
|  | 54 | + | 
|  | 55 | +    @parameterized.expand(BITWIDTHS) | 
|  | 56 | +    def test_export(self, nbit): | 
|  | 57 | +        model, group_size, k0, n = self._model_setup() | 
|  | 58 | +        m = 3 | 
|  | 59 | +        activations = torch.randn(m, k0, dtype=torch.float32, device="mps") | 
|  | 60 | + | 
|  | 61 | +        quantized_model = self._quantize_model(model, torch.float32, nbit, group_size) | 
|  | 62 | +        exported = torch.export.export(quantized_model, (activations,)) | 
|  | 63 | + | 
|  | 64 | +        for node in exported.graph.nodes: | 
|  | 65 | +            if node.op == "call_function": | 
|  | 66 | +                self.assertTrue( | 
|  | 67 | +                    str(node.target) | 
|  | 68 | +                    == f"torchao._linear_fp_act_{nbit}bit_weight.default" | 
|  | 69 | +                ) | 
|  | 70 | + | 
|  | 71 | +    @parameterized.expand(BITWIDTHS) | 
|  | 72 | +    def test_2d_output_device_and_shape(self, nbit): | 
|  | 73 | +        model, group_size, k0, n = self._model_setup() | 
|  | 74 | +        m = 3 | 
|  | 75 | +        activations = torch.randn(m, k0, dtype=torch.float32, device="mps") | 
|  | 76 | + | 
|  | 77 | +        quantized_model = self._quantize_model(model, torch.float32, nbit, group_size) | 
|  | 78 | +        result = quantized_model(activations) | 
|  | 79 | +        self.assertTrue(result.is_mps) | 
|  | 80 | +        self.assertTrue(result.shape == (m, n)) | 
|  | 81 | + | 
|  | 82 | +    @parameterized.expand(BITWIDTHS) | 
|  | 83 | +    def test_3d_output_device_and_shape(self, nbit): | 
|  | 84 | +        model, group_size, k0, n = self._model_setup() | 
|  | 85 | +        leading_shape = (3, 5) | 
|  | 86 | +        activations = torch.randn(*leading_shape, k0, dtype=torch.float32, device="mps") | 
|  | 87 | + | 
|  | 88 | +        quantized_model = self._quantize_model(model, torch.float32, nbit, group_size) | 
|  | 89 | +        result = quantized_model(activations) | 
|  | 90 | +        self.assertTrue(result.is_mps) | 
|  | 91 | +        self.assertTrue(result.shape == (*leading_shape, n)) | 
|  | 92 | + | 
|  | 93 | +    @parameterized.expand(itertools.product(BITWIDTHS, GROUPSIZES)) | 
|  | 94 | +    def test_valid_groupsizes(self, nbit, group_size): | 
|  | 95 | +        k0 = 3 * group_size | 
|  | 96 | +        k1 = 7 * group_size | 
|  | 97 | +        n = 47 | 
|  | 98 | +        layers = [ | 
|  | 99 | +            torch.nn.Linear(k0, k1, bias=False), | 
|  | 100 | +            torch.nn.Linear(k1, n, bias=False), | 
|  | 101 | +        ] | 
|  | 102 | +        model = torch.nn.Sequential(*layers) | 
|  | 103 | +        m = 5 | 
|  | 104 | +        activations = torch.randn(m, k0, dtype=torch.float32, device="mps") | 
|  | 105 | + | 
|  | 106 | +        quantized_model = self._quantize_model(model, torch.float32, nbit, group_size) | 
|  | 107 | +        result = quantized_model(activations) | 
|  | 108 | +        self.assertTrue(result.is_mps) | 
|  | 109 | +        self.assertTrue(result.shape == (m, n)) | 
|  | 110 | + | 
|  | 111 | +    @parameterized.expand(BITWIDTHS) | 
|  | 112 | +    def test_invalid_groupsizes(self, nbit): | 
|  | 113 | +        group_size = 16 | 
|  | 114 | +        k0 = 3 * group_size | 
|  | 115 | +        k1 = 7 * group_size | 
|  | 116 | +        n = 47 | 
|  | 117 | +        layers = [ | 
|  | 118 | +            torch.nn.Linear(k0, k1, bias=False), | 
|  | 119 | +            torch.nn.Linear(k1, n, bias=False), | 
|  | 120 | +        ] | 
|  | 121 | +        model = torch.nn.Sequential(*layers) | 
|  | 122 | + | 
|  | 123 | +        with self.assertRaises(ValueError): | 
|  | 124 | +            self._quantize_model(model, torch.float32, nbit, group_size) | 
|  | 125 | + | 
|  | 126 | +    # TODO(mcandales): Consolidate with the reference impl in test_lowbit.py | 
|  | 127 | +    def _reference_linear_lowbit_quant_weights(self, A, W, group_size, S, Z): | 
|  | 128 | +        N = W.shape[0] | 
|  | 129 | +        K = W.shape[1] | 
|  | 130 | +        W = W.to(torch.float32) | 
|  | 131 | +        scales = S.t().unsqueeze(2).repeat(1, 1, group_size).view(N, -1)[:, :K] | 
|  | 132 | +        zeros = Z.t().unsqueeze(2).repeat(1, 1, group_size).view(N, -1)[:, :K] | 
|  | 133 | +        W = scales * W + zeros | 
|  | 134 | +        return torch.mm(A, W.t()) | 
|  | 135 | + | 
|  | 136 | +    @parameterized.expand(BITWIDTHS) | 
|  | 137 | +    def test_accuracy(self, nbit): | 
|  | 138 | +        group_size = 32 | 
|  | 139 | +        m = 3 | 
|  | 140 | +        n = 7 | 
|  | 141 | +        k = 64 | 
|  | 142 | +        with torch.no_grad(): | 
|  | 143 | +            activations = torch.rand(m, k, dtype=torch.float32, device="mps") | 
|  | 144 | +            model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)]) | 
|  | 145 | +            quantized_model = self._quantize_model( | 
|  | 146 | +                model, torch.float32, nbit, group_size | 
|  | 147 | +            ) | 
|  | 148 | +            result = quantized_model(activations) | 
|  | 149 | + | 
|  | 150 | +            # Compute expected result | 
|  | 151 | +            weight_cpu = model[0].weight.data | 
|  | 152 | +            weight_qvals_cpu, weight_scales_cpu, weight_zeros_cpu = _quantize( | 
|  | 153 | +                weight_cpu, group_size, nbit, True, torch.uint8 | 
|  | 154 | +            ) | 
|  | 155 | +            weight_scales_cpu = weight_scales_cpu.t() | 
|  | 156 | +            weight_zeros_cpu = -weight_zeros_cpu.t() * weight_scales_cpu | 
|  | 157 | +            expected = self._reference_linear_lowbit_quant_weights( | 
|  | 158 | +                activations.cpu(), | 
|  | 159 | +                weight_qvals_cpu, | 
|  | 160 | +                group_size, | 
|  | 161 | +                weight_scales_cpu, | 
|  | 162 | +                weight_zeros_cpu, | 
|  | 163 | +            ) | 
|  | 164 | + | 
|  | 165 | +            # Compare results | 
|  | 166 | +            torch.testing.assert_close(result.cpu(), expected, rtol=0.001, atol=0.001) | 
|  | 167 | + | 
|  | 168 | + | 
|  | 169 | +if __name__ == "__main__": | 
|  | 170 | +    unittest.main() | 
0 commit comments