Transformer学习理解

1.前言        

        本文介绍当下人工智能领域的基石与核心结构模型——Transformer,为什么说它是基石,因为以ChatGPT为代表的聊 天机器人以及各种有望通向AGI(通用人工智能)的道路上均在采用的Transformer。 Transformer也是当下NLP任务的底座,包括后续的BERT和GPT,都是Transformer架构,BERT主要由Transformer的 encoder构成,GPT则主要是decoder构成。

        本文将会通读Transformer原论文《Attention is all you need》,总结Transformer结构以及文章要点,然后采用pytorch 代码进行机器翻译实验,通过代码进一步了解Transformer在应用过程中的步骤。

2.论文阅读笔记

        Transformer的论文是《Attention is all you need》(https://arxiv.org/abs/1706.03762),由Google团队在2017提出的 一种针对机器翻译的模型结构,后续在各类NLP任务上均获取SOTA。

 Motivation:针对RNN存在的计算复杂、无法串行的问题,提出仅通过attention机制的简洁结构——Transformer,在多个序列任务、机器翻译中获得SOTA,并有运用到其他数据模态的潜力。

 模型结构: 模型由encoder和decoder两部分构成,分别采用了6个block堆叠, encoder的block有两层,multi-head attention和FC层 decoder的block有三层,处理自回归的输入masked multi-head attention,处理encoder的attention,FC层。

注意力机制:scale dot-production attention,采用QK矩阵乘法缩放后softmax充当权重,再与value进行乘法。 多头注意力机制:实验发现多个头效果好,block的输出,把多个头向量进行concat,然后加一个FC层。因此每个头的向量长度是总长度/头数,例:512/8=64,每个头是64维向量。

Transformer的三种注意力层:

(1)encoder:输入来自上一层的全部输出

(2)decoder-输入:为避免模型未卜先知,只允许看见第i步之前的信息,需要做mask操作,确保在生成序列的每个 元素时,只考虑该元素之前的元素。这里通过softmax位置设置负无穷来控制无效的连接。

(3)decoder-第二个attention:q来自上一步输出,k和v来自encoer的输出。这样解码器可以看到输入的所有序列。

FFN层:attention之后接入两个FC层,第一个FC层采用2048,第二个是512,第一个FC层采用max(0, x)作为激活函数。

embedding层的缩放:在embedding层进行尺度缩放,乘以根号d_model。

位置编码:采用正余弦函数构建位置向量,然后采用加法,融入embedding中。

实验:两个任务,450万句子(英-德),3600万句子(英-法),8*16GB显卡,分别训练12小时,84小时。10万step 时,采用4kstep预热;采用标签平滑0.1。

重点句子摘录

1. In the Transformer this is reduced to a constant number of operations, albeit at the cost of reduced effective resolution due to averaging attention-weighted positions, an effect we counteract with Multi-Head Attention as described in section 3.2.

        由于注意力机制最后的加权平均可能会序列中各位置的细粒度捕捉不足,因此引入多头注意力机制。 这里是官方对多头注意力机制引入的解释。

2. We suspect that for large values of dk, the dot products grow large in magnitude, pushing the softmax function into regions where it has extremely small gradients 4. To counteract this effect, we scale the dot products by 根号 dk

        在Q和K做点积时,若向量维度过长,会导致点积结果过大,再经softmax映射后,梯度会变小,不利于模型学习,因此 需要进行缩放。缩放因子为除以根号dk。

多头注意力机制的三种情况:

        i. encoder:输入来自上一层的全部输出。

        ii. decoder-输入:为避免模型未卜先知,只允许看见第i步之前的信息,需要做mask操作,确保在生成序列的每个 元素时,只考虑该元素之前的元素。这里通过softmax位置设置负无穷来控制无效的连接。

        iii. decoder-第二个attention:q来自上一步输出,k和v来自encoer的输出。这样解码器可以看到输入的所有序列。

首次阅读遗留问题

        1. 位置编码是否重新学习? 不重新学习,详见 PositionalEncoding的         self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))

        2. qkv具体实现过程

                i. 通过3个w获得3个特征向量:

                ii. attn = torch.matmul(q / self.temperature, k.transpose(2, 3)) # q.shape == (bs, n_head, len_seq, d_k/n_head) ,每个token用 一个向量表示,总向量长度是头数*每个头的向量长度。

                iii. attn = self.dropout(F.softmax(attn, dim=-1)) iv. output = torch.matmul(attn, v)

        3. decoder输入的处理细节

                 i. 训练阶段,无特殊处理,一个样本可以直接输入,根据下三角mask避免未卜先知。

                ii. 推理阶段,首先手动执行token的输入,然后for循环直至最大长度,期间输出的token拼接到输出列表中,并作为下一步decoder的输入。选择输出token时还采用了beam search来有效地平衡广度和深度。

