Sickit learn中用于一对一的优化解算器

2024-09-22 16:36:16 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图用逻辑回归来解决一个多类分类问题。我的数据集有3个不同的类,每个数据点只属于一个类。以下是样本训练数据; enter image description here

这里的第一列是我作为偏差项添加的向量。并且目标列已经使用标签binarize的概念进行了二值化,如sickit-learn中所述

然后我得到了如下目标:

array([[1, 0, 0],
   [1, 0, 0],
   [0, 1, 0],
   [1, 0, 0],
   [1, 0, 0]])

接下来,我将使用“一对一”的概念对其进行培训,即一次培训一名学员。样本代码


for i in range(label_train.shape[1]):
    clf = LogisticRegression(random_state=0,multi_class='ovr', solver='liblinear',fit_intercept=True).\
 fit(train_data_copy, label_train[:,i])
    #print(clf.coef_.shape)

如您所见,我总共培训了3个分类器,每个标签对应一个分类器。我这里有两个问题

第一个问题:根据sickit学习文档

multi_class{‘auto’, ‘ovr’, ‘multinomial’}, default=’auto’ If the option chosen is ‘ovr’, then a binary problem is fit for each label. For ‘multinomial’ the loss minimised is the multinomial loss fit across the entire probability distribution, even when the data is binary. ‘multinomial’ is unavailable when solver=’liblinear’. ‘auto’ selects ‘ovr’ if the data is binary, or if solver=’liblinear’, and otherwise selects ‘multinomial’.

我的问题是,既然我选择了解算器作为liblinear(作为o.v.r问题),那么我选择multi_class作为auto还是ovr有关系吗

第二个问题是关于截距(或偏差)项。文档中说,如果fit_intercept=True,则会在决策函数中添加一个偏差项。但我注意到,当我没有将向量1添加到我的数据矩阵中时,系数θ向量中的参数数量与特征数量相同,尽管fit_intercept=True。我的问题是,我们是否必须将1的向量添加到数据矩阵中,以及启用拟合截距,以便将偏差项添加到决策函数中


Tags: the数据autoistrain向量multilabel
1条回答
网友
1楼 · 发布于 2024-09-22 16:36:16
  1. 没关系;正如您可能看到的here,无论是选择multi_class='auto'还是multi_class='ovr',只要solver='liblinear',都会得到相同的结果
  2. solver='liblinear'的情况下,将使用等于1的默认偏差项,并通过intercept_scaling属性将其附加到X(这反过来只有在fit_intercept=True时才有用),如您所见here。拟合后intercept_将返回拟合偏差(维度(n_classes,))(如果fit_intercept=False,则为零值)。拟合系数通过coef_返回(维度(n_classes, n_features)和非(n_classes, n_features + 1)-拆分完成here

下面是一个示例,考虑Iris数据集(具有3个类和4个特征):

from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
X, y = load_iris(return_X_y=True)

clf = LogisticRegression(random_state=0, fit_intercept=True, multi_class='ovr', solver='liblinear')
clf.fit(X, y)
clf.intercept_, clf.coef_
################################
(array([ 0.26421853,  1.09392467, -1.21470917]),
 array([[ 0.41021713,  1.46416217, -2.26003266, -1.02103509],
        [ 0.4275087 , -1.61211605,  0.5758173 , -1.40617325],
        [-1.70751526, -1.53427768,  2.47096755,  2.55537041]]))

相关问题 更多 >