论文阅读:Scalable Diffusion Models with Transformers

Scalable Diffusion Models with Transformers

论文链接

介绍

传统的扩散模型基于一个U-Net骨架,这篇文章提出了一种新的扩散模型结构,将U-Net替换为一个transformer,并将这种结构称为Diffusion Transformers (DiTs)。他们还发现,transformer的规模越大(通过Gflops衡量),生成的图片的质量越好(FID越低)。
如图2所示,DiT的规模越大,图片生成的质量越好(左图),和当前流行的扩散模型相比,DiT的计算效率也表现优异。
ImageNet generation with Diffusion Transformers (DiTs)

相关工作

  • Transformers:这篇文章研究了transformer作为扩散模型的骨架时,其规模的性质。
  • Denoising diffusion probabilistic models (DDPMs):传统的扩散模型都使用U-Net作为骨架,本文尝试使用纯transformer作为骨架。
  • Architecture complexity:在结构设计领域,Gflops是常见的衡量结构复杂度的指标。

方法(Diffusion Transformers)

预备知识

  • Diffusion formulation:扩散模型Diffusion Model(DM)在训练过程中,首先向图片中添加噪声,然后预测噪声来从图片中将噪声去除。这样,在推理过程中,首先初始化一个高斯噪声图片,然后去除预测的噪声,即可得到生成的图片。
  • Classifier-free guidance:条件扩散模型引入了额外信息 c c c(比如,类别)作为输入。而classifier-free guidance可以引导生成的图片 x x x是类别 c c c的概率 l o g ( c ∣ x ) log(c|x) log(cx)最大。
  • Latent diffusion models:扩散模型在像素空间上训练和推理的计算开销过大,Latent Diffusion Model(LDM)将像素空间替换为VAE编码得到的潜在空间 z = E ( x ) z=E(x) z=E(x),可以提高计算效率。本文提出的DiT沿用了LDM中的潜在空间,但是在预测潜在空间特征的模型上,将LDM中的U-Net替换为了纯Transformer骨架。

Diffusion Transformer Design Space

Diffusion Transformers (DiTs)是基于Vision Transformer (ViT)的模型,它的大体结构如图3所示,从左图可以看到,输入的噪音特征被分解为不同批,然后被若干个DiT块处理;右边的三张图展示了DiT块的详细结构,分别是三种不同的变体。
The Diffusion Transformer (DiT) architecture
下面对DiT的各层进行分析:
Patchify. 从图3中可以看到,DiT的第一个层是Patchify,其将输入转化为 T T T个token序列。在这之后,作者使用标准ViT中基于频率的位置嵌入处理前面的token序列。而token序列的数量是由一个超参数 p p p决定的, p p p减半导致 T T T翻四倍,并且导致整个transformer的GFlops至少翻四倍,如图4所示。
Input specifications for DiT
DiT block design. 在patchfiy层之后,几个transformer块处理输入token以及一些额外的条件信息,比如,类标签 c c c和时间步数 t t t。作者尝试了4种不同的ViT变体:

  • In-context conditioning:这种变体直接将时间步数 t t t和类标签 c c c作为额外的token添加到输入token序列后面,类似于ViT的cls tokens,因此也可以直接使用标准的ViT块。这种方式引入的Gflops可以忽略不计。
  • Cross-attention block:这种变体将条件信息拼接为一个长度为2的序列,独立于图片输入序列。然后,在transformer块的self-attention层后添加了一个cross-attention层,类似于LDM,在cross-attention层将条件信息加入图片特征中。cross-attention方案增加的Gflops最多,大概15%。
  • Adaptive layer norm (adaLN) block:这种变体将transformer块中标准的layer norm layers替换为adaptive layer
    norm (adaLN),这一技术在GAN相关的模型中被广泛采用。不同于直接学习维度放缩和偏移因子 γ \gamma γ β \beta β,该方案回归 t t t c c c的嵌入的和得到这两个参数。在目前的三种方案中,该变体额外增加的Gflops最少。
  • adaLN-Zero block:先前的工作说明,ResNet中的恒等映射是有益处的。Diffusion U-Net在残差之前,零初始化了每个块中最后一个卷积层。作者采用了和Diffusion U-Net相同的方案。此外,除了回归 γ \gamma γ β \beta β,该方案还对DiT块中残差连接上的放缩因此 α \alpha α进行了回归。对于所有的 α \alpha α,作者初始化MLP以输出零向量,这使得DiT块为一个恒等函数。和adaLN方案一样,ada-Zero方案引入的Gflops也可以忽略不计。

Model Size. 作者设置了四种规模的DiT:DiT-S, DiT-B, DiT-L and DiT-XL,结构复杂度依次增大。
Transformer decoder. 在经过最后的DiT块之后,使用tranformer decoder将输入tokens转化为和输入同等性状的噪音预测。

