scikit学习中决策树的可视化

2024-05-17 09:02:08 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图使用Python中的scikit learn设计一个简单的决策树(我在Windows操作系统上使用Anaconda的Ipython笔记本和Python 2.7.3),并将其可视化如下:

from pandas import read_csv, DataFrame
from sklearn import tree
from os import system

data = read_csv('D:/training.csv')
Y = data.Y
X = data.ix[:,"X0":"X33"]

dtree = tree.DecisionTreeClassifier(criterion = "entropy")
dtree = dtree.fit(X, Y)

dotfile = open("D:/dtree2.dot", 'w')
dotfile = tree.export_graphviz(dtree, out_file = dotfile, feature_names = X.columns)
dotfile.close()
system("dot -Tpng D:.dot -o D:/dtree2.png")

但是,我得到以下错误:

AttributeError: 'NoneType' object has no attribute 'close'

我使用以下博客文章作为参考:Blogpost link

下面的stackoverflow问题似乎对我也不起作用:Question

有人能帮我在scikit learn中可视化决策树吗?


Tags: csvfromimport决策树treereaddata可视化
3条回答

这里有一行是给那些使用jupyter和sklearn(18.2+)的人的,你甚至不需要matplotlib。唯一的要求是graphviz

pip install graphviz

比运行(根据问题中的代码X是pandas数据帧)

from graphviz import Source
from sklearn import tree
Source( tree.export_graphviz(dtreg, out_file=None, feature_names=X.columns))

这将以SVG格式显示。上面的代码生成Graphviz的Source对象(source_code-不可怕),该对象将直接在jupyter中呈现。

一些你可能会用它做的事情

以点唱显示:

from IPython.display import SVG
graph = Source( tree.export_graphviz(dtreg, out_file=None, feature_names=X.columns))
SVG(graph.pipe(format='svg'))

另存为png:

graph = Source( tree.export_graphviz(dtreg, out_file=None, feature_names=X.columns))
graph.format = 'png'
graph.render('dtree_render',view=True)

获取png图像,保存并查看:

graph = Source( tree.export_graphviz(dtreg, out_file=None, feature_names=X.columns))
png_bytes = graph.pipe(format='png')
with open('dtree_pipe.png','wb') as f:
    f.write(png_bytes)

from IPython.display import Image
Image(png_bytes)

如果要使用这个库,这里有指向examplesuserguide的链接

如果像我一样,在安装graphviz时遇到问题,可以通过

  1. export_graphviz导出它,如前面的答案所示
  2. 在文本编辑器中打开.dot文件
  3. 复制代码并粘贴到@webgraphviz.com

^{}不返回任何内容,因此默认情况下返回None

通过执行dotfile = tree.export_graphviz(...)操作,将覆盖先前分配给dotfile的open file对象,因此在尝试关闭文件时会出现错误(现在是None)。

要修复它,请将代码更改为

...
dotfile = open("D:/dtree2.dot", 'w')
tree.export_graphviz(dtree, out_file = dotfile, feature_names = X.columns)
dotfile.close()
...

相关问题 更多 >