我试图理解sklearn's .fit()
方法和.predict()
方法之间的关系;主要是数据(通常)是如何从一个传递到另一个的。我还没有找到另一个问题来解决这个问题,但我已经围绕它跳舞了(即here)
我编写了一个自定义的估计器,使用BaseEstimator和RegressorMixin类,但是在我开始运行数据时遇到了一些“notfittedror”。有人能告诉我一个简单的线性回归,以及数据是如何通过拟合和预测方法传递的吗?不需要进入数学-我知道回归是如何工作的,以及拼图的各个部分是怎么做的。也许我忽略了显而易见的东西,让它变得比它应该的更复杂?但估计方法感觉有点像黑匣子。在
让我们看一个玩具估计器做
LinearRegression
输出:
^{pr2}$那么上面的代码做了什么呢?在
fit
中,我们使用训练数据拟合/训练权重。这些权重是类的成员变量[这是在OOPs中学习到的]transform
方法使用作为成员变量存储的训练权重对数据进行预测。在所以在调用
transform
之前,您需要调用fit
,因为transform
使用fit期间计算的权重。在在sklearn模块中,如果在
fit
之前调用transform
,则会得到NotFittedError
异常。在当您在训练或使用
.fit()
方法之前尝试使用分类器的.predict()
方法时,NotFittedError
会发生。在以scikit learn中的LinearRegression为例。在
所以用
reg = LinearRegression().fit(X, y)
行实例化类LinearRegression
,然后将其与数据X和y相匹配,其中X是自变量,y是依赖变量。一旦模型在该类中被训练,线性回归的beta系数保存在类属性coef_
中,您可以使用reg.coef_
来访问它。这就是类如何知道何时使用.predict()
类方法来预测。这个类访问这些系数,然后用简单的代数来生成预测。在回到你的错误。如果你没有将模型拟合到你的训练数据中,那么这个类就没有进行预测所需的必要属性。希望这能澄清类内部的一些混乱,至少在}方法如何交互方面。在
fit()
和{最后,就像上面所说的,这可以追溯到面向对象编程的基本原理,所以如果您想进一步了解,我将阅读Python如何在scikit学习模型遵循相同行为的情况下处理类
相关问题 更多 >
编程相关推荐