随机微分方程的分数扩散模型 (score-based diffusion model) 代码示例

随机微分方程的分数扩散模型(Score-Based Generative Modeling through Stochastic Differential Equations)

基于分数的扩散模型,是估计数据分布梯度的方法,可以在不需要对抗训练的基础上,生成与GAN一样高质量的图片。来源于文章:Yang Song, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole. "Score-Based Generative Modeling through Stochastic Differential Equations." Internation Conference on Learning Representations, 2021

score-based diffusion是diffusion模型大火之后,又一个里程碑式的工作,将扩散模型和分数生成模型进行了统一。原始的扩散模型也有缺点,它的采样速度慢,通常需要数千个评估步骤才能抽取一个样本。而 score-based 的扩散模型可以在较短的时间内完成采样。

网络上有很多关于score-based diffusion原理介绍,应用案例等,还有文章解读,大家可以参考。但是,提供代码简介的很少,为此这里提供了score-based diffusion 模型的简单的可运行的代码示例。

1. 定义time-dependent score-based模型

导入相关模块

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as npimport torch
import functools
from torch.optim import Adam
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
import tqdm

1.1 将时间t嵌入的投影层

其实并没有投影层的说法,这里是为了描述将时间t (time step),随机初始化采样权重,然后使用[sin(2πωt);cos(2πωt)]生成相应的高斯随机特征向量的过程。注意,里面的参数是不可训练的。

class GaussianFourierProjection(nn.Module):"""Gaussian random features for encoding time steps."""  def __init__(self, embed_dim, scale=30.):super().__init__()# 在初始化期间随机采样权重。 这些权重是固定的 # 在优化期间并且不可训练self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)def forward(self, x):x_proj = x[:, None] * self.W[None, :] * 2 * np.pireturn torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)

将时间t嵌入的投影层的出现,是因为score-based的扩散模型和正常的扩散模型的训练过程不一样。score-based的扩散模型在训练过程中,神经网络接受带有随机噪音的 x ,然后随机的时间信息 t 添加x中,然后利用x 和 t 作为输入,计算模型损失。

维度转换全连接层:

class Dense(nn.Module):"""A fully connected layer that reshapes outputs to feature maps."""def __init__(self, input_dim, output_dim):super().__init__()self.dense = nn.Linear(input_dim, output_dim)def forward(self, x):return self.dense(x)[..., None, None]

1.2 时间依赖基于分数的Unet模型

(time-dependent score-based model) 时间依赖,打分相关的Unet模型,froward函数中,输入除了x,还有时间t. 时间t经过GaussianFourierProjection嵌入后融合到模型中,然后输出marginal_prob_std正则化的结果。

class ScoreNet(nn.Module):"""初始化一个依赖时间的基于分数的Unet网络."""def __init__(self, marginal_prob_std, channels=[32, 64, 128, 256], embed_dim=256):""".Args:marginal_prob_std: 输入时间 t 并给出扰动核的标准差的函数 p_{0t}(x(t) | x(0)).channels: 各分辨率特征图的通道数.embed_dim: 高斯随机特征嵌入的维数,与1.1中GaussianFourierProjection相同."""super().__init__()# 时间t的高斯随机特征嵌入层self.embed = nn.Sequential(GaussianFourierProjection(embed_dim=embed_dim),nn.Linear(embed_dim, embed_dim))# Encoding layers where the resolution decreasesself.conv1 = nn.Conv2d(1, channels[0], 3, stride=1, bias=False)self.dense1 = Dense(embed_dim, channels[0])self.gnorm1 = nn.GroupNorm(4, num_channels=channels[0])self.conv2 = nn.Conv2d(channels[0], channels[1], 3, stride=2, bias=False)self.dense2 = Dense(embed_dim, channels[1])self.gnorm2 = nn.GroupNorm(32, num_channels=channels[1])self.conv3 = nn.Conv2d(channels[1], channels[2], 3, stride=2, bias=False)self.dense3 = Dense(embed_dim, channels[2])self.gnorm3 = nn.GroupNorm(32, num_channels=channels[2])self.conv4 = nn.Conv2d(channels[2], channels[3], 3, stride=2, bias=False)self.dense4 = Dense(embed_dim, channels[3])self.gnorm4 = nn.GroupNorm(32, num_channels=channels[3])    # 分辨率增加的解码层self.tconv4 = nn.ConvTranspose2d(channels[3], channels[2], 3, stride=2, bias=False)self.dense5 = Dense(embed_dim, channels[2])self.tgnorm4 = nn.GroupNorm(32, num_channels=channels[2])self.tconv3 = nn.ConvTranspose2d(channels[2] + channels[2], channels[1], 3, stride=2, bias=False, output_padding=1)    self.dense6 = Dense(embed_dim, channels[1])self.tgnorm3 = nn.GroupNorm(32, num_channels=channels[1])self.tconv2 = nn.ConvTranspose2d(channels[1] + channels[1], channels[0], 3, stride=2, bias=False, output_padding=1)    self.dense7 = Dense(embed_dim, channels[0])self.tgnorm2 = nn.GroupNorm(32, num_channels=channels[0])self.tconv1 = nn.ConvTranspose2d(channels[0] + channels[0], 1, 3, stride=1)# Swish 激活函数self.act = lambda x: x * torch.sigmoid(x)self.marginal_prob_std = marginal_prob_stddef forward(self, x, t): # 0   embed = self.act(self.embed(t))    # Encoding pathh1 = self.conv1(x)    ## 合并来自 t 的信息h1 += self.dense1(embed)## 组标准化h1 = self.gnorm1(h1)h1 = self.act(h1)h2 = self.conv2(h1)h2 += self.dense2(embed)h2 = self.gnorm2(h2)h2 = self.act(h2)h3 = self.conv3(h2)h3 += self.dense3(embed)h3 = self.gnorm3(h3)h3 = self.act(h3)h4 = self.conv4(h3)h4 += self.dense4(embed)h4 = self.gnorm4(h4)h4 = self.act(h4)# Decoding pathh = self.tconv4(h4)## 从编码路径跳过连接h += self.dense5(embed)h = self.tgnorm4(h)h = self.act(h)h = self.tconv3(torch.cat([h, h3], dim=1))h += self.dense6(embed)h = self.tgnorm3(h)h = self.act(h)h = self.tconv2(torch.cat([h, h2], dim=1))h += self.dense7(embed)h = self.tgnorm2(h)h = self.act(h)h = self.tconv1(torch.cat([h, h1], dim=1))# Normalize output 正则化输出h = h / self.marginal_prob_std(t)[:, None, None, None]return h

