diff --git a/pina/data/dataset.py b/pina/data/dataset.py index 8b5f998f9..4eeb20e02 100644 --- a/pina/data/dataset.py +++ b/pina/data/dataset.py @@ -93,8 +93,7 @@ def __getitem__(self, idx): class PinaGraphDataset(PinaDataset): - pass -''' + def __init__(self, conditions_dict, max_conditions_lengths, automatic_batching): super().__init__(conditions_dict, max_conditions_lengths) @@ -113,7 +112,7 @@ def fetch_from_idx_list(self, idx): to_return_dict[condition] = {k: Batch.from_data_list([v[i] for i in cond_idx]) if isinstance(v, list) - else v[cond_idx] + else v[cond_idx].reshape(-1, *v[cond_idx].shape[2:]) for k, v in data.items() } return to_return_dict @@ -132,5 +131,4 @@ def get_all_data(self): return self.fetch_from_idx_list(index) def __getitem__(self, idx): - return self._getitem_func(idx) -''' \ No newline at end of file + return self._getitem_func(idx) \ No newline at end of file diff --git a/pina/graph.py b/pina/graph.py index bde5bbf50..8c167417c 100644 --- a/pina/graph.py +++ b/pina/graph.py @@ -1,118 +1,296 @@ -""" Module for Loss class """ +from logging import warning -import logging -from torch_geometric.nn import MessagePassing, InstanceNorm, radius_graph -from torch_geometric.data import Data import torch +from . import LabelTensor +from torch_geometric.data import Data +from torch_geometric.utils import to_undirected +import inspect + + class Graph: """ - PINA Graph managing the PyG Data class. + Class for the graph construction. """ - def __init__(self, data): - self.data = data - - @staticmethod - def _build_triangulation(**kwargs): - logging.debug("Creating graph with triangulation mode.") - # check for mandatory arguments - if "nodes_coordinates" not in kwargs: - raise ValueError("Nodes coordinates must be provided in the kwargs.") - if "nodes_data" not in kwargs: - raise ValueError("Nodes data must be provided in the kwargs.") - if "triangles" not in kwargs: - raise ValueError("Triangles must be provided in the kwargs.") + def __init__( + self, + x, + pos, + edge_index, + edge_attr=None, + build_edge_attr=False, + undirected=False, + custom_build_edge_attr=None, + additional_params=None + ): + """ + Constructor for the Graph class. This object creates a list of PyTorch Geometric Data objects. + Based on the input of x and pos there could be the following cases: + 1. 1 pos, 1 x: a single graph will be created + 2. N pos, 1 x: N graphs will be created with the same node features + 3. 1 pos, N x: N graphs will be created with the same nodes but different node features + 4. N pos, N x: N graphs will be created + + :param x: Node features. Can be a single 2D tensor of shape [num_nodes, num_node_features], + or a 3D tensor of shape [n_graphs, num_nodes, num_node_features] + or a list of such 2D tensors of shape [num_nodes, num_node_features]. + :type x: torch.Tensor or list[torch.Tensor] + :param pos: Node coordinates. Can be a single 2D tensor of shape [num_nodes, num_coordinates], + or a 3D tensor of shape [n_graphs, num_nodes, num_coordinates] + or a list of such 2D tensors of shape [num_nodes, num_coordinates]. + :type pos: torch.Tensor or list[torch.Tensor] + :param edge_index: The edge index defining connections between nodes. + It should be a 2D tensor of shape [2, num_edges] + or a 3D tensor of shape [n_graphs, 2, num_edges] + or a list of such 2D tensors of shape [2, num_edges]. + :type edge_index: torch.Tensor or list[torch.Tensor] + :param edge_attr: Edge features. If provided, should have the shape [num_edges, num_edge_features] + or be a list of such tensors for multiple graphs. + :type edge_attr: torch.Tensor or list[torch.Tensor], optional + :param build_edge_attr: Whether to compute edge attributes during initialization. + :type build_edge_attr: bool, default=False + :param undirected: If True, converts the graph(s) into an undirected graph by adding reciprocal edges. + :type undirected: bool, default=False + :param custom_build_edge_attr: A user-defined function to generate edge attributes dynamically. + The function should take (x, pos, edge_index) as input and return a tensor + of shape [num_edges, num_edge_features]. + :type custom_build_edge_attr: function or callable, optional + :param additional_params: Dictionary containing extra attributes to be added to each Data object. + Keys represent attribute names, and values should be tensors or lists of tensors. + :type additional_params: dict, optional - nodes_coordinates = kwargs["nodes_coordinates"] - nodes_data = kwargs["nodes_data"] - triangles = kwargs["triangles"] + Note: if x, pos, and edge_index are both lists or 3D tensors, then len(x) == len(pos) == len(edge_index). + """ + self.data = [] + x, pos, edge_index = self._check_input_consistency(x, pos, edge_index) + # Check input dimension consistency and store the number of graphs + data_len = self._check_len_consistency(x, pos) + if inspect.isfunction(custom_build_edge_attr): + self._build_edge_attr = custom_build_edge_attr - def less_first(a, b): - return [a, b] if a < b else [b, a] + # Check consistency and initialize additional_parameters (if present) + additional_params = self._check_additional_params(additional_params, + data_len) - list_of_edges = [] + # Make the graphs undirected + if undirected: + if isinstance(edge_index, list): + edge_index = [to_undirected(e) for e in edge_index] + else: + edge_index = to_undirected(edge_index) - for triangle in triangles: - for e1, e2 in [[0, 1], [1, 2], [2, 0]]: - list_of_edges.append(less_first(triangle[e1],triangle[e2])) + # Prepare internal lists to create a graph list (same positions but + # different node features) + if isinstance(x, list) and isinstance(pos, + (torch.Tensor, LabelTensor)): + # Replicate the positions, edge_index and edge_attr + pos, edge_index = [pos] * data_len, [edge_index] * data_len + # Prepare internal lists to create a list containing a single graph + elif isinstance(x, (torch.Tensor, LabelTensor)) and isinstance(pos, ( + torch.Tensor, LabelTensor)): + # Encapsulate the input tensors into lists + x, pos, edge_index = [x], [pos], [edge_index] + # Prepare internal lists to create a list of graphs (same node features + # but different positions) + elif (isinstance(x, (torch.Tensor, LabelTensor)) + and isinstance(pos, list)): + # Replicate the node features + x = [x] * data_len + elif not isinstance(x, list) and not isinstance(pos, list): + raise TypeError("x and pos must be lists or tensors.") - array_of_edges = torch.unique(torch.Tensor(list_of_edges), dim=0) # remove duplicates - array_of_edges = array_of_edges.t().contiguous() - print(array_of_edges) + # Build the edge attributes + edge_attr = self._check_and_build_edge_attr(edge_attr, build_edge_attr, + data_len, edge_index, pos, + x) - # list_of_lengths = [] + # Perform the graph construction + self._build_graph_list(x, pos, edge_index, edge_attr, additional_params) + + def _build_graph_list(self, x, pos, edge_index, edge_attr, + additional_params): + for i, (x_, pos_, edge_index_) in enumerate(zip(x, pos, edge_index)): + if isinstance(x_, LabelTensor): + x_ = x_.tensor + add_params_local = {k: v[i] for k, v in additional_params.items()} + if edge_attr is not None: + + self.data.append(Data(x=x_, pos=pos_, edge_index=edge_index_, + edge_attr=edge_attr[i], + **add_params_local)) + else: + self.data.append(Data(x=x_, pos=pos_, edge_index=edge_index_, + **add_params_local)) + + @staticmethod + def _build_edge_attr(x, pos, edge_index): + distance = torch.abs(pos[edge_index[0]] - pos[edge_index[1]]) + return distance + + @staticmethod + def _check_len_consistency(x, pos): + if isinstance(x, list) and isinstance(pos, list): + if len(x) != len(pos): + raise ValueError("x and pos must have the same length.") + return max(len(x), len(pos)) + elif isinstance(x, list) and not isinstance(pos, list): + return len(x) + elif not isinstance(x, list) and isinstance(pos, list): + return len(pos) + else: + return 1 - # for p1,p2 in array_of_edges: - # x1, y1 = tri.points[p1] - # x2, y2 = tri.points[p2] - # list_of_lengths.append((x1-x2)**2 + (y1-y2)**2) + @staticmethod + def _check_input_consistency(x, pos, edge_index=None): + # If x is a 3D tensor, we split it into a list of 2D tensors + if isinstance(x, torch.Tensor) and x.ndim == 3: + x = [x[i] for i in range(x.shape[0])] + elif (not (isinstance(x, list) and all(t.ndim == 2 for t in x)) and + not (isinstance(x, torch.Tensor) and x.ndim == 2)): + raise TypeError("x must be either a list of 2D tensors or a 2D " + "tensor or a 3D tensor") - # array_of_lengths = np.sqrt(np.array(list_of_lengths)) + # If pos is a 3D tensor, we split it into a list of 2D tensors + if isinstance(pos, torch.Tensor) and pos.ndim == 3: + pos = [pos[i] for i in range(pos.shape[0])] + elif not (isinstance(pos, list) and all( + t.ndim == 2 for t in pos)) and not ( + isinstance(pos, torch.Tensor) and pos.ndim == 2): + raise TypeError("pos must be either a list of 2D tensors or a 2D " + "tensor or a 3D tensor") - # return array_of_edges, array_of_lengths + # If edge_index is a 3D tensor, we split it into a list of 2D tensors + if edge_index is not None: + if isinstance(edge_index, torch.Tensor) and edge_index.ndim == 3: + edge_index = [edge_index[i] for i in range(edge_index.shape[0])] + elif not (isinstance(edge_index, list) and all( + t.ndim == 2 for t in edge_index)) and not ( + isinstance(edge_index, + torch.Tensor) and edge_index.ndim == 2): + raise TypeError( + "edge_index must be either a list of 2D tensors or a 2D " + "tensor or a 3D tensor") - return Data( - x=nodes_data, - pos=nodes_coordinates.T, - - edge_index=array_of_edges, - ) + return x, pos, edge_index @staticmethod - def _build_radius(**kwargs): - logging.debug("Creating graph with radius mode.") - - # check for mandatory arguments - if "nodes_coordinates" not in kwargs: - raise ValueError("Nodes coordinates must be provided in the kwargs.") - if "nodes_data" not in kwargs: - raise ValueError("Nodes data must be provided in the kwargs.") - if "radius" not in kwargs: - raise ValueError("Radius must be provided in the kwargs.") - - nodes_coordinates = kwargs["nodes_coordinates"] - nodes_data = kwargs["nodes_data"] - radius = kwargs["radius"] - - edges_data = kwargs.get("edge_data", None) - loop = kwargs.get("loop", False) - batch = kwargs.get("batch", None) - - logging.debug(f"radius: {radius}, loop: {loop}, " - f"batch: {batch}") - - edge_index = radius_graph( - x=nodes_coordinates.tensor, - r=radius, - loop=loop, - batch=batch, - ) - - logging.debug(f"edge_index computed") - return Data( - x=nodes_data.tensor, - pos=nodes_coordinates.tensor, - edge_index=edge_index, - edge_attr=edges_data, - ) + def _check_additional_params(additional_params, data_len): + if additional_params is not None: + if not isinstance(additional_params, dict): + raise TypeError("additional_params must be a dictionary.") + for param, val in additional_params.items(): + # Check if the values are tensors or lists of tensors + if isinstance(val, torch.Tensor): + # If the tensor is 3D, we split it into a list of 2D tensors + # In this case there must be a additional parameter for each + # node + if val.ndim == 3: + additional_params[param] = [val[i] for i in + range(val.shape[0])] + # If the tensor is 2D, we replicate it for each node + elif val.ndim == 2: + additional_params[param] = [val] * data_len + # If the tensor is 1D, each graph has a scalar values as + # additional parameter + if val.ndim == 1: + if len(val) == data_len: + additional_params[param] = [val[i] for i in + range(len(val))] + else: + additional_params[param] = [val for _ in + range(data_len)] + elif not isinstance(val, list): + raise TypeError("additional_params values must be tensors " + "or lists of tensors.") + else: + additional_params = {} + return additional_params + + def _check_and_build_edge_attr(self, edge_attr, build_edge_attr, data_len, + edge_index, pos, x): + # Check if edge_attr is consistent with x and pos + if edge_attr is not None: + if build_edge_attr is True: + warning("edge_attr is not None. build_edge_attr will not be " + "considered.") + if isinstance(edge_attr, list): + if len(edge_attr) != data_len: + raise TypeError("edge_attr must have the same length as x " + "and pos.") + return [edge_attr] * data_len + + if build_edge_attr: + return [self._build_edge_attr(x, pos_, edge_index_) for + pos_, edge_index_ in zip(pos, edge_index)] + + +class RadiusGraph(Graph): + def __init__( + self, + x, + pos, + r, + **kwargs + ): + x, pos, edge_index = Graph._check_input_consistency(x, pos) + + if isinstance(pos, (torch.Tensor, LabelTensor)): + edge_index = RadiusGraph._radius_graph(pos, r) + else: + edge_index = [RadiusGraph._radius_graph(p, r) for p in pos] + + super().__init__(x=x, pos=pos, edge_index=edge_index, + **kwargs) @staticmethod - def build(mode, **kwargs): + def _radius_graph(points, r): """ - Constructor for the `Graph` class. + Implementation of the radius graph construction. + :param points: The input points. + :type points: torch.Tensor + :param r: The radius. + :type r: float + :return: The edge index. + :rtype: torch.Tensor """ - if mode == "radius": - graph = Graph._build_radius(**kwargs) - elif mode == "triangulation": - graph = Graph._build_triangulation(**kwargs) - else: - raise ValueError(f"Mode {mode} not recognized") - - return Graph(graph) + dist = torch.cdist(points, points, p=2) + edge_index = torch.nonzero(dist <= r, as_tuple=False).t() + return edge_index - def __repr__(self): - return f"Graph(data={self.data})" \ No newline at end of file +class KNNGraph(Graph): + def __init__( + self, + x, + pos, + k, + **kwargs + ): + x, pos, edge_index = Graph._check_input_consistency(x, pos) + if isinstance(pos, (torch.Tensor, LabelTensor)): + edge_index = KNNGraph._knn_graph(pos, k) + else: + edge_index = [KNNGraph._knn_graph(p, k) for p in pos] + super().__init__(x=x, pos=pos, edge_index=edge_index, + **kwargs) + + @staticmethod + def _knn_graph(points, k): + """ + Implementation of the k-nearest neighbors graph construction. + :param points: The input points. + :type points: torch.Tensor + :param k: The number of nearest neighbors. + :type k: int + :return: The edge index. + :rtype: torch.Tensor + """ + dist = torch.cdist(points, points, p=2) + knn_indices = torch.topk(dist, k=k + 1, largest=False).indices[:, 1:] + row = torch.arange(points.size(0)).repeat_interleave(k) + col = knn_indices.flatten() + edge_index = torch.stack([row, col], dim=0) + return edge_index diff --git a/pina/model/__init__.py b/pina/model/__init__.py index 3224d0af3..c75f9b658 100644 --- a/pina/model/__init__.py +++ b/pina/model/__init__.py @@ -10,6 +10,7 @@ "AveragingNeuralOperator", "LowRankNeuralOperator", "Spline", + "GraphNeuralOperator" ] from .feed_forward import FeedForward, ResidualFeedForward @@ -20,3 +21,4 @@ from .avno import AveragingNeuralOperator from .lno import LowRankNeuralOperator from .spline import Spline +from .gno import GraphNeuralOperator \ No newline at end of file diff --git a/pina/model/gno.py b/pina/model/gno.py new file mode 100644 index 000000000..9a8b88878 --- /dev/null +++ b/pina/model/gno.py @@ -0,0 +1,177 @@ +import torch +from torch.nn import Tanh +from .layers import GNOBlock +from .base_no import KernelNeuralOperator + + +class GraphNeuralKernel(torch.nn.Module): + """ + TODO add docstring + """ + + def __init__( + self, + width, + edge_features, + n_layers=2, + internal_n_layers=0, + internal_layers=None, + inner_size=None, + internal_func=None, + external_func=None, + shared_weights=False + ): + """ + The Graph Neural Kernel constructor. + + :param width: The width of the kernel. + :type width: int + :param edge_features: The number of edge features. + :type edge_features: int + :param n_layers: The number of kernel layers. + :type n_layers: int + :param internal_n_layers: The number of layers the FF Neural Network internal to each Kernel Layer. + :type internal_n_layers: int + :param internal_layers: Number of neurons of hidden layers(s) in the FF Neural Network inside for each Kernel Layer. + :type internal_layers: list | tuple + :param internal_func: The activation function used inside the computation of the representation of the edge features in the Graph Integral Layer. + :param external_func: The activation function applied to the output of the Graph Integral Layer. + :type external_func: torch.nn.Module + :param shared_weights: If ``True`` the weights of the Graph Integral Layers are shared. + """ + super().__init__() + if external_func is None: + external_func = Tanh + if internal_func is None: + internal_func = Tanh + + if shared_weights: + self.layers = GNOBlock( + width=width, + edges_features=edge_features, + n_layers=internal_n_layers, + layers=internal_layers, + inner_size=inner_size, + internal_func=internal_func, + external_func=external_func) + self.n_layers = n_layers + self.forward = self.forward_shared + else: + self.layers = torch.nn.ModuleList( + [GNOBlock( + width=width, + edges_features=edge_features, + n_layers=internal_n_layers, + layers=internal_layers, + inner_size=inner_size, + internal_func=internal_func, + external_func=external_func + ) + for _ in range(n_layers)] + ) + + def forward(self, x, edge_index, edge_attr): + """ + The forward pass of the Graph Neural Kernel used when the weights are not shared. + + :param x: The input batch. + :type x: torch.Tensor + :param edge_index: The edge index. + :type edge_index: torch.Tensor + :param edge_attr: The edge attributes. + :type edge_attr: torch.Tensor + """ + for layer in self.layers: + x = layer(x, edge_index, edge_attr) + return x + + def forward_shared(self, x, edge_index, edge_attr): + """ + The forward pass of the Graph Neural Kernel used when the weights are shared. + + :param x: The input batch. + :type x: torch.Tensor + :param edge_index: The edge index. + :type edge_index: torch.Tensor + :param edge_attr: The edge attributes. + :type edge_attr: torch.Tensor + """ + for _ in range(self.n_layers): + x = self.layers(x, edge_index, edge_attr) + return x + + +class GraphNeuralOperator(KernelNeuralOperator): + """ + TODO add docstring + """ + + def __init__( + self, + lifting_operator, + projection_operator, + edge_features, + n_layers=10, + internal_n_layers=0, + inner_size=None, + internal_layers=None, + internal_func=None, + external_func=None, + shared_weights=True + ): + """ + The Graph Neural Operator constructor. + + :param lifting_operator: The lifting operator mapping the node features to its hidden dimension. + :type lifting_operator: torch.nn.Module + :param projection_operator: The projection operator mapping the hidden representation of the nodes features to the output function. + :type projection_operator: torch.nn.Module + :param edge_features: Number of edge features. + :type edge_features: int + :param n_layers: The number of kernel layers. + :type n_layers: int + :param internal_n_layers: The number of layers the Feed Forward Neural Network internal to each Kernel Layer. + :type internal_n_layers: int + :param internal_layers: Number of neurons of hidden layers(s) in the FF Neural Network inside for each Kernel Layer. + :type internal_layers: list | tuple + :param internal_func: The activation function used inside the computation of the representation of the edge features in the Graph Integral Layer. + :type internal_func: torch.nn.Module + :param external_func: The activation function applied to the output of the Graph Integral Kernel. + :type external_func: torch.nn.Module + :param shared_weights: If ``True`` the weights of the Graph Integral Layers are shared. + :type shared_weights: bool + """ + + if internal_func is None: + internal_func = Tanh + if external_func is None: + external_func = Tanh + + super().__init__( + lifting_operator=lifting_operator, + integral_kernels=GraphNeuralKernel( + width=lifting_operator.out_features, + edge_features=edge_features, + internal_n_layers=internal_n_layers, + inner_size=inner_size, + internal_layers=internal_layers, + external_func=external_func, + internal_func=internal_func, + n_layers=n_layers, + shared_weights=shared_weights + ), + projection_operator=projection_operator + ) + + def forward(self, x): + """ + The forward pass of the Graph Neural Operator. + + :param x: The input batch. + :type x: torch_geometric.data.Batch + """ + x, edge_index, edge_attr = x.x, x.edge_index, x.edge_attr + x = self.lifting_operator(x) + x = self.integral_kernels(x, edge_index, edge_attr) + x = self.projection_operator(x) + return x diff --git a/pina/model/layers/__init__.py b/pina/model/layers/__init__.py index 5108522c5..3e3e71682 100644 --- a/pina/model/layers/__init__.py +++ b/pina/model/layers/__init__.py @@ -15,6 +15,7 @@ "AVNOBlock", "LowRankBlock", "RBFBlock", + "GNOBlock" ] from .convolution_2d import ContinuousConvBlock @@ -31,3 +32,4 @@ from .avno_layer import AVNOBlock from .lowrank_layer import LowRankBlock from .rbf_layer import RBFBlock +from .gno_block import GNOBlock diff --git a/pina/model/layers/gno_block.py b/pina/model/layers/gno_block.py new file mode 100644 index 000000000..34929fe89 --- /dev/null +++ b/pina/model/layers/gno_block.py @@ -0,0 +1,87 @@ +import torch +from torch_geometric.nn import MessagePassing + + +class GNOBlock(MessagePassing): + """ + TODO: Add documentation + """ + + def __init__( + self, + width, + edges_features, + n_layers=2, + layers=None, + inner_size=None, + internal_func=None, + external_func=None + ): + """ + Initialize the Graph Integral Layer, inheriting from the MessagePassing class of PyTorch Geometric. + + :param width: The width of the hidden representation of the nodes features + :type width: int + :param edges_features: The number of edge features. + :type edges_features: int + :param n_layers: The number of layers in the Feed Forward Neural Network used to compute the representation of the edges features. + :type n_layers: int + """ + from pina.model import FeedForward + super(GNOBlock, self).__init__(aggr='mean') + self.width = width + if layers is None and inner_size is None: + inner_size = width + self.dense = FeedForward(input_dimensions=edges_features, + output_dimensions=width ** 2, + n_layers=n_layers, + layers=layers, + inner_size=inner_size, + func=internal_func) + self.W = torch.nn.Linear(width, width) + self.func = external_func() + + def message(self, x_j, edge_attr): + """ + This function computes the message passed between the nodes of the graph. Overwrite the default message function defined in the MessagePassing class. + + :param x_j: The node features of the neighboring. + :type x_j: torch.Tensor + :param edge_attr: The edge features. + :type edge_attr: torch.Tensor + :return: The message passed between the nodes of the graph. + :rtype: torch.Tensor + """ + x = self.dense(edge_attr).view(-1, self.width, self.width) + return torch.einsum('bij,bj->bi', x, x_j) + + def update(self, aggr_out, x): + """ + This function updates the node features of the graph. Overwrite the default update function defined in the MessagePassing class. + + :param aggr_out: The aggregated messages. + :type aggr_out: torch.Tensor + :param x: The node features. + :type x: torch.Tensor + :return: The updated node features. + :rtype: torch.Tensor + """ + aggr_out = aggr_out + self.W(x) + return aggr_out + + def forward(self, x, edge_index, edge_attr): + """ + The forward pass of the Graph Integral Layer. + + :param x: Node features. + :type x: torch.Tensor + :param edge_index: Edge index. + :type edge_index: torch.Tensor + :param edge_attr: Edge features. + :type edge_attr: torch.Tensor + :return: Output of a single iteration over the Graph Integral Layer. + :rtype: torch.Tensor + """ + return self.func( + self.propagate(edge_index, x=x, edge_attr=edge_attr) + ) diff --git a/tests/test_collector.py b/tests/test_collector.py new file mode 100644 index 000000000..f25a01af4 --- /dev/null +++ b/tests/test_collector.py @@ -0,0 +1,125 @@ +import torch +import pytest +from pina import Condition, LabelTensor, Graph +from pina.condition import InputOutputPointsCondition, DomainEquationCondition +from pina.graph import RadiusGraph +from pina.problem import AbstractProblem, SpatialProblem +from pina.domain import CartesianDomain +from pina.equation.equation import Equation +from pina.equation.equation_factory import FixedValue +from pina.operators import laplacian + +def test_supervised_tensor_collector(): + class SupervisedProblem(AbstractProblem): + output_variables = None + conditions = { + 'data1' : Condition(input_points=torch.rand((10,2)), + output_points=torch.rand((10,2))), + 'data2' : Condition(input_points=torch.rand((20,2)), + output_points=torch.rand((20,2))), + 'data3' : Condition(input_points=torch.rand((30,2)), + output_points=torch.rand((30,2))), + } + problem = SupervisedProblem() + collector = problem.collector + for v in collector.conditions_name.values(): + assert v in problem.conditions.keys() + assert all(collector._is_conditions_ready.values()) + +def test_pinn_collector(): + def laplace_equation(input_, output_): + force_term = (torch.sin(input_.extract(['x']) * torch.pi) * + torch.sin(input_.extract(['y']) * torch.pi)) + delta_u = laplacian(output_.extract(['u']), input_) + return delta_u - force_term + + my_laplace = Equation(laplace_equation) + in_ = LabelTensor(torch.tensor([[0., 1.]], requires_grad=True), ['x', 'y']) + out_ = LabelTensor(torch.tensor([[0.]], requires_grad=True), ['u']) + class Poisson(SpatialProblem): + output_variables = ['u'] + spatial_domain = CartesianDomain({'x': [0, 1], 'y': [0, 1]}) + + conditions = { + 'gamma1': + Condition(domain=CartesianDomain({ + 'x': [0, 1], + 'y': 1 + }), + equation=FixedValue(0.0)), + 'gamma2': + Condition(domain=CartesianDomain({ + 'x': [0, 1], + 'y': 0 + }), + equation=FixedValue(0.0)), + 'gamma3': + Condition(domain=CartesianDomain({ + 'x': 1, + 'y': [0, 1] + }), + equation=FixedValue(0.0)), + 'gamma4': + Condition(domain=CartesianDomain({ + 'x': 0, + 'y': [0, 1] + }), + equation=FixedValue(0.0)), + 'D': + Condition(domain=CartesianDomain({ + 'x': [0, 1], + 'y': [0, 1] + }), + equation=my_laplace), + 'data': + Condition(input_points=in_, output_points=out_) + } + + def poisson_sol(self, pts): + return -(torch.sin(pts.extract(['x']) * torch.pi) * + torch.sin(pts.extract(['y']) * torch.pi)) / (2 * torch.pi**2) + + truth_solution = poisson_sol + + problem = Poisson() + collector = problem.collector + for k,v in problem.conditions.items(): + if isinstance(v, InputOutputPointsCondition): + assert collector._is_conditions_ready[k] == True + assert list(collector.data_collections[k].keys()) == ['input_points', 'output_points'] + else: + assert collector._is_conditions_ready[k] == False + assert collector.data_collections[k] == {} + + boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] + problem.discretise_domain(10, 'grid', locations=boundaries) + problem.discretise_domain(10, 'grid', locations='D') + assert all(collector._is_conditions_ready.values()) + for k,v in problem.conditions.items(): + if isinstance(v, DomainEquationCondition): + assert list(collector.data_collections[k].keys()) == ['input_points', 'equation'] + + +def test_supervised_graph_collector(): + pos = torch.rand((100,3)) + x = [torch.rand((100,3)) for _ in range(10)] + graph_list_1 = RadiusGraph(pos=pos, x=x, build_edge_attr=True, r=.4) + out_1 = torch.rand((10,100,3)) + pos = torch.rand((50,3)) + x = [torch.rand((50,3)) for _ in range(10)] + graph_list_2 = RadiusGraph(pos=pos, x=x, build_edge_attr=True, r=.4) + out_2 = torch.rand((10,50,3)) + class SupervisedProblem(AbstractProblem): + output_variables = None + conditions = { + 'data1' : Condition(input_points=graph_list_1, + output_points=out_1), + 'data2' : Condition(input_points=graph_list_2, + output_points=out_2), + } + + problem = SupervisedProblem() + collector = problem.collector + assert all(collector._is_conditions_ready.values()) + for v in collector.conditions_name.values(): + assert v in problem.conditions.keys() diff --git a/tests/test_graph.py b/tests/test_graph.py new file mode 100644 index 000000000..660ec3428 --- /dev/null +++ b/tests/test_graph.py @@ -0,0 +1,163 @@ +import pytest +import torch +from pina.graph import RadiusGraph, KNNGraph + + +@pytest.mark.parametrize( + "x, pos", + [ + ([torch.rand(10, 2) for _ in range(3)], + [torch.rand(10, 3) for _ in range(3)]), + ([torch.rand(10, 2) for _ in range(3)], + [torch.rand(10, 3) for _ in range(3)]), + (torch.rand(3, 10, 2), torch.rand(3, 10, 3)), + (torch.rand(3, 10, 2), torch.rand(3, 10, 3)), + ] +) +def test_build_multiple_graph_multiple_val(x, pos): + graph = RadiusGraph(x=x, pos=pos, build_edge_attr=False, r=.3) + assert len(graph.data) == 3 + data = graph.data + assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x)) + assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos)) + assert all(len(d.edge_index) == 2 for d in data) + graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3) + data = graph.data + assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x)) + assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos)) + assert all(len(d.edge_index) == 2 for d in data) + assert all(d.edge_attr is not None for d in data) + assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data) + + graph = KNNGraph(x=x, pos=pos, build_edge_attr=True, k=3) + data = graph.data + assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x)) + assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos)) + assert all(len(d.edge_index) == 2 for d in data) + assert all(d.edge_attr is not None for d in data) + assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data) + + +def test_build_single_graph_multiple_val(): + x = torch.rand(10, 2) + pos = torch.rand(10, 3) + graph = RadiusGraph(x=x, pos=pos, build_edge_attr=False, r=.3) + assert len(graph.data) == 1 + data = graph.data + assert all(torch.isclose(d.x, x).all() for d in data) + assert all(torch.isclose(d_.pos, pos).all() for d_ in data) + assert all(len(d.edge_index) == 2 for d in data) + graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3) + data = graph.data + assert len(graph.data) == 1 + assert all(torch.isclose(d.x, x).all() for d in data) + assert all(torch.isclose(d_.pos, pos).all() for d_ in data) + assert all(len(d.edge_index) == 2 for d in data) + assert all(d.edge_attr is not None for d in data) + assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data) + + x = torch.rand(10, 2) + pos = torch.rand(10, 3) + graph = KNNGraph(x=x, pos=pos, build_edge_attr=True, k=3) + assert len(graph.data) == 1 + data = graph.data + assert all(torch.isclose(d.x, x).all() for d in data) + assert all(torch.isclose(d_.pos, pos).all() for d_ in data) + assert all(len(d.edge_index) == 2 for d in data) + graph = KNNGraph(x=x, pos=pos, build_edge_attr=True, k=3) + data = graph.data + assert len(graph.data) == 1 + assert all(torch.isclose(d.x, x).all() for d in data) + assert all(torch.isclose(d_.pos, pos).all() for d_ in data) + assert all(len(d.edge_index) == 2 for d in data) + assert all(d.edge_attr is not None for d in data) + assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data) + + +@pytest.mark.parametrize( + "pos", + [ + ([torch.rand(10, 3) for _ in range(3)]), + ([torch.rand(10, 3) for _ in range(3)]), + (torch.rand(3, 10, 3)), + (torch.rand(3, 10, 3)) + ] +) +def test_build_single_graph_single_val(pos): + x = torch.rand(10, 2) + graph = RadiusGraph(x=x, pos=pos, build_edge_attr=False, r=.3) + assert len(graph.data) == 3 + data = graph.data + assert all(torch.isclose(d.x, x).all() for d in data) + assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos)) + assert all(len(d.edge_index) == 2 for d in data) + graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3) + data = graph.data + assert all(torch.isclose(d.x, x).all() for d in data) + assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos)) + assert all(len(d.edge_index) == 2 for d in data) + assert all(d.edge_attr is not None for d in data) + assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data) + x = torch.rand(10, 2) + graph = KNNGraph(x=x, pos=pos, build_edge_attr=False, k=3) + assert len(graph.data) == 3 + data = graph.data + assert all(torch.isclose(d.x, x).all() for d in data) + assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos)) + assert all(len(d.edge_index) == 2 for d in data) + graph = KNNGraph(x=x, pos=pos, build_edge_attr=True, k=3) + data = graph.data + assert all(torch.isclose(d.x, x).all() for d in data) + assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos)) + assert all(len(d.edge_index) == 2 for d in data) + assert all(d.edge_attr is not None for d in data) + assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data) + + +def test_additional_parameters_1(): + x = torch.rand(3, 10, 2) + pos = torch.rand(3, 10, 2) + additional_parameters = {'y': torch.ones(3)} + graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3, + additional_params=additional_parameters) + assert len(graph.data) == 3 + data = graph.data + assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x)) + assert all(hasattr(d, 'y') for d in data) + assert all(d_.y == 1 for d_ in data) + + +@pytest.mark.parametrize( + "additional_parameters", + [ + ({'y': torch.rand(3, 10, 1)}), + ({'y': [torch.rand(10, 1) for _ in range(3)]}), + ] +) +def test_additional_parameters_2(additional_parameters): + x = torch.rand(3, 10, 2) + pos = torch.rand(3, 10, 2) + graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3, + additional_params=additional_parameters) + assert len(graph.data) == 3 + data = graph.data + assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x)) + assert all(hasattr(d, 'y') for d in data) + assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x)) + +def test_custom_build_edge_attr_func(): + x = torch.rand(3, 10, 2) + pos = torch.rand(3, 10, 2) + + def build_edge_attr(x, pos, edge_index): + return torch.cat([pos[edge_index[0]], pos[edge_index[1]]], dim=-1) + + graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3, + custom_build_edge_attr=build_edge_attr) + assert len(graph.data) == 3 + data = graph.data + assert all(hasattr(d, 'edge_attr') for d in data) + assert all(d.edge_attr.shape[1] == 4 for d in data) + assert all(torch.isclose(d.edge_attr, + build_edge_attr(d.x, d.pos, d.edge_index)).all() + for d in data) diff --git a/tests/test_model/test_gno.py b/tests/test_model/test_gno.py new file mode 100644 index 000000000..8fb10d8e5 --- /dev/null +++ b/tests/test_model/test_gno.py @@ -0,0 +1,129 @@ +import pytest +import torch +from pina.graph import KNNGraph +from pina.model import GraphNeuralOperator +from torch_geometric.data import Batch + +x = [torch.rand(100, 6) for _ in range(10)] +pos = [torch.rand(100, 3) for _ in range(10)] +graph = KNNGraph(x=x, pos=pos, build_edge_attr=True, k=6) +input_ = Batch.from_data_list(graph.data) + + +@pytest.mark.parametrize( + "shared_weights", + [ + True, + False + ] +) +def test_constructor(shared_weights): + lifting_operator = torch.nn.Linear(6, 16) + projection_operator = torch.nn.Linear(16, 3) + GraphNeuralOperator(lifting_operator=lifting_operator, + projection_operator=projection_operator, + edge_features=3, + internal_layers=[16, 16], + shared_weights=shared_weights) + + GraphNeuralOperator(lifting_operator=lifting_operator, + projection_operator=projection_operator, + edge_features=3, + inner_size=16, + internal_n_layers=10, + shared_weights=shared_weights) + + int_func = torch.nn.Softplus + ext_func = torch.nn.ReLU + + GraphNeuralOperator(lifting_operator=lifting_operator, + projection_operator=projection_operator, + edge_features=3, + internal_n_layers=10, + shared_weights=shared_weights, + internal_func=int_func, + external_func=ext_func) + + +@pytest.mark.parametrize( + "shared_weights", + [ + True, + False + ] +) +def test_forward_1(shared_weights): + lifting_operator = torch.nn.Linear(6, 16) + projection_operator = torch.nn.Linear(16, 3) + model = GraphNeuralOperator(lifting_operator=lifting_operator, + projection_operator=projection_operator, + edge_features=3, + internal_layers=[16, 16], + shared_weights=shared_weights) + output_ = model(input_) + assert output_.shape == torch.Size([1000, 3]) + + +@pytest.mark.parametrize( + "shared_weights", + [ + True, + False + ] +) +def test_forward_2(shared_weights): + lifting_operator = torch.nn.Linear(6, 16) + projection_operator = torch.nn.Linear(16, 3) + model = GraphNeuralOperator(lifting_operator=lifting_operator, + projection_operator=projection_operator, + edge_features=3, + inner_size=32, + internal_n_layers=2, + shared_weights=shared_weights) + output_ = model(input_) + assert output_.shape == torch.Size([1000, 3]) + + +@pytest.mark.parametrize( + "shared_weights", + [ + True, + False + ] +) +def test_backward(shared_weights): + lifting_operator = torch.nn.Linear(6, 16) + projection_operator = torch.nn.Linear(16, 3) + model = GraphNeuralOperator(lifting_operator=lifting_operator, + projection_operator=projection_operator, + edge_features=3, + internal_layers=[16, 16], + shared_weights=shared_weights) + input_.x.requires_grad = True + output_ = model(input_) + l = torch.mean(output_) + l.backward() + assert input_.x.grad.shape == torch.Size([1000, 6]) + + +@pytest.mark.parametrize( + "shared_weights", + [ + True, + False + ] +) +def test_backward_2(shared_weights): + lifting_operator = torch.nn.Linear(6, 16) + projection_operator = torch.nn.Linear(16, 3) + model = GraphNeuralOperator(lifting_operator=lifting_operator, + projection_operator=projection_operator, + edge_features=3, + inner_size=32, + internal_n_layers=2, + shared_weights=shared_weights) + input_.x.requires_grad = True + output_ = model(input_) + l = torch.mean(output_) + l.backward() + assert input_.x.grad.shape == torch.Size([1000, 6])