3.数据集构建

        数据下载

        本案例数据集Tatoeba下载自这里。该项目是帮助不同语言的人学习英语,因此是英语与其它几十种语言的翻译文本。 其中就包括本案例使用的英中文本,共计29668条(Mandarin Chinese - English cmn-eng.zip (29668)) 数据以txt形式存储,一行是一对翻译文本,例如长这样:

        1.That mountain is easy to climb. 那座山很容易爬。

        2.It's too soon. 太早了。

        3.Japan is smaller than Canada. 日本比加拿大小。

        数据集划分

         对于29668条数据进行8:2划分为训练、验证,这里采用配套代码a_data_split.py进行划分,即可在统计目录下获得 train.txt和text.txt。

4.词表构建

        文本任务首要任务是为文本构建词表,这里采用与上节一样的方法,首先对文本进行分词,然后统计语料库中所有的 词,最后根据最大上限、最小词频等约束,构建词表。本部分配套代码是b_gen_vocabulary.py

        词表的构建过程中,涉及两个知识点:中文分词和特殊token。

        1. 中文分词

        对于英文,分词可以直接采用空格。而对于中文,就需要用特定的分词方法,这里采用的是jieba分词工具,以下是英文 和中文的分词代码。

        source.append(parts[0].split(' '))

        target.append(list(jieba.cut(parts[1]))) # 分词

        2. 特殊token

         由于seq2seq任务的特殊性,在解码器部分,通常需要一个token告诉模型,现在是开始,同时还需要有个token让模型 输出,以此告诉人类,模型输出完毕,不要再继续生成了。

         因此相较于文本分类,还多了bos, eos,两个特殊token,有的时候,开始token也会用start表示。

        PAD_TAG = "" # 用PAD补全句子长度

         BOS_TAG = "" # 用BOS表示开始

        EOS_TAG = "" # 用EOS表示结束

        UNK_TAG = "" # 用EOS表示结束

        PAD = 0 # PAD字符对应的数字

        BOS = 1 # BOS字符对应的数字

         EOS = 2 # EOS字符对应的数字

         UNK = 3 # UNK字符对应的数字

        运行代码后,词表字典保存到了result目录下,并得到如下输出,表明英文中有2518个词,中文有3365,但经过最大长 度3000的截断后,只剩下2996,另外4个是特殊token。

        100%|██████████| 23635/23635 [00:00<00:00, 732978.24it/s]

        原始词表长度:2518,截断后长度:2518

        2522

        保存词频统计图:vocab_en.npy_word_freq.jpg

        100%|██████████| 23635/23635 [00:00<00:00, 587040.62it/s]

        保存统计图:vocab_en.npy_length_freq.jpg

        原始词表长度:3365,截断后长度:2996

        3000

5.Dataset编写

        NMTDataset的编写逻辑与上一小节的Dataset类似,首先在类初始化的时候加载原始数据,并进行分词;在getitem迭代 时,再进行token转index操作,这里会涉及增加结束符、填充符、未知符。 核心代码如下:

