【原创】实现ChatGPT中Transformer模型之Encoder-Decoder

作者:黑夜路人

时间:2023年7月

Transformer Block (通用块)实现

看以上整个链路图,其实我们可以很清晰看到这心其实在Encoder环节里面主要是有几个大环节,每一层主要的核心作用如下:

  • Multi-headed self Attention(注意力机制层):通过不同的注意力函数并拼接结果,提高模型的表达能力,主要计算词与词的相关性和长距离次的相关性。
  • Normalization layer(归一化层):对每个隐层神经元进行归一化,使其特征值的均值为0,方差为1,解决梯度爆炸和消失问题。通过归一化,可以将数据压缩在一个合适范围内,避免出现超大或超小值,有利于模型训练,也能改善模型的泛化能力和加速模型收敛以及减少参数量的依赖。
  • Feed forward network(前馈神经网络):对注意力输出结果进行变换。
  • Another normalization layer(另一个归一化层):Weight Normalization用于对模型中层与层之间的权重矩阵进行归一化,其主要目的是解决梯度消失问题

注意力(Attention)实现

参考其他我们了解的信息可以看到里面核心的自注意力层等等,我们把每个层剥离看看核心这一层应该如何实现。

简单看看注意力的计算过程:

这张图所表示的大致运算过程是:

对于每个token,先产生三个向量query,key,value:

query向量类比于询问。某个token问:“其余的token都和我有多大程度的相关呀?”

key向量类比于索引。某个token说:“我把每个询问内容的回答都压缩了下装在我的key里”

value向量类比于回答。某个token说:“我把我自身涵盖的信息又抽取了一层装在我的value里”

注意力计算代码:

def attention(query: Tensor,key: Tensor,value: Tensor,mask: Optional[Tensor] = None,dropout: float = 0.1):"""定义如何计算注意力得分参数:query: shape (batch_size, num_heads, seq_len, k_dim)key: shape(batch_size, num_heads, seq_len, k_dim)value: shape(batch_size, num_heads, seq_len, v_dim)mask: shape (batch_size, num_heads, seq_len, seq_len). Since our assumption, here the shape is(1, 1, seq_len, seq_len)返回:out: shape (batch_size, v_dim). 注意力头的输出。注意力分数:形状(seq_len,seq_ln)。"""k_dim = query.size(-1)# shape (seq_len ,seq_len),row: token,col: token记号的注意力得分scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(k_dim)if mask is not None:scores = scores.masked_fill(mask == 0, -1e10)attention_score = F.softmax(scores, dim=-1)if dropout is not None:attention_score = dropout(attention_score)out = torch.matmul(attention_score, value)return out, attention_score  # shape: (seq_len, v_dim), (seq_len, seq_lem)

以图中的token a2为例:

它产生一个query,每个query都去和别的token的key做“某种方式”的计算,得到的结果我们称为attention score(即为图中的$$\alpha $$)。则一共得到四个attention score。(attention score又可以被称为attention weight)。

将这四个score分别乘上每个token的value,我们会得到四个抽取信息完毕的向量。

将这四个向量相加,就是最终a2过attention模型后所产生的结果b2。

整个这一层,我们通过代码来进行这个逻辑的简单实现:

class MultiHeadedAttention(nn.Module):def __init__(self,num_heads: int,d_model: int,dropout: float = 0.1):super(MultiHeadedAttention, self).__init__()assert d_model % num_heads == 0, "d_model must be divisible by num_heads"# 假设v_dim总是等于k_dimself.k_dim = d_model // num_headsself.num_heads = num_headsself.proj_weights = clones(nn.Linear(d_model, d_model), 4)  # W^Q, W^K, W^V, W^Oself.attention_score = Noneself.dropout = nn.Dropout(p=dropout)def forward(self,query: Tensor,key: Tensor,value: Tensor,mask: Optional[Tensor] = None):"""参数:query: shape (batch_size, seq_len, d_model)key: shape (batch_size, seq_len, d_model)value: shape (batch_size, seq_len, d_model)mask: shape (batch_size, seq_len, seq_len). 由于我们假设所有数据都使用相同的掩码,因此这里的形状也等于(1,seq_len,seq-len)返回:out: shape (batch_size, seq_len, d_model). 多头注意力层的输出"""if mask is not None:mask = mask.unsqueeze(1)batch_size = query.size(0)# 1) 应用W^Q、W^K、W^V生成新的查询、键、值query, key, value \= [proj_weight(x).view(batch_size, -1, self.num_heads, self.k_dim).transpose(1, 2)for proj_weight, x in zip(self.proj_weights, [query, key, value])]  # -1 equals to seq_len# 2) 计算注意力得分和outout, self.attention_score = attention(query, key, value, mask=mask,dropout=self.dropout)# 3) "Concat" 输出out = out.transpose(1, 2).contiguous() \.view(batch_size, -1, self.num_heads * self.k_dim)# 4) 应用W^O以获得最终输出out = self.proj_weights[-1](out)return out

