信号处理--基于混合CNN和transfomer自注意力的多通道脑电信号的情绪分类的简单应用

目录

关于

工具

数据集

数据集简述

方法实现

数据读取

​编辑数据预处理

传统机器学习模型(逻辑回归,支持向量机,随机森林)

多层感知机模型

CNN+transfomer模型

代码获取


关于

  • 本实验利用结合了卷积神经网络 (CNN) 和 Transformer 组件的混合架构,实现基于 EEG 的有效情绪分类。
  • 尝试各种机器学习模型,包括逻辑回归、支持向量机 (SVM)、随机森林分类器和多层感知器 (MLP) 神经网络,以比较不同模型的性能。 

 图片来自于: https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9991178

工具

数据集

数据集简述

脑电图数据是从两名受试者(1 名男性、1 名女性,年龄 20-22 岁)收集的,针对特定电影剪辑引发的六种情绪状态(积极、消极、中性)中的每一种状态。该数据集包括从脑电波中收集的 324,000 个数据点,这些数据点被重新采样到 150Hz。还收集了中性脑电波数据,作为代表受试者静息情绪状态的第三类数据。从四个电极(TP9、AF7、AF8、TP10)记录 EEG 数据,并进行处理以生成通过 1 秒滑动窗口提取的统计特征数据集。

图片来源:https://www.researchgate.net/figure/This-figure-shows-the-standard-locations-for-measuring-EEG-as-per-10-20-International_fig2_358644174 

方法实现

数据读取
raw_eeg_data = pd.read_csv('../data/features_raw.csv')
raw_eeg_data.head()# plot the F8 column
plt.figure(figsize=(20, 5))
plt.plot(raw_eeg_data['F8'])
plt.title('F8 Electrode Data')
plt.ylabel('Voltage (uV)')
plt.xlabel('Time')
plt.show()# plot the F7 column
plt.figure(figsize=(20, 5))
plt.plot(raw_eeg_data['F7'])
plt.title('F7 Electrode Data')
plt.ylabel('Voltage (uV)')
plt.xlabel('Time')
plt.show()
数据预处理
X = eeg_emotions_data.drop(['label'], axis=1)
y = eeg_emotions_data['label']# Encoding categorical data
from sklearn.preprocessing import LabelEncoder, OneHotEncoderlabelencoder_emotions = LabelEncoder()
y = labelencoder_emotions.fit_transform(y)# Standardizing the features in the dataset
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()X = scaler.fit_transform(X)

传统机器学习模型(逻辑回归,支持向量机,随机森林)
from sklearn.linear_model import LogisticRegression
import pickle# Create a logistic regression classifier
model = LogisticRegression(random_state=2003, multi_class='multinomial', max_iter=1000)# Train the model
model.fit(X_train, y_train)# Evaluate the model
evaluate_model(y_test, model.predict(X_test))

from sklearn.svm import SVC# Create a model: a support vector classifier
model = SVC(kernel='rbf', gamma='auto', C=1.0, random_state=2003)# Train the model
model.fit(X_train, y_train)# Evaluate the model
evaluate_model(y_test, model.predict(X_test))

from sklearn.ensemble import RandomForestClassifier# Create a random forest Classifier.
model = RandomForestClassifier(n_estimators=100, random_state=2003)# Train the model
model.fit(X_train, y_train)# Evaluate the model
evaluate_model(y_test, model.predict(X_test))

在传统机器模型,我们可以发现随机森林的性能表现最好; 

多层感知机模型
class EEGClassifier(nn.Module):def __init__(self, input_dim, num_classes, hidden_dim=256):super(EEGClassifier, self).__init__()self.fc1 = nn.Linear(input_dim, hidden_dim)self.relu = nn.ReLU()self.fc2 = nn.Linear(hidden_dim, num_classes)def forward(self, x):x = self.fc1(x)x = self.relu(x)x = self.fc2(x)return xinput_dim = 2548  # Number of features in EEG signal
num_classes = 3   # Number of classes for classification
model = EEGClassifier(input_dim, num_classes)
loss = nn.CrossEntropyLoss()

