理解transformer

文章目录

  • 1 注意力机制
  • 2 自注意力机制
  • 3 自注意力机制加强版
  • 4 Transformer的结构
    • 4.1 input
    • 4.2 encoder
      • 4.2.1 Multi-head attention
      • 4.2.2 残差链接
      • 4.2.3 层正则化layer norm
      • 4.2.4 前馈神经网络 feed forward network
    • 4.3 decoder
    • 4.3.1 输入
      • 4.3.1 Masked Multi-head attention
      • 4.3.2 Multi-head attention
      • 4.3.3 前馈神经网络 feed forward network
    • 4.4 The Final Linear and Softmax Layer
    • 4.5 decoder总结
    • 4.4 The Loss Function
    • 4.5 代码阅读

1 注意力机制

在LSTM模型中,将encoder最后一个时间步的隐状态作为decoder的输入,这会导致一个问题:会丢失很多前面的信息。毕竟隐状态的维度是有限的,能承载的信息也是有限的。距离越远,丢的信息越多。注意力机制的提出就是用来解决这个问题的。
注意力机制是在seq2seq模型中提出的。注意力机制是seq2seq带给nlp最好的礼物。
在这里插入图片描述
我们希望h1,h2,h3,h4h_1,h_2,h_3,h_4h1,h2,h3,h4都能参与到解码器的计算中。这样就需要给他们分配一个权重,那这个权重怎么学习呢?
我们把编码器的输出变量h1,h2,h3,h4h_1,h_2,h_3,h_4h1,h2,h3,h4称为value。
把解码器的每一步输出hi′h_i'hi称为query。
使用(query,key)对计算权重。key就是h1,h2,h3,h4h_1,h_2,h_3,h_4h1,h2,h3,h4
上一步得到的权重与value(这里仍然是h1,h2,h3,h4h_1,h_2,h_3,h_4h1,h2,h3,h4)加权平均得到注意力输出。

在这里插入图片描述

