Transformer架构详解

文章目录

  • 引言
  • 1. Transformer架构详解
    • 1.1 编码器与解码器
      • 1.1.1 编码器(Encoder)
      • 1.1.2 解码器(Decoder)
  • 2. 核心代码
    • 2.1 自注意力(Self-Attention)机制
    • 2.2 多头注意力(Multi-Head Attention)
    • 2.3 位置编码(Positional Encoding)
    • 2.4 前馈网络(Feed-Forward Network)
    • 2.5 编码器和解码器层

引言

Transformer是一种深度学习模型,最初是由Vaswani等人在2017年的论文《Attention Is All You Need》中提出的。这种模型在自然语言处理(NLP)领域特别流行,它引入了一种新的机制——自注意力(self-attention),使得模型能够更加高效和有效地处理序列数据。

Transformer完全基于注意力机制,没有使用循环神经网络(RNN)或卷积神经网络(CNN)。核心概念是:

  1. 自注意力机制:这允许模型在序列内的任意位置间直接建立依赖,从而更好地理解数据的上下文关系。自注意力机制可以并行处理所有位置的数据,这提高了计算效率。

  2. 多头注意力:模型会同时学习数据的不同表示,每个“头”关注序列的不同部分。这种机制可以捕获序列中多种不同级别的依赖关系。

  3. 位置编码:由于Transformer不使用基于顺序的结构,因此需要通过位置编码来给模型提供关于单词在序列中位置的信息。

1. Transformer架构详解

transformer与传统模型存在较大差别,主要体现在:

1. 处理数据的方式

  • RNN:RNN(特别是其变种如LSTM和GRU)通过递归地处理序列数据,每个时间步的输出依赖于前一个时间步。这使得RNN特别适合处理时间序列数据或任何顺序敏感的数据。
  • CNN:CNN通过卷积层提取空间特征,广泛应用于图像处理。在NLP中,CNN可以用来捕捉局部依赖,如单词和短语级别的模式。
  • Transformer:Transformer使用自注意力机制,允许模型同时处理整个序列的所有元素,有效地捕捉全局依赖。这种并行处理能力使得它在处理长序列时比RNN更高效。

2. 结构和复杂性

  • RNN:RNN的递归结构使得它在理论上能够处理任意长度的序列,但在实践中,由于梯度消失或爆炸的问题,它往往难以捕捉长期依赖。
  • CNN:CNN结构简单,易于训练,但主要适用于捕捉局部特征。在处理全局上下文或长距离依赖时可能不够有效。
  • Transformer:Transformer通过堆叠多个自注意力和前馈层,能有效地处理长距离依赖。但这也使得它的模型参数通常比RNN和CNN多,需要更多的计算资源和数据来训练。

3. 应用场景

  • RNN:理想的应用场景包括语音识别、机器翻译等时间序列相关任务。
  • CNN:主要用于图像处理、图像识别等领域。在NLP中,CNN可以用于文本分类、情感分析等任务。
  • Transformer:Transformer在NLP领域取得了显著的成功,如BERT、GPT等,广泛应用于文本理解、生成等任务。Vision Transformer(ViT)则将其应用于计算机视觉领域。

1.1 编码器与解码器

transformer的本质就是由编码器(encoder)和解码器(decoder)组成。

1.1.1 编码器(Encoder)

事实上,编码器包括两个子层:self-attention 和 feed-forward,我们只需搞清楚这两个子层,那么理解编码器就不是什么问题了。
在这里插入图片描述
在每一个子层的传输过程中,都会有一个(残差网络 + 归一化),意思就是 self-attention 的输出会通过一个残差网络 + 归一化之后才会传给 feed-forward 。

编码器详解图
在这里插入图片描述

  1. 首先输入一个单词thinking - > 得到词向量X1(可以通过one-hot,word2vec得到)
  2. 然后进行positional encoding,因为self-attention的缺点在于它是没有位置信息的,所以叠加一个位置编码,给X1赋予位置属性,得到黄色的X1。也就是说,黄色的X1是拥有位置属性的one-hot编码词向量。
  3. 输入到 self-attention 子层中,做注意力机制(X1和X2拼接起来的一句话做注意力机制),得到,Z1(X1 与 X1,X2拼接起来的句子做了自注意力机制的词向量,表征的仍然是thinking),也就是说Z1拥有了位置特征,句法特征。语义特征的词向量。
  4. 残差网络(避免梯度消失),归一化(避免梯度爆炸),得到深粉色的Z1,
  5. feed-forward (前面每一步都在做线性变换,wx+b,线性变换的叠加永远都是线性变换,通过feed-forward 添加激活函数做非线性变换,这样空间变换可以无限拟合任何一种状态了) ,得到r1(thinking的新的表征)。

