Pytorch实现心跳信号分类识别(支持LSTM,GRU,TCN模型)

 Pytorch实现心跳信号分类识别(支持LSTM,GRU,TCN模型)

目录

 Pytorch实现心跳信号分类识别(支持LSTM,GRU,TCN模型)

1. 项目说明

2. 数据说明

(1)心跳信号分类预测数据集

3. 模型训练

(1)项目安装

(2)准备Train和Test数据

(3)配置文件:​config.yaml​

(4)开始训练

(5)可视化训练过程

(6)一些优化建议

(7)一些运行错误处理方法

4. 模型测试效果

5. 项目源码下载


1. 项目说明

本项目将基于深度学习Pytorch,搭建一个心跳信号分类识别的训练和测试项目,实现一个简单的信号分类识别系统;项目网络模型支持LSTM,GRU和TCN等常见的模型,用户也可以自定义其他模型,进行训练和测试。采用GRU模型,在心跳信号分类预测数据集上,验证集的准确率99.3600%。

ae3b750e38a64e738927c1526778c4cb.png

模型准确率
LSTM

97.7000

TCN

96.1600

GRU

99.3600

【尊重原创,转载请注明出处】https://blog.csdn.net/guyuealian/article/details/142714205

2. 数据说明

(1)心跳信号分类预测数据集

项目数据来源于某比赛《心跳信号分类预测》,该赛题主要任务是预测心电图心跳信号类别,提供了总数超过20万心电图数据记录,主要为1列心跳信号序列数据,其中每个样本的信号序列采样频次一致,长度相等。为了保证比赛的公平性,将会从中抽取10万条作为训练集,2万条作为测试集A,2万条作为测试集B,同时会对心跳信号类别(label)信息进行脱敏。

FieldDescription
id为心跳信号分配的唯一标识
heartbeat_signals心跳信号序列
label心跳信号类别(0、1、2、3)

比赛提供的数据的都是一维vector形式的信号,且数据已被归一化至 0~1 了,数据总长度均为 205 (205 个时间节点/心跳节拍),数据非常理想,无须进行填充和异常情况处理,当然在项目训练和开发中,建议加上信号数据增强,提高模型的泛化性。

赛题提供train.csv和testA.csv数据,其中train.csv提供了label标签可用于模型训练,testA.csv用于打榜比赛,为了方面模型本地开发和调优,本项目将train.csv文件中的前5000行数据作为验证集val.csv,剩余的数据作为训练集train.csv。


3. 模型训练

(1)项目安装

整套工程基本框架结构如下:

.
├── core                 # 训练模型核心代码
├── configs              # 训练配置文件
├── data                 # 项目相关数据
├── libs                 # 项目依赖的相关库
├── demo.py              # 模型推理demo
├── README.md            # 项目工程说明文档
├── requirements.txt     # 项目相关依赖包
└── train.py             # 训练文件

   项目依赖python包请参考requirements.txt,使用pip安装即可:

# python3.8
imgaug==0.4.0
numpy==1.21.6
matplotlib==3.1.0
Pillow==9.5.0
easydict==1.9
onnx==1.14.0
onnx-simplifier==0.4.33
onnxruntime==1.15.1
onnxruntime-gpu==1.15.1
onnxsim==0.4.33
opencv-contrib-python==4.8.1.78
opencv-python==4.8.0.76
pandas==1.1.5
PyYAML==5.3.1
Pillow==9.5.0
scikit-image==0.21.0
scikit-learn==1.2.2
scipy==1.10.1
seaborn==0.12.2
tensorboard==2.13.0
tensorboardX==2.6.1
torch==1.13.1+cu117
torchvision==0.14.1+cu117
tqdm==4.55.1
xmltodict==0.12.0
basetrainer
pybaseutils

 项目安装教程请参考(初学者入门,麻烦先看完下面教程,配置好开发环境):

  • 项目开发使用教程和常见问题和解决方法
  • 视频教程:1 手把手教你安装CUDA和cuDNN(1)
  • 视频教程:2 手把手教你安装CUDA和cuDNN(2)
  • 视频教程:3 如何用Anaconda创建pycharm环境
  • 视频教程:4 如何在pycharm中使用Anaconda创建的python环境

(2)准备Train和Test数据

下载项目数据集,train.csv和val.csv

数据增强方式主要采用: 随机增加噪声,随机平移,裁剪填充,数据归一化等处理方式

