在scikitlearn中创建添加群集标签的自定义转换器

2024-05-03 01:38:47 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在用scikit learn编写一个自定义转换器,它使用stock KMeans将集群标签作为一个新列添加到dataframe中。 自定义转换器应适合现有数据,然后通过添加索引名为“Cluster”的新列来转换看不见的数据,并在不修改原始数据帧的情况下返回带有附加列的新数据帧。 下面是我提出的代码:

from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.cluster import KMeans

class AddClustersFeature(BaseEstimator, TransformerMixin):
    def __init__(self, clusters = 10): 
        self.clusters = clusters
        self.model = KMeans(n_clusters = self.clusters)
           
    def fit(self, X):
        self.X=X
        self.model.fit (self.X)
        return self.model
       
    def transform(self, X):
        self.X=X
        X_=X.copy() # avoiding modification of the original df
        
        X_['Clusters'] = self.model.transform(self.X_).labels_
        
        return X_

cluster_enc_tr_data = AddClustersFeature().fit_transform(enc_tr_data)
cluster_enc_tr_data

不幸的是,代码工作正常。结果是一个数据帧,其中簇号作为列索引,行号和未知值。 任何帮助或提示都将不胜感激

6月21日第23版更新: 在实施Ben的修订意见后,请参见下面的代码。 它现在工作得很好

from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.cluster import KMeans

class AddClustersFeature(BaseEstimator, TransformerMixin):
    def __init__(self, clusters = 10): 
        self.clusters = clusters
        
           
    def fit(self, X):
        self.X=X
        self.model = KMeans(n_clusters = self.clusters)
        self.model.fit (self.X)
        return self
       
    def transform(self, X):
        self.X=X
        X_=X.copy() # avoiding modification of the original df
        X_['Clusters'] = self.model.predict(X_)
        return X_

cluster_enc_tr_data = AddClustersFeature().fit_transform(enc_tr_data)

Tags: 数据fromimportselfdatamodeldeftransform
1条回答
网友
1楼 · 发布于 2024-05-03 01:38:47

fit方法必须始终返回self

这里的问题是fit_transform(X, y)继承自TransformerMixin,只是fit(X, y).transform(X);您的fit现在返回底层的KMeans转换器,并且使用来转换X,而不是您的transform

不过,还有一些注意事项:

  1. KMeans.transform给出了簇距离矩阵,但您需要簇标签。改用predict。并删除labels_,所以只需X_['Clusters'] = self.model.predict(X_)。)

  2. __init__应该只设置出现在其签名中的属性,以便克隆工作(例如超参数搜索所需)。您可以在fit时间定义self.model

  3. transform中,使用self.X_,但从未定义它;我猜你的意思是X_。也没有真正的理由在适当的时候保存Xself.X从来就不是真正需要的吗

  4. 这只适用于数据帧;这对你来说可能不是问题,但要记住。(在内置sklearn转换器之后,不能将此作为管道中的一个步骤,因为这些转换器将返回numpy数组。)

相关问题 更多 >