论文解读在这里
File path | Description
```/pretrains
┣ 📂 models
┃ ┗ 📜 config.yaml
┃ ┗ 📜 v1-5-pruned.ckpt┣ 📂 generation
┃ ┗ 📜 checkpoint_best.pth ┣ 📂 eeg_pretain
┃ ┗ 📜 checkpoint.pth (pre-trained EEG encoder)/datasets
┣ 📂 imageNet_images (subset of Imagenet)┗ 📜 block_splits_by_image_all.pth
┗ 📜 block_splits_by_image_single.pth
┗ 📜 eeg_5_95_std.pth /code
┣ 📂 sc_mbm
┃ ┗ 📜 mae_for_eeg.py
┃ ┗ 📜 trainer.py
┃ ┗ 📜 utils.py┣ 📂 dc_ldm
┃ ┗ 📜 ldm_for_eeg.py
┃ ┗ 📜 utils.py
┃ ┣ 📂 models
┃ ┃ ┗ (adopted from LDM)
┃ ┣ 📂 modules
┃ ┃ ┗ (adopted from LDM)┗ 📜 stageA1_eeg_pretrain.py (main script for EEG pre-training)
┗ 📜 eeg_ldm.py (main script for fine-tuning stable diffusion)
┗ 📜 gen_eval_eeg.py (main script for generating images)┗ 📜 dataset.py (functions for loading datasets)
┗ 📜 eval_metrics.py (functions for evaluation metrics)
┗ 📜 config.py (configurations for the main scripts)```
目录
dataset.py
gen_eval_eeg.py
stageA1_eeg_pretrain.py
eeg_ldm.py
gen_eval_eeg.py
dataset.py
一、基础工具函数模块
"沿时间轴进行环形填充"是一种信号处理技术,当数据长度不足时,用数据的起始部分循环填充到末尾(类似"循环播放")
对比其他填充方式:
零填充(Zero-pad):
[1,2,3] -> [1,2,3,0,0]
环形填充:
[1,2,3] -> [1,2,3,1,2]
参数解读:
((0,0), (0, pad_size))
:表示只在第二个维度(时间轴)右侧填充
'wrap'
:指定环形填充模式输入:
x.shape = (128, 500)
(128个EEG通道,500个时间点)
patch_size = 16
(每个时间块包含16个时间点)计算需要填充的长度:
当前时间点:500
需要达到
N × patch_size
的最小长度
ceil(500 / 16) = 32
块 →32×16=512
需填充:
512 - 500 = 12
个时间点填充操作:从每个通道的起始位置取前12个时间点,拼接到末尾
为什么选择环形填充?
填充方式 优点 缺点 适用场景 环形填充 保持信号周期性
避免边界突变可能引入周期性假象 EEG/ECG等准周期信号 零填充 实现简单 引入高频噪声 通用场景 镜像填充 平滑边界 计算复杂 图像处理 对于EEG信号:
具有准周期性(alpha/beta波等)
避免零填充导致的频谱泄漏(spectral leakage)
更适合后续的块处理(patch划分)
Z-score标准化(又称标准差标准化)是一种常见的数据标准化方法,其核心是通过线性变换将原始数据转换为均值为0、标准差为1的分布。
对于一组数据 x,其标准化值 z的计算公式为:z=(x−μ)/σ
μ:数据的均值(平均值)
σ:数据的标准差(反映数据离散程度)
二、时间序列处理模块
时间窗口
定义:将连续的EEG信号按固定时长分段处理
目的:
降低计算复杂度
捕捉局部时域特征
匹配后续处理(如傅里叶变换、模型输入长度)
8 / 0.75 ≈ 10.67,0.75秒/帧:该数据集的时间分辨率(每帧持续时间)
三、数据增强模块
四、核心数据集类
1. 预训练数据集
2. 完整EEG-Image数据集
class EEGDataset(Dataset):def __init__(self, eeg_signals_path):loaded = torch.load(eeg_signals_path) # 加载预处理数据self.data = [{'eeg': tensor, # EEG信号 [通道, 时间]'label': int, # 类别标签 'image': 'n01440764' # ImageNet ID}, ...]def __getitem__(self, i):# EEG处理eeg = data[i]['eeg'].t() # 转置为[时间, 通道]eeg = eeg[20:460] # 选择有效时间窗口eeg = interp1d(...) # 插值到512点# 图像处理image_path = 'n01440764/n01440764_10026.JPEG'image = Image.open(path)image = processor(image) # CLIP预处理
五、数据划分模块
class Splitter:def __init__(self, dataset, split_path):loaded = torch.load(split_path)self.split_idx = loaded['splits'][0]['train'] # 取第一个划分方案# 过滤条件:# 1. EEG长度在450-600之间# 2. 被试匹配(当subject!=0时)
六、图像处理模块
class random_crop:def __call__(self, img):if 概率p: 执行随机裁剪else: 返回原图def normalize2(img):return img * 2.0 - 1.0 # 归一化到[-1,1]
七、重要技术细节
对齐流程:
sequenceDiagramparticipant EEG_Dataparticipant ImageNetEEG_Data->>EEGDataset: 加载样本iEEGDataset->>EEG_Data: 读取self.data[i]["image"]字段EEGDataset->>ImageNet: 根据ID构造路径ImageNet-->>EEGDataset: 返回对应图像EEGDataset->>Model: 返回{'eeg':eeg, 'image':image}
gen_eval_eeg.py
基于MAE (Masked Autoencoder) 的EEG信号预训练框架,主要包含以下核心模块:
-
环境配置与工具函数
-
数据加载与预处理
-
模型定义与训练流程
-
可视化与日志记录
-
分布式训练支持
1. 核心模块解析
2. 关键实现细节
4. 可视化模块
代码流程图
graph TDA[初始化配置] --> B[加载数据集]B --> C[构建MAE模型]C --> D[初始化优化器]D --> E[训练循环]E --> F{达到保存点?}F -- 是 --> G[保存模型+可视化]F -- 否 --> EG --> H[完成训练]
stageA1_eeg_pretrain.py
Pre-training on EEG data
用于大量训练的数据集从MOABB上下载,还没学会,,,,
eeg_ldm.py
Finetune the Stable Diffusion with Pre-trained EEG Encoder
实现了一个基于Latent Diffusion Model (LDM) 的EEG信号到图像生成的完整流程:
一、代码整体架构
本代码是DreamDiffusion项目的第二阶段(Stage B),主要包含以下核心模块:
-
配置管理(Config_Generative_Model)
-
数据加载与预处理(create_EEG_dataset)
-
生成模型定义(eLDM)
-
训练流程控制(main函数)
-
图像生成与评估(generate_images)
-
实验日志记录(wandb集成)
二、核心组件详解
1. 配置管理
class Config_Generative_Model:def __init__(self):# 项目参数self.seed = 2022self.root_path = '.'self.eeg_signals_path = 'datasets/eeg_5_95_std.pth'# 模型参数self.pretrain_mbm_path = 'pretrains/generation/checkpoint.pth'self.pretrain_gm_path = 'pretrains/stable-diffusion-v1-5'# 训练参数self.batch_size = 25self.lr = 5.3e-5self.num_epoch = 500
2. 数据加载
-
加载EEG信号和对应的ImageNet图像路径
-
应用两种图像变换:
-
训练集:随机裁剪+归一化(
img_transform_train
) -
测试集:仅归一化(
img_transform_test
)
-
-
返回包含EEG-图像对的数据集
3. 生成模型(eLDM)
-
双条件机制:同时接受EEG特征和CLIP文本特征
-
基于Latent Diffusion架构
-
支持从检查点恢复训练
5. 图像生成与评估
def generate_images(generative_model, dataset, num_samples, ddim_steps):grid, samples = generative_model.generate(dataset, num_samples, ddim_steps)# 保存图像网格Image.fromarray(grid).save('samples.png')# 计算评估指标metrics = get_eval_metric(samples)return metrics
评估指标:
-
像素级:MSE, PCC, SSIM
-
语义级:Top-1分类准确率
三、关键技术细节
1. 条件扩散模型
graph LRA[EEG信号] --> B[EEG编码器]C[CLIP文本编码] --> D[LDM UNet]B --> DD --> E[图像生成]
2. 双阶段训练策略
-
阶段A:预训练EEG编码器(MAE架构)
-
阶段B:微调扩散模型(本代码)
3. 图像变换流水线
img_transform_train = transforms.Compose([normalize, # 归一化到[-1,1]transforms.Resize(512), # 调整大小random_crop(448, p=0.5), # 随机裁剪(数据增强)transforms.Resize(512), # 再次调整channel_last # 通道顺序转换
])
gen_eval_eeg.py
Generating Images with Trained Checkpoints
实现了EEG信号到图像生成的评估流程:
一、代码整体架构
这段代码是DreamDiffusion项目的评估部分,主要功能是加载预训练好的生成模型,对EEG信号进行图像生成并保存结果。核心模块包括:
-
配置加载:从检查点恢复实验配置
-
数据准备:加载EEG测试数据集
-
模型初始化:构建条件扩散模型(eLDM)
-
图像生成:使用训练好的模型生成图像
-
结果保存:存储生成的图像网格
二、核心组件详解
图像变换流程:
img_transform_test = transforms.Compose([normalize, # 归一化到[-1,1]transforms.Resize((512,512)), # 调整尺寸channel_last # 通道顺序转换 (C,H,W)->(H,W,C)
])
-
数据规格:
-
输入EEG形状:
(num_samples, 128通道, 512时间点)
-
输出图像尺寸:512×512
-
3. 模型初始化
generative_model = eLDM(pretrain_mbm_metafile, # EEG编码器配置num_voxels, # 输入维度=EEG特征长度device=device, # 计算设备pretrain_root=config.pretrain_gm_path, # SD权重路径ddim_steps=config.ddim_steps # 扩散步数(默认250)
)
generative_model.model.load_state_dict(sd['model_state_dict']) # 加载训练权重
模型架构特点:
-
双条件机制:EEG特征 + CLIP文本特征
-
基于Latent Diffusion架构
-
使用DDIM采样方法
4. 图像生成
# 生成训练集样本(10个实例)
grid, _ = generative_model.generate(dataset_train, num_samples=config.num_samples,ddim_steps=config.ddim_steps,HW=config.HW, # 图像尺寸limit=10
)# 生成测试集样本
grid, samples = generative_model.generate(dataset_test,num_samples=config.num_samples,ddim_steps=config.ddim_steps,state=sd['state'] # 随机状态恢复
)
生成参数:
参数 | 含义 | 典型值 |
---|---|---|
num_samples | 每样本生成数量 | 5 |
ddim_steps | 扩散采样步数 | 250 |
HW | 图像高宽 | [512,512] |
limit | 最大生成样本数 | 10 |
三、关键技术细节
1. 条件生成流程
sequenceDiagramparticipant EEGparticipant Modelparticipant ImageEEG->>Model: 输入EEG信号(128ch×512t)Model->>Model: 通过EEG编码器提取特征Model->>Model: 扩散模型条件生成Model->>Image: 输出512×512图像
这个生成代码很有问题啊,一直报错,类似这样,很多人都出现了,但目前无法解决,,,,