【Diffusion实战】训练一个类别引导diffusion模型(Pytorch代码详解)

  又学习了一种方法,类别引导diffusion模型,使用mnist数据集,记录一下它的用法吧。


Diffusion实战篇:
  【Diffusion实战】训练一个diffusion模型生成S曲线(Pytorch代码详解)
  【Diffusion实战】训练一个diffusion模型生成蝴蝶图像(Pytorch代码详解)
  【Diffusion实战】引导一个diffusion模型根据文字生成图像(Pytorch代码详解)
Diffusion综述篇:
  【Diffusion综述】医学图像分析中的扩散模型(一)
  【Diffusion综述】医学图像分析中的扩散模型(二)


1、数据集装载

  使用mnist数据集来训练类别引导diffusion模型,因为其比较简单清晰:

import torch
import torchvision
from torchvision import transforms
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel
from matplotlib import pyplot as plt
from tqdm.auto import tqdm
from PIL import Image
import numpy as npdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')dataset = torchvision.datasets.MNIST(root="mnist/", train=True, download=False, transform=torchvision.transforms.ToTensor())
train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)# 查看MNIST数据集样本
x, y = next(iter(train_dataloader))
print('Input shape:', x.shape)
print('Labels:', y)
plt.imshow(torchvision.utils.make_grid(x)[0], cmap='Greys')
plt.axis('off')
plt.show()

  看一看我们朴素的样本:
在这里插入图片描述


2、创建条件扩散模型

  创建了一个名为ClassConditionedUnet的条件扩散模型,定义了一个可学习的嵌入层,用以将数字类别映射到特征向量上,将类别嵌入与原始输入拼接之后,送入常规的UNet网络即可。

  知识传送:【python函数】torch.nn.Embedding函数用法图解

class ClassConditionedUnet(nn.Module):def __init__(self, num_classes=10, class_emb_size=4):super().__init__()# 嵌入层将数字类别映射到特征向量上self.class_emb = nn.Embedding(num_classes, class_emb_size)# 一个常规的UNet网络self.model = UNet2DModel(sample_size=28,           # 图像尺寸in_channels=1 + class_emb_size, # 增加一个通道, 用于条件生成out_channels=1,           # 输出通道layers_per_block=2,       # 残差连接层数目block_out_channels=(32, 64, 64), down_block_types=( "DownBlock2D",        # a regular ResNet downsampling block"AttnDownBlock2D",    # a ResNet downsampling block with spatial self-attention"AttnDownBlock2D",), up_block_types=("AttnUpBlock2D", "AttnUpBlock2D",      # a ResNet upsampling block with spatial self-attention"UpBlock2D",          # a regular ResNet upsampling block),)def forward(self, x, t, class_labels):bs, ch, w, h = x.shape  # [8, 1, 28, 28] # 类别条件以额外通道的形式输入class_cond = self.class_emb(class_labels)  # [8, 4]class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h)  # [8, 4, 28, 28]# 拼接原始输入与类别条件映射net_input = torch.cat((x, class_cond), 1)   # (8, 5, 28, 28)# 模型预测return self.model(net_input, t).sample  # (8, 1, 28, 28)noisy_xb = torch.randn(8, 1, 28, 28).to(device)
timesteps = torch.linspace(0, 999, 8).long().to(device)
y = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1]).to(device)
model = ClassConditionedUnet().to(device)
with torch.no_grad():model_prediction = model(noisy_xb, timesteps, y)
model_prediction.shape  # 验证输出与输出尺寸相同

3、模型训练

  训练过程就跟之前的一样啦~

# 创建调度器
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')
train_dataloader = DataLoader(dataset, batch_size=128, shuffle=True)n_epochs = 10
net = ClassConditionedUnet().to(device)
loss_fn = nn.MSELoss()
opt = torch.optim.Adam(net.parameters(), lr=1e-3) losses = []
for epoch in range(n_epochs):for x, y in tqdm(train_dataloader):# 获取数据并添加噪声x = x.to(device) * 2 - 1  # 归一化到[-1, 1]y = y.to(device)noise = torch.randn_like(x)timesteps = torch.randint(0, 999, (x.shape[0],)).long().to(device)# 前向加噪noisy_x = noise_scheduler.add_noise(x, noise, timesteps)# 获得模型预测结果pred = net(noisy_x, timesteps, y)  # 此处传入了类别标签# 损失计算loss = loss_fn(pred, noise) # 损失回传, 参数更新opt.zero_grad()loss.backward()opt.step()# 损失保存losses.append(loss.item())# 输出损失avg_loss = sum(losses[-100:])/100print(f'Finished epoch {epoch}. Average of the last 100 loss values: {avg_loss:05f}')# 查看损失曲线
plt.figure(dpi=300)
plt.plot(losses)
plt.show()

  输出损失曲线为:

在这里插入图片描述


4、模型推理

  进行采样循环,用类别标签引导图像生成:

x = torch.randn(80, 1, 28, 28).to(device)  # 随机噪声
y = torch.tensor([[i]*8 for i in range(10)]).flatten().to(device)  # 类别标签# 采样循环
for i, t in tqdm(enumerate(noise_scheduler.timesteps)):# 模型预测结果with torch.no_grad():residual = net(x, t, y)# 根据预测噪声和时间步更新图像x = noise_scheduler.step(residual, t, x).prev_sample# 结果可视化
fig, ax = plt.subplots(1, 1, figsize=(12, 12))
ax.imshow(torchvision.utils.make_grid(x.detach().cpu().clip(-1, 1), nrow=8)[0], 'Greys')
ax.axis('off')

  类别引导效果如下,效果还是挺好的哩:

在这里插入图片描述


5、代码汇总

import torch
import torchvision
from torchvision import transforms
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel
from matplotlib import pyplot as plt
from tqdm.auto import tqdm
from PIL import Image
import numpy as npdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')# -----------------------------------------------------------------------------
# 1、数据集装载
dataset = torchvision.datasets.MNIST(root="mnist/", train=True, download=False, transform=torchvision.transforms.ToTensor())
train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)# 查看MNIST数据集样本
x, y = next(iter(train_dataloader))
print('Input shape:', x.shape)
print('Labels:', y)
plt.imshow(torchvision.utils.make_grid(x)[0], cmap='Greys')
plt.axis('off')
plt.show()
# -----------------------------------------------------------------------------# -----------------------------------------------------------------------------
# 2、创建条件扩散模型
class ClassConditionedUnet(nn.Module):def __init__(self, num_classes=10, class_emb_size=4):super().__init__()# 嵌入层将数字类别映射到特征向量上self.class_emb = nn.Embedding(num_classes, class_emb_size)# 一个常规的UNet网络self.model = UNet2DModel(sample_size=28,           # 图像尺寸in_channels=1 + class_emb_size, # 增加一个通道, 用于条件生成out_channels=1,           # 输出通道layers_per_block=2,       # 残差连接层数目block_out_channels=(32, 64, 64), down_block_types=( "DownBlock2D",        # a regular ResNet downsampling block"AttnDownBlock2D",    # a ResNet downsampling block with spatial self-attention"AttnDownBlock2D",), up_block_types=("AttnUpBlock2D", "AttnUpBlock2D",      # a ResNet upsampling block with spatial self-attention"UpBlock2D",          # a regular ResNet upsampling block),)def forward(self, x, t, class_labels):bs, ch, w, h = x.shape  # [8, 1, 28, 28] # 类别条件以额外通道的形式输入class_cond = self.class_emb(class_labels)  # [8, 4]class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h)  # [8, 4, 28, 28]# 拼接原始输入与类别条件映射net_input = torch.cat((x, class_cond), 1)   # (8, 5, 28, 28)# 模型预测return self.model(net_input, t).sample  # (8, 1, 28, 28)noisy_xb = torch.randn(8, 1, 28, 28).to(device)
timesteps = torch.linspace(0, 999, 8).long().to(device)
y = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1]).to(device)
model = ClassConditionedUnet().to(device)
with torch.no_grad():model_prediction = model(noisy_xb, timesteps, y)
model_prediction.shape  # 验证输出与输出尺寸相同
# -----------------------------------------------------------------------------# -----------------------------------------------------------------------------
# 3、模型训练
# 创建调度器
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')
train_dataloader = DataLoader(dataset, batch_size=128, shuffle=True)n_epochs = 10
net = ClassConditionedUnet().to(device)
loss_fn = nn.MSELoss()
opt = torch.optim.Adam(net.parameters(), lr=1e-3) losses = []
for epoch in range(n_epochs):for x, y in tqdm(train_dataloader):# 获取数据并添加噪声x = x.to(device) * 2 - 1  # 归一化到[-1, 1]y = y.to(device)noise = torch.randn_like(x)timesteps = torch.randint(0, 999, (x.shape[0],)).long().to(device)# 前向加噪noisy_x = noise_scheduler.add_noise(x, noise, timesteps)# 获得模型预测结果pred = net(noisy_x, timesteps, y)  # 此处传入了类别标签# 损失计算loss = loss_fn(pred, noise) # 损失回传, 参数更新opt.zero_grad()loss.backward()opt.step()# 损失保存losses.append(loss.item())# 输出损失avg_loss = sum(losses[-100:])/100print(f'Finished epoch {epoch}. Average of the last 100 loss values: {avg_loss:05f}')# 查看损失曲线
plt.figure(dpi=300)
plt.plot(losses)
plt.show()
# -----------------------------------------------------------------------------# -----------------------------------------------------------------------------
# 4、模型推理
x = torch.randn(80, 1, 28, 28).to(device)  # 随机噪声
y = torch.tensor([[i]*8 for i in range(10)]).flatten().to(device)  # 类别标签# 采样循环
for i, t in tqdm(enumerate(noise_scheduler.timesteps)):# 模型预测结果with torch.no_grad():residual = net(x, t, y)# 根据预测噪声和时间步更新图像x = noise_scheduler.step(residual, t, x).prev_sample# 结果可视化
fig, ax = plt.subplots(1, 1, figsize=(12, 12))
ax.imshow(torchvision.utils.make_grid(x.detach().cpu().clip(-1, 1), nrow=8)[0], 'Greys')
ax.axis('off')
# -----------------------------------------------------------------------------

  diffusion的修炼境界又提升了一级~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/7713.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Rust Course学习(编写测试)

如果友友你的计算机上没有安装Rust,可以直接安装:Rust 程序设计语言 (rust-lang.org)https://www.rust-lang.org/zh-CN/ Introduce 介绍 Testing in Rust involves writing code specifically designed to verify that other code works as expected. It…

2024.5.7

槽函数声明 private slots:void on_ed_textChanged();void on_pushButton_clicked(); }; 槽函数定义 void Widget::on_ed_textChanged()//文本框 {if(ui->ed1->text().length()>5&&ui->ed2->text().length()>5){ui->pushButton->setStyleSh…

Xinstall广告效果监测,助力广告主优化投放策略

在移动互联网时代,APP推广已成为企业营销的重要手段。然而,如何衡量推广效果,了解用户来源,优化投放策略,一直是广告主和开发者面临的难题。这时,Xinstall作为国内专业的App全渠道统计服务商,以…

Docker 安装部署 postgres

Docker 安装部署 postgres 1、拉取 postgres 镜像文件 [rootiZbp19a67kznq0h0rgosuxZ ~]# docker pull postgres:latest latest: Pulling from library/postgres b0a0cf830b12: Pull complete dda3d8fbd5ed: Pull complete 283a477db7bb: Pull complete 91d2729fa4d5: Pul…

RT-DETR-20240507周更说明|更新Inner-IoU、Focal-IoU、Focaler-IoU等数十种IoU计算方式

RT-DETR改进专栏|包含主干、模块、注意力、损失函数等改进 专栏介绍 本专栏包含模块、卷积、检测头、损失等深度学习前沿改进,目前已有改进点70!每周更新。 20240507更新说明: ⭐⭐ 更新CIoU、DIoU、MDPIoU、GIoU、EIoU、SIoU、ShapeIou、PowerfulIoU、…

8.MyBatis 操作数据库(进阶)

文章目录 1.动态SQL插入1.1使用注解方式插入数据1.2使用xml方式插入数据1.3何时用注解何时用xml?1.4使用SQL查询中有多个and时,如何自动去除多余and1.4.1方法一:删除and之后的代码如图所示,再次运行1.4.2方法二:加上tr…

书生·浦语大模型实战营之 OpenCompass大模型评测

书生浦语大模型实战营之 OpenCompass :是骡子是马,拉出来溜溜 为什么要研究大模型的评测? 百家争鸣,百花齐放。 首先,研究评测对于我们全面了解大型语言模型的优势和限制至关重要。尽管许多研究表明大型语言模型在多…

Linux cmake 初窥【2】

