扩散模型实战(十):Stable Diffusion文本条件生成图像大模型

推荐阅读列表:

 扩散模型实战(一):基本原理介绍

扩散模型实战(二):扩散模型的发展

扩散模型实战(三):扩散模型的应用

扩散模型实战(四):从零构建扩散模型

扩散模型实战(五):采样过程

扩散模型实战(六):Diffusers DDPM初探

扩散模型实战(七):Diffusers蝴蝶图像生成实战

扩散模型实战(八):微调扩散模型

扩散模型实战(九):使用CLIP模型引导和控制扩散模型

        在AIGC时代,Stable Diffusion无疑是其中最亮的“仔”,它是一个强大的文本条件隐式扩散模型(text-conditioned latent diffusion model),可以根据文字描述(也称为Prompt)生成精美图片。

一、基本概念

1.1 隐式扩散

       对于基于transformer的大模型来说,self-attention的计算复杂度与输入数据是平方关系的,比如一张128X128像素的图片在像素数量上是64X64像素图片的4倍,内存和计算量是16倍。这正是高分辨率图像生成任务存在的普遍现象。

       为了解决这个问题,提出了隐式扩散(Latent Diffusion)方法,该方法认为图片通常包含大量冗余信息,首先使用大量图片数据训练一个Variational Auto-Encode(VAE)模型,编码器将图片映射到一个较小的隐式表示,解码器可以将较小的隐式表示映射到原始图片。Stable Diffusion中的VAE接受一张3通道图片作为输入,生成一个4通道的隐式特征,同时每一个空间维度都将减少为原来的八分之一。例如,一张512X512像素的图片可以被压缩到一个4X64X64的隐式表示。

       通过在隐式表示(而不是完整图像)上进行扩散,可以使用更少的内存也可以减少UNet层数,从而加速图片生成,极大降低了训练和推理成本。
        隐式扩散的结构,如下图所示:

1.2 以文本为生成条件

       前面的章节展示了如何将额外信息输入给UNet,以实现对生成图像的控制,这种方法称为条件生成。以文本为条件进行控制图像的生成是在推理阶段,我们可以输入期望图像的文本描述(Prompt),并把纯噪声数据作为起点,然后模型对噪声数据进行“去噪”,从而生成能够匹配文本描述的图像。那么这个过程是如何实现的呢?

      我们需要对文本进行编码表示,然后输入给UNet作为生成条件,文本嵌入表示如下图ENCODER_HIDDEN_STATES

       Stable Diffusion使用CLIP对文本描述进行编码,首先对输入文本描述进行分词,然后输入给CLIP文本编码器,从而为每个token产生一个768维(Stable Diffusion 1.x版本)或者1024维(Stable Diffusion 2.x版本)向量,为了使输入格式一致,文本描述总是被补全或者截断为77个token。

       那么,如何将这些条件信息输入给UNet进行预测呢?答案是使用交叉注意力(cross-attention)机制。UNet网络中的每个空间位置都可以与文本条件中的不同token建立注意力(在稍后的代码中可以看到具体的实现),如下图所示:

1.3 无分类器引导

       第2节我们提到可以使用CLIP编码文本描述来控制图像的生成,但是实际使用中,每个生成的图像都是按照文本描述生成的吗?当然不一定,其实是大模型的幻觉问题,原因可能是训练数据中图像与文本描述相关性弱,模型可能学着不过度依赖文本描述,而是从大量图像中学习来生成图像,最终达不到我们的期望,那如何解决呢?

       我们可以引入一个小技巧-无分类器引导(Classifier-Free Guidance,CFG)。在训练时,我们时不时把文本条件置空,强制模型去学习如何在无文字信息的情况下对图像“去噪”。在推理阶段,我们分别进行了两个预测:一个有文字条件,另一个则没有文字条件。这样我们就可以利用两者的差异来建立一个最终的预测了,并使最终结果在文本条件预测所指明的方向上依据一个缩放系数(引导尺度)更好的生成文本描述匹配的结果。从下图看到,更大的引导尺度能让生成的图像更接近文本描述。

