sklearn中使用决策树

1.示例

criterion可以是信息熵,entropy,可以是基尼系数gini

# -*-coding:utf-8-*-
from sklearn import tree
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split
wine=load_wine()# print ( wine.feature_names )
#(178, 13)
print(wine.data.shape)Xtrain,Xtest,Ytrain,Ytest=train_test_split(wine.data,wine.target,test_size=0.3)#random_state=30:输入任意整数,会一直长同一棵树,让模型稳定下来
clf=tree.DecisionTreeClassifier(criterion="entropy",random_state=30,splitter="best")
# clf=tree.DecisionTreeClassifier(criterion="entropy")
clf=clf.fit(Xtrain,Ytrain)
#返回预测准确度accuracy
score=clf.score(Xtest,Ytest)print( score )import graphviz
dot_data=tree.export_graphviz(clf,feature_names=wine.feature_names,class_names=["wine1","wine2","wine3"],filled=True,rounded=True)
graph=graphviz.Source(dot_data)
#生成pdf文件
graph.render(view=True, format="pdf", filename="tree_pdf")
print ( graph )
#feature_importances_:每个特征在决策树中的重要成都
print(clf.feature_importances_)
print ( [*zip(wine.feature_names,clf.feature_importances_)] )

决策树生成的pdf 

 2.示例

max_depth:这参数用来控制决策树的最大深度。以下示例,构建1~10深度的决策时,看哪个深度的决策树的精确率(score)高

# -*-coding:utf-8-*-
from sklearn import tree
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as pltplt.switch_backend("TkAgg")wine=load_wine()# print ( wine.feature_names )
#(178, 13)
print(wine.data.shape)import pandas as pd
# print (pd.concat([pd.DataFrame(wine.data),pd.DataFrame(wine.target)],axis=1))
#所有的train,test必须是二维矩阵
Xtrain,Xtest,Ytrain,Ytest=train_test_split(wine.data,wine.target,test_size=0.3)test=[]
bestScore=-1
bestClf=None
for i in range(10):clf=tree.DecisionTreeClassifier(max_depth=i+1,criterion="entropy",random_state=30,splitter="random")clf=clf.fit(Xtrain,Ytrain)score=clf.score(Xtest,Ytest)test.append(score)if score>bestScore:bestScore=scorebestClf=clf
print(test)
print(test.index(bestScore))
#predict返回每个测试样本的分类/回归结果
predicted=bestClf.predict(Xtest)
print(predicted)#返回每个测试样本的叶子节点的索引
leaf=bestClf.apply(Xtest)
print(leaf)plt.plot(range(1,11),test,color="red",label="max_depth")
plt.legend()
plt.show()

结果:

(178, 13)
[0.5555555555555556, 0.8148148148148148, 0.9444444444444444, 0.9259259259259259, 0.8518518518518519, 0.8333333333333334, 0.8333333333333334, 0.8333333333333334, 0.8333333333333334, 0.8333333333333334]
2
[0 1 0 1 2 0 1 1 1 2 2 0 0 2 0 1 1 0 0 0 0 1 1 0 2 1 0 2 2 1 2 1 1 1 1 0 12 2 0 1 1 2 0 2 1 1 0 1 1 2 1 2 2]
[12  7 12 11  3 12  7  7  4  3  3 12 12  3 12  9  7 12 12 12 12  7  9 123  9 12  3  3  4  3  4  7  7  7 12  7  3  3 12  9  9  3 12  3  7  7 127  7  3  7  3  3]

3.交叉熵验证的示例 

# -*-coding:utf-8-*-
from sklearn.model_selection import cross_val_score
from sklearn.tree import DecisionTreeRegressor
import sklearn
from sklearn.datasets import fetch_california_housinghousing=fetch_california_housing()
# print(housing)
# print(housing.data)
# print(housing.target)regressor=DecisionTreeRegressor(random_state=0)#cv=10,10次交叉验证,default:cv=5
#scoring="neg_mean_squared_error",评价指标是负的均方误差
cross_res=cross_val_score(regressor,housing.data,housing.target,scoring="neg_mean_squared_error",cv=10)
print(cross_res)
[-1.30551334 -0.78405711 -0.72809865 -0.50413232 -0.79683323 -0.83698199-0.56591889 -1.03621067 -1.02786488 -0.51371889]

4.Titanic生存者预测

数据来源:

Titanic - Machine Learning from Disaster | Kaggle

数据预处理

读取数据 

