Transformer算法组件详解

自2017年Google推出Transformer以来,基于其架构的语言模型便如雨后春笋般涌现,其中Bert、T5等备受瞩目,而近期风靡全球的大模型ChatGPT和LLaMa更是大放异彩。网络上关于Transformer的解析文章非常大。

前言

Transformer是谷歌在2017年的论文《Attention Is All You Need》中提出的,用于NLP的各项任务,现在是谷歌云TPU推荐的参考模型。网上有关Transformer原理的介绍很多,在本文中我们将尽量模型简化,让普通读者也能轻松理解。

1. Transformer整体结构

在机器翻译中,Transformer可以将一种语言翻译成另一种语言,如果把Transformer看成一个黑盒,那么其结构如下图所示:

将法语翻译成英语

那么拆开这个黑盒,那么可以看到Transformer由若干个编码器和解码器组成,如下图所示:

继续将Encoder和Decoder拆开,可以看到完整的结构,如下图所示:

Transformer整体结构(引自谷歌论文)

可以看到Encoder包含一个Muti-Head Attention模块,是由多个Self-Attention组成,而Decoder包含两个Muti-Head Attention。Muti-Head Attention上方还包括一个 Add & Norm 层,Add 表示残差连接 (Residual Connection) 用于防止网络退化,Norm 表示 Layer Normalization,用于对每一层的激活值进行归一化。

假设我们的输入包含两个单词,我们看一下Transformer的整体结构:

Transformer整体结构(输入两个单词的例子)

为了能够对Transformer的流程有个大致的了解,我们举一个简单的例子,还是以之前的为例,将法语"Je suis etudiant"翻译成英文。

  • 第一步:获取输入句子的每一个单词的表示向量 ,  由单词的Embedding和单词位置的Embedding 相加得到。

Transformer输入表示

  • 第二步:将单词向量矩阵传入Encoder模块,经过N个Encoder后得到句子所有单词的编码信息矩阵C,如下图。输入句子的单词向量矩阵用X\in R^{n*d}表示,其中n是单词个数,d表示向量的维度(论文中d=512)。每一个Encoder输出的矩阵维度与输入完全一致。

输入X经过Encoder输出编码矩阵C

  • 第三步:将Encoder输出的编码矩阵C传递到Decoder中,Decoder会根据当前翻译过的单词1~i翻译下一个单词i+1,如下图所示。

Transformer Decoder预测

上图Decoder接收了Encoder的编码矩阵,然后首先输入一个开始符 "<Begin>",预测第一个单词,输出为"I";然后输入翻译开始符 "<Begin>" 和单词 "I",预测第二个单词,输出为"am",以此类推。这是Transformer的大致流程,接下来介绍里面各个部分的细节。

2. Transformer的输入表示

Transformer中单词的输入表示由单词Embedding位置Embedding(Positional Encoding)相加得到。

Transformer输入表示

2.1 单词Embedding

单词的Embedding可以通过Word2vec等模型预训练得到,可以在Transformer中加入Embedding层。

2.2 位置Embedding

Transformer 中除了单词的Embedding,还需要使用位置Embedding 表示单词出现在句子中的位置。因为 Transformer不采用RNN结构,而是使用全局信息,不能利用单词的顺序信息,而这部分信息对于NLP来说非常重要。所以Transformer中使用位置Embedding保存单词在序列中的相对或绝对位置。

位置Embedding用PE表示,PE的维度与单词Embedding相同。PE可以通过训练得到,也可以使用某种公式计算得到。在Transformer中采用了后者,计算公式如下:

 其中,pos表示单词在句子中的位置,d表示PE的维度。

3. Multi-Head Attention(多头注意力机制)

Transformer内部结构

上图是Transformer的内部结构,其中红色方框内为Multi-Head Attention,是由多个Self-Attention组成,具体结构如下图:

Self-Attention和Multi-Head Attention

因为Self-Attention是Transformer的重点,所以我们重点关注 Multi-Head Attention 以及 Self-Attention,首先介绍下Self-Attention的内部逻辑。

3.1 Self-Attention结构

Self-Attention结构

上图是Self-Attention结构,最下面是Q (查询)、K(键值)、V(值)矩阵,是通过输入矩阵X和权重矩阵W^{Q}, W^{K}, W^{V}相乘得到的。

Q,K,V的计算

