【扩散模型 李宏毅B站教学以及基础代码运用】

李宏毅教学视频:
Link1

B站DDPM公式推导以及代码实现:
Link2

这个视频里面有论文里面的公式推导,并且1小时10分开始讲解实例代码。

文章目录

    • 扩散模型概念:
    • Diffusion Model工作原理:
    • 影像生成模型本质上的共同目标
    • B站简单示例代码讲解

扩散模型概念:

就像石头里面已经有了雕塑,只需要看我们怎么把其他多余的部分去掉。
在这里插入图片描述
注意观察,我们每一个Denoise阶段都不一样,因为每一个阶段传入的图片以及需要处理的noise都不一样,并且直接产生图片比直接产生噪音更难,所以我们通过预测noise来解决问题。
在这里插入图片描述

比如下图所示:step2是我们加的噪声,那么传入input和2的时候就希望预测出gt了,然后进行相减得到step1的图片。
在这里插入图片描述

Diffusion Model工作原理:

VAE和Diffusion的区别
在这里插入图片描述
先看整个训练过程:
在这里插入图片描述

实际结果和我们想的是不一样的。训练时通过X0和噪声得到一个图,逆向的时候输入t和生成的图来得到噪音。想象的是一点一点加入噪音,实际上是直接加进去的。在这里插入图片描述
推断时刻:theat是带有参数的网络。
在这里插入图片描述

影像生成模型本质上的共同目标

通过采样一个高深distribution生成一个图片。希望生成的图片和真实的图片的distribution很接近。
在这里插入图片描述
那么怎么衡量这两个分布的接近程度呢?多数采用的都是Maximum liklihood Estimation.
我们希望我们采样的数据能够通过theta网络计算出来的概率越大越好。 在这里插入图片描述
通过数学变换,将概率最大变为Pdata和Ptheat这两个distribution的KL散度最小。
在这里插入图片描述
VAE的下界
Ptheat(x)表示:通过theta产生x的概率。
在这里插入图片描述

在这里插入图片描述
DDPM计算Ptheta(x)的方法 下图表示产生X0的概率。
在这里插入图片描述
两者对比
在这里插入图片描述
接下来需要计算q(x1|x0)此类公式。
计算方法:X1到X2的计算方法在论文中有提及。
在这里插入图片描述
两个高斯分布都是服从N(0,1),相加的话还是一个高斯分布,并且还是服从N(0,1),只是前面系数会发生变化。系数的话是根号下面数字相加。所以相加之后均值还是为0,方差a方加b方即可,这个在另外一个视频里面有讲解。
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
经过一番推导之后得到:
在这里插入图片描述
之后计算最下面三项:
在这里插入图片描述
通过以下推导:
在这里插入图片描述
之后通过X0,Xt可以得到Xt-1的分布。
在这里插入图片描述
可以看到前面一项的mean 和 variance是固定的,第二项的variance也是固定的,因此我们需要把第二项的mean变得和第一项的接近。
在这里插入图片描述
那么怎么minimiaze这个mean呢?希望用Xt去预测出来那个mean。
在这里插入图片描述
经过推导:
在这里插入图片描述
最终得到下图:
在这里插入图片描述
里面beta可以学习,但是效果不好,所以使用线性固定。最后加上一个噪声猜测是为了增强鲁棒性,并且本身就是从噪声开始,不加噪声的话可能不会生成图片。

B站简单示例代码讲解

# 加载数据集
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import make_s_curve
import torchs_curve,_ = make_s_curve(10**4,noise=0.1)
print(np.shape(s_curve))
s_curve = s_curve[:,[0,2]]/10.0print("shape of s:",np.shape(s_curve))data = s_curve.Tfig,ax = plt.subplots()
ax.scatter(*data,color='blue',edgecolor='white');ax.axis('off')dataset = torch.Tensor(s_curve).float()

在这里插入图片描述

