从0搭建ECG深度学习网络

本篇博客介绍使用Python语言的深度学习网络,从零搭建一个ECG深度学习网络。

任务

本次入门的任务是,筛选出MIT-BIH数据集中注释为[‘N’, ‘A’, ‘V’, ‘L’, ‘R’]的数据作为本次数据集,然后按照8:2的比例划分为训练集,验证集。最后送入RCNN模型进行训练。

1. 数据集介绍

本次使用大名鼎鼎的MIT-BIH Arrhythmia Database数据集。下载地址:https://physionet.org/content/mitdb/1.0.0/

MIT系列有很多数据集,都可以在生理网:https://physionet.org/about/database/ 上找到下载地址。本次使用的MT-BIH心律失常数据库拥有48条心电记录,且每个记录的时长是30分钟。这些记录来自于47名研究对象。这些研究对象包括25名男性和22名女性,其年龄介于23到89岁(其中记录201与202来自于同一个人)。信号的采样率为360赫兹,AD分辨率为11比特。对于每条记录来说,均包含两个通道的信号。第一个通道一般为MLⅡ导联(记录102和104为V5导联);第二个通道一般为V1导联(有些为V2导联或V5导联,其中记录124号为Ⅴ4导联)。为了保持导联的一致性,往往在研究中采用MLⅡ导联。

在生理网:https://physionet.org/about/database/上,我们可以看到数据集更加详细的说明。比如:

MIT-BIH 数据集每个单独病人的说明:https://www.physionet.org/physiobank/database/html/mitdbdir/mitdbdir.htm

MIT-BIH 数据集每个单独病人的整个数据以及注释的可视化:https://www.physionet.org/physiobank/database/html/mitdbdir/mitdbdir.htm

下载MIT-BIH 数据集之后,我们需要知晓以下几点:

  1. 从100-234不连续号码,一共48个病人。每个病人有三个文件(.dat,.atr,*.hea),包含有两路心电信号,一个注释。
  2. 有专门库读取MIT-BIH 数据集,叫做 wfdb。所以不要担心文件后缀的陌生感。
  3. 对心电图的标注样式如上图,“A"代表心房早搏,”."代表正常。整个数据集标注有40多种符号,表示40多种心拍状态。

2. 提取数据集

提取之前,先安装必要的库wfdb。wfdb详细介绍

pip install wfdb

这个库非常强大,打印数据信息,读取数据,绘制心电波形图,都可以靠它完成。
现在我们的划分步骤是:

  1. 提取出所有心电图数据点,心电图注释点
  2. 筛选出所有心电图注释点中仅为[‘N’, ‘A’, ‘V’, ‘L’, ‘R’]某一类的注释点
  3. 截取心电图数据中标记为[‘N’, ‘A’, ‘V’, ‘L’, ‘R’]某一类的点,在点周围长度为300的数据
  4. 将得到的数据进行维度处理,送入DataLoader()函数,完成模型对数据的认可。

3. 定义模型

本次使用的模型是输入大小为300,3层循环,隐藏层大小50。

'''
模型搭建
'''
class RnnModel(nn.Module):def __init__(self):super(RnnModel, self).__init__()'''参数解释:(输入维度,隐藏层维度,网络层数)'''self.rnn = nn.RNN(300, 50, 3, nonlinearity='tanh')self.linear = nn.Linear(50, 5)def forward(self, x):r_out, h_state = self.rnn(x)output = self.linear(r_out[:,-1,:])  # 将 RNN 层的输出 r_out 在最后一个时间步上的输出(隐藏状态)传递给线性层return outputmodel = RnnModel()

4. 全部训练代码

