名称:使用sklearn wrapper做参数搜索
时间:2016年11月17日
说明:建造一个简单的卷积模型,通过使用sklearn的GridSearchCV去发现最好的模型。
数据集:MNIST
from __future__ import print_function
import numpy as np
np.random.seed(1337) # for reproducibility
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Convolution2D, MaxPooling2D
from keras.utils import np_utils
from keras.wrappers.scikit_learn import KerasClassifier
from sklearn.grid_search import GridSearchCV
Using TensorFlow backend. /root/Util/miniconda/lib/python2.7/site-packages/sklearn/cross_validation.py:44: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20. "This module will be removed in 0.20.", DeprecationWarning) /root/Util/miniconda/lib/python2.7/site-packages/sklearn/grid_search.py:43: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. This module will be removed in 0.20. DeprecationWarning)
nb_classes = 10
# input image dimensions
img_rows, img_cols = 28, 28
# load training data and do basic data normalization
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1)
X_test = X_test.reshape(X_test.shape[0], img_rows, img_cols, 1)
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_test /= 255
# convert class vectors to binary class matrices
y_train = np_utils.to_categorical(y_train, nb_classes)
y_test = np_utils.to_categorical(y_test, nb_classes)
def make_model(dense_layer_sizes, nb_filters, nb_conv, nb_pool):
'''Creates model comprised of 2 convolutional layers followed by dense layers
dense_layer_sizes: List of layer sizes. This list has one number for each layer
nb_filters: Number of convolutional filters in each convolutional layer
nb_conv: Convolutional kernel size
nb_pool: Size of pooling area for max pooling
'''
model = Sequential()
model.add(Convolution2D(nb_filters, nb_conv, nb_conv,
border_mode='valid',
input_shape=(img_rows, img_cols, 1)))
model.add(Activation('relu'))
model.add(Convolution2D(nb_filters, nb_conv, nb_conv))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(nb_pool, nb_pool)))
model.add(Dropout(0.25))
model.add(Flatten())
for layer_size in dense_layer_sizes:
model.add(Dense(layer_size))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(nb_classes))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy',
optimizer='adadelta',
metrics=['accuracy'])
return model
KerasClassifier()实现了sklearn的分类器接口
keras.wrappers.scikit_learn.KerasClassifier(build_fn=None, **sk_params)
build_fn:可调用的函数或类对象
sk_params:模型参数和训练参数
dense_size_candidates = [[32], [64], [32, 32], [64, 64]]
my_classifier = KerasClassifier(make_model, batch_size=32)
说明:对估计器的指定参数值进行穷举搜索。
validator = GridSearchCV(my_classifier,
param_grid={'dense_layer_sizes': dense_size_candidates,
# nb_epoch可用于调整,即使不是模型构建函数的参数
'nb_epoch': [3, 6],
'nb_filters': [8],
'nb_conv': [3],
'nb_pool': [2]},
scoring='log_loss',
n_jobs=1)
validator.fit(X_train, y_train)
Epoch 1/3 40000/40000 [==============================] - 12s - loss: 0.8605 - acc: 0.7147 Epoch 2/3 40000/40000 [==============================] - 11s - loss: 0.5645 - acc: 0.8208 Epoch 3/3 40000/40000 [==============================] - 12s - loss: 0.4642 - acc: 0.8525 1536/20000 [=>............................] - ETA: 2s
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20. sample_weight=sample_weight)
19968/20000 [============================>.] - ETA: 0sEpoch 1/3 40000/40000 [==============================] - 12s - loss: 0.8284 - acc: 0.7265 Epoch 2/3 40000/40000 [==============================] - 12s - loss: 0.5357 - acc: 0.8283 Epoch 3/3 40000/40000 [==============================] - 12s - loss: 0.4524 - acc: 0.8563 1280/20000 [>.............................] - ETA: 2s
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20. sample_weight=sample_weight)
19968/20000 [============================>.] - ETA: 0sEpoch 1/3 40000/40000 [==============================] - 12s - loss: 0.8130 - acc: 0.7311 Epoch 2/3 40000/40000 [==============================] - 12s - loss: 0.5159 - acc: 0.8359 Epoch 3/3 40000/40000 [==============================] - 12s - loss: 0.4416 - acc: 0.8602 1152/20000 [>.............................] - ETA: 3s
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20. sample_weight=sample_weight)
19968/20000 [============================>.] - ETA: 0sEpoch 1/6 40000/40000 [==============================] - 12s - loss: 0.8093 - acc: 0.7304 Epoch 2/6 40000/40000 [==============================] - 12s - loss: 0.4811 - acc: 0.8459 Epoch 3/6 40000/40000 [==============================] - 12s - loss: 0.4099 - acc: 0.8723 Epoch 4/6 40000/40000 [==============================] - 11s - loss: 0.3624 - acc: 0.8859 Epoch 5/6 40000/40000 [==============================] - 11s - loss: 0.3331 - acc: 0.8956 Epoch 6/6 40000/40000 [==============================] - 12s - loss: 0.3093 - acc: 0.9030 928/20000 [>.............................] - ETA: 3s
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20. sample_weight=sample_weight)
19936/20000 [============================>.] - ETA: 0sEpoch 1/6 40000/40000 [==============================] - 12s - loss: 0.7886 - acc: 0.7393 Epoch 2/6 40000/40000 [==============================] - 12s - loss: 0.4860 - acc: 0.8451 Epoch 3/6 40000/40000 [==============================] - 12s - loss: 0.4136 - acc: 0.8712 Epoch 4/6 40000/40000 [==============================] - 12s - loss: 0.3739 - acc: 0.8827 Epoch 5/6 40000/40000 [==============================] - 11s - loss: 0.3499 - acc: 0.8924 Epoch 6/6 40000/40000 [==============================] - 12s - loss: 0.3297 - acc: 0.8989 800/20000 [>.............................] - ETA: 4s
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20. sample_weight=sample_weight)
19936/20000 [============================>.] - ETA: 0sEpoch 1/6 40000/40000 [==============================] - 12s - loss: 0.9260 - acc: 0.6871 Epoch 2/6 40000/40000 [==============================] - 11s - loss: 0.6032 - acc: 0.8043 Epoch 3/6 40000/40000 [==============================] - 11s - loss: 0.5158 - acc: 0.8342 Epoch 4/6 40000/40000 [==============================] - 12s - loss: 0.4425 - acc: 0.8599 Epoch 5/6 40000/40000 [==============================] - 11s - loss: 0.4088 - acc: 0.8709 Epoch 6/6 40000/40000 [==============================] - 11s - loss: 0.3644 - acc: 0.8848 544/20000 [..............................] - ETA: 6s
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20. sample_weight=sample_weight)
20000/20000 [==============================] - 2s Epoch 1/3 40000/40000 [==============================] - 11s - loss: 0.6009 - acc: 0.8104 Epoch 2/3 40000/40000 [==============================] - 11s - loss: 0.3410 - acc: 0.8968 Epoch 3/3 40000/40000 [==============================] - 12s - loss: 0.2770 - acc: 0.9162 256/20000 [..............................] - ETA: 14s
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20. sample_weight=sample_weight)
19904/20000 [============================>.] - ETA: 0sEpoch 1/3 40000/40000 [==============================] - 12s - loss: 0.6185 - acc: 0.8061 Epoch 2/3 40000/40000 [==============================] - 12s - loss: 0.3376 - acc: 0.8999 Epoch 3/3 40000/40000 [==============================] - 12s - loss: 0.2741 - acc: 0.9193 32/20000 [..............................] - ETA: 119s
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20. sample_weight=sample_weight)
19936/20000 [============================>.] - ETA: 0sEpoch 1/3 40000/40000 [==============================] - 12s - loss: 0.6259 - acc: 0.7990 Epoch 2/3 40000/40000 [==============================] - 12s - loss: 0.3257 - acc: 0.9015 Epoch 3/3 40000/40000 [==============================] - 12s - loss: 0.2599 - acc: 0.9230
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20. sample_weight=sample_weight)
19936/20000 [============================>.] - ETA: 0sEpoch 1/6 40000/40000 [==============================] - 12s - loss: 0.6295 - acc: 0.7993 Epoch 2/6 40000/40000 [==============================] - 12s - loss: 0.3693 - acc: 0.8871 Epoch 3/6 40000/40000 [==============================] - 12s - loss: 0.2988 - acc: 0.9092 Epoch 4/6 40000/40000 [==============================] - 11s - loss: 0.2542 - acc: 0.9238 Epoch 5/6 40000/40000 [==============================] - 12s - loss: 0.2246 - acc: 0.9343 Epoch 6/6 40000/40000 [==============================] - 11s - loss: 0.2026 - acc: 0.9413
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20. sample_weight=sample_weight)
19968/20000 [============================>.] - ETA: 0sEpoch 1/6 40000/40000 [==============================] - 12s - loss: 0.5739 - acc: 0.8182 Epoch 2/6 40000/40000 [==============================] - 12s - loss: 0.3139 - acc: 0.9077 Epoch 3/6 40000/40000 [==============================] - 12s - loss: 0.2565 - acc: 0.9245 Epoch 4/6 40000/40000 [==============================] - 12s - loss: 0.2306 - acc: 0.9316 Epoch 5/6 40000/40000 [==============================] - 11s - loss: 0.2072 - acc: 0.9398 Epoch 6/6 40000/40000 [==============================] - 12s - loss: 0.1947 - acc: 0.9416
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20. sample_weight=sample_weight)
20000/20000 [==============================] - 2s Epoch 1/6 40000/40000 [==============================] - 12s - loss: 0.6035 - acc: 0.8089 Epoch 2/6 40000/40000 [==============================] - 12s - loss: 0.3363 - acc: 0.8993 Epoch 3/6 40000/40000 [==============================] - 12s - loss: 0.2729 - acc: 0.9181 Epoch 4/6 40000/40000 [==============================] - 12s - loss: 0.2380 - acc: 0.9298 Epoch 5/6 40000/40000 [==============================] - 12s - loss: 0.2114 - acc: 0.9376 Epoch 6/6 40000/40000 [==============================] - 12s - loss: 0.1930 - acc: 0.9442
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20. sample_weight=sample_weight)
19904/20000 [============================>.] - ETA: 0sEpoch 1/3 40000/40000 [==============================] - 13s - loss: 0.7216 - acc: 0.7599 Epoch 2/3 40000/40000 [==============================] - 13s - loss: 0.4140 - acc: 0.8687 Epoch 3/3 40000/40000 [==============================] - 13s - loss: 0.3545 - acc: 0.8897
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20. sample_weight=sample_weight)
19968/20000 [============================>.] - ETA: 0sEpoch 1/3 40000/40000 [==============================] - 13s - loss: 0.8014 - acc: 0.7343 Epoch 2/3 40000/40000 [==============================] - 13s - loss: 0.4586 - acc: 0.8549 Epoch 3/3 40000/40000 [==============================] - 13s - loss: 0.3886 - acc: 0.8797
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20. sample_weight=sample_weight)
20000/20000 [==============================] - 2s Epoch 1/3 40000/40000 [==============================] - 14s - loss: 0.8124 - acc: 0.7284 Epoch 2/3 40000/40000 [==============================] - 13s - loss: 0.4838 - acc: 0.8477 Epoch 3/3 40000/40000 [==============================] - 13s - loss: 0.4148 - acc: 0.8705
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20. sample_weight=sample_weight)
19936/20000 [============================>.] - ETA: 0sEpoch 1/6 40000/40000 [==============================] - 13s - loss: 0.7192 - acc: 0.7608 Epoch 2/6 40000/40000 [==============================] - 13s - loss: 0.4043 - acc: 0.8712 Epoch 3/6 40000/40000 [==============================] - 13s - loss: 0.3514 - acc: 0.8902 Epoch 4/6 40000/40000 [==============================] - 13s - loss: 0.3170 - acc: 0.9009 Epoch 5/6 40000/40000 [==============================] - 13s - loss: 0.2986 - acc: 0.9079 Epoch 6/6 40000/40000 [==============================] - 13s - loss: 0.2777 - acc: 0.9138
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20. sample_weight=sample_weight)
20000/20000 [==============================] - 2s Epoch 1/6 40000/40000 [==============================] - 13s - loss: 0.7651 - acc: 0.7428 Epoch 2/6 40000/40000 [==============================] - 13s - loss: 0.4377 - acc: 0.8626 Epoch 3/6 40000/40000 [==============================] - 12s - loss: 0.3688 - acc: 0.8846 Epoch 4/6 40000/40000 [==============================] - 13s - loss: 0.3298 - acc: 0.8983 Epoch 5/6 40000/40000 [==============================] - 13s - loss: 0.3050 - acc: 0.9052 Epoch 6/6 40000/40000 [==============================] - 13s - loss: 0.2945 - acc: 0.9091
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20. sample_weight=sample_weight)
19968/20000 [============================>.] - ETA: 0sEpoch 1/6 40000/40000 [==============================] - 13s - loss: 0.8654 - acc: 0.7107 Epoch 2/6 40000/40000 [==============================] - 13s - loss: 0.5192 - acc: 0.8338 Epoch 3/6 40000/40000 [==============================] - 13s - loss: 0.4300 - acc: 0.8638 Epoch 4/6 40000/40000 [==============================] - 13s - loss: 0.3788 - acc: 0.8795 Epoch 5/6 40000/40000 [==============================] - 13s - loss: 0.3477 - acc: 0.8908 Epoch 6/6 40000/40000 [==============================] - 13s - loss: 0.3197 - acc: 0.8999
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20. sample_weight=sample_weight)
19968/20000 [============================>.] - ETA: 0sEpoch 1/3 40000/40000 [==============================] - 13s - loss: 0.5614 - acc: 0.8237 Epoch 2/3 40000/40000 [==============================] - 13s - loss: 0.2812 - acc: 0.9163 Epoch 3/3 40000/40000 [==============================] - 13s - loss: 0.2251 - acc: 0.9347
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20. sample_weight=sample_weight)
19904/20000 [============================>.] - ETA: 0sEpoch 1/3 40000/40000 [==============================] - 13s - loss: 0.5107 - acc: 0.8401 Epoch 2/3 40000/40000 [==============================] - 13s - loss: 0.2421 - acc: 0.9307 Epoch 3/3 40000/40000 [==============================] - 13s - loss: 0.1988 - acc: 0.9424
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20. sample_weight=sample_weight)
19936/20000 [============================>.] - ETA: 0sEpoch 1/3 40000/40000 [==============================] - 13s - loss: 0.5245 - acc: 0.8351 Epoch 2/3 40000/40000 [==============================] - 13s - loss: 0.2639 - acc: 0.9222 Epoch 3/3 40000/40000 [==============================] - 13s - loss: 0.2173 - acc: 0.9356
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20. sample_weight=sample_weight)
19904/20000 [============================>.] - ETA: 0sEpoch 1/6 40000/40000 [==============================] - 13s - loss: 0.5514 - acc: 0.8266 Epoch 2/6 40000/40000 [==============================] - 13s - loss: 0.2738 - acc: 0.9178 Epoch 3/6 40000/40000 [==============================] - 12s - loss: 0.2165 - acc: 0.9365 Epoch 4/6 40000/40000 [==============================] - 13s - loss: 0.1909 - acc: 0.9453 Epoch 5/6 40000/40000 [==============================] - 13s - loss: 0.1734 - acc: 0.9492 Epoch 6/6 40000/40000 [==============================] - 13s - loss: 0.1621 - acc: 0.9533
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20. sample_weight=sample_weight)
20000/20000 [==============================] - 2s Epoch 1/6 40000/40000 [==============================] - 11s - loss: 0.5373 - acc: 0.8282 Epoch 2/6 40000/40000 [==============================] - 11s - loss: 0.2628 - acc: 0.9222 Epoch 3/6 40000/40000 [==============================] - 11s - loss: 0.2104 - acc: 0.9392 Epoch 4/6 40000/40000 [==============================] - 11s - loss: 0.1844 - acc: 0.9455 Epoch 5/6 40000/40000 [==============================] - 10s - loss: 0.1657 - acc: 0.9530 Epoch 6/6 40000/40000 [==============================] - 11s - loss: 0.1482 - acc: 0.9576
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20. sample_weight=sample_weight)
19936/20000 [============================>.] - ETA: 0sEpoch 1/6 40000/40000 [==============================] - 11s - loss: 0.5453 - acc: 0.8316 Epoch 2/6 40000/40000 [==============================] - 11s - loss: 0.2769 - acc: 0.9198 Epoch 3/6 40000/40000 [==============================] - 11s - loss: 0.2206 - acc: 0.9356 Epoch 4/6 40000/40000 [==============================] - 11s - loss: 0.1952 - acc: 0.9447 Epoch 5/6 40000/40000 [==============================] - 11s - loss: 0.1756 - acc: 0.9485 Epoch 6/6 40000/40000 [==============================] - 11s - loss: 0.1650 - acc: 0.9511
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20. sample_weight=sample_weight)
20000/20000 [==============================] - 2s Epoch 1/6 60000/60000 [==============================] - 17s - loss: 0.4784 - acc: 0.8494 Epoch 2/6 60000/60000 [==============================] - 16s - loss: 0.2399 - acc: 0.9295 Epoch 3/6 60000/60000 [==============================] - 16s - loss: 0.1875 - acc: 0.9451 Epoch 4/6 60000/60000 [==============================] - 16s - loss: 0.1602 - acc: 0.9521 Epoch 5/6 60000/60000 [==============================] - 16s - loss: 0.1445 - acc: 0.9584 Epoch 6/6 60000/60000 [==============================] - 16s - loss: 0.1357 - acc: 0.9610
GridSearchCV(cv=None, error_score='raise', estimator=<keras.wrappers.scikit_learn.KerasClassifier object at 0x7f42703d3e10>, fit_params={}, iid=True, n_jobs=1, param_grid={'dense_layer_sizes': [[32], [64], [32, 32], [64, 64]], 'nb_epoch': [3, 6], 'nb_pool': [2], 'nb_conv': [3], 'nb_filters': [8]}, pre_dispatch='2*n_jobs', refit=True, scoring='log_loss', verbose=0)
print('The parameters of the best model are: ')
print(validator.best_params_)
The parameters of the best model are: {'dense_layer_sizes': [64, 64], 'nb_conv': 3, 'nb_pool': 2, 'nb_epoch': 6, 'nb_filters': 8}
validator.best_estimator_ 返回sklearn-wrapped版本的最好模型
validator.best_estimator_.model 返回(unwrapped)keras模型
best_model = validator.best_estimator_.model
metric_names = best_model.metrics_names
metric_values = best_model.evaluate(X_test, y_test)
print('\n')
for metric, value in zip(metric_names, metric_values):
print(metric, ': ', value)
10000/10000 [==============================] - 1s loss : 0.0535527251991 acc : 0.9825