使用保护程序类保存和恢复所选变量

默认情况下,Saver()类将所有变量保存在图中,但您可以通过将变量列表传递给Saver()类的构造函数来选择要保存的变量:

# create saver object
saver = tf.train.Saver({'weights': w})

变量名称可以作为列表或字典传递。如果变量名称作为列表传递,则列表中的每个变量将以其自己的名称保存。变量也可以作为由键值对组成的字典传递,其中键是用于保存的名称,值是要保存的变量的名称。

以下是我们刚看到的示例的代码,但这次我们只保存w变量的权重;保存时将其命名为weights

# Saving selected variables in a graph in TensorFlow

# Assume Linear Model y = w * x + b
# Define model parameters
w = tf.Variable([.3], tf.float32)
b = tf.Variable([-.3], tf.float32)
# Define model input and output
x = tf.placeholder(tf.float32)
y = w * x + b
output = 0

# create saver object
saver = tf.train.Saver({'weights': w})

with tf.Session() as tfs:
    # initialize and print the variable y
    tfs.run(tf.global_variables_initializer())
    output = tfs.run(y,{x:[1,2,3,4]})
    saved_model_file = saver.save(tfs,
        'saved-models/weights-save-example.ckpt')
    print('Model saved in {}'.format(saved_model_file))
    print('Values of variables w,b: {}{}'
        .format(w.eval(),b.eval()))
    print('output={}'.format(output))

我们得到以下输出:

Model saved in saved-models/weights-save-example.ckpt
Values of variables w,b: [ 0.30000001][-0.30000001]
output=[ 0\.          0.30000001  0.60000002  0.90000004]

检查点文件仅保存权重而不是偏差。现在让我们将偏差和权重初始化为零,并恢复权重。此示例的代码在此处给出:

# Restoring selected variables in a graph in TensorFlow
tf.reset_default_graph()
# Assume Linear Model y = w * x + b
# Define model parameters
w = tf.Variable([0], dtype=tf.float32)
b = tf.Variable([0], dtype=tf.float32)
# Define model input and output
x = tf.placeholder(dtype=tf.float32)
y = w * x + b
output = 0

# create saver object
saver = tf.train.Saver({'weights': w})

with tf.Session() as tfs:
    b.initializer.run()
    saved_model_file = saver.restore(tfs,
        'saved-models/weights-save-example.ckpt')
    print('Values of variables w,b: {}{}'
        .format(w.eval(),b.eval()))
    output = tfs.run(y,{x:[1,2,3,4]})
    print('output={}'.format(output))

如您所见,这次我们必须使用b.initializer.run()初始化偏差。我们不使用tfs.run(tf.global_variables_initializer())因为它会初始化所有变量,并且不需要初始化权重,因为它们将从检查点文件中恢复。

我们得到以下输出,因为计算仅使用恢复的权重,而偏差设置为零:

INFO:tensorflow:Restoring parameters from saved-models/weights-save-example.ckpt
Values of variables w,b: [ 0.30000001][ 0.]
output=[ 0.30000001  0.60000002  0.90000004  1.20000005]

results matching ""

    No results matching ""