AI绘画训练一个扩散模型-上集

介绍

AI绘画,其中最常见方案基于扩散模型,Stable Diffusion 在此基础上,增加了 VAE 模块和 CLIP 模块,本文搞了一个测试Demo,分为上下两集,第一集是denoising_diffusion_pytorch ,第二集是diffusers。
对于专业的算法同学而言,我更推荐使用 diffusers 来训练。原因是 diffusers 工具包在实际的 AI 绘画项目中用得更多,并且也更易于我们修改代码逻辑,实现定制化功能。
https://arxiv.org/abs/2112.10752

基础模块

  • 创建UNet模型和高斯扩散模型(Gaussian Diffusion)。

UNet是一个编码器-解码器结构的全卷积神经网络。Gaussian Diffusion用于定义噪声过程和损失函数。

  • 将模型加载到GPU上(如果有GPU)。

  • 使用随机初始化的图片进行一次训练,计算损失并反向传播。

这一步的目的是对模型进行一次预热,更新权重。

  • 使用diffusion模型采样生成图片。

这里采样1000步,也就是将噪声逐步减少,每步用UNet预测下一步的图像,最终输出生成的图片。

  • 如果图片在GPU上,将其移回到CPU。

  • 可视化第一张生成图片。

plt.imshow显示图片。

这样通过DDPM框架,可以从随机噪声生成符合数据分布的新图片。每次训练会使模型逐步逼近真实数据分布,从而产生更高质量的图片。

# 创建UNet和扩散模型from denoising_diffusion_pytorch import Unet, GaussianDiffusion
import torchmodel = Unet(dim = 64,dim_mults = (1, 2, 4, 8)
).cuda()diffusion = GaussianDiffusion(model,image_size = 128,timesteps = 1000   # number of steps
).cuda()# 使用随机初始化的图片进行一次训练
training_images = torch.randn(8, 3, 128, 128)
loss = diffusion(training_images.cuda())
loss.backward()# 采样1000步生成一张图片
sampled_images = diffusion.sample(batch_size = 4)
import torch
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
import torchvision.transforms as transforms# 如果张量在 GPU上,需要移动到 CPU上
if sampled_images.is_cuda:sampled_images = sampled_images.cpu()# 检查我们生成的一张图
img = sampled_images[0].clone().detach().permute(1, 2, 0)plt.imshow(img)

数据集

  • 导入所需的库:PIL、io、datasets等。

  • 使用datasets库中的load_dataset方法加载Oxford Flowers数据集。

  • 创建一个目录来保存图片。

  • 遍历数据集的训练、验证、测试split,逐个图像获取图片bytes数据,并保存为PNG格式图片。

  • 使用PIL库的Image对象将bytes数据加载并保存为图片文件。

  • 使用tqdm显示循环进度。

# 数据集下载
from PIL import Image
from io import BytesIO
from datasets import load_dataset
import os
from tqdm import tqdmdataset = load_dataset("nelorth/oxford-flowers")# 创建一个用于保存图片的文件夹
images_dir = "./oxford-datasets/raw-images"
os.makedirs(images_dir, exist_ok=True)# 遍历所有图片并保存,针对oxford-flowers,整个过程要持续15分钟左右
for split in dataset.keys():for index, item in enumerate(tqdm(dataset[split])):image = item['image']image.save(os.path.join(images_dir, f"{split}_image_{index}.png"))

模型训练

  • 定义Unet模型架构和Gaussian Diffusion过程。

  • 加载数据,指定训练参数:

    • 训练总步数20000
    • batch size 16
    • 学习率2e-5
    • 梯度累积步数2
    • EMA指数衰减参数0.995
    • 使用混合精度训练
    • 每2000步保存一次模型
    • 创建Trainer进行模型训练。Trainer封装了训练循环逻辑。
  • 调用trainer.train()进行训练。

# 模型训练
import torch
from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainermodel = Unet(dim = 64,dim_mults = (1, 2, 4, 8)
).cuda()diffusion = GaussianDiffusion(model,image_size = 128,timesteps = 1000   # 加噪总步数
).cuda()trainer = Trainer(diffusion,'./oxford-datasets/raw-images',train_batch_size = 16,train_lr = 2e-5,train_num_steps = 20000,          # 总共训练20000步gradient_accumulate_every = 2,    # 梯度累积步数ema_decay = 0.995,                # 指数滑动平均decay参数amp = True,                       # 使用混合精度训练加速calculate_fid = False,            # 我们关闭FID评测指标计算(比较耗时)。FID用于评测生成质量。save_and_sample_every = 2000      # 每隔2000步保存一次模型
)trainer.train()
# 你可以等待上面的模型训练完成后,查看生成结果from glob import globresult_images = glob(r"./results/*.png")
print(result_images)
# 可视化图像看看
from PIL import Imageimg = Image.open("./results/sample-1.png")
img

测试

https://colab.research.google.com/github/NightWalker888/ai_painting_journey/blob/main/lesson12/train_diffusion_v2.ipynb#scrollTo=8BVjfBPI93Ar

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/420912.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

WPF多线程UI更新——两种方法