总结
事实上transformer就是在做词向量。所有的工作都是为了让这个词向量变得更加优秀,让这个词向量能够更加精准地表达这个词,这句话。总而言之,编码器就是让计算机能够合理地认识人类世界客观存在的一些东西

1.1.2 解码器(Decoder)

解码器由三部分构成,
掩码自注意力(Masked Self-Attention)
编码器-解码器注意力(Encoder-Decoder Attention)
前馈神经网络(Feed-Forward Neural Network)
解码器会接收编码器生成的词向量,然后通过这个词向量去生成翻译的结果。

解码器详解图
在这里插入图片描述
Self-Attention

  1. 解码器的 Self-Attention 在编码已经生成的单词
  2. 假如目标词 “我是一个学生” --> masked Self-Attention
  3. 训练阶段:目标词 “我是一个学生” 是已知的,然后 Self-Attention 是对 “我是一个学生” 做计算。如果不做masked,每次训练阶段,都会获得全部的信息
  4. 测试阶段:目标词未知,假设目标词是 “我是一个老师”,Self-Attention 第一次对 “我” 做计算,第二次对 “我是” 做计算,…。而测试阶段,每生成一点,获得一点

这一部分与编码器中的自注意力机制相似,但有一个关键区别:它使用掩码来防止位置 i 的注意力机制查看在位置 i 之后的输出。这种“掩码”操作确保了解码器在生成第 i 个词时只能依赖于前 i−1 个词,保持了解码过程的自回归特性。
控制信息流向的这种方法是必要的,因为在训练时模型同时看到整个目标序列,如果没有掩码,模型就可以直接“看到”正确的输出,而不是学习如何生成它。

Encoder-Decoder Attention
在这个子层中,解码器从编码器中获取信息。这里的注意力机制使得解码器的每个位置都能考虑到编码器输出的所有位置。这种机制对于将输入序列中的相关信息与输出序列正确对齐至关重要。

前馈神经网络(Feed-Forward Neural Network)
每个解码器层还包含一个前馈神经网络,这与编码器中的网络相同。它独立地处理每个位置的表示,然后将其传递到下一个层。
这个网络通常包含两个线性变换和一个激活函数。

生成词
在这里插入图片描述
总结
解码器的核心功能是结合编码器的输出和自身的历史输出(即迄今为止已生成的部分序列)来生成下一个输出元素。在每个时间步,解码器都会更新其关注点,既考虑到编码器的输出,也考虑到自身之前生成的输出。这种结构使得Transformer解码器能够高效地生成精确且上下文相关的序列。

2. 核心代码

2.1 自注意力(Self-Attention)机制

def scaled_dot_product_attention(query, key, value, mask):"""计算注意力权重。"""matmul_qk = tf.matmul(query, key, transpose_b=True)  # 缩放 matmul_qkdepth = tf.cast(tf.shape(key)[-1], tf.float32)logits = matmul_qk / tf.math.sqrt(depth)# 添加掩码以避免看到未来信息if mask is not None:logits += (mask * -1e9)# softmax 在最后一个轴(seq_len_k)上归一化,因此分数相加等于1。attention_weights = tf.nn.softmax(logits, axis=-1)output = tf.matmul(attention_weights, value)return output

2.2 多头注意力(Multi-Head Attention)

class MultiHeadAttention(tf.keras.layers.Layer):def __init__(self, d_model, num_heads):super(MultiHeadAttention, self).__init__()self.num_heads = num_headsself.d_model = d_modelassert d_model % self.num_heads == 0self.depth = d_model // self.num_headsself.wq = tf.keras.layers.Dense(d_model)self.wk = tf.keras.layers.Dense(d_model)self.wv = tf.keras.layers.Dense(d_model)self.dense = tf.keras.layers.Dense(d_model)# 省略了细节代码

2.3 位置编码(Positional Encoding)

