Sora 基础作品之 DiT:Scalable Diffusion Models with Transformer

Paper name

Scalable Diffusion Models with Transformers (DiT)

Paper Reading Note

Paper URL: https://arxiv.org/abs/2212.09748

Project URL: https://www.wpeebles.com/DiT.html

Code URL: https://github.com/facebookresearch/DiT

TL;DR

  • 2022 年 UC Berkeley 出品的论文,将 transformer 应用于 diffusion 上实现了当时最佳的生成效果。DiT 论文作者也是 OpenAI 项目领导者之一,该论文是 Sora 的基础工作之一。

Introduction

背景

  • transformer 在自回归模型中得到了广泛应用,但在其他生成模型框架中的采用较少。例如,扩散模型已经处于近期图像级生成模型进展的前沿;然而,它们都采用了卷积 U-Net 架构作为默认的主干选择
  • 本文展示了 U-Net 的归纳偏置对扩散模型的性能并非至关重要,可以替换为 transformer

本文方案

  • 提出了基于 transformer 的 diffusion 模型 Diffusion Transformers (简称 DiTs)。该架构具有良好的可扩展性,即网络复杂度(以Gflops衡量)与样本质量(以FID衡量)之间存在强相关性。通过简单地放大 DiT 并训练一个具有高容量主干的 LDM,能够在类条件 256 × 256 ImageNet 生成基准测试上达到 2.27 FID 的最新结果
    DiT 效果可视化

Methods

整体设计思路

  • 使用 Latent diffusion models(LDM) + Classifier-free guidance + transformer + VAE (Conv) 架构设计,从下图可以看出该设计的优势,左图显示有 scaling law,右图显示 LDM 相比于 pixel space diffusion 模型 ADM 有优势,不仅精度更高,训练计算量也更低
    scaling law for DiT

Diffusion Transformers

  • DiT 基于 ViT 修改得到,整体架构如下图所示:DiT block 通过区分 condition 的添加方式分为三种设计思路,分别是通过 adaLN (或 adaLN-Zero),cross-attention 或 In-Context
    DiT 架构
DiT 前向过程的各个模块
  • Patchify:DiT 的输入是图像的空间表示 z(对于 256×256×3 的图像,z 的形状为 32×32×4)。DiT的第一层是 “patchify”,它通过将输入中的每个补丁线性嵌入,将空间输入转换成一系列 T 个 token,每个 token 的维度为 d。在执行 patchify 之后,我们对所有输入 token 应用标准的 ViT 基于频率的位置嵌入(正弦-余弦版本)。由 patchify 创建的 token 数量 T 由补丁大小超参数 p 决定。如下图所示,将 p 减半会使 T 增加四倍,因此至少使整个 transformer Gflops 增加四倍。DiT 中主要实验了 p = 2, 4, 8
    patchify

  • DiT block:如整体框架图中所示,根据 condition 加入的不同方式分为以下四种设计思路

    • 上下文条件化。我们简单地将 t 和 c 的向量嵌入作为输入序列中的两个额外 tokens 附加上去,对待它们与图像 tokens 没有区别。这与 ViT 中的 cls tokens 类似,它允许我们无需修改就使用标准的 ViT 模块。在最后一个模块之后,我们从序列中移除条件化 tokens。这种方法对模型的新 Gflops 增加可以忽略不计。
    • 交叉注意力模块。我们将 t 和 c 的嵌入 concat 成一个长度为二的序列,与图像 token 序列分开。transformer 模块被修改为在多头自注意力模块后面增加一个额外的多头交叉注意力层,类似于 Attention is All you need 中的原始设计,也类似于 LDM 用于条件化类别标签的设计。交叉注意力对模型的 Gflops 增加最多,大约增加了 15% 的开销。
    • 自适应层归一化(adaLN)模块。在 GANs 和具有 UNet 骨干的扩散模型中广泛使用自适应归一化层之后,我们探索了用自适应层归一化(adaLN)替换 transformer 模块中的标准归一化层。adaLN 并不是直接学习维度规模的缩放和偏移参数 γ 和 β,而是从 t 和 c 的嵌入向量之和中回归得到它们。在我们探索的三种模块设计中,adaLN 增加的 Gflops 最少,因此是最计算高效的。它也是唯一一个限制对所有 tokens 应用相同函数的条件化机制。
    • adaLN-Zero 模块。之前的 ResNets 工作发现,将每个残差块初始化为恒等函数是有益的。例如,在监督学习环境中,将每个块中最后的批量归一化缩放因子 γ 零初始化可以加速大规模训练。扩散 U-Net 模型使用了类似的初始化策略,在任何残差连接之前零初始化每个块中的最终卷积层。我们探索了对 adaLN DiT 模块的修改,它做了同样的事情。除了回归 γ 和 β,我们还回归了在 DiT 模块内的任何残差连接之前作用的 dimension-wise 的缩放参数 α。初始化 MLP 以输出所有 α 为零向量;这将完整的 DiT 模块初始化为恒等函数。与标准的 adaLN 模块一样,adaLNZero 对模型的 Gflops 增加可以忽略不计。
  • Model size

    • 使用四种配置:DiT-S、DiT-B、DiT-L 和 DiT-XL。它们涵盖了从 0.3 到 118.6 Gflops 不同范围的模型大小和浮点运算分配,使我们能够评估扩展性能。下表提供了配置的详细信息。
      model size
  • Transformer 解码器

    • 在最后的 DiT 模块之后,需要将图像 tokens 序列解码成输出噪声预测和输出对角协方差预测。这两个输出的形状等于原始的空间输入。我们使用标准 linear 解码器来完成这一任务;我们应用最终的层归一化(如果使用 adaLN 则为自适应的)并将每个 token 线性解码成一个 p×p×2C 的张量,其中 C 是输入到 DiT 的空间输入中的通道数。最后,我们将解码后的 tokens 重新排列成它们原始的空间布局,以得到预测的噪声和协方差。

