重新训练现有的 CNN 模型
从头开始训练新的图像识别需要大量的时间和计算能力。如果我们可以采用先前训练的网络并使用我们的图像重新训练它,它可以节省我们的计算时间。对于此秘籍,我们将展示如何使用预先训练的 TensorFlow 图像识别模型并对其进行微调以处理不同的图像集。
做好准备
其思想是从卷积层重用先前模型的权重和结构,并重新训练网络顶部的完全连接层。
TensorFlow 在现有 CNN 模型的基础上创建了一个关于训练的教程(请参阅下一节中的第一个要点)。在本文中,我们将说明如何对 CIFAR-10 使用相同的方法。我们将采用的 CNN 网络使用一种非常流行的架构,称为 Inception。 Inception CNN 模型由 Google 创建,在许多图像识别基准测试中表现非常出色。有关详细信息,请参阅“另请参阅”部分的第二个要点中的纸张参考。
我们将介绍的主要 Python 脚本显示如何下载 CIFAR-10 图像数据并自动分离,标记和保存图像到每个训练和测试文件夹中的十个类。之后,我们将重申如何在我们的图像上训练网络。
操作步骤
执行以下步骤:
- 我们首先加载必要的库来下载,解压缩和保存 CIFAR-10 图像:
import os
import tarfile
import _pickle as cPickle
import numpy as np
import urllib.request
import scipy.misc
from imageio import imwrite
- 我们现在声明 CIFAR-10 数据链接并创建我们将存储数据的临时目录。我们还将在以后保存图像时声明要引用的十个类别:
cifar_link = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
data_dir = 'temp'
if not os.path.isdir(data_dir):
os.makedirs(data_dir)
objects = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
- 现在我们将下载 CIFAR-10
.tar
数据文件,并解压该文件:
target_file = os.path.join(data_dir, 'cifar-10-python.tar.gz')
if not os.path.isfile(target_file):
print('CIFAR-10 file not found. Downloading CIFAR data (Size = 163MB)')
print('This may take a few minutes, please wait.')
filename, headers = urllib.request.urlretrieve(cifar_link, target_file)
# Extract into memory
tar = tarfile.open(target_file)
tar.extractall(path=data_dir)
tar.close()
- 我们现在为训练创建必要的文件夹结构。临时目录将有两个文件夹,
train_dir
和validation_dir
。在每个文件夹中,我们将为每个类别创建 10 个子文件夹:
# Create train image folders
train_folder = 'train_dir'
if not os.path.isdir(os.path.join(data_dir, train_folder)):
for i in range(10):
folder = os.path.join(data_dir, train_folder, objects[i])
os.makedirs(folder)
# Create test image folders
test_folder = 'validation_dir'
if not os.path.isdir(os.path.join(data_dir, test_folder)):
for i in range(10):
folder = os.path.join(data_dir, test_folder, objects[i])
os.makedirs(folder)
- 为了保存图像,我们将创建一个从内存加载它们并将它们存储在图像字典中的函数:
def load_batch_from_file(file):
file_conn = open(file, 'rb')
image_dictionary = cPickle.load(file_conn, encoding='latin1')
file_conn.close()
return(image_dictionary)
- 使用前面的字典,我们将使用以下函数将每个文件保存在正确的位置:
def save_images_from_dict(image_dict, folder='data_dir'):
for ix, label in enumerate(image_dict['labels']):
folder_path = os.path.join(data_dir, folder, objects[label])
filename = image_dict['filenames'][ix]
#Transform image data
image_array = image_dict['data'][ix]
image_array.resize([3, 32, 32])
# Save image
output_location = os.path.join(folder_path, filename)
imwrite(output_location,image_array.transpose())
- 使用上述函数,我们可以遍历下载的数据文件并将每个图像保存到正确的位置:
data_location = os.path.join(data_dir, 'cifar-10-batches-py')
train_names = ['data_batch_' + str(x) for x in range(1,6)]
test_names = ['test_batch']
# Sort train images
for file in train_names:
print('Saving images from file: {}'.format(file))
file_location = os.path.join(data_dir, 'cifar-10-batches-py', file)
image_dict = load_batch_from_file(file_location)
save_images_from_dict(image_dict, folder=train_folder)
# Sort test images
for file in test_names:
print('Saving images from file: {}'.format(file))
file_location = os.path.join(data_dir, 'cifar-10-batches-py', file)
image_dict = load_batch_from_file(file_location)
save_images_from_dict(image_dict, folder=test_folder)
- 我们脚本的最后一部分创建了图像标签文件,这是我们需要的最后一条信息。这个文件让我们将输出解释为标签而不是数字索引:
cifar_labels_file = os.path.join(data_dir,'cifar10_labels.txt')
print('Writing labels file, {}'.format(cifar_labels_file))
with open(cifar_labels_file, 'w') as labels_file:
for item in objects:
labels_file.write("{}n".format(item))
- 当前面的脚本运行时,它将下载图像并将它们分类到 TensorFlow 再训练教程所期望的正确文件夹结构中。完成后,我们只需按照教程进行操作即可。首先,我们应该克隆教程仓库:
git clone https://github.com/tensorflow/models/tree/master/research/inception
- 为了使用先前训练的模型,我们必须下载网络权重并将其应用于我们的模型。为此,您必须访问该站点: https://github.com/tensorflow/models/tree/master/research/slim ,并按照说明下载并安装 cifar10 模型架构和权重。您还将最终下载包含下面描述的构建,训练和测试脚本的数据目录。
对于此步骤,我们导航到 research / inception / inception 目录,然后执行以下命令,
--train_directory
,--validation_directory
,--output_directory
和--labels_file
的路径指向相对路径或完整路径创建的目录结构。
- 现在我们将图像放在正确的文件夹结构中,我们必须将它们变成
TFRecords
对象。我们通过运行以下命令来完成此操作:
me@computer:~$ python3 data/build_image_data.py
--train_directory="temp/train_dir/"
--validation_directory="temp/validation_dir"
--output_directory="temp/" --labels_file="temp/cifar10_labels.txt"
- 现在我们将使用
bazel
训练模型,将参数设置为true
。该脚本每 10 代输出一次损失。我们可以随时终止此过程,模型输出将在temp/training_results
文件夹中。我们可以从此文件夹加载模型以进行评估:
me@computer:~$ bazel-bin/inception/flowers_train
--train_dir="temp/training_results" --data_dir="temp/data_dir"
--pretrained_model_checkpoint_path="model.ckpt-157585"
--fine_tune=True --initial_learning_rate=0.001
--input_queue_memory_factor=1
- 这应该导致输出类似于以下内容:
2018-06-02 11:10:10.557012: step 1290, loss = 2.02 (1.2 examples/sec; 23.771 sec/batch)
...
工作原理
关于预训练 CNN 上的训练的官方 TensorFlow 教程需要设置一个文件夹;我们从 CIFAR-10 数据创建的设置。然后我们将数据转换为所需的TFRecords
格式并开始训练模型。请记住,我们正在微调模型并重新训练顶部的完全连接的层以适合我们的 10 类数据。
另见
- 官方 Tensorflow Inception-v3 教程: https://www.tensorflow.org/tutoriaimg/image_recognition
- Googlenet Inception-v3 文件: https://arxiv.org/abs/1512.00567