Transformer(2)--位置编码器

文章目录

  • 一、嵌入表示层
  • 二、流程详解
    • 1.初始化位置编码器
    • 2.计算位置编码
    • 3.扩维,与输入张量匹配
    • 4.添加位置编码到输入张量上
  • 三、完整代码


一、嵌入表示层

  对于输入文本序列,首先通过输入嵌入层(Input Embedding)将每个单词转换为其相对应的向量表示。通常直接对每个单词创建一个向量表示。由于 Transfomer 模型不再使用基于循环的方式建模文本输入,序列中不再有任何信息能够提示模型单词之间的相对位置关系。在送入编码器端建模其上下文语义之前,一个非常重要的操作是在词嵌入中加入位置编码(Positional Encoding)这一特征。具体来说,序列中每一个单词所在的位置都对应一个向量。这一向量会与单词表示对应相加并送入到后续模块中做进一步处理。在训练的过程当中,模型会自动地学习到如何利用这部分位置信息。
  为了得到不同位置对应的编码,Transformer 模型使用不同频率的正余弦函数如下所示:

for pos in range(max_seq_len):for i in range(0, d_model, 2):pe[pos, i] = math.sin(pos / (10000 ** ((2 * i) / d_model)))pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1)) / d_model)))

  其中,pos 表示单词所在的位置,2i 和 2i+ 1 表示位置编码向量中的对应维度,d 则对应位置编码的总维度。通过上面这种方式计算位置编码有这样几个好处:首先,正余弦函数的范围是在 [-1,+1],导出的位置编码与原词嵌入相加不会使得结果偏离过远而破坏原有单词的语义信息。其次,依据三角函数的基本性质,可以得知第 pos + k 个位置的编码是第 pos 个位置的编码的线性组合,这就意味着位置编码中蕴含着单词之间的距离信息。

二、流程详解

1.初始化位置编码器

# 初始化位置编码矩阵 pe,形状为 (max_seq_len, d_model)
pe = torch.zeros(max_seq_len, d_model)

打印初始化位置编码矩阵

(Pdb) p pe
tensor([[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],...,[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.]])

2.计算位置编码

for pos in range(max_seq_len):for i in range(0, d_model, 2):pe[pos, i] = math.sin(pos / (10000 ** ((2 * i) / d_model)))pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1)) / d_model)))

打印位置编码矩阵

tensor([[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  ...,  1.0000e+00,0.0000e+00,  1.0000e+00],[ 8.4147e-01,  5.6969e-01,  8.0196e-01,  ...,  1.0000e+00,1.0746e-08,  1.0000e+00],[ 9.0930e-01, -3.5090e-01,  9.5814e-01,  ...,  1.0000e+00,2.1492e-08,  1.0000e+00],...,[ 3.7961e-01,  7.8033e-01,  7.4511e-01,  ...,  1.0000e+00,1.0424e-06,  1.0000e+00],[-5.7338e-01,  9.5851e-01, -8.9752e-02,  ...,  1.0000e+00,1.0531e-06,  1.0000e+00],[-9.9921e-01,  3.1179e-01, -8.5234e-01,  ...,  1.0000e+00,1.0639e-06,  1.0000e+00]])

3.扩维,与输入张量匹配

pe = pe.unsqueeze(0)
(Pdb) p pe.shape
torch.Size([1, 100, 512])

4.添加位置编码到输入张量上

