diff --git a/pina/model/kolmogorov_arnold_network/kan_layer.py b/pina/model/kolmogorov_arnold_network/kan_layer.py new file mode 100644 index 000000000..ddd360587 --- /dev/null +++ b/pina/model/kolmogorov_arnold_network/kan_layer.py @@ -0,0 +1,223 @@ +"""Create the infrastructure for a KAN layer""" +import torch +import numpy as np + +from pina.model.spline import Spline + + +class KAN_layer(torch.nn.Module): + """define a KAN layer using splines""" + def __init__(self, k: int, input_dimensions: int, output_dimensions: int, inner_nodes: int, num=3, grid_eps=0.1, grid_range=[-1, 1], grid_extension=True, noise_scale=0.1, base_function=torch.nn.SiLU(), scale_base_mu=0.0, scale_base_sigma=1.0, scale_sp=1.0, sparse_init=True, sp_trainable=True, sb_trainable=True) -> None: + """ + Initialize the KAN layer. + """ + super().__init__() + self.k = k + self.input_dimensions = input_dimensions + self.output_dimensions = output_dimensions + self.inner_nodes = inner_nodes + self.num = num + self.grid_eps = grid_eps + self.grid_range = grid_range + self.grid_extension = grid_extension + + if sparse_init: + self.mask = torch.nn.Parameter(self.sparse_mask(input_dimensions, output_dimensions)).requires_grad_(False) + else: + self.mask = torch.nn.Parameter(torch.ones(input_dimensions, output_dimensions)).requires_grad_(False) + + grid = torch.linspace(grid_range[0], grid_range[1], steps=self.num + 1)[None,:].expand(self.input_dimensions, self.num+1) + + if grid_extension: + h = (grid[:, [-1]] - grid[:, [0]]) / (grid.shape[1] - 1) + for i in range(self.k): + grid = torch.cat([grid[:, [0]] - h, grid], dim=1) + grid = torch.cat([grid, grid[:, [-1]] + h], dim=1) + + n_coef = grid.shape[1] - (self.k + 1) + + control_points = torch.nn.Parameter( + torch.randn(self.input_dimensions, self.output_dimensions, n_coef) * noise_scale + ) + + self.spline = Spline(order=self.k+1, knots=grid, control_points=control_points, grid_extension=grid_extension) + + self.scale_base = torch.nn.Parameter(scale_base_mu * 1 / np.sqrt(input_dimensions) + \ + scale_base_sigma * (torch.rand(input_dimensions, output_dimensions)*2-1) * 1/np.sqrt(input_dimensions), requires_grad=sb_trainable) + self.scale_spline = torch.nn.Parameter(torch.ones(input_dimensions, output_dimensions) * scale_sp * 1 / np.sqrt(input_dimensions) * self.mask, requires_grad=sp_trainable) + self.base_function = base_function + + @staticmethod + def sparse_mask(in_dimensions: int, out_dimensions: int) -> torch.Tensor: + ''' + get sparse mask + ''' + in_coord = torch.arange(in_dimensions) * 1/in_dimensions + 1/(2*in_dimensions) + out_coord = torch.arange(out_dimensions) * 1/out_dimensions + 1/(2*out_dimensions) + + dist_mat = torch.abs(out_coord[:,None] - in_coord[None,:]) + in_nearest = torch.argmin(dist_mat, dim=0) + in_connection = torch.stack([torch.arange(in_dimensions), in_nearest]).permute(1,0) + out_nearest = torch.argmin(dist_mat, dim=1) + out_connection = torch.stack([out_nearest, torch.arange(out_dimensions)]).permute(1,0) + all_connection = torch.cat([in_connection, out_connection], dim=0) + mask = torch.zeros(in_dimensions, out_dimensions) + mask[all_connection[:,0], all_connection[:,1]] = 1. + return mask + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass through the KAN layer. + Each input goes through: w_base*base(x) + w_spline*spline(x) + Then sum across input dimensions for each output node. + """ + if hasattr(x, 'tensor'): + x_tensor = x.tensor + else: + x_tensor = x + + base = self.base_function(x_tensor) # (batch, input_dimensions) + + basis = self.spline.basis(x_tensor, self.spline.k, self.spline.knots) + spline_out_per_input = torch.einsum("bil,iol->bio", basis, self.spline.control_points) + + base_term = self.scale_base[None, :, :] * base[:, :, None] + spline_term = self.scale_spline[None, :, :] * spline_out_per_input + combined = base_term + spline_term + combined = self.mask[None,:,:] * combined + + output = torch.sum(combined, dim=1) # (batch, output_dimensions) + + return output + + def update_grid_from_samples(self, x: torch.Tensor, mode: str = 'sample'): + """ + Update grid from input samples to better fit data distribution. + Based on PyKAN implementation but with boundary preservation. + """ + # Convert LabelTensor to regular tensor for spline operations + if hasattr(x, 'tensor'): + # This is a LabelTensor, extract the tensor part + x_tensor = x.tensor + else: + x_tensor = x + + with torch.no_grad(): + batch_size = x_tensor.shape[0] + x_sorted = torch.sort(x_tensor, dim=0)[0] # (batch_size, input_dimensions) + + # Get current number of intervals (excluding extensions) + if self.grid_extension: + num_interval = self.spline.knots.shape[1] - 1 - 2*self.k + else: + num_interval = self.spline.knots.shape[1] - 1 + + def get_grid(num_intervals: int): + """PyKAN-style grid creation with boundary preservation""" + ids = [int(batch_size * i / num_intervals) for i in range(num_intervals)] + [-1] + grid_adaptive = x_sorted[ids, :].transpose(0, 1) # (input_dimensions, num_intervals+1) + + original_min = self.grid_range[0] + original_max = self.grid_range[1] + + # Clamp adaptive grid to not shrink beyond original domain + grid_adaptive[:, 0] = torch.min(grid_adaptive[:, 0], + torch.full_like(grid_adaptive[:, 0], original_min)) + grid_adaptive[:, -1] = torch.max(grid_adaptive[:, -1], + torch.full_like(grid_adaptive[:, -1], original_max)) + + margin = 0.0 + h = (grid_adaptive[:, [-1]] - grid_adaptive[:, [0]] + 2 * margin) / num_intervals + grid_uniform = (grid_adaptive[:, [0]] - margin + + h * torch.arange(num_intervals + 1, device=x_tensor.device, dtype=x_tensor.dtype)[None, :]) + + grid_blended = (self.grid_eps * grid_uniform + + (1 - self.grid_eps) * grid_adaptive) + + return grid_blended + + # Create augmented evaluation points: samples + boundary points + # This ensures we preserve boundary behavior while adapting to sample density + boundary_points = torch.tensor([[self.grid_range[0]], [self.grid_range[1]]], + device=x_tensor.device, dtype=x_tensor.dtype).expand(-1, self.input_dimensions) + + # Combine samples with boundary points for evaluation + x_augmented = torch.cat([x_sorted, boundary_points], dim=0) + x_augmented = torch.sort(x_augmented, dim=0)[0] # Re-sort with boundaries included + + # Evaluate current spline at augmented points (samples + boundaries) + basis = self.spline.basis(x_augmented, self.spline.k, self.spline.knots) + y_eval = torch.einsum("bil,iol->bio", basis, self.spline.control_points) + + # Create new grid + new_grid = get_grid(num_interval) + + if mode == 'grid': + # For 'grid' mode, use denser sampling + sample_grid = get_grid(2 * num_interval) + x_augmented = sample_grid.transpose(0, 1) # (batch_size, input_dimensions) + basis = self.spline.basis(x_augmented, self.spline.k, self.spline.knots) + y_eval = torch.einsum("bil,iol->bio", basis, self.spline.control_points) + + # Add grid extensions if needed + if self.grid_extension: + h = (new_grid[:, [-1]] - new_grid[:, [0]]) / (new_grid.shape[1] - 1) + for i in range(self.k): + new_grid = torch.cat([new_grid[:, [0]] - h, new_grid], dim=1) + new_grid = torch.cat([new_grid, new_grid[:, [-1]] + h], dim=1) + + # Update grid and refit coefficients + self.spline.knots = new_grid + + try: + # Refit coefficients using augmented points (preserves boundaries) + self.spline.compute_control_points(x_augmented, y_eval) + except Exception as e: + print(f"Warning: Failed to update coefficients during grid refinement: {e}") + + def update_grid_resolution(self, new_num: int): + """ + Update grid resolution to a new number of intervals. + """ + with torch.no_grad(): + # Sample the current spline function on a dense grid + x_eval = torch.linspace( + self.grid_range[0], + self.grid_range[1], + steps=2 * new_num, + device=self.spline.knots.device + ) + x_eval = x_eval.unsqueeze(1).expand(-1, self.input_dimensions) + + basis = self.spline.basis(x_eval, self.spline.k, self.spline.knots) + y_eval = torch.einsum("bil,iol->bio", basis, self.spline.control_points) + + # Update num and create a new grid + self.num = new_num + new_grid = torch.linspace( + self.grid_range[0], + self.grid_range[1], + steps=self.num + 1, + device=self.spline.knots.device + ) + new_grid = new_grid[None, :].expand(self.input_dimensions, self.num + 1) + + if self.grid_extension: + h = (new_grid[:, [-1]] - new_grid[:, [0]]) / (new_grid.shape[1] - 1) + for i in range(self.k): + new_grid = torch.cat([new_grid[:, [0]] - h, new_grid], dim=1) + new_grid = torch.cat([new_grid, new_grid[:, [-1]] + h], dim=1) + + # Update spline with the new grid and re-compute control points + self.spline.knots = new_grid + self.spline.compute_control_points(x_eval, y_eval) + + def get_grid_statistics(self): + """Get statistics about the current grid for debugging/analysis""" + return { + 'grid_shape': self.spline.knots.shape, + 'grid_min': self.spline.knots.min().item(), + 'grid_max': self.spline.knots.max().item(), + 'grid_range': (self.spline.knots.max() - self.spline.knots.min()).mean().item(), + 'num_intervals': self.spline.knots.shape[1] - 1 - (2*self.k if self.spline.grid_extension else 0) + } \ No newline at end of file diff --git a/pina/model/kolmogorov_arnold_network/kan_network.py b/pina/model/kolmogorov_arnold_network/kan_network.py new file mode 100644 index 000000000..cd94a5894 --- /dev/null +++ b/pina/model/kolmogorov_arnold_network/kan_network.py @@ -0,0 +1,194 @@ +"""Kolmogorov Arnold Network implementation""" +import torch +import torch.nn as nn +from typing import List + +try: + from .kan_layer import KAN_layer +except ImportError: + from kan_layer import KAN_layer + +class KAN_Network(torch.nn.Module): + """ + Kolmogorov Arnold Network - A neural network using KAN layers instead of traditional MLP layers. + Each layer uses learnable univariate functions (B-splines + base functions) on edges. + """ + + def __init__( + self, + layer_sizes: List[int], + k: int = 3, + num: int = 3, + grid_eps: float = 0.1, + grid_range: List[float] = [-1, 1], + grid_extension: bool = True, + noise_scale: float = 0.1, + base_function = torch.nn.SiLU(), + scale_base_mu: float = 0.0, + scale_base_sigma: float = 1.0, + scale_sp: float = 1.0, + inner_nodes: int = 5, + sparse_init: bool = False, + sp_trainable: bool = True, + sb_trainable: bool = True, + save_act: bool = True + ): + """ + Initialize the KAN network. + + Args: + layer_sizes: List of integers defining the size of each layer [input_dim, hidden1, hidden2, ..., output_dim] + k: Order of the B-spline + num: Number of grid points for B-splines + grid_eps: Epsilon for grid spacing + grid_range: Range for the grid [min, max] + grid_extension: Whether to extend the grid + noise_scale: Scale for initialization noise + base_function: Base activation function (e.g., SiLU) + scale_base_mu: Mean for base function scaling + scale_base_sigma: Std for base function scaling + scale_sp: Scale for spline functions + """ + super().__init__() + + if len(layer_sizes) < 2: + raise ValueError("Need at least input and output dimensions") + + self.layer_sizes = layer_sizes + self.num_layers = len(layer_sizes) - 1 + self.save_act = save_act + + # Create KAN layers + self.kan_layers = nn.ModuleList() + + for i in range(self.num_layers): + layer = KAN_layer( + k=k, + input_dimensions=layer_sizes[i], + output_dimensions=layer_sizes[i+1], + num=num, + grid_eps=grid_eps, + grid_range=grid_range, + grid_extension=grid_extension, + noise_scale=noise_scale, + base_function=base_function, + scale_base_mu=scale_base_mu, + scale_base_sigma=scale_base_sigma, + scale_sp=scale_sp, + inner_nodes=inner_nodes, + sparse_init=sparse_init, + sp_trainable=sp_trainable, + sb_trainable=sb_trainable + ) + self.kan_layers.append(layer) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass through the KAN network. + + Args: + x: Input tensor of shape (batch_size, input_dimensions) + + Returns: + Output tensor of shape (batch_size, output_dimensions) + """ + current = x + self.acts = [current] + + for i, layer in enumerate(self.kan_layers): + current = layer(current) + + if self.save_act: + self.acts.append(current.detach()) + + return current + + def get_num_parameters(self) -> int: + """Get total number of trainable parameters""" + return sum(p.numel() for p in self.parameters() if p.requires_grad) + + + def update_grid_from_samples(self, x: torch.Tensor, mode: str = 'sample'): + """ + Update grid for all layers based on input samples. + This adapts the grid points to better fit the data distribution. + + Args: + x: Input samples, shape (batch_size, input_dimensions) + mode: 'sample' or 'grid' - determines sampling strategy + """ + current = x + + for i, layer in enumerate(self.kan_layers): + layer.update_grid_from_samples(current, mode=mode) + + if i < len(self.kan_layers) - 1: + with torch.no_grad(): + current = layer(current) + + def update_grid_resolution(self, new_num: int): + """ + Update the grid resolution for all layers. + This can be used for adaptive training where grid resolution increases over time. + + Args: + new_num: New number of grid points + """ + for layer in self.kan_layers: + layer.update_grid_resolution(new_num) + + def enable_sparsification(self, threshold: float = 1e-4): + """ + Enable sparsification by setting small weights to zero. + + Args: + threshold: Threshold below which weights are set to zero + """ + with torch.no_grad(): + for layer in self.kan_layers: + # Sparsify scale parameters + layer.scale_base.data[torch.abs(layer.scale_base.data) < threshold] = 0 + layer.scale_spline.data[torch.abs(layer.scale_spline.data) < threshold] = 0 + + # Update mask + layer.mask.data = ((torch.abs(layer.scale_base) >= threshold) | + (torch.abs(layer.scale_spline) >= threshold)).float() + + def get_activation_statistics(self, x: torch.Tensor): + """ + Get statistics about activations for analysis purposes. + + Args: + x: Input tensor + + Returns: + Dictionary with activation statistics + """ + stats = {} + current = x + + for i, layer in enumerate(self.kan_layers): + current = layer(current) + stats[f'layer_{i}'] = { + 'mean': current.mean().item(), + 'std': current.std().item(), + 'min': current.min().item(), + 'max': current.max().item() + } + + return stats + + + def get_network_grid_statistics(self): + """ + Get grid statistics for all layers in the network. + + Returns: + Dictionary with grid statistics for each layer + """ + stats = {} + for i, layer in enumerate(self.kan_layers): + stats[f'layer_{i}'] = layer.get_grid_statistics() + return stats + + \ No newline at end of file diff --git a/pina/model/spline.py b/pina/model/spline.py index c22c7937c..bc854ec56 100644 --- a/pina/model/spline.py +++ b/pina/model/spline.py @@ -9,7 +9,7 @@ class Spline(torch.nn.Module): Spline model class. """ - def __init__(self, order=4, knots=None, control_points=None) -> None: + def __init__(self, order=4, knots=None, control_points=None, grid_extension=True) -> None: """ Initialization of the :class:`Spline` class. @@ -33,6 +33,10 @@ def __init__(self, order=4, knots=None, control_points=None) -> None: self.order = order self.k = order - 1 + self.grid_extension = grid_extension + + # Cache for performance optimization + self._boundary_interval_idx = None if knots is not None and control_points is not None: self.knots = knots @@ -65,45 +69,123 @@ def __init__(self, order=4, knots=None, control_points=None) -> None: else: raise ValueError("Knots and control points cannot be both None.") - if self.knots.ndim != 1: - raise ValueError("Knot vector must be one-dimensional.") + if self.knots.ndim > 2: + raise ValueError("Knot vector must be one or two-dimensional.") + + # Precompute boundary interval index for performance + self._compute_boundary_interval() - def basis(self, x, k, i, t): + def _compute_boundary_interval(self): """ - Recursive method to compute the basis functions of the spline. + Precompute the rightmost non-degenerate interval index for performance. + This avoids the search loop in the basis function on every call. + """ + if not isinstance(self.knots, torch.Tensor): + self._boundary_interval_idx = None + return + + # Find the rightmost interval with positive width + knots = self.knots + + # Handle multi-dimensional knots + if knots.ndim > 1: + # For multi-dimensional knots, we'll handle boundary detection in the basis function + self._boundary_interval_idx = None + return + + # For 1D knots, find the rightmost non-degenerate interval + for i in range(len(knots) - 2, -1, -1): + if knots[i] < knots[i + 1]: # Non-degenerate interval found + self._boundary_interval_idx = i + return + + self._boundary_interval_idx = len(knots) - 2 if len(knots) > 1 else 0 + + def basis(self, x, k, knots): + """ + Compute the basis functions for the spline using an iterative approach. + This is a vectorized implementation based on the Cox-de Boor recursion. :param torch.Tensor x: The points to be evaluated. :param int k: The spline degree. - :param int i: The index of the interval. - :param torch.Tensor t: The tensor of knots. + :param torch.Tensor knots: The tensor of knots. :return: The basis functions evaluated at x :rtype: torch.Tensor """ - if k == 0: - a = torch.where( - torch.logical_and(t[i] <= x, x < t[i + 1]), 1.0, 0.0 + if x.ndim == 1: + x = x.unsqueeze(1) # (batch_size, 1) + if x.ndim == 2: + x = x.unsqueeze(2) # (batch_size, in_dim, 1) + + if knots.ndim == 1: + knots = knots.unsqueeze(0) # (1, n_knots) + if knots.ndim == 2: + knots = knots.unsqueeze(0) # (1, in_dim, n_knots) + + # Base case: k=0 + basis = (x >= knots[..., :-1]) & (x < knots[..., 1:]) + basis = basis.to(x.dtype) + + + if self._boundary_interval_idx is not None: + i = self._boundary_interval_idx + tolerance = 1e-10 + x_squeezed = x.squeeze(-1) + knot_left = knots[..., i] + knot_right = knots[..., i + 1] + + at_right_boundary = torch.abs(x_squeezed - knot_right) <= tolerance + in_rightmost_interval = (x_squeezed >= knot_left) & at_right_boundary + + if torch.any(in_rightmost_interval): + # For points at the boundary, ensure they're included in the rightmost interval + basis[..., i] = torch.logical_or(basis[..., i].bool(), in_rightmost_interval).to(basis.dtype) + + # Iterative step (Cox-de Boor recursion) + for i in range(1, k + 1): + # First term of the recursion + denom1 = knots[..., i:-1] - knots[..., : -(i + 1)] + denom1 = torch.where( + torch.abs(denom1) < 1e-8, torch.ones_like(denom1), denom1 ) - if i == len(t) - self.order - 1: - a = torch.where(x == t[-1], 1.0, a) - a.requires_grad_(True) - return a - - if t[i + k] == t[i]: - c1 = torch.tensor([0.0] * len(x), requires_grad=True) - else: - c1 = (x - t[i]) / (t[i + k] - t[i]) * self.basis(x, k - 1, i, t) + numer1 = x - knots[..., : -(i + 1)] + term1 = (numer1 / denom1) * basis[..., :-1] - if t[i + k + 1] == t[i + 1]: - c2 = torch.tensor([0.0] * len(x), requires_grad=True) - else: - c2 = ( - (t[i + k + 1] - x) - / (t[i + k + 1] - t[i + 1]) - * self.basis(x, k - 1, i + 1, t) + denom2 = knots[..., i + 1 :] - knots[..., 1:-i] + denom2 = torch.where( + torch.abs(denom2) < 1e-8, torch.ones_like(denom2), denom2 ) + numer2 = knots[..., i + 1 :] - x + term2 = (numer2 / denom2) * basis[..., 1:] + + basis = term1 + term2 + + return basis - return c1 + c2 + def compute_control_points(self, x_eval, y_eval): + """ + Compute control points from given evaluations using least squares. + This method fits the control points to match the target y_eval values. + """ + # (batch, in_dim) + A = self.basis(x_eval, self.k, self.knots) + # (batch, in_dim, n_basis) + + in_dim = A.shape[1] + out_dim = y_eval.shape[2] + n_basis = A.shape[2] + c = torch.zeros(in_dim, out_dim, n_basis).to(A.device) + + for i in range(in_dim): + # A_i is (batch, n_basis) + # y_i is (batch, out_dim) + A_i = A[:, i, :] + y_i = y_eval[:, i, :] + c_i = torch.linalg.lstsq(A_i, y_i).solution # (n_basis, out_dim) + c[i, :, :] = c_i.T # (out_dim, n_basis) + + self.control_points = torch.nn.Parameter(c) @property def control_points(self): @@ -131,9 +213,12 @@ def control_points(self, value): dim = value.get("dim", 1) value = torch.zeros(n, dim) + if not isinstance(value, torch.nn.Parameter): + value = torch.nn.Parameter(value) + if not isinstance(value, torch.Tensor): raise ValueError("Invalid value for control_points") - self._control_points = torch.nn.Parameter(value, requires_grad=True) + self._control_points = value @property def knots(self): @@ -180,6 +265,10 @@ def knots(self, value): raise ValueError("Invalid value for knots") self._knots = value + + # Recompute boundary interval when knots change + if hasattr(self, '_boundary_interval_idx'): + self._compute_boundary_interval() def forward(self, x): """ @@ -193,7 +282,21 @@ def forward(self, x): k = self.k c = self.control_points - basis = map(lambda i: self.basis(x, k, i, t)[:, None], range(len(c))) - y = (torch.cat(list(basis), dim=1) * c).sum(axis=1) + # Create the basis functions + # B will have shape (batch, in_dim, n_basis) + B = self.basis(x, k, t) + + # KAN case where control points are (in_dim, out_dim, n_basis) + if c.ndim == 3: + y_ij = torch.einsum("bil,iol->bio", B, c) # (batch, in_dim, out_dim) + # sum over input dimensions + y = torch.sum(y_ij, dim=1) # (batch, out_dim) + # Original test case + else: + B = B.squeeze(1) # (batch, n_basis) + if c.ndim == 1: + y = torch.einsum("bi,i->b", B, c) + else: + y = torch.einsum("bi,ij->bj", B, c) - return y + return y \ No newline at end of file