得到Q, K, V之后就可以计算出Self-Attention的输出,如下图所示:

Self-Attention输出

3.2 Multi-Head Attention输出

在上一步,我们已经知道怎么通过Self-Attention计算得到输出矩阵Z,而Multi-Head Attention是由多个Self-Attention组合形成的,下图是论文中Multi-Head Attention的结构图。

Multi-Head Attention

从上图可以看到Multi-Head Attention包含多个Self-Attention层,首先将输入X分别传递到h个不同的Self-Attention中,计算得到h个输出矩阵Z。下图是h=8的情况,此时会得到 8 个输出矩阵Z。

多个Self-Attention

得到8个输出矩阵Z_{0} \sim Z_{7}后,Multi-Head Attention将它们拼接在一起(Concat),然后传入一个Linear层,得到Multi-Head Attention最终的输出矩阵Z。

Multi-Head Attention输出

4. 编码器Encoder结构

Transformer Encoder模块

上图红色部分是Transformer的Encoder结构,N表示Encoder的个数,可以看到是由Multi-Head Attention、Add & Norm、Feed Forward、Add & Norm组成的。前面已经介绍了Multi-Head Attention的计算过程,现在了解一下Add & Norm和 Feed Forward部分。

4.1 单个Encoder输出

Add & Norm是指残差连接后使用LayerNorm,表示如下:

 其Sublayer表示经过的变换,比如第一个Add & Norm中Sublayer表示Multi-Head Attention。

Feed Forward是指全连接层,表示如下:

 因此输入矩阵X经过一个Encoder后,输出表示如下:

4.2 多个Encoder输出

通过上面的单个Encoder,输入矩阵,最后输出矩阵。通过多个Encoder叠加,最后便是编码器Encoder的输出。

5. 解码器Decoder结构

Transformer Decoder模块

上图红色部分为Transformer的Decoder结构,与Encoder相似,但是存在一些区别:

  • 包含两个Multi-Head Attention
  • 第一个Multi-Head Attention采用了Masked操作
  • 第二个Multi-Head Attention的K,V矩阵使用Encoder的编码信息矩阵C进行计算,而Q使用上一个 Decoder的输出计算
  • 最后有一个Softmax层计算下一个翻译单词的概率

5.1 第一个Multi-Head Attention

Decoder的第一个Multi-Head Attention采用了Masked操作,因为在翻译的过程中是顺序翻译的,即翻译完第i个单词,才可以翻译第i+1个单词。通过 Masked 操作可以防止第i个单词知道i+1个单词之后的信息。下面以法语"Je suis etudiant"翻译成英文"I am a student"为例,了解一下 Masked 操作。

在Decoder的时候,需要根据之前翻译的单词,预测当前最有可能翻译的单词,如下图所示。首先根据输入"<Begin>"预测出第一个单词为"I",然后根据输入"<Begin> I" 预测下一个单词 "am"。

Decoder预测(右图有问题,应该是Decoder 1)

Decoder在预测第i个输出时,需要将第i+1之后的单词掩盖住,Mask操作是在Self-Attention的Softmax之前使用的,下面以前面的"I am a student"为例。

第一步:是Decoder的输入矩阵和Mask矩阵,输入矩阵包含"<Begin> I am a student"4个单词的表示向量,Mask是一个4*4的矩阵。在Mask可以发现单词"<Begin>"只能使用单词"<Begin>"的信息,而单词"I"可以使用单词"<Begin> I"的信息,即只能使用之前的信息。

输入矩阵与Mask矩阵

第二步:接下来的操作和之前Encoder中的Self-Attention一样,只是在Softmax之前需要进行Mask操作。

Mask Self-Attention输出

第三步:通过上述步骤就可以得到一个Mask Self-Attention的输出矩阵Z_{i},然后和Encoder类似,通过Multi-Head Attention拼接多个输出Z_{i}然后计算得到第一个Multi-Head Attention的输出Z,Z与输入X维度一样。

5.2 第二个Multi-Head Attention

Decoder的第二个Multi-Head Attention变化不大,主要的区别在于其中Self-Attention的K, V矩阵不是使用上一个Multi-Head Attention的输出,而是使用Encoder的编码信息矩阵C计算的。根据Encoder的输出C计算得到K, V,根据上一个Multi-Head Attention的输出Z计算Q。这样做的好处是在Decoder的时候,每一位单词(这里是指"I am a student")都可以利用到Encoder所有单词的信息(这里是指"Je suis etudiant")。