这样的模式可以推广到所有需要注意力模型的地方。
value通常是上一个阶段的输出,key和value是一样的。
query通常是另外的一个向量。
使用 (query,key) 进行加权求和,得到的值作为权重,再对value进行加权求和。
这样的话在decoder阶段,每一步的输入会有3个变量:编码器经过注意力加权后的输出h(这个时候的query是hi−1′h_{i-1}'hi1),解码器上一步的隐状态hi−1′h_{i-1}'hi1,上一步的输出yi−1y_{i-1}yi1

Bahdanau是注意力机制的一种计算方法,也是现在很多工具包中的实现方法。
在这里插入图片描述

当前处于解码器中的第i步
在解码器中上一步的隐状态si−1s_{i-1}si1,上一步的输出yi−1y_{i-1}yi1,这一步的上下文向量cic_ici
在编码器中的最后输出的隐状态为h1h_1h1,h2h_2h2,h3h_3h3,h4h_4h4

为了计算cic_ici,需要使用注意力机制来解决。
value : h1h_1h1,h2h_2h2,h3h_3h3,h4h_4h4
query : si−1s_{i-1}si1
key : h1h_1h1,h2h_2h2,h3h_3h3,h4h_4h4
我们会使用(query,key)计算value的权重。对于其中第j个权重的计算方式是这样的:
eij=a(si−1,hj)e_{ij}=a(s_{i-1,h_j})eij=a(si1,hj) # 这里是将两个向量拼接
αij=exp(eij)∑j=1Txexp(eik)\alpha_{ij}=\dfrac{exp(e_{ij})}{\sum_{j=1}^{T_x}exp(e_{ik})}αij=j=1Txexp(eik)exp(eij) #这里会保证权重和为1。这里Tx=4{T_x}=4Tx=4
ci=∑j=1Txαijhjc_i=\sum_{j=1}^{T_x}\alpha_{ij}h_jci=j=1Txαijhj #计算得到上下文向量,Tx=4{T_x}=4Tx=4

si=f(si−1,yi−1,ci)s_i=f(s_{i-1},y_{i-1},c_i)si=f(si1,yi1,ci) #得到第i步的隐状态
P(yi∣y1,y2...yi−1,X)=g(yi−1,si,ci)P(y_i|y_1,y_2...y_{i-1},X) = g(y_{i-1},s_i,c_i)P(yiy1,y2...yi1,X)=g(yi1,si,ci) #得到第i步的输出

2 自注意力机制

在上面的介绍中 (query,key)计算得到一个权重。这里query是不同于key的向量。自注意力机制中query是key的一部分。就是说key通过自己注意自己学习到权重。所以称为自注意力机制。
在这里插入图片描述

例如在学习x2x_2x2的权重参数值时候,使用x2x_2x2作为query,x1,x3,x4x_1,x_3,x_4x1,x3,x4作为key和value。
对每个位置都计算得到权重参数,然后加权平均得到y2y_2y2
同理y3,y4,y1y_3,y_4,y_1y3,y4,y1的计算也是一样。
在这里插入图片描述
dkd_kdk是embedding的维度。在归一化之前会对每一个分数除以定值(embedding的维度开根号)。这样可以让softmax的分布更加平滑。

3 自注意力机制加强版

在这里插入图片描述
增强版的自注意力机制是
1 不使用x2x_2x2作为query,而是先对x2x_2x2做线性变换:Wqx2W_qx_2Wqx2,之后的向量作query。
2 x1,x3,x4x_1,x_3,x_4x1,x3,x4不直接作为key和value,而是先做线性变换之后再做key和value。Wkx1W_kx_1Wkx1作为key,Wvx1W_vx_1Wvx1作为value。
其余步骤相同。
这样的模型有更多的参数,模型性能也更加强大。

4 Transformer的结构

以下内容会部分来自于The Illustrated Transformer【译】
了解了注意力机制的变迁之后,我们再来看transformer结构。Transformer是在"Attention is All You Need"中提出的。这是一篇刷爆朋友圈的论文。因为它的效果基于现有效果有了较大幅度的提升。
transformer与之前一些结构的不同在于:

  • 双向LSTM:一个模型想要包含当前位置的信息,前一个位置的信息,后一个位置的信息
  • CNN:一个位置包含的信息取决于kernel size大小
  • transformer:可以得到全局信息
    transformer 是由input、encoder、decoder和output四部分组成的。

在这里插入图片描述
在这里插入图片描述

encoder组件由6层首尾相连的encoder组成。decoder组件是由6层decoder组成。
在这里插入图片描述

4.1 input

transformer模型的输入由词向量以及位置编码两部分组成。
词向量是使用word-piece。数据集是英-法 WMT 2014,包含36M 句子,这些句子被分为 32000 word-piece 词汇。每个词汇使用dmodel=512d_{model}=512dmodel=512来表示。

每个位置都定义了一个encoding。 在transformer中一直在做加权平均,没有前后顺序,这就会成为bag of words。

在这里有些位置用sin,有些位置用cos,表示位置信息。每个位置的encoding是什么样子并不重要。重要的是每个位置的encoding不一样

位置信息encoding之后 与 词向量相加,也就是 embed(word) + embed(position),整体作为输入送入到encoder。embed(position)的位置也是512。
在这里插入图片描述

在这里插入图片描述

按照偶数位sin,奇数位cos的方式,得到的结果确实是i,j越接近,pm.pnp_m.p_npm.pn越大。相对位置越远,点乘的结果越⼩。
在这里插入图片描述

4.2 encoder

6个encoder结构完全相同,但是参数不共享。
在这里插入图片描述

4.2.1 Multi-head attention

多头注意力机制是transformer模型中的重要改进。这个模型使用的是自注意力机制加强版。这部分内容在前面已经介绍了。这里重点介绍一下Multi-head。

在这里插入图片描述

不是对输入做一个Attention,而是需要做多个Attention。
假如每个单词512维度,这里有h个scaled dot-product attention。每一套可以并行计算。 Q K V 做了不同的affine变换,投射到不同的空间,得到不同的维度,也就是WX+b变换。不同head的长度一样,但是映射参数是不一样的。
之后过一个scaled dot-product attention。
h个结果concat
然后再做Linear
论文中h=8,dk=dv=dmodel/h=64d_k=d_v=d_{model/h}=64dk=dv=dmodel/h=64dmodel=512d_{model}=512dmodel=512
做Attention,Q K V 形状是不会发生变化的,每个的形状还是 seq_length,x,hidden_size。

公式如下:
输入X,包含token embedding和position embedding

  • 对X做变换

Qi=QWiQQ^i=QW^Q_iQi=QWiQ,Ki=KWiKK^i=KW^K_iKi=KWiK,Vi=VWiVV^i=VW^V_iVi=VWiV
每一次映射不共享参数,每一次映射会有(WiQ,WiK,WiV)W^Q_i,W^K_i,W^V_i)WiQ,WiK,WiV)三个参数。

  • 对多头中的某一组做attention

