【Week-R1】RNN实现心脏病预测,基于tensorflow框架

文章目录

  • 一、什么是RNN?
  • 二、准备环境和数据
    • 2.1 导入数据
  • 三、构建模型
  • 四、训练和预测
  • 五、其他
    • (1)sklearn模块导入报错:`ModuleNotFoundError: No module named 'sklearn'`
    • (2)优化器改为SGD,accuracy=25.81%
    • (3)使用训练acc最高的模型进行预测

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制

一、什么是RNN?

RNN:Recurrent Neural Network,用于处理序列数据。和传统神经网络不同的点在于,当前一层的输出会被当做输入,带到下一个隐藏层中,进行训练,于是除了第一层,RNN中每一个隐藏层的输入都包含两个部分【上一层的输出和当前层的输入】,如教案中给出的简易示意图,每个单词用一种颜色表示,01~05为不同的隐藏层,到达最后一层得到的输出为 05,也是神经网络需要判断的层。
在这里插入图片描述

二、准备环境和数据

环境:tensorflow框架,py312,cpu
编译:VSCode

使用CPU进行编译,就无需再设置GPU

2.1 导入数据

根据给出的数据集文件,分析如下:
在这里插入图片描述
在这里插入图片描述

import tensorflow as tfgpus = tf.config.list_physical_devices("GPU")if gpus:gpu0 = gpus[0]tf.config.experimental.set_memory_growth(gpu0,true)tf.config.set_visible_devices([gpu0],"GPU")print("GPU: ",gpus)
else:print("CPU:")# 2.1 导入数据
import numpy as np
import pandas as pddf = pd.read_csv("D:\\jupyter notebook\\DL-100-days\\RNN\\heart.csv")
print("df: ", df)# 2.2 检查是否有空值
df.isnull().sum()
#3.1 划分训练集与测试集
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_splitx = df.iloc[:, :-1]
y = df.iloc[:, -1]
x_train, x_test, y_train, y_test = train_test_split(x,y,test_size=0.1,random_state=1)
print("x_train.shape: ", x_train.shape)
print("y_train.shape: ", y_train.shape)# 3.2 标准化: 针对每一列进行标准化
sc = StandardScaler()
x_train = sc.fit_transform(x_train)
x_test = sc.transform(x_test)x_train = x_train.reshape(x_train.shape[0], x_train.shape[1], 1)
x_test = x_test.reshape(x_test.shape[0], x_test.shape[1], 1)

在这里插入图片描述

三、构建模型

tf官方教程 Keras中的循环神经网络RNN 一文中有提到:
在这里插入图片描述
在这里插入图片描述

本次学习使用的是SimpleRNN内置层,其关键参数说明:

● units: 正整数,输出空间的维度。● activation: 要使用的激活函数。 默认:双曲正切(tanh)。 如果传入 None,则不使用激活函数 (即 线性激活:a(x) = x)。● use_bias: 布尔值,该层是否使用偏置向量。● kernel_initializer: kernel 权值矩阵的初始化器, 用于输入的线性转换 (详见 initializers)。● recurrent_initializer: recurrent_kernel 权值矩阵 的初始化器,用于循环层状态的线性转换 (详见 initializers)。● bias_initializer:偏置向量的初始化器 (详见initializers).● dropout: 在 0 和 1 之间的浮点数。 单元的丢弃比例,用于输入的线性转换。
# 4.1 构建RNN模型
import keras
from keras.models import Sequential
from keras.layers import Dense,LSTM,SimpleRNNmodel = Sequential()
model.add(SimpleRNN(200, input_shape=(13,1), activation='relu'))
model.add(Dense(100, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
print(model.summary())

在这里插入图片描述

四、训练和预测

#4.2 编译模型
opt = tf.keras.optimizers.Adam(learning_rate=1e-4)
model.compile(loss='binary_crossentropy',optimizer = opt,metrics="accuracy")# 4.3 训练模型 
epochs = 100
history = model.fit(x_train, y_train,epochs=epochs,batch_size=128,validation_data=(x_test,y_test),verbose=1)
#4.4 评估模型 
import matplotlib.pyplot as plt
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']loss = history.history['loss']
val_loss = history.history['val_loss']epochs_range = range(epochs)plt.figure(figsize=(14,4))
plt.subplot(1,2,1)
plt.plot(epochs_range, acc, label='Training Acuuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training & Validation Accuracy')plt.subplot(1,2,2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Losss')
plt.legend(loc='upper right')
plt.title('Training & Validation Loss')plt.savefig("D:\\jupyter notebook\\DL-100-days\\RNN\\result.png")
plt.show()scores = model.evaluate(x_test,y_test,verbose=0)
print("%s: %.2f%%" % (model.metrics_names[1], scores[1]*100))

在这里插入图片描述
请添加图片描述

五、其他

(1)sklearn模块导入报错:ModuleNotFoundError: No module named 'sklearn'

解决办法:在VSCode终端中,cd到.venv/Script路径下,执行.\pip install scikit-learn,等待安装完成,如下:
在这里插入图片描述

(2)优化器改为SGD,accuracy=25.81%

请添加图片描述
在这里插入图片描述

(3)使用训练acc最高的模型进行预测

观察100个epoch输出的训练结果,可以看到最高的val_accuracy=0.9032,可以把这次的模型保存出来,作为预测模型。
在这里插入图片描述
修改代码如下:
在这里插入图片描述
得到结果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/16073.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux系统之GoAccess实时Web日志分析工具的基本使用

Linux系统之GoAccess实时Web日志分析工具的基本使用 一、GoAccess介绍1.1 GoAccess简介1.2 GoAccess功能1.3 Web日志格式 二、本地环境介绍2.1 本地环境规划2.2 本次实践介绍 三、检查本地环境3.1 检查本地操作系统版本3.2 检查系统内核版本3.3 检查系统镜像源3.4 更新软件列表…

JavaFX安装与使用

前言 最近学习了javafx,开始时在配置环境和导包时遇到了一些麻烦,关于网上很多方法都尝试过了,现在问题都解决了,和大家分享一下我是怎么实现javafx的配置,希望大家可以通过这个方法实现自己的环境配置! 🙈个人主页: 心.c 🔥文章专题:javafx &#x1f49…

计算机网络-Traffic-Filter流量过滤策略

一、概述 为提高网络安全性,管理人员需要控制进入网络的流量,将不信任的报文丢弃在网络边界。所谓的不信任报文是指对用户来说存在安全隐患或者不愿意接收的报文。同时保证数据访问安全性,企业网络中经常会要求一些部门之间不能相互访问。 背…

服务器数据恢复—同友存储raid5阵列上层虚拟机数据恢复案例

服务器数据恢复环境: 某市教育局同友存储,存储中有一组由数块磁盘组建的raid5阵列,存储空间划分若干lun。每个lun中有若干台虚拟机,其中有数台linux操作系统的虚拟机为重要数据。 存储结构: 服务器故障: r…

slam14讲(第9,10讲 后端)

slam14讲(第9,10讲 后端) 后端分类基于滤波器的后端线性系统和卡尔曼滤波非线性系统和扩展卡尔曼滤波 BA优化H矩阵的稀疏性和边缘化H矩阵求解的总结 位姿图优化公式推导 基于滑动窗口的后端个人见解旧关键帧的边缘化 后端分类 基于滤波器的后…

AtCoder Beginner Contest 355 A~F

A.Who Ate the Cake?(思维) 题意 已知有三个嫌疑人,有两个证人,每个证人可以指出其中一个嫌疑人不是罪犯,如果可以排除两个嫌疑人来确定犯人,输出犯人的身份,如果无法确定,输出"-1"。 分析 …

springboot + Vue前后端项目(第十一记)

项目实战第十一记 1.写在前面2. 文件上传和下载后端2.1 数据库编写2.2 工具类CodeGenerator生成代码2.2.1 FileController2.2.2 application.yml2.2.3 拦截器InterceptorConfig 放行 3 文件上传和下载前端3.1 File.vue页面编写3.2 路由配置3.3 Aside.vue 最终效果图总结写在最后…

TabAttention:基于表格数据的条件注意力学习

文章目录 TabAttention: Learning Attention Conditionally on Tabular Data摘要方法实验结果 TabAttention: Learning Attention Conditionally on Tabular Data 摘要 医疗数据分析通常结合成像数据和表格数据处理,使用机器学习算法。尽管先前的研究探讨了注意力…

Hudi 多表摄取工具 HoodieMultiTableStreamer 配置方法与示例

博主历时三年精心创作的《大数据平台架构与原型实现:数据中台建设实战》一书现已由知名IT图书品牌电子工业出版社博文视点出版发行,点击《重磅推荐:建大数据平台太难了!给我发个工程原型吧!》了解图书详情,…

vue3添加收藏网站页面

结构与样式 <template><div class"web_view"><ul><li v-for"web in webList" :key"web.title"><a :href"web.src" :title"web.title" target"_blank"><img :src"web.img&…

微信小程序基础 -- 小程序UI组件(5)

小程序UI组件 1.小程序UI组件概述 开发文档&#xff1a;https://developers.weixin.qq.com/miniprogram/dev/framework/view/component.html 什么是组件&#xff1a; 组件是视图层的基本组成单元。 组件自带一些功能与微信风格一致的样式。 一个组件通常包括 开始标签 和 结…

Cyber Weekly #8

赛博新闻 1、微软召开年度发布会Microsoft Build 2024 本周&#xff08;5.22&#xff09;微软召开了年度发布会&#xff0c;Microsoft Build 2024&#xff0c;发布了包括大杀器 Copilot Studio 在内的 50 项更新。主要包括&#xff1a; 硬件层面&#xff1a;与英伟达 & A…

3D牙科网格分割使用基于语义的特征学习与图变换器

文章目录 3D Dental Mesh Segmentation Using Semantics-Based Feature Learning with Graph-Transformer摘要方法实验结果 3D Dental Mesh Segmentation Using Semantics-Based Feature Learning with Graph-Transformer 摘要 本文提出了一种新颖的基于语义的牙科网格分割方…

民国漫画杂志《时代漫画》第16期.PDF

时代漫画16.PDF: https://url03.ctfile.com/f/1779803-1248612470-6a05f0?p9586 (访问密码: 9586) 《时代漫画》的杂志在1934年诞生了&#xff0c;截止1937年6月战争来临被迫停刊共发行了39期。 ps:资源来源网络&#xff01;

【C++】二分查找:在排序数组中查找元素的第一个和最后一个位置

1.题目 难点&#xff1a;要求时间复杂度度为O(logn)。 2.算法思路 需要找到左边界和右边界就可以解决问题。 题目中的数组具有“二段性”&#xff0c;所以可以通过二分查找的思想进行解题。 代码&#xff1a; class Solution { public:vector<int> searchRange(vect…

Camunda BPM主要组件

Camunda BPM是使用java开发的,核心流程引擎运行在JVM里,纯java库,不依赖其他库或者底层操作系统。可以完美地与其他java框架融合,比如Spring。除了核心流程引擎外,还提供了一系列的管理,操作和监控工具。 1,工作流引擎 既适用于服务或者微服务编排,也适用于人工任务管…

Leetcode42题:接雨水

1.题目描述 给定 n 个非负整数表示每个宽度为 1 的柱子的高度图&#xff0c;计算按此排列的柱子&#xff0c;下雨之后能接多少雨水。 示例1&#xff1a; 输入&#xff1a;height [0,1,0,2,1,0,1,3,2,1,2,1] 输出&#xff1a;6 解释&#xff1a;上面是由数组 [0,1,0,2,1,0,1,…

【C语言】二叉树的实现

文章目录 前言⭐一、二叉树的定义&#x1f6b2;二、创建二叉树&#x1f3a1;三、二叉树的销毁&#x1f389;四、遍历二叉树1. 前序遍历2. 中序遍历3. 后序遍历4. 层序遍历 &#x1f332;五、二叉树的计算1. 计算二叉树结点个数2. 计算二叉树叶子结点的个数3. 计算二叉树的深度4…

一、Elasticsearch介绍与部署

目录 一、什么是Elasticsearch 二、安装Elasticsearch 三、配置es 四、启动es 1、下载安装elasticsearch的插件head 2、在浏览器&#xff0c;加载扩展程序 3、运行扩展程序 4、输入es地址就可以了 五、Elasticsearch 创建、查看、删除索引、创建、查看、修改、删除文档…