Practical Guide to Cross-Validation in Machine Learning
 With an example of Python sklearn

Lianne & Justin

Lianne & Justin

k-fold cross validation boat
Source: Pexels

In this guide, we’ll explain everything about cross-validation in machine learning.

When building machine learning models for production, it’s critical how well the result of the statistical analysis will generalize to independent datasets. Cross-validation is one of the simplest and commonly used techniques that can validate models based on these criteria.

Following this tutorial, you’ll learn:

  • What is cross-validation in machine learning.
  • What is the k-fold cross-validation method.
  • How to use k-fold cross-validation.
  • How to implement cross-validation with Python sklearn, with an example.

If you want to validate your predictive model’s performance before applying it, cross-validation can be critical and handy.

Let’s get started!


If you are new to Python, please take our FREE Python crash course for data science to get a good foundation. This tutorial assumes you have the basic knowledge of Python.



What is Cross-Validation in Machine Learning?

Before introducing cross-validation, let’s review some basic concepts.

Overfitting and Generalization

When we create a machine learning model, we want it to predict well on data outside the sample it’s trained upon. It’s critical to assess how well our models’ performance can be generalized on independent datasets.

Why is generalization so important?

In (supervised) machine learning, the algorithms only target to build models for the training dataset. After building the model, we may measure the model’s performance on the same training data, but it’s usually too optimistic. The model might fail to provide reliable predictions on new data.

This problem is called overfitting.

But we are hoping to apply the same model on new datasets to make predictions in production. So we need the model to generalize its performance well over new datasets. We need to know whether the same model works on data outside of the training set.

So how do we validate our model for this?

Training, Validation, Test sets

The best practice to select and assess the models is to randomly divide the original dataset into three subsets: training, validation, and test datasets. We can:

  • fit the model using the training set
  • select the model based on the models’ performance on the validation set
  • assess the final chosen model to see whether it can be generalized well with the test set.

That sounds like a lot of partitioning of the dataset, which is often of limited quantity. These divisions reduce the number of samples we have for training. Plus, the result might be biased due to the particular random selection of training, validation, and test sets.

That’s why cross-validation is a powerful and useful technique! Let’s see how it works.

Cross-Validation

Instead of splitting into three partitions, we only (randomly) split into training and test sets. We can perform “cross” validation using the training dataset.

Note that an independent test set is still necessary. We need a dataset that hasn’t been touched to assess the final selected model’s performance. So we lock this test set away and only use it at the very end.

The basic cross-validation approach involves different partitions of the training dataset further into sub-training and sub-validation sets. The model is then fitted using the sub-training set while evaluated using the sub-validation (or sub-test) set. This procedure is repeated a few times using different subsets. In the end, the metrics calculated from all the different partitions are combined (e.g., averaging) to estimate the model’s performance.

This might sound a little confusing, but don’t worry, you’ll see examples soon.

We save data resources by cross validating with the training dataset’s sub-partitions, instead of having a separate validation set. CV does take more computations but is also more efficient.

With the general principle of cross-validation, let’s dive into details of the most basic method, the k-fold cross-validation.

K-fold Cross-Validation and its variations

As mentioned earlier, we first split the data into training and test sets. And then, we perform the cross-validation method using the training set.

In a k-fold CV, we further randomly partition the training dataset into k roughly equal-sized smaller sets (folds). Then we iterate the same following procedure for the ith set (i = 1, …, k):

  • train the model using the remaining k-1 folds beside the ith one.
  • calculate a model performance measure (e.g., prediction error) of the fitted model using the validation fold (ith set).

Then we can combine the measure from all the k times of model fitting as a single metric such as average. This single measure is used to decide the model performance.

Note that every observation in the training set is assigned to an individual fold and stays in that fold for the entire validation process. In this way, each observation has the opportunity to be used in the validation fold once and also be used to train the model k – 1 times.

For example, the chart below shows the process of a 5-fold cross-validation. Model one uses the fold 1 for evaluation, and fold 2 – 5 for training. Model two uses fold 2 for evaluation, and the remaining folds for training, and so on.

