【扩散模型第三篇】Classifier Guidance 和 Classifier Free Guidance(CFG)

参考:
[1] 张振虎博客
[2] https://www.bilibili.com/video/BV1s8411i7cU/?spm_id_from=333.788&vd_source=9e9b4b6471a6e98c3e756ce7f41eb134
[3] https://zhuanlan.zhihu.com/p/660518657
[4] https://zhuanlan.zhihu.com/p/640631667

进食顺序

  • 1 前言
  • 2 Classifier Guidance
  • 3 Classifier Free Guidance

1 前言

我们在DDPM或DDIM生成图像时是通常是不可控的,因为它是由一张随即高斯噪声一步步去噪得到生成图像。如果我们想要这个过程是可控的话,最直观的一个做法就是在生成过程中加上一个条件 y y y,既整个过程的变为:
p ( x 1 : T ∣ x 0 , y ) p(x_{1:T}|x_0,y) p(x1:Tx0,y)
接下来就是讨论加上了条件 y y y对于公式有无影响。

首先扩散模型遵循马尔科夫链性质,所以我们可以得出:
p ( x t ∣ x t − 1 , y ) : = p ( x t ∣ x t − 1 ) p(x_t|x_{t-1},y) := p(x_t|x_{t-1}) p(xtxt1,y):=p(xtxt1)
基于这一事实,我们还可以推出:
q ^ ( x t ∣ x t − 1 ) = ∫ y q ^ ( x t , y ∣ x t − 1 ) d y = ∫ y q ^ ( x t ∣ y , x t − 1 ) q ^ ( y ∣ x t − 1 ) d y = ∫ y q ^ ( x t ∣ x t − 1 ) q ^ ( y ∣ x t − 1 ) d y = q ^ ( x t ∣ x t − 1 ) = q ^ ( x t ∣ x t − 1 , y ) \begin{aligned} \hat q(x_t|x_{t-1}) &= \int_y \hat{q}(x_t,y|x_{t-1})dy\\ &=\int_y\hat q(x_t|y,x_{t-1})\hat q(y|x_{t-1})dy\\ &= \int_y \hat q(x_t|x_{t-1})\hat q(y|x_{t-1})dy\\ & = \hat q(x_t|x_{t-1}) = \hat q(x_t|x_{t-1},y) \end{aligned} q^(xtxt1)=yq^(xt,yxt1)dy=yq^(xty,xt1)q^(yxt1)dy=yq^(xtxt1)q^(yxt1)dy=q^(xtxt1)=q^(xtxt1,y)

同样用全概率公式,可以推出:
在这里插入图片描述
所以我们可以断论:加上条件 y y y 对前向过程毫无影响

2 Classifier Guidance

逆向过程有如下公式:
p ^ ( x t − 1 ∣ x t , y ) = p ^ ( x t − 1 ∣ x t ) p ^ ( y ∣ x t − 1 , x t ) p ^ ( y ∣ x t ) \hat p(x_{t-1}|x_t,y)=\frac{\hat p(x_{t-1}|x_t)\hat p(y|x_{t-1},x_t)}{\hat p(y|x_t)} p^(xt1xt,y)=p^(yxt)p^(xt1xt)p^(yxt1,xt)
其中分母和 x t − 1 x_{t-1} xt1毫无关系,所以分母可以看作是常数 C C C
而我们知道加上条件对于扩散过程是没有影响的,所以我们还已知:
q ^ ( x t ∣ x t − 1 , y ) = q ( x t ∣ x t − 1 ) q ^ ( x 0 ) = q ( x 0 ) q ^ ( x 1 : T ∣ x 0 , y ) = ∏ t = 1 T q ( x t ∣ x t − 1 , y ) \hat q(x_t|x_{t-1},y) = q(x_t|x_{t-1})\\ \hat q(x_0)=q(x_0)\\ \hat q(x_{1:T}|x_0,y) = \prod_{t=1}^Tq(x_t|x_{t-1},y) q^(xtxt1,y)=q(xtxt1)q^(x0)=q(x0)q^(x1:Tx0,y)=t=1Tq(xtxt1,y)