import pandas as pd
from sklearn.tree import DecisionTreeClassifier
import matplotlib.pyplot as plt
from sklearn.model_selection import GridSearchCV
#---------设置pd,在pycharm中显示完全表格-------
pd.set_option('display.max_columns', 1000)
pd.set_option('display.width', 1000)
pd.set_option('display.max_colwidth', 1000)
#----------------------------------------
data=pd.read_csv("./data.csv")
print (data.head(5))
print(data.info())
PassengerId  Survived  Pclass                                                 Name     Sex   Age  SibSp  Parch            Ticket     Fare Cabin Embarked
0            1         0       3                              Braund, Mr. Owen Harris    male  22.0      1      0         A/5 21171   7.2500   NaN        S
1            2         1       1  Cumings, Mrs. John Bradley (Florence Briggs Thayer)  female  38.0      1      0          PC 17599  71.2833   C85        C
2            3         1       3                               Heikkinen, Miss. Laina  female  26.0      0      0  STON/O2. 3101282   7.9250   NaN        S
3            4         1       1         Futrelle, Mrs. Jacques Heath (Lily May Peel)  female  35.0      1      0            113803  53.1000  C123        S
4            5         0       3                             Allen, Mr. William Henry    male  35.0      0      0            373450   8.0500   NaN        S
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 12 columns):#   Column       Non-Null Count  Dtype  
---  ------       --------------  -----  0   PassengerId  891 non-null    int64  1   Survived     891 non-null    int64  2   Pclass       891 non-null    int64  3   Name         891 non-null    object 4   Sex          891 non-null    object 5   Age          714 non-null    float646   SibSp        891 non-null    int64  7   Parch        891 non-null    int64  8   Ticket       891 non-null    object 9   Fare         891 non-null    float6410  Cabin        204 non-null    object 11  Embarked     889 non-null    object 
dtypes: float64(2), int64(5), object(5)
memory usage: 83.7+ KB
NoneProcess finished with exit code 0

筛选特征

data.drop(["Cabin","Name","Ticket"],inplace=True,axis=1)
print(data.head())
print(data.info())
   PassengerId  Survived  Pclass     Sex   Age  SibSp  Parch     Fare Embarked
0            1         0       3    male  22.0      1      0   7.2500        S
1            2         1       1  female  38.0      1      0  71.2833        C
2            3         1       3  female  26.0      0      0   7.9250        S
3            4         1       1  female  35.0      1      0  53.1000        S
4            5         0       3    male  35.0      0      0   8.0500        S
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 9 columns):#   Column       Non-Null Count  Dtype  
---  ------       --------------  -----  0   PassengerId  891 non-null    int64  1   Survived     891 non-null    int64  2   Pclass       891 non-null    int64  3   Sex          891 non-null    object 4   Age          714 non-null    float645   SibSp        891 non-null    int64  6   Parch        891 non-null    int64  7   Fare         891 non-null    float648   Embarked     889 non-null    object 
dtypes: float64(2), int64(5), object(2)
memory usage: 62.8+ KB
None

处理缺失值

#年龄用均值填补
data["Age"]=data["Age"].fillna(data["Age"].mean())
print(data.info())
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 9 columns):#   Column       Non-Null Count  Dtype  
---  ------       --------------  -----  0   PassengerId  891 non-null    int64  1   Survived     891 non-null    int64  2   Pclass       891 non-null    int64  3   Sex          891 non-null    object 4   Age          891 non-null    float645   SibSp        891 non-null    int64  6   Parch        891 non-null    int64  7   Fare         891 non-null    float648   Embarked     889 non-null    object 
dtypes: float64(2), int64(5), object(2)
memory usage: 62.8+ KB
None
#删除有缺失值的行,Embarked缺了两行
data=data.dropna()
print(data.info())
<class 'pandas.core.frame.DataFrame'>
Int64Index: 889 entries, 0 to 890
Data columns (total 9 columns):#   Column       Non-Null Count  Dtype  
---  ------       --------------  -----  0   PassengerId  889 non-null    int64  1   Survived     889 non-null    int64  2   Pclass       889 non-null    int64  3   Sex          889 non-null    object 4   Age          889 non-null    float645   SibSp        889 non-null    int64  6   Parch        889 non-null    int64  7   Fare         889 non-null    float648   Embarked     889 non-null    object 
dtypes: float64(2), int64(5), object(2)
memory usage: 69.5+ KB
None

处理非数值的列

查看非数值列的所有值

