NLP实战8:图解 Transformer笔记

目录

1.Transformer宏观结构

2.Transformer结构细节

2.1输入

2.2编码部分

2.3解码部分

2.4多头注意力机制

2.5线性层和softmax

2.6 损失函数

3.参考代码


🍨 本文为[🔗365天深度学习训练营]内部限免文章(版权归 *K同学啊* 所有)
🍖 作者:[K同学啊]

Transformer整体结构图,与seq2seq模型类似,Transformer模型结构中的左半部分为编码器(encoder),右半部分为解码器(decoder),接下来拆解Transformer。

1.Transformer宏观结构

Transformer模型类似于seq2seq结构,包含编码部分和解码部分。不同之处在于它能够并行计算整个序列输入,无需按时间步进行逐步处理。

其宏观结构如下:

6层编码和6层解码器

其中,每层encoder由两部分组成:

  • Self-Attention Layer
  • Feed Forward Neural Network(前馈神经网络,FFNN)

decoder在encoder的Self-Attention和FFNN中间多加了一个Encoder-Decoder Attention层。该层的作用是帮助解码器集中注意力于输入序列中最相关的部分。

单层encoder和decoder

2.Transformer结构细节

2.1输入

Transformer的数据输入与seq2seq不同。除了词向量,Transformer还需要输入位置向量,用于确定每个单词的位置特征和句子中不同单词之间的距离特征。

2.2编码部分

编码部分的输入文本序列经过处理后得到向量序列,送入第一层编码器。每层编码器输出一个向量序列,作为下一层编码器的输入。第一层编码器的输入是融合位置向量的词向量,后续每层编码器的输入则是前一层编码器的输出。

2.3解码部分

最后一个编码器输出一组序列向量,作为解码器的K、V输入。

解码阶段的每个时间步输出一个翻译后的单词。当前时间步的解码器输出作为下一个时间步解码器的输入Q,与编码器的输出K、V共同组成下一步的输入。重复此过程直到输出一个结束符。

解码器中的 Self-Attention 层,和编码器中的 Self-Attention 层的区别:

  • 在解码器里,Self-Attention 层只允许关注到输出序列中早于当前位置之前的单词。具体做法是:在 Self-Attention 分数经过 Softmax 层之前,屏蔽当前位置之后的那些位置(将Attention Score设置成-inf)。
  • 解码器 Attention层是使用前一层的输出来构造Query 矩阵,而Key矩阵和Value矩阵来自于编码器最终的输出。

2.4多头注意力机制

Transformer论文引入了多头注意力机制(多个注意力头组成),以进一步完善Self-Attention。

  • 它扩展了模型关注不同位置的能力
  • 多头注意力机制赋予Attention层多个“子表示空间”。

残差链接&Normalize: 编码器和解码器的每个子层(Self-Attention 层和 FFNN)都有一个残差连接和层标准化(layer-normalization),细节如下图

2.5线性层和softmax

Decoder最终输出一个浮点数向量。通过线性层和Softmax,将该向量转换为一个包含模型输出词汇表中每个单词分数的logits向量(假设有10000个英语单词)。Softmax将这些分数转换为概率,使其总和为1。然后选择具有最高概率的数字对应的词作为该时间步的输出单词。

2.6 损失函数

在Transformer训练过程中,解码器的输出和标签一起输入损失函数,以计算损失(loss)。最终,模型通过方向传播(backpropagation)来优化损失。

3.参考代码

