DreamDiffusion代码学习及复现

论文解读在这里

File path | Description
```/pretrains
┣ 📂 models
┃   ┗ 📜 config.yaml
┃   ┗ 📜 v1-5-pruned.ckpt┣ 📂 generation  
┃   ┗ 📜 checkpoint_best.pth ┣ 📂 eeg_pretain
┃   ┗ 📜 checkpoint.pth  (pre-trained EEG encoder)/datasets
┣ 📂 imageNet_images (subset of Imagenet)┗  📜 block_splits_by_image_all.pth
┗  📜 block_splits_by_image_single.pth 
┗  📜 eeg_5_95_std.pth  /code
┣ 📂 sc_mbm
┃   ┗ 📜 mae_for_eeg.py
┃   ┗ 📜 trainer.py
┃   ┗ 📜 utils.py┣ 📂 dc_ldm
┃   ┗ 📜 ldm_for_eeg.py
┃   ┗ 📜 utils.py
┃   ┣ 📂 models
┃   ┃   ┗ (adopted from LDM)
┃   ┣ 📂 modules
┃   ┃   ┗ (adopted from LDM)┗  📜 stageA1_eeg_pretrain.py   (main script for EEG pre-training)
┗  📜 eeg_ldm.py    (main script for fine-tuning stable diffusion)
┗  📜 gen_eval_eeg.py               (main script for generating images)┗  📜 dataset.py                (functions for loading datasets)
┗  📜 eval_metrics.py           (functions for evaluation metrics)
┗  📜 config.py                 (configurations for the main scripts)```

目录

dataset.py

gen_eval_eeg.py

stageA1_eeg_pretrain.py

eeg_ldm.py

gen_eval_eeg.py


dataset.py

一、基础工具函数模块

"沿时间轴进行环形填充"是一种信号处理技术,当数据长度不足时,用数据的起始部分循环填充到末尾(类似"循环播放")

  • 对比其他填充方式

    • 零填充(Zero-pad):[1,2,3] -> [1,2,3,0,0]

    • 环形填充:[1,2,3] -> [1,2,3,1,2]

  • 参数解读

    • ((0,0), (0, pad_size)):表示只在第二个维度(时间轴)右侧填充

    • 'wrap':指定环形填充模式

  • 输入

    • x.shape = (128, 500)(128个EEG通道,500个时间点)

    • patch_size = 16(每个时间块包含16个时间点)

  • 计算需要填充的长度

    • 当前时间点:500

    • 需要达到 N × patch_size 的最小长度

    • ceil(500 / 16) = 32 块 → 32×16=512

    • 需填充:512 - 500 = 12 个时间点

  • 填充操作:从每个通道的起始位置取前12个时间点,拼接到末尾

为什么选择环形填充?

填充方式优点缺点适用场景
环形填充保持信号周期性
避免边界突变
可能引入周期性假象EEG/ECG等准周期信号
零填充实现简单引入高频噪声通用场景
镜像填充平滑边界计算复杂图像处理

对于EEG信号:

  • 具有准周期性(alpha/beta波等)

  • 避免零填充导致的频谱泄漏(spectral leakage)

  • 更适合后续的块处理(patch划分)

Z-score标准化(又称标准差标准化)是一种常见的数据标准化方法,其核心是通过线性变换将原始数据转换为均值为0、标准差为1的分布。

对于一组数据 x,其标准化值 z的计算公式为:z=(x−μ)/σ

  • μ:数据的均值(平均值)

  • σ:数据的标准差(反映数据离散程度)

二、时间序列处理模块

 时间窗口

  • 定义:将连续的EEG信号按固定时长分段处理

  • 目的

    • 降低计算复杂度

    • 捕捉局部时域特征

    • 匹配后续处理(如傅里叶变换、模型输入长度)

  • 8 / 0.75 ≈ 10.67,0.75秒/帧:该数据集的时间分辨率(每帧持续时间)

