使用 TensorFlow 中的再训练的 Inception v3 进行图像分类
Inception v3 的再训练与 VGG16 不同,因为我们使用 softmax 激活层作为输出,tf.losses.softmax_cross_entropy()
作为损耗函数。
- 首先定义占位符:
is_training = tf.placeholder(tf.bool,name='is_training')
x_p = tf.placeholder(shape=(None,
image_height,
image_width,
3
),
dtype=tf.float32,
name='x_p')
y_p = tf.placeholder(shape=(None,coco.n_classes),
dtype=tf.int32,
name='y_p')
- 接下来,加载模型:
with slim.arg_scope(inception.inception_v3_arg_scope()):
logits,_ = inception.inception_v3(x_p,
num_classes=coco.n_classes,
is_training=True
)
probabilities = tf.nn.softmax(logits)
- 接下来,定义函数以恢复除最后一层之外的变量:
with slim.arg_scope(inception.inception_v3_arg_scope()):
logits,_ = inception.inception_v3(x_p,
num_classes=coco.n_classes,
is_training=True
)
probabilities = tf.nn.softmax(logits)
# restore except last layer
checkpoint_exclude_scopes=["InceptionV3/Logits",
"InceptionV3/AuxLogits"]
exclusions = [scope.strip() for scope in checkpoint_exclude_scopes]
variables_to_restore = []
for var in slim.get_model_variables():
excluded = False
for exclusion in exclusions:
if var.op.name.startswith(exclusion):
excluded = True
break
if not excluded:
variables_to_restore.append(var)
init_fn = slim.assign_from_checkpoint_fn(
os.path.join(model_home, '{}.ckpt'.format(model_name)),
variables_to_restore)
- 定义损失,优化程序和训练操作:
tf.losses.softmax_cross_entropy(onehot_labels=y_p, logits=logits)
loss = tf.losses.get_total_loss()
learning_rate = 0.001
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
train_op = optimizer.minimize(loss)
- 训练模型并在同一会话中完成训练后运行预测:
n_epochs=10
coco.y_onehot = True
coco.batch_size = 32
coco.batch_shuffle = True
total_images = len(x_train_files)
n_batches = total_images // coco.batch_size
with tf.Session() as tfs:
tfs.run(tf.global_variables_initializer())
init_fn(tfs)
for epoch in range(n_epochs):
print('Starting epoch ',epoch)
coco.reset_index()
epoch_accuracy=0
epoch_loss=0
for batch in range(n_batches):
x_batch, y_batch = coco.next_batch()
images=np.array([coco.preprocess_for_inception(x) \
for x in x_batch])
feed_dict={x_p:images,y_p:y_batch,is_training:True}
batch_loss,_ = tfs.run([loss,train_op],
feed_dict = feed_dict)
epoch_loss += batch_loss
epoch_loss /= n_batches
print('Train loss in epoch {}:{}'
.format(epoch,epoch_loss))
# now run the predictions
feed_dict={x_p:images_test,is_training: False}
probs = tfs.run([probabilities],feed_dict=feed_dict)
probs=probs[0]
我们看到每个周期的损失都在减少:
INFO:tensorflow:Restoring parameters from /home/armando/models/inception_v3/inception_v3.ckpt
Starting epoch 0
Train loss in epoch 0:2.7896385192871094
Starting epoch 1
Train loss in epoch 1:1.6651896286010741
Starting epoch 2
Train loss in epoch 2:1.2332031989097596
Starting epoch 3
Train loss in epoch 3:0.9912329530715942
Starting epoch 4
Train loss in epoch 4:0.8110128355026245
Starting epoch 5
Train loss in epoch 5:0.7177265572547913
Starting epoch 6
Train loss in epoch 6:0.6175705575942994
Starting epoch 7
Train loss in epoch 7:0.5542363750934601
Starting epoch 8
Train loss in epoch 8:0.523461252450943
Starting epoch 9
Train loss in epoch 9:0.4923107647895813
这次结果正确识别了绵羊,但错误地将猫图片识别为狗:
Probability 98.84% of [zebra]
Probability 0.84% of [giraffe]
Probability 0.11% of [sheep]
Probability 0.07% of [cat]
Probability 0.06% of [dog]
Probability 95.77% of [horse]
Probability 1.34% of [dog]
Probability 0.89% of [zebra]
Probability 0.68% of [bird]
Probability 0.61% of [sheep]
Probability 94.83% of [dog]
Probability 4.53% of [cat]
Probability 0.56% of [sheep]
Probability 0.04% of [bear]
Probability 0.02% of [zebra]
Probability 42.80% of [bird]
Probability 25.64% of [cat]
Probability 15.56% of [bear]
Probability 8.77% of [giraffe]
Probability 3.39% of [sheep]
Probability 72.58% of [sheep]
Probability 8.40% of [bear]
Probability 7.64% of [giraffe]
Probability 4.02% of [horse]
Probability 3.65% of [bird]
Probability 98.03% of [bear]
Probability 0.74% of [cat]
Probability 0.54% of [sheep]
Probability 0.28% of [bird]
Probability 0.17% of [horse]
Probability 96.43% of [giraffe]
Probability 1.78% of [bird]
Probability 1.10% of [sheep]
Probability 0.32% of [zebra]
Probability 0.14% of [bear]
Probability 34.43% of [horse]
Probability 23.53% of [dog]
Probability 16.03% of [zebra]
Probability 9.76% of [cat]
Probability 9.02% of [giraffe]