【扩散模型第三篇】Classifier Guidance 和 Classifier Free Guidance(CFG)

参考:
[1] 张振虎博客
[2] https://www.bilibili.com/video/BV1s8411i7cU/?spm_id_from=333.788&vd_source=9e9b4b6471a6e98c3e756ce7f41eb134
[3] https://zhuanlan.zhihu.com/p/660518657
[4] https://zhuanlan.zhihu.com/p/640631667

进食顺序

  • 1 前言
  • 2 Classifier Guidance
  • 3 Classifier Free Guidance

1 前言

我们在DDPM或DDIM生成图像时是通常是不可控的,因为它是由一张随即高斯噪声一步步去噪得到生成图像。如果我们想要这个过程是可控的话,最直观的一个做法就是在生成过程中加上一个条件 y y y,既整个过程的变为:
p ( x 1 : T ∣ x 0 , y ) p(x_{1:T}|x_0,y) p(x1:Tx0,y)
接下来就是讨论加上了条件 y y y对于公式有无影响。

首先扩散模型遵循马尔科夫链性质,所以我们可以得出:
p ( x t ∣ x t − 1 , y ) : = p ( x t ∣ x t − 1 ) p(x_t|x_{t-1},y) := p(x_t|x_{t-1}) p(xtxt1,y):=p(xtxt1)
基于这一事实,我们还可以推出:
q ^ ( x t ∣ x t − 1 ) = ∫ y q ^ ( x t , y ∣ x t − 1 ) d y = ∫ y q ^ ( x t ∣ y , x t − 1 ) q ^ ( y ∣ x t − 1 ) d y = ∫ y q ^ ( x t ∣ x t − 1 ) q ^ ( y ∣ x t − 1 ) d y = q ^ ( x t ∣ x t − 1 ) = q ^ ( x t ∣ x t − 1 , y ) \begin{aligned} \hat q(x_t|x_{t-1}) &= \int_y \hat{q}(x_t,y|x_{t-1})dy\\ &=\int_y\hat q(x_t|y,x_{t-1})\hat q(y|x_{t-1})dy\\ &= \int_y \hat q(x_t|x_{t-1})\hat q(y|x_{t-1})dy\\ & = \hat q(x_t|x_{t-1}) = \hat q(x_t|x_{t-1},y) \end{aligned} q^(xtxt1)=yq^(xt,yxt1)dy=yq^(xty,xt1)q^(yxt1)dy=yq^(xtxt1)q^(yxt1)dy=q^(xtxt1)=q^(xtxt1,y)

同样用全概率公式,可以推出:
在这里插入图片描述
所以我们可以断论:加上条件 y y y 对前向过程毫无影响

2 Classifier Guidance

逆向过程有如下公式:
p ^ ( x t − 1 ∣ x t , y ) = p ^ ( x t − 1 ∣ x t ) p ^ ( y ∣ x t − 1 , x t ) p ^ ( y ∣ x t ) \hat p(x_{t-1}|x_t,y)=\frac{\hat p(x_{t-1}|x_t)\hat p(y|x_{t-1},x_t)}{\hat p(y|x_t)} p^(xt1xt,y)=p^(yxt)p^(xt1xt)p^(yxt1,xt)
其中分母和 x t − 1 x_{t-1} xt1毫无关系,所以分母可以看作是常数 C C C
而我们知道加上条件对于扩散过程是没有影响的,所以我们还已知:
q ^ ( x t ∣ x t − 1 , y ) = q ( x t ∣ x t − 1 ) q ^ ( x 0 ) = q ( x 0 ) q ^ ( x 1 : T ∣ x 0 , y ) = ∏ t = 1 T q ( x t ∣ x t − 1 , y ) \hat q(x_t|x_{t-1},y) = q(x_t|x_{t-1})\\ \hat q(x_0)=q(x_0)\\ \hat q(x_{1:T}|x_0,y) = \prod_{t=1}^Tq(x_t|x_{t-1},y) q^(xtxt1,y)=q(xtxt1)q^(x0)=q(x0)q^(x1:Tx0,y)=t=1Tq(xtxt1,y)