现在我们未知的是 p ^ ( x t − 1 ∣ x t ) 和 p ^ ( y ∣ x t − 1 , x t ) \hat p(x_{t-1}|x_t)和\hat p(y|x_{t-1},x_t) p^(xt1xt)p^(yxt1,xt),现在来推导这两项:

1)推导 p ^ ( x t − 1 ∣ x t ) \hat p(x_{t-1}|x_t) p^(xt1xt)
根据贝叶斯公式,我们有:
p ^ ( x t − 1 ∣ x t ) = p ^ ( x t ∣ x t − 1 ) p ^ ( x t − 1 ) p ^ ( x t ) \hat p(x_{t-1}|x_t) = \frac{\hat p(x_t|x_{t-1})\hat p(x_{t-1})}{\hat p(x_t)} p^(xt1xt)=p^(xt)p^(xtxt1)p^(xt1)
我们已知条件 y y y 对扩散过程不影响(可以通过全概率公式推出),所以我们有
p ^ ( x t ∣ x t − 1 ) = p ( x t ∣ x t − 1 ) \hat p(x_t|x_{t-1})=p(x_t|x_{t-1}) p^(xtxt1)=p(xtxt1)
我们同样可以由全概率公式推出:
p ^ ( x t ) = p ( x t ) \hat p(x_t) = p(x_t) p^(xt)=p(xt)

所以
p ^ ( x t − 1 ∣ x t ) = p ( x t ∣ x t − 1 ) p ( x t − 1 ) p ( x t ) \hat p(x_{t-1}|x_t) = \frac{p(x_t|x_{t-1})p(x_{t-1})}{p(x_t)} p^(xt1xt)=p(xt)p(xtxt1)p(xt1)

2)推导 p ^ ( y ∣ x t − 1 , x t ) \hat p(y|x_{t-1},x_t) p^(yxt1,xt)

根据贝叶斯公式,有:
p ^ ( y ∣ x t − 1 , x t ) = p ^ ( x t ∣ y , x t − 1 ) p ^ ( y ∣ x t − 1 ) p ^ ( x t ∣ x t − 1 ) \hat p(y|x_{t-1},x_t)=\frac{\hat p(x_t|y,x_{t-1})\hat p(y|x_{t-1})}{\hat p(x_t|x_{t-1})} p^(yxt1,xt)=p^(xtxt1)p^(xty,xt1)p^(yxt1)
根据马尔可夫链性质,所以约去分子的第一项和分母,所以得到
p ^ ( y ∣ x t − 1 , x t ) = p ^ ( y ∣ x t − 1 ) \hat p(y|x_{t-1},x_t) = \hat p(y|x_{t-1}) p^(yxt1,xt)=p^(yxt1)

3)终极目标

所以我们的公式此刻为:
p ^ ( x t − 1 ∣ x t , y ) = q ( x t − 1 ∣ x t ) q ( y ∣ x t − 1 ) q ( y ∣ x t ) = C ∗ q ( x t − 1 ∣ x t ) ∗ q ( y ∣ x t − 1 ) \hat p(x_{t-1}|x_t,y) = \frac{q(x_{t-1}|x_t)q(y|x_{t-1})}{q(y|x_t)} = C*q(x_{t-1}|x_t)*q(y|x_{t-1}) p^(xt1xt,y)=q(yxt)q(xt1xt)q(yxt1)=Cq(xt1xt)q(yxt1)
其中第一项是常数,第二项为DDPM的目标,第三项既为分类器输出概率(根据 x t − 1 x_{t-1} xt1输出类别标签 y y y

4)问题与进一步推导
我们此刻为 t t t时刻,我们是不可以得出 x t − 1 x_{t-1} xt1的。但是我们只是每一次从 x t x_t xt x t − 1 x_{t-1} xt1实际上只做了很微小的变化,所以我们是可以近似 x t − 1 x_{t-1} xt1的,用泰勒展开式去近似。