def __init__(self, path_txt, vocab_path_en, vocab_path_fra, max_len=32):self.path_txt = path_txtself.vocab_path_en = vocab_path_enself.vocab_path_fra = vocab_path_fraself.max_len = max_lenself.word2index = WordToIndex()self._init_vocab()self._get_file_info()def __getitem__(self, item):
# 获取切分好的句子list,一个元素是一个词sentence_src, sentence_trg = self.source_list[item], self.target_list[item]
# 进行填充, 增加结束符,索引转换token_idx_src = self.word2index.encode(sentence_src, self.vocab_en, self.max_len)token_idx_trg = self.word2index.encode(sentence_trg, self.vocab_fra, self.max_len)str_len, trg_len = len(sentence_src) + 1, len(sentence_trg) + 1 # 有效长度, +1是填充的结束符 <eos>.return np.array(token_idx_src, dtype=np.int64), str_len, np.array(token_idx_trg,         dtype=np.int64), trg_lendef _get_file_info(self):text_raw = read_data_nmt(self.path_txt)text_clean = text_preprocess(text_raw)self.source_list, self.target_list = text_split(text_clean)

6.模型构建

        Transformer代码梳理如下图所示,大体可分为三个层级

        1. Transformer的构建,包含encoder、decoder两个模块,以及两个mask构建函数。

        2. 两个coder内部实现,包括位置编码、堆叠的block。

        3. block实现,包含多头注意力、FFN,其中多头注意力将softmax(QK)*V拆分为ScaledDotProductAttention类。

        代码整体与论文中保持一致,总结几个不同之处:

        1. layernorm使用时机前置到attention层和FFN层之前

        2. position embedding 的序列长度默认采用了200,如果需要更长的序列,则要注意配置。 具体代码实现不再赘述,把论文中图2的结果熟悉,并掌握上面的代码结构,可以快速理解各模块、组件的运算和操作步 骤,如有疑问的点,再打开代码观察具体运算过程即可。

7.模型训练

        原论文进行两个数据集的机器翻译任务,采用的数据和超参数列举如下,供参考。

        英语-德语,450万句子对,英语-法语,3600万句子对。均进行base/big两种尺寸训练,分别进行10万step和30万step训 练,耗时12小时/84小时。10万step时,采用4千step进行warmup。正则化方面采用了dropout=0.1的残差连接,0.1的标 签平滑。

        本实验中有2.3万句子对训练,只能作为Transformer的学习,性能指标仅为参考,具体任务后续由BERT、GPT、T5来 实现更为具体的项目。 采用配套代码train_transformer.py,执行训练即可:

        采用配套代码train_transformer.py,执行训练即可:

python train_transformer.py -embs_share_weight -proj_share_weight -label_smoothing -b 256 -warmup 128000 -epoch 400

        训练完成后,在result文件夹下会得到日志与模型,接下来采用配套代码c_train_curve_plot.py

        对日志数据进行绘图可视化,Loss和Accuracy分别如下,可以看出模型拟合能力非常强,性能还在提高,但受限于数据量过少,模型在200个 epoch之后就已经出现了过拟合。

        这里面的评估指标用的acc,具体是直接复用github项目,也能作为模型性能的评估指标,就没有去修改为BLUE。

8.模型推理

        Transformer的推理过程与训练时不同,值得仔细学习。

        Transformer的推理输出是典型的自回归(Auto regressive),并且需要根据多个条件综合判断何时停止,因此推理部分的逻辑值得认真学习,具体步骤如下:

        第一步:输入序列经encoder,获得特征,每个token输出一个向量,这个特征会在decoder的每个step都用到,即 decoder中各block的第二个multi-head attention。需要enc_output去计算k和v,用decoder上一层输出特征向量去计算 q,以此进行decoder的第二个attention。

        enc_output, *_ = self.model.encoder(src_seq, src_mask)

        第二步:手动执行decoder第一步,输入是这个token,输出的是一个概率向量,由分类概率向量再决定第一个输出 token。

        self.register_buffer('init_seq', torch.LongTensor([[trg_bos_idx]]))

        dec_output =  self._model_decode(self.init_seq, enc_output, src_mask)

        第三步:循环max_length次执行decoder剩余步。第i步时,将前序所有步输出的token,组装为一个序列,输入到 decoder。

         在代码中用一个gen_seq维护模型输出的token,输入给模型时,只需要gen_seq[:, :step]即可,很巧妙。 在每一步输出时,都会判断是否需要停止输出字符。