WPF多线程UI更新——两种方法 前言 在WPF中,在使用多线程在后台进行计算限制的异步操作的时候,如果在后台线程中对UI进行了修改,则会出现一个错误:(调用线程无法访问此对象,因为另一个线程拥有该对象。&…

二叉树的递归遍历

二叉树遍历一,什么是二叉树二,递归实现2.1 结点类描述2.2 三种递归2.2 测试一,什么是二叉树 在计算机科学中,二叉树是每个结点最多有两个子树的树结构。通常子树被称作"左子树"(left subtree)和&…

概率论的公理结构

样本点 一个随机事件出现的可能的结果叫做样本点。 类比平面几何,线、面、体也是由点组成的集合,研究的是点线面关系及性质,同样样本点也是组成事件(集合)的材料,是集合的基本元素,把这些样本…

python词云的简单使用

词云的生成所需库代码实现wordclod参数说明具体实现效果展示所需库 wordcloud, jieba, imageiowordcloud 词云库,用来统计文本文档里面出现的高频词汇,或者句子,以图片可视化的方式显示出来jieba库,分割中文的库,把较…

(一)Neo4j在Centos7虚拟机上的安装

1、什么是图数据库? 图数据库是基于数学里图论的思想和算法而实现的高效处理复杂关系网络的新型数据库系统。图形数据库善于高效处理大量的、复杂的、互连的、多变的数据。其计算效率远远高于传统的关系型数据库。图形数据库在社交网络、实时推荐、征信系统、人工智…

(二)Cypher语言常用方法举例

1、概述 “Cypher”是一个描述性的类Sql的图操作语言。相当于关系数据库的Sql,可见其重要性!其语法针对图的特点而设计,非常方便和灵活。没有Join,是一大特点!学好Cypher是学好Neo4j的关键,也是核心所在&a…

github 人像卡通化探索项目

把项目下载到本地 下载地址 https://github.com/minivision-ai/photo2cartoon安装依赖库 python 3.7 # 3.x版本都可 pytorch 1.4 tensorflow-gpu 1.14 # tesorflow 得是1.0版本,2.0版本语法部分改变,不然项目运行会出错 face-alignment dlibpytorch …

CVE-2013-3897漏洞成因与利用分析

CVE-2013-3897漏洞成因与利用分析 1. 简介 此漏洞是UAF(Use After Free)类漏洞,即引用了已经释放的内存。攻击者可以利用此类漏洞实现远程代码执行。UAF漏洞的根源源于对对象引用计数的处理不当,比如在编写程序时忘记AddRef或者多…

(三)Neo4j自带northwind案例--Cypher语言应用

0、概述 通过该案例,应用Cypher查询语言,感受Neo4j套路。官方的用此案例的用意: The Northwind Graph demonstrates how to migrate(迁移) from a relational database to Neo4j(把一个负责的多表关系数据…

RDIFramework.NET 中多表关联查询分页实例

RDIFramework.NET 中多表关联查询分页实例 RDIFramework.NET 中多表关联查询分页实例 RDIFramework.NET,基于.NET的快速信息化系统开发、整合框架,给用户和开发者最佳的.Net框架部署方案。该框架以SOA范式作为指导思想,作为异质系统整合与互操…

(六)Neo4j综合项目

0、概述 本文以热播电视剧《人民的名义》中的人物关系为数据基础,抛开案例本身的内容,本项目的意义在于指出使用Neo4j数据库的一般流程是什么?包括数据的导入、操作、查询、展示,从而体会出与传统数据库相比Neo4j在处理图数据的巨…

过滤器filter,监听器listener

目录1. filter过滤器1.1 原理1.2 配置1.3 过滤掉脏话demo2. listener监听器2.1 作用2.2 ServletContextListener demo1. filter过滤器 作用:过滤servlet,jsp,js,css,图片对象,以及一切在服务器,客户端想访…

(一)elasticsearch6.1.1安装详细过程

1、配置java环境 检查java环境 满足elasticsearch6.1.1java环境要求; 2、安装ElasticSearch6.1.1 ①为es新生成用户、用户组 su root groupadd esgroup useradd ela -g esgroup -p 5tgbhu8[rootlocalhost fibonacci]# su ela Attempting to create directory /h…

使用jdk DOM,SAX和第三方jar包DOM4J创建,解析xml文件

xml的创建,解析1. 什么是xml文件1.1 什么是xml文件1.2 解析xml的方式,优缺点2. 使用dom操作xml文件2.1 使用dom创建xml文件2.2 使用dom解析xml文件2.3 使用dom对xml文件增删改3. 使用SAX解析xml文件4. 使用DOM4J操作xml文件4.1 使用DOM4J创建xml文件4.2 …

(二)ElasticSearch6.1.1 Python API

0、准备开启数据库 ① 关闭Linux防火墙,这个很重要,否则API总是报错连不上。 # 查看防火墙状态 firewall-cmd --state# 关闭防护墙 systemctl stop firewalld.service# 开启防火墙 systemctl start firewalld.service# 重启防火墙 systemctl restart f…

sqlite3数据库使用

SQLite简介 SQLite是一个软件库,实现了自给自足的、无服务器的、零配置的、事务性的 SQL 数据库引擎。SQLite是一个增长最快的数据库引擎,这是在普及方面的增长,与它的尺寸大小无关。SQLite 源代码不受版权限制。 什么是sqlite SQLite是一…

(三)ElasticSearch的基本概念

0、面向文档 应用中的对象很少只是简单的键值列表,更多时候它拥有复杂的数据结构,比如包含日期、地理位置、另一个对象或者数组。 总有一天你会想到把这些对象存储到数据库中。将这些数据保存到由行和列组成的关系数据库中,就好像是把一个丰…

ajax下拉框省市级联动

目录效果sql数据前后台代码实现效果 初始访问页面 选中省会,自动刷新页面 sql数据 -- 省市联动数据CREATE TABLE PROVINCE (PID NUMBER PRIMARY KEY,PNAME VARCHAR(20) NOT NULL )SELECT * FROM PROVINCEINSERT INTO province VALUES (1, 北京市); INSERT I…