2. 设置SDE

SDE用于将P_0扰动到P_T, 其中,包含两个重要函数:之前提到的marginal_prob_std和扩散系数diffusion_coeff marginal_prob_std,计算 p_{0t}(x(t) | x(0)) 的平均值和标准差; diffusion_coeff,计算SDE的扩散系数.

device = 'cuda' #@param ['cuda', 'cpu'] {'type':'string'}def marginal_prob_std(t, sigma):"""计算p_{0t}(x(t) | x(0))的平均值和标准差.Args:    t: A vector of time steps.sigma: The $\sigma$ in our SDE.  Returns:标准差."""    t = torch.tensor(t, device=device)return torch.sqrt((sigma**(2 * t) - 1.) / 2. / np.log(sigma))def diffusion_coeff(t, sigma):"""计算SDE的扩散系数.Args:t: A vector of time steps.sigma: The $\sigma$ in our SDE.Returns:扩散系数向量."""return torch.tensor(sigma**t, device=device)sigma =  25.0 #@param {'type':'number'}
marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma)
diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/130996.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Cube MX 开发高精度电流源跳坑过程/SPI连接ADS1255/1256系列问题总结/STM32 硬件SPI开发过程

文章目录 概要整体架构流程技术名词解释技术细节小结 概要 1.使用STM32F系列开发一款高精度恒流电源,用到了24位高精度采样芯片ADS1255/ADS1256系列。 2.使用时发现很多的坑,详细介绍了每个坑的具体情况和实际的解决办法。 坑1:波特率设置…

如何使用Ruby 多线程爬取数据

现在比较主流的爬虫应该是用python,之前也写了很多关于python的文章。今天在这里我们主要说说ruby。我觉得ruby也是ok的,我试试看写了一个爬虫的小程序,并作出相应的解析。 Ruby中实现网页抓取,一般用的是mechanize,使…

Pytorch从零开始实战08

Pytorch从零开始实战——YOLOv5-C3模块实现 本系列来源于365天深度学习训练营 原作者K同学 文章目录 Pytorch从零开始实战——YOLOv5-C3模块实现环境准备数据集模型选择开始训练可视化模型预测总结 环境准备 本文基于Jupyter notebook,使用Python3.8&#xff0c…

webJS基础-----制作一个时间倒计时

1,可以使用以下两个方式制作 方式1:setTimeout ()定时器是在指定的时间后执行某些代码,代码执行一次就会自动停止; 方式2:setInterval ()定时器是按照指定的周期来重复执行某些代码,该定时器不会自动停止…

DL Homework 6

目录 一、概念 (1)卷积 (2)卷积核 (3)特征图 (4)特征选择 (5)步长 (6)填充 (7)感受野 二、探究不同卷…

【开题报告】基于uniapp的在线考试小程序的设计与实现

1.研究背景 随着社会的发展和科技的进步,网络技术被广泛应用于教育领域。在线教育已成为当今发展趋势之一,其中在线考试更是具有重要的意义。传统的考试方式不仅耗费大量人力物力,而且存在考试成果的保密问题。而在线考试可以使考试过程更加…

JVM运行时数据区-堆

