Mamba is a new neural net architecture that is better than transformers at language modelling. Yes that’s right, after reigning supreme for 7 years, the transformer has finally been dethroned. Well, maybe, so far Mamba has only been tested at small model sizes up to a few billion parameters, but the results so far are promising! In addition, Mamba uses less compute than transformers. For an input sequence of n words, Mamba only uses O(nlog(n)) compute, whereas transformers use O(n^2). So Mamba based language models should allow for much greater context sizes to be used. In this video we’re going to do a deep dive of the Mamba architecture, what is it, why does it work so well, and how could you have gone about designing such an architecture yourself? Usually Mamba is presented as an extension of something called a state-space model. State-space models are another type of sequence model that have been steadily gaining popularity over the past few years, but, to be honest, the theory behind state-space models is massively over-complicated and uses some pretty advanced mathematics. Fortunately, Mamba can also be understood as an extension of recurrent neural networks, or RNNs for short, which are much easier to understand. So in this video we will be taking the RNN path to understanding Mamba. Now let’s get started: what is a recurrent neural network? Given a sequence of input vectors, a convolutional layer applies a neural net to consecutive groups of vectors. The key thing is that the neural net only sees a small number of vectors at a time, which makes the model easy to train. The downside is that information from vectors which are far way can’t be combined until many convolutional layers have been applied. This makes it difficult for convolutional neural nets to understand long range dependencies in their input, and such long-range dependencies occur all the time in natural language text. To remedy this flaw, the transformer architecture was invented, which successfully allows a single layer to combine information from vectors no matter how far away they are. I previously made a video explaining how and why transformers work in detail, which you can find here. And while transformers work great, they have a significant limitation, which is that the amount of compute they use is quadratic in the input length. This isn’t a huge deal for small inputs, but if you want to have a million vectors in the input, that means you need to do a million times a million operations, which is a lot. Recurrent neural nets take a completely different approach to improving convolutional layers. The idea is very simple: instead of applying the neural net to two consecutive input vectors, you apply it to one input vector and the previous output of the neural net. This seems like a small change, but it has profound consequences: each output vector now contains information from all of the input vectors prior to it, instead of only one previous vector. This final output vector contains information from every vector in the input, no matter how many there are. And we have not used any more compute than a convolutional layer. We’ve managed to incorporate long-range information, for free. This is exactly what we want. Or at least, it would be, if it weren’t for 2 small problems with RNNs which make them almost impossible to use in practice. The first problem is that, while a recurrent layer uses the same amount of compute as a convolutional layer, that compute cannot be paralellized across multiple processors. Even if you have lots of processors available, you can’t begin evaluating the neural net on an input until all of the previous steps have finished, because you need to feed the output from the previous step into the neural net. Compare this to a convolutional layer, where the neural net only needs to see the original input. You can run the neural net on all inputs at the same time, so long as you have enough processors available. And since modern hardware, such as GPUs, are highly specialized for parallel computation with thousands of processors, RNNs are actually a lot slower than CNNs in practice. In fact RNNs are even slower than transformers, despite doing less computation. And the second problem, is that RNNs are incredibly difficult to train. While in theory, a single recurrent layer can incorporate information from arbitrarily many inputs, in practice, they don’t. Instead, they only learn to incorporate information from the previous few dozen inputs at most. The idea for RNNs has been around since the 1980s, but because of these 2 problems, RNNs have fallen out of favour, with convolutional neural nets and transformers being much more successful in practice. In fact, RNNs have hardly been used at all in the past decade. Until now. Last year, a new paper was published showing that linear RNNs can avoid both of these problems, and therefore linear RNNs are highly effective long sequence models. So what is a linear recurrent neural network? Well you simply replace the neural net with a linear function. This might seem like a bad idea, since linear functions can only perform relatively simple transformations of their inputs, but we can make up for it by applying a full neural net to each output vector afterwards. This is similar to how in transformers you can replace the value neural nets with simple linear functions, and then add neural nets in between self-attention layers to make up for the lack of non-linear processing power. So just like in a transformer, we will alternate linear recurrent layers with element wise neural networks. But importantly, by making the recurrent operation purely linear it becomes possible to solve both of the RNN problems. To start with I’ll explain how a linear recurrence applied to n vectors can be computed in parallel in just O(log(n)) time. And then I’ll explain how the training issues that plague regular RNNs can be fixed in linear recurrences. The linear recurrence operator is given by this formula: to get the i’th output vector you multiply the previous, (i-1)’th, output vector with a matrix W_y, and add the i’th input vector multiplied by a different matrix W_x. The entries in the W matrices are the parameters which will be learned by the model, so they start off as random samples from a normal distribution centred at 0, and are then updated with gradient descent. And since the W_x matrix is just applied to each input independently, we can actually just think of it as being part of the previous layer, so we can simplify our recurrence operator to just add the input x, assuming that a linear function has already been applied to the input in the previous layer. A linear recurrence is actually a special case of a more general operation called a scan, so let’s start with the simplest example of a scan: a cumulative sum. Given a list of n numbers as input, the goal is to compute the list of partial sums, up to each term. So the i’th item in the output list should be the sum of of the first i items of the input list. While it is trivial to compute this by simply adding the numbers together, one at a time, we want to do it in parallel. And it turns out we can do so as follows: first add together each consecutive pair of numbers. Then, from the resulting list, add together pairs of numbers which are 2 steps apart. Then 4 steps apart. And 8… and so on, each iteration doubling the step size, until the step size is as large as the entire input list, which will be after log(n) steps. This algorithm works because at each iteration, the i’th output element contains the sum of the previous step size numbers. For example, in the first iteration, each output number is the sum of the previous 2 terms. In the next iteration, each item contains the sum of the previous 2 terms plus the sum of the previous 2 terms starting 2 away, that is the sum of the previous 4 terms. And so on. When the step size is the size of the input, each output contains the sum of all previous terms, as desired. It’s trivial to see that each iteration can be computed in parallel, however the different iterations do still need to be computed sequentially, and there are log(n) iterations. So, if you have n processors, the total run time of this algorithm is O(log(n)), down from O(n) of the naive sequential version. And this same algorithm works for computing lists of cumulative applications of any binary operator, not just addition, so long as the binary operator is associative. Associative means that you can change the order of application and you’ll still end up with the same result. This is true of addition, which is why our parallel cumulative sum algorithm works. And it’s also true of of a bunch of other operations. In particular, this binary operator is associative: f((W1, x1), (W2, x2)) = (W1W2, W1x1+x2). Note that this operator uses a pair of a matrix and a vector as input and output, instead of just a single number like with addition. And remarkably, performing a scan with this operator is equivalent to a linear recurrence. We first need to replace our input list of vectors with a list of pairs, where the first element is the recurrent weight matrix and the second element is the input vector, but then we just perform the scan as usual. You can check for yourself that this operator is in fact associative by expanding a few terms in the other order. To summarize, we just need to do our parallel cumulative sum algorithm with this operator in place of addition, and we get the result of a linear recurrent layer in just O(log(n)) time. Except for one small problem. If you look closely at this operation, the way it works is by using the first element of the tuples as a cumulative matrix, which contains the product of all of the matrices seen so far. That’s why the first element of the output tuple is the product of the two input matrices. But this means we’re performing a [d, d] times [d, d] matrix multiplication in every step, where d is the dimension of the vectors. This is really slow. Note that in the original sequential RNN we didn’t need to keep track of this cumulative matrix, and so we only ever multiply the weight matrix with a length [d] input vector at each step, which is a O(d^2) operation. But now we have to do a O(d^3) operation in every step. For standard model sizes, this is easily a thousand fold increase in computation. And that’s bad. Fortunately, there is a way around this: matrix diagonalization. You see (almost) every square matrix can be factored into the product of an invertible matrix P, a diagonal matrix D, and P^-1, so long as the matrix elements are allowed to be complex numbers. Here’s an example. Note that this middle matrix is diagonal, that is all elements except for the main diagonal are 0. What’s neat about this is when you multiply the matrix by itself in this form, the inner P inverse and P terms cancel, and the product of 2 diagonal matrices is just the diagonal matrix with the product of elements. That is, in order to compute D^2, all you need to do is square the elements on the main diagonal of D, which can be done in just O(m) operations, instead of O(m^3), much better. So then, what we can do is represent the recurrent weight matrix in diagonalized form, which means we only need to use a complex vector which contains the elements of the main diagonal of D. That is to say, we first apply a complex matrix P to the input vectors, then perform the linear recurrence with a complex weight vector w, using element-wise multiplication, and finally apply P^-1 to the output. The result of this will then be equivalent to a linear recurrence for some real valued weight matrix W. But when computed this way, the recurrence operator only needs to compute element-wise multiplication between two vectors to update the cumulative weights, instead of matrix multiplication. When we plug this operator into our parallel scan algorithm, the total compute is now just O(dnlog(n)), and the parallel run time is O(log(n)). Much better. Note that the parameters of this layer are the complex entries in the recurrent weight vector w and matrix P. In practice you would just use two separate real numbers to represent the real and imaginary components of each parameter, which are initialized by sampling from a normal distribution centred at 0, and updated with gradient descent as usual. Lastly, computing matrix inverses is really slow, so in practice we don’t bother, and instead just use 2 independent complex matrices before and after the linear recurrence. This actually makes the model more expressive than a real valued linear RNN, and it saves computation. But it does mean that the model is no longer equivalent to a real valued recurrence, and the output can now be a complex number, so we will need to take the real valued part of the output before passing it to the next layer. Ok, so we’ve seen how to make linear RNNs fast for modern hardware, but what about the other problem, that RNNs are very difficult to train? Before we solve this problem, here’s a quick recap of why training RNNs is so problematic in the first place: neural nets are trained by subtracting the gradient of the loss function from each weight in the model. What is the gradient? Well imagine evaluating the neural net, then increasing the value of a weight by a very small amount, and then evaluating it again. The difference in these scores is (proportional to) the gradient for that weight, and it tells you how to change the weight to make the neural net better. So let’s evaluate the gradient of a linear recurrent layer. Actually to make this a bit easier, let’s simplify the model and suppose that every input after the first is 0, so we can just ignore them. When we evaluate the recurrent layer, at each step the previous output is multiplied by the weight vector, so after n steps the output vector is equal to the recurrent weight vector to the power of n times the first vector x_1. When we increase the weight by a small amount and evaluate it again we get this. Taking the difference, we get, up to a constant scaling factor, w^(n-1) x_1. The problem here is that as n becomes large, this term, w^(n-1), either gets very small or very large, depending on whether the values in w are less than or greater than 1. In either case it’s a problem: If the gradient is very large then the neural net weights change too much, and the existing functionality already learned by the neural net gets destroyed. If the gradient is very small then the weights don’t change enough and the neural net doesn’t learn anything at all. This is what makes training RNNs difficult, while in principle RNNs can use infinitely long context, in practice, with gradient based training techniques, the RNN will only learn to use context for as many steps as the gradient remains the right size for learning. This is known as the problem of vanishing and exploding gradients. And when we add back in non-zero inputs, this problem only gets worse, as the additional inputs make the gradients even more unstable. And to be clear, the reason why this isn’t a problem for regular neural nets is because they use different weights in each layer. Some layers can have weights smaller than 1, and some layers can have weights larger than 1, so long as the gradient remains about the same size, the neural net will be able to learn. There are lots and lots of different configurations of weights that result in stable gradients, and its easy to stay in stable configurations all throughout training. But for RNNs, you’re using the same weight in each step, so there is exactly one stable configuration which is when the weight is