【SHAP解释运用】基于python的树模型特征选择+随机森林回归预测+SHAP解释预测

1.导入必要的库

import pandas as pd  
import numpy as np  
import matplotlib.pyplot as plt  
import seaborn as sns  
from sklearn.model_selection import train_test_split  
from sklearn.ensemble import RandomForestRegressor  
from sklearn.tree import export_graphviz  
#from sklearn.inspection import plot_partial_dependence   
from sklearn.metrics import mean_squared_error  
import shap  
import warnings  

2.设置忽略警告与显示字体、负号

warnings.filterwarnings("ignore")  # 设置Matplotlib的字体属性  
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用于中文显示,你可以更改为其他支持中文的字体  
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号 

3.导入数据集

3.1加载数据
# 1. 加载数据  
df = pd.read_excel('train.xlsx')  
X = df.iloc[:, :-1]  # 特征  
y = df.iloc[:, -1]   # 标签  
3.2查看数据分布
1.箱线图
plt.figure(figsize=(30, 6))  
sns.boxplot(data=df)  
plt.title('Box Plots of Dataset Features', fontsize=40, color='black')  
# 如果需要设置坐标轴标签的字体大小和颜色  
plt.xlabel('X-axis Label', fontsize=20, color='red')  # 设置x轴标签的字体大小和颜色  
plt.ylabel('Y-axis Label', fontsize=20, color='green')  # 设置y轴标签的字体大小和颜色  # 还可以调整刻度线的长度、宽度等属性  
plt.tick_params(axis='x', labelsize=20, colors='blue', length=5, width=1)  # 设置x轴刻度线、刻度标签的更多属性  
plt.tick_params(axis='y', labelsize=20, colors='deepskyblue', length=5, width=1)  # 设置y轴刻度线、刻度标签的更多属性    
plt.xticks(rotation=45)  # 如果特征名很长,可以旋转x轴标签  
plt.show()

        结果如图3-1所示:

图3-1

        结果图实在丑陋,这是由数据分布不均衡造成的,这里重点不是数据清洗,就这样凑着用吧。

2.分布图
# 注意:distplot 在 seaborn 0.11.0+ 中已被移除  
# 你可以分别使用 histplot 和 kdeplot  plt.figure(figsize=(50, 10))  
for i, feature in enumerate(df.columns, 1):  plt.subplot(1, len(df.columns), i)  sns.histplot(df[feature], kde=True, bins=30, label=feature,color='blue') plt.title(f'QQ plot of {feature}', fontsize=40, color='black')  # 如果需要设置坐标轴标签的字体大小和颜色  plt.xlabel('X-axis Label', fontsize=35, color='red')  # 设置x轴标签的字体大小和颜色  plt.ylabel('Y-axis Label', fontsize=40, color='green')  # 设置y轴标签的字体大小和颜色  # 还可以调整刻度线的长度、宽度等属性  plt.tick_params(axis='x', labelsize=40, colors='blue', length=5, width=1)  # 设置x轴刻度线、刻度标签的更多属性  plt.tick_params(axis='y', labelsize=40, colors='deepskyblue', length=5, width=1)  # 设置y轴刻度线、刻度标签的更多属性 
plt.tight_layout()  
plt.show()

        结果如图3-2所示:

图3-2

3.QQ图
from scipy import stats
plt.figure(figsize=(50, 10))  
for i, feature in enumerate(df.columns, 1):  plt.subplot(1, len(df.columns), i)  stats.probplot(df[feature], dist="norm", plot=plt)  plt.title(f'QQ plot of {feature}', fontsize=40, color='black')  # 如果需要设置坐标轴标签的字体大小和颜色  plt.xlabel('X-axis Label', fontsize=35, color='red')  # 设置x轴标签的字体大小和颜色  plt.ylabel('Y-axis Label', fontsize=40, color='green')  # 设置y轴标签的字体大小和颜色  # 还可以调整刻度线的长度、宽度等属性  plt.tick_params(axis='x', labelsize=40, colors='blue', length=5, width=1)  # 设置x轴刻度线、刻度标签的更多属性  plt.tick_params(axis='y', labelsize=40, colors='deepskyblue', length=5, width=1)  # 设置y轴刻度线、刻度标签的更多属性   
plt.tight_layout()  
plt.show()

        结果如图3-3所示:

图3-3

4.树模型特征选择

# 4. 特征选择(使用随机森林的特征重要性)  
rf = RandomForestRegressor(n_estimators=100, random_state=42)  
rf.fit(X_scaled, y)  
importances = rf.feature_importances_  
indices = np.argsort(importances)[::-1]  # 可视化特征重要性  
plt.figure(figsize=(10,7))  
plt.title("Feature importances")  
plt.bar(range(X.shape[1]), importances[indices],align="center", color='cyan')
plt.xticks(range(X.shape[1]), [X.columns[i] for i in indices], rotation='vertical')  
plt.xlim([-1, X.shape[1]])  
plt.show()

        特征重要性比较如图4-1所示:

图4-1

5.随机森林回归预测

# 划分训练集和测试集  
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)  # 随机森林回归预测  
rf_final = RandomForestRegressor(n_estimators=100, random_state=42)  
rf_final.fit(X_train, y_train)  
y_pred = rf_final.predict(X_test)  
mse = mean_squared_error(y_test, y_pred)  
print(f"Mean Squared Error: {mse}")  # 预测结果输出与比对
plt.figure()
plt.plot(np.arange(21), y_test[:100], "go-", label="True value")
plt.plot(np.arange(21), y_pred[:100], "ro-", label="Predict value")
plt.title("True value And Predict value")
plt.legend()
plt.show()

        预测结果如图5-1所示:

图5-1

        由图5-1结合这里的误差Mean Squared Error: 16.092619015714185,说明预测效果很一般,不过本身数据集没有经过清洗,数据分布不合理,有这样的结果也能接受。我一般使用matlab进行数据清晰和标准化,matlab暂时打不开,先搁置,后面我会出数据标准化的文章。

5.SHAP库解释预测

5.1shap库下载安装

        这里的shap库我已经下载安装过了,没有下载安装的在pycharm终端、Anaconda Promt终端等等执行命令进行下载安装,最好带上清华镜像源,在网络信号不好时也能顺利安装且节省时间。

pip install -i https://pypi.tuna.tsinghua.edu.cn/simple shap
5.2waterfall
shap.plots.waterfall(shap_values[0]) # For the first observation

        结果如图5-1所示:

图5-1

5.3forceplot
#相互作用图
force_plot1 = shap.force_plot(explainer.expected_value, np.mean(shap_values, axis=0), np.mean(X_test, axis=0),feature_label,matplotlib=True, show=False)
shap_interaction_values = explainer.shap_interaction_values(X_test)
shap.summary_plot(shap_interaction_values,X_test)

        结果如图5-2所示:

图5-2

5.4特征影响图
shap.plots.force(explainer.expected_value,shap_values.values,shap_values.data)

        结果如图5-3所示:

图5-3

5.5特征密度散点图:summary_plot/beeswarm
5.5.1summary_plot
# 创建SHAP解释器
explainer = shap.TreeExplainer(rf)# 计算SHAP值
shap_values = explainer.shap_values(X_test)#特征标签
feature_label=['feature1','feature2','feature3','feature4','feature5','feature6','feature7']plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = 'Times New Roman'
plt.rcParams['font.size'] = 13  # 设置字体大小为14
# 现在创建 SHAP 可视化 #配色   viridis  Spectral   coolwarm  RdYlGn  RdYlBu  RdBu  RdGy  PuOr  BrBG PRGn  PiYG 
shap.summary_plot(shap_values, X_test,feature_names=feature_label)#粉红色点:表示该特征值在这个观察中对模型预测产生了正面影响(增加预测值)
#蓝色点:表示该特征值在这个观察中对模型预测产生了负面影响(降低预测值)
#水平轴(SHAP 值)显示了影响的大小。点越远离中心线(零点),该特征对模型输出的影响越大
#图中垂直排列的特征按影响力从上到下排序。上方的特征对模型输出的总体影响更大,而下方的特征影响较小。
# 最上方的特征显示了大量的正面和负面影响,表明它在不同的观察值中对模型预测的结果有很大的不同影响。
# 中部的特征也显示出两种颜色的点,但点的分布更集中,影响相对较小。
# 底部的特征对模型的影响最小,且大部分影响较为接近零,表示这些特征对模型预测的贡献较小

        结果如图5-4所示:

图5-4


# 创建SHAP解释器
explainer = shap.TreeExplainer(rf)
# 计算SHAP值
shap_values = explainer.shap_values(X_test)
#特征标签
feature_label=['feature1','feature2','feature3','feature4','feature5','feature6','feature7']
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = 'Times New Roman'
plt.rcParams['font.size'] = 13  # 设置字体大小为14
# 现在创建 SHAP 可视化 
#配色   viridis  Spectral   coolwarm  RdYlGn  RdYlBu  RdBu  RdGy  PuOr  BrBG PRGn  PiYG 
shap.summary_plot(shap_values,X_test,feature_names=feature_label,cmap='Spectral')

使颜色丰富些如图5-5所示:

图5-5

5.5.2beeswarm
# summarize the effects of all the features
# 样本决策图
shap.initjs()
shap_values = explainer(X_test)
expected_value = explainer.expected_value
shap.plots.beeswarm(shap_values)

结果如图5-6所示:

图5-6

