A Detailed Introduction To Cross-Validation in Machine Learning

October 11, 2018 By Pascal Schmidt Machine Learning

­­In this blog post I will be giving out a detailed introduction to cross-validation in machine learning. If I had to describe cross-validation in machine learning in one sentence I’d say that:

Cross-validation in machine learning makes sure that our trained model performs well on independent data.

But clearly, there is more to it.

­­­­

For example, the bias-variance trade-off and how it is related to overfitting. In this post, we mentioned ways of how to avoid overfitting with regularization methods and variable selection methods which can help us to find the right amount of bias and variance.

Besides these tools, cross-validation is another method which helps us not to overfit. Hence, find the right amount of bias and variance for our model.

In this blog post, I will be giving a detailed introduction to cross-validation in context of machine learning. I will be going over basics such as training and testing/validation data sets, to describing various methods of cross-validation and how to never overfit anymore. Here is a quick outline of what we will be discussing:

  • What is cross-validation in machine learning?
  • The concept of training data and testing/validation data
  • Hold-out method
  • Leave-one-out cross-validation
  • K-fold cross-validation
  • When cross-validation in machine learning is not helpful
  • Cross-validation alternatives

detailed introduction to cross validation in machine learning

What is Cross-Validation in Machine Learning?

Cross-validation in machine learning is a method to control for overfitting and for finding a good bias-variance trade-off. It is an important tool for predictive models and is a way to estimate the test set error very accurately. So, in short, cross-validation is a model evaluation method.

Therefore, having a bunch of different machine learning algorithms (logistic regression, random forest, etc.) cross-validation lets us evaluate, which of these different algorithms is best suited for a particular type of data set. In addition to that, cross-validation is not only used to pick the most promising algorithm, it is also used to find the perfect tuning parameters. For example, lambda in a lasso/ridge regression model, or the perfect k (number of neighbors) when using the k-nearest neighbors algorithm. Another example would be when implementing a logistic regression model/linear regression model for finding the perfect degree of a polynomials.

All of these procedures (model evaluation, tuning parameters, k neighbors, or degrees of polynomials) are helpful to achieve the lowest test set error. You do all that by applying cross-validation methods.

The Concept of Training Set and Testing/Validation Set

In the above paragraph, I introduced the word test set error. But what does it really mean in context of cross-validation in machine learning?

The aim of machine learning in data science is to learn from data. However, sometimes too much learning is going on. So, what do I mean by that?

Cross-Validation Analogy Part 1

Have you ever studied really hard for a test and despite your huge effort you still bombed it? Don’t worry it happened to me too. Afterwards, I am beating myself up and asking myself what I could have done differently. Maybe I should have looked at the book more closely or focused more at my professor’s slides. What went wrong?

In some cases, I followed the noise too much. I wanted to study everything about the topic and lost the bigger picture. What I should have done was going to my professor and asking him about what he thinks is important. Sometimes, I was too focused on little details and did not consider the fundamentals.

In conclusion, there was too much studying going on in the wrong direction. I was trying to optimize my grade by studying topics that would have been never tested on an exam because they are just so minor and unimportant.

detailed introduction to cross validation in machine learning

So, how does this relate to training and testing/validation data sets and too much learning in statistical learning or machine learning? Well, in this scenario, the training data set is the material given out by the professor (slides, books etc.) and the testing/validation data set is the actual exam from the professor.

To continue this thought further, the training data set is used to learn from data. However, we can learn too much from the training data set. There is a lot of irrelevant information in the data and when we are following the noise in the data too much, we are overfitting. Meaning, we keep learning and optimizing for details in the training data that is irrelevant when it comes to the testing/validation data (the actual exam).

We learned so many things from the data we trained on, but all of this information is not included in the (future) test/validation data. Therefore, we are getting a high accuracy when we test our model on the training data (feel confident for our exam) but a low accuracy for our test/validation data (we bombed the actual exam).

Hence, it very crucial in machine learning to have two data sets one where we can learn from and another one we can test if what we learned was sufficient, if we have to learn more, or if we learned even too much.

The concept of learning and how much learning from the training data is sufficient is a very fine line and even sometimes experts make big mistakes.

In conclusion, a training data set is there for learning and a test/validation data set is supposed to simulate future unknown data. The algorithm you are using should produce a low error on the test/validation set. This error is then a good estimate for the true test set error.

Cross-Validation Analogy Part 2

If you still didn’t get the concept quite right there is another little analogy.

Sometimes, it happens that professors do not teach the concepts of a topic properly and using homework questions or questions answered in class just with different numbers for exams. Ever had such a professor? Me too. So instead of learning the concepts, you are learning how to do the questions.

However, your knowledge about a topic should generalize across all the exams of different professors. So if you get exams from other professors with different questions, you are lost because you don’t understand the concepts. That’s how it works with cross-validation. The knowledge acquired in the training data should generalize across all data sets not just the one you trained on.

So, you might ask yourself now: “How can we get a testing data set?”.

The Validation Set Approach or Hold-out Method

An easy method to simulate test/validation data is to divide the entire data set into a training and validation set. So, you can either do that with a 50/50 split or an 80/20 split or any other split (depending on how much data you have). However, use always more or same amount of data for training than for testing. The training data is there for learning and coming up with an appropriate model based on only the training data. This model is then used to predict the responses for the observations in the validation set. Afterwards, the predicted responses are being compared to the actual responses in the validation set. This error is a “good” estimation of the true test set error (which is unknown). You could also call it the population test set error.

