Transformer模型:Postion Embedding实现

前言

        这是对上一篇WordEmbedding的续篇PositionEmbedding。

视频链接:19、Transformer模型Encoder原理精讲及其PyTorch逐行实现_哔哩哔哩_bilibili

上一篇链接:Transformer模型:WordEmbedding实现-CSDN博客


正文

        先回顾一下原论文中对Position Embedding的计算公式:pos表示位置,i表示维度索引,d_model表示嵌入向量的维度,position分奇数列和偶数列。

        Position Embedding也是二维的,行数是训练的序列最大长度,列是d_model。首先定义position的最大长度,这里定为12,也就是训练中的长度最大值都是12。

max_position_len = 12

        这里先循环遍历得到pos,构造Pos序列,pos是从0到最大长度的遍历,决定行:

pos_mat = torch.arange(max_position_len)

        但是此时得到的是一维的,我们要将它转为二维矩阵的,也就是得到目标行数,使用.reshape()函数,这样就构造好了行矩阵

pos_mat = torch.arange(max_position_len).reshape((-1,1))

tensor([[ 0],
        [ 1],
        [ 2],
        [ 3],
        [ 4],
        [ 5],
        [ 6],
        [ 7],
        [ 8],
        [ 9],
        [10],
        [11]]) 

        接下来要构造列矩阵,构造 i 序列,首先是是2i/d_model部分,这里的8是因为我们设定的d_model=8,2是步长:

i_mat = torch.arange(0, 8, 2)/model_dim

        这时候再把分母的完整形式实现,幂次使用pow()函数:

i_mat = torch.pow(10000, torch.arange(0, 8, 2)/model_dim)

tensor([   1.,   10.,  100., 1000.]) 

         此时就得到了列向量,这时候就有疑问了为什么列只有4列,我们的d_model不是8吗,应该有8列才对啊。这是因为区分了奇数列跟偶数列的计算,所以这里才要求步长为2生成的只有4列。

        先初始化一个max_position_len*model_dim的零矩阵(12*8),然后再分别使用sin和cos填充偶数列和奇数列:

pe_embedding_table = torch.zeros(max_position_len, model_dim)pe_embedding_table[:, 0::2] = torch.sin(pos_mat/i_mat)   # 从第0列到结束,步长为2,也就是填充偶数列
pe_embedding_table[:, 1::2] = torch.cos(pos_mat/i_mat)   # 从第1列到结束,步长为2,也就是填充奇数列

        得到的就是Position Embedding的权重矩阵了: 

        这下面采用的是使用nn.Embedding()的方法,得到的跟上面的结果还是一样的,只不过这里的pe_embedding是可以传入位置的,之后的调用就是这样得到的:

pe_embedding = nn.Embedding(max_position_len, model_dim)
pe_embedding.weight = nn.Parameter(pe_embedding_table,requires_grad=False)

         这里就要构造位置索引了:

src_pos = torch.cat([torch.unsqueeze(torch.arange(max_position_len),0) for _ in src_len]).to(torch.int32)
tgt_pos = torch.cat([torch.unsqueeze(torch.arange(max_position_len),0) for _ in tgt_len]).to(torch.int32)

        然后传入位置索引,就得到了src跟tgt的Position Embedding:

src_pe_embedding = pe_embedding(src_pos)
tgt_pe_embedding = pe_embedding(tgt_pos)

         这里我很疑惑的点是生成的结果src_pe_embedding跟tgt_pe_embedding内容是一样的,并且单个里面的一个内容也就是position embedding,刚入门听得我还是有点不太能理解。

src_pos is:
 tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11]], dtype=torch.int32)
tgt_pos is:
 tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11]], dtype=torch.int32)
