@@ -422,6 +422,7 @@ def compile(
422422 enable_weight_streaming : bool = _defaults .ENABLE_WEIGHT_STREAMING ,
423423 tiling_optimization_level : str = _defaults .TILING_OPTIMIZATION_LEVEL ,
424424 l2_limit_for_tiling : int = _defaults .L2_LIMIT_FOR_TILING ,
425+ offload_module_to_cpu : bool = _defaults .OFFLOAD_MODULE_TO_CPU ,
425426 ** kwargs : Any ,
426427) -> torch .fx .GraphModule :
427428 """Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
@@ -666,6 +667,7 @@ def compile(
666667 "enable_weight_streaming" : enable_weight_streaming ,
667668 "tiling_optimization_level" : tiling_optimization_level ,
668669 "l2_limit_for_tiling" : l2_limit_for_tiling ,
670+ "offload_module_to_cpu" : offload_module_to_cpu ,
669671 }
670672
671673 settings = CompilationSettings (** compilation_options )
@@ -677,16 +679,16 @@ def compile(
677679
678680 gm = exported_program .module ()
679681 # Move the weights in the state_dict to CPU
680- logger .info (
681- "The model is moved to CPU during compilation. If you want to keep the model on GPU, call module.to('cuda') on the model after compilation."
682- )
683682 logger .debug ("Input graph: " + str (gm .graph ))
684683
685684 # Apply lowering on the graph module
686685 gm = post_lowering (gm , settings )
687686 logger .debug ("Lowered Input graph: " + str (gm .graph ))
688-
689- exported_program .module ().to (CPU_DEVICE )
687+ if offload_module_to_cpu :
688+ exported_program .module ().to (CPU_DEVICE )
689+ logger .info (
690+ "The model is offloaded to CPU during compilation. If you want to keep the model on GPU, set offload_module_to_cpu=False."
691+ )
690692 trt_gm = compile_module (
691693 gm , trt_arg_inputs , trt_kwarg_inputs , settings , engine_cache
692694 )
0 commit comments