33import  logging 
44from  typing  import  Any , List , Optional , Sequence 
55
6+ import  tensorrt  as  trt 
67import  torch 
78from  torch .fx .experimental .proxy_tensor  import  unset_fake_temporarily 
89from  torch_tensorrt ._Device  import  Device 
1718from  torch_tensorrt .dynamo .runtime  import  PythonTorchTensorRTModule , TorchTensorRTModule 
1819from  torch_tensorrt .dynamo .utils  import  get_torch_inputs 
1920
20- import  tensorrt  as  trt 
21- 
2221logger  =  logging .getLogger (__name__ )
2322
2423
@@ -131,13 +130,13 @@ def convert_module(
131130    from  torch_tensorrt .dynamo ._refit  import  _refit_single_trt_engine_with_gm 
132131    from  torch_tensorrt .logging  import  TRT_LOGGER 
133132
134-     runtime  =  trt .Runtime (TRT_LOGGER )
135-     refit_test_engine  =  runtime .deserialize_cuda_engine (
136-         interpreter_result .serialized_engine 
137-     )
138133    weight_name_map : Any  =  None 
139134    # Do the test refit with cached map if make_refitable is enabled 
140135    if  settings .make_refitable :
136+         runtime  =  trt .Runtime (TRT_LOGGER )
137+         refit_test_engine  =  runtime .deserialize_cuda_engine (
138+             interpreter_result .serialized_engine 
139+         )
141140        weight_name_map  =  interpreter_result .weight_name_map 
142141        try :
143142            _refit_single_trt_engine_with_gm (
@@ -150,6 +149,9 @@ def convert_module(
150149        except  AssertionError :
151150            logger .warning ("Fast refit test failed. Removing the weight map caching." )
152151
152+         del  refit_test_engine 
153+         torch .cuda .empty_cache ()
154+ 
153155    rt_cls  =  PythonTorchTensorRTModule 
154156
155157    if  ENABLED_FEATURES .torch_tensorrt_runtime  and  not  settings .use_python_runtime :
0 commit comments