Attention(Qi,Ki,Vi)=KiQiTdkViAttention(Q_i,K_i,V_i)=\dfrac{K_iQ_i^T}{\sqrt{d_k}}V_iAttention(Qi,Ki,Vi)=dkKiQiTVi

headi=Attention(Qi,Ki,Vi)head_i=Attention(Q_i,K_i,V_i)headi=Attention(Qi,Ki,Vi)

h组并行计算

  • 拼接之后输出

MultiHead(Q,K,V)=Concat(head1,head2,...head5)MultiHead(Q,K,V)=Concat(head_1,head_2,...head_5)MultiHead(Q,K,V)=Concat(head1,head2,...head5)

经过multi-head之后,得到h1,h2,h3,h4h_1,h_2,h_3,h_4h1,h2,h3,h4
看图上怎么还有一个Linear???

4.2.2 残差链接

残差链接是这样的。
将输入x加到multi-head或者feed network的输出h上。这样可以加快训练。
这一步得到的结果记为h1′,h2′,h3′,h4′h_1',h_2',h_3',h_4'h1,h2,h3,h4

4.2.3 层正则化layer norm

层正则化,是对残差链接的结果做正则化。

h1′,h2′,h3′,h4′h_1',h_2',h_3',h_4'h1,h2,h3,h4这4个向量分别计算每个向量的均值μ\muμ和方差σ\sigmaσ
在这里插入图片描述
γ\gammaγβ\betaβ是共享的参数,在模型中需要训练。
γ\gammaγβ\betaβ可以在一定程度上抵消掉正则的操作。为什么正则了又要抵消呢?
这样做可以让每一个时间步的值更平均一些,差异不会特别大。
这一步的输出是h1′′,h2′′,h3′′,h4′′h_1'',h_2'',h_3'',h_4''h1,h2,h3,h4

4.2.4 前馈神经网络 feed forward network

对于上一步的结果加一个前馈神经网络。
FFN(x)=max(0,xW1+b1)W2+b2FFN(x) = max(0, xW_1 + b_1 )W_2 + b_2FFN(x)=max(0,xW1+b1)W2+b2
在每一个时间步会做一个y=F(x)的变化,得到另外的100维的向量。
对这一步的结果再加一个残差链接和层正则化。

这样就得到一个transformer block。
在这里插入图片描述
输入->Multi head attention ->残差链接->层正则化->Feed-forward Network->残差链接->层正则化。
在实际使用过程中层正则化会放在Multi head attention或者Feed-forward Network-前面。

4.3 decoder

在这里插入图片描述

decoder组件是由多个decoder组成的。在本模型中是6个decoder。
每一个decoder是由Masked Multi-head attention, Multi-head attention以及Feed-Forward三部分组成。

4.3.1 输入

在decoder中的第一步输入是始位置表示以及 encoder组件的输出:K和V。经过decoder之后,输出第一个单词I。第二步的输入是第一步的输出,以及K和V。

4.3.1 Masked Multi-head attention

Masked Multi-head attention的输入是decoder前一步的输出。第一个位置为起始位置表示。
接着加上位置编码。整体作为输入送入 Masked Multi-head attention。
在这里,处理第i步的时候,只能使用第1步到第i-1的向量做attention。这就是Masked含义。i位置之后的信息不可见。
之后做残差链接。
再之后做层正则化。将结果送入Multi-head attention。

4.3.2 Multi-head attention

