【AIGC入门一】Transformers 模型结构详解及代码解析

Transformers 开启了NLP一个新时代,注意力模块目前各类大模型的重要结构。作为刚入门LLM的新手,怎么能不感受一下这个“变形金刚的魅力”呢?

目录

Transformers ——Attention is all You Need

背景介绍

模型结构

位置编码

代码实现:

Attention

Scaled Dot-product Attention

Multi-head Attention

Position-Wise Feed-Forward Networks

Encoder and Decoder

Add & Norm

mask 机制

参考链接


论文链接:Attention Is All You Need

Transformers ——Attention is all You Need

背景介绍

        在Transformer提出之前,NLP主要基于RNN、LSTM等算法解救相关问题。这些模型在处理长序列时面临梯度消失和梯度爆炸等问题,且这些模型是串行计算的,运行时间较长。

        Transformer 模型的提出是为了摆脱序列模型的顺序依赖性,引入了注意力机制,使得模型能够在不同位置上同时关注输入序列的各个部分,且支持并行计算。该模型的提出对深度学习和自然语言处理领域产生了深远的影响,成为了现代NLP模型的基础架构,并推动了attention 机制在各种任务中的应用。

模型结构

位置编码

        任何一门语言,单词在句子中的位置以及排列顺序是非常重要的。一个单词在句子的位置或者排列顺序不同,整个句子的意义就发生了偏差。举个例子:

小明小王500块

小王小明500块

顺序不同,债主关系就发生变化了😑

        当采用了Attention之后,句子中的词序信息就会丢失,模型就没法知道每个词在句子中的相对和绝对的位置信息。目前位置编码有多种方法:

(1)整型值标记位置,即第一个token标记为1, 第二个token标记为2。。。以此类推

         可能存在的问题:

  • 随着序列长度的增加,位置值会越来越大;
  • 推理的序列长度比训练时所用的序列长度更长,不利于模型的泛化

(2)用[0,1] 范围标记位置

        将位置值的范围限制在[0,1]之内,即在第一种的方法进行归一化操作(除以序列长度)。比如有4个token,那么位置信息就是[0, 0.33, 0.69, 1]。 但这样产生的问题是,当序列长度不同时,token间的相对距离是不一样的。

        因此,一个好的位置编码方法应该满足以下特性:

(1)可以表示一个token 在序列中的绝对位置;

(2)在序列长度不同的情况下,不同序列中token 的相对位置/ 距离要保持一直;

(3)可以扩展到更长的句子长度;

        Transformers 中选择的是sincos编码法,其公式如下所示:

        其中,pos 是token在sentence中的位置,i是维度。

代码实现:

        假设句子长度是 s, embedding的维度是d, 最终生成的PE的shape是(s, d)。公式的核心是计算pos /10000\tfrac{2i}{d_{model}}, 这里可以借助对数和指数的性质进行如下操作:

a = e^{^{loga}}

        所以可以转换成1/ 10000^{^{2i/d_{model}}} = e^{ - log(10000) * 2i /d_{model})}(可对照代码进行推导理解)

class Position_Encoding(nn.Module):def __init__(self, max_length, d_model):self.max_length = max_lengthself.dim = d_modelpe = torch.zeros(self.max_length, self.dim)position = torch.arange(0, self.max_length).unsqueeze(1)div_term = torch.exp( torch.arange(0, self.dim ,2) * (-1) *math.log(10000) / self.dim)pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)def forward(self, x):#  input_embedding + position_encoding#....

Attention

        Attention 是将query 和key、value映射为输出值,其中query 和 key 计算一个相似度,然后以这个相似度为权重,计算value的加权和,最终得到输出。

Scaled Dot-product Attention

        论文中用的是放缩点乘注意力(scaled dot-product attention),其公式是:

Attention(Q, K, V) = softmax(\frac{QK^{}{T}}{sqrt(d_k))})

