ScikitLearn的.fit()方法如何将数据传递给.predict()?

2024-09-25 00:35:52 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图理解sklearn's .fit()方法和.predict()方法之间的关系;主要是数据(通常)是如何从一个传递到另一个的。我还没有找到另一个问题来解决这个问题,但我已经围绕它跳舞了(即here

我编写了一个自定义的估计器,使用BaseEstimator和RegressorMixin类,但是在我开始运行数据时遇到了一些“notfittedror”。有人能告诉我一个简单的线性回归,以及数据是如何通过拟合和预测方法传递的吗?不需要进入数学-我知道回归是如何工作的,以及拼图的各个部分是怎么做的。也许我忽略了显而易见的东西,让它变得比它应该的更复杂?但估计方法感觉有点像黑匣子。在


Tags: 数据方法here关系线性数学sklearnpredict
2条回答

让我们看一个玩具估计器做LinearRegression

from sklearn.base import TransformerMixin, BaseEstimator
import numpy as np

class ToyEstimator(BaseEstimator):
    def __init__(self):
        pass

    def fit(self, X, y):
        X = np.hstack((X,np.ones((len(X),1))))
        self.W = np.random.randn(X.shape[1])

        self.W = np.dot(np.dot(np.linalg.inv(np.dot(X.T,X)), X.T), y)
        self.coef_ = self.W[:-1]
        self.intercept_ = self.W[-1]
        return self


    def transform(self, X):
        X = np.hstack((X,np.ones((len(X),1))))
        return np.dot(X,self.W)

X = np.random.randn(10,3)
y = X[:,0]*1.11+X[:,1]*2.22+X[:,2]*3.33+4.44

reg = ToyEstimator()
reg.fit(X,y)
y_ = reg.transform(X)
print (reg.coef_, reg.intercept_)

输出:

^{pr2}$

那么上面的代码做了什么呢?在

  1. fit中,我们使用训练数据拟合/训练权重。这些权重是类的成员变量[这是在OOPs中学习到的]
  2. transform方法使用作为成员变量存储的训练权重对数据进行预测。在

所以在调用transform之前,您需要调用fit,因为transform使用fit期间计算的权重。在

在sklearn模块中,如果在fit之前调用transform,则会得到NotFittedError异常。在

当您在训练或使用.fit()方法之前尝试使用分类器的.predict()方法时,NotFittedError会发生。在

以scikit learn中的LinearRegression为例。在

>>> import numpy as np
>>> from sklearn.linear_model import LinearRegression
>>> X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])
>>> # y = 1 * x_0 + 2 * x_1 + 3
>>> y = np.dot(X, np.array([1, 2])) + 3
>>> reg = LinearRegression().fit(X, y)
>>> reg.score(X, y)
1.0
>>> reg.coef_
array([1., 2.])
>>> reg.intercept_ 
3.0000...
>>> reg.predict(np.array([[3, 5]]))
array([16.])

所以用reg = LinearRegression().fit(X, y)行实例化类LinearRegression,然后将其与数据X和y相匹配,其中X是自变量,y是依赖变量。一旦模型在该类中被训练,线性回归的beta系数保存在类属性coef_中,您可以使用reg.coef_来访问它。这就是类如何知道何时使用.predict()类方法来预测。这个类访问这些系数,然后用简单的代数来生成预测。在

回到你的错误。如果你没有将模型拟合到你的训练数据中,那么这个类就没有进行预测所需的必要属性。希望这能澄清类内部的一些混乱,至少在fit()和{}方法如何交互方面。在

最后,就像上面所说的,这可以追溯到面向对象编程的基本原理,所以如果您想进一步了解,我将阅读Python如何在scikit学习模型遵循相同行为的情况下处理类

相关问题 更多 >