Transformer详解encoder

目录

1. Input Embedding

2. Positional Encoding

3. Multi-Head Attention

4. Add & Norm

5. Feedforward + Add & Norm

6.代码展示

(1)layer_norm

(2)encoder_layer=1


最近刚好梳理了下transformer,今天就来讲讲它~

        Transformer是谷歌大脑2017年在论文attention is all you need中提出来的seq2seq模型,它的本质就是由编码器和解码器组成,今天的主角则是其中的编码器(在BERT预训练模型中也只用到了编码器部分)如下图所示,这个模块的输入为 𝑋 (每一行代表一个句子,batchsize有多大就有多少行),我们将从输入到隐藏层按照从1到4的顺序逐层来看一下各个维度的变化。

1. Input Embedding

        所谓的Embedding其实就是查字典或者叫查表,也就是将一个句子里的每一个字转化为一个维度为embedding dimension的向量来表示,因此 𝑋 经过嵌入后变成 𝑋𝑒𝑚𝑏𝑒𝑑𝑑𝑖𝑛𝑔 ,三个维度分别表示一个批次的句子数,每个句子的字数,每个字的嵌入维度。

2. Positional Encoding

        位置编码,按照字面意思理解就是给输入的位置做个标记,简单理解比如你就给一个字在句子中的位置编码1,2,3,4这样下去,高级点的比如作者用的正余弦函数

𝑃𝐸(𝑝𝑜𝑠,2𝑖)=𝑠𝑖𝑛(𝑝𝑜𝑠/100002𝑖/𝑑𝑚𝑜𝑑𝑒𝑙)

𝑃𝐸(𝑝𝑜𝑠,2𝑖+1)=𝑐𝑜𝑠(𝑝𝑜𝑠/100002𝑖/𝑑𝑚𝑜𝑑𝑒𝑙)

 

        其中pos表示字在句子中的位置,i指的词向量的维度。经过位置编码,相当于能够得到一个和输入维度完全一致的编码数组 𝑋𝑝𝑜𝑠 ,当它叠加到原来的词嵌入上得到新的词嵌入

𝑋𝑒𝑚𝑏𝑒𝑑𝑑𝑖𝑛𝑔=𝑋𝑒𝑚𝑏𝑒𝑑𝑑𝑖𝑛𝑔+𝑋𝑝𝑜𝑠

        此时的维度为:一个批次的句子数 × 一个句子的词数 × 一个词的嵌入维度

3. Multi-Head Attention

        注意力机制,其实可以理解为就是在计算相关性,很自然的想法就是去更多地关注那些相关更大的东西。这里首先要引入Query,Key和Value的概念,Query就是查询的意思,Key就是键用来和你要查询的Query做比较,比较得到一个分数(相关性或者相似度)再乘以Value这个值得到最终的结果。

        那么这个Q,K,V从哪里来呢,这里采用的是self-attention的方式,也就是从输入自己 𝑋𝑒𝑚𝑏𝑒𝑑𝑑𝑖𝑛𝑔 来产生,即做线性映射产生Q,K,V:

𝑄=𝑋𝑒𝑚𝑏𝑒𝑑𝑑𝑖𝑛𝑔∗𝑊𝑄𝐾=𝑋𝑒𝑚𝑏𝑒𝑑𝑑𝑖𝑛𝑔∗𝑊𝐾𝑉=𝑋𝑒𝑚𝑏𝑒𝑑𝑑𝑖𝑛𝑔∗𝑊𝑉

        这里三个权重矩阵均为维度为Embedding的方阵,也就是说Q,K,V的维度和 𝑋𝑒𝑚𝑏𝑒𝑑𝑑𝑖𝑛𝑔 是一致的。

        接下来考虑什么叫做multi-head(多头)呢,本质上就是从embedding的维度上将矩阵切分为多份,每一份就是一个头,比如之前的Q,K,V切完后的维度就是一个批次的句子数 × 一个句子的词数 × 头数 × (词嵌入维度/头数)这个多头的切分体现在最后两个维度:词嵌入维度=数 × (词嵌入维度/头数)为了便于计算,通常会将第二第三维度进行转置,即最终的维度为一个批次的句子数 × 头数 × 一个句子的词数 × (词嵌入维度/头数)

        接下来说说注意力机制的计算,假设Q,K,V为切分完后的矩阵(其中一个头),根据两个向量的点积越大越相似,我们通过 𝑄𝐾𝑇 求出注意力矩阵,再根据注意力矩阵来给Value进行加权,即