src_pe_embedding is:
 tensor([[[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,
           1.0000e+00,  0.0000e+00,  1.0000e+00],
         [ 8.4147e-01,  5.4030e-01,  9.9833e-02,  9.9500e-01,  9.9998e-03,
           9.9995e-01,  1.0000e-03,  1.0000e+00],
         [ 9.0930e-01, -4.1615e-01,  1.9867e-01,  9.8007e-01,  1.9999e-02,
           9.9980e-01,  2.0000e-03,  1.0000e+00],
         [ 1.4112e-01, -9.8999e-01,  2.9552e-01,  9.5534e-01,  2.9995e-02,
           9.9955e-01,  3.0000e-03,  1.0000e+00],
         [-7.5680e-01, -6.5364e-01,  3.8942e-01,  9.2106e-01,  3.9989e-02,
           9.9920e-01,  4.0000e-03,  9.9999e-01],
         [-9.5892e-01,  2.8366e-01,  4.7943e-01,  8.7758e-01,  4.9979e-02,
           9.9875e-01,  5.0000e-03,  9.9999e-01],
         [-2.7942e-01,  9.6017e-01,  5.6464e-01,  8.2534e-01,  5.9964e-02,
           9.9820e-01,  6.0000e-03,  9.9998e-01],
         [ 6.5699e-01,  7.5390e-01,  6.4422e-01,  7.6484e-01,  6.9943e-02,
           9.9755e-01,  6.9999e-03,  9.9998e-01],
         [ 9.8936e-01, -1.4550e-01,  7.1736e-01,  6.9671e-01,  7.9915e-02,
           9.9680e-01,  7.9999e-03,  9.9997e-01],
         [ 4.1212e-01, -9.1113e-01,  7.8333e-01,  6.2161e-01,  8.9879e-02,
           9.9595e-01,  8.9999e-03,  9.9996e-01],
         [-5.4402e-01, -8.3907e-01,  8.4147e-01,  5.4030e-01,  9.9833e-02,
           9.9500e-01,  9.9998e-03,  9.9995e-01],
         [-9.9999e-01,  4.4257e-03,  8.9121e-01,  4.5360e-01,  1.0978e-01,
           9.9396e-01,  1.1000e-02,  9.9994e-01]],

        [[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,
           1.0000e+00,  0.0000e+00,  1.0000e+00],
         [ 8.4147e-01,  5.4030e-01,  9.9833e-02,  9.9500e-01,  9.9998e-03,
           9.9995e-01,  1.0000e-03,  1.0000e+00],
         [ 9.0930e-01, -4.1615e-01,  1.9867e-01,  9.8007e-01,  1.9999e-02,
           9.9980e-01,  2.0000e-03,  1.0000e+00],
         [ 1.4112e-01, -9.8999e-01,  2.9552e-01,  9.5534e-01,  2.9995e-02,
           9.9955e-01,  3.0000e-03,  1.0000e+00],
         [-7.5680e-01, -6.5364e-01,  3.8942e-01,  9.2106e-01,  3.9989e-02,
           9.9920e-01,  4.0000e-03,  9.9999e-01],
         [-9.5892e-01,  2.8366e-01,  4.7943e-01,  8.7758e-01,  4.9979e-02,
           9.9875e-01,  5.0000e-03,  9.9999e-01],
         [-2.7942e-01,  9.6017e-01,  5.6464e-01,  8.2534e-01,  5.9964e-02,
           9.9820e-01,  6.0000e-03,  9.9998e-01],
         [ 6.5699e-01,  7.5390e-01,  6.4422e-01,  7.6484e-01,  6.9943e-02,
           9.9755e-01,  6.9999e-03,  9.9998e-01],
         [ 9.8936e-01, -1.4550e-01,  7.1736e-01,  6.9671e-01,  7.9915e-02,
           9.9680e-01,  7.9999e-03,  9.9997e-01],
         [ 4.1212e-01, -9.1113e-01,  7.8333e-01,  6.2161e-01,  8.9879e-02,
           9.9595e-01,  8.9999e-03,  9.9996e-01],
         [-5.4402e-01, -8.3907e-01,  8.4147e-01,  5.4030e-01,  9.9833e-02,
           9.9500e-01,  9.9998e-03,  9.9995e-01],
         [-9.9999e-01,  4.4257e-03,  8.9121e-01,  4.5360e-01,  1.0978e-01,
           9.9396e-01,  1.1000e-02,  9.9994e-01]]])
