44from typing import Any , Callable , Dict , List , NamedTuple , Optional , Sequence , Set
55
66import numpy as np
7+ import tensorrt as trt
78import torch
89import torch .fx
910from torch .fx .node import _get_qualified_name
2526from torch_tensorrt .fx .observer import Observer
2627from torch_tensorrt .logging import TRT_LOGGER
2728
28- import tensorrt as trt
2929from packaging import version
3030
3131_LOGGER : logging .Logger = logging .getLogger (__name__ )
@@ -313,8 +313,10 @@ def run(
313313 )
314314 timing_cache = self ._create_timing_cache (builder_config , existing_cache )
315315
316- engine = self .builder .build_serialized_network (self .ctx .net , builder_config )
317- assert engine
316+ serialized_engine = self .builder .build_serialized_network (
317+ self .ctx .net , builder_config
318+ )
319+ assert serialized_engine
318320
319321 serialized_cache = (
320322 bytearray (timing_cache .serialize ())
@@ -324,10 +326,10 @@ def run(
324326 _LOGGER .info (
325327 f"Build TRT engine elapsed time: { datetime .now () - build_engine_start_time } "
326328 )
327- _LOGGER .info (f"TRT Engine uses: { engine .nbytes } bytes of Memory" )
329+ _LOGGER .info (f"TRT Engine uses: { serialized_engine .nbytes } bytes of Memory" )
328330
329331 return TRTInterpreterResult (
330- engine , self ._input_names , self ._output_names , serialized_cache
332+ serialized_engine , self ._input_names , self ._output_names , serialized_cache
331333 )
332334
333335 def run_node (self , n : torch .fx .Node ) -> torch .fx .Node :
0 commit comments