现在我们未知的是 p ^ ( x t − 1 ∣ x t ) 和 p ^ ( y ∣ x t − 1 , x t ) \hat p(x_{t-1}|x_t)和\hat p(y|x_{t-1},x_t) p^(xt1xt)p^(yxt1,xt),现在来推导这两项:

1)推导 p ^ ( x t − 1 ∣ x t ) \hat p(x_{t-1}|x_t) p^(xt1xt)
根据贝叶斯公式,我们有:
p ^ ( x t − 1 ∣ x t ) = p ^ ( x t ∣ x t − 1 ) p ^ ( x t − 1 ) p ^ ( x t ) \hat p(x_{t-1}|x_t) = \frac{\hat p(x_t|x_{t-1})\hat p(x_{t-1})}{\hat p(x_t)} p^(xt1xt)=p^(xt)p^(xtxt1)p^(xt1)
我们已知条件 y y y 对扩散过程不影响(可以通过全概率公式推出),所以我们有
p ^ ( x t ∣ x t − 1 ) = p ( x t ∣ x t − 1 ) \hat p(x_t|x_{t-1})=p(x_t|x_{t-1}) p^(xtxt1)=p(xtxt1)
我们同样可以由全概率公式推出:
p ^ ( x t ) = p ( x t ) \hat p(x_t) = p(x_t) p^(xt)=p(xt)

所以
p ^ ( x t − 1 ∣ x t ) = p ( x t ∣ x t − 1 ) p ( x t − 1 ) p ( x t ) \hat p(x_{t-1}|x_t) = \frac{p(x_t|x_{t-1})p(x_{t-1})}{p(x_t)} p^(xt1xt)=p(xt)p(xtxt1)p(xt1)

2)推导 p ^ ( y ∣ x t − 1 , x t ) \hat p(y|x_{t-1},x_t) p^(yxt1,xt)

根据贝叶斯公式,有:
p ^ ( y ∣ x t − 1 , x t ) = p ^ ( x t ∣ y , x t − 1 ) p ^ ( y ∣ x t − 1 ) p ^ ( x t ∣ x t − 1 ) \hat p(y|x_{t-1},x_t)=\frac{\hat p(x_t|y,x_{t-1})\hat p(y|x_{t-1})}{\hat p(x_t|x_{t-1})} p^(yxt1,xt)=p^(xtxt1)p^(xty,xt1)p^(yxt1)
根据马尔可夫链性质,所以约去分子的第一项和分母,所以得到
p ^ ( y ∣ x t − 1 , x t ) = p ^ ( y ∣ x t − 1 ) \hat p(y|x_{t-1},x_t) = \hat p(y|x_{t-1}) p^(yxt1,xt)=p^(yxt1)

3)终极目标

所以我们的公式此刻为:
p ^ ( x t − 1 ∣ x t , y ) = q ( x t − 1 ∣ x t ) q ( y ∣ x t − 1 ) q ( y ∣ x t ) = C ∗ q ( x t − 1 ∣ x t ) ∗ q ( y ∣ x t − 1 ) \hat p(x_{t-1}|x_t,y) = \frac{q(x_{t-1}|x_t)q(y|x_{t-1})}{q(y|x_t)} = C*q(x_{t-1}|x_t)*q(y|x_{t-1}) p^(xt1xt,y)=q(yxt)q(xt1xt)q(yxt1)=Cq(xt1xt)q(yxt1)
其中第一项是常数,第二项为DDPM的目标,第三项既为分类器输出概率(根据 x t − 1 x_{t-1} xt1输出类别标签 y y y

4)问题与进一步推导
我们此刻为 t t t时刻,我们是不可以得出 x t − 1 x_{t-1} xt1的。但是我们只是每一次从 x t x_t xt x t − 1 x_{t-1} xt1实际上只做了很微小的变化,所以我们是可以近似 x t − 1 x_{t-1} xt1的,用泰勒展开式去近似。

