【深度学习】StableDiffusion的组件解析,运行一些基础组件效果

文章目录

  • 前言
  • vae
  • clip
  • UNet
  • unet训练
  • 帮助、问询

前言

看了篇文:

https://zhuanlan.zhihu.com/p/617134893

运行一些组件试试效果。

vae

代码:

import torch
from diffusers import AutoencoderKL
import numpy as np
from PIL import Image# 加载模型: autoencoder可以通过SD权重指定subfolder来单独加载
autoencoder = AutoencoderKL.from_pretrained("/ssd/xiedong/src_data/eff_train/Stable-diffusion/majicmixRealistic_v7_diffusers", subfolder="vae")
autoencoder.to("cuda", dtype=torch.float16)# 读取图像并预处理
raw_image = Image.open("girl.png").convert("RGB").resize((512, 512))
image = np.array(raw_image).astype(np.float32) / 127.5 - 1.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)# 压缩图像为latent并重建
with torch.inference_mode():# latentx形状是 (B, C, H, W) 的张量  (1,4,64,64)latentx = autoencoder.encode(image.to("cuda", dtype=torch.float16)).latent_dist.sample()  # 压缩# 保存 latent 为 PNG 图像latent = latentx.permute(0, 2, 3, 1)  # 将 latent 重新排列为 (B, H, W, C) 格式latent = latent.cpu().numpy()  # 将 latent 转换为 NumPy 数组latent = (latent * 127.5 + 127.5).astype('uint8')  # 将值缩放到 [0, 255] 范围内,并转换为 uint8 类型latent = latent.squeeze(0)  # 去掉批次维度latent_image = Image.fromarray(latent)  # 将 NumPy 数组转换为 PIL Imagelatent_image.save("latent.png")  # 保存为 PNG 图像# shapeprint(latentx.shape)rec_image = autoencoder.decode(latentx).samplerec_image = (rec_image / 2 + 0.5).clamp(0, 1)rec_image = rec_image.cpu().permute(0, 2, 3, 1).numpy()rec_image = (rec_image * 255).round().astype("uint8")rec_image = Image.fromarray(rec_image[0])# save
rec_image.save("demo.png")

原图:

在这里插入图片描述

encoder之后:

在这里插入图片描述
decoder之后:

在这里插入图片描述

clip

在自然语言处理任务中,tokenizer和text_encoder是两个重要的组件,用于将文本转换为模型可以理解的数值表示形式。

  1. Tokenizer:

Tokenizer的作用是将一个文本序列(如句子或段落)分割成一系列的token(通常是单词或子词)。它将文本映射到一个词表(vocabulary),每个token对应词表中的一个索引(index)。在您的代码示例中,tokenizer将输入的文本"a photograph of an astronaut riding a horse"转换为对应的token id序列。

  1. Text Encoder:

Text Encoder(文本编码器)是一个预训练的语言模型,它可以将token id序列编码为对应的向量表示(embeddings)。这些向量表示捕获了单词及其上下文的语义信息。

在您的代码中:

  • tokenizer(prompt) 将文本"a photograph of an astronaut riding a horse"转换为对应的token id序列(如[1, 27, 38, 61, …]等)。
  • text_encoder(text_input_ids) 将这个token id序列输入到文本编码器模型中,得到对应的向量表示text_embeddings

最终得到的text_embeddings的形状是 torch.Size([1, 77, 768])。其中:

  • 1 表示批次大小(batch size),对于单个输入就是1。
  • 77 表示输入token的数量(由于padding,长度被填充到模型的最大长度)。
  • 768 是每个token的向量表示的维度。

通过这种编码方式,原始的文本被转换为了具有丰富语义信息的数值向量表示,方便被深度学习模型进一步处理。这种编码过程是自然语言处理中的常见做法。