print(data["Embarked"].unique())
print(data["Sex"].unique())#------------结果如下----------
['S' 'C' 'Q']
['male' 'female']
labels=data["Embarked"].unique().tolist()
#x代表data[Embarked]的每一行的值,S-->0,C-->1,Q-->2
data["Embarked"]=data["Embarked"].apply(lambda x:labels.index(x))#把条件为True的转为int行
#也可以这样写:data.loc[:,"Sex"]=(data["Sex"]=="male").astype("int")
#male-->0,female-->1
data["Sex"]=(data["Sex"]=="male").astype("int")

提取数据

x=data.iloc[:, data.columns!="Survived"]
y=data.iloc[:,data.columns=="Survived"]#Xtrain:(622, 8)
#划分数据集和测试集
from sklearn.model_selection import train_test_split
Xtrain,Xtest,Ytrain,Ytest=train_test_split(x,y,test_size=0.3)#把索引变为从0~622
for i in [Xtrain,Xtest,Ytrain,Ytest]:i.index=range(i.shape[0])

第一种方法构建决策树

# clf=DecisionTreeClassifier(random_state=25)
# clf=clf.fit(Xtrain,Ytrain)
# score=clf.score(Xtest,Ytest)
# print(score)
from sklearn.model_selection import cross_val_score
# clf=DecisionTreeClassifier(random_state=25)
# score=cross_val_score(clf,x,y,cv=10).mean()
# print(score)tr=[]
te=[]
for i in range(10):clf=DecisionTreeClassifier(random_state=25,max_depth=i+1,criterion="entropy")clf=clf.fit(Xtrain,Ytrain)score_tr=clf.score(Xtrain,Ytrain)score_te=cross_val_score(clf,x,y,cv=10).mean()tr.append(score_tr)te.append(score_te)
print(max(te))
plt.plot(range(1,11),tr,color="red",label="train")
plt.plot(range(1,11),te,color="blue",label="test")
#1~10全部显示
plt.xticks(range(1,11))
plt.legend()
plt.show()

不同深度的决策树的测试集和训练集的表现 

 第二种方法构建决策树

import pandas as pd
from sklearn.tree import DecisionTreeClassifier
import matplotlib.pyplot as plt
plt.switch_backend("TkAgg")
from sklearn.model_selection import GridSearchCV
import numpy as np#---------设置pd,在pycharm中显示完全表格-------
pd.set_option('display.max_columns', 1000)
pd.set_option('display.width', 1000)
pd.set_option('display.max_colwidth', 1000)
#----------------------------------------
data=pd.read_csv("./data.csv")
# print (data.head(5))
# print(data.info())#去掉姓名、Cabin、票号的特征
data.drop(["Cabin","Name","Ticket"],inplace=True,axis=1)
# print(data.head())
# print(data.info())#处理缺失值
#年龄用均值填补
data["Age"]=data["Age"].fillna(data["Age"].mean())
# print(data.info())#删除有缺失值的行,Embarked缺了两行,所有的数据去掉不完整的行
data=data.dropna()
# print(data.info())# print(data["Embarked"].unique())
# print(data["Sex"].unique())labels=data["Embarked"].unique().tolist()
#x代表data[Embarked]的每一行的值,S-->0,C-->1,Q-->2
data["Embarked"]=data["Embarked"].apply(lambda x:labels.index(x))#把条件为True的转为int行
#也可以这样写:data.loc[:,"Sex"]=(data["Sex"]=="male").astype("int")
#male-->0,female-->1
data["Sex"]=(data["Sex"]=="male").astype("int")x=data.iloc[:, data.columns!="Survived"]
y=data.iloc[:,data.columns=="Survived"]#Xtrain:(622, 8)
#划分数据集和测试集
from sklearn.model_selection import train_test_split
Xtrain,Xtest,Ytrain,Ytest=train_test_split(x,y,test_size=0.3)#把索引变为从0~622
for i in [Xtrain,Xtest,Ytrain,Ytest]:i.index=range(i.shape[0])from sklearn.model_selection import cross_val_scoreclf=DecisionTreeClassifier(random_state=25)
#GridSearchCV:满足fit,score,交叉验证三个功能
#parameters:一串参数和这些参数对应的,我们希望网格搜索来搜索对应的参数的取值范围
parameters={"criterion":("gini","entropy"),"splitter":("best","random"),"max_depth":[*range(1,10)],"min_samples_leaf":[*range(1,50,5)],"min_impurity_decrease":[*np.linspace(0,0.5,20)]
}
GS=GridSearchCV(clf,parameters,cv=10)
gs=GS.fit(Xtrain,Ytrain)#从输入的参数和参数取值中,返回最佳组合
print(gs.best_params_)#网格搜索后的模型的评判标准
print(gs.best_score_)
{'criterion': 'entropy', 'max_depth': 3, 'min_impurity_decrease': 0.0, 'min_samples_leaf': 1, 'splitter': 'best'}
0.8297235023041475

