我的数据对象是以下对象的实例:
class data_instance:
def __init__(self, data, tlabel):
self.data = data # 1xd numpy array
self.true_label = tlabel # integer {1,-1}
到目前为止,在代码中,我有一个名为data_history
的列表,其中包含data_istance
和一组centers
(形状为(k,d)的numpy数组)。你知道吗
对于给定的数据实例new_data
,我想要:
1/从centers
(欧几里德距离)得到最接近new_data
的中心,称之为Nearest_center
。
2/迭代槽data_history
和:
Nearest_center
(1/的结果)的元素到名为neighbors
的列表中。你知道吗neighbors
中对象的标签。你知道吗贝娄是我的代码工作,但它钢慢,我正在寻找一些更有效的。你知道吗
我的代码
对于1/
def getNearestCenter(data,centers):
if centers.shape != (1,2):
dist_ = np.sqrt(np.sum(np.power(data-centers,2),axis=1)) # This compute distance between data and all centers
center = centers[np.argmin(dist_)] # this return center which have the minimum distance from data
else:
center=centers[0]
return center
对于2/(优化)
def getLabel(dataPoint, C, history):
labels = []
cluster = getNearestCenter(dataPoint.data,C)
for x in history:
if np.all(getNearestCenter(x.data,C) == cluster):
labels.append(x.true_label)
return labels
您应该使用来自
scipy.spatial
的优化的cdist
,这比使用numpy计算效率更高一个更优雅的解决方案是使用
scipy.spatial.cKDTree
(正如@Saullo Castro在评论中指出的那样),对于大型数据集来说,这可能更快找到了:
它应该从
data
的每个数据点返回centers
中最近中心的索引。你知道吗相关问题 更多 >
编程相关推荐