Skip to content

Commit 54a4d90

Browse files
Merge pull request #764 from mathLab/0.3-kan
Kolmogorov-Arnold Networks and Vectorized Splines
2 parents 77ea0c4 + 6caa873 commit 54a4d90

13 files changed

Lines changed: 1463 additions & 7 deletions

File tree

docs/source/_rst/_code.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ Models
110110
PirateNet <model/pirate_network.rst>
111111
EquivariantGraphNeuralOperator <model/equivariant_graph_neural_operator.rst>
112112
SINDy <model/sindy.rst>
113+
Vectorized Spline <model/vectorized_spline.rst>
114+
Kolmogorov-Arnold Network <model/kolmogorov_arnold_network.rst>
113115

114116
Blocks
115117
-------------
@@ -128,6 +130,7 @@ Blocks
128130
Continuous Convolution Block <model/block/convolution.rst>
129131
Orthogonal Block <model/block/orthogonal.rst>
130132
PirateNet Block <model/block/pirate_network_block.rst>
133+
KAN Block <model/block/kan_block.rst>
131134

132135
Message Passing
133136
-------------------
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
KANBlock
2+
=======================
3+
.. currentmodule:: pina.model.block.kan_block
4+
5+
.. autoclass:: pina._src.model.block.kan_block.KANBlock
6+
:members:
7+
:show-inheritance:
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
KolmogorovArnoldNetwork
2+
===========================
3+
.. currentmodule:: pina.model.kolmogorov_arnold_network
4+
5+
.. autoclass:: pina._src.model.kolmogorov_arnold_network.KolmogorovArnoldNetwork
6+
:members:
7+
:show-inheritance:
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
VectorizedSpline
2+
=======================
3+
.. currentmodule:: pina.model.vectorized_spline
4+
5+
.. autoclass:: pina._src.model.vectorized_spline.VectorizedSpline
6+
:members:
7+
:show-inheritance:

