信号处理--基于单通道脑电信号EEG的睡眠分期评估

背景

睡眠对人体健康很重要。监测人体的睡眠分期对于人体健康和医疗具有重要意义。

亮点

  • 架构在第一层使用两个具有不同滤波器大小的 CNN 和双向 LSTM。 CNN 可以被训练来学习滤波器,以从原始单通道 EEG 中提取时不变特征,而双向 LSTM 可以被训练来将时间信息(例如睡眠阶段转换规则)编码到模型中。
  • 实现了一种两步训练算法,可以通过反向传播有效地端到端训练我们的模型,同时防止模型遭受大睡眠中出现的类别不平衡问题(即学习仅对大多数睡眠阶段进行分类) 数据集。
  • 在不改变模型架构和训练算法的情况下,模型可以从两个数据集的不同原始单通道脑电图自动学习睡眠阶段评分的特征,这两个数据集具有不同的属性(例如采样率)和评分标准( AASM 和 R&K)。

环境配置

  • python3.5.4
  • tensorflowgpu  1.15.2

数据

Sleep-EDF

MASS

方法

 模型主要代码:

class MyModel(DeepFeatureNet):def __init__(self, batch_size, input_dims, n_classes, seq_length,n_rnn_layers,return_last,is_train, reuse_params,use_dropout_feature, use_dropout_sequence,name="deepsleepnet"):super(self.__class__, self).__init__(batch_size=batch_size, input_dims=input_dims, n_classes=n_classes, is_train=is_train, reuse_params=reuse_params, use_dropout=use_dropout_feature, name=name)self.seq_length = seq_lengthself.n_rnn_layers = n_rnn_layersself.return_last = return_lastself.use_dropout_sequence = use_dropout_sequencedef _build_placeholder(self):# Inputname = "x_train" if self.is_train else "x_valid"self.input_var = tf.compat.v1.placeholder(tf.float32, shape=[self.batch_size*self.seq_length, self.input_dims, 1, 1],name=name + "_inputs")# Targetself.target_var = tf.compat.v1.placeholder(tf.int32, shape=[self.batch_size*self.seq_length, ],name=name + "_targets")def build_model(self, input_var):# Create a network with superclass methodnetwork = super(self.__class__, self).build_model(input_var=self.input_var)# Residual (or shortcut) connectionoutput_conns = []# Fully-connected to select some part of the output to add with the output from bi-directional LSTMname = "l{}_fc".format(self.layer_idx)with tf.compat.v1.variable_scope(name) as scope:output_tmp = fc(name="fc", input_var=network, n_hiddens=1024, bias=None, wd=0)output_tmp = batch_norm_new(name="bn", input_var=output_tmp, is_train=self.is_train)# output_tmp = leaky_relu(name="leaky_relu", input_var=output_tmp)output_tmp = tf.nn.relu(output_tmp, name="relu")self.activations.append((name, output_tmp))self.layer_idx += 1output_conns.append(output_tmp)####################################################################### Reshape the input from (batch_size * seq_length, input_dim) to# (batch_size, seq_length, input_dim)name = "l{}_reshape_seq".format(self.layer_idx)input_dim = network.get_shape()[-1].valueseq_input = tf.reshape(network,shape=[-1, self.seq_length, input_dim],name=name)assert self.batch_size == seq_input.get_shape()[0].valueself.activations.append((name, seq_input))self.layer_idx += 1# Bidirectional LSTM networkname = "l{}_bi_lstm".format(self.layer_idx)hidden_size = 512   # will output 1024 (512 forward, 512 backward)with tf.compat.v1.variable_scope(name) as scope:def lstm_cell():cell = tf.compat.v1.nn.rnn_cell.LSTMCell(hidden_size,                               use_peepholes=True,state_is_tuple=True,reuse=tf.compat.v1.get_variable_scope().reuse) if self.use_dropout_sequence:keep_prob = 0.5 if self.is_train else 1.0cell = tf.compat.v1.nn.rnn_cell.DropoutWrapper(cell,output_keep_prob=keep_prob)return cellfw_cell = tf.compat.v1.nn.rnn_cell.MultiRNNCell([lstm_cell() for _ in range(self.n_rnn_layers)], state_is_tuple = True)bw_cell = tf.compat.v1.nn.rnn_cell.MultiRNNCell([lstm_cell() for _ in range(self.n_rnn_layers)], state_is_tuple = True)# Initial state of RNNself.fw_initial_state = fw_cell.zero_state(self.batch_size, tf.float32)self.bw_initial_state = bw_cell.zero_state(self.batch_size, tf.float32)# Feedforward to MultiRNNCelllist_rnn_inputs = tf.unstack(seq_input, axis=1)#outputs, fw_state, bw_state = tf.nn.bidirectional_rnn(outputs, fw_state, bw_state = tf.compat.v1.nn.static_bidirectional_rnn(cell_fw=fw_cell,cell_bw=bw_cell,inputs=list_rnn_inputs,initial_state_fw=self.fw_initial_state,initial_state_bw=self.bw_initial_state)if self.return_last:network = outputs[-1]else:network = tf.reshape(tf.concat(axis=1, values=outputs), [-1, hidden_size*2],name=name)self.activations.append((name, network))self.layer_idx +=1self.fw_final_state = fw_stateself.bw_final_state = bw_state# Append outputoutput_conns.append(network)####################################################################### Addname = "l{}_add".format(self.layer_idx)network = tf.add_n(output_conns, name=name)self.activations.append((name, network))self.layer_idx += 1# Dropoutif self.use_dropout_sequence:name = "l{}_dropout".format(self.layer_idx)if self.is_train:network = tf.nn.dropout(network, keep_prob=0.5, name=name)else:network = tf.nn.dropout(network, keep_prob=1.0, name=name)self.activations.append((name, network))self.layer_idx += 1return networkdef init_ops(self):self._build_placeholder()# Get loss and prediction operationswith tf.compat.v1.variable_scope(self.name) as scope:# Reuse variables for validationif self.reuse_params:scope.reuse_variables()# Build modelnetwork = self.build_model(input_var=self.input_var)# Softmax linearname = "l{}_softmax_linear".format(self.layer_idx)network = fc(name=name, input_var=network, n_hiddens=self.n_classes, bias=0.0, wd=0)self.activations.append((name, network))self.layer_idx += 1# Outputs of softmax linear are logitsself.logits = network######### Compute loss ########## Weighted cross-entropy loss for a sequence of logits (per example)loss = tf.contrib.legacy_seq2seq.sequence_loss_by_example([self.logits],[self.target_var],[tf.ones([self.batch_size * self.seq_length])],name="sequence_loss_by_example")loss = tf.reduce_sum(loss) / self.batch_size# Regularization lossregular_loss = tf.add_n(tf.compat.v1.get_collection("losses", scope=scope.name + "\/"),name="regular_loss")# print " "# print "Params to compute regularization loss:"# for p in tf.compat.v1.get_collection("losses", scope=scope.name + "\/"):#     print p.name# print " "# Total lossself.loss_op = tf.add(loss, regular_loss)# Predictionsself.pred_op = tf.argmax(self.logits, 1)

