TensorFlow 中的 LSTM 文本生成

您可以在 Jupyter 笔记本ch-08b_RNN_Text_TensorFlow中按照本节的代码进行操作。

我们使用以下步骤在 TensorFlow 中实现文本生成 LSTM:

  1. 让我们为xy定义参数和占位符:
batch_size = 128
n_x = 5 # number of input words
n_y = 1 # number of output words
n_x_vars = 1 # in case of our text, there is only 1 variable at each timestep
n_y_vars = text8.vocab_len
state_size = 128
learning_rate = 0.001
x_p = tf.placeholder(tf.float32, [None, n_x, n_x_vars], name='x_p') 
y_p = tf.placeholder(tf.float32, [None, n_y_vars], name='y_p')

对于输入,我们使用单词的整数表示,因此n_x_vars是 1.对于输出,我们使用单热编码值,因此输出的数量等于词汇长度。

  1. 接下来,创建一个长度为n_x的张量列表:
x_in = tf.unstack(x_p,axis=1,name='x_in')
  1. 接下来,从输入和单元创建 LSTM 单元和静态 RNN 网络:
cell = tf.nn.rnn_cell.LSTMCell(state_size)
rnn_outputs, final_states = tf.nn.static_rnn(cell, x_in,dtype=tf.float32)
  1. 接下来,我们定义最终层的权重,偏差和公式。最后一层只需要为第六个单词选择输出,因此我们应用以下公式来仅获取最后一个输出:
# output node parameters
w = tf.get_variable('w', [state_size, n_y_vars], initializer= tf.random_normal_initializer)
b = tf.get_variable('b', [n_y_vars], initializer=tf.constant_initializer(0.0))
y_out = tf.matmul(rnn_outputs[-1], w) + b
  1. 接下来,创建一个损失函数和优化器:
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
        logits=y_out, labels=y_p))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
            .minimize(loss)
  1. 创建我们可以在会话块中运行的准确率函数,以检查训练模式的准确性:
n_correct_pred = tf.equal(tf.argmax(y_out,1), tf.argmax(y_p,1))
accuracy = tf.reduce_mean(tf.cast(n_correct_pred, tf.float32))
  1. 最后,我们训练模型 1000 个周期,并每 100 个周期打印结果。此外,每 100 个周期,我们从上面描述的种子字符串打印生成的文本。

LSTM 和 RNN 网络需要对大量数据集进行大量周期的训练,以获得更好的结果。 请尝试加载完整的数据集并在计算机上运行 50,000或80,000 个周期,并使用其他超参数来改善结果。

n_epochs = 1000
learning_rate = 0.001
text8.reset_index_in_epoch()
n_batches = text8.n_batches_seq(batch_size=batch_size,n_tx=n_x,n_ty=n_y)
n_epochs_display = 100

with tf.Session() as tfs:
    tf.global_variables_initializer().run()

    for epoch in range(n_epochs):
        epoch_loss = 0
        epoch_accuracy = 0
        for step in range(n_batches):
            x_batch, y_batch = text8.next_batch_seq(batch_size=batch_size,
                                n_tx=n_x,n_ty=n_y)
            y_batch = dsu.to2d(y_batch,unit_axis=1)
            y_onehot = np.zeros(shape=[batch_size,text8.vocab_len],
                        dtype=np.float32)
            for i in range(batch_size):
                y_onehot[i,y_batch[i]]=1

            feed_dict = {x_p: x_batch.reshape(-1, n_x, n_x_vars), 
                         y_p: y_onehot}
            _, batch_accuracy, batch_loss = tfs.run([optimizer,accuracy,
                                            loss],feed_dict=feed_dict)
            epoch_loss += batch_loss
            epoch_accuracy += batch_accuracy

        if (epoch+1) % (n_epochs_display) == 0:
            epoch_loss = epoch_loss / n_batches
            epoch_accuracy = epoch_accuracy / n_batches
            print('\nEpoch {0:}, Average loss:{1:}, Average accuracy:{2:}'.
                    format(epoch,epoch_loss,epoch_accuracy ))

            y_pred_r5 = np.empty([10])
            y_pred_f5 = np.empty([10])

            x_test_r5 = random5.copy()
            x_test_f5 = first5.copy()
            # let us generate text of 10 words after feeding 5 words
            for i in range(10):
                for x,y in zip([x_test_r5,x_test_f5],
                               [y_pred_r5,y_pred_f5]):
                    x_input = x.copy()
                    feed_dict = {x_p: x_input.reshape(-1, n_x, n_x_vars)}
                    y_pred = tfs.run(y_out, feed_dict=feed_dict)
                    y_pred_id = int(tf.argmax(y_pred, 1).eval())
                    y[i]=y_pred_id
                    x[:-1] = x[1:]
                    x[-1] = y_pred_id
            print(' Random 5 prediction:',id2string(y_pred_r5))
            print(' First 5 prediction:',id2string(y_pred_f5))

结果如下:

Epoch 99, Average loss:1.3972469369570415, Average accuracy:0.8489583333333334
  Random 5 prediction: labor warren together strongly profits strongly supported supported co without
  First 5 prediction: market own self free together strongly profits strongly supported supported

Epoch 199, Average loss:0.7894854595263799, Average accuracy:0.9186197916666666
  Random 5 prediction: syndicalists spanish class movements also also anarcho anarcho anarchist was
  First 5 prediction: five civil association class movements also anarcho anarcho anarcho anarcho

Epoch 299, Average loss:1.360412875811259, Average accuracy:0.865234375
  Random 5 prediction: anarchistic beginnings influenced true tolstoy tolstoy tolstoy tolstoy tolstoy tolstoy
  First 5 prediction: early civil movement be for was two most most most

Epoch 399, Average loss:1.1692512730757396, Average accuracy:0.8645833333333334
  Random 5 prediction: including war than than revolutionary than than war than than
  First 5 prediction: left including including including other other other other other other

Epoch 499, Average loss:0.5921860883633295, Average accuracy:0.923828125
  Random 5 prediction: ever edited interested interested variety variety variety variety variety variety
  First 5 prediction: english market herbert strongly price interested variety variety variety variety

Epoch 599, Average loss:0.8356450994809469, Average accuracy:0.8958333333333334
  Random 5 prediction: management allow trabajo trabajo national national mag mag ricardo ricardo
  First 5 prediction: spain prior am working n war war war self self

Epoch 699, Average loss:0.7057955612738928, Average accuracy:0.8971354166666666
  Random 5 prediction: teachings can directive tend resist obey christianity author christianity christianity
  First 5 prediction: early early called social called social social social social social

Epoch 799, Average loss:0.772875706354777, Average accuracy:0.90234375
  Random 5 prediction: associated war than revolutionary revolutionary revolutionary than than revolutionary revolutionary
  First 5 prediction: political been hierarchy war than see anti anti anti anti

Epoch 899, Average loss:0.43675946692625683, Average accuracy:0.9375
  Random 5 prediction: individualist which which individualist warren warren tucker benjamin how tucker
  First 5 prediction: four at warren individualist warren published considered considered considered considered

Epoch 999, Average loss:0.23202441136042276, Average accuracy:0.9602864583333334
  Random 5 prediction: allow allow trabajo you you you you you you you
  First 5 prediction: labour spanish they they they movement movement anarcho anarcho two

生成的文本中的重复单词是常见的,并且应该更好地训练模型。虽然模型的准确性提高到 96%,但仍然不足以生成清晰的文本。尝试增加 LSTM 单元/隐藏层的数量,同时在较大的数据集上运行模型以获取大量周期。

现在让我们在 Keras 建立相同的模型:

results matching ""

    No results matching ""