我有N组C维点。每组有M个点。所以,有一个张量(N,M,C)。让我们称之为功能
我通过M维计算最大元素和索引,以找到每个C维的最大点(最大池操作),得到最大张量(N,1,C)和索引张量(N,1,C)
我有另一个形状的张量(N,M,3)存储这些N*M高维点的几何坐标。现在,我想使用每个C维中最大点的索引,来获得所有这些最大点的坐标
例如,N=2,M=4,C=6
坐标张量,其形状为(2,4,3):
[[[1, 2, 3]
[4, 5, 6]
[7, 8, 9]
[8, 7, 6]]
[11, 12, 13]
[14, 15, 16]
[17, 18, 19]
[18, 17, 16]]]
指数张量,其形状为(2,1,6):
[[[0, 1, 2, 1, 2, 3]]
[[1, 2, 3, 2, 1, 0]]]
例如,索引中的第一个元素是0,我想从坐标张量中获取[1,2,3]。对于第二个元素(1),我想把[4,5,6]拿出来。对于下一维度(3)中的第三个元素,我想把[18,17,16]拿出来
结果张量如下所示:
[[[1, 2, 3] # 0
[4, 5, 6] # 1
[7, 8, 9] # 2
[4, 5, 6] # 1
[7, 8, 9] # 2
[8, 7, 6]] # 3
[[14, 15, 16] # 1
[17, 18, 19] # 2
[18, 17, 16] # 3
[17, 18, 19] # 2
[14, 15, 16] # 1
[11, 12, 13]]]# 0
其形状为(2,6,3)
我试着用手电筒,但没法用。我编写了一个简单的算法来枚举所有N个组,但实际上它很慢,即使使用TorchScript的JIT。那么,如何在pytorch中高效地编写这篇文章呢
您可以使用integer array indexing与broadcasting semantics组合来获得结果
相关问题 更多 >
编程相关推荐