class SelfAttention(nn.Module):def __init__(self, embed_size, heads):super(SelfAttention, self).__init__()self.embed_size = embed_sizeself.heads = headsself.head_dim = embed_size // headsassert (self.head_dim * heads == embed_size), "Embed size needs to be div by heads"self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)self.fc_out = nn.Linear(heads * self.head_dim, embed_size)def forward(self, values, keys, query, mask):N =query.shape[0]value_len , key_len , query_len = values.shape[1], keys.shape[1], query.shape[1]# split embedding into self.heads piecesvalues = values.reshape(N, value_len, self.heads, self.head_dim)keys = keys.reshape(N, key_len, self.heads, self.head_dim)queries = query.reshape(N, query_len, self.heads, self.head_dim)values = self.values(values)keys = self.keys(keys)queries = self.queries(queries)energy = torch.einsum("nqhd,nkhd->nhqk", queries, keys)# queries shape: (N, query_len, heads, heads_dim)# keys shape : (N, key_len, heads, heads_dim)# energy shape: (N, heads, query_len, key_len)if mask is not None:energy = energy.masked_fill(mask == 0, float("-1e20"))attention = torch.softmax(energy/ (self.embed_size ** (1/2)), dim=3)out = torch.einsum("nhql, nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads*self.head_dim)# attention shape: (N, heads, query_len, key_len)# values shape: (N, value_len, heads, heads_dim)# (N, query_len, heads, head_dim)out = self.fc_out(out)return outclass TransformerBlock(nn.Module):def __init__(self, embed_size, heads, dropout, forward_expansion):super(TransformerBlock, self).__init__()self.attention = SelfAttention(embed_size, heads)self.norm1 = nn.LayerNorm(embed_size)self.norm2 = nn.LayerNorm(embed_size)self.feed_forward = nn.Sequential(nn.Linear(embed_size, forward_expansion*embed_size),nn.ReLU(),nn.Linear(forward_expansion*embed_size, embed_size))self.dropout = nn.Dropout(dropout)def forward(self, value, key, query, mask):attention = self.attention(value, key, query, mask)x = self.dropout(self.norm1(attention + query))forward = self.feed_forward(x)out = self.dropout(self.norm2(forward + x))return outclass Encoder(nn.Module):def __init__(self,src_vocab_size,embed_size,num_layers,heads,device,forward_expansion,dropout,max_length,):super(Encoder, self).__init__()self.embed_size = embed_sizeself.device = deviceself.word_embedding = nn.Embedding(src_vocab_size, embed_size)self.position_embedding = nn.Embedding(max_length, embed_size)self.layers = nn.ModuleList([TransformerBlock(embed_size,heads,dropout=dropout,forward_expansion=forward_expansion,)for _ in range(num_layers)])self.dropout = nn.Dropout(dropout)def forward(self, x, mask):N, seq_length = x.shapepositions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))for layer in self.layers:out = layer(out, out, out, mask)return outclass DecoderBlock(nn.Module):def __init__(self, embed_size, heads, forward_expansion, dropout, device):super(DecoderBlock, self).__init__()self.attention = SelfAttention(embed_size, heads)self.norm = nn.LayerNorm(embed_size)self.transformer_block = TransformerBlock(embed_size, heads, dropout, forward_expansion)self.dropout = nn.Dropout(dropout)def forward(self, x, value, key, src_mask, trg_mask):attention = self.attention(x, x, x, trg_mask)query = self.dropout(self.norm(attention + x))out = self.transformer_block(value, key, query, src_mask)return outclass Decoder(nn.Module):def __init__(self,trg_vocab_size,embed_size,num_layers,heads,forward_expansion,dropout,device,max_length,):super(Decoder, self).__init__()self.device = deviceself.word_embedding = nn.Embedding(trg_vocab_size, embed_size)self.position_embedding = nn.Embedding(max_length, embed_size)self.layers = nn.ModuleList([DecoderBlock(embed_size, heads, forward_expansion, dropout, device)for _ in range(num_layers)])self.fc_out = nn.Linear(embed_size, trg_vocab_size)self.dropout = nn.Dropout(dropout)def forward(self, x ,enc_out , src_mask, trg_mask):N, seq_length = x.shapepositions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)x = self.dropout((self.word_embedding(x) + self.position_embedding(positions)))for layer in self.layers:x = layer(x, enc_out, enc_out, src_mask, trg_mask)out =self.fc_out(x)return outclass Transformer(nn.Module):def __init__(self,src_vocab_size,trg_vocab_size,src_pad_idx,trg_pad_idx,embed_size = 256,num_layers = 6,forward_expansion = 4,heads = 8,dropout = 0,device="cuda",max_length=100):super(Transformer, self).__init__()self.encoder = Encoder(src_vocab_size,embed_size,num_layers,heads,device,forward_expansion,dropout,max_length)self.decoder = Decoder(trg_vocab_size,embed_size,num_layers,heads,forward_expansion,dropout,device,max_length)self.src_pad_idx = src_pad_idxself.trg_pad_idx = trg_pad_idxself.device = devicedef make_src_mask(self, src):src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)# (N, 1, 1, src_len)return src_mask.to(self.device)def make_trg_mask(self, trg):N, trg_len = trg.shapetrg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(N, 1, trg_len, trg_len)return trg_mask.to(self.device)def forward(self, src, trg):src_mask = self.make_src_mask(src)trg_mask = self.make_trg_mask(trg)enc_src = self.encoder(src, src_mask)out = self.decoder(trg, enc_src, src_mask, trg_mask)return out

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/5679.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

