我正在尝试使用cross_val_score
和一个定制的估计器。重要的是,这个估计器接收一个成员变量,稍后可以在fit
函数中使用。但似乎在cross_val_score
内部,成员变量被破坏了(或者正在创建一个新的估计器实例)。
以下是可以重现错误的最小代码:
from sklearn.model_selection import cross_val_score
from sklearn.base import BaseEstimator
class MyEstimator(BaseEstimator):
def __init__(self, member):
self._member = member
def fit(self, X, y):
if self._member is None:
raise Exception('member is None.')
X = np.array([[1, 1, 1], [2 ,2 , 2]])
y = np.array([1, 2])
score_values = cross_val_score(
MyEstimator('some value'),
X,
y,
cv=2,
scoring='r2'
)
在上面的代码中,总是引发异常。 有没有办法解决这个问题?在
Sklearn在内部克隆估计器,以创建估计器的多个副本。Reference;使用
clone
函数。在clone只从对象复制构造函数参数值。在
解决方案:
使您的构造函数参数和对象属性一致,因此从下划线开始,或删除所有地方的下划线!在
^{pr2}$相关问题 更多 >
编程相关推荐