16.卷积自编码器
程序说明
时间:2016年11月23日
说明:一个卷积自编码器的例程。
数据集:MNIST
原博客地址:Building Autoencoders in Keras
1.加载keras模块
from keras.layers import Input, Dense, Convolution2D, MaxPooling2D, UpSampling2D
from keras.models import Model
from keras.datasets import mnist
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
Using TensorFlow backend.
2.Convolutional autoencoder
input_img = Input(shape=(28, 28, 1))
x = Convolution2D(16, 3, 3, activation='relu', border_mode='same')(input_img)
x = MaxPooling2D((2, 2), border_mode='same')(x)
x = Convolution2D(8, 3, 3, activation='relu', border_mode='same')(x)
x = MaxPooling2D((2, 2), border_mode='same')(x)
x = Convolution2D(8, 3, 3, activation='relu', border_mode='same')(x)
encoded = MaxPooling2D((2, 2), border_mode='same')(x)
# at this point the representation is (8, 4, 4) i.e. 128-dimensional
x = Convolution2D(8, 3, 3, activation='relu', border_mode='same')(encoded)
x = UpSampling2D((2, 2))(x)
x = Convolution2D(8, 3, 3, activation='relu', border_mode='same')(x)
x = UpSampling2D((2, 2))(x)
x = Convolution2D(16, 3, 3, activation='relu')(x)
x = UpSampling2D((2, 2))(x)
decoded = Convolution2D(1, 3, 3, activation='sigmoid', border_mode='same')(x)
autoencoder = Model(input_img, decoded)
autoencoder.compile(optimizer='adadelta', loss='binary_crossentropy')
(x_train, _), (x_test, _) = mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = np.reshape(x_train, (len(x_train), 28, 28, 1))
x_test = np.reshape(x_test, (len(x_test), 28, 28, 1))
autoencoder.fit(x_train, x_train,
nb_epoch=50,
batch_size=128,
shuffle=True,
validation_data=(x_test, x_test))
Train on 60000 samples, validate on 10000 samples
Epoch 1/50
60000/60000 [==============================] - 8s - loss: 0.2372 - val_loss: 0.1694
Epoch 2/50
60000/60000 [==============================] - 7s - loss: 0.1588 - val_loss: 0.1502
Epoch 3/50
60000/60000 [==============================] - 7s - loss: 0.1441 - val_loss: 0.1346
Epoch 4/50
60000/60000 [==============================] - 7s - loss: 0.1353 - val_loss: 0.1313
Epoch 5/50
60000/60000 [==============================] - 7s - loss: 0.1293 - val_loss: 0.1266
Epoch 6/50
60000/60000 [==============================] - 7s - loss: 0.1252 - val_loss: 0.1229
Epoch 7/50
60000/60000 [==============================] - 7s - loss: 0.1222 - val_loss: 0.1193
Epoch 8/50
60000/60000 [==============================] - 7s - loss: 0.1202 - val_loss: 0.1179
Epoch 9/50
60000/60000 [==============================] - 7s - loss: 0.1184 - val_loss: 0.1185
Epoch 10/50
60000/60000 [==============================] - 7s - loss: 0.1169 - val_loss: 0.1166
Epoch 11/50
60000/60000 [==============================] - 7s - loss: 0.1155 - val_loss: 0.1132
Epoch 12/50
60000/60000 [==============================] - 7s - loss: 0.1143 - val_loss: 0.1141
Epoch 13/50
60000/60000 [==============================] - 7s - loss: 0.1134 - val_loss: 0.1136
Epoch 14/50
60000/60000 [==============================] - 7s - loss: 0.1125 - val_loss: 0.1124
Epoch 15/50
60000/60000 [==============================] - 7s - loss: 0.1119 - val_loss: 0.1091
Epoch 16/50
60000/60000 [==============================] - 7s - loss: 0.1111 - val_loss: 0.1090
Epoch 17/50
60000/60000 [==============================] - 7s - loss: 0.1103 - val_loss: 0.1120
Epoch 18/50
60000/60000 [==============================] - 7s - loss: 0.1094 - val_loss: 0.1074
Epoch 19/50
60000/60000 [==============================] - 7s - loss: 0.1091 - val_loss: 0.1070
Epoch 20/50
60000/60000 [==============================] - 7s - loss: 0.1087 - val_loss: 0.1085
Epoch 21/50
60000/60000 [==============================] - 7s - loss: 0.1082 - val_loss: 0.1069
Epoch 22/50
60000/60000 [==============================] - 7s - loss: 0.1079 - val_loss: 0.1058
Epoch 23/50
60000/60000 [==============================] - 7s - loss: 0.1075 - val_loss: 0.1045
Epoch 24/50
60000/60000 [==============================] - 7s - loss: 0.1067 - val_loss: 0.1055
Epoch 25/50
60000/60000 [==============================] - 7s - loss: 0.1067 - val_loss: 0.1070
Epoch 26/50
60000/60000 [==============================] - 7s - loss: 0.1063 - val_loss: 0.1032
Epoch 27/50
60000/60000 [==============================] - 7s - loss: 0.1058 - val_loss: 0.1060
Epoch 28/50
60000/60000 [==============================] - 7s - loss: 0.1055 - val_loss: 0.1069
Epoch 29/50
60000/60000 [==============================] - 7s - loss: 0.1053 - val_loss: 0.1037
Epoch 30/50
60000/60000 [==============================] - 7s - loss: 0.1051 - val_loss: 0.1055
Epoch 31/50
60000/60000 [==============================] - 7s - loss: 0.1048 - val_loss: 0.1038
Epoch 32/50
60000/60000 [==============================] - 7s - loss: 0.1045 - val_loss: 0.1024
Epoch 33/50
60000/60000 [==============================] - 7s - loss: 0.1043 - val_loss: 0.1022
Epoch 34/50
60000/60000 [==============================] - 7s - loss: 0.1041 - val_loss: 0.1023
Epoch 35/50
60000/60000 [==============================] - 7s - loss: 0.1040 - val_loss: 0.1042
Epoch 36/50
60000/60000 [==============================] - 7s - loss: 0.1038 - val_loss: 0.1030
Epoch 37/50
60000/60000 [==============================] - 7s - loss: 0.1036 - val_loss: 0.1020
Epoch 38/50
60000/60000 [==============================] - 7s - loss: 0.1033 - val_loss: 0.1015
Epoch 39/50
60000/60000 [==============================] - 7s - loss: 0.1032 - val_loss: 0.1010
Epoch 40/50
60000/60000 [==============================] - 7s - loss: 0.1027 - val_loss: 0.1023
Epoch 41/50
60000/60000 [==============================] - 7s - loss: 0.1025 - val_loss: 0.1000
Epoch 42/50
60000/60000 [==============================] - 7s - loss: 0.1024 - val_loss: 0.1016
Epoch 43/50
60000/60000 [==============================] - 7s - loss: 0.1018 - val_loss: 0.1001
Epoch 44/50
60000/60000 [==============================] - 7s - loss: 0.1016 - val_loss: 0.1000
Epoch 45/50
60000/60000 [==============================] - 7s - loss: 0.1016 - val_loss: 0.1004
Epoch 46/50
60000/60000 [==============================] - 7s - loss: 0.1013 - val_loss: 0.0993
Epoch 47/50
60000/60000 [==============================] - 7s - loss: 0.1010 - val_loss: 0.1006
Epoch 48/50
60000/60000 [==============================] - 7s - loss: 0.1006 - val_loss: 0.1001
Epoch 49/50
60000/60000 [==============================] - 7s - loss: 0.1008 - val_loss: 0.0994
Epoch 50/50
60000/60000 [==============================] - 7s - loss: 0.1006 - val_loss: 0.0998
<keras.callbacks.History at 0x7f0244252a50>
decoded_imgs = autoencoder.predict(x_test)
n = 10
plt.figure(figsize=(20, 4))
for i in range(n):
# display original
ax = plt.subplot(2, n, i+1)
plt.imshow(x_test[i].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
# display reconstruction
ax = plt.subplot(2, n, i + n+1)
plt.imshow(decoded_imgs[i].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()
autoencoder.summary()
____________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
====================================================================================================
input_1 (InputLayer) (None, 28, 28, 1) 0
____________________________________________________________________________________________________
convolution2d_1 (Convolution2D) (None, 28, 28, 16) 160 input_1[0][0]
____________________________________________________________________________________________________
maxpooling2d_1 (MaxPooling2D) (None, 14, 14, 16) 0 convolution2d_1[0][0]
____________________________________________________________________________________________________
convolution2d_2 (Convolution2D) (None, 14, 14, 8) 1160 maxpooling2d_1[0][0]
____________________________________________________________________________________________________
maxpooling2d_2 (MaxPooling2D) (None, 7, 7, 8) 0 convolution2d_2[0][0]
____________________________________________________________________________________________________
convolution2d_3 (Convolution2D) (None, 7, 7, 8) 584 maxpooling2d_2[0][0]
____________________________________________________________________________________________________
maxpooling2d_3 (MaxPooling2D) (None, 4, 4, 8) 0 convolution2d_3[0][0]
____________________________________________________________________________________________________
convolution2d_4 (Convolution2D) (None, 4, 4, 8) 584 maxpooling2d_3[0][0]
____________________________________________________________________________________________________
upsampling2d_1 (UpSampling2D) (None, 8, 8, 8) 0 convolution2d_4[0][0]
____________________________________________________________________________________________________
convolution2d_5 (Convolution2D) (None, 8, 8, 8) 584 upsampling2d_1[0][0]
____________________________________________________________________________________________________
upsampling2d_2 (UpSampling2D) (None, 16, 16, 8) 0 convolution2d_5[0][0]
____________________________________________________________________________________________________
convolution2d_6 (Convolution2D) (None, 14, 14, 16) 1168 upsampling2d_2[0][0]
____________________________________________________________________________________________________
upsampling2d_3 (UpSampling2D) (None, 28, 28, 16) 0 convolution2d_6[0][0]
____________________________________________________________________________________________________
convolution2d_7 (Convolution2D) (None, 28, 28, 1) 145 upsampling2d_3[0][0]
====================================================================================================
Total params: 4385
____________________________________________________________________________________________________
中间编码层编码可视化
model_extractfeatures = Model(input=autoencoder.input, output=autoencoder.get_layer('maxpooling2d_3').output)
encoded_imgs = model_extractfeatures.predict(x_test)
print(encoded_imgs.shape)
n = 10
plt.figure(figsize=(20, 8))
for i in range(n):
ax = plt.subplot(1, n, i+1)
plt.imshow(encoded_imgs[i].reshape(4, 4 * 8).T)
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()
(10000, 4, 4, 8)