Stable Diffusion文生图模型训练入门实战(完整代码)

Stable Diffusion 1.5(SD1.5)是由Stability AI在2022年8月22日开源的文生图模型,是SD最经典也是社区最活跃的模型之一。

以SD1.5作为预训练模型,在火影忍者数据集上微调一个火影风格的文生图模型(非Lora方式),是学习SD训练的入门任务。

在这里插入图片描述

显存要求 22GB左右

在本文中,我们会使用SD-1.5模型在火影忍者数据集上做训练,同时使用SwanLab监控训练过程、评估模型效果。

  • 代码:Github
  • 实验日志过程:SD-naruto - SwanLab
  • 模型:runwayml/stable-diffusion-v1-5
  • 数据集:lambdalabs/naruto-blip-captions
  • SwanLab:https://swanlab.cn

1.环境安装

本案例基于Python>=3.8,请在您的计算机上安装好Python;

另外,您的计算机上至少要有一张英伟达显卡(显存大约要求22GB左右)。

我们需要安装以下这几个Python库,在这之前,请确保你的环境内已安装了pytorch以及CUDA:

swanlab
diffusers
datasets
accelerate
torchvision
transformers

一键安装命令:

pip install swanlab diffusers datasets accelerate torchvision transformers

本文的代码测试于diffusers0.29.0、accelerate0.30.1、datasets2.18.0、transformers4.41.2、swanlab==0.3.11,更多库版本可查看SwanLab记录的Python环境。

2.准备数据集

本案例是用的是火影忍者数据集,该数据集主要被用于训练文生图模型。

该数据集由1200条(图像、描述)对组成,左边是火影人物的图像,右边是对它的描述:

在这里插入图片描述

我们的训练任务,便是希望训练后的SD模型能够输入提示词,生成火影风格的图像:

在这里插入图片描述


数据集的大小大约700MB左右;数据集的下载方式有两种:

  1. 如果你的网络与HuggingFace连接是通畅的,那么直接运行我下面提供的代码即可,它会直接通过HF的datasets库进行下载。
  2. 如果网络存在问题,我也把它放到百度网盘(提取码: gtk8),下载naruto-blip-captions.zip到本地解压后,运行到与训练脚本同一目录下。

3.准备模型

这里我们使用HuggingFace上Runway发布的stable-diffusion-v1-5模型。

在这里插入图片描述

模型的下载方式同样有两种:

  1. 如果你的网络与HuggingFace连接是通畅的,那么直接运行我下面提供的代码即可,它会直接通过HF的transformers库进行下载。
  2. 如果网络存在问题,我也把它放到百度网盘(提取码: gtk8),下载stable-diffusion-v1-5.zip到本地解压后,运行到与训练脚本同一目录下。

4. 配置训练可视化工具

我们使用SwanLab来监控整个训练过程,并评估最终的模型效果。

如果你是第一次使用SwanLab,那么还需要去https://swanlab.cn上注册一个账号,在用户设置页面复制你的API Key,然后在训练开始时粘贴进去即可:

在这里插入图片描述

5.开始训练

由于训练的代码比较长,所以我把它放到了Github里,请Clone里面的代码:

git clone https://github.com/Zeyi-Lin/Stable-Diffusion-Example.git

如果你与HuggingFace的网络连接通畅,那么直接运行训练:

python train_sd1-5_naruto.py \--use_ema \--resolution=512 --center_crop --random_flip \--train_batch_size=1 \--gradient_accumulation_steps=4 \--gradient_checkpointing \--max_train_steps=15000 \--learning_rate=1e-05 \--max_grad_norm=1 \--seed=42 \--lr_scheduler="constant" \--lr_warmup_steps=0 \--output_dir="sd-naruto-model"

如果你的模型或数据集用的是上面的网盘下载,那么你需要做下面的两件事:

第一步:将数据集和模型文件夹放到训练脚本同一目录下,文件结构如下:

|--- sd_config.py
|--- train_sd1-5_naruto.py
|--- stable-diffusion-v1-5
|--- naruto-blip-captions

stable-diffusion-v1-5是下载好的模型文件夹,naruto-blip-captions是下载好的数据集文件夹。

