【论文笔记】Scalable Diffusion Models with State Space Backbone

原文链接:https://arxiv.org/abs/2402.05608

1. 引言

主干网络是扩散模型发展的关键方面,其中基于CNN的U-Net(下采样-跳跃连接-上采样)和基于Transformer的结构(使用自注意力替换采样块)是代表性的例子。

状态空间模型(SSM)在长序列建模方面有极大潜力。本文受Mamba启发,建立基于SSM的扩散模型,称为DiS。DiS将所有输入(时间、条件和有噪声的图像patch)视为离散token。DiS中的状态空间模型使其比CNN和Transformer有更优的放缩性,且有更低的计算开销。

2. 方法

2.1 准备知识

扩散模型:扩散模型逐步向数据加入噪声,然后将此过程反过来从噪声生成数据。噪声的加入过程称为前向过程,可表达为马尔科夫链。逆过程中,使用高斯模型近似真实逆转移,其中学习相当于对噪声的预测(即使用噪声预测网络,来最小化噪声预测目标)。

条件扩散模型会将条件(如类别、文本等,通常形式为索引或连续嵌入)引入噪声预测目标中。

具体公式见扩散模型(Diffusion Model)简介 - CSDN。

状态空间主干:状态空间模型的传统定义是将 x ( t ) ∈ R N x(t)\in\mathbb R^N x(t)RN通过隐状态 h ( t ) ∈ R N h(t)\in\mathbb R^N h(t)RN映射为 y ( t ) ∈ R N y(t)\in\mathbb R^N y(t)RN的线性时不变系统:
h ′ ( t ) = A h ( t ) + B x ( t ) y ( t ) = C h ( t ) h'(t)=Ah(t)+Bx(t)\\y(t)=Ch(t) h(t)=Ah(t)+Bx(t)y(t)=Ch(t)

其中 A ∈ R N × N A\in\mathbb R^{N\times N} ARN×N为状态矩阵, B , C ∈ R N B,C\in\mathbb R^N B,CRN为输入和输出矩阵。真实世界的数据通常为离散形式,可将上式离散化为
h t = A ˉ h t − 1 + B ˉ x t y t = C h t h_t=\bar Ah_{t-1}+\bar Bx_t\\y_t=Ch_t ht=Aˉht1+Bˉxtyt=Cht

其中 A ˉ = exp ⁡ ( Δ ⋅ A ) , B ˉ = ( Δ ⋅ A ) − 1 ( exp ⁡ ( Δ ⋅ A ) − I ) ⋅ ( Δ B ) \bar A=\exp(\Delta\cdot A),\bar B=(\Delta\cdot A)^{-1}(\exp(\Delta\cdot A)-I)\cdot(\Delta B) Aˉ=exp(ΔA),Bˉ=(ΔA)1(exp(ΔA)I)(ΔB)为离散状态参数, Δ \Delta Δ为离散步长。

虽然SSM理论上性质优良,但通常有高计算量和数值不稳定性。结构状态空间模型(S4)通过强制 A A A的形式来减轻这一问题,能达到比Transformer更高的性能;Mamba则进一步通过输入依赖的选择机制和更快的硬件感知算法改进之。

2.2 模型结构设计

DiS参数化噪声预测网络 ϵ θ ( x t , t , c ) \epsilon_\theta(x_t,t,c) ϵθ(xt,t,c),以时间 t t t、条件 c c c和噪声图像 x t x_t xt,预测向 x t x_t xt加入的噪声。DiS基于双向Mamba结构,如下图所示。
在这里插入图片描述
图像patch化:DiS的第一层将输入图像 I ∈ R H × W × C I\in\mathbb R^{H\times W\times C} IRH×W×C转化为拉直的2D patch X ∈ R J × ( p 2 ⋅ C ) X\in\mathbb R^{J\times (p^2\cdot C)} XRJ×(p2C)。然后,通过对每个patch进行线性嵌入,转化为含 J J J个token的、维度为 D D D的序列。为每个输入token使用可学习位置编码。 J = H × W p 2 J=\frac{H\times W}{p^2} J=p2H×W由patch大小 p p p决定。

