Transformer详解encoder

目录

1. Input Embedding

2. Positional Encoding

3. Multi-Head Attention

4. Add & Norm

5. Feedforward + Add & Norm

6.代码展示

(1)layer_norm

(2)encoder_layer=1


最近刚好梳理了下transformer,今天就来讲讲它~

        Transformer是谷歌大脑2017年在论文attention is all you need中提出来的seq2seq模型,它的本质就是由编码器和解码器组成,今天的主角则是其中的编码器(在BERT预训练模型中也只用到了编码器部分)如下图所示,这个模块的输入为 𝑋 (每一行代表一个句子,batchsize有多大就有多少行),我们将从输入到隐藏层按照从1到4的顺序逐层来看一下各个维度的变化。

1. Input Embedding

        所谓的Embedding其实就是查字典或者叫查表,也就是将一个句子里的每一个字转化为一个维度为embedding dimension的向量来表示,因此 𝑋 经过嵌入后变成 𝑋𝑒𝑚𝑏𝑒𝑑𝑑𝑖𝑛𝑔 ,三个维度分别表示一个批次的句子数,每个句子的字数,每个字的嵌入维度。

2. Positional Encoding

        位置编码,按照字面意思理解就是给输入的位置做个标记,简单理解比如你就给一个字在句子中的位置编码1,2,3,4这样下去,高级点的比如作者用的正余弦函数

𝑃𝐸(𝑝𝑜𝑠,2𝑖)=𝑠𝑖𝑛(𝑝𝑜𝑠/100002𝑖/𝑑𝑚𝑜𝑑𝑒𝑙)

𝑃𝐸(𝑝𝑜𝑠,2𝑖+1)=𝑐𝑜𝑠(𝑝𝑜𝑠/100002𝑖/𝑑𝑚𝑜𝑑𝑒𝑙)

 

        其中pos表示字在句子中的位置,i指的词向量的维度。经过位置编码,相当于能够得到一个和输入维度完全一致的编码数组 𝑋𝑝𝑜𝑠 ,当它叠加到原来的词嵌入上得到新的词嵌入

𝑋𝑒𝑚𝑏𝑒𝑑𝑑𝑖𝑛𝑔=𝑋𝑒𝑚𝑏𝑒𝑑𝑑𝑖𝑛𝑔+𝑋𝑝𝑜𝑠

        此时的维度为:一个批次的句子数 × 一个句子的词数 × 一个词的嵌入维度

3. Multi-Head Attention

        注意力机制,其实可以理解为就是在计算相关性,很自然的想法就是去更多地关注那些相关更大的东西。这里首先要引入Query,Key和Value的概念,Query就是查询的意思,Key就是键用来和你要查询的Query做比较,比较得到一个分数(相关性或者相似度)再乘以Value这个值得到最终的结果。

        那么这个Q,K,V从哪里来呢,这里采用的是self-attention的方式,也就是从输入自己 𝑋𝑒𝑚𝑏𝑒𝑑𝑑𝑖𝑛𝑔 来产生,即做线性映射产生Q,K,V:

𝑄=𝑋𝑒𝑚𝑏𝑒𝑑𝑑𝑖𝑛𝑔∗𝑊𝑄𝐾=𝑋𝑒𝑚𝑏𝑒𝑑𝑑𝑖𝑛𝑔∗𝑊𝐾𝑉=𝑋𝑒𝑚𝑏𝑒𝑑𝑑𝑖𝑛𝑔∗𝑊𝑉

        这里三个权重矩阵均为维度为Embedding的方阵,也就是说Q,K,V的维度和 𝑋𝑒𝑚𝑏𝑒𝑑𝑑𝑖𝑛𝑔 是一致的。

        接下来考虑什么叫做multi-head(多头)呢,本质上就是从embedding的维度上将矩阵切分为多份,每一份就是一个头,比如之前的Q,K,V切完后的维度就是一个批次的句子数 × 一个句子的词数 × 头数 × (词嵌入维度/头数)这个多头的切分体现在最后两个维度:词嵌入维度=数 × (词嵌入维度/头数)为了便于计算,通常会将第二第三维度进行转置,即最终的维度为一个批次的句子数 × 头数 × 一个句子的词数 × (词嵌入维度/头数)

        接下来说说注意力机制的计算,假设Q,K,V为切分完后的矩阵(其中一个头),根据两个向量的点积越大越相似,我们通过 𝑄𝐾𝑇 求出注意力矩阵,再根据注意力矩阵来给Value进行加权,即

