Transformer的代码实现 day03(Positional Encoding)

Positional Encoding的理论部分

  • 注意力机制是不含有位置信息,这也就表明:“我爱你”,“你爱我”这两者没有区别,而在现实世界中,这两者有区别。
  • 所以位置编码是在进行注意力计算之前,给输入加上一个位置信息,如下图:
    在这里插入图片描述
  • 位置编码的公式如下:
    • 注意,pos表示该单词在句子中的位置,i表示该单词的输入向量的第i维度
      在这里插入图片描述
  • 由此我们可以得出不同位置之间的位置编码关系:
    在这里插入图片描述

Positional Encoding代码

  • 由于位置编码的公式固定,所以对于相同位置的位置编码也固定,即“我爱你”中的我,和“你爱我”中的你的位置编码相同
  • 所以我们可以一次将所有要输入信息的位置编码都生成出来,之后需要哪个就传哪个
class PositionalEncoding(nn.Module):def __init__(self, dim, dropout, max_len=5000):super(PositionalEncoding, self).__init__()# 确保每个单词的输入维度为偶数,这样sin和cos能配对if dim % 2 != 0:raise ValueError("Cannot use sin/cos positional encoding with ""odd dim (got dim={:d})".format(dim))"""构建位置编码pepe公式为:PE(pos,2i/2i+1) = sin/cos(pos/10000^{2i/d_{model}})"""pe = torch.zeros(max_len, dim)  # max_len 是解码器生成句子的最长的长度,假设是 10,dim为单词的输入维度# 将位置序号从一维变为只有一列的二维,方便与div_term进行运算,# 如将[0, 1, 2, 3, 4]变为:#[  #  [0],  #  [1],  #  [2],  #  [3],  #  [4]  #]position = torch.arange(0, max_len).unsqueeze(1)# 这里使用a^b = e^(blna)公式,来简化运算# torch.arange(0, dim, 2, dtype=torch.float)表示从0到dim-1,步长为2的一维张量# 通过以下公式,我们可以得出全部2i的(pos/10000^2i/dim)方便接下来的pe计算div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) *-(math.log(10000.0) / dim)))# 得出的div_term为从0开始,到dim-1,长度为dim/2,步长为2的一维张量# 将position与div_term做张量乘法,得到的张量形状为(max_len,dim/2)# 将结果取sin赋给pe中偶数维度,取cos赋给pe中奇数维度pe[:, 0::2] = torch.sin(position.float() * div_term)pe[:, 1::2] = torch.cos(position.float() * div_term)# 将pe的形状从(max_len,dim)变成(max_len,1,dim),在第二个维度上增加一个大小为1的新维度# 如从原始 pe 张量形状: (5, 4)  #[  # [a1, b1, c1, d1],  # [a2, b2, c2, d2],  # [a3, b3, c3, d3],  # [a4, b4, c4, d4],  # [a5, b5, c5, d5]  #]# 转换为:执行 unsqueeze(1) 后的 pe 张量形状: (5, 1, 4)  #[  # [[a1, b1, c1, d1]],  # [[a2, b2, c2, d2]],  # [[a3, b3, c3, d3]],  # [[a4, b4, c4, d4]],  # [[a5, b5, c5, d5]]  #]pe = pe.unsqueeze(1)# 将pe张量注册为模块的buffer。在PyTorch中,buffer是模型的一部分,但不包含可学习的参数(即不需要梯度)。# 这样做是因为位置编码在训练过程中是固定的,不需要更新。self.register_buffer('pe', pe)self.drop_out = nn.Dropout(p=dropout)self.dim = dimdef forward(self, emb, step=None):# 做乘法是因为在 Transformer 模型中,位置编码被加到输入张量中,而输入张量通常是词嵌入的向量,其值通常在较小的范围内。# 但是,在将位置编码添加到输入张量之前,我们希望将其值扩大到一个较大的范围,以便位置编码对输入的影响更加显著。# 注意:emb为输入张量,形状为(seq_len, dim),seq_len 表示输入的句子的长度,dim为单词的输入维度emb = emb * math.sqrt(self.dim)# 根据step来选择加入pe的哪一部分if step is None:# 如果pe的形状为(max_len, dim),那么pe[:a]表示:取pe的第0行到第a-1行的全部元素,得到的新二维张量的形状为(a, dim)# 而pe[:, a]表示:取pe的第a-1列的全部元素,得到的新一维张量的形状为(max_len)# 而pe[:, :a]表示:取pe的第0列到第a-1列的全部元素,得到的新二维张量的形状为(max_len,a)emb = emb + self.pe[:emb.size(0)]else:emb = emb + self.pe[step]emb = self.drop_out(emb)return emb

