使用保护程序类保存和恢复所选变量
默认情况下,Saver()
类将所有变量保存在图中,但您可以通过将变量列表传递给Saver()
类的构造函数来选择要保存的变量:
# create saver object
saver = tf.train.Saver({'weights': w})
变量名称可以作为列表或字典传递。如果变量名称作为列表传递,则列表中的每个变量将以其自己的名称保存。变量也可以作为由键值对组成的字典传递,其中键是用于保存的名称,值是要保存的变量的名称。
以下是我们刚看到的示例的代码,但这次我们只保存w
变量的权重;保存时将其命名为weights
:
# Saving selected variables in a graph in TensorFlow
# Assume Linear Model y = w * x + b
# Define model parameters
w = tf.Variable([.3], tf.float32)
b = tf.Variable([-.3], tf.float32)
# Define model input and output
x = tf.placeholder(tf.float32)
y = w * x + b
output = 0
# create saver object
saver = tf.train.Saver({'weights': w})
with tf.Session() as tfs:
# initialize and print the variable y
tfs.run(tf.global_variables_initializer())
output = tfs.run(y,{x:[1,2,3,4]})
saved_model_file = saver.save(tfs,
'saved-models/weights-save-example.ckpt')
print('Model saved in {}'.format(saved_model_file))
print('Values of variables w,b: {}{}'
.format(w.eval(),b.eval()))
print('output={}'.format(output))
我们得到以下输出:
Model saved in saved-models/weights-save-example.ckpt
Values of variables w,b: [ 0.30000001][-0.30000001]
output=[ 0\. 0.30000001 0.60000002 0.90000004]
检查点文件仅保存权重而不是偏差。现在让我们将偏差和权重初始化为零,并恢复权重。此示例的代码在此处给出:
# Restoring selected variables in a graph in TensorFlow
tf.reset_default_graph()
# Assume Linear Model y = w * x + b
# Define model parameters
w = tf.Variable([0], dtype=tf.float32)
b = tf.Variable([0], dtype=tf.float32)
# Define model input and output
x = tf.placeholder(dtype=tf.float32)
y = w * x + b
output = 0
# create saver object
saver = tf.train.Saver({'weights': w})
with tf.Session() as tfs:
b.initializer.run()
saved_model_file = saver.restore(tfs,
'saved-models/weights-save-example.ckpt')
print('Values of variables w,b: {}{}'
.format(w.eval(),b.eval()))
output = tfs.run(y,{x:[1,2,3,4]})
print('output={}'.format(output))
如您所见,这次我们必须使用b.initializer.run()
初始化偏差。我们不使用tfs.run(tf.global_variables_initializer())
因为它会初始化所有变量,并且不需要初始化权重,因为它们将从检查点文件中恢复。
我们得到以下输出,因为计算仅使用恢复的权重,而偏差设置为零:
INFO:tensorflow:Restoring parameters from saved-models/weights-save-example.ckpt
Values of variables w,b: [ 0.30000001][ 0.]
output=[ 0.30000001 0.60000002 0.90000004 1.20000005]