Norm 归一化层实现

# 归一化层,标准化的计算公式
class NormLayer(nn.Module):def __init__(self, features, eps=1e-6):super(LayerNorm, self).__init__()self.a_2 = nn.Parameter(torch.ones(features))self.b_2 = nn.Parameter(torch.zeros(features))self.eps = epsdef forward(self, x):mean = x.mean(-1, keepdim=True)std = x.std(-1, keepdim=True)return self.a_2 * (x - mean) / (std + self.eps) + self.b_2

前馈神经网络实现

class FeedForward(nn.Module):def __init__(self, d_model, d_ff=2048, dropout=0.1):super().__init__()# 设置 d_ff 缺省值为2048self.linear_1 = nn.Linear(d_model, d_ff)self.dropout = nn.Dropout(dropout)self.linear_2 = nn.Linear(d_ff, d_model)def forward(self, x):x = self.dropout(F.relu(self.linear_1(x)))x = self.linear_2(x)

Encoder(编码器曾)实现

Encoder 就是将前面介绍的整个链路部分,全部组装迭代起来,完成将源编码到中间编码的转换。


class EncoderLayer(nn.Module):def __init__(self, d_model, heads, dropout=0.1):super().__init__()self.norm_1 = Norm(d_model)self.norm_2 = Norm(d_model)self.attn = MultiHeadAttention(heads, d_model, dropout=dropout)self.ff = FeedForward(d_model, dropout=dropout)self.dropout_1 = nn.Dropout(dropout)self.dropout_2 = nn.Dropout(dropout)def forward(self, x, mask):x2 = self.norm_1(x)x = x + self.dropout_1(self.attn(x2, x2, x2, mask))x2 = self.norm_2(x)x = x + self.dropout_2(self.ff(x2))return x

class Encoder(nn.Module):def __init__(self, vocab_size, d_model, N, heads, dropout):super().__init__()self.N = Nself.embed = Embedder(d_model, vocab_size)self.pe = PositionalEncoder(d_model, dropout=dropout)self.layers = get_clones(EncoderLayer(d_model, heads, dropout), N)self.norm = Norm(d_model)def forward(self, src, mask):x = self.embed(src)x = self.pe(x)for i in range(self.N):x = self.layers[i](x, mask)return self.norm(x)

Decoder(解码器层)实现

Decoder部分和 Encoder 的部分非常的相似,它主要是把 Encoder 生成的中间编码,转换为目标编码。

class DecoderLayer(nn.Module):def __init__(self, d_model, heads, dropout=0.1):super().__init__()self.norm_1 = Norm(d_model)self.norm_2 = Norm(d_model)self.norm_3 = Norm(d_model)self.dropout_1 = nn.Dropout(dropout)self.dropout_2 = nn.Dropout(dropout)self.dropout_3 = nn.Dropout(dropout)self.attn_1 = MultiHeadAttention(heads, d_model, dropout=dropout)self.attn_2 = MultiHeadAttention(heads, d_model, dropout=dropout)self.ff = FeedForward(d_model, dropout=dropout)def forward(self, x, e_outputs, src_mask, trg_mask):x2 = self.norm_1(x)x = x + self.dropout_1(self.attn_1(x2, x2, x2, trg_mask))x2 = self.norm_2(x)x = x + self.dropout_2(self.attn_2(x2, e_outputs, e_outputs,src_mask))x2 = self.norm_3(x)x = x + self.dropout_3(self.ff(x2))return x

class Decoder(nn.Module):def __init__(self, vocab_size, d_model, N, heads, dropout):super().__init__()self.N = Nself.embed = Embedder(vocab_size, d_model)self.pe = PositionalEncoder(d_model, dropout=dropout)self.layers = get_clones(DecoderLayer(d_model, heads, dropout), N)self.norm = Norm(d_model)def forward(self, trg, e_outputs, src_mask, trg_mask):x = self.embed(trg)x = self.pe(x)for i in range(self.N):x = self.layers[i](x, e_outputs, src_mask, trg_mask)return self.norm(x)

