预测类及其相应的概率

2024-09-27 23:21:43 发布

您现在位置:Python中文网/ 问答频道 /正文

我用maxvoting(决策树、随机森林、逻辑回归)分类器建立了一个机器学习模型。我的输入是

{ “工资”:50000美元, “流动贷款”:15000, “信用评分”:616分, “请求贷款”:25000 }

当我将这些数据传递给我的模型时。它给出了如下预测:

{“状态”:批准}

但我需要像这样检索响应

{“状态”:批准,“准确性”:0.87}

任何帮助都将不胜感激


Tags: 数据模型机器决策树分类器状态森林逻辑
1条回答
网友
1楼 · 发布于 2024-09-27 23:21:43

看起来您可能正在使用sklearn的^{}。一旦安装了分类器,就可以通过属性^{}看到与每个类相关联的概率。请注意,这不是准确度,而是每个类的关联概率。因此,如果您想要测试样本属于类n的概率,您必须在相应列上对输出y_pred_prob进行索引。以下是使用sklearn的iris数据集的示例:

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier, VotingClassifier

from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import GaussianNB

clf1 = LogisticRegression(multi_class='multinomial', random_state=1)
clf2 = RandomForestClassifier(n_estimators=50, random_state=1)
clf3 = GaussianNB()

X, y = load_iris(return_X_y=True)

X_train, X_test, y_train, y_test = train_test_split(X, y)

eclf2 = VotingClassifier(estimators=[
        ('lr', clf1), ('rf', clf2), ('gnb', clf3)],
        voting='soft')

eclf2 = eclf2.fit(X_train, y_train)

我们可以得到与第一类相关的概率,例如:

eclf2.predict_proba(X_test)[:,0].round(2)

array([0.99, 0.  , 0.  , 0.  , 0.  , 0.  , 0.01, 0.01, 0.  , 0.  , 0.  ,
       0.99, 0.  , 0.99, 0.99, 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
       0.  , 0.01, 0.98, 0.  , 1.  , 0.99, 0.  , 0.  , 0.  , 0.99, 0.98,
       0.  , 0.99, 0.  , 0.01, 0.99])

最后,要获得如上所述的输出,可以使用predict返回的结果对2D概率数组进行索引,如下所示:

import pandas as pd

y_pred = eclf2.predict(X_test)
y_pred_prob = eclf2.predict_proba(X_test).round(2)
associated_prob = y_pred_prob[np.arange(len(y_test)), y_pred]
pd.DataFrame({'class':y_pred, 'Accuracy':associated_prob})

    class  Accuracy
0       0      0.99
1       2      0.84
2       2      1.00
3       1      0.95
4       2      0.99
5       2      0.91
6       1      0.98
7       1      0.98
8       1      0.93

或者,如果您喜欢将输出作为字典:

pd.DataFrame({'class':y_pred, 'Accuracy':associated_prob}).to_dict(orient='index')

 {0: {'class': 0, 'Accuracy': 0.99},
  1: {'class': 2, 'Accuracy': 0.84},
  2: {'class': 2, 'Accuracy': 1.0},
  3: {'class': 1, 'Accuracy': 0.95},
  4: {'class': 2, 'Accuracy': 0.99},

相关问题 更多 >

    热门问题