Skip to content

Commit 9c93b60

Browse files
committed
Fix bugs in 0.2 (#344)
* Fix some bugs
1 parent 5084747 commit 9c93b60

File tree

11 files changed

+112
-55
lines changed

11 files changed

+112
-55
lines changed

examples/problems/stokes.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,19 @@ def wall(input_, output_):
4949
value = 0.0
5050
return output_.extract(['ux', 'uy']) - value
5151

52+
domains = {
53+
'gamma_top': CartesianDomain({'x': [-2, 2], 'y': 1}),
54+
'gamma_bot': CartesianDomain({'x': [-2, 2], 'y': -1}),
55+
'gamma_out': CartesianDomain({'x': 2, 'y': [-1, 1]}),
56+
'gamma_in': CartesianDomain({'x': -2, 'y': [-1, 1]}),
57+
'D': CartesianDomain({'x': [-2, 2], 'y': [-1, 1]})
58+
}
59+
5260
# problem condition statement
5361
conditions = {
54-
'gamma_top': Condition(location=CartesianDomain({'x': [-2, 2], 'y': 1}), equation=Equation(wall)),
55-
'gamma_bot': Condition(location=CartesianDomain({'x': [-2, 2], 'y': -1}), equation=Equation(wall)),
56-
'gamma_out': Condition(location=CartesianDomain({'x': 2, 'y': [-1, 1]}), equation=Equation(outlet)),
57-
'gamma_in': Condition(location=CartesianDomain({'x': -2, 'y': [-1, 1]}), equation=Equation(inlet)),
58-
'D': Condition(location=CartesianDomain({'x': [-2, 2], 'y': [-1, 1]}), equation=SystemEquation([momentum, continuity]))
62+
'gamma_top': Condition(domain='gamma_top', equation=Equation(wall)),
63+
'gamma_bot': Condition(domain='gamma_bot', equation=Equation(wall)),
64+
'gamma_out': Condition(domain='gamma_out', equation=Equation(outlet)),
65+
'gamma_in': Condition(domain='gamma_in', equation=Equation(inlet)),
66+
'D': Condition(domain='D', equation=SystemEquation([momentum, continuity]))
5967
}

examples/run_stokes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717

1818
# create problem and discretise domain
1919
stokes_problem = Stokes()
20-
stokes_problem.discretise_domain(n=1000, locations=['gamma_top', 'gamma_bot', 'gamma_in', 'gamma_out'])
21-
stokes_problem.discretise_domain(n=2000, locations=['D'])
20+
stokes_problem.discretise_domain(n=1000, domains=['gamma_top', 'gamma_bot', 'gamma_in', 'gamma_out'])
21+
stokes_problem.discretise_domain(n=2000, domains=['D'])
2222

2323
# make the model
2424
model = FeedForward(

pina/condition/condition.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,14 +84,15 @@ def __new__(cls, *args, **kwargs):
8484
return DomainEquationCondition(**kwargs)
8585
else:
8686
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")
87-
87+
# TODO: remove, not used anymore
88+
'''
8889
if (
8990
sorted(kwargs.keys()) != sorted(["input_points", "output_points"])
9091
and sorted(kwargs.keys()) != sorted(["location", "equation"])
9192
and sorted(kwargs.keys()) != sorted(["input_points", "equation"])
9293
):
9394
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")
94-
95+
# TODO: remove, not used anymore
9596
if not self._dictvalue_isinstance(kwargs, "input_points", LabelTensor):
9697
raise TypeError("`input_points` must be a torch.Tensor.")
9798
if not self._dictvalue_isinstance(kwargs, "output_points", LabelTensor):
@@ -103,3 +104,4 @@ def __new__(cls, *args, **kwargs):
103104
104105
for key, value in kwargs.items():
105106
setattr(self, key, value)
107+
'''

pina/condition/condition_interface.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,7 @@ def residual(self, model):
1515
:param model: The model to evaluate the condition.
1616
:return: The residual of the condition.
1717
"""
18-
pass
18+
pass
19+
20+
def set_problem(self, problem):
21+
self._problem = problem

pina/condition/domain_equation_condition.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,20 @@ def __init__(self, domain, equation):
1515
self.domain = domain
1616
self.equation = equation
1717

18+
def residual(self, model):
19+
"""
20+
Compute the residual of the condition.
21+
"""
22+
self.batch_residual(model, self.domain, self.equation)
23+
1824
@staticmethod
1925
def batch_residual(model, input_pts, equation):
2026
"""
2127
Compute the residual of the condition for a single batch. Input and
2228
output points are provided as arguments.
2329
2430
:param torch.nn.Module model: The model to evaluate the condition.
25-
:param torch.Tensor input_points: The input points.
26-
:param torch.Tensor output_points: The output points.
31+
:param torch.Tensor input_pts: The input points.
32+
:param torch.Tensor equation: The output points.
2733
"""
28-
return equation.residual(model(input_pts))
34+
return equation.residual(input_pts, model(input_pts))

pina/condition/domain_output_condition.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,5 @@ def batch_residual(model, input_points, output_points):
4040
:param torch.Tensor input_points: The input points.
4141
:param torch.Tensor output_points: The output points.
4242
"""
43+
4344
return output_points - model(input_points)

pina/domain/cartesian.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
import torch
23

34
from .domain_interface import DomainInterface
45
from ..label_tensor import LabelTensor

pina/label_tensor.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from torch import Tensor
66

77

8-
98
# class LabelTensor(torch.Tensor):
109
# """Torch tensor with a label for any column."""
1110

@@ -307,13 +306,13 @@
307306
# s = "no labels\n"
308307
# s += super().__str__()
309308
# return s
310-
311309
def issubset(a, b):
312310
"""
313311
Check if a is a subset of b.
314312
"""
315313
return set(a).issubset(set(b))
316314

315+
317316
class LabelTensor(torch.Tensor):
318317
"""Torch tensor with a label for any column."""
319318

@@ -403,6 +402,10 @@ def extract(self, label_to_extract):
403402
return LabelTensor(new_tensor, label_to_extract)
404403

405404
def __str__(self):
405+
"""
406+
returns a string with the representation of the class
407+
"""
408+
406409
s = ''
407410
for key, value in self.labels.items():
408411
s += f"{key}: {value}\n"
@@ -431,4 +434,32 @@ def requires_grad_(self, mode=True):
431434

432435
@property
433436
def dtype(self):
434-
return super().dtype
437+
return super().dtype
438+
439+
440+
def to(self, *args, **kwargs):
441+
"""
442+
Performs Tensor dtype and/or device conversion. For more details, see
443+
:meth:`torch.Tensor.to`.
444+
"""
445+
tmp = super().to(*args, **kwargs)
446+
new = self.__class__.clone(self)
447+
new.data = tmp.data
448+
return new
449+
450+
451+
def clone(self, *args, **kwargs):
452+
"""
453+
Clone the LabelTensor. For more details, see
454+
:meth:`torch.Tensor.clone`.
455+
456+
:return: A copy of the tensor.
457+
:rtype: LabelTensor
458+
"""
459+
# # used before merging
460+
# try:
461+
# out = LabelTensor(super().clone(*args, **kwargs), self.labels)
462+
# except:
463+
# out = super().clone(*args, **kwargs)
464+
out = LabelTensor(super().clone(*args, **kwargs), self.labels)
465+
return out

pina/problem/abstract_problem.py

Lines changed: 41 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -20,26 +20,26 @@ class AbstractProblem(metaclass=ABCMeta):
2020

2121
def __init__(self):
2222

23-
2423
self._discretized_domains = {}
2524

2625
for name, domain in self.domains.items():
2726
if isinstance(domain, (torch.Tensor, LabelTensor)):
2827
self._discretized_domains[name] = domain
2928

3029
for condition_name in self.conditions:
31-
self.conditions[condition_name]._problem = self
30+
self.conditions[condition_name].set_problem(self)
31+
3232
# # variable storing all points
33-
# self.input_pts = {}
33+
self.input_pts = {}
3434

3535
# # varible to check if sampling is done. If no location
3636
# # element is presented in Condition this variable is set to true
3737
# self._have_sampled_points = {}
38-
# for condition_name in self.conditions:
39-
# self._have_sampled_points[condition_name] = False
38+
for condition_name in self.conditions:
39+
self._discretized_domains[condition_name] = False
4040

4141
# # put in self.input_pts all the points that we don't need to sample
42-
# self._span_condition_points()
42+
self._span_condition_points()
4343

4444
def __deepcopy__(self, memo):
4545
"""
@@ -125,7 +125,7 @@ def _span_condition_points(self):
125125
if hasattr(condition, "input_points"):
126126
samples = condition.input_points
127127
self.input_pts[condition_name] = samples
128-
self._have_sampled_points[condition_name] = True
128+
self._discretized_domains[condition_name] = True
129129
if hasattr(self, "unknown_parameter_domain"):
130130
# initialize the unknown parameters of the inverse problem given
131131
# the domain the user gives
@@ -141,7 +141,7 @@ def _span_condition_points(self):
141141
)
142142

143143
def discretise_domain(
144-
self, n, mode="random", variables="all", locations="all"
144+
self, n, mode="random", variables="all", domains="all"
145145
):
146146
"""
147147
Generate a set of points to span the `Location` of all the conditions of
@@ -192,31 +192,37 @@ def discretise_domain(
192192
f"should be in {self.input_variables}.",
193193
)
194194

195-
# check consistency location
196-
locations_to_sample = [
197-
condition
198-
for condition in self.conditions
199-
if hasattr(self.conditions[condition], "location")
200-
]
201-
if locations == "all":
202-
# only locations that can be sampled
203-
locations = locations_to_sample
195+
# # check consistency location # TODO: check if this is needed (from 0.1)
196+
# locations_to_sample = [
197+
# condition
198+
# for condition in self.conditions
199+
# if hasattr(self.conditions[condition], "location")
200+
# ]
201+
# if locations == "all":
202+
# # only locations that can be sampled
203+
# locations = locations_to_sample
204+
# else:
205+
# check_consistency(locations, str)
206+
207+
# if sorted(locations) != sorted(locations_to_sample):
208+
if domains == "all":
209+
domains = [condition for condition in self.conditions]
204210
else:
205-
check_consistency(locations, str)
206-
207-
if sorted(locations) != sorted(locations_to_sample):
211+
check_consistency(domains, str)
212+
print(domains)
213+
if sorted(domains) != sorted(self.conditions):
208214
TypeError(
209215
f"Wrong locations for sampling. Location ",
210216
f"should be in {locations_to_sample}.",
211217
)
212218

213219
# sampling
214-
for location in locations:
215-
condition = self.conditions[location]
220+
for d in domains:
221+
condition = self.conditions[d]
216222

217223
# we try to check if we have already sampled
218224
try:
219-
already_sampled = [self.input_pts[location]]
225+
already_sampled = [self.input_pts[d]]
220226
# if we have not sampled, a key error is thrown
221227
except KeyError:
222228
already_sampled = []
@@ -225,25 +231,27 @@ def discretise_domain(
225231
# but we want to sample again we set already_sampled
226232
# to an empty list since we need to sample again, and
227233
# self._have_sampled_points to False.
228-
if self._have_sampled_points[location]:
234+
if self._discretized_domains[d]:
229235
already_sampled = []
230-
self._have_sampled_points[location] = False
231-
236+
self._discretized_domains[d] = False
237+
print(condition.domain)
238+
print(d)
232239
# build samples
233240
samples = [
234-
condition.location.sample(n=n, mode=mode, variables=variables)
241+
self.domains[d].sample(n=n, mode=mode, variables=variables)
235242
] + already_sampled
236243
pts = merge_tensors(samples)
237-
self.input_pts[location] = pts
244+
self.input_pts[d] = pts
238245

239246
# the condition is sampled if input_pts contains all labels
240-
if sorted(self.input_pts[location].labels) == sorted(
247+
if sorted(self.input_pts[d].labels) == sorted(
241248
self.input_variables
242249
):
243-
self._have_sampled_points[location] = True
244-
self.input_pts[location] = self.input_pts[location].extract(
245-
sorted(self.input_variables)
246-
)
250+
# self._have_sampled_points[location] = True
251+
# self.input_pts[location] = self.input_pts[location].extract(
252+
# sorted(self.input_variables)
253+
# )
254+
self._have_sampled_points[d] = True
247255

248256
def add_points(self, new_points):
249257
"""

pina/solvers/supervised.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,6 @@ def training_step(self, batch, batch_idx):
134134
condition = self.problem.conditions[condition_name]
135135
pts = batch.input
136136
out = batch.output
137-
print(out)
138-
print(pts)
139137

140138
if condition_name not in self.problem.conditions:
141139
raise RuntimeError("Something wrong happened.")

0 commit comments

Comments
 (0)