Transformer 实现

把整个链路结合,包括Encoder和Decoder,最终就能够形成一个Transformer框架的基本MVP实现。

class Transformer(nn.Module):def __init__(self, src_vocab, trg_vocab, d_model, N, heads, dropout):super().__init__()self.encoder = Encoder(src_vocab, d_model, N, heads, dropout)self.decoder = Decoder(trg_vocab, d_model, N, heads, dropout)self.out = nn.Linear(d_model, trg_vocab)def forward(self, src, trg, src_mask, trg_mask):e_outputs = self.encoder(src, src_mask)d_output = self.decoder(trg, e_outputs, src_mask, trg_mask)output = self.out(d_output)return output

代码说明

如果想要学习阅读整个代码,访问 black-transformer 项目,github访问地址:

GitHub - heiyeluren/black-transformer: black-transformer 是一个轻量级模拟Transformer模型实现的概要代码,用于了解整个Transformer工作机制

取代你的不是AI,而是比你更了解AI和更会使用AI的人!

##End##

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/4687.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

中小企业部署MES管理系统需要考虑哪些问题

随着制造业的快速发展,越来越多的中小企业开始意识到数字化管理的重要性。为了提高生产效率、降低成本、提升品质及满足客户需求,部署MES生产管理系统成为了中小企业实现数字化转型的关键一步。然而,在部署MES管理系统时,中小企业…

[QT编程系列-11]:C++图形用户界面编程,QT框架快速入门培训 - 5- QT主要控件与自定义控件

目录 5. QT主要控件 5.1 预定义控件 5.2 自定义控件 5.3 用预定义容器橙子和提升自定义控件 5.3 后记 5. QT主要控件 5.1 预定义控件 在Qt中,有许多预定义的控件(Widgets)可用于创建用户界面。这些控件提供了各种常见的用户界面元素&am…

测试基础 Android 应用测试总结

目录 启动: 功能介绍,引导图,流量提示等: 权限: 文件错误 屏幕旋转: 流量: 缓存(/sdcard/data/com.your.package/cache/): 正常中断: 异…

被B站用户高赞的广告文案:暴涨900万播放

今年6月,B站公布第一季度财报数据,B站日均活跃用户达9370万,月活3.15亿。在高月活的基础上,用户日均使用时长已经到了96分钟,日均视频播放量达41亿。 来源-B站 用户属性年轻、活跃度高已经成为B站典型的平台标签&…

深入篇【C++】谈vector中的深浅拷贝与迭代器失效问题

深入篇【C】谈vector中的深浅拷贝与迭代器失效问题 Ⅰ.深浅拷贝问题1.内置类型深拷贝2.自定义类型深拷贝 Ⅱ.迭代器失效问题1.内部迭代器失效2.外部迭代器失效 Ⅰ.深浅拷贝问题 1.内置类型深拷贝 浅拷贝是什么意思?就是单纯的值拷贝。 浅拷贝的坏处: ①…