cross validation in machine learning demonstration of k-fold cross-validation method
scikit-learn

So how to choose the value of K?

Let’s think about the extreme case first.

When k = n, which is the number of the observations, it becomes leave-one-out cross-validation (LOOCV). In this case, each observation becomes a fold. But each training fold becomes very similar to each other, so there’s a high variance in the prediction error measured by the CV. The measure may change a lot based on the different data used to fit the model. Plus, the computation effort is a lot since we need to apply the validation process n times.

As k gets smaller, the training folds will become smaller compared to the validation folds. For example, if K = 2, we fit a training fold with half of the data while validating it using the other half. The prediction measure calculated tends to be more biased, i.e., not representing the real value.

The common practice is to pick k of 5 or 10, because 5-fold or 10-fold cross-validation has shown to produce results with a good balance between bias and variance.

Besides partitioning into k-folds, we could also do stratified or repeated partitions:

  • Stratified k-fold cross-validation: the folds are stratified, i.e., they contain roughly the same percentage of observations for each target class as the complete dataset.
    It’s a good practice to use this method when the target classes are imbalanced.
  • Repeated k-fold cross-validation: the k-fold validation process is repeated multiple times. Each repetition produces different splits of the dataset.

If you are interested in learning about other types of CV, check out the Wikipedia CV types.

When and How to apply K-fold Cross-Validation?

We have the details of cross-validation. Where does this technique fit in our bigger picture of the model building process?

Cross-validation is a useful technique for evaluating and selecting machine learning algorithms/models. This includes helping with tuning the hyperparameters of a particular model.

Assume we want the best performing model among different algorithms:

  • we can pick the algorithm that produces the model with the best CV measure/score.
  • Then we can apply it to the entire training dataset and then use it on the test dataset for the final assessment.
  • If we are good with the test results, we can retrain the model on the entire dataset (train + test) and put it into production.

The flowchart below shows this process as well.

cross validation in machine learning flowchart of the overall process
Cross-Validation Application

Further Reading: Hyperparameter Tuning with Python: Complete Step-by-Step Guide or Hyperparameter Tuning with Python: Keras Step-by-Step Guide
If you want to learn more about using k-fold cross-validation to select the hyperparameter values. Read this practical guide.

The common model evaluation measures are below:

  • For regression problems, the popular ones are mean squared error (MSE) and mean absolute error (MAE).
  • For classification problems, we can use metrics such as accuracy, Area Under ROC Curve, or F1 score.
    The Python scikit-learn (sklearn) library has a lot of different metrics available. Please check them out in the metrics and scoring documentation.

Further Reading: 8 popular Evaluation Metrics for Machine Learning Models

And before we move onto the example, one last note for applying the k-fold cross-validation.

Since we are rotating the data between validation and training folds, we need to make sure that no knowledge of the validation folds “leaks into the training process. This could happen when we apply something to the entire training dataset before doing the cross-validation. For example, we might pick the features based on the whole training set, which will make the CV technique less effective. So we should pick the features based on the training “folds” rather than the entire training set.

Cross-Validation Example with Python sklearn

Finally, we are ready to see an example in Python. We’ll use the scikit-learn (sklearn) library, which provides useful functions for cross-validations.

First, we’ll import the necessary packages/functions.

Next, we generate a random dataset of size 20 with NumPy’s random number generators. The target y has a linear relationship with the input variable X with a random error.

Before implementing the cross-validation method, we split the whole dataset into training and test sets for both input and target variables: X_train, X_test, y_train, and y_test.

With the function train_test_split, we can split 20% as the test set, i.e., 80% as the training set. Note that there’s no set optimal proportion for splitting the dataset, we just picked 80% in this example.

To take a look at the datasets, we can plot the training sets and the test sets.

Now let’s set aside the test set and focus on the training set for cross-validation.

