tf.Assert()
调试 TensorFlow 模型的另一种方法是插入条件断言。tf.Assert()
函数需要一个条件,如果条件为假,则打印给定张量的列表并抛出 tf.errors.InvalidArgumentError
。
tf.Assert()
函数具有以下特征:
tf.Assert(
condition,
data,
summarize=None,
name=None
)
- 断言操作不会像
tf.Print()
函数那样落入图的路径中。为了确保tf.Assert()
操作得到执行,我们需要将它添加到依赖项中。例如,让我们定义一个断言来检查所有输入是否为正:
assert_op = tf.Assert(tf.reduce_all(tf.greater_equal(x,0)),[x])
- 在定义模型时将
assert_op
添加到依赖项,如下所示:
with tf.control_dependencies([assert_op]):
# x is input layer
layer = x
# add hidden layers
for i in range(num_layers):
layer = tf.nn.relu(tf.matmul(layer, w[i]) + b[i])
# add output layer
layer = tf.matmul(layer, w[num_layers]) + b[num_layers]
- 为了测试这段代码,我们在第 5 周期之后引入了一个杂质,如下:
if epoch > 5:
X_batch = np.copy(X_batch)
X_batch[0,0]=-2
- 代码运行正常五个周期,然后抛出错误:
epoch: 0000 loss = 6.975991
epoch: 0001 loss = 2.246228
epoch: 0002 loss = 1.924571
epoch: 0003 loss = 1.745509
epoch: 0004 loss = 1.616791
epoch: 0005 loss = 1.520804
-----------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
...
InvalidArgumentError: assertion failed: [[-2 0 0]...]
...
除了tf.Assert()
函数,它可以采用任何有效的条件表达式,TensorFlow 提供以下断言操作,检查特定条件并具有简单的语法:
assert_equal
assert_greater
assert_greater_equal
assert_integer
assert_less
assert_less_equal
assert_negative
assert_none_equal
assert_non_negative
assert_non_positive
assert_positive
assert_proper_iterable
assert_rank
assert_rank_at_least
assert_rank_in
assert_same_float_dtype
assert_scalar
assert_type
assert_variables_initialized
作为示例,前面提到的示例断言操作也可以写成如下:
assert_op = tf.assert_greater_equal(x,0)