Classifier Guidance 与 Classifier-Free Guidance

Classifier Guidance 与 Classifier-Free Guidance

DDPM 终于把 diffusion 模型做 work 了,但无条件的生成在现实中应用场景不多,我们终归还是要可控的图像生成。本文简要介绍两篇关于 diffusion 模型可控生成的工作。其中 Classifier-Free Guidance 的方法还是现在多数条件生成 diffusion 模型的主流思路。

Classifier Guidance: Diffusion Models Beat GANs on Image Synthesis

Classifier-Free Guidance: Classifier-Free Diffusion Guidance

Classifier Guidance

要做可控生成,即条件生成,首先想到我们可以拿类别来作为条件,比如要指定类别猫,就生成猫的图片。也就是说要给定类别 y y y,生成图片 x x x,即 P ( x ∣ y ) P(x|y) P(xy) 。而一般分类器做的事情正好是反过来,给定图片,预测类别,即 P ( y ∣ x ) P(y|x) P(yx) 。这刚好是一对逆条件概率,应该敏锐地想到贝叶斯公式就是处理这类逆概率问题的。推导如下:
∇ log ⁡ P ( x ∣ y ) = ∇ log ⁡ P ( x ) P ( y ∣ x ) P ( y ) = ∇ log ⁡ P ( y ) + ∇ log ⁡ P ( y ∣ x ) − ∇ log ⁡ P ( y ) = ∇ log ⁡ P ( x ) + ∇ log ⁡ P ( y ∣ x ) \begin{aligned} \nabla\log P(x|y)&=\nabla\log\frac{P(x)P(y|x)}{P(y)} \\ &=\nabla\log P(y)+\nabla\log P(y|x)-\nabla\log P(y) \\ &=\nabla\log P(x)+\nabla\log P(y|x) \end{aligned} logP(xy)=logP(y)P(x)P(yx)=logP(y)+logP(yx)logP(y)=logP(x)+logP(yx)
其中 P ( y ) P(y) P(y) 是某个类别的先验概率,是一个常数,其梯度为 0,故直接丢掉。这里的 ∇ P ( x ) \nabla P(x) P(x) 实际就是 score-base model 中所谓的 score,score-based model 实际可以看作是 diffusion model 的另一种形式,这里不展开。

在结果中,第一项 ∇ log ⁡ P ( x ) \nabla\log P(x) logP(x) 就是原本无条件生成的梯度,而第二项 ∇ P ( y ∣ x ) \nabla P(y|x) P(yx) 则相当于是分类器进行图形分类的梯度。也就是说,我们只要在无条件生成的基础上,加上想要的类别的分类器梯度,作为引导(或者称为条件的梯度修正),就可以导出以类别作为条件的生成。

推导看起来并不复杂,具体怎么实现呢?怎么在生成的时候加入分类器的梯度作为引导呢?这里我们参考 OpenAI 原 Classifier Guidance 给出的代码来理解:

# https://github.com/openai/guided-diffusion/blob/main/scripts/classifier_sample.py#L54
# 核心就是这里的cond_fn函数import torch as th
import torch.nn.functional as F
classifier = ... # 加载一个(噪声)图像分类器def cond_fn(x, t, y=None):assert y is not Nonewith th.enable_grad():x_in = x.detach().requires_grad_(True)logits = classifier(x_in, t)log_probs = F.log_softmax(logits, dim=-1)selected = log_probs[range(len(logits)), y.view(-1)]return th.autograd.grad(selected.sum(), x_in)[0] * args.classifier_scale

这里的 t 是当前时间步,x 是当前步的去噪结果图,y 是类别索引。我们看到,计算分类器梯度的过程其实很简单:

  1. 首先把 x 和原始的梯度断开(detach),准备计算分类器的梯度
  2. 把 x_in 和 t 都输入到分类其中,得到分类器预测的类别 logits
    • 注意,这里的分类器实际上需要是一个能够分类带噪声图像的分类器,不仅需要输入图像,还要输入当前时间步 t,相当于告知分类器当前噪声的强度。所以说,在 Classifier Guidance 的方法中,我们虽然不需要重新训练 diffusion 模型,但是我们需要单独训练一个噪声图像分类器。
  3. 再把预测的类别 logits 过一下 softmax,得到各类别的概率 log_probs
  4. 从 log_probs 中取出我们指定的类别 y 对应的概率,即 selected
  5. 最后将 selected 中各个目标类别的概率值加在一起,希望该值越大越好,取该值对于 x 的梯度,即为分类器引导的梯度。