# 2确定超参数的值
num_steps = 100
#制定每一步的beta
betas = torch.linspace(-6,6,num_steps)
betas = torch.sigmoid(betas)*(0.5e-2 - 1e-5)+1e-5#计算alpha、alpha_prod、alpha_prod_previous、alpha_bar_sqrt等变量的值
alphas = 1-betas
alphas_prod = torch.cumprod(alphas,0)
# print(alphas_prod)
alphas_prod_p = torch.cat([torch.tensor([1]).float(),alphas_prod[:-1]],0)
# print(alphas_prod_p)
alphas_bar_sqrt = torch.sqrt(alphas_prod)
one_minus_alphas_bar_log = torch.log(1 - alphas_prod)
one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_prod)assert alphas.shape==alphas_prod.shape==alphas_prod_p.shape==\
alphas_bar_sqrt.shape==one_minus_alphas_bar_log.shape\
==one_minus_alphas_bar_sqrt.shape
print("all the same shape",betas.shape)、确定扩散过程任意时刻的采样值#3 计算任意时刻的x采样值,基于x_0和重参数化
def q_x(x_0,t):"""可以基于x[0]得到任意时刻t的x[t]"""noise = torch.randn_like(x_0)alphas_t = alphas_bar_sqrt[t]alphas_1_m_t = one_minus_alphas_bar_sqrt[t]return (alphas_t * x_0 + alphas_1_m_t * noise)#在x[0]的基础上添加噪声
j
# 4 演示原始数据分布加噪100步后的结果num_shows = 20
fig,axs = plt.subplots(2,10,figsize=(28,3))
plt.rc('text',color='black')#共有10000个点,每个点包含两个坐标
#生成100步以内每隔5步加噪声后的图像
for i in range(num_shows):j = i//10k = i%10q_i = q_x(dataset,torch.tensor([i*num_steps//num_shows]))#生成t时刻的采样数据axs[j,k].scatter(q_i[:,0],q_i[:,1],color='red',edgecolor='white')axs[j,k].set_axis_off()axs[j,k].set_title('$q(\mathbf{x}_{'+str(i*num_steps//num_shows)+'})$')

在这里插入图片描述

# 5 编写拟合逆扩散过程高斯分布的模型import torch
import torch.nn as nn
​
class MLPDiffusion(nn.Module):def __init__(self,n_steps,num_units=128):super(MLPDiffusion,self).__init__()self.linears = nn.ModuleList([nn.Linear(2,num_units),nn.ReLU(),nn.Linear(num_units,num_units),nn.ReLU(),nn.Linear(num_units,num_units),nn.ReLU(),nn.Linear(num_units,2),])self.step_embeddings = nn.ModuleList([nn.Embedding(n_steps,num_units),nn.Embedding(n_steps,num_units),nn.Embedding(n_steps,num_units),])def forward(self,x,t):
#         x = x_0for idx,embedding_layer in enumerate(self.step_embeddings):t_embedding = embedding_layer(t)x = self.linears[2*idx](x)x += t_embeddingx = self.linears[2*idx+1](x)x = self.linears[-1](x)return x

loss_fn 就是Lsimple得表达式。通过传入参数,生成一个随机噪声,并且送入model里面,那么上面也讲了,model的作用是根据X0,和t预测出我们应该减去的噪声,所以损失函数就是用生成的噪声减去预测的噪声。
在这里插入图片描述

