使用tf.get_variable()
获取变量
如果使用之前定义的名称定义变量,则 TensorFlow 会抛出异常。因此,使用tf.get_variable()
函数代替tf.Variable()
很方便。函数tf.get_variable()
返回具有相同名称的现有变量(如果存在),并创建具有指定形状的变量和初始化器(如果它不存在)。例如:
w = tf.get_variable(name='w',shape=[1],dtype=tf.float32,initializer=[.3])
b = tf.get_variable(name='b',shape=[1],dtype=tf.float32,initializer=[-.3])
初始化程序可以是上面示例中显示的张量或值列表,也可以是内置初始化程序之一:
tf.constant_initializer
tf.random_normal_initializer
tf.truncated_normal_initializer
tf.random_uniform_initializer
tf.uniform_unit_scaling_initializer
tf.zeros_initializer
tf.ones_initializer
tf.orthogonal_initializer
在分布式 TensorFlow 中,我们可以跨机器运行代码,tf.get_variable()
为我们提供了全局变量。要获取局部变量,TensorFlow 具有类似签名的函数:tf.get_local_variable()
。
共享或重用变量:获取已定义的变量可促进重用。但是,如果未使用tf.variable_scope.reuse_variable()
或tf.variable.scope(reuse=True)
设置重用标志,则会抛出异常。
现在您已经学会了如何定义张量,常量,运算,占位符和变量,让我们了解 TensorFlow 中的下一级抽象,它将这些基本元素组合在一起形成一个基本的计算单元,即数据流图或计算图。