结果

睡眠分期效果图

MASS数据集分类表

代码获取

后台私信 1

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/726810.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数据库-DDL

show databases; 查询所有数据库 select database(); 查询当前数据库 use 数据库名; 使用数据库 creat database[if not exists] 数据库名…

vue2中如何实现添加一个空标签的效果,</>

前言&#xff1a; 众所周知&#xff0c;vue3突破了每一个vue文件中只能有一个根标签的限制&#xff0c;但是我们还有很多项目都是vue2的项目&#xff0c;如果让vue2中实现一个类似</>的效果呢&#xff0c;像react的16.2.0的版本之后&#xff0c;可以直接用<></&…

2024 年 AI 辅助研发趋势

随着人工智能技术的持续发展与突破&#xff0c;2024年AI辅助研发正成为科技界和工业界瞩目的焦点。从医药研发到汽车设计&#xff0c;从软件开发到材料科学&#xff0c;AI正逐渐渗透到研发的各个环节&#xff0c;变革着传统的研发模式。在这一背景下&#xff0c;AI辅助研发不仅…

【动态规划】【数论】【区间合并】3041. 修改数组后最大化数组中的连续元素数目

作者推荐 视频算法专题 本文涉及知识点 动态规划汇总 数论 区间合并 LeetCode3041. 修改数组后最大化数组中的连续元素数目 给你一个下标从 0 开始只包含 正 整数的数组 nums 。 一开始&#xff0c;你可以将数组中 任意数量 元素增加 至多 1 。 修改后&#xff0c;你可以从…