这种方法构建的决策树的准确率比第一种的还低

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/27037.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【2.3】Java微服务:sentinel服务哨兵

✅作者简介&#xff1a;大家好&#xff0c;我是 Meteors., 向往着更加简洁高效的代码写法与编程方式&#xff0c;持续分享Java技术内容。 &#x1f34e;个人主页&#xff1a;Meteors.的博客 &#x1f49e;当前专栏&#xff1a;Java微服务 ✨特色专栏&#xff1a; 知识分享 &…

css-3:什么是响应式设计?响应式的原理是什么?如何做?

1、响应式设计是什么&#xff1f; 响应式网站设计&#xff08;Responsive WEB desgin&#xff09;是一个网络页面设计布局&#xff0c;页面的设计与开发应当根据用户行为以及设备环境&#xff08;系统平台、屏幕尺寸、屏幕定向等&#xff09;进行相应的相应和调整。 描述响应式…

ensp与虚拟机搭建测试环境

1.虚拟机配置 ①首先确定VMnet8 IP地址&#xff0c;若要修改IP地址&#xff0c;保证在启动Ensp前操作 ②尽量保证NAT模式 2.ensp配置 (1)拓扑结构 (2)Cloud配置 ①首先点击 绑定信息 UDP → 增加 ②然后点击 绑定信息 VMware ... → 增加 ③最后在 端口映射设置上点击双向通…

Hive创建外部表详细步骤

① 在hive中执行HDFS命令&#xff1a;创建/data目录 hive命令终端输入&#xff1a; hive> dfs -mkdir -p /data; 或者在linux命令终端输入&#xff1a; hdfs dfs -mkdir -p /data; ② 在hive中执行HDFS命令&#xff1a;上传/emp.txt至HDFS的data目录下&#xff0c;并命名为…

jmeter工具测试和压测websocket协议【杭州多测师_王sir】

一、安装JDK配置好环境变量&#xff0c;安装好jmeter 二、下载WebSocketSampler发送请求用的&#xff0c;地址&#xff1a;https://bitbucket.org/pjtr/jmeter-websocket-samplers/downloads/?spma2c4g.11186623.2.15.363f211bH03KeI 下载解压后的jar包放到D:\JMeter\apache-j…

2.Flink应用

2.1 数据流 DataStream&#xff1a;DataStream是Flink数据流的核心抽象&#xff0c;其上定义了对数据流的一系列操作DataStreamSource&#xff1a;DataStreamSource 是 DataStream 的 起 点 &#xff0c; DataStreamSource 在StreamExecutionEnvironment 中 创 建 &#xff0c;…

init_pg_dir 的大小及作用

init_pg_dir 的大小 vmlinux.lds.S 中 在vmlinux.lds.S 中&#xff0c;有 init_pg_dir .; . INIT_DIR_SIZE; init_pg_end .;/*include/asm/kernel-pgtable.h*/ #define EARLY_ENTRIES(vstart, vend, shift) \ ((((vend) - 1) >&g…

基于 CentOS 7 构建 LVS-DR 群集

文章目录 前言1、LVS集群2、DR模式的工作流程图 一、LVS DR模式的配置二、配置步骤总结 前言 什么是LVS集群&#xff1f;DR模式&#xff1f; 1、LVS集群 LVS采用的是合入内核模块&#xff0c;先把对于nginx来说要稳定很多&#xff0c;性能和稳定都在一定层度上占据优势&…

【ChatGPT 指令大全】怎么使用ChatGPT写履历和通过面试

目录 怎么使用ChatGPT写履历 寻求履历的反馈 为履历加上量化数据 把经历修精简 为不同公司客制化撰写履历 怎么使用ChatGPT通过面试 汇整面试题目 给予回馈 提供追问的问题 用 STAR 原则回答面试问题 感谢面试官的 email 总结 在职场竞争激烈的今天&#xff0c;写一…

linux网络编程--线程池UDP

