2525logger = init_logger (__name__ )
2626
2727
28+ @dataclasses .dataclass
29+ class InductorArtifact :
30+ hash_str : str = ""
31+ file_path : str = ""
32+
33+
2834class InductorHashCache :
2935 """
3036 Disk format: a Python list of tuples, each tuple is
31- (runtime_shape, graph_index, hash_str)
37+ (runtime_shape, graph_index, hash_str, file_path )
3238 We use list of tuple for readability.
3339
3440 In-memory format: a defaultdict of dict, where the key is
3541 runtime_shape, and the value is a dict of graph_index to hash_str.
3642
37- The data is essentially `Dict[Optional[int], Dict[int, str ]]`,
43+ The data is essentially `Dict[Optional[int], Dict[int, InductorArtifact ]]`,
3844 we don't use json here because json doesn't support int as key.
3945
4046 TODO: better off-the-shelf solution to serialize the data?
4147 """
4248
4349 def __init__ (self , cache_dir : str , disabled : bool = False ):
44- self .cache : defaultdict = defaultdict (dict )
50+ self .cache : Dict [Optional [int ],
51+ Dict [int , InductorArtifact ]] = defaultdict (dict )
4552 self .disabled = disabled
4653 self .cache_dir = cache_dir
4754 self .cache_file_path = os .path .join (cache_dir ,
@@ -66,14 +73,25 @@ def deserialize(self, data: str):
6673 # because it is a safe way to parse Python literals.
6774 # do not use eval(), it is unsafe.
6875 list_data = ast .literal_eval (data )
69- for runtime_shape , graph_index , hash_str in list_data :
70- self .cache [runtime_shape ][graph_index ] = hash_str
76+ for item in list_data :
77+ runtime_shape = item [0 ]
78+ graph_index = item [1 ]
79+ hash_str = item [2 ]
80+ # for compatibility of old version,
81+ # where we don't have file_path.
82+ # NOTE: after running the new code, the file_path
83+ # will be updated.
84+ file_path = "" if len (item ) == 3 else item [3 ]
85+ self .cache [runtime_shape ][graph_index ] = InductorArtifact (
86+ hash_str = hash_str , file_path = file_path )
7187
7288 def serialize (self ) -> str :
7389 data = []
74- for runtime_shape , graph_index_to_hash_str in self .cache .items ():
75- for graph_index , hash_str in graph_index_to_hash_str .items ():
76- data .append ((runtime_shape , graph_index , hash_str ))
90+ for runtime_shape , value in self .cache .items ():
91+ for graph_index , inductor_artifact in value .items ():
92+ data .append (
93+ (runtime_shape , graph_index , inductor_artifact .hash_str ,
94+ inductor_artifact .file_path ))
7795 printer = pprint .PrettyPrinter (indent = 4 )
7896 return printer .pformat (data )
7997
@@ -90,13 +108,14 @@ def __contains__(self, key: Tuple[Optional[int], int]) -> bool:
90108 return runtime_shape in self .cache and graph_index in self .cache [
91109 runtime_shape ]
92110
93- def __getitem__ (self , key : Tuple [Optional [int ], int ]) -> str :
111+ def __getitem__ (self , key : Tuple [Optional [int ], int ]) -> InductorArtifact :
94112 if self .disabled :
95113 raise KeyError ("cannot read from disabled cache" )
96114 runtime_shape , graph_index = key
97115 return self .cache [runtime_shape ][graph_index ]
98116
99- def __setitem__ (self , key : Tuple [Optional [int ], int ], value : str ):
117+ def __setitem__ (self , key : Tuple [Optional [int ], int ],
118+ value : InductorArtifact ):
100119 # setitem for disabled cache is fine, because we
101120 # don't actually write to the disk
102121 runtime_shape , graph_index = key
@@ -181,7 +200,8 @@ def wrap_inductor(graph: fx.GraphModule,
181200 if (runtime_shape , graph_index ) in cache_data :
182201 # we compiled this graph before
183202 # so we can directly lookup the compiled graph via hash
184- hash_str = cache_data [(runtime_shape , graph_index )]
203+ inductor_artifact = cache_data [(runtime_shape , graph_index )]
204+ hash_str = inductor_artifact .hash_str
185205 if graph_index == 0 :
186206 # adds some info logging for the first graph
187207 logger .info (
@@ -199,6 +219,7 @@ def wrap_inductor(graph: fx.GraphModule,
199219 "Inductor cache lookup failed. Please remove"
200220 f"the cache file { cache_data .cache_file_path } and try again." # noqa
201221 )
222+ inductor_artifact .file_path = inductor_compiled_graph .current_callable .__code__ .co_filename # noqa
202223
203224 # Inductor calling convention (function signature):
204225 # f(list) -> tuple
@@ -224,19 +245,20 @@ def compiled_graph(*args):
224245 # the assumption is that we don't have nested Inductor compilation.
225246 # compiled_fx_graph_hash will only be called once, and we can hook
226247 # it to get the hash of the compiled graph directly.
227- from torch ._inductor .codecache import compiled_fx_graph_hash
248+
249+ inductor_artifact = InductorArtifact ()
250+ from torch ._inductor .codecache import (FxGraphCache ,
251+ compiled_fx_graph_hash )
252+ original_load = FxGraphCache .load
253+
254+ def hijack_load (* args , ** kwargs ):
255+ inductor_compiled_graph = original_load (* args , ** kwargs )
256+ inductor_artifact .file_path = inductor_compiled_graph .current_callable .__code__ .co_filename # noqa
257+ return inductor_compiled_graph
228258
229259 def hijack_compiled_fx_graph_hash (* args , ** kwargs ):
230260 out = compiled_fx_graph_hash (* args , ** kwargs )
231- # store the hash in the cache
232- nonlocal cache_data
233- cache_data [(runtime_shape , graph_index )] = out [0 ]
234- if graph_index == 0 :
235- # adds some info logging for the first graph
236- logger .info ("Cache the graph of shape %s for later use" ,
237- str (runtime_shape ))
238- logger .debug ("store the %s-th graph for shape %s via hash %s" ,
239- graph_index , str (runtime_shape ), out [0 ])
261+ inductor_artifact .hash_str = out [0 ]
240262 return out
241263
242264 def _check_can_cache (* args , ** kwargs ):
@@ -255,6 +277,11 @@ def _get_shape_env() -> AlwaysHitShapeEnv:
255277 if not cache_data .disabled :
256278 # compilation cache is enabled, patch several functions
257279
280+ # hijack to get the compiled graph itself
281+ stack .enter_context (
282+ patch ("torch._inductor.codecache.FxGraphCache.load" ,
283+ hijack_load ))
284+
258285 # for hijacking the hash of the compiled graph
259286 stack .enter_context (
260287 patch ("torch._inductor.codecache.compiled_fx_graph_hash" ,
@@ -275,7 +302,16 @@ def _get_shape_env() -> AlwaysHitShapeEnv:
275302 compiled_graph = compile_fx (graph ,
276303 example_inputs ,
277304 config_patches = current_config )
278-
305+ # store the inductor_artifact in the cache
306+ cache_data [(runtime_shape , graph_index )] = inductor_artifact
307+ if graph_index == 0 :
308+ # adds some info logging for the first graph
309+ logger .info ("Cache the graph of shape %s for later use" ,
310+ str (runtime_shape ))
311+ logger .debug (
312+ "store the %s-th graph for shape %s via hash %s from file %s" ,
313+ graph_index , str (runtime_shape ), inductor_artifact .hash_str ,
314+ inductor_artifact .file_path )
279315 # after compiling the last graph, record the end time
280316 if graph_index == num_graphs - 1 :
281317 now = time .time ()
0 commit comments