其中, 计算时需要用到的矩阵Q(查询),K(键值),V(值)是输入单词的embedding 变换或者 上一个Encoder block的输出 。注意Q、K、V的shape 会存在一定的联系(因为需要做矩阵乘法运算)。

        公式中会除以dk的平方根,从而避免内积过大。还有解释是说 softmax 在 绝对值较大的区域梯度较小,梯度下降的速度比较慢。因此希望softmax的点乘数值尽可能小。

        论文中解释了为什么点积会变大。假设q 和 k中的元素满足独立分布,且均值是0,方差为1。点积 q*k = \sum_{i=1}^{dk} q_i * k_i   的均值是0, 方差是dk 。

Multi-head Attention

        作者发现,相比直接在dmodel 维度上的 q、k、v进行attention计算 ,使用不同的、可学习的linear function 分别地对q、k、v 进行多次映射(映射的维度是dk, dk, dv)  , 然后对每一组映射的q、k、v进行attention 并行计算,并concat得到最终输出。后一种方法更有效。就像卷积层可以用多个卷积核生成多个通道的特征,在Transformers中可以用多组self attention 生成多组注意力的结果,从而增加特征表示。其计算公式和流程图如下:

        注意: head的数量 * 每一组head中q的维度 = dmodel (输入Q的维度)

Position-Wise Feed-Forward Networks

        前馈网络比较简单,是一个两层的全连接层,第一层的激活函数是ReLU,第二层不使用激活函数,对应的公式如下所示:

Encoder and Decoder

       Transformer 从结构上可以分为Encoder 和Decoder 两个部分,这两者结构上比较类似,但也存在一些差异。

        上图红色区域对应的是Encoder部分,可以看出是由 Input Embedding 、Position Encoding 和6层的EncoderLayer组成。 EncoderLayer 主要包括Multi-head Attention, Add&Norm, Feed Forward ,Add&Norm。

        上图绿色区域对应的是Decoder部分,相比Encoder,需要注意Decoder中的Multi-head Attention 有所不同。首先是Masked Multi-head Attention, 是为了实现串行推理;第二个Multi-head Attention输入的Q、K、V来自不同的地方,其中Q是Masked Multi-head Attention 的输出, K和V是Encoder 的输出。

Add & Norm

        这部分主要由Add 和 Norm 组成,其计算公式如下所示:

        Add 是一种残差结构,和ResNet中的是一样的,可以帮助网络收敛。Norm 是指Layer Norm。

mask 机制

        Transformers中比较重要的一个知识点就是mask设置。mask主要来源有两个:第一个是填充操作的空白字符(为了保证batch内句子的长度一样会进行padding操作);第二个是因为模拟串行推理需要用到mask(Decoder部分)。

        一般情况下, query 和 key都是一样的,但是在Decoder的第二个多头注意力层中,query 来自目标语言,key来自源语言。为了生成mask, 首先要知道query 和 key中<pad> 字符的分布情况,它们的形状为[n, seq_len]。如果某处是True, 表明这个地方的字符是<pad>。

src_pad_mask = x == pad_idx
dst_pad_mask = y == pad_idx

        为了实现串行推理,即某字符只能知道该字符以及该字符之前的内容,即一个下三角全1矩阵。mask矩阵需要取反,实现方式如下所示:

mask = 1 - torch.tril(torch.ones(mask_shape))

        最后根据<pad>字符分布情况分别将mask对应的行或者列置1。

参考链接

  1. GitHub - P3n9W31/transformer-pytorch: Transformer model for Chinese-English translation.
  2. PyTorch Transformer 英中翻译超详细教程 - 知乎
  3. Transformer模型详解(图解最完整版) - 知乎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/627759.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Agent检索增强生成

检索增强生成(RAG)设计模式通常用于在特定数据域中开发大语言模型(LLM)应用。然而&#xff0c;RAG的过往的研究重点主要在于提高检索工具的效率&#xff0c;例如嵌入搜索、混合搜索和微调嵌入&#xff0c;而忽视了智能搜索。本文介绍了一种受人类研究方法启发的新方法&#xff…

for循环判断有几个偶数

