AI学习记录 - 旋转位置编码

创作不易,有用点赞,写作有利于锻炼一门新的技能,有很大一部分是我自己总结的新视角

1、前置条件:要理解旋转位置编码前,要熟悉自注意力机制,否则很难看得懂,在我的系列文章中有对自注意力机制的画图解释。

先说重要的结论(下面 q向量 和 k向量 是自注意力矩阵诞生的,不懂先去看注意力机制):

结论1:旋转位置编码本身是绝对位置编码,但是和自注意力机制中的一个qk向量结合之后,就变成了相对位置编码。因为自注意力机制中qk会计算点积,正是恰好这个内积,顺带把旋转位置编码变成了相对位置编码,所以一般说旋转位置编码既包含了绝对位置编码含义,也包含了相对位置编码含义。
结论2:假设没有位置编码这个东西,自注意力机制中,qk向量进行内积的时候,经过反向传播,会逐渐得出词汇与词汇的关联度矩阵,假设10个词汇计算内积,当两个词汇关联度越高,这两个词汇的内积(q * k)越大,重点来了:当对q 和 k叠加上旋转位置编码之后,那不仅仅是两个词汇关联度越高,内积越大,并且当两个词汇位置距离越近,内积也越大。
结论3:原来词向量跟词向量的内积大小只跟词汇的语义相关,内积越大,两个词汇的语义关联度越高。叠加上旋转位置编码后,距离相近的词向量内积也大。当一个句子中,两个词汇距离很远但是语义强相关,那他们的内积就是大;当两个词汇语义没啥关联但是距离很近,内积也是大;当两个词汇距离又近,语义有强相关,内积就是大大的。

2、经过上面的结论,其实我们知道了旋转位置编码在哪个位置起到的作用,就是得出 q 和 k 向量之后。

在说旋转位置编码怎么旋转之前,数学界已经就有了怎么对一个向量进行旋转,举个例子

在这里插入图片描述

如果你本身对位置编码不熟悉,在了解旋转位置编码之前,建议先去看我的另一篇博客,有个传统绝对位置编码的解释,旋转位置编码在没和qk叠加之前,其实和绝对位置编码差不多,你会发现他们的公式在某些地方非常的接近。如果这个所谓的旋转位置编码和传统绝对位置编码通过一样的方式叠加到词向量上面,旋转位置编码还是一个绝对位置编码,关键在于叠加方式不一样。当然传统位置编码使用旋转位置编码的叠加方式,也没有产生相对位置含义,所以旋转位置编码的计算公式和他的叠加方式是相互相成的。
传统绝对位置编码公式:

在这里插入图片描述

旋转位置编码公式:

在这里插入图片描述

3、上面知道如果向量需要旋转,其实需要一个二维向量,但是 q 和 k 都是一维向量,怎么办呢,通过如下叠加,把 q 和 k 向量都按照如下图所示变成二维向量:

在这里插入图片描述

然后把q的每一列当成(x,y)取出来,下图所示,一共有8个(x,y),所有的q向量都进行这样子的计算,计算完成之后,我们就说q叠加上了旋转位置编码。

在这里插入图片描述

然后又转换回来,这个q叠加上了旋转位置编码

在这里插入图片描述

4、我简单提供一个证明,证明在向量在旋转位置编码之后,词汇距离越近,内积就越大,假设两个token的q向量都一样。

假设两个token的初始表示为相同的向量:𝑣=[1,0,1,0]

旋转矩阵为:
在这里插入图片描述
下面我们来套用上面说到的公式计算:

当这个向量位置为 1

在这里插入图片描述

当这个向量位置为 3

在这里插入图片描述

在这里插入图片描述

5、最后代码实现,在这里我也是拿某些大佬的,我在这里写了很多print形状,从观察矩阵形状变化去理解比较好

