【深度学习】Transformer简介

 

近年来,Transformer模型在自然语言处理(NLP)领域中横扫千军,以BERT、GPT为代表的模型屡屡屠榜,目前已经成为了该领域的标准模型。同时,在计算机视觉等领域中,Transformer模型也逐渐得到了重视,越来越多的研究工作开始将这类模型引入到算法中。本文基于2017年Google发表的论文,介绍Transformer模型的原理。

 

一、为什么要引入Transformer?

最早提出的Transformer模型[1]针对的是自然语言翻译任务。在自然语言翻译任务中,既需要理解每个单词的含义,也需要利用单词的前后顺序关系。常用的自然语言模型是循环神经网络(Recurrent Neural Network,RNN)和卷积神经网络(Convolutional Neural Network,CNN)。

其中,循环神经网络模型每次读入一个单词,并基于节点当前的隐含状态和输入的单词,更新节点的隐含状态。从上述过程来看,循环神经网络在处理一个句子的时候,只能一个单词一个单词按顺序处理,必须要处理完前边的单词才能开始处理后边的单词,因此循环神经网络的计算都是串行化的,模型训练、模型推理的时间都会比较长。

另一方面,卷积神经网络把整个句子看成一个1*D维的向量(其中D是每个单词的特征的维度),通过一维的卷积对句子进行处理。在卷积神经网络中,通过堆叠卷积层,逐渐增加每一层卷积层的感受野大小,从而实现对上下文的利用。由于卷积神经网络对句子中的每一块并不加以区分,可以并行处理句子中的每一块,因此在计算时,可以很方便地将每一层的计算过程并行化,计算效率高于循环神经网络。但是卷积神经网络模型中,为了建立两个单词之间的关联,所需的网络深度与单词在句子中的距离正相关,因此通过卷积神经网络模型学习句子中长距离的关联关系的难度很大。

Transformer模型的提出就是为了解决上述两个问题:(1)可以高效计算;(2)可以准确学习到句子中长距离的关联关系。

 

二、Transformer模型介绍

如下图所示,Transformer模型采用经典的encoder-decoder结构。其中,待翻译的句子作为encoder的输入,经过encoder编码后,再输入到decoder中;decoder除了接收encoder的输出外,还需要当前step之前已经得到的输出单词;整个模型的最终输出是翻译的句子中下一个单词的概率。

【论文阅读】Transformer简介

现有方法中,encoder和decoder通常都是通过多层循环神经网络或卷积实现,而Transformer中则提出了一种新的、完全基于注意力的网络layer,用来替代现有的模块,如下图所示。图中encoder、decoder的结构类似,都是由一种模块堆叠N次构成的,但是encoder和decoder中使用的模块有一定的区别。具体来说,encoder中的基本模块包含多头注意力操作(Multi-Head Attention)、多层感知机(Feed Forward)两部分;decoder中的基本模块包含2个不同的多头注意力操作(Masked Multi-Head Attention和Multi-Head Attention)、多层感知机(Feed Forward)三部分。

【论文阅读】Transformer简介

在上述这些操作中,最核心的部分是三种不同的Multi-Head Attention操作,该操作的过程如下图所示,可以简单理解为对输入feature的一种变换,通过特征之间的关系(attention),增强或减弱特征中不同维度的强度。模型中使用的三种注意力模块如下:

  • Encoder中的Multi-Head Attention:encoder中的multi-head attention的输入只包含编码器中上一个基本模块的输出,使用上一个基本模块的输出计算注意力,并调整上一个基本模块的输出,因此是一种“自注意力”机制;
  • Decoder中的Masked Multi-Head Attention:Transformer中,decoder的输入是完整的目标句子,为了避免模型利用还没有处理到的单词,因此在decoder的基础模块中,在“自注意力”机制中加入了mask,从而屏蔽掉不应该被模型利用的信息;
  • Decoder中的Multi-Head Attention:decoder中,除了自注意力外,还要利用encoder的输出信息才能正确进行文本翻译,因此decoder中相比encoder多使用了一个multi-head attention来融合输入语句和已经翻译出来的句子的信息。这个multi-head attention结合使用decoder中前一层“自注意力”的输出和encoder的输出计算注意力,然后对encoder的输出进行变换,以变换后的encoder输出作为输出结果,相当于根据当前的翻译结果和原始的句子来确定后续应该关注的单词。