from transformers import CLIPTextModel, CLIPTokenizertext_encoder = CLIPTextModel.from_pretrained("/ssd/xiedong/src_data/eff_train/Stable-diffusion/majicmixRealistic_v7_diffusers", subfolder="text_encoder").to("cuda")
# text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to("cuda")
tokenizer = CLIPTokenizer.from_pretrained("/ssd/xiedong/src_data/eff_train/Stable-diffusion/majicmixRealistic_v7_diffusers", subfolder="tokenizer")
# tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")# 对输入的text进行tokenize,得到对应的token ids
prompt = "a photograph of an astronaut riding a horse"
text_input_ids = tokenizer(prompt,padding="max_length",max_length=tokenizer.model_max_length,truncation=True,return_tensors="pt"
).input_idsprint(text_input_ids.shape)  # torch.Size([1, 77]
# 将token ids送入text model得到77x768的特征
text_embeddings = text_encoder(text_input_ids.to("cuda"))[0]
print(text_embeddings.shape)  # torch.Size([1, 77, 768])

UNet

稳定扩散(Stable Diffusion,SD)是一种基于扩散模型的生成式人工智能模型,用于创建图像。它的核心是一个860万参数的UNet结构,负责将文本提示转化为图像。

UNet结构主要由编码器(Encoder)和解码器(Decoder)两部分组成,两部分通过跳跃连接(Skip Connection)相连。编码器用于捕获输入的潜在表征(Latent Representation),解码器则根据这些表征生成最终图像。

具体而言,编码器包含:

  • 3个CrossAttnDownBlock2D模块,每个模块会对输入进行2倍下采样
  • 1个DownBlock2D模块,不进行下采样

中间是一个UNetMidBlock2DCrossAttn模块,用于连接编码器和解码器。

解码器包含:

  • 1个UpBlock2D模块
  • 3个CrossAttnUpBlock2D模块

编码器和解码器的模块数量和结构是对应的,通过跳跃连接融合不同尺度的特征。

在这里插入图片描述

CrossAttnDownBlock2D 和 CrossAttnUpBlock2D 是 Stable Diffusion 中 UNet 架构的关键模块,用于将文本条件(text condition)融入到图像生成的过程中。它们利用了自注意力(Self-Attention)机制,将文本嵌入与图像特征进行交叉注意力(Cross-Attention)操作。

以 CrossAttnDownBlock2D 为例,其核心是一个 CrossAttention 模块,该模块的工作原理如下:

  1. 将文本条件(如"一只可爱的小狗")编码为文本嵌入(text embeddings),记为 E t e x t \mathbf{E}_{text} Etext

  2. 从 UNet 的中间层获取图像特征,记为 F i m a g e \mathbf{F}_{image} Fimage

  3. 在 CrossAttention 中,将 F i m a g e \mathbf{F}_{image} Fimage 作为 Query,而 E t e x t \mathbf{E}_{text} Etext 作为 Key 和 Value,计算注意力权重:

    Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V

    其中 Q = F i m a g e Q=\mathbf{F}_{image} Q=Fimage, K = E t e x t K=\mathbf{E}_{text} K=Etext, V = E t e x t V=\mathbf{E}_{text} V=Etext, d k d_k dk 是缩放因子。

  4. 将注意力权重与 F i m a g e \mathbf{F}_{image} Fimage 相加,得到融合了文本条件的新特征表示。

这个过程可以用以下伪代码表示:

class CrossAttention(nn.Module):def forward(self, x, context):"""x: 图像特征 [batch, channels, height, width]context: 文本嵌入 [batch, text_len, text_dim]"""q = x.view(x.size(0), -1, x.size(-1))  # [batch, channels*height*width, dim]k = context.permute(0, 2, 1)  # [batch, text_dim, text_len]v = context.permute(0, 2, 1)  # [batch, text_dim, text_len]attn = torch.bmm(q, k)  # [batch, channels*height*width, text_len]attn = attn / sqrt(k.size(-1))attn = softmax(attn, dim=-1)x = torch.bmm(attn, v)  # [batch, channels*height*width, text_dim]x = x.view(x.size(0), -1, x.size(-1))  # [batch, channels, height, width]return x