综上,作者探索了DiT设计空间中的patch_size、transformer架构(4种,in-context,cross-attention, adaptive layer
norm and adaLN-Zero blocks)和model size(4种,DiT-S, DiT-B, DiT-L and DiT-XL)。

实验

实验设置

  • 训练:在256 × 256和512 × 512 图片分辨率的ImageNet数据集上训练。超参数设置几乎和ADM一致。
  • Diffusion:和Stable DIffusion一样使用VAE编码图片和解码特征。
  • 评估指标:主要使用Fr´echet Inception Distance (FID),还使用了Inception Score [51], sFID [34] and Precision/Recall [32]
  • 计算平台:在JAX [1]这个深度学习框架上实现了DiT,在TPU上训练模型。

实验结果

DiT block design. 四个不同的DiT块:in-context (119.4 Gflops), cross-attention (137.6 Gflops),
adaptive layer norm (adaLN, 118.6 Gflops) or adaLN-zero (118.6 Gflops)中, adaLN-zero (118.6 Gflops) 取得最低的FID。其中,adaLN-zero相较于adaptive layer norm的提升,说明了恒等映射的好处。(后续的实验除非特别说明都是在adaLN-zero上做的)

Comparing different conditioning strategies
Scaling model size and patch size. 模型size增大和patch zise减小,均会提高Gflops,降低FID。我们注意到,DiT-L 和DiT-XL的FID很接近,因为它们的Gflops也相对更接近。
Scaling the DiT model improves FID at all stages of training
DiT Gflops are critical to improving performance. 上面的图6再次说明了模型参数量的增大并不等同于DiT模型的图片质量提高,真正的关键是提高Gflops。比如,DiT S/2的表现和DiT B/4接近,因为小的batch size会增大Gflops,二者的Gflops接近,所以FID也接近。
Larger DiT models are more compute-efficient
小的DiT模型即便训练时间更长,相对于训练时间更短的大的DiT模型,其计算效率也是更差的。
这里,作者估计训练计算量的方式为model Gflops · batch size · training steps · 3。
Larger DiT models use large compute more effi-
ciently

State-of-the-Art Diffusion Models

和主流的扩散模型相比,DiT-XL/2 (即参数量最大,patch size最小的DiT)的表现最优。

Scaling Model vs. Sampling Compute

扩散模型有一个比较特殊的点,在生成图片时,它可以通过增加调整采样步数,引入额外的增加的计算量,但是,这并不能弥补训练时模型计算量的差距,即大GFlops的DiT在采样步数少的情况下,仍然能比小GFlops的DiT在采样步数多的情况下,取得更低的FID。

结论

Diffusion Transformers (DiTs)作为一种新的扩散模型,比基于U-Net的扩散模型表现更加优异。并且,其在模型复杂度提高的时候,能够有明显的性能提高,因此,使用更大规模的DiT有助于提高模型性能。此外,DiT也可以用于文生图生成任务。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/729252.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

python77-Python的函数参数,个数可变参数

很多编程语言都允许定义个数可变的参数,这样可以在调用函数时传入任意多个参数。Python当然也不例外,Python 允许在形参前面添加一个星号(*),这样就意味着该参数可接收多个参数值,多个参数值被当成元组传入。下面程序定义了一个形参个数可变的函数。 # !/usr/bin/env pyth…

阿里云服务器使用教程_2024建站教程_10分钟网站搭建流程

使用阿里云服务器快速搭建网站教程,先为云服务器安装宝塔面板,然后在宝塔面板上新建站点,阿里云服务器网aliyunfuwuqi.com以搭建WordPress网站博客为例,来详细说下从阿里云服务器CPU内存配置选择、Web环境、域名解析到网站上线全流…

Vscode连接远程服务器失败解决方案

一、 could not establish connection to “XXX” 尝试使用Remote-SSH插件连接远程的服务器,但是配置显示出错,端口显示试图写入的管道不存在,弹出窗口显示could not establish connection to “XXX” 二、检查Windows的OpenSSH 1.检索是否…

Java中注解@RequestParam 和 @ApiParam详解

一、RequestParam 和 ApiParam的常用属性 RequestParam 和 ApiParam 是在 Spring MVC 控制器方法中使用的注解,它们分别服务于不同的目的: RequestParam RequestParam 是 Spring MVC 中用来处理 HTTP 请求参数的注解,主要用于绑定请求中的查…

数据分析项目[开发中]

学习进度记录: 12.7: 教程链接:智能 BI 项目教程(一) (yuque.com) 前端: Ant Design Pro:开始使用 - Ant Design Pro​​​​​​ 然后安装依赖 yarn install 去除不需要的: 移除国际化 【报错】 …

【数据结构与算法】二分查找题解(二)

这里写目录标题 一、81. 搜索旋转排序数组 II二、167. 两数之和 II - 输入有序数组三、441. 排列硬币四、374. 猜数字大小五、367. 有效的完全平方数六、69. x 的平方根 一、81. 搜索旋转排序数组 II 中等 已知存在一个按非降序排列的整数数组 nums ,数组中的值不必…