5.6特征重要性SHAP值
shap.summary_plot(shap_values,X_test,feature_names=feature_label,plot_type='bar')
#主要表示绝对重要值的大小,把SHAP value 的样本取了绝对平均值

        或者:

shap.plots.bar(shap_values)

        结果如图5-7、图5-8所示,本质都是一样的:

图5-7

图5-8

5.7聚类热力图:heatmap plot

#热图
shap.initjs()
shap_values = explainer(X_test)
shap.plots.heatmap(shap_values)

        结果如图5-9所示:

图5-9

5.7层次聚类shap值
# 层次聚类 + SHAP值
clust = shap.utils.hclust(X, y, linkage="single")
shap.plots.bar(shap_values, clustering=clust, clustering_cutoff=1)

        结果如图5-10所示:

图5-10

5.8决策图
# 样本决策图
shap.initjs()
shap_values = explainer.shap_values(X_test)
expected_value = explainer.expected_value
shap.decision_plot(expected_value, shap_values,feature_label)

        结果如图5-11所示:

图5-11

变形1:由数值 -> 概率

# 样本决策图
shap.initjs()
shap_values = explainer.shap_values(X_test)
expected_value = explainer.expected_value
feature_label=['feature1','feature2','feature3','feature4','feature5','feature6','feature7']
shap.decision_plot(expected_value, shap_values, feature_label, link='logit')

        结果如图5-12所示:

图5-12

变形2:高亮某个样本线highlight

shap.decision_plot(expected_value, shap_values, feature_label, highlight=12)

        结果如图5-13所示:

图5-13

5.9特征依赖图:dependence_plot
5.9.1单个特征依赖
shap.dependence_plot('feature1', shap_values,X_test, interaction_index=None)

        结果如图5.14所示:

图5-14

5.9.2相互依赖图
shap.dependence_plot('feature3', shap_values,X_test, interaction_index='feature4')

        结果如图5-15所示:

图5-15

5.10相互作用图:summary_plot
shap.summary_plot(shap_interaction_values,X_test)

        结果如图5-16所示:

图5-16

具体的每种解释图的含义可以搜寻以下参考文章:

代码借鉴:http://t.csdnimg.cn/6JWrj

理论借鉴   

http://t.csdnimg.cn/6JWrj

http://t.csdnimg.cn/H9X0B

http://t.csdnimg.cn/zvtA8

http://t.csdnimg.cn/nygl6

http://t.csdnimg.cn/zyHy0

http://t.csdnimg.cn/rTPw2

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/34868.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Stable Diffusion 3 文本生成图像 在线体验 原理分析

前言 本文分享使用Stable Diffusion 3实现文本生成图像,可以通过在线网页中免费使用的,也有API等方式访问。 同时结合论文和开源代码进行分析,理解其原理。 Stable Diffusion 3是Stability AI开发的最新、最先进的文本生成图像模型&#x…

性能工具之 MySQL OLTP Sysbench BenchMark 测试示例

