如何探索使用scikit learn构建的决策树

2024-05-17 07:34:22 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在用

clf = tree.DecisionTreeClassifier()
clf = clf.fit(X_train, Y_train)

一切正常。然而,我该如何探索决策树呢?

例如,如何查找X_train中的哪些条目出现在特定的叶中?


Tags: 决策树tree条目trainfitclfdecisiontreeclassifier
3条回答

下面的代码将生成十大功能的图表:

import numpy as np
import matplotlib.pyplot as plt

importances = clf.feature_importances_
std = np.std(clf.feature_importances_,axis=0)
indices = np.argsort(importances)[::-1]

# Print the feature ranking
print("Feature ranking:")

for f in range(10):
    print("%d. feature %d (%f)" % (f + 1, indices[f], importances[indices[f]]))

# Plot the feature importances of the forest
plt.figure()
plt.title("Feature importances")
plt.bar(range(10), importances[indices],
       color="r", yerr=std[indices], align="center")
plt.xticks(range(10), indices)
plt.xlim([-1, 10])
plt.show()

取自here并稍加修改以适应DecisionTreeClassifier

这并不完全有助于你探索这棵树,但它确实告诉你关于这棵树的事情。

你需要使用预测方法。

在对树进行训练之后,输入X值以预测其输出。

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
clf = DecisionTreeClassifier(random_state=0)
iris = load_iris()
tree = clf.fit(iris.data, iris.target)
tree.predict(iris.data) 

输出:

>>> tree.predict(iris.data)
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

要获取树结构的详细信息,可以使用tree_.__getstate__()

树结构转换为“ASCII艺术”图片

              0  
        _____________
        1           2
               ______________
               3            12
            _______      _______
            4     7      13   16
           ___   ______        _____
           5 6   8    9        14 15
                      _____
                      10 11

作为数组的树结构。

In [38]: tree.tree_.__getstate__()['nodes']
Out[38]: 
array([(1, 2, 3, 0.800000011920929, 0.6666666666666667, 150, 150.0),
       (-1, -1, -2, -2.0, 0.0, 50, 50.0),
       (3, 12, 3, 1.75, 0.5, 100, 100.0),
       (4, 7, 2, 4.949999809265137, 0.16803840877914955, 54, 54.0),
       (5, 6, 3, 1.6500000953674316, 0.04079861111111116, 48, 48.0),
       (-1, -1, -2, -2.0, 0.0, 47, 47.0), 
       (-1, -1, -2, -2.0, 0.0, 1, 1.0),
       (8, 9, 3, 1.5499999523162842, 0.4444444444444444, 6, 6.0),
       (-1, -1, -2, -2.0, 0.0, 3, 3.0),
       (10, 11, 2, 5.449999809265137, 0.4444444444444444, 3, 3.0),
       (-1, -1, -2, -2.0, 0.0, 2, 2.0), 
       (-1, -1, -2, -2.0, 0.0, 1, 1.0),
       (13, 16, 2, 4.850000381469727, 0.042533081285444196, 46, 46.0),
       (14, 15, 1, 3.0999999046325684, 0.4444444444444444, 3, 3.0),
       (-1, -1, -2, -2.0, 0.0, 2, 2.0), 
       (-1, -1, -2, -2.0, 0.0, 1, 1.0),
       (-1, -1, -2, -2.0, 0.0, 43, 43.0)], 
      dtype=[('left_child', '<i8'), ('right_child', '<i8'), 
             ('feature', '<i8'), ('threshold', '<f8'), 
             ('impurity', '<f8'), ('n_node_samples', '<i8'), 
             ('weighted_n_node_samples', '<f8')])

其中:

  • 第一个节点[0]是根节点。
  • 内部节点的左/右子节点指的是值为正值且大于当前节点的节点。
  • 左和右子节点的叶具有-1值。
  • 节1、5、6、8、10、11、14、15、16为叶。
  • 采用深度优先搜索算法建立节点结构。
  • feature字段告诉我们节点中使用了哪个iris.data特性来确定此示例的路径。
  • 阈值告诉我们用于根据特征计算方向的值。
  • 杂质在叶子上达到0。。。因为一旦你拿到叶子,所有的样本都在同一个类中。
  • 节点样本告诉我们每个叶子有多少个样本。

使用这些信息,我们可以按照脚本上的分类规则和阈值,简单地跟踪每个样本X到它最终到达的叶。此外,n_node_samples允许我们执行单元测试,确保每个节点获得正确的样本数,然后使用tree.predict的输出,我们可以将每个叶映射到关联的类。

注意:这不是一个答案,只是对可能的解决方案的提示。

我最近在我的项目中遇到了类似的问题。我的目标是为一些特定的样本提取相应的决策链。我认为你的问题是我的一个子集,因为你只需要记录下决策链的最后一步。

到目前为止,似乎唯一可行的解决方案是用Python编写一个定制的predict方法来跟踪一路上的决策。原因是scikit learn提供的predict方法不能在开箱即用(据我所知)。更糟糕的是,它是C实现的包装器,很难定制。

定制对我的问题很好,因为我处理的是一个不平衡的数据集,我关心的样本(积极的)很少。所以我可以先使用sklearn predict将它们过滤掉,然后使用我的定制获得决策链。

但是,如果您有一个大型数据集,这可能对您不起作用。因为如果您解析树并使用Python进行预测,那么它在Python中的运行速度会很慢,并且不会(容易)扩展。您可能不得不回退到自定义C实现。

相关问题 更多 >