Skip to content

Commit c3bd828

Browse files
add value and add equation_factory tests
1 parent ef75f13 commit c3bd828

File tree

5 files changed

+88
-11
lines changed

5 files changed

+88
-11
lines changed

docs/source/_rst/equation/equation_factory.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,6 @@ Equation Factory
1414
:members:
1515
:show-inheritance:
1616

17-
.. autoclass:: Laplace
17+
.. autoclass:: FixedLaplacian
1818
:members:
1919
:show-inheritance:

pina/equation/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,14 @@
66
"FixedValue",
77
"FixedGradient",
88
"FixedFlux",
9-
"Laplace",
9+
"FixedLaplacian",
1010
]
1111

1212
from .equation import Equation
13-
from .equation_factory import FixedFlux, FixedGradient, Laplace, FixedValue
13+
from .equation_factory import (
14+
FixedFlux,
15+
FixedGradient,
16+
FixedLaplacian,
17+
FixedValue,
18+
)
1419
from .system_equation import SystemEquation

pina/equation/equation_factory.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -110,15 +110,16 @@ def equation(input_, output_):
110110
super().__init__(equation)
111111

112112

113-
class Laplace(Equation):
113+
class FixedLaplacian(Equation):
114114
"""
115-
Equation to enforce a null laplacian for a specific condition.
115+
Equation to enforce a fixed laplacian for a specific condition.
116116
"""
117117

118-
def __init__(self, components=None, d=None):
118+
def __init__(self, value, components=None, d=None):
119119
"""
120-
Initialization of the :class:`Laplace` class.
120+
Initialization of the :class:`FixedLaplacian` class.
121121
122+
:param float value: The fixed value to be enforced to the laplacian.
122123
:param list[str] components: The name of the output variables for which
123124
the null laplace condition is applied. It should be a subset of the
124125
output labels. If ``None``, all output variables are considered.
@@ -131,7 +132,7 @@ def __init__(self, components=None, d=None):
131132

132133
def equation(input_, output_):
133134
"""
134-
Definition of the equation to enforce a null laplacian.
135+
Definition of the equation to enforce a fixed laplacian.
135136
136137
:param LabelTensor input_: Input points where the equation is
137138
evaluated.
@@ -140,6 +141,8 @@ def equation(input_, output_):
140141
:return: The computed residual of the equation.
141142
:rtype: LabelTensor
142143
"""
143-
return laplacian(output_, input_, components=components, d=d)
144+
return (
145+
laplacian(output_, input_, components=components, d=d) - value
146+
)
144147

145148
super().__init__(equation)
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from pina.equation import FixedValue, FixedGradient, FixedFlux, FixedLaplacian
2+
from pina import LabelTensor
3+
import torch
4+
import pytest
5+
6+
# Define input and output values
7+
pts = LabelTensor(torch.rand(10, 3, requires_grad=True), labels=["x", "y", "z"])
8+
u = torch.pow(pts, 2)
9+
u.labels = ["u", "v", "w"]
10+
11+
12+
@pytest.mark.parametrize("value", [0, 10, -7.5])
13+
@pytest.mark.parametrize("components", [None, "u", ["u", "w"]])
14+
def test_fixed_value(value, components):
15+
16+
# Constructor
17+
equation = FixedValue(value=value, components=components)
18+
19+
# Residual
20+
residual = equation.residual(pts, u)
21+
len_c = len(components) if components is not None else u.shape[1]
22+
assert residual.shape == (pts.shape[0], len_c)
23+
24+
25+
@pytest.mark.parametrize("value", [0, 10, -7.5])
26+
@pytest.mark.parametrize("components", [None, "u", ["u", "w"]])
27+
@pytest.mark.parametrize("d", [None, "x", ["x", "z"]])
28+
def test_fixed_gradient(value, components, d):
29+
30+
# Constructor
31+
equation = FixedGradient(value=value, components=components, d=d)
32+
33+
# Residual
34+
residual = equation.residual(pts, u)
35+
len_c = len(components) if components is not None else u.shape[1]
36+
len_d = len(d) if d is not None else pts.shape[1]
37+
assert residual.shape == (pts.shape[0], len_c * len_d)
38+
39+
40+
@pytest.mark.parametrize("value", [0, 10, -7.5])
41+
@pytest.mark.parametrize("components", [None, "u", ["u", "w"]])
42+
@pytest.mark.parametrize("d", [None, "x", ["x", "z"]])
43+
def test_fixed_flux(value, components, d):
44+
45+
# Divergence requires components and d to be of the same length
46+
len_c = len(components) if components is not None else u.shape[1]
47+
len_d = len(d) if d is not None else pts.shape[1]
48+
if len_c != len_d:
49+
return
50+
51+
# Constructor
52+
equation = FixedFlux(value=value, components=components, d=d)
53+
54+
# Residual
55+
residual = equation.residual(pts, u)
56+
assert residual.shape == (pts.shape[0], 1)
57+
58+
59+
@pytest.mark.parametrize("value", [0, 10, -7.5])
60+
@pytest.mark.parametrize("components", [None, "u", ["u", "w"]])
61+
@pytest.mark.parametrize("d", [None, "x", ["x", "z"]])
62+
def test_fixed_laplacian(value, components, d):
63+
64+
# Constructor
65+
equation = FixedLaplacian(value=value, components=components, d=d)
66+
67+
# Residual
68+
residual = equation.residual(pts, u)
69+
len_c = len(components) if components is not None else u.shape[1]
70+
assert residual.shape == (pts.shape[0], len_c)

tutorials/tutorial12/tutorial.ipynb

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,7 @@
152152
"Once the equations are set as above in the problem conditions, the PINN solver will aim to minimize the residuals described in each equation during the training phase. \n",
153153
"\n",
154154
"### Available classes of equations:\n",
155-
"- `FixedGradient` and `FixedFlux`: These work analogously to the `FixedValue` class, where we can enforce a constant value on the gradient or the divergence of the solution, respectively.\n",
156-
"- `Laplace`: This class can be used to enforce that the Laplacian of the solution is zero.\n",
155+
"- `FixedGradient`, `FixedFlux`, `FixedLaplacian`: These work analogously to the `FixedValue` class, where we can enforce a constant value on the gradient, on the divergence, or on the laplacian of the solution, respectively.\n",
157156
"- `SystemEquation`: This class allows you to enforce multiple conditions on the same subdomain by passing a list of residual equations defined in the problem.\n",
158157
"\n",
159158
"## Defining a new Equation class\n",

0 commit comments

Comments
 (0)