参考文献

  1. 04 Transformer 中的位置编码的 Pytorch 实现

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/793908.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【RISC-V 指令集】RISC-V 向量V扩展指令集介绍(五)- 向量加载和存储

1. 引言 以下是《riscv-v-spec-1.0.pdf》文档的关键内容: 这是一份关于向量扩展的详细技术文档,内容覆盖了向量指令集的多个关键方面,如向量寄存器状态映射、向量指令格式、向量加载和存储操作、向量内存对齐约束、向量内存一致性模型、向量…

Redis -- 缓存穿透问题解决思路

缓存穿透 :缓存穿透是指客户端请求的数据在缓存中和数据库中都不存在,这样缓存永远不会生效,这些请求都会打到数据库。 常见的解决方案有两种: 缓存空对象 优点:实现简单,维护方便 缺点: 额外…

【JavaSE】接口 详解(上)

前言 本篇会讲到Java中接口内容,概念和注意点可能比较多,需要耐心多看几遍,我尽可能的使用经典的例子帮助大家理解~ 欢迎关注个人主页:逸狼 创造不易,可以点点赞吗~ 如有错误,欢迎指出~ 目录 前言 接口 语法…

pta 1086 就不告诉你

1086 就不告诉你 分数 15 全屏浏览 切换布局 作者 CHEN, Yue 单位 浙江大学 做作业的时候,邻座的小盆友问你:“五乘以七等于多少?”你应该不失礼貌地围笑着告诉他:“五十三。”本题就要求你,对任何一对给定的正整数…

新手开抖店:选品过后如何有效对接达人?这些方法100%有效!

哈喽~我是电商月月 要说做抖音小店最主要的是什么?那当然是找品了 那出单最快的方法是什么?无疑是达人带货了! 但新手店铺没销量,没体验分,没好评怎么能让达人同意帮我们带货呢? 方法其实很简单&#x…

“双碳”目标下资源环境中的可计算一般均衡(CGE)模型应用

我国政府承诺在2030年实现“碳达峰”,2060年实现“碳中和”,这就是“双碳”目标。为了实现这一目标就必须应用各种二氧化碳排放量很高技术的替代技术,不仅需要考虑技术上的可靠性,也需要考虑经济上的可行性。可计算一般均衡模型&a…

AI预测福彩3D第26弹【2024年4月4日预测--第4套算法重新开始计算第11次测试】

今天清明节假日,一会要外出,可能要晚点回来。咱们尽早先把预测数据跑完,把结果发出来供各位彩友参考。合并下算法,3D的预测以后将重点测试本套算法,因为本套算法的命中率较高。以后有时间的话会在第二篇文章中发布排列…

UTONMOS:AI+Web3+元宇宙数字化“三位一体”将触发经济新爆点

人工智能、元宇宙、Web3,被称为数字化的“三位一体”,如何看待这三大技术所扮演的角色? 3月24日,2024全球开发者先锋大会“数字化的三位一体——人工智能、元宇宙、Web3.0”论坛在上海漕河泾开发区举行,首次提出&…

深入探索MySQL:成本模型解析与查询性能优化,及未来深度学习与AI模型的应用展望

码到三十五 : 个人主页 在数据库管理系统中,查询优化器是一个至关重要的组件,它负责将用户提交的SQL查询转换为高效的执行计划。在MySQL中,查询优化器使用了一个称为“成本模型”的机制来评估不同执行计划的优劣,并选择…

网络安全 | 什么是负载均衡器?

