GridSearch のパラメタチューニング

前回の記事は「モデル評価」を話しました。

今回の記事はGridSearchでのパラメータチューニングを解説します。

Pythonの機械学習ライブラリscikit-learnにはモデルのパラメタをチューニングする方法としてGridSearchCVが用意されています。

digitsのデータセットをGridSearchCVを使ってパラメタチューニングをしてみます。

GridSearch とは

グリッドサーチとは、モデルの精度を向上させるために用いられる手法です。全てのパラメータの組み合わせを試してみる方法のことです。パラメータの組み合わせを全て試し、最も評価精度の良いものを探索する方法です。パラメータを変更することで予測精度は飛躍的に変わります。

ライブラリ

In [1]:
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import classification_report
from sklearn.svm import SVC
import numpy as np
import matplotlib.pyplot as plt

データセットのロード

0-9の数字の画像コードデータをロード

In [2]:
# sklernのライブラリからDigits dataset
digits = datasets.load_digits()

# X, y 作成
n_samples = len(digits.images)
X = digits.images.reshape((n_samples, -1))
y = digits.target
X, y
Out[2]:
(array([[ 0.,  0.,  5., ...,  0.,  0.,  0.],
        [ 0.,  0.,  0., ..., 10.,  0.,  0.],
        [ 0.,  0.,  0., ..., 16.,  9.,  0.],
        ...,
        [ 0.,  0.,  1., ...,  6.,  0.,  0.],
        [ 0.,  0.,  2., ..., 12.,  0.,  0.],
        [ 0.,  0., 10., ..., 12.,  1.,  0.]]), array([0, 1, 2, ..., 8, 9, 8]))

データセットの可視化

In [3]:
data_train = digits.images
label_train = digits.target
mean_images = np.zeros((10,8,8))
fig = plt.figure(figsize=(10,5))
for i in range(10):
    mean_images[i] = data_train[label_train==i].mean(axis=0)
    ax = fig.add_subplot(2, 5, i+1)
    ax.axis('off')
    ax.set_title('{0} (n={1})'.format(i, len(data_train[label_train==i])))
    ax.imshow(mean_images[i],cmap=plt.cm.gray_r, interpolation='nearest')
plt.show()

学習データとテストデータの割合

半分の学習データとテストデータに分けます。

In [4]:
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.5, random_state=0)

ハイパーパラメータの調整

交差検定によってパラメータを設定します。

GridSearchCVで kernel、gamma、Cのパラメータを調整し、

最も評価のprecisionとrecallのモデルを探索します。

最も評価のprecisionとrecallのパラメータは C = 10, gamma = 0.001, kernel = rbfです。

In [5]:
# 交差検定によってパラメータを設定
tuned_parameters = [{'kernel': ['rbf'], 
                     'gamma': [1e-3, 1e-4],
                     'C': [1, 10, 100, 1000]},
                    {'kernel': ['linear'], 'C': [1, 10, 100, 1000]}]

scores = ['precision', 'recall']

for score in scores:
    print("# ハイパーパラメータの調整: %s" % score)
    print()

    clf = GridSearchCV(SVC(), tuned_parameters, cv=5,
                       scoring='%s_macro' % score)
    clf.fit(X_train, y_train)

    print("最良のパラメータセット:")
    print()
    print(clf.best_params_)
    print()
    print("グリッドスコア: %s" % score)
    print()
    means = clf.cv_results_['mean_test_score']
    stds = clf.cv_results_['std_test_score']
    for mean, std, params in zip(means, stds, clf.cv_results_['params']):
        print("%0.3f (+/-%0.03f) for %r"
              % (mean, std * 2, params))
    print()

    print("詳細レポート:")
    print()
    print("学習データでモデルを作成")
    print("テストデータの評価スコアを作成")
    print()
    y_true, y_pred = y_test, clf.predict(X_test)
    print(classification_report(y_true, y_pred))
    print()
    
# ハイパーパラメータの調整: precision

最良のパラメータセット:

{'C': 10, 'gamma': 0.001, 'kernel': 'rbf'}

グリッドスコア: precision

