第15周:RNN心脏病预测

目录

前言

二、前期准备

2.1 设置GPU

2.2 导入数据

2.2.1 数据介绍

2.2.2 导入代码

2.2.3 检查数据

三、数据预处理

3.1 划分训练集与测试集

3.2 标准化

四、构建RNN模型

4.1 基本概念

4.2 搭建代码

五、编译模型

六、训练模型

七、模型评估

总结


前言

  • 🍨 本文为[🔗365天深度学习训练营](https://mp.weixin.qq.com/s/0dvHCaOoFnW8SCp3JpzKxg) 中的学习记录博客
  • 🍖 原作者:[K同学啊](https://mtyjkh.blog.csdn.net/)

说在前面

本周目标:本地读取并加载数据、了解循环神经网络(RNN)的构建过程、调整代码是的测试机acuuracy达到87%;拔高目标——测试集accuracy达到89%

我的环境:Python3.8、Pycharm2020、tensorflow2.4.0

数据来源:[K同学啊](https://mtyjkh.blog.csdn.net/)

代码的流程图如下:


一、RNN简介

传统神经网络结构比较简单是输入层——隐藏层——输出层,而RNN与传统神经网络最大的区别在于每次都会将前一次的输出结果,带到下一次的隐藏层中,一起训练。如下图所示,左图为传统神经网络,右图为RNN

 以一个案例具体分析RNN工作过程,用户说了一句“what time is it?”,我们的神经网络首先会将这句话分为五个基本单元(四个单词➕一个问号);然后按照顺序将5个基本单元输入RNN网络,what作为RNN的输入得到输出01,按照顺序将“time”输入RNN网络,得到输出02,这个过程中可以看到输入“time”的时候,前面“what”的输出也会对02的输出产生了影响(如下图中所示,隐藏层中有一半是黑色的),依次类推,前面所有的输入产生的结果都对后续的输出产生了印象(下图中最后的圆形中就包含了前面所有的颜色) 

当神经网络判断意图的时候,只需要最后一层的输出05,如下图所示

                               

循环神经网络(RNN)是一类用于处理序列数据的神经网络。不同于传统的前馈神经网络,RNN 能够处理序列长度变化的数据,如文本、语音等。RNN 的特点是在模型中引入循环,使得网络能够保持某种状态,从而在处理序列数据时表现出更好的性能。

上图左边简单描述 RNN 的原理,x 是输入层,o 是输出层,中间 s 是隐藏层,在 s 层进行一个循环,右边表示展开循环看到的逻辑,其实是和时间 t 相关的一个状态变化,也就是说神经网络在处理数据的时候,能看到前一时刻、后一时刻的状态,也就是常说的上下文

二、前期准备

2.1 设置GPU

代码如下:

#一、前期准备
#1.1 导入所需包和设置GPU
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # 不显示等级2以下的提示信息
import tensorflow as tf
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,LSTM,SimpleRNN
import matplotlib.pyplot as pltgpus = tf.config.list_physical_devices("GPU")if gpus:gpu0 = gpus[0]                                        #如果有多个GPU,仅使用第0个GPUtf.config.experimental.set_memory_growth(gpu0, True)  #设置GPU显存用量按需使用tf.config.set_visible_devices([gpu0],"GPU")
print(gpus)

2.2 导入数据

2.2.1 数据介绍

  • age:1)年龄
  • sex:2)性别
  • cp:3)胸痛类型(4 values)
  • trestbps:4)静息血压
  • chol:5)血清胆甾淳(mg/dl)
  • fbs:6)空腹血糖>120mg/dl
  • restecg:7)静息心电图结果(值0,1,2)
  • thalach:8)达到的最大心率
  • exang:9)运动诱发的心绞痛
  • olddpeak:10)相对静止状态,运动引起的ST段压低
  • slope:11)运动峰值ST段的斜率
  • ca:12)荧光透视着色的主要血管数量(0-3)
  • thal:13)0=正常,1=固定缺陷;2=可逆转的缺陷
  • target:14)0=心脏病发作的几率较小,1=心脏病发作的几率更大

