12.使用sklearn wrapper进行的参数搜索

程序说明

名称:使用sklearn wrapper做参数搜索

时间:2016年11月17日

说明:建造一个简单的卷积模型,通过使用sklearn的GridSearchCV去发现最好的模型。

数据集:MNIST

1.加载keras模块

from __future__ import print_function
import numpy as np
np.random.seed(1337)  # for reproducibility

from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Convolution2D, MaxPooling2D
from keras.utils import np_utils
from keras.wrappers.scikit_learn import KerasClassifier
from sklearn.grid_search import GridSearchCV
Using TensorFlow backend.
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/cross_validation.py:44: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.
  "This module will be removed in 0.20.", DeprecationWarning)
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/grid_search.py:43: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. This module will be removed in 0.20.
  DeprecationWarning)

2.变量初始化

nb_classes = 10

# input image dimensions
img_rows, img_cols = 28, 28

3.准备数据

# load training data and do basic data normalization
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1)
X_test = X_test.reshape(X_test.shape[0], img_rows, img_cols, 1)
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_test /= 255

转换类标号

# convert class vectors to binary class matrices
y_train = np_utils.to_categorical(y_train, nb_classes)
y_test = np_utils.to_categorical(y_test, nb_classes)

4.建立模型

使用Sequential()

构造一个有两个卷积层和若干个全连接层组成的模型,这里全连接的层数是由参数所决定的。

dense_layer_sizes:层尺寸的列表。这个列表中对于每个层都有一组数字。

nb_filters:每个卷积层中滤波器的个数

nb_conv:卷积核的尺寸

nb_pool:用于max pooling的池化面积

def make_model(dense_layer_sizes, nb_filters, nb_conv, nb_pool):
    '''Creates model comprised of 2 convolutional layers followed by dense layers
    dense_layer_sizes: List of layer sizes. This list has one number for each layer
    nb_filters: Number of convolutional filters in each convolutional layer
    nb_conv: Convolutional kernel size
    nb_pool: Size of pooling area for max pooling
    '''

    model = Sequential()

    model.add(Convolution2D(nb_filters, nb_conv, nb_conv,
                            border_mode='valid',
                            input_shape=(img_rows, img_cols, 1)))
    model.add(Activation('relu'))
    model.add(Convolution2D(nb_filters, nb_conv, nb_conv))
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(nb_pool, nb_pool)))
    model.add(Dropout(0.25))

    model.add(Flatten())
    for layer_size in dense_layer_sizes:
        model.add(Dense(layer_size))
    model.add(Activation('relu'))
    model.add(Dropout(0.5))
    model.add(Dense(nb_classes))
    model.add(Activation('softmax'))

    model.compile(loss='categorical_crossentropy',
                  optimizer='adadelta',
                  metrics=['accuracy'])

    return model

5.sklearn接口

KerasClassifier()实现了sklearn的分类器接口

keras.wrappers.scikit_learn.KerasClassifier(build_fn=None, **sk_params)

build_fn:可调用的函数或类对象

sk_params:模型参数和训练参数

dense_size_candidates = [[32], [64], [32, 32], [64, 64]]
my_classifier = KerasClassifier(make_model, batch_size=32)

sklearn中的GridSearchCV函数

说明:对估计器的指定参数值进行穷举搜索。

validator = GridSearchCV(my_classifier,
                         param_grid={'dense_layer_sizes': dense_size_candidates,
                                     # nb_epoch可用于调整,即使不是模型构建函数的参数
                                     'nb_epoch': [3, 6],
                                     'nb_filters': [8],
                                     'nb_conv': [3],
                                     'nb_pool': [2]},
                         scoring='log_loss',
                         n_jobs=1)

开始拟合

validator.fit(X_train, y_train)
Epoch 1/3
40000/40000 [==============================] - 12s - loss: 0.8605 - acc: 0.7147    
Epoch 2/3
40000/40000 [==============================] - 11s - loss: 0.5645 - acc: 0.8208    
Epoch 3/3
40000/40000 [==============================] - 12s - loss: 0.4642 - acc: 0.8525    
 1536/20000 [=>............................] - ETA: 2s

