我有一个numpy数组indices
:
array([[ 0, 0, 0],
[ 0, 0, 0],
[ 2, 0, 2],
[ 0, 0, 0],
[ 2, 0, 2],
[95, 71, 95]])
我有另一个相同长度的数组叫做distances
:
array([ 0.98713981, 1.04705992, 1.42340327, 74.0139111 ,
74.4285216 , 74.84623217])
indices
中的所有行在distances
数组中都有匹配项。问题是,indices
数组中存在重复项,它们在相应的distances
数组中具有不同的值。我想得到所有三元组索引的最小距离,并丢弃其他索引。因此,通过上面的输入,我想要输出:
indicesOUT =
array([[ 0, 0, 0],
[ 2, 0, 2],
[95, 71, 95]])
distancesOUT=
array([ 0.98713981, 1.42340327, 74.84623217])
我目前的策略如下:
import numpy as np
indicesOUT = []
distancesOUT = []
for i in range(6):
for j in range(6):
for k in range(6):
if len([s for s in indicesOUT if [i,j,k] == s]) == 0:
current = np.array([i, j, k])
ind = np.where((indices == current).all(-1) == True)[0]
currentDistances = distances[ind]
dist = np.amin(distances)
indicesOUT.append([i, j, k])
distancesOUT.append(dist)
问题是,实际的数组每个都有大约400万个元素,所以这种方法太慢了。最有效的方法是什么?你知道吗
这本质上是一个分组操作,而NumPy并没有为此进行很好的优化。幸运的是,Pandas包有一些非常快速的工具,可以适应这个确切的问题。 利用以上数据,我们可以做到:
数据的输出是
我的基准测试显示,对于4000000个元素,这应该在大约一秒钟内运行:
如上所述,索引的输入顺序不一定会被保留;保持原始顺序需要更多的考虑。你知道吗
相关问题 更多 >
编程相关推荐