【Linux】iftop命令详解

目录 一、iftop简介 二、安装iftop命令 2.1 命令查看测试环境系统信息 2.2 查看iftop版本与命令帮助 三、iftop的基本使用 3.1 直接使用iftop命令 3.2 iftop的显示说明 3.3 指定监控某块网卡 3.4 显示某个网段进出封包流量 3.5 按照流量排序 3.6 过滤显示连接 3.7 …

米酒生产加工污水处理需要哪些工艺设备

米酒生产加工过程中产生的污水是一项重要的环境问题,需要采用适当的工艺设备进行处理。下面将介绍一些常用的污水处理工艺设备。 首先,生产过程中的污水需要进行初级处理,常见的设备包括格栅和砂池。格栅用于去除污水中的大颗粒杂质&#xff…

金相显微镜(金相镜)主要用于材料金相分析 我国市场集中度较低

金相显微镜(金相镜)主要用于材料金相分析 我国市场集中度较低 金相显微镜又称为金相镜,是指通过光学放大,对材料显微组织、低倍组织和断口组织等进行分析研究和表征的光学显微镜。金相显微镜通常由目镜、物镜、照明系统、旋转台等…

视频监控平台EasyCVR+4G/5G应急布控球远程视频监控方案

随着科技的不断发展,应急布控球远程视频监控方案在公共安全、交通管理、城市管理等领域的应用越来越广泛。这种方案通过在现场部署应急布控球,实现对特定区域的实时监控,有助于及时发现问题、快速响应,提高管理效率。 智慧安防视…

window vscode安装node.js

window vscode安装node.js 官网下好vscode 和nodejs 点这个安装 等待安装完毕 在VScode运行 npm -v node -v 显示版本号则成功

【剑指offer】70--矩形覆盖[C++]

目录 1. 题目描述 2. 解题思路 3.【C解法】 4. 输出结果 1. 题目描述 我们可以用 2*1 的小矩形横着或者竖着去覆盖更大的矩形。请问用 n 个 2*1 的小矩形无重叠地覆盖一个 2*n 的大矩形,总共有多少种方法?注意:约定 n 0 时,输…

StarCoder 2:GitHub Copilot本地开源LLM替代方案

GitHub CoPilot拥有超过130万付费用户,部署在5万多个组织中,是世界上部署最广泛的人工智能开发工具。使用LLM进行编程辅助工作不仅提高了生产力,而且正在永久性地改变数字原住民开发软件的方式,我也是它的付费用户之一。 低代码/…

JVM-对象创建与内存分配机制深度剖析 3

JVM对象创建过程详解 类加载检查 虚拟机遇到一条new指令时,首先将去检查这个指令的参数是否能在常量池中定位到一个类的符号引用,并且检查这个 符号引用代表的类是否已被加载、解析和初始化过。如果没有,那必须先执行相应的类加载过程。 new…

【小黑送书—第十一期】>>如何阅读“计算机界三大神书”之一 ——SICP(文末送书)

《计算机程序的构造和解释》(Structure and Interpretation of Computer Programs,简记为SICP)是MIT的基础课教材,出版后引起计算机教育界的广泛关注,对推动全世界大学计算机科学技术教育的发展和成熟产生了很大影响。…

园区内网配置——华为

部署思路 表1 设备登录安全策略部署 部署建议 功能描述 应用场景 配置本地Console口登录的安全功能 配置Console口用户界面的认证方式和用户级别等。 本地Console口登录设备时建议配置此功能,提升本地设备登录的安全性。 配置远程STelnet登录的安全功能 配置…

c1-第三周

文章目录 1月份2.定义一个整形数组arr2.定义整形栈s3.输入一个字符串包括大小写和数字,将其中的大写英文字母改为小写,并且输出数字个数4.根据下面数据,编程实现要求功能: 9月1.编写程序实现以下功能或问题3.完成以下功能4.对运算…

数据结构之单链表详解(C语言手撕)

​ 🎉个人名片:🐼作者简介:一名乐于分享在学习道路上收获的大二在校生 🙈个人主页🎉:GOTXX 🐼个人WeChat:ILXOXVJE 🐼本文由GOTXX原创,首发CSDN…

移动执法远程视频监控方案:视频监控系统EasyCVR+4G/5G移动执法仪

一、背景需求 在现代城市管理中,移动执法仪视频监控方案正逐渐成为一种高效、便捷的管理工具。该方案通过结合移动执法仪和视频监控技术,实现了对城市管理现场的实时监控和取证,有效提升了城市管理水平和效率。 移动执法仪作为现场执法的重…

深入浅出(二)MVVM

MVVM 1. 简介2. 示例 1. 简介 2. 示例 示例下载地址:https://download.csdn.net/download/qq_43572400/88925141 创建C# WPF应用(.NET Framework)工程,WpfApp1 添加程序集 GalaSoft.MvvmLight 创建ViewModel文件夹,并创建MainWindowV…