AIGC笔记--条件自回归Transformer的搭建

1--概述

        1. 自回归 TransFormer 规定Token只能看到自身及前面的Token,因此需生成一个符合规定的Attention Mask;(代码提供了两种方式自回归Attention Mask的定义方式);

        2. 使用Cross Attention实现条件模态和输入模态之间的模态融合,输入模态作为Query,条件模态作为Key和Value;

2--代码

import torch
import torch.nn as nnclass CrossAttention(nn.Module):def __init__(self, embed_dim: int, num_heads: int):super().__init__()self.cross_attn = nn.MultiheadAttention(embed_dim, num_heads)def forward(self, input_x: torch.Tensor, condition: torch.Tensor, attn_mask: torch.Tensor = None):'''query: input_xkey: conditionval: condition'''input_x = self.cross_attn(input_x, condition, condition, attn_mask=attn_mask)[0]return input_xclass Cond_Autoregressive_layer(nn.Module):def __init__(self, input_dim: int, condtion_dim: int, embed_dim: int, num_heads: int):super(Cond_Autoregressive_layer, self).__init__()self.linear1 = nn.Linear(input_dim, embed_dim)self.linear2 = nn.Linear(condtion_dim, embed_dim)self.cond_multihead_attn = CrossAttention(embed_dim = embed_dim, num_heads = num_heads)def forward(self, input_x: torch.Tensor, conditon: torch.Tensor, attention_mask1: torch.Tensor, attention_mask2: torch.Tensor):# q, k, v, attention mask, here we set key and value are both condtion y1 = self.cond_multihead_attn(self.linear1(input_x), self.linear2(conditon), attn_mask = attention_mask1)y2 = self.cond_multihead_attn(self.linear1(input_x), self.linear2(conditon), attn_mask = attention_mask2)return y1, y2if __name__ == "__main__":# set sequence len, embedding dim, multi attention headseq_length = 10input_dim = 32condtion_dim = 128embed_dim = 64num_heads = 8# init input sequence and condtioninput_x = torch.randn(seq_length, 1, input_dim)condtion = torch.randn(seq_length, 1, condtion_dim)# create two attention mask (actually they have the same function)attention_mask1 = torch.triu((torch.ones((seq_length, seq_length)) == 1), diagonal=1) # bool typeattention_mask2 = attention_mask1.float() # True->1 False->0attention_mask2 = attention_mask2.masked_fill(attention_mask2 == 1, float("-inf"))  # Convert ones to -inf# init modelAG_layer = Cond_Autoregressive_layer(input_dim, condtion_dim, embed_dim, num_heads)# forwardy1, y2 = AG_layer(input_x, condtion, attention_mask1, attention_mask2)# here we demonstrate the attention_mask1 and attention_mask2 have the same functionassert(y1[0].equal(y2[0]))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/725709.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【GIT】git合并分支

假如目前我们处于dev分支 一、重点:我们在开发前必须养成pull的习惯 git pull origin dev二、开发完毕后执行以下命令,即可将代码push到远程仓库 git add . git commit -m 提交的备注信息 git push origin dev三、此时想将dev分支合并到master分支…

文件上传{session文件包含以及条件竞争、图片文件渲染绕过(gif、png、jpg)}

session文件包含以及条件竞争 条件: 知道session文件存储在哪里 一般的默认位置: /var/lib/php/sess_PHPSESSID /var/lib/php/sessions/sess_PHPSESSID /tmp/sess_PHPSESSID /tmp/sessions/sess_PHPSESSID ####在没做过设置的情况下一般都是存储在/var…

【数仓】flume软件安装及配置

相关文章 【数仓】基本概念、知识普及、核心技术【数仓】数据分层概念以及相关逻辑【数仓】Hadoop软件安装及使用(集群配置)【数仓】Hadoop集群配置常用参数说明【数仓】zookeeper软件安装及集群配置【数仓】kafka软件安装及集群配置【数仓】flume软件安…

解决WordPress更新插件或者更新版本报WordPress 需要访问您网页服务器的权限的问题

