信号处理--基于混合CNN和transfomer自注意力的多通道脑电信号的情绪分类的简单应用

目录

关于

工具

数据集

数据集简述

方法实现

数据读取

​编辑数据预处理

传统机器学习模型(逻辑回归,支持向量机,随机森林)

多层感知机模型

CNN+transfomer模型

代码获取


关于

  • 本实验利用结合了卷积神经网络 (CNN) 和 Transformer 组件的混合架构,实现基于 EEG 的有效情绪分类。
  • 尝试各种机器学习模型,包括逻辑回归、支持向量机 (SVM)、随机森林分类器和多层感知器 (MLP) 神经网络,以比较不同模型的性能。 

 图片来自于: https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9991178

工具

数据集

数据集简述

脑电图数据是从两名受试者(1 名男性、1 名女性,年龄 20-22 岁)收集的,针对特定电影剪辑引发的六种情绪状态(积极、消极、中性)中的每一种状态。该数据集包括从脑电波中收集的 324,000 个数据点,这些数据点被重新采样到 150Hz。还收集了中性脑电波数据,作为代表受试者静息情绪状态的第三类数据。从四个电极(TP9、AF7、AF8、TP10)记录 EEG 数据,并进行处理以生成通过 1 秒滑动窗口提取的统计特征数据集。

图片来源:https://www.researchgate.net/figure/This-figure-shows-the-standard-locations-for-measuring-EEG-as-per-10-20-International_fig2_358644174 

方法实现

数据读取
raw_eeg_data = pd.read_csv('../data/features_raw.csv')
raw_eeg_data.head()# plot the F8 column
plt.figure(figsize=(20, 5))
plt.plot(raw_eeg_data['F8'])
plt.title('F8 Electrode Data')
plt.ylabel('Voltage (uV)')
plt.xlabel('Time')
plt.show()# plot the F7 column
plt.figure(figsize=(20, 5))
plt.plot(raw_eeg_data['F7'])
plt.title('F7 Electrode Data')
plt.ylabel('Voltage (uV)')
plt.xlabel('Time')
plt.show()
数据预处理
X = eeg_emotions_data.drop(['label'], axis=1)
y = eeg_emotions_data['label']# Encoding categorical data
from sklearn.preprocessing import LabelEncoder, OneHotEncoderlabelencoder_emotions = LabelEncoder()
y = labelencoder_emotions.fit_transform(y)# Standardizing the features in the dataset
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()X = scaler.fit_transform(X)

传统机器学习模型(逻辑回归,支持向量机,随机森林)
from sklearn.linear_model import LogisticRegression
import pickle# Create a logistic regression classifier
model = LogisticRegression(random_state=2003, multi_class='multinomial', max_iter=1000)# Train the model
model.fit(X_train, y_train)# Evaluate the model
evaluate_model(y_test, model.predict(X_test))

from sklearn.svm import SVC# Create a model: a support vector classifier
model = SVC(kernel='rbf', gamma='auto', C=1.0, random_state=2003)# Train the model
model.fit(X_train, y_train)# Evaluate the model
evaluate_model(y_test, model.predict(X_test))

from sklearn.ensemble import RandomForestClassifier# Create a random forest Classifier.
model = RandomForestClassifier(n_estimators=100, random_state=2003)# Train the model
model.fit(X_train, y_train)# Evaluate the model
evaluate_model(y_test, model.predict(X_test))

在传统机器模型,我们可以发现随机森林的性能表现最好; 

多层感知机模型
class EEGClassifier(nn.Module):def __init__(self, input_dim, num_classes, hidden_dim=256):super(EEGClassifier, self).__init__()self.fc1 = nn.Linear(input_dim, hidden_dim)self.relu = nn.ReLU()self.fc2 = nn.Linear(hidden_dim, num_classes)def forward(self, x):x = self.fc1(x)x = self.relu(x)x = self.fc2(x)return xinput_dim = 2548  # Number of features in EEG signal
num_classes = 3   # Number of classes for classification
model = EEGClassifier(input_dim, num_classes)
loss = nn.CrossEntropyLoss()

