A Simple Introduction to Cross-Validation
Introduction
Machine learning is the process of creating models that can learn from data and make predictions or decisions. However, machine learning models are not perfect and may not always perform well on new, unseen data. This is because they may be overfitting the training data, which means they memorize the noise and details of the training data but fail to generalize to new data.
To avoid overfitting and ensure that the model can generalize well to new data, we need to evaluate its performance on different subsets of the data. This is where cross validation comes in.
Cross validation is a technique that involves dividing the available data into multiple folds or subsets, using one of these folds as a validation set, and training the model on the remaining folds. This process is repeated multiple times, each time using a different fold as the validation set. Finally, the results from each validation step are averaged to produce a more robust estimate of the model’s performance.
In this article, we will see what cross validation is, why it is used and why it should be used, how to do cross validation using scikit-learn, when to do cross validation, and how it can benefit in the overall model performance and development.
What is Cross Validation?
Cross validation is a resampling method that uses different portions of the data to test and train a model on different iterations. It is mainly used in settings where the goal is prediction, and one wants to estimate how accurately a predictive model will perform in practice.
Cross validation provides information about how well a classifier generalizes, specifically the range of expected errors of the classifier. However, a classifier trained on a high dimensional dataset with no structure may still perform better than expected on cross validation, just by chance.
To avoid this problem, we need to use cross validation with an appropriate number of folds or subsets. The number of folds determines how many times we repeat the process of splitting and testing the data. A larger number of folds means more samples are used for testing each time, which reduces the variance but increases bias. A smaller number of folds means fewer samples are used for testing each time, which increases variance but reduces bias.
There are several types of cross validation techniques that differ in how they split and use the data for testing. Some common types are:
- k-fold cross-validation: In this technique, we divide our input dataset into k smaller sets (called folds). We then use one fold as a test set and k-1 folds as training sets. We repeat this process k times (one for each fold) and average the results from each iteration.
- Leave-one-out cross-validation: In this technique, we treat each sample as its own test set and use all other samples as training sets. We repeat this process k times (one for each sample) and average the results from each iteration.
- Stratified k-fold cross-validation: In this technique, we divide our input dataset into k smaller sets (called folds). We then stratify our dataset by some feature (such as class label) so that each fold has approximately equal proportions of samples from different classes. We then use one fold as a test set and k-1 folds as training sets. We repeat this process k times (one for each fold) and average the results from each iteration.
The following diagram illustrates how k-fold cross validation works:
Why Use Cross Validation?
Cross validation has several advantages over using only one split or test set:
- It reduces overfitting: By using multiple splits or test sets from different parts of the data, we can reduce the chance that our model will memorize the noise or details of the training data but fail to generalize to new data.
- It provides more realistic estimates: By using multiple splits or test sets from different parts of the data, we can get more information about how well our model performs on new, unseen data. We can also compare our model’s performance with other models or methods.
- It helps choose optimal hyperparameters: By using multiple splits or test sets from different parts of the data, we can evaluate how sensitive our model’s performance is to changes in its hyperparameters (such as learning rate, number of layers, number of neurons, etc.). We can also use grid search or random search techniques to find the best combination of hyperparameters that maximizes our model’s performance.
How to Do Cross Validation?
To do cross validation using scikit-learn, here is a simple example:
领英推è
# Import necessary modules
from sklearn.model_selection import KFold, cross_val_score
from sklearn.datasets import load_iris
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score
# Load dataset (Iris dataset as an alternative)
X, y = load_iris(return_X_y=True) # Updated dataset
# Create model (Support Vector Machine)
model = SVC(kernel='linear', C=1)
# Create cross-validation object
cv = KFold(n_splits=5, shuffle=True, random_state=42)
# Perform cross-validation and get scores (using accuracy as the scoring metric)
scores = cross_val_score(model, X, y, cv=cv, scoring='accuracy')
# Calculate mean and standard deviation of scores
mean_accuracy = scores.mean()
std_accuracy = scores.std()
# Print results
print(f'Mean Accuracy: {mean_accuracy:.2f}')
print(f'Standard deviation: {std_accuracy:.2f}')
The output of this code is:
Mean Accuracy: 0.97
Standard deviation: 0.02
This means that the average accuracy of the SVC model on the Iris dataset using 5-fold cross validation is 0.97, with a standard deviation of 0.02. This gives us an idea of how well the model can predict the housing prices on new data, and how much variation there is in its performance.
Now, of course this is a very simple example, but this is just to give an idea on how cross validation works and how it is used.
There are many other libraries or frameworks that support cross validation, such as TensorFlow, PyTorch, Keras, etc. You can find more information and examples on how to use them for cross validation in their documentation or tutorials.
When to Use Cross Validation?
Cross validation is a useful technique for evaluating machine learning models, especially when the available data is limited or the model is complex. However, cross validation is not always necessary or appropriate for every situation. Some factors that may influence the decision to use cross validation are:
- The size and quality of the data: If the data is large and representative enough, a single test set may be sufficient to evaluate the model’s performance. However, if the data is small or noisy, cross validation can help reduce the variance and bias of the estimates.
- The complexity and stability of the model: If the model is simple and stable, meaning that it does not change much with different training sets, cross validation may not provide much additional information. However, if the model is complex and sensitive, meaning that it changes a lot with different training sets, cross validation can help assess its generalization ability and robustness.
- The computational cost and time: Cross validation requires more computation and time than using a single test set, as it involves training and testing the model multiple times. Therefore, cross validation may not be feasible or efficient for models that are very expensive or slow to train and test.
How Can Cross Validation Benefit in the Overall Model Performance and Development?
Cross validation can benefit in the overall model performance and development in several ways, such as:
- It can help select the best model or method among different alternatives, by comparing their cross validation scores and choosing the one with the highest score or the lowest error.
- It can help tune the hyperparameters of the model, by using grid search or random search techniques to find the optimal combination of hyperparameters that maximizes the cross validation score or minimizes the cross validation error.
- It can help prevent overfitting or underfitting, by using regularization techniques such as Lasso or Ridge, which penalize the model’s complexity and reduce its variance, and by using early stopping techniques, which stop the model’s training when the cross validation score or error stops improving or starts worsening.
- It can help improve the model’s performance, by using ensemble techniques such as bagging or boosting, which combine multiple models or methods and reduce their variance and bias, and by using feature selection or extraction techniques, which reduce the dimensionality and noise of the data and improve its quality.
Here are some definitions and links for some of the terms and techniques mentioned above:
- Overfitting: Overfitting occurs when a model learns the details and noise in the training data to the extent that it negatively impacts the performance of the model on new data. This means that the model does not generalize well to new data. You can find more info on scikit-learn website.
- Underfitting: Underfitting occurs when a model is not able to capture the underlying pattern of the data. This means that the model does not fit the data well and has low performance on both the training and the test data. You can find more info on scikit-learn website.
- Variance: Variance is the amount by which the model’s predictions would change if we used a different training data set. High variance means that the model is sensitive to changes in the training data and may overfit the data. You can find more info on scikit-learn website.
- Bias: Bias is the difference between the average prediction of our model and the correct value which we are trying to predict. High bias means that the model is not able to capture the complexity of the data and may underfit the data. You can find more info on scikit-learn website.
- Hyperparameters: Hyperparameters are parameters that are not directly learned by the model, but are set by the user before the training process. They control the behavior and performance of the model, such as the learning rate, the number of layers, the number of neurons, etc. You can find more info on scikit-learn website.
- Lasso: Lasso is a regularization technique that adds a penalty term to the loss function of the model, which is proportional to the absolute value of the coefficients of the model. This means that it shrinks the coefficients of the model to zero, which reduces the number of features used by the model and prevents overfitting. You can find more info on scikit-learn website.
- Ridge: Ridge is a regularization technique that adds a penalty term to the loss function of the model, which is proportional to the square of the coefficients of the model. This means that it reduces the magnitude of the coefficients of the model, which reduces the variance and prevents overfitting. You can find more info on scikit-learn website.
- Bagging: Bagging is an ensemble technique that creates multiple models or methods by randomly sampling from the original data with replacement. Each model or method is trained and tested on a different subset of the data, and the final prediction is obtained by averaging or voting the predictions of each model or method. Bagging reduces the variance and prevents overfitting. You can find more info on scikit-learn website.
- Boosting: Boosting is an ensemble technique that creates multiple models or methods by sequentially adding new models or methods that correct the errors of the previous ones. Each model or method is trained and tested on a different subset of the data, which is weighted according to the performance of the previous models or methods. The final prediction is obtained by combining the predictions of each model or method. Boosting reduces the bias and improves the performance. You can find more info on scikit-learn website.
- Feature selection: Feature selection is a technique that selects a subset of the original features that are most relevant or informative for the prediction task. Feature selection reduces the dimensionality and noise of the data and improves the quality and performance of the model. You can find more info on scikit-learn website.
- Feature extraction: Feature extraction is a technique that transforms the original features into a new set of features that are more suitable or meaningful for the prediction task. Feature extraction reduces the dimensionality and noise of the data and improves the quality and performance of the model. You can find more info on scikit-learn website.
Conclusion
Cross validation is a technique that involves dividing the available data into multiple folds or subsets, using one of these folds as a validation set, and training the model on the remaining folds. This process is repeated multiple times, each time using a different fold as the validation set. Finally, the results from each validation step are averaged to produce a more robust estimate of the model’s performance.
Cross validation has several advantages over using only one split or test set, such as reducing overfitting, providing more realistic estimates, and helping choose optimal hyperparameters. However, cross validation also has some limitations and challenges, such as requiring more computation and time, and depending on the choice of the number and type of folds.
Cross validation can benefit in the overall model performance and development, by helping select the best model or method, tune the hyperparameters, prevent overfitting or underfitting, and improve the model’s performance.
If you would like to see a more detailed post and examples on cross-validation, let me know in the comments and hit the follow button!
If you like this post, consider giving it a like. Wanna see more posts like this? Hit the follow button!