-
Notifications
You must be signed in to change notification settings - Fork 83
Add Graph class #403
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Graph class #403
Changes from all commits
1f32ace
05faaaa
6eba0cf
d591aee
dc87615
d79017e
a68b711
9b7cdbf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
import torch | ||
from torch.nn import Tanh | ||
from .layers import GNOBlock | ||
from .base_no import KernelNeuralOperator | ||
|
||
|
||
class GraphNeuralKernel(torch.nn.Module): | ||
""" | ||
TODO add docstring | ||
""" | ||
|
||
def __init__( | ||
self, | ||
width, | ||
edge_features, | ||
n_layers=2, | ||
internal_n_layers=0, | ||
internal_layers=None, | ||
inner_size=None, | ||
internal_func=None, | ||
external_func=None, | ||
shared_weights=False | ||
): | ||
""" | ||
The Graph Neural Kernel constructor. | ||
|
||
:param width: The width of the kernel. | ||
:type width: int | ||
:param edge_features: The number of edge features. | ||
:type edge_features: int | ||
:param n_layers: The number of kernel layers. | ||
:type n_layers: int | ||
:param internal_n_layers: The number of layers the FF Neural Network internal to each Kernel Layer. | ||
:type internal_n_layers: int | ||
:param internal_layers: Number of neurons of hidden layers(s) in the FF Neural Network inside for each Kernel Layer. | ||
:type internal_layers: list | tuple | ||
:param internal_func: The activation function used inside the computation of the representation of the edge features in the Graph Integral Layer. | ||
:param external_func: The activation function applied to the output of the Graph Integral Layer. | ||
:type external_func: torch.nn.Module | ||
:param shared_weights: If ``True`` the weights of the Graph Integral Layers are shared. | ||
""" | ||
super().__init__() | ||
if external_func is None: | ||
external_func = Tanh | ||
if internal_func is None: | ||
internal_func = Tanh | ||
|
||
if shared_weights: | ||
self.layers = GNOBlock( | ||
width=width, | ||
edges_features=edge_features, | ||
n_layers=internal_n_layers, | ||
layers=internal_layers, | ||
inner_size=inner_size, | ||
internal_func=internal_func, | ||
external_func=external_func) | ||
self.n_layers = n_layers | ||
self.forward = self.forward_shared | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't like a lot this forward separation, is there a way to combine the two? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In order to have an efficient way to store parameters (avoid to use torch.nn.ModuleList with the same model repeated n_layer times), another possible solution is using an if in the forward. Otherwise I can define another 2 classes: one for the shared_weights and one for the non shared_weights. Let me know what how to proceed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can keep it like this for the moment, maybe two classes is the best but for a single model I would not care a lot |
||
else: | ||
self.layers = torch.nn.ModuleList( | ||
[GNOBlock( | ||
width=width, | ||
edges_features=edge_features, | ||
n_layers=internal_n_layers, | ||
layers=internal_layers, | ||
inner_size=inner_size, | ||
internal_func=internal_func, | ||
external_func=external_func | ||
) | ||
for _ in range(n_layers)] | ||
) | ||
|
||
def forward(self, x, edge_index, edge_attr): | ||
""" | ||
The forward pass of the Graph Neural Kernel used when the weights are not shared. | ||
|
||
:param x: The input batch. | ||
:type x: torch.Tensor | ||
:param edge_index: The edge index. | ||
:type edge_index: torch.Tensor | ||
:param edge_attr: The edge attributes. | ||
:type edge_attr: torch.Tensor | ||
""" | ||
for layer in self.layers: | ||
x = layer(x, edge_index, edge_attr) | ||
return x | ||
|
||
def forward_shared(self, x, edge_index, edge_attr): | ||
""" | ||
The forward pass of the Graph Neural Kernel used when the weights are shared. | ||
|
||
:param x: The input batch. | ||
:type x: torch.Tensor | ||
:param edge_index: The edge index. | ||
:type edge_index: torch.Tensor | ||
:param edge_attr: The edge attributes. | ||
:type edge_attr: torch.Tensor | ||
""" | ||
for _ in range(self.n_layers): | ||
x = self.layers(x, edge_index, edge_attr) | ||
return x | ||
|
||
|
||
class GraphNeuralOperator(KernelNeuralOperator): | ||
""" | ||
TODO add docstring | ||
""" | ||
|
||
def __init__( | ||
self, | ||
lifting_operator, | ||
projection_operator, | ||
edge_features, | ||
n_layers=10, | ||
internal_n_layers=0, | ||
inner_size=None, | ||
internal_layers=None, | ||
internal_func=None, | ||
external_func=None, | ||
shared_weights=True | ||
): | ||
""" | ||
The Graph Neural Operator constructor. | ||
|
||
:param lifting_operator: The lifting operator mapping the node features to its hidden dimension. | ||
:type lifting_operator: torch.nn.Module | ||
:param projection_operator: The projection operator mapping the hidden representation of the nodes features to the output function. | ||
:type projection_operator: torch.nn.Module | ||
:param edge_features: Number of edge features. | ||
:type edge_features: int | ||
:param n_layers: The number of kernel layers. | ||
:type n_layers: int | ||
:param internal_n_layers: The number of layers the Feed Forward Neural Network internal to each Kernel Layer. | ||
:type internal_n_layers: int | ||
:param internal_layers: Number of neurons of hidden layers(s) in the FF Neural Network inside for each Kernel Layer. | ||
:type internal_layers: list | tuple | ||
:param internal_func: The activation function used inside the computation of the representation of the edge features in the Graph Integral Layer. | ||
:type internal_func: torch.nn.Module | ||
:param external_func: The activation function applied to the output of the Graph Integral Kernel. | ||
:type external_func: torch.nn.Module | ||
:param shared_weights: If ``True`` the weights of the Graph Integral Layers are shared. | ||
:type shared_weights: bool | ||
""" | ||
|
||
if internal_func is None: | ||
internal_func = Tanh | ||
if external_func is None: | ||
external_func = Tanh | ||
|
||
super().__init__( | ||
lifting_operator=lifting_operator, | ||
integral_kernels=GraphNeuralKernel( | ||
width=lifting_operator.out_features, | ||
edge_features=edge_features, | ||
internal_n_layers=internal_n_layers, | ||
inner_size=inner_size, | ||
internal_layers=internal_layers, | ||
external_func=external_func, | ||
internal_func=internal_func, | ||
n_layers=n_layers, | ||
shared_weights=shared_weights | ||
), | ||
projection_operator=projection_operator | ||
) | ||
|
||
def forward(self, x): | ||
""" | ||
The forward pass of the Graph Neural Operator. | ||
|
||
:param x: The input batch. | ||
:type x: torch_geometric.data.Batch | ||
""" | ||
x, edge_index, edge_attr = x.x, x.edge_index, x.edge_attr | ||
x = self.lifting_operator(x) | ||
x = self.integral_kernels(x, edge_index, edge_attr) | ||
x = self.projection_operator(x) | ||
return x |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import torch | ||
from torch_geometric.nn import MessagePassing | ||
|
||
|
||
class GNOBlock(MessagePassing): | ||
""" | ||
TODO: Add documentation | ||
""" | ||
|
||
def __init__( | ||
self, | ||
width, | ||
edges_features, | ||
n_layers=2, | ||
layers=None, | ||
inner_size=None, | ||
internal_func=None, | ||
external_func=None | ||
): | ||
""" | ||
Initialize the Graph Integral Layer, inheriting from the MessagePassing class of PyTorch Geometric. | ||
|
||
:param width: The width of the hidden representation of the nodes features | ||
:type width: int | ||
:param edges_features: The number of edge features. | ||
:type edges_features: int | ||
:param n_layers: The number of layers in the Feed Forward Neural Network used to compute the representation of the edges features. | ||
:type n_layers: int | ||
""" | ||
from pina.model import FeedForward | ||
super(GNOBlock, self).__init__(aggr='mean') | ||
self.width = width | ||
if layers is None and inner_size is None: | ||
inner_size = width | ||
self.dense = FeedForward(input_dimensions=edges_features, | ||
output_dimensions=width ** 2, | ||
n_layers=n_layers, | ||
layers=layers, | ||
inner_size=inner_size, | ||
func=internal_func) | ||
self.W = torch.nn.Linear(width, width) | ||
self.func = external_func() | ||
|
||
def message(self, x_j, edge_attr): | ||
""" | ||
This function computes the message passed between the nodes of the graph. Overwrite the default message function defined in the MessagePassing class. | ||
|
||
:param x_j: The node features of the neighboring. | ||
:type x_j: torch.Tensor | ||
:param edge_attr: The edge features. | ||
:type edge_attr: torch.Tensor | ||
:return: The message passed between the nodes of the graph. | ||
:rtype: torch.Tensor | ||
""" | ||
x = self.dense(edge_attr).view(-1, self.width, self.width) | ||
return torch.einsum('bij,bj->bi', x, x_j) | ||
|
||
def update(self, aggr_out, x): | ||
""" | ||
This function updates the node features of the graph. Overwrite the default update function defined in the MessagePassing class. | ||
|
||
:param aggr_out: The aggregated messages. | ||
:type aggr_out: torch.Tensor | ||
:param x: The node features. | ||
:type x: torch.Tensor | ||
:return: The updated node features. | ||
:rtype: torch.Tensor | ||
""" | ||
aggr_out = aggr_out + self.W(x) | ||
return aggr_out | ||
|
||
def forward(self, x, edge_index, edge_attr): | ||
""" | ||
The forward pass of the Graph Integral Layer. | ||
|
||
:param x: Node features. | ||
:type x: torch.Tensor | ||
:param edge_index: Edge index. | ||
:type edge_index: torch.Tensor | ||
:param edge_attr: Edge features. | ||
:type edge_attr: torch.Tensor | ||
:return: Output of a single iteration over the Graph Integral Layer. | ||
:rtype: torch.Tensor | ||
""" | ||
return self.func( | ||
self.propagate(edge_index, x=x, edge_attr=edge_attr) | ||
) |
Uh oh!
There was an error while loading. Please reload this page.