/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
  sample_weight=sample_weight)


19968/20000 [============================>.] - ETA: 0sEpoch 1/3
40000/40000 [==============================] - 12s - loss: 0.8284 - acc: 0.7265    
Epoch 2/3
40000/40000 [==============================] - 12s - loss: 0.5357 - acc: 0.8283    
Epoch 3/3
40000/40000 [==============================] - 12s - loss: 0.4524 - acc: 0.8563    
 1280/20000 [>.............................] - ETA: 2s

/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
  sample_weight=sample_weight)


19968/20000 [============================>.] - ETA: 0sEpoch 1/3
40000/40000 [==============================] - 12s - loss: 0.8130 - acc: 0.7311    
Epoch 2/3
40000/40000 [==============================] - 12s - loss: 0.5159 - acc: 0.8359    
Epoch 3/3
40000/40000 [==============================] - 12s - loss: 0.4416 - acc: 0.8602    
 1152/20000 [>.............................] - ETA: 3s

/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
  sample_weight=sample_weight)


19968/20000 [============================>.] - ETA: 0sEpoch 1/6
40000/40000 [==============================] - 12s - loss: 0.8093 - acc: 0.7304    
Epoch 2/6
40000/40000 [==============================] - 12s - loss: 0.4811 - acc: 0.8459    
Epoch 3/6
40000/40000 [==============================] - 12s - loss: 0.4099 - acc: 0.8723    
Epoch 4/6
40000/40000 [==============================] - 11s - loss: 0.3624 - acc: 0.8859    
Epoch 5/6
40000/40000 [==============================] - 11s - loss: 0.3331 - acc: 0.8956    
Epoch 6/6
40000/40000 [==============================] - 12s - loss: 0.3093 - acc: 0.9030    
  928/20000 [>.............................] - ETA: 3s

/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
  sample_weight=sample_weight)


19936/20000 [============================>.] - ETA: 0sEpoch 1/6
40000/40000 [==============================] - 12s - loss: 0.7886 - acc: 0.7393    
Epoch 2/6
40000/40000 [==============================] - 12s - loss: 0.4860 - acc: 0.8451    
Epoch 3/6
40000/40000 [==============================] - 12s - loss: 0.4136 - acc: 0.8712    
Epoch 4/6
40000/40000 [==============================] - 12s - loss: 0.3739 - acc: 0.8827    
Epoch 5/6
40000/40000 [==============================] - 11s - loss: 0.3499 - acc: 0.8924    
Epoch 6/6
40000/40000 [==============================] - 12s - loss: 0.3297 - acc: 0.8989    
  800/20000 [>.............................] - ETA: 4s

/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
  sample_weight=sample_weight)


19936/20000 [============================>.] - ETA: 0sEpoch 1/6
40000/40000 [==============================] - 12s - loss: 0.9260 - acc: 0.6871    
Epoch 2/6
40000/40000 [==============================] - 11s - loss: 0.6032 - acc: 0.8043    
Epoch 3/6
40000/40000 [==============================] - 11s - loss: 0.5158 - acc: 0.8342    
Epoch 4/6
40000/40000 [==============================] - 12s - loss: 0.4425 - acc: 0.8599    
Epoch 5/6
40000/40000 [==============================] - 11s - loss: 0.4088 - acc: 0.8709    
Epoch 6/6
40000/40000 [==============================] - 11s - loss: 0.3644 - acc: 0.8848    
  544/20000 [..............................] - ETA: 6s

/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
  sample_weight=sample_weight)


20000/20000 [==============================] - 2s     
Epoch 1/3
40000/40000 [==============================] - 11s - loss: 0.6009 - acc: 0.8104    
Epoch 2/3
40000/40000 [==============================] - 11s - loss: 0.3410 - acc: 0.8968    
Epoch 3/3
40000/40000 [==============================] - 12s - loss: 0.2770 - acc: 0.9162    
  256/20000 [..............................] - ETA: 14s

/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
  sample_weight=sample_weight)