2.2.2 导入代码

#1.2 导入数据
df = pd.read_csv('heart.csv')
print(df)

2.2.3 检查数据

检查是否存在空值

df.isnull().sum()  #检查是否有空值

数据打印显示如下

三、数据预处理

3.1 划分训练集与测试集

补充:测试集与验证集的关系——①验证集并没有参与训练中梯度下降的过程,狭义上来讲是没有参与模型的参数训练更新的;②但广义上来说,验证集存在的意义确实参与了一个“人工调参”的过程,我们根据每一个epoch训练之后的模型在vaild data上的表现来决定是否需要训练进行early stop,或者根据这个过程模型的性能变化来调整模型的超参数,如学习率,batch_size等等;③所以也可以认为,验证集也参与了训练,但是并没有使得模型去overfit验证集

代码如下:

#二、数据预处理
#2.1 数据集划分
x = df.iloc[:,:-1]
y = df.iloc[:,-1]x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.1, random_state=1)
print(x_train.shape, y_train.shape)

打印输出:(272, 13) (272,)

3.2 标准化

代码如下:

# 将每一列特征标准化为标准正态分布,注意,标准化是针对每一列而言的
sc = StandardScaler()
x_train = sc.fit_transform(x_train)
x_test = sc.transform(x_test)x_train = x_train.reshape(x_train.shape[0], x_train.shape[1], 1)
x_test = x_test.reshape(x_test.shape[0], x_test.shape[1], 1)

四、构建RNN模型

4.1 基本概念

函数原型:tf.keras.layers.SimpleRNN(units,activation='tanh',use_bias=True,kernel_initializer='glorot_uniform',recurrent_initializer='orthogonal',bias_initializer='zeros',kernel_regularizer=Noe,recurrent_regularizer=Noe,bias_regularizer=None,activity_regularizer=None,keenel_constraint=None,recurrent_constraint=None,bias_constraint=None,dropout=0.0,recurrent_dropout=0.0,return_sequences=False,return_state=False,go_backwards=False,stateful=False,unroll=False,**kwargs)

关键参数说明:

  • units——正整数,输出空间的维度
  • activation——要使用的激活函数,默认为双曲正切(tanh),如果传入None,则不使用激活函数(即线性激活a(x)=x)
  • use_bias——布尔值,该层是否使用偏置向量
  • kernel_initializer——kernel权值矩阵的初始化器,用于输入的线性转换
  • recurrent_initializer——recurrent_kernel权值矩阵的初始化器,用于循环层状态的线性转换
  • bias_initializer——偏置向量的初始化器
  • dropout:在-0和1之间的浮点数,单元的丢弃比例,用于输入的线性转换

4.2 搭建代码

