Writing Custom Cross-Validation Methods For Grid Search in Scikit-learn

Recently I was interested in applying Blocking Time Series Split following this lovely post in a Grid Search hyper-parameter tuning setting using scikit-learn library to maintain the time order and prevent information leakage. In this post, I will try to document some knowledge that I build while reading through the articles, documentation, and blog posts about custom cross-validation generators in Python.

It is great that scikit-learn provides a class called TimeSeriesSplit, and by using that we can generate fixed time interval training and test sets. Here is a basic example using scikit-learn data generators. I generate a regression dataset with 5 features and 30 samples. Then I generate 3 splits. For those 3 splits, we obtain 10 training examples and n_samples//(n_splits + 1) test examples:

<span>import</span> <span>numpy</span> <span>as</span> <span>np</span>
<span>from</span> <span>sklearn.datasets</span> <span>import</span> <span>make_regression</span>
<span>from</span> <span>sklearn.model_selection</span> <span>import</span> <span>TimeSeriesSplit</span>
<span>X_experiment</span><span>,</span> <span>y_experiment</span> <span>=</span> <span>make_regression</span><span>(</span>
<span>n_samples</span><span>=</span><span>30</span><span>,</span> <span>n_features</span><span>=</span><span>5</span><span>,</span> <span>noise</span><span>=</span><span>0.2</span><span>)</span>
<span>tscv</span> <span>=</span> <span>TimeSeriesSplit</span><span>(</span><span>max_train_size</span><span>=</span><span>10</span><span>,</span> <span>n_splits</span><span>=</span><span>3</span><span>)</span>
<span>for</span> <span>idx</span><span>,</span> <span>(</span><span>x</span><span>,</span> <span>y</span><span>)</span> <span>in</span> <span>enumerate</span><span>(</span><span>tscv</span><span>.</span><span>split</span><span>(</span><span>X_experiment</span><span>)):</span>
<span>print</span><span>(</span><span>f</span><span>"</span><span>Split number: </span><span>{</span><span>idx</span><span>}</span><span>"</span><span>)</span>
<span>print</span><span>(</span><span>f</span><span>"</span><span>Training indices: </span><span>{</span><span>x</span><span>}</span><span>"</span><span>)</span>
<span>print</span><span>(</span><span>f</span><span>"</span><span>Test indices: </span><span>{</span><span>y</span><span>}</span><span>\n</span><span>"</span><span>)</span>
<span>import</span> <span>numpy</span> <span>as</span> <span>np</span>
<span>from</span> <span>sklearn.datasets</span> <span>import</span> <span>make_regression</span>
<span>from</span> <span>sklearn.model_selection</span> <span>import</span> <span>TimeSeriesSplit</span>

<span>X_experiment</span><span>,</span> <span>y_experiment</span> <span>=</span> <span>make_regression</span><span>(</span>
    <span>n_samples</span><span>=</span><span>30</span><span>,</span> <span>n_features</span><span>=</span><span>5</span><span>,</span> <span>noise</span><span>=</span><span>0.2</span><span>)</span>

<span>tscv</span> <span>=</span> <span>TimeSeriesSplit</span><span>(</span><span>max_train_size</span><span>=</span><span>10</span><span>,</span> <span>n_splits</span><span>=</span><span>3</span><span>)</span>

<span>for</span> <span>idx</span><span>,</span> <span>(</span><span>x</span><span>,</span> <span>y</span><span>)</span> <span>in</span> <span>enumerate</span><span>(</span><span>tscv</span><span>.</span><span>split</span><span>(</span><span>X_experiment</span><span>)):</span>
    <span>print</span><span>(</span><span>f</span><span>"</span><span>Split number: </span><span>{</span><span>idx</span><span>}</span><span>"</span><span>)</span>
    <span>print</span><span>(</span><span>f</span><span>"</span><span>Training indices: </span><span>{</span><span>x</span><span>}</span><span>"</span><span>)</span>
    <span>print</span><span>(</span><span>f</span><span>"</span><span>Test indices: </span><span>{</span><span>y</span><span>}</span><span>\n</span><span>"</span><span>)</span>
import numpy as np from sklearn.datasets import make_regression from sklearn.model_selection import TimeSeriesSplit X_experiment, y_experiment = make_regression( n_samples=30, n_features=5, noise=0.2) tscv = TimeSeriesSplit(max_train_size=10, n_splits=3) for idx, (x, y) in enumerate(tscv.split(X_experiment)): print(f"Split number: {idx}") print(f"Training indices: {x}") print(f"Test indices: {y}\n")