Experiments

训练配置
  • ImageNet, 256x256 或 512x512 训练
  • AdamW, no weight decay
  • constant lr: 1 x 10−4
  • batch size: 256
  • EMA: decay 0.9999
VAE/Diffusion
  • Stable Diffusion 中的 VAE,下采样倍数为 8: 256 × 256 × 3 -> 32 × 32 × 4.
  • tmax=1000
  • linear variance schedule: 1e-4 -> 2e-2
评测指标
  • FID, FID-50k,250 DDPM sampling step

DiT block 消融

  • condition 的加入方式很影响精度:adaLN-Zero 精度最佳,说明权重初始化方式也很重要(让 DiT 的 block 初始化为 identity 函数)。
  • 计算量:in-context (119.4 Gflops), cross-attention (137.6 Gflops), adaptive layer norm (adaLN, 118.6 Gflops) or adaLN-zero (118.6 Gflops)
    DiT block

scaling 分析

  • 提升模型计算量稳定涨点
    scaling 基础实验

  • 只要模型计算量接近,FID 就接近。
    scaling

训练效率
  • 更大的 DiT 模型计算效率更高。
    • 训练计算量的评估方式:Gflops · batch size · training steps · 3。因子 3 大致近似于反向传播的计算量是前向传播的两倍。发现,即使训练时间更长,小型 DiT 模型最终相对于训练步数更少的大型 DiT 模型而言,在计算上变得效率低下。同样,我们发现,除了 patch 大小不同之外,其他都相同的模型即使在控制了训练 Gflops 的情况下,也有不同的性能表现。例如,在大约 1 0 10 10^{10} 1010 Gflops 之后,XL/4 的性能被 XL/2 超越。
      训练效率
可视化效果
  • 提升模型计算量可视化效果明显提升
    visualize scaling

class condition 定量分析

  • 达到 sota 效果,比之前的 sota StyleGAN-XL 精度更高
    class2img
    class2img 512px