'''
导入相关包
'''
import wfdb
import pywt
import seaborn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import torch
import torch.utils.data as Data
from torch import nn'''
加载数据集
'''# 测试集在数据集中所占的比例
RATIO = 0.2# 小波去噪预处理
def denoise(data):# 小波变换coeffs = pywt.wavedec(data=data, wavelet='db5', level=9)cA9, cD9, cD8, cD7, cD6, cD5, cD4, cD3, cD2, cD1 = coeffs# 阈值去噪threshold = (np.median(np.abs(cD1)) / 0.6745) * (np.sqrt(2 * np.log(len(cD1))))cD1.fill(0)cD2.fill(0)for i in range(1, len(coeffs) - 2):coeffs[i] = pywt.threshold(coeffs[i], threshold)# 小波反变换,获取去噪后的信号rdata = pywt.waverec(coeffs=coeffs, wavelet='db5')return rdata# 读取心电数据和对应标签,并对数据进行小波去噪
def getDataSet(number, X_data, Y_data):ecgClassSet = ['N', 'A', 'V', 'L', 'R']# 读取心电数据记录# print("正在读取 " + number + " 号心电数据...")# 读取MLII导联的数据record = wfdb.rdrecord('C:/mycode/dataset_make/mit-bih-arrhythmia-database-1.0.0/' + number, channel_names=['MLII'])data = record.p_signal.flatten()rdata = denoise(data=data)# 获取心电数据记录中R波的位置和对应的标签annotation = wfdb.rdann('C:/mycode/dataset_make/mit-bih-arrhythmia-database-1.0.0/' + number, 'atr')Rlocation = annotation.sampleRclass = annotation.symbol# 去掉前后的不稳定数据start = 10end = 5i = startj = len(annotation.symbol) - end# 因为只选择NAVLR五种心电类型,所以要选出该条记录中所需要的那些带有特定标签的数据,舍弃其余标签的点# X_data在R波前后截取长度为300的数据点# Y_data将NAVLR按顺序转换为01234while i < j:try:# Rclass[i] 是标签lable = ecgClassSet.index(Rclass[i])  # 这一步就是相当于抛弃了不在ecgClassSet的索引# 基于经验值,基于R峰向前取100个点,向后取200个点x_train = rdata[Rlocation[i] - 100:Rlocation[i] + 200]X_data.append(x_train)Y_data.append(lable)i += 1except ValueError:i += 1return# 加载数据集并进行预处理
def loadData():numberSet = ['100', '101', '103', '105', '106', '107', '108', '109', '111', '112', '113', '114', '115','116', '117', '119', '121', '122', '123', '124', '200', '201', '202', '203', '205', '208','210', '212', '213', '214', '215', '217', '219', '220', '221', '222', '223', '228', '230','231', '232', '233', '234']dataSet = []lableSet = []for n in numberSet:getDataSet(n, dataSet, lableSet)# 转numpy数组,打乱顺序dataSet = np.array(dataSet).reshape(-1, 300)  # 转化为二维,一行有 300 个,行数需要计算lableSet = np.array(lableSet).reshape(-1, 1)  # 转化为二维,一行有   1 个,行数需要计算train_ds = np.hstack((dataSet, lableSet))  # 将数据集和标签集水平堆叠在一起,(92192, 300) (92192, 1) (92192, 301)# print(dataSet.shape, lableSet.shape, train_ds.shape)  # (92192, 300) (92192, 1) (92192, 301)np.random.shuffle(train_ds)# 数据集及其标签集X = train_ds[:, :300].reshape(-1, 1, 300)  # (92192, 1, 300)Y = train_ds[:, 300]  # (92192)# 测试集及其标签集shuffle_index = np.random.permutation(len(X))  # 生成0-(X-1)的随机索引数组# 设定测试集的大小 RATIO是测试集在数据集中所占的比例test_length = int(RATIO * len(shuffle_index))# 测试集的长度test_index = shuffle_index[:test_length]# 训练集的长度train_index = shuffle_index[test_length:]X_test, Y_test = X[test_index], Y[test_index]X_train, Y_train = X[train_index], Y[train_index]return X_train, Y_train, X_test, Y_testX_train, Y_train, X_test, Y_test = loadData()'''
数据处理
'''
train_Data = Data.TensorDataset(torch.Tensor(X_train), torch.Tensor(Y_train)) # 返回结果为一个个元组,每一个元组存放数据和标签
train_loader = Data.DataLoader(dataset=train_Data, batch_size=128)
test_Data = Data.TensorDataset(torch.Tensor(X_test), torch.Tensor(Y_test)) # 返回结果为一个个元组,每一个元组存放数据和标签
test_loader = Data.DataLoader(dataset=test_Data, batch_size=128)'''
模型搭建
'''
class RnnModel(nn.Module):def __init__(self):super(RnnModel, self).__init__()'''参数解释:(输入维度,隐藏层维度,网络层数)'''self.rnn = nn.RNN(300, 50, 3, nonlinearity='tanh')self.linear = nn.Linear(50, 5)def forward(self, x):r_out, h_state = self.rnn(x)output = self.linear(r_out[:,-1,:])  # 将 RNN 层的输出 r_out 在最后一个时间步上的输出(隐藏状态)传递给线性层return outputmodel = RnnModel()'''
设置损失函数和参数优化方法
'''
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)'''
模型训练
'''
EPOCHS = 5
for epoch in range(EPOCHS):running_loss = 0for i, data in enumerate(train_loader):inputs, label = datay_predict = model(inputs)loss = criterion(y_predict, label.long())optimizer.zero_grad()loss.backward()optimizer.step()running_loss += loss.item()# 预测correct = 0total = 0with torch.no_grad():for data in test_loader:inputs, label = datay_pred = model(inputs)_, predicted = torch.max(y_pred.data, dim=1)total += label.size(0)correct += (predicted == label).sum().item()print(f'Epoch: {epoch + 1}, ACC on test: {correct / total}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/40390.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

什么是DNS服务器的层次化和分布式?

DNS (Domain Name System) 的结构是层次化的&#xff0c;意味着它是由多个级别的服务器组成&#xff0c;每个级别负责不同的部分。以下是 DNS 结构的层次&#xff1a; 根域服务器&#xff08;Root Servers&#xff09;&#xff1a; 这是 DNS 层次结构的最高级别。全球有13组根域…

【云原生】Docker 详解(二):Docker 架构及工作原理

Docker 详解&#xff08;二&#xff09;&#xff1a;Docker 架构及工作原理 Docker 在运行时分为 Docker 引擎&#xff08;服务端守护进程&#xff09; 和 客户端工具&#xff0c;我们日常使用各种 docker 命令&#xff0c;其实就是在使用 客户端工具 与 Docker 引擎 进行交互。…

[oneAPI] 手写数字识别-LSTM

[oneAPI] 手写数字识别-LSTM 手写数字识别参数与包加载数据模型训练过程结果 oneAPI 比赛&#xff1a;https://marketing.csdn.net/p/f3e44fbfe46c465f4d9d6c23e38e0517 Intel DevCloud for oneAPI&#xff1a;https://devcloud.intel.com/oneapi/get_started/aiAnalyticsToolk…

Curson 编辑器

Curson 汉化与vacode一样 Curson 自带chat功能 1、快捷键ctrlk(代码中编辑) 2、快捷键ctrll 右侧打开窗口

小程序项目组件的基本应用

宿主环境&#xff1a;程序运行必须依赖的环境 小程序的宿主环境 ---->手机微信(定位、扫码、支付等) 小程序的通信模型&#xff1a; 渲染层和逻辑层之间的通信(微信客户端转发)逻辑层和第三方服务器之间的通信(微信客户端转发) 小程序的运行机制&#xff1a; 启动&#xff1…

vue基础知识五:请描述下你对vue生命周期的理解?在created和mounted这两个生命周期中请求数据有什么区别呢?

一、生命周期是什么 生命周期&#xff08;Life Cycle&#xff09;的概念应用很广泛&#xff0c;特别是在政治、经济、环境、技术、社会等诸多领域经常出现&#xff0c;其基本涵义可以通俗地理解为“从摇篮到坟墓”&#xff08;Cradle-to-Grave&#xff09;的整个过程在Vue中实…

Python opennsfw/opennsfw2 图片/视频 鉴黄 笔记

nsfw&#xff08; Not Suitable for Work&#xff09;直接翻译就是 工作的时候不适合看&#xff0c;真文雅 nsfw效果&#xff0c;注意底部的分数 大体流程&#xff0c;输入图片/视频&#xff0c;输出0-1之间的数字&#xff0c;一般情况下&#xff0c;Scores < 0.2 认为是非…

7zip分卷压缩

前言 有些项目上传文件大小有限制 压缩包大了之后传输也会比较慢 解决方案 我们可以利用7zip压缩工具对文件进行分卷压缩 利用7zip压缩工具进行分卷压缩 查看待压缩文件大小 压缩完成之后有300多M&#xff0c;我们用100M去进行分卷压缩 选择待压缩的文件夹&#xff0c;右…

网络安全 Day30-运维安全项目-容器架构上

容器架构上 1. 什么是容器2. 容器 vs 虚拟机(化) :star::star:3. Docker极速上手指南1&#xff09;使用rpm包安装docker2) docker下载镜像加速的配置3) 载入镜像大礼包&#xff08;老师资料包中有&#xff09; 4. Docker使用案例1&#xff09; 案例01&#xff1a;:star::star::…

《内网穿透》无需公网IP,公网SSH远程访问家中的树莓派

文章目录 前言 如何通过 SSH 连接到树莓派步骤1. 在 Raspberry Pi 上启用 SSH步骤2. 查找树莓派的 IP 地址步骤3. SSH 到你的树莓派步骤 4. 在任何地点访问家中的树莓派4.1 安装 Cpolar内网穿透4.2 cpolar进行token认证4.3 配置cpolar服务开机自启动4.4 查看映射到公网的隧道地…

【JavaEE基础学习打卡02】是时候了解Java EE了!

目录 前言一、为什么要学习Java EE二、Java EE规范介绍1.什么是规范&#xff1f;2.什么是Java EE规范&#xff1f;3.Java EE版本 三、Java EE应用程序模型1.模型前置说明2.模型具体说明 总结 前言 &#x1f4dc; 本系列教程适用于 Java Web 初学者、爱好者&#xff0c;小白白。…

java接口导出csv

1、背景介绍 项目中需要导出数据质检结果&#xff0c;本来使用Excel&#xff0c;但是质检结果数据行数过多&#xff0c;导致用hutool报错&#xff0c;因此转为导出csv格式数据。 2、参考文档 https://blog.csdn.net/ityqing/article/details/127879556 工程环境&#xff1a;…

Redis-分布式锁!

分布式锁&#xff0c;顾名思义&#xff0c;分布式锁就是分布式场景下的锁&#xff0c;比如多台不同机器上的进程&#xff0c;去竞争同一项资源&#xff0c;就是分布式锁。 分布式锁特性 互斥性:锁的目的是获取资源的使用权&#xff0c;所以只让一个竞争者持有锁&#xff0c;这…

【算法】排序+双指针——leetcode三数之和、四数之和

三数之和 &#xff08;1&#xff09;排序双指针 算法思路&#xff1a; 和之前的两数之和类似&#xff0c;我们对暴力枚举进行了一些优化&#xff0c;利用了排序双指针的思路&#xff1a; 我们先排序&#xff0c;然后固定⼀个数 a &#xff0c;接着我们就可以在这个数后面的区间…

OpenCV实例(九)基于深度学习的运动目标检测(一)YOLO运动目标检测算法

基于深度学习的运动目标检测&#xff08;一&#xff09; 1.YOLO算法检测流程2.YOLO算法网络架构3.网络训练模型3.1 训练策略3.2 代价函数的设定 2012年&#xff0c;随着深度学习技术的不断突破&#xff0c;开始兴起基于深度学习的目标检测算法的研究浪潮。 2014年&#xff0c;…

Davinci 报表工具 0.3.0-rc release 文本框模糊查询不生效问题

背景: 在使用过程中发现davinci 的控制器配置中, 取值配置的对应关系设置 包含 或 不包含时 不生效, 不能实现模糊匹配效果, 只能精确查询; 问题分析: 通过跟踪接口及相应代码, 发现在sql 拼接时没有对 like 和 not like 类型的值两侧添加百分号, 导致模糊查询失败 调用过程…

CentOS系统环境搭建(七)——Centos7安装MySQL

centos系统环境搭建专栏&#x1f517;点击跳转 坦诚地说&#xff0c;本文中百分之九十的内容都来自于该文章&#x1f517;Linux&#xff1a;CentOS7安装MySQL8&#xff08;详&#xff09;&#xff0c;十分佩服大佬文章结构合理&#xff0c;文笔清晰&#xff0c;我曾经在这篇文章…

Kotlin 使用 View Binding

解决的问题&#xff1a; 《第一行代码——Android》第三版 郭霖 P277 视图绑定的问题 描述&#xff1a; kotlin-android-extensions 插件已经弃用 butter knife 已经弃用 解决办法 推荐使用 View Binding 来代替 findViewById 使用方法 1、配置 build.gradle 2、在act…

绝对值函数的可导性

绝对值函数的可导性 声明&#xff1a;下面截图来自《考研数学常考题型解题方法技巧归纳》

利用Figlet工具创建酷炫Linux Centos8服务器-登录欢迎界面-SHELL自动化编译安装代码

因为我们需要生成需要的特定字符,所以需要在当前服务器中安装Figlet,默认没有安装包的,其实如果我们也只要在一台环境中安装,然后需要什么字符只要复制到需要的服务器中,并不需要所有都安装。同样的,我们也可以利用此生成的字符用到脚本运行的开始起头部分,用ECHO分行标…