在外远程NAS群晖Drive - 群晖Drive挂载电脑磁盘同步备份【无需公网IP】

文章目录 前言1.群晖Synology Drive套件的安装1.1 安装Synology Drive套件1.2 设置Synology Drive套件1.3 局域网内电脑测试和使用 2.使用cpolar远程访问内网Synology Drive2.1 Cpolar云端设置2.2 Cpolar本地设置2.3 测试和使用 3. 结语 前言 群晖作为专业的数据存储中心&…

【Hadoop 01】简介

目录 1 Hadoop 简介 2 下载并配置Hadoop 2.1 修改/etc/profile 2.2 修改hadoop-env.sh 2.3 修改core-site.xml 2.4 修改hdfs-site.xml 2.5 修改mapred-site.xml 2.6 修改yarn-site.xml 2.7 修改workers 2.8 修改start-dfs.sh、stop-dfs.sh 2.9 修改start-yarn.sh、s…

Elemui表单合并

原代码形式 <template><el-table:data"tableData"borderstyle"width: 100%"><el-table-columnprop"date"label"日期"width"180"></el-table-column><el-table-columnprop"name"label…

WebGL 概念和基础入门

WebGL 概念和基础入门 WebGL 是什么 对于 WebGL 百度百科给出的解释是 WebGL 是一种 3D 绘图协议&#xff0c;而对此维基百科给出的解释却是一种 JavaScript API。由于 WebGL 技术旨在帮助我们在不使用插件的情况下在任何兼容的网页浏览器中开发交互式 2D 和 3D 网页效果&…

react 实现小球加入购物车动画

