【SHAP解释运用】基于python的树模型特征选择+随机森林回归预测+SHAP解释预测

1.导入必要的库

import pandas as pd  
import numpy as np  
import matplotlib.pyplot as plt  
import seaborn as sns  
from sklearn.model_selection import train_test_split  
from sklearn.ensemble import RandomForestRegressor  
from sklearn.tree import export_graphviz  
#from sklearn.inspection import plot_partial_dependence   
from sklearn.metrics import mean_squared_error  
import shap  
import warnings  

2.设置忽略警告与显示字体、负号

warnings.filterwarnings("ignore")  # 设置Matplotlib的字体属性  
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用于中文显示,你可以更改为其他支持中文的字体  
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号 

3.导入数据集

3.1加载数据
# 1. 加载数据  
df = pd.read_excel('train.xlsx')  
X = df.iloc[:, :-1]  # 特征  
y = df.iloc[:, -1]   # 标签  
3.2查看数据分布
1.箱线图
plt.figure(figsize=(30, 6))  
sns.boxplot(data=df)  
plt.title('Box Plots of Dataset Features', fontsize=40, color='black')  
# 如果需要设置坐标轴标签的字体大小和颜色  
plt.xlabel('X-axis Label', fontsize=20, color='red')  # 设置x轴标签的字体大小和颜色  
plt.ylabel('Y-axis Label', fontsize=20, color='green')  # 设置y轴标签的字体大小和颜色  # 还可以调整刻度线的长度、宽度等属性  
plt.tick_params(axis='x', labelsize=20, colors='blue', length=5, width=1)  # 设置x轴刻度线、刻度标签的更多属性  
plt.tick_params(axis='y', labelsize=20, colors='deepskyblue', length=5, width=1)  # 设置y轴刻度线、刻度标签的更多属性    
plt.xticks(rotation=45)  # 如果特征名很长,可以旋转x轴标签  
plt.show()

        结果如图3-1所示:

图3-1

        结果图实在丑陋,这是由数据分布不均衡造成的,这里重点不是数据清洗,就这样凑着用吧。

2.分布图
# 注意:distplot 在 seaborn 0.11.0+ 中已被移除  
# 你可以分别使用 histplot 和 kdeplot  plt.figure(figsize=(50, 10))  
for i, feature in enumerate(df.columns, 1):  plt.subplot(1, len(df.columns), i)  sns.histplot(df[feature], kde=True, bins=30, label=feature,color='blue') plt.title(f'QQ plot of {feature}', fontsize=40, color='black')  # 如果需要设置坐标轴标签的字体大小和颜色  plt.xlabel('X-axis Label', fontsize=35, color='red')  # 设置x轴标签的字体大小和颜色  plt.ylabel('Y-axis Label', fontsize=40, color='green')  # 设置y轴标签的字体大小和颜色  # 还可以调整刻度线的长度、宽度等属性  plt.tick_params(axis='x', labelsize=40, colors='blue', length=5, width=1)  # 设置x轴刻度线、刻度标签的更多属性  plt.tick_params(axis='y', labelsize=40, colors='deepskyblue', length=5, width=1)  # 设置y轴刻度线、刻度标签的更多属性 
plt.tight_layout()  
plt.show()

        结果如图3-2所示:

图3-2

3.QQ图
from scipy import stats
plt.figure(figsize=(50, 10))  
for i, feature in enumerate(df.columns, 1):  plt.subplot(1, len(df.columns), i)  stats.probplot(df[feature], dist="norm", plot=plt)  plt.title(f'QQ plot of {feature}', fontsize=40, color='black')  # 如果需要设置坐标轴标签的字体大小和颜色  plt.xlabel('X-axis Label', fontsize=35, color='red')  # 设置x轴标签的字体大小和颜色  plt.ylabel('Y-axis Label', fontsize=40, color='green')  # 设置y轴标签的字体大小和颜色  # 还可以调整刻度线的长度、宽度等属性  plt.tick_params(axis='x', labelsize=40, colors='blue', length=5, width=1)  # 设置x轴刻度线、刻度标签的更多属性  plt.tick_params(axis='y', labelsize=40, colors='deepskyblue', length=5, width=1)  # 设置y轴刻度线、刻度标签的更多属性   
plt.tight_layout()  
plt.show()

        结果如图3-3所示:

图3-3

4.树模型特征选择

# 4. 特征选择(使用随机森林的特征重要性)  
rf = RandomForestRegressor(n_estimators=100, random_state=42)  
rf.fit(X_scaled, y)  
importances = rf.feature_importances_  
indices = np.argsort(importances)[::-1]  # 可视化特征重要性  
plt.figure(figsize=(10,7))  
plt.title("Feature importances")  
plt.bar(range(X.shape[1]), importances[indices],align="center", color='cyan')
plt.xticks(range(X.shape[1]), [X.columns[i] for i in indices], rotation='vertical')  
plt.xlim([-1, X.shape[1]])  
plt.show()

        特征重要性比较如图4-1所示:

图4-1

5.随机森林回归预测

# 划分训练集和测试集  
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)  # 随机森林回归预测  
rf_final = RandomForestRegressor(n_estimators=100, random_state=42)  
rf_final.fit(X_train, y_train)  
y_pred = rf_final.predict(X_test)  
mse = mean_squared_error(y_test, y_pred)  
print(f"Mean Squared Error: {mse}")  # 预测结果输出与比对
plt.figure()
plt.plot(np.arange(21), y_test[:100], "go-", label="True value")
plt.plot(np.arange(21), y_pred[:100], "ro-", label="Predict value")
plt.title("True value And Predict value")
plt.legend()
plt.show()

        预测结果如图5-1所示:

图5-1

        由图5-1结合这里的误差Mean Squared Error: 16.092619015714185,说明预测效果很一般,不过本身数据集没有经过清洗,数据分布不合理,有这样的结果也能接受。我一般使用matlab进行数据清晰和标准化,matlab暂时打不开,先搁置,后面我会出数据标准化的文章。

5.SHAP库解释预测

5.1shap库下载安装

        这里的shap库我已经下载安装过了,没有下载安装的在pycharm终端、Anaconda Promt终端等等执行命令进行下载安装,最好带上清华镜像源,在网络信号不好时也能顺利安装且节省时间。

pip install -i https://pypi.tuna.tsinghua.edu.cn/simple shap
5.2waterfall
shap.plots.waterfall(shap_values[0]) # For the first observation

        结果如图5-1所示:

图5-1

5.3forceplot
#相互作用图
force_plot1 = shap.force_plot(explainer.expected_value, np.mean(shap_values, axis=0), np.mean(X_test, axis=0),feature_label,matplotlib=True, show=False)
shap_interaction_values = explainer.shap_interaction_values(X_test)
shap.summary_plot(shap_interaction_values,X_test)

        结果如图5-2所示:

图5-2

5.4特征影响图
shap.plots.force(explainer.expected_value,shap_values.values,shap_values.data)

        结果如图5-3所示:

图5-3

5.5特征密度散点图:summary_plot/beeswarm
5.5.1summary_plot
# 创建SHAP解释器
explainer = shap.TreeExplainer(rf)# 计算SHAP值
shap_values = explainer.shap_values(X_test)#特征标签
feature_label=['feature1','feature2','feature3','feature4','feature5','feature6','feature7']plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = 'Times New Roman'
plt.rcParams['font.size'] = 13  # 设置字体大小为14
# 现在创建 SHAP 可视化 #配色   viridis  Spectral   coolwarm  RdYlGn  RdYlBu  RdBu  RdGy  PuOr  BrBG PRGn  PiYG 
shap.summary_plot(shap_values, X_test,feature_names=feature_label)#粉红色点:表示该特征值在这个观察中对模型预测产生了正面影响(增加预测值)
#蓝色点:表示该特征值在这个观察中对模型预测产生了负面影响(降低预测值)
#水平轴(SHAP 值)显示了影响的大小。点越远离中心线(零点),该特征对模型输出的影响越大
#图中垂直排列的特征按影响力从上到下排序。上方的特征对模型输出的总体影响更大,而下方的特征影响较小。
# 最上方的特征显示了大量的正面和负面影响,表明它在不同的观察值中对模型预测的结果有很大的不同影响。
# 中部的特征也显示出两种颜色的点,但点的分布更集中,影响相对较小。
# 底部的特征对模型的影响最小,且大部分影响较为接近零,表示这些特征对模型预测的贡献较小

        结果如图5-4所示:

图5-4


# 创建SHAP解释器
explainer = shap.TreeExplainer(rf)
# 计算SHAP值
shap_values = explainer.shap_values(X_test)
#特征标签
feature_label=['feature1','feature2','feature3','feature4','feature5','feature6','feature7']
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = 'Times New Roman'
plt.rcParams['font.size'] = 13  # 设置字体大小为14
# 现在创建 SHAP 可视化 
#配色   viridis  Spectral   coolwarm  RdYlGn  RdYlBu  RdBu  RdGy  PuOr  BrBG PRGn  PiYG 
shap.summary_plot(shap_values,X_test,feature_names=feature_label,cmap='Spectral')

使颜色丰富些如图5-5所示:

图5-5

5.5.2beeswarm
# summarize the effects of all the features
# 样本决策图
shap.initjs()
shap_values = explainer(X_test)
expected_value = explainer.expected_value
shap.plots.beeswarm(shap_values)

结果如图5-6所示:

图5-6