x = x + self.pe[:, :seq_len].detach().to(x.device)
示例 1: 在 CPU 上运行
位置编码后的张量 (CPU): tensor([[[ -4.4866,  -3.6170,   7.9131,  ...,  -5.4459,  15.9657,   4.2406],[-47.0210, -13.7024, -40.5477,  ...,  34.5023,   0.4545, -32.0102],[ 16.6810,  12.8272,  40.9043,  ..., -12.4140,  70.6676, -14.0449],...,[ -8.1882,   1.9146,  25.2393,  ...,  16.1251, -24.0830, -25.0094],[ 35.0248,  -0.2711, -40.9559,  ...,  -3.2930,  29.2630,  13.0763],[ -2.8143, -10.6067,  43.7963,  ...,   8.7323,   7.0742,  -8.5050]],

三、完整代码

import math
import torch
import torch.nn as nnclass PositionalEncoder(nn.Module):def __init__(self, d_model, max_seq_len=100):"""初始化位置编码器。参数:- d_model: 每个位置的嵌入维度。- max_seq_len: 支持的最大序列长度。"""super(PositionalEncoder, self).__init__()self.d_model = d_model# 初始化位置编码矩阵 pe,形状为 (max_seq_len, d_model)pe = torch.zeros(max_seq_len, d_model)# 计算位置编码值for pos in range(max_seq_len):for i in range(0, d_model, 2):pe[pos, i] = math.sin(pos / (10000 ** ((2 * i) / d_model)))pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1)) / d_model)))# 增加批次维度,形状变为 (1, max_seq_len, d_model)pe = pe.unsqueeze(0)# 注册位置编码矩阵为缓冲区,确保其不会作为模型参数被更新self.register_buffer('pe', pe)def forward(self, x):"""前向传播方法,将位置编码添加到输入张量 x 上。参数:- x: 输入张量,形状为 (batch_size, seq_len, d_model)返回:- 带有位置编码的输入张量"""# 使得单词嵌入表示相对大一些x = x * math.sqrt(self.d_model)# 获取输入序列长度seq_len = x.size(1)# 检查输入序列长度是否超过最大序列长度if seq_len > self.pe.size(1):raise ValueError(f"Input sequence length ({seq_len}) exceeds maximum sequence length ({self.pe.size(1)}) for positional encoding.")# 添加位置编码到输入张量上,并确保张量在同一个设备上x = x + self.pe[:, :seq_len].detach().to(x.device)"""[:, :seq_len] 表示对一个张量(或数组)进行切片操作,其中 : 表示对第一个维度(通常是行)进行完整切片,而 :seq_len 表示对第二个维度(通常是列)进行从第0列到第 seq_len - 1 列的切片。detach() 是一个函数调用,用于创建一个新的张量,与原始张量共享相同的数据,但不进行梯度追踪。.detach() 的目的是将切片操作的结果从计算图中分离出来,以便后续的计算不会影响到原始张量的梯度计算。to(x.device) 是一个张量的方法,用于将张量移动到指定的计算设备上。其中 x.device 表示张量 x 当前所在的计算设备。这个操作的目的是将切片结果转移到与张量 x 相同的设备上,以便后续的计算能够在相同的设备上进行。"""return x# 使用示例
d_model = 512  # 每个位置的嵌入维度
seq_len = 100  # 输入序列的长度
batch_size = 32  # 批次大小# 初始化位置编码器,确保 max_seq_len >= seq_len
pos_encoder = PositionalEncoder(d_model, max_seq_len=seq_len)# 创建一个随机张量作为输入,形状为 (batch_size, seq_len, d_model)
x = torch.randn(batch_size, seq_len, d_model)# 示例 1: 在 CPU 上运行
print("示例 1: 在 CPU 上运行")
x_cpu = x  # 确保张量在 CPU 上
pos_encoder_cpu = pos_encoder  # 确保位置编码器在 CPU 上
x_encoded_cpu = pos_encoder_cpu(x_cpu)  # 添加位置编码
print("位置编码后的张量 (CPU):", x_encoded_cpu)# 示例 2: 在 GPU 上运行(如果可用)
if torch.cuda.is_available():print("示例 2: 在 GPU 上运行")device = torch.device("cuda")x_gpu = x.to(device)  # 将张量移动到 GPUpos_encoder_gpu = pos_encoder.to(device)  # 将位置编码器移动到 GPUx_encoded_gpu = pos_encoder_gpu(x_gpu)  # 添加位置编码print("位置编码后的张量 (GPU):", x_encoded_gpu)
else:print("GPU 不可用,跳过 GPU 示例")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/13116.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Oracle数据库查询各表空间的占用比例

目录 1、查询各表数据记录和数据大小 2、查询数据库已有表空间的大小 3、查询某表空间下各表占用突然间的大小 1、查询各表数据记录和数据大小 select a.table_name "表名",a.num_rows "数据记录",b.total "总大小(MB)" from us…

【前端】CSS基础(4)

文章目录 前言1、CSS常用属性1.1 文本属性1.1.1 文本对齐1.1.2 文本装饰1.1.3 文本缩进1.1.5 行高 前言 这篇博客仅仅是对CSS的基本结构进行了一些说明,关于CSS的更多讲解以及HTML、Javascript部分的讲解可以关注一下下面的专栏,会持续更新的。 链接&…

golang http请求返回 io.ReadCloser 数据读取和编码转换为utf8 注意事项

在go语言中我们发送一个http请求后, 我们需要通过resp返回体中的Body对象(是一个 io.ReadCloser对象)来对请求返回的数据进行读取。 对于这类Reader的数据读取我们需要先定义一个byte切片, 然后通过循环来对reader中的数据进行读取&#xff…

Day_5

1. Apache ECharts Apache ECharts 是一款基于 Javascript 的数据可视化图表库,提供直观,生动,可交互,可个性化定制的数据可视化图表 官网地址:https://echarts.apache.org/zh/index.html 入门案例 快速入门&#x…

记录一下-排查免密登录过程

过程记录 2024-05-15 18:15:15 在本地机器上生成新的密钥对: ssh-keygen -t rsa -b 2048 -m PEM -f ~/.ssh/id_rsa_new2024-05-15 18:25:37 将新生成的公钥复制到服务器: ssh-copy-id -i ~/.ssh/id_rsa_new.pub xaykt10.24.17.52024-05-15 18:10:58 执…

企业计算机服务器中了faust勒索病毒如何处理,faust勒索病毒解密恢复

随着网络技术的不断发展与应用,越来越多的企业利用网络走向了数字化办公模式,网络也极大地方便了企业生产运营,大大提高了企业生产效率,但对于众多企业来说,企业的数据安全一直是大家关心的主要话题,保护好…

fastjson2使用

说明:fastjson2是一个性能极致并且简单易用的Java JSON库(官方语),本文介绍在Spring Boot项目中如何使用fastjson2。 创建项目 首先,创建一个Maven项目,引入fastjson2依赖,如下: …

战网国际服注册教程 暴雪战网国际服账号注册一站式教程分享

战网国际版,也即Battle.net环球版,是由暴雪娱乐操刀的全球化游戏交流枢纽,它突破地理限制,拥抱全世界的游戏玩家。与仅限特定地区的版本不同,国际版为玩家开辟了无障碍通道,让他们得以自由探索暴雪庞大游戏…

Python使用fastdfs-client与FastDFS交互

1. 安装(要求Python3.10) pip install fastdfs-client 注:Python3.8和Python3.9可以用这个GitHub - waketzheng/fastdfs-client-python at 1.0.1 2. 使用 from pathlib import Path from fastdfs_client import FastdfsClientclient Fas…

如何使用JMeter测试导入接口/导出接口?

🍅 视频学习:文末有免费的配套视频可观看 🍅 关注公众号:互联网杂货铺,回复1 ,免费获取软件测试全套资料,资料在手,涨薪更快 今天上班,被开发问了一个问题:JM…

opencv 轮廓区域检测

直线检测 void LineDetect(const cv::Mat &binaryImage) {cv::Mat xImage,yImage,binaryImage1,binaryImage2;// 形态学变化,闭操作 先膨胀,再腐蚀 可以填充小洞,填充小的噪点cv::Mat element cv::getStructuringElement(cv::MORPH_RE…

最小质数对-第12届蓝桥杯国赛Python真题解析

[导读]:超平老师的Scratch蓝桥杯真题解读系列在推出之后,受到了广大老师和家长的好评,非常感谢各位的认可和厚爱。作为回馈,超平老师计划推出《Python蓝桥杯真题解析100讲》,这是解读系列的第63讲。 最小质数对&#…

Flutter 中的 Icon 小部件:全面指南

Flutter 中的 Icon 小部件:全面指南 Flutter 提供了多种方式来展示图标,其中 Icon 是最常用的小部件之一。它不仅用于展示简单的图标,还可以与文本、按钮和其他小部件组合使用,以增强用户界面的交互性。本篇文章将详细介绍 Icon …

Windows内核函数 - ANSI_STRING字符串与UNICODE_STRING字符串

DDK不鼓励程序员使用C语言的字符串,主要是因为:标准C的字符串处理函数容易导致缓冲区溢出等错误。如果程序员不对字符串的长度进行检查,很容易导致这个错误,从而导致整个操作系统的崩溃。DDK鼓励程序员使用DDK自定义的字符串&…

基于SSM的“羽毛球馆管理系统”的设计与实现(源码+数据库+文档)

基于SSM的“羽毛球馆管理系统”的设计与实现(源码数据库文档) 开发语言:Java 数据库:MySQL 技术:SSM 工具:IDEA/Ecilpse、Navicat、Maven 系统展示 系统结构图 登录界面 后台用户添加 后台用户管理 球场添加 球场…

英特尔处理器-----ERMS

ERMS,全称为Enhanced REP MOVSB/STOSB,是英特尔处理器的一种特性。它增强了使用REP MOVSB和REP STOSB指令进行内存操作的效率 section .datasrc db Hello,World! ; 源数据dst times 12 db 0 ; 目标缓冲区section .textglobal _start _start:mov es…

vj题单 Color the ball c 差分

题目链接&#xff1a;Problem - 1556 (hdu.edu.cn) 笔者思路&#xff1a;利用一维差分数组进行区间同时1的操作&#xff0c;然后还原为一维前缀和数组 笔者答案&#xff1a; #include<stdio.h> int cut[100010];int main() {long N,a,b,i,k1,j;scanf("%ld",…

Leetcode 3148. Maximum Difference Score in a Grid

Leetcode 3148. Maximum Difference Score in a Grid 1. 解题思路2. 代码实现 题目链接&#xff1a;3148. Maximum Difference Score in a Grid 1. 解题思路 这一题的话算是一个脑筋急转弯的题目吧&#xff0c;本质上就是求各个坐标下其右下方矩阵当中除自己外最大的元素是多…

Linux 第三十三章

&#x1f436;博主主页&#xff1a;ᰔᩚ. 一怀明月ꦿ ❤️‍&#x1f525;专栏系列&#xff1a;线性代数&#xff0c;C初学者入门训练&#xff0c;题解C&#xff0c;C的使用文章&#xff0c;「初学」C&#xff0c;linux &#x1f525;座右铭&#xff1a;“不要等到什么都没有了…

大模型学习笔记九:模型微调

文章目录 一、什么时候需要Fine-Tuning二、用Hugging Face根据电影评论输出来对电影进行情感分类1)安装依赖2)操作流程3)名字解释4)代码导入库和加载模型、加载数据库、加载tokenlizer5)其他相关公共变量赋值(随机种子、标签集评价、标签转token_Id)6)处理数据集:转成…