𝐴𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛(𝑄,𝐾,𝑉)=𝑠𝑜𝑓𝑡𝑚𝑎𝑥(𝑄𝐾𝑇𝑑𝑘)𝑉

        其中 𝑑𝑘 是为了把注意力矩阵变成标准正态分布,softmax进行归一化,使每个字与其他字的注意力权重之和为1。这一操作使得每一个字的嵌入都包含当前句子内所有字的信息,注意Attention(Q,K,V)的维度和 𝑉 的维度保持一致。

4. Add & Norm

这里主要做了两个操作

  • 一个是残差连接(或者叫做短路连接),说得直白点就是把上一层的输入 𝑋 和上一层的输出加起来 𝑆𝑢𝑏𝐿𝑎𝑦𝑒𝑟(𝑋) ,即 𝑋+𝑆𝑢𝑏𝐿𝑎𝑦𝑒𝑟(𝑋) ,举例说明,比如在注意力机制前后的残差连接:

𝑋𝑒𝑚𝑏𝑒𝑑𝑑𝑖𝑛𝑔+𝐴𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛(𝑄,𝐾,𝑉)

  • 一个是LayerNormalization(作用是把神经网络中隐藏层归一为标准正态分布,加速收敛),具体操作是将每一行每一个元素减去这行的均值, 再除以这行的标准差, 从而得到归一化后的数值。

5. Feedforward + Add & Norm

前馈网络也就是简单的两层线性映射再经过激活函数一下,比如

𝑋ℎ𝑖𝑑𝑑𝑒𝑛=𝑅𝑒𝑙𝑢(𝑋ℎ𝑖𝑑𝑑𝑒𝑛∗𝑊1∗𝑊2)

残差操作和层归一化同步骤3.


上述的1,2,3,4就构成Transformer中的一个encoder模块,经过1,2,3,4后得到的就是encode后的隐藏层表示,可以发现它的维度其实和输入是一致的!即:一个批次中句子数 × 一个句子的字数 × 字嵌入的维度

6.代码展示

(1)layer_norm

bs=2,seq=3,dim=5

import torchbatch_size = 2
seq = 3
fea_dim = 5
X = torch.rand(batch_size,seq,fea_dim)
layer_norm = torch.nn.LayerNorm(fea_dim)
out = layer_norm(X)
print(out)
print('-'*30)mean = torch.mean(X,dim=-1,keepdim=True)
std = torch.sqrt(torch.var(X,unbiased=False,dim=-1,keepdim=True) + 1e-5)
weight = layer_norm.state_dict()['weight']
bias = layer_norm.state_dict()['bias']
my_norm = ((X - mean)/std) * weight + bias
print(my_norm)

(2)encoder_layer=1

bs=1,seq=1,dim=6,head=1

