从零开始实现大语言模型(四):简单自注意力机制

1. 前言

理解大语言模型结构的关键在于理解自注意力机制(self-attention)。自注意力机制可以判断输入文本序列中各个token与序列中所有token之间的相关性,并生成包含这种相关性信息的context向量。

本文介绍一种不包含训练参数的简化版自注意力机制——简单自注意力机制(simplified self-attention),后续三篇文章将分别介绍缩放点积注意力机制(scaled dot-product attention),因果注意力机制(causal attention),多头注意力机制(multi-head attention),并最终实现OpenAI的GPT系列大语言模型中MultiHeadAttention类。

2. 从循环神经网络到自注意力机制

解决机器翻译等多对多(many-to-many)自然语言处理任务最常用的模型是sequence-to-sequence模型。Sequence-to-sequence模型包含一个编码器(encoder)和一个解码器(decoder),编码器将输入序列信息编码成信息向量,解码器用于解码信息向量,生成输出序列。在Transformer模型出现之前,编码器和解码器一般都是一个循环神经网络(RNN, recurrent neural network)。

RNN是一种非常适合处理文本等序列数据的神经网络架构。Encoder RNN对输入序列进行处理,将输入序列信息压缩到一个向量中。状态向量 h 0 h_0 h0包含第一个token x 0 x_0 x0的信息, h 1 h_1 h1包含前两个tokens x 0 x_0 x0 x 1 x_1 x1的信息。以此类推, Encoder RNN最后一个状态 h m h_m hm是整个输入序列的概要,包含了整个输入序列的信息。Decoder RNN的初始状态等于Encoder RNN最后一个状态 h m h_m hm h m h_m hm包含了输入序列的信息,Decoder RNN可以通过 h m h_m hm知道输入序列的信息。Decoder RNN可以将 h m h_m hm中包含的信息解码,逐个元素地生成输出序列。

RNN的神经网络结构及计算方法使Encoder RNN必须用一个隐藏状态向量 h m h_m hm记住整个输入序列的全部信息。当输入序列很长时,隐藏状态向量 h m h_m hm对输入序列中前面部分的tokens的偏导数(如对 x 0 x_0 x0的偏导数 ∂ h m x 0 \frac{\partial h_m}{x_0} x0hm)会接近0。输入不同的 x 0 x_0 x0,隐藏状态向量 h m h_m hm几乎不会发生变化,即RNN会遗忘输入序列前面部分的信息。

本文不会详细介绍RNN的原理,大语言模型的神经网络中没有循环结构,RNN的原理及结构与大语言模型没有关系。对RNN的原理感兴趣读者可以参见本人的博客专栏:自然语言处理。

2014年,文章Neural Machine Translation by Jointly Learning to Align and Translate提出了一种改进sequence-to-sequence模型的方法,使Decoder每次更新状态时会查看Encoder所有状态,从而避免RNN遗忘的问题,而且可以让Decoder关注Encoder中最相关的信息,这也是attention名字的由来。

2017年,文章Attention Is All You Need指出可以剥离RNN,仅保留attention,且attention并不局限于sequence-to-sequence模型,可以直接用在输入序列数据上,构建self-attention,并提出了基于attention的sequence-to-sequence架构模型Transformer。

3. 简单自注意力机制

自注意力机制的目标是计算输入文本序列中各个token与序列中所有tokens之间的相关性,并生成包含这种相关性信息的context向量。如下图所示,简单自注意力机制生成context向量的计算步骤如下:

  1. 计算注意力分数(attention score):简单注意力机制使用向量的点积(dot product)作为注意力分数,注意力分数可以衡量两个向量的相关性;
  2. 计算注意力权重(attention weight):将注意力分数归一化得到注意力权重,序列中每个token与序列中所有tokens之间的注意力权重之和等于1;
  3. 计算context向量:简单注意力机制将所有tokens对应Embedding向量的加权和作为context向量,每个token对应Embedding向量的权重等于其相应的注意力权重。

图一

3.1 计算注意力分数

