<p>你需要使用预测方法。</p>
<p>在对树进行训练之后,输入X值以预测其输出。</p>
<pre><code>from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
clf = DecisionTreeClassifier(random_state=0)
iris = load_iris()
tree = clf.fit(iris.data, iris.target)
tree.predict(iris.data)
</code></pre>
<p>输出:</p>
<pre><code>>>> tree.predict(iris.data)
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
</code></pre>
<p>要获取树结构的详细信息,可以使用<code>tree_.__getstate__()</code></p>
<p>树结构转换为“ASCII艺术”图片</p>
<pre><code> 0
_____________
1 2
______________
3 12
_______ _______
4 7 13 16
___ ______ _____
5 6 8 9 14 15
_____
10 11
</code></pre>
<p>作为数组的树结构。</p>
<pre><code>In [38]: tree.tree_.__getstate__()['nodes']
Out[38]:
array([(1, 2, 3, 0.800000011920929, 0.6666666666666667, 150, 150.0),
(-1, -1, -2, -2.0, 0.0, 50, 50.0),
(3, 12, 3, 1.75, 0.5, 100, 100.0),
(4, 7, 2, 4.949999809265137, 0.16803840877914955, 54, 54.0),
(5, 6, 3, 1.6500000953674316, 0.04079861111111116, 48, 48.0),
(-1, -1, -2, -2.0, 0.0, 47, 47.0),
(-1, -1, -2, -2.0, 0.0, 1, 1.0),
(8, 9, 3, 1.5499999523162842, 0.4444444444444444, 6, 6.0),
(-1, -1, -2, -2.0, 0.0, 3, 3.0),
(10, 11, 2, 5.449999809265137, 0.4444444444444444, 3, 3.0),
(-1, -1, -2, -2.0, 0.0, 2, 2.0),
(-1, -1, -2, -2.0, 0.0, 1, 1.0),
(13, 16, 2, 4.850000381469727, 0.042533081285444196, 46, 46.0),
(14, 15, 1, 3.0999999046325684, 0.4444444444444444, 3, 3.0),
(-1, -1, -2, -2.0, 0.0, 2, 2.0),
(-1, -1, -2, -2.0, 0.0, 1, 1.0),
(-1, -1, -2, -2.0, 0.0, 43, 43.0)],
dtype=[('left_child', '<i8'), ('right_child', '<i8'),
('feature', '<i8'), ('threshold', '<f8'),
('impurity', '<f8'), ('n_node_samples', '<i8'),
('weighted_n_node_samples', '<f8')])
</code></pre>
<p>其中:</p>
<ul>
<li>第一个节点[0]是根节点。</li>
<li>内部节点的左/右子节点指的是值为正值且大于当前节点的节点。</li>
<li>左和右子节点的叶具有-1值。</li>
<li>节1、5、6、8、10、11、14、15、16为叶。</li>
<li>采用深度优先搜索算法建立节点结构。</li>
<li>feature字段告诉我们节点中使用了哪个iris.data特性来确定此示例的路径。</li>
<li>阈值告诉我们用于根据特征计算方向的值。</li>
<li>杂质在叶子上达到0。。。因为一旦你拿到叶子,所有的样本都在同一个类中。</li>
<li>节点样本告诉我们每个叶子有多少个样本。</li>
</ul>
<p>使用这些信息,我们可以按照脚本上的分类规则和阈值,简单地跟踪每个样本X到它最终到达的叶。此外,n_node_samples允许我们执行单元测试,确保每个节点获得正确的样本数,然后使用tree.predict的输出,我们可以将每个叶映射到关联的类。</p>