for step in range(2, max_seq_len):dec_output = self._model_decode(gen_seq[:, :step], enc_output, src_mask)
略if (eos_locs.sum(1) > 0).sum(0).item() == beam_size:_, ans_idx = scores.div(seq_lens.float() ** alpha).max(0)ans_idx = ans_idx.item()break

        借助李宏毅老师2021年课件中一幅图,配合代码,可以很好的理解Transformer在推理时的流程。

         1. 输入序列经encoder,获得特征(绿色、蓝色、蓝色、橙色)

          2. decoder输入第一个序列(序列长度为1,token是),输出第一个概率向量,并通过max得到“机”。

          3. decoder输入第二个序列(序列长度为2,token是[BOS, 机]),输出得到“器”

          4. decoder输入第三个序列(序列长度为3,token是[BOS, 机,器]),输出得到“学”

          5. decoder输入第四个序列(序列长度为4,token是[BOS, 机,器,学]),输出得到“习”

          6. ...以此类推 

        在推理过程中,通常还会使用Beam Search 来最终确定当前步应该输出哪个token,此处不做展开。

        运行配套代码inference_transformer.py可以看到10条训练集的推理结果。

         从结果上看,基本像回事儿了。

        src: tom's door is open .

        trg: 湯姆的門開著 。

        pred: 汤姆的 <unk>了 。

        src:<unk> is a <unk> country .

        trg: <unk>是一個<unk>的國家 。

         pred: <unk>是一個<unk>的城市 。

         src: i can come at three .

         trg: <unk>可以 來 。

        pred: 我 可以 在 那裡

9.总结

        本文通过Transformer论文的学习,了解Transformer基础架构,并通过机器翻译案例,从代码实现的角度深入剖析 Transformer训练和推理的过程。由于Transformer是目前人工智能的核心与基石,因此需要认真、仔细的掌握其中细节。

        本文内容值得注意的知识点:

         1. 多头注意力机制的引入:官方解释为,由于注意力机制最后的加权平均可能会序列中各位置的细粒度捕捉不足,因此引入多头注意力机制。

         2. 注意力计算时的缩放因子:QK乘法后需要缩放,是因为若向量维度过长,会导致点积结果过大,再经softmax映射后,梯度会变小,不利于模型学习,因此需要进行缩放。缩放因子为除以根号dk。

        3. 多头注意力机制的三种情况:

          i. encoder:输入来自上一层的全部输出

         ii. decoder-输入:为避免模型未卜先知,只允许看见第i步之前的信息,需要做mask操作,确保在生成序列的每个元素时,只考虑该元素之前的元素。这里通过softmax位置设置负无穷来控制无效的连接。

         iii. decoder-第二个attention:q来自上一步输出,k和v来自encoer的输出。这样解码器可以看到输入的所有序列。

        4. 输入序列的msk:代码实现时,由于输入数据是通过添加来组batch的,并且为了在运算时做并行运算,因此需要对 src中是pad的token做mask,这一点在论文是不会提及的。

        5. decoder的mask:根据下三角mask避免未卜先知。

        6. 推理时会采用beam search进行搜索,确定输出的token。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/30879.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【上海交大】博士生年度进展报告模板

上海交通大学 博士生年度进展报告模板 比较不好找&#xff0c;在交我办中发起申请流程后才能看到链接&#xff0c;链接如下&#xff1a; https://www.gs.sjtu.edu.cn/xzzx/pygl/15

爬取CSDN博文到本地(包含图片,标签等信息)