CNN+transfomer模型
class EEGConformer(nn.Module):def __init__(self, input_dim, num_classes):super(EEGConformer, self).__init__()# CNNself.conv1 = nn.Conv2d(1, 40, kernel_size=(1, 25), stride=(1, 1))self.conv2 = nn.Conv2d(40, 40, kernel_size=(1, input_dim), stride=(1, 1))self.batchnorm = nn.BatchNorm2d(40)# Transformerself.layernorm1 = nn.LayerNorm(40)self.multiheadattention = nn.MultiheadAttention(40, 1)self.layernorm2 = nn.LayerNorm(40)self.feedworward_block = nn.Sequential(nn.Linear(40, 32),nn.GELU(),nn.Dropout(p=0.1),nn.Linear(32, 40))# MLPself.fc1 = nn.Linear(40, 32)self.fc2 = nn.Linear(32, 32)self.fc3 = nn.Linear(32, num_classes)def forward(self, x):# CNNx = x.unsqueeze(1).unsqueeze(1)x = self.conv1(x)x = self.conv2(x)x = self.batchnorm(x)# Transformerx = x.squeeze()x = self.layernorm1(x)attn_out = self.multiheadattention(x, x, x)x = x + nn.Dropout(0.1)(attn_out[0])x = self.layernorm2(x)x = self.feedworward_block(x)x = nn.Dropout(p=0.1)(x)# MLPx = self.fc1(x)x = F.elu(x)x = nn.Dropout(p=0.5)(x)x = self.fc2(x)x = F.elu(x)x = nn.Dropout(p=0.3)(x)x = self.fc3(x)return xinput_dim = 2524  # Number of features in EEG signal
num_classes = 3   # Number of classes for classification
model = EEGConformer(input_dim, num_classes)
loss = nn.CrossEntropyLoss()

代码获取

后台私信,注明来意和文章名称;

其他问题,欢迎沟通交流。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/773366.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

五-容量管理之容量预案

容量预案(Capacity Plan)是容量管理的一个重要组成部分。 容量预警条件和措施: 类型预警条件措施应用服务器Load95分位值大于CPU核数的2倍 前一天CPU 95分位值大于90% 内存使用率95分位值大于90%增加应用服务器数据库数据库连接超过200扩容DB服务规格/ 优化SQL查询…

windows下powershell与linux下bash美化教程(使用starship)

starship美化教程 Win11 Powershell 安装 在命令行使用下面命令安装 # 安装starship winget install starship将以下内容添加到 Microsoft.PowerShell_profile.ps1,可以在 PowerShell 通过 $PROFILE 变量来查询文件的位置 Invoke-Expression (&starship i…

【HBZ分享】Kafka为什么性能非常高

Kafka性能高的原因 磁盘顺序读写:磁盘顺序读写的性能可以和内存相媲美,顺序读写不需要寻道时间,也不需要大幅旋转磁头找扇区,所以性能极高 零拷贝: 大幅降低了用户态与内核态之间的切换,从而减少了数据来回…

Chrome安装Vue插件vue-devtools

安装Vue.js开发者工具(Vue DevTools)到Google Chrome浏览器的步骤可能会随着Vue DevTools更新和Chrome政策变化而有所调整。 1.从GitHub获取源代码: 访问Vue DevTools的GitHub仓库:https://github.com/vuejs/vue-devtools 根据仓…

web学习笔记(四十五)Node.js

目录 1. Node.js 1.1 什么是Node.js 1.2 为什么要学node.js 1.3 node.js的使用场景 1.4 Node.js 环境的安装 1.5 如何查看自己安装的node.js的版本 1.6 常用终端命令 2. fs 文件系统模块 2.1引入fs核心模块 2.2 读取指定文件的内容 2.3 向文件写入指定内容 2.4 创…

sql oracle 获取当前日期的最后一天

语法 LAST_DAY 传入一个日期类型的变量&#xff0c;但会给你一个当月的最后一天的变量 LAST_DAY(TO_DATE(year || - || SUBSTR(month, -2) || -01, YYYY-MM-DD)) < ?应用实例 AssetValueSingleQT.spl 一个表中只存储的年和月&#xff0c;需要更具年月筛选小于指定日期&…

yarn按包的时候报错 ../../../package.json: No license field

运行 yarn config list 然后运行 yarn config set strict-ssl false 之后yarn就成功了

基于SpringBoot“网上选课系统”设计和实现(源码定制以及咨询!!)

博主介绍&#xff1a;✌全网粉丝10W,B站项目阿龙、csdn特邀作者、博客专家、CSDN新星计划导师、Java领域优质创作者,博客之星、华为云获奖者&#xff0c;“程序员阿龙”✌ 主要内容&#xff1a;SpringBoot、Vue、SSM、HLMT、Jsp、PHP、Nodejs、Python&#xff0c;MYSQL、Hodpoo…

网络工程师软考中级考试大纲