第二步:修改sd_config.py的代码,将pretrained_model_name_or_pathdataset_name的default值分别改为下面这样:

    parser.add_argument("--pretrained_model_name_or_path",type=str,default="./stable-diffusion-v1-5",)parser.add_argument("--dataset_name",type=str,default="./naruto-blip-captions",)

然后运行启动命令即可。


看到下面的进度条即代表训练开始:

在这里插入图片描述

6. 训练结果演示

我们在SwanLab上查看最终的训练结果:

在这里插入图片描述

可以看到SD训练的特点是loss一直在震荡,随着epoch的增加,loss在最初下降后,后续的变化其实并不大:

在这里插入图片描述

我们来看看主观生成的图像,第一个epoch的图像长这样:

在这里插入图片描述

可以看到詹姆斯还是非常的“原生态”,迈克尔杰克逊生成的也怪怪的。。。

再看一下中间的状态:

在这里插入图片描述

在这里插入图片描述

经过比较长时间的训练后,效果就好了不少。

比较有意思的是,比尔盖茨生成出来的形象总是感觉非常邪恶。。。

详细训练过程看这里:SD-Naruto - SwanLab

至此,你已经完成了SD模型在火影忍者数据集上的训练。

相关链接

  • 代码:Github
  • 实验日志过程:SD-naruto - SwanLab
  • 模型:runwayml/stable-diffusion-v1-5
  • 数据集:lambdalabs/naruto-blip-captions
  • SwanLab:https://swanlab.cn

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/29191.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python | Leetcode Python题解之第162题寻找峰值

题目: 题解: class Solution:def findPeakElement(self, nums: List[int]) -> int:n len(nums)# 辅助函数,输入下标 i,返回 nums[i] 的值# 方便处理 nums[-1] 以及 nums[n] 的边界情况def get(i: int) -> int:if i -1 or…

数实融合创新发展 隆道分享企业级AI应用

在数字化浪潮的推动下,人工智能(AI)技术正以前所未有的速度改变着各行各业,重塑企业的运营管理模式和创新发展路径。6月14日,数实融合全国行(潍坊站)暨 AI 企业级应用专题会在山东潍坊成功召开。…

STM32单片机DMA存储器详解

文章目录 1. DMA概述 2. 存储器映像 3. DMA框架图 4. DMA请求 5. 数据宽度与对齐 6. DMA数据转运 7. ADC扫描模式和DMA 8. 代码示例 1. DMA概述 DMA(Direct Memory Access)可以直接访问STM32内部的存储器,DMA是一种技术,…

【随性】学习感想

这几天的作息时间:13:00-23:00,稍显疲惫。为了继续调整作战,明天改变生物钟,尝试新作息时间:9:00-20:00!最近已经学完了C与QT,又重回Linux的怀抱,以下为适应Linux下C编程的代码&…

【 ARMv8/ARMv9 硬件加速系列 3.5.1 -- SVE 谓词寄存器有多少位?】

