1- import ast
21import copy
32import logging
43import os
4+ import pickle
55import shutil
6+ import sys
67from abc import ABC , abstractmethod
78from typing import Any , Dict , List , Optional , Tuple , cast
89
@@ -50,6 +51,7 @@ def save(
5051 serialized_engine : bytes ,
5152 input_names : List [str ],
5253 output_names : List [str ],
54+ weight_name_map : Optional [Dict [str , Any ]] = None ,
5355 ) -> bool :
5456 """Save the serialized engine to hard disk
5557
@@ -58,21 +60,24 @@ def save(
5860 serialized_engine (bytes): serialized TRT engine
5961 input_names (List[str]): input names of TRT engine
6062 output_names (List[str]): output names of TRT engine
63+ weight_name_map (Optional[Dict[str, Any]]): weight name map for refitting
6164
6265 Returns:
6366 bool: whether the serialized engine is saved successfully
6467 """
6568 pass
6669
6770 @abstractmethod
68- def load (self , hash : str ) -> Tuple [Optional [bytes ], List [str ], List [str ]]:
71+ def load (
72+ self , hash : str
73+ ) -> Tuple [Optional [bytes ], List [str ], List [str ], Optional [Dict [str , Any ]]]:
6974 """Load the serialized engine from hard disk
7075
7176 Args:
7277 hash (str): hash value of the GraphModule
7378
7479 Returns:
75- Sequence[Optional[bytes], List[str], List[str]] : serialized TRT engine, input names of TRT Engine , output names of TRT Engine
80+ Sequence[Optional[bytes], List[str], List[str], Optional[Dict[str, Any]]] : serialized engine, input names, output names, weight name map
7681 """
7782 pass
7883
@@ -89,16 +94,16 @@ def __init__(
8994 self .engine_cache_dir = engine_cache_dir
9095 self .hash2size_map : Dict [str , int ] = {}
9196
92- def has_available_cache_size (self , serialized_engine : bytes ) -> bool :
97+ def has_available_cache_size (self , needed_size : int ) -> bool :
9398 """Check if the cache has available space for saving the serialized engine
9499
95100 Args:
96- serialized_engine (bytes ): serialized TRT engine
101+ needed_size (int ): needed size for erialized TRT engine and/or weight_name_map
97102
98103 Returns:
99104 bool: whether the cache has available size for the serialized engine
100105 """
101- return int ( serialized_engine . nbytes ) <= self .available_engine_cache_size
106+ return needed_size <= self .available_engine_cache_size
102107
103108 def clear_cache (self , needed_min_size : int ) -> bool :
104109 """Clear the cache to make sure at least `needed_min_size` bytes are available, if possible
@@ -154,36 +159,75 @@ def save(
154159 serialized_engine : bytes ,
155160 input_names : List [str ],
156161 output_names : List [str ],
162+ weight_name_map : Optional [Dict [str , Any ]] = None ,
157163 ) -> bool :
158164 serialized_engine_size = int (serialized_engine .nbytes )
165+ if weight_name_map is not None :
166+ serialized_engine_size += sum (
167+ sys .getsizeof (v ) for v in weight_name_map .values ()
168+ )
159169 if serialized_engine_size > self .total_engine_cache_size :
160170 _LOGGER .warning (
161171 f"The serialized engine cannot be saved because the size of the engine { serialized_engine_size } is larger than the total cache size { self .total_engine_cache_size } ."
162172 )
163173 return False
164174
165- # Check if there is enough available cache size for the serialized engine
166- if not self .has_available_cache_size (serialized_engine ):
175+ # Check if there is enough available cache size for the serialized engine and/or weight_name_map
176+ if not self .has_available_cache_size (serialized_engine_size ):
167177 self .clear_cache (serialized_engine_size )
168178
169179 # Save the serialized engine to the cache directory
170- if self .has_available_cache_size (serialized_engine ):
171- path = os .path .join (
172- self .engine_cache_dir ,
173- f"{ hash } /engine--{ input_names } --{ output_names } .trt" ,
180+ if self .has_available_cache_size (serialized_engine_size ):
181+ self .hash2size_map [hash ] = serialized_engine_size
182+ self .available_engine_cache_size -= serialized_engine_size
183+ directory = os .path .join (self .engine_cache_dir , hash )
184+
185+ engine_path = os .path .join (
186+ directory ,
187+ "engine.trt" ,
188+ )
189+ io_names_path = os .path .join (
190+ directory ,
191+ "io_names.pkl" ,
174192 )
175193 try :
176- os .makedirs (os .path .dirname (path ), exist_ok = True )
177- with open (path , "wb" ) as f :
194+ os .makedirs (os .path .dirname (engine_path ), exist_ok = True )
195+ with open (engine_path , "wb" ) as f :
178196 f .write (serialized_engine )
179- self .hash2size_map [hash ] = serialized_engine_size
180- self .available_engine_cache_size -= serialized_engine_size
181- _LOGGER .info (f"A TRT engine was cached to { path } " )
182-
197+ os .makedirs (os .path .dirname (io_names_path ), exist_ok = True )
198+ with open (io_names_path , "wb" ) as f :
199+ pickle .dump (
200+ {"input_names" : input_names , "output_names" : output_names }, f
201+ )
202+ _LOGGER .info (f"The TRT engine was saved to { engine_path } " )
183203 except Exception as e :
184- _LOGGER .warning (f"Failed to save the TRT engine to { path } : { e } " )
204+ del self .hash2size_map [hash ]
205+ self .available_engine_cache_size += serialized_engine_size
206+ shutil .rmtree (directory )
207+ _LOGGER .warning (f"Failed to save the TRT engine to { engine_path } : { e } " )
185208 return False
186209
210+ if weight_name_map is not None :
211+ weight_name_map_path = os .path .join (
212+ directory ,
213+ "weight_name_map.pkl" ,
214+ )
215+ try :
216+ os .makedirs (os .path .dirname (weight_name_map_path ), exist_ok = True )
217+ with open (weight_name_map_path , "wb" ) as f :
218+ pickle .dump (weight_name_map , f )
219+ _LOGGER .info (
220+ f"The weight_name_map was saved to { weight_name_map_path } "
221+ )
222+ except Exception as e :
223+ del self .hash2size_map [hash ]
224+ self .available_engine_cache_size += serialized_engine_size
225+ shutil .rmtree (directory )
226+ _LOGGER .warning (
227+ f"Failed to save the weight_name_map to { weight_name_map_path } : { e } "
228+ )
229+ return False
230+
187231 return True
188232
189233 else :
@@ -192,21 +236,33 @@ def save(
192236 )
193237 return False
194238
195- def load (self , hash : str ) -> Tuple [Optional [bytes ], List [str ], List [str ]]:
239+ def load (
240+ self , hash : str
241+ ) -> Tuple [Optional [bytes ], List [str ], List [str ], Optional [Dict [str , Any ]]]:
196242 directory = os .path .join (self .engine_cache_dir , hash )
197243 if os .path .exists (directory ):
198- engine_list = os .listdir (directory )
199- assert (
200- len (engine_list ) == 1
201- ), f"There are more than one engine { engine_list } under { directory } ."
202- path = os .path .join (directory , engine_list [0 ])
203- input_names_str , output_names_str = (
204- engine_list [0 ].split (".trt" )[0 ].split ("--" )[1 :]
205- )
206- input_names = ast .literal_eval (input_names_str )
207- output_names = ast .literal_eval (output_names_str )
208- with open (path , "rb" ) as f :
209- serialized_engine = f .read ()
210- return serialized_engine , input_names , output_names
244+ # load engine
245+ serialized_engine = None
246+ engine_path = os .path .join (directory , "engine.trt" )
247+ if os .path .exists (engine_path ):
248+ with open (engine_path , "rb" ) as f :
249+ serialized_engine = f .read ()
250+
251+ input_names = []
252+ output_names = []
253+ io_names_path = os .path .join (directory , "io_names.pkl" )
254+ if os .path .exists (io_names_path ):
255+ with open (io_names_path , "rb" ) as f :
256+ io_names = pickle .load (f )
257+ input_names = io_names ["input_names" ]
258+ output_names = io_names ["output_names" ]
259+
260+ # load weight_name_map
261+ weight_name_map = None
262+ weight_name_map_path = os .path .join (directory , "weight_name_map.pkl" )
263+ if os .path .exists (weight_name_map_path ):
264+ with open (weight_name_map_path , "rb" ) as f :
265+ weight_name_map = pickle .load (f )
266+ return serialized_engine , input_names , output_names , weight_name_map
211267 else :
212- return None , [], []
268+ return None , [], [], {}
0 commit comments