SSM块:输入token会被一组SSM块处理。SSM块的输入还包括时间 t t t与条件 c c c。本文使用双向序列建模,即SSM块的前向过程包含了前向和反向两个方向的处理。

跳跃连接:本文将 L L L个SSM块分为前半和后半两部分,每部分 ⌊ L 2 ⌋ \lfloor\frac L2\rfloor 2L个。设 h s h a l l o w , h d e e p ∈ R J × D h_{shallow},h_{deep}\in\mathbb{R}^{J\times D} hshallow,hdeepRJ×D分别为跳跃连接分支和主分支的隐状态,则通过拼接和线性投影后再送入下一个SSM块,即 L i n e a r ( C o n c a t ( h s h a l l o w , h d e e p ) ) \mathtt{Linear}(\mathtt{Concat}(h_{shallow},h_{deep})) Linear(Concat(hshallow,hdeep))

线性解码器:需要将最后一个SSM块的隐状态解码为噪声预测和对角化协方差矩阵(与原始输入尺寸相同)。本文使用线性解码器,即LayerNorm+线性层,将每个token转化为 p 2 ⋅ C p^2\cdot C p2C的张量。最后,将解码的token重排为原始大小,得到预测噪声与协方差。

条件引入:本文在输入token的序列上增加时间 t t t与条件 c c c的向量嵌入作为额外token(类似ViT中的类别token),从而无需修改SSM块。在最后一个SSM块后,从序列移除条件token。此外,还用自适应归一化层替换标准归一化层,使模型从 c c c t t t嵌入向量的和中回归缩放和偏移参数。

2.3 计算分析

对序列 X ∈ R 1 × J × D X\in\mathbb R^{1\times J\times D} XR1×J×D和状态扩维默认设置 E = 2 E=2 E=2,自注意力与SSM的计算复杂度分别为 O ( S A ) = 4 J D 2 + 2 J 2 D O(SA)=4JD^2+2J^2D O(SA)=4JD2+2J2D O ( S S M ) = 3 J ( 2 D ) N + J ( 2 D ) N 2 O(SSM)=3J(2D)N+J(2D)N^2 O(SSM)=3J(2D)N+J(2D)N2

其中自注意力的计算是序列长度 J J J的二次方,而SSM则是线性关系。注意 N N N为固定参数。这说明DiS有较强的可放缩性。

3. 实验

3.1 实验设置

数据集:仅使用水平翻转数据增广。

实施细节:本文对DiS的权重使用指数移动平均方法。

3.2 模型分析

patch大小的影响:当模型大小一致时,减小patch大小(增加token数),性能会提高。这可能是扩散模型噪声预测任务的低级特性,导致需要小型patch,而不像更高级的分类任务。对高分辨率图像,使用小尺寸patch可能会引入高计算成本,可将图像转换为低维隐式表达,然后再使用DiS处理。

