Keras 中的 LSTM 文本生成
您可以在 Jupyter 笔记本ch-08b_RNN_Text_Keras
中按照本节的代码进行操作。
我们在 Keras 实现文本生成 LSTM,步骤如下:
- 首先,我们将所有数据转换为两个张量,张量
x
有五列,因为我们一次输入五个字,张量y
只有一列输出。我们将y
或标签张量转换为单热编码表示。
请记住,在大型数据集的实践中,您将使用 word2vec 嵌入而不是单热表示。
# get the data
x_train, y_train = text8.seq_to_xy(seq=text8.part['train'],n_tx=n_x,n_ty=n_y)
# reshape input to be [samples, time steps, features]
x_train = x_train.reshape(x_train.shape[0], x_train.shape[1],1)
y_onehot = np.zeros(shape=[y_train.shape[0],text8.vocab_len],dtype=np.float32)
for i in range(y_train.shape[0]):
y_onehot[i,y_train[i]]=1
- 接下来,仅使用一个隐藏的 LSTM 层定义 LSTM 模型。由于我们的输出不是序列,我们还将
return_sequences
设置为False
:
n_epochs = 1000
batch_size=128
state_size=128
n_epochs_display=100
# create and fit the LSTM model
model = Sequential()
model.add(LSTM(units=state_size,
input_shape=(x_train.shape[1], x_train.shape[2]),
return_sequences=False
)
)
model.add(Dense(text8.vocab_len))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam')
model.summary()
该模型如下所示:
Layer (type) Output Shape Param #
=================================================================
lstm_1 (LSTM) (None, 128) 66560
_________________________________________________________________
dense_1 (Dense) (None, 1457) 187953
_________________________________________________________________
activation_1 (Activation) (None, 1457) 0
=================================================================
Total params: 254,513
Trainable params: 254,513
Non-trainable params: 0
_________________________________________________________________
- 对于 Keras,我们运行一个循环来运行 10 次,在每次迭代中训练 100 个周期的模型并打印文本生成的结果。以下是训练模型和生成文本的完整代码:
for j in range(n_epochs // n_epochs_display):
model.fit(x_train, y_onehot, epochs=n_epochs_display,
batch_size=batch_size,verbose=0)
# generate text
y_pred_r5 = np.empty([10])
y_pred_f5 = np.empty([10])
x_test_r5 = random5.copy()
x_test_f5 = first5.copy()
# let us generate text of 10 words after feeding 5 words
for i in range(10):
for x,y in zip([x_test_r5,x_test_f5],
[y_pred_r5,y_pred_f5]):
x_input = x.copy()
x_input = x_input.reshape(-1, n_x, n_x_vars)
y_pred = model.predict(x_input)[0]
y_pred_id = np.argmax(y_pred)
y[i]=y_pred_id
x[:-1] = x[1:]
x[-1] = y_pred_id
print('Epoch: ',((j+1) * n_epochs_display)-1)
print(' Random5 prediction:',id2string(y_pred_r5))
print(' First5 prediction:',id2string(y_pred_f5))
- 输出并不奇怪,从重复单词开始,模型有所改进,但是可以通过更多 LSTM 层,更多数据,更多训练迭代和其他超参数调整来进一步提高。
Random 5 words: free bolshevik be n another
First 5 words: anarchism originated as a term
预测的输出如下:
Epoch: 99
Random5 prediction: anarchistic anarchistic wrote wrote wrote wrote wrote wrote wrote wrote
First5 prediction: right philosophy than than than than than than than than
Epoch: 199
Random5 prediction: anarchistic anarchistic wrote wrote wrote wrote wrote wrote wrote wrote
First5 prediction: term i revolutionary than war war french french french french
Epoch: 299
Random5 prediction: anarchistic anarchistic wrote wrote wrote wrote wrote wrote wrote wrote
First5 prediction: term i revolutionary revolutionary revolutionary revolutionary revolutionary revolutionary revolutionary revolutionary
Epoch: 399
Random5 prediction: anarchistic anarchistic wrote wrote wrote wrote wrote wrote wrote wrote
First5 prediction: term i revolutionary labor had had french french french french
Epoch: 499
Random5 prediction: anarchistic anarchistic amongst wrote wrote wrote wrote wrote wrote wrote
First5 prediction: term i revolutionary labor individualist had had french french french
Epoch: 599
Random5 prediction: tolstoy wrote tolstoy wrote wrote wrote wrote wrote wrote wrote First5 prediction: term i revolutionary labor individualist had had had had had
Epoch: 699
Random5 prediction: tolstoy wrote tolstoy wrote wrote wrote wrote wrote wrote wrote First5 prediction: term i revolutionary labor individualist had had had had had
Epoch: 799
Random5 prediction: tolstoy wrote tolstoy tolstoy tolstoy tolstoy tolstoy tolstoy tolstoy tolstoy
First5 prediction: term i revolutionary labor individualist had had had had had
Epoch: 899
Random5 prediction: tolstoy wrote tolstoy tolstoy tolstoy tolstoy tolstoy tolstoy tolstoy tolstoy
First5 prediction: term i revolutionary labor should warren warren warren warren warren
Epoch: 999
Random5 prediction: tolstoy wrote tolstoy tolstoy tolstoy tolstoy tolstoy tolstoy tolstoy tolstoy
First5 prediction: term i individualist labor should warren warren warren warren warren
如果您注意到我们在 LSTM 模型的输出中有重复的单词用于文本生成。虽然超参数和网络调整可以消除一些重复,但还有其他方法可以解决这个问题。我们得到重复单词的原因是模型总是从单词的概率分布中选择具有最高概率的单词。这可以改变以选择诸如在连续单词之间引入更大可变性的单词。