# -*-coding: utf-8 -*-
"""@Author : Pan@E-mail : 390737991@qq.com@Date   : 2021-08-02 14:33:33
"""
import numbers
import random
import numpy as np
from core.transforms import data_augmentdef data_transform(seq_size, dim_size, trans_type="train"):"""x's shape (batch_size, seq_size(序列长度), dim_size(序列中每个数据的长度)):param seq_size::param dim_size::param trans_type::return::"""if trans_type == "train":transform = data_augment.Compose([data_augment.RandomNoise(low=0.0, high=1.0, w=0.03),data_augment.RandomShift(shift=(-20, 5), low=0.0, high=1.0, w=0.001),data_augment.CropPadding(size=seq_size),data_augment.Normalize(),])elif trans_type == "val" or trans_type == "test":transform = data_augment.Compose([data_augment.CropPadding(size=seq_size),data_augment.Normalize(),])else:raise Exception("transform_type ERROR:{}".format(trans_type))return transform

修改配置文件数据路径:​config.yaml​

data_type: "signal"
# 训练数据集,可支持多个数据集
train_data: "/home/user/to/心跳信号分类预测/train.csv"
# 测试数据集
test_data: "/home/user/to/心跳信号分类预测/val.csv"
# 类别文件
class_name: [ 0,1,2,3 ]

(3)配置文件:​config.yaml​

  • 模型支持,LSTM,GRU和TCN等模型,用户也可以自定义模型,进行模型训练和测试。
  • 训练参数可以通过(configs/config.yaml)配置文件进行设置
  • 损失函数支持交叉熵CrossEntropy,LabelSmoothing以及FocalLoss等损失函数

 配置文件:​config.yaml​说明如下:

data_type: "signal"
# 训练数据集,可支持多个数据集
train_data: "data/train.csv"
# 测试数据集
test_data: "data/val.csv"
# 类别文件
class_name: [ 0,1,2,3 ]
train_transform: "train"       # 训练使用的数据增强方法
test_transform: "val"          # 测试使用的数据增强方法
work_dir: "work_space/"        # 保存输出模型的目录
net_type: "GRU"               # 骨干网络,支持:TCN,GRU,LSTM
batch_size: 256                # 训练batch-size
seq_size: 205                  # 模型输入序列长度
dim_size: 1                    # 模型输入特征数据维度
lr: 0.001                      # 初始学习率
optim_type: "AdamW"            # 选择优化器,SGD,Adam
#loss_type: "CrossEntropyLoss"  # 选择损失函数:支持CrossEntropyLoss
loss_type: "LabelSmoothingCrossEntropy"  # 选择损失函数:支持CrossEntropyLoss
momentum: 0.9                  # SGD momentum
num_epochs: 120                # 训练循环次数
num_workers: 8                 # 加载数据工作进程数
weight_decay: 0.0005           # weight_decay,默认5e-4
scheduler: "multi-step"        # 学习率调整策略
milestones: [ 40,80 ]          # 下调学习率方式
gpu_id: [ 0 ]                  # GPU ID
pretrained: True               # 是否使用pretrained模型
finetune: False                # 是否进行finetune

(4)开始训练

整套训练代码非常简单操作,用户只需要将相同类别的数据放在同一个目录下,并填写好对应的数据路径,即可开始训练了。

python train.py -c configs/config.yaml 

8abb6a7894984f169cfd5b3989aa61b1.png

训练完成后,在心跳信号分类预测数据集上,验证集的Accuracy在99%左右,下表给出LSTM,GRU和TCN等常用模型验证集的准确率:

模型准确率
LSTM

97.7000

TCN

96.1600

GRU

99.3600

(5)可视化训练过程

训练过程可视化工具是使用Tensorboard,使用方法,可参考这里:项目开发使用教程和常见问题和解决方法
在终端输入:
# 基本方法
tensorboard --logdir=path/to/log/
# 例如
tensorboard --logdir=work_space/GRU_LabelSmoothingCrossEntropy_20241004_223428_9301/log

可视化效果 

9cf309d5b40446c5801025091456c3c0.png

684d69473ccf4fa9958bd89783db9aa4.png

70e31395b8c742bda6f3dbf55cc4e0d5.png

3633e060393c418296daab267f20a688.png

8ceb953f365f4650b0a08f32d9fdaed5.png

(6)一些优化建议

如果想进一步提高准确率,可以尝试:

  1. 样本均衡: 建议进行样本均衡处理,避免长尾问题
  2. 调超参: 比如学习率调整策略,优化器(SGD,Adam等)
  3. 损失函数: 目前训练代码已经支持:交叉熵CrossEntropy,LabelSmoothing,可以尝试FocalLoss等损失函数

(7)一些运行错误处理方法

  • 项目不要出现含有中文字符的目录文件或路径,否则可能会出现很多异常!!!!!!!!