# 6 编写训练的误差函数
def diffusion_loss_fn(model,x_0,alphas_bar_sqrt,one_minus_alphas_bar_sqrt,n_steps):"""对任意时刻t进行采样计算loss"""batch_size = x_0.shape[0]#对一个batchsize样本生成随机的时刻tt = torch.randint(0,n_steps,size=(batch_size//2,))t = torch.cat([t,n_steps-1-t],dim=0)t = t.unsqueeze(-1)#x0的系数a = alphas_bar_sqrt[t]#eps的系数aml = one_minus_alphas_bar_sqrt[t]#生成随机噪音epse = torch.randn_like(x_0)#构造模型的输入x = x_0*a+e*aml#送入模型,得到t时刻的随机噪声预测值output = model(x,t.squeeze(-1))#与真实噪声一起计算误差,求平均值return torch.pow((e - output),2).mean()
# 7、编写逆扩散采样函数(inference)def p_sample_loop(model,shape,n_steps,betas,one_minus_alphas_bar_sqrt):"""从x[T]恢复x[T-1]、x[T-2]|...x[0]"""cur_x = torch.randn(shape)x_seq = [cur_x]for i in reversed(range(n_steps)):cur_x = p_sample(model,cur_x,i,betas,one_minus_alphas_bar_sqrt)x_seq.append(cur_x)return x_seq
​
def p_sample(model,x,t,betas,one_minus_alphas_bar_sqrt):"""从x[T]采样t时刻的重构值"""t = torch.tensor([t])coeff = betas[t] / one_minus_alphas_bar_sqrt[t]eps_theta = model(x,t)mean = (1/(1-betas[t]).sqrt())*(x-(coeff*eps_theta))z = torch.randn_like(x)sigma_t = betas[t].sqrt()sample = mean + sigma_t * zreturn (sample)
# 8、开始训练模型,打印loss及中间重构效果seed = 1234class EMA():"""构建一个参数平滑器"""def __init__(self,mu=0.01):self.mu = muself.shadow = {}def register(self,name,val):self.shadow[name] = val.clone()def __call__(self,name,x):assert name in self.shadownew_average = self.mu * x + (1.0-self.mu)*self.shadow[name]self.shadow[name] = new_average.clone()return new_averageprint('Training model...')
batch_size = 128
dataloader = torch.utils.data.DataLoader(dataset,batch_size=batch_size,shuffle=True)
num_epoch = 4000
plt.rc('text',color='blue')
​
model = MLPDiffusion(num_steps)#输出维度是2,输入是x和step
optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)for t in range(num_epoch):for idx,batch_x in enumerate(dataloader):loss = diffusion_loss_fn(model,batch_x,alphas_bar_sqrt,one_minus_alphas_bar_sqrt,num_steps)optimizer.zero_grad()loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(),1.)optimizer.step()if(t%100==0):print(loss)x_seq = p_sample_loop(model,dataset.shape,num_steps,betas,one_minus_alphas_bar_sqrt)fig,axs = plt.subplots(1,10,figsize=(28,3))for i in range(1,11):cur_x = x_seq[i*10].detach()axs[i-1].scatter(cur_x[:,0],cur_x[:,1],color='red',edgecolor='white');axs[i-1].set_axis_off();axs[i-1].set_title('$q(\mathbf{x}_{'+str(i*10)+'})$')

最后的演示

动画演示扩散过程和逆扩散过程import io
from PIL import Image
​
imgs = []
for i in range(100):plt.clf()q_i = q_x(dataset,torch.tensor([i]))plt.scatter(q_i[:,0],q_i[:,1],color='red',edgecolor='white',s=5);plt.axis('off');img_buf = io.BytesIO()plt.savefig(img_buf,format='png')img = Image.open(img_buf)imgs.append(img)
mg = Image.open(img_buf)reverse.append(img)
reverse = []
for i in range(100):plt.clf()cur_x = x_seq[i].detach()plt.scatter(cur_x[:,0],cur_x[:,1],color='red',edgecolor='white',s=5);plt.axis('off')img_buf = io.BytesIO()plt.savefig(img_buf,format='png')img = Image.open(img_buf)reverse.append(img)
​
imgs = imgs +reverse
imgs[0].save("diffusion.gif",format='GIF',append_images=imgs,save_all=True,duration=100,loop=0)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/70558.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

算法——组合程序算法解析

组合就是从m个元素的数组中求n个元素的所有组合&#xff0c;代码如下&#xff1a; #include <iostream> #include <vector> using namespace std; // 递归求解组合 void combinations(vector<int>& nums, vector<int>& combination, int star…

Linux 安装 JDK

要在Linux上安装JDK 1&#xff0c;按照以下步骤进行操作&#xff1a; 1. 下载JDK安装文件&#xff1a;首先&#xff0c;你需要找到适用于你操作系统的JDK安装文件&#xff08;tar.gz或tar.bz2格式&#xff09;。你可以从Oracle官方网站或其他可信的来源下载该文件。 2. 解压…

Ansible自动化运维

目录 前言 一、概述 常见的开源自动化运维工具比较 二、ansible环境搭建 三、ansible模块 &#xff08;一&#xff09;、hostname模块 &#xff08;二&#xff09;、file模块 &#xff08;三&#xff09;、copy模块 &#xff08;四&#xff09;、fetch模块 &#xff…

借助各大模型的优点生成原创视频(真人人声)Plus

【技术背景】 众所周知&#xff0c;组成视频的3大元素&#xff0c;即文本语音图片。接着小编逐一介绍生成原创视频的过程。 【文本生成】 天工AI搜索&#xff08;thttp://iangong.cn&#xff09; 直接手机短信验证就可以使用&#xff0c;该大模型已经接入互联网&#xff0c…

什么是IIFE(Immediately Invoked Function Expression)?它有什么作用?

聚沙成塔每天进步一点点 ⭐ 专栏简介⭐IIFE 的基本语法⭐IIFE 的主要作用⭐如何使用 IIFE 来创建私有变量和模块封装⭐ 写在最后 ⭐ 专栏简介 前端入门之旅&#xff1a;探索Web开发的奇妙世界 记得点击上方或者右侧链接订阅本专栏哦 几何带你启航前端之旅 欢迎来到前端入门之旅…

GOOGLE SRE 运维模式解读

一、SRE核心是什么 我总结下来是&#xff1a;通过软件工程的方式开发&#xff08;GOOGLE规定SRE团队必须将50%的精力花在真实的开发工作上&#xff09;一些自动化的工具系统来解放传统运维工程师大量重复和手工操作&#xff0c;从而让新生代的SRE工程师有更多的时间&#xff1…

五种定时任务方案(Timer+ScheduleExecutorService+spring task+多线程执行+quartz)

方案一&#xff1a;Timer (1)Timer.schedule(TimerTask task,Date time)安排在制定的时间执行指定的任务。 (2)Timer.schedule(TimerTask task,Date firstTime ,long period)安排指定的任务在指定的时间开始进行重复的固定延迟执行&#xff0e; (3)Timer.schedule(TimerTask…

YashanDB:潜心实干,数据库核心技术突破没有捷径可走

都说数据库是三大基础软件中的一块硬骨头&#xff0c;技术门槛高、研发周期长、工程要求高&#xff0c;市场长期被几大巨头所把持。 因此&#xff0c;实现突破一直是中国数据库产业的夙愿。自上个世纪80年代起&#xff0c;中国数据库产业走过艰辛坎坷的四十余载&#xff0c;终…

【数据结构】二叉搜索树——二叉搜索树的概念和介绍、二叉搜索树的简单实现、二叉搜索树的增删查改

文章目录 二叉搜索树1. 二叉搜索树的概念和介绍2. 二叉搜索树的简单实现2.1二叉搜索树的插入2.2二叉搜索树的查找2.3二叉搜索树的遍历2.4二叉搜索树的删除2.5完整代码和测试 二叉搜索树 1. 二叉搜索树的概念和介绍 二叉搜索树又称二叉排序树&#xff0c;它或者是一棵空树&…

【Spring 事务和事务传播机制】

目录 1 事务概述 1.1 为什么需要事务 1.2 事务的特性 1.3 Spring 中事务的实现 2 Spring 声明式事务 2.1 Transactional 2.2 Transactional 的作用范围 2.3 Transactional 的各种参数 2.3.1 ioslation 2.4 事务发生了异常&#xff0c;也不回滚的情况 异常被捕获时 3 事务的传…

通过 Blob 对二进制流文件下载实现文件保存下载

