12.使用sklearn wrapper进行的参数搜索
程序说明
名称:使用sklearn wrapper做参数搜索
时间:2016年11月17日
说明:建造一个简单的卷积模型,通过使用sklearn的GridSearchCV去发现最好的模型。
数据集:MNIST
1.加载keras模块
from __future__ import print_function
import numpy as np
np.random.seed(1337) # for reproducibility
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Convolution2D, MaxPooling2D
from keras.utils import np_utils
from keras.wrappers.scikit_learn import KerasClassifier
from sklearn.grid_search import GridSearchCV
Using TensorFlow backend.
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/cross_validation.py:44: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.
"This module will be removed in 0.20.", DeprecationWarning)
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/grid_search.py:43: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. This module will be removed in 0.20.
DeprecationWarning)
2.变量初始化
nb_classes = 10
# input image dimensions
img_rows, img_cols = 28, 28
3.准备数据
# load training data and do basic data normalization
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1)
X_test = X_test.reshape(X_test.shape[0], img_rows, img_cols, 1)
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_test /= 255
转换类标号
# convert class vectors to binary class matrices
y_train = np_utils.to_categorical(y_train, nb_classes)
y_test = np_utils.to_categorical(y_test, nb_classes)
4.建立模型
使用Sequential()
构造一个有两个卷积层和若干个全连接层组成的模型,这里全连接的层数是由参数所决定的。
dense_layer_sizes:层尺寸的列表。这个列表中对于每个层都有一组数字。
nb_filters:每个卷积层中滤波器的个数
nb_conv:卷积核的尺寸
nb_pool:用于max pooling的池化面积
def make_model(dense_layer_sizes, nb_filters, nb_conv, nb_pool):
'''Creates model comprised of 2 convolutional layers followed by dense layers
dense_layer_sizes: List of layer sizes. This list has one number for each layer
nb_filters: Number of convolutional filters in each convolutional layer
nb_conv: Convolutional kernel size
nb_pool: Size of pooling area for max pooling
'''
model = Sequential()
model.add(Convolution2D(nb_filters, nb_conv, nb_conv,
border_mode='valid',
input_shape=(img_rows, img_cols, 1)))
model.add(Activation('relu'))
model.add(Convolution2D(nb_filters, nb_conv, nb_conv))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(nb_pool, nb_pool)))
model.add(Dropout(0.25))
model.add(Flatten())
for layer_size in dense_layer_sizes:
model.add(Dense(layer_size))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(nb_classes))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy',
optimizer='adadelta',
metrics=['accuracy'])
return model
5.sklearn接口
KerasClassifier()实现了sklearn的分类器接口
keras.wrappers.scikit_learn.KerasClassifier(build_fn=None, **sk_params)
build_fn:可调用的函数或类对象
sk_params:模型参数和训练参数
dense_size_candidates = [[32], [64], [32, 32], [64, 64]]
my_classifier = KerasClassifier(make_model, batch_size=32)
sklearn中的GridSearchCV函数
说明:对估计器的指定参数值进行穷举搜索。
validator = GridSearchCV(my_classifier,
param_grid={'dense_layer_sizes': dense_size_candidates,
# nb_epoch可用于调整,即使不是模型构建函数的参数
'nb_epoch': [3, 6],
'nb_filters': [8],
'nb_conv': [3],
'nb_pool': [2]},
scoring='log_loss',
n_jobs=1)
开始拟合
validator.fit(X_train, y_train)
Epoch 1/3
40000/40000 [==============================] - 12s - loss: 0.8605 - acc: 0.7147
Epoch 2/3
40000/40000 [==============================] - 11s - loss: 0.5645 - acc: 0.8208
Epoch 3/3
40000/40000 [==============================] - 12s - loss: 0.4642 - acc: 0.8525
1536/20000 [=>............................] - ETA: 2s
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
sample_weight=sample_weight)
19968/20000 [============================>.] - ETA: 0sEpoch 1/3
40000/40000 [==============================] - 12s - loss: 0.8284 - acc: 0.7265
Epoch 2/3
40000/40000 [==============================] - 12s - loss: 0.5357 - acc: 0.8283
Epoch 3/3
40000/40000 [==============================] - 12s - loss: 0.4524 - acc: 0.8563
1280/20000 [>.............................] - ETA: 2s
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
sample_weight=sample_weight)
19968/20000 [============================>.] - ETA: 0sEpoch 1/3
40000/40000 [==============================] - 12s - loss: 0.8130 - acc: 0.7311
Epoch 2/3
40000/40000 [==============================] - 12s - loss: 0.5159 - acc: 0.8359
Epoch 3/3
40000/40000 [==============================] - 12s - loss: 0.4416 - acc: 0.8602
1152/20000 [>.............................] - ETA: 3s
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
sample_weight=sample_weight)
19968/20000 [============================>.] - ETA: 0sEpoch 1/6
40000/40000 [==============================] - 12s - loss: 0.8093 - acc: 0.7304
Epoch 2/6
40000/40000 [==============================] - 12s - loss: 0.4811 - acc: 0.8459
Epoch 3/6
40000/40000 [==============================] - 12s - loss: 0.4099 - acc: 0.8723
Epoch 4/6
40000/40000 [==============================] - 11s - loss: 0.3624 - acc: 0.8859
Epoch 5/6
40000/40000 [==============================] - 11s - loss: 0.3331 - acc: 0.8956
Epoch 6/6
40000/40000 [==============================] - 12s - loss: 0.3093 - acc: 0.9030
928/20000 [>.............................] - ETA: 3s
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
sample_weight=sample_weight)
19936/20000 [============================>.] - ETA: 0sEpoch 1/6
40000/40000 [==============================] - 12s - loss: 0.7886 - acc: 0.7393
Epoch 2/6
40000/40000 [==============================] - 12s - loss: 0.4860 - acc: 0.8451
Epoch 3/6
40000/40000 [==============================] - 12s - loss: 0.4136 - acc: 0.8712
Epoch 4/6
40000/40000 [==============================] - 12s - loss: 0.3739 - acc: 0.8827
Epoch 5/6
40000/40000 [==============================] - 11s - loss: 0.3499 - acc: 0.8924
Epoch 6/6
40000/40000 [==============================] - 12s - loss: 0.3297 - acc: 0.8989
800/20000 [>.............................] - ETA: 4s
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
sample_weight=sample_weight)
19936/20000 [============================>.] - ETA: 0sEpoch 1/6
40000/40000 [==============================] - 12s - loss: 0.9260 - acc: 0.6871
Epoch 2/6
40000/40000 [==============================] - 11s - loss: 0.6032 - acc: 0.8043
Epoch 3/6
40000/40000 [==============================] - 11s - loss: 0.5158 - acc: 0.8342
Epoch 4/6
40000/40000 [==============================] - 12s - loss: 0.4425 - acc: 0.8599
Epoch 5/6
40000/40000 [==============================] - 11s - loss: 0.4088 - acc: 0.8709
Epoch 6/6
40000/40000 [==============================] - 11s - loss: 0.3644 - acc: 0.8848
544/20000 [..............................] - ETA: 6s
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
sample_weight=sample_weight)
20000/20000 [==============================] - 2s
Epoch 1/3
40000/40000 [==============================] - 11s - loss: 0.6009 - acc: 0.8104
Epoch 2/3
40000/40000 [==============================] - 11s - loss: 0.3410 - acc: 0.8968
Epoch 3/3
40000/40000 [==============================] - 12s - loss: 0.2770 - acc: 0.9162
256/20000 [..............................] - ETA: 14s
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
sample_weight=sample_weight)
19904/20000 [============================>.] - ETA: 0sEpoch 1/3
40000/40000 [==============================] - 12s - loss: 0.6185 - acc: 0.8061
Epoch 2/3
40000/40000 [==============================] - 12s - loss: 0.3376 - acc: 0.8999
Epoch 3/3
40000/40000 [==============================] - 12s - loss: 0.2741 - acc: 0.9193
32/20000 [..............................] - ETA: 119s
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
sample_weight=sample_weight)
19936/20000 [============================>.] - ETA: 0sEpoch 1/3
40000/40000 [==============================] - 12s - loss: 0.6259 - acc: 0.7990
Epoch 2/3
40000/40000 [==============================] - 12s - loss: 0.3257 - acc: 0.9015
Epoch 3/3
40000/40000 [==============================] - 12s - loss: 0.2599 - acc: 0.9230
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
sample_weight=sample_weight)
19936/20000 [============================>.] - ETA: 0sEpoch 1/6
40000/40000 [==============================] - 12s - loss: 0.6295 - acc: 0.7993
Epoch 2/6
40000/40000 [==============================] - 12s - loss: 0.3693 - acc: 0.8871
Epoch 3/6
40000/40000 [==============================] - 12s - loss: 0.2988 - acc: 0.9092
Epoch 4/6
40000/40000 [==============================] - 11s - loss: 0.2542 - acc: 0.9238
Epoch 5/6
40000/40000 [==============================] - 12s - loss: 0.2246 - acc: 0.9343
Epoch 6/6
40000/40000 [==============================] - 11s - loss: 0.2026 - acc: 0.9413
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
sample_weight=sample_weight)
19968/20000 [============================>.] - ETA: 0sEpoch 1/6
40000/40000 [==============================] - 12s - loss: 0.5739 - acc: 0.8182
Epoch 2/6
40000/40000 [==============================] - 12s - loss: 0.3139 - acc: 0.9077
Epoch 3/6
40000/40000 [==============================] - 12s - loss: 0.2565 - acc: 0.9245
Epoch 4/6
40000/40000 [==============================] - 12s - loss: 0.2306 - acc: 0.9316
Epoch 5/6
40000/40000 [==============================] - 11s - loss: 0.2072 - acc: 0.9398
Epoch 6/6
40000/40000 [==============================] - 12s - loss: 0.1947 - acc: 0.9416
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
sample_weight=sample_weight)
20000/20000 [==============================] - 2s
Epoch 1/6
40000/40000 [==============================] - 12s - loss: 0.6035 - acc: 0.8089
Epoch 2/6
40000/40000 [==============================] - 12s - loss: 0.3363 - acc: 0.8993
Epoch 3/6
40000/40000 [==============================] - 12s - loss: 0.2729 - acc: 0.9181
Epoch 4/6
40000/40000 [==============================] - 12s - loss: 0.2380 - acc: 0.9298
Epoch 5/6
40000/40000 [==============================] - 12s - loss: 0.2114 - acc: 0.9376
Epoch 6/6
40000/40000 [==============================] - 12s - loss: 0.1930 - acc: 0.9442
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
sample_weight=sample_weight)
19904/20000 [============================>.] - ETA: 0sEpoch 1/3
40000/40000 [==============================] - 13s - loss: 0.7216 - acc: 0.7599
Epoch 2/3
40000/40000 [==============================] - 13s - loss: 0.4140 - acc: 0.8687
Epoch 3/3
40000/40000 [==============================] - 13s - loss: 0.3545 - acc: 0.8897
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
sample_weight=sample_weight)
19968/20000 [============================>.] - ETA: 0sEpoch 1/3
40000/40000 [==============================] - 13s - loss: 0.8014 - acc: 0.7343
Epoch 2/3
40000/40000 [==============================] - 13s - loss: 0.4586 - acc: 0.8549
Epoch 3/3
40000/40000 [==============================] - 13s - loss: 0.3886 - acc: 0.8797
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
sample_weight=sample_weight)
20000/20000 [==============================] - 2s
Epoch 1/3
40000/40000 [==============================] - 14s - loss: 0.8124 - acc: 0.7284
Epoch 2/3
40000/40000 [==============================] - 13s - loss: 0.4838 - acc: 0.8477
Epoch 3/3
40000/40000 [==============================] - 13s - loss: 0.4148 - acc: 0.8705
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
sample_weight=sample_weight)
19936/20000 [============================>.] - ETA: 0sEpoch 1/6
40000/40000 [==============================] - 13s - loss: 0.7192 - acc: 0.7608
Epoch 2/6
40000/40000 [==============================] - 13s - loss: 0.4043 - acc: 0.8712
Epoch 3/6
40000/40000 [==============================] - 13s - loss: 0.3514 - acc: 0.8902
Epoch 4/6
40000/40000 [==============================] - 13s - loss: 0.3170 - acc: 0.9009
Epoch 5/6
40000/40000 [==============================] - 13s - loss: 0.2986 - acc: 0.9079
Epoch 6/6
40000/40000 [==============================] - 13s - loss: 0.2777 - acc: 0.9138
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
sample_weight=sample_weight)
20000/20000 [==============================] - 2s
Epoch 1/6
40000/40000 [==============================] - 13s - loss: 0.7651 - acc: 0.7428
Epoch 2/6
40000/40000 [==============================] - 13s - loss: 0.4377 - acc: 0.8626
Epoch 3/6
40000/40000 [==============================] - 12s - loss: 0.3688 - acc: 0.8846
Epoch 4/6
40000/40000 [==============================] - 13s - loss: 0.3298 - acc: 0.8983
Epoch 5/6
40000/40000 [==============================] - 13s - loss: 0.3050 - acc: 0.9052
Epoch 6/6
40000/40000 [==============================] - 13s - loss: 0.2945 - acc: 0.9091
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
sample_weight=sample_weight)
19968/20000 [============================>.] - ETA: 0sEpoch 1/6
40000/40000 [==============================] - 13s - loss: 0.8654 - acc: 0.7107
Epoch 2/6
40000/40000 [==============================] - 13s - loss: 0.5192 - acc: 0.8338
Epoch 3/6
40000/40000 [==============================] - 13s - loss: 0.4300 - acc: 0.8638
Epoch 4/6
40000/40000 [==============================] - 13s - loss: 0.3788 - acc: 0.8795
Epoch 5/6
40000/40000 [==============================] - 13s - loss: 0.3477 - acc: 0.8908
Epoch 6/6
40000/40000 [==============================] - 13s - loss: 0.3197 - acc: 0.8999
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
sample_weight=sample_weight)
19968/20000 [============================>.] - ETA: 0sEpoch 1/3
40000/40000 [==============================] - 13s - loss: 0.5614 - acc: 0.8237
Epoch 2/3
40000/40000 [==============================] - 13s - loss: 0.2812 - acc: 0.9163
Epoch 3/3
40000/40000 [==============================] - 13s - loss: 0.2251 - acc: 0.9347
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
sample_weight=sample_weight)
19904/20000 [============================>.] - ETA: 0sEpoch 1/3
40000/40000 [==============================] - 13s - loss: 0.5107 - acc: 0.8401
Epoch 2/3
40000/40000 [==============================] - 13s - loss: 0.2421 - acc: 0.9307
Epoch 3/3
40000/40000 [==============================] - 13s - loss: 0.1988 - acc: 0.9424
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
sample_weight=sample_weight)
19936/20000 [============================>.] - ETA: 0sEpoch 1/3
40000/40000 [==============================] - 13s - loss: 0.5245 - acc: 0.8351
Epoch 2/3
40000/40000 [==============================] - 13s - loss: 0.2639 - acc: 0.9222
Epoch 3/3
40000/40000 [==============================] - 13s - loss: 0.2173 - acc: 0.9356
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
sample_weight=sample_weight)
19904/20000 [============================>.] - ETA: 0sEpoch 1/6
40000/40000 [==============================] - 13s - loss: 0.5514 - acc: 0.8266
Epoch 2/6
40000/40000 [==============================] - 13s - loss: 0.2738 - acc: 0.9178
Epoch 3/6
40000/40000 [==============================] - 12s - loss: 0.2165 - acc: 0.9365
Epoch 4/6
40000/40000 [==============================] - 13s - loss: 0.1909 - acc: 0.9453
Epoch 5/6
40000/40000 [==============================] - 13s - loss: 0.1734 - acc: 0.9492
Epoch 6/6
40000/40000 [==============================] - 13s - loss: 0.1621 - acc: 0.9533
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
sample_weight=sample_weight)
20000/20000 [==============================] - 2s
Epoch 1/6
40000/40000 [==============================] - 11s - loss: 0.5373 - acc: 0.8282
Epoch 2/6
40000/40000 [==============================] - 11s - loss: 0.2628 - acc: 0.9222
Epoch 3/6
40000/40000 [==============================] - 11s - loss: 0.2104 - acc: 0.9392
Epoch 4/6
40000/40000 [==============================] - 11s - loss: 0.1844 - acc: 0.9455
Epoch 5/6
40000/40000 [==============================] - 10s - loss: 0.1657 - acc: 0.9530
Epoch 6/6
40000/40000 [==============================] - 11s - loss: 0.1482 - acc: 0.9576
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
sample_weight=sample_weight)
19936/20000 [============================>.] - ETA: 0sEpoch 1/6
40000/40000 [==============================] - 11s - loss: 0.5453 - acc: 0.8316
Epoch 2/6
40000/40000 [==============================] - 11s - loss: 0.2769 - acc: 0.9198
Epoch 3/6
40000/40000 [==============================] - 11s - loss: 0.2206 - acc: 0.9356
Epoch 4/6
40000/40000 [==============================] - 11s - loss: 0.1952 - acc: 0.9447
Epoch 5/6
40000/40000 [==============================] - 11s - loss: 0.1756 - acc: 0.9485
Epoch 6/6
40000/40000 [==============================] - 11s - loss: 0.1650 - acc: 0.9511
/root/Util/miniconda/lib/python2.7/site-packages/sklearn/metrics/scorer.py:127: DeprecationWarning: Scoring method log_loss was renamed to neg_log_loss in version 0.18 and will be removed in 0.20.
sample_weight=sample_weight)
20000/20000 [==============================] - 2s
Epoch 1/6
60000/60000 [==============================] - 17s - loss: 0.4784 - acc: 0.8494
Epoch 2/6
60000/60000 [==============================] - 16s - loss: 0.2399 - acc: 0.9295
Epoch 3/6
60000/60000 [==============================] - 16s - loss: 0.1875 - acc: 0.9451
Epoch 4/6
60000/60000 [==============================] - 16s - loss: 0.1602 - acc: 0.9521
Epoch 5/6
60000/60000 [==============================] - 16s - loss: 0.1445 - acc: 0.9584
Epoch 6/6
60000/60000 [==============================] - 16s - loss: 0.1357 - acc: 0.9610
GridSearchCV(cv=None, error_score='raise',
estimator=<keras.wrappers.scikit_learn.KerasClassifier object at 0x7f42703d3e10>,
fit_params={}, iid=True, n_jobs=1,
param_grid={'dense_layer_sizes': [[32], [64], [32, 32], [64, 64]], 'nb_epoch': [3, 6], 'nb_pool': [2], 'nb_conv': [3], 'nb_filters': [8]},
pre_dispatch='2*n_jobs', refit=True, scoring='log_loss', verbose=0)
打印最好模型的参数
print('The parameters of the best model are: ')
print(validator.best_params_)
The parameters of the best model are:
{'dense_layer_sizes': [64, 64], 'nb_conv': 3, 'nb_pool': 2, 'nb_epoch': 6, 'nb_filters': 8}
返回模型
validator.bestestimator 返回sklearn-wrapped版本的最好模型
validator.bestestimator.model 返回(unwrapped)keras模型
best_model = validator.best_estimator_.model
metric_names = best_model.metrics_names
metric_values = best_model.evaluate(X_test, y_test)
print('\n')
for metric, value in zip(metric_names, metric_values):
print(metric, ': ', value)
10000/10000 [==============================] - 1s
loss : 0.0535527251991
acc : 0.9825