19904/20000 [============================>.] - ETA: 0sEpoch 1/3
40000/40000 [==============================] - 12s - loss: 0.6185 - acc: 0.8061    
Epoch 2/3
40000/40000 [==============================] - 12s - loss: 0.3376 - acc: 0.8999    
Epoch 3/3
40000/40000 [==============================] - 12s - loss: 0.2741 - acc: 0.9193    
   32/20000 [..............................] - ETA: 119s

/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
  sample_weight=sample_weight)


19936/20000 [============================>.] - ETA: 0sEpoch 1/3
40000/40000 [==============================] - 12s - loss: 0.6259 - acc: 0.7990    
Epoch 2/3
40000/40000 [==============================] - 12s - loss: 0.3257 - acc: 0.9015    
Epoch 3/3
40000/40000 [==============================] - 12s - loss: 0.2599 - acc: 0.9230    


/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
  sample_weight=sample_weight)


19936/20000 [============================>.] - ETA: 0sEpoch 1/6
40000/40000 [==============================] - 12s - loss: 0.6295 - acc: 0.7993    
Epoch 2/6
40000/40000 [==============================] - 12s - loss: 0.3693 - acc: 0.8871    
Epoch 3/6
40000/40000 [==============================] - 12s - loss: 0.2988 - acc: 0.9092    
Epoch 4/6
40000/40000 [==============================] - 11s - loss: 0.2542 - acc: 0.9238    
Epoch 5/6
40000/40000 [==============================] - 12s - loss: 0.2246 - acc: 0.9343    
Epoch 6/6
40000/40000 [==============================] - 11s - loss: 0.2026 - acc: 0.9413    


/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
  sample_weight=sample_weight)


19968/20000 [============================>.] - ETA: 0sEpoch 1/6
40000/40000 [==============================] - 12s - loss: 0.5739 - acc: 0.8182    
Epoch 2/6
40000/40000 [==============================] - 12s - loss: 0.3139 - acc: 0.9077    
Epoch 3/6
40000/40000 [==============================] - 12s - loss: 0.2565 - acc: 0.9245    
Epoch 4/6
40000/40000 [==============================] - 12s - loss: 0.2306 - acc: 0.9316    
Epoch 5/6
40000/40000 [==============================] - 11s - loss: 0.2072 - acc: 0.9398    
Epoch 6/6
40000/40000 [==============================] - 12s - loss: 0.1947 - acc: 0.9416    


/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
  sample_weight=sample_weight)


20000/20000 [==============================] - 2s     
Epoch 1/6
40000/40000 [==============================] - 12s - loss: 0.6035 - acc: 0.8089    
Epoch 2/6
40000/40000 [==============================] - 12s - loss: 0.3363 - acc: 0.8993    
Epoch 3/6
40000/40000 [==============================] - 12s - loss: 0.2729 - acc: 0.9181    
Epoch 4/6
40000/40000 [==============================] - 12s - loss: 0.2380 - acc: 0.9298    
Epoch 5/6
40000/40000 [==============================] - 12s - loss: 0.2114 - acc: 0.9376    
Epoch 6/6
40000/40000 [==============================] - 12s - loss: 0.1930 - acc: 0.9442    


/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
  sample_weight=sample_weight)


19904/20000 [============================>.] - ETA: 0sEpoch 1/3
40000/40000 [==============================] - 13s - loss: 0.7216 - acc: 0.7599    
Epoch 2/3
40000/40000 [==============================] - 13s - loss: 0.4140 - acc: 0.8687    
Epoch 3/3
40000/40000 [==============================] - 13s - loss: 0.3545 - acc: 0.8897    


/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
  sample_weight=sample_weight)


19968/20000 [============================>.] - ETA: 0sEpoch 1/3
40000/40000 [==============================] - 13s - loss: 0.8014 - acc: 0.7343    
Epoch 2/3
40000/40000 [==============================] - 13s - loss: 0.4586 - acc: 0.8549    
Epoch 3/3
40000/40000 [==============================] - 13s - loss: 0.3886 - acc: 0.8797    


/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
  sample_weight=sample_weight)