5.6特征重要性SHAP值
shap.summary_plot(shap_values,X_test,feature_names=feature_label,plot_type='bar')
#主要表示绝对重要值的大小,把SHAP value 的样本取了绝对平均值

        或者:

shap.plots.bar(shap_values)

        结果如图5-7、图5-8所示,本质都是一样的:

图5-7

图5-8

5.7聚类热力图:heatmap plot

#热图
shap.initjs()
shap_values = explainer(X_test)
shap.plots.heatmap(shap_values)

        结果如图5-9所示:

图5-9

5.7层次聚类shap值
# 层次聚类 + SHAP值
clust = shap.utils.hclust(X, y, linkage="single")
shap.plots.bar(shap_values, clustering=clust, clustering_cutoff=1)

        结果如图5-10所示:

图5-10

5.8决策图
# 样本决策图
shap.initjs()
shap_values = explainer.shap_values(X_test)
expected_value = explainer.expected_value
shap.decision_plot(expected_value, shap_values,feature_label)

        结果如图5-11所示:

图5-11

变形1:由数值 -> 概率

# 样本决策图
shap.initjs()
shap_values = explainer.shap_values(X_test)
expected_value = explainer.expected_value
feature_label=['feature1','feature2','feature3','feature4','feature5','feature6','feature7']
shap.decision_plot(expected_value, shap_values, feature_label, link='logit')

        结果如图5-12所示:

图5-12

变形2:高亮某个样本线highlight

shap.decision_plot(expected_value, shap_values, feature_label, highlight=12)

        结果如图5-13所示:

图5-13

5.9特征依赖图:dependence_plot
5.9.1单个特征依赖
shap.dependence_plot('feature1', shap_values,X_test, interaction_index=None)

        结果如图5.14所示:

图5-14

5.9.2相互依赖图
shap.dependence_plot('feature3', shap_values,X_test, interaction_index='feature4')

        结果如图5-15所示:

图5-15

5.10相互作用图:summary_plot
shap.summary_plot(shap_interaction_values,X_test)

        结果如图5-16所示:

图5-16

具体的每种解释图的含义可以搜寻以下参考文章:

代码借鉴:http://t.csdnimg.cn/6JWrj

理论借鉴   

http://t.csdnimg.cn/6JWrj

http://t.csdnimg.cn/H9X0B

http://t.csdnimg.cn/zvtA8

http://t.csdnimg.cn/nygl6

http://t.csdnimg.cn/zyHy0

http://t.csdnimg.cn/rTPw2

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/34868.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

CesiumJS【Basic】- #013添加点线面(Primitive方式)

添加点线面(Primitive方式) 1 目标 使用图元方式添加点线面 - 绘制点 - 贴图点 - 标签 - 线 - 贴地线 - 面 - 贴地面 - 带洞的面 2 实现 2.1 PrimitiveGeometryManager.ts // src/PrimitiveGeometryManager.tsimport * as Cesium from "cesium";

Stable Diffusion 3 文本生成图像 在线体验 原理分析

前言 本文分享使用Stable Diffusion 3实现文本生成图像,可以通过在线网页中免费使用的,也有API等方式访问。 同时结合论文和开源代码进行分析,理解其原理。 Stable Diffusion 3是Stability AI开发的最新、最先进的文本生成图像模型&#x…

性能工具之 MySQL OLTP Sysbench BenchMark 测试示例