文章目录 csdnToMD改进将CSDN文章转化为Markdown文档那有什么办法快速得到md文档&#xff1f;例如&#xff1a;获取单个文章markdown获取所有的文章markdown 项目中待解决的问题 csdnToMD 项目原作者&#xff1a;https://gitee.com/liushili888/csdn-is—mark-down 改进后仓库…

Z语言学习——基于通讯案例

目录 1数据类型 2初始状态 3 Alice的消息发送 4 Bob接收与发送消息 5 Alice接收消息 6消息的增删改查 6.1 删除消息 6.2查询消息 6.3修改/增加消息 7定理证明——重要目的 案例背景&#xff1a; (1)构建一个交互式的通讯方案&#xff1b; (2)攻击者控制了所有的通讯…

眼见不一定为实之MySQL中的不可见字符

目录 前言 一、问题的由来 1、需求背景 2、数据表结构 二、定位问题 1、初步的问题 2、编码是否有问题 3、依然回到字符本身 三、深入字符本身 1、回归本质 2、数据库解决之道 3、代码层解决 四、总结 前言 在开始今天的博客内容之前&#xff0c;正在看博客的您先来…

Element-UI实现el-dialog弹框拖拽功能

在实际开发中&#xff0c;会发现有些系统&#xff0c;弹框是可以在浏览器的可见区域自由拖拽的&#xff0c;这极大方便用户的操作。但在查看Element-UI中弹框&#xff08;el-dialog&#xff09;组件的文档时&#xff0c;发现并未实现这一功能。不过也无须担心&#xff0c;vue中…

Day 28:2748. 美丽下标对的数目

Leetcode 2748. 美丽下标对的数目 给你一个下标从 0 开始的整数数组 nums 。如果下标对 i、j 满足 0 ≤ i < j < nums.length &#xff0c;如果 nums[i] 的 第一个数字 和 nums[j] 的 最后一个数字 互质 &#xff0c;则认为 nums[i] 和 nums[j] 是一组 美丽下标对 。 返回…

Linux系统之ARP命令的基本使用

Linux系统之ARP命令的基本使用 一、ARP介绍二、ARP命令帮助2.1 ARP的help帮助信息2.2 ARP命令的帮助解释 三、ARP命令的基本使用3.1 查看ARP缓存3.2 显示详细信息3.3 添加静态arp映射3.4 删除指定主机的ARP条目3.5 从文件读取并添加条目3.6 清除ARP缓存 四、注意事项五、总结 一…

wins系统资源监视器任务管理器运行监控CPU、内存、磁盘、网络运行状态

目录 1.Windows系统资源监视器的详细介绍2.通过任务管理器打开资源监视器3.任务管理中总体观察观察cpu、pid、应用程序、I/O次数或者说读写字节数 4.观察CPU观察cpu核心数&#xff0c;以及哪些占用cpu频率过高 5.观察内存观察各个应用占用的内存大小和对应线程 6.观察磁盘活动观…

【前端技巧】css篇

利用counter实现计数器 counter-reset&#xff1a;为计数器设置名称&#xff0c;语法如下&#xff1a; counter-rese: <idntifier><integer>第一个参数为变量名称&#xff0c;第二个参数为初始值&#xff0c;默认为0 counter-increment&#xff1a;设置计数器增…

LabVIEW与3D相机开发高精度表面检测系统

使用LabVIEW与3D相机开发一个高精度表面检测系统。该系统能够实时获取三维图像&#xff0c;进行精细的表面分析&#xff0c;广泛应用于工业质量控制、自动化检测和科学研究等领域。通过真实案例&#xff0c;展示开发过程中的关键步骤、挑战及解决方案&#xff0c;确保系统的高性…

宕机了, redis如何保证数据不丢?

前言 如果有人问你&#xff1a;"你会把 Redis 用在什么业务场景下&#xff1f;" 我想你大概率会说&#xff1a;"我会把它当作缓存使用&#xff0c;因为它把后端数据库中的数据存储在内存中&#xff0c;然后直接从内存中读取数据&#xff0c;响应速度会非常快。…