【状态估计】基于UKF法、AUKF法、EUKF法电力系统三相状态估计研究(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

Linux 系统编程-开发环境(二)

目录 7 压缩包管理 7.1 tar 7.2 rar 7.3 zip 8 进程管理 8.1 who 8.2 ps 8.3 jobs 8.4 fg 8.5 bg 8.6 kill 8.7 env 8.8 top 9 用户管理 9.1 创建用户 9.2 设置用户组 9.3 设置密码 9.4 切换用户 9.5 root用户 9.6 删除用户 10 网络管理 10.1 i…

那些漏洞挖掘高手都是怎么挖漏洞的?

前言 说到安全就不能不说漏洞,而说到漏洞就不可避免地会说到三座大山: 漏洞分析 漏洞利用 漏洞挖掘 从个人的感觉上来看,这三者尽管通常水乳交融、相互依赖,但难度是不尽相同的。本文就这三者分别谈谈自己的经验和想法。 漏洞分析…

236. 二叉树的最近公共祖先

题目描述: 主要思路: 利用dfs遍历树,依次判断节点是否为公共祖先节点。 class Solution { public:TreeNode* ans;bool dfs(TreeNode* root, TreeNode* p, TreeNode* q){if(!root)return false;bool ldfs(root->left,p,q);bool rdfs(root…

【PostgreSQL内核学习(二)—— 查询分析】

查询分析 查询处理查询分析查询处理与查询分析的关系查询分析执行流程Lex和YaccLex:Yacc:词法分析工具Lex语法分析工具Yacc使用Lex和Yacc的案例 词法和语法分析以SELECT语句为例讲解 PostgreSQL中查询语句如何被解析并生成分析树。 语义分析 声明&#x…

【ArcGIS微课1000例】0070:制作宾馆酒店分布热度热力图

本文讲解在ArcGIS中,基于长沙市酒店宾馆分布矢量点数据(POI数据)绘制酒店分布热力图。 相关阅读: 【GeoDa实用技巧100例】004:绘制长沙市宾馆热度图 【ArcGIS Pro微课1000例】0028:绘制酒店分布热力图(POI数据) 文章目录 一、加载宾馆分布数据二、绘制热度图一、加载宾…

stm32(HAL库)使用printf函数打印到串口

目录 1、简介 2.1 基础配置 2.1.1 SYS配置 2.1.2 RCC配置 2.2 串口外设配置 2.3 项目生成 3、KEIL端程序整合 4、效果测试 1、简介 在HAL库中,常用的printf函数是无法使用的。本文通过重映射实现在HAL库中进行printf函数。 2.1 基础配置 2.1.1 SYS配置 2.1.2 …

ceph集群的维护

ceph集群的维护 1、ceph集群常用命令 1.1查看集群的状态 rootceph-mon1:~#ceph -s#或者 rootceph-mon1:~#ceph health detail #显示集群状态的详细信息1.2查看所有存储池的列表 rootceph-mon1:~# ceph osd pool ls1.3查看所有存储池的编号 rootceph-mon1:~# ceph osd ls…

Linux gdb汇编调试

文章目录 一、示例代码二、gdb汇编指令2.1 step/stepi2.2 next/nexti2.3 info registers2.4 set2.5 x2.6 rsp寄存器2.7 rip 寄存器 参考资料 一、示例代码 &#xff08;1&#xff09; #include <stdio.h>int add(int a, int b) {return a b; }int main() {int a 3;in…

【Python】数据可视化利器PyCharts在测试工作中的应用

点击跳转原文&#xff1a;【Python】数据可视化利器PyCharts在测试工作中的应用 实际应用&#xff1a;常态化性能压测数据统计 import random from pyecharts.charts import Line, Bar, Grid, Pie, Page from pyecharts import options as opts # 查询过去 8 次数据 time_rang…

MVVM 实现记录文本

1. MVVM 框架说明: Model - 数据层 View - 视图层 ViewModel - 管理模型的视图 2. 资源文件 2.1 启动图标: AppIconhttps://img-blog.csdnimg.cn/8fa1031489f544ef9757b6b3ab0eddbe.png 2.2 Display Name: Do Stuff 2.2 颜色图: 2.3 项目结构图: 3. Model 层实现&a…

组合(力扣)dfs + 回溯 + 剪枝 JAVA

给定两个整数 n 和 k&#xff0c;返回范围 [1, n] 中所有可能的 k 个数的组合。 你可以按 任何顺序 返回答案。 示例 1&#xff1a; 输入&#xff1a;n 4, k 2 输出&#xff1a; [ [2,4], [3,4], [2,3], [1,2], [1,3], [1,4], ] 示例 2&#xff1a; 输入&#xff1a;n 1, …

鲸鱼优化算法MATLAB代码

论文 Seyedali Mirjalili,Andrew Lewis. The Whale Optimization Algorithm[J]. Advances in Engineering Software,2016,95.func_plot.m % This function draw the benchmark functionsfunction func_plot(func_name)[lb,ub,dim,fobj]Get_Functions_details(func_name);switch…

数据结构(王道)——线性表之静态链表顺序表和链表的比较

一、静态链表 定义&#xff1a; 代码实现&#xff1a; 如何定义一个静态链表 静态链表的基本操作思路&#xff1a; 初始化静态链表&#xff1a; 静态链表的查找、插入、删除 静态链表总结&#xff1a; 二、顺序表和链表的比较 逻辑结构对比&#xff1a; 存储结构对比&#xff…

vue3 引入dataV 报错,使用patch-package记录插件包 node_modeule 修改记录。 vite 版DataV

开发数字大屏功能&#xff0c;引用dataV UI组件库比较好用&#xff0c;目前分为Vue2 和 Vue3 两个版本。 Vue2 --DataV版本 yarn add jiaminghi/data-viewVue3 --DataV版本 yarn add dataview/datav-vue3vite – --DataV版本 //不想动手改的&#xff0c;也可以使用此版本&a…