6. Softmax预测输出

Softmax预测输出

编码器Decoder最后的部分是利用 Softmax 预测下一个单词,在Softmax之前,会经过Linear变换,将维度转换为词表的个数。

假设我们的词表只有6个单词,表示如下:

词表

因此,最后的输出可以表示如下:

Softmax预测输出示例

总结

Transformer由于可并行、效果好等特点,如今已经成为机器翻译、特征抽取等任务的基础模块,目前ChatGPT特征抽取的模块用的就是Transformer,这对于后面理解ChatGPT的原理做了好的铺垫。

代码实现

绝密伏击:OPenAI ChatGPT(一):Tensorflow实现Transformer

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/5556.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【webrtc】MessageHandler 7: 基于线程的消息处理:切换main线程向observer发出通知

以当前线程作为main线程 RemoteAudioSource 作为一个handler 仅实现一个退出清理的功能 首先on message的处理会切换到main 线程 :main_thread_其次,这里在main 线程对sink_ 做清理再次,在main 线程做出状态改变,并能通知给所有的observer 做出on changed 行为。对接mediac…

AC+AP三层组网实验(华为)

一&#xff0c;技术简介 APAC架构是一种常见的无线局域网&#xff08;WLAN&#xff09;组网方式&#xff0c;主要由接入点&#xff08;Access Point&#xff0c;简称AP&#xff09;和接入控制器&#xff08;Access Controller&#xff0c;简称AC&#xff09;组成。 在APAC架构…

AI大模型系列:自然语言处理,从规则到统计的演变

AI大模型系列文章目录 文明基石&#xff0c;文字与数字的起源与演变自然语言处理&#xff0c;从规则到统计的演变AI魔法师&#xff0c;提示工程的力量 自然语言处理&#xff0c;从规则到统计的演变 自然语言处理&#xff08;Natural Language Processing&#xff0c;NLP&…

前端 CSS

目录 选择器 复合选择器 伪类-超链接 结构伪装选择器 伪元素选择器 画盒子 字体属性 CSS三大属性 Emmet写法 背景属性 显示模式 盒子模型 盒子模型-组成 盒子模型-向外溢出 盒子模型-圆角 盒子模型-阴影 flex position定位 CSS小精灵 字体图标 垂直对齐方式…

数据库(MySQL)—— DML语句

数据库&#xff08;MySQL&#xff09;—— DML语句 什么是DML语句添加数据给全部字段添加数据批量添加数据 修改数据删除数据 什么是DML语句 在MySQL中&#xff0c;DML&#xff08;Data Manipulation Language&#xff0c;数据操纵语言&#xff09;语句主要用于对数据库中的数…

基础安全:CSRF攻击原理与防范

CSRF的概念 CSRF(Cross-Site Request Forgery)中文名为“跨站请求伪造”。这是一种常见的网络攻击手段,攻击者通过构造恶意请求,诱骗已登录的合法用户在不知情的情况下执行非本意的操作。这种攻击方式利用了Web应用程序中用户身份验证的漏洞,即浏览器在用户完成登录后会自…

eclipse导入工程提示Project has no explicit encoding set

eclipse导入工程提示Project has no explicit encoding set 文章目录 eclipse导入工程提示Project has no explicit encoding set一、Eclipse的工程导入二、可能的问题1.在工程名下有黄色叹号 一、Eclipse的工程导入 用Eclipse的导入可以将原有工程导入到新环境中 具体方法是&…

3.C++动态内存管理(超全)

目录 1 .C/C 内存分布 2. C语言中动态内存管理方式&#xff1a;malloc/calloc/realloc/free 3. C内存管理方式 3.1 new/delete操作内置类型 3.2 new和delete操作自定义类型 3.3 operator new函数 3.4 定位new表达式(placement-new) &#xff08;了解&#xff09; 4. 常…

Linux搭建靶场

提前准备&#xff1a; 文章中所使用到的Linux系统&#xff1a;Ubantu20.4sqlilabs靶场下载地址&#xff1a;GitHub - Audi-1/sqli-labs: SQLI labs to test error based, Blind boolean based, Time based. 一. 安装phpstudy phpstudy安装命令&#xff1a;wget -O install.sh h…