detailed introduction to cross-validation in machine learning hold-out method

Drawbacks of the Validation Set Approach

  1. The variance of the estimated test set error is very high. This is because it depends on which observations are included in the training set and which observations are included in the validation set.
  2. If you split the data, the observations used in the validation set cannot be used for training. This means that you have less observations available to train your algorithm on. This will lead to an overestimation of the test set error because there are fewer observations available.

One important thing to note:

When you evaluate your algorithm, which you have developed from your training data, on the validation set DO NOT GO BACK AND TRY TO IMPROVE YOUR ALGORITHM ANY FURTHER.

This is very very very crucial. Here is a blog post from a Kaggler where it went horribly wrong.

Why shouldn’t you do that?

You should not do that because then you are optimizing your algorithm for the validation set. Meaning, when you go back to your training data to improve your algorithm, then you will include information seen in the validation set. This will bias the estimated test set error. What you should do is create a new random split of training and validation set and try to improve your algorithm now. Then, test your algorithm on the validation set and see if your estimated test set error has improved. If you are still not satisfied create a new random split and start again.

Leave-One-Out Cross-Validation

For this method, we will also divide out data set into 2 parts, a training set and a validation set. However, a little bit differently. This time, only a single observation will be used for the validation and the rest (n-1) observations will be used for training.

So, when we have 10,000 observations, we are training with 9999 observations and predicting on the 10,000th observation. We do this until we have predicted every single observation.

Then we are summing up all 10,000 errors and dividing it by 10,000 to get the average estimated test set error.

Leave-one-out Cross-Validation in machine learning

The estimated test set error will be unbiased (has low bias) because we are using almost all observations for training and only one for validation. However, the estimated test set error will also have high variance. This is because we are only using one observation for predictions in our validation set. Because we are predicting on outliers and other single data points that aren’t representative of the training data, we’ll have a wide-spread for the errors. Hence, a high variance. In addition to that, only one single observation is exchanged between training and validation set. That means that a lot of estimated test set error outputs are highly correlated with each other. Averaging correlated outputs leads to a higher variance than averaging non-correlated outputs.

What we are striving for in machine learning is obtaining a perfect mix of bias and variance. Check out my blog post about the bias-variance trade-off.

Our next method, will give us a better balance of bias and variance.

K-Fold Cross-Validation

Another method for dividing the data set is k-fold cross validation. It is called k-fold cross validation because the data is divided into k folds. Usually one performs cross-validation with k = 5 or k = 10.

So, what we are doing is dividing the data into 5 folds, training the data on 4 folds and testing on the remaining fold. We are repeating this process until we used every single fold for testing.

k-fold cross-validation in machine learning

The advantage over LOOCV is that instead of k = n, we are using k = 5 or k = 10. This means that it is less computationally intensive than LOOCV. In addition to that, the bias-variance trade-off is generally better handled with k-fold cross-validation. The bias will be increased by a little bit because we are testing on 10-20% of the data as opposed to 1/n% for LOOCV. In addition to that, k-fold cross-validation has lower variance because the outputs are less correlated. Each time, we are swapping 10-20% of the data so the training data looks a bit different each time. This leads to less correlated outputs. Hence, to a lower variance.

As always which method is best depends on what you are interested in and what your data set looks like. There is no single best answer.

When Cross-Validation in Machine Learning is not Helpful

So, as we know, cross-validation in machine learning tries to find a statistical learning method/machine learning algorithm, minimization parameter, degree of polynomial or number of neighbors to minimize the estimated test set error.

The methods I described in the above paragraphs are in general a good estimation of the test set error. However, cross-validation in machine learning has no power when we are dealing with a biased data set. This means that no matter how sophisticated our machine learning algorithm chosen by cross-validation or other parameters, the estimated test set error will never be close to the true test set error.

Sometimes, it can happen that the data set at hand has a completely different distribution in comparison to future data sets. In this case, cross-validation can’t give us accurate measures. In fact, no statistical method in the world can deal with that. The key to every good statistical analysis is clean and representative data. If this is not the case, it is important as a statistician/machine learning engineer to know the limitations of the data and communicate them accordingly.

Cross-Validation Alternatives

There are a few alternatives to cross-validation in machine learning. For example, we can use AIC or BIC values to avoid overfitting. In fact, AIC and LOOCV are asymptotically equivalent (as our sample size gets large).

We can illustrate that with an example and a simulated data set here.

In the above link, you can see that AIC, BIC and cross-validation (LOOCV) are choosing the same degree of polynomial.

Feature selection can also be achieved by AIC, BIC or cross-validation. However, cross-validation is not as straight forward as AIC and BIC.

Other Resources

  • If you want to find out more about how to do feature selection, then check out this blog post. There, I am showing how cross-validation in machine learning can go incredibly wrong when not done properly.
  • Another related post is where I am explaining the bias-variance trade-off in machine learning.
  • You might also like my post about parsimony versus accuracy. In this post I am explaining what the difference is and what models are desired in certain situations.

I hope you have enjoyed the detailed introduction to cross-validation in machine learning. If you have any questions or suggestions, please let me know in the comment section below.

Post your comment