15.自编码器
程序说明
时间:2016年11月23日
说明:一个自编码器的例程。
- 单隐藏层的自编码器
- 稀疏约束的自编码器
- 多隐藏层的自编码器
数据集:MNIST
原博客地址:Building Autoencoders in Keras
1.加载keras模块
from keras.layers import Input, Dense
from keras.models import Model
from keras.datasets import mnist
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
Using TensorFlow backend.
2.构建单隐藏层的自编码器
# this is the size of our encoded representations
encoding_dim = 32 # 32 floats -> compression of factor 24.5, assuming the input is 784 floats
# this is our input placeholder
input_img = Input(shape=(784,))
# "encoded" is the encoded representation of the input
encoded = Dense(encoding_dim, activation='relu')(input_img)
# "decoded" is the lossy reconstruction of the input
decoded = Dense(784, activation='sigmoid')(encoded)
# this model maps an input to its reconstruction
autoencoder = Model(input=input_img, output=decoded)
定义编码器
# this model maps an input to its encoded representation
encoder = Model(input=input_img, output=encoded)
定义解码器
# create a placeholder for an encoded (32-dimensional) input
encoded_input = Input(shape=(encoding_dim,))
# retrieve the last layer of the autoencoder model
decoder_layer = autoencoder.layers[-1]
# create the decoder model
decoder = Model(input=encoded_input, output=decoder_layer(encoded_input))
autoencoder.compile(optimizer='adadelta', loss='binary_crossentropy')
准备数据
(x_train, _), (x_test, _) = mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))
print x_train.shape
print x_test.shape
(60000, 784)
(10000, 784)
训练模型
autoencoder.fit(x_train, x_train,
nb_epoch=50,
batch_size=256,
shuffle=True,
validation_data=(x_test, x_test))
Train on 60000 samples, validate on 10000 samples
Epoch 1/50
60000/60000 [==============================] - 1s - loss: 0.3755 - val_loss: 0.2733
Epoch 2/50
60000/60000 [==============================] - 1s - loss: 0.2659 - val_loss: 0.2553
Epoch 3/50
60000/60000 [==============================] - 1s - loss: 0.2458 - val_loss: 0.2344
Epoch 4/50
60000/60000 [==============================] - 1s - loss: 0.2269 - val_loss: 0.2174
Epoch 5/50
60000/60000 [==============================] - 1s - loss: 0.2121 - val_loss: 0.2043
Epoch 6/50
60000/60000 [==============================] - 1s - loss: 0.2004 - val_loss: 0.1938
Epoch 7/50
60000/60000 [==============================] - 1s - loss: 0.1910 - val_loss: 0.1854
Epoch 8/50
60000/60000 [==============================] - 1s - loss: 0.1833 - val_loss: 0.1784
Epoch 9/50
60000/60000 [==============================] - 1s - loss: 0.1768 - val_loss: 0.1726
Epoch 10/50
60000/60000 [==============================] - 1s - loss: 0.1713 - val_loss: 0.1674
Epoch 11/50
60000/60000 [==============================] - 1s - loss: 0.1666 - val_loss: 0.1631
Epoch 12/50
60000/60000 [==============================] - 1s - loss: 0.1626 - val_loss: 0.1592
Epoch 13/50
60000/60000 [==============================] - 1s - loss: 0.1589 - val_loss: 0.1557
Epoch 14/50
60000/60000 [==============================] - 1s - loss: 0.1555 - val_loss: 0.1524
Epoch 15/50
60000/60000 [==============================] - 1s - loss: 0.1523 - val_loss: 0.1495
Epoch 16/50
60000/60000 [==============================] - 1s - loss: 0.1493 - val_loss: 0.1464
Epoch 17/50
60000/60000 [==============================] - 1s - loss: 0.1465 - val_loss: 0.1436
Epoch 18/50
60000/60000 [==============================] - 1s - loss: 0.1439 - val_loss: 0.1411
Epoch 19/50
60000/60000 [==============================] - 1s - loss: 0.1414 - val_loss: 0.1388
Epoch 20/50
60000/60000 [==============================] - 1s - loss: 0.1390 - val_loss: 0.1365
Epoch 21/50
60000/60000 [==============================] - 1s - loss: 0.1368 - val_loss: 0.1342
Epoch 22/50
60000/60000 [==============================] - 1s - loss: 0.1347 - val_loss: 0.1322
Epoch 23/50
60000/60000 [==============================] - 1s - loss: 0.1327 - val_loss: 0.1302
Epoch 24/50
60000/60000 [==============================] - 1s - loss: 0.1307 - val_loss: 0.1284
Epoch 25/50
60000/60000 [==============================] - 1s - loss: 0.1289 - val_loss: 0.1264
Epoch 26/50
60000/60000 [==============================] - 1s - loss: 0.1271 - val_loss: 0.1246
Epoch 27/50
60000/60000 [==============================] - 1s - loss: 0.1253 - val_loss: 0.1230
Epoch 28/50
60000/60000 [==============================] - 1s - loss: 0.1237 - val_loss: 0.1213
Epoch 29/50
60000/60000 [==============================] - 1s - loss: 0.1221 - val_loss: 0.1198
Epoch 30/50
60000/60000 [==============================] - 1s - loss: 0.1206 - val_loss: 0.1184
Epoch 31/50
60000/60000 [==============================] - 1s - loss: 0.1191 - val_loss: 0.1169
Epoch 32/50
60000/60000 [==============================] - 1s - loss: 0.1178 - val_loss: 0.1155
Epoch 33/50
60000/60000 [==============================] - 1s - loss: 0.1165 - val_loss: 0.1143
Epoch 34/50
60000/60000 [==============================] - 1s - loss: 0.1152 - val_loss: 0.1131
Epoch 35/50
60000/60000 [==============================] - 1s - loss: 0.1141 - val_loss: 0.1120
Epoch 36/50
60000/60000 [==============================] - 1s - loss: 0.1130 - val_loss: 0.1110
Epoch 37/50
60000/60000 [==============================] - 1s - loss: 0.1120 - val_loss: 0.1100
Epoch 38/50
60000/60000 [==============================] - 1s - loss: 0.1111 - val_loss: 0.1090
Epoch 39/50
60000/60000 [==============================] - 1s - loss: 0.1102 - val_loss: 0.1082
Epoch 40/50
60000/60000 [==============================] - 1s - loss: 0.1094 - val_loss: 0.1074
Epoch 41/50
60000/60000 [==============================] - 1s - loss: 0.1086 - val_loss: 0.1066
Epoch 42/50
60000/60000 [==============================] - 1s - loss: 0.1079 - val_loss: 0.1059
Epoch 43/50
60000/60000 [==============================] - 1s - loss: 0.1072 - val_loss: 0.1053
Epoch 44/50
60000/60000 [==============================] - 1s - loss: 0.1066 - val_loss: 0.1047
Epoch 45/50
60000/60000 [==============================] - 1s - loss: 0.1060 - val_loss: 0.1041
Epoch 46/50
60000/60000 [==============================] - 1s - loss: 0.1055 - val_loss: 0.1036
Epoch 47/50
60000/60000 [==============================] - 1s - loss: 0.1049 - val_loss: 0.1031
Epoch 48/50
60000/60000 [==============================] - 1s - loss: 0.1045 - val_loss: 0.1026
Epoch 49/50
60000/60000 [==============================] - 1s - loss: 0.1040 - val_loss: 0.1022
Epoch 50/50
60000/60000 [==============================] - 1s - loss: 0.1036 - val_loss: 0.1018
<keras.callbacks.History at 0x7fd3f40ce850>
在测试集上进行编码和解码
# encode and decode some digits
# note that we take them from the *test* set
encoded_imgs = encoder.predict(x_test)
decoded_imgs = decoder.predict(encoded_imgs)
显示结果
n = 10 # how many digits we will display
plt.figure(figsize=(20, 4))
for i in range(n):
# display original
ax = plt.subplot(2, n, i + 1)
plt.imshow(x_test[i].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
# display reconstruction
ax = plt.subplot(2, n, i + 1 + n)
plt.imshow(decoded_imgs[i].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()
3.构建编码表示上添加稀疏约束的自编码器
from keras import regularizers
encoding_dim = 32
input_img = Input(shape=(784,))
# add a Dense layer with a L1 activity regularizer
encoded = Dense(encoding_dim, activation='relu',
activity_regularizer=regularizers.activity_l1(10e-5))(input_img)
decoded = Dense(784, activation='sigmoid')(encoded)
autoencoder = Model(input=input_img, output=decoded)
定义模型
input_img = Input(shape=(784,))
encoded = Dense(128, activation='relu')(input_img)
encoded = Dense(64, activation='relu')(encoded)
encoded = Dense(32, activation='relu')(encoded)
decoded = Dense(64, activation='relu')(encoded)
decoded = Dense(128, activation='relu')(decoded)
decoded = Dense(784, activation='sigmoid')(decoded)
autoencoder = Model(input=input_img, output=decoded)
autoencoder.compile(optimizer='adadelta', loss='binary_crossentropy')
训练模型
autoencoder.fit(x_train, x_train,
nb_epoch=100,
batch_size=256,
shuffle=True,
validation_data=(x_test, x_test))
Train on 60000 samples, validate on 10000 samples
Epoch 1/100
60000/60000 [==============================] - 2s - loss: 0.3401 - val_loss: 0.2627
Epoch 2/100
60000/60000 [==============================] - 2s - loss: 0.2541 - val_loss: 0.2413
Epoch 3/100
60000/60000 [==============================] - 2s - loss: 0.2296 - val_loss: 0.2164
Epoch 4/100
60000/60000 [==============================] - 2s - loss: 0.2068 - val_loss: 0.1989
Epoch 5/100
60000/60000 [==============================] - 2s - loss: 0.1940 - val_loss: 0.1868
Epoch 6/100
60000/60000 [==============================] - 2s - loss: 0.1842 - val_loss: 0.1792
Epoch 7/100
60000/60000 [==============================] - 2s - loss: 0.1765 - val_loss: 0.1711
Epoch 8/100
60000/60000 [==============================] - 2s - loss: 0.1704 - val_loss: 0.1667
Epoch 9/100
60000/60000 [==============================] - 2s - loss: 0.1658 - val_loss: 0.1616
Epoch 10/100
60000/60000 [==============================] - 2s - loss: 0.1616 - val_loss: 0.1597
Epoch 11/100
60000/60000 [==============================] - 2s - loss: 0.1579 - val_loss: 0.1546
Epoch 12/100
60000/60000 [==============================] - 2s - loss: 0.1545 - val_loss: 0.1516
Epoch 13/100
60000/60000 [==============================] - 2s - loss: 0.1513 - val_loss: 0.1482
Epoch 14/100
60000/60000 [==============================] - 2s - loss: 0.1485 - val_loss: 0.1452
Epoch 15/100
60000/60000 [==============================] - 2s - loss: 0.1462 - val_loss: 0.1431
Epoch 16/100
60000/60000 [==============================] - 2s - loss: 0.1437 - val_loss: 0.1433
Epoch 17/100
60000/60000 [==============================] - 2s - loss: 0.1417 - val_loss: 0.1397
Epoch 18/100
60000/60000 [==============================] - 2s - loss: 0.1398 - val_loss: 0.1375
Epoch 19/100
60000/60000 [==============================] - 2s - loss: 0.1380 - val_loss: 0.1374
Epoch 20/100
60000/60000 [==============================] - 2s - loss: 0.1363 - val_loss: 0.1339
Epoch 21/100
60000/60000 [==============================] - 2s - loss: 0.1347 - val_loss: 0.1323
Epoch 22/100
60000/60000 [==============================] - 2s - loss: 0.1334 - val_loss: 0.1309
Epoch 23/100
60000/60000 [==============================] - 2s - loss: 0.1320 - val_loss: 0.1303
Epoch 24/100
60000/60000 [==============================] - 2s - loss: 0.1307 - val_loss: 0.1288
Epoch 25/100
60000/60000 [==============================] - 2s - loss: 0.1295 - val_loss: 0.1287
Epoch 26/100
60000/60000 [==============================] - 2s - loss: 0.1284 - val_loss: 0.1279
Epoch 27/100
60000/60000 [==============================] - 2s - loss: 0.1271 - val_loss: 0.1255
Epoch 28/100
60000/60000 [==============================] - 2s - loss: 0.1261 - val_loss: 0.1236
Epoch 29/100
60000/60000 [==============================] - 2s - loss: 0.1250 - val_loss: 0.1220
Epoch 30/100
60000/60000 [==============================] - 2s - loss: 0.1238 - val_loss: 0.1225
Epoch 31/100
60000/60000 [==============================] - 2s - loss: 0.1229 - val_loss: 0.1230
Epoch 32/100
60000/60000 [==============================] - 2s - loss: 0.1220 - val_loss: 0.1196
Epoch 33/100
60000/60000 [==============================] - 2s - loss: 0.1212 - val_loss: 0.1200
Epoch 34/100
60000/60000 [==============================] - 2s - loss: 0.1202 - val_loss: 0.1188
Epoch 35/100
60000/60000 [==============================] - 2s - loss: 0.1193 - val_loss: 0.1191
Epoch 36/100
60000/60000 [==============================] - 2s - loss: 0.1187 - val_loss: 0.1170
Epoch 37/100
60000/60000 [==============================] - 2s - loss: 0.1178 - val_loss: 0.1163
Epoch 38/100
60000/60000 [==============================] - 2s - loss: 0.1171 - val_loss: 0.1156
Epoch 39/100
60000/60000 [==============================] - 2s - loss: 0.1166 - val_loss: 0.1145
Epoch 40/100
60000/60000 [==============================] - 2s - loss: 0.1160 - val_loss: 0.1136
Epoch 41/100
60000/60000 [==============================] - 2s - loss: 0.1155 - val_loss: 0.1132
Epoch 42/100
60000/60000 [==============================] - 2s - loss: 0.1150 - val_loss: 0.1153
Epoch 43/100
60000/60000 [==============================] - 2s - loss: 0.1145 - val_loss: 0.1129
Epoch 44/100
60000/60000 [==============================] - 2s - loss: 0.1140 - val_loss: 0.1129
Epoch 45/100
60000/60000 [==============================] - 2s - loss: 0.1135 - val_loss: 0.1118
Epoch 46/100
60000/60000 [==============================] - 2s - loss: 0.1131 - val_loss: 0.1118
Epoch 47/100
60000/60000 [==============================] - 2s - loss: 0.1126 - val_loss: 0.1117
Epoch 48/100
60000/60000 [==============================] - 2s - loss: 0.1122 - val_loss: 0.1117
Epoch 49/100
60000/60000 [==============================] - 2s - loss: 0.1117 - val_loss: 0.1103
Epoch 50/100
60000/60000 [==============================] - 2s - loss: 0.1114 - val_loss: 0.1093
Epoch 51/100
60000/60000 [==============================] - 2s - loss: 0.1109 - val_loss: 0.1096
Epoch 52/100
60000/60000 [==============================] - 2s - loss: 0.1107 - val_loss: 0.1099
Epoch 53/100
60000/60000 [==============================] - 2s - loss: 0.1102 - val_loss: 0.1090
Epoch 54/100
60000/60000 [==============================] - 2s - loss: 0.1098 - val_loss: 0.1083
Epoch 55/100
60000/60000 [==============================] - 2s - loss: 0.1095 - val_loss: 0.1077
Epoch 56/100
60000/60000 [==============================] - 2s - loss: 0.1091 - val_loss: 0.1077
Epoch 57/100
60000/60000 [==============================] - 2s - loss: 0.1089 - val_loss: 0.1081
Epoch 58/100
60000/60000 [==============================] - 2s - loss: 0.1085 - val_loss: 0.1082
Epoch 59/100
60000/60000 [==============================] - 2s - loss: 0.1082 - val_loss: 0.1073
Epoch 60/100
60000/60000 [==============================] - 2s - loss: 0.1078 - val_loss: 0.1066
Epoch 61/100
60000/60000 [==============================] - 2s - loss: 0.1075 - val_loss: 0.1059
Epoch 62/100
60000/60000 [==============================] - 2s - loss: 0.1072 - val_loss: 0.1053
Epoch 63/100
60000/60000 [==============================] - 2s - loss: 0.1069 - val_loss: 0.1063
Epoch 64/100
60000/60000 [==============================] - 2s - loss: 0.1067 - val_loss: 0.1047
Epoch 65/100
60000/60000 [==============================] - 2s - loss: 0.1064 - val_loss: 0.1044
Epoch 66/100
60000/60000 [==============================] - 2s - loss: 0.1060 - val_loss: 0.1048
Epoch 67/100
60000/60000 [==============================] - 2s - loss: 0.1057 - val_loss: 0.1039
Epoch 68/100
60000/60000 [==============================] - 2s - loss: 0.1055 - val_loss: 0.1039
Epoch 69/100
60000/60000 [==============================] - 2s - loss: 0.1052 - val_loss: 0.1043
Epoch 70/100
60000/60000 [==============================] - 2s - loss: 0.1050 - val_loss: 0.1042
Epoch 71/100
60000/60000 [==============================] - 2s - loss: 0.1047 - val_loss: 0.1036
Epoch 72/100
60000/60000 [==============================] - 2s - loss: 0.1044 - val_loss: 0.1033
Epoch 73/100
60000/60000 [==============================] - 2s - loss: 0.1042 - val_loss: 0.1040
Epoch 74/100
60000/60000 [==============================] - 2s - loss: 0.1040 - val_loss: 0.1024
Epoch 75/100
60000/60000 [==============================] - 2s - loss: 0.1037 - val_loss: 0.1025
Epoch 76/100
60000/60000 [==============================] - 2s - loss: 0.1035 - val_loss: 0.1020
Epoch 77/100
60000/60000 [==============================] - 2s - loss: 0.1031 - val_loss: 0.1019
Epoch 78/100
60000/60000 [==============================] - 2s - loss: 0.1030 - val_loss: 0.1022
Epoch 79/100
60000/60000 [==============================] - 2s - loss: 0.1028 - val_loss: 0.1024
Epoch 80/100
60000/60000 [==============================] - 2s - loss: 0.1026 - val_loss: 0.1011
Epoch 81/100
60000/60000 [==============================] - 2s - loss: 0.1023 - val_loss: 0.1004
Epoch 82/100
60000/60000 [==============================] - 2s - loss: 0.1020 - val_loss: 0.1015
Epoch 83/100
60000/60000 [==============================] - 2s - loss: 0.1018 - val_loss: 0.1005
Epoch 84/100
60000/60000 [==============================] - 2s - loss: 0.1016 - val_loss: 0.1018
Epoch 85/100
60000/60000 [==============================] - 3s - loss: 0.1013 - val_loss: 0.0995
Epoch 86/100
60000/60000 [==============================] - 2s - loss: 0.1013 - val_loss: 0.0998
Epoch 87/100
60000/60000 [==============================] - 2s - loss: 0.1010 - val_loss: 0.0995
Epoch 88/100
60000/60000 [==============================] - 2s - loss: 0.1008 - val_loss: 0.1008
Epoch 89/100
60000/60000 [==============================] - 2s - loss: 0.1006 - val_loss: 0.0995
Epoch 90/100
60000/60000 [==============================] - 2s - loss: 0.1005 - val_loss: 0.1007
Epoch 91/100
60000/60000 [==============================] - 2s - loss: 0.1003 - val_loss: 0.0990
Epoch 92/100
60000/60000 [==============================] - 2s - loss: 0.1001 - val_loss: 0.0983
Epoch 93/100
60000/60000 [==============================] - 2s - loss: 0.0999 - val_loss: 0.0992
Epoch 94/100
60000/60000 [==============================] - 2s - loss: 0.0999 - val_loss: 0.0994
Epoch 95/100
60000/60000 [==============================] - 2s - loss: 0.0996 - val_loss: 0.0977
Epoch 96/100
60000/60000 [==============================] - 2s - loss: 0.0994 - val_loss: 0.0983
Epoch 97/100
60000/60000 [==============================] - 2s - loss: 0.0994 - val_loss: 0.0981
Epoch 98/100
60000/60000 [==============================] - 2s - loss: 0.0991 - val_loss: 0.0974
Epoch 99/100
60000/60000 [==============================] - 2s - loss: 0.0989 - val_loss: 0.0971
Epoch 100/100
60000/60000 [==============================] - 2s - loss: 0.0988 - val_loss: 0.0971
<keras.callbacks.History at 0x7fd354149b10>
在测试集上进行编码和解码
该模型与之前的模型明显不同在于编码的稀疏结构,该模型在10000个测试数据上encoded_imgs.mean()得到了3.33的值,而之前的模型得到了7.30值,说明该模型产生的编码更加的稀疏。
# encode and decode some digits
# note that we take them from the *test* set
encoded_imgs = encoder.predict(x_test)
decoded_imgs = decoder.predict(encoded_imgs)
显示结果
n = 10 # how many digits we will display
plt.figure(figsize=(20, 4))
for i in range(n):
# display original
ax = plt.subplot(2, n, i + 1)
plt.imshow(x_test[i].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
# display reconstruction
ax = plt.subplot(2, n, i + 1 + n)
plt.imshow(decoded_imgs[i].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()
4.构建多隐藏层自编码器
input_img = Input(shape=(784,))
encoded = Dense(128, activation='relu')(input_img)
encoded = Dense(64, activation='relu')(encoded)
encoded = Dense(32, activation='relu')(encoded)
decoded = Dense(64, activation='relu')(encoded)
decoded = Dense(128, activation='relu')(decoded)
decoded = Dense(784, activation='sigmoid')(decoded)
autoencoder = Model(input=input_img, output=decoded)
autoencoder.compile(optimizer='adadelta', loss='binary_crossentropy')
autoencoder.fit(x_train, x_train,
nb_epoch=100,
batch_size=256,
shuffle=True,
validation_data=(x_test, x_test))