使用tf.get_variable()获取变量

如果使用之前定义的名称定义变量,则 TensorFlow 会抛出异常。因此,使用tf.get_variable()函数代替tf.Variable()很方便。函数tf.get_variable()返回具有相同名称的现有变量(如果存在),并创建具有指定形状的变量和初始化器(如果它不存在)。例如:

w = tf.get_variable(name='w',shape=[1],dtype=tf.float32,initializer=[.3])
b = tf.get_variable(name='b',shape=[1],dtype=tf.float32,initializer=[-.3])

初始化程序可以是上面示例中显示的张量或值列表,也可以是内置初始化程序之一:

  • tf.constant_initializer
  • tf.random_normal_initializer
  • tf.truncated_normal_initializer
  • tf.random_uniform_initializer
  • tf.uniform_unit_scaling_initializer
  • tf.zeros_initializer
  • tf.ones_initializer
  • tf.orthogonal_initializer

在分布式 TensorFlow 中,我们可以跨机器运行代码,tf.get_variable()为我们提供了全局变量。要获取局部变量,TensorFlow 具有类似签名的函数:tf.get_local_variable()

共享或重用变量:获取已定义的变量可促进重用。但是,如果未使用tf.variable_scope.reuse_variable()tf.variable.scope(reuse=True)设置重用标志,则会抛出异常。

现在您已经学会了如何定义张量,常量,运算,占位符和变量,让我们了解 TensorFlow 中的下一级抽象,它将这些基本元素组合在一起形成一个基本的计算单元,即数据流图或计算图。

results matching ""

    No results matching ""