增强大模型高效检索:基于LlamaIndex ,构建一个轻量级带有记忆的 ColBERT 检索 Agent

在自然语言处理领域&#xff0c;高效检索相关信息的能力至关重要。将对话式记忆集成到文档检索系统中已经成为增强信息检索代理效果的强大技术。 在文中&#xff0c;我们专为 LlamaIndex 量身定制&#xff0c;将深入探讨构建一个轻量级的带有记忆的 ColBERT 检索代理&#xff…

后端方案设计文档结构模板可参考

文章目录 1 方案设计文档整体结构2 方案详细设计2.1 概要设计2.2 详细设计方案2.2.1 需求分析2.2.2 业务流程设计2.2.3 抽象类&#xff1a;实体对象建模2.2.4 接口设计2.2.5 存储设计 1 方案设计文档整体结构 一&#xff0c;现状&#xff1a;把项目的基本情况和背景都说清楚&a…

iA Writer for Mac:简洁强大的写作软件

在追求高效写作的今天&#xff0c;iA Writer for Mac凭借其简洁而强大的功能&#xff0c;成为了许多作家、记者和学生的首选工具。这款专为Mac用户打造的写作软件&#xff0c;以其独特的设计理念和实用功能&#xff0c;助你轻松打造高质量的文章。 iA Writer for Mac v7.1.2中文…

第2节:UIOTOS前端零代码应用 蓝图连线 信号值变化小示例02

目标 通过连线&#xff0c;实现信号值随机变化。 最终效果 实现过程 步骤1&#xff1a;接11节&#xff0c;选中底板设置其背景颜色 步骤2&#xff1a;拖入普通按钮V2组件&#xff0c;设置“text”值为“1”&#xff0c;并做form绑定 步骤3&#xff1a;选中按钮对输入框进行交…

Java 写一个死锁的例子

public class DeadLock {public static void main(String[] args) {Object lock1 new Object();Object lock2 new Object();new Thread(new A(lock1,lock2),"线程A").start();new Thread(new B(lock1,lock2),"线程B").start();} }class A implements Run…

Microsoft Universal Print 与 SAP 集成教程

引言 从 SAP 环境打印是许多客户的要求。例如数据列表打印、批量打印或标签打印。此类生产和批量打印方案通常使用专用硬件、驱动程序和打印解决方案来解决。 Microsoft Universal Print 是一种基于云的打印解决方案&#xff0c;它允许组织以集中化的方式管理打印机和打印机驱…

2023年蓝桥杯C++A组第三题:更小的数(双指针解法)

题目描述 小蓝有一个长度均为 n 且仅由数字字符 0 ∼ 9 组成的字符串&#xff0c;下标从 0 到 n − 1&#xff0c;你可以将其视作是一个具有 n 位的十进制数字 num&#xff0c;小蓝可以从 num 中选出一段连续的子串并将子串进行反转&#xff0c;最多反转一次。小蓝想要将选出的…

神经网络参数初始化

一、引入 在深度学习和机器学习的世界中&#xff0c;神经网络是构建智能系统的重要基石&#xff0c;参数初始化是神经网络训练过程中的一个重要步骤。在构建神经网络时&#xff0c;我们需要为权重和偏置等参数赋予初始值。对于偏置&#xff0c;通常可以将其初始化为0或者较小的…

【基础算法总结】滑动窗口一

滑动窗口 1.长度最小的字数组2.无重复字符的最长子串3.最大连续1的个数 III4.将 x 减到 0 的最小操作数 点赞&#x1f44d;&#x1f44d;收藏&#x1f31f;&#x1f31f;关注&#x1f496;&#x1f496; 你的支持是对我最大的鼓励&#xff0c;我们一起努力吧!&#x1f603;&…

工作任务管理平台B端实战项目作品集+WebApp项目源文件 figma格式

首先&#xff0c;作品集是什么&#xff1f;通常应该包含什么内容&#xff1f;为什么大家都在做自己的作品集呢&#xff1f; 作品集是个人或公司展示其过往工作成果的集合&#xff0c;通常包括各种专案、作品或成就的范例&#xff0c;用以展示创建者的技能、经验和专业水平。 …

Mysql--创建数据库

一、创建一个数据库 “db_classes” mysql> create database db_classes; mysql> show databases; -------------------- | Database | -------------------- | db_classes | | information_schema | | mysql | | performance_schema | |…