CNN+transfomer模型
class EEGConformer(nn.Module):def __init__(self, input_dim, num_classes):super(EEGConformer, self).__init__()# CNNself.conv1 = nn.Conv2d(1, 40, kernel_size=(1, 25), stride=(1, 1))self.conv2 = nn.Conv2d(40, 40, kernel_size=(1, input_dim), stride=(1, 1))self.batchnorm = nn.BatchNorm2d(40)# Transformerself.layernorm1 = nn.LayerNorm(40)self.multiheadattention = nn.MultiheadAttention(40, 1)self.layernorm2 = nn.LayerNorm(40)self.feedworward_block = nn.Sequential(nn.Linear(40, 32),nn.GELU(),nn.Dropout(p=0.1),nn.Linear(32, 40))# MLPself.fc1 = nn.Linear(40, 32)self.fc2 = nn.Linear(32, 32)self.fc3 = nn.Linear(32, num_classes)def forward(self, x):# CNNx = x.unsqueeze(1).unsqueeze(1)x = self.conv1(x)x = self.conv2(x)x = self.batchnorm(x)# Transformerx = x.squeeze()x = self.layernorm1(x)attn_out = self.multiheadattention(x, x, x)x = x + nn.Dropout(0.1)(attn_out[0])x = self.layernorm2(x)x = self.feedworward_block(x)x = nn.Dropout(p=0.1)(x)# MLPx = self.fc1(x)x = F.elu(x)x = nn.Dropout(p=0.5)(x)x = self.fc2(x)x = F.elu(x)x = nn.Dropout(p=0.3)(x)x = self.fc3(x)return xinput_dim = 2524  # Number of features in EEG signal
num_classes = 3   # Number of classes for classification
model = EEGConformer(input_dim, num_classes)
loss = nn.CrossEntropyLoss()

代码获取

后台私信,注明来意和文章名称;

其他问题,欢迎沟通交流。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/773366.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

windows下powershell与linux下bash美化教程(使用starship)

starship美化教程 Win11 Powershell 安装 在命令行使用下面命令安装 # 安装starship winget install starship将以下内容添加到 Microsoft.PowerShell_profile.ps1,可以在 PowerShell 通过 $PROFILE 变量来查询文件的位置 Invoke-Expression (&starship i…

web学习笔记(四十五)Node.js

目录 1. Node.js 1.1 什么是Node.js 1.2 为什么要学node.js 1.3 node.js的使用场景 1.4 Node.js 环境的安装 1.5 如何查看自己安装的node.js的版本 1.6 常用终端命令 2. fs 文件系统模块 2.1引入fs核心模块 2.2 读取指定文件的内容 2.3 向文件写入指定内容 2.4 创…

yarn按包的时候报错 ../../../package.json: No license field

运行 yarn config list 然后运行 yarn config set strict-ssl false 之后yarn就成功了

基于SpringBoot“网上选课系统”设计和实现(源码定制以及咨询!!)

博主介绍:✌全网粉丝10W,B站项目阿龙、csdn特邀作者、博客专家、CSDN新星计划导师、Java领域优质创作者,博客之星、华为云获奖者,“程序员阿龙”✌ 主要内容:SpringBoot、Vue、SSM、HLMT、Jsp、PHP、Nodejs、Python,MYSQL、Hodpoo…

SpringBoot2.6.3 + knife4j-openapi3

1.引入项目依赖&#xff1a; <dependency><groupId>com.github.xiaoymin</groupId><artifactId>knife4j-openapi3-spring-boot-starter</artifactId><version>4.5.0</version> </dependency> 2.新增配置文件 import io.swag…

Docker搭建LNMP环境实战(05):CentOS环境安装Docker-CE

前面几篇文章讲了那么多似乎和Docker无关的实战操作&#xff0c;本篇总算开始说到Docker了。 1、关于Docker 1.1、什么是Docker Docker概念就是大概了解一下就可以&#xff0c;还是引用一下百度百科吧&#xff1a; Docker 是一个开源的应用容器引擎&#xff0c;让开发者可以…

【机器学习之---数学】随机游走

every blog every motto: You can do more than you think. https://blog.csdn.net/weixin_39190382?typeblog 0. 前言 随机游走 1. 概念 1.1 例1 在你的饮食俱乐部度过了一个富有成效的晚上后&#xff0c;你在不太清醒的状态下离开了。因此&#xff0c;你会醉醺醺地在展…

数据结构(五)单链表专题

在开始之前&#xff0c;我先来给大家讲一下顺序表与链表的区别&#xff1a; 它们在堆上存储的差异&#xff1a; 我们可以很容易的知道&#xff0c;循序表是连续的有序的&#xff0c;但链表是杂乱的&#xff0c;它们通过地址彼此联系起来。 1. 链表的概念及结构 概念&#xff1…

智慧交通(代码实现案例)

1.项目简介 目标: 了解智慧交通项目的架构知道智慧交通项目中的模块能够完成智慧交通项目的环境搭建 该项目是智慧交通项目&#xff0c;通过该项目掌握计算机视觉的方法在交通领域的相关应用&#xff0c;包括车道线检测的方法&#xff0c;多目标车辆追踪及流量统计方法&#…

Linux Tomcat的服务器如何查看接口请求方式?