我们有
l o g p θ ( x t − 1 ∣ x t ) = − 1 2 ( x t − 1 − μ ) 2 Σ logp_\theta(x_{t-1}|x_t) = -\frac{1}{2}\frac{(x_{t-1} -\mu)^2}{\Sigma} logpθ(xt1xt)=21Σ(xt1μ)2
Σ \Sigma Σ是很小的,我们可以理解为 x t − 1 x_{t-1} xt1出现在 x t x_t xt的附近,而 x t x_t xt约等于期望
所以令 x t − 1 = μ x_{t-1}=\mu xt1=μ


l o g p ϕ ( y ∣ x t − 1 ) = l o g p ϕ ( y ∣ x t − 1 ) ∣ x t − 1 = μ + ( x t − 1 − μ ) ∇ x t − 1 l o g p ϕ ( y ∣ x t − 1 ) ∣ x t − 1 = μ + o ( 高阶 ) logp_\phi(y|x_{t-1}) = logp_\phi(y|x_{t-1})_{|x_{t-1}=\mu} +(x_{t-1}-\mu)\nabla_{x_{t-1}}logp_\phi(y|x_{t-1})_{|x_{t-1}=\mu} +o(高阶) logpϕ(yxt1)=logpϕ(yxt1)xt1=μ+(xt1μ)xt1logpϕ(yxt1)xt1=μ+o(高阶)
又第一项和后面的高阶项相当于常数,所以约等于
l o g p ϕ ( y ∣ x t − 1 ) = ( x t − 1 − μ ) ∇ x t − 1 l o g p ϕ ( y ∣ x t − 1 ) ∣ x t − 1 = μ logp_\phi(y|x_{t-1}) = (x_{t-1}-\mu)\nabla_{x_{t-1}}logp_\phi(y|x_{t-1})_{|x_{t-1}=\mu} logpϕ(yxt1)=(xt1μ)xt1logpϕ(yxt1)xt1=μ

我们将两个对数相加,一番推导后(我不会)可以得到
l o g p ( x t − 1 ∣ x t , y ) ∼ N ( μ + Σ ∇ l o g p ϕ ( y ∣ x t − 1 ) ∣ x t − 1 = μ ) logp(x_{t-1}|x_t,y) \sim N(\mu+\Sigma \nabla log p_\phi(y|x_{t-1})_{|x_{t-1}=\mu}) logp(xt1xt,y)N(μ+Σ∇logpϕ(yxt1)xt1=μ)
既采样时,有
x t − 1 = μ + Σ ∇ + Σ ϵ x_{t-1} = \mu+\Sigma\nabla+\Sigma\epsilon xt1=μ+Σ∇+Σϵ
其中 ∇ \nabla 为分类器的梯度,所以这么一番推导,我们只是在最后的采样公式里加了一个引导方向的梯度项。
但有个缺点就是,DDIM的 Σ = 0 \Sigma=0 Σ=0,那么不就没用了。
而且还有两个缺点:

  1. 还要预训练一个分类器模型
  2. 只能生成分类器训练集所有的类别

5)用能量函数(score-base function)做进一步泛化

