Skip to content

Commit abc3eec

Browse files
jamesjwupytorchmergebot
authored andcommitted
First version of AOTAutogradCache (pytorch#126791)
This PR implements "V0" of AOTAutogradCache. Given an input to AOTAutograd, we calculate a cache key, then save an AOTAutogradCacheEntry. Each AOTAutogradCacheEntry has: - A CompiledForward and optionally a CompiledBackward - A bunch of metadata. CompiledForward and CompiledBackward each save the *key* to the FXGraphCache associated with the compiled object. FXGraphCache populates this key field as long as it's able to return a compiled graph given a set of inputs. We then load the same object from the FXGraphCache on an AOTAutogradCache hit. On cache miss: - Run AOTAutograd, up to AOTAutogradDispatch.post_compile. - Save an AOTAutogradCacheEntry to the cache after compiling the necessary portions and receiving a cache key from FXGraphCache. In this we *always* compile the backwards ahead of time. The PR above this one implements backward lazy caching, so that we only save to the cache after compiling the backward in a lazy backward scenario. - Return the resulting object On cache hit: - Run AOTAutogradCacheEntry.post_compile() on the cache key. - This attempts to load the forward and backward graphs from FXGraphCache - As long as we successfully load from FXGraphCache, it's a hit. We then rewrap the callable with post compile wrappers using our saved metadata. For now, we ignore the fakified out and debug wrappers. We only save to the cache if Fakified out is turned off. V0 Guards behavior: FXGraphCache serializes guards that are needed in the shape_env based on the symint inputs to the graph. The invariant that AOTAutograd uses here is that the sources for symints given to it by dynamo are exactly the same as the ones it passes to inductor, for both the forward and backward passes. (This does *not* mean that the tensor values passed in are the same: only that their symints are). That is, AOTAutograd and Inductor never create new guards based on symints with *different sources* than those passed to it by inductor. We don't currently store any AOTAutograd specific guards: my hypothesis is that FXGraphCache already stores these, as any guards generated by AOTAutograd should already be in the shape_env before calling into inductor, and we don't generate new guards post inductor. If this is needed, I'll add it in another diff. Testing: We'll start with some basic unit tests, but I'll be adding more and more complicated testing as the next step. Pull Request resolved: pytorch#126791 Approved by: https://github.com/bdhirsh
1 parent 2e065f2 commit abc3eec

File tree

9 files changed

+741
-24
lines changed

9 files changed

+741
-24
lines changed

test/dynamo/test_aot_autograd_cache.py

Lines changed: 255 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,270 @@
11
# Owner(s): ["module: dynamo"]
22

3+
import os
4+
import unittest
5+
36
import torch
47
import torch._dynamo
58
import torch._dynamo.test_case
69

710
import torch._functorch._aot_autograd
11+
from torch._dynamo.utils import counters
812
from torch._functorch import config as functorch_config
913
from torch._functorch._aot_autograd.autograd_cache import (
10-
autograd_cache_hash,
14+
AOTAutogradCache,
15+
autograd_cache_key,
1116
BypassAOTAutogradCache,
1217
)
1318
from torch._functorch._aot_autograd.schemas import AOTConfig
1419
from torch._inductor import config as inductor_config
15-
16-
20+
from torch.testing._internal.common_cuda import SM80OrLater
21+
from torch.testing._internal.common_device_type import largeTensorTest
22+
from torch.testing._internal.common_utils import (
23+
instantiate_parametrized_tests,
24+
parametrize,
25+
)
26+
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
27+
28+
29+
@instantiate_parametrized_tests
30+
class AOTAutogradCacheTests(torch._dynamo.test_case.TestCase):
31+
def setUp(self):
32+
"""
33+
Reset all counters and caches before each unit test
34+
"""
35+
super().setUp()
36+
counters.clear()
37+
self._clear_all_caches()
38+
39+
def _clear_all_caches(self):
40+
"""
41+
Clear every cache, including AOTAutogradCache and FXCache
42+
"""
43+
torch._inductor.codecache.FxGraphCache.clear()
44+
AOTAutogradCache.clear()
45+
self._clear_dynamo_and_codecache()
46+
47+
def _clear_dynamo_and_codecache(self):
48+
"""
49+
Clear unrelated caches, like dynamo and PyCodeCache
50+
"""
51+
torch._dynamo.reset()
52+
for m in torch._inductor.codecache.PyCodeCache.cache.values():
53+
os.remove(m.__file__)
54+
torch._inductor.codecache.PyCodeCache.cache_clear()
55+
56+
@inductor_config.patch("fx_graph_cache", True)
57+
@functorch_config.patch({"enable_autograd_cache": True})
58+
def test_basic(self):
59+
"""
60+
Verify the interactions between FXGraphCache and AOTAutogradCache.
61+
"""
62+
63+
def fn(x, y):
64+
return (x * 2, y @ y)
65+
66+
a = torch.rand(25)
67+
b = torch.rand(5, 5)
68+
69+
compiled_fn = torch.compile(fn, backend="inductor")
70+
71+
# A first call should miss in the cache.
72+
self.assertEqual(fn(a, b), compiled_fn(a, b))
73+
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
74+
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
75+
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
76+
77+
# A second call should hit. (First reset so in-memory guards
78+
# don't prevent compilation).
79+
self._clear_dynamo_and_codecache()
80+
self.assertEqual(fn(a, b), compiled_fn(a, b))
81+
82+
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
83+
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
84+
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
85+
86+
@inductor_config.patch("fx_graph_cache", True)
87+
@functorch_config.patch({"enable_autograd_cache": True})
88+
def test_clear_fx_graph_cache(self):
89+
"""
90+
Verify the interactions between FXGraphCache and AOTAutogradCache.
91+
"""
92+
93+
def fn(x, y):
94+
return (x * 2, y @ y)
95+
96+
a = torch.rand(25)
97+
b = torch.rand(5, 5)
98+
99+
compiled_fn = torch.compile(fn, backend="inductor")
100+
101+
# A first call should miss in the cache.
102+
self.assertEqual(fn(a, b), compiled_fn(a, b))
103+
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
104+
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
105+
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
106+
107+
# Clear FX graph cache: second call should also be a miss
108+
self._clear_dynamo_and_codecache()
109+
torch._inductor.codecache.FxGraphCache.clear()
110+
self.assertEqual(fn(a, b), compiled_fn(a, b))
111+
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
112+
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
113+
# We save again into the cache
114+
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 2)
115+
116+
@inductor_config.patch("fx_graph_cache", False)
117+
@functorch_config.patch({"enable_autograd_cache": True})
118+
def test_fx_graph_cache_off(self):
119+
"""
120+
Should not use cache if FXGraphCache is not enabled
121+
"""
122+
123+
def fn(x, y):
124+
return (x * 2, y @ y)
125+
126+
a = torch.rand(25)
127+
b = torch.rand(5, 5)
128+
129+
compiled_fn = torch.compile(fn, backend="inductor")
130+
131+
# A first call should miss in the cache.
132+
self.assertEqual(fn(a, b), compiled_fn(a, b))
133+
self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1)
134+
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
135+
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 0)
136+
137+
# Clear FX graph cache: second call should also be a miss
138+
self._clear_dynamo_and_codecache()
139+
140+
self.assertEqual(fn(a, b), compiled_fn(a, b))
141+
self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 2)
142+
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
143+
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 0)
144+
145+
@inductor_config.patch("fx_graph_cache", True)
146+
@functorch_config.patch({"enable_autograd_cache": True})
147+
def test_autograd_function(self):
148+
"""
149+
Tests autograd cache hits
150+
"""
151+
152+
def fn(a, b):
153+
return a.sin() + b
154+
155+
a = torch.randn(25, requires_grad=True)
156+
b = torch.randn(25, requires_grad=True)
157+
a2 = a.detach().clone().requires_grad_(True)
158+
b2 = b.detach().clone().requires_grad_(True)
159+
160+
compiled_fn = torch.compile(fn, backend="inductor")
161+
162+
# A first call should miss in the cache.
163+
self.assertEqual(fn(a, b), compiled_fn(a2, b2))
164+
fn(a, b).sum().backward()
165+
compiled_fn(a2, b2).sum().backward()
166+
self.assertEqual(a.grad, a2.grad)
167+
self.assertEqual(b.grad, b2.grad)
168+
169+
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
170+
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
171+
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
172+
173+
# Reset all tensors
174+
a = torch.randn(25, requires_grad=True)
175+
b = torch.randn(25, requires_grad=True)
176+
a2 = a.detach().clone().requires_grad_(True)
177+
b2 = b.detach().clone().requires_grad_(True)
178+
179+
# A second call should hit. (First reset so in-memory guards
180+
# don't prevent compilation).
181+
self._clear_dynamo_and_codecache()
182+
self.assertEqual(fn(a, b), compiled_fn(a2, b2))
183+
fn(a, b).sum().backward()
184+
compiled_fn(a2, b2).sum().backward()
185+
self.assertEqual(a.grad, a2.grad)
186+
self.assertEqual(b.grad, b2.grad)
187+
188+
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
189+
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
190+
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
191+
192+
@largeTensorTest("64GB", device=GPU_TYPE)
193+
@parametrize("device", (GPU_TYPE,))
194+
@parametrize("dtype", (torch.float16, torch.bfloat16))
195+
@inductor_config.patch("fx_graph_cache", True)
196+
@functorch_config.patch({"enable_autograd_cache": True})
197+
def test_autograd_inductor_guards(self, device, dtype):
198+
"""
199+
Tests that functions that would add inductor guards are cached properly
200+
"""
201+
if device == GPU_TYPE and not HAS_GPU:
202+
raise unittest.SkipTest(f"requires {GPU_TYPE}")
203+
if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater:
204+
raise unittest.SkipTest("requires CUDA SM80 or later")
205+
206+
def fn(x, y):
207+
return (x + x, y + y)
208+
209+
compiled_fn = torch.compile(fn, dynamic=True)
210+
211+
# Iterate over different shapes, varying whether the total
212+
# size is below or above int32. For each combination, we expect
213+
# different guards around whether the symbolic sizes do or do
214+
# not exceed int32.
215+
shapes = (
216+
((5, 6), (7, 8)),
217+
((5, 6), (47000, 47001)),
218+
((47000, 47001), (5, 6)),
219+
)
220+
expected_hits = expected_misses = expected_saves = 0
221+
for a_shape, b_shape in shapes:
222+
a = torch.rand(a_shape, device=device, dtype=dtype)
223+
b = torch.rand(b_shape, device=device, dtype=dtype)
224+
225+
# AVOID a dynamo reset here. We expect guards to have been
226+
# added that will be violated with the new shape. We should
227+
# see a recompilation (along with a cache miss).
228+
res1 = compiled_fn(a, b)
229+
# A first call should miss in the cache.
230+
# NOTE: Currently, this cache miss is *not* due to guards,
231+
# but instead because the AOTAutogradCache key calculation specializes on input shapes.
232+
# Once we allow tensors with symints as part of the cache key calculation, it will
233+
# instead cache miss because of guard failure.
234+
expected_misses += 1
235+
expected_saves += 1
236+
self.assertEqual(
237+
counters["aot_autograd"]["autograd_cache_miss"], expected_misses
238+
)
239+
self.assertEqual(
240+
counters["aot_autograd"]["autograd_cache_hit"], expected_hits
241+
)
242+
self.assertEqual(
243+
counters["aot_autograd"]["autograd_cache_saved"], expected_saves
244+
)
245+
246+
# A second call should hit. (First reset so in-memory guards
247+
# don't prevent compilation).
248+
249+
# Now clear dynamo and we should see a cache hit
250+
# This should populate guards to dynamo's cache, so that a subsequent run with a different
251+
# shape will still trigger a second call to autograd_cache.
252+
self._clear_dynamo_and_codecache()
253+
res2 = compiled_fn(a, b)
254+
expected_hits += 1
255+
self.assertEqual(
256+
counters["aot_autograd"]["autograd_cache_miss"], expected_misses
257+
)
258+
self.assertEqual(
259+
counters["aot_autograd"]["autograd_cache_hit"], expected_hits
260+
)
261+
self.assertEqual(
262+
counters["aot_autograd"]["autograd_cache_saved"], expected_saves
263+
)
264+
self.assertEqual(res1, res2)
265+
266+
267+
@inductor_config.patch("fx_graph_cache", True)
17268
class AOTAutogradCachePicklerTests(torch._dynamo.test_case.TestCase):
18269
@property
19270
def device_type(self) -> str:
@@ -57,7 +308,7 @@ def gen_cache_key(self, f, config, inputs=None):
57308
if inputs is None:
58309
inputs = [torch.ones(3)]
59310
_, fx_g, example_inputs = self._get_dynamo_output(f, *inputs)
60-
return autograd_cache_hash(fx_g, example_inputs, config)
311+
return autograd_cache_key(fx_g, example_inputs, config)
61312

62313
def test_basic_hash_key(self):
63314
def fn(x):

0 commit comments

Comments
 (0)