【MATLAB第112期】基于MATLAB的SHAP可解释神经网络回归模型(敏感性分析方法)

【MATLAB第112期】基于MATLAB的SHAP可解释神经网络回归模型(敏感性分析方法)

引言

该文章实现了一个可解释的神经网络回归模型,使用BP神经网络(BPNN)来预测特征输出。该模型利用七个变量参数作为输入特征进行训练。为了提高可解释性,应用了SHapley Additive exPlanations(SHAP),去深入了解每个参数对模型预测的贡献。

一、案例数据

1、导入数据

res = xlsread('数据集.xlsx');  %103行样本,7输入,1输出
x = res (:,1:end-1); %   
y = res(:,end); % 最后一列为输出

2、数据标准化
该部分使用mapminmax函数对输入和输出数据进行标准化,将数据缩放到[-1, 1]范围内。

% 输入数据归一化
[x_norm, x_settings] = mapminmax(x',-1,1);
% 输出数据归一化
[y_norm, y_settings] = mapminmax(y',-1,1);normalization_x = x_settings;
save ('normalization_x.mat', 'x_settings');
normalization_y = y_settings;
save ('normalization_y.mat', 'y_settings');x_norm_t = x_norm';
y_norm_t = y_norm';

输入数据标准化:输入特征被标准化,标准化设置(x_settings)保存在名为normalization_x.mat的文件中,以便后续使用或反转标准化.
输出数据标准化:同样,输出数据被标准化,标准化设置(y_settings)保存在名为normalization_y.mat的文件中. 标准化后的数据被转置回原始方向,以保持模型进一步处理的一致性. 此步骤确保输入和输出数据适当缩放,以便于神经网络训练,从而有助于提高模型性能和收敛速度.

二、交叉验证和模型评估

该部分执行5折交叉验证以评估基于优化超参数构建的模型性能.

1、交叉验证设置
脚本使用K折交叉验证,numFolds = 5,将数据分成5个子集(折)。在每次迭代中,一个子集用于测试,其余子集用于训练模型.

2、模型训练和测试
对于每个折,使用cvpartition生成的索引将训练和测试数据分开. 使用BP神经网络(BPNN)训练模型,超参数设置:

 % 训练模型neuron = 5;%%  创建网络net = newff(trainData', trainLabels', neuron);%%  设置训练参数net.trainParam.epochs = 1000;     % 迭代次数 net.trainParam.goal = 1e-6;       % 误差阈值net.trainParam.lr = 0.01;         % 学习率

3、解释交叉验证结果
RMSE:交叉验证后,可以计算所有折的平均RMSE。如果所有折的RMSE值一致且相对较低,则表明模型对未见数据具有良好的泛化能力。如果RMSE值在各折之间变化较大,可能表明模型对训练数据敏感,这可能是过拟合的迹象.
R²(决定系数):R²值也应在所有折之间进行平均,以评估模型的拟合优度。较高的R²值表明模型能够解释目标变量的更大比例的方差。如果R²值较低,则可能表明模型未能很好地捕捉输入特征与目标之间的关系.
最终模型选择:完成交叉验证后,可以通过所有折的平均RMSE和R²总结模型的整体性能。这有助于选择在偏差和方差之间具有最佳权衡的模型.
在这里插入图片描述

在这里插入图片描述

通过训练数据集评估模型的预测性能
选择最优数据集进行可视化(折数=3)
在这里插入图片描述

在这里插入图片描述

三、SHAP分析

1、生成随机数据
在本部分,生成一组合成输入数据用于SHAP分析。这种合成数据允许在受控和一致的方式下评估模型的特征贡献。步骤包括:

样本数量:脚本设置生成的合成样本数量为80(numSamples = 80).
特征范围: 定义操作参数在特定范围内,选择训练数据中各个输入变量的最大值和最小值

VarMin =  [137.0000         0         0  160.0000    4.7000  708.0000  640.6000]
VarMax =  1.0e+03 *[    0.3660    0.1930    0.2600    0.2400    0.0190    1.0495    0.9020]

随机数据生成: 使用rand函数在定义的范围内为每个特征生成随机值,创建80个样本.

for i=1:size(x,2)
x_shap(:,i)=VarMin(i)+ (VarMax(i) - VarMin(i)) * rand(numSamples, 1);
end

此生成数据用于评估SHAP值并分析每个特征如何影响模型的预测。生成随机输入数据确保了SHAP分析中特征值的广泛范围,便于更全面地评估特征重要性.

2、计算SHAP值
该代码计算神经网络模型的SHapley Additive exPlanations(SHAP)值。SHAP值量化了每个特征对模型预测的贡献。该过程包括:

  1. 预分配SHAP值矩阵:初始化一个矩阵以存储所有输入样本和特征的SHAP值.
    2.计算参考值:将参考值计算为所有输入特征的平均值,用于在排除或包含特征时进行比较.
    3.计算SHAP值:对于每个输入样本,使用自定义的shapley_ann函数计算SHAP值,该函数迭代所有可能的特征组合以确定每个特征对预测的贡献.
    4.自定义的shapley函数接受一个训练好的神经网络(net)、当前输入样本和参考值来计算每个特征的SHAP值。该方法提供了对单个特征如何影响模型输出的洞察.
% ------------------------------------
function shapValues = shapley(net, x_shap, refValue) % 假设您有一个名为'net'的训练好的网络% 使用Shapley公式计算SHAP值如果有7个特征,则依次分析每个特征的累计贡献值当分析第1个特征时,排除当前特征,即 1  0  0  0  0  0  0迭代所有可能的特征组合 for i=12^(D-1)xt1: 每个样本的特征变量输入值(处理后)   1*7xt2: 计算的每个样本平均值(处理后)       1*7xt3: 当分析不同特征时,将该特征值替换为平均值。  1*7shapValues=shapValues+net(xt3)-net(xt2)   end

3、可视化
------蜂群图:为每个特征创建散点图(蜂群图),显示所有样本的SHAP值。特征值被标准化并颜色编码以提高可解释性.
包括轴标签、网格、框以提高清晰度以及带有操作参数标签的颜色条. 此SHAP摘要图有助于理解哪些特征对模型的预测影响最大以及特征在样本中的变化情况.显示每个特征对模型预测的贡献。
在这里插入图片描述

-----条形图
计算平均绝对SHAP值:计算每个特征的绝对SHAP值的平均值,以量化每个特征的整体重要性.
条形图可视化:创建一个水平条形图,特征按其平均绝对SHAP值排序。这提供了模型中特征重要性的清晰、排序表示. 结果的SHAP摘要条形图有助于识别哪些特征对模型的预测影响最大.
在这里插入图片描述

四、代码获取

1.阅读首页置顶文章
2.关注CSDN
3.根据自动回复消息,私信回复“112期”以及相应指令,即可获取对应下载方式。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/892330.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

iOS 本地新项目上传git仓库,并使用sourceTree管理

此文记录的场景描述: iOS前期开发时,在本地创建项目,直至开发一段时间,初期编码及框架已完善后,才拿到git仓库的地址。此时需要将本地代码上传到git仓库。 上传至git仓库,可以使用终端,键入命令…

为深度学习引入张量

为深度学习引入张量 什么是张量? 神经网络中的输入、输出和转换都是使用张量表示的,因此,神经网络编程大量使用张量。 张量是神经网络使用的主要数据结构。 张量的概念是其他更具体概念的数学概括。让我们看看一些张量的具体实例。 张量…

欧拉公式和傅里叶变换

注:英文引文机翻,未校。 中文引文未整理去重,如有异常,请看原文。 Euler’s Formula and Fourier Transform Posted byczxttkl October 7, 2018 Euler’s formula states that e i x cos ⁡ x i sin ⁡ x e^{ix} \cos{x} i …

HDFS编程 - 使用HDFS Java API进行文件操作

文章目录 前言一、创建hdfs-demo项目1. 在idea上创建maven项目2. 导入hadoop相关依赖 二、常用 HDFS Java API1. 简介2. 获取文件系统实例3. 创建目录4. 创建文件4.1 创建文件并写入数据4.2 创建新空白文件 5. 查看文件内容6. 查看目录下的文件或目录信息6.1 查看指定目录下的文…

高德地图调用geoserver提供WMTS服务报错Unknown TILEMATRIX问题

1. 高德地图JSAPI要求WMTS必须是EPSG:3857坐标系 2. 高德调用WMTS服务时参数 TileMatrix中未带有坐标系字段,需要修改geoserver源码兼容一下,修改JSAPI也可以,如你用都用离线的话 leaflet加载geoserver的WMTS服务时TILEMATRIX字段 TILEMATR…

C语言——文件IO 【文件IO和标准IO区别,操作文件IO】open,write,read,dup2,access,stat

1.思维导图 2.练习 1:使用C语言编写一个简易的界面,界面如下 1:标准输出流 2:标准错误流 3:文件流 要求:按1的时候,通过printf输出数据,按2的时候,通过p…

C++实现图书管理系统(Qt C++ GUI界面版)

前瞻 本项目基于【C】图书管理系统(完整版) 图书管理系统功能概览: 登录,注册学生,老师借书,查看自己当前借书情况,还书。管理员增加书,查看当前借阅情况,查看当前所有借阅人,图书信息。 效果…

使用 NestJS 构建高效且模块化的 Node.js 应用程序,从安装到第一个 API 端点:一步一步指南

一、安装 NestJS 要开始构建一个基于 NestJS 的应用,首先需要安装一系列依赖包。以下是必要的安装命令: npm i --save nestjs/core nestjs/common rxjs reflect-metadata nestjs/platform-express npm install -g ts-node包名介绍nestjs/coreNestJS 框…

鸿蒙面试 2025-01-09

鸿蒙分布式理念?(个人认为理解就好) 鸿蒙操作系统的分布式理念主要体现在其独特的“流转”能力和相关的分布式操作上。在鸿蒙系统中,“流转”是指涉多端的分布式操作,它打破了设备之间的界限,实现了多设备…

Mysql--基础篇--SQL(DDL,DML,窗口函数,CET,视图,存储过程,触发器等)

SQL(Structured Query Language,结构化查询语言)是用于管理和操作关系型数据库的标准语言。它允许用户定义、查询、更新和管理数据库中的数据。SQL是一种声明性语言,用户只需要指定想要执行的操作,而不需要详细说明如何…

SQL 幂运算 — POW() and POWER()函数用法详解

POW() and POWER()函数用法详解 POW() 和 POWER() —计算幂运算(即一个数的指定次方)的函数。 这两个函数是等价的,功能完全相同,只是名字不同。 POW(base, exponent); POWER(base, exponent); base:底数。exponen…

Elasticsearch:聚合操作

这里写目录标题 一、聚合的概述二、聚合的分类1、指标聚合(Metric Aggregation)2、桶聚合(Bucket Aggregation)3、管道聚合(Pipeline Aggregation) 三、ES聚合分析不精准原因分析四、聚合性能优化1、ES聚合…

Ubuntu 磁盘修复

Ubuntu 磁盘修复 在 ubuntu 文件系统变成只读模式,该处理呢? 文件系统内部的错误,如索引错误、元数据损坏等,也可能导致系统进入只读状态。磁盘坏道或硬件故障也可能引发文件系统只读的问题。/etc/fstab配置错误,可能…

重新整理机器学习和神经网络框架

本篇重新梳理了人工智能(AI)、机器学习(ML)、神经网络(NN)和深度学习(DL)之间存在一定的包含关系,以下是它们的关系及各自内容,以及人工智能领域中深度学习分支对比整理。…

LabVIEW瞬变电磁接收系统

利用LabVIEW软件与USB4432采集卡开发瞬变电磁接收系统。系统通过改进硬件配置与软件编程,解决了传统仪器在信噪比低和抗干扰能力差的问题,实现了高精度的数据采集和处理,特别适用于地质勘探等领域。 ​ 项目背景: 瞬变电磁法是探…

Redis 优化秒杀(异步秒杀)

目录 为什么需要异步秒杀 异步优化的核心逻辑是什么? 阻塞队列的特点是什么? Lua脚本在这里的作用是什么? 异步调用创建订单的具体逻辑是什么? 为什么要用代理对象proxy调用createVoucherOrder方法? 对于代码的详细…

C++笔记之`size_t`辨析

C++笔记之size_t辨析 code review! 文章目录 C++笔记之`size_t`辨析一.什么是 `size_t`?二.`size_t` 的来源和设计目的三.`size_t` 的应用场景四.`size_t` 的优点五.`size_t` 的缺点和注意事项六.`size_t` 和其他类型的比较七.总结与建议在 C/C++ 中,size_t 是一个非常重要的…

MySQL表的增删查改(下)——Update(更新),Delete(删除)

文章目录 Update将孙悟空同学的数学成绩修改为80分将曹孟德同学的数学成绩变更为 60 分,语文成绩变更为 70 分将总成绩倒数前三的 3 位同学的数学成绩加上 30 分将所有同学的语文成绩更新为原来的 2 倍 Delete删除数据删除孙悟空同学的考试成绩删除整张表数据 截断表…

大语言模型训练的数据集从哪里来?

继续上篇文章的内容说说大语言模型预训练的数据集从哪里来以及为什么互联网上的数据已经被耗尽这个说法并不专业,再谈谈大语言模型预训练数据集的优化思路。 1. GPT2使用的数据集是WebText,该数据集大概40GB,由OpenAI创建,主要内…

【hadoop学习遇见的小问题】clone克隆完之后网络连接不上问题解决

vi /etc/udev/rules.d/70-persistent-net.rules注释掉第一行 第二行的eth1 改为eth0 由上图也可以看到物理地址 记录下来在网卡中修改物理地址 vi /etc/sysconfig/network-scripts/ifcfg-eth0修改完之后 重启reboot 即可