文章目录 SVE 谓词寄存器(predicate registers)简介SVE 谓词寄存器的位数SVE 谓词寄存器对向量寄存器的控制SVE 谓词寄存器位数计算SVE 谓词寄存器小结SVE 谓词寄存器(predicate registers)简介 ARMv9的Scalable Vector Extension (SVE) 引入了谓词寄存器(Predicate Register…

打造工业操作系统开源开放体系

我国制造业具有细分行业、领域众多,产品丰富,制造模式多样等特点,围绕以工业操作系统为核心的工业软件赋能体系建设,离不开平台运营商、工业软件开发商、系统服务商、科研机构、工业企业等多方联合参与。聚众同行、聚力创新&#…

【数据库系统概论复习】关系数据库与关系代数笔记

文章目录 基本概念数据库基本概念关系数据结构完整性约束 关系代数关系代数练习课堂练习 语法树 基本概念 数据库基本概念 DB 数据库, 为了存用户的各种数据,我们要建很多关系(二维表),所以把相关的关系(二…

创建型模式--抽象工厂模式

产品族创建–抽象工厂模式 工厂方法模式通过引入工厂等级结构,解决了简单工厂模式中工厂类职责太重的问题。 但由于工厂方法模式中的每个工厂只生产一类产品,可能会导致系统中存在大量的工厂类,势必会增加系统的开销。此时,可以考虑将一些相关的产品组成一个“产品族”,…

Java基础16(集合 List)

目录 一、什么是集合? 二、集合接口 三、List集合 1. ArrayList容器类 1.1 常用方法 1.1.1 增加 1.1.2 查找 int size() E get(int index) int indexOf(Object c) boolean contains(Object c) boolean isEmpty() List SubList(int fromindex,i…

运行SpringBoot项目失败?代码出现爆红横线,提示“No beans of ‘UserService‘ type found”让我来看看~

今天在做实验运行项目的时候,发现userService: 一直在提示“No beans of UserService type found”,回去翻了Service业务层的代码,Service注解我也加了呀,奇了怪了。 运行项目,出现了这样的提示&#xff1…

求最小公倍数 、小球走过路程计算 题目

题目 JAVA11 求最小公倍数分析:代码:大佬代码: JAVA12 小球走过路程计算分析:代码: JAVA11 求最小公倍数 描述 编写一个方法,该方法的返回值是两个不大于100的正整数的最小公倍数。 输入描述:…

向mysql发送一个请求的时候,mysql到底做了什么

当 MySQL 接收到一个请求时,它会经过多个步骤来处理该请求,具体包括解析、优化、执行以及返回结果。这一过程涉及 MySQL 的多个组件和机制。 1. 客户端/服务器通信 当客户端发送一个请求到 MySQL 服务器时,通信是通过 TCP/IP 协议、Unix 套接字或其他支持的协议进行的。 …

immich上传库中删除本地照片后前台页面仍显示照片的问题解决方法

最近用了immich来管理照片,感觉很好用。 由于刚上手不了解使用方法,遇到了在上传库(upload库)删除本地照片后前台页面仍显示照片的问题。看了官方文档后有了解决方法,遂进行记录。 事件背景: immich有两…

LeetCode 310 最小高度树

题目信息 LeetoCode地址: . - 力扣(LeetCode) 题目理解 可以通过归纳法证明出一棵树的最小高度树可以通过从最外面度为1的叶子节点一层一层向内遍历得到,可以使用一种称为 **中心缩减法**(或称为 **剥洋葱法**)的方…

StarRocks分区表历史数据删除与管理

一、背景介绍 在使用 StarRocks 时,可能会遇到需要删除大批量数据的情况。然而,StarRocks 对 DELETE 操作的支持并不理想,主要存在以下问题: 不建议执行高频的 DELETE 操作:删除的数据会标记为“Deleted”&#xff0…

判断一组数据哪些是素数,并统计一个数组中元素的出现频率

import java.util.HashMap; import java.util.Map; public class Test_A26 {//判断一个数是不是素数public static boolean isPrime(int num){if(num<1){return false;}for(int i2;i<Math.sqrt(num);i){if(num%i0){return false;}}return true;}//统计数组中出现的频率 p…

python安装目录文件说明----Dlls文件夹

在Python的安装目录下&#xff0c;通常会有一个DLLs文件夹&#xff0c;它是Python标准库的一部分。这个文件夹包含了一些动态链接库&#xff08;Dynamic Link Libraries&#xff0c;DLL&#xff09;&#xff0c;这些库提供了Python解释器和标准库的一些关键功能。以下是对这个文…

模拟自动滚动并展开所有评论列表以及回复内容(如:抖音、b站等平台)

由于各大视频平台的回复内容排序不都是按照时间顺序&#xff0c;而且想看最新的评论回复讨论内容还需逐个点击展开&#xff0c;真的很蛋疼&#xff0c;尤其是热评很多的情况&#xff0c;还需要多次点击展开&#xff0c;太麻烦&#xff01; 于是写了一个自动化展开所有评论回复…

Kaggle比赛:成人人口收入分类

拿到数据首先查看数据信息和描述 import pandas as pd import seaborn as sns import matplotlib.pyplot as plt # 加载数据&#xff08;保留原路径&#xff0c;但在实际应用中建议使用相对路径或环境变量&#xff09; data pd.read_csv(r"C:\Users\11794\Desk…

嵌入式技术学习——c51——串口

一、串口介绍。 串口是一个 通讯接口。成本低&#xff0c;容易使用&#xff0c;通信线路简单&#xff0c;可实现两个设备的相互通信 单片机的串口可以实现单片机于单片机&#xff0c;单片机与电脑&#xff0c;单片机与其他模块相互通信。 51单片机内部自带UART&#xff0c;通…