【Linux从入门到放弃】进程地址空间

&#x1f9d1;‍&#x1f4bb;作者&#xff1a; 情话0.0 &#x1f4dd;专栏&#xff1a;《Linux从入门到放弃》 &#x1f466;个人简介&#xff1a;一名双非编程菜鸟&#xff0c;在这里分享自己的编程学习笔记&#xff0c;欢迎大家的指正与点赞&#xff0c;谢谢&#xff01; 进…

如何更换OpenHarmony SDK API 10

OpenHarmony社区已经发布OpenHarmony SDK API 10 beta版本&#xff0c;有些 Sample案例 也有需要API10。那么如何替换使用新的OpenHarmony SDK API 10呢&#xff1f;本文做个记录。 1、如何获取OpenHarmony SDK 1.1 每日构建流水线 可以从OpenHarmony每日构建站点获取最新的…

【网络安全的神秘世界】已解决Failed to start proxy service on 127.0.0.1:8080

&#x1f31d;博客主页&#xff1a;泥菩萨 &#x1f496;专栏&#xff1a;Linux探索之旅 | 网络安全的神秘世界 | 专接本 | 每天学会一个渗透测试工具 解决burpsuite无法在 127.0.0.1&#xff1a;8080 上启动代理服务端口被占用以及抓不到本地包的问题 Burpsuite无法启动proxy…

定个小目标之刷LeetCode热题(25)

这道题采用的解法是桶排序&#xff0c;画草图如下 代码如下 //基于桶排序求解「前 K 个高频元素」 class Solution {public int[] topKFrequent(int[] nums, int k) {HashMap<Integer, Integer> map new HashMap();for (int num : nums) {if (map.containsKey(num)) {m…

【安防天下】模拟视频监控系统——模拟监控系统的构成视频采集设备

文章目录 1 模拟监控系统的构成2 视频采集设备2.1 摄像机相关技术2.1.1 摄像机的工作原理2.1.2 摄像机的分类2.1.3 摄像机的主要参数 2.2 镜头相关介绍2.2.1 镜头的主要分类2.2.2 镜头的主要参数 1 模拟监控系统的构成 模拟视频监控系统又称闭路电视监控系统&#xff0c; 一般…

htb_Blurry

端口扫描 80 按照教程注册安装clear ml 加载configuration的时候会报错 将json里的API&#xff0c;File Store的host都添加到/etc/hosts中 即可成功初始化 查找clear ml漏洞 发现一个cve-2024-24590 下面是一个利用脚本&#xff0c;但不能直接用 ClearML-vulnerability-…

好用的linux一键换源脚本

最近发现一个好用的linux一键换源脚本&#xff0c;记录一下 官方链接 大陆使用 bash <(curl -sSL https://linuxmirrors.cn/main.sh)# github地址 bash <(curl -sSL https://raw.githubusercontent.com/SuperManito/LinuxMirrors/main/ChangeMirrors.sh) # gitee地址 …

Linux基础命令大全(详解版)

Linux基础命令&#xff08;详解版&#xff09; 文章目录 Linux基础命令&#xff08;详解版&#xff09;1.Linux的目录结构**2.Linux路径的描述方式**3.Linux命令基础格式4.ls命令 隐藏文件、文件夹5.pwd命令6.cd命令 特殊路径符7.mkdir命令 文件操作命令8.touch命令9.cat命令10…

英伟达中国特供芯片降价背后:巨头与市场的较量

英伟达&#xff0c;这家曾经在人工智能芯片领域独领风骚的巨头&#xff0c;近期在中国市场遭遇了一些挑战。为了应对来自华为等中国本土企业的竞争&#xff0c;英伟达不得不采取降价策略&#xff0c;调整其专为中国市场打造的H20芯片价格&#xff0c;甚至低于华为的同类产品。这…