Unsupervised learning

AutoEncoders

An autoencoder, is an artificial neural network used for learning efficient codings.

The aim of an autoencoder is to learn a representation (encoding) for a set of data, typically for the purpose of dimensionality reduction.

Unsupervised learning is a type of machine learning algorithm used to draw inferences from datasets consisting of input data without labeled responses. The most common unsupervised learning method is cluster analysis, which is used for exploratory data analysis to find hidden patterns or grouping in data.

Reference

Based on https://blog.keras.io/building-autoencoders-in-keras.html

Introducing Keras Functional API

The Keras functional API is the way to go for defining complex models, such as multi-output models, directed acyclic graphs, or models with shared layers.

All the Functional API relies on the fact that each keras.Layer object is a callable object!

See 8.2 Multi-Modal Networks for further details.


from keras.layers import Input, Dense
from keras.models import Model

from keras.datasets import mnist

import numpy as np
Using TensorFlow backend.
# this is the size of our encoded representations
encoding_dim = 32  # 32 floats -> compression of factor 24.5, assuming the input is 784 floats

# this is our input placeholder
input_img = Input(shape=(784,))
# "encoded" is the encoded representation of the input
encoded = Dense(encoding_dim, activation='relu')(input_img)

# "decoded" is the lossy reconstruction of the input
decoded = Dense(784, activation='sigmoid')(encoded)

# this model maps an input to its reconstruction
autoencoder = Model(input_img, decoded)
# this model maps an input to its encoded representation
encoder = Model(input_img, encoded)
# create a placeholder for an encoded (32-dimensional) input
encoded_input = Input(shape=(encoding_dim,))
# retrieve the last layer of the autoencoder model
decoder_layer = autoencoder.layers[-1]
# create the decoder model
decoder = Model(encoded_input, decoder_layer(encoded_input))
autoencoder.compile(optimizer='adadelta', loss='binary_crossentropy')
(x_train, _), (x_test, _) = mnist.load_data()

x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))
#note: x_train, x_train :) 
autoencoder.fit(x_train, x_train,
                epochs=50,
                batch_size=256,
                shuffle=True,
                validation_data=(x_test, x_test))
Train on 60000 samples, validate on 10000 samples
Epoch 1/50
60000/60000 [==============================] - 1s - loss: 0.3830 - val_loss: 0.2731
Epoch 2/50
60000/60000 [==============================] - 1s - loss: 0.2664 - val_loss: 0.2561
Epoch 3/50
60000/60000 [==============================] - 1s - loss: 0.2463 - val_loss: 0.2336
Epoch 4/50
60000/60000 [==============================] - 1s - loss: 0.2258 - val_loss: 0.2156
Epoch 5/50
60000/60000 [==============================] - 1s - loss: 0.2105 - val_loss: 0.2030
Epoch 6/50
60000/60000 [==============================] - 1s - loss: 0.1997 - val_loss: 0.1936
Epoch 7/50
60000/60000 [==============================] - 1s - loss: 0.1914 - val_loss: 0.1863
Epoch 8/50
60000/60000 [==============================] - 1s - loss: 0.1846 - val_loss: 0.1800
Epoch 9/50
60000/60000 [==============================] - 1s - loss: 0.1789 - val_loss: 0.1749
Epoch 10/50
60000/60000 [==============================] - 1s - loss: 0.1740 - val_loss: 0.1702
Epoch 11/50
60000/60000 [==============================] - 1s - loss: 0.1697 - val_loss: 0.1660
Epoch 12/50
60000/60000 [==============================] - 1s - loss: 0.1657 - val_loss: 0.1622
Epoch 13/50
60000/60000 [==============================] - 1s - loss: 0.1620 - val_loss: 0.1587
Epoch 14/50
60000/60000 [==============================] - 1s - loss: 0.1586 - val_loss: 0.1554
Epoch 15/50
60000/60000 [==============================] - 1s - loss: 0.1554 - val_loss: 0.1524
Epoch 16/50
60000/60000 [==============================] - 1s - loss: 0.1525 - val_loss: 0.1495
Epoch 17/50
60000/60000 [==============================] - 1s - loss: 0.1497 - val_loss: 0.1468
Epoch 18/50
60000/60000 [==============================] - 1s - loss: 0.1470 - val_loss: 0.1441
Epoch 19/50
60000/60000 [==============================] - 1s - loss: 0.1444 - val_loss: 0.1415
Epoch 20/50
60000/60000 [==============================] - 1s - loss: 0.1419 - val_loss: 0.1391
Epoch 21/50
60000/60000 [==============================] - 1s - loss: 0.1395 - val_loss: 0.1367
Epoch 22/50
60000/60000 [==============================] - 1s - loss: 0.1371 - val_loss: 0.1345
Epoch 23/50
60000/60000 [==============================] - 1s - loss: 0.1349 - val_loss: 0.1323ss: 0.13
Epoch 24/50
60000/60000 [==============================] - 1s - loss: 0.1328 - val_loss: 0.1302
Epoch 25/50
60000/60000 [==============================] - 1s - loss: 0.1308 - val_loss: 0.1283
Epoch 26/50
60000/60000 [==============================] - 1s - loss: 0.1289 - val_loss: 0.1264
Epoch 27/50
60000/60000 [==============================] - 1s - loss: 0.1271 - val_loss: 0.1247
Epoch 28/50
60000/60000 [==============================] - 1s - loss: 0.1254 - val_loss: 0.1230
Epoch 29/50
60000/60000 [==============================] - 1s - loss: 0.1238 - val_loss: 0.1215
Epoch 30/50
60000/60000 [==============================] - 1s - loss: 0.1223 - val_loss: 0.1200
Epoch 31/50
60000/60000 [==============================] - 1s - loss: 0.1208 - val_loss: 0.1186
Epoch 32/50
60000/60000 [==============================] - 1s - loss: 0.1195 - val_loss: 0.1172
Epoch 33/50
60000/60000 [==============================] - 1s - loss: 0.1182 - val_loss: 0.1160
Epoch 34/50
60000/60000 [==============================] - 1s - loss: 0.1170 - val_loss: 0.1149
Epoch 35/50
60000/60000 [==============================] - 1s - loss: 0.1158 - val_loss: 0.1137
Epoch 36/50
60000/60000 [==============================] - 1s - loss: 0.1148 - val_loss: 0.1127
Epoch 37/50
60000/60000 [==============================] - 1s - loss: 0.1138 - val_loss: 0.1117
Epoch 38/50
60000/60000 [==============================] - 1s - loss: 0.1129 - val_loss: 0.1109
Epoch 39/50
60000/60000 [==============================] - 1s - loss: 0.1120 - val_loss: 0.1100
Epoch 40/50
60000/60000 [==============================] - 1s - loss: 0.1112 - val_loss: 0.1093
Epoch 41/50
60000/60000 [==============================] - 1s - loss: 0.1105 - val_loss: 0.1085
Epoch 42/50
60000/60000 [==============================] - 1s - loss: 0.1098 - val_loss: 0.1079
Epoch 43/50
60000/60000 [==============================] - 1s - loss: 0.1092 - val_loss: 0.1072
Epoch 44/50
60000/60000 [==============================] - 1s - loss: 0.1086 - val_loss: 0.1066
Epoch 45/50
60000/60000 [==============================] - 1s - loss: 0.1080 - val_loss: 0.1061
Epoch 46/50
60000/60000 [==============================] - 1s - loss: 0.1074 - val_loss: 0.1056
Epoch 47/50
60000/60000 [==============================] - 1s - loss: 0.1069 - val_loss: 0.1051
Epoch 48/50
60000/60000 [==============================] - 1s - loss: 0.1065 - val_loss: 0.1046
Epoch 49/50
60000/60000 [==============================] - 1s - loss: 0.1060 - val_loss: 0.1042
Epoch 50/50
60000/60000 [==============================] - 1s - loss: 0.1056 - val_loss: 0.1037





<keras.callbacks.History at 0x7fd1ce5140f0>

Testing the Autoencoder

from matplotlib import pyplot as plt

%matplotlib inline
encoded_imgs = encoder.predict(x_test)
decoded_imgs = decoder.predict(encoded_imgs)

n = 10 
plt.figure(figsize=(20, 4))
for i in range(n):
    # original
    ax = plt.subplot(2, n, i + 1)
    plt.imshow(x_test[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # reconstruction
    ax = plt.subplot(2, n, i + 1 + n)
    plt.imshow(decoded_imgs[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()

png

Sample generation with Autoencoder

encoded_imgs = np.random.rand(10,32)
decoded_imgs = decoder.predict(encoded_imgs)

n = 10 
plt.figure(figsize=(20, 4))
for i in range(n):
    # generation
    ax = plt.subplot(2, n, i + 1 + n)
    plt.imshow(decoded_imgs[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()

png


Convolutional AutoEncoder

Since our inputs are images, it makes sense to use convolutional neural networks (convnets) as encoders and decoders.

In practical settings, autoencoders applied to images are always convolutional autoencoders --they simply perform much better.

The encoder will consist in a stack of Conv2D and MaxPooling2D layers (max pooling being used for spatial down-sampling), while the decoder will consist in a stack of Conv2D and UpSampling2D layers.

from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D
from keras.models import Model
from keras import backend as K

input_img = Input(shape=(28, 28, 1))  # adapt this if using `channels_first` image data format

x = Conv2D(16, (3, 3), activation='relu', padding='same')(input_img)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(8, (3, 3), activation='relu', padding='same')(x)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(8, (3, 3), activation='relu', padding='same')(x)
encoded = MaxPooling2D((2, 2), padding='same')(x)

# at this point the representation is (4, 4, 8) i.e. 128-dimensional

x = Conv2D(8, (3, 3), activation='relu', padding='same')(encoded)
x = UpSampling2D((2, 2))(x)
x = Conv2D(8, (3, 3), activation='relu', padding='same')(x)
x = UpSampling2D((2, 2))(x)
x = Conv2D(16, (3, 3), activation='relu')(x)
x = UpSampling2D((2, 2))(x)
decoded = Conv2D(1, (3, 3), activation='sigmoid', padding='same')(x)

conv_autoencoder = Model(input_img, decoded)
conv_autoencoder.compile(optimizer='adadelta', loss='binary_crossentropy')
from keras import backend as K

if K.image_data_format() == 'channels_last':
    shape_ord = (28, 28, 1)
else:
    shape_ord = (1, 28, 28)

(x_train, _), (x_test, _) = mnist.load_data()

x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.

x_train = np.reshape(x_train, ((x_train.shape[0],) + shape_ord))  
x_test = np.reshape(x_test, ((x_test.shape[0],) + shape_ord))
x_train.shape
(60000, 28, 28, 1)
from keras.callbacks import TensorBoard
batch_size=128
steps_per_epoch = np.int(np.floor(x_train.shape[0] / batch_size))
conv_autoencoder.fit(x_train, x_train, epochs=50, batch_size=128,
                     shuffle=True, validation_data=(x_test, x_test),
                     callbacks=[TensorBoard(log_dir='./tf_autoencoder_logs')])
Train on 60000 samples, validate on 10000 samples
Epoch 1/50
60000/60000 [==============================] - 8s - loss: 0.2327 - val_loss: 0.1740
Epoch 2/50
60000/60000 [==============================] - 7s - loss: 0.1645 - val_loss: 0.1551
Epoch 3/50
60000/60000 [==============================] - 7s - loss: 0.1501 - val_loss: 0.1442
Epoch 4/50
60000/60000 [==============================] - 7s - loss: 0.1404 - val_loss: 0.1375
Epoch 5/50
60000/60000 [==============================] - 7s - loss: 0.1342 - val_loss: 0.1316
Epoch 6/50
60000/60000 [==============================] - 7s - loss: 0.1300 - val_loss: 0.1298
Epoch 7/50
60000/60000 [==============================] - 7s - loss: 0.1272 - val_loss: 0.1301
Epoch 8/50
60000/60000 [==============================] - 7s - loss: 0.1243 - val_loss: 0.1221
Epoch 9/50
60000/60000 [==============================] - 7s - loss: 0.1222 - val_loss: 0.1196
Epoch 10/50
60000/60000 [==============================] - 7s - loss: 0.1207 - val_loss: 0.1184
Epoch 11/50
60000/60000 [==============================] - 7s - loss: 0.1188 - val_loss: 0.1162
Epoch 12/50
60000/60000 [==============================] - 7s - loss: 0.1175 - val_loss: 0.1160
Epoch 13/50
60000/60000 [==============================] - 7s - loss: 0.1167 - val_loss: 0.1164
Epoch 14/50
60000/60000 [==============================] - 7s - loss: 0.1154 - val_loss: 0.1160
Epoch 15/50
60000/60000 [==============================] - 7s - loss: 0.1145 - val_loss: 0.1159
Epoch 16/50
60000/60000 [==============================] - 7s - loss: 0.1132 - val_loss: 0.1110
Epoch 17/50
60000/60000 [==============================] - 7s - loss: 0.1127 - val_loss: 0.1108
Epoch 18/50
60000/60000 [==============================] - 7s - loss: 0.1118 - val_loss: 0.1099
Epoch 19/50
60000/60000 [==============================] - 7s - loss: 0.1113 - val_loss: 0.1106
Epoch 20/50
60000/60000 [==============================] - 7s - loss: 0.1108 - val_loss: 0.1120
Epoch 21/50
60000/60000 [==============================] - 7s - loss: 0.1104 - val_loss: 0.1064
Epoch 22/50
60000/60000 [==============================] - 7s - loss: 0.1094 - val_loss: 0.1075
Epoch 23/50
60000/60000 [==============================] - 7s - loss: 0.1088 - val_loss: 0.1088
Epoch 24/50
60000/60000 [==============================] - 7s - loss: 0.1085 - val_loss: 0.1071
Epoch 25/50
60000/60000 [==============================] - 7s - loss: 0.1081 - val_loss: 0.1060
Epoch 26/50
60000/60000 [==============================] - 7s - loss: 0.1075 - val_loss: 0.1062
Epoch 27/50
60000/60000 [==============================] - 7s - loss: 0.1074 - val_loss: 0.1062
Epoch 28/50
60000/60000 [==============================] - 7s - loss: 0.1065 - val_loss: 0.1045
Epoch 29/50
60000/60000 [==============================] - 7s - loss: 0.1062 - val_loss: 0.1043
Epoch 30/50
60000/60000 [==============================] - 7s - loss: 0.1057 - val_loss: 0.1038
Epoch 31/50
60000/60000 [==============================] - 7s - loss: 0.1053 - val_loss: 0.1040
Epoch 32/50
60000/60000 [==============================] - 7s - loss: 0.1048 - val_loss: 0.1041
Epoch 33/50
60000/60000 [==============================] - 7s - loss: 0.1045 - val_loss: 0.1057
Epoch 34/50
60000/60000 [==============================] - 7s - loss: 0.1041 - val_loss: 0.1026
Epoch 35/50
60000/60000 [==============================] - 7s - loss: 0.1041 - val_loss: 0.1042
Epoch 36/50
60000/60000 [==============================] - 7s - loss: 0.1035 - val_loss: 0.1053
Epoch 37/50
60000/60000 [==============================] - 7s - loss: 0.1032 - val_loss: 0.1006
Epoch 38/50
60000/60000 [==============================] - 7s - loss: 0.1030 - val_loss: 0.1011
Epoch 39/50
60000/60000 [==============================] - 7s - loss: 0.1028 - val_loss: 0.1013
Epoch 40/50
60000/60000 [==============================] - 7s - loss: 0.1027 - val_loss: 0.1018
Epoch 41/50
60000/60000 [==============================] - 7s - loss: 0.1025 - val_loss: 0.1019
Epoch 42/50
60000/60000 [==============================] - 7s - loss: 0.1024 - val_loss: 0.1025
Epoch 43/50
60000/60000 [==============================] - 7s - loss: 0.1020 - val_loss: 0.1015
Epoch 44/50
60000/60000 [==============================] - 7s - loss: 0.1020 - val_loss: 0.1018
Epoch 45/50
60000/60000 [==============================] - 7s - loss: 0.1015 - val_loss: 0.1011
Epoch 46/50
60000/60000 [==============================] - 7s - loss: 0.1013 - val_loss: 0.0999
Epoch 47/50
60000/60000 [==============================] - 7s - loss: 0.1010 - val_loss: 0.0995
Epoch 48/50
60000/60000 [==============================] - 7s - loss: 0.1008 - val_loss: 0.0996
Epoch 49/50
60000/60000 [==============================] - 7s - loss: 0.1008 - val_loss: 0.0990
Epoch 50/50
60000/60000 [==============================] - 7s - loss: 0.1006 - val_loss: 0.0995





<keras.callbacks.History at 0x7fd1bebacfd0>
decoded_imgs = conv_autoencoder.predict(x_test)

n = 10
plt.figure(figsize=(20, 4))
for i in range(n):
    # display original
    ax = plt.subplot(2, n, i+1)
    plt.imshow(x_test[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # display reconstruction
    ax = plt.subplot(2, n, i + n + 1)
    plt.imshow(decoded_imgs[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()

png

We coudl also have a look at the 128-dimensional encoded middle representation

conv_encoder = Model(input_img, encoded)
encoded_imgs = conv_encoder.predict(x_test)

n = 10
plt.figure(figsize=(20, 8))
for i in range(n):
    ax = plt.subplot(1, n, i+1)
    plt.imshow(encoded_imgs[i].reshape(4, 4 * 8).T)
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()

png

Pretraining encoders

One of the powerful tools of auto-encoders is using the encoder to generate meaningful representation from the feature vectors.

# Use the encoder to pretrain a classifier

Application to Image Denoising

Let's put our convolutional autoencoder to work on an image denoising problem. It's simple: we will train the autoencoder to map noisy digits images to clean digits images.

Here's how we will generate synthetic noisy digits: we just apply a gaussian noise matrix and clip the images between 0 and 1.

from keras.datasets import mnist
import numpy as np

(x_train, _), (x_test, _) = mnist.load_data()

x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = np.reshape(x_train, (len(x_train), 28, 28, 1))  # adapt this if using `channels_first` image data format
x_test = np.reshape(x_test, (len(x_test), 28, 28, 1))  # adapt this if using `channels_first` image data format

noise_factor = 0.5
x_train_noisy = x_train + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=x_train.shape) 
x_test_noisy = x_test + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=x_test.shape) 

x_train_noisy = np.clip(x_train_noisy, 0., 1.)
x_test_noisy = np.clip(x_test_noisy, 0., 1.)
Using TensorFlow backend.

Here's how the noisy digits look like:

n = 10
plt.figure(figsize=(20, 2))
for i in range(n):
    ax = plt.subplot(1, n, i+1)
    plt.imshow(x_test_noisy[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()

png

Question

If you squint you can still recognize them, but barely.

Can our autoencoder learn to recover the original digits? Let's find out.

Compared to the previous convolutional autoencoder, in order to improve the quality of the reconstructed, we'll use a slightly different model with more filters per layer:

from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D
from keras.models import Model

from keras.callbacks import TensorBoard
Using TensorFlow backend.
input_img = Input(shape=(28, 28, 1))  # adapt this if using `channels_first` image data format

x = Conv2D(32, (3, 3), activation='relu', padding='same')(input_img)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(32, (3, 3), activation='relu', padding='same')(x)
encoded = MaxPooling2D((2, 2), padding='same')(x)

# at this point the representation is (7, 7, 32)

x = Conv2D(32, (3, 3), activation='relu', padding='same')(encoded)
x = UpSampling2D((2, 2))(x)
x = Conv2D(32, (3, 3), activation='relu', padding='same')(x)
x = UpSampling2D((2, 2))(x)
decoded = Conv2D(1, (3, 3), activation='sigmoid', padding='same')(x)

autoencoder = Model(input_img, decoded)
autoencoder.compile(optimizer='adadelta', loss='binary_crossentropy')

Let's train the AutoEncoder for 100 epochs

autoencoder.fit(x_train_noisy, x_train,
                epochs=100,
                batch_size=128,
                shuffle=True,
                validation_data=(x_test_noisy, x_test),
                callbacks=[TensorBoard(log_dir='/tmp/autoencoder_denoise', 
                                       histogram_freq=0, write_graph=False)])
Train on 60000 samples, validate on 10000 samples
Epoch 1/100
60000/60000 [==============================] - 9s - loss: 0.1901 - val_loss: 0.1255
Epoch 2/100
60000/60000 [==============================] - 8s - loss: 0.1214 - val_loss: 0.1142
Epoch 3/100
60000/60000 [==============================] - 8s - loss: 0.1135 - val_loss: 0.1085
Epoch 4/100
60000/60000 [==============================] - 8s - loss: 0.1094 - val_loss: 0.1074
Epoch 5/100
60000/60000 [==============================] - 8s - loss: 0.1071 - val_loss: 0.1052
Epoch 6/100
60000/60000 [==============================] - 8s - loss: 0.1053 - val_loss: 0.1046
Epoch 7/100
60000/60000 [==============================] - 8s - loss: 0.1040 - val_loss: 0.1020
Epoch 8/100
60000/60000 [==============================] - 8s - loss: 0.1031 - val_loss: 0.1028
Epoch 9/100
60000/60000 [==============================] - 8s - loss: 0.1023 - val_loss: 0.1009
Epoch 10/100
60000/60000 [==============================] - 8s - loss: 0.1017 - val_loss: 0.1005
Epoch 11/100
60000/60000 [==============================] - 8s - loss: 0.1009 - val_loss: 0.1003
Epoch 12/100
60000/60000 [==============================] - 8s - loss: 0.1007 - val_loss: 0.1010
Epoch 13/100
60000/60000 [==============================] - 8s - loss: 0.1002 - val_loss: 0.0989
Epoch 14/100
60000/60000 [==============================] - 8s - loss: 0.1000 - val_loss: 0.0986
Epoch 15/100
60000/60000 [==============================] - 8s - loss: 0.0998 - val_loss: 0.0983
Epoch 16/100
60000/60000 [==============================] - 8s - loss: 0.0993 - val_loss: 0.0983
Epoch 17/100
60000/60000 [==============================] - 8s - loss: 0.0991 - val_loss: 0.0979
Epoch 18/100
60000/60000 [==============================] - 8s - loss: 0.0988 - val_loss: 0.0988
Epoch 19/100
60000/60000 [==============================] - 8s - loss: 0.0986 - val_loss: 0.0976
Epoch 20/100
60000/60000 [==============================] - 8s - loss: 0.0984 - val_loss: 0.0987
Epoch 21/100
60000/60000 [==============================] - 8s - loss: 0.0983 - val_loss: 0.0973
Epoch 22/100
60000/60000 [==============================] - 8s - loss: 0.0981 - val_loss: 0.0971
Epoch 23/100
60000/60000 [==============================] - 8s - loss: 0.0979 - val_loss: 0.0978
Epoch 24/100
60000/60000 [==============================] - 8s - loss: 0.0977 - val_loss: 0.0968
Epoch 25/100
60000/60000 [==============================] - 8s - loss: 0.0975 - val_loss: 0.0976
Epoch 26/100
60000/60000 [==============================] - 8s - loss: 0.0974 - val_loss: 0.0963
Epoch 27/100
60000/60000 [==============================] - 8s - loss: 0.0973 - val_loss: 0.0963
Epoch 28/100
60000/60000 [==============================] - 8s - loss: 0.0972 - val_loss: 0.0964
Epoch 29/100
60000/60000 [==============================] - 8s - loss: 0.0970 - val_loss: 0.0961
Epoch 30/100
60000/60000 [==============================] - 8s - loss: 0.0970 - val_loss: 0.0968
Epoch 31/100
60000/60000 [==============================] - 8s - loss: 0.0969 - val_loss: 0.0959
Epoch 32/100
60000/60000 [==============================] - 8s - loss: 0.0968 - val_loss: 0.0959
Epoch 33/100
60000/60000 [==============================] - 8s - loss: 0.0967 - val_loss: 0.0957
Epoch 34/100
60000/60000 [==============================] - 8s - loss: 0.0966 - val_loss: 0.0958
Epoch 35/100
60000/60000 [==============================] - 8s - loss: 0.0965 - val_loss: 0.0956
Epoch 36/100
60000/60000 [==============================] - 8s - loss: 0.0965 - val_loss: 0.0959
Epoch 37/100
60000/60000 [==============================] - 8s - loss: 0.0964 - val_loss: 0.0963
Epoch 38/100
60000/60000 [==============================] - 8s - loss: 0.0963 - val_loss: 0.0960
Epoch 39/100
60000/60000 [==============================] - 8s - loss: 0.0963 - val_loss: 0.0963
Epoch 40/100
60000/60000 [==============================] - 8s - loss: 0.0962 - val_loss: 0.0954
Epoch 41/100
60000/60000 [==============================] - 8s - loss: 0.0961 - val_loss: 0.0955
Epoch 42/100
60000/60000 [==============================] - 8s - loss: 0.0960 - val_loss: 0.0953
Epoch 43/100
60000/60000 [==============================] - 8s - loss: 0.0960 - val_loss: 0.0952
Epoch 44/100
60000/60000 [==============================] - 8s - loss: 0.0960 - val_loss: 0.0951
Epoch 45/100
60000/60000 [==============================] - 8s - loss: 0.0959 - val_loss: 0.0951
Epoch 46/100
60000/60000 [==============================] - 8s - loss: 0.0958 - val_loss: 0.0953
Epoch 47/100
60000/60000 [==============================] - 8s - loss: 0.0957 - val_loss: 0.0952
Epoch 48/100
60000/60000 [==============================] - 8s - loss: 0.0957 - val_loss: 0.0954
Epoch 49/100
60000/60000 [==============================] - 8s - loss: 0.0957 - val_loss: 0.0954
Epoch 50/100
60000/60000 [==============================] - 8s - loss: 0.0957 - val_loss: 0.0954
Epoch 51/100
60000/60000 [==============================] - 8s - loss: 0.0955 - val_loss: 0.0948
Epoch 52/100
60000/60000 [==============================] - 8s - loss: 0.0956 - val_loss: 0.0951
Epoch 53/100
60000/60000 [==============================] - 8s - loss: 0.0955 - val_loss: 0.0951
Epoch 54/100
60000/60000 [==============================] - 8s - loss: 0.0955 - val_loss: 0.0951
Epoch 55/100
60000/60000 [==============================] - 8s - loss: 0.0955 - val_loss: 0.0948
Epoch 56/100
60000/60000 [==============================] - 8s - loss: 0.0954 - val_loss: 0.0955
Epoch 57/100
60000/60000 [==============================] - 8s - loss: 0.0954 - val_loss: 0.0950
Epoch 58/100
60000/60000 [==============================] - 8s - loss: 0.0953 - val_loss: 0.0955
Epoch 59/100
60000/60000 [==============================] - 8s - loss: 0.0952 - val_loss: 0.0947
Epoch 60/100
60000/60000 [==============================] - 8s - loss: 0.0953 - val_loss: 0.0947
Epoch 61/100
60000/60000 [==============================] - 8s - loss: 0.0952 - val_loss: 0.0947
Epoch 62/100
60000/60000 [==============================] - 8s - loss: 0.0952 - val_loss: 0.0945
Epoch 63/100
60000/60000 [==============================] - 8s - loss: 0.0952 - val_loss: 0.0945
Epoch 64/100
60000/60000 [==============================] - 8s - loss: 0.0952 - val_loss: 0.0945
Epoch 65/100
60000/60000 [==============================] - 8s - loss: 0.0950 - val_loss: 0.0954
Epoch 66/100
60000/60000 [==============================] - 8s - loss: 0.0951 - val_loss: 0.0945
Epoch 67/100
60000/60000 [==============================] - 8s - loss: 0.0951 - val_loss: 0.0946
Epoch 68/100
60000/60000 [==============================] - 8s - loss: 0.0950 - val_loss: 0.0951
Epoch 69/100
60000/60000 [==============================] - 8s - loss: 0.0950 - val_loss: 0.0952
Epoch 70/100
60000/60000 [==============================] - 8s - loss: 0.0949 - val_loss: 0.0948
Epoch 71/100
60000/60000 [==============================] - 8s - loss: 0.0949 - val_loss: 0.0958
Epoch 72/100
60000/60000 [==============================] - 8s - loss: 0.0949 - val_loss: 0.0953
Epoch 73/100
60000/60000 [==============================] - 8s - loss: 0.0949 - val_loss: 0.0942
Epoch 74/100
60000/60000 [==============================] - 8s - loss: 0.0948 - val_loss: 0.0946
Epoch 75/100
60000/60000 [==============================] - 8s - loss: 0.0948 - val_loss: 0.0942
Epoch 76/100
60000/60000 [==============================] - 8s - loss: 0.0948 - val_loss: 0.0945
Epoch 77/100
60000/60000 [==============================] - 8s - loss: 0.0948 - val_loss: 0.0944
Epoch 78/100
60000/60000 [==============================] - 8s - loss: 0.0948 - val_loss: 0.0942
Epoch 79/100
60000/60000 [==============================] - 8s - loss: 0.0947 - val_loss: 0.0944
Epoch 80/100
60000/60000 [==============================] - 8s - loss: 0.0947 - val_loss: 0.0942
Epoch 81/100
60000/60000 [==============================] - 8s - loss: 0.0946 - val_loss: 0.0943
Epoch 82/100
60000/60000 [==============================] - 8s - loss: 0.0946 - val_loss: 0.0942
Epoch 83/100
60000/60000 [==============================] - 8s - loss: 0.0946 - val_loss: 0.0941
Epoch 84/100
60000/60000 [==============================] - 8s - loss: 0.0947 - val_loss: 0.0940
Epoch 85/100
60000/60000 [==============================] - 8s - loss: 0.0946 - val_loss: 0.0941
Epoch 86/100
60000/60000 [==============================] - 8s - loss: 0.0945 - val_loss: 0.0941
Epoch 87/100
60000/60000 [==============================] - 8s - loss: 0.0946 - val_loss: 0.0945
Epoch 88/100
60000/60000 [==============================] - 8s - loss: 0.0945 - val_loss: 0.0944
Epoch 89/100
60000/60000 [==============================] - 8s - loss: 0.0945 - val_loss: 0.0944
Epoch 90/100
60000/60000 [==============================] - 8s - loss: 0.0945 - val_loss: 0.0941
Epoch 91/100
60000/60000 [==============================] - 8s - loss: 0.0945 - val_loss: 0.0939
Epoch 92/100
60000/60000 [==============================] - 8s - loss: 0.0944 - val_loss: 0.0946
Epoch 93/100
60000/60000 [==============================] - 8s - loss: 0.0944 - val_loss: 0.0941
Epoch 94/100
60000/60000 [==============================] - 8s - loss: 0.0944 - val_loss: 0.0939
Epoch 95/100
60000/60000 [==============================] - 8s - loss: 0.0944 - val_loss: 0.0941
Epoch 96/100
60000/60000 [==============================] - 8s - loss: 0.0944 - val_loss: 0.0939
Epoch 97/100
60000/60000 [==============================] - 8s - loss: 0.0944 - val_loss: 0.0939
Epoch 98/100
60000/60000 [==============================] - 8s - loss: 0.0943 - val_loss: 0.0939
Epoch 99/100
60000/60000 [==============================] - 8s - loss: 0.0944 - val_loss: 0.0941
Epoch 100/100
60000/60000 [==============================] - 8s - loss: 0.0943 - val_loss: 0.0938





<keras.callbacks.History at 0x7fb45ad95f28>

Now Let's Take a look....

decoded_imgs = autoencoder.predict(x_test_noisy)

n = 10
plt.figure(figsize=(20, 4))
for i in range(n):
    # display original
    ax = plt.subplot(2, n, i+1)
    plt.imshow(x_test[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # display reconstruction
    ax = plt.subplot(2, n, i + n + 1)
    plt.imshow(decoded_imgs[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()

png

Variational AutoEncoder

(Reference https://blog.keras.io/building-autoencoders-in-keras.html)

Variational autoencoders are a slightly more modern and interesting take on autoencoding.

What is a variational autoencoder ?

It's a type of autoencoder with added constraints on the encoded representations being learned.

More precisely, it is an autoencoder that learns a latent variable model for its input data.

So instead of letting your neural network learn an arbitrary function, you are learning the parameters of a probability distribution modeling your data.

If you sample points from this distribution, you can generate new input data samples: a VAE is a "generative model".

How does a variational autoencoder work?

First, an encoder network turns the input samples $x$ into two parameters in a latent space, which we will note $z{\mu}$ and $z{log_{\sigma}}$.

Then, we randomly sample similar points $z$ from the latent normal distribution that is assumed to generate the data, via $z = z{\mu} + \exp(z{log_{\sigma}}) * \epsilon$, where $\epsilon$ is a random normal tensor.

Finally, a decoder network maps these latent space points back to the original input data.

The parameters of the model are trained via two loss functions:

  • a reconstruction loss forcing the decoded samples to match the initial inputs (just like in our previous autoencoders);
  • and the KL divergence between the learned latent distribution and the prior distribution, acting as a regularization term.

You could actually get rid of this latter term entirely, although it does help in learning well-formed latent spaces and reducing overfitting to the training data.

Encoder Network

batch_size = 100
original_dim = 784
latent_dim = 2
intermediate_dim = 256
epochs = 50
epsilon_std = 1.0
x = Input(batch_shape=(batch_size, original_dim))
h = Dense(intermediate_dim, activation='relu')(x)
z_mean = Dense(latent_dim)(h)
z_log_sigma = Dense(latent_dim)(h)

We can use these parameters to sample new similar points from the latent space:

from keras.layers.core import Lambda
from keras import backend as K
def sampling(args):
    z_mean, z_log_sigma = args
    epsilon = K.random_normal(shape=(batch_size, latent_dim),
                              mean=0., stddev=epsilon_std)
    return z_mean + K.exp(z_log_sigma) * epsilon

# note that "output_shape" isn't necessary with the TensorFlow backend
# so you could write `Lambda(sampling)([z_mean, z_log_sigma])`
z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_sigma])

Decoder Network

Finally, we can map these sampled latent points back to reconstructed inputs:

decoder_h = Dense(intermediate_dim, activation='relu')
decoder_mean = Dense(original_dim, activation='sigmoid')
h_decoded = decoder_h(z)
x_decoded_mean = decoder_mean(h_decoded)

What we've done so far allows us to instantiate 3 models:

  • an end-to-end autoencoder mapping inputs to reconstructions
  • an encoder mapping inputs to the latent space
  • a generator that can take points on the latent space and will output the corresponding reconstructed samples.
# end-to-end autoencoder
vae = Model(x, x_decoded_mean)

# encoder, from inputs to latent space
encoder = Model(x, z_mean)

# generator, from latent space to reconstructed inputs
decoder_input = Input(shape=(latent_dim,))
_h_decoded = decoder_h(decoder_input)
_x_decoded_mean = decoder_mean(_h_decoded)
generator = Model(decoder_input, _x_decoded_mean)

Let's Visualise the VAE Model

from IPython.display import SVG
from keras.utils.vis_utils import model_to_dot

SVG(model_to_dot(vae).create(prog='dot', format='svg'))

svg

## Exercise: Let's Do the Same for `encoder` and `generator` Model(s)

VAE on MNIST

We train the model using the end-to-end model, with a custom loss function: the sum of a reconstruction term, and the KL divergence regularization term.

from keras.objectives import binary_crossentropy

def vae_loss(x, x_decoded_mean):
    xent_loss = binary_crossentropy(x, x_decoded_mean)
    kl_loss = - 0.5 * K.mean(1 + z_log_sigma - K.square(z_mean) - K.exp(z_log_sigma), axis=-1)
    return xent_loss + kl_loss

vae.compile(optimizer='rmsprop', loss=vae_loss)

Traing on MNIST Digits

from keras.datasets import mnist
import numpy as np

(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))

vae.fit(x_train, x_train,
        shuffle=True,
        epochs=epochs,
        batch_size=batch_size,
        validation_data=(x_test, x_test))
Train on 60000 samples, validate on 10000 samples
Epoch 1/50
60000/60000 [==============================] - 3s - loss: 0.2932 - val_loss: 0.2629
Epoch 2/50
60000/60000 [==============================] - 3s - loss: 0.2631 - val_loss: 0.2628
Epoch 3/50
60000/60000 [==============================] - 3s - loss: 0.2630 - val_loss: 0.2626
Epoch 4/50
60000/60000 [==============================] - 3s - loss: 0.2630 - val_loss: 0.2629
Epoch 5/50
60000/60000 [==============================] - 3s - loss: 0.2630 - val_loss: 0.2627
Epoch 6/50
60000/60000 [==============================] - 3s - loss: 0.2630 - val_loss: 0.2627
Epoch 7/50
60000/60000 [==============================] - 3s - loss: 0.2630 - val_loss: 0.2626
Epoch 8/50
60000/60000 [==============================] - 3s - loss: 0.2629 - val_loss: 0.2627
Epoch 9/50
60000/60000 [==============================] - 3s - loss: 0.2629 - val_loss: 0.2626
Epoch 10/50
60000/60000 [==============================] - 3s - loss: 0.2629 - val_loss: 0.2626
Epoch 11/50
60000/60000 [==============================] - 3s - loss: 0.2629 - val_loss: 0.2626
Epoch 12/50
60000/60000 [==============================] - 3s - loss: 0.2629 - val_loss: 0.2625
Epoch 13/50
60000/60000 [==============================] - 3s - loss: 0.2629 - val_loss: 0.2626
Epoch 14/50
60000/60000 [==============================] - 3s - loss: 0.2629 - val_loss: 0.2627
Epoch 15/50
60000/60000 [==============================] - 3s - loss: 0.2629 - val_loss: 0.2627
Epoch 16/50
60000/60000 [==============================] - 3s - loss: 0.2629 - val_loss: 0.2627
Epoch 17/50
60000/60000 [==============================] - 3s - loss: 0.2629 - val_loss: 0.2626
Epoch 18/50
60000/60000 [==============================] - 3s - loss: 0.2629 - val_loss: 0.2626
Epoch 19/50
60000/60000 [==============================] - 3s - loss: 0.2629 - val_loss: 0.2626
Epoch 20/50
60000/60000 [==============================] - 3s - loss: 0.2629 - val_loss: 0.2626
Epoch 21/50
60000/60000 [==============================] - 3s - loss: 0.2629 - val_loss: 0.2625
Epoch 22/50
60000/60000 [==============================] - 3s - loss: 0.2629 - val_loss: 0.2626
Epoch 23/50
60000/60000 [==============================] - 3s - loss: 0.2629 - val_loss: 0.2626
Epoch 24/50
60000/60000 [==============================] - 3s - loss: 0.2629 - val_loss: 0.2626
Epoch 25/50
60000/60000 [==============================] - 3s - loss: 0.2629 - val_loss: 0.2626
Epoch 26/50
60000/60000 [==============================] - 3s - loss: 0.2629 - val_loss: 0.2626
Epoch 27/50
60000/60000 [==============================] - 3s - loss: 0.2629 - val_loss: 0.2627
Epoch 28/50
60000/60000 [==============================] - 3s - loss: 0.2629 - val_loss: 0.2627
Epoch 29/50
60000/60000 [==============================] - 3s - loss: 0.2629 - val_loss: 0.2627
Epoch 30/50
60000/60000 [==============================] - 3s - loss: 0.2629 - val_loss: 0.2626
Epoch 31/50
60000/60000 [==============================] - 3s - loss: 0.2629 - val_loss: 0.2626
Epoch 32/50
60000/60000 [==============================] - 3s - loss: 0.2629 - val_loss: 0.2625
Epoch 33/50
60000/60000 [==============================] - 3s - loss: 0.2629 - val_loss: 0.2626
Epoch 34/50
60000/60000 [==============================] - 3s - loss: 0.2629 - val_loss: 0.2626
Epoch 35/50
60000/60000 [==============================] - 3s - loss: 0.2629 - val_loss: 0.2626
Epoch 36/50
60000/60000 [==============================] - 3s - loss: 0.2629 - val_loss: 0.2625
Epoch 37/50
60000/60000 [==============================] - 3s - loss: 0.2629 - val_loss: 0.2625
Epoch 38/50
60000/60000 [==============================] - 3s - loss: 0.2629 - val_loss: 0.2627
Epoch 39/50
60000/60000 [==============================] - 3s - loss: 0.2629 - val_loss: 0.2626
Epoch 40/50
60000/60000 [==============================] - 3s - loss: 0.2629 - val_loss: 0.2626
Epoch 41/50
60000/60000 [==============================] - 3s - loss: 0.2629 - val_loss: 0.2626
Epoch 42/50
60000/60000 [==============================] - 3s - loss: 0.2629 - val_loss: 0.2627
Epoch 43/50
60000/60000 [==============================] - 3s - loss: 0.2629 - val_loss: 0.2626
Epoch 44/50
60000/60000 [==============================] - 3s - loss: 0.2629 - val_loss: 0.2626
Epoch 45/50
60000/60000 [==============================] - 3s - loss: 0.2629 - val_loss: 0.2627
Epoch 46/50
60000/60000 [==============================] - 3s - loss: 0.2629 - val_loss: 0.2626
Epoch 47/50
60000/60000 [==============================] - 3s - loss: 0.2629 - val_loss: 0.2626
Epoch 48/50
60000/60000 [==============================] - 3s - loss: 0.2629 - val_loss: 0.2627
Epoch 49/50
60000/60000 [==============================] - 3s - loss: 0.2629 - val_loss: 0.2625
Epoch 50/50
60000/60000 [==============================] - 3s - loss: 0.2629 - val_loss: 0.2626





<keras.callbacks.History at 0x7fb62fc26d30>

Because our latent space is two-dimensional, there are a few cool visualizations that can be done at this point.

One is to look at the neighborhoods of different classes on the latent 2D plane:

x_test_encoded = encoder.predict(x_test, batch_size=batch_size)
plt.figure(figsize=(6, 6))
plt.scatter(x_test_encoded[:, 0], x_test_encoded[:, 1], c=y_test)
plt.colorbar()
plt.show()

png

Each of these colored clusters is a type of digit. Close clusters are digits that are structurally similar (i.e. digits that share information in the latent space).

Because the VAE is a generative model, we can also use it to generate new digits! Here we will scan the latent plane, sampling latent points at regular intervals, and generating the corresponding digit for each of these points. This gives us a visualization of the latent manifold that "generates" the MNIST digits.

# display a 2D manifold of the digits
n = 15  # figure with 15x15 digits
digit_size = 28
figure = np.zeros((digit_size * n, digit_size * n))
# we will sample n points within [-15, 15] standard deviations
grid_x = np.linspace(-15, 15, n)
grid_y = np.linspace(-15, 15, n)

for i, yi in enumerate(grid_x):
    for j, xi in enumerate(grid_y):
        z_sample = np.array([[xi, yi]]) * epsilon_std
        x_decoded = generator.predict(z_sample)
        digit = x_decoded[0].reshape(digit_size, digit_size)
        figure[i * digit_size: (i + 1) * digit_size,
               j * digit_size: (j + 1) * digit_size] = digit

plt.figure(figsize=(10, 10))
plt.imshow(figure)
plt.show()

png

results matching ""

    No results matching ""