扩散模型实战(四):从零构建扩散模型

推荐阅读列表:

扩散模型实战(一):基本原理介绍

扩散模型实战(二):扩散模型的发展

扩散模型实战(三):扩散模型的应用

本文以MNIST数据集为例,从零构建扩散模型,具体会涉及到如下知识点:

  • 退化过程(向数据中添加噪声)
  • 构建一个简单的UNet模型
  • 训练扩散模型
  • 采样过程分析

下面介绍具体的实现过程:

一、环境配置&python包的导入
     

最好有GPU环境,比如公司的GPU集群或者Google Colab,下面是代码实现:

# 安装diffusers库!pip install -q diffusers# 导入所需要的包import torchimport torchvisionfrom torch import nnfrom torch.nn import functional as Ffrom torch.utils.data import DataLoaderfrom diffusers import DDPMScheduler, UNet2DModelfrom matplotlib import pyplot as pltdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(f'Using device: {device}')
# 输出Using device: cuda

此时会输出运行环境是GPU还是CPU

载MNIST数据集

       MNIST数据集是一个小数据集,存储的是0-9手写数字字体,每张图片都28X28的灰度图片,每个像素的取值范围是[0,1],下面加载该数据集,并展示部分数据:

dataset = torchvision.datasets.MNIST(root="mnist/", train=True, download=True, transform=torchvision.transforms.ToTensor())train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)x, y = next(iter(train_dataloader))print('Input shape:', x.shape)print('Labels:', y)plt.imshow(torchvision.utils.make_grid(x)[0], cmap='Greys');
# 输出Input shape: torch.Size([8, 1, 28, 28])Labels: tensor([7, 8, 4, 2, 3, 6, 0, 2])

扩散模型的退化过程

       所谓退化过程,其实就是对输入数据加入噪声的过程,由于MNIST数据集的像素范围在[0,1],那么我们加入噪声也需要保持在相同的范围,这样我们可以很容易的把输入数据与噪声进行混合,代码如下:

def corrupt(x, amount):  """Corrupt the input `x` by mixing it with noise according to `amount`"""  noise = torch.rand_like(x)  amount = amount.view(-1, 1, 1, 1) # Sort shape so broadcasting works  return x*(1-amount) + noise*amount

接下来,我们看一下逐步加噪的效果,代码如下:

# Plotting the input datafig, axs = plt.subplots(2, 1, figsize=(12, 5))axs[0].set_title('Input data')axs[0].imshow(torchvision.utils.make_grid(x)[0], cmap='Greys')# Adding noiseamount = torch.linspace(0, 1, x.shape[0]) # Left to right -> more corruptionnoised_x = corrupt(x, amount)# Plottinf the noised versionaxs[1].set_title('Corrupted data (-- amount increases -->)')axs[1].imshow(torchvision.utils.make_grid(noised_x)[0], cmap='Greys');

       从上图可以看出,从左到右加入的噪声逐步增多,当噪声量接近1时,数据看起来像纯粹的随机噪声。

构建一个简单的UNet模型

       UNet模型与自编码器有异曲同工之妙,UNet最初是用于完成医学图像中分割任务的,网络结构如下所示:

代码如下:

class BasicUNet(nn.Module):    """A minimal UNet implementation."""    def __init__(self, in_channels=1, out_channels=1):        super().__init__()        self.down_layers = torch.nn.ModuleList([             nn.Conv2d(in_channels, 32, kernel_size=5, padding=2),            nn.Conv2d(32, 64, kernel_size=5, padding=2),            nn.Conv2d(64, 64, kernel_size=5, padding=2),        ])        self.up_layers = torch.nn.ModuleList([            nn.Conv2d(64, 64, kernel_size=5, padding=2),            nn.Conv2d(64, 32, kernel_size=5, padding=2),            nn.Conv2d(32, out_channels, kernel_size=5, padding=2),         ])        self.act = nn.SiLU() # The activation function        self.downscale = nn.MaxPool2d(2)        self.upscale = nn.Upsample(scale_factor=2)    def forward(self, x):        h = []        for i, l in enumerate(self.down_layers):            x = self.act(l(x)) # Through the layer and the activation function            if i < 2: # For all but the third (final) down layer:              h.append(x) # Storing output for skip connection              x = self.downscale(x) # Downscale ready for the next layer                      for i, l in enumerate(self.up_layers):            if i > 0: # For all except the first up layer              x = self.upscale(x) # Upscale              x += h.pop() # Fetching stored output (skip connection)            x = self.act(l(x)) # Through the layer and the activation function                    return x

我们来检验一下模型输入输出的shape变化是否符合预期,代码如下:

net = BasicUNet()x = torch.rand(8, 1, 28, 28)net(x).shape
# 输出torch.Size([8, 1, 28, 28])

再来看一下模型的参数量,代码如下:

sum([p.numel() for p in net.parameters()])
# 输出309057

至此,已经完成数据加载和UNet模型构建,当然UNet模型的结构可以有不同的设计。

、扩散模型训练

        扩散模型应该学习什么?其实有很多不同的目标,比如学习噪声,我们先以一个简单的例子开始,输入数据为带噪声的MNIST数据,扩散模型应该输出对应的最佳数字预测,因此学习的目标是预测值与真实值的MSE,训练代码如下:

# Dataloader (you can mess with batch size)batch_size = 128train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)# How many runs through the data should we do?n_epochs = 3# Create the networknet = BasicUNet()net.to(device)# Our loss finctionloss_fn = nn.MSELoss()# The optimizeropt = torch.optim.Adam(net.parameters(), lr=1e-3) # Keeping a record of the losses for later viewinglosses = []# The training loopfor epoch in range(n_epochs):    for x, y in train_dataloader:        # Get some data and prepare the corrupted version        x = x.to(device) # Data on the GPU        noise_amount = torch.rand(x.shape[0]).to(device) # Pick random noise amounts        noisy_x = corrupt(x, noise_amount) # Create our noisy x        # Get the model prediction        pred = net(noisy_x)        # Calculate the loss        loss = loss_fn(pred, x) # How close is the output to the true 'clean' x?        # Backprop and update the params:        opt.zero_grad()        loss.backward()        opt.step()        # Store the loss for later        losses.append(loss.item())    # Print our the average of the loss values for this epoch:    avg_loss = sum(losses[-len(train_dataloader):])/len(train_dataloader)    print(f'Finished epoch {epoch}. Average loss for this epoch: {avg_loss:05f}')# View the loss curveplt.plot(losses)plt.ylim(0, 0.1);
# 输出Finished epoch 0. Average loss for this epoch: 0.024689Finished epoch 1. Average loss for this epoch: 0.019226Finished epoch 2. Average loss for this epoch: 0.017939

训练过程的loss曲线如下图所示:

六、扩散模型效果评估

我们选取一部分数据来评估一下模型的预测效果,代码如下:

#@markdown Visualizing model predictions on noisy inputs:# Fetch some datax, y = next(iter(train_dataloader))x = x[:8] # Only using the first 8 for easy plotting# Corrupt with a range of amountsamount = torch.linspace(0, 1, x.shape[0]) # Left to right -> more corruptionnoised_x = corrupt(x, amount)# Get the model predictionswith torch.no_grad():  preds = net(noised_x.to(device)).detach().cpu()# Plotfig, axs = plt.subplots(3, 1, figsize=(12, 7))axs[0].set_title('Input data')axs[0].imshow(torchvision.utils.make_grid(x)[0].clip(0, 1), cmap='Greys')axs[1].set_title('Corrupted data')axs[1].imshow(torchvision.utils.make_grid(noised_x)[0].clip(0, 1), cmap='Greys')axs[2].set_title('Network Predictions')axs[2].imshow(torchvision.utils.make_grid(preds)[0].clip(0, 1), cmap='Greys');

从上图可以看出,对于噪声量较低的输入,模型的预测效果是很不错的,当amount=1时,模型的输出接近整个数据集的均值,这正是扩散模型的工作原理。

Note:我们的训练并不太充分,读者可以尝试不同的超参数来优化模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/41113.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

智能楼宇综合布线实训室建设方案

一、楼宇智能综合布线实训室方案概述 楼宇智能综合布线实训室方案旨在为学生提供一个真实的学习和实践环境&#xff0c;以培养他们在楼宇智能综合布线领域的实际操作能力和技能。以下是一个概述&#xff1a; 1. 培养目标&#xff1a;培养学生在楼宇智能综合布线方面的综合能力…

Shader学习(三)(片元着色器)

1、在片元着色器处理漫反射 // Upgrade NOTE: replaced _World2Object with unity_WorldToObjectShader "Custom/specularfragement" {properties{_sp("Specular",color) (1,1,1,1)_shiness("Shiness",range(1,64)) 8}SubShader{pass {tags{&…

网络通信原理应用层(第五十一课)

1)DNS:域名解析系统,端口号TCP或UDP的53 2)域名注册网站 -新网 www.xinnet.com -万网-阿里云 www.net.cn -中国互联 hulian.top 配置通过域名访问网站(NETBASE第七课)_IHOPEDREAM的博客-CSDN博客 2、FTP 1)FTP概述 -文件传输协议 -控制连接:TCP 21 <

leetcode-413. 等差数列划分(java)

等差数列划分 leetcode-413. 等差数列划分题目描述双指针 上期经典算法 leetcode-413. 等差数列划分 难度 - 中等 原题链接 - 等差数列划分 题目描述 如果一个数列 至少有三个元素 &#xff0c;并且任意两个相邻元素之差相同&#xff0c;则称该数列为等差数列。 例如&#xff0…

JMeter接口测试数据分离驱动应用

步骤&#xff1a; 创建csv文件&#xff0c;编写接口测试用例 新建线程组——创建循环控制器&#xff08;循环次数填用例总数&#xff09; 创建CSV数据文件设置&#xff0c;设置参数。&#xff08;注意&#xff1a;是否允许带引号&#xff1f;&#xff1a;一定要设置为true&a…

深度学习实战48-【未来的专家团队】基于AutoCompany模型的自动化企业概念设计与设想

大家好,我是微学AI,今天给大家介绍一下深度学习实战48-【未来的专家团队】基于AutoCompany模型的自动化企业概念设计与设想,文本将介绍AutoCompany模型的概念设计,涵盖了AI智能公司的各个角色,并结合了GPT-4接口来实现各个角色的功能,设置中央控制器,公司运作过程会生成…

JMM内存模型之happens-before阐述

文章目录 一、happens-before的定义二、happens-before的规则1. 程序顺序规则&#xff1a;2. 监视器锁规则&#xff1a;3. volatile变量规则&#xff1a;4. 传递性&#xff1a;5. start()规则&#xff1a;6. join()规则&#xff1a; 一、happens-before的定义 如果一个操作hap…

【编程二三事】ES究竟是个啥?

在最近的项目中&#xff0c;总是或多或少接触到了搜索的能力。而在这些项目之中&#xff0c;或多或少都离不开一个中间件 - ElasticSearch。 今天忙里偷闲&#xff0c;就来好好了解下这个中间件是用来干什么的。 ES是什么? ​ ES全称ElasticSearch&#xff0c;是个基于Lucen…

性能优化的重要性

性能优化的重要性 性能优化的重要性摘要引言注意事项代码示例及注释性能优化的重要性 性能优化的重要性在 Java 中的体现响应速度资源利用效率扩展性与可维护性并发性能合理的锁策略线程安全的数据结构并发工具类的应用避免竞态条件和死锁 总结代码示例 博主 默语带您 Go to Ne…

一张图看懂 USDT三种类型地址 Omni、ERC20、TRC20的区别

USDT是当前实用最广泛&#xff0c;市值最高的稳定币&#xff0c;它是中心化的公司Tether发行的。在今年的4月17日之前&#xff0c;市场上存在着2种不同类型的USDT。4月17日又多了一种波场TRC20协议发行的USDT&#xff0c;它们各自有什么区别呢?哪个转账最快到账&#xff1f;哪…

谷歌推出首款量子弹性 FIDO2 安全密钥

谷歌在本周二宣布推出首个量子弹性 FIDO2 安全密钥&#xff0c;作为其 OpenSK 安全密钥计划的一部分。 Elie Bursztein和Fabian Kaczmarczyck表示&#xff1a;这一开源硬件优化的实现采用了一种新颖的ECC/Dilithium混合签名模式&#xff0c;它结合了ECC抵御标准攻击的安全性和…

[LeetCode]矩阵对角线元素的和

解题 思路 1: 循环,找到主对角线的下标和副对角线的下标,如果矩阵长或宽为奇数的时候,需要减去中间公共的那一个值,中间公共的那个数的下标为mat[mat.size()/2][mat.size()/2]副对角线的下标为 mat [i][mat.size()-i-1] class Solution { public:int diagonalSum(vector<ve…

JVM中判定对象是否回收的的方法

引用计数法 引用计数法是一种垃圾回收&#xff08;Garbage Collection&#xff09;算法&#xff0c;用于自动管理内存中的对象。在引用计数法中&#xff0c;每个对象都有一个关联的引用计数器&#xff0c;用于记录对该对象的引用数量。 当一个新的引用指向对象时&#xff0c;…

Hive底层数据存储格式

前言 在大数据领域,Hive是一种常用的数据仓库工具,用于管理和处理大规模数据集。Hive底层支持多种数据存储格式,这些格式对于数据存储、查询性能和压缩效率等方面有不同的优缺点。本文将介绍Hive底层的三种主要数据存储格式:文本文件格式、Parquet格式和ORC格式。 一、三…

SpringBoot复习:(42)WebServerCustomizer的customize方法是在哪里被调用的?

ServletWebServletAutoConfiguration类定义如下&#xff1a; 可以看到其中通过Import注解导入了其内部类BeanPostProcessorRegister。 BeanPostProcessor中定义的registerBeanDefinition方法会被Spring容器调用。 registerBeanDefinitions方法调用了RegistrySyntheticBeanIf…

Intellij IDEA SBT依赖分析插件

可分析模块和传递依赖 安装完插件后&#xff0c;由于IDEA BUG&#xff0c;会出现两个分析按钮&#xff0c;一个是gradle的&#xff0c;一般是后者是新安装的sbt。 选择需要分析的模块 只需要在project/plugins.sbt中添加代码&#xff0c;启动官方分析插件addDependencyTreeP…

1281. 整数的各位积和之差

诸神缄默不语-个人CSDN博文目录 力扣刷题笔记 文章目录 1. 简单粗暴的遍历2. 其实也是遍历&#xff0c;但是用Python内置函数只用写一行 1. 简单粗暴的遍历 Python版&#xff1a; class Solution:def subtractProductAndSum(self, n: int) -> int:he0ji1while n>1:last…

redis 数据结构(一)

Redis 为什么那么快 redis是一种内存数据库&#xff0c;所有的操作都是在内存中进行的&#xff0c;还有一种重要原因是&#xff1a;它的数据结构的设计对数据进行增删查改操作很高效。 redis的数据结构是什么 redis数据结构是对redis键值对值的数据类型的底层的实现&#xff0c…

团团代码生成器V1.0:一键生成完整的CRUD功能(提供Gitee源码)

前言&#xff1a;在日常开发的中&#xff0c;经常会需要重复写一些基础的增删改查接口&#xff0c;虽说不难&#xff0c;但是会耗费我们一些时间&#xff0c;所以我自己开发了一套纯SpringBoot实现的代码生成器&#xff0c;可以为我们生成单条数据的增删改查&#xff0c;还可以…

中远麒麟堡垒机 SQL注入漏洞复现

0x01 产品简介 中远麒麟依托自身强大的研发能力,丰富的行业经验&#xff0c;自主研发了新一代软硬件一体化统一安全运维平台一-iAudit 统一安全运维平台。该产品支持对企业运维人员在运维过程中进行统一身份认证、统一授权、统一审计、统一监控&#xff0c;消除了传统运维过程中…