如何使用scipy.spatial.KDTree.query\u ball\u point方法返回的索引从numpy数组中删除元素

2024-10-09 20:27:58 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试使用Kdtree数据结构从数组中删除最近的点,最好不要使用for循环

import sys

import time

import scipy.spatial

class KDTree:
    """
    Nearest neighbor search class with KDTree
    """

    def __init__(self, data):
        # store kd-tree
        self.tree = scipy.spatial.cKDTree(data)

    def search(self, inp, k=1):
        """
        Search NN
        inp: input data, single frame or multi frame
        """

        if len(inp.shape) >= 2:  # multi input
            index = []
            dist = []

            for i in inp.T:
                idist, iindex = self.tree.query(i, k=k)
                index.append(iindex)
                dist.append(idist)

            return index, dist

        dist, index = self.tree.query(inp, k=k)
        return index, dist

    def search_in_distance(self, inp, r):
        """
        find points with in a distance r
        """

        index = self.tree.query_ball_point(inp, r)
        return np.asarray(index)


import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
start = time.time()
fig, ar = plt.subplots()
t = 0
R = 50.0
u = R *np.cos(t)
v = R *np.sin(t)

x = np.linspace(-100,100,51)
y = np.linspace(-100,100,51)

xx, yy = np.meshgrid(x,y)
points =np.vstack((xx.ravel(),yy.ravel())).T
Tree = KDTree(points)
ind = Tree.search_in_distance([u, v],10.0)
ar.scatter(points[:,0],points[:,1],c='k',s=1)
infected = points[ind]
ar.scatter(infected[:,0],infected[:,1],c='r',s=5)

def animate(i):
    global R,t,start,points
    ar.clear()
    u = R *np.cos(t)
    v = R *np.sin(t)
    ind = Tree.search_in_distance([u, v],10.0)
    ar.scatter(points[:,0],points[:,1],c='k',s=1)
    infected = points[ind]
    ar.scatter(infected[:,0],infected[:,1],c='r',s=5)
    #points = np.delete(points,ind)
    t+=0.01
    end = time.time()
    if end - start != 0:
        print((end - start), end="\r")
        start = end
ani = animation.FuncAnimation(fig, animate, interval=20)
plt.show()  

但无论我做什么,我都无法让np.delete处理ball_查询方法返回的索引。我错过了什么

我想让红色的点在点数组的每次迭代中消失


Tags: inimportselftreesearchindextimedist
1条回答
网友
1楼 · 发布于 2024-10-09 20:27:58

您的points数组是一个Nx2矩阵。您的ind索引是行索引的列表。您需要的是指定需要删除的轴,最终如下所示:

points = np.delete(points,ind,axis=0)

此外,一旦删除索引,请在下一次迭代/计算中注意缺少的索引。可能需要一个副本来删除点和打印,另一个副本用于不从中删除的计算

相关问题 更多 >

    热门问题