目录 学习目标 1线程池 2.UDP通信 3本地socket通信 学习目标 了解线程池模型的设计思想能看懂线程池实现源码掌握tcp和udp的优缺点和使用场景说出udp服务器通信流程说出udp客户端通信流程独立实现udp服务器代码独立实现udp客户端代码熟练掌握本地套接字进行本地进程通信 1…

FreeRTOS源码分析-10 互斥信号量

目录 1 事件标志组概念及其应用 1.1 事件标志组定义 1.2 FreeRTOS事件标志组介绍 1.3 FreeRTOS事件标志组工作原理 2 事件标志组应用 2.1 功能需求 2.2 API 2.3 功能实现 3 事件标志组原理 3.1 事件标志组控制块 3.2 事件标志组获取标志位 3.3 等待事件标志触发 3.4…

小程序的api使用 以及一些weui组件实列获取头像 扫码等

今日目标 响应式单位rpx小程序的生命周期 【重点】20%小程序框架 weui 【重点】 50%内置API 【重点】30%综合练习 1. 响应式rpx 1.1 rpx单位 rpx是微信小程序提出的一个尺寸单位&#xff0c;将整个手机屏幕宽度分为750份&#xff0c;1rpx 就是 1/750&#xff0c;避免不同手…

QT自带PDF库的使用

QT自带PDF库可以方便的打开PDF文件&#xff0c;并将文件解析为QImage&#xff0c;相比网上提供的开源库&#xff0c;QT自带PDF库使用更方便&#xff0c;也更加可靠&#xff0c;然而&#xff0c;QT自带PDF库的使用却不同于其他通用库的使用&#xff0c;具备一定的技巧。 1. 安装…

以太网DHCP协议(十)

目录 一、工作原理 二、DHCP报文 2.1 DHCP报文类型 2.2 DHCP报文格式 当网络内部的主机设备数量过多是&#xff0c;IP地址的手动设置是一件非常繁琐的事情。为了实现自动设置IP地址、统一管理IP地址分配&#xff0c;TCPIP协议栈中引入了DHCP协议。 一、工作原理 使用DHCP之…

通向架构师的道路之weblogic与apache的整合与调优

一、BEAWeblogic的历史 BEA WebLogic是用于开发、集成、部署和管理大型分布式Web应用、 网络应用和数据库应 用的Java应用服务器。将Java的动态功能和Java Enterprise标准的安全性引入大型网络应用的 开发、集成、部署和管理之中。 BEA WebLogic Server拥有处理关键Web应…

pytorch求导

pytorch求导的初步认识 requires_grad tensor(data, dtypeNone, deviceNone, requires_gradFalse)requires_grad是torch.tensor类的一个属性。如果设置为True&#xff0c;它会告诉PyTorch跟踪对该张量的操作&#xff0c;允许在反向传播期间计算梯度。 x.requires_grad 判…

TM4C123库函数学习(1)--- 点亮LED+TM4C123的ROM函数简介+keil开发环境搭建

前言 &#xff08;1&#xff09; 首先&#xff0c;我们需要知道TM4C123是M4的内核。对于绝大多数人而言&#xff0c;入门都是学习STM32F103&#xff0c;这款芯片是采用的M3的内核。所以想必各位对M3内核还是有一定的了解。M4内核就是M3内核的升级版本&#xff0c;他继承了M3的的…

【力扣每日一题】2023.8.5 合并两个有序链表

目录 题目&#xff1a; 示例&#xff1a; 分析&#xff1a; 代码&#xff1a; 题目&#xff1a; 示例&#xff1a; 分析&#xff1a; 题目给我们两个有序的链表&#xff0c;要我们保持升序的状态合并它们。 我们可以马上想要把两个链表都遍历一遍&#xff0c;把所有节点的…

1-搭建一个最简单的验证平台UVM,已用Questasim实现波形!

UVM-搭建一个最简单的验证平台&#xff0c;已用Questasim实现波形 1&#xff0c;背景知识2&#xff0c;".sv"文件搭建的UVM验证平台&#xff0c;包括代码块分享3&#xff0c;Questasim仿真输出&#xff08;1&#xff09;compile all&#xff0c;成功&#xff01;&…

【力扣每日一题】2023.8.8 任意子数组和的绝对值的最大值

目录 题目&#xff1a; 示例&#xff1a; 分析&#xff1a; 代码&#xff1a; 题目&#xff1a; 示例&#xff1a; 分析&#xff1a; 题目给我们一个数组&#xff0c;让我们找出它的绝对值最大的子数组的和。 这边的子数组是要求连续的&#xff0c;让我们找出一个元素之和…