Mamba is a new neural net architecture that
is better than transformers at language modelling. Yes that’s right, after reigning supreme
for 7 years, the transformer has finally been dethroned. Well, maybe, so far Mamba has only been tested
at small model sizes up to a few billion parameters, but the results so far are promising! In addition, Mamba uses less compute than
transformers. For an input sequence of n words, Mamba only
uses O(nlog(n)) compute, whereas transformers use O(n^2). So Mamba based language models should allow
for much greater context sizes to be used. In this video we’re going to do a deep dive
of the Mamba architecture, what is it, why does it work so well, and how could you have
gone about designing such an architecture yourself? Usually Mamba is presented as an extension
of something called a state-space model. State-space models are another type of sequence
model that have been steadily gaining popularity over the past few years, but, to be honest,
the theory behind state-space models is massively over-complicated and uses some pretty advanced
mathematics. Fortunately, Mamba can also be understood
as an extension of recurrent neural networks, or RNNs for short, which are much easier to
understand. So in this video we will be taking the RNN
path to understanding Mamba. Now let’s get started: what is a recurrent
neural network? Given a sequence of input vectors, a convolutional
layer applies a neural net to consecutive groups of vectors. The key thing is that the neural net only
sees a small number of vectors at a time, which makes the model easy to train. The downside is that information from vectors
which are far way can’t be combined until many convolutional layers have been applied. This makes it difficult for convolutional
neural nets to understand long range dependencies in their input, and such long-range dependencies
occur all the time in natural language text. To remedy this flaw, the transformer architecture
was invented, which successfully allows a single layer to combine information from vectors
no matter how far away they are. I previously made a video explaining how and
why transformers work in detail, which you can find here. And while transformers work great, they have
a significant limitation, which is that the amount of compute they use is quadratic in
the input length. This isn’t a huge deal for small inputs,
but if you want to have a million vectors in the input, that means you need to do a
million times a million operations, which is a lot. Recurrent neural nets take a completely different
approach to improving convolutional layers. The idea is very simple: instead of applying
the neural net to two consecutive input vectors, you apply it to one input vector and the previous
output of the neural net. This seems like a small change, but it has
profound consequences: each output vector now contains information from all of the input
vectors prior to it, instead of only one previous vector. This final output vector contains information
from every vector in the input, no matter how many there are. And we have not used any more compute than
a convolutional layer. We’ve managed to incorporate long-range
information, for free. This is exactly what we want. Or at least, it would be, if it weren’t
for 2 small problems with RNNs which make them almost impossible to use in practice. The first problem is that, while a recurrent
layer uses the same amount of compute as a convolutional layer, that compute cannot be
paralellized across multiple processors. Even if you have lots of processors available,
you can’t begin evaluating the neural net on an input until all of the previous steps
have finished, because you need to feed the output from the previous step into the neural
net. Compare this to a convolutional layer, where
the neural net only needs to see the original input. You can run the neural net on all inputs at
the same time, so long as you have enough processors available. And since modern hardware, such as GPUs, are
highly specialized for parallel computation with thousands of processors, RNNs are actually
a lot slower than CNNs in practice. In fact RNNs are even slower than transformers,
despite doing less computation. And the second problem, is that RNNs are incredibly
difficult to train. While in theory, a single recurrent layer
can incorporate information from arbitrarily many inputs, in practice, they don’t. Instead, they only learn to incorporate information
from the previous few dozen inputs at most. The idea for RNNs has been around since the
1980s, but because of these 2 problems, RNNs have fallen out of favour, with convolutional
neural nets and transformers being much more successful in practice. In fact, RNNs have hardly been used at all
in the past decade. Until now. Last year, a new paper was published showing
that linear RNNs can avoid both of these problems, and therefore linear RNNs are highly effective
long sequence models. So what is a linear recurrent neural network? Well you simply replace the neural net with
a linear function. This might seem like a bad idea, since linear
functions can only perform relatively simple transformations of their inputs, but we can
make up for it by applying a full neural net to each output vector afterwards. This is similar to how in transformers you
can replace the value neural nets with simple linear functions, and then add neural nets
in between self-attention layers to make up for the lack of non-linear processing power. So just like in a transformer, we will alternate
linear recurrent layers with element wise neural networks. But importantly, by making the recurrent operation
purely linear it becomes possible to solve both of the RNN problems. To start with I’ll explain how a linear
recurrence applied to n vectors can be computed in parallel in just O(log(n)) time. And then I’ll explain how the training issues
that plague regular RNNs can be fixed in linear recurrences. The linear recurrence operator is given by
this formula: to get the i’th output vector you multiply the previous, (i-1)’th, output
vector with a matrix W_y, and add the i’th input vector multiplied by a different matrix
W_x. The entries in the W matrices are the parameters
which will be learned by the model, so they start off as random samples from a normal
distribution centred at 0, and are then updated with gradient descent. And since the W_x matrix is just applied to
each input independently, we can actually just think of it as being part of the previous
layer, so we can simplify our recurrence operator to just add the input x, assuming that a linear
function has already been applied to the input in the previous layer. A linear recurrence is actually a special
case of a more general operation called a scan, so let’s start with the simplest example
of a scan: a cumulative sum. Given a list of n numbers as input, the goal
is to compute the list of partial sums, up to each term. So the i’th item in the output list should
be the sum of of the first i items of the input list. While it is trivial to compute this by simply
adding the numbers together, one at a time, we want to do it in parallel. And it turns out we can do so as follows:
first add together each consecutive pair of numbers. Then, from the resulting list, add together
pairs of numbers which are 2 steps apart. Then 4 steps apart. And 8… and so on, each iteration doubling
the step size, until the step size is as large as the entire input list, which will be after
log(n) steps. This algorithm works because at each iteration,
the i’th output element contains the sum of the previous step size numbers. For example, in the first iteration, each
output number is the sum of the previous 2 terms. In the next iteration, each item contains
the sum of the previous 2 terms plus the sum of the previous 2 terms starting 2 away, that
is the sum of the previous 4 terms. And so on. When the step size is the size of the input,
each output contains the sum of all previous terms, as desired. It’s trivial to see that each iteration
can be computed in parallel, however the different iterations do still need to be computed sequentially,
and there are log(n) iterations. So, if you have n processors, the total run
time of this algorithm is O(log(n)), down from O(n) of the naive sequential version. And this same algorithm works for computing
lists of cumulative applications of any binary operator, not just addition, so long as the
binary operator is associative. Associative means that you can change the
order of application and you’ll still end up with the same result. This is true of addition, which is why our
parallel cumulative sum algorithm works. And it’s also true of of a bunch of other
operations. In particular, this binary operator is associative:
f((W1, x1), (W2, x2)) = (W1*W2, W1*x1+x2). Note that this operator uses a pair of a matrix
and a vector as input and output, instead of just a single number like with addition. And remarkably, performing a scan with this
operator is equivalent to a linear recurrence. We first need to replace our input list of
vectors with a list of pairs, where the first element is the recurrent weight matrix and
the second element is the input vector, but then we just perform the scan as usual. You can check for yourself that this operator
is in fact associative by expanding a few terms in the other order. To summarize, we just need to do our parallel
cumulative sum algorithm with this operator in place of addition, and we get the result
of a linear recurrent layer in just O(log(n)) time. Except for one small problem. If you look closely at this operation, the
way it works is by using the first element of the tuples as a cumulative matrix, which
contains the product of all of the matrices seen so far. That’s why the first element of the output
tuple is the product of the two input matrices. But this means we’re performing a [d, d]
times [d, d] matrix multiplication in every step, where d is the dimension of the vectors. This is really slow. Note that in the original sequential RNN we
didn’t need to keep track of this cumulative matrix, and so we only ever multiply the weight
matrix with a length [d] input vector at each step, which is a O(d^2) operation. But now we have to do a O(d^3) operation in
every step. For standard model sizes, this is easily a
thousand fold increase in computation. And that’s bad. Fortunately, there is a way around this: matrix
diagonalization. You see (almost) every square matrix can be
factored into the product of an invertible matrix P, a diagonal matrix D, and P^-1, so
long as the matrix elements are allowed to be complex numbers. Here’s an example. Note that this middle matrix is diagonal,
that is all elements except for the main diagonal are 0. What’s neat about this is when you multiply
the matrix by itself in this form, the inner P inverse and P terms cancel, and the product
of 2 diagonal matrices is just the diagonal matrix with the product of elements. That is, in order to compute D^2, all you
need to do is square the elements on the main diagonal of D, which can be done in just O(m)
operations, instead of O(m^3), much better. So then, what we can do is represent the recurrent
weight matrix in diagonalized form, which means we only need to use a complex vector
which contains the elements of the main diagonal of D. That is to say, we first apply a complex matrix
P to the input vectors, then perform the linear recurrence with a complex weight vector w,
using element-wise multiplication, and finally apply P^-1 to the output. The result of this will then be equivalent
to a linear recurrence for some real valued weight matrix W. But when computed this way,
the recurrence operator only needs to compute element-wise multiplication between two vectors
to update the cumulative weights, instead of matrix multiplication. When we plug this operator into our parallel
scan algorithm, the total compute is now just O(d*n*log(n)), and the parallel run time is
O(log(n)). Much better. Note that the parameters of this layer are
the complex entries in the recurrent weight vector w and matrix P. In practice you would
just use two separate real numbers to represent the real and imaginary components of each
parameter, which are initialized by sampling from a normal distribution centred at 0, and
updated with gradient descent as usual. Lastly, computing matrix inverses is really
slow, so in practice we don’t bother, and instead just use 2 independent complex matrices
before and after the linear recurrence. This actually makes the model more expressive
than a real valued linear RNN, and it saves computation. But it does mean that the model is no longer
equivalent to a real valued recurrence, and the output can now be a complex number, so
we will need to take the real valued part of the output before passing it to the next
layer. Ok, so we’ve seen how to make linear RNNs
fast for modern hardware, but what about the other problem, that RNNs are very difficult
to train? Before we solve this problem, here’s a quick
recap of why training RNNs is so problematic in the first place: neural nets are trained
by subtracting the gradient of the loss function from each weight in the model. What is the gradient? Well imagine evaluating the neural net, then
increasing the value of a weight by a very small amount, and then evaluating it again. The difference in these scores is (proportional
to) the gradient for that weight, and it tells you how to change the weight to make the neural
net better. So let’s evaluate the gradient of a linear
recurrent layer. Actually to make this a bit easier, let’s
simplify the model and suppose that every input after the first is 0, so we can just
ignore them. When we evaluate the recurrent layer, at each
step the previous output is multiplied by the weight vector, so after n steps the output
vector is equal to the recurrent weight vector to the power of n times the first vector x_1. When we increase the weight by a small amount
and evaluate it again we get this. Taking the difference, we get, up to a constant
scaling factor, w^(n-1) x_1. The problem here is that as n becomes large,
this term, w^(n-1), either gets very small or very large, depending on whether the values
in w are less than or greater than 1. In either case it’s a problem: If the gradient
is very large then the neural net weights change too much, and the existing functionality
already learned by the neural net gets destroyed. If the gradient is very small then the weights
don’t change enough and the neural net doesn’t learn anything at all. This is what makes training RNNs difficult,
while in principle RNNs can use infinitely long context, in practice, with gradient based
training techniques, the RNN will only learn to use context for as many steps as the gradient
remains the right size for learning. This is known as the problem of vanishing
and exploding gradients. And when we add back in non-zero inputs, this
problem only gets worse, as the additional inputs make the gradients even more unstable. And to be clear, the reason why this isn’t
a problem for regular neural nets is because they use different weights in each layer. Some layers can have weights smaller than
1, and some layers can have weights larger than 1, so long as the gradient remains about
the same size, the neural net will be able to learn. There are lots and lots of different configurations
of weights that result in stable gradients, and its easy to stay in stable configurations
all throughout training. But for RNNs, you’re using the same weight
in each step, so there is exactly one stable configuration which is when the weight is
1. Any deviation from 1 and you have exponentially
growing or decaying gradients. Note that when the weights are complex numbers
the same argument applies, just using the absolute value of the weights. So how can we fix vanishing and exploding
gradients? Well, we saw that the RNN gradients are stable
so long as the recurrent weights are 1 and the inputs are 0, so in the linear RNN paper
the authors propose to initialize their linear RNN in this stable state. Specifically, they parameterize the weights
in complex polar form ae^ib, where a is magnitude and b is angle. They then restrict the magnitude to be less
than 1 by running a through this e^(-e^()) function, which always outputs a number between
0 and 1, and instead of randomly sampling a from a normal distribution centred at 0,
as we usually do, they initialize a so that the magnitude e^(-e^(a)) is uniformly distributed
between 0.999 and 1. They initialize the angle, b, uniformly between
0 and /10 radians. This ensures that, at initialization, the
weights are all very close to 1. Finally they multiply the inputs by , which
is another learnable parameter initialized to sqrt(1-e^(-e^(a))), which since e^(-e^a)
is close to one, this is some very small number. This ensures that at initialization the inputs
are all close to 0, and so they don’t interfere with the recurrence. So at initialization, this model is approximately
the same as the stable RNN I showed you before. After the model begins training and the weights
change, there is no guarantee that it will remain stable, but in practice it appears
that just initializing the model like this is sufficient to allow it to learn to remember
context for tens of thousands of steps. And there we have it, with these modifications,
we now have a linear RNN that is fast to compute, and learns to use extremely long context. In the linear RNN paper, they evaluate this
model on the long range arena benchmark, which is a collection of 6 synthetic tasks that
evaluate a model’s ability to perform long range reasoning. For example, in the PathX task the model must
classify images as whether or not they contain a complete dotted path between two circles. Except that the image are flattened into one
long sequence of 16 thousand pixels. The linear RNN achieved state-of-the-art performance
on the long range arena, outperforming transformers by about 33% on average across tasks. So now that we understand the linear RNN,
what’s with all the talk about state-space models? Well, it turns out that state-space models
are just linear RNNs. State space models were inspired by control
theory, and were derived from a totally different idea of trying to discretize a continuous
dynamical system, but the end result is just a linear RNN, with a slightly different initialization
scheme. The most common form of state space model
parameterizes each recurrent weight as w=e^((a+bi)), where is again a learnable parameter which
is initialized to a very small number, usually between 0.0001 and 0.1. Multiplying the weight by a small number makes
it close to 0, and when you take e to the power of something close to 0 you get something
close to 1. This again ensures that at initialization
the recurrent weights are all approximately one, so training is stable. State space models also multiply the inputs
by ((a+bi))^-1(w-1), because that’s what’s prescribed by control theory, but empirically
you get the same performance when you just scale the inputs by as in the linear RNN setup. On the long range arena, the control theory
inspired state-space initialization performs roughly the same as the linear RNN initialization. Anyway, whenever you hear state-space model,
think linear RNN. And finally we can talk about Mamba. You see, while linear RNNs do perform really
well on the long range arena benchmark, this does not mean they are good language models. For language modelling, linear RNNs perform
way worse than transformers. Here is the performance of various state-of-the-art
language models, lower is better on this graph. As you can see, everything, including state-space
models does significantly worse than transformers. The reason for this, as identified in the
Mamba paper, is that linear RNNs are incapable of selectively forgetting information from
the output vector. If the weights are close to 0 then the output
vector will be set to 0 after every input, effectively the model will always immediately
forget whatever came before the current input. If the recurrent weights are close to 1 then
the output vector doesn’t change when its multiplied with the weights, so the output
vector will accumulate information from all inputs observed. What you want is for the model to be able
to decide when to store information and when to forget information, based on what input
it sees. Mamba proposes an elegant solution: instead
of using the same weights in each step, use different weights which depend on the input. Mamba applies a linear function to each input
vector to generate a separate weight vector for that input. Then the recurrent scan is performed using
these generated weights. This way, certain inputs can generate weights
close to 0 and thereby erase information from the output vector, but other inputs can generate
weights close to 1 thereby leaving the output vector unchanged. And I also suspect that using different weights
at each step helps with vanishing and exploding gradients, since there should now be many
different stable configurations, like in feed-forward networks, although this wasn’t mentioned
in the Mamba paper. Mamba also uses one more trick, which is to
increase the size of the output vectors. In a standard RNN the output vectors are the
same size as the input vectors. Mamba expands the size of the output vectors
by a factor of 16. This allows it to store much more information
from previous inputs. The output vectors are then projected back
down to the original size before being passed to the next layer. Usually this would increase the computation
time by a factor of 16, but it turns out that the major bottleneck of a Mamba layer on modern
GPUs is the time it takes to read and write data into high performance memory. You see modern GPUs actually have 2 different
types of memory. Data is stored in main memory, but in order
to do computations, data first needs to be transferred into high-performance memory. For the mamba recurrence operation, it turns
out that the time taken to transfer data is actually much larger than the time it takes
to do the computation itself. Mamba therefore transfers the input vectors
and model parameters to high performance memory, then computes the whole mamba operation in
a single block, including projecting outputs back down to the smaller original size, before
writing the results back to main memory. This way you only transfer vectors of the
original size to and from high performance memory, so the transfer time is unchanged. The actual computation time is 16 times slower,
but the computation time was so small compared to the transfer time that it doesn’t really
affect the overall time taken. You essentially get to use 16 times larger
vectors for free. And there we have it, this, along with a few
minor architecture modifications, is Mamba, the dynamic linear recurrent neural network,
which performs better than transformers at language modelling, while using only O(nlog(n))
compute, down from O(n^2). Ok, now that we’ve made it through all of
those boring technical details, we can finally talk about what really matters: the Mamba
drama. You see, the Mamba paper caused quite a bit
of controversy in the machine learning community this year. The Mamba paper was submitted to ICLR 2024,
which is one of the most prestigious machine learning conferences in the world. And in January, it was rejected by peer reviewers. But so what? Papers get rejected from top conferences all
the time, right. Well, to give some context, the Mamba pre-print
has been publicly available since last year and during this time several different groups
have re-implemented Mamba and all successfully reproduced the results claimed in the Mamba
paper, namely that Mamba performs better than transformers and uses less computation. And since transformers are all anyone has
talked about for the last 5 years, that’s kind of a big deal. Because of this, everyone in the community
was expecting the Mamba paper to be accepted, if not win a best-paper award. So then, if the Mamba architecture really
works, what glaring flaws are in the paper that caused it to be rejected? Well, the ICLR peer review is publicly available
for everyone to view. So let’s take a look. According to the meta review, Mamba wasn’t
tested on the long range arena benchmark. Remember that benchmark I talked about, where
linear RNNs performed way better than transformers? This reviewer wanted to see how well Mamba
performed on that task. Now this is a really dumb reason to reject
a paper, because, the long range arena is a completely different task to language modelling,
and Mamba is specifically a language model. Keep in mind, transformers perform way worse
than linear RNNs on the long range arena, but transformers are still way better language
models. So performance on the long range arena is
in no way indicative of language modelling ability. Mamba sets a new state of the art for language
modelling, it shouldn't be rejected because it doesn’t also solve some other, unrelated
task. The only other major criticism in the meta
review is that Mamba was only evaluated on language modelling, that is the accuracy when
predicting the next word in a piece of text. The reviewers argue that this metric isn’t
indicative of a language model’s utility, and instead Mamba should have been evaluated
on downstream tasks that measure a model’s reasoning ability. Except that... this is exactly what they did
in the Mamba paper, they pre-trained Mamba as a language model and then performed zero-shot
prompting on a bunch of standard down-stream benchmark tasks. And surprise, surprise, Mamba outperforms
all other language models. As a bonus, another reviewer said, and I quote,
“Mamba has a quadratic memory requirement during training, just like transformers”. Which… is just not true. Neither Mamba nor transformers have quadratic
memory costs. Transformers have a quadratic compute cost,
but their memory cost is linear. So is Mamba’s… I’m not sure how its even possible to come
to the conclusion that Mamba has a quadratic memory cost if you understand how it works
at all… So as you can imagine, this less than ideal
peer review sparked some debate in the machine learning community about peer reviewing practices
and whether or not Mamba should have been rejected. You can probably guess which side of the debate
I fall on, but let me know your thoughts about how broken academic peer review is in the
comments below. Or thoughts about the actual Mamba architecture
itself, I guess that’s fine too.