Scikit学习随机森林分类器:如何根据树的数量生成OOB错误图

2024-10-05 17:39:14 发布

您现在位置:Python中文网/ 问答频道 /正文


编辑2: 现在在sklearn documentation中有一个很好的例子。


为了了解在我的森林中需要多少棵树,我想绘制OOB错误图,因为在森林中使用的树的数量增加了。我正在使用Python中的sklearn.ensemble.RandomForestClassifier,但是我找不到如何使用森林中的树子集进行预测。我可以通过在每次迭代中创建一个新的随机林来实现这一点,同时增加树的数量,但是这太昂贵了。

对于使用staged_decision_function方法的渐变增强对象,似乎也可以执行类似的任务。见this example

这在R中是一个非常简单的过程,只需调用plot(randomForestObject)Random Forest OOB error against Trees


--编辑-- 我现在看到RandomForestClassifier对象有一个属性estimators_,它返回列表中的所有DecisionTreeClassifier对象。所以为了解决这个问题,我可以遍历这个列表,预测每个树的结果,并取“累积平均值”。然而,是否真的没有更简单的方法来实现这一点呢?


Tags: 对象方法编辑列表数量documentation错误森林
1条回答
网友
1楼 · 发布于 2024-10-05 17:39:14

本期有一个讨论和代码: https://github.com/scikit-learn/scikit-learn/issues/4273

可以像这样逐个添加树:

n_estimators = 100
forest = RandomForestClassifier(warm_start=True, oob_score=True)

for i in range(1, n_estimators + 1):
    forest.set_params(n_estimators=i)
    forest.fit(X, y)
    print i, forest.oob_score_

您建议的解决方案还需要获取每个树的oob索引,因为您不想计算所有培训数据的分数。

我仍然觉得这是一件很奇怪的事情,因为森林里的树木实在没有自然的顺序。 你能解释一下你的用例是什么吗?是否要找到给定精度的最小树数以减少预测时间?如果你想要快速的预测时间,我建议使用GradientBoostingClassifier,它通常要快得多。

相关问题 更多 >