增加模型参数量 or 增加 sampling 步数
  • 扩散模型的独特之处在于可以通过增加生成图像时的采样步骤来在训练后使用额外的计算资源。
  • 研究了在使用更多采样计算的情况下,较小模型计算的 DiT 是否能够超越较大的模型。结论是增加采样计算的规模无法弥补模型计算能力的不足
    scaling model vs scaling sampling

Thoughts

  • 符合 scaling law 的简洁架构才是王道。scaling law 在 DiT 这的实验效果极佳,和 OpenAI 价值观相符,这应该是作为 Sora 基础工作之一的原因
  • condition 的加入方式可能还需要更多的 class condition 之外的消融实验,比如 image condition、text condition 等

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/787910.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

罗克韦尔AB的PLC协议和西门子PLC协议转换网关

下面是罗克韦尔(AB)的Compact系列的PLC与西门子S7-1500之间的通讯的配置,实现AB的标签数组与西门子DB数据块之间通讯。 首先在AB的PLC内建立输入和输出数组,用于接收和写入S7-1500的PLC数据,名称分别是IN_INT16、OUT_OUT16,输入80…

为“自研”的KV数据库编写JDBC驱动

一觉醒来,受到梦的启发,自研了一套K/V数据库系统,因为"客户"一直催促我提供数据库的JDBC驱动,无奈之下,只好花费一个上午的时间为用户编写一个。 我们知道,JDBC只定义一系列的接口, 具体的实现需…

WeekPaper:GraphTranslator将知识图谱与大模型对齐

GraphTranslator: 将图模型与大型语言模型对齐,用于开放式任务。 将基于图的结构和信息与大型语言模型的能力整合在一起,以提高在涉及复杂和多样数据的任务中的性能。其目标是利用图模型和大型语言模型的优势,解决需要处理和理解结构化和非结…

Python深度学习034:cuda的环境如何配置

文章目录 1.安装nvidia cuda驱动CMD中看一下cuda版本:下载并安装cuda驱动2.创建虚拟环境并安装pytorch的torch_cuda3.测试附录1.安装nvidia cuda驱动 CMD中看一下cuda版本: 注意: 红框的cuda版本,是你的显卡能装的最高的cuda版本,所以可以选择低于它的版本。比如我的是11…

Prometheus+grafana环境搭建redis(docker+二进制两种方式安装)(四)

由于所有组件写一篇幅过长,所以每个组件分一篇方便查看,前三篇 Prometheusgrafana环境搭建方法及流程两种方式(docker和源码包)(一)-CSDN博客 Prometheusgrafana环境搭建rabbitmq(docker二进制两种方式安装)(二)-CSDN博客 Prometheusgrafana环境搭建m…

HarmonyOS实战开发-一次开发,多端部署-视频应用

介绍 随着智能设备类型的不断丰富,用户可以在不同的设备上享受同样的服务,但由于设备形态不尽相同,开发者往往需要针对具体设备修改或重构代码,以实现功能完整性和界面美观性的统一。OpenHarmony为开发者提供了“一次开发&#x…

Ubuntu20.04安装MatlabR2018a

一、安装包 安装包下载链接 提取码:kve2 网上相关教程很多,此处仅作为安装软件记录,方便后续软件重装,大家按需取用。 二、安装 1. 相关文件一览 下载并解压文件后,如下图所示: 2. 挂载镜像并安装 2…

python实战之宝塔部署flask项目

一. 项目 这个demo只是提供了简单的几个api接口, 并没有前端页面 # -*- coding: utf-8 -*- import flask as fk from flask import jsonify, requestapp fk.Flask(__name__)app.route(/api/hello, methods[GET]) def get_data():return hello world# 假设我们要提供一个获取用…

rabbitmq死信交换机,死信队列使用

背景 对于核心业务需要保证消息必须正常消费,就必须考虑消费失败的场景,rabbitmq提供了以下三种消费失败处理机制 直接reject,丢弃消息(默认)返回nack,消息重新入队列将失败消息投递到指定的交换机 对于核…

每日一题 --- 右旋字符串[卡码][Go]

