DiT:Scalable Diffusion Models with Transformers

TOC

  • 1 前言
  • 2 方法和代码

1 前言

该论文发表之前,市面上几乎都是用卷积网络作为实际意义上的(de-facto)backbone。于是一个想法就来了:为啥不用transformer作为backbone呢?

文章说本论文的意义就在于揭示模型选择对于扩散模型的重要性,并为生成模型研究提供一个可借鉴的基准(baseline)。

本文还揭示出卷积网络的inductive bias对生成性能并没有多大的影响,所以可以使用transformer网络去替代卷积网络。文章使用Gflops和FID去分别评估模型复杂度和生成图像质量。

刚刚又去学了一下FLOPs,真是破破烂烂,缝缝补补啊……

总的来说,DiT有如下优点:

  1. 高质量:achieve a state-of-the-art result of 2.27 FID on the classconditional 256 × 256 ImageNet generation benchmark.
  2. 发现了FID和GFLOPs之间存在强相关关系,通过增加depth of transformer或者amount of patches可以增加GFLOPs
  3. 灵活性:可以挑战模型大小、patches大小和序列长度
  4. 跨领域研究:DiT架构和ViT类似,为跨领域研究提供可能

2 方法和代码

在这里插入图片描述
整体来看:

  • 使用transformer作为其主干网络,代替了原先的UNet
  • 在latent space进行训练,通过transformer处理潜在的patch
  • 输入的条件(timestep 和 text/label )的四种处理方法:
    • In-context conditioning: 将condition和input embedding合并成一个tokens(concat),不增加额外计算量
    • Cross-attention block:在transformer中插入cross attention,将condition当作是K、V,input当作是Q
    • Adaptive layer norm (adaLN) block:将timestep和 text/label相加,通过MLP去回归参数scale和shift,也不增加计算量。并且在每一次残差相加时,回归一个gate系数。
    • adaLN-Zero block:参数初始化为0,那么在训练开始时,残差模块当于identical function。
  • 整体流程:patchify -> Transfomer Block -> Linear -> Unpatchify。 注意最后输出的维度是原来维度的2倍,分别输出noise和方差。

由下图可见,adaLN-Zero最好。然后就是探索各种调参效果,此处略。
在这里插入图片描述

代码以及注释:
DiTBlock

# DIT的核心子模块
class DiTBlock(nn.Module):"""A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning."""def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):super().__init__()self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)# 此处为miltihead-self-Attentionself.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)mlp_hidden_dim = int(hidden_size * mlp_ratio)approx_gelu = lambda: nn.GELU(approximate="tanh")self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)#使用自适应归一化替换标准归一化层self.adaLN_modulation = nn.Sequential(nn.SiLU(),nn.Linear(hidden_size, 6 * hidden_size, bias=True))def forward(self, x, c):shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))return x
  • addLN_zero: 先通过SiLU,然后再通过线性层输出6个值

forward

  def forward(self, x, t, y):x = self.x_embedder(x) + self.pos_embed  # (N, T, D), where T = H * W / patch_size ** 2t = self.t_embedder(t)                   # (N, D)# time step embeddingy = self.y_embedder(y, self.training)    # (N, D)c = t + y                                # (N, D)# 送入上述的DIT-Block中for block in self.blocks:x = block(x, c)                      # (N, T, D)x = self.final_layer(x, c)                # (N, T, patch_size ** 2 * out_channels)x = self.unpatchify(x)                   # (N, out_channels, H, W)return x
  • x通过embedding,与position embedding相加(固定的sin-cos位置编码)
  • t通过embedding
  • y通过embedding, t和y相加得到c
  • 遍历每一个block,传入x和c
  • 最后传入最后一层线性层,然后通过unpatchify恢复图像