对输入文本序列 “Your journey starts with one step.” 做tokenization,将文本中每个单词分割成一个token,并转换成Embedding向量,得到 x 1 , x 2 , ⋯ , x 6 x_1, x_2, \cdots, x_6 x1,x2,,x6。自注意力机制分别计算 x i x_i xi x 1 , x 2 , ⋯ , x 6 x_1, x_2, \cdots, x_6 x1,x2,,x6的注意力权重,进而计算 x 1 , x 2 , ⋯ , x 6 x_1, x_2, \cdots, x_6 x1,x2,,x6与其相应注意力权重的加权和,得到context向量 z i z_i zi

如下图所示,将context向量 z i z_i zi对应的向量 x i x_i xi称为query向量,计算query向量 x 2 x_2 x2对应的context向量 z 2 z_2 z2的第一步是计算注意力分数。将query向量 x 2 x_2 x2分别点乘向量 x 1 , x 2 , ⋯ , x 6 x_1, x_2, \cdots, x_6 x1,x2,,x6,得到实数 ω 21 , ω 22 , ⋯ , ω 26 \omega_{21}, \omega_{22}, \cdots, \omega_{26} ω21,ω22,,ω26,其中 ω 2 i \omega_{2i} ω2i是query向量 x 2 x_2 x2与向量 x i x_i xi的注意力分数,可以衡量 x 2 x_2 x2对应token与 x i x_i xi对应token之间的相关性。

图二

两个向量的点积等于这两个向量相同位置元素的乘积之和。假如向量 x 1 = ( x 11 , x 12 , x 13 ) x_1=(x_{11}, x_{12}, x_{13}) x1=(x11,x12,x13),向量 x 2 = ( x 21 , x 22 , x 23 ) x_2=(x_{21}, x_{22}, x_{23}) x2=(x21,x22,x23),则向量 x 1 x_1 x1 x 2 x_2 x2的点积等于 x 11 × x 21 + x 12 × x 22 + x 13 × x 23 x_{11}\times x_{21} + x_{12}\times x_{22} + x_{13}\times x_{23} x11×x21+x12×x22+x13×x23

可以使用如下代码计算query向量 x 2 x_2 x2 x 1 , x 2 , ⋯ , x 6 x_1, x_2, \cdots, x_6 x1,x2,,x6的注意力分数:

import torch
inputs = torch.tensor([[0.43, 0.15, 0.89], # Your     (x^1)[0.55, 0.87, 0.66], # journey  (x^2)[0.57, 0.85, 0.64], # starts   (x^3)[0.22, 0.58, 0.33], # with     (x^4)[0.77, 0.25, 0.10], # one      (x^5)[0.05, 0.80, 0.55]] # step     (x^6)
)query = inputs[1]
attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):attn_scores_2[i] = torch.dot(x_i, query)
print(attn_scores_2)

执行上面代码,打印结果如下:

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])

3.2 计算注意力权重

如下图所示,将注意力分数 ω 21 , ω 22 , ⋯ , ω 26 \omega_{21}, \omega_{22}, \cdots, \omega_{26} ω21,ω22,,ω26归一化可得到注意力权重 α 21 , α 22 , ⋯ , α 26 \alpha_{21}, \alpha_{22}, \cdots, \alpha_{26} α21,α22,,α26。每个注意力权重 α 2 i \alpha_{2i} α2i的值均介于0到1之间,所有注意力权重的和 ∑ i α 2 i = 1 \sum_i\alpha_{2i}=1 iα2i=1。可以用注意力权重 α 2 i \alpha_{2i} α2i表示 x i x_i xi对当前context向量 z 2 z_2 z2的重要性占比,注意力权重 α 2 i \alpha_{2i} α2i越大,表示 x i x_i xi x 2 x_2 x2的相关性越强,context向量 z 2 z_2 z2 x i x_i xi的信息量比例应该越高。使用注意力权重对 x 1 , x 2 , ⋯ , x 6 x_1, x_2, \cdots, x_6 x1,x2,,x6加权求和计算context向量,可以使context向量的数值分布范围始终与 x 1 , x 2 , ⋯ , x 6 x_1, x_2, \cdots, x_6 x1,x2,,x6一致。这种数值分布范围的一致性可以使大语言模型训练过程更稳定,模型更容易收敛。