Enter fullscreen mode Exit fullscreen mode

Here the output will be, and it will follow a Walk Forward Cross Validation pattern:

Split number: 0
Training indices: [0 1 2 3 4 5 6 7 8]
Test indices: [ 9 10 11 12 13 14 15]
Split number: 1
Training indices: [ 6 7 8 9 10 11 12 13 14 15]
Test indices: [16 17 18 19 20 21 22]
Split number: 2
Training indices: [13 14 15 16 17 18 19 20 21 22]
Test indices: [23 24 25 26 27 28 29]
Split number: 0
Training indices: [0 1 2 3 4 5 6 7 8]
Test indices: [ 9 10 11 12 13 14 15]

Split number: 1
Training indices: [ 6  7  8  9 10 11 12 13 14 15]
Test indices: [16 17 18 19 20 21 22]

Split number: 2
Training indices: [13 14 15 16 17 18 19 20 21 22]
Test indices: [23 24 25 26 27 28 29]
Split number: 0 Training indices: [0 1 2 3 4 5 6 7 8] Test indices: [ 9 10 11 12 13 14 15] Split number: 1 Training indices: [ 6 7 8 9 10 11 12 13 14 15] Test indices: [16 17 18 19 20 21 22] Split number: 2 Training indices: [13 14 15 16 17 18 19 20 21 22] Test indices: [23 24 25 26 27 28 29]

Enter fullscreen mode Exit fullscreen mode

However, the setting that I found was using dates instead of timestamps. This was leading to discrete numeric values as anchor points for cross-validation splits, instead of continuous. Hence, I was not able to leverage the TimeSeriesSplit from scikit-learn. Instead, I wrote a simple generator object with groupings for date splits to use in Grid Search.

<span>class</span> <span>CustomCrossValidation</span><span>:</span>
<span>@classmethod</span>
<span>def</span> <span>split</span><span>(</span><span>cls</span><span>,</span>
<span>X</span><span>:</span> <span>pd</span><span>.</span><span>DataFrame</span><span>,</span>
<span>y</span><span>:</span> <span>np</span><span>.</span><span>ndarray</span> <span>=</span> <span>None</span><span>,</span>
<span>groups</span><span>:</span> <span>np</span><span>.</span><span>ndarray</span> <span>=</span> <span>None</span><span>):</span>
<span>"""</span><span>Returns to a grouped time series split generator.</span><span>"""</span>
<span>assert</span> <span>len</span><span>(</span><span>X</span><span>)</span> <span>==</span> <span>len</span><span>(</span><span>groups</span><span>),</span> <span>(</span>
<span>"</span><span>Length of the predictors is not</span><span>"</span>
<span>"</span><span>matching with the groups.</span><span>"</span><span>)</span>
<span># The min max index must be sorted in the range </span> <span>for</span> <span>group_idx</span> <span>in</span> <span>range</span><span>(</span><span>groups</span><span>.</span><span>min</span><span>(),</span> <span>groups</span><span>.</span><span>max</span><span>()):</span>
<span>training_group</span> <span>=</span> <span>group_idx</span>
<span># Gets the next group right after </span> <span># the training as test </span> <span>test_group</span> <span>=</span> <span>group_idx</span> <span>+</span> <span>1</span>
<span>training_indices</span> <span>=</span> <span>np</span><span>.</span><span>where</span><span>(</span>
<span>groups</span> <span>==</span> <span>training_group</span><span>)[</span><span>0</span><span>]</span>
<span>test_indices</span> <span>=</span> <span>np</span><span>.</span><span>where</span><span>(</span><span>groups</span> <span>==</span> <span>test_group</span><span>)[</span><span>0</span><span>]</span>
<span>if</span> <span>len</span><span>(</span><span>test_indices</span><span>)</span> <span>></span> <span>0</span><span>:</span>
<span># Yielding to training and testing indices </span> <span># for cross-validation generator </span> <span>yield</span> <span>training_indices</span><span>,</span> <span>test_indices</span>
<span>class</span> <span>CustomCrossValidation</span><span>:</span>

    <span>@classmethod</span>
    <span>def</span> <span>split</span><span>(</span><span>cls</span><span>,</span>
              <span>X</span><span>:</span> <span>pd</span><span>.</span><span>DataFrame</span><span>,</span>
              <span>y</span><span>:</span> <span>np</span><span>.</span><span>ndarray</span> <span>=</span> <span>None</span><span>,</span>
              <span>groups</span><span>:</span> <span>np</span><span>.</span><span>ndarray</span> <span>=</span> <span>None</span><span>):</span>
        <span>"""</span><span>Returns to a grouped time series split generator.</span><span>"""</span>
        <span>assert</span> <span>len</span><span>(</span><span>X</span><span>)</span> <span>==</span> <span>len</span><span>(</span><span>groups</span><span>),</span>  <span>(</span>
            <span>"</span><span>Length of the predictors is not</span><span>"</span>
            <span>"</span><span>matching with the groups.</span><span>"</span><span>)</span>
        <span># The min max index must be sorted in the range </span>        <span>for</span> <span>group_idx</span> <span>in</span> <span>range</span><span>(</span><span>groups</span><span>.</span><span>min</span><span>(),</span> <span>groups</span><span>.</span><span>max</span><span>()):</span>

            <span>training_group</span> <span>=</span> <span>group_idx</span>
            <span># Gets the next group right after </span>            <span># the training as test </span>            <span>test_group</span> <span>=</span> <span>group_idx</span> <span>+</span> <span>1</span>
            <span>training_indices</span> <span>=</span> <span>np</span><span>.</span><span>where</span><span>(</span>
                <span>groups</span> <span>==</span> <span>training_group</span><span>)[</span><span>0</span><span>]</span>
            <span>test_indices</span> <span>=</span> <span>np</span><span>.</span><span>where</span><span>(</span><span>groups</span> <span>==</span> <span>test_group</span><span>)[</span><span>0</span><span>]</span>
            <span>if</span> <span>len</span><span>(</span><span>test_indices</span><span>)</span> <span>></span> <span>0</span><span>:</span>
                <span># Yielding to training and testing indices </span>                <span># for cross-validation generator </span>                <span>yield</span> <span>training_indices</span><span>,</span> <span>test_indices</span>