Multi-head attention这一部分与encoder的Multi-head attention相同。输入是encoder组件的输出K和V,以及Masked Multi-head attention的输出,三部分作为输入。
经过Multi-head attention->残差链接->层正则化,输出。

4.3.3 前馈神经网络 feed forward network

上一步的输出,经过前馈神经网络的结果作为输出。

至此,一个decoder完成。其输出作为下一个decoder的输入。

4.4 The Final Linear and Softmax Layer

解码器最后输出浮点向量,如何将它转成词?这是最后的线性层和softmax层的主要工作。
线性层是个简单的全连接层,将解码器的最后输出映射到一个非常大的logits向量上。假设模型已知有1万个单词(输出的词表)从训练集中学习得到。那么,logits向量就有1万维,每个值表示是某个词的可能倾向值。
softmax层将这些分数转换成概率值(都是正值,且加和为1),最高值对应的维上的词就是这一步的输出单词。

4.5 decoder总结

encoder组件从输入序列的处理开始,最后的encoder组件的输出被转换为K和V,它俩被每个解码器的"encoder-decoder atttention"层来使用,帮助解码器集中于输入序列的合适位置。

下面的步骤一直重复直到一个特殊符号出现表示解码器完成了翻译输出。每一步的输出被喂到下一个解码器中。正如编码器的输入所做的处理,对解码器的输入增加位置向量。

4.4 The Loss Function

如何对比两个概率分布呢?简单采用 cross-entropy或者Kullback-Leibler divergence中的一种。

4.5 代码阅读

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/423960.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

第一百三十七期:一个简单的小案例带你理解MySQL中的事务

事务又叫做TCL,全称是transaction control language,意思是事务控制语言。 作者:Java的架构师技术栈 事务又叫做TCL,全称是transaction control language,意思是事务控制语言。这篇文章还是属于我的mysql基础文章&…

解决ffmpeg拉流转发频繁丢包问题max delay reached. need to consume packet

软件: 1、流媒体服务器EasyDarwin-windows-8.1.0-1901141151 2、ffmpeg-20181001-dcbd89e-win64-static 3、直播源:rtsp://192.168.1.168/0 4、流媒体服务器EasyDarwin地址rtsp://192.168.1.28/3 问题现象 [rtsp 0000000000122bc0] max delay reached. …

第一百三十八期:37 个MySQL数据库小知识,为面试做准备

无论是运维、开发、测试,还是架构师,数据库技术是一个必备加薪神器,那么,一直说学习数据库、学MySQL,到底是要学习它的哪些东西呢? 作者:芒果教你学编程 无论是运维、开发、测试,还是架构师&…

第一百三十九期:11月数据库排行:排名前三数据库分数暴跌

DB-Engines 数据库流行度排行榜 11 月更新已发布,与上期数据相比,这期排行榜最大的亮点就是排名前三数据库那引人注目的“红色”分数。 作者:局长 DB-Engines 数据库流行度排行榜 11 月更新已发布,排名前二十如下: ▲…

对话系统之NLU总结报告

文章目录1 项目介绍1.1 背景知识介绍1.2 数据集介绍1.3 评价指标2 技术方案梳理2.1 模型目标2.2 模型介绍2.3 模型实现2.3.1 数据处理2.3.2 构建dataset2.3.3 模型定义2.3.4 训练相关参数2.3.5 训练结果3 项目总结1 项目介绍 1.1 背景知识介绍 对话系统按领域分类&#xff0c…

闲聊型对话系统之NLG总结报告

文章目录1 项目介绍1.1 背景知识介绍1.2 NLG的实现方式1.2.1 基于模板1.2.2 检索式1.2.3 生成式1.3 数据集介绍2 技术方案梳理2.1 模型介绍2.2 评价指标2.3 模型实现2.3.1 数据处理2.3.2 构建dataset2.3.3 模型定义2.3.4 训练相关参数2.3.5 训练结果1 项目介绍 1.1 背景知识介…

spring mvc学习(50):java.lang.ClassNotFoundException: org.springframework.web.servlet. DispatcherSe

今天朋友发了个maven项目给我看,问我为什么启动不了。说实话,一直用Jfinal都快不会用spring了… 还是决定看看。 接收了文件,是maven构建的,打开eclipse,导入maven项目,然后部署到tomcat,启动t…