Spring Boot 3核心技术与最佳实践

&#x1f482; 个人网站:【 海拥】【神级代码资源网站】【办公神器】&#x1f91f; 基于Web端打造的&#xff1a;&#x1f449;轻量化工具创作平台&#x1f485; 想寻找共同学习交流的小伙伴&#xff0c;请点击【全栈技术交流群】 highlight: a11y-dark 引言 Spring Boot作为…

企业财务分析该怎么做?重点分析哪些财务指标?

在企业经营管理的过程中&#xff0c;财务分析是评估当前企业或特定部门财务状况和绩效的过程&#xff0c;这一过程通常涉及对财务报表&#xff08;如资产负债表、利润表和现金流量表&#xff09;进行定量和定性的评估&#xff0c;以便为盈利能力、偿债能力、现金流动性和资金稳…

解决 RuntimeError: “LayerNormKernelImpl“ not implemented for ‘Half‘

解决 RuntimeError: “LayerNormKernelImpl” not implemented for ‘Half’。 错误类似如下&#xff1a; Traceback (most recent call last): File “cli_demo.py”, line 21, in for results in webglm.stream_query(question): File “/root/WebGLM/model/modeling_webgl…

<C++>【继承篇】

​ ✨前言✨ &#x1f393;作者&#xff1a;【 教主 】 &#x1f4dc;文章推荐&#xff1a; ☕博主水平有限&#xff0c;如有错误&#xff0c;恳请斧正。 &#x1f4cc;机会总是留给有准备的人&#xff0c;越努力&#xff0c;越幸运&#xff01; &#x1f4a6;导航助手&#x1…

Vue+SpringBoot打造校园疫情防控管理系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 学生2.2 老师2.3 学校管理部门 三、系统展示四、核心代码4.1 新增健康情况上报4.2 查询健康咨询4.3 新增离返校申请4.4 查询防疫物资4.5 查询防控宣传数据 五、免责说明 一、摘要 1.1 项目介绍 基于JAVAVueSpringBoot…

入门版式设计:设计小白的必备知识!

你曾经被一本华丽的杂志、一张引人注目的海报或一个优雅的网站设计所吸引吗&#xff1f;这些都是版式设计的魅力所在。作为一个设计小白&#xff0c;我们可能不熟悉版式设计&#xff0c;但事实上&#xff0c;它无处不在&#xff0c;深深影响着我们的生活。那么&#xff0c;什么…

大型网站架构演化总结

本文图解大型网站架构演化。 目录 1、单一应用服务阶段 2、应用与数据服务分离阶段 3、利用缓存提高性能阶段 4、应用服务集群阶段 5、数据库读写分离阶段 6、反向代理与CDN加速阶段 7、分布式数据库阶段 8、 NoSQL与搜索引擎阶段 9、业务拆分阶段 10、分布式服务阶…

Leetcode刷题(三十八)