20000/20000 [==============================] - 2s     
Epoch 1/3
40000/40000 [==============================] - 14s - loss: 0.8124 - acc: 0.7284    
Epoch 2/3
40000/40000 [==============================] - 13s - loss: 0.4838 - acc: 0.8477    
Epoch 3/3
40000/40000 [==============================] - 13s - loss: 0.4148 - acc: 0.8705    


/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
  sample_weight=sample_weight)


19936/20000 [============================>.] - ETA: 0sEpoch 1/6
40000/40000 [==============================] - 13s - loss: 0.7192 - acc: 0.7608    
Epoch 2/6
40000/40000 [==============================] - 13s - loss: 0.4043 - acc: 0.8712    
Epoch 3/6
40000/40000 [==============================] - 13s - loss: 0.3514 - acc: 0.8902    
Epoch 4/6
40000/40000 [==============================] - 13s - loss: 0.3170 - acc: 0.9009    
Epoch 5/6
40000/40000 [==============================] - 13s - loss: 0.2986 - acc: 0.9079    
Epoch 6/6
40000/40000 [==============================] - 13s - loss: 0.2777 - acc: 0.9138    


/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
  sample_weight=sample_weight)


20000/20000 [==============================] - 2s     
Epoch 1/6
40000/40000 [==============================] - 13s - loss: 0.7651 - acc: 0.7428    
Epoch 2/6
40000/40000 [==============================] - 13s - loss: 0.4377 - acc: 0.8626    
Epoch 3/6
40000/40000 [==============================] - 12s - loss: 0.3688 - acc: 0.8846    
Epoch 4/6
40000/40000 [==============================] - 13s - loss: 0.3298 - acc: 0.8983    
Epoch 5/6
40000/40000 [==============================] - 13s - loss: 0.3050 - acc: 0.9052    
Epoch 6/6
40000/40000 [==============================] - 13s - loss: 0.2945 - acc: 0.9091    


/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
  sample_weight=sample_weight)


19968/20000 [============================>.] - ETA: 0sEpoch 1/6
40000/40000 [==============================] - 13s - loss: 0.8654 - acc: 0.7107    
Epoch 2/6
40000/40000 [==============================] - 13s - loss: 0.5192 - acc: 0.8338    
Epoch 3/6
40000/40000 [==============================] - 13s - loss: 0.4300 - acc: 0.8638    
Epoch 4/6
40000/40000 [==============================] - 13s - loss: 0.3788 - acc: 0.8795    
Epoch 5/6
40000/40000 [==============================] - 13s - loss: 0.3477 - acc: 0.8908    
Epoch 6/6
40000/40000 [==============================] - 13s - loss: 0.3197 - acc: 0.8999    


/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
  sample_weight=sample_weight)


19968/20000 [============================>.] - ETA: 0sEpoch 1/3
40000/40000 [==============================] - 13s - loss: 0.5614 - acc: 0.8237    
Epoch 2/3
40000/40000 [==============================] - 13s - loss: 0.2812 - acc: 0.9163    
Epoch 3/3
40000/40000 [==============================] - 13s - loss: 0.2251 - acc: 0.9347    


/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
  sample_weight=sample_weight)


19904/20000 [============================>.] - ETA: 0sEpoch 1/3
40000/40000 [==============================] - 13s - loss: 0.5107 - acc: 0.8401    
Epoch 2/3
40000/40000 [==============================] - 13s - loss: 0.2421 - acc: 0.9307    
Epoch 3/3
40000/40000 [==============================] - 13s - loss: 0.1988 - acc: 0.9424    


/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
  sample_weight=sample_weight)


19936/20000 [============================>.] - ETA: 0sEpoch 1/3
40000/40000 [==============================] - 13s - loss: 0.5245 - acc: 0.8351    
Epoch 2/3
40000/40000 [==============================] - 13s - loss: 0.2639 - acc: 0.9222    
Epoch 3/3
40000/40000 [==============================] - 13s - loss: 0.2173 - acc: 0.9356    


/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
  sample_weight=sample_weight)