0.986 (+/-0.016) for {'C': 1, 'gamma': 0.001, 'kernel': 'rbf'}
0.959 (+/-0.029) for {'C': 1, 'gamma': 0.0001, 'kernel': 'rbf'}
0.988 (+/-0.017) for {'C': 10, 'gamma': 0.001, 'kernel': 'rbf'}
0.982 (+/-0.026) for {'C': 10, 'gamma': 0.0001, 'kernel': 'rbf'}
0.988 (+/-0.017) for {'C': 100, 'gamma': 0.001, 'kernel': 'rbf'}
0.982 (+/-0.025) for {'C': 100, 'gamma': 0.0001, 'kernel': 'rbf'}
0.988 (+/-0.017) for {'C': 1000, 'gamma': 0.001, 'kernel': 'rbf'}
0.982 (+/-0.025) for {'C': 1000, 'gamma': 0.0001, 'kernel': 'rbf'}
0.975 (+/-0.014) for {'C': 1, 'kernel': 'linear'}
0.975 (+/-0.014) for {'C': 10, 'kernel': 'linear'}
0.975 (+/-0.014) for {'C': 100, 'kernel': 'linear'}
0.975 (+/-0.014) for {'C': 1000, 'kernel': 'linear'}

詳細レポート:

学習データでモデルを作成
テストデータの評価スコアを作成

              precision    recall  f1-score   support

           0       1.00      1.00      1.00        89
           1       0.97      1.00      0.98        90
           2       0.99      0.98      0.98        92
           3       1.00      0.99      0.99        93
           4       1.00      1.00      1.00        76
           5       0.99      0.98      0.99       108
           6       0.99      1.00      0.99        89
           7       0.99      1.00      0.99        78
           8       1.00      0.98      0.99        92
           9       0.99      0.99      0.99        92

   micro avg       0.99      0.99      0.99       899
   macro avg       0.99      0.99      0.99       899
weighted avg       0.99      0.99      0.99       899


# ハイパーパラメータの調整: recall

最良のパラメータセット:

{'C': 10, 'gamma': 0.001, 'kernel': 'rbf'}

グリッドスコア: recall

0.986 (+/-0.019) for {'C': 1, 'gamma': 0.001, 'kernel': 'rbf'}
0.957 (+/-0.029) for {'C': 1, 'gamma': 0.0001, 'kernel': 'rbf'}
0.987 (+/-0.019) for {'C': 10, 'gamma': 0.001, 'kernel': 'rbf'}
0.981 (+/-0.028) for {'C': 10, 'gamma': 0.0001, 'kernel': 'rbf'}
0.987 (+/-0.019) for {'C': 100, 'gamma': 0.001, 'kernel': 'rbf'}
0.981 (+/-0.026) for {'C': 100, 'gamma': 0.0001, 'kernel': 'rbf'}
0.987 (+/-0.019) for {'C': 1000, 'gamma': 0.001, 'kernel': 'rbf'}
0.981 (+/-0.026) for {'C': 1000, 'gamma': 0.0001, 'kernel': 'rbf'}
0.972 (+/-0.012) for {'C': 1, 'kernel': 'linear'}
0.972 (+/-0.012) for {'C': 10, 'kernel': 'linear'}
0.972 (+/-0.012) for {'C': 100, 'kernel': 'linear'}
0.972 (+/-0.012) for {'C': 1000, 'kernel': 'linear'}

詳細レポート:

学習データでモデルを作成
テストデータの評価スコアを作成

              precision    recall  f1-score   support

           0       1.00      1.00      1.00        89
           1       0.97      1.00      0.98        90
           2       0.99      0.98      0.98        92
           3       1.00      0.99      0.99        93
           4       1.00      1.00      1.00        76
           5       0.99      0.98      0.99       108
           6       0.99      1.00      0.99        89
           7       0.99      1.00      0.99        78
           8       1.00      0.98      0.99        92
           9       0.99      0.99      0.99        92

   micro avg       0.99      0.99      0.99       899
   macro avg       0.99      0.99      0.99       899
weighted avg       0.99      0.99      0.99       899