旋转矩阵&#xff08;Medium&#xff09; 给定一个 n n 的二维矩阵 matrix 表示一个图像。请你将图像顺时针旋转 90 度。你必须在 原地 旋转图像&#xff0c;这意味着你需要直接修改输入的二维矩阵。请不要 使用另一个矩阵来旋转图像。示例 1&#xff1a;输入&#xff1a;mat…

基于springboot+vue的医疗挂号管理系统

博主主页&#xff1a;猫头鹰源码 博主简介&#xff1a;Java领域优质创作者、CSDN博客专家、阿里云专家博主、公司架构师、全网粉丝5万、专注Java技术领域和毕业设计项目实战&#xff0c;欢迎高校老师\讲师\同行交流合作 ​主要内容&#xff1a;毕业设计(Javaweb项目|小程序|Pyt…

vcruntime140.dll丢失的修复办法详细介绍以及详细步骤

当电脑丢失vcruntime140.dll文件时&#xff0c;电脑会出现关于vcruntime140.dll丢失的错误提示&#xff0c;vcruntime140.dll文件包含许多重要的函数和资源&#xff0c;若缺少或丢失该文件&#xff0c;可能会导致电脑出现异常状况。今天就来和大家说说如果电脑出现关于vcruntim…

获取别人店铺的所有商品API接口

使用淘宝淘口令接口的步骤通常包括&#xff1a; 注册成为淘宝开放平台的开发者&#xff1a;在淘宝开放平台网站上注册账号并完成认证。 创建应用以获取API密钥&#xff1a;在您的开发者控制台中创建一个应用&#xff0c;并获取用于API调用的密钥&#xff0c;如Client ID和Clie…

【JavaEE初阶 -- 计算机核心工作机制】

这里写目录标题 1.冯诺依曼体系2.CPU是怎么构成的3.指令表4.CPU执行代码的方式5.CPU小结&#xff1a;6.编程语言和操作系统7. 进程/任务&#xff08;Process/Task&#xff09;8.进程在系统中是如何管理的9. CPU分配 -- 进程调度10.内存分配 -- 内存管理11.进程间通信 1.冯诺依曼…

javaweb学习(day07-手动实现tomcat)

一、引入案例 1 小案例 引出对 Tomcat 底层实现思考 1.1 完成小案例 1.1.1 运行效果 1.2 maven简要介绍 我们准备使用 Maven 来 创建一个 WEB 项目 , 先 简单给小伙伴介绍一下 Maven 是 什 么 , 更加详细的使用&#xff0c;我们还会细讲 , 现在先使用一把 1.3 创…

多个变量指向同一个数组

多个变量中的内存地址是一样的&#xff0c;都是指向当前的数组&#xff0c;存储当前数组对象的地址&#xff0c;因此修改是对当前数组的值进行修改 数组中存储的是null&#xff0c;那么他将不会指向任何数组对象 System.out.println(arr) 输出结果为null&#xff0c;里面没有…

Vue+OpenLayers7入门到实战:webgl图层叠加大量Icon图片到地图,解决叠加超大数据量图片导致浏览器卡住变慢的问题

返回《Vue+OpenLayers7》专栏目录:Vue+OpenLayers7 前言 之前已经讲了如何地图中如何添加大量点到webgl图层优化大量点浏览器页面卡顿的问题。本章介绍补充一下叠加大量图片图标要素到地图的情况下的问题。 二、依赖和使用 "ol": "7.5.2"使用npm安装依…

Vue+OpenLayers7入门到实战:OpenLayers7如何使用gifler库来实现gif动态图图片叠加到地图上

返回《Vue+OpenLayers7》专栏目录:Vue+OpenLayers7 前言 OpenLayers7本身不支持gif图片作为图标要素显示到地图上,所以需要通过其他办法来实现支持gif图片。 本章介绍如何使用OpenLayers7在地图上使用gifler库先生成canvas画板,然后通过canvas画板的重绘事件来重新渲染地图…