|
| 1 | +"""Module for the Kolmogorov-Arnold Network block.""" |
| 2 | + |
| 3 | +import torch |
| 4 | +from pina._src.model.vectorized_spline import VectorizedSpline |
| 5 | +from pina._src.core.utils import check_consistency, check_positive_integer |
| 6 | + |
| 7 | + |
| 8 | +class KANBlock(torch.nn.Module): |
| 9 | + """ |
| 10 | + The inner block of the Kolmogorov-Arnold Network (KAN). |
| 11 | +
|
| 12 | + The block applies a spline transformation to the input, optionally combined |
| 13 | + with a linear transformation of a base activation function. The output is |
| 14 | + aggregated across input dimensions to produce the final output. |
| 15 | +
|
| 16 | + .. seealso:: |
| 17 | +
|
| 18 | + **Original reference**: |
| 19 | + Liu Z., Wang Y., Vaidya S., Ruehle F., Halverson J., Soljacic M., |
| 20 | + Hou T., Tegmark M. (2025). |
| 21 | + *KAN: Kolmogorov-Arnold Networks*. |
| 22 | + DOI: `arXiv preprint arXiv:2404.19756. |
| 23 | + <https://arxiv.org/abs/2404.19756>`_ |
| 24 | + """ |
| 25 | + |
| 26 | + def __init__( |
| 27 | + self, |
| 28 | + input_dimensions, |
| 29 | + output_dimensions, |
| 30 | + spline_order=3, |
| 31 | + n_knots=10, |
| 32 | + grid_range=[0, 1], |
| 33 | + base_function=torch.nn.SiLU, |
| 34 | + use_base_linear=True, |
| 35 | + use_bias=True, |
| 36 | + init_scale_spline=1e-2, |
| 37 | + init_scale_base=1.0, |
| 38 | + ): |
| 39 | + """ |
| 40 | + Initialization of the :class:`KANBlock` class. |
| 41 | +
|
| 42 | + :param int input_dimensions: The number of input features. |
| 43 | + :param int output_dimensions: The number of output features. |
| 44 | + :param int spline_order: The order of each spline basis function. |
| 45 | + Default is 3 (cubic splines). |
| 46 | + :param int n_knots: The number of knots for each spline basis function. |
| 47 | + Default is 10. |
| 48 | + :param grid_range: The range for the spline knots. It must be either a |
| 49 | + list or a tuple of the form [min, max]. Default is [0, 1]. |
| 50 | + :type grid_range: list | tuple. |
| 51 | + :param torch.nn.Module base_function: The base activation function to be |
| 52 | + applied to the input before the linear transformation. Default is |
| 53 | + :class:`torch.nn.SiLU`. |
| 54 | + :param bool use_base_linear: Whether to include a linear transformation |
| 55 | + of the base function output. Default is True. |
| 56 | + :param bool use_bias: Whether to include a bias term in the output. |
| 57 | + Default is True. |
| 58 | + :param init_scale_spline: The scale for initializing each spline |
| 59 | + control points. Default is 1e-2. |
| 60 | + :type init_scale_spline: float | int. |
| 61 | + :param init_scale_base: The scale for initializing the base linear |
| 62 | + weights. Default is 1.0. |
| 63 | + :type init_scale_base: float | int. |
| 64 | + :raises ValueError: If ``grid_range`` is not of length 2. |
| 65 | + """ |
| 66 | + super().__init__() |
| 67 | + |
| 68 | + # Check consistency |
| 69 | + check_consistency(base_function, torch.nn.Module, subclass=True) |
| 70 | + check_positive_integer(input_dimensions, strict=True) |
| 71 | + check_positive_integer(output_dimensions, strict=True) |
| 72 | + check_positive_integer(spline_order, strict=True) |
| 73 | + check_positive_integer(n_knots, strict=True) |
| 74 | + check_consistency(use_base_linear, bool) |
| 75 | + check_consistency(use_bias, bool) |
| 76 | + check_consistency(init_scale_spline, (int, float)) |
| 77 | + check_consistency(init_scale_base, (int, float)) |
| 78 | + check_consistency(grid_range, (int, float)) |
| 79 | + |
| 80 | + # Raise error if grid_range is not valid |
| 81 | + if len(grid_range) != 2: |
| 82 | + raise ValueError("Grid must be a list or tuple with two elements.") |
| 83 | + |
| 84 | + # Knots for the spline basis functions |
| 85 | + initial_knots = torch.ones(spline_order) * grid_range[0] |
| 86 | + final_knots = torch.ones(spline_order) * grid_range[1] |
| 87 | + |
| 88 | + # Number of internal knots |
| 89 | + n_internal = max(0, n_knots - 2 * spline_order) |
| 90 | + |
| 91 | + # Internal knots are uniformly spaced in the grid range |
| 92 | + internal_knots = torch.linspace( |
| 93 | + grid_range[0], grid_range[1], n_internal + 2 |
| 94 | + )[1:-1] |
| 95 | + |
| 96 | + # Define the knots |
| 97 | + knots = torch.cat((initial_knots, internal_knots, final_knots)) |
| 98 | + knots = knots.unsqueeze(0).repeat(input_dimensions, 1) |
| 99 | + |
| 100 | + # Define the control points for the spline basis functions |
| 101 | + control_points = ( |
| 102 | + torch.randn( |
| 103 | + input_dimensions, |
| 104 | + output_dimensions, |
| 105 | + knots.shape[-1] - spline_order, |
| 106 | + ) |
| 107 | + * init_scale_spline |
| 108 | + ) |
| 109 | + |
| 110 | + # Define the vectorized spline module |
| 111 | + self.spline = VectorizedSpline( |
| 112 | + order=spline_order, knots=knots, control_points=control_points |
| 113 | + ) |
| 114 | + |
| 115 | + # Initialize the base function |
| 116 | + self.base_function = base_function() |
| 117 | + |
| 118 | + # Initialize the base linear weights if needed |
| 119 | + if use_base_linear: |
| 120 | + self.base_weight = torch.nn.Parameter( |
| 121 | + torch.randn(output_dimensions, input_dimensions) |
| 122 | + * (init_scale_base / (input_dimensions**0.5)) |
| 123 | + ) |
| 124 | + else: |
| 125 | + self.register_parameter("base_weight", None) |
| 126 | + |
| 127 | + # Initialize the bias term if needed |
| 128 | + if use_bias: |
| 129 | + self.bias = torch.nn.Parameter(torch.zeros(output_dimensions)) |
| 130 | + else: |
| 131 | + self.register_parameter("bias", None) |
| 132 | + |
| 133 | + def forward(self, x): |
| 134 | + """ |
| 135 | + Forward pass of the Kolmogorov-Arnold block. The input is passed through |
| 136 | + the spline transformation, optionally combined with a linear |
| 137 | + transformation of the base function output, and then aggregated across |
| 138 | + input dimensions to produce the final output. |
| 139 | +
|
| 140 | + :param x: The input tensor for the model. |
| 141 | + :type x: torch.Tensor | LabelTensor |
| 142 | + :return: The output tensor of the model. |
| 143 | + :rtype: torch.Tensor | LabelTensor |
| 144 | + """ |
| 145 | + y = self.spline(x) |
| 146 | + |
| 147 | + if self.base_weight is not None: |
| 148 | + base_x = self.base_function(x) |
| 149 | + base_out = torch.einsum("bi,oi->bio", base_x, self.base_weight) |
| 150 | + y = y + base_out |
| 151 | + |
| 152 | + # aggregate contributions from all input dimensions |
| 153 | + y = y.sum(dim=1) |
| 154 | + |
| 155 | + if self.bias is not None: |
| 156 | + y = y + self.bias |
| 157 | + |
| 158 | + return y |
0 commit comments