CrossAttnUpBlock2D 的工作方式与 CrossAttnDownBlock2D 类似,只是它位于 UNet 的解码器部分。通过这种交叉注意力机制,Stable Diffusion 能够将文本条件与图像特征进行融合,从而根据提示生成所需的图像。

CrossAttention 模块是 Stable Diffusion 中实现文本到图像生成的关键部分,它利用注意力机制将文本信息注入到图像特征中,使生成的图像能够匹配给定的文本描述。

unet训练

SD训练过程原理:

SD模型的训练过程可以概括为以下几个步骤:

  1. 将图像编码到潜在空间,获得潜在向量表示。
  2. 使用CLIP文本编码器获取文本的嵌入向量表示。
  3. 在潜在空间中采样噪声,并将其添加到潜在向量中,进行扩散过程。
  4. U-Net模型接收扩散后的噪声潜在向量和文本嵌入,预测原始的噪声向量。
  5. 计算预测噪声和实际噪声之间的均方误差作为损失函数。
  6. 通过反向传播优化U-Net模型参数,使其能够更好地预测噪声。

整个过程可以用以下公式表示:

L = E x 0 , ϵ , t [ ∥ ϵ − ϵ θ ( x t , t , y ) ∥ 2 ] \mathcal{L} = \mathbb{E}_{x_0, \epsilon, t} \left[\left\| \epsilon - \epsilon_\theta(x_t, t, y) \right\|^2\right] L=Ex0,ϵ,t[ϵϵθ(xt,t,y)2]

其中:

  • x 0 x_0 x0是原始的潜在向量
  • ϵ \epsilon ϵ是添加到潜在向量上的噪声
  • t t t是随机采样的时间步长
  • x t x_t xt是扩散后的噪声潜在向量
  • y y y是文本嵌入向量
  • ϵ θ \epsilon_\theta ϵθ是U-Net模型预测的噪声

通过最小化这个损失函数,U-Net模型可以学习到从噪声潜在向量和文本嵌入中预测原始噪声的能力。在生成图像时,可以通过从随机噪声开始,逐步去噪,并根据文本嵌入对每一步的去噪过程进行条件控制,最终得到符合文本描述的图像。

Classifier-Free Guidance (CFG)是一种在训练和生成过程中提高文本条件控制的技术。它的核心思想是在训练时,以一定概率随机丢弃文本嵌入,这样模型就可以同时学习条件预测和无条件预测。在生成时,将条件预测和无条件预测的结果进行加权融合,从而增强文本条件的影响。CFG可以用以下公式表示:

ϵ θ ( C F G ) = ϵ θ ( x t , t , y ) + s ⋅ ( ϵ θ ( x t , t , ∅ ) − ϵ θ ( x t , t , y ) ) \epsilon_\theta^{(CFG)} = \epsilon_\theta(x_t, t, y) + s \cdot (\epsilon_\theta(x_t, t, \emptyset) - \epsilon_\theta(x_t, t, y)) ϵθ(CFG)=ϵθ(xt,t,y)+s(ϵθ(xt,t,)ϵθ(xt,t,y))

其中:

  • ϵ θ ( x t , t , y ) \epsilon_\theta(x_t, t, y) ϵθ(xt,t,y)是条件预测的噪声
  • ϵ θ ( x t , t , ∅ ) \epsilon_\theta(x_t, t, \emptyset) ϵθ(xt,t,)是无条件预测的噪声
  • s s s是一个scale参数,用于控制条件和无条件预测的权重

通过CFG,可以有效地提高生成图像与输入文本的一致性。

帮助、问询

https://docs.qq.com/sheet/DUEdqZ2lmbmR6UVdU?tab=BB08J2

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/799059.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Redis分布式锁误删情况说明

