top of page
Writer's pictureEkta Aggarwal

Grid Search Explained

Performance of many of the algorithms like Random Forests, Decision Trees, XGBoost etc. are highly dependent upon the value of the parameters being passed on to these algorithms.

For instance, the value of 'number of trees' in Random Forests can extremely impact the results. Not

only that, there exists possibility of having a extremely good training set which is leading to

high accuracy, but the model performs poorly on test set, leading to overfitting. Moreover, these algorithms can have more than one hyperparameter to optimise.


Thus, for this we leverage the approach of K-fold cross validation along with grid search.

To learn how to implement Grid Search in Python you can refer to this article: Grid Search in Python


K-Fold cross validation Explained


Suppose we have 1000 data points, and we want to run K-fold cross validation for K = 5. Thus we will divide our data in 5 equal parts, where each of the parts are not overlapping with each other. Thus, in this way each part will have 200 points each. Let us call the folds as F1, F2, F3, F4 and, F5.

Let us keep the first fold of 200 points aside (F1) and we will build the model on F2, F3, F4, and F5. Now, we will use F1 to get the accuracy. Let us call this accuracy as "Acc1".

Similarly, we will keep F2 aside, and then build the same model with same hyperparameters on F1, F3, F4, and F5. Now, we will use F2 to get the accuracy. Let us call this accuracy as "Acc2".

Similarly we can get Acc1, Acc2, Acc3, Acc4, Acc5 . An average of these accuracies is calculated.


Figure below illustrates the procedure for K-fold cross validation for K=5.


To learn more about k-fold cross validation in detail, you can refer to this article: K-fold cross validation explanation made easy!


Working of Grid Search

Suppose we want to test various values of number of trees (say, ntree) such as 100, 250 and 500 and and number of variables say 4, 6, and 8 in Random Forests.


In grid search we take all the permutation and combinations of these parameters (9 combinations in this case) and then run K-Fold cross validation to get an average accuracy for each of these combinations.


The combination of parameters which leads to highest performance is considered as the best estimate.

Grid Search is a time consuming technique.

Grid Search takes all the possible combination of parameters you have defined and run K-fold cross validation on the models. i.e., say if you have ntree= 100,250 and 500, and max_features= 4, 6, 8 then it will take 9 possible combinations:

  1. (ntree = 100, max_features= 4)

  2. (ntree = 100, max_features= 6)

  3. (ntree = 100, max_features= 8

  4. (ntree = 250, max_features= 4),

  5. (ntree = 250, max_features= 6) ,

  6. (ntree = 250, max_features= 8),

  7. (ntree = 500, max_features= 4)

  8. (ntree = 500, max_features= 6)

  9. (ntree = 500, max_features= 8).

Then it will run K-fold cross validation on all these 9 possible models.

If you want to hypertune another parameters say max_depth = 3,4,5 then it will create 3 X 3 X 3 = 27 different combinations of these 3 parameters and run K-fold cross validation on all of them. Thus, we try to be careful about how many parameters we are tuning and how many values we are trying to check for each parameter, as it can largely increase the run time for our models.


An example


Following example best illustrates the application of Grid search:


Let us assume that we have run 5-fold cross validation was run for each of the ntree = 100,250 and 500 and and number of variables say 4, 6, and 8 in Random Forests.

Thus, it will run K-fold cross validation for all the 9 possible models with these parameters. Then, we calculated the average accuracies for 5-folds for each of the combination of parameters.

We can see that with 250 trees and 6 variables, our average 5-fold cross validation led to highest accuracy of 83%, while all other 8 combinations have lower accuracy. Thus, ntree = 250 and nvar = 6 is the most suitable combination of parameters for the data.


Important Note about performance measure!


In this tutorial we have assumed we are using 'accuracy' as performance measure to hypertune our models. You can use any performance measure, eg. F1-score, precision, recall, AUC,R square, Mean Absolute Percentage Error (MAPE), Mean Percentage Error (MPE), Root Mean Square Error (RMSE) etc., depending upon the problem.


To learn how to implement Grid Search in Python you can refer to this article: Grid Search in Python

Commentaires


bottom of page