【论文阅读】Transformer简介

除核心的Multi-Head Attention操作外,作者还采用了位置编码、残差连接、层归一化、dropout等操作将输入、注意力、多层感知机连接起来,从而构成了完整的Transformer模型。通过修改encoder和decoder中堆叠的基本模块数量、多层感知机节点数、Multi-Head Attention中的head数量等参数,即可得到BERT、GPT-3等不同的模型结构。

 

三、实验效果

实验中,作者在newstest2013和newstest2014上训练模型,并测试了模型在英语-德语、英语-法语之间的翻译精度。实验结果显示,Transformer模型达到了State-of-the-art精度,并且在训练开销上比已有方法低一到两个数量级,展现出了该方法的优越性。

与已有方法的对比实验,显示出更高的BLEU得分和更低的计算开销:

【论文阅读】Transformer简介

模块有效性验证,模型中每个单次的特征维度、多头注意力中头的数量、基本模块堆叠数量等参数对模型的精度有明显的影响:

【论文阅读】Transformer简介

参考文献

[1] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Łukasz Kaiser, Illia Polosukhin. Attention Is All You Need. NIPS 2017.

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/162830.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【PythonGIS】基于Python面矢量转换线矢量

今天有些不一样,发这篇文章并不是项目需要。单纯的想到有这个功能没使用Python实现,所以就去研究了一下,第一时间就和大家分享。如何使用Python的osgeo库实现面矢量数据与线矢量数据的互相转换。 一、导入所需库 import os from osgeo impor…

论文速读《DeepFusion: Lidar-Camera Deep Fusion for Multi-Modal 3D Object Detection》

概括主要内容 文章《DeepFusion: Lidar-Camera Deep Fusion for Multi-Modal 3D Object Detection》提出了两种创新技术,以改善多模态3D检测模型的性能,通过更有效地融合相机和激光雷达传感器数据来提高对象检测的准确性,尤其是在行人检测方面…

自动化提交git

1.前要 这里只是讲解如何在Windows上创建自动化脚本/程序来达到自动pull、commit、push,减少冗余的仓库更新工作,避免在多平台下合作造成版本冲突等。 2.原理 使用Windows下默认的cmd/bat脚本编写代码。 只需要在网络上查询一些相关的语法&#xff0…

2023亚太杯数学建模C题思路 - 我国新能源电动汽车的发展趋势

1 赛题 问题C 我国新能源电动汽车的发展趋势 新能源汽车是指以先进技术原理、新技术、新结构的非常规汽车燃料为动力来源( 非常规汽车燃料指汽油、柴油以外的燃料),将先进技术进行汽车动力控制和驱动相结 合的汽车。新能源汽车主要包括四种类型&#x…

【计算思维】蓝桥杯STEMA 科技素养考试真题及解析 6

1、明明买了一个扫地机器人,可以通过以下指令控制机器人运动: F:向前走 10 个单位长度 L:原地左转 90 度 R:原地右转 90 度 机器人初始方向向右,需要按顺序执行以下那条指令,才能打扫完下图中的道路 A、F-L-F-R-F-F-R-F-L-F B、F-R-F-L-F-F…

h5如何使用navigateBack回退到微信小程序页面并携带参数

前言 在h5中使用navigateBack回退到微信小程序页面很常见,但是有一种交互需要在回退之后的页面可以得到通知,拿到标识之后,进行某些操作,这样的话,由于微信官方并没有直接提供这样的api,就需要我们开动脑筋…

视频剪辑有妙招:批量置入封面,轻松提升视频效果

随着社交媒体的兴起,视频已经成为分享和交流的重要方式。无论是专业的内容创作者还是普通的社交媒体用户,都要在视频剪辑上下一番功夫,才能让视频更具吸引力。而一个吸引的封面往往能在一瞬间抓住眼球,提高点击率。还在因如何选择…

【SpringBoot】Redisson 分布式锁注解和 @Transactional 注解一起使用问题

一、前言 平时使用切面去加分布式锁,是先开启事务还是先尝试获得锁?这两者有啥区别? 业务中怎么控制切面的顺序?切面的顺序对事务的影响怎么避免? 下面程序分析: OverrideTransactionalpublic ReceiveH5…