原理&#xff1a;前端将二进制文件做转换实现下载: 请求后端接口->接收后端返回的二进制流(通过二进制流&#xff08;Blob&#xff09;下载,把后端返回的二进制文件放在 Blob 里面)->再通过file-saver插件保存 页面上使用&#xff1a; <span click"downloadFil…

作为产品经理,有必要考PMP或者NPDP么?

产品经理的核心竞争力是什么? 三点&#xff1a;知识、能力和决策 懂得越多&#xff0c;能力越强&#xff0c;决策越正确&#xff0c;核心竞争力越强。一般来说&#xff0c;看的越多&#xff0c;做的越多&#xff0c;实践出经验才是王道&#xff0c;但是&#xff0c;总有看不…

智慧物流发展的重要推动力量:北斗卫星导航系统

随着经济的快速发展和电商的普及&#xff0c;物流行业的规模不断扩大&#xff0c;对物流运输的效率和安全性也提出了更高的要求。传统的物流运输方式存在着效率低下、信息不对称、安全隐患等问题&#xff0c;因此发展智慧物流已经成为物流行业的必然趋势。智慧物流可以通过先进…

立晶半导体Cubic Lattice Inc 专攻音频ADC,音频DAC,音频CODEC,音频CLASS D等CL7016

概述&#xff1a; CL7016是一款高保真USB Type-C兼容音频编解码芯片。可以录制和回放有24比特音乐和声音。内置回放通路信号动态压缩&#xff0c; 最大42db录音通路增益&#xff0c;PDM数字麦克风&#xff0c;和立体声无需电容耳机驱动放大器。 5V单电源供电。兼容USB 2.0全速工…

深度学习面试八股文(2023.9.06持续更新)

一、优化器 1、SGD是什么&#xff1f; 批梯度下降&#xff08;Batch gradient descent&#xff09;&#xff1a;遍历全部数据集算一次损失函数&#xff0c;计算量开销大&#xff0c;计算速度慢&#xff0c;不支持在线学习。随机梯度下降&#xff08;Stochastic gradient desc…

基于vue-cli创建后台管理系统前端页面——element-ui,axios,跨域配置,布局初步,导航栏

目录 引出安装npm install安装element-ui安装axios 进行配置main.js中引入添加jwt前端跨域配置 进行初始布局HomeView.vueApp.vue 新增页面和引入home页面导航栏总结 引出 1.vue-cli创建前端工程&#xff0c;安装element-ui&#xff0c;axios和配置&#xff1b; 2.前端跨域的配…

记录学习--字节码解析try catch

1.示例代码 Testpublic void someTest() {String s "111";try {s "222";int i 1/0;} catch (Exception e){e.printStackTrace();System.out.println(s);}System.out.println(s);}2.示例代码对应的字节码 0 ldc #2 <111>2 astore_13 ldc #3 <22…

“深入理解SpringMVC的注解驱动开发“

目录 引言1. SpringMVC的常用注解2. SpringMVC的参数传递3. SpringMVC的返回值4. SpringMVC页面跳转总结 引言 在现代的Web开发中&#xff0c;SpringMVC已经成为了一个非常流行和强大的框架。它提供了许多注解来简化开发过程&#xff0c;使得我们能够更加专注于业务逻辑的实现…

【python】TCP socket服务器 Demo

目录 一、单线程服务器 二、多线程服务器 三、多线程服务器&#xff08;发送和接收分离&#xff09; 一、单线程服务器 说明&#xff1a;只能连接一个客户端 import socket,binascii# 创建一个 TCP 套接字 server_socket socket.socket(socket.AF_INET, socket.SOCK_STRE…

nas汇编程序的调试排错方法

nas汇编程序的调试排错方法&#xff1a; 1、查找是哪一步错了 2、查看对应的*.lst文件&#xff0c;本例中是"asmhead.lst" 3、根据*.lst文件的[ERROR #002]提示查看源码&#xff0c;改错。 4、重新运行编译&#xff0c;OK 1、查找是哪一步错了&#xff1a; nask.ex…