我想填充3D tf.SparseTensor
的上三角形。
为了获得正确的索引,我使用triu_indices
triu_indices(3, 1)
<tf.Tensor: shape=(3, 2), dtype=int64, numpy=
array([[0, 1],
[0, 2],
[1, 2]], dtype=int64)>
现在我想重复这个n次,并将索引添加到其中。 我发现了以下工作,但我想知道,是否有一个更具python风格,也许更快的方法来完成这项工作
a = tf.range(3, dtype=tf.int64)
a = tf.expand_dims(tf.repeat(tf.expand_dims(a, 1), len(b[0]), axis=1), 2)
<tf.Tensor: shape=(3, 3, 1), dtype=int64, numpy=
array([[[0],
[0],
[0]],
[[1],
[1],
[1]],
[[2],
[2],
[2]]], dtype=int64)>
然后连接得到填充SparseTensor所需的索引
tf.concat([a, b], axis=2)
<tf.Tensor: shape=(3, 3, 3), dtype=int64, numpy=
array([[[0, 0, 1],
[0, 0, 2],
[0, 1, 2]],
[[1, 0, 1],
[1, 0, 2],
[1, 1, 2]],
[[2, 0, 1],
[2, 0, 2],
[2, 1, 2]]], dtype=int64)>
目前没有回答
相关问题 更多 >
编程相关推荐