生成具有最小N值位置reset p的掩码数组

2024-10-01 07:12:03 发布

您现在位置:Python中文网/ 问答频道 /正文

给定一个二维距离数组,使用argsort生成一个索引数组,其中第一个元素是行中最低值的索引。使用索引仅选择前K列,例如K=3

position = np.random.randint(100, size=(5, 5))
array([[36, 63,  3, 78, 98],
   [75, 86, 63, 61, 79],
   [21, 12, 72, 27, 23],
   [38, 16, 17, 88, 29],
   [93, 37, 48, 88, 10]])
idx = position.argsort()
array([[2, 0, 1, 3, 4],
   [3, 2, 0, 4, 1],
   [1, 0, 4, 3, 2],
   [1, 2, 4, 0, 3],
   [4, 1, 2, 3, 0]])
idx[:,0:3]
array([[2, 0, 1],
   [3, 2, 0],
   [1, 0, 4],
   [1, 2, 4],
   [4, 1, 2]])

然后我想做的是创建一个屏蔽数组,当应用到原始位置数组时,它只返回产生k个最短距离的索引

我基于我在一维数组上找到的一些代码来实现这种方法

# https://glowingpython.blogspot.co.uk/2012/04/k-nearest-neighbor-search.html

from numpy import random, argsort, sqrt
from matplotlib import pyplot as plt    

def knn_search(x, D, K):
    """ find K nearest neighbours of data among D """
    ndata = D.shape[1]
    K = K if K < ndata else ndata
    # euclidean distances from the other points
    sqd = sqrt(((D - x[:, :ndata]) ** 2).sum(axis=0))
    idx = argsort(sqd)  # sorting
    # return the indexes of K nearest neighbours
    return idx[:K]

# knn_search test
data = random.rand(2, 5)  # random dataset
x = random.rand(2, 1)  # query point

# performing the search
neig_idx = knn_search(x, data, 2)

figure = plt.figure()
plt.scatter(data[0,:], data[1,:])
plt.scatter(x[0], x[1], c='g')
plt.scatter(data[0, neig_idx], data[1, neig_idx], c='r', marker = 'o')
plt.show()

Tags: thefromsearchdatapltrandom数组array
1条回答
网友
1楼 · 发布于 2024-10-01 07:12:03

有一个办法-

N = 3 # number of points to be set as False per row

# Slice out the first N cols per row
k_idx = idx[:,:N]

# Initialize output array
out = np.ones(position.shape, dtype=bool)

# Index into output with k_idx as col indices to reset
out[np.arange(k_idx.shape[0])[:,None], k_idx] = 0

最后一步涉及advanced-indexing,如果您是NumPy新手,这可能是一个很大的步骤,但基本上这里我们使用k_idx索引到列中,并且我们使用np.arange(k_idx.shape[0])[:,None]范围数组形成索引元组索引到行中。有关^{}的详细信息

我们可以通过使用^{}而不是argsort来提高性能,就像这样-

k_idx = np.argpartition(position, N)[:,:N]

示例输入、输出,用于将每行的最低3元素设置为False-

In [227]: position
Out[227]: 
array([[36, 63,  3, 78, 98],
       [75, 86, 63, 61, 79],
       [21, 12, 72, 27, 23],
       [38, 16, 17, 88, 29],
       [93, 37, 48, 88, 10]])

In [228]: out
Out[228]: 
array([[False, False, False,  True,  True],
       [False,  True, False, False,  True],
       [False, False,  True,  True, False],
       [ True, False, False,  True, False],
       [ True, False, False,  True, False]], dtype=bool)

相关问题 更多 >