import matplotlib.pyplot as plt
import numpy as np
import math
from PIL import Image
import cPickle as pkl
import time
import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
import glob
%matplotlib inline
print ("Packages loaded")
Packages loaded
Load dataset
dirpath = "./data/iccv09Data/images/"
height = 240
width = 320
resize_ratio = 3
nr_img = 0
fileList = glob.glob(dirpath + '*.jpg')
for i, file in enumerate(fileList):
img = Image.open(file)
array = np.array(img)
if array.shape[0] == height and array.shape[1] == width:
nr_img = nr_img + 1
rgb = array.reshape(1, height, width, 3)
imglow = img.resize((int(width/resize_ratio)
,int(height/resize_ratio)), Image.BICUBIC)
imglow = imglow.resize((width, height), Image.BICUBIC)
rgblow = np.array(np.float32(imglow)/255.)
rgblow = rgblow.reshape(1, height, width, 3)
rgb = np.reshape(rgb, [1, -1])
rgblow = np.reshape(rgblow, [1, -1])
if nr_img == 1:
data = rgb
datalow = rgblow
else:
data = np.concatenate((data, rgb), axis=0)
datalow = np.concatenate((datalow, rgblow), axis=0)
print ("nr_img is %d" % (nr_img))
print ("Shape of 'data' is %s" % (data.shape,))
print ("Shape of 'datalow' is %s" % (datalow.shape,))
nr_img is 531
Shape of 'data' is (531, 230400)
Shape of 'datalow' is (531, 230400)
Divide into two sets
(xtrain, ytrain) and (xtest, ytest)
randidx = np.random.permutation(nr_img)
nrtrain = int(nr_img*0.7)
nrtest = nr_img - nrtrain
xtrain = datalow[randidx[0:nrtrain], :]
ytrain = data[randidx[0:nrtrain], :]
xtest = datalow[randidx[nrtrain:nr_img], :]
ytest = data[randidx[nrtrain:nr_img], :]
print ("Shape of 'xtrain' is %s" % (xtrain.shape,))
print ("Shape of 'ytrain' is %s" % (ytrain.shape,))
print ("Shape of 'xtest' is %s" % (xtest.shape,))
print ("Shape of 'ytest' is %s" % (ytest.shape,))
Shape of 'xtrain' is (371, 230400)
Shape of 'ytrain' is (371, 230400)
Shape of 'xtest' is (160, 230400)
Shape of 'ytest' is (160, 230400)
Plot some images
randidx = np.random.randint(nrtrain)
currx = xtrain[randidx, :]
currx = np.reshape(currx, [height, width, 3])
plt.imshow(currx)
plt.title("Train input image")
plt.show()
curry = ytrain[randidx, :]
curry = np.reshape(curry, [height, width, 3])
plt.imshow(curry)
plt.title("Train output image")
plt.show()
randidx = np.random.randint(nrtest)
currx = xtest[randidx, :]
currx = np.reshape(currx, [height, width, 3])
plt.imshow(currx)
plt.title("Test input image")
plt.show()
curry = ytest[randidx, :]
curry = np.reshape(curry, [height, width, 3])
plt.imshow(curry)
plt.title("Test output image")
plt.show()
Define network
n1 = 32
n2 = 64
n3 = 64
n4 = 64
n5 = 64
n6 = 3
ksize = 3
weights = {
'ce1': tf.Variable(tf.random_normal([ksize, ksize, 3, n1], stddev=0.01)),
'ce2': tf.Variable(tf.random_normal([ksize, ksize, n1, n2], stddev=0.01)),
'ce3': tf.Variable(tf.random_normal([ksize, ksize, n2, n3], stddev=0.01)),
'ce4': tf.Variable(tf.random_normal([ksize, ksize, n3, n4], stddev=0.01)),
'ce5': tf.Variable(tf.random_normal([ksize, ksize, n4, n5], stddev=0.01)),
'ce6': tf.Variable(tf.random_normal([ksize, ksize, n5, n6], stddev=0.01))
}
biases = {
'be1': tf.Variable(tf.random_normal([n1], stddev=0.01)),
'be2': tf.Variable(tf.random_normal([n2], stddev=0.01)),
'be3': tf.Variable(tf.random_normal([n3], stddev=0.01)),
'be4': tf.Variable(tf.random_normal([n4], stddev=0.01)),
'be5': tf.Variable(tf.random_normal([n5], stddev=0.01)),
'be6': tf.Variable(tf.random_normal([n6], stddev=0.01))
}
def srn(_X, _W, _b, _keepprob):
_input_r = tf.reshape(_X, shape=[-1, height, width, 3])
_ce1 = tf.nn.relu(tf.add(tf.nn.conv2d(_input_r, _W['ce1']
, strides=[1, 1, 1, 1], padding='SAME'), _b['be1']))
_ce1 = tf.nn.dropout(_ce1, _keepprob)
_ce2 = tf.nn.relu(tf.add(tf.nn.conv2d(_ce1, _W['ce2']
, strides=[1, 1, 1, 1], padding='SAME'), _b['be2']))
_ce2 = tf.nn.dropout(_ce2, _keepprob)
_ce3 = tf.nn.relu(tf.add(tf.nn.conv2d(_ce2, _W['ce3']
, strides=[1, 1, 1, 1], padding='SAME'), _b['be3']))
_ce3 = tf.nn.dropout(_ce3, _keepprob)
_ce4 = tf.nn.relu(tf.add(tf.nn.conv2d(_ce3, _W['ce4']
, strides=[1, 1, 1, 1], padding='SAME'), _b['be4']))
_ce4 = tf.nn.dropout(_ce4, _keepprob)
_ce5 = tf.nn.relu(tf.add(tf.nn.conv2d(_ce4, _W['ce5']
, strides=[1, 1, 1, 1], padding='SAME'), _b['be5']))
_ce5 = tf.nn.dropout(_ce5, _keepprob)
_ce6 = tf.nn.relu(tf.add(tf.nn.conv2d(_ce5, _W['ce6']
, strides=[1, 1, 1, 1], padding='SAME'), _b['be6']))
_out = _ce6 + _input_r
return {'input_r': _input_r, 'ce1': _ce1, 'ce2': _ce2, 'ce3': _ce3
, 'ce4': _ce4, 'ce5': _ce5, 'ce6': _ce6
, 'layers': (_input_r, _ce1, _ce2, _ce3, _ce4, _ce5, _ce6)
, 'out': _out}
print ("Network ready")
Network ready
Define functions
dim = height*width*3
x = tf.placeholder(tf.float32, [None, dim])
y = tf.placeholder(tf.float32, [None, dim])
keepprob = tf.placeholder(tf.float32)
pred = srn(x, weights, biases, keepprob)['out']
cost = tf.reduce_mean(tf.square(srn(x, weights, biases, keepprob)['out']
- tf.reshape(y, shape=[-1, height, width, 3])))
learning_rate = 0.001
optm = tf.train.AdamOptimizer(learning_rate, 0.9).minimize(cost)
init = tf.initialize_all_variables()
print ("Functions ready")
Functions ready
Run
sess = tf.Session()
sess.run(init)
batch_size = 16
n_epochs = 100000
print("Strart training..")
for epoch_i in range(n_epochs):
for batch_i in range(nrtrain // batch_size):
randidx = np.random.randint(nrtrain, size=batch_size)
batch_xs = xtrain[randidx, :]
batch_ys = ytrain[randidx, :]
sess.run(optm, feed_dict={x: batch_xs
, y: batch_ys, keepprob: 0.7})
if (epoch_i % 10) == 0:
print ("[%02d/%02d] cost: %.4f" % (epoch_i, n_epochs
, sess.run(cost, feed_dict={x: batch_xs
, y: batch_ys, keepprob: 1.})))
if (epoch_i % 100) == 0:
n_examples = 2
print ("Training dataset")
randidx = np.random.randint(nrtrain, size=n_examples)
train_xs = xtrain[randidx, :]
train_ys = ytrain[randidx, :]
recon = sess.run(pred, feed_dict={x: train_xs, keepprob: 1.})
fig, axs = plt.subplots(3, n_examples, figsize=(15, 20))
for example_i in range(n_examples):
axs[0][example_i].imshow(np.reshape(
train_xs[example_i, :], (height, width, 3)))
axs[1][example_i].imshow(np.reshape(
recon[example_i, :], (height, width, 3)))
axs[2][example_i].imshow(np.reshape(
train_ys[example_i, :], (height, width, 3)))
plt.show()
print ("Test dataset")
randidx = np.random.randint(nrtest, size=n_examples)
test_xs = xtest[randidx, :]
test_ys = ytest[randidx, :]
recon = sess.run(pred, feed_dict={x: test_xs, keepprob: 1.})
fig, axs = plt.subplots(3, n_examples, figsize=(15, 20))
for example_i in range(n_examples):
axs[0][example_i].imshow(np.reshape(
test_xs[example_i, :], (height, width, 3)))
axs[1][example_i].imshow(np.reshape(
recon[example_i, :], (height, width, 3)))
axs[2][example_i].imshow(np.reshape(
test_ys[example_i, :], (height, width, 3)))
plt.show()
print("Training done. ")
Strart training..
[00/100000] cost: 3772.8538
Training dataset
Test dataset
[10/100000] cost: 702.8953
[20/100000] cost: 506.8691
[30/100000] cost: 532.7783
[40/100000] cost: 459.9507
[50/100000] cost: 392.8480
[60/100000] cost: 415.6570
[70/100000] cost: 354.2756
[80/100000] cost: 357.1342
[90/100000] cost: 320.3797
[100/100000] cost: 440.0826
Training dataset
Test dataset
[110/100000] cost: 324.7620
[120/100000] cost: 326.6705
[130/100000] cost: 434.0658
[140/100000] cost: 446.6097
[150/100000] cost: 411.9198
[160/100000] cost: 501.3135
[170/100000] cost: 433.7086
[180/100000] cost: 400.1469
[190/100000] cost: 330.6881
[200/100000] cost: 522.5967
Training dataset
Test dataset
[210/100000] cost: 486.0645
[220/100000] cost: 410.2568
[230/100000] cost: 428.8757
[240/100000] cost: 552.0220
[250/100000] cost: 422.1021
[260/100000] cost: 377.5060
[270/100000] cost: 269.3975
[280/100000] cost: 219.0566
[290/100000] cost: 216.6913
[300/100000] cost: 248.9276
Training dataset
Test dataset
[310/100000] cost: 230.1701
[320/100000] cost: 254.1594
[330/100000] cost: 233.3717
[340/100000] cost: 177.7164
[350/100000] cost: 203.6235
[360/100000] cost: 256.6276
[370/100000] cost: 230.0090
[380/100000] cost: 219.8041
[390/100000] cost: 215.0632
[400/100000] cost: 212.2820
Training dataset
Test dataset
[410/100000] cost: 198.7617
[420/100000] cost: 196.8098
[430/100000] cost: 219.6303
[440/100000] cost: 171.0177
[450/100000] cost: 215.5296
[460/100000] cost: 177.6324
[470/100000] cost: 190.9361
[480/100000] cost: 201.9544
[490/100000] cost: 191.6286
[500/100000] cost: 200.1959
Training dataset
Test dataset
[510/100000] cost: 199.5182
[520/100000] cost: 170.4744
[530/100000] cost: 197.0104
[540/100000] cost: 202.3129
[550/100000] cost: 225.0957
[560/100000] cost: 179.6629
[570/100000] cost: 190.9138
[580/100000] cost: 177.1373
[590/100000] cost: 230.7214
[600/100000] cost: 199.4890
Training dataset
Test dataset
[610/100000] cost: 150.3460
[620/100000] cost: 142.2733
[630/100000] cost: 184.0573
[640/100000] cost: 220.4450
[650/100000] cost: 165.4367
[660/100000] cost: 253.1218
[670/100000] cost: 158.5751
[680/100000] cost: 227.8497
[690/100000] cost: 167.6208
[700/100000] cost: 196.1220
Training dataset
Test dataset
[710/100000] cost: 197.5616
[720/100000] cost: 198.3330
[730/100000] cost: 200.2553
[740/100000] cost: 204.8952
[750/100000] cost: 209.1677
[760/100000] cost: 205.5825
[770/100000] cost: 158.0817
[780/100000] cost: 205.6169
[790/100000] cost: 188.0784
[800/100000] cost: 196.0623
Training dataset
Test dataset
[810/100000] cost: 210.1583
[820/100000] cost: 221.4599
[830/100000] cost: 237.4969
[840/100000] cost: 192.0751
[850/100000] cost: 213.0567
[860/100000] cost: 174.9142
[870/100000] cost: 179.2275
[880/100000] cost: 211.8621
[890/100000] cost: 212.8904
[900/100000] cost: 257.3144
Training dataset
Test dataset
[910/100000] cost: 279.0889
[920/100000] cost: 275.9886
[930/100000] cost: 294.9624
[940/100000] cost: 218.6249
[950/100000] cost: 200.8324
[960/100000] cost: 287.3336
[970/100000] cost: 245.9705
[980/100000] cost: 240.9357
[990/100000] cost: 223.2813
[1000/100000] cost: 258.9540
Training dataset
Test dataset
[1010/100000] cost: 270.5289
[1020/100000] cost: 271.8137
[1030/100000] cost: 246.7184
[1040/100000] cost: 192.9518
[1050/100000] cost: 222.2798
[1060/100000] cost: 248.5685
[1070/100000] cost: 234.6805
[1080/100000] cost: 203.7659
[1090/100000] cost: 253.0008
[1100/100000] cost: 269.9266
Training dataset
Test dataset
[1110/100000] cost: 349.5590
[1120/100000] cost: 268.0766
[1130/100000] cost: 302.1344
[1140/100000] cost: 208.0664
[1150/100000] cost: 264.6624
[1160/100000] cost: 222.6232
[1170/100000] cost: 272.3559
[1180/100000] cost: 233.2416
[1190/100000] cost: 231.7469
[1200/100000] cost: 236.6369
Training dataset
Test dataset
[1210/100000] cost: 243.6459
[1220/100000] cost: 202.2522
[1230/100000] cost: 251.7291
[1240/100000] cost: 246.2129
[1250/100000] cost: 259.5334
[1260/100000] cost: 272.7007
[1270/100000] cost: 299.9875
[1280/100000] cost: 260.2490
[1290/100000] cost: 291.4149
[1300/100000] cost: 278.5886
Training dataset
Test dataset
[1310/100000] cost: 260.8284
[1320/100000] cost: 246.8349
[1330/100000] cost: 355.7622
[1340/100000] cost: 287.9314
[1350/100000] cost: 253.2013
[1360/100000] cost: 294.0295
[1370/100000] cost: 249.6499
[1380/100000] cost: 270.4777
[1390/100000] cost: 253.6567
[1400/100000] cost: 258.3598
Training dataset
Test dataset
[1410/100000] cost: 252.8865
[1420/100000] cost: 226.4185
[1430/100000] cost: 274.6122
[1440/100000] cost: 284.6050
[1450/100000] cost: 275.3493
[1460/100000] cost: 286.1541
[1470/100000] cost: 300.7496
[1480/100000] cost: 258.0696
[1490/100000] cost: 302.6696
[1500/100000] cost: 321.3855
Training dataset
Test dataset
[1510/100000] cost: 268.9408
[1520/100000] cost: 305.9114
[1530/100000] cost: 301.6736
[1540/100000] cost: 267.8618
[1550/100000] cost: 323.6226
[1560/100000] cost: 265.2032
[1570/100000] cost: 318.1612
[1580/100000] cost: 256.5219
[1590/100000] cost: 300.4944
[1600/100000] cost: 305.1118
Training dataset
Test dataset
[1610/100000] cost: 272.6064
[1620/100000] cost: 284.0279
[1630/100000] cost: 317.0608
[1640/100000] cost: 278.9991
[1650/100000] cost: 305.5137
[1660/100000] cost: 271.6007
[1670/100000] cost: 287.9285
[1680/100000] cost: 299.2344
[1690/100000] cost: 284.0301
[1700/100000] cost: 248.4221
Training dataset
Test dataset
[1710/100000] cost: 280.1856
[1720/100000] cost: 337.9356
[1730/100000] cost: 267.5448
[1740/100000] cost: 334.8813
[1750/100000] cost: 302.4286
[1760/100000] cost: 324.7914
[1770/100000] cost: 258.5010
[1780/100000] cost: 327.7728
---------------------------------------------------------------------------
KeyboardInterrupt Traceback (most recent call last)
<ipython-input-8-b93cfefc0574> in <module>()
11 batch_ys = ytrain[randidx, :]
12 sess.run(optm, feed_dict={x: batch_xs
---> 13 , y: batch_ys, keepprob: 0.7})
14 if (epoch_i % 10) == 0:
15 print ("[%02d/%02d] cost: %.4f" % (epoch_i, n_epochs
/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.pyc in run(self, fetches, feed_dict, options, run_metadata)
338 try:
339 result = self._run(None, fetches, feed_dict, options_ptr,
--> 340 run_metadata_ptr)
341 if run_metadata:
342 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.pyc in _run(self, handle, fetches, feed_dict, options, run_metadata)
562 try:
563 results = self._do_run(handle, target_list, unique_fetches,
--> 564 feed_dict_string, options, run_metadata)
565 finally:
566 # The movers are no longer used. Delete them.
/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.pyc in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
635 if handle is None:
636 return self._do_call(_run_fn, self._session, feed_dict, fetch_list,
--> 637 target_list, options, run_metadata)
638 else:
639 return self._do_call(_prun_fn, self._session, handle, feed_dict,
/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.pyc in _do_call(self, fn, *args)
642 def _do_call(self, fn, *args):
643 try:
--> 644 return fn(*args)
645 except tf_session.StatusNotOK as e:
646 error_message = compat.as_text(e.error_message)
/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.pyc in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata)
626 else:
627 return tf_session.TF_Run(
--> 628 session, None, feed_dict, fetch_list, target_list, None)
629
630 def _prun_fn(session, handle, feed_dict, fetch_list):
KeyboardInterrupt: