【Diffusion实战】训练一个类别引导diffusion模型(Pytorch代码详解)

  又学习了一种方法,类别引导diffusion模型,使用mnist数据集,记录一下它的用法吧。


Diffusion实战篇:
  【Diffusion实战】训练一个diffusion模型生成S曲线(Pytorch代码详解)
  【Diffusion实战】训练一个diffusion模型生成蝴蝶图像(Pytorch代码详解)
  【Diffusion实战】引导一个diffusion模型根据文字生成图像(Pytorch代码详解)
Diffusion综述篇:
  【Diffusion综述】医学图像分析中的扩散模型(一)
  【Diffusion综述】医学图像分析中的扩散模型(二)


1、数据集装载

  使用mnist数据集来训练类别引导diffusion模型,因为其比较简单清晰:

import torch
import torchvision
from torchvision import transforms
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel
from matplotlib import pyplot as plt
from tqdm.auto import tqdm
from PIL import Image
import numpy as npdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')dataset = torchvision.datasets.MNIST(root="mnist/", train=True, download=False, transform=torchvision.transforms.ToTensor())
train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)# 查看MNIST数据集样本
x, y = next(iter(train_dataloader))
print('Input shape:', x.shape)
print('Labels:', y)
plt.imshow(torchvision.utils.make_grid(x)[0], cmap='Greys')
plt.axis('off')
plt.show()

  看一看我们朴素的样本:
在这里插入图片描述


2、创建条件扩散模型

  创建了一个名为ClassConditionedUnet的条件扩散模型,定义了一个可学习的嵌入层,用以将数字类别映射到特征向量上,将类别嵌入与原始输入拼接之后,送入常规的UNet网络即可。

  知识传送:【python函数】torch.nn.Embedding函数用法图解

class ClassConditionedUnet(nn.Module):def __init__(self, num_classes=10, class_emb_size=4):super().__init__()# 嵌入层将数字类别映射到特征向量上self.class_emb = nn.Embedding(num_classes, class_emb_size)# 一个常规的UNet网络self.model = UNet2DModel(sample_size=28,           # 图像尺寸in_channels=1 + class_emb_size, # 增加一个通道, 用于条件生成out_channels=1,           # 输出通道layers_per_block=2,       # 残差连接层数目block_out_channels=(32, 64, 64), down_block_types=( "DownBlock2D",        # a regular ResNet downsampling block"AttnDownBlock2D",    # a ResNet downsampling block with spatial self-attention"AttnDownBlock2D",), up_block_types=("AttnUpBlock2D", "AttnUpBlock2D",      # a ResNet upsampling block with spatial self-attention"UpBlock2D",          # a regular ResNet upsampling block),)def forward(self, x, t, class_labels):bs, ch, w, h = x.shape  # [8, 1, 28, 28] # 类别条件以额外通道的形式输入class_cond = self.class_emb(class_labels)  # [8, 4]class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h)  # [8, 4, 28, 28]# 拼接原始输入与类别条件映射net_input = torch.cat((x, class_cond), 1)   # (8, 5, 28, 28)# 模型预测return self.model(net_input, t).sample  # (8, 1, 28, 28)noisy_xb = torch.randn(8, 1, 28, 28).to(device)
timesteps = torch.linspace(0, 999, 8).long().to(device)
y = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1]).to(device)
model = ClassConditionedUnet().to(device)
with torch.no_grad():model_prediction = model(noisy_xb, timesteps, y)
model_prediction.shape  # 验证输出与输出尺寸相同

3、模型训练

  训练过程就跟之前的一样啦~

# 创建调度器
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')
train_dataloader = DataLoader(dataset, batch_size=128, shuffle=True)n_epochs = 10
net = ClassConditionedUnet().to(device)
loss_fn = nn.MSELoss()
opt = torch.optim.Adam(net.parameters(), lr=1e-3) losses = []
for epoch in range(n_epochs):for x, y in tqdm(train_dataloader):# 获取数据并添加噪声x = x.to(device) * 2 - 1  # 归一化到[-1, 1]y = y.to(device)noise = torch.randn_like(x)timesteps = torch.randint(0, 999, (x.shape[0],)).long().to(device)# 前向加噪noisy_x = noise_scheduler.add_noise(x, noise, timesteps)# 获得模型预测结果pred = net(noisy_x, timesteps, y)  # 此处传入了类别标签# 损失计算loss = loss_fn(pred, noise) # 损失回传, 参数更新opt.zero_grad()loss.backward()opt.step()# 损失保存losses.append(loss.item())# 输出损失avg_loss = sum(losses[-100:])/100print(f'Finished epoch {epoch}. Average of the last 100 loss values: {avg_loss:05f}')# 查看损失曲线
plt.figure(dpi=300)
plt.plot(losses)
plt.show()

  输出损失曲线为:

在这里插入图片描述


4、模型推理

  进行采样循环,用类别标签引导图像生成:

x = torch.randn(80, 1, 28, 28).to(device)  # 随机噪声
y = torch.tensor([[i]*8 for i in range(10)]).flatten().to(device)  # 类别标签# 采样循环
for i, t in tqdm(enumerate(noise_scheduler.timesteps)):# 模型预测结果with torch.no_grad():residual = net(x, t, y)# 根据预测噪声和时间步更新图像x = noise_scheduler.step(residual, t, x).prev_sample# 结果可视化
fig, ax = plt.subplots(1, 1, figsize=(12, 12))
ax.imshow(torchvision.utils.make_grid(x.detach().cpu().clip(-1, 1), nrow=8)[0], 'Greys')
ax.axis('off')

  类别引导效果如下,效果还是挺好的哩:

在这里插入图片描述


5、代码汇总

import torch
import torchvision
from torchvision import transforms
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel
from matplotlib import pyplot as plt
from tqdm.auto import tqdm
from PIL import Image
import numpy as npdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')# -----------------------------------------------------------------------------
# 1、数据集装载
dataset = torchvision.datasets.MNIST(root="mnist/", train=True, download=False, transform=torchvision.transforms.ToTensor())
train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)# 查看MNIST数据集样本
x, y = next(iter(train_dataloader))
print('Input shape:', x.shape)
print('Labels:', y)
plt.imshow(torchvision.utils.make_grid(x)[0], cmap='Greys')
plt.axis('off')
plt.show()
# -----------------------------------------------------------------------------# -----------------------------------------------------------------------------
# 2、创建条件扩散模型
class ClassConditionedUnet(nn.Module):def __init__(self, num_classes=10, class_emb_size=4):super().__init__()# 嵌入层将数字类别映射到特征向量上self.class_emb = nn.Embedding(num_classes, class_emb_size)# 一个常规的UNet网络self.model = UNet2DModel(sample_size=28,           # 图像尺寸in_channels=1 + class_emb_size, # 增加一个通道, 用于条件生成out_channels=1,           # 输出通道layers_per_block=2,       # 残差连接层数目block_out_channels=(32, 64, 64), down_block_types=( "DownBlock2D",        # a regular ResNet downsampling block"AttnDownBlock2D",    # a ResNet downsampling block with spatial self-attention"AttnDownBlock2D",), up_block_types=("AttnUpBlock2D", "AttnUpBlock2D",      # a ResNet upsampling block with spatial self-attention"UpBlock2D",          # a regular ResNet upsampling block),)def forward(self, x, t, class_labels):bs, ch, w, h = x.shape  # [8, 1, 28, 28] # 类别条件以额外通道的形式输入class_cond = self.class_emb(class_labels)  # [8, 4]class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h)  # [8, 4, 28, 28]# 拼接原始输入与类别条件映射net_input = torch.cat((x, class_cond), 1)   # (8, 5, 28, 28)# 模型预测return self.model(net_input, t).sample  # (8, 1, 28, 28)noisy_xb = torch.randn(8, 1, 28, 28).to(device)
timesteps = torch.linspace(0, 999, 8).long().to(device)
y = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1]).to(device)
model = ClassConditionedUnet().to(device)
with torch.no_grad():model_prediction = model(noisy_xb, timesteps, y)
model_prediction.shape  # 验证输出与输出尺寸相同
# -----------------------------------------------------------------------------# -----------------------------------------------------------------------------
# 3、模型训练
# 创建调度器
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')
train_dataloader = DataLoader(dataset, batch_size=128, shuffle=True)n_epochs = 10
net = ClassConditionedUnet().to(device)
loss_fn = nn.MSELoss()
opt = torch.optim.Adam(net.parameters(), lr=1e-3) losses = []
for epoch in range(n_epochs):for x, y in tqdm(train_dataloader):# 获取数据并添加噪声x = x.to(device) * 2 - 1  # 归一化到[-1, 1]y = y.to(device)noise = torch.randn_like(x)timesteps = torch.randint(0, 999, (x.shape[0],)).long().to(device)# 前向加噪noisy_x = noise_scheduler.add_noise(x, noise, timesteps)# 获得模型预测结果pred = net(noisy_x, timesteps, y)  # 此处传入了类别标签# 损失计算loss = loss_fn(pred, noise) # 损失回传, 参数更新opt.zero_grad()loss.backward()opt.step()# 损失保存losses.append(loss.item())# 输出损失avg_loss = sum(losses[-100:])/100print(f'Finished epoch {epoch}. Average of the last 100 loss values: {avg_loss:05f}')# 查看损失曲线
plt.figure(dpi=300)
plt.plot(losses)
plt.show()
# -----------------------------------------------------------------------------# -----------------------------------------------------------------------------
# 4、模型推理
x = torch.randn(80, 1, 28, 28).to(device)  # 随机噪声
y = torch.tensor([[i]*8 for i in range(10)]).flatten().to(device)  # 类别标签# 采样循环
for i, t in tqdm(enumerate(noise_scheduler.timesteps)):# 模型预测结果with torch.no_grad():residual = net(x, t, y)# 根据预测噪声和时间步更新图像x = noise_scheduler.step(residual, t, x).prev_sample# 结果可视化
fig, ax = plt.subplots(1, 1, figsize=(12, 12))
ax.imshow(torchvision.utils.make_grid(x.detach().cpu().clip(-1, 1), nrow=8)[0], 'Greys')
ax.axis('off')
# -----------------------------------------------------------------------------

  diffusion的修炼境界又提升了一级~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/7713.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

store内路由跳转router.push

选择action还是mutation 选择action mutation 是用来改变state的,不应该包含路由相关操作mutation是同步执行的,不应该包含异步操作,而路由是异步操作 action中进行路由跳转 因为vuex中没有this,所以不能用this.$router&#…

Rust Course学习(编写测试)

如果友友你的计算机上没有安装Rust,可以直接安装:Rust 程序设计语言 (rust-lang.org)https://www.rust-lang.org/zh-CN/ Introduce 介绍 Testing in Rust involves writing code specifically designed to verify that other code works as expected. It…

2024.5.7

槽函数声明 private slots:void on_ed_textChanged();void on_pushButton_clicked(); }; 槽函数定义 void Widget::on_ed_textChanged()//文本框 {if(ui->ed1->text().length()>5&&ui->ed2->text().length()>5){ui->pushButton->setStyleSh…

Xinstall广告效果监测,助力广告主优化投放策略

在移动互联网时代,APP推广已成为企业营销的重要手段。然而,如何衡量推广效果,了解用户来源,优化投放策略,一直是广告主和开发者面临的难题。这时,Xinstall作为国内专业的App全渠道统计服务商,以…

实名认证的接口方式、PHP身份证实名认证接口集成

身份证实名认证接口用于验证用户提交身份信息的真实性和有效性,开发者可以下载开发者示例快速的将身份证实名认证接口功能集成到自己的应用中,以此来保障用户的身份信息不被泄露和滥用。 翔云身份证实名认证接口可以通过身份证号、姓名、证件人像、现场…

Docker 安装部署 postgres

Docker 安装部署 postgres 1、拉取 postgres 镜像文件 [rootiZbp19a67kznq0h0rgosuxZ ~]# docker pull postgres:latest latest: Pulling from library/postgres b0a0cf830b12: Pull complete dda3d8fbd5ed: Pull complete 283a477db7bb: Pull complete 91d2729fa4d5: Pul…

RT-DETR-20240507周更说明|更新Inner-IoU、Focal-IoU、Focaler-IoU等数十种IoU计算方式

RT-DETR改进专栏|包含主干、模块、注意力、损失函数等改进 专栏介绍 本专栏包含模块、卷积、检测头、损失等深度学习前沿改进,目前已有改进点70!每周更新。 20240507更新说明: ⭐⭐ 更新CIoU、DIoU、MDPIoU、GIoU、EIoU、SIoU、ShapeIou、PowerfulIoU、…

数据结构-自定义栈、队列、二分查找树、双向链表

/*** 底层是数组*/ public class MyStack {private long [] arr; // 底层是数组private int top -1; // 核心【栈顶的索引(指针)】public MyStack() {super();arr new long[10];}public MyStack(int capacity) {super();arr new long[capacity]; // 自…

leetcode刷题:884、977

884.比较含退格的字符串 给定 s 和 t 两个字符串&#xff0c;当它们分别被输入到空白的文本编辑器后&#xff0c;如果两者相等&#xff0c;返回 true 。# 代表退格字符。 注意&#xff1a;如果对空文本输入退格字符&#xff0c;文本继续为空。 方法一、用栈 #include <i…

8.MyBatis 操作数据库(进阶)

文章目录 1.动态SQL插入1.1使用注解方式插入数据1.2使用xml方式插入数据1.3何时用注解何时用xml&#xff1f;1.4使用SQL查询中有多个and时&#xff0c;如何自动去除多余and1.4.1方法一&#xff1a;删除and之后的代码如图所示&#xff0c;再次运行1.4.2方法二&#xff1a;加上tr…

Linux服务器上网络端口测试

在使用telnet 111.22.345.66 80在Linux主机上尝试连接目标IP地址111.22.345.66的80端口时&#xff0c;会看到以下四行返回信息的含义解释&#xff1a; Trying 111.22.345.66...&#xff1a; 这一行指示telnet正在尝试与IP地址为111.22.345.66的主机建立连接。这表明telnet正尝…

elasticsearch安装配置注意事项

安装Elasticsearch时&#xff0c;需要注意以下几个重要事项&#xff1a; 1、版本选择&#xff1a;选择与你系统和其他组件&#xff08;如Logstash、Kibana&#xff09;兼容的Elasticsearch版本。 2、Java环境&#xff1a;Elasticsearch是基于Java构建的&#xff0c;因此确保已…

书生·浦语大模型实战营之 OpenCompass大模型评测

书生浦语大模型实战营之 OpenCompass &#xff1a;是骡子是马&#xff0c;拉出来溜溜 为什么要研究大模型的评测&#xff1f; 百家争鸣&#xff0c;百花齐放。 首先&#xff0c;研究评测对于我们全面了解大型语言模型的优势和限制至关重要。尽管许多研究表明大型语言模型在多…

Linux cmake 初窥【2】

1.开发背景 基于上一篇的基础上&#xff0c;再次升级 2.开发需求 基于 cmake 指定源文件目录可以是多个文件夹&#xff0c;多层目录 3.开发环境 ubuntu 20.04 cmake-3.23.1 4.实现步骤 4.1 准备源码文件 工程目录如下 顶层脚本 compile.sh 负责执行 cmake 操作&#xff0…

FSC森林认证是什么?

FSC森林认证&#xff0c;又称木材认证&#xff0c;是一种运用市场机制来促进森林可持续经营&#xff0c;实现生态、社会和经济目标的工具。FSC森林认证包括森林经营认证&#xff08;Forest Management, FM&#xff09;和产销监管链认证&#xff08;Chain of Custody, COC&#…

微搭低代码入门06分页查询

目录 1 创建自定义代码2 编写分页代码3 创建页面4 创建变量5 配置数据列表总结 我们在数据模型章节介绍了微搭后端服务编写的三种方式&#xff0c;包括Http请求、自定义代码、云函数。本篇我们详细讲解一下利用自定义代码开发分页查询的功能。 1 创建自定义代码 打开控制台&am…

Qt——入门基础

目录 Qt入门第一个应用程序 main.cpp widget.h widget.cpp widget.ui .pro Hello World程序 对象树 编辑框 按钮 Qt 窗口坐标系 Qt入门第一个应用程序 main.cpp 这就像一开始学语言时都会打印一个“Hello World”一样&#xff0c;我们先来看看创建好一个项目后&…

LeetCode 难题解析 —— 正则表达式匹配 (动态规划)

10. 正则表达式匹配 思路解析 这道题虽然看起来不难理解&#xff0c;但却存在多种可能&#xff0c;当然这种可能的数量是有限的&#xff0c;且其规律对于每一次判别都使用&#xff0c;所以自然而然就想到用 动态规划 的方法啦 接下来逐步分析可能的情况&#xff1a; &#x…

栈(使用顺序表构建)

P. S.&#xff1a;以下代码均在VS2019环境下测试&#xff0c;不代表所有编译器均可通过。 P. S.&#xff1a;测试代码均未展示头文件stdio.h的声明&#xff0c;使用时请自行添加。 目录 1、栈的概念2、栈的数组构建方法2.1 前言2.2 正文2.2.1 栈的初始化2.2.2 栈的销毁2.2.3 压…

Jupyter notebook和 Jupyter lab内核死亡问题的原因和解决方案

写在前面&#xff1a;之前也遇到过几次内核死亡的问题&#xff0c;也一直没有想解决办法。这里总结一下并提出几个解决办法。 首先明确一下jupyter出现内核死亡的原因&#xff1a;jupyter lab 或者 jupyter notebook 本身是一个web服务&#xff0c; 无法支持高并发和频繁的计算…