Comprehensive Guide on Grid Search
Start your free 7-days trial now!
Colab Notebook
You can run all the code snippets in this guide with my Colab Notebook
What is grid search?
Grid search is a brute-force technique to find the optimal hyper-parameters for model building, which is an extremely important task in machine learning because the final performance of a model will depend largely on the hyper-parameters. Grid search simply trains and evaluates a model based on all the combinations of the chosen values of the hyper-parameters, and then selects the model that performs the best.
Simple example of using grid search for hyper-parameter tuning
As an example, consider a Random Forest classifier, which has the following two hyper-parameters:
max_depth
: the maximum-depth that a decision tree can go down to.max_features
: the maximum number of features that can be selected at random at each split.
For grid search, we supply values for these hyper-parameters that we want to test. For instance, suppose we wanted to test out the following values:
max_depth: [2,3]max_features: [1,2,3]
Grid search will then select every combination of the hyper-parameters (like a grid) and build the model for each combination using cross validation to obtain its performance. In this case, grid search will test out the following 6 combinations of hyper-parameters:
Combination | max_depth | max_features |
---|---|---|
1 | 2 | 1 |
2 | 2 | 2 |
3 | 2 | 3 |
4 | 3 | 1 |
5 | 3 | 2 |
6 | 3 | 3 |
For each model built, cross validation will return a performance metric like so:
Combination | max_depth | max_features | Accuracy |
---|---|---|---|
1 | 2 | 1 | 80% |
2 | 2 | 2 | 77% |
3 | 2 | 3 | 85% |
4 | 3 | 1 | 88% |
5 | 3 | 2 | 85% |
6 | 3 | 3 | 83% |
Here, we are assuming that we are tackling a classification problem, and the metric we are after is simply the classification accuracy. If our model is a regression model, then we may decide to compute metrics such as mean squared error and mean absolute error instead. Grid search will return the combination of hyper-parameters with the best performance, which in this case is combination 3.
Grid search tests through all combinations of the given hyper-parameter values. This means, for example, that if you have 3 hyper-parameters with 10 different values to test for each, then you will be testing 10*10*10=1000 different combinations. This is computationally expensive, especially for large datasets.
In such cases, we recommend that you perform random search, which randomly selects and tests a number of different combinations to get a general sense of what hyper-parameters work well. Afterwards, you can use grid search to perform granular testing around those hyper-parameter values.
Whether to retrain the model with the combination of training and validation set
After we have obtained our optimal hyper-parameter values, we have two options - we can either:
combine the training and validation set and retrain the random forest model with the optimal hyper-parameters.
use the best performing model obtained during the grid search process.
There are trade-offs between the two approaches:
by retraining on the combined training + validation set, the model can be trained on more data. This typically means that the model will be able to generalize better. If you have a small dataset, then it would make sense to use as many data points as possible for training.
some data scientists prefer not to retrain the model after the validation step, that is, keep the best model trained using only the training set without the validation set. This is because the optimal hyper-parameters obtained using grid search is only optimal for the specific validation set used to measure the performance. If you combine the training and validation set, then the hyper-parameters chosen may no longer be optimal.
Using Python's sklearn to implement grid search
Suppose we wanted to classify the type of an iris given four features (e.g. sepal length) using a Random Forest classifier. As explained above, we can use grid search to tune the two hyper-parameters: max_depth
and max_features
.
We begin by importing the relevant modules:
from sklearn.ensemble import RandomForestClassifierfrom sklearn.model_selection import train_test_splitfrom sklearn.model_selection import GridSearchCVfrom sklearn.metrics import classification_reportfrom sklearn import datasetsimport pandas as pdimport numpy as np
We then read the Iris dataset and convert the data-type to Pandas' DataFrame:
sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) target0 5.1 3.5 1.4 0.2 0.01 4.9 3.0 1.4 0.2 0.02 4.7 3.2 1.3 0.2 0.03 4.6 3.1 1.5 0.2 0.04 5.0 3.6 1.4 0.2 0.0
We then split the data into training and testing sets:
# Break into X (features) and y (target)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=2000)
Number of rows of X_train: 120Number of rows of y_train: 120Number of rows of X_test: 30Number of rows of y_test: 30
We then use the training set to perform grid search:
param_grid = { "max_depth":[2,3], "max_features":[1,2,3]}
# random_state is like the seed - this is for reproducible resultsmodel = RandomForestClassifier(random_state=42)grid_search = GridSearchCV(model, param_grid, cv=5, scoring="accuracy")grid_search.fit(X_train, y_train)
Best hyper-parameters:{'max_depth': 2, 'max_features': 1}Best score:0.9833333333333334
Here, note the following:
Just like in our previous example, we are testing out 6 different combinations of hyper-parameters. Make sure that the keys of the
param_grid
match the keyword argument of the model - in this case,RandomForestClassifier
takes in as argument the keyword argumentsmax_depth
andmax_features
.Since
GridSearchCV
uses cross validation to obtain the performance metric, we need to specify the number of folds withcv
. For a guide on cross validation, click here.The results of the grid search tell us that the best combination of hyper-parameters is
max_depth=2
andmax_features=1
. With these hyper-parameters, the classification accuracy is over 0.98.
Now that we have obtained the optimal hyper-parameters, we can either:
use the best-performing model found in grid-search
retrain the model using training + validation set.
For the first approach, you can obtain the best-performing model like so:
model_optimal = grid_search.best_estimator_
For the second approach, we can build our model using the combination of training + validation set like so:
model_optimal = RandomForestClassifier(max_depth=2, max_features=1, random_state=42)model_optimal.fit(X_train, y_train)
Note that you could also pass in optimal hyper-parameter values using the **
syntax:
model_optimal = RandomForestClassifier(**grid_search.best_params_, random_state=42)
To measure the accuracy using our testing set:
y_test_predicted = model_optimal.predict(X_test)
precision recall f1-score support 0.0 1.00 1.00 1.00 8 1.0 0.88 0.70 0.78 10 2.0 0.79 0.92 0.85 12 accuracy 0.87 30 macro avg 0.89 0.87 0.87 30weighted avg 0.87 0.87 0.86 30
We see that the classification accuracy based on the testing set is 0.87.