三维生成指数的Tensorflow SparseTensor

2024-09-29 06:28:09 发布

您现在位置:Python中文网/ 问答频道 /正文

我想填充3D tf.SparseTensor的上三角形。 为了获得正确的索引,我使用triu_indices

triu_indices(3, 1)
<tf.Tensor: shape=(3, 2), dtype=int64, numpy=
array([[0, 1],
       [0, 2],
       [1, 2]], dtype=int64)>

现在我想重复这个n次,并将索引添加到其中。 我发现了以下工作,但我想知道,是否有一个更具python风格,也许更快的方法来完成这项工作

a = tf.range(3, dtype=tf.int64)
a = tf.expand_dims(tf.repeat(tf.expand_dims(a, 1), len(b[0]), axis=1), 2)
<tf.Tensor: shape=(3, 3, 1), dtype=int64, numpy=
array([[[0],
        [0],
        [0]],

       [[1],
        [1],
        [1]],

       [[2],
        [2],
        [2]]], dtype=int64)>

然后连接得到填充SparseTensor所需的索引

tf.concat([a, b], axis=2)
<tf.Tensor: shape=(3, 3, 3), dtype=int64, numpy=
array([[[0, 0, 1],
        [0, 0, 2],
        [0, 1, 2]],

       [[1, 0, 1],
        [1, 0, 2],
        [1, 1, 2]],

       [[2, 0, 1],
        [2, 0, 2],
        [2, 1, 2]]], dtype=int64)>

Tags: numpy风格tfarrayexpandtensorshapedtype