预测时如何处理onehotencoding后的类别不匹配?

2024-09-30 06:20:26 发布

您现在位置:Python中文网/ 问答频道 /正文

很抱歉,如果题目不太清楚,我不能用一句话来概括问题。在

以下是简化的数据集以供解释。基本上,训练集中的类别数远远大于测试集中的类别数,这是因为一次热编码后测试集和训练集中的列数存在差异。我怎么处理这个问题?在

训练集

+-------+----------+
| Value | Category |
+-------+----------+
| 100   | SE1      |
+-------+----------+
| 200   | SE2      |
+-------+----------+
| 300   | SE3      |
+-------+----------+

OneHotEncoding后的训练集

^{pr2}$

测试集

+-------+----------+
| Value | Category |
+-------+----------+
| 100   | SE1      |
+-------+----------+
| 200   | SE1      |
+-------+----------+
| 300   | SE2      |
+-------+----------+

OneHotEncoding后的测试集

+-------+-----------+-----------+
| Value | DummyCat1 | DummyCat2 |
+-------+-----------+-----------+
| 100   | 1         | 0         |
+-------+-----------+-----------+
| 200   | 1         | 0         |
+-------+-----------+-----------+
| 300   | 0         | 1         |
+-------+-----------+-----------+

您可以注意到,OneHotEncoding之后的训练集是(3,4)的形状,而OneHotEncoding之后的测试集是(3,3)的形状。 因此,当我执行以下代码时(y_train是形状(3,)的列向量)

from sklearn.linear_model import LinearRegression
regressor = LinearRegression()
regressor.fit(x_train, y_train)

x_pred = regressor.predict(x_test)

我得到了预测函数的误差。如您所见,与基本示例不同,错误中的维度相当大。在

  Traceback (most recent call last):

  File "<ipython-input-2-5bac76b24742>", line 30, in <module>
    x_pred = regressor.predict(x_test)

  File "/Users/parthapratimneog/anaconda3/lib/python3.6/site-packages/sklearn/linear_model/base.py", line 256, in predict
    return self._decision_function(X)

  File "/Users/parthapratimneog/anaconda3/lib/python3.6/site-packages/sklearn/linear_model/base.py", line 241, in _decision_function
    dense_output=True) + self.intercept_

  File "/Users/parthapratimneog/anaconda3/lib/python3.6/site-packages/sklearn/utils/extmath.py", line 140, in safe_sparse_dot
    return np.dot(a, b)

ValueError: shapes (4801,2236) and (4033,) not aligned: 2236 (dim 1) != 4033 (dim 0)

Tags: inmodelvaluelinetrainsklearnuserspredict
1条回答
网友
1楼 · 发布于 2024-09-30 06:20:26

您必须按照x峎u train的转换方式来转换x_test。在

x_test = onehotencoder.transform(x_test)
x_pred = regressor.predict(x_test)

请确保使用的onehotencoder对象与x}列车上的fit()相同。在

我假设您当前正在对测试数据使用fit_transform()。 做fit()fit_transform()会忘记先前学习的数据,并重新拟合oneHotEncoder。它现在认为列中只有两个不同的值,因此会改变输出的形状。在

相关问题 更多 >

    热门问题