我们有
l o g p θ ( x t − 1 ∣ x t ) = − 1 2 ( x t − 1 − μ ) 2 Σ logp_\theta(x_{t-1}|x_t) = -\frac{1}{2}\frac{(x_{t-1} -\mu)^2}{\Sigma} logpθ(xt1xt)=21Σ(xt1μ)2
Σ \Sigma Σ是很小的,我们可以理解为 x t − 1 x_{t-1} xt1出现在 x t x_t xt的附近,而 x t x_t xt约等于期望
所以令 x t − 1 = μ x_{t-1}=\mu xt1=μ


l o g p ϕ ( y ∣ x t − 1 ) = l o g p ϕ ( y ∣ x t − 1 ) ∣ x t − 1 = μ + ( x t − 1 − μ ) ∇ x t − 1 l o g p ϕ ( y ∣ x t − 1 ) ∣ x t − 1 = μ + o ( 高阶 ) logp_\phi(y|x_{t-1}) = logp_\phi(y|x_{t-1})_{|x_{t-1}=\mu} +(x_{t-1}-\mu)\nabla_{x_{t-1}}logp_\phi(y|x_{t-1})_{|x_{t-1}=\mu} +o(高阶) logpϕ(yxt1)=logpϕ(yxt1)xt1=μ+(xt1μ)xt1logpϕ(yxt1)xt1=μ+o(高阶)
又第一项和后面的高阶项相当于常数,所以约等于
l o g p ϕ ( y ∣ x t − 1 ) = ( x t − 1 − μ ) ∇ x t − 1 l o g p ϕ ( y ∣ x t − 1 ) ∣ x t − 1 = μ logp_\phi(y|x_{t-1}) = (x_{t-1}-\mu)\nabla_{x_{t-1}}logp_\phi(y|x_{t-1})_{|x_{t-1}=\mu} logpϕ(yxt1)=(xt1μ)xt1logpϕ(yxt1)xt1=μ

我们将两个对数相加,一番推导后(我不会)可以得到
l o g p ( x t − 1 ∣ x t , y ) ∼ N ( μ + Σ ∇ l o g p ϕ ( y ∣ x t − 1 ) ∣ x t − 1 = μ ) logp(x_{t-1}|x_t,y) \sim N(\mu+\Sigma \nabla log p_\phi(y|x_{t-1})_{|x_{t-1}=\mu}) logp(xt1xt,y)N(μ+Σ∇logpϕ(yxt1)xt1=μ)
既采样时,有
x t − 1 = μ + Σ ∇ + Σ ϵ x_{t-1} = \mu+\Sigma\nabla+\Sigma\epsilon xt1=μ+Σ∇+Σϵ
其中 ∇ \nabla 为分类器的梯度,所以这么一番推导,我们只是在最后的采样公式里加了一个引导方向的梯度项。
但有个缺点就是,DDIM的 Σ = 0 \Sigma=0 Σ=0,那么不就没用了。
而且还有两个缺点:

  1. 还要预训练一个分类器模型
  2. 只能生成分类器训练集所有的类别

5)用能量函数(score-base function)做进一步泛化

已知
s = ∇ x t l o g p ( x t ) s = \nabla_{x_t}log p(x_t) s=xtlogp(xt)
我们已知梯度和噪声的关系为:
∇ x t l o g p ( x t ) = − ϵ 1 − α ˉ t \nabla_{x_t}log p(x_t) = \frac{-\epsilon}{\sqrt{1-\bar\alpha_t}} xtlogp(xt)=1αˉt ϵ
如果没有classifier guidance,那么我们的神经网络想要预测的就是
∇ x t l o g p θ ( x t ) = − ϵ θ 1 − α ˉ t \nabla_{x_t}log p_\theta(x_t) = \frac{-\epsilon_\theta}{\sqrt{1-\bar\alpha_t}} xtlogpθ(xt)=1αˉt ϵθ
现在加上classifier guidance,也就是加上了 ∇ l o g p θ ( y ∣ x t ) \nabla logp_\theta(y|x_t) logpθ(yxt),假设其值为 g g g(为了方便,他就是分类器梯度)