文章目录 一、前言二、测试环境1、服务器配置2、测试拓扑 三、测试工具安装四、测试步骤1、导入数据2、压测数据3、清理数据 五、结果解析六、最后 一、前言 做为一名性能工程师掌握对 MySQL 的性能测试是非常必要的,本文基于 Sysbench 对MySQL OLTP(联…

现在的Java面试都这么扯淡了吗?

在开始前刚好我有一些资料,是我根据网友给的问题精心整理了一份「java的资料从专业入门到高级教程」, 点个关注在评论区回复“666”之后私信回复“666”,全部无偿共享给大家!!!开发兼过半年面试官 刚开始…

使用js实现input框的模糊搜索

使用简单的js代码就可以实现模糊搜索功能,使用indexOf属性。 json假数据:使用数组包对象的方法,在json中设置了三列数据,分别是:名称,性别和交易额。 [{"name": "虚拟星辰","de…

【杂记-浅谈OSPF协议之STUB、NSSA区域】

OSPF协议之STUB、NSSA区域】 一、STUB区域1、STUB区域概述2、STUB区域的特点3、STUB区域的优缺点 二、NSSA区域1、NSSA区域概述2、NSSA区域的特点3、NSSA区域的配置和使用 一、STUB区域 1、STUB区域概述 STUB区域是OSPF协议中的一个特殊区域类型,主要用于网络设计…

学习java第一百一十天

请解释Spring中的事务管理? 事务管理是确保数据完整性和一致性的重要机制。在Spring框架中,事务管理可以通过声明式事务管理或编程式事务管理来实现。声明式事务管理允许我们将事务管理逻辑与业务逻辑分离,让Spring容器自动处理事务的开启、提…

Selenium与PyAutoGUI的联动:一种创新的Web自动化测试方案

在当今的软件开发周期中,自动化测试是确保软件质量和效率的关键步骤。Selenium是广泛使用的Web应用程序自动化测试工具之一,它支持多种编程语言并且能够模拟用户对浏览器的操作。然而,有些测试场景可能超出了Selenium的处理范围,例…

网络安全等级保护测评

网络安全等级保护 《GB17859 计算机信息系统安全保护等级划分准则》 规定计算机信息系统安全保护等级共分五级 《中华人民共和国网络安全法》 “国家实行网络安全等级保护制度。 等级测评 测评机构依据国家网络安全等级保护制度规定,按照有关 管理规范和…

JVM虚拟机的组成

一、为什么要学习 JVM ? 1. “ ⾯试造⽕箭,⼯作拧螺丝” , JVM 属于⾯试官特别喜欢提问的知识点; 2. 未来在⼯作场景中,也许你会遇到以下场景: 线上系统突然宕机,系统⽆法访问,甚⾄直…

2024年虚拟现实、图像和信号处理国际学术会议(ICVISP 2024,8月2日-4)

2024年虚拟现实、图像和信号处理国际学术会议(ICVISP 2024)将于2024年8月2-4日在中国厦门召开。ICVISP 2024将围绕“虚拟现实、图像和信号处理”的最新研究领域, 为来自国内外高等院校、科学研究所、企事业单位的专家、教授、学者、工程师等提…

STM32+HAL+FreeRTOS,已经修改了系统时钟为定时器,为什么还卡死在HAL_Delay()

问题 使用CubeMX创建了STM32的工程,启用了FreeRTOS,使用的是HAL库,运行后发现卡死在HAL_Delay(),修改了Timebase Source后正常了,后来加入了USB,又卡死了,参考这篇文章解决,后来我又…

迁移方案详解|使用YMP从异构数据库迁移到YashanDB

数据迁移简介 01典型场景与需求 在国产化浪潮下,数据库系统的国产化替代成为了一个日益重要的议题,有助于企业降低对外依赖,提升信息安全和自主性。 以Oracle、MySQL为代表的传统关系型数据库管理系统,在企业应用中占据了重要的…

通用VS垂直,个人观点分析。

摘要:随着人工智能技术的飞速发展,大模型的应用场景越来越广泛。在这个背景下,通用大模型和垂直大模型之间的竞争日趋激烈。本文将围绕这两个方向,探讨它们在第一个赛点中的优劣,并给出个人观点。  一、通用大模型 …

SpringMvcの拦截器全局异常处理

一、拦截器 我们在网上发贴子的时候如果没有登录,点击发送按钮会提示未进行登录,跳转到登录页面。这样的功能是如何实现的。 1、 拦截器的作用 Spring MVC 的处理器拦截器类似于Servlet开发中的过滤器Filter,用于对处理器进行预处理和后处理…

服务器卡的情况下,一般会出现什么表现状况?

1、服务器严重丢包,正常的服务器丢包率为0%,若丢包率高于1%则会出现卡的情况。 2、部分用户卡,部分用户不卡,可能由于硬件防火墙造成,部分链路堵塞。 3、另外,上述情况也可能是互联网节点故障造成。

Elasticsearch:has_child 和 has_parent 查询——父子关系查询详解

在 Elasticsearch 中,父子关系查询是一种特殊的查询类型,它允许我们在具有父子关系的文档之间进行关联查询。这种关系在树形结构或者层次化数据模型中尤为常见。Elasticsearch 提供了 has_child 和 has_parent 两种查询类型,用于在这种关系中…

掌握 Postman 监控功能:自动化测试与性能监控的秘诀

掌握 Postman 监控功能:自动化测试与性能监控的秘诀 引言 在现代软件开发中,API 的稳定性和性能至关重要。Postman,作为最受欢迎的 API 开发工具之一,提供了强大的监控功能,帮助开发者自动化测试和监控 API 的运行状…

图书馆借阅表

DDL 用户表 (Users) 图书表 (Books) 图书类别表 (BookCategories) 图书与类别关联表 (BookCategoryRelations) 借阅记录表 (BorrowRecords) 供应商表 (Suppliers) 采购记录表 (PurchaseRecords) CREATE TABLE Users (user_id INT PRIMARY KEY AUTO_INCREMENT,username …

pytorch神经网络训练(VGG-19)

VGG-19 导包 import torchimport torch.nn as nnimport torch.optim as optimimport torchvisionfrom torchvision import datasets, transformsfrom torch.utils.data import DataLoaderimport matplotlib.pyplot as plt 数据预处理和增强 transform transforms.Compose(…