#三、构建RNN模型model = Sequential()
model.add(SimpleRNN(128, input_shape= (13,1),return_sequences=True,activation='relu'))
model.add(SimpleRNN(64,return_sequences=True, activation='relu'))
model.add(SimpleRNN(32, activation='relu'))
model.add(Dense(64, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
model.summary()

模型输出如下:

五、编译模型

代码如下:

#四、编译模型
opt = tf.keras.optimizers.Adam(learning_rate=1e-4)
model.compile(loss='binary_crossentropy', optimizer=opt,metrics=['accuracy'])

六、训练模型

代码如下:

#五、训练模型
epochs = 100
history = model.fit(x_train, y_train,epochs=epochs,batch_size=128,validation_data=(x_test, y_test),verbose=1)

训练过程:

七、模型评估

代码如下

#六、模型评估
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']loss = history.history['loss']
val_loss = history.history['val_loss']epochs_range = range(epochs)plt.figure(figsize=(14, 4))
plt.subplot(1, 2, 1)plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()scores = model.evaluate(x_test,y_test,verbose=0)
print("%s: %.2f%%" % (model.metrics_names[1], scores[1]*100))

打印结果:

accuracy: 90.32%


总结

RNN实战应用,是一种用于处理序列数据的神经网络,了解了基于Tensorflow搭建RNN的过程;学习了对于文本类数据,是怎么将其数字化。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/35269.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

直播怎么录制视频?直播视频,3种录制方法

“今晚我最喜欢的游戏博主要进行直播,但我可能还要加班。怎么办,不想错过直播的内容!电脑怎么才能进行直播录制视频啊?谁能教教我?” 在数字化的今天,直播已经成为人们获取信息和娱乐的重要途径。有时&…

执行yum命令报错Could not resolve host: mirrors.cloud.aliyuncs.com; Unknown error

执行yum命令报错 [Errno 14] curl#6 - "Could not resolve host: mirrors.cloud.aliyuncs.com; Unknown error 修改图中所示两个文件: vim epel.repo vim CentOS-Base.repo 将所有的http://mirrors.cloud.aliyuncs.com 修改为http://mirrors.aliyun.com。 修改…

趣测系统搭建APP源码开发,娱乐丰富生活的选择!

文章目录 前言一、pandas是什么?二、使用步骤 1.引入库2.读入数据总结 前言 趣测系统提供了一个集合多种有趣测试的平台,如心理测试和星座测试等,这些测试内容富有趣味性和娱乐性,能够帮助大众在忙碌的生活中找到放松和娱乐的时刻…

yolov5驾驶员不规范行为检测

1 项目介绍 1.1 摘要 随着汽车工业的迅速发展和交通拥堵的加剧,驾驶员在行车过程中的不规范行为成为了导致交通事故频发的重要因素之一。为了减少交通事故的发生,保障道路安全,提高驾驶员的行车安全意识,本研究致力于实现驾驶员…

5. PyTorch+NCCL源码编译

系列文章 第1章 多机多卡运行nccl-tests 和channel获取第2章 多机多卡nccl-tests 对比分析第3章 使用tcpdump抓取rdma数据包第5章 PyTorchNCCL源码编译 目录 系列文章前言一、本地环境二、安装cudnn三、使用pytorch自带NCCL库进行编译安装1. 源码编译2. 查看版本和all_reduce测…

【机器学习】机器学习重要方法——迁移学习:理论、方法与实践

文章目录 迁移学习:理论、方法与实践引言第一章 迁移学习的基本概念1.1 什么是迁移学习1.2 迁移学习的类型1.3 迁移学习的优势 第二章 迁移学习的核心方法2.1 特征重用(Feature Reuse)2.2 微调(Fine-Tuning)2.3 领域适…

【启明智显分享】典型的HMI应用实现方案:帮你更好地主控选型!

HMI是操作者与机器/系统间资讯传递和交换的主要桥梁。HMI系统通常能提供丰富的资讯,例如温度、压力、制造流程步骤以及材料的计量数据。还能显示设备中物料的确切位置或储存槽内的液位数据等讯息。无论是在工业自动化还是医疗、商业等重要行业领域,HMI都…

【前端项目笔记】6 参数管理

参数管理 效果展示: 在开发功能之前先创建分支goods_params cls 清空终端 git branch 查看所有分支 git checkout -b goods_params 新建分支goods_params git push -u origin goods_params 把本地的新分支推送到云端origin并命名为goods_params 参数管理需要维…

一个易于使用、与Android系统良好整合的多合一游戏模拟器

大家好,今天给大家分享的是一个易于使用、与Android系统良好整合的多合一游戏模拟器 Lemuroid。 Lemuroid 是一个专为Android平台设计的开源游戏模拟器项目,它基于强大的Libretro框架,旨在提供广泛的兼容性和卓越的用户体验。 项目介绍 Lem…

如何安装多版本CUDA?

首先聊一个题外话:前几天在csdn上看到的一个话题”安装pytorch一定要去nvidia官网下载安装cuda和cudnn吗?“ 我相信任何一个刚开始接触或者从事深度学习的炼丹者都会从安装cuda开始,现在网上随便一搜如何安装pytorch,蹦出来教程提…

pd虚拟机 Parallels Desktop 19 for Mac 破解版小白安装使用指南

Parallels Desktop 19 for Mac 乃是一款适配于 Mac 的虚拟化软件。它能让您在 Mac 计算机上同时运行多个操作系统。您可借此创建虚拟机,并于其中装设不同的操作系统,如 Windows、Linux 或 macOS。使用 Parallels Desktop 19 mac 版时,您可在 …

无线麦克风推荐哪些品牌,一文揭秘无线麦克风领夹哪个牌子好!

​究竟该如何选择麦克风呢?又该如何挑选无线麦克呢?询问我关于麦克风选择问题的人着实不少。对于那些仅仅是想要简单地自我娱乐的朋友而言,着实没必要去折腾,直接使用手机自带的麦克风便可以了。 但若是处于想要直播、拍摄短视频…

【Termius】详细说明MacOS中的SSH的客户端利器Termius

希望文章能给到你启发和灵感~ 如果觉得有帮助的话,点赞+关注+收藏支持一下博主哦~ 阅读指南 开篇说明一、基础环境说明1.1 硬件环境1.2 软件环境二、软件的安装2.1 Termius界面介绍2.1.1 Hosts 主机列表2.1.2 SFTP 文件传输2.1.3 Port ForWarding 端口转发2.1.4 Snippets 片…

为什么带货主播,他突然就不吃香了?

为什么带货主播他突然就不吃香了?工资骤降50%。 相比 2023 年初主播的平均薪资降了50%,那不管你是头部主播还是腰部主播,全部都降薪了。那尾部主播就更不用说了,有的主播他的时薪已经低到 20 块钱一个小时,还不如大学…

U-boot相关基础知识

U-boot和Bootloader之间的关系 U-Boot是Bootloader的一种实现,它专门用于嵌入式系统,特别是那些基于ARM、MIPS等处理器的系统。U-Boot提供了丰富的硬件支持和功能,使得开发者能够轻松地初始化硬件、加载操作系统内核,并进行一些基…

【漏洞复现】安美数字酒店宽带运营系统——命令执行漏洞(CNVD-2021-37784)

声明:本文档或演示材料仅供教育和教学目的使用,任何个人或组织使用本文档中的信息进行非法活动,均与本文档的作者无关。 文章目录 漏洞描述漏洞复现测试工具 漏洞描述 安美数字酒店宽带运营系统 server_ping.php 存在远程命令执行漏洞&#…

一文带你了解什么是【点击劫持】

点击劫持,意思就是你点击网页的时候,有人劫持你,对没错,劫持你的信息,甚至劫持你的马内,劫持你的理想,劫持你的肉体,劫持你的灵魂。就是这么可怕。 目录 1 如何实现假网站 1.1 if…

“未来独角兽” | 安全狗入选福建省数字经济核心产业创新企业名单

近日,福建省数据管理局公布了入选2024年度全省数字经济核心产业创新企业名单。 作为国内云原生安全领导厂商,安全狗凭借自身在云安全领域的卓越表现和创新实力入选,获得“未来独角兽”称号。 据悉,此次对“未来独角兽”的评选条件…

计算机视觉(CV)技术:优势、挑战与前景

摘要 计算机视觉作为人工智能的关键领域之一,正迅速改变我们的生活和工作方式。本文将探讨CV技术的主要优势、面临的挑战以及未来的发展方向。 关键词 计算机视觉, 人工智能, 数据处理, 自动化, 伦理问题 目录 引言计算机视觉技术的优势计算机视觉技术的挑战实…

内网穿透小工具

内网穿透小工具 前言 当在本地或者虚拟机,内网搭建了项目,数据库。可是在外网无法访问。下面的两款小工具可以暂时实现内网穿透能力。(不支持自定义域名,但是不限制隧道数量!且免费!免费!免费…