19904/20000 [============================>.] - ETA: 0sEpoch 1/6
40000/40000 [==============================] - 13s - loss: 0.5514 - acc: 0.8266    
Epoch 2/6
40000/40000 [==============================] - 13s - loss: 0.2738 - acc: 0.9178    
Epoch 3/6
40000/40000 [==============================] - 12s - loss: 0.2165 - acc: 0.9365    
Epoch 4/6
40000/40000 [==============================] - 13s - loss: 0.1909 - acc: 0.9453    
Epoch 5/6
40000/40000 [==============================] - 13s - loss: 0.1734 - acc: 0.9492    
Epoch 6/6
40000/40000 [==============================] - 13s - loss: 0.1621 - acc: 0.9533    


/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
  sample_weight=sample_weight)


20000/20000 [==============================] - 2s     
Epoch 1/6
40000/40000 [==============================] - 11s - loss: 0.5373 - acc: 0.8282    
Epoch 2/6
40000/40000 [==============================] - 11s - loss: 0.2628 - acc: 0.9222    
Epoch 3/6
40000/40000 [==============================] - 11s - loss: 0.2104 - acc: 0.9392    
Epoch 4/6
40000/40000 [==============================] - 11s - loss: 0.1844 - acc: 0.9455    
Epoch 5/6
40000/40000 [==============================] - 10s - loss: 0.1657 - acc: 0.9530    
Epoch 6/6
40000/40000 [==============================] - 11s - loss: 0.1482 - acc: 0.9576    


/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
  sample_weight=sample_weight)


19936/20000 [============================>.] - ETA: 0sEpoch 1/6
40000/40000 [==============================] - 11s - loss: 0.5453 - acc: 0.8316    
Epoch 2/6
40000/40000 [==============================] - 11s - loss: 0.2769 - acc: 0.9198    
Epoch 3/6
40000/40000 [==============================] - 11s - loss: 0.2206 - acc: 0.9356    
Epoch 4/6
40000/40000 [==============================] - 11s - loss: 0.1952 - acc: 0.9447    
Epoch 5/6
40000/40000 [==============================] - 11s - loss: 0.1756 - acc: 0.9485    
Epoch 6/6
40000/40000 [==============================] - 11s - loss: 0.1650 - acc: 0.9511    


/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
  sample_weight=sample_weight)


20000/20000 [==============================] - 2s     
Epoch 1/6
60000/60000 [==============================] - 17s - loss: 0.4784 - acc: 0.8494    
Epoch 2/6
60000/60000 [==============================] - 16s - loss: 0.2399 - acc: 0.9295    
Epoch 3/6
60000/60000 [==============================] - 16s - loss: 0.1875 - acc: 0.9451    
Epoch 4/6
60000/60000 [==============================] - 16s - loss: 0.1602 - acc: 0.9521    
Epoch 5/6
60000/60000 [==============================] - 16s - loss: 0.1445 - acc: 0.9584    
Epoch 6/6
60000/60000 [==============================] - 16s - loss: 0.1357 - acc: 0.9610    





GridSearchCV(cv=None, error_score='raise',
       estimator=<keras.wrappers.scikit_learn.KerasClassifier object at 0x7f42703d3e10>,
       fit_params={}, iid=True, n_jobs=1,
       param_grid={'dense_layer_sizes': [[32], [64], [32, 32], [64, 64]], 'nb_epoch': [3, 6], 'nb_pool': [2], 'nb_conv': [3], 'nb_filters': [8]},
       pre_dispatch='2*n_jobs', refit=True, scoring='log_loss', verbose=0)

打印最好模型的参数

print('The parameters of the best model are: ')
print(validator.best_params_)
The parameters of the best model are: 
{'dense_layer_sizes': [64, 64], 'nb_conv': 3, 'nb_pool': 2, 'nb_epoch': 6, 'nb_filters': 8}

返回模型

validator.bestestimator 返回sklearn-wrapped版本的最好模型

validator.bestestimator.model 返回(unwrapped)keras模型

best_model = validator.best_estimator_.model
metric_names = best_model.metrics_names
metric_values = best_model.evaluate(X_test, y_test)
print('\n')
for metric, value in zip(metric_names, metric_values):
    print(metric, ': ', value)
10000/10000 [==============================] - 1s     


loss :  0.0535527251991
acc :  0.9825

results matching ""

    No results matching ""