tgt_pe_embedding is:
 tensor([[[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,
           1.0000e+00,  0.0000e+00,  1.0000e+00],
         [ 8.4147e-01,  5.4030e-01,  9.9833e-02,  9.9500e-01,  9.9998e-03,
           9.9995e-01,  1.0000e-03,  1.0000e+00],
         [ 9.0930e-01, -4.1615e-01,  1.9867e-01,  9.8007e-01,  1.9999e-02,
           9.9980e-01,  2.0000e-03,  1.0000e+00],
         [ 1.4112e-01, -9.8999e-01,  2.9552e-01,  9.5534e-01,  2.9995e-02,
           9.9955e-01,  3.0000e-03,  1.0000e+00],
         [-7.5680e-01, -6.5364e-01,  3.8942e-01,  9.2106e-01,  3.9989e-02,
           9.9920e-01,  4.0000e-03,  9.9999e-01],
         [-9.5892e-01,  2.8366e-01,  4.7943e-01,  8.7758e-01,  4.9979e-02,
           9.9875e-01,  5.0000e-03,  9.9999e-01],
         [-2.7942e-01,  9.6017e-01,  5.6464e-01,  8.2534e-01,  5.9964e-02,
           9.9820e-01,  6.0000e-03,  9.9998e-01],
         [ 6.5699e-01,  7.5390e-01,  6.4422e-01,  7.6484e-01,  6.9943e-02,
           9.9755e-01,  6.9999e-03,  9.9998e-01],
         [ 9.8936e-01, -1.4550e-01,  7.1736e-01,  6.9671e-01,  7.9915e-02,
           9.9680e-01,  7.9999e-03,  9.9997e-01],
         [ 4.1212e-01, -9.1113e-01,  7.8333e-01,  6.2161e-01,  8.9879e-02,
           9.9595e-01,  8.9999e-03,  9.9996e-01],
         [-5.4402e-01, -8.3907e-01,  8.4147e-01,  5.4030e-01,  9.9833e-02,
           9.9500e-01,  9.9998e-03,  9.9995e-01],
         [-9.9999e-01,  4.4257e-03,  8.9121e-01,  4.5360e-01,  1.0978e-01,
           9.9396e-01,  1.1000e-02,  9.9994e-01]],

        [[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,
           1.0000e+00,  0.0000e+00,  1.0000e+00],
         [ 8.4147e-01,  5.4030e-01,  9.9833e-02,  9.9500e-01,  9.9998e-03,
           9.9995e-01,  1.0000e-03,  1.0000e+00],
         [ 9.0930e-01, -4.1615e-01,  1.9867e-01,  9.8007e-01,  1.9999e-02,
           9.9980e-01,  2.0000e-03,  1.0000e+00],
         [ 1.4112e-01, -9.8999e-01,  2.9552e-01,  9.5534e-01,  2.9995e-02,
           9.9955e-01,  3.0000e-03,  1.0000e+00],
         [-7.5680e-01, -6.5364e-01,  3.8942e-01,  9.2106e-01,  3.9989e-02,
           9.9920e-01,  4.0000e-03,  9.9999e-01],
         [-9.5892e-01,  2.8366e-01,  4.7943e-01,  8.7758e-01,  4.9979e-02,
           9.9875e-01,  5.0000e-03,  9.9999e-01],
         [-2.7942e-01,  9.6017e-01,  5.6464e-01,  8.2534e-01,  5.9964e-02,
           9.9820e-01,  6.0000e-03,  9.9998e-01],
         [ 6.5699e-01,  7.5390e-01,  6.4422e-01,  7.6484e-01,  6.9943e-02,
           9.9755e-01,  6.9999e-03,  9.9998e-01],
         [ 9.8936e-01, -1.4550e-01,  7.1736e-01,  6.9671e-01,  7.9915e-02,
           9.9680e-01,  7.9999e-03,  9.9997e-01],
         [ 4.1212e-01, -9.1113e-01,  7.8333e-01,  6.2161e-01,  8.9879e-02,
           9.9595e-01,  8.9999e-03,  9.9996e-01],
         [-5.4402e-01, -8.3907e-01,  8.4147e-01,  5.4030e-01,  9.9833e-02,
           9.9500e-01,  9.9998e-03,  9.9995e-01],
         [-9.9999e-01,  4.4257e-03,  8.9121e-01,  4.5360e-01,  1.0978e-01,
           9.9396e-01,  1.1000e-02,  9.9994e-01]]])

 代码