问题描述 最近在和安卓开发对接接口&#xff0c;遇到一个接口总是报405错误&#xff0c;有对接经验的开发应该都知道是请求方式不对&#xff0c;假如接口定义为POST请求的&#xff0c;但是客户端却用GET请求&#xff0c;这时候就会报这个错误。Android客户端那边使用xUtils框架…

【小白入门篇3】还是GPT4更香

上一节文章《【小白入门篇2】总有一款AI工具适合你》介绍了很多ai产品给大家&#xff0c;有同学私信我&#xff0c;国内工具还是比较差&#xff0c;还是想用gpt4模型。这个章节介绍一些gpt4工具给大家, 其中大部分都只有一些免费的次数, 而且都需要kx上网才能访问。 OpenAI ch…

浙大版《C语言程序设计(第4版)》题目集-练习4-7 求e的近似值

自然常数 e 可以用级数 1 1 / 1 ! 1 / 2 ! ⋯ 1 / n ! ⋯ 11/1!1/2!⋯1/n!⋯ 11/1!1/2!⋯1/n!⋯来近似计算。本题要求对给定的非负整数 n&#xff0c;求该级数的前 n1 项和。 输入格式: 输入第一行中给出非负整数 n&#xff08;≤1000&#xff09;。 输出格式: 在一行…

记录一下安装ubuntu子系统的pycharm遇到的问题

sudo su #切换为root用户获取管理员权限用于新建用户 adduser username #新建用户&#xff08;例如用户名为username&#xff09; adduser username sudo #将用户添加到 sudo 组同样遇到这个问题&#xff0c;解决方法是&#xff1a;先新建一个用户名&#xff0c;然后再切换到这…

Android 性能优化实例分享-内存优化 兼顾效率与性能

背景 项目上线一段时间后,回顾重要页面 保证更好用户体验及生产效率&#xff0c;做了内存优化和下载导出优化&#xff0c;具体效果如最后的一节的表格所示。 下面针对拍摄流程的两个页面 预览页 导出页优化实例进行介绍&#xff1a; 一.拍摄前预览页面优化 预览效果问题 存在…

Quartus II仿真出现错误

ModelSim executable not found in D:/intelFPGA/18.0/quartus/bin64/modelsim_ase/win32aloem/ Error. 找不到modelsim地址&#xff0c;原来是我下载了.exe,但没有双击启动安装ase文件夹呀&#xff01;&#xff01;&#xff01;&#xff01;晕&#xff0c;服了我自己

Python7:接口自动化学习1 RPC

API&#xff08;Application Programmming Interface&#xff09; 应用编程接口&#xff0c;简称“接口” 接口&#xff1a;程序之间约定的通信方法 特点&#xff1a;约定了调用方法&#xff0c;以及预期的行为&#xff0c;但是不透露具体细节 意义&#xff1a;程序能解耦&…

【No.20】蓝桥杯简单数论下|寻找整数|素数的判断|笨小猴|最大最小公倍数|素数筛|埃氏筛|欧氏线性筛|质数|分解质因子(C++)

寻找整数 【题目描述】 有一个不超过 1 0 1 7 10^17 1017的正整数n&#xff0c;知道这个数除以2至49后的余数如下表所示&#xff0c;求这个正整数最小是多少 解法一&#xff1a;模拟 暴力法&#xff1a;一个个检验 1 … 1 0 17 1\dots 10^{17} 1…1017的每个数 由于这个数n…

【pytest、playwright】构建POM项目,以及解决登录问题,allure环境问题

目录 前言 1、文件目录 2、安装依赖 3、POM项目实战-案例&#xff1a;打开指定页面 目录结构&#xff1a; pages中的代码&#xff1a; cases中的代码&#xff1a; 4、解决登录问题 问题&#xff1a; 解决方案&#xff1a; 获取登录的用户信息&#xff08;cookie&a…

DasViewer电脑客户端打开文件夹时,一直显示崩溃,该怎么解决?

问题如图 如若用的是DasViewer V3.2.4Beta版本&#xff0c;可以换回3.2.1版本进行尝试。 DasViewer是由大势智慧自主研发的免费的实景三维模型浏览器,采用多细节层次模型逐步自适应加载技术,让用户在极低的电脑配置下,也能流畅的加载较大规模实景三维模型,提供方便快捷的数据浏…

发送请求- header配置

请求头里是客户端的要求&#xff0c;把你的诉求告诉服务端&#xff0c;服务端按照你的要求返回数据 &#xff0c; 请求header需要严格全配置&#xff0c;把请求header全部传入&#xff0c;不能频繁访问&#xff0c;让后端知道它是正常请求 一般只配置User-Agent和Content Typ…