def positional_encoding(position, d_model):angle_rates = 1 / np.power(10000, (2 * (np.arange(d_model)[np.newaxis, :] // 2)) / np.float32(d_model))angle_rads = np.arange(position)[:, np.newaxis] * angle_rates# 将sine应用于数组中的偶数索引(indices); 2isines = np.sin(angle_rads[:, 0::2])# 将cosine应用于数组中的奇数索引; 2i+1cosines = np.cos(angle_rads[:, 1::2])pos_encoding = np.concatenate([sines, cosines], axis=-1)pos_encoding = pos_encoding[np.newaxis, ...]return tf.cast(pos_encoding, dtype=tf.float32)

2.4 前馈网络(Feed-Forward Network)

def point_wise_feed_forward_network(d_model, dff):return tf.keras.Sequential([tf.keras.layers.Dense(dff, activation='relu'),  # (batch_size, seq_len, dff)tf.keras.layers.Dense(d_model)  # (batch_size, seq_len, d_model)])

2.5 编码器和解码器层

编码器

import torch
import torch.nn as nnclass EncoderLayer(nn.Module):def __init__(self, d_model, num_heads, dff, dropout_rate):super(EncoderLayer, self).__init__()self.multi_head_attention = nn.MultiheadAttention(d_model, num_heads)self.feed_forward = nn.Sequential(nn.Linear(d_model, dff),nn.ReLU(),nn.Linear(dff, d_model))self.layernorm1 = nn.LayerNorm(d_model)self.layernorm2 = nn.LayerNorm(d_model)self.dropout1 = nn.Dropout(dropout_rate)self.dropout2 = nn.Dropout(dropout_rate)def forward(self, x, mask):attn_output, _ = self.multi_head_attention(x, x, x, attn_mask=mask)out1 = self.layernorm1(x + self.dropout1(attn_output))ff_output = self.feed_forward(out1)out2 = self.layernorm2(out1 + self.dropout2(ff_output))return out2

解码器

class DecoderLayer(nn.Module):def __init__(self, d_model, num_heads, dff, dropout_rate):super(DecoderLayer, self).__init__()self.masked_attention = nn.MultiheadAttention(d_model, num_heads)self.multi_head_attention = nn.MultiheadAttention(d_model, num_heads)self.feed_forward = nn.Sequential(nn.Linear(d_model, dff),nn.ReLU(),nn.Linear(dff, d_model))self.layernorm1 = nn.LayerNorm(d_model)self.layernorm2 = nn.LayerNorm(d_model)self.layernorm3 = nn.LayerNorm(d_model)self.dropout1 = nn.Dropout(dropout_rate)self.dropout2 = nn.Dropout(dropout_rate)self.dropout3 = nn.Dropout(dropout_rate)def forward(self, x, enc_output, src_mask, tgt_mask):attn_output, _ = self.masked_attention(x, x, x, attn_mask=tgt_mask)out1 = self.layernorm1(x + self.dropout1(attn_output))attn_output, _ = self.multi_head_attention(out1, enc_output, enc_output, attn_mask=src_mask)out2 = self.layernorm2(out1 + self.dropout2(attn_output))ff_output = self.feed_forward(out2)out3 = self.layernorm3(out2 + self.dropout3(ff_output))return out3

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/230156.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python PIP安装pycorrector、kemln报错

本来想装个pycorrector用一下,结果在安装其依赖包kemln的时候疯狂报错,报错关键词包括但不限于Bash、Cmake,C啥的,搜了很多文章,终于摸索到了安装的办法。 1、安装bash 去官网https://gitforwindows.org/下载bash&am…

消费者痛点怎么分析,不同行业如何营销

消费者的痛点是品牌营销中的关键因素,因为准确把握消费者的痛点,可以为品牌带来更大的市场机会。今天和大家探讨下消费者痛点怎么分析,不同行业如何营销? 今天我们会从分类、洞察、场景分析、分级与评判以及不同行业细分的角度来进…

neuq-acm预备队训练week 9 P8604 [蓝桥杯 2013 国 C] 危险系数

题目背景 抗日战争时期,冀中平原的地道战曾发挥重要作用。 题目限制 题目描述 地道的多个站点间有通道连接,形成了庞大的网络。但也有隐患,当敌人发现了某个站点后,其它站点间可能因此会失去联系。 我们来定义一个危险系数 DF…

Android动画(四)——属性动画ValueAnimator的妙用

目录 介绍 效果图 代码实现 xml文件 介绍 ValueAnimator是ObjectAnimator的父类,它继承自Animator。ValueAnimaotor同样提供了ofInt、ofFloat、ofObject等静态方法,传入的参数是动画过程的开始值、中间值、结束值来构造动画对象。可以将ValueAnimator看…

环境搭建及源码运行_java环境搭建_maven

书到用时方恨少、觉知此时要躬行;拥有技术,成就未来,抖音视频教学地址:​​​​​​​ ​​​​​​​ 1、介绍 1)管理项目依赖和版本 统一的项目依赖和版本管理 ​​​​​​​​​​​ 2)Maven支持多模块…

From Human Attention to Computational Attention (1)

”is the taking possession by the mind, in clear and vivid form, of one out of what seem several simultaneously possible objects or trains of thought. It implies withdrawal from some things in order to deal effectively with others“,William Jame…

6.任务分配与执行总体设计实现

1.设计 执行任务找一个落地场景:连接设备采集参数。设备有不同的协议,如:modbus rtu、modbus tcp、opc ua、simens s7等。协议多种多样,需要的参数也不同,连接及任务执行参数存放在t_job表的link_spec中,任…

Jenkins配置代理节点时遇到的坑和解决办法

需求:服务器太满了,需要找个比较空闲的机器分担一下,看上了同网络的某开会用的笔记本,把这个本本利用起来能跑一个算一个。 但配置起来并不容易,遇到的问题有些网上也几乎找不到答案。这里记录一下能救一个是一个&…

python处理数据内存不够,python处理数据安全吗

大家好,小编为大家解答python处理数据索引的常见方法的问题。很多人还不知道python处理数据内存不够,现在让我们一起来看看吧! 学 目录 1.数据表的基本信息查看 2.查看数据表的大小 3.数据格式的查看 4、查看具体的数据分布 二、缺失值处理 …

sap table 获取 valuation class MBEW 查表获取

参考 https://www.tcodesearch.com/sap-tables/search?qvaluationclass

FastAPI访问/docs接口文档显示空白、js/css无法加载

如图: 原因是FastAPI的接口文档默认使用https://cdn.jsdelivr.net/npm/swagger-ui-dist5.9.0/swagger-ui.css 和https://cdn.jsdelivr.net/npm/swagger-ui-dist5.9.0/swagger-ui-bundle.js 来渲染页面,而这两个URL是外网的CDN,在国内响应超…

Text2SQL学习整理(二) WikiSQL数据集介绍

导语 上篇博客中,我们已经了解到Text2SQL任务的基本定义,本篇博客将对近年来该领域第一个大型数据集WikiSQL做简要介绍。 WikiSQL数据集概述 基本统计特性 WikiSQL数据集是一个多数据库、单表、单轮查询的Text-to-SQL数据集。它是Salesforce在2017年…

python之双链表

双链表简单讲解 双向链表(doubly linked list)是一种链式数据结构,它的每个节点包含两个指针,一个指向前一个节点,一个指向后一个节点。与单向链表相比,双向链表可以在任何位置进行插入和删除操作&#xf…

PDF转为图片

PDF转为图片 背景pdf展示目标效果 发展过程最终解决方案:python PDF转图片pdf2image注意:poppler 安装 背景 最近接了一项目,主要的需求就是本地的文联单位,需要做一个电子刊物阅览的网站,将民族的刊物发布到网站上供…

字节开源的netPoll多路复用器源码解析

字节开源的netPoll多路复用器源码解析 引言NetPollepoll API原生网络库实现netpoll 设计思路netpoll 对比 go net数据结构 源码解析多路复用池初始化Epoll相关API可读事件处理server启动accept 事件客户端连接初始化客户端连接建立 可读事件等待读取数据 可写事件处理客户端启动…

word增加引用-endnote使用

使用软件: web of science https://webofscience.clarivate.cn/wos/alldb/basic-search; Pub Med等数据库endnote20 链接: https://pan.baidu.com/s/1VQMEsgFY3kcpCNfIyqEjtQ?pwdy1mz 提取码: y1mz 复制这段内容后打开百度网盘手机App,操作更方便哦 --…

信号与线性系统翻转课堂笔记4——连续LTI系统的微分方程模型与求解

信号与线性系统翻转课堂笔记4——连续LTI系统的微分方程模型与求解 The Flipped Classroom4 of Signals and Linear Systems 对应教材:《信号与线性系统分析(第五版)》高等教育出版社,吴大正著 一、要点 (1&#x…

探索 Coinbase 二层链 Base 的潜力与风险

作者:lesleyfootprint.network 在不断变化的加密货币领域,Coinbase 已经确立了自己领先中心化交易所(CEX)的地位。然而,Coinbase 坚信去中心化是创造一个开放、全球范围内对每个人都可访问的加密经济的关键&#xff0…

python学习3

大家好,今天又来更新python学习篇了。本次的内容比较简单,时描述性统计代码,直接给出所有代码,如下: import pandas as pd from scipy.stats import fisher_exact from fuzzywuzzy import fuzz from fuzzywuzzy impor…

高性能计算HPC与统一存储

高性能计算(HPC)广泛应用于处理大量数据的复杂计算,提供更精确高效的计算结果,在石油勘探、基因分析、气象预测等领域,是企业科研机构进行研发的有效手段。为了分析复杂和大量的数据,存储方案需要响应更快&…