𝐴𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛(𝑄,𝐾,𝑉)=𝑠𝑜𝑓𝑡𝑚𝑎𝑥(𝑄𝐾𝑇𝑑𝑘)𝑉

        其中 𝑑𝑘 是为了把注意力矩阵变成标准正态分布,softmax进行归一化,使每个字与其他字的注意力权重之和为1。这一操作使得每一个字的嵌入都包含当前句子内所有字的信息,注意Attention(Q,K,V)的维度和 𝑉 的维度保持一致。

4. Add & Norm

这里主要做了两个操作

  • 一个是残差连接(或者叫做短路连接),说得直白点就是把上一层的输入 𝑋 和上一层的输出加起来 𝑆𝑢𝑏𝐿𝑎𝑦𝑒𝑟(𝑋) ,即 𝑋+𝑆𝑢𝑏𝐿𝑎𝑦𝑒𝑟(𝑋) ,举例说明,比如在注意力机制前后的残差连接:

𝑋𝑒𝑚𝑏𝑒𝑑𝑑𝑖𝑛𝑔+𝐴𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛(𝑄,𝐾,𝑉)

  • 一个是LayerNormalization(作用是把神经网络中隐藏层归一为标准正态分布,加速收敛),具体操作是将每一行每一个元素减去这行的均值, 再除以这行的标准差, 从而得到归一化后的数值。

5. Feedforward + Add & Norm

前馈网络也就是简单的两层线性映射再经过激活函数一下,比如

𝑋ℎ𝑖𝑑𝑑𝑒𝑛=𝑅𝑒𝑙𝑢(𝑋ℎ𝑖𝑑𝑑𝑒𝑛∗𝑊1∗𝑊2)

残差操作和层归一化同步骤3.


上述的1,2,3,4就构成Transformer中的一个encoder模块,经过1,2,3,4后得到的就是encode后的隐藏层表示,可以发现它的维度其实和输入是一致的!即:一个批次中句子数 × 一个句子的字数 × 字嵌入的维度

6.代码展示

(1)layer_norm

bs=2,seq=3,dim=5

import torchbatch_size = 2
seq = 3
fea_dim = 5
X = torch.rand(batch_size,seq,fea_dim)
layer_norm = torch.nn.LayerNorm(fea_dim)
out = layer_norm(X)
print(out)
print('-'*30)mean = torch.mean(X,dim=-1,keepdim=True)
std = torch.sqrt(torch.var(X,unbiased=False,dim=-1,keepdim=True) + 1e-5)
weight = layer_norm.state_dict()['weight']
bias = layer_norm.state_dict()['bias']
my_norm = ((X - mean)/std) * weight + bias
print(my_norm)

(2)encoder_layer=1

bs=1,seq=1,dim=6,head=1