class FinalLayer(nn.Module):"""The final layer of DiT."""def __init__(self, hidden_size, patch_size, out_channels):super().__init__()self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)self.adaLN_modulation = nn.Sequential(nn.SiLU(),nn.Linear(hidden_size, 2 * hidden_size, bias=True))nn.init.constant_(self.adaLN_modulation[-1].weight, 0)nn.init.constant_(self.adaLN_modulation[-1].bias, 0)nn.init.constant_(self.linear.weight, 0)nn.init.constant_(self.linear.bias, 0)def forward(self, x, c):shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)x = modulate(self.norm_final(x), shift, scale)x = self.linear(x)return x
  • 同样引入adpLN_zero,并且让输出维度为p*p*2c,是特征维度原来大小的2倍,分别预测noise和方差

最后unpatchify

    def unpatchify(self, x):x: (N, T, patch_size**2 * C)imgs: (N, H, W, C)"""c = self.out_channelsp = self.x_embedder.patch_size[0]h = w = int(x.shape[1] ** 0.5)assert h * w == x.shape[1]x = x.reshape(shape=(x.shape[0], h, w, p, p, c))x = torch.einsum('nhwpqc->nchpwq', x)imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))return imgs

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/729873.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

用python写一个自动进程守护,带UI

功能是指定程序关闭后自动重启,并点击1作为启动 原来的想法是群成员说的某软件打包后,软件进程被杀后,界面白屏。所以写了个计算器重启demo进行进程守护 import subprocess import time import pyautogui import psutil #用计算器做演示。 d…

WiFi模块助力敏捷办公:现代办公室的关键角色

随着信息技术的飞速发展,现代办公室正经历着一场数字化和智能化的变革。在这一变革过程中,WiFi模块作为无线通信技术的核心组成部分,扮演着关键的角色,为敏捷办公提供了强大的支持。本文将深入探讨WiFi模块在现代办公室中的关键角…

Spring Boot工作原理

Spring Boot Spring Boot 基于 Spring 开发,Spirng Boot 本身并不提供 Spring 框架的核心特性以及扩展功能,只是用于快速、敏捷地开发新一代基于 Spring 框架的应用程序。也就是说,它并不是用来替代 Spring 的解决方案,而是和 Spr…

安康杯安全知识竞赛上的讲话稿

各位领导、同志们: 经过近半个月时间的准备,南五十家子镇平泉首届安康杯安全生产知识竞赛初赛在今天圆满落下帏幕,经过紧张激烈的角逐, 代表队、 代表队和 代表队分别获得本次竞赛的第一、二、三名让我们以热烈的掌声表示祝…

使用插件vue-seamless-scroll 完成内容持续动态

1、安装插件 npm install vue-seamless-scroll --save 2、项目中引入 //单独引入import vueSeamlessScroll from vue-seamless-scrollexport default {components: { vueSeamlessScroll},}//或者在main.js引入import scroll from vue-seamless-scrollVue.use(scroll)3、页面使…

SRS服务器ffmpeg 推流rtmp超时中断

ffmpeg错误显示 failed to update header with correct duration failed to update header with correct filesize. Error writing trailer of rtmp://----- broken pipe SRS日志错误显示 serve error code2056 kickoffforidle : service cycle : rtmp stream service: timeou…

基于Pytorch搭建分布式训练环境

Pytorch系列 文章目录 Pytorch系列前言一、DDP是什么二、DPP原理terms、nodes 和 ranks等相关术语解读DDP 的局限性为什么要选择 DDP 而不是 DP代码演示1. 在一个单 GPU 的 Node 上进行训练(baseline)2. 在一个多 GPU 的 Node 上进行训练临门一脚&#x…

【深度学习笔记】稠密连接网络(DenseNet)

注:本文为《动手学深度学习》开源内容,部分标注了个人理解,仅为个人学习记录,无抄袭搬运意图 5.12 稠密连接网络(DenseNet) ResNet中的跨层连接设计引申出了数个后续工作。本节我们介绍其中的一个&#xf…

5个实用的PyCharm插件

大家好,本文向大家推荐五个顶级插件,帮助开发人员提升PyCharm工作流程,将生产力飞升到新高度。 1.CodiumAI 安装链接:https://plugins.jetbrains.com/plugin/21206-codiumate--code-test-and-review-with-confidence--by-codium…

Windows上基于名称快速定位文件和文件夹的免费工具Everything

在Windows上搜索文件时,使用windows上内置搜索会很慢,这里推荐使用Everything工具进行搜索。 "Everything"是Windows上一款搜索引擎,它能够基于文件名快速定位文件和文件夹位置。不像Windows内置搜索,"Everything&…

容器:Docker部署

docker 是容器,可以将项目的环境(比如 java、nginx)和项目的代码一起打包成镜像,所有同学都能下载镜像,更容易分发和移植。 再启动项目时,不需要敲一大堆命令,而是直接下载镜像、启动镜像就可以…

echarts x轴名称过长tip显示全称

xAxis的axisLabel的内容如下: axisLabel: { rotate: -45, color: document.body.className.indexOf(custom-f4c46d) > -1 ? #fff : #343434, // 显示省略号操作(第一步) formatter: function (value) { var val if (value.length >…

NTP协议介绍

知识改变命运,技术就是要分享,有问题随时联系,免费答疑,欢迎联系! 网络时间协议NTP(Network Time Protocol)是TCP/IP协议族里面的一个应用层协议,用来使客户端和服务器之间进行时…

C while 循环

只要给定的条件为真,C 语言中的 while 循环语句会重复执行一个目标语句。 语法 C 语言中 while 循环的语法: while(condition) {statement(s); }在这里,statement(s) 可以是一个单独的语句,也可以是几个语句组成的代码块。 co…

IOS开发0基础入门UIkit-1cocoapod安装、更新和使用 , 安装中出现的错误及解决方案 M1或者M2安装cocoapods

cocoapod是ios开发时常用的包管理工具 1.M1或者是M2系统安装cocoapods先操作一下两个设置 1、打开访达->应用->实用工具->终端->右键点击终端->显示简介->勾选使用 Rosetta 打开,关闭终端,重新打开。 2、打开访达->应用->Xcod…

ApiPost设置预执行脚本获取token,并设置给请求头

ApiPost设置预执行脚本获取token,并设置给请求头 预执行脚本 这个地方获取字段为 {"msg": "操作成功","code": 200,"token": "eyJhbGciOixMiJ9.123-NQQPPKGr4Yxa1_H_JIrUXJQ" }修改head 里面参数

OpenAI劲敌吹新风! Claude 3正式发布,Claude3使用指南

Claude 3是什么? 是Anthropic 实验室近期推出的 Claude 3 大规模语言模型(Large Language Model,LLM)系列,代表了人工智能技术的一个显著飞跃。 该系列包括三个不同定位的子模型:Claude 3 Haiku、Claude 3…

BUUCTF-Misc3

LSB1 1.打开附件 得到一张图片,像是某个大学的校徽 2.Stegsolve工具 根据标题LSB,可能是LSB隐写 放到Stegsolve中,点Analyse在点Data Extract 数据提取 因为是LSB隐写,发现含以.png结尾的图片 3.保存图片 4.得到flag 扫描二维…

一招教你优化TCP提高大文件传输效率

在当今企业的数据传输实践中,传统的传输控制协议(TCP)在处理大型文件传输时,其固有的可靠性和复杂性有时会导致效率不足。为了提升大文件传输的效率,对TCP进行优化成为了一个关键任务。 TCP传输的可靠性是其核心优势&a…

UnityShader常用算法笔记(颜色叠加混合、RGB-HSV-HSL的转换、重映射、UV序列帧动画采样等,持续更新中)

一.颜色叠加混合 1.Blend混合 // 正常,透明度混合 Normal Blend SrcAlpha OneMinusSrcAlpha //柔和叠加 Soft Additive Blend OneMinusDstColor One //正片叠底 相乘 Multiply Blend DstColor Zero //两倍叠加 相加 2x Multiply Blend DstColor SrcColor //变暗…