diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 1ca088ebbc720..cf17a84ce175e 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -420,10 +420,16 @@ def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode: # single GPU data transfer if self.single_gpu: - # for single GPU put inputs on gpu manually - root_gpu = 0 + if isinstance(self.data_parallel_device_ids, list): root_gpu = self.data_parallel_device_ids[0] + root_device = (torch.device("cuda", root_gpu) + if root_gpu else torch.device("cpu")) + torch.cuda.set_device(root_device) + else: + raise RuntimeError( + 'Expected `data_parallel_device_ids` as a list, cannot determine root gpu.' + ) batch = self.transfer_batch_to_gpu(batch, root_gpu) args[0] = batch diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 8c5a906fc3573..3b2219ec75ff4 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -365,6 +365,9 @@ def __init__( self.gpus = gpus self.data_parallel_device_ids = parse_gpu_ids(self.gpus) self.root_gpu = determine_root_gpu_device(self.data_parallel_device_ids) + root_device = (torch.device("cuda", self.root_gpu) + if self.root_gpu else torch.device("cpu")) + torch.cuda.set_device(root_device) # tpu state flags self.use_tpu = False