长跳跃的影响:比较拼接( L i n e a r ( C o n c a t ( h s h a l l o w , h d e e p ) ) \mathtt{Linear}(\mathtt{Concat}(h_{shallow},h_{deep})) Linear(Concat(hshallow,hdeep))
、求和( h s h a l l o w + h d e e p h_{shallow}+h_{deep} hshallow+hdeep)和无跳跃连接三种方式。实验表明,求和不会带来明显的性能提升,因为SSM自身可以通过线性方式保留一些浅层信息。而使用拼接和可学习的线性投影可以大幅增加性能。

条件组合:比较两种引入时间 t t t的方案:(1)将 t t t视为token,与图像patch一同处理;(2)将 t t t的嵌入整合到SSM块的层归一化中,类似U-Net中的自适应分组归一化,得到自适应层归一化: A d a L N ( h , s ) = y s L a y e r N o r m ( h ) + y b AdaLN(h,s)=y_s\mathtt{LayerNorm}(h)+y_b AdaLN(h,s)=ysLayerNorm(h)+yb,其中 h h h为SSM的隐状态, y s , y b y_s,y_b ys,yb为时间嵌入的线性投影。实验表明前者的性能优于后者。

缩放模型大小:增大模型深度(SSM块层数)和宽度(隐状态维度)均能提高性能。

3.3 主要结果

无条件图像生成:DiS与基于U-Net或Transformer的扩散模型有相当的性能,但参数量更少。

以类别为条件的图像生成:本文的方法可以超过其余方法的性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/732091.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数据结构(二)——线性表(单链表)

2.3线性表的链式表示 顺序表的优缺点: 优点:可随机存储,存储密度高 缺点:要求大片连续空间,且改变容量不方便 2.3.1 单链表的基本概念 单链表:用链式存储实现了线性结构。一个结点存储一个数据元素&…

【IEEE列表会议】IEEE第三届信息与通信工程国际会议国际会议(JCICE 2024)

会议简介 Brief Introduction 2024年第三届信息与通信工程国际会议国际会议 (JCICE 2024) 会议时间:2024年5月10日-12日 召开地点:中国福州 大会官网:JCICE 2024-2024 International Joint Conference on Information and Communication Engi…

高电平复位电路工作原理详解

单片机复位电路的作用是:使单片机恢复到起始状态,让单片机的程序从头开始执行,运行时钟处于稳定状态、各种寄存器、端口处于初始化状态等等。目的是让单片机能够稳定、正确的从头开始执行程序。一共分为:高电平复位,低…

程序员失业,被迫开启 PlanB——成为自由职业/独立开发者的第 0 天

程序员失业,被迫开启 PlanB——成为自由职业/独立开发者的第 0 天 今天在逛V2EX的时候看到的一个帖子,程序员中年被裁,被迫开启独立开发这条路。 原贴如下: lastday, 失业啦 公司年前通知我合同到期不续签,今天是我…

docker学习进阶

一、dockerfile解析 官方文档: Dockerfile reference | Docker Docs 1.1、dockfile是什么? dockerfile是用来构建docker镜像的文本文件,由一条条构建镜像所需的指令和参数构成的脚本。 之前我们介绍过通过具体容器反射构建镜像(docker comm…

JavaWeb - 3 - JavaScript(JS)

JavaScript(JS)官方参考文档:JavaScript 教程 JavaScript(简称:JS)是一门跨平台、面向对象的脚本语言,是用来控制网页行为的,它能使网页可交互(脚本语言就不需要编译,直接通过浏览器…

Luajit 2023移动版本编译 v2.1.ROLLING

文章顶部有编好的 2.1.ROLLING 2023/08/21版本源码 Android 64 和 iOS 64 luajit 目前最新的源码tag版本为 v2.1.ROLLING on Aug 21, 2023应该是修正了很多bug, 我是出现下面问题才编的. cocos2dx-lua 游戏 黑屏 并报错: [LUA ERROR] bad light userdata pointer 编…

【校园导航小程序】2.0版本 静态/云开发项目 升级日志

演示视频 【校园导航小程序】2.0版本 静态/云开发项目 演示 首页 重做了首页,界面更加高效和美观 校园指南页 新增了 “校园指南” 功能,可以搜索和浏览校园生活指南 地图页 ①弃用路线规划插件,改用SDK开发包。可以无阻通过审核并发布…

seq2seq翻译实战-Pytorch复现

🍨 本文为[🔗365天深度学习训练营学习记录博客 🍦 参考文章:365天深度学习训练营 🍖 原作者:[K同学啊 | 接辅导、项目定制]\n🚀 文章来源:[K同学的学习圈子](https://www.yuque.com/…

前端使用Ant Design中的Modal框实现长按顶部拖动弹框需求

需求:需要Ant Design中的一个Modal弹框,并且可以让用户按住顶部实现拖动操作 实现:在Ant Design的Modal框的基础上,在title中添加一个onMouseDown去记录拖拽的坐标,然后将其赋值给Modal的style属性 代码部分&#xff…

【JavaEE进阶】 @Transactional详解

文章目录 🍃前言🌲rollbackFor(异常回滚属性)🎄事务隔离级别🚩MySQL事务隔离级别🚩Spring事务隔离级别 🎋Spring事务传播机制🚩什么是事务传播机制🚩事务有哪…

【SpringMVC】RESTFul风格设计和实战 第三期

文章目录 概述一、 RESTFul风格简介二、 RESTFul风格特点三、 RESTFul风格设计规范3.1 HTTP协议请求方式要求3.2 URL路径风格要求 实战一、需求分析二、RESTFul风格接口设计三、后台接口实现总结模糊查询 有范围 返回多数据用户 添加 与 更新 请求参数接收数据显示用户详情/删除…

进腾讯工作一个月,我想辞职了......

前几天,我在网上看到一个微博。 一个应届的校招生,目前入职腾讯,工作了一个月。这一个月给他的感受是大量的写测试用例,自己写测试用例的能力熟练了不少,测试技能倒是没有多大的提高,真正需要技术的工作却…

软考71-上午题-【面向对象技术2-UML】-UML中的图2

一、用例图 上午题,考的少;下午题,考的多。 1-1、用例图的定义 用例图展现了一组用例、参与者以及它们之间的关系。 用例图用于对系统的静态用例图进行建模。 可以用下列两种方式来使用用例图: 1、对系统的语境建模&#xff1b…

ChatGPT 升级出现「我们未能验证您的支付方式/we are unable to authenticate」怎么办?

ChatGPT 升级出现「我们未能验证您的支付方式/we are unable to authenticate」怎么办? 在订阅 ChatGPT Plus 时,有时候会出现以下报错 : We are unable to authenticate your payment method. 我们未能验证您的支付方式。 出现 unable to a…

低密度奇偶校验码LDPC(十)——LDPC码的密度进化

一、密度进化的概念 二、规则LDPC码的密度进化算法(SPA算法) 算法变量表 VN更新的密度进化 CN更新的密度进化 算法总结 程序仿真 参考文献 [1] 白宝明 孙韶辉 王加庆. 5G 移动通信中的信道编码[M]. 北京: 电子工业出版社, 2018. [2] William E. Ryan, Shu Lin. Channel Co…

优牛企讯司法涉诉维度全解析,了解这些小白也可以变专家!

在商业的海洋中,信息的掌握就如同舵手对风向的了解。每一条信息都可能成为引领航船前行的关键,尤其是在法律风险的管理上,准确而及时的信息更是企业稳健航行的保障。 优牛企讯,一款专业的企业司法涉诉监控查询工具,它…

SpringMVC03、HelloSpring

3、HelloSpring 3.1、配置版 新建一个Moudle &#xff0c; springmvc-02-hello &#xff0c; 添加web的支持&#xff01; 确定导入了SpringMVC 的依赖&#xff01; 配置web.xml &#xff0c; 注册DispatcherServlet <?xml version"1.0" encoding"UTF-8…

微调模型(Machine Learning 研习之十二)

现在正处于百模乱战的时期&#xff0c;对于模型微调&#xff0c;想必您是有所了解了&#xff0c;毕竟国外的大语言模型一开源&#xff0c;国内便纷纷基于该模型进行微调&#xff0c;从而开始宣称领先于某某、超越了谁。可到头来&#xff0c;却让人发现他们套壳了国外大语言模型…

Linux Ubuntu 部署SVN

最近需要在ubuntu server上部署一个svn&#xff0c;记录 不需要特定版本 如果不需要特定版本&#xff0c;这样安装就是最简单的 sudo apt update然后开始安装 sudo apt install subversion等到安装完成之后执行查看版本命令&#xff0c;如果正常输出那就没问题了 svnadmin …