关注WX: CodingTechWork 介绍 负载均衡是在多个服务器之间有效分配网络流量的过程。负载均衡的目的是优化应用程序的可用性,并确保良好的终端用户体验。负载均衡可协助高流量网站和云计算应用程序应对数百万个用户请求,从而保证客户请求不会…

2012年认证杯SPSSPRO杯数学建模C题(第二阶段)碎片化趋势下的奥运会商业模式全过程文档及程序

2012年认证杯SPSSPRO杯数学建模 C题 碎片化趋势下的奥运会商业模式 原题再现: 从 1984 年的美国洛杉矶奥运会开始,奥运会就不在成为一个“非卖品”,它在向观众诠释更高更快更强的体育精神的同时,也在攫取着巨大的商业价值&#…

颜色空间/模型(RGB, YUV,CMY/CMYK, HSI, HSV等)

什么是颜色 颜色是通过眼、脑和我们的生活经验所产生的对光的视觉感受,我们肉眼所见到的光线,是由波长范围很窄的电磁波产生的,不同波长的电磁波表现为不同的颜色,对色彩的辨认是肉眼受到电磁波辐射能刺激后所引起的视觉神经感觉…

51单片机实验02- P0口流水灯实验

目录 一、实验的背景和意义 二、实验目的 三、实验步骤 四、实验仪器 五、实验任务及要求 1,从led4开始右移 1)思路 ①起始灯 (led4) ②右移 2)效果 3)代码 2,从其他小灯并向右依次…

面向C++程序员的Rust教程(二)

先序文章请看: 面向C程序员的Rust教程(一) 所有权与移动语义 要说Rust语言跟其他语言最大的区别,那笔者觉得非数这个所有权和移动语义莫属。 深浅复制 对于绝大多数语言来说,变量/对象之间的赋值通常都是复制语义。…

微信开发工具——进行网页授权

微信开发工具——进行网页授权 微信公众平台设置 1.在首页创建好自己的订阅号 网站:https://mp.weixin.qq.com/ 点击立即注册,在选择订阅号(个人创建使用) 之后按流程填写后,点击设置与开发-------->基本配置,这…

JAVA八股--redis

JAVA八股--redis 如何保证Redis和数据库数据一致性redisson实现的分布式锁的主从一致性Redis脑裂现象及解决方案介绍I/O多路复用模型undo log 和 redo log(没掌握MyISAM 和 InnoDB 有什么区别? 如何保证Redis和数据库数据一致性 关于异步通知中消息队列…

Kubernetes(k8s):精通 Pod 操作的关键命令

Kubernetes(k8s):精通 Pod 操作的关键命令 1、查看 Pod 列表2、 查看 Pod 的详细信息3、创建 Pod4、删除 Pod5、获取 Pod 日志6、进入 Pod 执行命令7、暂停和启动 Pod8、改变 Pod 副本数量9、查看当前部署中使用的镜像版本10、滚动更新 Pod11…

基于Java+SpringBoot+Mybaties+layui+Vue+elememt 实习管理系统 的设计与实现

一.项目介绍 前台功能:用户进入系统可以实现首页,系统公告,个人中心,后台管理等功能进行操作 后台由管理员,实习单位,教师和学生,主要功能包括首页,个人中心,班级管理&am…

【C++学习】哈希的应用—位图与布隆过滤器

目录 1.位图1.1位图的概念1.2位图的实现3.位图的应用 2.布隆过滤器2.1 布隆过滤器提出2.2布隆过滤器概念2.3如何选择哈希函数个数和布隆过滤器长度2.4布隆过滤器的实现2.4.1布隆过滤器插入操作2.4.2布隆过滤器查找操作2.4.3 布隆过滤器删除 2.5 布隆过滤器优点2.6布隆过滤器缺陷…

Linux:make/makefile的使用

一、什么是makefile/make 会不会写makefile,从一个侧面说明了一个人是否具备完成大型工程的能力 一个工程中的源文件不计数,其按类型、功能、模块分别放在若干个目录中,makefile定义了一系列的 规则来指定,哪些文件需要先编译&am…