num100 count0 for i in range(1,num):if i%20:print("为偶数")count1 print(f"1-100的范围内&#xff0c;有{count}个偶数") 运行结果如下&#xff1a;

代码随想录Day21 | 530.二叉搜索树的最小绝对差 501.二叉搜索树中的众数 236. 二叉树的最近公共祖先

代码随想录Day21 | 530.二叉搜索树的最小绝对差 501.二叉搜索树中的众数 236. 二叉树的最近公共祖先 二叉搜索树的最小绝对差二叉搜索树中的众数二叉树的最近公共祖先 二叉搜索树的最小绝对差 文档讲解&#xff1a;代码随想录 视频讲解&#xff1a; 二叉搜索树中&#xff0c;需…

Nginx Ingress轻松上手 | Kubernetes服务管理指南

1. 揭秘Nginx Ingress的魔力 Nginx Ingress是你Kubernetes集群中的得力助手&#xff0c;无需额外安装&#xff0c;已内置于K8s。作为基于Nginx的扩展&#xff0c;它担任负载均衡器和入口控制器的重要角色。 2. 为何选择Nginx Ingress&#xff1f; 2.1 服务曝露的便利 通过简…

Kaggle之旅1

Kaggle之旅1 文章目录 Kaggle之旅1前言一、目标&#xff1f;二、课程1 pandas1. 学和练2. 一些关键摘要 总结 前言 Kaggle是一个以数据科学竞赛为主题的在线平台。它提供了一个数据科学社区&#xff0c;让数据科学家和机器学习专家可以在这里交流、学习和竞争。Kaggle上有大量…

深度掌握 Nginx Ingress:解锁高级功能,打造 Kubernetes 中的流量掌控艺术

前言 在 Kubernetes 的世界里&#xff0c;Nginx Ingress 不仅是流量的门卫&#xff0c;更是一把强大的调控利器。我们已经领略了其基础面貌&#xff0c;现在让我们踏上深度之旅&#xff0c;揭示 Nginx Ingress 的高级功能&#xff0c;助你在 Kubernetes 中创造流量掌控的艺术。…

2024秋招,深信服测试开发工程师一面

前言 回顾一下我秋招参加的第一次线下面试 这个面试体现出了我的很多弱点&#xff0c;也为我后面的改进起着很重要的作用 时间&#xff1a;40min 平台&#xff1a;线下面试 过程 1、个人介绍 2、项目经历 3、团队项目中负责的模块&#xff0c;队友都负责哪些工作&#x…

使用免费敏捷工具Leangoo领歌管理Sprint Backlog

什么是Sprint Backlog&#xff1f; Sprint Backlog是Scrum的主要工件之一。在Scrum中&#xff0c;团队按照迭代的方式工作&#xff0c;每个迭代称为一个Sprint。在Sprint开始之前&#xff0c;PO会准备好产品Backlog&#xff0c;准备好的产品Backlog应该是经过梳理、估算和优先…

C语言编译链接

1.翻译环境和运⾏环境 在ANSI C的任何⼀种实现中&#xff0c;存在两个不同的环境。 第1种是翻译环境&#xff0c;在这个环境中源代码被转换为可执⾏的机器指令。 第2种是执⾏环境&#xff0c;它⽤于实际执⾏代码。 2. 翻译环境 翻译环境是由编译和链接两个⼤的过程组成的&…

Spring环境搭配

概述 Spring 是一个开源框架&#xff0c;Spring 是于2003 年兴起的一个轻量级的Java 开发框架&#xff0c;由 RodJohnson 在其著作 Expert One-On-One J2EE Development and Design 中阐述的部分理念和原型衍生而来。它是 为了解决企业应用开发的复杂性而创建的。框架的主要优势…

mockjs使用1

mockjs使用 1、定义 Mock.js 是一款模拟数据生成器&#xff0c;旨在帮助前端攻城师独立于后端进行开发&#xff0c;帮助编写单元测试。提供了以下模拟功能&#xff1a; 根据数据模板生成模拟数据模拟 Ajax 请求&#xff0c;生成并返回模拟数据基于 HTML 模板生成模拟数据 2…