右旋字符串 题目:55. 右旋字符串(第八期模拟笔试) (kamacoder.com) 题目描述 字符串的右旋转操作是把字符串尾部的若干个字符转移到字符串的前面。给定一个字符串 s 和一个正整数 k,请编写一个函数,将字符串中的后面…

HarmonyOS 应用开发之同步任务开发指导 (TaskPool和Worker)

同步任务是指在多个线程之间协调执行的任务,其目的是确保多个任务按照一定的顺序和规则执行,例如使用锁来防止数据竞争。 同步任务的实现需要考虑多个线程之间的协作和同步,以确保数据的正确性和程序的正确执行。 由于TaskPool偏向于单个独…

scRNA+bulk+MR:动脉粥样硬化五个GEO数据集+GWAS,工作量十分到位

今天给大家分享一篇JCR一区,单细胞bulkMR的文章:An integrative analysis of single-cell and bulk transcriptome and bidirectional mendelian randomization analysis identified C1Q as a novel stimulated risk gene for Atherosclerosis 标题&…

rtph264depay插件分析笔记

1、rtp协议头 2、rtp可以基于TCP或者UDP 其中基于TCP需要加4个字节的RTP标志 3、rtph264depay定义解析函数gst_rtp_h264_depay_process,通过RFC 3984文档实现。 static void gst_rtp_h264_depay_class_init (GstRtpH264DepayClass * klass) {GObjectClass *gobject…

AI资讯2024-04-02 | 前微软副总裁姜大昕携「阶跃星辰」入场,出手即万亿参数大模型!

关注文章底部公众号获取每日AI新闻,以及各种好玩的黑科技,如AI换脸,AI数字人,AI生成视频等工具 阶跃星辰发布万亿参数大模型 终于!国内大模型创业公司最后一位强实力玩家入场——阶跃星辰。它是由微软前全球副总裁姜大昕所创办,公司名称也来源于,发了三个大模型:Step-…

当msvcp120.dll文件找不到了要怎么解决?教你靠谱的3种修复msvcp120.dll方法

当出现msvcp120.dll文件丢失的问题时,不用担心,这是一个常见的情况。在日常使用电脑时,误删或受到计算机病毒影响都可能导致这个问题。为了解决这个问题,今天我们将向大家介绍正确的msvcp120.dll修复方法。 一.msvcp120.dll文件是…

体验OceanBase 的binlog service

OceanBase对MySQL具备很好的兼容性。目前,已经发布了开源版的binlog service工具,该工具能够将OceanBase特有的clog模式转换成binlog模式,以便下游工具如canal、flink cdc等使用。今天,我们就来简单体验一下这个binlog service的功…

RA8889/RA8876显示自定义ASCII字符方法

本文介绍用户自己生成的ASCII字库如何通过RA8889/RA8876显示到液晶屏上。 先上一张实例效果图: 再上程序代码: int main(void) {unsigned short x,y;/* System Clocks Configuration */RCC_Configuration(); delay_init(72); GPIO_Configuration(); …

转圈游戏(acwing)

题目描述: n 个小伙伴(编号从 0 到 n−1)围坐一圈玩游戏。 按照顺时针方向给 n 个位置编号,从 0 到 n−1。 最初,第 0 号小伙伴在第 0 号位置,第 1 号小伙伴在第 1 号位置,…

前端学习<二>CSS基础——17-CSS3的常见边框汇总

CSS3 常见边框汇总 <!DOCTYPE html><html lang"en"><head><meta charset"UTF-8"><title>CSS3 边框</title><style>body, ul, li, dl, dt, dd, h1, h2, h3, h4, h5 {margin: 0;padding: 0;}​body {background-c…

治愈风景视频素材在哪找?日落风景、伤感风景、江南风景这里都有

在这个视频内容为王的时代&#xff0c;做个爆款视频好比烹饪一道米其林三星级大餐&#xff0c;少了那么一点儿神秘的调料&#xff0c;总觉得差了点味道。我&#xff0c;一个在视频剪辑战场上摸爬滚打多年的老兵&#xff0c;今天就来跟大家分享几个找素材的秘密武器&#xff0c;…