Transformer的代码实现 day03(Positional Encoding)

Positional Encoding的理论部分

  • 注意力机制是不含有位置信息,这也就表明:“我爱你”,“你爱我”这两者没有区别,而在现实世界中,这两者有区别。
  • 所以位置编码是在进行注意力计算之前,给输入加上一个位置信息,如下图:
    在这里插入图片描述
  • 位置编码的公式如下:
    • 注意,pos表示该单词在句子中的位置,i表示该单词的输入向量的第i维度
      在这里插入图片描述
  • 由此我们可以得出不同位置之间的位置编码关系:
    在这里插入图片描述

Positional Encoding代码

  • 由于位置编码的公式固定,所以对于相同位置的位置编码也固定,即“我爱你”中的我,和“你爱我”中的你的位置编码相同
  • 所以我们可以一次将所有要输入信息的位置编码都生成出来,之后需要哪个就传哪个
class PositionalEncoding(nn.Module):def __init__(self, dim, dropout, max_len=5000):super(PositionalEncoding, self).__init__()# 确保每个单词的输入维度为偶数,这样sin和cos能配对if dim % 2 != 0:raise ValueError("Cannot use sin/cos positional encoding with ""odd dim (got dim={:d})".format(dim))"""构建位置编码pepe公式为:PE(pos,2i/2i+1) = sin/cos(pos/10000^{2i/d_{model}})"""pe = torch.zeros(max_len, dim)  # max_len 是解码器生成句子的最长的长度,假设是 10,dim为单词的输入维度# 将位置序号从一维变为只有一列的二维,方便与div_term进行运算,# 如将[0, 1, 2, 3, 4]变为:#[  #  [0],  #  [1],  #  [2],  #  [3],  #  [4]  #]position = torch.arange(0, max_len).unsqueeze(1)# 这里使用a^b = e^(blna)公式,来简化运算# torch.arange(0, dim, 2, dtype=torch.float)表示从0到dim-1,步长为2的一维张量# 通过以下公式,我们可以得出全部2i的(pos/10000^2i/dim)方便接下来的pe计算div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) *-(math.log(10000.0) / dim)))# 得出的div_term为从0开始,到dim-1,长度为dim/2,步长为2的一维张量# 将position与div_term做张量乘法,得到的张量形状为(max_len,dim/2)# 将结果取sin赋给pe中偶数维度,取cos赋给pe中奇数维度pe[:, 0::2] = torch.sin(position.float() * div_term)pe[:, 1::2] = torch.cos(position.float() * div_term)# 将pe的形状从(max_len,dim)变成(max_len,1,dim),在第二个维度上增加一个大小为1的新维度# 如从原始 pe 张量形状: (5, 4)  #[  # [a1, b1, c1, d1],  # [a2, b2, c2, d2],  # [a3, b3, c3, d3],  # [a4, b4, c4, d4],  # [a5, b5, c5, d5]  #]# 转换为:执行 unsqueeze(1) 后的 pe 张量形状: (5, 1, 4)  #[  # [[a1, b1, c1, d1]],  # [[a2, b2, c2, d2]],  # [[a3, b3, c3, d3]],  # [[a4, b4, c4, d4]],  # [[a5, b5, c5, d5]]  #]pe = pe.unsqueeze(1)# 将pe张量注册为模块的buffer。在PyTorch中,buffer是模型的一部分,但不包含可学习的参数(即不需要梯度)。# 这样做是因为位置编码在训练过程中是固定的,不需要更新。self.register_buffer('pe', pe)self.drop_out = nn.Dropout(p=dropout)self.dim = dimdef forward(self, emb, step=None):# 做乘法是因为在 Transformer 模型中,位置编码被加到输入张量中,而输入张量通常是词嵌入的向量,其值通常在较小的范围内。# 但是,在将位置编码添加到输入张量之前,我们希望将其值扩大到一个较大的范围,以便位置编码对输入的影响更加显著。# 注意:emb为输入张量,形状为(seq_len, dim),seq_len 表示输入的句子的长度,dim为单词的输入维度emb = emb * math.sqrt(self.dim)# 根据step来选择加入pe的哪一部分if step is None:# 如果pe的形状为(max_len, dim),那么pe[:a]表示:取pe的第0行到第a-1行的全部元素,得到的新二维张量的形状为(a, dim)# 而pe[:, a]表示:取pe的第a-1列的全部元素,得到的新一维张量的形状为(max_len)# 而pe[:, :a]表示:取pe的第0列到第a-1列的全部元素,得到的新二维张量的形状为(max_len,a)emb = emb + self.pe[:emb.size(0)]else:emb = emb + self.pe[step]emb = self.drop_out(emb)return emb