import torchseq = 1
dim = 6
heads = 1
batch_size = 1
value = torch.rand(batch_size,seq,dim)encoder_layer = torch.nn.TransformerEncoderLayer(dim,heads,dropout=0.0,batch_first=True)
out = encoder_layer(value)
print(out)# 多头自注意力
def my_scaled_dot_product(query,key,value):qk_T = torch.mm(query,key.T)qk_T_scale = qk_T / torch.sqrt(torch.tensor(value.shape[1]))qk_exp = torch.exp(qk_T_scale)qk_exp_sum = torch.sum(qk_exp,dim=1,keepdim=True)qk_softmax = qk_exp / qk_exp_sumv_attn = torch.mm(qk_softmax,value)return v_attn,qk_softmaxin_proj_weight = encoder_layer.state_dict()['self_attn.in_proj_weight']
in_proj_bias = encoder_layer.state_dict()['self_attn.in_proj_bias']out_proj_weight = encoder_layer.state_dict()['self_attn.out_proj.weight']
out_proj_bias = encoder_layer.state_dict()['self_attn.out_proj.bias']batch_V_output = torch.empty(batch_size,seq,dim)
for i in range(batch_size):in_proj = torch.mm(value[i],in_proj_weight.T) + in_proj_biasQs,Ks,Vs = torch.split(in_proj,dim,dim=-1)head_Vs = []attn_weight = torch.zeros(seq,seq)for Q,K,V in zip(torch.split(Qs,dim//heads,dim=-1),torch.split(Ks,dim//heads,dim=-1),torch.split(Vs,dim//heads,dim=-1)):head_v,_ = my_scaled_dot_product(Q,K,V)head_Vs.append(head_v)V_cat = torch.cat(head_Vs,dim=-1)V_ouput = torch.mm(V_cat,out_proj_weight.T) + out_proj_biasbatch_V_output[i] = V_ouput# 第一次加
first_Add = value + batch_V_output# 第一次layer_norm
norm1_mean = torch.mean(first_Add,dim=-1,keepdim=True)
norm1_std = torch.sqrt(torch.var(first_Add,unbiased=False,dim=-1,keepdim=True) + 1e-5)
norm1_weight = encoder_layer.state_dict()['norm1.weight']
norm1_bias = encoder_layer.state_dict()['norm1.bias']
norm1 = ((first_Add - norm1_mean)/norm1_std) * norm1_weight + norm1_bias# feed forward
linear1_weight = encoder_layer.state_dict()['linear1.weight']
linear1_bias = encoder_layer.state_dict()['linear1.bias']
linear2_weight = encoder_layer.state_dict()['linear2.weight']
linear2_bias = encoder_layer.state_dict()['linear2.bias']
linear1 = torch.matmul(norm1,linear1_weight.T) + linear1_bias
linear1_relu = torch.nn.functional.relu(linear1)
linear2 = torch.matmul(linear1_relu,linear2_weight.T) + linear2_bias# 第二次加
second_Add = norm1 + linear2# 第二次layer_norm
norm2_mean = torch.mean(second_Add,dim=-1,keepdim=True)
norm2_std = torch.sqrt(torch.var(second_Add,unbiased=False,dim=-1,keepdim=True) + 1e-5)
norm2_weight = encoder_layer.state_dict()['norm2.weight']
norm2_bias = encoder_layer.state_dict()['norm2.bias']
norm2 = ((second_Add - norm2_mean)/norm2_std) * norm2_weight + norm2_bias
print(norm2)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/38198.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【VScode】常用配置

1.indenticator 增加白色竖条,显示方法范围 2.Git Graph 给git变换分支增添颜色区分 3.Vue 系列 vue 系列:给纯白色代码添加 颜色区分 3.eslint eslint警告,比如{ } 只写了半个会标红提示错误 等错误信息提示 需要配置js等页面 非下…

1.linux操作系统CPU负载

目录 概述CPU平均负载查看平均负载结束 概述 CPU 使用率 和CPU 平均使用率。 CPU平均负载 单位时间内系统处于 [可运行状态] 和 [不可中断状态] 的平均进程数,就是平均活跃进程数,和CPU使用率并没有直接关系 可运行状态 正在使用CPU或者正等待CPU的进…

【Elasticsearch】linux使用supervisor常驻Elasticsearch,centos6.10安装 supervisor

背景: linux服务器,CentOS 6操作系统,默认版本python2.6.6,避免安装过多的依赖不升级python 在网上查的资料python2.6.6兼容supervisor版本 3.1.3 安装supervisor 手动在python官网下载supervisor,并上传到服务器 下…

量化交易心法——如何建立自己的算法交易事业

量化交易,也称算法交易,是严格按照将计算机算法程序给出的买卖决策进行的证券交易。 一、 什么人适合成为量化交易员 做量化交易并不一定需要特别高的学历,只要具备一定的金融学以及统计学知识,有一定的经济基础,不需要用交易的收益来维持日常生活,因为并不是很快就能找…

Linux_动、静态库

目录 一、静态库 1、静态库的概念 2、制作静态库的指令 3、制作静态库 4、链接静态库 二、动态库 1、动态库的概念 2、制作动态库的指令 3、制作动态库 4、链接动态库 5、动态库的加载 三、静态库与动态库的区别 结语 前言: 在Linux下大部分程序进…

第2章 数据存储篇

目录 2.1 MongoDB:面向文档的灵活存储 2.1.1 MongoDB基础与架构 2.1.1.1基本概念 2.1.1.2MongoDB安装与配置 1)安装MongoDB-Linux安装示例(以Ubuntu为例) 2)更新包列表并安装MongoDB 3)启动MongoDB服…

利用OPT算法解决最短访问次数问题