图三

可以使用softmax函数将注意力分数归一化得到注意力权重:

attn_weights_2 = torch.softmax(attn_scores_2, dim=0)
print("Attention weights:", attn_weights_2)
print("Sum:", attn_weights_2.sum())

执行上面代码,打印结果如下:

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)

3.3 计算context向量

简单注意力机制使用所有tokens对应Embedding向量的加权和作为context向量,context向量 z 2 = ∑ i α 2 i x i z_2=\sum_i\alpha_{2i}x_i z2=iα2ixi

图四

可以使用如下代码计算context向量 z 2 z_2 z2

query = inputs[1] # 2nd input token is the query
context_vec_2 = torch.zeros(query.shape)
for i,x_i in enumerate(inputs):context_vec_2 += attn_weights_2[i] * x_i
print(context_vec_2)

执行上面代码,打印结果如下:

tensor([0.4419, 0.6515, 0.5683])

3.4 计算所有tokens对应的context向量

将向量 x 2 x_2 x2作为query向量,按照3.1所述方法,可以计算出注意力分数 ω 21 , ω 22 , ⋯ , ω 26 \omega_{21}, \omega_{22}, \cdots, \omega_{26} ω21,ω22,,ω26。使用softmax函数将注意力分数 ω 21 , ω 22 , ⋯ , ω 26 \omega_{21}, \omega_{22}, \cdots, \omega_{26} ω21,ω22,,ω26归一化,可以得到注意力权重 α 21 , α 22 , ⋯ , α 26 \alpha_{21}, \alpha_{22}, \cdots, \alpha_{26} α21,α22,,α26。Context向量 z 2 z_2 z2是使用注意力权重对 x 1 , x 2 , ⋯ , x 6 x_1, x_2, \cdots, x_6 x1,x2,,x6的加权和。

计算所有tokens对应的context向量,可以使用矩阵乘法运算,分别将各个 x i x_i xi作为query向量,一次性批量计算注意力分数及注意力权重,并最终得到context向量 z i z_i zi

如下面代码所示,可以使用矩阵乘法,一次性计算出所有注意力分数:

attn_scores = inputs @ inputs.T
print(attn_scores)

执行上面代码,打印结果如下:

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],[0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],[0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],[0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],[0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],[0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

@操作符是PyTorch中的矩阵乘法运算符号,与函数torch.matmul运算逻辑相同。

将一个 n n n m m m列的矩阵 A A A与另一个 m m m n n n B B B的矩阵相乘,结果 C C C是一个 n n n n n n列的矩阵。其中矩阵 C C C i i i j j j列元素等于矩阵 A A A的第 i i i行与矩阵 B B B的第 j j j列两个向量的内积。

如下面代码所示,使用softmax函数注意力分数归一化,可以一次批量计算出所有注意力权重:

attn_weights = torch.softmax(attn_scores, dim=1)
print(attn_weights)

执行上面代码,打印结果如下:

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],[0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],[0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],[0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],[0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],[0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])

可以同样使用矩阵乘法运算,一次性批量计算出所有context向量:

all_context_vecs = attn_weights @ inputs
print(all_context_vecs)

执行上面代码,打印结果如下:

tensor([[0.4421, 0.5931, 0.5790],[0.4419, 0.6515, 0.5683],[0.4431, 0.6496, 0.5671],[0.4304, 0.6298, 0.5510],[0.4671, 0.5910, 0.5266],[0.4177, 0.6503, 0.5645]])

4. 结束语

自注意力机制是大语言模型神经网络结构中最复杂的部分。为降低自注意力机制原理的理解门槛,本文介绍了一种不带任何训练参数的简化版自注意力机制。

自注意力机制的目标是计算输入文本序列中各个token与序列中所有tokens之间的相关性,并生成包含这种相关性信息的context向量。简单自注意力机制生成context向量共3个步骤,首先计算注意力分数,然后使用softmax函数将注意力分数归一化得到注意力权重,最后使用注意力权重对所有tokens对应的Embedding向量加权求和得到context向量。