已知
s = ∇ x t l o g p ( x t ) s = \nabla_{x_t}log p(x_t) s=xtlogp(xt)
我们已知梯度和噪声的关系为:
∇ x t l o g p ( x t ) = − ϵ 1 − α ˉ t \nabla_{x_t}log p(x_t) = \frac{-\epsilon}{\sqrt{1-\bar\alpha_t}} xtlogp(xt)=1αˉt ϵ
如果没有classifier guidance,那么我们的神经网络想要预测的就是
∇ x t l o g p θ ( x t ) = − ϵ θ 1 − α ˉ t \nabla_{x_t}log p_\theta(x_t) = \frac{-\epsilon_\theta}{\sqrt{1-\bar\alpha_t}} xtlogpθ(xt)=1αˉt ϵθ
现在加上classifier guidance,也就是加上了 ∇ l o g p θ ( y ∣ x t ) \nabla logp_\theta(y|x_t) logpθ(yxt),假设其值为 g g g(为了方便,他就是分类器梯度)

其实我们神经网络实际上是预测:
− ϵ θ 1 − α ˉ t + g \frac{-\epsilon_\theta}{\sqrt{1-\bar\alpha_t}} +g 1αˉt ϵθ+g

我们设其在预测
− ϵ ′ 1 − α ˉ t \frac{-\epsilon'}{\sqrt{1-\bar\alpha_t}} 1αˉt ϵ

将其取等式,然后做变换,再加上一个强度因子 w w w,得到
ϵ ′ = ϵ θ − w 1 − α ˉ t ∇ l o g p ϕ ( y ∣ x t ) \epsilon' = \epsilon_\theta -w\sqrt{1-\bar\alpha_t}\nabla log p_\phi(y|x_{t}) ϵ=ϵθw1αˉt logpϕ(yxt)
也就是说,我们只需要在预测的噪声上加上一点扰动即可。而扰动项为分类器的梯度。

6)纯能量函数角度推导
在这里插入图片描述
原论文的伪代码的两个算法也是我们推导的:
在这里插入图片描述

7)不严谨代码理解

classifier_model = ...  # 加载一个训好的图像分类模型
y = 1  # 生成类别为 1 的图像,假设类别 1 对应“狗”这个类
guidance_scale = 7.5  # 控制类别引导的强弱,越大越强
input = get_noise(...)  # 从高斯分布随机取一个跟输出图像一样 shape 的噪声图for t in tqdm(scheduler.timesteps):# 用 unet 推理,预测噪声with torch.no_grad():noise_pred = unet(input, t).sample# 用 input 和预测出的 noise_pred 和 x_t 计算得到 x_t-1input = scheduler.step(noise_pred, t, input).prev_sample# classifier guidance 步骤class_guidance = classifier_model.get_class_guidance(input, y)input += class_guidance * guidance_scals  # 把梯度加上去

3 Classifier Free Guidance

我们知道对于classifier guidance最主要的限制就是分类器!CFG的方法就是直接将条件 y y y也加到模型中直接训练,而不用在训练一个分类器了,这样相当于训练了一个隐式的分类器,也就是训练了无条件生成模型和有条件生成模型,只不过这两个模型融合在同一个生成模型里。数学推导如下:

在这里插入图片描述
不严谨代码理解

clip_model = ...  # 加载一个官方的 clip 模型text = "一只狗"  # 输入文本
text_embeddings = clip_model.text_encode(text)  # 编码条件文本
empty_embeddings = clip_model.text_encode("")  # 编码空文本
text_embeddings = torch.cat(empty_embeddings, text_embeddings)  # 把它俩 concate 到一起作为条件input = get_noise(...)  # 从高斯分布随机取一个跟输出图像一样 shape 的噪声图for t in tqdm(scheduler.timesteps):# 用 unet 推理,预测噪声with torch.no_grad():# 这里同时预测出了有文本的和空文本的图像噪声noise_pred = unet(input, t, encoder_hidden_states=text_embeddings).sample# Classifier-Free Guidance 引导noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)  # 拆成无条件和有条件的噪声# 把【“无条件噪声”指向“有条件噪声”】看做一个向量,根据 guidance_scale 的值放大这个向量# (当 guidance_scale = 1 时,下面这个式子退化成 noise_pred_text)noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)# 用预测出的 noise_pred 和 x_t 计算得到 x_t-1input = scheduler.step(noise_pred, t, input).prev_sample

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/709034.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【c++】stack和queue模拟实现

