Regularization. It's just another way to say desensitization. Let's check it out with Ridge Regression. StatQuest.
Hello, I'm Josh Starmer and welcome to StatQuest. Today we're going to do part one of a series of video on regularization techniques. In this video, we're going to cover ridge regression and it's going to be clearly explained.
This StatQuest assumes you understand the concepts of regularization. bias and variance in the context of machine learning. If not, check out Machine Learning Fundamentals, Bias, and Variance. It also assumes that you are familiar with linear models. If not, check out the following stat quests.
The links are in the description below. Lastly, if you're not already familiar with the concept of cross-validation, check out the stat quest on cross-validation. In this stat quest, we will... 1. Look at a simple example that shows the main ideas behind ridge regression. 2. Go into details about how ridge regression works.
Let's start by collecting weight and size measurements from a bunch of mice. Since these data look relatively linear, we will use linear regression, aka least squares, to model the relationship between weight and size. So we'll fit a line to the data using least squares. In other words, we find the line that results in the minimum sum of squared residuals.
Ultimately, we end up with this equation for the line. The line has two parameters. a y-axis intercept, and a slope.
We can plug in a value for weight, for example 2.5, and do the math and get a value for size. Together, the value for weight, 2.5, and the value for size, 2.8, give us a point on the line. When we have a lot of measurements, we can be fairly confident that the least squares line accurately reflects the relationship between size and weight.
But what if we only have two measurements? We fit a new line with least squares. Since the new line overlaps the two data points, the minimum sum of squared residuals equals zero. Ultimately, we end up with this equation for the new line.
Note, here are the original data and the original line for comparison. Let's call the two red dots the training data, and the remaining green dots the testing data. The sum of the squared residuals for just the two red points, the training data, is small. In this case, it is zero.
But the sum of the squared residuals for the green points, the testing data, is large. And that means that the new line has high variance. In machine learning lingo...
We'd say that the new line is overfit to the training data. Now let's go back to just the training data. We just saw that least squares results in a line that is overfit and has high variance.
The main idea behind ridge regression is to find a new line that doesn't fit the training data as well. In other words, we introduce a small amount of bias into how the new line is fit to the data. But in return for that small amount of bias, we get a significant drop in variance. In other words, by starting with a slightly worse fit, ridge regression can provide better long-term predictions. BAM!
Now let's dive into the nitty-gritty and learn how ridge regression works. Let's go back to just the training data. When least squares determines values for the parameters in this equation, it minimizes the sum of the squared residuals. In contrast, when ridge regression determines the values for the parameters in this equation, it minimizes the sum of the squared residuals plus lambda times the slope squared. I usually try to avoid using Greek characters as much as possible.
But if you are ever going to do ridge regression in practice, you have to know that this term is called lambda. This part of the equation adds a penalty to the traditional least squares method, and lambda determines how severe that penalty is. To get a better idea of what's going on, let's plug in some numbers. Let's start by plugging in the numbers that correspond to the least squares fit.
The sum of the squared residuals for the least squares fit is zero, because the line overlaps the data points. and the slope is 1.3. We'll talk more about lambda later, but for now let lambda equal 1. Altogether we have 0 plus 1 times 1.3 squared, and when we do the math we get 1.69. Now let's see what happens when we plug in numbers for the ridge regression line. The sum of the squared residuals is 0.3 squared for this residual plus 0.1 squared for this residual.
The slope is 0.8. And just like before, we'll let lambda equal 1. Altogether, we have 0.3 squared plus 0.1 squared plus 1 times 0.8 squared. And when we do the math, we get...
0.74. For the least squares line, the sum of squared residuals plus the ridge regression penalty is 1.69. For the ridge regression line, the sum of squared residuals plus the ridge regression penalty is 0.74.
Thus, if we wanted to minimize the sum of the squared residuals plus the ridge regression penalty, we would choose the ridge regression line over the least squares line. Without the small amount of bias that the penalty creates, the least squares fit has a large amount of variance. In contrast, the ridge regression line, which has a small amount of bias due to the penalty, has less variance.
Now, before we talk about lambda, let's talk a little bit more about the effect that the ridge regression penalty has on how the line is fit to the data. To keep things simple, imagine we only have one line. This line suggests that for every one unit increase in weight, there is a one unit increase in predicted size.
If the slope of the line is steeper than for every one unit increase in weight, the prediction for size increases by over two units. In other words, when the slope of the line is steep, then the prediction for size is very sensitive to relatively small changes in weight. When the slope is small, then for every one unit increase in weight, the prediction for size barely increases. In other words, when the slope of the line is small, then predictions for size are much less sensitive to changes in weight. Now let's go back to the least squares and ridge regression lines fit to the two data points.
The ridge regression penalty resulted in a line that has a smaller slope, which means that predictions made with the ridge regression line are less sensitive to weight than the least squares line. BAM! Now let's go back to the equation that ridge regression tries to minimize and talk about lambda. Lambda can be any value from 0 to positive infinity.
When lambda equals 0, then the ridge regression penalty is also 0, and that means that the ridge regression line will only minimize the sum of squared residuals, and the ridge regression line will be the same as the least squares line because they are both minimizing the exact same thing. Now let's see what happens as we increase the value for lambda. In the example we just looked at, we set lambda equals 1, and the ridge regression line ended up with a smaller slope than the least squares line.
When we set lambda equals 2, the slope gets even smaller, and when we set lambda equals 3, the slope is even smaller. And the larger we make lambda, the slope gets asymptotically close to zero. So the larger lambda gets, Our prediction for size become less and less sensitive to weight. So how do we decide what value to give lambda? We just try a bunch of values for lambda and use cross-validation, typically 10-fold cross-validation, to determine which one results in the lowest variance.
Double-Bam! In the previous example, we showed how ridge regression would work when we want to predict size, which is a continuous variable, using weight. which is also a continuous variable. However, ridge regression also works when we use a discrete variable like normal diet versus high fat diet to predict size. In this case, the data might look like this, and the least squares fitted equation might look like this, where 1.5, the equivalent of a y-intercept, corresponds to the average size of the mice on the normal diet.
and 0.7, the equivalent of a slope, corresponds to the difference between the average size for the mice on the normal diet compared to the mice on the high-fat diet. From here on out, we'll refer to this distance as diet difference. High-fat diet is either 0 for mice on the normal diet or 1 for mice on the high-fat diet.
In other words, This term alone predicts the size of mice on the normal diet. And the sum of these two terms is the prediction for the size of mice on the high fat diet. For the mice on the normal diet, the residuals are the distances between the mice and the normal diet mean.
And for mice on the high fat diet, the residuals are the distances between the mice and the high fat diet mean. When least squares determines the values for the parameters in this equation, it minimizes the sum of the squared residuals. In other words, these distances between the data and the means are minimized. When ridge regression determines values for the parameters in this equation, it minimizes the sum of the squared residuals plus lambda times diet difference squared.
Remember, Diet difference simply refers to the distance between the mice on the normal diet and the mice on the high-fat diet. When lambda equals zero, this whole term ends up being zero, and we get the same equation that we got with least squares. But when lambda gets large, the only way to minimize the whole equation is to shrink diet distance down. In other words, As lambda gets larger, our prediction for the size of mice on the high-fat diet becomes less sensitive to the difference between the normal diet and the high-fat diet.
And remember, the whole point of doing ridge regression is because small sample sizes like these can lead to poor least squares estimates that result in terrible machine learning predictions. BAM! Ridge regression can also be applied to logistic regression.
In this example, we are using weight to predict if a mouse is obese or not. This is the equation for this logistic regression, and ridge regression would shrink the estimate for the slope, making our prediction about whether or not a mouse is obese less sensitive to weight. When applied to logistic regression, Ridge regression optimizes the sum of the likelihoods instead of the squared residuals because logistic regression is solved using maximum likelihood. So far, we've seen simple examples of how ridge regression helps reduce variance by shrinking parameters and making our predictions less sensitive to them. But we can apply ridge regression to complicated models as well.
In this model, We've combined the weight measurement data from the first example with the two diets from the second example. Combining these two data sets gives us this equation, and ridge regression tries to minimize this. Now the ridge regression penalty contains the parameters for the slope and the difference between diets.
In general, the ridge regression penalty contains all of the parameters except for the y-intercept. If we had a big, huge, crazy equation with terms for astrological sign, the airspeed of a swallow, and other stuff, then the ridge regression penalty would have all those parameters squared except for the y-intercept. Every parameter except for the y-intercept is scaled by the measurements, and that's why the y-intercept is not included in the ridge regression penalty.
Double bam! Okay, now the next thing we're going to talk about is going to sound totally random, but trust me, it will lead to the coolest thing about ridge regression. It's so cool it's almost like magic.
We all know that this is the equation for a line, and in order for least squares to solve for the parameters, the y-intercept and slope, we need at least two data points. These data points result in these parameters and this specific line. If we only had one data point, then we wouldn't be able to solve for these parameters, because there would be no way to tell if this line is better than this line, or this line, or any old line that goes through the one data point.
All of these lines have zero residuals, and thus all minimize the sum of the squared residuals. It's not until we have two data points that it becomes clear that this is the least squares solution. Now let's look at an equation that has three parameters to estimate.
We need to estimate a y-intercept, a slope that reflects how weight contributes to the prediction of size, and a slope that reflects how age contributes to the prediction of size. When we have three parameters to estimate, then just two data points isn't going to cut it. That's because in three dimensions, which is what we get when we add another axis to our graph for age, we have to fit a plane to the data instead of just a line. And with only two data points, there's no reason why this plane fits the data any better than this plane or this plane.
But as soon as we have three data points, we can solve for these parameters. If we have an equation with four parameters, then least squares needs at least four data points to estimate all four parameters. And if we have an equation with 10,001 parameters, then we need at least 10,001 data points to estimate all of the parameters. An equation with 10,001 parameters might sound bonkers, but it's more common than you might expect.
For example, we might use gene expression measurements from 10,000 genes to predict size. And that would mean we would need gene expression measurements from 10000 mice. Unfortunately, collecting gene expression measurements from 10000 mice is crazy expensive and time consuming right now.
In practice, a huge dataset might have measurements from 500 mice. So what do we do if we have an equation with 10000 parameters and only 500 data points? We use ridge regression. It turns out that by adding the ridge regression penalty, we can solve for all 10000 parameters with only 500 or even fewer samples.
One way to think about how ridge regression can solve for parameters when there isn't enough data is to go back to our original size versus weight example. Only this time, there is only one data point in the training set. Least squares can't find a single optimal solution, since any line that goes through the dot will minimize the sum of the squared residuals.
But ridge regression can find a solution with cross-validation and the ridge regression penalty that favors smaller parameter values. Since this stat quest is already super long, we'll save a more thorough discussion of how this works for a future stat quest. Triple bam! In summary, when the sample sizes are relatively small, then ridge regression can improve predictions made from new data, i.e., reduced variance, by making the predictions less sensitive to the training data.
This is done by adding the ridge regression penalty to the thing that must be minimized. The ridge regression penalty itself is lambda times the sum of all squared parameters except for the y-intercept. and lambda is determined using cross-validation. Lastly, even when there isn't enough data to find the least squares parameter estimates, ridge regression can still find a solution using cross-validation and the ridge regression penalty.
Hooray! We've made it to the end of another exciting StatQuest. If you like this StatQuest and want to see more, please subscribe. And if you want to support StatQuest, well, consider buying one or two of my original songs. Alright, until next time, quest on!