考试要求&#xff1a; &#xff08;1&#xff09;熟悉计算机系统的基础知识&#xff1b;&#xff08;2&#xff09;熟悉网络操作系统的基础知识&#xff1b;&#xff08;3&#xff09;理解计算机应用系统的设计和开发方法&#xff1b;&#xff08;4&#xff09;熟悉数据通信的基…

SpringBoot2.6.3 + knife4j-openapi3

1.引入项目依赖&#xff1a; <dependency><groupId>com.github.xiaoymin</groupId><artifactId>knife4j-openapi3-spring-boot-starter</artifactId><version>4.5.0</version> </dependency> 2.新增配置文件 import io.swag…

Docker搭建LNMP环境实战(05):CentOS环境安装Docker-CE

前面几篇文章讲了那么多似乎和Docker无关的实战操作&#xff0c;本篇总算开始说到Docker了。 1、关于Docker 1.1、什么是Docker Docker概念就是大概了解一下就可以&#xff0c;还是引用一下百度百科吧&#xff1a; Docker 是一个开源的应用容器引擎&#xff0c;让开发者可以…

【机器学习之---数学】随机游走

every blog every motto: You can do more than you think. https://blog.csdn.net/weixin_39190382?typeblog 0. 前言 随机游走 1. 概念 1.1 例1 在你的饮食俱乐部度过了一个富有成效的晚上后&#xff0c;你在不太清醒的状态下离开了。因此&#xff0c;你会醉醺醺地在展…

【OpenStack】在本地OpenStack中创建一个应用及其网络

【OpenStack】在本地OpenStack中创建一个应用及其网络 目录 【OpenStack】在本地OpenStack中创建一个应用及其网络应用程序设计创建项目创建一个新项目使用新项目创建网络创建我们的应用程序VNET创建路由器更新安全组创建浮动IP结论推荐超级课程: Docker快速入门到精通Kuberne…

计算机视觉中的NMS非极大值抑制

NMS 是“非极大抑制”&#xff08;Non-Maximum Suppression&#xff09;的缩写&#xff0c;是一种在目标检测算法中广泛使用的技术。它的主要目的是减少目标检测过程中的多余的边界框&#xff0c;以便只保留最佳的一个边界框。 在目标检测任务中&#xff0c;算法会对图像中可能…

数据结构(五)单链表专题

在开始之前&#xff0c;我先来给大家讲一下顺序表与链表的区别&#xff1a; 它们在堆上存储的差异&#xff1a; 我们可以很容易的知道&#xff0c;循序表是连续的有序的&#xff0c;但链表是杂乱的&#xff0c;它们通过地址彼此联系起来。 1. 链表的概念及结构 概念&#xff1…

智慧交通(代码实现案例)

1.项目简介 目标: 了解智慧交通项目的架构知道智慧交通项目中的模块能够完成智慧交通项目的环境搭建 该项目是智慧交通项目&#xff0c;通过该项目掌握计算机视觉的方法在交通领域的相关应用&#xff0c;包括车道线检测的方法&#xff0c;多目标车辆追踪及流量统计方法&#…

t检验原理

t检验是一种常用的统计方法&#xff0c;用于比较两个样本均值是否有显著差异。它的基本原理是通过计算样本均值之间的差异&#xff0c;以及这种差异相对于样本误差的大小来判断差异是否显著。 t检验的基本步骤如下&#xff1a; 1. 假设两个样本是独立、随机抽取的&#xff0c;…

Linux Tomcat的服务器如何查看接口请求方式?

问题描述 最近在和安卓开发对接接口&#xff0c;遇到一个接口总是报405错误&#xff0c;有对接经验的开发应该都知道是请求方式不对&#xff0c;假如接口定义为POST请求的&#xff0c;但是客户端却用GET请求&#xff0c;这时候就会报这个错误。Android客户端那边使用xUtils框架…

Python 中处理JSON文件的方法

用 Python 读取、写入和操作 JSON 文件 JSON&#xff08;JavaScript Object Notation&#xff09;是一种流行的数据交换格式&#xff0c;易于人类阅读和编写。在编程领域中&#xff0c;与网络API或HTTP请求交互时经常会用到 JSON。Python 通过 json 模块提供了对 JSON 文件的内…

六-容量管理之相关工具

容量管理是一种综合性的事项&#xff0c;其中涉及多种相关工具和技术&#xff0c;常用的工具有&#xff1a; 压测平台&#xff08;压力/脚本/数据&#xff09;监控平台&#xff08;Log/Trace/Metrics&#xff09;发布平台 (CI/CD/弹性伸缩)预案平台 &#xff08;限流/降级/容量…