三、数据增强模块

四、核心数据集类

1. 预训练数据集

2. 完整EEG-Image数据集
class EEGDataset(Dataset):def __init__(self, eeg_signals_path):loaded = torch.load(eeg_signals_path)  # 加载预处理数据self.data = [{'eeg': tensor,       # EEG信号 [通道, 时间]'label': int,        # 类别标签 'image': 'n01440764' # ImageNet ID}, ...]def __getitem__(self, i):# EEG处理eeg = data[i]['eeg'].t()     # 转置为[时间, 通道]eeg = eeg[20:460]            # 选择有效时间窗口eeg = interp1d(...)          # 插值到512点# 图像处理image_path = 'n01440764/n01440764_10026.JPEG'image = Image.open(path)image = processor(image)     # CLIP预处理

五、数据划分模块

class Splitter:def __init__(self, dataset, split_path):loaded = torch.load(split_path)self.split_idx = loaded['splits'][0]['train']  # 取第一个划分方案# 过滤条件:# 1. EEG长度在450-600之间# 2. 被试匹配(当subject!=0时)

六、图像处理模块

class random_crop:def __call__(self, img):if 概率p: 执行随机裁剪else: 返回原图def normalize2(img):return img * 2.0 - 1.0  # 归一化到[-1,1]

七、重要技术细节

对齐流程:

sequenceDiagramparticipant EEG_Dataparticipant ImageNetEEG_Data->>EEGDataset: 加载样本iEEGDataset->>EEG_Data: 读取self.data[i]["image"]字段EEGDataset->>ImageNet: 根据ID构造路径ImageNet-->>EEGDataset: 返回对应图像EEGDataset->>Model: 返回{'eeg':eeg, 'image':image}

gen_eval_eeg.py

基于MAE (Masked Autoencoder) 的EEG信号预训练框架,主要包含以下核心模块:

  1. 环境配置与工具函数

  2. 数据加载与预处理

  3. 模型定义与训练流程

  4. 可视化与日志记录

  5. 分布式训练支持

1. 核心模块解析

2. 关键实现细节

4. 可视化模块

代码流程图

graph TDA[初始化配置] --> B[加载数据集]B --> C[构建MAE模型]C --> D[初始化优化器]D --> E[训练循环]E --> F{达到保存点?}F -- 是 --> G[保存模型+可视化]F -- 否 --> EG --> H[完成训练]

stageA1_eeg_pretrain.py

Pre-training on EEG data

用于大量训练的数据集从MOABB上下载,还没学会,,,,

eeg_ldm.py

Finetune the Stable Diffusion with Pre-trained EEG Encoder

实现了一个基于Latent Diffusion Model (LDM) 的EEG信号到图像生成的完整流程:


一、代码整体架构

本代码是DreamDiffusion项目的第二阶段(Stage B),主要包含以下核心模块:

  1. 配置管理(Config_Generative_Model)

  2. 数据加载与预处理(create_EEG_dataset)

  3. 生成模型定义(eLDM)

  4. 训练流程控制(main函数)

  5. 图像生成与评估(generate_images)

  6. 实验日志记录(wandb集成)

二、核心组件详解

