<p>让我们看一个玩具估计器做<code>LinearRegression</code></p>
<pre><code>from sklearn.base import TransformerMixin, BaseEstimator
import numpy as np
class ToyEstimator(BaseEstimator):
def __init__(self):
pass
def fit(self, X, y):
X = np.hstack((X,np.ones((len(X),1))))
self.W = np.random.randn(X.shape[1])
self.W = np.dot(np.dot(np.linalg.inv(np.dot(X.T,X)), X.T), y)
self.coef_ = self.W[:-1]
self.intercept_ = self.W[-1]
return self
def transform(self, X):
X = np.hstack((X,np.ones((len(X),1))))
return np.dot(X,self.W)
X = np.random.randn(10,3)
y = X[:,0]*1.11+X[:,1]*2.22+X[:,2]*3.33+4.44
reg = ToyEstimator()
reg.fit(X,y)
y_ = reg.transform(X)
print (reg.coef_, reg.intercept_)
</code></pre>
<p>输出:</p>
^{pr2}$
<p>那么上面的代码做了什么呢?在</p>
<ol>
<li>在<code>fit</code>中,我们使用训练数据拟合/训练权重。这些权重是类的成员变量[这是在OOPs中学习到的]</li>
<li><code>transform</code>方法使用作为成员变量存储的训练权重对数据进行预测。在</li>
</ol>
<p>所以在调用<code>transform</code>之前,您需要调用<code>fit</code>,因为<code>transform</code>使用fit期间计算的权重。在</p>
<p>在sklearn模块中,如果在<code>fit</code>之前调用<code>transform</code>,则会得到<code>NotFittedError</code>异常。在</p>