1.4 其他类型的条件生成模型:Img2Img、Inpainting与Depth2Img模型

       其实除了使用文本描述作为条件生成图像,还有其他不同类型的条件可以控制Stable Diffusion生成图像,比如图片到图片、图片的部分掩码(mask)到图片以及深度图到图片,这些模型分别使用图片本身、图片掩码和图片深度信息作为条件来生成最终的图片。

       Img2Img是图片到图片的转换,包括多种类型,如风格转换(从照片风格转换为动漫风格)和图片超分辨率(给定一张低分辨率图片作为条件,让模型生成对应的高分辨率图片,类似Stable Diffusion Upscaler)。Inpainting又称图片修复,模型会根据掩码的区域信息和掩码之外的全局结构信息生成连贯的图片。Depth2Img采用图片的深度新作为条件,模型生成与深度图本身相似的具有全局结构的图片,如下图所示:

1.5 使用DreamBooth微调扩散模型

      DreamBooth可以微调文本到图像的生成模型,最初是为Google的Imagen Model开发的,很快被应用到Stable Diffusion中。它可以根据用户提供的一个主题3~5张图片,就可以生成与该主题相关的图像,但它对于各种设置比较敏感。

二、环境准备

安装python库

pip install -Uq diffusers ftfy acceleratepip install -Uq git+https://github.com/huggingface/transformers

数据准备

import torchimport requestsfrom PIL import Imagefrom io import BytesIOfrom matplotlib import pyplot as plt # 这次要探索的管线比较多from diffusers import (    StableDiffusionPipeline,     StableDiffusionImg2ImgPipeline,    StableDiffusionInpaintPipeline,     StableDiffusionDepth2ImgPipeline    )        # 因为要用到的展示图片较多,所以我们写了一个旨在下载图片的函数def download_image(url):    response = requests.get(url)    return Image.open(BytesIO(response.content)).convert("RGB") # Inpainting需要用到的图片img_url = "https://raw.githubusercontent.com/CompVis/latent- diffusion/main/data/inpainting_examples/overture-creations- 5sI6fQgYIuo.png"mask_url = "https://raw.githubusercontent.com/CompVis/latent- diffusion/main/data/ inpainting_examples/overture-creations- 5sI6fQgYIuo_mask.png" init_image = download_image(img_url).resize((512, 512))mask_image = download_image(mask_url).resize((512, 512)) device = (    "mps"    if torch.backends.mps.is_available()    else "cuda"    if torch.cuda.is_available()    else "cpu")

三、使用文本描述控制生成图像

       加载Stable Diffusion Pipeline,当然可以通过model_id切换Stable Diffusion版本

# 载入管线model_id = "stabilityai/stable-diffusion-2-1-base"pipe = StableDiffusionPipeline.from_pretrained(model_id).to(device)

如果GPU显存不足,可以尝试以下方法来减少GPU显存的使用

  • 降低模型的精度为FP16
pipe = StableDiffusionPipeline.from_pretrained(model_id,    revision="fp16",torch_dtype=torch.float16).to(device)
  • 开启注意力切分功能,可以降低速度来减少GPU显存的使用
pipe.enable_attention_slicing()

  • 减小生成图像的尺寸
# 给生成器设置一个随机种子,这样可以保证结果的可复现性generator = torch.Generator(device=device).manual_seed(42) # 运行这个管线pipe_output = pipe(    prompt="Palette knife painting of an autumn cityscape",    # 提示文字:哪些要生成    negative_prompt="Oversaturated, blurry, low quality",    # 提示文字:哪些不要生成    height=480, width=640,     # 定义所生成图片的尺寸    guidance_scale=8,          # 提示文字的影响程度    num_inference_steps=35,    # 定义一次生成需要多少个推理步骤    generator=generator        # 设定随机种子的生成器) # 查看生成结果,如图6-7所示pipe_output.images[0]

主要参数介绍:

width和height:用于指定生成图片的尺寸,他们必须可以被8整除,否则VAE不能整除工作;

num_inference_steps:会影响生成图片的质量,采用默认50即可,用户也可以尝试不同的值来对比一下效果;

negative_prompt:用于强调不希望生成的内容,该参数一般在无分类器引导的情况下使用。列出一些不想要的特征,以帮助模型生成更好的结果;

guidance_scale:决定了无分类器引导的影响强度。增大这个参数可以使生成的内容更接近给出的文本描述,但是参数值过大,则可能导致结果过于饱和,不美观,如下图所示:

cfg_scales = [1.1, 8, 12] prompt = "A collie with a pink hat" fig, axs = plt.subplots(1, len(cfg_scales), figsize=(16, 5))for i, ax in enumerate(axs):    im = pipe(prompt, height=480, width=480,        guidance_scale=cfg_scales[i], num_inference_steps=35,        generator=torch.Generator(device=device).manual_seed(42)).            images[0]     ax.imshow(im); ax.set_title(f'CFG Scale {cfg_scales[i]}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/151709.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【推荐】智元兔AI:一款集写作、问答、绘画于一体的全能工具!

在当今技术飞速发展的时代,越来越多的领域开始应用人工智能(Artificial Intelligence,简称AI)。其中,AI写作工具备受瞩目,备受推崇。在众多的选择中,智元兔AI是一款在笔者使用过程中非常有帮助的…

Halcon Solution Guide I basics(2): Image Acquisition(图像加载)

文章目录 文章专栏前言文章解读文章开头流程图算子介绍案例自主练习读取一张图片读取多张图片 文章专栏 Halcon开发 Halcon学习 练习项目gitee仓库 前言 今天来看Halcon的第二章,图像获取。在第二章之后,后面文章就会提供案例了。到时候我会尽量完成每一…

场景交互与场景漫游-交运算与对象选取(8-1)

交运算与对象选取 在面对大规模的场景管理时,场景图形的交运算和图形对象的拾取变成了一项基本工作。OSG作为一个场景管理系统,自然也实现了场景图形的交运算,交运算主要封装在osgUtil 工具中在OSG中,osgUtil是一个非常强有力的工…

【Python】给定一个长度为n的数列,将这个数列按从小到大的顺序排列。1<=n<=200

2、问题描述 给定一个长度为n的数列&#xff0c;将这个数列按从小到大的顺序排列。1<n<200 样例输入 5 8 3 6 4 9 样例输出 3 4 6 8 9 n int(input()) a list(map(int,input().split())) a.sort() for i in a:print(i,end ) 运行结果&#xff1a;

毕业设计JSP 2384网上diy蛋糕店管理系统【程序源码+讲解视频+调试运行】

一、摘要 本文将介绍一个功能全面、易于使用的网上DIY蛋糕店管理系统。该系统包括用户和管理员两种用户&#xff0c;每种用户都有相应的功能模块。系统实现了网站首页、用户注册/登录、蛋糕展示、综合排行、购物车、蛋糕DIY和用户中心等功能&#xff0c;同时管理员还可以进行管…

庖丁解牛:NIO核心概念与机制详解 01 _ 入门篇

文章目录 Pre输入/输出Why NIO流与块的比较通道和缓冲区概述什么是缓冲区&#xff1f;缓冲区类型什么是通道&#xff1f;通道类型 NIO 中的读和写概述Demo : 从文件中读取1. 从FileInputStream中获取Channel2. 创建ByteBuffer缓冲区3. 将数据从Channle读取到Buffer中 Demo : 写…

算法-二叉树-简单-二叉树的最大和最小深度

记录一下算法题的学习7 二叉树的最大深度 题目&#xff1a;给定一个二叉树 root &#xff0c;返回其最大深度。 二叉树的 最大深度 是指从根节点到最远叶子节点的最长路径上的节点数。 输入&#xff1a;root [3,9,20,null,null,15,7] 输出&#xff1a;3 示例分析&#xff…

MATLAB 状态空间设计 —— LQG/LQR 和极点配置算法

系列文章目录 文章目录 系列文章目录前言一、相关函数 —— LQG/LQR 和极点配置算法1.1 LQR —— lqr 函数1.1.1 函数用法1.1.2 举例1.1.2.1 倒摆模型的 LQR 控制 1.2 LQG —— lqg() 函数1.2.1 函数用法1.2.2 举例 1.3 极点配置 —— place() 函数1.3.1 函数用法1.3.2 示例1.3…

Selenium安装WebDriver最新Chrome驱动(含116/117/118/119)

&#x1f4e2;专注于分享软件测试干货内容&#xff0c;欢迎点赞 &#x1f44d; 收藏 ⭐留言 &#x1f4dd; 如有错误敬请指正&#xff01;&#x1f4e2;交流讨论&#xff1a;欢迎加入我们一起学习&#xff01;&#x1f4e2;资源分享&#xff1a;耗时200小时精选的「软件测试」资…

如何在虚拟机的Ubuntu22.04中设置静态IP地址

为了让Linux系统的IP地址在重新启动电脑之后IP地址不进行变更&#xff0c;所以将其IP地址设置为静态IP地址。 查看虚拟机中虚拟网络编辑器获取当前的子网IP端 修改文件/etc/netplan/00-installer-config.yaml文件&#xff0c;打开你会看到以下内容 # This is the network conf…

面向开发者的Android

Developerhttps://developer.android.google.cn/?hlzh-cn SDK 平台工具版本说明https://developer.android.google.cn/studio/releases/platform-tools?hlzh-cn#revisions Android SDK Platform-Tools 是 Android SDK 的一个组件。它包含与 Android 平台进行交互的工具…

【Redis】springboot整合redis(模拟短信注册)

要保证redis的服务器处于打开状态 上一篇&#xff1a; 基于session的模拟短信注册 https://blog.csdn.net/m0_67930426/article/details/134420531 整个流程是&#xff0c;前端点击获取验证码这个按钮&#xff0c;后端拿到这个请求&#xff0c;通过RandomUtil 工具类的方法生…

Labview中for循环“无法终止”问题?即使添加了条线接线端,达到终止条件后,仍在持续运行?

关键&#xff1a; 搞清楚“运行”和“连续运行”两种运行模式的区别。 出现题目中所述问题&#xff0c;大概率是因为代码运行在“连续运行“模式下。 可以通过添加 探针 的方式&#xff0c;加深理解&#xff01;

拼图游游戏代码

一.创建新项目 二.插入图片 三.游戏的主界面 1.代码 package com.itheima.ui;import java.awt.event.ActionEvent; import java.awt.event.ActionListener; import java.awt.event.KeyEvent; import java.awt.event.KeyListener; import java.util.Random;import javax.swing…

pnpm : 无法加载文件 E:\Soft\PromSoft\nodejs\node_global\pnpm.ps1,

pnpm : 无法加载文件 E:\Soft\PromSoft\nodejs\node_global\pnpm.ps1&#xff0c;因为在此系统上禁止运行脚本。有关详细信息&#xff0c;请参阅 https:/go.microsoft.com/fwlink/?LinkID135170 中 的 about_Execution_Policies。 所在位置 行:1 字符: 1pnpm -v~~~~ CategoryI…

Django 入门学习总结6 - 测试

1、介绍自动化测试 测试的主要工作是检查代码的运行情况。测试有全覆盖和部分覆盖。 自动测试表示测试工作由系统自动完成。 在大型系统中&#xff0c;有许多组件有很复杂的交互。一个小的变化可能会带来意想不到的后果 测试能发现问题&#xff0c;并以此解决问题。 测试驱…

FPGA实现平衡小车(文末开源!!)

FPGA平衡小车 一. 硬件介绍 底板资源: TB6612电机驱动芯片 * 2 MPU6050陀螺仪 WS2812 RGB彩色灯 * 4 红外接收头 ESP-01S WIFI 核心板 微相 A7_Lite Artix-7 FPGA开发板 电机采用的是平衡小车之家的MG310(GMR编码器)电机。底板上有两个TB6612芯片&#xff0c;可以驱动…

C++设计模式——单例模式

单例设计模式 应用场景特点设计模式分类懒汉设计模式饿汉设计模式使用编写的测试代码运行结果 应用场景 当多个类都需要调用某一个类的一些公共接口&#xff0c;同时不想创建多个该类的对象&#xff0c;可以考虑将该类封装为一个单例模式。 特点 单例模式的特点&#xff1a;…

UnitTest框架

目标&#xff1a; 1.掌握UnitTest框架的基本使用方法 2.掌握断言的使用方法 3.掌握如何实现参数化 4.掌握测试报告的生成 1.定义 &#xff08;1&#xff09;框架(framework)&#xff1a;为解决一类事情的功能集合。&#xff08;需要按照框架的规定(套路) 去书写代码&…