我这里提一下,就是你会发现代码其实有点难以看懂,这是因为涉及到批次计算,多头,导致矩阵代码中做了很多的矩阵变换,但是本质的流程还是我上面所说的,只是在实现过程中,考虑到优化导致的代码难以按照我上面所述的流程看懂,但是本质和上面一样。
import torch
import torch.nn as nn
import torch.nn.functional as F
import math# %%def sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, device):# batch_size = 8# nums_head = 12# max_len = 10# output_dim = 32position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(-1)ids = torch.arange(0, output_dim // 2, dtype=torch.float)  # 即公式里的i, i的范围是 [0,d/2]theta = torch.pow(10000, -2 * ids / output_dim)print(position) # [[0.],[1.],[2.],[3.],[4.],[5.],[6.],[7.],[8.],[9.]]print(output_dim) # 32print(theta) # tensor([1.0000e+00, 5.6234e-01, 3.1623e-01, 1.7783e-01, 1.0000e-01, 5.6234e-02,# 3.1623e-02, 1.7783e-02, 1.0000e-02, 5.6234e-03, 3.1623e-03, 1.7783e-03,# 1.0000e-03, 5.6234e-04, 3.1623e-04, 1.7783e-04])print(theta.size()) # torch.Size([16])print(position.size()) # torch.Size([10, 1])embeddings = position * theta  # 即公式里的:pos / (10000^(2i/d))print(embeddings.size()) # torch.Size([10, 16])# (max_len, output_dim//2, 2)embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)# For example:# torch.sin(embeddings) = tensor([[ 0.0000,  0.8415,  0.9093,  0.1411, -0.7568, -0.9589]])# torch.cos(embeddings) = tensor([[ 1.0000,  0.5403, -0.4161, -0.9900, -0.6536,  0.2837]])# torch.stack = tensor([[[ 0.0000,  1.0000],#                        [ 0.8415,  0.5403],#                        [ 0.9093, -0.4161],#                        [ 0.1411, -0.9900],#                        [-0.7568, -0.6536],#                        [-0.9589,  0.2837]]])print(embeddings.size()) # torch.Size([10, 16, 2])embeddings = embeddings.repeat((batch_size, nums_head, *([1] * len(embeddings.shape))))  # 在bs维度重复,其他维度都是1不重复print(embeddings.size()) # torch.Size([8, 12, 10, 16, 2])# reshape后就是:偶数sin, 奇数cos了embeddings = torch.reshape(embeddings, (batch_size, nums_head, max_len, output_dim))print(embeddings.size()) # torch.Size([8, 12, 10, 32])embeddings = embeddings.to(device)return embeddings# %%def RoPE(q, k):# q,k: (bs, head, max_len, output_dim)batch_size = q.shape[0] # batch_size = 8nums_head = q.shape[1] # nums_head = 12max_len = q.shape[2] # max_len = 10output_dim = q.shape[3] # output_dim = 32pos_emb = sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, q.device)print(pos_emb.size()) # torch.Size([8, 12, 10, 32])# 看rope公式可知,相邻cos,sin之间是相同的,所以复制一遍。如(1,2,3)变成(1,1,2,2,3,3)cos_pos = pos_emb[...,  1::2].repeat_interleave(2, dim=-1)  # 将奇数列信息抽取出来也就是cos 拿出来并复制sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1)  # 将偶数列信息抽取出来也就是sin 拿出来并复制print(cos_pos.size()) # torch.Size([8, 12, 10, 32])print(sin_pos.size()) # torch.Size([8, 12, 10, 32])q2 = torch.stack([-q[..., 1::2], q[..., ::2]], dim=-1)print(q2.size()) # torch.Size([8, 12, 10, 16, 2])q2 = q2.reshape(q.shape)  # reshape后就是正负交替了print(q2.size()) # torch.Size([8, 12, 10, 32])# 更新qw, *对应位置相乘q = q * cos_pos + q2 * sin_posprint(q.size()) # torch.Size([8, 12, 10, 32])k2 = torch.stack([-k[..., 1::2], k[..., ::2]], dim=-1)k2 = k2.reshape(k.shape)# 更新kw, *对应位置相乘k = k * cos_pos + k2 * sin_posreturn q, k# %%def attention(q, k, v, mask=None, dropout=None, use_RoPE=True):# q.shape: (bs, head, seq_len, dk)# k.shape: (bs, head, seq_len, dk)# v.shape: (bs, head, seq_len, dk)if use_RoPE:q, k = RoPE(q, k)d_k = k.size()[-1]att_logits = torch.matmul(q, k.transpose(-2, -1))  # (bs, head, seq_len, seq_len)att_logits /= math.sqrt(d_k)if mask is not None:att_logits = att_logits.masked_fill(mask == 0, -1e9)  # mask掉为0的部分,设为无穷大att_scores = F.softmax(att_logits, dim=-1)  # (bs, head, seq_len, seq_len)if dropout is not None:att_scores = dropout(att_scores)# (bs, head, seq_len, seq_len) * (bs, head, seq_len, dk) = (bs, head, seq_len, dk)return torch.matmul(att_scores, v), att_scoresif __name__ == '__main__':# (bs, head, seq_len, dk)q = torch.randn((8, 12, 10, 32))k = torch.randn((8, 12, 10, 32))v = torch.randn((8, 12, 10, 32))res, att_scores = attention(q, k, v, mask=None, dropout=None, use_RoPE=True)# (bs, head, seq_len, dk),  (bs, head, seq_len, seq_len)print(res.shape, att_scores.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/878753.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2409d,d语言非常简单利用sqlite3库