1.开发背景 基于上一篇的基础上,再次升级 2.开发需求 基于 cmake 指定源文件目录可以是多个文件夹,多层目录 3.开发环境 ubuntu 20.04 cmake-3.23.1 4.实现步骤 4.1 准备源码文件 工程目录如下 顶层脚本 compile.sh 负责执行 cmake 操作&#xff0…

FSC森林认证是什么?

FSC森林认证,又称木材认证,是一种运用市场机制来促进森林可持续经营,实现生态、社会和经济目标的工具。FSC森林认证包括森林经营认证(Forest Management, FM)和产销监管链认证(Chain of Custody, COC&#…

微搭低代码入门06分页查询

目录 1 创建自定义代码2 编写分页代码3 创建页面4 创建变量5 配置数据列表总结 我们在数据模型章节介绍了微搭后端服务编写的三种方式,包括Http请求、自定义代码、云函数。本篇我们详细讲解一下利用自定义代码开发分页查询的功能。 1 创建自定义代码 打开控制台&am…

Qt——入门基础

目录 Qt入门第一个应用程序 main.cpp widget.h widget.cpp widget.ui .pro Hello World程序 对象树 编辑框 按钮 Qt 窗口坐标系 Qt入门第一个应用程序 main.cpp 这就像一开始学语言时都会打印一个“Hello World”一样,我们先来看看创建好一个项目后&…

LeetCode 难题解析 —— 正则表达式匹配 (动态规划)

10. 正则表达式匹配 思路解析 这道题虽然看起来不难理解,但却存在多种可能,当然这种可能的数量是有限的,且其规律对于每一次判别都使用,所以自然而然就想到用 动态规划 的方法啦 接下来逐步分析可能的情况: &#x…

栈(使用顺序表构建)

P. S.:以下代码均在VS2019环境下测试,不代表所有编译器均可通过。 P. S.:测试代码均未展示头文件stdio.h的声明,使用时请自行添加。 目录 1、栈的概念2、栈的数组构建方法2.1 前言2.2 正文2.2.1 栈的初始化2.2.2 栈的销毁2.2.3 压…

栈与队列(包括例题一道)

栈 栈的概念 栈:一种特殊的线性表,其只允许在固定的一端进行插入和删除元素操作。 进行数据插入和删除操作的一端 称为栈顶,另一端称为栈底。 栈中的数据元素遵守后进先出 LIFO ( Last In First Out )的原则。 压栈&…

AI去衣技术在动画制作中的应用

随着科技的发展,人工智能(AI)已经在各个领域中发挥了重要作用,其中包括动画制作。在动画制作中,AI去衣技术是一个重要的工具,它可以帮助动画师们更加高效地完成工作。 AI去衣技术是一种基于人工智能的图像…

神经网络怎么把隐含层变量融合到损失函数中?

🏆本文收录于「Bug调优」专栏,主要记录项目实战过程中的Bug之前因后果及提供真实有效的解决方案,希望能够助你一臂之力,帮你早日登顶实现财富自由🚀;同时,欢迎大家关注&&收藏&&…

【工具分享】Amnesia2勒索病毒解密工具

前言 Amnesia 勒索软件于 2017 年 4 月 26 日开始出现。Amnesia 主要通过 RDP(远程桌面服务)暴力攻击进行传播,允许恶意软件作者登录受害者的服务器并执行勒索行为。 特征 Amnesia 是一种用 Delphi 编程语言编写的勒索软件,它使…

程序员的实用神器:助力软件开发的利器 ️

程序员的实用神器:助力软件开发的利器 🛠️ 程序员的实用神器:助力软件开发的利器 🛠️引言摘要自动化测试工具:保障代码质量的利剑 🗡️编写高效测试用例 持续集成/持续部署工具:加速交付的利器…

ASP.NET通用作业批改系统设计

摘  要 该系统采用B/S结构,以浏览器方式登陆系统,用ASP.NET作为开发语言,数据库则使用Microsoft SQL Server 2000实现。《通用作业批改系统》包括了学生子系统、教师子系统、管理员子系统三大模块,该系统主要完成学生&#xff…

基于C语言的贪吃蛇小游戏(简易版)

这篇博客会是对学习C语言成果的检测,为了实现贪吃蛇小游戏,我们用到的“工具”有:C语言函数、枚举、结构体、动态内存管理、预处理指令、链表、Win32 API等。 目录 1.简易版游戏效果 1.1欢迎界面 1.2游戏规则提示页面 1.3游戏进行页面 …