Classifier-Free Guidance

Classifier Guidance 的方法虽然不需要重新训练 diffusion 模型,但是需要额外的训练一个噪声图像分类器,并且在采样时需要额外的梯度引导。最关键的是,其作为可控生成的方法,对结果的控制能力十分有限,仅能够支持分类器所认识的有限类别,这无疑是不能满足我们多种多样的使用需求的。我们想要的肯定是 zeroshot 的,能够直接理解自然语言的可控生成。所幸在 CLIP 之后,图像文本特征已经能够在一定程度上对齐,CV 各个方向都基于 CLIP 实现了 zeroshot / open-vocab,图像生成也不例外。

Classifier-Free Guidance 的方法训练额外的分类器,并且,可以实现各种条件的引导生成。以最火爆的文生图为例,只要结合 CLIP 文本编码器提取 prompt 的文本特征 embedding,输入到 diffusion 模型中作为条件,即可实现。目前,Classifier-Free Guidance 已经成为条件生成的主流思路。

Classifier-Free Guidance 的想法是这样的:同时训练无条件生成模型和条件生成模型(实际上这俩是一个模型,只是训练时有概率输入是有条件的,有概率是无条件的),在推理时,同时 forward 带输入条件的生成和无条件的生成吗,然后把俩结果进行线性组合外推,得到最终的条件生成结果。

直接来看一下伪代码(参考 diffusers 的 API):

unet = ... # 加载unet去噪模型
clip_model = ...  # 加载CLIP模型text = "a cat"  # 文本条件
text_embeddings = clip_model.text_encode(text)  # 编码条件文本,cond
empty_embeddings = clip_model.text_encode("")  # 编码空文本,uncond
text_embeddings = torch.cat(empty_embeddings, text_embeddings)  # concat到一起,只做一次forwardinput = torch.randn((1, 3, sample_size, sample_size), device="cuda") # 采样初始噪声for t in scheduler.timesteps:# 用 unet 推理,预测噪声with torch.no_grad():# 这里同时预测出了有文本的和空文本的图像噪声noise_pred = unet(input, t, encoder_hidden_states=text_embeddings).samplenoise_pred_uncond, noise_pred_text = noise_pred.chunk(2)  # 拆成无条件和有条件的噪声# Classifier-Free Guidance 引导 noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)# 用预测出的 noise_pred 和 x_t 计算得到 x_t-1latents = scheduler.step(noise_pred, t, latents).prev_sample

