使用保护程序类保存和恢复所有图变量
我们进行如下:
- 要使用
saver
类,首先要创建此类的对象:
saver = tf.train.Saver()
- 保存图中所有变量的最简单方法是使用以下两个参数调用
save()
方法:会话对象和磁盘上保存变量的文件的路径:
with tf.Session() as tfs:
...
saver.save(tfs,"saved-models/model.ckpt")
- 要恢复变量,调用
restore()
方法:
with tf.Session() as tfs:
saver.restore(tfs,"saved-models/model.ckpt")
...
- 让我们重温一下第 1 章,TensorFlow 101 的例子,在简单的例子中保存变量的代码如下:
# Assume Linear Model y = w * x + b
# Define model parameters
w = tf.Variable([.3], tf.float32)
b = tf.Variable([-.3], tf.float32)
# Define model input and output
x = tf.placeholder(tf.float32)
y = w * x + b
output = 0
# create saver object
saver = tf.train.Saver()
with tf.Session() as tfs:
# initialize and print the variable y
tfs.run(tf.global_variables_initializer())
output = tfs.run(y,{x:[1,2,3,4]})
saved_model_file = saver.save(tfs,
'saved-models/full-graph-save-example.ckpt')
print('Model saved in {}'.format(saved_model_file))
print('Values of variables w,b: {}{}'
.format(w.eval(),b.eval()))
print('output={}'.format(output))
我们得到以下输出:
Model saved in saved-models/full-graph-save-example.ckpt
Values of variables w,b: [ 0.30000001][-0.30000001]
output=[ 0\. 0.30000001 0.60000002 0.90000004]
- 现在让我们从刚刚创建的检查点文件中恢复变量:
# Assume Linear Model y = w * x + b
# Define model parameters
w = tf.Variable([0], dtype=tf.float32)
b = tf.Variable([0], dtype=tf.float32)
# Define model input and output
x = tf.placeholder(dtype=tf.float32)
y = w * x + b
output = 0
# create saver object
saver = tf.train.Saver()
with tf.Session() as tfs:
saved_model_file = saver.restore(tfs,
'saved-models/full-graph-save-example.ckpt')
print('Values of variables w,b: {}{}'
.format(w.eval(),b.eval()))
output = tfs.run(y,{x:[1,2,3,4]})
print('output={}'.format(output))
您会注意到在恢复代码中我们没有调用tf.global_variables_initializer()
,因为不需要初始化变量,因为它们将从文件中恢复。我们得到以下输出,它是根据恢复的变量计算得出的:
INFO:tensorflow:Restoring parameters from saved-models/full-graph-save-example.ckpt
Values of variables w,b: [ 0.30000001][-0.30000001]
output=[ 0\. 0.30000001 0.60000002 0.90000004]