uni-app - 弹出框

目录 1.基本介绍 2.原生uinapp 通过uni.showActionSheet实现 3.使用组件 Popup 弹出层 ③效果展示 1.基本介绍 弹出框让我们在需要时在屏幕底部弹出一个菜单,它通常用于在各种应用程序中进行选择操作。Uniapp为我们提供了基本的底部弹出框组件,但它也有…

OpenSearch开发环境安装Docker和Docker-Compose两种方式

文章目录 简介常用请求创建映射写入数据查询数据其他 安装Docker方式安装OpenSearch安装OpenSearchDashboard Docker-Compose方式Docker-Compose安装1.设置主机环境2.下载docker-compose.yml文件3.启动docker-compose4.验证 问题问题1:IPv4 forwarding is disabled.…

如何搭建Zblog网站并通过内网穿透将个人博客发布到公网

文章目录 1. 前言2. Z-blog网站搭建2.1 XAMPP环境设置2.2 Z-blog安装2.3 Z-blog网页测试2.4 Cpolar安装和注册 3. 本地网页发布3.1. Cpolar云端设置3.2 Cpolar本地设置 4. 公网访问测试5. 结语 1. 前言 想要成为一个合格的技术宅或程序员,自己搭建网站制作网页是绕…

Altium Designer学习笔记11

画一个LED的封装: 使用这个SMD5050的封装。 我们先看下这个芯片的功能说明: 5050贴片式发光二极管: XL-5050 是单线传输的三通道LED驱动控制芯片,采用的是单极性归零码协议。 数据再生模块的功能,自动将级联输出的数…

CSGO搬砖干货,全网最详细教学!

CSGO游戏搬砖全套操作流程及注意事项(第一课) 在电竞游戏中,CSGO(Counter-Strike: Global Offensive)被广大玩家誉为经典之作。然而,除了在游戏中展现个人实力和团队合作外,有些玩家还将CSGO作为…

Java之API(上)

前言: 这一次内容主要是围绕Java开发中的一些常用类,然后主要是去学习这些类里面的方法。 一、高级API: (1)介绍:API指的是应用程序编程接口,API可以让编程变得更加方便简单。Java也提供了大量API供程序开发者使用&…

如何使用Google My Business来提升您的内容和SEO?

如果您的企业有实体店,那么使用Google My Business(GMB)来改善您的本地SEO并增强您的在线形象至关重要。Google My Business (GMB) 是 Google 提供的补充工具,使企业能够控制其在 Google 搜索和地图上的数字…

大数据基础设施搭建 - Flume

文章目录 一、上传压缩包二、解压压缩包三、监控本地文件(file to kafka)3.1 编写配置文件3.2 自定义拦截器3.2.1 开发拦截器jar包(1)创建maven项目(2)开发拦截器类(3)开发pom文件&a…

【数字化转型方法论读书笔记】-数据中台角色解读

一千个读者,就有一千个哈姆雷特。同样,数据中台对于企业内部不同角色的价值也不同,下面分别从董事长、CEO、 CTO/CIO、IT 架构师、数据分析师这 5 个角色的视角详细解读数据中台。 1、董事长视角下的数据中台 在数字经济时代,企业…

RTT打印在分区跳转后无法打印问题

场景: RTT打印仅占用JLINK的带宽,比串口传输更快更简洁,同时RTT可以使用jscope对代码里面的变量实时绘图显示波形,而采用串口打印波形无法实时打印。同时可以保存原始数据到本地进行分析,RTT在各方面完胜串口。 问题描…

PTA-城市间紧急救援

作为一个城市的应急救援队伍的负责人,你有一张特殊的全国地图。在地图上显示有多个分散的城市和一些连接城市的快速道路。每个城市的救援队数量和每一条连接两个城市的快速道路长度都标在地图上。当其他城市有紧急求助电话给你的时候,你的任务是带领你的…

采样概率 假设检验推导数组最大值的方法与可行性

当需要寻找大量数据中的最大值的时候,比如从 2G 个 float16 中寻找其中的最大值,是一件耗时的操作。 现计划通过小样本来发掘数据的规律,对最大值进行预测。 方案: step1,从2G个float16 中截取64段float16&#xff…