[DL]深度学习_扩散模型正弦时间编码

1  扩散模型时间步嵌入

1.1  时间步正弦编码

        在扩散模型按时间步 t 进行加噪去噪过程时,需要包括反映噪声水平的时间步长 t 作为噪声预测器的额外输入。但是最初与图像配套的时间步 t 是数字,需要将代表时间步 t 的数字编码为向量嵌入。嵌入时间向量的宽度dim是按照输出设定的。

def timestep_embedding(timesteps, dim, max_period=10000):"""Create sinusoidal timestep embeddings.:param timesteps: a 1-D Tensor of N indices, one per batch element.These may be fractional.:param dim: the dimension of the output.:param max_period: controls the minimum frequency of the embeddings.:return: an [N x dim] Tensor of positional embeddings."""half = dim // 2freqs = th.exp(-math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half).to(device=timesteps.device)args = timesteps[:, None].float() * freqs[None]embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)if dim % 2:embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)return embedding

参数定义

  • timesteps--时间步 t ,包含batch中每个元素的时间步信息,是一个一维张量,形状为[N];
  • dim--编码后时间嵌入的最后一维大小,即输出多宽的向量;
  • max_period--控制嵌入的最小频率,值为10000。 

        假设此时的batch中有4个图像,则此时4个图像有4个对应的扩散时间步,所需编码的时间嵌入宽度是8。timesteps的形状为[4] 。

timesteps = [t1, t2, t3, t4], dim = 8

编码过程

    half = dim // 2

        计算所需输出向量的一半尺寸,以供后续将正弦嵌入和余弦嵌入按照最后维度concat拼接。

    freqs = th.exp(-math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half).to(device=timesteps.device)

        用了固定化编码方式,计算出的频率向量,且该频率向量的维度是由dim的一半half决定的。arange(start=0,end=half)则代表arange(0,4),若以i为代表,则此时 i = 0,1,2,3。公式如下:

freqs=e^{(-\frac{i\times log(maxperiod)}{half})} 

则此时计算出的 freqs = [freqs1,freqs2,freqs3,freqs4],是一个一维张量,形状为[4]。

    args = timesteps[:, None].float() * freqs[None]

        首先,timesteps[:, None]作用是将timesteps的形状从一维张量[4]扩展为2维张量[4x1],之前是 timesteps = [ t1, t2, t3, t4 ] 扩展为 timesteps = [ [ t1 ], [ t2 ], [ t3 ], [ t4 ] ]。

        其次将freqs的形状从一维张量[4]扩展为2维张量[1x4],之前是 freqs = [freqs1,freqs2,freqs3,freqs4] 扩展为 freqs = [ [freqs1,freqs2,freqs3,freqs4] ]。

        这样做的目的是为了将时间步 timesteps 代表的batch内每个图片的时间步与频率 freqs 做乘法,将时间步广播进去。得出args的形状是 [4x4] ,即 [N x half] ,每个时间步对应得到half列,有N个时间步。

    embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)

        将得到的args求cos值和sin值,此时得到的 cos(args) 与 sin(args) 的形状依旧都是 [N x half],按照最后一维拼接 cos(args) 与 sin(args) ,则拼接后得到的embedding的形状是 [4 x 8], [N x 2half] ,即 [N x dim]。对应开始定义的 dim 代表的是编码后的最后维度大小。 

    if dim % 2:embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)return embedding

        最后加一个判断,当 dim % 2 ==1 即dim是奇数时,embedding的形状是 [N x dim-1],需要用零向量补充1个维度,变为 [N x dim] 。

