We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 0191442 commit f32283fCopy full SHA for f32283f
pina/trainer.py
@@ -1,5 +1,6 @@
1
""" Trainer module. """
2
3
+import torch
4
import pytorch_lightning
5
from .utils import check_consistency
6
from .dataset import SamplePointDataset, SamplePointLoader, DataPointDataset
@@ -63,6 +64,12 @@ def _create_or_update_loader(self):
63
64
self._loader = SamplePointLoader(
65
dataset_phys, dataset_data, batch_size=self.batch_size, shuffle=True
66
)
67
+ pb = self._model.problem
68
+ if hasattr(pb, "unknown_parameters"):
69
+ for key in pb.unknown_parameters:
70
+ pb.unknown_parameters[key] = torch.nn.Parameter(pb.unknown_parameters[key].data.to(device))
71
+
72
73
74
def train(self, **kwargs):
75
"""
0 commit comments