从scikitlearn/numpy中的群集中心计算标签?

2024-10-03 23:24:02 发布

您现在位置:Python中文网/ 问答频道 /正文

我需要从没有原始clusterer对象的不同数据集中的另一个clusterer生成的聚类中心计算标签

我知道我可以像这样用python硬编码

def compute_labels(centers,datapoints):
    ans=[]
    for point in datapoints:
        ans.append(
            min(
                ((i,np.linalg.norm(point-center)) for i,center in enumerate(centers)),
                key=lambda t:t[1]
            )[0]
        )
    return ans

但是,对于我的应用程序来说,它会很慢,我需要一个较低级别的实现,所以我想知道是否可以只使用scikit learn或numpy

我尝试的是:

from sklearn.cluster import KMeans
import numpy as np

np.random.seed(42)

datapoints1=np.random.rand(200,38)
datapoints2=np.random.rand(200,38)

kmeans1=KMeans(
    init="k-means++",
    random_state=42,
    n_init=100
 )
kmeans1=kmeans1.fit(datapoints1)

kmeans2=KMeans(
    init=kmeans1.cluster_centers_,
    max_iter=1,
    n_init=1
)
kmeans2.predict(datapoints2)
print((kmeans1.cluster_centers_==kmeans2.cluster_centers_).all())

但是它会在{}中引发{}异常,我尝试在kmeans kwargs中设置{},但它也会引发异常


Tags: inforinitnprandompointcentercluster
2条回答

sklearn正在检查一个名为check_is_fitted的函数,该函数正在查看模型的属性。在您的例子中,由于您从未调用过fit,因此某些属性不存在,这会触发错误。您可以通过手动创建它们来伪造它,例如:

from sklearn.cluster import KMeans
import numpy as np

np.random.seed(42)

datapoints1=np.random.rand(200,38)
datapoints2=np.random.rand(200,38)

kmeans1=KMeans(
    init="k-means++",
    random_state=42,
    n_init=100
 )
kmeans1=kmeans1.fit(datapoints1)

kmeans2=KMeans(
    init=kmeans1.cluster_centers_,
    max_iter=1,
    n_init=1
)

kmeans2.cluster_centers_ = kmeans1.cluster_centers_                # you have it
kmeans2.labels_ = kmeans1.labels_                                  # to test if required, no difference found
print([v for v in vars(kmeans2)
       if v.endswith("_") and not v.startswith("__")])             # if this list is empty, the model if not fitted, you can compare it to kmeans1

pred = kmeans2.predict(datapoints2)
print(pred)                                                        # [3 7 1 ... 2]
print((kmeans1.cluster_centers_== kmeans2.cluster_centers_).all()) # True

只是对Nicolas M. answer的一个补充

广义函数(带有虚拟静态变量):

def compute_labels(centers,datapoints):
    compute_labels.dummy.cluster_centers_=centers
    return compute_labels.dummy.predict(datapoints)
compute_labels.dummy=KMeans()

相关问题 更多 >