4.4 Redis分布式锁误删情况说明 逻辑说明: 持有锁的线程在锁的内部出现了阻塞,导致他的锁自动释放,这时其他线程,线程2来尝试获得锁,就拿到了这把锁,然后线程2在持有锁执行过程中,线程1反应过…

Linux 命令完全手册(1),受益匪浅

第一列返回的是行数,第二列是字数,第三列则是比特数。 我们可以让它只计算行数: wc -l test.txt 或者只计算字数: wc -w test.txt 或者只计算比特数: wc -c test.txt 在 ASCII 字符集中,比特数等于字…

python爬虫学习第十六天--------URLError和HTTPError、cookie登录、Handler处理器

🎈🎈作者主页: 喔的嘛呀🎈🎈 🎈🎈所属专栏:python爬虫学习🎈🎈 ✨✨谢谢大家捧场,祝屏幕前的小伙伴们每天都有好运相伴左右,一定要天天…

多轴机械臂/正逆解/轨迹规划/机器人运动学/Matlab/DH法 学习记录01——数学基础

系列文章目录 本科毕设正在做多轴机械臂相关的内容,这里是一个学习机械臂运动学课程的相关记录。 如有任何问题,可发邮件至layraliufoxmail.com问询。 1. 数学基础 文章目录 系列文章目录一、空间位置、姿态描述二、旋转矩阵(Rotation matri…

线程池的方式爬虫

<!--爬虫仅支持1.8版本的jdk--> <!-- 爬虫需要的依赖--> <dependency><groupId>org.apache.httpcomponents</groupId><artifactId>httpclient</artifactId><version>4.5.2</version> </dependency><!-- 爬虫需…

mysql修改密码提示: Your password does not satisfy the current policy requirements

1、问题概述&#xff1f; 环境说明&#xff1a; Red Hat Enterprise Linux7mysql5.7.10 执行如下语句报错&#xff1a; set password for rootlocalhost password(123456); ERROR 1819 (HY000): Your password does not satisfy the current policy requirements意思就是&a…

摄影杂记二

一、相机操作指南 ⑴按键说明&#xff1a; 除了常规的几个模式&#xff0c;里面就特殊场景可以看一下&#xff0c;有全景&#xff0c;支持摇摄。 lock&#xff1a;多功能锁。可以锁定控制按钮和控制环。在设置中找到多功能锁&#xff0c;可以设置锁定什么。 m-fn&#xff1a;多…

Go数据结构的底层原理(图文详解)

空结构体的底层原理 基本类型的字节数 fmt.Println(unsafe.Sizeof(0)) // 8 fmt.Println(unsafe.Sizeof(uint(0))) // 8 a : 0 b : &a fmt.Println(unsafe.Sizeof(b)) // 8int大小跟随系统字长指针的大小也是系统字长 空结构体 a : struct { }{} b : struct {…

国内ChatGPT大数据模型

在中国&#xff0c;随着人工智能技术的迅猛发展&#xff0c;多个科技公司和研究机构已经开发出了与OpenAI的ChatGPT类似的大型语言模型。这些模型通常基于深度学习技术&#xff0c;尤其是Transformer架构&#xff0c;它们在大量的文本数据上进行训练&#xff0c;以理解和生成自…

每天五分钟掌握深度学习框架pytorch:本专栏说明

专栏大纲 专栏计划更新章节在100章左右&#xff0c;之后还会不断更新&#xff0c;都会配备代码实现。以下是专栏大纲 部分代码实现 代码获取 为了方便用户浏览代码&#xff0c;本专栏将代码同步更新到github中&#xff0c;所有用户可以读完专栏内容和代码解析之后&#xff0c…

Struts2:Action类的写法,推荐使用继承ActionSupport类的方法

文章目录 方法一&#xff1a;Action类是一个POJO类&#xff08;简单的Java类&#xff09;ActionDemo2.javastruts_demo2.xmlstruts.xml运行结果其他strutsz_demo1.xml 方法二&#xff1a;实现一个Action的接口ActionDemo2_2.javastruts_demo2.xml运行结果 推荐&#xff01;&…

SiteSpace 使用方法笔记

目录 介绍下载及安装准备工作知网 CNKI 文献分析数据准备数据转换新建项目图形处理 介绍 CiteSpace 是一个用于可视化和分析科学文献的工具。它可以从科学文献库中提取关键词、作者、机构和引用关系等信息&#xff0c;并将其可视化为图形网络。 一些使用案例 下载及安装 下载…

Redis从入门到精通(九)Redis实战(六)基于Redis队列实现异步秒杀下单

文章目录 前言4.5 分布式锁-Redisson4.5.4 Redission锁重试4.5.5 WatchDog机制4.5.5 MutiLock原理 4.6 秒杀优化4.6.1 优化方案4.6.2 完成秒杀优化 4.7 Redis消息队列4.7.1 基于List实现消息队列4.7.2 基于PubSub的消息队列4.7.3 基于Stream的消息队列4.7.4 基于Stream的消息队…

Golang单元测试和压力测试

一.单元测试 1.1 go test工具 go语言中的测试依赖go test命令。编写测试代码和编写普通的Go代码过程类似&#xff0c;并不需要学习新的语法&#xff0c;规则和工具。 go test命令是一个按照一定约定和组织的测试代码的驱动程序。在包目录内&#xff0c;所有以_test.go为后缀名的…

零代码编程:用kimichat打造一个最简单的window程序

用kimichat可以非常方便的自动生成程序代码&#xff0c;有些小程序可能会频繁使用&#xff0c;如果每次都在vscode中执行就会很麻烦。常用的Python代码&#xff0c;可以直接做成一个window程序&#xff0c;点击就可以打开使用&#xff0c;方便很多。 首先&#xff0c;把kimich…

Tokenize Anything via Prompting

SAM的延续&#xff0c;把SAM输出的token序列用来进行分类&#xff0c;分割和一个自然语言的decoder处理&#xff0c;但其实现在多模态的图像的tokenizer也几乎都是用VIT来实现的。一开始认为这篇文章可能是关于tokenize的&#xff0c;tokenize还是很重要的&#xff0c;后来看完…

JVM虚拟机(一)介绍、JVM组成、堆、栈、方法区/元空间、直接内存

目录 一、JVM 介绍1.1 为什么要学 JVM&#xff1f;1.2 JVM 是什么&#xff1f; 二、JVM 组成2.1 程序计数器2.2 Java堆1&#xff09;JVM 内存结构2&#xff09;Java 1.7 和 1.8 中堆的区别 2.3 Java虚拟机栈1&#xff09;虚拟机栈 和 栈帧2&#xff09;常见面试题 2.4 方法区/元…

搜索二维矩阵2 合并两个有序链表

240. 搜索二维矩阵 II - 力扣&#xff08;LeetCode&#xff09; class Solution { public:bool searchMatrix(vector<vector<int>>& matrix, int target) {int i matrix.size() - 1, j 0;while(i > 0 && j < matrix[0].size()){if(matrix[i][j…

基于wsl的Ubuntu20.04上安装桌面环境

在子系统Ubuntu20.04上安装桌面环境 1. 更换软件源 由于Ubuntu默认的软件源在国外&#xff0c;有时候后可能会造成下载软件卡顿&#xff0c;这里我们更换为国内的阿里云源&#xff0c;其他国内源亦可。 双击打开Ubuntu20.04 LTS图标&#xff0c;在命令行中输入 # 备份原来的软…

Java(二)面向对象进阶

目录 面向对象 多态性 向下转型 Object equals() toString() clone() finalize() Static 单例模式 代码块 final 抽象类与抽象方法(或abstract关键字&#xff09; 接口 接口的多态性 接口的默认方法 内部类 成员内部类 局部内部类 枚举类 实现接口的枚举类 …