@@ -678,17 +678,24 @@ def compile(
678678 )
679679
680680 gm = exported_program .module ()
681- # Move the weights in the state_dict to CPU
682681 logger .debug ("Input graph: " + str (gm .graph ))
683682
684683 # Apply lowering on the graph module
685684 gm = post_lowering (gm , settings )
686685 logger .debug ("Lowered Input graph: " + str (gm .graph ))
686+
687+ # Move the weights in the state_dict to CPU
687688 if offload_module_to_cpu :
688689 exported_program .module ().to (CPU_DEVICE )
689690 logger .info (
690- "The model is offloaded to CPU during compilation. If you want to keep the model on GPU, set offload_module_to_cpu=False. "
691+ "The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False"
691692 )
693+ else :
694+ remaining_memory , total_memory = torch .cuda .mem_get_info ()
695+ if remaining_memory < total_memory // 2 :
696+ logger .warning (
697+ "The remaining GPU memory is not enough to compile the model. This may lead to an OOM error. Consider setting offload_module_to_cpu=True."
698+ )
692699 trt_gm = compile_module (
693700 gm , trt_arg_inputs , trt_kwarg_inputs , settings , engine_cache
694701 )
@@ -833,7 +840,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
833840 str (name ),
834841 str (submodule .graph ),
835842 )
836- submodule .to (torch . cuda . current_device ( ))
843+ submodule .to (to_torch_device ( settings . device ))
837844 continue
838845
839846 if name not in submodule_node_dict :
0 commit comments