import torchseq = 1
dim = 6
heads = 1
batch_size = 1
value = torch.rand(batch_size,seq,dim)encoder_layer = torch.nn.TransformerEncoderLayer(dim,heads,dropout=0.0,batch_first=True)
out = encoder_layer(value)
print(out)# 多头自注意力
def my_scaled_dot_product(query,key,value):qk_T = torch.mm(query,key.T)qk_T_scale = qk_T / torch.sqrt(torch.tensor(value.shape[1]))qk_exp = torch.exp(qk_T_scale)qk_exp_sum = torch.sum(qk_exp,dim=1,keepdim=True)qk_softmax = qk_exp / qk_exp_sumv_attn = torch.mm(qk_softmax,value)return v_attn,qk_softmaxin_proj_weight = encoder_layer.state_dict()['self_attn.in_proj_weight']
in_proj_bias = encoder_layer.state_dict()['self_attn.in_proj_bias']out_proj_weight = encoder_layer.state_dict()['self_attn.out_proj.weight']
out_proj_bias = encoder_layer.state_dict()['self_attn.out_proj.bias']batch_V_output = torch.empty(batch_size,seq,dim)
for i in range(batch_size):in_proj = torch.mm(value[i],in_proj_weight.T) + in_proj_biasQs,Ks,Vs = torch.split(in_proj,dim,dim=-1)head_Vs = []attn_weight = torch.zeros(seq,seq)for Q,K,V in zip(torch.split(Qs,dim//heads,dim=-1),torch.split(Ks,dim//heads,dim=-1),torch.split(Vs,dim//heads,dim=-1)):head_v,_ = my_scaled_dot_product(Q,K,V)head_Vs.append(head_v)V_cat = torch.cat(head_Vs,dim=-1)V_ouput = torch.mm(V_cat,out_proj_weight.T) + out_proj_biasbatch_V_output[i] = V_ouput# 第一次加
first_Add = value + batch_V_output# 第一次layer_norm
norm1_mean = torch.mean(first_Add,dim=-1,keepdim=True)
norm1_std = torch.sqrt(torch.var(first_Add,unbiased=False,dim=-1,keepdim=True) + 1e-5)
norm1_weight = encoder_layer.state_dict()['norm1.weight']
norm1_bias = encoder_layer.state_dict()['norm1.bias']
norm1 = ((first_Add - norm1_mean)/norm1_std) * norm1_weight + norm1_bias# feed forward
linear1_weight = encoder_layer.state_dict()['linear1.weight']
linear1_bias = encoder_layer.state_dict()['linear1.bias']
linear2_weight = encoder_layer.state_dict()['linear2.weight']
linear2_bias = encoder_layer.state_dict()['linear2.bias']
linear1 = torch.matmul(norm1,linear1_weight.T) + linear1_bias
linear1_relu = torch.nn.functional.relu(linear1)
linear2 = torch.matmul(linear1_relu,linear2_weight.T) + linear2_bias# 第二次加
second_Add = norm1 + linear2# 第二次layer_norm
norm2_mean = torch.mean(second_Add,dim=-1,keepdim=True)
norm2_std = torch.sqrt(torch.var(second_Add,unbiased=False,dim=-1,keepdim=True) + 1e-5)
norm2_weight = encoder_layer.state_dict()['norm2.weight']
norm2_bias = encoder_layer.state_dict()['norm2.bias']
norm2 = ((second_Add - norm2_mean)/norm2_std) * norm2_weight + norm2_bias
print(norm2)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/38198.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【VScode】常用配置

1.indenticator 增加白色竖条,显示方法范围 2.Git Graph 给git变换分支增添颜色区分 3.Vue 系列 vue 系列:给纯白色代码添加 颜色区分 3.eslint eslint警告,比如{ } 只写了半个会标红提示错误 等错误信息提示 需要配置js等页面 非下…

1.linux操作系统CPU负载

目录 概述CPU平均负载查看平均负载结束 概述 CPU 使用率 和CPU 平均使用率。 CPU平均负载 单位时间内系统处于 [可运行状态] 和 [不可中断状态] 的平均进程数,就是平均活跃进程数,和CPU使用率并没有直接关系 可运行状态 正在使用CPU或者正等待CPU的进…

【Elasticsearch】linux使用supervisor常驻Elasticsearch,centos6.10安装 supervisor

背景: linux服务器,CentOS 6操作系统,默认版本python2.6.6,避免安装过多的依赖不升级python 在网上查的资料python2.6.6兼容supervisor版本 3.1.3 安装supervisor 手动在python官网下载supervisor,并上传到服务器 下…

Linux_动、静态库

目录 一、静态库 1、静态库的概念 2、制作静态库的指令 3、制作静态库 4、链接静态库 二、动态库 1、动态库的概念 2、制作动态库的指令 3、制作动态库 4、链接动态库 5、动态库的加载 三、静态库与动态库的区别 结语 前言: 在Linux下大部分程序进…

【JavaEE】多线程代码案例(1)

🎏🎏🎏个人主页🎏🎏🎏 🎏🎏🎏JavaEE专栏🎏🎏🎏 🎏🎏🎏上一篇文章:多线程(2…

Android Studio环境搭建(4.03)和报错解决记录

1.本地SDK包导入 安装好IDE以及下好SDK包后,先不要管IDE的引导配置,直接新建一个新工程,进到开发界面。 SDK路径配置:File---->>Other Settings---->>Default Project Structure 拷贝你SDK解压的路径来这,…

ros笔记01--初次体验ros2

ros笔记01--初次体验ros2 介绍安装ros2测试验证ros2说明 介绍 机器人操作系统(ROS)是一组用于构建机器人应用程序的软件库和工具。从驱动程序和最先进的算法到强大的开发者工具,ROS拥有我们下一个机器人项目所需的开源工具。 当前ros已经应用到各类机器人项目开发中…

操作符详解(下) (C语言)

操作符详解下 操作符的属性1.优先级2.结合级 表达式求值1.整型提升2.如何进行整形提升呢?3.算术转换4.问题表达式解析 操作符的属性 C语言的操作符有2个重要的属性:优先级、结合性,这两个属性决定了表达式求值的计算顺序。 1.优先级 优先级…

问题:第一次世界大战的起止时间是 #其他#学习方法#微信

问题:第一次世界大战的起止时间是 A.1913 ~1918 年 B.1913 ~1918 年 C.1914 ~1918 年 D.1914 ~1919 年 参考答案如图所示

RabbitMQ 进程内流控(Flow Control) 源码解析

1. 概述 1.1 为什么要流控? 流控主要是为了防止生产者生产消息速度过快,超过 Broker 可以处理的速度。这时需要暂时限制生产者的生产速度,让 Broker 的处理能够跟上生产速度。 Erlang进程之间不共享内存,每个进程都有自己的进程邮…

42.HOOK引擎核心代码

上一个内容:41.HOOK引擎设计原理 以 40.设计HOOK引擎的好处 它的代码为基础进行修改 主要做的是读写寄存器 效果图 添加一个类 htdHook.h文件中的实现 #pragma once class htdHook { public:htdHook(); };htdHook.cpp文件中的实现: #include "…

独辟蹊径:我是如何用Java自创一套工作流引擎的(下)

作者:后端小肥肠 创作不易,未经允许严禁转载。 姊妹篇:独辟蹊径:我是如何用Java自创一套工作流引擎的(上)_java工作流引擎-CSDN博客 1. 前言 在上一篇博客中,我们详细介绍了如何利用Java语言从…

LC437.路径总和Ⅲ、LC207.课程表

给定一个二叉树的根节点 root ,和一个整数 targetSum ,求该二叉树里节点值之和等于 targetSum 的 路径 的数目。 路径 不需要从根节点开始,也不需要在叶子节点结束,但是路径方向必须是向下的(只能从父节点到子节点&am…

【每日刷题】Day77

【每日刷题】Day77 🥕个人主页:开敲🍉 🔥所属专栏:每日刷题🍍 🌼文章目录🌼 1. LCR 159. 库存管理 III - 力扣(LeetCode) 2. LCR 075. 数组的相对排序 - 力…

Linux——passwd文件,grep,cut

/etc/passwd文件含义 作用 - 记录用户账户信息:共分为7段,使用冒号分割 含义 - 文件内容意义:账户名:密码代号x:UID:GID:注释:家目录:SHELL - 第7列/sbin/nologin&#x…

c++(三)

1. STL 1.1. 迭代器 迭代器是访问容器中元素的通用方法。如果使用迭代器,不同的容器,访问元素的方法是相同的;迭代器支持的基本操作:赋值()、解引用(*)、比较(和!&…

隐私计算实训营第二期第11讲组件介绍与自定义开发

首先给大家推荐一个博主的笔记:隐语课程学习笔记11- 组件介绍与自定义开发-CSDN博客 官方文档:隐语开放标准介绍 | 开放标准 v1.0.dev240328 | 隐语 SecretFlow 01 隐语开放标准 隐语开放标准是为隐私保护应用设计的协议栈。 目前,隐语开…

mysql8.0-学习

文章目录 mysql8.0基础知识-学习安装mysql_8.0登录mysql8.0的体系结构与管理体系结构图连接mysqlmysql8.0的 “新姿势” mysql的日常管理用户安全权限练习查看用户的权限回收:revoke角色 mysql的多种连接方式socket显示系统中当前运行的所有线程 tcp/ip客户端工具基于SSL的安全…

Wp-scan一键扫描wordpress网页(KALI工具系列三十二)

目录 1、KALI LINUX 简介 2、Wp-scan工具简介 3、信息收集 3.1 目标IP(服务器) 3.2kali的IP 4、操作实例 4.1 基本扫描 4.2 扫描已知漏洞 4.3 扫描目标主题 4.4 列出用户 4.5 输出扫描文件 4.6 输出详细结果 5、总结 1、KALI LINUX 简介 Kali Linux 是一…

STM32CubeIDE使用标准库

以STM32F030为例使用标准库文件。 1、新建工程,选择需要使用配置的型号。 2、在工程选型中,选择新建空工程。 3、新建空工程 新建完成,系统只有main函数和启动文件.s有用。 4、启动文件 启动文件,同样是一个汇编文件&#xf…