代码里的写法是 ϵ ˉ = ϵ u + s ( ϵ c − ϵ u ) \bar\epsilon=\epsilon_u+s(\epsilon_c-\epsilon_u) ϵˉ=ϵu+s(ϵcϵu), 论文里的公式是 ϵ ˉ = ( 1 + w ) ϵ c − w ϵ u \bar{\epsilon}=(1+w)\epsilon_c-w\epsilon_u ϵˉ=(1+w)ϵcwϵu,二者是等价的,只是做了下变换 s = 1 + w s=1+w s=1+w。个人感觉代码里这个形式更好理解一点: ϵ c − ϵ w \epsilon_c-\epsilon_w ϵcϵw 表示从无条件到目标条件的一个方向, w w w 是多大程度上考虑条件的系数,在无条件 ϵ u \epsilon_u ϵu 的基础上,再朝目标类别移动一定距离,即: ϵ ˉ = ϵ u + s ( ϵ c − ϵ u ) \bar\epsilon=\epsilon_u+s(\epsilon_c-\epsilon_u) ϵˉ=ϵu+s(ϵcϵu)。(也可能是作者大佬的思路我没领悟到

Classifier-Free Guidance 的做法看起来并不复杂,但有几个问题值得讨论(笔者自己也很不明白,希望有大佬指点一下):

  1. 为什么不像 cvae 一样直接把 embedding 丢进去做条件生成,而是非要同时训练无条件生成的情况,再做一个线性作何外推呢?
    • 可能是因为采样时现需要有一个无条件的基准,然后像目标条件的方向再修正?
  2. 关于空门大佬提到的 Classifier-Free Guidance 破坏了 Neural Diffusion Operator 的准线性性质。
    • 提到在 w w w 很大时,采样结果会崩掉,实践中确实这种现象。但笔者目前尚在学习,还无法完全理解。贴一下大佬的文章链接:Classifer-free Guidance 是万恶之源 。

这些问题有大佬了解,可以指点一下,或者介绍下应该去补充哪些理论知识来深化理解,感激不尽。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/221535.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

一文了解java中volatile关键字

认识volatile volatile关键字的作用有两个:变量修改对其他线程立即可见、禁止指令重排。 第二个作用我们后面再讲,先主要讲一下第一个作用。通俗点来说,就是我在一个线程对一个变量进行了修改,那么其他线程马上就可以知道我修改…

jquery 实现倒计时60秒

jquery 实现倒计时60秒 <!DOCTYPE html> <html><head><meta http-equiv"content-type" content"text/html; charsetUTF-8"><meta content"widthdevice-width,initial-scale1.0,maximum-scale1.0,user-scalableno" i…

树莓派zero w入坑指南

树莓派zero w入坑指南 入坑契机 说起创客不得不提到开源硬件Raspberry Pi(树莓派)。它是一款基于ARM的微型电脑主板&#xff0c;以MicroSD卡为硬盘&#xff0c;提供HDMI和USB等外部接口&#xff0c;可连接显示器和键鼠。以上部件全部整合在一张仅比信用卡稍大的主板上&#x…

pytorch_lightning 安装

在安装pytorch-lightning时一定注意自己的torch是pip安装还是conda安装&#xff0c;pytorch_lightning 安装方式要与torch的安装方式保持一致&#xff0c;否则也会导致你的torch版本被替换。 正确安装方式&#xff1a; pip方式&#xff1a; pip install pytorch-lightning版本…

issue unit

The Issue Unit issue queue用来hold住&#xff0c;已经dispatched&#xff0c;但是还没有执行的uops&#xff1b; 当一条uop的所有的operands已经ready之后&#xff0c;request请求会被拉起来&#xff1b;然后issue select logic将会从request bit 1的slot中&#xff0c;选择…

第十二章 React 路由配置,路由参数获取

一、专栏介绍 &#x1f436;&#x1f436; 欢迎加入本专栏&#xff01;本专栏将引领您快速上手React&#xff0c;让我们一起放弃放弃的念头&#xff0c;开始学习之旅吧&#xff01;我们将从搭建React项目开始&#xff0c;逐步深入讲解最核心的hooks&#xff0c;以及React路由、…

0基础学java-day19(IO流)

一、文件 1 什么是文件 2.文件流 3.常用的文件操作 3.1 创建文件对象相关构造器和方法 package com.hspedu.file;import org.junit.jupiter.api.DynamicTest; import org.junit.jupiter.api.Test;import java.io.File; import java.io.IOException;/*** author 林然* vers…

js根据数组对象中的某个值去重

原理&#xff1a;利用对象key-value进行去重 去重方法&#xff1a; // 数组对象根据某一个值去重 filterList(list[], key) {let obj {};list?.forEach(item>{obj[item[key]]item;});return Object.values(obj); }, 用法&#xff1a; let list [{id: 1, name: 1},{id…

【TI毫米波雷达入门-11】毫米波速度相关计算

知识回顾 傅里叶变换 信号用复数表示&#xff0c;A :振幅&#xff0c; Q &#xff1a;相位 中频 信号 中频信号的相位 中频信号的表达公式 频率和相位的表达方式 使用两个Chirp 实现单个目标的测量 两个连续的chirp &#xff0c;检测目标的相位差&#xff0c;通过速度和时间的关…

7+乳酸化+分型+实验,怎么贴合热点开展实验,这篇文章给你思路

今天给同学们分享一篇生信文章“Identification of lactylation related model to predict prognostic, tumor infiltrating immunocytes and response of immunotherapy in gastric cancer”&#xff0c;这篇文章发表在Front Immunol期刊上&#xff0c;影响因子为7.3。 结果解…

基于java的小型超市管理系统论文

摘 要 使用旧方法对超市信息进行系统化管理已经不再让人们信赖了&#xff0c;把现在的网络信息技术运用在超市信息的管理上面可以解决许多信息管理上面的难题&#xff0c;比如处理数据时间很长&#xff0c;数据存在错误不能及时纠正等问题。 这次开发的小型超市管理系统有管理…

uniapp 数组添加不重复元素

一、效果图 二、代码 //点击事件rightBtn(sub, index) {console.log(sub, index)//uniapp 数组添加不重复元素if (this.selectList.includes(sub.type)) {this.selectList this.selectList.filter((item) > {return item ! sub.type;});} else {this.selectList.push(sub.t…

Java实现ZIP算法压缩和解压操作

ZIP是一种流行的文件压缩格式&#xff0c;它可以将多个文件打包成一个文件&#xff0c;以减小文件大小并方便传输。ZIP文件可以在大多数操作系统和软件中解压缩&#xff0c;例如Windows、Mac和Linux系统上的许多文件管理器和解压缩工具。ZIP压缩可以使用许多不同的压缩算法&…

【DEBUG】plt.cm.hot 的归一化问题

可视化时调用 # Matplotlib有很多内置的colormap&#xff0c;比如jet, viridis, hot等colormap plt.cm.hot # 选择一个colormapimg img.astype(np.float64) # 为了进行归一化&#xff0c;自动转换时float32norm plt.Normalize(vminimg.min(), vmaximg.max()) # 标准化灰度…

365锦鲤助手 砍价小程序源码 流量主引流裂变

源码介绍 修改版365锦鲤 助手&#xff0c; 砍价小程序源码 流量主引流裂变 拼多多商品快速丰富产品内容满足广大用户需求&#xff1b;流量矩阵让流量都进你的圈子飞起来&#xff1b;长期盈利、项目稳定 1.后台安装微擎 2安装应用 后台打包上传

23.12.10日总结

周总结 这周三的晚自习&#xff0c;学姐讲了一下git的合作开发&#xff0c;还有懒加载&#xff0c;防抖&#xff0c;节流 答辩的时候问了几个问题&#xff1a; 为什么在js中0.10.2!0.3? 在js中进行属性运算时&#xff0c;会出现0.10.20.300000000000000004js遵循IEEE754标…

【有限元仿真】or【流体仿真】

流体和刚体的关系&#xff1f; 刚体仿真关注刚性物体的运动和力学行为。刚体是指在外力作用下保持形状和结构不变的物体&#xff0c;不受弯曲或拉伸的影响。刚体仿真基于刚体力学原理和刚体运动学方程&#xff0c;模拟刚体的运动、转动、碰撞等行为。它可以用于模拟刚体之间的…

Mysql进阶-InnoDB引擎事务原理及MVCC

事务原理 事务基础 事务是一组操作的集合&#xff0c;它是一个不可分割的工作单位&#xff0c;事务会把所有的操作作为一个整体一起向系 统提交或撤销操作请求&#xff0c;即这些操作要么同时成功&#xff0c;要么同时失败。 事务的四大特性&#xff1a; 原子性&#xff08;A…

记录 | docker启动权限问题Get Permission Denied

docker 启动报错权限问题&#xff1a; Got permission denied while trying to connect to the Docker daemon socket at unix:///var/run/docker.sock: Get “http://%2Fvar%2Frun%2Fdocker.sock/v1.24/images/json”: dial unix /var/run/docker.sock: connect: permission d…

el-upload添加FormData参数,自定义上传接口

添加 :http-request"selfUpload"&#xff1a; <el-upload:disabled"saveLoading"class"upload-demo":limit"1":on-exceed"handleExceed":before-upload"beforeAvatarUpload":file-list"fileList":a…