用于 MNIST 数据的 Keras 中的 RNN
虽然 RNN 主要用于序列数据,但它也可用于图像数据。我们知道图像具有最小的两个维度 - 高度和宽度。现在将其中一个维度视为时间步长,将其他维度视为特征。对于 MNIST,图像大小为 28 x 28 像素,因此我们可以将 MNIST 图像视为具有 28 个时间步长,每个时间步长具有 28 个特征。
我们将在下一章中提供时间序列和文本数据的示例,但让我们为 Keras 中的 MNIST 构建和训练 RNN,以快速浏览构建和训练 RNN 模型的过程。
您可以按照 Jupyter 笔记本中的代码ch-06_RNN_MNIST_Keras
。
导入所需的模块:
import keras
from keras.models import Sequential
from keras.layers import Dense, Activation
from keras.layers.recurrent import SimpleRNN
from keras.optimizers import RMSprop
from keras.optimizers import SGD
获取 MNIST 数据并将数据从 1-D 中的 784 像素转换为 2-D 中的 28 x 28 像素:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets(os.path.join(datasetslib.datasets_root,
'mnist'),
one_hot=True)
X_train = mnist.train.images
X_test = mnist.test.images
Y_train = mnist.train.labels
Y_test = mnist.test.labels
n_classes = 10
n_classes = 10
X_train = X_train.reshape(-1,28,28)
X_test = X_test.reshape(-1,28,28)
在 Keras 构建 SimpleRNN 模型:
# create and fit the SimpleRNN model
model = Sequential()
model.add(SimpleRNN(units=16, activation='relu', input_shape=(28,28)))
model.add(Dense(n_classes))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy',
optimizer=RMSprop(lr=0.01),
metrics=['accuracy'])
model.summary()
该模型如下:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
simple_rnn_1 (SimpleRNN) (None, 16) 720
_________________________________________________________________
dense_1 (Dense) (None, 10) 170
_________________________________________________________________
activation_1 (Activation) (None, 10) 0
=================================================================
Total params: 890
Trainable params: 890
Non-trainable params: 0
_________________________________________________________________
训练模型并打印测试数据集的准确性:
model.fit(X_train, Y_train,
batch_size=100, epochs=20)
score = model.evaluate(X_test, Y_test)
print('\nTest loss:', score[0])
print('Test accuracy:', score[1])
我们得到以下结果:
Test loss: 0.520945608187
Test accuracy: 0.8379