文章目录 前言一、原因二、解决步骤总结 前言 当对WordPress的插件或者版本进行更新时报错:要执行请求的操作,WordPress 需要访问您网页服务器的权限。 请输入您的 FTP 登录凭据以继续。 如果您忘记了您的登录凭据(如用户名、密码&#xff09…

光线追踪7 - 抗锯齿(Antialiasing)

目前为止,如果你放大渲染出的图像,可能会注意到图像边缘的明显“阶梯状”效果。这种阶梯效果通常被称为“走样”或“锯齿”。当真实相机拍摄图片时,边缘通常没有锯齿,因为边缘像素是一些前景和一些背景的混合。请考虑,…

5. 链接和加载(linker and loader)

链接和加载(linker and loader): linker即链接器,它负责将多个.c编译生成的.o文件,链接成一个可执行文件或者是库文件; loader即加载器,它原本的功能很单一只是将可执行文件的段拷贝到编译确定的内存地址即可&#x…

英福康INFICON残余气体RGA General Chinese中文培训PPT课件

英福康INFICON残余气体RGA General Chinese中文培训PPT课件

【树上倍增】【割点】 【换根法】3067. 在带权树网络中统计可连接服务器对数目

作者推荐 视频算法专题 本文涉及知识点 树上倍增 树 图论 并集查找 换根法 深度优先 割点 LeetCode3067. 在带权树网络中统计可连接服务器对数目 给你一棵无根带权树,树中总共有 n 个节点,分别表示 n 个服务器,服务器从 0 到 n - 1 编号…

Java | 在消息对话框中显示文本

首先需要导入JOptionPane类,JOptionPane类属于Swing组件中的一种,其导入方式如下: import javax.swing.JOptionPane;可以使用JOptionPane的showMessageDialog方法显示消息文本。 参数格式: JOptionPane.showMessageDialog(paren…

【C语言】指针详细解读2

1.const 修饰指针 1.1 const修饰变量 变量是可以修改的,如果把变量的地址交给⼀个指针变量,通过指针变量的也可以修改这个变量。 但是如果我们希望⼀个变量加上⼀些限制,不能被修改,怎么做呢?这就是const的作⽤。 …

AI推介-多模态视觉语言模型VLMs论文速览(arXiv方向):2024.03.01-2024.03.05

论文目录~ 1.CLEVR-POC: Reasoning-Intensive Visual Question Answering in Partially Observable Environments2.Feast Your Eyes: Mixture-of-Resolution Adaptation for Multimodal Large Language Models3.MADTP: Multimodal Alignment-Guided Dynamic Token Pruning for …

RK3568平台开发系列讲解(基础篇)注册字符设备

🚀返回专栏总目录 文章目录 一、字符设备初始化二、字符设备的注册和注销三、实验代码沉淀、分享、成长,让自己和他人都能有所收获!😄 注册字符设备可以分为两个步骤: 字符设备初始化字符设备的添加一、字符设备初始化 字符设备初始化所用到的函数为 cdev_init(…),在对…

Django面对高并发现象时处理方法

首先,我们需要使用适当的数据库引擎来处理高并发。默认情况下,Django使用的是SQLite数据库,但在高并发的情况下,它可能会变得非常慢。我们可以考虑使用更适合高并发的数据库,如MySQL或PostgreSQL。这些数据库引擎具有更…

解决QMYSQL driver not loaded问题

前言 之前都是在Qt5.51上开发,连接mysql数据库一直没有问题,换到5.15.2后一直报错 一查才发现\5.15.2\msvc2019_64\plugins\sqldrivers目录下没有qsqlmysql了,5.5.1是有的,5.15.2是要自己编译的。。。 下载源码 安装qt的时候没…

什么是IoC和AOP?

如何在实际项目中应用这些设计模式? 在实际项目中应用设计模式需要根据项目的需求和特点进行具体的选择和实现。以下是一些常见的方法和建议: 了解设计模式: 首先需要对各种设计模式有深入的了解,包括它们的原理、优缺点以及适用…

Vue tree树状结构数据转扁平数据

//数据结构可参考饿了么UItreeData: [{id: 1,label: Level one 1,type: 1,children: [{id: 4,label: Level two 1-1,type: 2,children: [{id: 9,label: Level three 1-1-1,type: 3}, {id: 10,label: Level three 1-1-2,type: 3}]}, {id: 11,label: Level three 1-2,type: 2,chi…

查看kafka消息消费堆积情况

查看主题命令 展示topic列表 ./kafka-topics.sh --list --zookeeper zookeeper_ip:2181描述topic ./kafka-topics.sh --describe --zookeeper zookeeper_ip:2181 --topic topic_name查看topic某分区偏移量最大(小)值 ./kafka-run-class.sh kafka.too…

2024-3-6 python列表的切片赋值

切片赋值 如果把切片放在赋值语句的左边,或把它作为del操作的对象,我们就可以对序列进行嫁接、切除 或就地修改操作。 >>> l [i for i in range(20)] >>> l [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 1…

小白如何快速入门计算机视觉?

去年 11 月的时候,给自己了一个目标,希望在未来的 3个月时间里,写满100篇关于从零入门AI视觉的算法、代码文字。 历经 3 个月,终于在今天 100 篇文章写完了,代码也全部调试完成,上传到 github 上开源给大家…