Kolmogorov Arnold Network & Multi-Layered Perceptron
829
Computer Science
Technology
The Kolmogorov-Arnold Network (KAN) has recently gained significant attention in the AI community for its groundbreaking potential as an alternative to Multi-Layered Perceptron (MLP). Utilizing the Kolmogorov-Arnold representation theorem, this innovative neural network model can accurately represent any continuous function, revolutionizing fields from machine learning to data science. Its resurgence in research promises to unlock new efficiencies and capabilities in complex computational tasks, positioning it at the forefront of modern AI advancements.
In the <a href="https://www.s-tronomic.in/post/118" target="_blank" rel="noopener noreferrer">last blog</a>, we discussed the mathematics required to understand the Kolmogorov-Arnold Networks, which will help you understand how KANs work in this article. We also discussed the Kolmogorov-Arnold Representation theorem, which explains why KANs should work. In <a href="https://www.s-tronomic.in/post/33" target="_blank" rel="noopener noreferrer">earlier articles</a>, we have also discussed how MLPs or Artificial Neural Networks work. In this blog, we will discuss the Universal Approximation Theorem, the working of MLP in short, and then compare MLPs to KANs. We will also see some advantages and possible problems with KANs as well. For writing this blog, we heavily referred to the original paper on <a href="https://arxiv.org/abs/2404.19756" target="_blank" rel="noopener noreferrer">KAN</a>.
<h2><b><center>Universal Approximation Theorem</h2></b></center>
Before discussing how Multi-Layered Perceptron works, we should discuss why it works. One of the biggest shortcomings of MLPs is that they are not interpretable. Still, it is also known that neural networks can learn the distribution of any data if appropriate architecture is selected and enough data is present. But how can we say that?
The Universal Approximation Theorem is a fundamental result in neural network theory, stating that a neural network with at least one hidden layer, given sufficient neurons and an appropriate activation function, can approximate any continuous function on a compact subset of R<sup>n</sup> to any desired degree of accuracy.
<h2><b><center>Working of MLPs</h2></b></center>
In this article, we will not discuss the working of MLP in depth as we have already done it in other blogs, which can be read <a href="https://www.s-tronomic.in/post/33" target="_blank" rel="noopener noreferrer">here from part 1</a>. Multi-Layered Perceptron can have any number of layers, but the minimum is 3 (1 hidden). The input from one layer is transformed, and the output is passed to the next layer. The output is a linear combination of inputs. Each input is multiplied by something called weights and summed together. These weights are learnable.
The linear combination is then passed through an activation function, which is generally non-linear and bounded. <b>The activation functions in MLPs are hard-coded in the architecture and cannot be altered.</b> We generally use matrix format to store weights between any two layers as it helps in faster computation.
<a href="https://ibb.co/yszb6p7" target="_blank" rel="noopener noreferrer"><img src="https://i.ibb.co/nkX57Bt/Architecture-of-a-MLP-NN-model.png" alt="Architecture-of-a-MLP-NN-model" border="0"></a>
Activation functions are essential for neural networks to learn complex patterns and relationships in data. Without them, neural networks would be limited to modeling linear relationships between inputs and outputs.
<h2><b><center>Kolmogorov-Arnold Networks</h2></b></center>
The Kolmogorov-Arnold Networks look very similar to the Multi-Layered Perceptrons. Although, unlike MLPs, we do not have learnable weights in KANs. KANs have learnable activation functions. So, the activation functions are not fixed in KANs; rather, they can be altered. And instead of a weight matrix, we will have an activation function matrix in KANs.
So, how do we reach here from learnable weights to learnable activation functions? As discussed in the previous article, the Kolmogorov-Arnold networks are based on the Kolmogorov-Arnold representation theorem. The theorem says that if <i>f</i> is a multivariate continuous function on a bounded domain, it can be written as finite composition continuous functions of a single variable and binary operation of addition. It was formulated as:
<a href="https://ibb.co/SxpXsGR" target="_blank" rel="noopener noreferrer"><img src="https://i.ibb.co/N3z6VbZ/Kolmogrov-Arnold-Representation-Theorem.png" alt="Kolmogrov-Arnold-Representation-Theorem" border="0"></a>
In the formulation, you can see the inner summation has Φ<sub>q,p</sub>(x<sub>p</sub>). And the summed output is the input to outer function Φ<sub>q</sub>. So, as we do for MLPs, we can represent this in neural network form. Assume there are 2 inputs (x<sub>1</sub> and x<sub>2</sub>). Hence, the input size (= n) is 2. So <i>p</i> ranges from 1 to 2, and <i>q</i> ranges from 1 to 5 (because it is 2n+1).
The inner function will come between the input layer and the first hidden layer. Since we have Φ<sub>q,p</sub>(x<sub>p</sub>), each x<sub>p</sub> will pass through five Φ<sub>q,p</sub> (since q ranges from 1 to 5). We can form the first two layers like this:
<a href="https://ibb.co/b6fWD3c" target="_blank" rel="noopener noreferrer"><img src="https://i.ibb.co/GdhxNJz/KAN-1st-layer.png" alt="KAN-1st-layer" border="0"></a>
As you can see, we are directly passing the input through the function instead of weights. Intuitively, we can say these functions identify the essential features from the inputs by passing them through different "filters" in the form of function (Φ<sub>q,p</sub>). These features make up the next layer.
We can call this second layer as a hidden layer. All the values from the hidden layer (there are five nodes in the hidden layer) are passed through the next activation function given by Φ<sub>q</sub> in the Kolmogorov-Arnold Representation theorem. All the outputs are then summed to get the single output in the output layer. The complete network can be seen below:
<a href="https://ibb.co/Tmj0R0D" target="_blank" rel="noopener noreferrer"><img src="https://i.ibb.co/BqW6j6H/KAN-complete.png" alt="KAN-complete" border="0"></a>
<h2><b><center>Learning Activation Functions</h2></b></center>
In the previous section, we understood how the KAN works, but how do we learn the activation functions? So, right now, we have inputs (x<sub>1</sub> and x<sub>2</sub>). And we have 5 nodes in the hidden layer (h<sub>1</sub> to h<sub>5</sub>). We want to learn the function which maps inputs to the values in the nodes in the hidden layer. This is where we will require the B-splines, something that we discussed in the <a href="https://www.s-tronomic.in/post/118" target="_blank" rel="noopener noreferrer">previous article</a>.
Till now, we have represented activation functions with the symbol <i>Φ(x)</i>, and they can be learned using the B-splines method. Here is how we formulate Φ(x) in KANs,
<center>
Φ(x) = w( b(x) + spline(x) )
b(x) = SiLU(x) = <sup>x</sup>/<sub>1+e<sup>-x</sup></sub>
spline(x) = Σ<sub>i</sub> c<sub>i</sub>B<sub>i</sub>(x<sub>i</sub>)
</center>
In the above expression, w(.) is used to control the magnitude of the activation function. It will be redundant in most of the cases. The function b(x) helps in providing some non-zero initialization. It acts as a residual function. Coming to the most important part of the expression, the spline(x), since we <b>do not know the control points of the function</b>, that is something we need to learn and is given by c<sub>i</sub>. <b>B<sub>i</sub>(x<sub>i</sub>) is the B-spline basis function</b> and is fixed (formula given below). Hence, the value of c<sub>i</sub> can be altered and learned.
<a href="https://ibb.co/qpK4NhV" target="_blank" rel="noopener noreferrer"><img src="https://i.ibb.co/m4kY9dg/b-spline-basis-function.png" alt="b-spline-basis-function" border="0"></a>
We use backpropagation to learn the good values of control points (c<sub>i</sub>). We do so by defining a loss function and then altering the values of control points so that loss is reduced iteratively. You can learn about backpropagation in-depth <a href="https://www.s-tronomic.in/post/36">here</a>. <b>As control points are adjusted to reduce the loss, the B-spline representation of the function with those control points becomes more accurate</b>.
<h2><b><center>KAN vs MLP</h2></b></center>
The Universal Approximation theorem backs MLP, while the Kolmogorov-Arnold Representation theorem backs KAN. The universal Approximation theorem was developed after neural nets had proven useful in doing complex tasks to understand why neural networks were so good at what they do, while KANs were developed based on the Kolmogorov-Arnold Representation theorem entirely.
Though in the above example, we have shown KAN with one hidden layer only which was formulated as (Φ<sub>1</sub>∘Φ<sub>2</sub>)(x), we can have multi-layered KANs formulated as (Φ<sub>1</sub>∘Φ<sub>2</sub>∘Φ<sub>3</sub>)(x). Assuming there are <i>L</i> layers in MLP, and each layer has <i>N</i> nodes, then the number of learnable parameters for such a network will be upper bounded by N<sup>2</sup>L or <i>O(N<sup>2</sup>L)</i>.
In KANs, since we are learning B-splines, so apart from the number of layers and the number of nodes on each layer, we have the number of control points per spline, which we take to be <i>G</i>, and the degree of each spline to be <i>k</i>. Then, we have the number of learnable parameters as <i>O(N<sup>2</sup>L(G+k))</i> = <i>O(N<sup>2</sup>L.G)</i>. Hence, for networks of the same size, <b>MLP has fewer learnable parameters than KAN, but the paper claims that KAN generally requires fewer nodes per layer than MLP for the same tasks and achieves better generalization</b>.
KANs have some properties missing in MLPs, such as <b>grid extension, interpretability, and continual learning</b>. These will be covered in more detail in some other article as this one is getting too big. Grid extension means that without making any changes in the architecture, we can make KAN learn more complex functions by increasing the number of control points. Continual learning means not forgetting what has been learned from the previous data before moving to the next phase or task. MLPs tend to forget their learning from the previous phase when moved to the next phase, but this is not the case with KANs, as shown in the image below.
<a href="https://ibb.co/ZTtZgXS" target="_blank" rel="noopener noreferrer"><img src="https://i.ibb.co/x5P4gLY/continual-learning.png" alt="continual-learning" border="0"></a>
<h2><b><center>Possible Shortcomings</h2></b></center>
One of the biggest criticisms received by the <a href="https://arxiv.org/abs/2404.19756" target="_blank" rel="noopener noreferrer">paper</a> is that it didn't show the performance of KANs on standard datasets like MNIST. Hence, we don't know how the KANs will scale with large datasets. Some more general shortcomings include a limited domain of splines. Splines are learned over a limited domain but outside the domain, the splines wouldn't give the correct approximation of the functions. For example, the sine function can be defined in domain (-∞, ∞), but B-splines will not find the approximated function on this whole domain. This can prove to be a problem.
Since B-splines couldn't distinguish between stable and unstable functions, they can learn such functions when applied to real-world problems, which can be another problem. One problem that they have stated in the paper is that KANs are about ten times slower than MLPs at learning (when they both have the same number of nodes). MLPs have been researched a lot, and since then, we have learned quite a lot about them, such as how to scale them efficiently, what types of activation functions to use, etc. More work in KANs will be required to make them applicable for real-world usage.
- Ojas Srivastava, 03:10 PM, 15 Jun, 2024