|  | 
|  | 1 | +import itertools | 
|  | 2 | + | 
|  | 3 | +import pytest | 
|  | 4 | +import torch | 
|  | 5 | +import torch.utils.checkpoint | 
|  | 6 | +from torch._dynamo.utils import counters | 
|  | 7 | +from torch._inductor import config | 
|  | 8 | +from torch._inductor.test_case import TestCase, run_tests | 
|  | 9 | +from torch._inductor.utils import run_and_get_code | 
|  | 10 | +from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm | 
|  | 11 | +from torch.testing._internal.inductor_utils import HAS_CPU | 
|  | 12 | +from torch.utils.cpp_extension import IS_WINDOWS | 
|  | 13 | + | 
|  | 14 | +import torchao | 
|  | 15 | +from torchao.prototype.inductor.fx_passes.int8_sdpa_fusion import _int8_sdpa_init | 
|  | 16 | +from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 | 
|  | 17 | + | 
|  | 18 | + | 
|  | 19 | +class SelfAttnLikeModule(torch.nn.Module): | 
|  | 20 | +    def __init__( | 
|  | 21 | +        self, | 
|  | 22 | +        input_dim, | 
|  | 23 | +        has_mask, | 
|  | 24 | +        num_attention_heads=None, | 
|  | 25 | +        attention_head_size=None, | 
|  | 26 | +    ) -> None: | 
|  | 27 | +        super().__init__() | 
|  | 28 | +        self.input_dim = input_dim | 
|  | 29 | +        self.q_proj = torch.nn.Linear(input_dim, input_dim, bias=False) | 
|  | 30 | +        self.k_proj = torch.nn.Linear(input_dim, input_dim, bias=False) | 
|  | 31 | +        self.v_proj = torch.nn.Linear(input_dim, input_dim, bias=False) | 
|  | 32 | +        self.softmax = torch.nn.Softmax(dim=-1) | 
|  | 33 | +        assert num_attention_heads is not None | 
|  | 34 | +        assert attention_head_size is not None | 
|  | 35 | +        self.num_attention_heads = num_attention_heads | 
|  | 36 | +        self.attention_head_size = attention_head_size | 
|  | 37 | +        self.all_head_size = self.num_attention_heads * self.attention_head_size | 
|  | 38 | +        self.dense = torch.nn.Linear(self.all_head_size, self.all_head_size) | 
|  | 39 | +        self.dropout = torch.nn.Dropout(0) | 
|  | 40 | +        self.has_mask = has_mask | 
|  | 41 | + | 
|  | 42 | +    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: | 
|  | 43 | +        new_x_shape = x.size()[:-1] + ( | 
|  | 44 | +            self.num_attention_heads, | 
|  | 45 | +            self.attention_head_size, | 
|  | 46 | +        ) | 
|  | 47 | +        x = x.view(new_x_shape) | 
|  | 48 | +        return x.permute([0, 2, 1, 3]) | 
|  | 49 | + | 
|  | 50 | +    def forward(self, x, mask): | 
|  | 51 | +        q = self.q_proj(x) | 
|  | 52 | +        k = self.k_proj(x) | 
|  | 53 | +        v = self.v_proj(x) | 
|  | 54 | +        q = self.transpose_for_scores(q) | 
|  | 55 | +        k = self.transpose_for_scores(k) | 
|  | 56 | +        v = self.transpose_for_scores(v) | 
|  | 57 | +        scores = torch.matmul(q, k.transpose(-1, -2)) / (self.input_dim**0.5) | 
|  | 58 | +        if self.has_mask and mask.dtype != scores.dtype: | 
|  | 59 | +            scores = scores + mask | 
|  | 60 | +        attention = self.softmax(scores) | 
|  | 61 | +        attention = self.dropout(attention) | 
|  | 62 | +        context_layer = torch.matmul(attention, v) | 
|  | 63 | +        context_layer = context_layer.permute(0, 2, 1, 3).contiguous() | 
|  | 64 | +        context_layer = context_layer.view( | 
|  | 65 | +            context_layer.size()[:-2] + (self.all_head_size,) | 
|  | 66 | +        ) | 
|  | 67 | +        return self.dense(context_layer) | 
|  | 68 | + | 
|  | 69 | + | 
|  | 70 | +class TestSDPAPatternRewriterTemplate(TestCase): | 
|  | 71 | +    def _clone_inputs(self, inputs): | 
|  | 72 | +        def clone(x): | 
|  | 73 | +            if not isinstance(x, torch.Tensor): | 
|  | 74 | +                return x | 
|  | 75 | +            return x.clone() | 
|  | 76 | + | 
|  | 77 | +        return [clone(x) for x in inputs] | 
|  | 78 | + | 
|  | 79 | +    def _check_common( | 
|  | 80 | +        self, | 
|  | 81 | +        dot_prod_attention, | 
|  | 82 | +        args1=None, | 
|  | 83 | +        contains=True, | 
|  | 84 | +        atol=1e-5, | 
|  | 85 | +        has_fuse_pattern=True, | 
|  | 86 | +        has_dropout=False, | 
|  | 87 | +        check_train=True, | 
|  | 88 | +        override_check_equal=False, | 
|  | 89 | +        dtype=torch.float, | 
|  | 90 | +        rtol=1.3e-6, | 
|  | 91 | +    ): | 
|  | 92 | +        if args1 is None: | 
|  | 93 | +            tensor_shape = (4, 2, 16, 32) | 
|  | 94 | +            args1 = [ | 
|  | 95 | +                torch.randn(tensor_shape, device=self.device, dtype=dtype), | 
|  | 96 | +                torch.randn(tensor_shape, device=self.device, dtype=dtype), | 
|  | 97 | +                torch.randn(tensor_shape, device=self.device, dtype=dtype), | 
|  | 98 | +            ] | 
|  | 99 | +        else: | 
|  | 100 | +            args1 = list(args1) | 
|  | 101 | +        args2 = self._clone_inputs(args1) | 
|  | 102 | + | 
|  | 103 | +        for training in [False, True] if check_train else [False]: | 
|  | 104 | +            for x in itertools.chain(args1[:], args2[:]): | 
|  | 105 | +                if isinstance(x, torch.Tensor) and x.is_floating_point(): | 
|  | 106 | +                    x.requires_grad = training | 
|  | 107 | + | 
|  | 108 | +            dropout_arg = [training] if has_dropout else [] | 
|  | 109 | +            torch.manual_seed(1234) | 
|  | 110 | +            result1 = dot_prod_attention(*(args1 + dropout_arg)) | 
|  | 111 | + | 
|  | 112 | +            counters.clear() | 
|  | 113 | +            torch.manual_seed(1234) | 
|  | 114 | +            compiled_model = torch.compile(dot_prod_attention, fullgraph=True) | 
|  | 115 | +            result2, source_code = run_and_get_code( | 
|  | 116 | +                compiled_model, | 
|  | 117 | +                *(args2 + dropout_arg), | 
|  | 118 | +            ) | 
|  | 119 | +            source_code = "\n".join(source_code) | 
|  | 120 | +            if has_fuse_pattern: | 
|  | 121 | +                self.assertGreaterEqual(counters["inductor"]["int8_fuse_attention"], 1) | 
|  | 122 | +            if contains: | 
|  | 123 | +                # many of the patterns get re-expanded in dispatcher | 
|  | 124 | +                self.assertIn( | 
|  | 125 | +                    "torchao.scaled_dot_product_int8", | 
|  | 126 | +                    source_code, | 
|  | 127 | +                ) | 
|  | 128 | + | 
|  | 129 | +            # some tests configured with very low dropout where we still want to check equality | 
|  | 130 | +            if not has_dropout or override_check_equal: | 
|  | 131 | +                self.assertEqual(result1, result2, atol=atol, rtol=1.3e-6) | 
|  | 132 | + | 
|  | 133 | +            if training: | 
|  | 134 | +                result1.sum().backward() | 
|  | 135 | +                result2.sum().backward() | 
|  | 136 | +                for arg1, arg2 in zip(args1, args2): | 
|  | 137 | +                    if ( | 
|  | 138 | +                        isinstance(arg1, torch.Tensor) | 
|  | 139 | +                        and arg1.is_floating_point() | 
|  | 140 | +                        and (not has_dropout or override_check_equal) | 
|  | 141 | +                    ): | 
|  | 142 | +                        self.assertEqual(arg1.grad, arg2.grad, atol=atol, rtol=rtol) | 
|  | 143 | + | 
|  | 144 | +    @skipIfRocm | 
|  | 145 | +    @pytest.mark.skipif( | 
|  | 146 | +        not TORCH_VERSION_AT_LEAST_2_7, reason="int8 sdpa requires torch 2.7 or later" | 
|  | 147 | +    ) | 
|  | 148 | +    @pytest.mark.skipif(IS_WINDOWS, reason="int8 sdpa does not support windows yet") | 
|  | 149 | +    @config.patch({"freezing": True}) | 
|  | 150 | +    def _test_sdpa_int8_rewriter(self): | 
|  | 151 | +        from torch.export import export_for_training | 
|  | 152 | + | 
|  | 153 | +        import torchao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq | 
|  | 154 | +        from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e | 
|  | 155 | +        from torchao.quantization.pt2e.quantizer.x86_inductor_quantizer import ( | 
|  | 156 | +            X86InductorQuantizer, | 
|  | 157 | +        ) | 
|  | 158 | + | 
|  | 159 | +        # pattern is different for bs=1 | 
|  | 160 | +        torch.manual_seed(1234) | 
|  | 161 | +        for dtype, has_mask, bs in itertools.product( | 
|  | 162 | +            [torch.float32, torch.bfloat16], [True, False], [56, 1] | 
|  | 163 | +        ): | 
|  | 164 | +            seqlen, numhead, headsize = 197, 16, 64 | 
|  | 165 | +            mod = SelfAttnLikeModule( | 
|  | 166 | +                input_dim=headsize * numhead, | 
|  | 167 | +                has_mask=has_mask, | 
|  | 168 | +                num_attention_heads=numhead, | 
|  | 169 | +                attention_head_size=headsize, | 
|  | 170 | +            ).eval() | 
|  | 171 | +            inputs = ( | 
|  | 172 | +                torch.randn( | 
|  | 173 | +                    (bs, seqlen, headsize * numhead), device=self.device, dtype=dtype | 
|  | 174 | +                ), | 
|  | 175 | +                torch.randn((bs, 1, 1, seqlen), device=self.device) | 
|  | 176 | +                if has_mask | 
|  | 177 | +                else None, | 
|  | 178 | +            ) | 
|  | 179 | +            enable_autocast = dtype == torch.bfloat16 | 
|  | 180 | +            with ( | 
|  | 181 | +                torch.no_grad(), | 
|  | 182 | +                torch.amp.autocast( | 
|  | 183 | +                    self.device, enabled=enable_autocast, dtype=torch.bfloat16 | 
|  | 184 | +                ), | 
|  | 185 | +            ): | 
|  | 186 | +                _int8_sdpa_init() | 
|  | 187 | +                quantizer = X86InductorQuantizer() | 
|  | 188 | +                quantizer.set_global(xiq.get_default_x86_inductor_quantization_config()) | 
|  | 189 | +                quantizer.set_function_type_qconfig( | 
|  | 190 | +                    torch.matmul, quantizer.get_global_quantization_config() | 
|  | 191 | +                ) | 
|  | 192 | +                export_model = export_for_training( | 
|  | 193 | +                    mod, | 
|  | 194 | +                    inputs, | 
|  | 195 | +                    strict=True, | 
|  | 196 | +                ).module() | 
|  | 197 | +                prepare_model = prepare_pt2e(export_model, quantizer) | 
|  | 198 | +                prepare_model(*inputs) | 
|  | 199 | +                convert_model = convert_pt2e(prepare_model) | 
|  | 200 | +                torchao.quantization.pt2e.move_exported_model_to_eval(convert_model) | 
|  | 201 | +                self._check_common( | 
|  | 202 | +                    convert_model, args1=inputs, check_train=False, atol=1.0 | 
|  | 203 | +                ) | 
|  | 204 | + | 
|  | 205 | + | 
|  | 206 | +if HAS_CPU: | 
|  | 207 | + | 
|  | 208 | +    class SDPAPatternRewriterCpuTests(TestSDPAPatternRewriterTemplate): | 
|  | 209 | +        device = "cpu" | 
|  | 210 | +        test_sdpa_int8_rewriter_cpu = ( | 
|  | 211 | +            TestSDPAPatternRewriterTemplate._test_sdpa_int8_rewriter | 
|  | 212 | +        ) | 
|  | 213 | + | 
|  | 214 | + | 
|  | 215 | +if __name__ == "__main__": | 
|  | 216 | +    if IS_LINUX: | 
|  | 217 | +        run_tests() | 
0 commit comments