重新训练现有的 CNN 模型

从头开始训练新的图像识别需要大量的时间和计算能力。如果我们可以采用先前训练的网络并使用我们的图像重新训练它,它可以节省我们的计算时间。对于此秘籍,我们将展示如何使用预先训练的 TensorFlow 图像识别模型并对其进行微调以处理不同的图像集。

做好准备

其思想是从卷积层重用先前模型的权重和结构,并重新训练网络顶部的完全连接层。

TensorFlow 在现有 CNN 模型的基础上创建了一个关于训练的教程(请参阅下一节中的第一个要点)。在本文中,我们将说明如何对 CIFAR-10 使用相同的方法。我们将采用的 CNN 网络使用一种非常流行的架构,称为 Inception。 Inception CNN 模型由 Google 创建,在许多图像识别基准测试中表现非常出色。有关详细信息,请参阅“另请参阅”部分的第二个要点中的纸张参考。

我们将介绍的主要 Python 脚本显示如何下载 CIFAR-10 图像数据并自动分离,标记和保存图像到每个训练和测试文件夹中的十个类。之后,我们将重申如何在我们的图像上训练网络。

操作步骤

执行以下步骤:

  1. 我们首先加载必要的库来下载,解压缩和保存 CIFAR-10 图像:
import os 
import tarfile 
import _pickle as cPickle 
import numpy as np 
import urllib.request 
import scipy.misc
from imageio import imwrite
  1. 我们现在声明 CIFAR-10 数据链接并创建我们将存储数据的临时目录。我们还将在以后保存图像时声明要引用的十个类别:
cifar_link = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' 
data_dir = 'temp' 
if not os.path.isdir(data_dir): 
    os.makedirs(data_dir) 
objects = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
  1. 现在我们将下载 CIFAR-10 .tar数据文件,并解压该文件:
target_file = os.path.join(data_dir, 'cifar-10-python.tar.gz') 
if not os.path.isfile(target_file): 
    print('CIFAR-10 file not found. Downloading CIFAR data (Size = 163MB)') 
    print('This may take a few minutes, please wait.') 
    filename, headers = urllib.request.urlretrieve(cifar_link, target_file) 
# Extract into memory 
tar = tarfile.open(target_file) 
tar.extractall(path=data_dir) 
tar.close()
  1. 我们现在为训练创建必要的文件夹结构。临时目录将有两个文件夹,train_dirvalidation_dir。在每个文件夹中,我们将为每个类别创建 10 个子文件夹:
# Create train image folders 
train_folder = 'train_dir' 
if not os.path.isdir(os.path.join(data_dir, train_folder)): 
    for i in range(10): 
        folder = os.path.join(data_dir, train_folder, objects[i]) 
        os.makedirs(folder) 
# Create test image folders 
test_folder = 'validation_dir' 
if not os.path.isdir(os.path.join(data_dir, test_folder)): 
    for i in range(10): 
        folder = os.path.join(data_dir, test_folder, objects[i]) 
        os.makedirs(folder)
  1. 为了保存图像,我们将创建一个从内存加载它们并将它们存储在图像字典中的函数:
def load_batch_from_file(file): 
    file_conn = open(file, 'rb') 
    image_dictionary = cPickle.load(file_conn, encoding='latin1') 
    file_conn.close() 
    return(image_dictionary)
  1. 使用前面的字典,我们将使用以下函数将每个文件保存在正确的位置:
def save_images_from_dict(image_dict, folder='data_dir'): 

    for ix, label in enumerate(image_dict['labels']): 
        folder_path = os.path.join(data_dir, folder, objects[label]) 
        filename = image_dict['filenames'][ix] 
        #Transform image data 
        image_array = image_dict['data'][ix] 
        image_array.resize([3, 32, 32]) 
        # Save image 
        output_location = os.path.join(folder_path, filename) 
        imwrite(output_location,image_array.transpose())
  1. 使用上述函数,我们可以遍历下载的数据文件并将每个图像保存到正确的位置:
data_location = os.path.join(data_dir, 'cifar-10-batches-py') 
train_names = ['data_batch_' + str(x) for x in range(1,6)] 
test_names = ['test_batch'] 
# Sort train images 
for file in train_names: 
    print('Saving images from file: {}'.format(file)) 
    file_location = os.path.join(data_dir, 'cifar-10-batches-py', file) 
    image_dict = load_batch_from_file(file_location) 
    save_images_from_dict(image_dict, folder=train_folder) 
# Sort test images 
for file in test_names: 
    print('Saving images from file: {}'.format(file)) 
    file_location = os.path.join(data_dir, 'cifar-10-batches-py', file) 
    image_dict = load_batch_from_file(file_location) 
    save_images_from_dict(image_dict, folder=test_folder)
  1. 我们脚本的最后一部分创建了图像标签文件,这是我们需要的最后一条信息。这个文件让我们将输出解释为标签而不是数字索引:
cifar_labels_file = os.path.join(data_dir,'cifar10_labels.txt') 
print('Writing labels file, {}'.format(cifar_labels_file)) 
with open(cifar_labels_file, 'w') as labels_file: 
    for item in objects: 
        labels_file.write("{}n".format(item))
  1. 当前面的脚本运行时,它将下载图像并将它们分类到 TensorFlow 再训练教程所期望的正确文件夹结构中。完成后,我们只需按照教程进行操作即可。首先,我们应该克隆教程仓库:
git clone https://github.com/tensorflow/models/tree/master/research/inception
  1. 为了使用先前训练的模型,我们必须下载网络权重并将其应用于我们的模型。为此,您必须访问该站点: https://github.com/tensorflow/models/tree/master/research/slim ,并按照说明下载并安装 cifar10 模型架构和权重。您还将最终下载包含下面描述的构建,训练和测试脚本的数据目录。

对于此步骤,我们导航到 research / inception / inception 目录,然后执行以下命令,--train_directory--validation_directory--output_directory--labels_file的路径指向相对路径或完整路径创建的目录结构。

  1. 现在我们将图像放在正确的文件夹结构中,我们必须将它们变成TFRecords对象。我们通过运行以下命令来完成此操作:
    me@computer:~$ python3 data/build_image_data.py
    --train_directory="temp/train_dir/"
    --validation_directory="temp/validation_dir"
    --output_directory="temp/" --labels_file="temp/cifar10_labels.txt"
  1. 现在我们将使用bazel训练模型,将参数设置为true。该脚本每 10 代输出一次损失。我们可以随时终止此过程,模型输出将在temp/training_results文件夹中。我们可以从此文件夹加载模型以进行评估:
    me@computer:~$ bazel-bin/inception/flowers_train
    --train_dir="temp/training_results" --data_dir="temp/data_dir"
    --pretrained_model_checkpoint_path="model.ckpt-157585"
    --fine_tune=True --initial_learning_rate=0.001
    --input_queue_memory_factor=1
  1. 这应该导致输出类似于以下内容:
2018-06-02 11:10:10.557012: step 1290, loss = 2.02 (1.2  examples/sec; 23.771 sec/batch)
...

工作原理

关于预训练 CNN 上的训练的官方 TensorFlow 教程需要设置一个文件夹;我们从 CIFAR-10 数据创建的设置。然后我们将数据转换为所需的TFRecords格式并开始训练模型。请记住,我们正在微调模型并重新训练顶部的完全连接的层以适合我们的 10 类数据。

另见

results matching ""

    No results matching ""