接下来,该去看看大语言模型中真正使用到的注意力机制了!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/44555.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JMH324-免费【最后一战LOL】MOBA竞技版本+单机一键端+视频教程+文本教程

资源介绍: 修改前打开【D:\ZHServer】文件夹里的【[1]一键启动.bat】,游戏不要打开,否则修改失败。 修改完以后重启架设程序才会生效。 fball_gamedb1数据库——gameuser数据表 obj_name 角色名 obj_lv 等级 obj_diamond 钻石 obj_gold 8…

ZFT9-7VE8043-Z同期脉冲发送装置100V JOSEF约瑟 柜内安装

ZFT9(PIG)同期脉冲发送装置 系列型号 ZFT9(PIG) 7VE8033同期脉冲发送装置; ZFT9(PIG) 7VE8043同期脉冲发送装置; ZFT9 7VE8033同期脉冲发送装置; ZFT9 7VE8043同期脉冲发送装置; 用途: ZFT9同期脉冲发送装置用于船舶的三相系统,根据发电机和电力系…

为什么要考国际人力资源证书?HR不能不知道!

在人力资源领域中,持有专业的人力资源证书并非铁律般的必需。但不容忽视的是,随着时代的进步和行业的不断演进,越来越多的人力资源专业人员开始重视并追求人力资源资格认证。 一张高含金量的证书让HR在求职市场上更具竞争力,更能…

JavaScript中的执行上下文和原型链

目录 一、执行上下文 1.执行上下文 2.执行上下文栈 3.闭包 1)定义 2)形成条件 3)例子 (1)例子1:简单闭包 (2)例子2:闭包与循环 (3)例子…

mes系统在新材料行业中的应用价值

万界星空科技新材料MES系统是针对新材料制造行业的特定需求而设计的制造执行系统,它集成了生产计划、过程监控、质量管理、设备管理、库存管理等多个功能模块,以支持新材料生产的高效、稳定和可控。以下是新材料MES系统的具体功能介绍: 一、生…

【算法入门-栈】逆波兰表达式求值

📖逆波兰表达式求值 ✅描述✅扩展:什么是逆波兰表达式✅题解方法一:栈✅题解方法二(数组模拟栈) 今天又刷了一道题,奥利给 刷题地址: 点击跳转 ✅描述 给定一个逆波兰表达式,求表达…

吹田电气绿色能源 未来可期

在2024年7月的上海慕尼黑电子展上,吹田电气功率分析仪成为了备受瞩目的明星产品。作为电子测试与测量领域的重要工具,功率分析仪在展会上展示了其在绿色能源和高效能量管理方面的最新应用,引发了广泛关注和热议。 领先技术,精准测…

[leetcode]kth-smallest-element-in-a-sorted-matrix 有序矩阵中第k小元素

. - 力扣&#xff08;LeetCode&#xff09; class Solution { public:bool check(vector<vector<int>>& matrix, int mid, int k, int n) {int i n - 1;int j 0;int num 0;while (i > 0 && j < n) {if (matrix[i][j] < mid) {num i 1;j;…

Qt/QML学习-PathView

