Tensorflow:十的交叉索引切片

2024-09-29 23:16:09 发布

您现在位置:Python中文网/ 问答频道 /正文

我有两个张量,形状如下:

tensor1 => shape(10, 99, 106)
tensor2 => shape(10, 99)

tensor2包含从0 - 105tensor1的值,我希望使用这些值来剖切tensor1的最后一个维度并获得形状的tensor3

tensor3 => shape(10, 99, 99)

我试过使用:

tensor4 = tf.gather(tensor1, tensor2)
# this causes tensor4 to be of shape (10, 99, 99, 106)

另外,使用

tensor4 = tf.gather_nd(tensor1, tensor2)
# gives the error: last dimension of tensor2 (which is 99) must be 
# less than the rank of the tensor1 (which is 3).

我要找的东西类似于numpy的交叉索引


Tags: ofthewhichistfbethis形状
1条回答
网友
1楼 · 发布于 2024-09-29 23:16:09

您可以使用^{}

 tensor3 = tf.map_fn(lambda u: tf.gather(u[0],u[1],axis=1),[tensor1,tensor2],dtype=tensor1.dtype)

你可以把这一行看作是一个循环,它在tensor1tensor2的第一个维度上运行,对于它们的第一个维度中的每个索引i,它在tensor1[i,:,:]tensor2[i,:]上应用tf.gather

相关问题 更多 >

    热门问题