> 作者简介:დ旧言~,目前大二,现在学习Java,c,c,Python等 > 座右铭:松树千年终是朽,槿花一日自为荣。 > 目标:能手撕stack和queue模拟 > 毒鸡汤:…

DataGrip2023配置连接Mssqlserver、Mysql、Oracle若干问题解决方案

1、Mssqlserver连接 本人连的是Sql2008,默认添加时,地址、端口、实例、账号、密码后,测试连接出现错误。 Use SSL:不要勾选 VM option:填写,"-Djdk.tls.disabledAlgorithmsSSLv3, RC4, DES, MD5withR…

【Redis | 第五篇】一篇文章看懂布谷鸟过滤器

文章目录 5.布谷鸟过滤器5.1起源介绍5.2原理5.2.1演示步骤(1)保存元素(两个位置均为空)(2)保存元素(其中一个位置被占)(3)保存元素(两个位置都被占…

Linux小项目:在线词典开发

在线词典介绍 流程图如下: 项目的功能介绍 在线英英词典项目功能描述用户注册和登录验证服务器端将用户信息和历史记录保存在数据中。客户端输入用户和密码,服务器端在数据库中查找、匹配,返回结果单词在线翻译根据客户端输入输入的单词在字…

轻松玩转Git

轻松玩转Git 快速入门什么是Git为什么要做版本控制安装git Git实战单枪匹马开始干拓展新功能小结 紧急修复bug分支紧急修复bug方案命令总结工作流 上传GitHub第一天上班前在家上传代码初次在公司新电脑下载代码下班回到家继续写代码到公司继续开发在公司约妹子忘记提交代码回家…

CDH6.3.1离线安装

一、从官方文档整体认识CDH 官方文档地址如下: CDH Overview | 6.3.x | Cloudera Documentation CDH是Apache Hadoop和相关项目中最完整、测试最全面、最受欢迎的发行版。CDH提供Hadoop的核心元素、可扩展存储和分布式计算,以及基于Web的用户界面和重…

蓝桥杯备战刷题three(自用)

1.合法日期 #include <iostream> #include <map> #include <string> using namespace std; int main() {map<string,int>mp;int days[13]{0,31,28,31,30,31,30,31,31,30,31,30,31};for(int i1;i<12;i){for(int j1;j<days[i];j){string sto_strin…

Vivado Vitis 2023.2 环境配置 Git TCL工程管理 MicroBlaze和HLS点灯测试

文章目录 本篇概要Vivado Vitis 环境搭建Vivado 免费标准版 vs 企业版Vivado Windows 安装Vivado 安装更新 Vivado 工程操作GUI 创建工程打开已有工程从已有工程创建, 重命名工程GUI导出TCL, TCL复原工程TCL命令 Vivado 版本控制BlinkTcl脚本新建导出重建工程纯Verilog BlinkTc…

[CISCN2019 华北赛区 Day2 Web1]Hack World 1 题目分析与详解

一、分析判断 进入靶机&#xff0c;主页面如图&#xff1a; 主页面提供给我们一条关键信息&#xff1a; flag值在 表flag 中的 flag列 中。 接着我们尝试输入不同的id&#xff0c;情况分别如图&#xff1a; 当id1时&#xff1a; 当id2时&#xff1a; 当id3时&#xff1a; 我…

YOLOv8改进涨点,添加GSConv+Slim Neck,有效提升目标检测效果,代码改进(超详细)

目录 摘要 主要想法 GSConv GSConv代码实现 slim-neck slim-neck代码实现 yaml文件 完整代码分享 总结 摘要 目标检测是计算机视觉中重要的下游任务。对于车载边缘计算平台来说&#xff0c;巨大的模型很难达到实时检测的要求。而且&#xff0c;由大量深度可分离卷积层构…

C++:常量表达式

C11开始constexpr作为一种声明&#xff0c;为编译器提供了在编译期间确认结果的优化建议&#xff0c;满足部分编译期特性的需求 constexpr和const区别 int b10; const int ab; //运行成功 constexpr int cb; //编译器报错&#xff0c;b的值在编译期间不能确定 const int size1…

面试笔记系列二之java基础+集合知识点整理及常见面试题

目录 Java面向对象有哪些特征&#xff0c;如何应用 Java基本数据类型及所占字节 Java中重写和重载有哪些区别 jdk1.8的新特性有哪些 内部类 1. 成员内部类&#xff08;Member Inner Class&#xff09;&#xff1a; 2. 静态内部类&#xff08;Static Nested Class&#…

Vue 组件和插件:探索细节与差异

查看本专栏目录 关于作者 还是大剑师兰特&#xff1a;曾是美国某知名大学计算机专业研究生&#xff0c;现为航空航海领域高级前端工程师&#xff1b;CSDN知名博主&#xff0c;GIS领域优质创作者&#xff0c;深耕openlayers、leaflet、mapbox、cesium&#xff0c;canvas&#x…

Linux查看进程占用句柄

ps -ef |grep python # 查找工具执行PID python pid 11287lsof -p 11287 |wc -l 查看进程占用句柄设置句柄上限IOError: [Errno 24] Too many open files:解决方法

阿里云短信验证笔记

1.了解阿里云的权限操作 进入AccessKey管理 选择子用户 创建用户组和用户 先创建用户组&#xff0c;建好再进行权限分配 添加短信管理权限 创建用户 创建好后的id和密码在此处下载可以得到 2.开通阿里云短信服务 进行申请&#xff0c;配置短信模板 阿里云短信API文档 短信服务…

逆向案例三:动态xhr包中AES解密的一般步骤,以精灵数据为例

补充知识&#xff1a;进行AES解密需要知道四个关键字&#xff0c;即密钥key,向量iv,模式mode,填充方式pad 一般网页AES都是16位的&#xff0c;m3u8视频加密一般是AES-128格式 网页链接:https://www.jinglingshuju.com/articles 进行抓包结果返回的是密文&#xff1a; 一般思…

【算法大家庭】分治算法

目录 &#x1f953;1.简单介绍 &#x1f9c8;2.汉诺塔问题 1.简单介绍 分治算法是解决问题的一种思想&#xff0c;它将一个大问题分解成若干个小问题&#xff0c;然后分别解决这些小问题&#xff0c;最后将小问题的解合并起来得到原问题的解。 分解&#xff1a;将原问题分解…

Mazing官方 2.17.17版新i功能介绍

iMazing官方 2.17.17版是一款管理苹果设备的软件&#xff0c;是一款帮助用户管理 IOS 手机的PC端应用程序&#xff0c;能力远超 iTunes 提供的终极 iOS 设备管理器。在iMazing官方版上与苹果设备连接后&#xff0c;可以轻松传输文件&#xff0c;浏览保存信息等&#xff0c;功能…

SD-WAN对云服务的影响

近年来&#xff0c;随着企业对云服务的依赖不断增加&#xff0c;SD-WAN技术成为提升连接性能的热门选择。SD-WAN通过简化云集成和连接&#xff0c;以及提升应用程序性能&#xff0c;为企业带来显著的业务优势。这种云连接的改进使企业能够更轻松地接触全球劳动力和潜在客户。 首…

语文专刊《中学语文》是什么级别的刊物?

语文专刊《中学语文》是什么级别的刊物&#xff1f; 《中学语文》创刊于1958年&#xff0c;由国家新闻出版总署批准&#xff0c;经湖北省教育厅主管的省级学术期刊。 《中学语文》是由湖北大学文学院主办、国内外公开发行的学术期刊&#xff0c;主要面向中学语文教师和语文教…