4. 模型测试效果

 demo.py文件用于推理和测试模型的效果,填写好配置文件,模型文件以及测试数据即可运行测试了

def get_parser():# 配置文件config_file = "work_space/GRU_LabelSmoothingCrossEntropy_20241004_223428_9301/config.yaml"# 模型文件model_file = "work_space/GRU_LabelSmoothingCrossEntropy_20241004_223428_9301/model/best_model_093_99.3600.pth"# 测试数据data_file = "data/val.csv"parser = argparse.ArgumentParser(description="Inference Argument")parser.add_argument("-c", "--config_file", help="configs file", default=config_file, type=str)parser.add_argument("-m", "--model_file", help="model_file", default=model_file, type=str)parser.add_argument("--data_file", help="data file", default=data_file, type=str)parser.add_argument("--device", help="cuda device id", default="cuda:0", type=str)return parser
#!/usr/bin/env bash
# Usage:
# python demo.py  -c "path/to/config.yaml" -m "path/to/model.pth" --data_file "path/to/data_file"python demo.py -c work_space/GRU_LabelSmoothingCrossEntropy_20241004_223428_9301/config.yaml -m work_space/GRU_LabelSmoothingCrossEntropy_20241004_223428_9301/model/best_model_093_99.3600.pth --data_file data/val.csv

运行测试结果: 

e487644f47a8490c9dc670b579f1e18b.png
f50ecf577be448c4b94671b7c3d54f1b.png
f1e081cd91464241951297194295ebea.png

61399cad72de4258be7ac2185104d801.png


5. 项目源码下载

【源码下载】请关注【AI吃大瓜】,回复关键字【心跳信号】

  • 项目提供心跳信号分类预测数数据集: 赛题提供train.csv和testA.csv数据,其中train.csv提供了label标签可用于模型训练,testA.csv用于打榜比赛,为了方面模型本地开发和调优,本项目将train.csv文件中的前5000行数据作为验证集val.csv,剩余的数据作为训练集train.csv。
  • 项目提供网络模型支持LSTM,GRU和TCN等常见的模型
  • 项目提供训练代码,损失函数支持交叉熵CrossEntropy,LabelSmoothing以及FocalLoss等损失函数
  • 项目提供已经训练好的模型,无需重新训练,即可运行demo.py测试效果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/62913.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

十,[极客大挑战 2019]Secret File1

点击进入靶场 查看源代码 有个显眼的紫色文件夹,点击 点击secret看看 既然这样,那就回去查看源代码吧 好像没什么用 抓个包 得到一个文件名 404 如果包含"../"、"tp"、"input"或"data",则输出"…

Windows远程桌面连接到Linux

我的电脑是一台瘦客户端,公司设置的不能安装其他软件,里面只有几个软件,还好有一个远程桌面(Remote Desktop Connection),我想连接到另一台Linux的电脑上。 在Linux上安装xrdp: sudo apt insta…

视觉处理基础1

目录 一、CNN 1. 概述 1.1 与传统网络的区别 1.2 全连接的局限性 1.3 卷积思想 1.4 卷积的概念 1.4.1 概念 1.4.2 局部连接 1.4.3 权重共享 2. 卷积层 2.1 卷积核 2.2 卷积计算 2.3 边缘填充 2.4 步长Stride 2.5 多通道卷积计算 2.7 特征图大小计算方法 2…

泛化调用 :在没有接口的情况下进行RPC调用

什么是泛化调用? 在RPC调用的过程中,调用端向服务端发起请求,首先要通过动态代理,动态代理可以屏蔽RPC处理流程,使得发起远程调用就像调用本地一样。 RPC调用本质:调用端向服务端发送一条请求消息&#x…

C++ 之弦上舞:string 类与多样字符串操作的优雅旋律

string 类的重要性及与 C 语言字符串对比 在 C 语言中,字符串是以 \0 结尾的字符集合,操作字符串需借助 C 标准库的 str 系列函数,但这些函数与字符串分离,不符合 OOP 思想,且底层空间管理易出错。而在 C 中&#xff0…

【大数据学习 | Spark调优篇】Spark之内存调优

1. 内存的花费 1)每个Java对象,都有一个对象头,会占用16个字节,主要是包括了一些对象的元信息,比如指向它的类的指针。如果一个对象本身很小,比如就包括了一个int类型的field,那么它的对象头实…

使用OpenCV和卡尔曼滤波器进行实时活体检测

引言 在现代计算机视觉应用中,实时检测和跟踪物体是一项重要的任务。本文将详细介绍如何使用OpenCV库和卡尔曼滤波器来实现一个实时的活体检测系统。该系统能够通过摄像头捕捉视频流,并使用YOLOv3模型来检测目标对象(例如人)&…

