更有效地获取最近的cen

2024-10-01 02:40:35 发布

您现在位置:Python中文网/ 问答频道 /正文

我的数据对象是以下对象的实例:

class data_instance:
    def __init__(self, data, tlabel):
        self.data = data # 1xd numpy array
        self.true_label = tlabel # integer {1,-1}

到目前为止,在代码中,我有一个名为data_history的列表,其中包含data_istance和一组centers(形状为(k,d)的numpy数组)。你知道吗

对于给定的数据实例new_data,我想要:

  • 1/从centers(欧几里德距离)得到最接近new_data的中心,称之为Nearest_center

  • 2/迭代槽data_history和:

    • 2.1/选择最近中心为Nearest_center(1/的结果)的元素到名为neighbors的列表中。你知道吗
    • 2.2/获取neighbors中对象的标签。你知道吗

贝娄是我的代码工作,但它钢慢,我正在寻找一些更有效的。你知道吗

我的代码

对于1/

def getNearestCenter(data,centers):

    if centers.shape != (1,2):
        dist_ = np.sqrt(np.sum(np.power(data-centers,2),axis=1)) # This compute distance between data and all centers

        center = centers[np.argmin(dist_)] # this return center which have the minimum distance from data

    else:
        center=centers[0]
    return center

对于2/(优化)

def getLabel(dataPoint, C, history):

    labels = []
    cluster = getNearestCenter(dataPoint.data,C)
    for x in history:
        if  np.all(getNearestCenter(x.data,C) == cluster):
            labels.append(x.true_label)
    return labels

Tags: 数据对象实例代码selfdatalabelsreturn
2条回答

您应该使用来自scipy.spatial的优化的cdist,这比使用numpy计算效率更高

from scipy.spatial.distance import cdist

dist = cdist(data, C, metric='euclidean')
dist_idx = np.argmin(dist, axis=1)

一个更优雅的解决方案是使用scipy.spatial.cKDTree(正如@Saullo Castro在评论中指出的那样),对于大型数据集来说,这可能更快

from scipy.spatial import cKDTree

tr = cKDTree(C)
dist, dist_idx = tr.query(data, k=1) 

找到了:

dist_ = np.argmin(np.sqrt(np.sum(np.power(data[:, None]-C,2),axis=2)),axis=1)

它应该从data的每个数据点返回centers中最近中心的索引。你知道吗

相关问题 更多 >