参考文献

  1. 04 Transformer 中的位置编码的 Pytorch 实现

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/793908.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【RISC-V 指令集】RISC-V 向量V扩展指令集介绍(五)- 向量加载和存储

1. 引言 以下是《riscv-v-spec-1.0.pdf》文档的关键内容: 这是一份关于向量扩展的详细技术文档,内容覆盖了向量指令集的多个关键方面,如向量寄存器状态映射、向量指令格式、向量加载和存储操作、向量内存对齐约束、向量内存一致性模型、向量…

Redis -- 缓存穿透问题解决思路

缓存穿透 :缓存穿透是指客户端请求的数据在缓存中和数据库中都不存在,这样缓存永远不会生效,这些请求都会打到数据库。 常见的解决方案有两种: 缓存空对象 优点:实现简单,维护方便 缺点: 额外…

解密ChatGPT技巧:提升学术论文写作水平

ChatGPT无限次数:点击直达 解密ChatGPT技巧:提升学术论文写作水平 在当今信息爆炸的时代,学术论文写作已经成为学术研究者们日常不可或缺的任务之一。而借助人工智能技术,特别是最新的语言模型ChatGPT,可以极大地提升论文写作的效…

【JavaSE】接口 详解(上)

前言 本篇会讲到Java中接口内容,概念和注意点可能比较多,需要耐心多看几遍,我尽可能的使用经典的例子帮助大家理解~ 欢迎关注个人主页:逸狼 创造不易,可以点点赞吗~ 如有错误,欢迎指出~ 目录 前言 接口 语法…

RabbitMQ面经 手敲浓缩版

保证可靠性 生产者 本地事务完成和消息发送同时完成 通过事务消息完成 重写confirm在里面做逻辑处理 确保发送成功(不成功就放入到重试队列) MQ 打开持久化确保消息不会丢失 消费者 改成手动回应 不重复消费 生产者 保证不重复发送消息 消费者…

pta 1086 就不告诉你

1086 就不告诉你 分数 15 全屏浏览 切换布局 作者 CHEN, Yue 单位 浙江大学 做作业的时候,邻座的小盆友问你:“五乘以七等于多少?”你应该不失礼貌地围笑着告诉他:“五十三。”本题就要求你,对任何一对给定的正整数…

新手开抖店:选品过后如何有效对接达人?这些方法100%有效!

哈喽~我是电商月月 要说做抖音小店最主要的是什么?那当然是找品了 那出单最快的方法是什么?无疑是达人带货了! 但新手店铺没销量,没体验分,没好评怎么能让达人同意帮我们带货呢? 方法其实很简单&#x…

“双碳”目标下资源环境中的可计算一般均衡(CGE)模型应用

我国政府承诺在2030年实现“碳达峰”,2060年实现“碳中和”,这就是“双碳”目标。为了实现这一目标就必须应用各种二氧化碳排放量很高技术的替代技术,不仅需要考虑技术上的可靠性,也需要考虑经济上的可行性。可计算一般均衡模型&a…

ChatGPT新手指南:如何应用于学术论文撰写

ChatGPT无限次数:点击直达 html ChatGPT新手指南:如何应用于学术论文撰写 在当今信息爆炸的时代,学术论文写作是许多研究人员、学生和学者每天都要面对的任务。随着人工智能技术的不断发展,如何利用自然语言生成模型来辅助学术论文的撰写…