【closerAI ComfyUI】物体转移术之图案转移,Flux三重控制万物一致性生图,实现LOGO和图案的精准迁移

更多AI前沿科技资讯,请关注我们:closerAI-一个深入探索前沿人工智能与AIGC领域的资讯平台 closerAIGCcloserAI,一个深入探索前沿人工智能与AIGC领域的资讯平台,我们旨在让AIGC渗入我们的工作与生活中,让我们一起探索AIGC的无限可能性! 【closerAI ComfyUI】物体转移术之图…

2025软考高级《系统架构设计师》案例模拟题合集

首先分享一下系统架构设计师资料合集,有历年真题、自学打卡表、精华知识点等,需要的留邮,打包分享! 1、在设计基于混合云的安全生产管理系统中,需要重点考虑5个方面的安全问题。设备安全、网络安全、控制安全、应用安全…

rpm包转deb包或deb包转rpm包

Debian系(Ubuntu、Deepin、麒麟Destop等)用的安装包是deb的,Red Hat系(CentOS、欧拉、麒麟Server等)用的安装包是rpm的。 如果需要在Ubuntu上安装rpm,或需要在CentOS上安装deb,需要安装alien s…

【C语言】递归的内存占用过程

递归 递归是函数调用自身的一种编程技术。在C语言中,递归的实现会占用内存栈(Call Stack),每次递归调用都会在栈上分配一个新的 “栈帧(Stack Frame)”,用于存储本次调用的函数局部变量、返回地…

数据仓库的概念

先用大白话讲一下,数据仓库的主要目的就是存储和分析大量结构化数据的。 > 那么它的核心目的是:支持商业智能(BI)和决策支持系统,也就是说,它不仅仅是为了存储,更重要的是为了分析提供便利。…

LeetCode 438.找到字符串中所有字母异位词

LeetCode 438.找到字符串中所有字母异位词 思路🧐: 需要找到子串异位词,也就是只看该子串是否有相同字母而不管位置是否相同。分析题目发现只需要单调向前找异位词,则可以使用滑动窗口求解,注意这里每当左右边框长度大…

算法刷题Day8:BM30 二叉搜索树与双向链表

题目 牛客网题目传送门 思路 对二叉搜索树进行中序遍历,结果就是按序数组。因此想办法把前面遍历过的节点给记下来,记作pre。当遍历到某个节点node的时候,令前驱指向pre,然后让pre的后驱指向node。 代码 class TreeNode:def…

1.Git安装与常用命令

前言 Git中会用到的一些基本的Linux命令 ls/ll 查看文件目录 (ll可以看隐藏文件)cat 查看文件内容touch 创建文件vi vi编辑器 1.下载与安装 安装成功后鼠标右键会出现Git Bash和Git GUI Git GUI:GUI图形化界面 Git Bash:Git提供的命令行工具 当安装…

ultralytics-YOLOv11的目标检测解析

1. Python的调用 from ultralytics import YOLO import os def detect_predict():model YOLO(../weights/yolo11n.pt)print(model)results model(../ultralytics/assets/bus.jpg)if not os.path.exists(results[0].save_dir):os.makedirs(results[0].save_dir)for result in…

【docker】docker compose多容器部署

Docker Compose 的详细讲解与实际应用 什么是 Docker Compose? Docker Compose 是一个工具,用于定义和运行多容器 Docker 应用。 通过一个 docker-compose.yml 文件,可以同时启动多个服务,简化多容器管理。 Docker Compose 的核心…

【AI系统】CANN 算子类型

CANN 算子类型 算子是编程和数学中的重要概念,它们是用于执行特定操作的符号或函数,以便处理输入值并生成输出值。本文将会介绍 CANN 算子类型及其在 AI 编程和神经网络中的应用,以及华为 CANN 算子在 AI CPU 的详细架构和开发要求。 算子基…

C++:特殊类设计及类型转换

目录 一.常见特殊类的设计方式 1.请设计一个类,不能被拷贝 2.请设计一个类,只能在堆上创建对象 3.请设计一个类,只能在栈上创建对象 4.请设计一个类,不能被继承 5.请设计一个类,只能创建一个对象(单例模式) 二.C…

GPT打字机效果—— fetchEventSouce进行sse流式请求

EventStream基本用法 与 WebSocket 不同的是,服务器发送事件是单向的。数据消息只能从服务端到发送到客户端(如用户的浏览器)。这使其成为不需要从客户端往服务器发送消息的情况下的最佳选择。 const evtSource new EventSource(“/api/v1/…