class CustomCrossValidation: @classmethod def split(cls, X: pd.DataFrame, y: np.ndarray = None, groups: np.ndarray = None): """Returns to a grouped time series split generator.""" assert len(X) == len(groups), ( "Length of the predictors is not" "matching with the groups.") # The min max index must be sorted in the range for group_idx in range(groups.min(), groups.max()): training_group = group_idx # Gets the next group right after # the training as test test_group = group_idx + 1 training_indices = np.where( groups == training_group)[0] test_indices = np.where(groups == test_group)[0] if len(test_indices) > 0: # Yielding to training and testing indices # for cross-validation generator yield training_indices, test_indices

Enter fullscreen mode Exit fullscreen mode

CustomCrossValidation is a simple class with one method (split) uses X (predictors), y (target values), and groups corresponding to the date groups. Those can be months or quarters for your dataset, however, I assumed that those can be mapped into integers to keep the order of time. Hence, if I have 3 quarters in the dataset, I can first have Q1, Q2, and Q3 as of date values. But I can simply map those into 0, 1, 2 to keep the order and use those in my validation generator class method.

The split method, with this naming, is required for GridSearchCV in scikit-learn. Here, I created a range of integers (groups) to keep the order of date. Then assigned the first group indices (t) to be training indices and the next (t + 1) to be validation indices. Then, in the end, the method yields to training and testing indices as the cv parameter of the GridSearchCV method requires a generator object with returning training and testing indices.

Here the example displays how the custom split works with the groups. To have different sizes of date groups, I created 4 groups with 5 instances of 0s, 10 instances of 1s, 10 instances of 2s, and 10 instances of 3s:

