Kolmogorov-Arnold Networks: Mathematical Foundations
674
Computer Science
Technology
The Kolmogorov-Arnold network has recently gained significant attention in the AI community for its groundbreaking potential as an alternative to Multi-Layered Perceptron (MLP). Utilizing the Kolmogorov-Arnold representation theorem, this innovative neural network model can accurately represent any continuous function, revolutionizing fields from machine learning to data science. Its resurgence in research promises to unlock new efficiencies and capabilities in complex computational tasks, positioning it at the forefront of modern AI advancements.
Before diving into how the Kolmogorov-Arnold Network or KAN works, we must understand some mathematical concepts that underline its architecture. The idea of KAN is not new, but for the first time, researchers could use it for learning purposes. This article will discuss topics like Fitting a curve to n-points, Bézier curves, B-splines, and the Kolmogorov-Arnold representation theorem.
<h2><center><b>Data Fitting</h2></center></b>
Well, before discussing data fitting, I would like to discuss the relevance of this topic. Machine Learning is nothing but approximating the function that would most closely fit the "points" or data in the available learning dataset. All the models share this same goal, although their approaches might differ.
Let's say we have n-points, which we can plot in multi-dimensional space. Now, suppose we want to approximate a function that would pass through all these points. What we can do is we can assume a polynomial of n-1 degree that <i>passes through all these points</i> and then solve for the coefficients by putting the values of each point one by one since all of them satisfy the polynomial. For example,
<a href="https://ibb.co/QN77wfp" target="_blank" rel="noopener noreferrer"><img src="https://i.ibb.co/0Mppbcm/data-fitting.png" alt="data-fitting" border="0"></a>
But as you can guess, there are reasons why this method is not generally used for approximating the functions. As the value of n, i.e., the number of points, increases, the computational complexity for calculating the values of the coefficients also increases. For complex tasks like image classification, which might require thousands of images for training, this method becomes inefficient. Another problem that can arise here is that the approximated function is seen to behave "weirdly" near the extremities of the curve (which can be seen in the above example), with the values fluctuating very quickly, as seen in the above image.
<h2><center><b>Bézier Curves</h2></center></b>
Bezier curves are parametric curves with multiple applications in computer graphics. They are used to approximate the functions based on the points in the space that lie on that function. These points are called <b>control points</b>. Since Bezier curves provide approximations, we might not get the exact function that the points belong to. Rather, the approximated curve depends highly on control points, and hence, a different set of control points would provide "similar" but different curves. However, the Bezier curve always passes through the points at the extremities of the dataset.
To identify the Bezier curve, we pick two points or curves at a time and interpolate them together recursively to get one smooth Bezier curve. If the parameter is 't', then let's say the Bezier curve is given by B(t) where t ∈ [0, 1]. The parameter 't' can be considered the percentage of the curve covered from the start point or, more loosely, as time. To understand how Bezier curves are created, we must take some examples. Assume we have two points ℙ<sub>0</sub> and ℙ<sub>1</sub>. Since we have only two points, all we can approximate is a linear function. Here, ℙ<sub>0</sub> and ℙ<sub>1</sub> can have any dimensionality. B(t), in this case, will be,
<center>
<span>B(t) = ℙ<sub>0</sub> + t(ℙ<sub>1</sub> - ℙ<sub>0</sub>)</span>
<span>B(t) = (1-t)ℙ<sub>0</sub> + t.ℙ<sub>1</sub>
</center>
For a general case where we have n+1 control points, we can find an n-degree Bezier curve using the following formula:
<a href="https://ibb.co/C6yjCZ2" target="_blank" rel="noopener noreferrer"><img src="https://i.ibb.co/m0KnZw8/bezier-curve-formula.png" alt="bezier-curve-formula" border="0"></a>
This can be written in a shorter format,
<a href="https://ibb.co/WyHgjyf" target="_blank" rel="noopener noreferrer"><img src="https://i.ibb.co/pz1J5zd/bazier-curve-basis-function.png" alt="bazier-curve-basis-function" border="0"></a>
Here, <i>b<sub>i,n</sub></i> are called <b>Bernstein basis polynomials</b> or just <b>basis polynomials</b> in short.
<h2><center><b>Basis Polynomial Interpretation</h2></center></b>
Basis polynomials provide the weightage or amount of contribution of each point for a given value of parameter <i>t</i>. As <i>t</i> can be seen as time or the percentage of "distance" from the first control point, the basis polynomials tell the contribution of a point ℙ at any given time.
Assuming the number of points (n+1) to be 4, then <i>n</i> is 3. Using the above formula, we can plot the basis polynomials for all the values of <i>i</i> iterating from 0 to <i>n</i> from <i>t</i> ranging from 0 to 1. We will get the following graph on plotting the basis functions.
<a href="https://ibb.co/37pzkkV" target="_blank" rel="noopener noreferrer"><img src="https://i.ibb.co/c16LxxR/basis-function-interpretation.png" alt="basis-function-interpretation" border="0"></a>
One thing to remember is that b<sub>0,3</sub>(t) quantifies the contribution of ℙ<sub>0</sub>, b<sub>1,3</sub>(t) quantifies the contribution of ℙ<sub>1</sub>, and so on till ℙ<sub>3</sub>. As you can see, at t = 0, only the value of b<sub>0,3</sub>(t) is non-zero. Similar is the case at t = 1, where the value of b<sub>3,3</sub>(t) is non-zero. That's the reason why the Bazier curve always passes through the extreme points. We can also see at any value of <i>t</i> moves away from 0, the contribution of ℙ<sub>0</sub> decreases, and contributions of other points increase.
<h2><center><b>B-Splines</h2></center></b>
Though Bezier curves solve some problems related to the first method we saw, i.e., there aren't high fluctuations at the extremities, they are still computationally expensive. They can prove to be inefficient for a large number of control points. The idea behind B-splines is similar to the idea of the Bezier curve, but with a change, making it more efficient. In B-splines, we do not consider all points at once. We pick, let's say, <i>'k'</i> points at a time and identify the Bezier curve for them; we then move one point to the right, keeping the window size as <i>k</i> and identifying the Bezier curve for that window of points. We do this till we reach the last point. Once we reach the end, we interpolate all the Bezier curves to get a single smooth curve, a B-spline.
The window size of the points (equal to k) decides the degree of the Bezier curve for each window, and hence, it's called the <b>order of the B-spline curve</b>. The points at which these Bezier curves are stitched together are called <b>knots</b>. A B-spline curve <b><i>r(t)</i></b> is defined as linear combination of control points, ℙ<sub>i</sub> and B-spline basis function <b><i>N<sub>i,k</sub>(t)</i></b>, given by:
<a href="https://ibb.co/9w5g9T3" target="_blank" rel="noopener noreferrer"><img src="https://i.ibb.co/CJG857s/b-spline.png" alt="b-spline" border="0"></a>
Here, <i>n+1</i> is the number of control points, and <i>k</i> is the order of the B-spline curve. As you might have noticed, the equation here is very similar to the formula for the Bezier curve. The major difference comes in the formulation of the B-spline basis function, though the interpretation of the B-spline basis function is the same as that of the Bezier basis function. You can take a look at the general formula for the value of N<sub>i,k</sub>(t) below. It might seem a bit complicated. This happens because a control point can occur in creating multiple Bezier curves, which are then "summed" together.
<a href="https://ibb.co/qpK4NhV" target="_blank" rel="noopener noreferrer"><img src="https://i.ibb.co/m4kY9dg/b-spline-basis-function.png" alt="b-spline-basis-function" border="0"></a>
The following is an example of B-spline. It shows control points in black circles, and the knots in red points. There are 6 control points in the example and spline has degree 3.
<a href="https://ibb.co/3MHytdh" target="_blank" rel="noopener noreferrer"><img src="https://i.ibb.co/VQ8WX9B/B-spline-wolfram.png" alt="B-spline-wolfram" border="0"></a>
<h2><center><b>Advantages of B-splines Over Bezier Curve</h2></center></b>
Firstly, the B-splines perform much better computationally than the Bezier curves. The nBézier curves are typically evaluated using the <b>de Casteljau algorithm</b>, which has a complexity of O(n<sup>2</sup>) for a curve of degree n. The degree of the curve is dependent on the number of control points. On the other hand, the most common algorithm for evaluating B-splines is <b>De Boor's algorithm</b>, which is efficient and numerically stable. The computational complexity of De Boor's algorithm is O(p<sup>2</sup>), where p is the degree of the B-spline. Since p is related to the order of the B-spline, it is a lot smaller than the number of control points, making the B-spline evaluation computationally more efficient than that of the Bezier curve with the same number of control points.
The second and very important advantage is since multiple smaller and independent Bezier curves calculate the B-spline, if there is a change in some control point or an addition of a control point, the changes happen in the locality of that altered point, and we don't need to change or evaluate the whole B-spline again. On the other hand, since the Bezier curves are calculated considering all the points together, we would need to evaluate the whole curve again.
<h2><center><b>Kolmogorov-Arnold Representation Theorem</h2></center></b>
The works of Vladimir Arnold and Andrey Kolmogorov established that if <i>f is a multivariate continuous function, then f can be written as a finite composition of continuous functions of a single variable and the binary operation of addition</i> in the form of:
<a href="https://ibb.co/SxpXsGR" target="_blank" rel="noopener noreferrer"><img src="https://i.ibb.co/N3z6VbZ/Kolmogrov-Arnold-Representation-Theorem.png" alt="Kolmogrov-Arnold-Representation-Theorem" border="0"></a>
This means that if the above formulation can be represented as a neural network-like architecture, we can learn any multivariate continuous function on a bounded domain. This theorem was proved in the 1950s, but the authors did not provide a way to evaluate the representation from a given set of points. That is something that has been solved with Kolmogorov-Arnold Networks. We will discuss KANs in the <a href="https://www.s-tronomic.in/post/119" target="_blank" rel="noopener noreferrer">next blog</a>.
- Ojas Srivastava, 10:57 PM, 13 Jun, 2024