在TensorF中提取子传感器

2024-10-05 12:27:21 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个2×4tensorA = [[0,1,0,1],[1,0,1,0]]。 我想从维度d中提取索引I。 在手电筒里我能做到:tensorA:select(d,i)。在

例如,tensorA:select(0,0)将返回[0,1,0,1],并且 tensorA:select(1,1)将返回[1,0]。在

如何在TensorFlow中执行此操作? 我能找到的最简单的方法是:tf.gather(tensorA, indices=[i], axis=d)

但是,使用聚集似乎有点过头了。有人知道更好的方法吗?在


Tags: 方法tftensorflowselectaxisindicesgather手电筒
2条回答

您可以使用以下配方:

用分号替换除d外的所有轴,并在d轴上输入值i,例如:

tensorA[0, :]  # same as tensorA:select(0,0)
tensorA[:, 1]  # same as tensorA:select(1,1)
tensorA[:, 0]  # same as tensorA:select(1,0)

然而,当我尝试这个的时候,我有一个语法错误:

^{pr2}$

所以我用切片代替

i = 1
selection = [slice(0,2,1), i]
tensorA[selection]  # same as tensorA:select(1,i)

此函数的作用是:

def select(t, axis, index):
    shape = K.int_shape(t)
    selection = [slice(shape[a]) if a != axis else index for a in 
                 range(len(shape))]
    return t[selection]

例如:

import numpy as np
t = K.constant(np.arange(60).reshape(2,5,6))
sub_tensor = select(t, 1, 1)
print(K.eval(sub_tensor)  

印刷品

[[6., 7., 8., 9., 10., 11.],

[36., 37., 38., 39., 40., 41.]]

您只需使用value = tensorA[d,i]。在引擎盖下,tensorflow呼叫 ^{}

相关问题 更多 >

    热门问题