8.2 RNN in TensorFlow for Text Data (NLP)
import os
import numpy as np
np.random.seed(123)
print("NumPy:{}".format(np.__version__))
import tensorflow as tf
tf.set_random_seed(123)
print("TensorFlow:{}".format(tf.__version__))
NumPy:1.13.1
Pandas:0.21.0
Matplotlib:2.1.0
TensorFlow:1.4.1
DATASETSLIB_HOME = os.path.join(os.path.expanduser('~'),'dl-ts','datasetslib')
import sys
if not DATASETSLIB_HOME in sys.path:
sys.path.append(DATASETSLIB_HOME)
%reload_ext autoreload
%autoreload 2
import datasetslib
from datasetslib import util as dsu
from datasetslib import nputil
datasetslib.datasets_root = os.path.join(os.path.expanduser('~'),'datasets')
Text Generation with Text8 Data in TensorFlow
Load and Prepare Text8 data
from datasetslib.text8 import Text8
text8 = Text8()
text8.load_data(clip_at=5000)
print('Train:', text8.part['train'][0:5])
print('Vocabulary Length = ',text8.vocab_len)
Already exists: /home/armando/datasets/text8/text8.zip
Train: [ 8 497 7 5 116]
Vocabulary Length = 1457
def id2string(ids):
return ' '.join([text8.id2word[x_i] for x_i in ids])
print(id2string(text8.part['train'][0:100]))
anarchism originated as a term of abuse first used against early working class radicals including the diggers of the english revolution and the sans culottes of the french revolution whilst the term is still used in a pejorative way to describe any act that used violent means to destroy the organization of society it has also been taken up as a positive label by self defined anarchists the word anarchism is derived from the greek without archons ruler chief king anarchism as a political philosophy is the belief that rulers are unnecessary and should be abolished although there are differing
tf.reset_default_graph()
batch_size = 128
n_x = 5
n_y = 1
n_x_vars = 1
n_y_vars = text8.vocab_len
state_size = 128
learning_rate = 0.001
x_p = tf.placeholder(tf.float32, [None, n_x, n_x_vars], name='x_p')
y_p = tf.placeholder(tf.float32, [None, n_y_vars], name='y_p')
x_in = tf.unstack(x_p,axis=1,name='x_in')
cell = tf.nn.rnn_cell.LSTMCell(state_size)
rnn_outputs, final_states = tf.nn.static_rnn(cell, x_in,dtype=tf.float32)
random5 = np.random.choice(n_x * 50, n_x, replace=False)
print('Random 5 words: ',id2string(random5))
first5 = text8.part['train'][0:n_x].copy()
print('First 5 words: ',id2string(first5))
Random 5 words: free bolshevik be n another
First 5 words: anarchism originated as a term
w = tf.get_variable('w', [state_size, n_y_vars], initializer= tf.random_normal_initializer)
b = tf.get_variable('b', [n_y_vars], initializer=tf.constant_initializer(0.0))
y_out = tf.matmul(rnn_outputs[-1], w) + b
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y_out, labels=y_p))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss)
n_correct_pred = tf.equal(tf.argmax(y_out,1), tf.argmax(y_p,1))
accuracy = tf.reduce_mean(tf.cast(n_correct_pred, tf.float32))
n_epochs = 1000
learning_rate = 0.001
text8.reset_index()
n_batches = text8.n_batches_seq(n_tx=n_x,n_ty=n_y)
n_epochs_display = 100
with tf.Session() as tfs:
tf.global_variables_initializer().run()
for epoch in range(n_epochs):
epoch_loss = 0
epoch_accuracy = 0
for step in range(n_batches):
x_batch, y_batch = text8.next_batch_seq(n_tx=n_x,n_ty=n_y)
y_batch = nputil.to2d(y_batch,unit_axis=1)
y_onehot = np.zeros(shape=[batch_size,text8.vocab_len],dtype=np.float32)
for i in range(batch_size):
y_onehot[i,y_batch[i]]=1
feed_dict = {x_p: x_batch.reshape(-1, n_x, n_x_vars), y_p: y_onehot}
_, batch_accuracy, batch_loss = tfs.run([optimizer, accuracy, loss], feed_dict=feed_dict)
epoch_loss += batch_loss
epoch_accuracy += batch_accuracy
if (epoch+1) % (n_epochs_display) == 0:
epoch_loss = epoch_loss / n_batches
epoch_accuracy = epoch_accuracy / n_batches
print('\nEpoch {0:}, Average loss:{1:}, Average accuracy:{2:}'.
format(epoch,epoch_loss,epoch_accuracy ))
y_pred_r5 = np.empty([10])
y_pred_f5 = np.empty([10])
x_test_r5 = random5.copy()
x_test_f5 = first5.copy()
for i in range(10):
for x,y in zip([x_test_r5,x_test_f5],[y_pred_r5,y_pred_f5]):
x_input = x.copy()
feed_dict = {x_p: x_input.reshape(-1, n_x, n_x_vars)}
y_pred = tfs.run(y_out, feed_dict=feed_dict)
y_pred_id = int(tf.argmax(y_pred, 1).eval())
y[i]=y_pred_id
x[:-1] = x[1:]
x[-1] = y_pred_id
print(' Random 5 prediction:',id2string(y_pred_r5))
print(' First 5 prediction:',id2string(y_pred_f5))
Epoch 99, Average loss:1.2983208000659943, Average accuracy:0.84375
Random 5 prediction: century warren own supported supported without without without without strongly
First 5 prediction: market argued individualist warren without without without without strongly strongly
Epoch 199, Average loss:0.5452078034480413, Average accuracy:0.939453125
Random 5 prediction: been cnt also also also called syndicalism syndicalist operation syndicalists
First 5 prediction: spain force like politics ricardo key mag mag mag mag
Epoch 299, Average loss:1.193410764137904, Average accuracy:0.8717447916666666
Random 5 prediction: tolstoy goods aiming anarchistic anarchistic anarchistic anarchistic anarchistic anarchistic anarchistic
First 5 prediction: their social groups groups authoritarian authoritarian authoritarian authoritarian authoritarian authoritarian
Epoch 399, Average loss:1.2231902281443279, Average accuracy:0.8704427083333334
Random 5 prediction: long long associated associated anti anti left left authoritarian left
First 5 prediction: has movement anarchy post post post post post post post
Epoch 499, Average loss:0.7656367868185043, Average accuracy:0.9140625
Random 5 prediction: noted mutual stirner warren warren tucker tucker tucker tucker tucker
First 5 prediction: her liberty noted own warren warren tucker tucker tucker tucker
Epoch 599, Average loss:1.107410545150439, Average accuracy:0.8756510416666666
Random 5 prediction: syndicalists syndicalists propaganda propaganda propaganda ricardo national national mag mag
First 5 prediction: spanish force working syndicalists syndicalists propaganda propaganda propaganda ricardo national
Epoch 699, Average loss:0.9093838532765707, Average accuracy:0.8854166666666666
Random 5 prediction: teachings jesus directive directive antifa antifa official relying relying relying
First 5 prediction: right who within communities tolstoy communities nonviolent official directive christianity
Epoch 799, Average loss:0.752622996767362, Average accuracy:0.890625
Random 5 prediction: important generalizations bob hakim hakim hakim hakim hakim hakim hakim
First 5 prediction: individual include associated important important important bey bey bey bey
Epoch 899, Average loss:0.41430705537398654, Average accuracy:0.9440104166666666
Random 5 prediction: benjamin egoism tucker tucker tucker tucker tucker tucker tucker tucker
First 5 prediction: self century tucker warren tucker tucker tucker tucker tucker tucker
Epoch 999, Average loss:0.33439325789610547, Average accuracy:0.9485677083333334
Random 5 prediction: syndicalists syndicalists t syndicalists syndicalists unity t syndicalists syndicalists management
First 5 prediction: spain century syndicalist self spain syndicalist french spain propaganda management