一、题目 数据库缓存,模拟访问规则如下: 当查询值在缓存中,直接访问缓存,不访问数据库。否则,访问数据库,并将值放入缓存。 若缓存已满,则必须删除一个缓存。 给定缓存大小和训练数据&#xff…

对代理模式和动态代理以及AOP的一些理解

代理模式: 代理模式,也叫做静态代理,是一种结构型设计模式,它为其他对象提供了一种代理,以控制对这个对象的访问。 代理模式可以在不修改原有类的情况下,对其功能进行扩展,编译时就确定了代理…

【JavaEE】多线程代码案例(1)

🎏🎏🎏个人主页🎏🎏🎏 🎏🎏🎏JavaEE专栏🎏🎏🎏 🎏🎏🎏上一篇文章:多线程(2…

leetcode每日一练:顺序表OJ题

第一题:移除元素 题目要求:给一个数组nums和一个值val,你需要 原地 移除所有所有数值等于val的元素,并返回移除后数组的新长度。 不要使用额外的数组空间,你必须仅使用0(1)的额外空间并 原地 修改输入数组。 元素的顺序…

【Tools】AIGC:人工智能生成内容的新时代

那年夏天我和你躲在 这一大片宁静的海 直到后来我们都还在 对这个世界充满期待 今年冬天你已经不在 我的心空出了一块 很高兴遇见你 让我终究明白 回忆比真实精彩 🎵 王心凌《那年夏天宁静的海》 随着人工智能(AI)技术的…

三生随记——午夜咖啡馆

在城市的边缘,隐藏着一间古老的咖啡馆——“午夜咖啡馆”。它的外观不起眼,却总能在夜晚吸引那些寻找安宁或寻求刺激的顾客。据说,咖啡馆的老板是一位年长的绅士,他的脸上总是挂着神秘莫测的微笑。 艾米是一名作家,常常…

基于weixin小程序智慧物业系统的设计

管理员账户功能包括:系统首页,个人中心,管理员管理,用户管理,员工管理,房屋管理,缴费管理,车位管理,报修管理 工作人员账号功能包括:系统首页,维…

使用electron打包Vue前端项目的详细流程

使用electron打包Vue前端项目的详细流程 需要更改的东西 路由模式的修改 # 修改前:url不带#mode: history# 修改后:url带#mode: hash全局修改Cookies为localStorage 由于打包成exe或deb这类可执行文件后,本地是没有 Cookies 全局搜索Cooki…

Android Studio环境搭建(4.03)和报错解决记录

1.本地SDK包导入 安装好IDE以及下好SDK包后,先不要管IDE的引导配置,直接新建一个新工程,进到开发界面。 SDK路径配置:File---->>Other Settings---->>Default Project Structure 拷贝你SDK解压的路径来这,…

ros笔记01--初次体验ros2

ros笔记01--初次体验ros2 介绍安装ros2测试验证ros2说明 介绍 机器人操作系统(ROS)是一组用于构建机器人应用程序的软件库和工具。从驱动程序和最先进的算法到强大的开发者工具,ROS拥有我们下一个机器人项目所需的开源工具。 当前ros已经应用到各类机器人项目开发中…

ORACLE 、达梦 数据库查询指定库指定表的索引信息

在Oracle数据库中,索引是一种关键的性能优化工具,通过它可以加快数据检索速度。在本文中,我们将深入探讨如何详细查询指定表的索引信息,以及如何利用系统视图和SQL查询来获取这些信息。 索引在数据库中的重要性 索引是一种数据结…

操作符详解(下) (C语言)

操作符详解下 操作符的属性1.优先级2.结合级 表达式求值1.整型提升2.如何进行整形提升呢?3.算术转换4.问题表达式解析 操作符的属性 C语言的操作符有2个重要的属性:优先级、结合性,这两个属性决定了表达式求值的计算顺序。 1.优先级 优先级…

问题:第一次世界大战的起止时间是 #其他#学习方法#微信

问题:第一次世界大战的起止时间是 A.1913 ~1918 年 B.1913 ~1918 年 C.1914 ~1918 年 D.1914 ~1919 年 参考答案如图所示

RabbitMQ 进程内流控(Flow Control) 源码解析

1. 概述 1.1 为什么要流控? 流控主要是为了防止生产者生产消息速度过快,超过 Broker 可以处理的速度。这时需要暂时限制生产者的生产速度,让 Broker 的处理能够跟上生产速度。 Erlang进程之间不共享内存,每个进程都有自己的进程邮…