Keras 中的 LSTM 文本生成

您可以在 Jupyter 笔记本ch-08b_RNN_Text_Keras中按照本节的代码进行操作。

我们在 Keras 实现文本生成 LSTM,步骤如下:

  1. 首先,我们将所有数据转换为两个张量,张量x有五列,因为我们一次输入五个字,张量y只有一列输出。我们将y或标签张量转换为单热编码表示。

请记住,在大型数据集的实践中,您将使用 word2vec 嵌入而不是单热表示。

# get the data
x_train, y_train = text8.seq_to_xy(seq=text8.part['train'],n_tx=n_x,n_ty=n_y)
# reshape input to be [samples, time steps, features]
x_train = x_train.reshape(x_train.shape[0], x_train.shape[1],1)
y_onehot = np.zeros(shape=[y_train.shape[0],text8.vocab_len],dtype=np.float32)
for i in range(y_train.shape[0]):
    y_onehot[i,y_train[i]]=1
  1. 接下来,仅使用一个隐藏的 LSTM 层定义 LSTM 模型。由于我们的输出不是序列,我们还将return_sequences设置为False
n_epochs = 1000
batch_size=128
state_size=128
n_epochs_display=100

# create and fit the LSTM model
model = Sequential()
model.add(LSTM(units=state_size,
                input_shape=(x_train.shape[1], x_train.shape[2]),
                return_sequences=False
                )
          )
model.add(Dense(text8.vocab_len))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam')
model.summary()

该模型如下所示:

Layer (type)                 Output Shape              Param #   
=================================================================
lstm_1 (LSTM)                (None, 128)               66560     
_________________________________________________________________
dense_1 (Dense)              (None, 1457)              187953    
_________________________________________________________________
activation_1 (Activation)    (None, 1457)              0         
=================================================================
Total params: 254,513
Trainable params: 254,513
Non-trainable params: 0
_________________________________________________________________
  1. 对于 Keras,我们运行一个循环来运行 10 次,在每次迭代中训练 100 个周期的模型并打印文本生成的结果。以下是训练模型和生成文本的完整代码:
for j in range(n_epochs // n_epochs_display):
     model.fit(x_train, y_onehot, epochs=n_epochs_display,
                         batch_size=batch_size,verbose=0)
     # generate text
     y_pred_r5 = np.empty([10])
     y_pred_f5 = np.empty([10])
     x_test_r5 = random5.copy()
     x_test_f5 = first5.copy()
     # let us generate text of 10 words after feeding 5 words
     for i in range(10):
         for x,y in zip([x_test_r5,x_test_f5],
                        [y_pred_r5,y_pred_f5]):
             x_input = x.copy()
             x_input = x_input.reshape(-1, n_x, n_x_vars)
             y_pred = model.predict(x_input)[0]
             y_pred_id = np.argmax(y_pred)
             y[i]=y_pred_id
             x[:-1] = x[1:]
             x[-1] = y_pred_id
     print('Epoch: ',((j+1) * n_epochs_display)-1)
     print(' Random5 prediction:',id2string(y_pred_r5))
     print(' First5 prediction:',id2string(y_pred_f5))
  1. 输出并不奇怪,从重复单词开始,模型有所改进,但是可以通过更多 LSTM 层,更多数据,更多训练迭代和其他超参数调整来进一步提高。
Random 5 words: free bolshevik be n another 
First 5 words: anarchism originated as a term

预测的输出如下:

Epoch: 99 
    Random5 prediction: anarchistic anarchistic wrote wrote wrote wrote wrote wrote wrote wrote 
    First5 prediction: right philosophy than than than than than than than than 

Epoch: 199 
    Random5 prediction: anarchistic anarchistic wrote wrote wrote wrote wrote wrote wrote wrote 
    First5 prediction: term i revolutionary than war war french french french french 

Epoch: 299 
    Random5 prediction: anarchistic anarchistic wrote wrote wrote wrote wrote wrote wrote wrote 
    First5 prediction: term i revolutionary revolutionary revolutionary revolutionary revolutionary revolutionary revolutionary revolutionary 

Epoch: 399 
    Random5 prediction: anarchistic anarchistic wrote wrote wrote wrote wrote wrote wrote wrote 
    First5 prediction: term i revolutionary labor had had french french french french 

Epoch: 499 
    Random5 prediction: anarchistic anarchistic amongst wrote wrote wrote wrote wrote wrote wrote 
    First5 prediction: term i revolutionary labor individualist had had french french french 

Epoch: 599 
    Random5 prediction: tolstoy wrote tolstoy wrote wrote wrote wrote wrote wrote wrote     First5 prediction: term i revolutionary labor individualist had had had had had 

Epoch: 699 
    Random5 prediction: tolstoy wrote tolstoy wrote wrote wrote wrote wrote wrote wrote     First5 prediction: term i revolutionary labor individualist had had had had had 

Epoch: 799 
    Random5 prediction: tolstoy wrote tolstoy tolstoy tolstoy tolstoy tolstoy tolstoy tolstoy tolstoy 
    First5 prediction: term i revolutionary labor individualist had had had had had 

Epoch: 899 
    Random5 prediction: tolstoy wrote tolstoy tolstoy tolstoy tolstoy tolstoy tolstoy tolstoy tolstoy 
    First5 prediction: term i revolutionary labor should warren warren warren warren warren 

Epoch: 999 
    Random5 prediction: tolstoy wrote tolstoy tolstoy tolstoy tolstoy tolstoy tolstoy tolstoy tolstoy 
    First5 prediction: term i individualist labor should warren warren warren warren warren

如果您注意到我们在 LSTM 模型的输出中有重复的单词用于文本生成。虽然超参数和网络调整可以消除一些重复,但还有其他方法可以解决这个问题。我们得到重复单词的原因是模型总是从单词的概率分布中选择具有最高概率的单词。这可以改变以选择诸如在连续单词之间引入更大可变性的单词。

results matching ""

    No results matching ""