11import  copy 
2+ import  io 
23import  logging 
34import  os 
45import  pickle 
6+ import  pickletools 
57import  shutil 
68from  abc  import  ABC , abstractmethod 
7- from  typing  import  Any , Dict , List , Optional , Tuple , cast 
9+ from  typing  import  Any , Dict , List , Optional , Sequence ,  Tuple , cast 
810
911import  torch 
10- from  torch ._inductor .codecache  import  FxGraphCachePickler 
12+ from  torch ._inductor .codecache  import  FxGraphCachePickler ,  sha256_hash 
1113from  torch .fx .experimental .proxy_tensor  import  unset_fake_temporarily 
14+ from  torch_tensorrt ._Input  import  Input 
15+ from  torch_tensorrt .dynamo ._settings  import  (
16+     _SETTINGS_TO_BE_ENGINE_INVARIANT ,
17+     CompilationSettings ,
18+ )
1219
1320_LOGGER : logging .Logger  =  logging .getLogger (__name__ )
1421
22+ UnpackedCacheHit  =  Tuple [
23+     bytes ,
24+     List [str ],
25+     List [str ],
26+     Sequence [Input ],
27+     CompilationSettings ,
28+     Optional [Dict [str , Any ]],
29+ ]
30+ 
1531
1632class  BaseEngineCache (ABC ):
1733
@@ -24,7 +40,11 @@ def __init__(
2440        pass 
2541
2642    @staticmethod  
27-     def  get_hash (gm : torch .fx .GraphModule ) ->  str :
43+     def  get_hash (
44+         gm : torch .fx .GraphModule ,
45+         input_specs : Sequence [Input ],
46+         settings : CompilationSettings ,
47+     ) ->  str :
2848        """Get the hash value of the GraphModule 
2949
3050        Args: 
@@ -39,7 +59,23 @@ def get_hash(gm: torch.fx.GraphModule) -> str:
3959            for  name , param  in  new_gm .named_parameters ():
4060                param .data .zero_ ()
4161
42-             hash_val  =  cast (str , FxGraphCachePickler .get_hash (new_gm ))
62+             graph_hash_val  =  cast (str , FxGraphCachePickler .get_hash (new_gm ))
63+ 
64+         input_spec_strs  =  [str (i ) for  i  in  input_specs ]
65+         with  io .BytesIO () as  stream :
66+             input_specs_data  =  pickle .dumps (input_spec_strs )
67+             input_specs_data  =  pickletools .optimize (input_specs_data )
68+         input_specs_hash  =  sha256_hash (input_specs_data )
69+ 
70+         invariant_engine_specs  =  [
71+             str (getattr (settings , field )) for  field  in  _SETTINGS_TO_BE_ENGINE_INVARIANT 
72+         ]
73+         with  io .BytesIO () as  stream :
74+             engine_specs_data  =  pickle .dumps (invariant_engine_specs )
75+             engine_specs_data  =  pickletools .optimize (engine_specs_data )
76+         engine_specs_hash  =  sha256_hash (engine_specs_data )
77+ 
78+         hash_val : str  =  graph_hash_val  +  input_specs_hash  +  engine_specs_hash 
4379
4480        return  hash_val 
4581
@@ -48,6 +84,8 @@ def pack(
4884        serialized_engine : bytes ,
4985        input_names : List [str ],
5086        output_names : List [str ],
87+         input_specs : Sequence [Input ],
88+         compilation_settings : CompilationSettings ,
5189        weight_name_map : Optional [Dict [Any , Any ]],
5290    ) ->  bytes :
5391        """Pack serialized engine, input names, output names, and weight map into a single blob 
@@ -56,40 +94,83 @@ def pack(
5694            serialized_engine (bytes): serialized TRT engine 
5795            input_names (List[str]): input names of TRT engine 
5896            output_names (List[str]): output names of TRT engine 
97+             input_specs (Sequence[Input]): input specs of TRT engine 
98+             compilation_settings (CompilationSettings): compilation settings of TRT engine 
5999            weight_name_map (Optional[Dict[Any, Any]]): weight name map for refitting 
60100
61101        Returns: 
62102            bytes: packed blob 
63103        """ 
104+ 
105+         settings  =  copy .deepcopy (compilation_settings )
64106        return  pickle .dumps (
65107            {
66108                "serialized_engine" : bytes (serialized_engine ),
67109                "input_names" : input_names ,
68110                "output_names" : output_names ,
111+                 "input_specs" : input_specs ,
112+                 "compilation_settings" : settings ,
69113                "weight_name_map" : weight_name_map ,
70114            }
71115        )
72116
73117    @staticmethod  
74-     def  unpack (
75-         packed_obj : bytes ,
76-     ) ->  Tuple [bytes , List [str ], List [str ], Optional [Dict [Any , Any ]]]:
118+     def  unpack (packed_obj : bytes ) ->  UnpackedCacheHit :
77119        """Unpack packed blob into serialized engine, input names, output names, and weight map 
78120
79121        Args: 
80122            packed_obj (bytes): packed blob 
81123
82124        Returns: 
83-             Tuple[bytes, List[str], List[str], Optional[Dict[str, Any]]]: serialized engine, input names, output names, weight name map 
125+             Tuple[bytes, List[str], List[str], Sequence[Input], CompilationSettings,  Optional[Dict[str, Any]]]: serialized engine, input names, output names, input specs, CompilationSettings , weight name map 
84126        """ 
85127        unpacked  =  pickle .loads (packed_obj )
86128        return  (
87129            unpacked ["serialized_engine" ],
88130            unpacked ["input_names" ],
89131            unpacked ["output_names" ],
132+             unpacked ["input_specs" ],
133+             unpacked ["compilation_settings" ],
90134            unpacked ["weight_name_map" ],
91135        )
92136
137+     def  insert (
138+         self , hash : str , entry : UnpackedCacheHit , * args : Any , ** kwargs : Any 
139+     ) ->  None :
140+         """ 
141+         Insert a cache entry into the engine cache. 
142+ 
143+         Args: 
144+             hash (str): The hash value of the GraphModule. 
145+             entry (Tuple[bytes, List[str], List[str], CompilationSettings, Optional[Dict[Any, Any]]]): The cache entry to be inserted. 
146+             *args: Variable length argument list passed to ``save``. 
147+             **kwargs: Arbitrary keyword arguments passed to ``save``. 
148+ 
149+         Returns: 
150+             None 
151+         """ 
152+         packed_cache_info  =  BaseEngineCache .pack (* entry )
153+         return  self .save (hash , packed_cache_info , * args , ** kwargs )
154+ 
155+     def  check (self , hash : str , * args : Any , ** kwargs : Any ) ->  Optional [UnpackedCacheHit ]:
156+         """ 
157+         Check if a cache entry exists for the given hash. 
158+ 
159+         Args: 
160+             hash (str): The hash value of the GraphModule. 
161+             *args: Variable length argument list passed to ``load``. 
162+             **kwargs: Arbitrary keyword arguments passed to ``load``. 
163+ 
164+         Returns: 
165+             Optional[Tuple[bytes, List[str], List[str], CompilationSettings, Optional[Dict[Any, Any]]]]: The unpacked cache entry if found, None otherwise. 
166+         """ 
167+         packed_cache_info  =  self .load (hash , * args , ** kwargs )
168+ 
169+         if  packed_cache_info :
170+             return  BaseEngineCache .unpack (packed_cache_info )
171+         else :
172+             return  None 
173+ 
93174    @abstractmethod  
94175    def  save (self , hash : str , blob : bytes , * args : Any , ** kwargs : Any ) ->  None :
95176        """Store blob in cache 
@@ -203,11 +284,7 @@ def LRU() -> None:
203284        else :
204285            LRU ()
205286
206-     def  save (
207-         self ,
208-         hash : str ,
209-         blob : bytes ,
210-     ) ->  None :
287+     def  save (self , hash : str , blob : bytes , * args : Any , ** kwargs : Any ) ->  None :
211288        blob_size  =  len (blob )
212289        if  blob_size  >  self .total_engine_cache_size :
213290            _LOGGER .warning (
@@ -244,7 +321,7 @@ def save(
244321                f"The size { blob_size }   is still larger than the available cache size { self .available_engine_cache_size }  ." 
245322            )
246323
247-     def  load (self , hash : str ) ->  Optional [bytes ]:
324+     def  load (self , hash : str ,  * args :  Any ,  ** kwargs :  Any ) ->  Optional [bytes ]:
248325        directory  =  os .path .join (self .engine_cache_dir , hash )
249326        if  os .path .exists (directory ):
250327            blob_path  =  os .path .join (directory , "blob.bin" )
0 commit comments