Luogu2439 [SDOI2005]阶梯教室设备利用 (动态规划)

同上一题&#xff0c;区间改左闭右开就双倍经验了。貌似可以跑最长路。 #include <iostream> #include <cstdio> #include <cstring> #include <algorithm> #include <cmath> #define R(a,b,c) for(register int a (b); a < (c); a) #defi…

spring mvc学习(51):jsonp

引入jar包 pom.xml <project xmlns"http://maven.apache.org/POM/4.0.0" xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation"http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">…

什么是word2vector

原文地址&#xff1a;https://www.julyedu.com/questions/interview-detail?quesId2761&cateNLP&kp_id30 什么是 Word2vec? 在聊 Word2vec 之前&#xff0c;先聊聊 NLP (自然语言处理)。NLP 里面&#xff0c;最细粒度的是 词语&#xff0c;词语组成句子&#xff0c…

spring mvc学习(52):json数据类型提交

引入jar包 pom.xml <project xmlns"http://maven.apache.org/POM/4.0.0" xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation"http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">…

关注:Eclipse,转贴eclipse CDT的开发文章

致谢&#xff1a;Tinyfool的鼎立相助&#xff01; cdt是在eclipse中编写C程序的插件&#xff0c;虽然还不是很完美&#xff0c;但是是在windows中编写linux下C程序&#xff08;GNU C&#xff09;的一个好途径。按照eclipse的官方网站的要求&#xff0c;要下载如下的东东…

第三课 SVM(2)

1 线性可分的数据集 1.1 超平面 SVM的思想是找到最大间隔的分隔超平面。 在两个分类中&#xff0c;找到能够一条线&#xff0c;以最好地区分这两个分类。这样如果有了新的点&#xff0c;这条线也能很好地做出分类。 这样的线在高维样本中的时候就叫做超平面。 1.2 几何间隔与…

《C Traps and Pitfalls》 笔记

这本书短短的100多页&#xff0c;很象是一篇文章。但是指出的很多问题的确容易出现在笔试的改错题中--------------------------------------------------------------------第1章 词法陷阱1.1 和 1.3 词法分析的"贪心法则"编译器从左到右读入字符&#xff0c;每个符…

spring mvc学习(53):回顾和springmvc返回值类型总结

媒体类型 MIME媒体类型&#xff08;简称MIME类型&#xff09;是描述报文实体主体内容的一些标准化名称&#xff08;比如&#xff0c;text/html、image/jpeg&#xff09;。 因特网有数千种不同的数据类型&#xff0c;HTTP仔细地给每种要通过web传输的对象都打上了名为MIME类型的…

2019hdu多校1

1009 考虑贪心&#xff0c;暴力枚举一位。 $o(676n)$ #include<bits/stdc.h> using namespace std; const int N1e5333; int n,m,zl; int pos[26],cnt[N],t[26],az[N]; char s[N],st[N]; int l[N],r[N],nx[N],zzq[26]; int main(){ios::sync_with_stdio(0);//freopen(&qu…

关于梅花雪的js树

最近一段时间&#xff0c;为了学习java&#xff0c;天天在看别人的框架&#xff0c;为了实现一颗树&#xff0c;找到了一个改写梅花雪的js&#xff0c;下面是一个基本的结构<% page language"java" import"java.util.*" pageEncoding"GBK"%&g…

总和最大区间问题

题目和解题思路来源于吴军著作《计算之魂》。本题目是例题1.3。 文章目录1 问题描述2 解题思路2.1 三重循环2.2 两重循环2.3 分治法2.4 正反两遍扫描的方法2.5 再进一步&#xff0c;假设失效3 应用动态规划1 问题描述 总和最大区间问题&#xff1a;给定一个实数序列&#xff0…

spring mvc学习(54):简单异常处理

引入jar包 pom.xml <project xmlns"http://maven.apache.org/POM/4.0.0" xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation"http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">…

spring mvc学习(55):简单异常处理二

引入jar包 pom.xml <project xmlns"http://maven.apache.org/POM/4.0.0" xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation"http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">…