| 
1 | 1 | # Owner(s): ["module: dynamo"]  | 
2 | 2 | 
 
  | 
 | 3 | +import os  | 
 | 4 | +import unittest  | 
 | 5 | + | 
3 | 6 | import torch  | 
4 | 7 | import torch._dynamo  | 
5 | 8 | import torch._dynamo.test_case  | 
6 | 9 | 
 
  | 
7 | 10 | import torch._functorch._aot_autograd  | 
 | 11 | +from torch._dynamo.utils import counters  | 
8 | 12 | from torch._functorch import config as functorch_config  | 
9 | 13 | from torch._functorch._aot_autograd.autograd_cache import (  | 
10 |  | -    autograd_cache_hash,  | 
 | 14 | +    AOTAutogradCache,  | 
 | 15 | +    autograd_cache_key,  | 
11 | 16 |     BypassAOTAutogradCache,  | 
12 | 17 | )  | 
13 | 18 | from torch._functorch._aot_autograd.schemas import AOTConfig  | 
14 | 19 | from torch._inductor import config as inductor_config  | 
15 |  | - | 
16 |  | - | 
 | 20 | +from torch.testing._internal.common_cuda import SM80OrLater  | 
 | 21 | +from torch.testing._internal.common_device_type import largeTensorTest  | 
 | 22 | +from torch.testing._internal.common_utils import (  | 
 | 23 | +    instantiate_parametrized_tests,  | 
 | 24 | +    parametrize,  | 
 | 25 | +)  | 
 | 26 | +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU  | 
 | 27 | + | 
 | 28 | + | 
 | 29 | +@instantiate_parametrized_tests  | 
 | 30 | +class AOTAutogradCacheTests(torch._dynamo.test_case.TestCase):  | 
 | 31 | +    def setUp(self):  | 
 | 32 | +        """  | 
 | 33 | +        Reset all counters and caches before each unit test  | 
 | 34 | +        """  | 
 | 35 | +        super().setUp()  | 
 | 36 | +        counters.clear()  | 
 | 37 | +        self._clear_all_caches()  | 
 | 38 | + | 
 | 39 | +    def _clear_all_caches(self):  | 
 | 40 | +        """  | 
 | 41 | +        Clear every cache, including AOTAutogradCache and FXCache  | 
 | 42 | +        """  | 
 | 43 | +        torch._inductor.codecache.FxGraphCache.clear()  | 
 | 44 | +        AOTAutogradCache.clear()  | 
 | 45 | +        self._clear_dynamo_and_codecache()  | 
 | 46 | + | 
 | 47 | +    def _clear_dynamo_and_codecache(self):  | 
 | 48 | +        """  | 
 | 49 | +        Clear unrelated caches, like dynamo and PyCodeCache  | 
 | 50 | +        """  | 
 | 51 | +        torch._dynamo.reset()  | 
 | 52 | +        for m in torch._inductor.codecache.PyCodeCache.cache.values():  | 
 | 53 | +            os.remove(m.__file__)  | 
 | 54 | +        torch._inductor.codecache.PyCodeCache.cache_clear()  | 
 | 55 | + | 
 | 56 | +    @inductor_config.patch("fx_graph_cache", True)  | 
 | 57 | +    @functorch_config.patch({"enable_autograd_cache": True})  | 
 | 58 | +    def test_basic(self):  | 
 | 59 | +        """  | 
 | 60 | +        Verify the interactions between FXGraphCache and AOTAutogradCache.  | 
 | 61 | +        """  | 
 | 62 | + | 
 | 63 | +        def fn(x, y):  | 
 | 64 | +            return (x * 2, y @ y)  | 
 | 65 | + | 
 | 66 | +        a = torch.rand(25)  | 
 | 67 | +        b = torch.rand(5, 5)  | 
 | 68 | + | 
 | 69 | +        compiled_fn = torch.compile(fn, backend="inductor")  | 
 | 70 | + | 
 | 71 | +        # A first call should miss in the cache.  | 
 | 72 | +        self.assertEqual(fn(a, b), compiled_fn(a, b))  | 
 | 73 | +        self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)  | 
 | 74 | +        self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)  | 
 | 75 | +        self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)  | 
 | 76 | + | 
 | 77 | +        # A second call should hit. (First reset so in-memory guards  | 
 | 78 | +        # don't prevent compilation).  | 
 | 79 | +        self._clear_dynamo_and_codecache()  | 
 | 80 | +        self.assertEqual(fn(a, b), compiled_fn(a, b))  | 
 | 81 | + | 
 | 82 | +        self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)  | 
 | 83 | +        self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)  | 
 | 84 | +        self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)  | 
 | 85 | + | 
 | 86 | +    @inductor_config.patch("fx_graph_cache", True)  | 
 | 87 | +    @functorch_config.patch({"enable_autograd_cache": True})  | 
 | 88 | +    def test_clear_fx_graph_cache(self):  | 
 | 89 | +        """  | 
 | 90 | +        Verify the interactions between FXGraphCache and AOTAutogradCache.  | 
 | 91 | +        """  | 
 | 92 | + | 
 | 93 | +        def fn(x, y):  | 
 | 94 | +            return (x * 2, y @ y)  | 
 | 95 | + | 
 | 96 | +        a = torch.rand(25)  | 
 | 97 | +        b = torch.rand(5, 5)  | 
 | 98 | + | 
 | 99 | +        compiled_fn = torch.compile(fn, backend="inductor")  | 
 | 100 | + | 
 | 101 | +        # A first call should miss in the cache.  | 
 | 102 | +        self.assertEqual(fn(a, b), compiled_fn(a, b))  | 
 | 103 | +        self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)  | 
 | 104 | +        self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)  | 
 | 105 | +        self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)  | 
 | 106 | + | 
 | 107 | +        # Clear FX graph cache: second call should also be a miss  | 
 | 108 | +        self._clear_dynamo_and_codecache()  | 
 | 109 | +        torch._inductor.codecache.FxGraphCache.clear()  | 
 | 110 | +        self.assertEqual(fn(a, b), compiled_fn(a, b))  | 
 | 111 | +        self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)  | 
 | 112 | +        self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)  | 
 | 113 | +        # We save again into the cache  | 
 | 114 | +        self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 2)  | 
 | 115 | + | 
 | 116 | +    @inductor_config.patch("fx_graph_cache", False)  | 
 | 117 | +    @functorch_config.patch({"enable_autograd_cache": True})  | 
 | 118 | +    def test_fx_graph_cache_off(self):  | 
 | 119 | +        """  | 
 | 120 | +        Should not use cache if FXGraphCache is not enabled  | 
 | 121 | +        """  | 
 | 122 | + | 
 | 123 | +        def fn(x, y):  | 
 | 124 | +            return (x * 2, y @ y)  | 
 | 125 | + | 
 | 126 | +        a = torch.rand(25)  | 
 | 127 | +        b = torch.rand(5, 5)  | 
 | 128 | + | 
 | 129 | +        compiled_fn = torch.compile(fn, backend="inductor")  | 
 | 130 | + | 
 | 131 | +        # A first call should miss in the cache.  | 
 | 132 | +        self.assertEqual(fn(a, b), compiled_fn(a, b))  | 
 | 133 | +        self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1)  | 
 | 134 | +        self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)  | 
 | 135 | +        self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 0)  | 
 | 136 | + | 
 | 137 | +        # Clear FX graph cache: second call should also be a miss  | 
 | 138 | +        self._clear_dynamo_and_codecache()  | 
 | 139 | + | 
 | 140 | +        self.assertEqual(fn(a, b), compiled_fn(a, b))  | 
 | 141 | +        self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 2)  | 
 | 142 | +        self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)  | 
 | 143 | +        self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 0)  | 
 | 144 | + | 
 | 145 | +    @inductor_config.patch("fx_graph_cache", True)  | 
 | 146 | +    @functorch_config.patch({"enable_autograd_cache": True})  | 
 | 147 | +    def test_autograd_function(self):  | 
 | 148 | +        """  | 
 | 149 | +        Tests autograd cache hits  | 
 | 150 | +        """  | 
 | 151 | + | 
 | 152 | +        def fn(a, b):  | 
 | 153 | +            return a.sin() + b  | 
 | 154 | + | 
 | 155 | +        a = torch.randn(25, requires_grad=True)  | 
 | 156 | +        b = torch.randn(25, requires_grad=True)  | 
 | 157 | +        a2 = a.detach().clone().requires_grad_(True)  | 
 | 158 | +        b2 = b.detach().clone().requires_grad_(True)  | 
 | 159 | + | 
 | 160 | +        compiled_fn = torch.compile(fn, backend="inductor")  | 
 | 161 | + | 
 | 162 | +        # A first call should miss in the cache.  | 
 | 163 | +        self.assertEqual(fn(a, b), compiled_fn(a2, b2))  | 
 | 164 | +        fn(a, b).sum().backward()  | 
 | 165 | +        compiled_fn(a2, b2).sum().backward()  | 
 | 166 | +        self.assertEqual(a.grad, a2.grad)  | 
 | 167 | +        self.assertEqual(b.grad, b2.grad)  | 
 | 168 | + | 
 | 169 | +        self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)  | 
 | 170 | +        self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)  | 
 | 171 | +        self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)  | 
 | 172 | + | 
 | 173 | +        # Reset all tensors  | 
 | 174 | +        a = torch.randn(25, requires_grad=True)  | 
 | 175 | +        b = torch.randn(25, requires_grad=True)  | 
 | 176 | +        a2 = a.detach().clone().requires_grad_(True)  | 
 | 177 | +        b2 = b.detach().clone().requires_grad_(True)  | 
 | 178 | + | 
 | 179 | +        # A second call should hit. (First reset so in-memory guards  | 
 | 180 | +        # don't prevent compilation).  | 
 | 181 | +        self._clear_dynamo_and_codecache()  | 
 | 182 | +        self.assertEqual(fn(a, b), compiled_fn(a2, b2))  | 
 | 183 | +        fn(a, b).sum().backward()  | 
 | 184 | +        compiled_fn(a2, b2).sum().backward()  | 
 | 185 | +        self.assertEqual(a.grad, a2.grad)  | 
 | 186 | +        self.assertEqual(b.grad, b2.grad)  | 
 | 187 | + | 
 | 188 | +        self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)  | 
 | 189 | +        self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)  | 
 | 190 | +        self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)  | 
 | 191 | + | 
 | 192 | +    @largeTensorTest("64GB", device=GPU_TYPE)  | 
 | 193 | +    @parametrize("device", (GPU_TYPE,))  | 
 | 194 | +    @parametrize("dtype", (torch.float16, torch.bfloat16))  | 
 | 195 | +    @inductor_config.patch("fx_graph_cache", True)  | 
 | 196 | +    @functorch_config.patch({"enable_autograd_cache": True})  | 
 | 197 | +    def test_autograd_inductor_guards(self, device, dtype):  | 
 | 198 | +        """  | 
 | 199 | +        Tests that functions that would add inductor guards are cached properly  | 
 | 200 | +        """  | 
 | 201 | +        if device == GPU_TYPE and not HAS_GPU:  | 
 | 202 | +            raise unittest.SkipTest(f"requires {GPU_TYPE}")  | 
 | 203 | +        if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater:  | 
 | 204 | +            raise unittest.SkipTest("requires CUDA SM80 or later")  | 
 | 205 | + | 
 | 206 | +        def fn(x, y):  | 
 | 207 | +            return (x + x, y + y)  | 
 | 208 | + | 
 | 209 | +        compiled_fn = torch.compile(fn, dynamic=True)  | 
 | 210 | + | 
 | 211 | +        # Iterate over different shapes, varying whether the total  | 
 | 212 | +        # size is below or above int32. For each combination, we expect  | 
 | 213 | +        # different guards around whether the symbolic sizes do or do  | 
 | 214 | +        # not exceed int32.  | 
 | 215 | +        shapes = (  | 
 | 216 | +            ((5, 6), (7, 8)),  | 
 | 217 | +            ((5, 6), (47000, 47001)),  | 
 | 218 | +            ((47000, 47001), (5, 6)),  | 
 | 219 | +        )  | 
 | 220 | +        expected_hits = expected_misses = expected_saves = 0  | 
 | 221 | +        for a_shape, b_shape in shapes:  | 
 | 222 | +            a = torch.rand(a_shape, device=device, dtype=dtype)  | 
 | 223 | +            b = torch.rand(b_shape, device=device, dtype=dtype)  | 
 | 224 | + | 
 | 225 | +            # AVOID a dynamo reset here. We expect guards to have been  | 
 | 226 | +            # added that will be violated with the new shape. We should  | 
 | 227 | +            # see a recompilation (along with a cache miss).  | 
 | 228 | +            res1 = compiled_fn(a, b)  | 
 | 229 | +            # A first call should miss in the cache.  | 
 | 230 | +            # NOTE: Currently, this cache miss is *not* due to guards,  | 
 | 231 | +            # but instead because the AOTAutogradCache key calculation specializes on input shapes.  | 
 | 232 | +            # Once we allow tensors with symints as part of the cache key calculation, it will  | 
 | 233 | +            # instead cache miss because of guard failure.  | 
 | 234 | +            expected_misses += 1  | 
 | 235 | +            expected_saves += 1  | 
 | 236 | +            self.assertEqual(  | 
 | 237 | +                counters["aot_autograd"]["autograd_cache_miss"], expected_misses  | 
 | 238 | +            )  | 
 | 239 | +            self.assertEqual(  | 
 | 240 | +                counters["aot_autograd"]["autograd_cache_hit"], expected_hits  | 
 | 241 | +            )  | 
 | 242 | +            self.assertEqual(  | 
 | 243 | +                counters["aot_autograd"]["autograd_cache_saved"], expected_saves  | 
 | 244 | +            )  | 
 | 245 | + | 
 | 246 | +            # A second call should hit. (First reset so in-memory guards  | 
 | 247 | +            # don't prevent compilation).  | 
 | 248 | + | 
 | 249 | +            # Now clear dynamo and we should see a cache hit  | 
 | 250 | +            # This should populate guards to dynamo's cache, so that a subsequent run with a different  | 
 | 251 | +            # shape will still trigger a second call to autograd_cache.  | 
 | 252 | +            self._clear_dynamo_and_codecache()  | 
 | 253 | +            res2 = compiled_fn(a, b)  | 
 | 254 | +            expected_hits += 1  | 
 | 255 | +            self.assertEqual(  | 
 | 256 | +                counters["aot_autograd"]["autograd_cache_miss"], expected_misses  | 
 | 257 | +            )  | 
 | 258 | +            self.assertEqual(  | 
 | 259 | +                counters["aot_autograd"]["autograd_cache_hit"], expected_hits  | 
 | 260 | +            )  | 
 | 261 | +            self.assertEqual(  | 
 | 262 | +                counters["aot_autograd"]["autograd_cache_saved"], expected_saves  | 
 | 263 | +            )  | 
 | 264 | +            self.assertEqual(res1, res2)  | 
 | 265 | + | 
 | 266 | + | 
 | 267 | +@inductor_config.patch("fx_graph_cache", True)  | 
17 | 268 | class AOTAutogradCachePicklerTests(torch._dynamo.test_case.TestCase):  | 
18 | 269 |     @property  | 
19 | 270 |     def device_type(self) -> str:  | 
@@ -57,7 +308,7 @@ def gen_cache_key(self, f, config, inputs=None):  | 
57 | 308 |         if inputs is None:  | 
58 | 309 |             inputs = [torch.ones(3)]  | 
59 | 310 |         _, fx_g, example_inputs = self._get_dynamo_output(f, *inputs)  | 
60 |  | -        return autograd_cache_hash(fx_g, example_inputs, config)  | 
 | 311 | +        return autograd_cache_key(fx_g, example_inputs, config)  | 
61 | 312 | 
 
  | 
62 | 313 |     def test_basic_hash_key(self):  | 
63 | 314 |         def fn(x):  | 
 | 
0 commit comments