使用tf.Print()
打印张量值
为调试目的打印值的另一个选项是使用tf.Print()
。当执行包含tf.Print()
节点的路径时,您可以在tf.Print()
中包含张量以在标准错误控制台中打印其值。 tf.Print()
函数具有以下签名:
tf.Print(
input_,
data,
message=None,
first_n=None,
summarize=None,
name=None
)
该函数的参数如下:
input_
是一个从函数返回的张量,没有任何操作data
是要打印的张量列表message
是一个字符串,它作为打印输出的前缀打印出来first_n
表示打印输出的步骤数;如果此值为负,则只要执行路径,就始终打印该值summarize
表示从张量打印的元素数量;默认情况下,仅打印三个元素
您可以按照 Jupyter 笔记本中的代码ch-18_TensorFlow_Debugging
。
让我们修改之前创建的 MNIST MLP 模型来添加 print 语句:
model = tf.Print(input_=model,
data=[tf.argmax(model,1)],
message='y_hat=',
summarize=10,
first_n=5
)
当我们运行代码时,我们在 Jupyter 的控制台中获得以下内容:
I tensorflow/core/kernels/logging_ops.cc:79] y_hat=[0 0 0 7 0 0 0 0 0 0...]
I tensorflow/core/kernels/logging_ops.cc:79] y_hat=[0 7 7 1 8 7 2 7 7 0...]
I tensorflow/core/kernels/logging_ops.cc:79] y_hat=[4 8 0 6 1 8 1 0 7 0...]
I tensorflow/core/kernels/logging_ops.cc:79] y_hat=[0 0 1 0 0 0 0 5 7 5...]
I tensorflow/core/kernels/logging_ops.cc:79] y_hat=[9 2 2 8 8 6 6 1 7 7...]
使用tf.Print()
的唯一缺点是该函数提供了有限的格式化功能。