1. 配置管理
class Config_Generative_Model:def __init__(self):# 项目参数self.seed = 2022self.root_path = '.'self.eeg_signals_path = 'datasets/eeg_5_95_std.pth'# 模型参数self.pretrain_mbm_path = 'pretrains/generation/checkpoint.pth'self.pretrain_gm_path = 'pretrains/stable-diffusion-v1-5'# 训练参数self.batch_size = 25self.lr = 5.3e-5self.num_epoch = 500
2. 数据加载
  1. 加载EEG信号和对应的ImageNet图像路径

  2. 应用两种图像变换:

    • 训练集:随机裁剪+归一化(img_transform_train

    • 测试集:仅归一化(img_transform_test

  3. 返回包含EEG-图像对的数据集

3. 生成模型(eLDM)
  • 双条件机制:同时接受EEG特征和CLIP文本特征

  • 基于Latent Diffusion架构

  • 支持从检查点恢复训练

5. 图像生成与评估
def generate_images(generative_model, dataset, num_samples, ddim_steps):grid, samples = generative_model.generate(dataset, num_samples, ddim_steps)# 保存图像网格Image.fromarray(grid).save('samples.png')# 计算评估指标metrics = get_eval_metric(samples)return metrics

评估指标

  • 像素级:MSE, PCC, SSIM

  • 语义级:Top-1分类准确率

三、关键技术细节

1. 条件扩散模型
graph LRA[EEG信号] --> B[EEG编码器]C[CLIP文本编码] --> D[LDM UNet]B --> DD --> E[图像生成]
2. 双阶段训练策略
  1. 阶段A:预训练EEG编码器(MAE架构)

  2. 阶段B:微调扩散模型(本代码)

3. 图像变换流水线
img_transform_train = transforms.Compose([normalize,                     # 归一化到[-1,1]transforms.Resize(512),        # 调整大小random_crop(448, p=0.5),       # 随机裁剪(数据增强)transforms.Resize(512),        # 再次调整channel_last                   # 通道顺序转换
])

gen_eval_eeg.py

Generating Images with Trained Checkpoints

实现了EEG信号到图像生成的评估流程:

一、代码整体架构

这段代码是DreamDiffusion项目的评估部分,主要功能是加载预训练好的生成模型,对EEG信号进行图像生成并保存结果。核心模块包括:

  1. 配置加载:从检查点恢复实验配置

  2. 数据准备:加载EEG测试数据集

  3. 模型初始化:构建条件扩散模型(eLDM)

  4. 图像生成:使用训练好的模型生成图像

  5. 结果保存:存储生成的图像网格


二、核心组件详解

图像变换流程
img_transform_test = transforms.Compose([normalize,                  # 归一化到[-1,1]transforms.Resize((512,512)), # 调整尺寸channel_last                # 通道顺序转换 (C,H,W)->(H,W,C)
])
  • 数据规格

    • 输入EEG形状:(num_samples, 128通道, 512时间点)

    • 输出图像尺寸:512×512

3. 模型初始化
generative_model = eLDM(pretrain_mbm_metafile,   # EEG编码器配置num_voxels,              # 输入维度=EEG特征长度device=device,           # 计算设备pretrain_root=config.pretrain_gm_path,  # SD权重路径ddim_steps=config.ddim_steps  # 扩散步数(默认250)
)
generative_model.model.load_state_dict(sd['model_state_dict'])  # 加载训练权重

模型架构特点

  • 双条件机制:EEG特征 + CLIP文本特征

  • 基于Latent Diffusion架构

  • 使用DDIM采样方法

4. 图像生成
# 生成训练集样本(10个实例)
grid, _ = generative_model.generate(dataset_train, num_samples=config.num_samples,ddim_steps=config.ddim_steps,HW=config.HW,  # 图像尺寸limit=10
)# 生成测试集样本
grid, samples = generative_model.generate(dataset_test,num_samples=config.num_samples,ddim_steps=config.ddim_steps,state=sd['state']  # 随机状态恢复
)

生成参数

参数含义典型值
num_samples每样本生成数量5
ddim_steps扩散采样步数250
HW图像高宽[512,512]
limit最大生成样本数10

三、关键技术细节

1. 条件生成流程
sequenceDiagramparticipant EEGparticipant Modelparticipant ImageEEG->>Model: 输入EEG信号(128ch×512t)Model->>Model: 通过EEG编码器提取特征Model->>Model: 扩散模型条件生成Model->>Image: 输出512×512图像

这个生成代码很有问题啊,一直报错,类似这样,很多人都出现了,但目前无法解决,,,,

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/75508.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

用Python实现TCP代理

依旧是Python黑帽子这本书 先附上代码,我在原书代码上加了注释,更好理解 import sys import socket import threading#生成可打印字符映射 HEX_FILTER.join([(len(repr(chr(i)))3) and chr(i) or . for i in range(256)])#接收bytes或string类型的输入…

Pyinstaller 打包flask_socketio为exe程序后出现:ValueError: Invalid async_mode specified

Pyinstaller 打包flask_socketio为exe程序后出现&#xff1a;ValueError: Invalid async_mode specified 一、详细描述问题描述 Traceback (most recent call last): File "app_3.py", line 22, in <module> File "flask_socketio\__init__.py"…

django REST framework(DRF)教程

Django DRF API Django 基本使用Django DRF序列化器Django DRF视图Django DRF常用功能Django 基本使用 前后端分离开发模式认识RestFulAPI回顾Django开发模式Django REST Framework初探前后端分离开发模式 前后端分离前:前端页面看到的效果都是由后端控制,即后端渲染HTML页面…

【Linux】Orin NX + Ubuntu22.04配置国内源

1、获取源 清华源 arm 系统的源,可以在如下地址获取到 https://mirror.tuna.tsinghua.edu.cn/help/ubuntu-ports/ 选择HTTPS,否则可能报错: 明文签署文件不可用,结果为‘NOSPLIT’(您的网络需要认证吗?)查看Orin NX系统版本 选择jammy的源 2、更新源 1)备份原配…

【含文档+PPT+源码】基于微信小程序的社交摄影约拍平台的设计与实现

项目介绍 本课程演示的是一款基于微信小程序的社交摄影约拍平台的设计与实现&#xff0c;主要针对计算机相关专业的正在做毕设的学生与需要项目实战练习的 Java 学习者。 1.包含&#xff1a;项目源码、项目文档、数据库脚本、软件工具等所有资料 2.带你从零开始部署运行本套系…

JDBC常用的接口

一、什么是JDBC JDBC是Java语言连接数据库的接口规范。 二、JDBC的体系 1、Java官方提供一个操作数据库的抽象接口 抽象接口有很多的接口和抽象类。 例如&#xff1a;Driver、Connection、Statement。 2、各个数据库厂商提供各自的Java实现类 需要各自实现具体的细节。 例如&am…

容器适配器-stack栈

C标准库不只是包含了顺序容器&#xff0c;还包含一些为满足特殊需求而设计的容器&#xff0c;它们提供简单的接口。 这些容器可被归类为容器适配器(container adapter)&#xff0c;它们是改造别的标准顺序容器&#xff0c;使之满足特殊需求的新容器。 适配器:也称配置器,把一…

[250403] HuggingFace 新增检查模型与电脑兼容性的功能 | Firefox 发布137.0 支持标签组

目录 Hugging Face 让寻找兼容的 AI 模型变得更容易Firefox 137 版本更新摘要 Hugging Face 让寻找兼容的 AI 模型变得更容易 Hugging Face 是一个流行的在线平台&#xff0c;用于访问开源人工智能 (AI) 工具和模型。该平台推出了一项有用的新功能&#xff0c;允许个人轻松检查…

.NET 创建MCP使用大模型对话二:调用远程MCP服务

在上一篇文章.NET 创建MCP使用大模型对话-CSDN博客中&#xff0c;我们简述了如何使用mcp client使用StdIo模式调用本地mcp server。本次实例将会展示如何使用mcp client模式调用远程mcp server。 一&#xff1a;创建mcp server 我们创建一个天气服务。 新建WebApi项目&#x…

Redis 中 Set(例如标签) 和 ZSet(例如排行榜) 的详细对比,涵盖定义、特性、命令、适用场景及总结表格

以下是 Redis 中 Set 和 ZSet 的详细对比&#xff0c;涵盖定义、特性、命令、适用场景及总结表格&#xff1a; 1. 核心定义 数据类型SetZSet&#xff08;Sorted Set&#xff09;定义无序的、唯一的字符串集合&#xff0c;元素不重复。有序的、唯一的字符串集合&#xff0c;每个…

解决Spring参数解析异常:Name for argument of type XXX not specified

前言 在开发 Spring Boot 应用时&#xff0c;我们常遇到类似 java.lang.IllegalArgumentException: Name for argument not specified 的报错。这类问题通常与方法参数名称的解析机制相关&#xff0c;尤其在使用 RequestParam、PathVariable 等注解时更为常见。 一、问题现象与…

刚刚,OpenAI开源PaperBench,重塑顶级AI Agent评测

今天凌晨1点&#xff0c;OpenAI开源了一个全新的AI Agent评测基准——PaperBench。 这个基准主要考核智能体的搜索、整合、执行等能力&#xff0c;需要对2024年国际机器学习大会上顶尖论文的复现&#xff0c;包括对论文内容的理解、代码编写以及实验执行等方面的能力。 根据O…

Golang封装Consul 服务发现库

以下是一个经过生产验证的 Consul 服务发现封装库,支持注册/注销、健康检查、智能发现等核心功能,可直接集成到项目中: package consulimport ("context""fmt""log""math/rand""net""os""sync"&quo…

自适应信号处理任务(过滤,预测,重建,分类)

自适应滤波 # signals creation: u, v, d N = 5000 n = 10 u = np.sin(np.arange(0, N/10., N/50000

PyTorch深度学习框架 的基础知识

目录 1.pyTorch检查是否安装成功 2.PyTorch的张量tensor 基础创建方式&#xff08;三种&#xff09; 2.2用列表创建tensor 2.2使用元组创建 tensor 2.3使用ndarray创建创建 tensor 2.4 快速创建tensor的常用方法 3.pyTorch中的张量tensor的常用属性 4. tensor中的基础数据…

MySQL学习集--DDL

DDL 数据库操作 查询所有数据库 SHOW DATABASES;查询当前数据库 SELECT DATABASE();创建 CREATE DATABASE[IF NOT EXISTS]数据库名[DEFAULT CHARSET 字符集][COLLATE 排序规则];删除 DROR DATABASE[IF EXISTS]数据库名;使用 USE 数据库名;表操作 创建表格 CREATE TABL…

Vue 3 中按照某个字段将数组分成多个数组

方法一&#xff1a;使用 reduce 方法 const originalArray [{ id: 1, category: A, name: Item 1 },{ id: 2, category: B, name: Item 2 },{ id: 3, category: A, name: Item 3 },{ id: 4, category: C, name: Item 4 },{ id: 5, category: B, name: Item 5 }, ];const grou…

LeetCode刷题 -- 48. 旋转图像

题目 算法题解&#xff1a;顺时针旋转矩阵&#xff08;90度&#xff09; 1. 算法描述 给定一个 n n 的二维矩阵&#xff0c;请将矩阵顺时针旋转 90 度。 例如&#xff1a; 输入&#xff1a; [[1,2,3],[4,5,6],[7,8,9] ]输出&#xff1a; [[7,4,1],[8,5,2],[9,6,3] ]2. 思…

Vulkan进阶系列1 - Vulkan应用程序结构(完整代码)

一: 概述 在前面的20多篇文章中,我们了解了Vulkan的基础知识,和相关API的使用,接下来我们要从零开始写一套完整Vulkan应用程序,在这个过程中加深对Vulkan中的各种概念的理解。 Vulkan 应用程序一般遵循 初始化 -> 运行循环 -> 资源清理 的结构,本实例也基本遵循了…

VTK的两种显示刷新方式

在类中先声明vtk的显示对象 vtkRenderer out_render; vtkVertexGlyphFilter glyphFilter; vtkPolyDataMapper mapper; // 新建制图器 vtkActor actor; // 新建角色 然后在init中先初始化一下&#xff1a; out_rend…