其实我们神经网络实际上是预测:
− ϵ θ 1 − α ˉ t + g \frac{-\epsilon_\theta}{\sqrt{1-\bar\alpha_t}} +g 1αˉt ϵθ+g

我们设其在预测
− ϵ ′ 1 − α ˉ t \frac{-\epsilon'}{\sqrt{1-\bar\alpha_t}} 1αˉt ϵ

将其取等式,然后做变换,再加上一个强度因子 w w w,得到
ϵ ′ = ϵ θ − w 1 − α ˉ t ∇ l o g p ϕ ( y ∣ x t ) \epsilon' = \epsilon_\theta -w\sqrt{1-\bar\alpha_t}\nabla log p_\phi(y|x_{t}) ϵ=ϵθw1αˉt logpϕ(yxt)
也就是说,我们只需要在预测的噪声上加上一点扰动即可。而扰动项为分类器的梯度。

6)纯能量函数角度推导
在这里插入图片描述
原论文的伪代码的两个算法也是我们推导的:
在这里插入图片描述

7)不严谨代码理解

classifier_model = ...  # 加载一个训好的图像分类模型
y = 1  # 生成类别为 1 的图像,假设类别 1 对应“狗”这个类
guidance_scale = 7.5  # 控制类别引导的强弱,越大越强
input = get_noise(...)  # 从高斯分布随机取一个跟输出图像一样 shape 的噪声图for t in tqdm(scheduler.timesteps):# 用 unet 推理,预测噪声with torch.no_grad():noise_pred = unet(input, t).sample# 用 input 和预测出的 noise_pred 和 x_t 计算得到 x_t-1input = scheduler.step(noise_pred, t, input).prev_sample# classifier guidance 步骤class_guidance = classifier_model.get_class_guidance(input, y)input += class_guidance * guidance_scals  # 把梯度加上去

3 Classifier Free Guidance

我们知道对于classifier guidance最主要的限制就是分类器!CFG的方法就是直接将条件 y y y也加到模型中直接训练,而不用在训练一个分类器了,这样相当于训练了一个隐式的分类器,也就是训练了无条件生成模型和有条件生成模型,只不过这两个模型融合在同一个生成模型里。数学推导如下:

在这里插入图片描述
不严谨代码理解

clip_model = ...  # 加载一个官方的 clip 模型text = "一只狗"  # 输入文本
text_embeddings = clip_model.text_encode(text)  # 编码条件文本
empty_embeddings = clip_model.text_encode("")  # 编码空文本
text_embeddings = torch.cat(empty_embeddings, text_embeddings)  # 把它俩 concate 到一起作为条件input = get_noise(...)  # 从高斯分布随机取一个跟输出图像一样 shape 的噪声图for t in tqdm(scheduler.timesteps):# 用 unet 推理,预测噪声with torch.no_grad():# 这里同时预测出了有文本的和空文本的图像噪声noise_pred = unet(input, t, encoder_hidden_states=text_embeddings).sample# Classifier-Free Guidance 引导noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)  # 拆成无条件和有条件的噪声# 把【“无条件噪声”指向“有条件噪声”】看做一个向量,根据 guidance_scale 的值放大这个向量# (当 guidance_scale = 1 时,下面这个式子退化成 noise_pred_text)noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)# 用预测出的 noise_pred 和 x_t 计算得到 x_t-1input = scheduler.step(noise_pred, t, input).prev_sample

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/709034.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【c++】stack和queue模拟实现