<span>X_experiment</span><span>,</span> <span>y_experiment</span> <span>=</span> <span>make_regression</span><span>(</span>
<span>n_samples</span><span>=</span><span>30</span><span>,</span> <span>n_features</span><span>=</span><span>5</span><span>,</span> <span>noise</span><span>=</span><span>0.2</span><span>)</span>
<span>groups_experiment</span> <span>=</span> <span>np</span><span>.</span><span>concatenate</span><span>([</span><span>np</span><span>.</span><span>zeros</span><span>(</span><span>5</span><span>),</span> <span># 5 0s </span> <span>np</span><span>.</span><span>ones</span><span>(</span><span>10</span><span>),</span> <span># 10 1s </span> <span>2</span> <span>*</span> <span>np</span><span>.</span><span>ones</span><span>(</span><span>10</span><span>),</span> <span># 10 2s </span> <span>3</span> <span>*</span> <span>np</span><span>.</span><span>ones</span><span>(</span><span>5</span><span>)</span> <span># 10 3s </span> <span>]).</span><span>astype</span><span>(</span><span>int</span><span>)</span>
<span>for</span> <span>idx</span><span>,</span> <span>(</span><span>x</span><span>,</span> <span>y</span><span>)</span> <span>in</span> <span>enumerate</span><span>(</span>
<span>CustomCrossValidation</span><span>.</span><span>split</span><span>(</span><span>X_experiment</span><span>,</span>
<span>y_experiment</span><span>,</span>
<span>groups_experiment</span><span>)):</span>
<span>print</span><span>(</span><span>f</span><span>"</span><span>Split number: </span><span>{</span><span>idx</span><span>}</span><span>"</span><span>)</span>
<span>print</span><span>(</span><span>f</span><span>"</span><span>Training indices: </span><span>{</span><span>x</span><span>}</span><span>"</span><span>)</span>
<span>print</span><span>(</span><span>f</span><span>"</span><span>Test indices: </span><span>{</span><span>y</span><span>}</span><span>\n</span><span>"</span><span>)</span>
<span>X_experiment</span><span>,</span> <span>y_experiment</span> <span>=</span> <span>make_regression</span><span>(</span>
    <span>n_samples</span><span>=</span><span>30</span><span>,</span> <span>n_features</span><span>=</span><span>5</span><span>,</span> <span>noise</span><span>=</span><span>0.2</span><span>)</span>

<span>groups_experiment</span> <span>=</span> <span>np</span><span>.</span><span>concatenate</span><span>([</span><span>np</span><span>.</span><span>zeros</span><span>(</span><span>5</span><span>),</span>  <span># 5 0s </span>                                    <span>np</span><span>.</span><span>ones</span><span>(</span><span>10</span><span>),</span>  <span># 10 1s </span>                                    <span>2</span> <span>*</span> <span>np</span><span>.</span><span>ones</span><span>(</span><span>10</span><span>),</span>  <span># 10 2s </span>                                    <span>3</span> <span>*</span> <span>np</span><span>.</span><span>ones</span><span>(</span><span>5</span><span>)</span>  <span># 10 3s </span>                                    <span>]).</span><span>astype</span><span>(</span><span>int</span><span>)</span>

<span>for</span> <span>idx</span><span>,</span> <span>(</span><span>x</span><span>,</span> <span>y</span><span>)</span> <span>in</span> <span>enumerate</span><span>(</span>
    <span>CustomCrossValidation</span><span>.</span><span>split</span><span>(</span><span>X_experiment</span><span>,</span>
                                <span>y_experiment</span><span>,</span>
                                <span>groups_experiment</span><span>)):</span>
    <span>print</span><span>(</span><span>f</span><span>"</span><span>Split number: </span><span>{</span><span>idx</span><span>}</span><span>"</span><span>)</span>
    <span>print</span><span>(</span><span>f</span><span>"</span><span>Training indices: </span><span>{</span><span>x</span><span>}</span><span>"</span><span>)</span>
    <span>print</span><span>(</span><span>f</span><span>"</span><span>Test indices: </span><span>{</span><span>y</span><span>}</span><span>\n</span><span>"</span><span>)</span>
X_experiment, y_experiment = make_regression( n_samples=30, n_features=5, noise=0.2) groups_experiment = np.concatenate([np.zeros(5), # 5 0s np.ones(10), # 10 1s 2 * np.ones(10), # 10 2s 3 * np.ones(5) # 10 3s ]).astype(int) for idx, (x, y) in enumerate( CustomCrossValidation.split(X_experiment, y_experiment, groups_experiment)): print(f"Split number: {idx}") print(f"Training indices: {x}") print(f"Test indices: {y}\n")

Enter fullscreen mode Exit fullscreen mode

The example dataset will look like with the groupings:

# The first 5 predictor values...
0 1 2 3 4
0 -0.566298 0.099651 2.190456 -0.503476 -0.990536
1 0.174578 0.257550 0.404051 -0.074446 1.886186
2 0.314247 -0.908024 -0.562288 -1.412304 -1.012831
3 -1.106335 -1.196207 -0.479174 0.812526 -0.185659
4 -0.013497 -1.057711 -0.601707 0.822545 1.852278
# The first 5 target values...
0
0 73.398681
1 195.221637
2 -139.402678
3 -124.863423
4 94.753517
# Groupings for the example dataset...
# The 0s are older date anchor values, whereas 3s the newest...
[0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 3 3 3 3 3]
# The first 5 predictor values...
          0         1         2         3         4
0 -0.566298  0.099651  2.190456 -0.503476 -0.990536
1  0.174578  0.257550  0.404051 -0.074446  1.886186
2  0.314247 -0.908024 -0.562288 -1.412304 -1.012831
3 -1.106335 -1.196207 -0.479174  0.812526 -0.185659
4 -0.013497 -1.057711 -0.601707  0.822545  1.852278

# The first 5 target values...
            0
0   73.398681
1  195.221637
2 -139.402678
3 -124.863423
4   94.753517

# Groupings for the example dataset...
# The 0s are older date anchor values, whereas 3s the newest...
[0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 3 3 3 3 3]
# The first 5 predictor values... 0 1 2 3 4 0 -0.566298 0.099651 2.190456 -0.503476 -0.990536 1 0.174578 0.257550 0.404051 -0.074446 1.886186 2 0.314247 -0.908024 -0.562288 -1.412304 -1.012831 3 -1.106335 -1.196207 -0.479174 0.812526 -0.185659 4 -0.013497 -1.057711 -0.601707 0.822545 1.852278 # The first 5 target values... 0 0 73.398681 1 195.221637 2 -139.402678 3 -124.863423 4 94.753517 # Groupings for the example dataset... # The 0s are older date anchor values, whereas 3s the newest... [0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 3 3 3 3 3]

Enter fullscreen mode Exit fullscreen mode

The groups will be used for having an order in the validation flow. Hence first the 0s are going to be used as the training set, and 1s as validation. Then the 1s are going to be used as training, the 2s as validation… The output of the example generated indices will be:

Split number: 0
Training indices: [0 1 2 3 4]
Test indices: [ 5 6 7 8 9 10 11 12 13 14]
Split number: 1
Training indices: [ 5 6 7 8 9 10 11 12 13 14]
Test indices: [15 16 17 18 19 20 21 22 23 24]
Split number: 2
Training indices: [15 16 17 18 19 20 21 22 23 24]
Test indices: [25 26 27 28 29]
Split number: 0
Training indices: [0 1 2 3 4]
Test indices: [ 5  6  7  8  9 10 11 12 13 14]

Split number: 1
Training indices: [ 5  6  7  8  9 10 11 12 13 14]
Test indices: [15 16 17 18 19 20 21 22 23 24]

Split number: 2
Training indices: [15 16 17 18 19 20 21 22 23 24]
Test indices: [25 26 27 28 29]
Split number: 0 Training indices: [0 1 2 3 4] Test indices: [ 5 6 7 8 9 10 11 12 13 14] Split number: 1 Training indices: [ 5 6 7 8 9 10 11 12 13 14] Test indices: [15 16 17 18 19 20 21 22 23 24] Split number: 2 Training indices: [15 16 17 18 19 20 21 22 23 24] Test indices: [25 26 27 28 29]

Enter fullscreen mode Exit fullscreen mode

To have an example setup, I will be using the Lasso Regression and try to optimize the alpha with Grid Search. In Lasso, when we have a larger alpha, this forces more coefficients to be 0. It is very common to search for the optimum values of alpha in a Lasso Regression.

<span># Instantiating the Lasso estimator </span><span>reg_estimator</span> <span>=</span> <span>linear_model</span><span>.</span><span>Lasso</span><span>()</span>
<span># Parameters </span><span>parameters_to_search</span> <span>=</span> <span>{</span><span>"</span><span>alpha</span><span>"</span><span>:</span> <span>[</span><span>0.1</span><span>,</span> <span>1</span><span>,</span> <span>10</span><span>]}</span>
<span># Splitter </span><span>custom_splitter</span> <span>=</span> <span>CustomCrossValidation</span><span>.</span><span>split</span><span>(</span>
<span>X</span><span>=</span><span>X_experiment</span><span>,</span>
<span>y</span><span>=</span><span>y_experiment</span><span>,</span>
<span>groups</span><span>=</span><span>groups_experiment</span><span>)</span>
<span># Search setup </span><span>reg_search</span> <span>=</span> <span>GridSearchCV</span><span>(</span>
<span>estimator</span><span>=</span><span>reg_estimator</span><span>,</span>
<span>param_grid</span><span>=</span><span>parameters_to_search</span><span>,</span>
<span>scoring</span><span>=</span><span>"</span><span>neg_root_mean_squared_error</span><span>"</span><span>,</span>
<span>cv</span><span>=</span><span>custom_splitter</span><span>)</span>
<span># Fitting </span><span>best_model</span> <span>=</span> <span>reg_search</span><span>.</span><span>fit</span><span>(</span>
<span>X</span><span>=</span><span>X_experiment</span><span>,</span>
<span>y</span><span>=</span><span>y_experiment</span><span>,</span>
<span>groups</span><span>=</span><span>groups_experiment</span><span>)</span>
<span># Instantiating the Lasso estimator </span><span>reg_estimator</span> <span>=</span> <span>linear_model</span><span>.</span><span>Lasso</span><span>()</span>
<span># Parameters </span><span>parameters_to_search</span> <span>=</span> <span>{</span><span>"</span><span>alpha</span><span>"</span><span>:</span> <span>[</span><span>0.1</span><span>,</span> <span>1</span><span>,</span> <span>10</span><span>]}</span>
<span># Splitter </span><span>custom_splitter</span> <span>=</span> <span>CustomCrossValidation</span><span>.</span><span>split</span><span>(</span>
    <span>X</span><span>=</span><span>X_experiment</span><span>,</span>
    <span>y</span><span>=</span><span>y_experiment</span><span>,</span>
    <span>groups</span><span>=</span><span>groups_experiment</span><span>)</span>

<span># Search setup </span><span>reg_search</span> <span>=</span> <span>GridSearchCV</span><span>(</span>
    <span>estimator</span><span>=</span><span>reg_estimator</span><span>,</span>
    <span>param_grid</span><span>=</span><span>parameters_to_search</span><span>,</span>
    <span>scoring</span><span>=</span><span>"</span><span>neg_root_mean_squared_error</span><span>"</span><span>,</span>
    <span>cv</span><span>=</span><span>custom_splitter</span><span>)</span>
<span># Fitting </span><span>best_model</span> <span>=</span> <span>reg_search</span><span>.</span><span>fit</span><span>(</span>
    <span>X</span><span>=</span><span>X_experiment</span><span>,</span>
    <span>y</span><span>=</span><span>y_experiment</span><span>,</span>
    <span>groups</span><span>=</span><span>groups_experiment</span><span>)</span>