目录 一、堆的核心概述 (一)概述 (二)堆空间细分 (三)jvisualvm工具 二、设置堆内存的大小与OOM 三、年轻代与老年代 四、图解对象分配一般过程 五、对象分配特殊过程 六、常用调优工具 七、Mino…

手搓一个ubuntu自动安装python3.9的sh脚本

#!/bin/bash# Step 1: 更新系统软件包 sudo apt update sudo apt upgrade -y sudo apt install -y software-properties-common# Step 2: 安装Python 3.9的依赖项 sudo apt install -y build-essential zlib1g-dev libncurses5-dev libgdbm-dev libnss3-dev libssl-dev libread…

leetCode 416.分割等和子集 + 01背包 + 动态规划 + 记忆化搜索 + 递推 + 空间优化

关于此题我的往期文章: LeetCode 416.分割等和子集(动态规划【0-1背包问题】采用一维数组dp:滚动数组)_呵呵哒( ̄▽ ̄)"的博客-CSDN博客https://heheda.blog.csdn.net/article/details/133212716看本期文章时&…

使用udevdm查询蓝牙模块的信息

1.首先查询蓝牙设备在系统中的设备路径 udevadm info --querypath -n /dev/ttyS1 2.查询蓝牙设备的所有信息包括父设备信息 EMUELEC:~ # udevadm info -ap /devices/platform/ffd24000.serial/tty/ttyS1 备注:查询设备所有信息 udevadm info --queryall -n /dev…

关于JADX和JEB的小问题

关于JADX和JEB的小问题 很久没水过技术文啦,最近也刚好遇到点小问题,特此记录 第一个问题 在处理app加密逻辑的时候一直拿不到正确的密文,反复看了反编译出来的代码(如下图) public static string n(String str, Stri…

基础课22——云服务(SaaS、Pass、laas、AIaas)

1.云服务概念和类型 云服务是一种基于互联网的计算模式,通过云计算技术将计算、存储、网络等资源以服务的形式提供给用户,用户可以通过网络按需使用这些资源,无需购买、安装和维护硬件设备。云服务具有灵活扩展、按需使用、随时随地访问等优…

linux 查看当前目录下每个文件夹大小

要在 Linux 中查看当前目录下每个文件夹的大小,可以使用 du 命令(磁盘使用情况)结合其他一些选项。下面是几个常用的命令示例: 显示当前目录下每个文件夹的大小——只显示一层文件夹: du -h --max-depth1该命令会以人…

2023年内衣行业分析:京东大数据平台-服饰内衣市场解析

如今,女性消费力的提升正在推动国内女性内衣市场份额逐年提升。而今年,内衣市场更是进入了存量之战,增长趋势明显减弱。 根据鲸参谋数据显示,今年1月至9月,京东平台内衣(文胸)累计销量约500万件…

【数智化案例展】某国际高端酒店品牌——呼叫中心培训数智化转型项目

‍ 维音案例 本项目案例由维音投递并参与数据猿与上海大数据联盟联合推出的《2023中国数智化转型升级创新服务企业》榜单/奖项”评选。 大数据产业创新服务媒体 ——聚焦数据 改变商业 培训是呼叫中心管理的重要环节,由于员工流动性强、培训需求多样、考核流程繁琐…

[Emuelec]独立模拟器自动映射手柄按键脚本研究

在Emuelec中,对独立模拟器配置手柄按键是个非常头疼的事,难点在于emuelec的按钮配置映射到模拟器所需的按钮配置,更头疼的是,每个模拟器所需的配置都不相同,此时就需要花大把时间了解每个模拟器的配置上。好在&#xf…

2003 - Can‘t connect to MysQL server on ‘39.108.169.0‘ (10060 “Unknown error“)

问题描述 某天和往常一样启动java项目,发现数据库出问题了,然后打开navicat,发现数据库的链接都连接不上, 一点击就会弹出报错框: 然后就各种上网搜索。 解决方案 上网查了一些解决方案,大部分都是说看…

hivesql,sql 函数总结:

1、NVL函数与Coalesce差异 -- select nvl(null,8); -- 结果是 8 -- select nvl(,7); -- 结果是"" -- select coalesce(null,null,9); -- 结果是 9 -- select coalesce("",null,9); -- 结果是 "" 1.2、 NVL函数与Coalesce差异 …

DB-GPT介绍

DB-GPT介绍 引言DB-GPT项目简介DB-GPT架构关键特性私域问答&数据处理多数据源&可视化自动化微调Multi-Agents&Plugins多模型支持与管理隐私安全支持数据源 子模块DB-GPT-Hub微调参考文献 引言 随着数据量的不断增长和数据分析的需求日益增多,将自然语言…

Technology strategy Pattern 学习笔记4 - Creating the Strategy-Corporate Context

Creating the Strategy-Corporate Context 1 •. Stakeholder Alignment 1.1 要成功,要尽可能获得powerful leader的支持 1.2 也需要获得最高执行层的支持 1.3 Determining(确定) Stakeholders 需要建立360度组织图,确认三类人…