> 作者简介:დ旧言~,目前大二,现在学习Java,c,c,Python等 > 座右铭:松树千年终是朽,槿花一日自为荣。 > 目标:能手撕stack和queue模拟 > 毒鸡汤:…

DataGrip2023配置连接Mssqlserver、Mysql、Oracle若干问题解决方案

1、Mssqlserver连接 本人连的是Sql2008,默认添加时,地址、端口、实例、账号、密码后,测试连接出现错误。 Use SSL:不要勾选 VM option:填写,"-Djdk.tls.disabledAlgorithmsSSLv3, RC4, DES, MD5withR…

【Redis | 第五篇】一篇文章看懂布谷鸟过滤器

文章目录 5.布谷鸟过滤器5.1起源介绍5.2原理5.2.1演示步骤(1)保存元素(两个位置均为空)(2)保存元素(其中一个位置被占)(3)保存元素(两个位置都被占…

Linux小项目:在线词典开发

在线词典介绍 流程图如下: 项目的功能介绍 在线英英词典项目功能描述用户注册和登录验证服务器端将用户信息和历史记录保存在数据中。客户端输入用户和密码,服务器端在数据库中查找、匹配,返回结果单词在线翻译根据客户端输入输入的单词在字…

轻松玩转Git

轻松玩转Git 快速入门什么是Git为什么要做版本控制安装git Git实战单枪匹马开始干拓展新功能小结 紧急修复bug分支紧急修复bug方案命令总结工作流 上传GitHub第一天上班前在家上传代码初次在公司新电脑下载代码下班回到家继续写代码到公司继续开发在公司约妹子忘记提交代码回家…

CDH6.3.1离线安装

一、从官方文档整体认识CDH 官方文档地址如下: CDH Overview | 6.3.x | Cloudera Documentation CDH是Apache Hadoop和相关项目中最完整、测试最全面、最受欢迎的发行版。CDH提供Hadoop的核心元素、可扩展存储和分布式计算,以及基于Web的用户界面和重…

使用Rust 实现文件批量下载

1.概述 Rust 是一种高效的系统编程语法,具有安全、并发和实用性的特点。本篇文章将通过实例详细介绍如何使用Rust来实现文件的批量下载功能,并提供示例以帮助读者理解。我们的实例将分为四个部分来详细描述:准备环境、创建主函数、实现下载函…

Mysql8.0 数据类型介绍

1,数值类型 1.1 整数类型 TINYINT:微整数,1字节 SMALLINT:小整数,2字节 MEDIUMINT:中等整数,3字节 INT:整数,4字节 BIGINT:大整数,8字节 如…

蓝桥杯备战刷题three(自用)

1.合法日期 #include <iostream> #include <map> #include <string> using namespace std; int main() {map<string,int>mp;int days[13]{0,31,28,31,30,31,30,31,31,30,31,30,31};for(int i1;i<12;i){for(int j1;j<days[i];j){string sto_strin…

江苏双线服务器租用的优势有哪些?

随着互联网的快速发展&#xff0c;服务器也随着科技的发展变得多种多样&#xff0c;其中双线服务器租用格外受大家关注&#xff0c;那么江苏双线服务器租用到底有哪些优势呢&#xff1f; 1.网络环境稳定 江苏双线服务器租用的主要优点就是有着高速稳定的网络环境&#xff0c;双…

P4198 楼房重建题解(线段树, 分治)

题目描述 题面 简要题意&#xff1a; 给你一个长度为 n n n 的序列 a i a_i ai​ ( n ≤ 1 0 5 n \leq 10^5 n≤105)&#xff0c;要求进行 m m m 次操作 ( m ≤ 1 0 5 m \leq 10^5 m≤105) 。操作分两种&#xff1a; 1.单点修改。 2.查询整个序列中有多少个位置 x x x 满…

动态规划(背包理论)-算法题

416. 分割等和子集 题目 给你一个 只包含正整数 的 非空 数组 nums 。请你判断是否可以将这个数组分割成两个子集&#xff0c;使得两个子集的元素和相等。 示例 1&#xff1a; 输入&#xff1a;nums [1,5,11,5] 输出&#xff1a;true 解释&#xff1a;数组可以分割成 [1, …

Vivado Vitis 2023.2 环境配置 Git TCL工程管理 MicroBlaze和HLS点灯测试

文章目录 本篇概要Vivado Vitis 环境搭建Vivado 免费标准版 vs 企业版Vivado Windows 安装Vivado 安装更新 Vivado 工程操作GUI 创建工程打开已有工程从已有工程创建, 重命名工程GUI导出TCL, TCL复原工程TCL命令 Vivado 版本控制BlinkTcl脚本新建导出重建工程纯Verilog BlinkTc…

js处理IOS虚拟键盘弹出后输入框被遮住

​ JS IOS 前言 在项目开发的过程中&#xff0c;在IOS手机端系统下&#xff0c;当对输入框&#xff08;input/textarea&#xff09;进行focus操作时&#xff0c;键盘弹起遮住输入框。 问题描述 从页面底部focus输入框失败从页面中间focus输入框失败 原因 造成上述问题的&…

【MySQL】_自连接与子查询

目录 1. 自连接 2. 子查询&#xff08;嵌套查询&#xff09; 2.1 子查询分类 2.2 单行子查询示例1&#xff1a;查询不想毕业同学的同班同学 2.3 多行子查询示例2&#xff1a;查询语文或英语课程的信息成绩 3. 合并查询 3.1 示例1&#xff1a;查询id3或者名字为英文的课程…

Flutter 处理异步操作并根据异步操作状态动态构建界面的方法FutureBuilder

概述 当界面的内容需要依靠网络请求的数据&#xff0c;就需要处理苦恼的&#xff0c;状态是空&#xff0c;非空的逻辑了&#xff0c;不然页面构建可能会报错&#xff0c;而FutureBuilder提供了一个非常好的解决方法&#xff0c;直接看代码 代码 异步操作函数 即网络请求函数…

[CISCN2019 华北赛区 Day2 Web1]Hack World 1 题目分析与详解

一、分析判断 进入靶机&#xff0c;主页面如图&#xff1a; 主页面提供给我们一条关键信息&#xff1a; flag值在 表flag 中的 flag列 中。 接着我们尝试输入不同的id&#xff0c;情况分别如图&#xff1a; 当id1时&#xff1a; 当id2时&#xff1a; 当id3时&#xff1a; 我…

YOLOv8改进涨点,添加GSConv+Slim Neck,有效提升目标检测效果,代码改进(超详细)

目录 摘要 主要想法 GSConv GSConv代码实现 slim-neck slim-neck代码实现 yaml文件 完整代码分享 总结 摘要 目标检测是计算机视觉中重要的下游任务。对于车载边缘计算平台来说&#xff0c;巨大的模型很难达到实时检测的要求。而且&#xff0c;由大量深度可分离卷积层构…

【Redis | 第三篇】Springboot整合Redis

文章目录 3.Springboot整合Redis3.1Spring Data Redis介绍3.2整合步骤3.2.1导入依赖3.2.2配置redis数据源3.2.3使用RedisTemplate进行操作&#xff08;1&#xff09;创建RedisTemplate Bean&#xff08;2&#xff09;注入RedisTemplate&#xff08;3&#xff09;执行Redis操作&…

C++:常量表达式

C11开始constexpr作为一种声明&#xff0c;为编译器提供了在编译期间确认结果的优化建议&#xff0c;满足部分编译期特性的需求 constexpr和const区别 int b10; const int ab; //运行成功 constexpr int cb; //编译器报错&#xff0c;b的值在编译期间不能确定 const int size1…