diff --git a/docs/source/_rst/_code.rst b/docs/source/_rst/_code.rst index 25f0e3062..93787a0b2 100644 --- a/docs/source/_rst/_code.rst +++ b/docs/source/_rst/_code.rst @@ -95,6 +95,7 @@ Models MultiFeedForward ResidualFeedForward Spline + SplineSurface DeepONet MIONet KernelNeuralOperator diff --git a/docs/source/_rst/model/spline_surface.rst b/docs/source/_rst/model/spline_surface.rst new file mode 100644 index 000000000..6bbf137d8 --- /dev/null +++ b/docs/source/_rst/model/spline_surface.rst @@ -0,0 +1,7 @@ +Spline Surface +================ +.. currentmodule:: pina.model.spline_surface + +.. autoclass:: SplineSurface + :members: + :show-inheritance: \ No newline at end of file diff --git a/pina/model/__init__.py b/pina/model/__init__.py index ee343e53d..b67e3562b 100644 --- a/pina/model/__init__.py +++ b/pina/model/__init__.py @@ -25,6 +25,7 @@ from .average_neural_operator import AveragingNeuralOperator from .low_rank_neural_operator import LowRankNeuralOperator from .spline import Spline +from .spline_surface import SplineSurface from .graph_neural_operator import GraphNeuralOperator from .pirate_network import PirateNet from .equivariant_graph_neural_operator import EquivariantGraphNeuralOperator diff --git a/pina/model/spline.py b/pina/model/spline.py index c22c7937c..fd6df3fe7 100644 --- a/pina/model/spline.py +++ b/pina/model/spline.py @@ -1,109 +1,238 @@ -"""Module for the Spline model class.""" +"""Module for the B-Spline model class.""" +import warnings import torch -from ..utils import check_consistency +from ..utils import check_positive_integer class Spline(torch.nn.Module): - """ - Spline model class. + r""" + The univariate B-Spline curve model class. + + A univariate B-spline curve of order :math:`k` is a parametric curve defined + as a linear combination of B-spline basis functions and control points: + + .. math:: + + S(x) = \sum_{i=1}^{n} B_{i,k}(x) C_i, \quad x \in [x_1, x_m] + + where: + + - :math:`C_i \in \mathbb{R}` are the control points. These fixed points + influence the shape of the curve but are not generally interpolated, + except at the boundaries under certain knot multiplicities. + - :math:`B_{i,k}(x)` are the B-spline basis functions of order :math:`k`, + i.e., piecewise polynomials of degree :math:`k-1` with support on the + interval :math:`[x_i, x_{i+k}]`. + - :math:`X = \{ x_1, x_2, \dots, x_m \}` is the non-decreasing knot vector. + + If the first and last knots are repeated :math:`k` times, then the curve + interpolates the first and last control points. + + + .. note:: + + The curve is forced to be zero outside the interval defined by the + first and last knots. + + + :Example: + + >>> from pina.model import Spline + >>> import torch + + >>> knots1 = torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0, 2.0, 2.0]) + >>> spline1 = Spline(order=3, knots=knots1, control_points=None) + + >>> knots2 = {"n": 7, "min": 0.0, "max": 2.0, "mode": "auto"} + >>> spline2 = Spline(order=3, knots=knots2, control_points=None) + + >>> knots3 = torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0, 2.0, 2.0]) + >>> control_points3 = torch.tensor([0.0, 1.0, 3.0, 2.0]) + >>> spline3 = Spline(order=3, knots=knots3, control_points=control_points3) """ - def __init__(self, order=4, knots=None, control_points=None) -> None: + def __init__(self, order=4, knots=None, control_points=None): """ Initialization of the :class:`Spline` class. - :param int order: The order of the spline. Default is ``4``. - :param torch.Tensor knots: The tensor representing knots. If ``None``, - the knots will be initialized automatically. Default is ``None``. - :param torch.Tensor control_points: The control points. Default is - ``None``. - :raises ValueError: If the order is negative. - :raises ValueError: If both knots and control points are ``None``. - :raises ValueError: If the knot tensor is not one-dimensional. + :param int order: The order of the spline. The corresponding basis + functions are polynomials of degree ``order - 1``. Default is 4. + :param knots: The knots of the spline. If a tensor is provided, knots + are set directly from the tensor. If a dictionary is provided, it + must contain the keys ``"n"``, ``"min"``, ``"max"``, and ``"mode"``. + Here, ``"n"`` specifies the number of knots, ``"min"`` and ``"max"`` + define the interval, and ``"mode"`` selects the sampling strategy. + The supported modes are ``"uniform"``, where the knots are evenly + spaced over :math:`[min, max]`, and ``"auto"``, where knots are + constructed to ensure that the spline interpolates the first and + last control points. In this case, the number of knots is adjusted + if :math:`n < 2 * order`. If None is given, knots are initialized + automatically over :math:`[0, 1]` ensuring interpolation of the + first and last control points. Default is None. + :type knots: torch.Tensor | dict + :param torch.Tensor control_points: The control points of the spline. + If None, they are initialized as learnable parameters with an + initial value of zero. Default is None. + :raises AssertionError: If ``order`` is not a positive integer. + :raises ValueError: If both ``knots`` and ``control_points`` are None. + :raises ValueError: If ``knots`` is not one-dimensional. + :raises ValueError: If ``control_points`` is not one-dimensional. + :raises ValueError: If the number of ``knots`` is not equal to the sum + of ``order`` and the number of ``control_points.`` + :raises UserWarning: If the number of control points is lower than the + order, resulting in a degenerate spline. """ super().__init__() - check_consistency(order, int) + # Check consistency + check_positive_integer(value=order, strict=True) - if order < 0: - raise ValueError("Spline order cannot be negative.") + # Raise error if neither knots nor control points are provided if knots is None and control_points is None: - raise ValueError("Knots and control points cannot be both None.") + raise ValueError("knots and control_points cannot both be None.") - self.order = order - self.k = order - 1 - - if knots is not None and control_points is not None: - self.knots = knots - self.control_points = control_points + # Initialize knots if not provided + if knots is None and control_points is not None: + knots = { + "n": len(control_points) + order, + "min": 0, + "max": 1, + "mode": "auto", + } - elif knots is not None: - print("Warning: control points will be initialized automatically.") - print(" experimental feature") + # Initialization - knots and control points managed by their setters + self.order = order + self.knots = knots + self.control_points = control_points + + # Check dimensionality of knots + if self.knots.ndim > 1: + raise ValueError("knots must be one-dimensional.") + + # Check dimensionality of control points + if self.control_points.ndim > 1: + raise ValueError("control_points must be one-dimensional.") + + # Raise error if #knots != order + #control_points + if len(self.knots) != self.order + len(self.control_points): + raise ValueError( + f" The number of knots must be equal to order + number of" + f" control points. Got {len(self.knots)} knots, {self.order}" + f" order and {len(self.control_points)} control points." + ) - self.knots = knots - n = len(knots) - order - self.control_points = torch.nn.Parameter( - torch.zeros(n), requires_grad=True + # Raise warning if spline is degenerate + if len(self.control_points) < self.order: + warnings.warn( + "The number of control points is smaller than the spline order." + " This creates a degenerate spline with limited flexibility.", + UserWarning, ) - elif control_points is not None: - print("Warning: knots will be initialized automatically.") - print(" experimental feature") + # Precompute boundary interval index + self._boundary_interval_idx = self._compute_boundary_interval() - self.control_points = control_points + def _compute_boundary_interval(self): + """ + Precompute the index of the rightmost non-degenerate interval to improve + performance, eliminating the need to perform a search loop in the basis + function on each call. - n = len(self.control_points) - 1 - self.knots = { - "type": "auto", - "min": 0, - "max": 1, - "n": n + 2 + self.order, - } + :return: The index of the rightmost non-degenerate interval. + :rtype: int + """ + # Return 0 if there is a single interval + if len(self.knots) < 2: + return 0 - else: - raise ValueError("Knots and control points cannot be both None.") + # Find all indices where knots are strictly increasing + diffs = self.knots[1:] - self.knots[:-1] + valid = torch.nonzero(diffs > 0, as_tuple=False) - if self.knots.ndim != 1: - raise ValueError("Knot vector must be one-dimensional.") + # If all knots are equal, return 0 for degenerate spline + if valid.numel() == 0: + return 0 - def basis(self, x, k, i, t): + # Otherwise, return the last valid index + return int(valid[-1]) + + def basis(self, x): """ - Recursive method to compute the basis functions of the spline. + Compute the basis functions for the spline using an iterative approach. + This is a vectorized implementation based on the Cox-de Boor recursion. :param torch.Tensor x: The points to be evaluated. - :param int k: The spline degree. - :param int i: The index of the interval. - :param torch.Tensor t: The tensor of knots. - :return: The basis functions evaluated at x + :return: The basis functions evaluated at x. :rtype: torch.Tensor """ + # Add a final dimension to x + x = x.unsqueeze(-1) + + # Add an initial dimension to knots + knots = self.knots.unsqueeze(0) + + # Base case of recursion: indicator functions for the intervals + basis = (x >= knots[..., :-1]) & (x < knots[..., 1:]) + basis = basis.to(x.dtype) - if k == 0: - a = torch.where( - torch.logical_and(t[i] <= x, x < t[i + 1]), 1.0, 0.0 + # One-dimensional knots case: ensure rightmost boundary inclusion + if self._boundary_interval_idx is not None: + + # Extract left and right knots of the rightmost interval + knot_left = knots[..., self._boundary_interval_idx] + knot_right = knots[..., self._boundary_interval_idx + 1] + + # Identify points at the rightmost boundary + at_rightmost_boundary = ( + x.squeeze(-1) >= knot_left + ) & torch.isclose(x.squeeze(-1), knot_right, rtol=1e-8, atol=1e-10) + + # Ensure the correct value is set at the rightmost boundary + if torch.any(at_rightmost_boundary): + basis[..., self._boundary_interval_idx] = torch.logical_or( + basis[..., self._boundary_interval_idx].bool(), + at_rightmost_boundary, + ).to(basis.dtype) + + # Iterative case of recursion + for i in range(1, self.order): + + # Compute the denominators for both terms + denom1 = knots[..., i:-1] - knots[..., : -(i + 1)] + denom2 = knots[..., i + 1 :] - knots[..., 1:-i] + + # Ensure no division by zero + denom1 = torch.where( + torch.abs(denom1) < 1e-8, torch.ones_like(denom1), denom1 ) - if i == len(t) - self.order - 1: - a = torch.where(x == t[-1], 1.0, a) - a.requires_grad_(True) - return a - - if t[i + k] == t[i]: - c1 = torch.tensor([0.0] * len(x), requires_grad=True) - else: - c1 = (x - t[i]) / (t[i + k] - t[i]) * self.basis(x, k - 1, i, t) - - if t[i + k + 1] == t[i + 1]: - c2 = torch.tensor([0.0] * len(x), requires_grad=True) - else: - c2 = ( - (t[i + k + 1] - x) - / (t[i + k + 1] - t[i + 1]) - * self.basis(x, k - 1, i + 1, t) + denom2 = torch.where( + torch.abs(denom2) < 1e-8, torch.ones_like(denom2), denom2 ) - return c1 + c2 + # Compute the two terms of the recursion + term1 = ((x - knots[..., : -(i + 1)]) / denom1) * basis[..., :-1] + term2 = ((knots[..., i + 1 :] - x) / denom2) * basis[..., 1:] + + # Combine terms to get the new basis + basis = term1 + term2 + + return basis + + def forward(self, x): + """ + Forward pass for the :class:`Spline` model. + + :param x: The input tensor. + :type x: torch.Tensor | LabelTensor + :return: The output tensor. + :rtype: torch.Tensor + """ + return torch.einsum( + "bi, i -> b", + self.basis(x.as_subclass(torch.Tensor)).squeeze(1), + self.control_points, + ).reshape(-1, 1) @property def control_points(self): @@ -116,24 +245,42 @@ def control_points(self): return self._control_points @control_points.setter - def control_points(self, value): + def control_points(self, control_points): """ Set the control points of the spline. - :param value: The control points. - :type value: torch.Tensor | dict - :raises ValueError: If invalid value is passed. + :param torch.Tensor control_points: The control points tensor. If None, + control points are initialized to learnable parameters with zero + initial value. Default is None. + :raises ValueError: If there are not enough knots to define the control + points, due to the relation: #knots = order + #control_points. + :raises ValueError: If control_points is not a torch.Tensor. """ - if isinstance(value, dict): - if "n" not in value: - raise ValueError("Invalid value for control_points") - n = value["n"] - dim = value.get("dim", 1) - value = torch.zeros(n, dim) + # If control points are not provided, initialize them + if control_points is None: + + # Check that there are enough knots to define control points + if len(self.knots) < self.order + 1: + raise ValueError( + f"Not enough knots to define control points. Got " + f"{len(self.knots)} knots, but need at least " + f"{self.order + 1}." + ) + + # Initialize control points to zero + control_points = torch.zeros(len(self.knots) - self.order) + + # Check validity of control points + elif not isinstance(control_points, torch.Tensor): + raise ValueError( + "control_points must be a torch.Tensor," + f" got {type(control_points)}" + ) - if not isinstance(value, torch.Tensor): - raise ValueError("Invalid value for control_points") - self._control_points = torch.nn.Parameter(value, requires_grad=True) + # Set control points + self._control_points = torch.nn.Parameter( + control_points, requires_grad=True + ) @property def knots(self): @@ -150,50 +297,80 @@ def knots(self, value): """ Set the knots of the spline. - :param value: The knots. + :param value: The knots of the spline. If a tensor is provided, knots + are set directly from the tensor. If a dictionary is provided, it + must contain the keys ``"n"``, ``"min"``, ``"max"``, and ``"mode"``. + Here, ``"n"`` specifies the number of knots, ``"min"`` and ``"max"`` + define the interval, and ``"mode"`` selects the sampling strategy. + The supported modes are ``"uniform"``, where the knots are evenly + spaced over :math:`[min, max]`, and ``"auto"``, where knots are + constructed to ensure that the spline interpolates the first and + last control points. In this case, the number of knots is inferred + and the ``"n"`` key is ignored. :type value: torch.Tensor | dict - :raises ValueError: If invalid value is passed. + :raises ValueError: If value is not a torch.Tensor or a dictionary. + :raises ValueError: If a dictionary is provided but does not contain + the required keys. + :raises ValueError: If the mode specified in the dictionary is invalid. """ + # Check validity of knots + if not isinstance(value, (torch.Tensor, dict)): + raise ValueError( + "Knots must be a torch.Tensor or a dictionary," + f" got {type(value)}." + ) + + # If a dictionary is provided, initialize knots accordingly if isinstance(value, dict): - type_ = value.get("type", "auto") - min_ = value.get("min", 0) - max_ = value.get("max", 1) - n = value.get("n", 10) - - if type_ == "uniform": - value = torch.linspace(min_, max_, n + self.k + 1) - elif type_ == "auto": - initial_knots = torch.ones(self.order + 1) * min_ - final_knots = torch.ones(self.order + 1) * max_ - - if n < self.order + 1: - value = torch.concatenate((initial_knots, final_knots)) - elif n - 2 * self.order + 1 == 1: - value = torch.Tensor([(max_ + min_) / 2]) - else: - value = torch.linspace(min_, max_, n - 2 * self.order - 1) + # Check that required keys are present + required_keys = {"n", "min", "max", "mode"} + if not required_keys.issubset(value.keys()): + raise ValueError( + f"When providing knots as a dictionary, the following " + f"keys must be present: {required_keys}. Got " + f"{value.keys()}." + ) - value = torch.concatenate((initial_knots, value, final_knots)) + # Uniform sampling of knots + if value["mode"] == "uniform": + value = torch.linspace(value["min"], value["max"], value["n"]) - if not isinstance(value, torch.Tensor): - raise ValueError("Invalid value for knots") + # Automatic sampling of interpolating knots + elif value["mode"] == "auto": - self._knots = value + # Repeat the first and last knots 'order' times + initial_knots = torch.ones(self.order) * value["min"] + final_knots = torch.ones(self.order) * value["max"] - def forward(self, x): - """ - Forward pass for the :class:`Spline` model. + # Number of internal knots + n_internal = value["n"] - 2 * self.order - :param torch.Tensor x: The input tensor. - :return: The output tensor. - :rtype: torch.Tensor - """ - t = self.knots - k = self.k - c = self.control_points - - basis = map(lambda i: self.basis(x, k, i, t)[:, None], range(len(c))) - y = (torch.cat(list(basis), dim=1) * c).sum(axis=1) + # If no internal knots are needed, just concatenate boundaries + if n_internal <= 0: + value = torch.cat((initial_knots, final_knots)) - return y + # Else, sample internal knots uniformly and exclude boundaries + # Recover the correct number of internal knots when slicing by + # adding 2 to n_internal + else: + internal_knots = torch.linspace( + value["min"], value["max"], n_internal + 2 + )[1:-1] + value = torch.cat( + (initial_knots, internal_knots, final_knots) + ) + + # Raise error if mode is invalid + else: + raise ValueError( + f"Invalid mode for knots initialization. Got " + f"{value['mode']}, but expected 'uniform' or 'auto'." + ) + + # Set knots + self.register_buffer("_knots", value.sort(dim=0).values) + + # Recompute boundary interval when knots change + if hasattr(self, "_boundary_interval_idx"): + self._boundary_interval_idx = self._compute_boundary_interval() diff --git a/pina/model/spline_surface.py b/pina/model/spline_surface.py new file mode 100644 index 000000000..288b77a72 --- /dev/null +++ b/pina/model/spline_surface.py @@ -0,0 +1,140 @@ +"""Module for the bivariate B-Spline surface model class.""" + +import torch +from .spline import Spline +from ..utils import check_consistency + + +class SplineSurface(torch.nn.Module): + r""" + The bivariate B-Spline surface model class. + + A bivariate B-spline surface is a parametric surface defined as the tensor + product of two univariate B-spline curves: + + .. math:: + + S(x, y) = \sum_{i,j=1}^{n_x, n_y} B_{i,k}(x) B_{j,s}(y) C_{i,j}, + \quad x \in [x_1, x_m], y \in [y_1, y_l] + + where: + + - :math:`C_{i,j} \in \mathbb{R}^2` are the control points. These fixed + points influence the shape of the surface but are not generally + interpolated, except at the boundaries under certain knot multiplicities. + - :math:`B_{i,k}(x)` and :math:`B_{j,s}(y)` are the B-spline basis functions + defined over two orthogonal directions, with orders :math:`k` and + :math:`s`, respectively. + - :math:`X = \{ x_1, x_2, \dots, x_m \}` and + :math:`Y = \{ y_1, y_2, \dots, y_l \}` are the non-decreasing knot + vectors along the two directions. + """ + + def __init__(self, orders, knots_u, knots_v, control_points=None): + """ + Initialization of the :class:`SplineSurface` class. + + :param list[int] orders: The orders of the spline along each parametric + direction. Each order defines the degree of the corresponding basis + as ``degree = order - 1``. + :param knots_u: The knots of the spline along the first direction. + Unlike the univariate case, this must be explicitly provided. + For details on valid formats and initialization modes, see the + :class:`Spline` class. + :type knots_u: torch.Tensor | dict + :param knots_v: The knots of the spline along the second direction. + Unlike the univariate case, this must be explicitly provided. + For details on valid formats and initialization modes, see the + :class:`Spline` class. + :type knots_v: torch.Tensor | dict + :param torch.Tensor control_points: The control points defining the + surface geometry. It must be a two-dimensional tensor of shape + ``[len(knots_u) - orders[0], len(knots_v) - orders[1]]``. + If None, they are initialized as learnable parameters with zero + values. Default is None. + :raises ValueError: If ``orders`` is not a list of two elements. + :raises ValueError: If ``knots_u`` or ``knots_v`` is None. + :raises ValueError: If ``control_points`` is not a torch.Tensor when + provided. + :raises ValueError: If ``control_points`` is not of the correct shape + when provided. + """ + super().__init__() + + # Check consistency + check_consistency(orders, int) + check_consistency(control_points, (type(None), torch.Tensor)) + + # Check orders is a list of two elements + if len(orders) != 2: + raise ValueError("orders must be a list of two elements.") + + # Check knots_u and knots_v are not None + if knots_u is None or knots_v is None: + raise ValueError("knots_u and knots_v must cannot be None.") + + # Create two univariate b-splines + self.spline_u = Spline(order=orders[0], knots=knots_u) + self.spline_v = Spline(order=orders[1], knots=knots_v) + + # Delete unneeded parameters + delattr(self.spline_u, "_control_points") + delattr(self.spline_v, "_control_points") + + # Save correct shape of control points + __valid_shape = ( + len(self.spline_u.knots) - self.spline_u.order, + len(self.spline_v.knots) - self.spline_v.order, + ) + + # Initialize control points, if not provided + if control_points is None: + control_points = torch.zeros(__valid_shape) + + # Check control points + if control_points.shape != __valid_shape: + raise ValueError( + "control_points must be of the correct shape. ", + f"Expected {__valid_shape}, got {control_points.shape}.", + ) + + # Register control points as a learnable parameter + self._control_points = torch.nn.Parameter( + control_points, requires_grad=True + ) + + def forward(self, x): + """ + Forward pass for the :class:`SplineSurface` model. + + :param x: The input tensor. + :type x: torch.Tensor | LabelTensor + :return: The output tensor. + :rtype: torch.Tensor + """ + return torch.einsum( + "bi, bj, ij -> b", + self.spline_u.basis(x.as_subclass(torch.Tensor)[:, 0]), + self.spline_v.basis(x.as_subclass(torch.Tensor)[:, 1]), + self.control_points, + ).reshape(-1, 1) + + @property + def knots(self): + """ + The knots of the univariate splines defining the spline surface. + + :return: The knots. + :rtype: tuple(torch.Tensor, torch.Tensor) + """ + return self.spline_u.knots, self.spline_v.knots + + @property + def control_points(self): + """ + The control points of the spline. + + :return: The control points. + :rtype: torch.Tensor + """ + return self._control_points diff --git a/tests/test_model/test_spline.py b/tests/test_model/test_spline.py index d38b1610b..c30e54244 100644 --- a/tests/test_model/test_spline.py +++ b/tests/test_model/test_spline.py @@ -1,81 +1,171 @@ import torch import pytest - +import numpy as np +from scipy.interpolate import BSpline from pina.model import Spline +from pina import LabelTensor -data = torch.rand((20, 3)) -input_vars = 3 -output_vars = 4 -valid_args = [ - { - "knots": torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0, 3.0, 3.0, 3.0]), - "control_points": torch.tensor([0.0, 0.0, 1.0, 0.0, 0.0]), - "order": 3, - }, - { - "knots": torch.tensor( - [-2.0, -2.0, -2.0, -2.0, -1.0, 0.0, 1.0, 2.0, 2.0, 2.0, 2.0] - ), - "control_points": torch.tensor([0.0, 0.0, 0.0, 6.0, 0.0, 0.0, 0.0]), - "order": 4, - }, - # {'control_points': {'n': 5, 'dim': 1}, 'order': 2}, - # {'control_points': {'n': 7, 'dim': 1}, 'order': 3} -] +# Utility quantities for testing +order = torch.randint(1, 8, (1,)).item() +n_ctrl_pts = torch.randint(order, order + 5, (1,)).item() +n_knots = order + n_ctrl_pts + +# Input tensor +pts = LabelTensor(torch.linspace(0, 1, 100).reshape(-1, 1), ["x"]) -def scipy_check(model, x, y): - from scipy.interpolate._bsplines import BSpline - import numpy as np +# Function to compare with scipy implementation +def check_scipy_spline(model, x, output_): - spline = BSpline( + # Define scipy spline + scipy_spline = BSpline( t=model.knots.detach().numpy(), c=model.control_points.detach().numpy(), k=model.order - 1, ) - y_scipy = spline(x).flatten() - y = y.detach().numpy() - np.testing.assert_allclose(y, y_scipy, atol=1e-5) + + # Compare outputs + np.testing.assert_allclose( + output_.squeeze().detach().numpy(), + scipy_spline(x).flatten(), + atol=1e-5, + rtol=1e-5, + ) + + +# Define all possible combinations of valid arguments for the Spline class +valid_args = [ + { + "order": order, + "control_points": torch.rand(n_ctrl_pts), + "knots": torch.linspace(0, 1, n_knots), + }, + { + "order": order, + "control_points": torch.rand(n_ctrl_pts), + "knots": {"n": n_knots, "min": 0, "max": 1, "mode": "auto"}, + }, + { + "order": order, + "control_points": torch.rand(n_ctrl_pts), + "knots": {"n": n_knots, "min": 0, "max": 1, "mode": "uniform"}, + }, + { + "order": order, + "control_points": None, + "knots": torch.linspace(0, 1, n_knots), + }, + { + "order": order, + "control_points": None, + "knots": {"n": n_knots, "min": 0, "max": 1, "mode": "auto"}, + }, + { + "order": order, + "control_points": None, + "knots": {"n": n_knots, "min": 0, "max": 1, "mode": "uniform"}, + }, + { + "order": order, + "control_points": torch.rand(n_ctrl_pts), + "knots": None, + }, +] @pytest.mark.parametrize("args", valid_args) def test_constructor(args): Spline(**args) + # Should fail if order is not a positive integer + with pytest.raises(AssertionError): + Spline( + order=-1, control_points=args["control_points"], knots=args["knots"] + ) + + # Should fail if control_points is not None or a torch.Tensor + with pytest.raises(ValueError): + Spline( + order=args["order"], control_points=[1, 2, 3], knots=args["knots"] + ) + + # Should fail if knots is not None, a torch.Tensor, or a dict + with pytest.raises(ValueError): + Spline( + order=args["order"], control_points=args["control_points"], knots=5 + ) -def test_constructor_wrong(): + # Should fail if both knots and control_points are None with pytest.raises(ValueError): - Spline() + Spline(order=args["order"], control_points=None, knots=None) + + # Should fail if knots is not one-dimensional + with pytest.raises(ValueError): + Spline( + order=args["order"], + control_points=args["control_points"], + knots=torch.rand(n_knots, 4), + ) + + # Should fail if control_points is not one-dimensional + with pytest.raises(ValueError): + Spline( + order=args["order"], + control_points=torch.rand(n_ctrl_pts, 4), + knots=args["knots"], + ) + + # Should fail if the number of knots != order + number of control points + # If control points are None, they are initialized to fulfill this condition + if args["control_points"] is not None: + with pytest.raises(ValueError): + Spline( + order=args["order"], + control_points=args["control_points"], + knots=torch.linspace(0, 1, n_knots + 1), + ) + + # Should fail if the knot dict is missing required keys + with pytest.raises(ValueError): + Spline( + order=args["order"], + control_points=args["control_points"], + knots={"n": n_knots, "min": 0, "max": 1}, + ) + + # Should fail if the knot dict has invalid 'mode' key + with pytest.raises(ValueError): + Spline( + order=args["order"], + control_points=args["control_points"], + knots={"n": n_knots, "min": 0, "max": 1, "mode": "invalid"}, + ) @pytest.mark.parametrize("args", valid_args) def test_forward(args): - min_x = args["knots"][0] - max_x = args["knots"][-1] - xi = torch.linspace(min_x, max_x, 1000) + + # Define the model model = Spline(**args) - yi = model(xi).squeeze() - scipy_check(model, xi, yi) - return + + # Evaluate the model + output_ = model(pts) + assert output_.shape == (pts.shape[0], 1) + + # Compare with scipy implementation only for interpolant knots (mode: auto) + if isinstance(args["knots"], dict) and args["knots"]["mode"] == "auto": + check_scipy_spline(model, pts, output_) @pytest.mark.parametrize("args", valid_args) def test_backward(args): - min_x = args["knots"][0] - max_x = args["knots"][-1] - xi = torch.linspace(min_x, max_x, 100) + + # Define the model model = Spline(**args) - yi = model(xi) - fake_loss = torch.sum(yi) - assert model.control_points.grad is None - fake_loss.backward() - assert model.control_points.grad is not None - - # dim_in, dim_out = 3, 2 - # fnn = FeedForward(dim_in, dim_out) - # data.requires_grad = True - # output_ = fnn(data) - # l=torch.mean(output_) - # l.backward() - # assert data._grad.shape == torch.Size([20,3]) + + # Evaluate the model + output_ = model(pts) + loss = torch.mean(output_) + loss.backward() + assert model.control_points.grad.shape == model.control_points.shape diff --git a/tests/test_model/test_spline_surface.py b/tests/test_model/test_spline_surface.py new file mode 100644 index 000000000..6fc0583e3 --- /dev/null +++ b/tests/test_model/test_spline_surface.py @@ -0,0 +1,159 @@ +import torch +import random +import pytest +from pina.model import SplineSurface +from pina import LabelTensor + + +# Utility quantities for testing +orders = [random.randint(1, 8) for _ in range(2)] +n_ctrl_pts = random.randint(max(orders), max(orders) + 5) +n_knots = [orders[i] + n_ctrl_pts for i in range(2)] + +# Input tensor +x = torch.linspace(0, 1, 100).reshape(-1, 1) +y = torch.linspace(0, 1, 100).reshape(-1, 1) +pts = LabelTensor(torch.cat((x, y), dim=1), labels=["x", "y"]) + + +@pytest.mark.parametrize( + "knots_u", + [ + torch.rand(n_knots[0]), + {"n": n_knots[0], "min": 0, "max": 1, "mode": "auto"}, + {"n": n_knots[0], "min": 0, "max": 1, "mode": "uniform"}, + ], +) +@pytest.mark.parametrize( + "knots_v", + [ + torch.rand(n_knots[1]), + {"n": n_knots[1], "min": 0, "max": 1, "mode": "auto"}, + {"n": n_knots[1], "min": 0, "max": 1, "mode": "uniform"}, + ], +) +@pytest.mark.parametrize( + "control_points", [torch.rand(n_ctrl_pts, n_ctrl_pts), None] +) +def test_constructor(knots_u, knots_v, control_points): + SplineSurface( + orders=orders, + knots_u=knots_u, + knots_v=knots_v, + control_points=control_points, + ) + + # Should fail if orders is not list of two elements + with pytest.raises(ValueError): + SplineSurface( + orders=[orders[0]], + knots_u=knots_u, + knots_v=knots_v, + control_points=control_points, + ) + + # Should fail if knots_u is None + with pytest.raises(ValueError): + SplineSurface( + orders=orders, + knots_u=None, + knots_v=knots_v, + control_points=control_points, + ) + + # Should fail if knots_v is None + with pytest.raises(ValueError): + SplineSurface( + orders=orders, + knots_u=knots_u, + knots_v=None, + control_points=control_points, + ) + + # Should fail if control_points is not a torch.Tensor when provided + with pytest.raises(ValueError): + SplineSurface( + orders=orders, + knots_u=knots_u, + knots_v=knots_v, + control_points=[[0.0] * n_ctrl_pts] * n_ctrl_pts, + ) + + # Should fail if control_points is not of the correct shape when provided + with pytest.raises(ValueError): + SplineSurface( + orders=orders, + knots_u=knots_u, + knots_v=knots_v, + control_points=torch.rand(n_ctrl_pts + 1, n_ctrl_pts), + ) + + +@pytest.mark.parametrize( + "knots_u", + [ + torch.rand(n_knots[0]), + {"n": n_knots[0], "min": 0, "max": 1, "mode": "auto"}, + {"n": n_knots[0], "min": 0, "max": 1, "mode": "uniform"}, + ], +) +@pytest.mark.parametrize( + "knots_v", + [ + torch.rand(n_knots[1]), + {"n": n_knots[1], "min": 0, "max": 1, "mode": "auto"}, + {"n": n_knots[1], "min": 0, "max": 1, "mode": "uniform"}, + ], +) +@pytest.mark.parametrize( + "control_points", [torch.rand(n_ctrl_pts, n_ctrl_pts), None] +) +def test_forward(knots_u, knots_v, control_points): + + # Define the model + model = SplineSurface( + orders=orders, + knots_u=knots_u, + knots_v=knots_v, + control_points=control_points, + ) + + # Evaluate the model + output_ = model(pts) + assert output_.shape == (pts.shape[0], 1) + + +@pytest.mark.parametrize( + "knots_u", + [ + torch.rand(n_knots[0]), + {"n": n_knots[0], "min": 0, "max": 1, "mode": "auto"}, + {"n": n_knots[0], "min": 0, "max": 1, "mode": "uniform"}, + ], +) +@pytest.mark.parametrize( + "knots_v", + [ + torch.rand(n_knots[1]), + {"n": n_knots[1], "min": 0, "max": 1, "mode": "auto"}, + {"n": n_knots[1], "min": 0, "max": 1, "mode": "uniform"}, + ], +) +@pytest.mark.parametrize( + "control_points", [torch.rand(n_ctrl_pts, n_ctrl_pts), None] +) +def test_backward(knots_u, knots_v, control_points): + + # Define the model + model = SplineSurface( + orders=orders, + knots_u=knots_u, + knots_v=knots_v, + control_points=control_points, + ) + + # Evaluate the model + output_ = model(pts) + loss = torch.mean(output_) + loss.backward() + assert model.control_points.grad.shape == model.control_points.shape