在下面的代码中,我创建了一个函数,其参数是两个张量和一个索引数组(与张量的最后一个轴相关)。该函数将遍历张量的最后一个轴,提取一个张量切片,执行一些操作(此处不包括),最后构建一个新的张量。输入索引只是指示每次应该处理哪个张量。非常简单
代码运行良好。但是,我注意到,如果我在if-else的两个语句中都放置print(),它将始终被执行,而在输出中我可以看到只有正确的语句被执行。为什么呢
@tf.function
def build_mixed_tensor(x, x2, ixCh):
# Make a boolean mask for the channels axis (True if channel selected)
mask = tf.equal(tf.range(0,x.shape[-1]), tf.expand_dims(ixCh, 1))
mask = tf.reduce_any(mask, axis=0)
# Deal with the channels
for ii in range(0, x.shape[-1]):
if mask[ii]:
print('This is for True')
selChn = tf.gather(x, [ii], axis=-1)
# Do some operations ...
print(selChn.shape)
else:
print('This is for False')
selChn = tf.gather(x2, [ii], axis=-1)
# Do some operations ...
print(selChn.shape)
if ii == 0:
outChn = selChn
else:
outChn = tf.concat([outChn, selChn], axis=2)
print(outChn.shape)
return outChn
# Create two tensors
inp1 = tf.reshape(tf.range(3 * 4 * 5, dtype=tf.int32), [3, 4, 5])
inp2 = tf.reshape(tf.range(100, 100+(3 * 4 * 5), dtype=tf.int32), [3, 4, 5])
updated = build_mixed_tensor(inp1, inp2, [1,2])
with tf.Session() as sess:
print('This is the mixed tensor:')
print(sess.run(updated))
输出为:
This is for True
(3, 4, 1)
This is for False
(3, 4, 1)
(3, 4, 1)
This is for True
(3, 4, 1)
This is for False
(3, 4, 1)
(3, 4, 2)
This is for True
(3, 4, 1)
This is for False
(3, 4, 1)
(3, 4, 3)
This is for True
(3, 4, 1)
This is for False
(3, 4, 1)
(3, 4, 4)
This is for True
(3, 4, 1)
This is for False
(3, 4, 1)
(3, 4, 5)
This is the mixed tensor:
[[[100 1 2 103 104]
[105 6 7 108 109]
[110 11 12 113 114]
[115 16 17 118 119]]
[[120 21 22 123 124]
[125 26 27 128 129]
[130 31 32 133 134]
[135 36 37 138 139]]
[[140 41 42 143 144]
[145 46 47 148 149]
[150 51 52 153 154]
[155 56 57 158 159]]]
好问题!这是因为第一次调用带有
@tf.function
decorator的函数时,该函数是跟踪的,因此以后只需要执行跟踪图(通常更快)。但是,这意味着要跟踪两个分支,以便在运行时立即选择正确的分支只有当条件位于
Tensor
上时才会发生这种情况(在您的案例中就是这样)。如果它是一个“常规”Python bool,则只跟踪具有真实条件的分支,但这并不真正适用于您的示例,因为bool依赖于张量,因此它也将是张量我强烈建议您查看TF官方网站上的this tutorial,该网站描述了这一点以及其他一些特性
相关问题 更多 >
编程相关推荐