AI预测福彩3D第26弹【2024年4月4日预测--第4套算法重新开始计算第11次测试】

今天清明节假日,一会要外出,可能要晚点回来。咱们尽早先把预测数据跑完,把结果发出来供各位彩友参考。合并下算法,3D的预测以后将重点测试本套算法,因为本套算法的命中率较高。以后有时间的话会在第二篇文章中发布排列…

JDK版本发布历史

以下是Java Development Kit(JDK)主要版本的发布历史: JDK 1.0:1996年1月23日发布,是Java的首个正式版本。JDK 1.1:1997年2月19日发布,引入了内部类、反射、JAR文件等新特性。JDK 1.2&#xff…

UTONMOS:AI+Web3+元宇宙数字化“三位一体”将触发经济新爆点

人工智能、元宇宙、Web3,被称为数字化的“三位一体”,如何看待这三大技术所扮演的角色? 3月24日,2024全球开发者先锋大会“数字化的三位一体——人工智能、元宇宙、Web3.0”论坛在上海漕河泾开发区举行,首次提出&…

深入探索MySQL:成本模型解析与查询性能优化,及未来深度学习与AI模型的应用展望

码到三十五 : 个人主页 在数据库管理系统中,查询优化器是一个至关重要的组件,它负责将用户提交的SQL查询转换为高效的执行计划。在MySQL中,查询优化器使用了一个称为“成本模型”的机制来评估不同执行计划的优劣,并选择…

网络安全 | 什么是负载均衡器?

关注WX: CodingTechWork 介绍 负载均衡是在多个服务器之间有效分配网络流量的过程。负载均衡的目的是优化应用程序的可用性,并确保良好的终端用户体验。负载均衡可协助高流量网站和云计算应用程序应对数百万个用户请求,从而保证客户请求不会…

2012年认证杯SPSSPRO杯数学建模C题(第二阶段)碎片化趋势下的奥运会商业模式全过程文档及程序

2012年认证杯SPSSPRO杯数学建模 C题 碎片化趋势下的奥运会商业模式 原题再现: 从 1984 年的美国洛杉矶奥运会开始,奥运会就不在成为一个“非卖品”,它在向观众诠释更高更快更强的体育精神的同时,也在攫取着巨大的商业价值&#…

颜色空间/模型(RGB, YUV,CMY/CMYK, HSI, HSV等)

什么是颜色 颜色是通过眼、脑和我们的生活经验所产生的对光的视觉感受,我们肉眼所见到的光线,是由波长范围很窄的电磁波产生的,不同波长的电磁波表现为不同的颜色,对色彩的辨认是肉眼受到电磁波辐射能刺激后所引起的视觉神经感觉…

51单片机实验02- P0口流水灯实验

目录 一、实验的背景和意义 二、实验目的 三、实验步骤 四、实验仪器 五、实验任务及要求 1,从led4开始右移 1)思路 ①起始灯 (led4) ②右移 2)效果 3)代码 2,从其他小灯并向右依次…

面向C++程序员的Rust教程(二)

先序文章请看: 面向C程序员的Rust教程(一) 所有权与移动语义 要说Rust语言跟其他语言最大的区别,那笔者觉得非数这个所有权和移动语义莫属。 深浅复制 对于绝大多数语言来说,变量/对象之间的赋值通常都是复制语义。…

微信开发工具——进行网页授权

微信开发工具——进行网页授权 微信公众平台设置 1.在首页创建好自己的订阅号 网站:https://mp.weixin.qq.com/ 点击立即注册,在选择订阅号(个人创建使用) 之后按流程填写后,点击设置与开发-------->基本配置,这…

JAVA八股--redis

JAVA八股--redis 如何保证Redis和数据库数据一致性redisson实现的分布式锁的主从一致性Redis脑裂现象及解决方案介绍I/O多路复用模型undo log 和 redo log(没掌握MyISAM 和 InnoDB 有什么区别? 如何保证Redis和数据库数据一致性 关于异步通知中消息队列…