1,在sqlite3d.c中 #include "sqlite3.h"2,直接使用了: import std.conv; import std.stdio; import std.string;import sqlite3d;extern(C) {static int callback(void *NotUsed, int argc, char **argv, char **azColName){int i;for(i0; i<argc; i){printf(&q…

OpenFeign请求拦截器,注入配置属性类(@ConfigurationProperties),添加配置文件(yml)中的token到请求头

一、需求 OpenFeign请求拦截器&#xff0c;注入配置属性类&#xff08;ConfigurationProperties&#xff09;&#xff0c;添加配置文件&#xff08;yml&#xff09;中的token到请求头 在使用Spring Boot结合OpenFeign进行微服务间调用时&#xff0c;需要在发起HTTP请求时添加一…

MLLM(二)| 阿里开源视频理解大模型:Qwen2-VL

2024年8月29日&#xff0c;阿里发布了 Qwen2-VL&#xff01;Qwen2-VL 是基于 Qwen2 的最新视觉语言大模型。与 Qwen-VL 相比&#xff0c;Qwen2-VL 具有以下能力&#xff1a; SoTA对各种分辨率和比例的图像的理解&#xff1a;Qwen2-VL在视觉理解基准上达到了最先进的性能&#…

kafka单机安装

kafka单机安装 下载地址 官网&#xff1a;https://kafka.apache.org/最新版本下载页面&#xff1a;https://kafka.apache.org/downloads 说明 版本选择&#xff1a;3.0.0&#xff0c;kafka_2.12-3.0.0.tgz下载地址&#xff1a;https://archive.apache.org/dist/kafka/3.0.0…

018、二级Java操作题汇总最新版100000+字

目录 1.基本操作&#xff08;源代码&#xff09;&#xff1a; 1.1 1.2 1.3 1.4 1.5 1.6 1.7 1.8 1.9 1.10 1.11 1.12 1.13 1.14 1.15 1.16 1.17 1.18 1.19 1.20 1.21 1.22 1.23 1.24 1.25 1.26 1.27 1.28 1.29 1.30 1.31 1.32 1.33 1.34 1.…

ClickHouse 的安装与基本配置

ClickHouse 是一款高性能的列式数据库管理系统&#xff0c;特别适合用于大数据分析。以下是 ClickHouse 的安装与基本配置步骤&#xff0c;涵盖了在常见平台&#xff08;如 Linux&#xff09;上的安装和基础配置。 1. 安装 ClickHouse 在 Linux (Ubuntu/Debian) 上安装 在 U…

Apache Guacamole 安装及配置VNC远程桌面控制

文章目录 官网简介支持多种协议无插件浏览器访问配置和管理应用场景 Podman 部署 Apache Guacamole拉取 docker 镜像docker-compose.yml部署 PostgreSQL生成 initdb.sql 脚本部署 guacamole Guacamole 基本用法配置 VNC 连接 Mac 电脑开启自带的 VNC 服务 官网 https://guacam…

Gmtracker安装中存在的问题

Gmtracker安装中存在的问题 GMtracker安装问题该如何解决&#xff1f; 使用用服务器&#xff0c;在云服务器中使用conda环境 python 3.6的版本环境. pip install -r requirements.txt 在网上查找资料&#xff1a;opencv安装失败卡在这里是因为没有使用高版本的python环境 切换…

MySQL灾难恢复策略:构建稳健的备份与恢复机制

在现代企业环境中&#xff0c;数据的安全性和可靠性至关重要。灾难恢复计划&#xff08;Disaster Recovery Plan, DRP&#xff09;是确保在发生灾难性事件后&#xff0c;能够迅速恢复业务的关键策略。对于依赖MySQL数据库的系统&#xff0c;实现有效的灾难恢复计划尤为重要。本…

