java在tensorflow中沿轴选择随机数目的点
我有一个大小为(X,Y,N)的掩码的三维数组,每个值为1或0。当值为1时,我们希望在最后一个维度中收集索引
X, Y, N = 10, 10, 10
# Each point in ar is 1 or 0
ar = tf.random.uniform((X, Y, N), maxval=2, seed=1, name=None, dtype=tf.int32)
# We now want to collect 4 point indices along last dimension if the corresponding value is 1
当相应的值为1时,我想沿着第三维对n(=4)个索引进行采样。我如何在tensorflow中做到这一点?我的函数的输出应该是形状(X,Y,4)
if output[x, y] = [n1, n2, n3, n4] then
ar[x,y, n1] = 1
ar[x,y, n2] = 1
ar[x,y, n3] = 1
ar[x,y, n4] = 1
...
...
...
# 1 楼答案
我发现你可能在找^{} 。当值为
1
时,保持概率相等,当值为0
时,保持概率为tf.log(0.0)
你还需要使用
tf.map_fn
来包装tf.multinomial
,因为tf.multinomial
的对数概率需要二维张量请注意
tf.multinomial
将在新版本中删除。更新说明:使用tf。随机的而是直截了当的