# Instantiating the Lasso estimator reg_estimator = linear_model.Lasso() # Parameters parameters_to_search = {"alpha": [0.1, 1, 10]} # Splitter custom_splitter = CustomCrossValidation.split( X=X_experiment, y=y_experiment, groups=groups_experiment) # Search setup reg_search = GridSearchCV( estimator=reg_estimator, param_grid=parameters_to_search, scoring="neg_root_mean_squared_error", cv=custom_splitter) # Fitting best_model = reg_search.fit( X=X_experiment, y=y_experiment, groups=groups_experiment)

Enter fullscreen mode Exit fullscreen mode

This will output the best estimator as follows, using the custom cross-validation. There will be 3 splits as we used 4 groups.

# Best model:
Lasso(alpha=0.1)
# Number of splits:
3
# Best model:
Lasso(alpha=0.1)

# Number of splits:
3
# Best model: Lasso(alpha=0.1) # Number of splits: 3

Enter fullscreen mode Exit fullscreen mode

Voila, having a simple generator helped me to have a custom validation flow in a Grid Search optimization. I enjoy reading scikit-learn documentation. Besides the fact that reading is fun, it helps me to understand some statistical implementations better and tweak whenever it is necessary.

To have a complete set of examples, please refer to the Github repository. Happy reading the documentation!

原文链接:Writing Custom Cross-Validation Methods For Grid Search in Scikit-learn

© 版权声明
THE END
喜欢就支持一下吧
点赞15 分享
People do a lot of thinking, and sometimes, that's what kills us.
有时候是我们自己想太多才让自己如此难受
评论 抢沙发

请登录后发表评论

    暂无评论内容