1.2  馈入多层感知机 

        当对时间步长执行正弦编码之后,需要将其馈送到多层感知机中获得隐式时间嵌入。 

        time_embed_dim = model_channels * 4self.time_embed = nn.Sequential(linear(model_channels, time_embed_dim),nn.SiLU(),linear(time_embed_dim, time_embed_dim),)emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))

        多层感知机由两个具有倒置瓶颈结构的线性层和一个SiLU激活函数组成。其中dim = channels = 8,time_embed_dim = model_channels * 4 = 32。需要注意是 N 为 batch 大小。

  • linear(model_channels, time_embed_dim)中,输入形状为 [4 x 8] ,[N x dim] 即 [N x model_channels] ,输出形状为 [4 x 32] , [N x time_embed_dim]。
  • 激活函数不会改变形状。
  • linear(time_embed_dim, time_embed_dim)中,输入形状为 [4 x 32] ,[N x time_embed_dim],输出形状为 [4 x 32],[N x time_embed_dim]。

        多层感知机主要作用是将输入的时间步正弦编码向量变得更宽。通过激活函数引入非线性。 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/62160.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Golang context 的作用和实现原理

context的作用 Go语言中的context包提供了一种在进程中跨API和跨进程边界传递取消信号、超时和其他请求范围的值的方式。context的主要作用包括: 取消信号(Cancellation): 当一个操作需要取消时,可以通过context传递…

【超全】目标检测模型分类对比与综述:单阶段、双阶段、有无锚点、DETR、旋转框

《------往期经典推荐------》 一、AI应用软件开发实战专栏【链接】 项目名称项目名称1.【人脸识别与管理系统开发】2.【车牌识别与自动收费管理系统开发】3.【手势识别系统开发】4.【人脸面部活体检测系统开发】5.【图片风格快速迁移软件开发】6.【人脸表表情识别系统】7.【…

Llmcad: Fast and scalable on-device large language model inference

题目:Llmcad: Fast and scalable on-device large language model inference 发表于2023.09 链接:https://arxiv.org/pdf/2309.04255 声称是第一篇speculative decoding边缘设备的论文(不一定是绝对的第一篇),不开源…

Edge浏览器保留数据,无损降级退回老版本+禁止更新教程(适用于Chrome)

3 个月前阿虚就已经写文章告警过大家,Chromium 内核的浏览器将在 127 以上版本开始限制仍在使用 Manifest V2 规范的扩展:https://mp.weixin.qq.com/s/v1gINxg5vMh86kdOOmqc6A 像是 IDM、油猴脚本管理器、uBblock 等扩展都会受到影响,后续将无…

人工智能零基础入门学习笔记

学习视频:人工智能零基础入门教程 文章目录 1.简介2.应用3.演进4.机器学习5.深度学习6.强化学习7.图像识别8.自然语言9.Python10.Python开发环境11.机器学习算法1.多元线性回归项自实战:糖尿病回归预测 2.逻辑回归3.Softmax回归项目实战:鸢尾…

Spring Boot 3 集成 Spring Security(3)数据管理

文章目录 准备工作新建项目引入MyBatis-Plus依赖创建表结构生成基础代码 逻辑实现application.yml配置SecurityConfig 配置自定义 UserDetailsService创建测试 启动测试 在前面的文章中我们介绍了 《Spring Boot 3 集成 Spring Security(1)认证》和 《…

Wireshark抓取HTTPS流量技巧