软件测试|SQLAlchemy query() 方法查询数据

简介 上一篇文章我们介绍了SQLAlchemy 的安装和基础使用&#xff0c;本文我们来详细介绍一下如何使用SQLAlchemy的query()方法来高效的查询我们的数据。 创建模型 我们可以先创建一个可供我们查询的模型&#xff0c;也可以复用上一篇文章中我们创建的模型&#xff0c;代码如…

Python字典,什么是字典、增删改查、属性操作和遍历

Python字典是一种无序的、可变的数据类型&#xff0c;它可以存储任意类型的数据&#xff0c;通过键值对的方式进行存储和访问。 字典的增&#xff1a; 使用键值对的方式&#xff0c;将数据添加到字典中。可以通过以下两种方式进行增加&#xff1a; 使用索引表达式&#xff08…

游戏开发,中小公司跳槽去大厂容易还是考研应届生校招容易?

游戏开发&#xff0c;中小公司跳槽去大厂容易还是考研应届生校招容易&#xff1f; 在之前的文章中&#xff0c;我们提到过&#xff0c;游戏开发行业首选直接进入游戏大厂。《开发者必读&#xff1a;如何选择适合的游戏开发公司&#xff1f;》因为大厂不仅能提供良好的职业发展…

接口防刷方案

1、前言 本文为描述通过Interceptor以及Redis实现接口访问防刷Demo 2、原理 通过ip地址uri拼接用以作为访问者访问接口区分 通过在Interceptor中拦截请求&#xff0c;从Redis中统计用户访问接口次数从而达到接口防刷目的 如下图所示 3、案例工程 项目地址&#xff1a; htt…

localStorage、sessionStorage、vuex区别和使用感悟

一、介绍及区别 localStorage的生命周期是永久&#xff1b;不手动在浏览器提供的UI上清除localStorage信息&#xff0c;否则这些信息将永远存在。 sessionStorage的生命周期为当前窗口或标签页&#xff0c;一旦窗口或标签页被永久关闭&#xff0c;那么所有通过sessionStorage存…

AI红娘开启约会新时代;网易云音乐Agent实践探索;微软生成式AI课程要点笔记;ComfyUI新手教程;图解RAG进阶技术 | ShowMeAI日报

&#x1f440;日报&周刊合集 | &#x1f3a1;生产力工具与行业应用大全 | &#x1f9e1; 点赞关注评论拜托啦&#xff01; &#x1f440; Perplexity 官宣 7360 万美元B轮融资&#xff0c;打造世界上最快最准确的答案平台 https://blog.perplexity.ai/blog/perplexity-rais…

uniapp中uview组件库Toast 消息提示 的使用方法

目录 #基本使用 #配置toast主题 #toast结束跳转URL #API #Props #Params #Methods 此组件表现形式类似uni的uni.showToastAPI&#xff0c;但也有不同的地方&#xff0c;具体表现在&#xff1a; uView的toast有5种主题可选可以配置toast结束后&#xff0c;跳转相应URL目…

Linux系统——yum仓库及NFS共享

目录 一、yum仓库 1.yum简介 2.yum实现过程 3.如何实现安装服务 4.yum配置文件及命令 4.1yum配置文件 4.1.1主配置文件 4.1.2仓库设置文件 4.1.3日志文件 4.2yum命令详解 4.2.1查询 4.2.2yum安装升级 4.2.3软件卸载 4.2.4操作安装历史记录 5.搭建本地yum仓库 5…

【分布式技术】分布式存储ceph部署

目录 一、存储的介绍 单机存储设备 单机存储的问题 商业存储 分布式存储 二、分布式存储 什么是分布式存储 分布式存储的类型 三、ceph简介 四、ceph的优点 五、ceph的架构 六、ceph的核心组件 七、OSD存储后端 八、Ceph 数据的存储过程 九、Ceph 版本发行生命周…