代码 import React, { useRef } from react;const ProductLayout () > {const box useRef(null);const createBall (left, top) > {const ball document.createElement(div);ball.style.position absolute;ball.style.left left - 10 px;ball.style.top top - 1…

【机器学习】了解 AUC - ROC 曲线

一、说明 在机器学习中&#xff0c;性能测量是一项基本任务。因此&#xff0c;当涉及到分类问题时&#xff0c;我们可以依靠AUC - ROC曲线。当我们需要检查或可视化多类分类问题的性能时&#xff0c;我们使用AUC&#xff08;曲线下面积&#xff09;ROC&#xff08;接收器工作特…

使用 Vue 创建一个简单的 Loading 动画

使用 Vue 创建一个简单的 Loading 动画 1. 开始之前 确保 正确安装了 Vue 3知道如何启动一个新的 Vue 项目&#xff08;或在项目中使用Vue&#xff09;了解 Vue 3 的 Composition API&#xff08;本文将使用&#xff09; 2. 设计组件 该组件应该包含三个部分 控制逻辑旋转…

win10 安装 langchain-chatglm 遇到的问题

win10 安装 langchain-chatglm 避坑指南&#xff08;2023年6月21日最新版本&#xff09;_憶的博客-CSDN博客官网看起来安装很简单&#xff0c;网上教程也是&#xff0c;但实际上我耗费了两天时间&#xff0c;查阅了当前网络上所有可查阅的资料&#xff0c;重复「安装-配置-卸载…

Spring Security 构建基于 JWT 的登录认证

一言以蔽之&#xff0c;JWT 可以携带非敏感信息&#xff0c;并具有不可篡改性。可以通过验证是否被篡改&#xff0c;以及读取信息内容&#xff0c;完成网络认证的三个问题&#xff1a;“你是谁”、“你有哪些权限”、“是不是冒充的”。 为了安全&#xff0c;使用它需要采用 …

HideSeeker论文阅读

文章目录 3.1 Overview of Our System HideSeeker3.2 Visual Information Extraction3.3 Relation Graph Learning3.4 Hidden Object Inference 4 EVALUATIONS4.7 Summary 6 DISCUSSIONS AND CONCLUSION 3.1 Overview of Our System HideSeeker 我们设计了一种名为“HideSeeke…

个人博客系统(SSM版 前端+后端)

前言 在学习Servlet的时候,也写了一个博客系统,主要的就是使用servelet加Tomcat进行实现的,而这个项目 仅仅适合去学习Web项目开发的思想,并不满足当下企业使用框架的思想,进行学习过Spring,Spring Boot,Spring MVC以及MyBatis之后,我们就可以对之前的项目使用SSM框架的形式进行…

react+redux异步操作数据

reactredux异步操作数据 redux中操作异步方法&#xff0c;主要是&#xff1a; 1、借助createAsyncThunk()封装异步方法&#xff1b;2、通过extraReducers处理异步方法触发后的具体逻辑&#xff0c;操作派生的state 1、异步操作的slice import { createSlice, createAsyncThunk…

uniapp 之 微信小程序、支付宝小程序 对于自定义导航栏的不同

目录 前言 微信小程序 代码 支付宝小程序 首页配置文件 二级菜单页面 配置 总结 不同 相同 前言 小程序都是 uni-app 写的 不是原生 微信小程序 代码 pages.json文件中配置 重点&#xff1a; "navigationStyle": "custom", // 导航栏样式…

ChatGPT开放自定义系统级别的指令,可设置偏好变成专属助理

OpenAI官方消息https://openai.com/blog/custom-instructions-for-chatgpt OpenAI为其大型语言模型接口ChatGPT引入了自定义指令&#xff0c;旨在为用户提供更加量身定制和个性化的体验&#xff0c;可以设置您的偏好&#xff0c;ChatGPT将在未来的所有对话中记住它们。 该功…

Python—数据结构(一)

先放一张自己学习和整理归纳的思维导图&#xff0c;以便让大家都知道我自己的整体学习路线。 数据结构的学习路上内容枯燥&#xff0c;但坚持下来一定有很大的收获&#xff01;加油&#x1f4aa;&#x1f3fb;&#xff01; 数据结构 数据的概念数据元素&#xff1a; 若干基本…

音视频开发-ffmpeg介绍-系列二

目录 一、FFmpeg核心结构体 二、解码流程 三、FFmpeg解码实现 四、FFmpeg编码实现 五、FFmpeg转码实现 一、FFmpeg核心结构体 AVFormatContext&#xff1a;解封装功能的结构体&#xff0c;包含文件名、音视频流、时长、比特率等信息&#xff1b; AVCodecContext&#xf…

【算法基础:数学知识】4.3 欧拉函数

文章目录 欧拉函数定义性质 例题列表873. 欧拉函数&#xff08;使用质因数分解求一个数的欧拉函数&#xff09;原理讲解&#xff08;公式推导&#xff09;⭐解法代码 874. 筛法求欧拉函数&#xff08;求 1 ~ n 中所有数字的欧拉函数&#xff09;⭐ 欧拉函数 https://oi-wiki.o…

[数据结构 -- 手撕排序算法第六篇] 递归实现快速排序(集霍尔版本,挖坑法,前后指针法为一篇的实现方法,很能打)

目录 1、常见的排序算法 1.1 交换排序基本思想 2、快速排序的实现方法 2.1 基本思想 3 hoare&#xff08;霍尔&#xff09;版本 3.1 实现思路 3.2 思路图解 3.3 为什么实现思路的步骤2、3不能交换 3.4 hoare版本代码实现 3.5 hoare版本代码测试 4、挖坑法 4.1 实现…

【手撕排序算法】---基数排序

个人主页&#xff1a;平行线也会相交 欢迎 点赞&#x1f44d; 收藏✨ 留言✉ 加关注&#x1f493;本文由 平行线也会相交 原创 收录于专栏【数据结构初阶&#xff08;C实现&#xff09;】 我们直到一般的排序都是通过关键字的比较和移动这两种操作来进行排序的。 而今天介绍的…

​MySQL高阶语句(三)

目录 1、内连接 2、左连接 3、右连接&#xff1a; 二、存储过程⭐⭐⭐ 4. 调用存储过程 5.查看存储过程 5.1 查看存储过程 5.2查看指定存储过程信息 三. 存储过程的参数 3.1存储过程的参数 3.2修改存储过程 四.删除存储过程 MySQL 的连接查询&#xff0c;通常都是将来…