一、工具准备 首先安装wireshark工具,官方链接:Wireshark Go Deep 二、环境变量配置 TLS 加密的核心是会话密钥。这些密钥由客户端和服务器协商生成,用于对通信流量进行对称加密。如果能通过 SSL/TLS 日志文件(例如包含密钥的…

Redis(概念、IO模型、多路选择算法、安装和启停)

一、概念 关系型数据库是典型的行存储数据库,存在的问题是,按行存储的数据在物理层面占用的是连续存储空间,不适合海量数据存储。 Redis在生产中使用的最多的是用作数据缓存。 服务器先在缓存中查询数据,查到则返回,…

Cobalt Strike 4.8 用户指南-第十一节 C2扩展

11.1、概述 Beacon 的 HTTP 指标由 Malleable Command and Control (Malleable C2) 配置文件控制。Malleable C2 配置文件是一个简单的程序,它指定如何转换数据并将其存储在事务中。转换和存储数据的同一程序(向后解释&#xff0…

哪里能找到好用的动物视频素材 优质网站推荐

想让你的短视频增添些活泼生动的动物元素?无论是搞笑的宠物瞬间,还是野外猛兽的雄姿,这些素材都能让视频更具吸引力。今天就为大家推荐几个超实用的动物视频素材网站,不论你是短视频新手还是老手,都能在这些网站找到心…

力扣_876. 链表的中间结点

力扣_876. 链表的中间结点 给你单链表的头结点 head ,请你找出并返回链表的中间结点。 如果有两个中间结点,则返回第二个中间结点。 输入:head [1,2,3,4,5] 输出:[3,4,5] 解释:链表只有一个中间结点,值为…

HarmonyOS ArkTS 基于CommonDialog实现自定义AlertDialog

在鸿蒙系统(HarmonyOS)中,CommonDialog 是一个用于显示对话框的组件,类似于 Android 的 AlertDialog。如果你想在鸿蒙系统中使用 ArkTS 自定义一个 AlertDialog,你可以基于 CommonDialog 实现。 步骤 1:创…

c++ 主函数里的return 0写不写的区别是什么?

在 C 中,main 函数是程序的入口点。main 函数的标准定义如下: int main() {// ... 代码 ... } 或者可以带参数: int main(int argc, char* argv[]) {// ... 代码 ... } main 函数的返回类型是 int,这意味着它应该返回一个整数…

mysql之慢查询设置及日志分析

mysql之慢查询日志分析 1.临时开启慢查询日志2.永久开启慢查询日志 慢查询是指mysql提供的日志记录功能,用来记录执行时间超过设置阈值的sql语句,并将信息写入到日志文件中; 1.临时开启慢查询日志 注意: 1.以下命令需要连接进入到…

代码随想录算法训练营第五十九天|Day59 图论

Bellman_ford 算法精讲 https://www.programmercarl.com/kamacoder/0094.%E5%9F%8E%E5%B8%82%E9%97%B4%E8%B4%A7%E7%89%A9%E8%BF%90%E8%BE%93I.html 思路 #include <stdio.h> #include <stdlib.h> #include <limits.h>#define MAXM 10000 // 假设最大边数为1…

【快速入门 LVGL】-- 1、STM32 工程移植 LVGL

目录 一、LVGL 简述 二、复制一个STM32工程 三、下载 LVGL 四、裁剪 源文件 五、工程添加 LVGL 文件 六、注册 显示 七、注册 触摸屏 八、LVGL 心跳、任务刷新 九、开跑 LVGL 十、控件的事件添加、响应处理 十 一、几个好玩小事情 十 二、显示中文 ~~ 约定 ~~ 在…

小程序租赁系统开发的优势与应用解析

内容概要 随着科技的迅猛发展&#xff0c;小程序租赁系统应运而生&#xff0c;成为许多企业优化业务的重要工具。首先&#xff0c;它提升了用户体验。想象一下&#xff0c;用户只需轻轻一点&#xff0c;就能够浏览和租赁心仪的商品&#xff0c;这种便捷的过程使繁琐的操作大大…

LLM应用-prompt提示:RAG query重写、相似query生成 加强检索准确率

参考&#xff1a; https://zhuanlan.zhihu.com/p/719510286 1、query重写 你是一名AI助手&#xff0c;负责在RAG&#xff08;知识库&#xff09;系统中通过重构用户查询来提高检索效果。根据原始查询&#xff0c;将其重写得更具体、详细&#xff0c;以便更有可能检索到相关信…

什么是域名监控?

域名监控是持续跟踪全球域名系统&#xff08;DNS&#xff09;中变化以发现恶意活动迹象的过程。组织可以对其拥有的域名进行监控&#xff0c;以判断是否有威胁行为者试图入侵其网络。他们还可以对客户的域名使用这种技术以执行类似的检查。 你可以将域名监控比作跟踪与自己实物…

Spring Boot 3 集成 Spring Security(2)授权

文章目录 授权配置 SecurityFilterChain基于注解的授权控制自定义权限决策 在《Spring Boot 3 集成 Spring Security&#xff08;1&#xff09;》中&#xff0c;我们简单实现了 Spring Security 的认证功能&#xff0c;通过实现用户身份验证来确保系统的安全性。Spring Securit…