流式域套接字通信

流式域套接字服务器端实现&#xff08;TCP&#xff09; #include <myhead.h>#define BACKLOG 5 int main(int argc, const char *argv[]) {int oldfdsocket(AF_UNIX,SOCK_STREAM,0);if(oldfd-1){perror("socket");return -1;}if(access("./server",…

pdf在线转换成word免费版,一键免费转换

在日常的学习和办公中&#xff0c;PDF文件和Word文档是我们离不开的两种最常见的文件&#xff0c;而PDF与Word文档之间的转换成为了我们日常工作中不可或缺的一部分。无论是为了编辑、修改还是共享文件&#xff0c;掌握多种PDF转Word的方法都显得尤为重要。很多小伙伴关心能不能…

摄像头的ISP和SOC的GPU有区别吗?

摄像头的主芯片必须包含ISP&#xff0c;也就是图像处理器核心。而SOC的GPU或者说显卡也包含图像处理器也就是GPU。两者并无本质区别&#xff0c;都是实现数字图像处理算法。同样的用FPGA做内窥镜图像处理和用FPGA做显示图像处理器本质上也是一样的。 当然两者存在一些细微差别…

Flask中 blinker 是什么

在Flask框架中&#xff0c;blinker 是一个非常重要的组件&#xff0c;它作为信号处理的库&#xff0c;为Flask应用提供了一种灵活而强大的事件处理机制。以下是对Flask中blinker的详细阐述&#xff0c;考虑到篇幅限制&#xff0c;无法直接达到5000字&#xff0c;但会尽量全面而…

SpringSecurity Oauth2 - 密码模式完成身份认证获取令牌 [自定义UserDetailsService]

文章目录 1. 授权服务器2. 授权类型1. Password (密码模式)2. Refresh Token&#xff08;刷新令牌&#xff09;3. Client Credentials&#xff08;客户端凭证模式&#xff09; 3. AuthorizationServerConfigurerAdapter4. 自定义 TokenStore 管理令牌1. TokenStore 的作用2. Cu…

Ajax 2024/3/31

Ajax 异步的Javascript和XML 作用&#xff1a; 数据交换&#xff1a;通过Ajax可以给服务器发送请求&#xff0c;并获取服务器响应的数据。 异步交互&#xff1a;可以在不重新加载整个页面的情况下&#xff0c;与服务器交换数据并更新部分网页的技术。 原生Ajax 1.准备数据…

看demo学算法之 贝叶斯网络

大家好&#xff0c;这里是小琳AI课堂&#xff01;今天我们一起来学习贝叶斯网络&#xff0c;这是一种非常酷的图形模型&#xff0c;它能帮助我们理解和处理变量之间的条件依赖关系。&#x1f3a8;&#x1f4ca; 贝叶斯网络基础 首先&#xff0c;贝叶斯网络是基于贝叶斯定理的…

springweb获取请求数据、spring中拦截器

SpringWeb获取请求数据 springWeb支持多种类型的请求参数进行封装 1、使用HttpServletRequest对象接收 PostMapping(path "/login")//post请求//spring自动注入public String login(HttpServletRequest request){ System.out.println(request.getParameter("…

C++基础知识(五)

struct VS class 特性structclass默认访问修饰符publicprivate成员访问权限成员默认是 public成员默认是 private继承方式默认继承方式为 public默认继承方式为 private用途通常用于简单的数据结构或记录通常用于复杂的数据类型和封装成员函数可以有成员函数可以有成员函数构造…

J.U.C Review - CopyOnWrite容器

文章目录 什么是CopyOnWrite容器CopyOnWriteArrayList优点缺点源码示例 仿写&#xff1a;CopyOnWriteMap的实现注意事项 什么是CopyOnWrite容器 CopyOnWrite容器是一种实现了写时复制&#xff08;Copy-On-Write&#xff0c;COW&#xff09;机制的并发容器。在并发场景中&#…

请解释一下 JDBC 的作用,并给出一个简单的使用 JDBC 查询数据库的例子?

JDBC (Java Database Connectivity) 是 Java 编程语言中用于连接和操作关系型数据库的标准 API。 它的主要作用是为 Java 应用程序提供了一种标准的方式来访问和处理数据库中的数据&#xff0c;而不需要关心底层具体的数据库系统&#xff08;如 MySQL, Oracle, PostgreSQL 等&…