import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F# 句子数
batch_size = 2# 单词表大小
max_num_src_words = 10
max_num_tgt_words = 10# 序列的最大长度
max_src_seg_len = 12
max_tgt_seg_len = 12
max_position_len = 12# 模型的维度
model_dim = 8# 生成固定长度的序列
src_len = torch.Tensor([11, 9]).to(torch.int32)
tgt_len = torch.Tensor([10, 11]).to(torch.int32)
print(src_len)
print(tgt_len)#单词索引构成的句子
src_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_src_words, (L,)),(0, max_src_seg_len-L)), 0) for L in src_len])
tgt_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_tgt_words, (L,)),(0, max_tgt_seg_len-L)), 0) for L in tgt_len])
print(src_seq)
print(tgt_seq)# 构造Word Embedding
src_embedding_table = nn.Embedding(max_num_src_words+1, model_dim)
tgt_embedding_table = nn.Embedding(max_num_tgt_words+1, model_dim)
src_embedding = src_embedding_table(src_seq)  
tgt_embedding = tgt_embedding_table(tgt_seq)
print(src_embedding_table.weight)    
print(src_embedding)    
print(tgt_embedding)# 构造Pos序列跟i序列
pos_mat = torch.arange(max_position_len).reshape((-1, 1))    
i_mat = torch.pow(10000, torch.arange(0, 8, 2)/model_dim)# 构造Position Embedding
pe_embedding_table = torch.zeros(max_position_len, model_dim)    
pe_embedding_table[:, 0::2] = torch.sin(pos_mat/i_mat)
pe_embedding_table[:, 1::2] = torch.cos(pos_mat/i_mat)
print("pe_embedding_table is:\n",pe_embedding_table)pe_embedding = nn.Embedding(max_position_len, model_dim)
pe_embedding.weight = nn.Parameter(pe_embedding_table,requires_grad=False)
print(pe_embedding.weight)# 构建位置索引
src_pos = torch.cat([torch.unsqueeze(torch.arange(max_position_len),0) for _ in src_len]).to(torch.int32)
tgt_pos = torch.cat([torch.unsqueeze(torch.arange(max_position_len),0) for _ in tgt_len]).to(torch.int32)
print("src_pos is:\n",src_pos)
print("tgt_pos is:\n",tgt_pos)src_pe_embedding = pe_embedding(src_pos)
tgt_pe_embedding = pe_embedding(tgt_pos)
print("src_pe_embedding is:\n",src_pe_embedding)
print("tgt_pe_embedding is:\n",tgt_pe_embedding)

参考

Python的reshape的用法:reshape(1,-1)-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/871928.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[Windows] 号称最快免费小巧的远程桌面 AnyDesk v8.0.11单文件版

描述 对于经常在互联网上进行操作的学生,白领等! 一款好用的软件总是能得心应手,事半功倍。 今天给大家带了一款高科技软件 虽然 QQ 拥有远程协助功能,但很多时候连接并不够流畅,而且被控电脑那方也必须要有人操作才行…

电脑关机被阻止

1. winR输入regedit进入注册表 2. 选择HKEY_USERS-》.DEFAULT-》Control Panel-》Desktop 3. 右键DeskTop新建字符串值,命名为AutoEndTasks,数值设置为1

C++中链表的底层迭代器实现

大家都知道在C的学习中迭代器是必不可少的,今天我们学习的是C中的链表的底层迭代器的实现,首先我们应该先知道链表的底层迭代器和顺序表的底层迭代器在实现上有什么区别,为什么顺序表的底层迭代器更加容易实现,而链表的底层迭代器…

不会编程怎么办?量化交易不会编程可以使用吗?

量化交易使用计算机模型程序代替人工进行交易,一般需要投资者自己编写程序建模,然后回测无误之后再进行实盘交易,那么不会编程的投资者能使用量化软件进行量化交易吗? 不会编程使用量化软件有两种方法 一种是请人代写代码&#x…

Java软件设计模式-单例设计模式

目录 1.软件设计模式的概念 2.设计模式分类 2.1 创建型模式 2.2 结构型模式 2.3 行为型模式 3.单例设计模式 3.1 单例模式的结构 3.2 单例模式的实现 3.2.1 饿汉式-方式1(静态变量方式) 3.2.2 懒汉式-方式1(线程不安全) 3.…

职场新人感受

互联网职场感受 阶段介绍 24届6月底毕业生,之前从未实习过。 岗位是后端开发(JAVA),目前已经上班三周(前两周看文档和做了半个简单需求,第三周脱产新人培训)。 职场体验 职场和想象中的工作…

c++ 网络编程udp协议 poco模块

官网资料(需要梯子):https://pocoproject.org/slides/200-Network.pdf 1、poco是在原生socket之上的封装,底层还是socket,性能低于socket,安全性和实用性高于socket,即使用简便,接口简单 2、udp协议是&…

办公灯多普勒雷达模组感应开关,飞睿智能24G毫米波雷达超低功耗uA级,节能LED灯新搭档

在科技日新月异的今天,节能、环保已经成为我们生活和工作中不可或缺的一部分。作为新时代的办公人,我们不仅要追求高效的工作方式,更要关注我们所使用的设备是否足够环保、节能。今天,我们就来聊聊一个令人兴奋的创新——飞睿智能…

