Skip to content

Commit f32283f

Browse files
annaivagnesDario Coscia
authored andcommitted
fix GPU training in inverse problem (#283)
1 parent 0191442 commit f32283f

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

pina/trainer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
""" Trainer module. """
22

3+
import torch
34
import pytorch_lightning
45
from .utils import check_consistency
56
from .dataset import SamplePointDataset, SamplePointLoader, DataPointDataset
@@ -63,6 +64,12 @@ def _create_or_update_loader(self):
6364
self._loader = SamplePointLoader(
6465
dataset_phys, dataset_data, batch_size=self.batch_size, shuffle=True
6566
)
67+
pb = self._model.problem
68+
if hasattr(pb, "unknown_parameters"):
69+
for key in pb.unknown_parameters:
70+
pb.unknown_parameters[key] = torch.nn.Parameter(pb.unknown_parameters[key].data.to(device))
71+
72+
6673

6774
def train(self, **kwargs):
6875
"""

0 commit comments

Comments
 (0)