Let’s use k = 5 for this example. So we need to split the training data into five folds. Since there are 16 ( = 20 * 0.8) observations in the training set, each fold won’t have the same size. But they will still have roughly equal sizes of 4, 3, 3, 3, 3.

Using the KFolds cross-validator below, we can generate the indices to split data into five folds with shuffling.

Then we can apply the split function on the training dataset X_train. With loops, the split function returns each set of training and validation folds for the five splits.

For example, the five random folds represented by indices could be:

  • Fold 1: [ 0 4 8 11]
  • Fold 2: [1 3 9]
  • Fold 3: [ 5 6 14]
  • Fold 4: [ 7 12 15]
  • Fold 5: [ 2 10 13]

These indices can be used later to retrieve elements from the dataset. Note that you would most likely have generated different indices since we didn’t set the seed of the random number generators (random_state = None).

Five models will be trained and evaluated, with each fold having a chance of being the holdout set. The list below shows the indices of training and validation folds of each model.

split 0:
training indices: [ 1  2  3  5  6  7  9 10 12 13 14 15]
validation indices: [ 0  4  8 11]
split 1:
training indices: [ 0  2  4  5  6  7  8 10 11 12 13 14 15]
validation indices: [1 3 9]
split 2:
training indices: [ 0  1  2  3  4  7  8  9 10 11 12 13 15]
validation indices: [ 5  6 14]
split 3:
training indices: [ 0  1  2  3  4  5  6  8  9 10 11 13 14]
validation indices: [ 7 12 15]
split 4:
training indices: [ 0  1  3  4  5  6  7  8  9 11 12 14 15]
validation indices: [ 2 10 13]

What model should we fit using these datasets?

We can define a function get_features to create some features based on the input variable(s). For simplicity, we are not creating any transformation of features but only reshape the input variable to make it able to fit a linear regression model.

Further Reading: Linear Regression in Machine Learning: Practical Python Tutorial
If you are not familiar with fitting linear regression models in Python sklearn, take a look at this step-by-step tutorial.

We also define another function get_mse to calculate the Mean Squared Error (MSE), the prediction error measure for our CV.

Now we are ready to implement the 5-fold cross-validation.

For each of the five splits, we print out the MSE measure for both the training and validation folds.

Again, you would most likely have generated different results since random_state = None.

split 0
training error: 103.99274412000705
validation error: 285.6098699374511
split 1
training error: 123.64131378552102
validation error: 253.93417358218412
split 2
training error: 141.88683949717839
validation error: 218.51257520792115
split 3
training error: 178.48538697818125
validation error: 14.362610323043755
split 4
training error: 171.54676696788482
validation error: 56.92373988863912

To combine the results from the 5-fold CV, we can look at the average, standard deviation, and confidence interval of the MSE scores.

K Fold CV Avg.: 165.86859378784786 
K Fold CV Standard Dev.: 109.25928112061015 
K Fold CV CI.: (-52.64996845337245, 384.38715602906814)

That’s it!

We just got the result of a 5-fold CV.


To summarize, you’ve learned cross-validation in machine learning with a focus on the k-fold CV.

Will you try out the k-fold CV in your next data science project?

Leave a comment for any questions you may have or anything else!


Related “Break into Data Science” resources:

Python crash course: breaking into Data Science
A FREE Python online course, beginner-friendly tutorial. Start your successful data science career journey.

Python NumPy Tutorial: Practical Basics for Data Science
This is a beginner-friendly tutorial of Python NumPy (arrays) basics for data science. Learn this essential library with examples.

Learn Python Pandas for Data Science: Quick Tutorial
This is a quick tutorial to learn Python pandas for data science, machine learning. Learn how to better manipulate and analyze data with this guide.

Twitter
LinkedIn
Facebook
Email
Lianne & Justin

Lianne & Justin

Leave a Comment

Your email address will not be published. Required fields are marked *

More recent articles

Scroll to Top

Learn Python for Data Analysis

with a practical online course

lectures + projects

based on real-world datasets

We use cookies to ensure you get the best experience on our website.  Learn more.