COMX-P2020、COMX-P1022 vxWorks系统开发主机

提供vxworks6.9开发环境和BSP源码,具有千兆以太网,调试串口,4个PCIe插槽,支持PCIe 1.0a和msi中断,底板板载一块Xilinx CPLD XC95144,提供ISE14.7安装包和verilog源码。可定制开发基于PCIe接口和fpga的拓展接…

多语言环境大师:在PyCharm中管理多个Python解释器

多语言环境大师:在PyCharm中管理多个Python解释器 PyCharm作为业界领先的Python集成开发环境(IDE),支持多种Python解释器的配置和管理,使得开发者可以针对不同项目使用不同的Python环境。本文将详细介绍如何在PyCharm…

如何30分钟下载完368G的Android系统源码?

如何30分钟下载完368G的Android系统源码? Android系统开发的一个痛点问题就是Android系统源码庞大,小则100G,大则,三四百G。如标题所言,本文介绍通过局域网高速网速下载源码的方法。 制作源码mirror 从源码git服务器A&#xff0c…

推荐系统:从协同过滤到深度学习

目录 一、协同过滤(Collaborative Filtering, CF)1. 基于用户的协同过滤2. 基于物品的协同过滤 二、深度学习在推荐系统中的应用1. 深度学习模型的优势2. 深度学习在推荐系统中的应用实例 三、总结与展望 推荐系统是现代信息处理和传播中不可或缺的技术&…

【话题】破茧而出:打破AI“信息茧房”,捍卫信息自由与多样性

目录 AI发展下的伦理挑战,应当如何应对? 方向一:构建可靠的AI隐私保护机制 方向二:确保AI算法的公正性和透明度 方向三:管控深度伪造技术 AI发展下的伦理挑战,应当如何应对? 在人工智能&…

Tita的OKR:高端制造行业的OKR案例

高端设备制造行业的发展趋势: 产业规模持续扩大:在高技术制造业方面,航空、航天器及设备制造业、电子工业专用设备制造等保持较快增长。新能源汽车保持产销双增,新材料新产品生产也高速增长。 标志性装备不断突破:例如…

数据结构第27节 优先队列

优先队列(Priority Queue)是在计算机科学中一种非常有用的抽象数据类型,它与标准队列的主要区别在于元素的出队顺序不是先进先出(FIFO),而是基于每个元素的优先级。具有较高优先级的元素会比低优先级的元素…

论文写作经验-摘要1

小王搬运工 时序课堂 2024年07月15日 13:10 新疆 本人菜鸡一名,最近几篇论文实验跑的比较顺利,结果也很不错,奈何于自己写作能力巨差,导致文章屡屡被拒。当前正在跟一位非常牛的老师学习写作技巧,我将一些心得体会和技…

MySQL教程 | 笔记 (包含数据库、表设计,数据库的增删改查操作;数据库优化等知识点)

SQL简介 一门操作关系型数据库的编程语言,定义操作所有关系型数据库的统一标准 通用语法: 可以单行或者多行书写,以分号结尾; 可以使用空格 / SQL语句可以使用空格/缩进来增强语句的可读性。 MySQL数据库的SOL语句不区分大小…

Flink Window 窗口【更新中】

Flink Window 窗口 在Flink流式计算中,最重要的转换就是窗口转换Window,在DataStream转换图中,可以发现处处都可以对DataStream进行窗口Window计算。 窗口(window)就是从 Streaming 到 Batch 的一个桥梁。窗口将无界流…

C#+GDAL影像处理笔记09:创建多边形、多部件图形、合并相邻的多边形

使用GDAL创建多边形、多部件要素、相邻面合并、以及shape文件创建的完整过程 1. 创建一个多边形 多边形必须闭合 // 创建第一个多边形几何对象Geometry polygon1 = new Geometry(wkbGeometryType.wkbPolygon);Geometry ring1 = new Geometry(wkbGeometryType.wkbLinearRing);…

银河麒麟如何部署QtMqtt(入门案例教程)

QtMqtt是一个基于Qt的MQTT客户端库,提供了使用MQTT协议与 MQTT broker 进行通信的功能。silver-linix是一个基于Linux的操作系统,用于嵌入式系统和物联网设备。下面将教您如何在silver-linix上部署QtMqtt。 1. 安装QtMqtt 1.1 安装QtMqtt依赖项 QtMqtt依赖于Qt和QtNetwork…