"""Tutorial on how to create an autoencoder w/ Tensorflow.
Parag K. Mital, Jan 2016
"""
'Tutorial on how to create an autoencoder w/ Tensorflow.\n\nParag K. Mital, Jan 2016\n'
%matplotlib inline
import tensorflow as tf
import numpy as np
import math
def autoencoder(dimensions=[784, 512, 256, 64]):
"""Build a deep autoencoder w/ tied weights.
Parameters
----------
dimensions : list, optional
The number of neurons for each layer of the autoencoder.
Returns
-------
x : Tensor
Input placeholder to the network
z : Tensor
Inner-most latent representation
y : Tensor
Output reconstruction of the input
cost : Tensor
Overall cost to use for training
"""
x = tf.placeholder(tf.float32, [None, dimensions[0]], name='x')
current_input = x
encoder = []
for layer_i, n_output in enumerate(dimensions[1:]):
n_input = int(current_input.get_shape()[1])
W = tf.Variable(
tf.random_uniform([n_input, n_output],
-1.0 / math.sqrt(n_input),
1.0 / math.sqrt(n_input)))
b = tf.Variable(tf.zeros([n_output]))
encoder.append(W)
output = tf.nn.tanh(tf.matmul(current_input, W) + b)
current_input = output
z = current_input
encoder.reverse()
for layer_i, n_output in enumerate(dimensions[:-1][::-1]):
W = tf.transpose(encoder[layer_i])
b = tf.Variable(tf.zeros([n_output]))
output = tf.nn.tanh(tf.matmul(current_input, W) + b)
current_input = output
y = current_input
cost = tf.reduce_sum(tf.square(y - x))
return {'x': x, 'z': z, 'y': y, 'cost': cost}
def test_mnist():
"""Test the autoencoder using MNIST."""
import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
import matplotlib.pyplot as plt
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
mean_img = np.mean(mnist.train.images, axis=0)
ae = autoencoder(dimensions=[784, 256, 64])
learning_rate = 0.001
optimizer = tf.train.AdamOptimizer(learning_rate).minimize(ae['cost'])
sess = tf.Session()
sess.run(tf.initialize_all_variables())
batch_size = 50
n_epochs = 10
for epoch_i in range(n_epochs):
for batch_i in range(mnist.train.num_examples // batch_size):
batch_xs, _ = mnist.train.next_batch(batch_size)
train = np.array([img - mean_img for img in batch_xs])
sess.run(optimizer, feed_dict={ae['x']: train})
print(epoch_i, sess.run(ae['cost'], feed_dict={ae['x']: train}))
n_examples = 15
test_xs, _ = mnist.test.next_batch(n_examples)
test_xs_norm = np.array([img - mean_img for img in test_xs])
recon = sess.run(ae['y'], feed_dict={ae['x']: test_xs_norm})
fig, axs = plt.subplots(2, n_examples, figsize=(10, 2))
for example_i in range(n_examples):
axs[0][example_i].imshow(
np.reshape(test_xs[example_i, :], (28, 28)))
axs[1][example_i].imshow(
np.reshape([recon[example_i, :] + mean_img], (28, 28)))
fig.show()
plt.draw()
if __name__ == '__main__':
test_mnist()
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
(0, 554.17175)
(1, 451.15411)
(2, 420.77158)
(3, 437.03281)
(4, 409.98288)
(5, 363.77893)
(6, 420.08453)
(7, 396.18784)
(8, 365.31839)
(9, 407.51379)
/home/heythisischo/anaconda2/lib/python2.7/site-packages/matplotlib/figure.py:397: UserWarning: matplotlib is currently using a non-GUI backend, so cannot show the figure
"matplotlib is currently using a non-GUI backend, "