pina/_src/model/block/kan_block.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
"""Module for the Kolmogorov-Arnold Network block."""
2+
3+
import torch
4+
from pina._src.model.vectorized_spline import VectorizedSpline
5+
from pina._src.core.utils import check_consistency, check_positive_integer
6+
7+
8+
class KANBlock(torch.nn.Module):
9+
"""
10+
The inner block of the Kolmogorov-Arnold Network (KAN).
11+
12+
The block applies a spline transformation to the input, optionally combined
13+
with a linear transformation of a base activation function. The output is
14+
aggregated across input dimensions to produce the final output.
15+
16+
.. seealso::
17+
18+
**Original reference**:
19+
Liu Z., Wang Y., Vaidya S., Ruehle F., Halverson J., Soljacic M.,
20+
Hou T., Tegmark M. (2025).
21+
*KAN: Kolmogorov-Arnold Networks*.
22+
DOI: `arXiv preprint arXiv:2404.19756.
23+
<https://arxiv.org/abs/2404.19756>`_
24+
"""
25+
26+
def __init__(
27+
self,
28+
input_dimensions,
29+
output_dimensions,
30+
spline_order=3,
31+
n_knots=10,
32+
grid_range=[0, 1],
33+
base_function=torch.nn.SiLU,
34+
use_base_linear=True,
35+
use_bias=True,
36+
init_scale_spline=1e-2,
37+
init_scale_base=1.0,
38+
):
39+
"""
40+
Initialization of the :class:`KANBlock` class.
41+
42+
:param int input_dimensions: The number of input features.
43+
:param int output_dimensions: The number of output features.
44+
:param int spline_order: The order of each spline basis function.
45+
Default is 3 (cubic splines).
46+
:param int n_knots: The number of knots for each spline basis function.
47+
Default is 10.
48+
:param grid_range: The range for the spline knots. It must be either a
49+
list or a tuple of the form [min, max]. Default is [0, 1].
50+
:type grid_range: list | tuple.
51+
:param torch.nn.Module base_function: The base activation function to be
52+
applied to the input before the linear transformation. Default is
53+
:class:`torch.nn.SiLU`.
54+
:param bool use_base_linear: Whether to include a linear transformation
55+
of the base function output. Default is True.
56+
:param bool use_bias: Whether to include a bias term in the output.
57+
Default is True.
58+
:param init_scale_spline: The scale for initializing each spline
59+
control points. Default is 1e-2.
60+
:type init_scale_spline: float | int.
61+
:param init_scale_base: The scale for initializing the base linear
62+
weights. Default is 1.0.
63+
:type init_scale_base: float | int.
64+
:raises ValueError: If ``grid_range`` is not of length 2.
65+
"""
66+
super().__init__()
67+
68+
# Check consistency
69+
check_consistency(base_function, torch.nn.Module, subclass=True)
70+
check_positive_integer(input_dimensions, strict=True)
71+
check_positive_integer(output_dimensions, strict=True)
72+
check_positive_integer(spline_order, strict=True)
73+
check_positive_integer(n_knots, strict=True)
74+
check_consistency(use_base_linear, bool)
75+
check_consistency(use_bias, bool)
76+
check_consistency(init_scale_spline, (int, float))
77+
check_consistency(init_scale_base, (int, float))
78+
check_consistency(grid_range, (int, float))
79+
80+
# Raise error if grid_range is not valid
81+
if len(grid_range) != 2:
82+
raise ValueError("Grid must be a list or tuple with two elements.")
83+
84+
# Knots for the spline basis functions
85+
initial_knots = torch.ones(spline_order) * grid_range[0]
86+
final_knots = torch.ones(spline_order) * grid_range[1]
87+
88+
# Number of internal knots
89+
n_internal = max(0, n_knots - 2 * spline_order)
90+
91+
# Internal knots are uniformly spaced in the grid range
92+
internal_knots = torch.linspace(
93+
grid_range[0], grid_range[1], n_internal + 2
94+
)[1:-1]
95+
96+
# Define the knots
97+
knots = torch.cat((initial_knots, internal_knots, final_knots))
98+
knots = knots.unsqueeze(0).repeat(input_dimensions, 1)
99+
100+
# Define the control points for the spline basis functions
101+
control_points = (
102+
torch.randn(
103+
input_dimensions,
104+
output_dimensions,
105+
knots.shape[-1] - spline_order,
106+
)
107+
* init_scale_spline
108+
)
109+
110+
# Define the vectorized spline module
111+
self.spline = VectorizedSpline(
112+
order=spline_order, knots=knots, control_points=control_points
113+
)
114+
115+
# Initialize the base function
116+
self.base_function = base_function()
117+
118+
# Initialize the base linear weights if needed
119+
if use_base_linear:
120+
self.base_weight = torch.nn.Parameter(
121+
torch.randn(output_dimensions, input_dimensions)
122+
* (init_scale_base / (input_dimensions**0.5))
123+
)
124+
else:
125+
self.register_parameter("base_weight", None)
126+
127+
# Initialize the bias term if needed
128+
if use_bias:
129+
self.bias = torch.nn.Parameter(torch.zeros(output_dimensions))
130+
else:
131+
self.register_parameter("bias", None)
132+
133+
def forward(self, x):
134+
"""
135+
Forward pass of the Kolmogorov-Arnold block. The input is passed through
136+
the spline transformation, optionally combined with a linear
137+
transformation of the base function output, and then aggregated across
138+
input dimensions to produce the final output.
139+
140+
:param x: The input tensor for the model.
141+
:type x: torch.Tensor | LabelTensor
142+
:return: The output tensor of the model.
143+
:rtype: torch.Tensor | LabelTensor
144+
"""
145+
y = self.spline(x)
146+
147+
if self.base_weight is not None:
148+
base_x = self.base_function(x)
149+
base_out = torch.einsum("bi,oi->bio", base_x, self.base_weight)
150+
y = y + base_out
151+
152+
# aggregate contributions from all input dimensions
153+
y = y.sum(dim=1)
154+
155+
if self.bias is not None:
156+
y = y + self.bias
157+
158+
return y
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import torch
2+
from pina._src.model.block.kan_block import KANBlock
3+
from pina._src.core.utils import check_consistency
4+
5+
6+
class KolmogorovArnoldNetwork(torch.nn.Module):
7+
"""
8+
Implementation of Kolmogorov-Arnold Network (KAN).
9+
10+
The model consists of a sequence of KAN blocks, where each block applies a
11+
spline transformation to the input, optionally combined with a linear
12+
transformation of a base activation function.
13+
14+
.. seealso::
15+
16+
**Original reference**:
17+
Liu Z., Wang Y., Vaidya S., Ruehle F., Halverson J., Soljacic M.,
18+
Hou T., Tegmark M. (2025).
19+
*KAN: Kolmogorov-Arnold Networks*.
20+
DOI: `arXiv preprint arXiv:2404.19756.
21+
<https://arxiv.org/abs/2404.19756>`_
22+
"""
23+
24+
def __init__(
25+
self,
26+
layers,
27+
spline_order=3,
28+
n_knots=10,
29+
grid_range=[-1, 1],
30+
base_function=torch.nn.SiLU,
31+
use_base_linear=True,
32+
use_bias=True,
33+
init_scale_spline=1e-2,
34+
init_scale_base=1.0,
35+
):
36+
"""
37+
Initialization of the :class:`KolmogorovArnoldNetwork` class.
38+
39+
:param layers: A list of integers specifying the sizes of each layer,
40+
including input and output dimensions.
41+
:type layers: list | tuple.
42+
:param int spline_order: The order of each spline basis function.
43+
Default is 3 (cubic splines).
44+
:param int n_knots: The number of knots for each spline basis function.
45+
Default is 3.
46+
:param grid_range: The range for the spline knots. It must be either a
47+
list or a tuple of the form [min, max]. Default is [0, 1].
48+
:type grid_range: list | tuple.
49+
:param torch.nn.Module base_function: The base activation function to be
50+
applied to the input before the linear transformation. Default is
51+
:class:`torch.nn.SiLU`.
52+
:param bool use_base_linear: Whether to include a linear transformation
53+
of the base function output. Default is True.
54+
:param bool use_bias: Whether to include a bias term in the output.
55+
Default is True.
56+
:param init_scale_spline: The scale for initializing each spline
57+
control points. Default is 1e-2.
58+
:type init_scale_spline: float | int.
59+
:param init_scale_base: The scale for initializing the base linear
60+
weights. Default is 1.0.
61+
:type init_scale_base: float | int.
62+
:raises ValueError: If ``grid_range`` is not of length 2.
63+
"""
64+
super().__init__()
65+
66+
# Check consistency -- all other checks are performed in KANBlock
67+
check_consistency(layers, int)
68+
if len(layers) < 2:
69+
raise ValueError(
70+
"`Provide at least two elements for layers (input and output)."
71+
)
72+
73+
# Initialize KAN blocks
74+
self.kan_layers = torch.nn.ModuleList(
75+
[
76+
KANBlock(
77+
input_dimensions=layers[i],
78+
output_dimensions=layers[i + 1],
79+
spline_order=spline_order,
80+
n_knots=n_knots,
81+
grid_range=grid_range,
82+
base_function=base_function,
83+
use_base_linear=use_base_linear,
84+
use_bias=use_bias,
85+
init_scale_spline=init_scale_spline,
86+
init_scale_base=init_scale_base,
87+
)
88+
for i in range(len(layers) - 1)
89+
]
90+
)
91+
92+
def forward(self, x):
93+
"""
94+
Forward pass of the KolmogorovArnoldNetwork model. It passes the input
95+
through each KAN block in the network and returns the final output.
96+
97+
:param x: The input tensor for the model.
98+
:type x: torch.Tensor | LabelTensor
99+
:return: The output tensor of the model.
100+
:rtype: torch.Tensor | LabelTensor
101+
"""
102+
for layer in self.kan_layers:
103+
x = layer(x)
104+
105+
return x

0 commit comments

Comments
 (0)