被GradientBoostingClassifi的应用函数搞糊涂了

2024-06-01 09:41:38 发布

您现在位置:Python中文网/ 问答频道 /正文

对于apply函数,可以参考here

我的困惑更多来自this sample,我在下面的代码片段中添加了一些打印内容,以输出更多的调试信息

grd = GradientBoostingClassifier(n_estimators=n_estimator)
grd_enc = OneHotEncoder()
grd_lm = LogisticRegression()
grd.fit(X_train, y_train)
test_var = grd.apply(X_train)[:, :, 0]
print "test_var.shape", test_var.shape
print "test_var", test_var
grd_enc.fit(grd.apply(X_train)[:, :, 0])
grd_lm.fit(grd_enc.transform(grd.apply(X_train_lr)[:, :, 0]), y_train_lr)

输出如下所示,并混淆了6.3.和{}是什么意思?它们与最终的分类结果有什么关系?在

^{pr2}$

Tags: sample函数testherevartrainthisfit
1条回答
网友
1楼 · 发布于 2024-06-01 09:41:38

要了解梯度提升,首先需要了解单个树。我将展示一个小例子。

下面是设置:一个在Iris数据集上训练的小GB模型,用于预测花是否属于2类。

# import the most common dataset
from sklearn.datasets import load_iris
from sklearn.ensemble import GradientBoostingClassifier
X, y = load_iris(return_X_y=True)
# there are 150 observations and 4 features
print(X.shape) # (150, 4)
# let's build a small model = 5 trees with depth no more than 2
model = GradientBoostingClassifier(n_estimators=5, max_depth=2, learning_rate=1.0)
model.fit(X, y==2) # predict 2nd class vs rest, for simplicity
# we can access individual trees
trees = model.estimators_.ravel()
print(len(trees)) # 5
# there are 150 observations, each is encoded by 5 trees, each tree has 1 output
applied = model.apply(X) 
print(applied.shape) # (150, 5, 1)
print(applied[0].T) # [[2. 2. 2. 5. 2.]] - a single row of the apply() result
print(X[0]) # [5.1 3.5 1.4 0.2] - the pbservation corresponding to that row
print(trees[0].apply(X[[0]])) # [2] - 2 is the result of application the 0'th tree to the sample
print(trees[3].apply(X[[0]])) # [5] - 5 is the result of application the 3'th tree to the sample

您可以看到,由^{cd2>}生成的序列^{cd1>}中的每个数字对应于单个树的输出。但是这些数字是什么意思?

通过视觉检查,我们可以很容易地分析决策树。这里有一个函数来绘制一个

^{pr2}$

enter image description here

可以看到每个节点都有一个数字(从0到6)。如果我们将单个示例推到该树中,它将首先转到节点1(因为特性^{cd3>}具有值^{cd4>}),然后转到节点#2(因为特性^{cd5>}具有值^{cd6>}。

同样,我们可以分析生成输出^{cd7>}的树3:

^{pr3}$

enter image description here

这里我们的观察首先转到节点4,然后转到节点5,因为^{{cd8>}和^{cd9>}。因此,它最终以5号结束。

<坚强的>就这么简单!由^{cd10>}生成的每个数字是相应树的节点的序号,在该树中,样本结束。

这些数字与最终分类结果的关系是通过相应树中的叶子的^{cd11>}。如果是二进制分类,所有叶子中的^{{cd11>}只是加起来,如果是正的,那么“正”会赢,否则是“负”类。在多类分类的情况下,每个类的值都会加起来,而总值最大的类将获胜。

在我们的情况下,第一棵树(与其节点#2)给出值-1.454,其他树也给出一些值,它们的总和为-4.84。它是否定的,因此,我们的示例不属于类2。

^{pr4}$

相关问题 更多 >