QML学习 PathView例程视频讲解代码 main.qml import QtQuick 2.15 import QtQuick.Window 2.15Window {width: 640height: 480visible: truetitle: qsTr("Hello World")color: "black"PathView {id: pathViewanchors.fill: parentmodel: ListModel {List…

电厂数字孪生能源数据可视化运维平台开发炫酷且性价比更高

3D数据可视化大屏平台是我们为工厂车间提供的线上展示自定义工具&#xff0c;深度融合了web3D开发建模、AI和图形图像技术&#xff0c;完美还原车间产线布局&#xff0c;让复杂的生产流程和设备运行数据在大屏上直观呈现。 3D可视化数据大屏采用全景3D视角和虚拟现实技术&#…

快速测试electron环境是否安装成功

快速测试electron环境是否安装成功 测试代码正确运行的效果运行错误的效果v22.4.1 版本无法使用v20.15.1版本无法使用v18.20.4 版本无法使用 终极解决办法 测试代码 1.npx create-electron-app my-electron-app 2.cd my-electron-app 3.npm start 正确运行的效果 环境没问题…

springboot高校讲座预约管理系统-计算机毕业设计源码21634

摘 要 本系统旨在设计和实现一个基于Android平台的高校讲座预约管理系统&#xff0c;以提供管理员和普通用户便捷的讲座预约服务和全面的管理功能。系统将包括在线讲座发布、讲座预约、座位安排、签到信息记录等功能模块&#xff0c;旨在提高高校讲座活动的组织效率和用户体验。…

【三维向量旋转】基于Matlab的三维坐标旋转

一、问题描述 若空间中存在三个点A,B,C&#xff0c;其中A点是不动点&#xff0c;B点是当前方向向量上的一个点&#xff0c;C是目标方向上的一个点。如果要让AB向量沿着BC方向进行旋转&#xff0c;使得AB最终旋转到AC。这个过程就是三维向量的旋转过程。我们关注的是这个过程&am…

MT3047 区间最大值

思路&#xff1a; 使用哈希表map和set&#xff08;去重&#xff09;维护序列 代码&#xff1a; #include <bits/stdc.h> using namespace std; const int N 1e5 10; int n, k, A[N]; map<int, int> mp; // 元素出现的次数 set<int> s; // 维护出现…

【案例】python集成OCR识别工具调研

目录 一、前言二、Tesseract_OCR2.1、安装过程2.2、python代码使用三、PaddleOCR3.1、安装过程3.2、python代码使用四、EasyOCR五、ddddOCR六、CnOCR一、前言 因项目需要OCR识别能力,且要支持私有化部署。本文将对比市场一些开源的OCR识别工具,从中选择适合项目需要的OCR,且…

Win10屏幕录制,这3种方法分享给你

数字化时代里&#xff0c;电脑的屏幕录制功能已经不再是简单的工具&#xff0c;而是成为我们表达、学习和交流的重要媒介。Win10系统依然是大部分人使用的电脑系统&#xff0c;那么关于Win10屏幕录制&#xff0c;有哪些好用高效的录制软件&#xff0c;能够帮助我们更加深入地捕…

美国商超入驻Homedepot,会成为传统家织厂家跨境赛道吗?

近年来&#xff0c;随着全球化步伐的加快和电子商务的蓬勃发展&#xff0c;越来越多的企业开始寻求跨境拓展的机会。在这样的背景下&#xff0c;美国知名的家居用品零售商超——Homedepot成为了许多国内外家织厂家关注的焦点。那么&#xff0c;美国商超入驻Homedepot究竟如何呢…

短视频剪辑软件-剪映必备快捷键大全 沈阳短视频剪辑培训

对于用剪映电脑版的朋友来说 快捷键是很重要的 那么剪映专业版有哪些快捷键呢 今天总结了一下快捷键大全 赶快收藏吧 1、基础功能 复制&#xff1a;Ctrl&#xff0b;C 粘贴&#xff1a;Ctrl&#xff0b;v 分割&#xff1a;Ctrl B 删除&#xff1a;Back 新建草稿&…

15.x86游戏实战-汇编指令jmp call ret

免责声明&#xff1a;内容仅供学习参考&#xff0c;请合法利用知识&#xff0c;禁止进行违法犯罪活动&#xff01; 本次游戏没法给 内容参考于&#xff1a;微尘网络安全 工具下载&#xff1a; 链接&#xff1a;https://pan.baidu.com/s/1rEEJnt85npn7N38Ai0_F2Q?pwd6tw3 提…

webGL可用的14种3D文件格式,但要具体问题具体分析。

hello&#xff0c;我威斯数据&#xff0c;你在网上看到的各种炫酷的3d交互效果&#xff0c;背后都必须有三维文件支撑&#xff0c;就好比你网页的时候&#xff0c;得有设计稿源文件一样。WebGL是一种基于OpenGL ES 2.0标准的3D图形库&#xff0c;可以在网页上实现硬件加速的3D图…