Python中欧氏距离平方的矢量化掩模

2024-09-29 23:22:15 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在运行代码来生成一个掩码,其中B中的位置比a中的某个距离D更近

N = [[0 for j in range(length_B)] for i in range(length_A)]    
dSquared = D*D

for i in range(length_A):
    for j in range(length_B):
        if ((A[j][0]-B[i][0])**2 + (A[j][1]-B[i][1])**2) <= dSquared:
            N[i][j] = 1

对于包含数万个位置的A和B列表,此代码需要一段时间。我很确定有一种方法可以将它矢量化,使它运行得更快。谢谢您。在


Tags: 方法代码in距离列表forifrange
3条回答

您可以使用^{},它对于这种距离计算是非常有效的,比如-

from scipy.spatial.distance import cdist

N = (cdist(A,B,'sqeuclidean') <= dSquared).astype(int)

正如^{}中所建议的那样,也可以使用broadcasting。现在,从问题中发布的代码来看,我们处理的是Nx2形状的数组。因此,我们基本上可以对第一列和第二列进行切片,并分别对它们执行广播减法。这样做的好处是,我们不会继续使用3D,因此可以保持它的内存效率,这也可以转化为性能提升。因此,平方欧几里德距离的计算如下-

^{pr2}$

让我们来计算欧几里德距离平方的这三种方法。在

运行时测试-

In [75]: # Input arrays
    ...: A = np.random.rand(200,2)
    ...: B = np.random.rand(200,2)
    ...: 

In [76]: %timeit ((A[:,None,:] - B[None,:,:])**2).sum(axis=-1) # @hpaulj's solution
1000 loops, best of 3: 1.9 ms per loop

In [77]: %timeit (A[:,None,0] - B[:,0])**2 + (A[:,None,1] - B[:,1])**2
1000 loops, best of 3: 401 µs per loop

In [78]: %timeit cdist(A,B,'sqeuclidean')
1000 loops, best of 3: 249 µs per loop

使用二维数组索引可以更容易地可视化此代码:

for j in range(length_A):
    for i in range(length_B):
        dist = (A[j,0] - B[i,0])**2 + (A[j,1] - B[i,1])**2 
        if dist <= dSquared:
            N[i, j] = 1

这个dist表达式看起来像

^{pr2}$

对于2个元素,这个数组表达式可能不会更快,但它应该有助于我们重新思考问题。在

我们可以用广播来执行i,joutter问题

A[:,None,:] - B[None,:,:]  # 3d difference array

dist=((A[:,None,:] - B[None,:,:])**2).sum(axis=-1)  # (lengthA,lengthB) array

将其与dSquared进行比较,并使用生成的布尔数组作为掩码,将N的元素设置为1:

N = np.zeros((lengthA,lengthB))
N[dist <= dSquared] = 1

我还没有测试过这段代码,所以可能有一些bug,但我认为基本思想已经存在了。也许你的思考过程足够让你为其他的案例找出细节。在

我支持上面使用Numpy的建议。循环代码在a中的索引也比它需要的多得多。您可以使用类似于:

import numpy as np

dimension = 10000
A = np.random.rand(dimension, 2) + 0.0
B = np.random.rand(dimension, 2) + 1.0
N = []
d = 1.0

for i in range(len(A)):
    distances = np.linalg.norm(B - A[i,:], axis=1)
    for j in range(len(distances)):
        if distances[j] <= d:
            N.append((i,j))

print(len(N))

要想从纯Python中获得良好的性能是相当困难的。我也会指出,更维的阵列解决方案将需要一个。。。很多。。。记忆。在

相关问题 更多 >

    热门问题