决策树可视化指南
决策树是机器学习的一种经典的模型,因其泛化性能好,可解释性强而被广泛应用到实际商业预测中。通常在我们完成决策树模型搭建后,我们会进一步研究分析我们搭建好的模型,这时候模型的可视化就显得尤为重要。下面是生成的决策树可视化图像:
插件安装
scikit-learn中决策树的可视化一般需要安装graphviz。主要包括graphviz库的安装和python的graphviz插件的安装。
安装graphviz库:
- pip install graphviz
安装graphviz插件:
- 安装插件地址:http://www.graphviz.org,下载对应的插件。
- 安装插件(默认安装地址,直接一直点下一步直到完成安装)
- 环境配置: 复制安装目录的bin路径:C:\Program Files\Graphviz\bin
(a)打开我的电脑,点击属性
(b)选择高级设置
(c)选择环境变量
(d)双击选择Path进入
(e)新建,粘贴的前面复制的bin路径,点击确定
(f)重启jupter notebook即可
可视化的三种方法
搭建模型
from sklearn import tree
dtree = tree.DecisionTreeClassifier()
dtree.fit(x_train,y_train)
在搭建完决策树后,下面介绍可视化具体操作的三种方法
方法一:
简单粗暴,一行代码搞定什么都不用安装。缺点也很明显,生成的可视化图比较模糊,且不能保存图片,违背了可视化的初衷,不建议使用这种方法。
tree.plot_tree(bdtree,filled=True)
方法二:
这种方法比较常用,需要安装graphviz库和graphviz插件,安装方法上面已经介绍。这种使用这种方法得到的图像比较高清,并且还会额外生成PDF文件和一个文本文件。比较推荐使用。
import graphviz
dot_data = tree.export_graphviz(dtree,out_file=None,feature_names=feature_names,class_names=class_names,filled=True,rounded=True,special_characters=True)
graph = graphviz.Source(dot_data)
graph.render('computer')
方法三:
最后一种方法是最麻烦的,除了需要安装graphviz库和graphviz插件,还需要安装:
- pip install pydotplus
- pip install six
这种方法得到的图像比较清晰,且可以双击放大缩小,可以保存为png图片和pdf文档。这种方法根据个人需求使用吧。
# 决策树可视化
import graphviz
import pydotplus
from six import StringIO
from sklearn.tree import export_graphviz
from IPython.display import Image# 文件缓存
dot_data = StringIO()
# 将决策树导入到dot中
export_graphviz(bdtree, out_file=dot_data, filled=True, rounded=True,special_characters=True,feature_names = feature_names,class_names=class_names)
# 将生成的dot文件生成graph
print(feature_names)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
# 将结果存入到png文件中
graph.write_png('diabetes.png')
graph.write_pdf('diabetes.pdf')
# 显示
Image(graph.create_png())
注意事项
在我们做可视化的时候,需要注意中各重要的参数:
- feature_names
- class_names
在做可视化的时候需要传入这两个参数,这两个参数都需要传入一个列表类型的数据,并且两个参数需要与模型传入的数据一一对应,否则就违背了可视化的初衷了。
feature_names:
假如我们的数据格式是DataFrame格式通常可以在数据处理后通过以下方法获取(也可以手动输入):
feature_names = data.columns[:-1]
class_names:
- 这个参数如果类别数不多的话,最好手动输入。或者通过data[‘分类标签’].unique()获取后再做格式转换和顺序调整
决策树可视化实战
from sklearn import tree
import pandas as pddata = pd.read_csv('zoo.csv')
data.head()# 获取训练数据和标签
x_data = data.drop(['animal_name', 'class_type'], axis=1)
y_data = data['class_type']# 搭建模型
dtree = tree.DecisionTreeClassifier()
dtree.fit(x_data,y_data)# 获取feature_names和class_names
feature_names = data.columns[1:-1]cls_n = data.class_type.unique()
class_names = []
for i in cls_n:class_names.append(str(i))
class_names.sort()# # 方法一
# tree.plot_tree(dtree,filled=True)# # 方法二
# import graphviz# dot_data = tree.export_graphviz(dtree,
# out_file=None,
# feature_names=feature_names,
# class_names=cls_n,
# filled=True,
# rounded=True,
# special_characters=True)
# graph = graphviz.Source(dot_data)# graph.render('computer')
# graph
# 方法三
import graphviz
import pydotplus
from six import StringIO
from IPython.display import Image # 文件缓存
dot_data = StringIO()
# 将决策树导入到dot中
tree.export_graphviz(dtree, out_file=dot_data, filled=True, rounded=True,special_characters=True,feature_names = feature_names,class_names=class_names)
# 将生成的dot文件生成graph
print(feature_names)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
# 将结果存入到png文件中
graph.write_png('diabetes.png')
graph.write_pdf('diabetes.pdf')
# 显示
Image(graph.create_png())