文章目录 一、前言二、测试环境1、服务器配置2、测试拓扑 三、测试工具安装四、测试步骤1、导入数据2、压测数据3、清理数据 五、结果解析六、最后 一、前言 做为一名性能工程师掌握对 MySQL 的性能测试是非常必要的,本文基于 Sysbench 对MySQL OLTP(联…

现在的Java面试都这么扯淡了吗?

在开始前刚好我有一些资料,是我根据网友给的问题精心整理了一份「java的资料从专业入门到高级教程」, 点个关注在评论区回复“666”之后私信回复“666”,全部无偿共享给大家!!!开发兼过半年面试官 刚开始…

网络安全等级保护测评

网络安全等级保护 《GB17859 计算机信息系统安全保护等级划分准则》 规定计算机信息系统安全保护等级共分五级 《中华人民共和国网络安全法》 “国家实行网络安全等级保护制度。 等级测评 测评机构依据国家网络安全等级保护制度规定,按照有关 管理规范和…

JVM虚拟机的组成

一、为什么要学习 JVM ? 1. “ ⾯试造⽕箭,⼯作拧螺丝” , JVM 属于⾯试官特别喜欢提问的知识点; 2. 未来在⼯作场景中,也许你会遇到以下场景: 线上系统突然宕机,系统⽆法访问,甚⾄直…

2024年虚拟现实、图像和信号处理国际学术会议(ICVISP 2024,8月2日-4)

2024年虚拟现实、图像和信号处理国际学术会议(ICVISP 2024)将于2024年8月2-4日在中国厦门召开。ICVISP 2024将围绕“虚拟现实、图像和信号处理”的最新研究领域, 为来自国内外高等院校、科学研究所、企事业单位的专家、教授、学者、工程师等提…

迁移方案详解|使用YMP从异构数据库迁移到YashanDB

数据迁移简介 01典型场景与需求 在国产化浪潮下,数据库系统的国产化替代成为了一个日益重要的议题,有助于企业降低对外依赖,提升信息安全和自主性。 以Oracle、MySQL为代表的传统关系型数据库管理系统,在企业应用中占据了重要的…

图书馆借阅表

DDL 用户表 (Users) 图书表 (Books) 图书类别表 (BookCategories) 图书与类别关联表 (BookCategoryRelations) 借阅记录表 (BorrowRecords) 供应商表 (Suppliers) 采购记录表 (PurchaseRecords) CREATE TABLE Users (user_id INT PRIMARY KEY AUTO_INCREMENT,username …

pytorch神经网络训练(VGG-19)

VGG-19 导包 import torchimport torch.nn as nnimport torch.optim as optimimport torchvisionfrom torchvision import datasets, transformsfrom torch.utils.data import DataLoaderimport matplotlib.pyplot as plt 数据预处理和增强 transform transforms.Compose(…

在 Go 中如何让结构体不可比较?

最近我在使用 Go 官方出品的结构化日志包 slog 时,看到 slog.Value 源码中有一个比较好玩的小 Tips,可以限制两个结构体之间的相等性比较,本文就来跟大家分享下。 在 Go 中结构体可以比较吗? 在 Go 中结构体可以比较吗&#xff…

鸿蒙开发HarmonyOS NEXT(一)

最近总听见大家讨论鸿蒙,前端转型的好方向?先入门学习下 目前官方版本和文档持续更新中 一、开发环境 提示:要占用的空间比较多,建议安装在剩余空间多的盘 1、下载:官网最新工具 - 下载中心 - 华为开发者联盟 (huaw…

RTL8305NB从电口模式切换为光口模式

对于RTL8305NB,要从电口模式切换为光口模式,主要操作涉及到PHY page的切换和特定寄存器的配置。以下是详细的操作步骤: PHY Page切换: 首先,需要访问PHY地址8的寄存器31。这个寄存器用于Page的切换。向PHY地址8的寄存…

从删库到还原

欢迎来到我的博客,代码的世界里,每一行都是一个故事 🎏:你只管努力,剩下的交给时间 🏠 :小破站 从删库到还原 魔法一魔法二魔法三魔法四查看是否开启binlog,且format为row执行以下命…

WAV怎么转mp3?将wav转成MP3的几种方法介绍

WAV怎么转mp3?很多情况下,我们可能需要将高质量的 WAV 文件转换为更小、更兼容的 MP3 文件。例如,你可能想要为你的音乐收藏腾出更多存储空间,或者需要将音频文件上传到联网平台,而这些平台通常对文件大小有严格限制。…

会声会影2024免费版下载无需激活码序列号

亲爱的影像爱好者们,今天我要和大家分享的是一款让我彻底着迷的软件——会声会影2024!自从用了它,我的视频编辑技能简直突飞猛进,每次上传作品到小红书都能收获满满的赞👍。接下来,就让我带你一起探索这个神…

window系统忘记密码解决方案

原理 通过命令修改粘滞键的作用打开cmd命令,通过cmd命令修改用户密码。 1.进入系统自动恢复页面 各品牌进入恢复页面各不一样,一般按住shift重启电脑即可,笔者的惠普电脑是开机按住F11键。页面如下: 之后选择 - > 疑难解答…

阿里云nginx更新证书后依旧显示旧证书

尝试的解决办法 重启nginx服务删除服务器上的旧证书清除浏览器缓存检查是否使用CDN服务 最后的解决办法 云服务器开启了WAF服务,在WAF服务中配置证书

ssm 宠物领养系统-计算机毕业设计源码08465

目 录 摘要 1 绪论 1.1课题背景及意义 1.2研究现状 1.3ssm框架介绍 1.3论文结构与章节安排 2 宠物领养系统系统分析 2.1 可行性分析 2.2 系统流程分析 2.2.1 数据流程 3.3.2 业务流程 2.3 系统功能分析 2.3.1 功能性分析 2.3.2 非功能性分析 2.4 系统用例分析 …

web开发学习(web简单入门)

前言: 从我刚接触博客没多久我就萌发了搭建一个个人博客网站的想法(用来装逼),但碍于学校屁事太多迟迟没有开始,最近学校课已经都差不多结课了,距离期末还有一段时间,我也得以抽出时间来学习我一…

js实现blockly后台解释器,可以单步执行,可以调用c/c++函数

实现原理 解析blockly语法树,使用js管理状态,实际使用lua执行,c/c++函数调用使用lua调用c/c++函数的能力 可以单行执行 已实现if功能 TODO for循环功能 函数功能 单步执行效果图 直接执行效果图 源代码 //0 暂停 1 单步执行 2 断点 //创建枚举 var AstStatus = {PAUS…