使用tf.Print()打印张量值

为调试目的打印值的另一个选项是使用tf.Print()。当执行包含tf.Print()节点的路径时,您可以在tf.Print()中包含张量以在标准错误控制台中打印其值。 tf.Print()函数具有以下签名:

tf.Print(
     input_,
     data,
     message=None,
     first_n=None,
     summarize=None,
     name=None
    )

该函数的参数如下:

  • input_是一个从函数返回的张量,没有任何操作
  • data是要打印的张量列表
  • message是一个字符串,它作为打印输出的前缀打印出来
  • first_n表示打印输出的步骤数;如果此值为负,则只要执行路径,就始终打印该值
  • summarize表示从张量打印的元素数量;默认情况下,仅打印三个元素

您可以按照 Jupyter 笔记本中的代码ch-18_TensorFlow_Debugging

让我们修改之前创建的 MNIST MLP 模型来添加 print 语句:

model = tf.Print(input_=model,
                 data=[tf.argmax(model,1)],
                 message='y_hat=',
                 summarize=10,
                 first_n=5
  )

当我们运行代码时,我们在 Jupyter 的控制台中获得以下内容:

I tensorflow/core/kernels/logging_ops.cc:79] y_hat=[0 0 0 7 0 0 0 0 0 0...]
I tensorflow/core/kernels/logging_ops.cc:79] y_hat=[0 7 7 1 8 7 2 7 7 0...]
I tensorflow/core/kernels/logging_ops.cc:79] y_hat=[4 8 0 6 1 8 1 0 7 0...]
I tensorflow/core/kernels/logging_ops.cc:79] y_hat=[0 0 1 0 0 0 0 5 7 5...]
I tensorflow/core/kernels/logging_ops.cc:79] y_hat=[9 2 2 8 8 6 6 1 7 7...]

使用tf.Print()的唯一缺点是该函数提供了有限的格式化功能。

results matching ""

    No results matching ""