基于位置的前馈神经网络

目录

1、什么是前馈全连接层

2、前馈全连接层的作用

3、代码实现FFN


1、什么是前馈全连接层

在Transformers中前馈全连接层就是具有两层线性层的全连接网络

2、前馈全连接层的作用

考虑注意力机制可能对复杂过程的拟合程度不够,通过增加家两层网络来增强模型的能力

3、代码实现FFN

import torch.nn as nn
import numpy as np
from torch.autograd import Variable
from torch.autograd import Variable
import copy
from torch import tensor, softmax
import math
# 构建Embedding类来实现文本嵌入层
class Embeddings(nn.Module):def __init__(self,vocab,d_model):""":param vocab: 词表的大小:param d_model: 词嵌入的维度"""super(Embeddings,self).__init__()self.lut = nn.Embedding(vocab,d_model)self.d_model = d_modeldef forward(self,x):""":param x: 因为Embedding层是首层,所以代表输入给模型的文本通过词汇映射后的张量:return:"""return self.lut(x) * math.sqrt(self.d_model)
class PositionalEncoding(nn.Module):def __init__(self,d_model,dropout,max_len=5000):""":param d_model: 词嵌入的维度:param dropout: 随机失活,置0比率:param max_len: 每个句子的最大长度,也就是每个句子中单词的最大个数"""super(PositionalEncoding,self).__init__()self.dropout = nn.Dropout(p=dropout)pe = torch.zeros(max_len,d_model) # 初始化一个位置编码器矩阵,它是一个0矩阵,矩阵的大小是max_len * d_modelposition = torch.arange(0,max_len).unsqueeze(1) # 初始一个绝对位置矩阵 max_len * 1div_term = torch.exp(torch.arange(0,d_model,2)*-(math.log(1000.0)/d_model)) # 定义一个变换矩阵,跳跃式的初始化# 将前面定义的变换矩阵进行奇数、偶数的分别赋值pe[:,0::2] = torch.sin(position*div_term)pe[:,1::2] = torch.cos(position*div_term)pe = pe.unsqueeze(0)  # 将二维矩阵扩展为三维和embedding的输出(一个三维向量)相加self.register_buffer('pe',pe) # 把pe位置编码矩阵注册成模型的buffer,对模型是有帮助的,但是却不是模型结构中的超参数或者参数,不需要随着优化步骤进行更新的增益对象。注册之后我们就可以在模型保存后重加载时,将这个位置编码与模型参数一同加载进来def forward(self, x):""":param x: 表示文本序列的词嵌入表示:return: 最后使用self.dropout(x)对对象进行“丢弃”操作,并返回结果"""x = x + Variable(self.pe[:, :x.size(1)],requires_grad = False)   # 不需要梯度求导,而且使用切片操作,因为我们默认的max_len为5000,但是很难一个句子有5000个词汇,所以要根据传递过来的实际单词的个数对创建的位置编码矩阵进行切片操作return self.dropout(x)def subsequent_mask(size):""":param size: 生成向后遮掩的掩码张量,参数 size 是掩码张量的最后两个维度大小,它的最后两个维度形成一个方阵:return:"""attn_shape = (1,size,size) # 定义掩码张量的形状subsequent_mask = np.triu(np.ones(attn_shape),k = 1).astype('uint8') # 定义一个上三角矩阵,元素为1,再使用其中的数据类型变为无符号8位整形return torch.from_numpy(1 - subsequent_mask) # 先将numpy 类型转化为 tensor,再做三角的翻转,将位置为 0 的地方变为 1,将位置为 1 的方变为 0def attention(query, key, value, mask = None, dropout = None):""":param query: 三个张量输入:param key: 三个张量输入:param value: 三个张量输入:param mask: 掩码张量:param dropout: 传入的 dropout 实例化对象:return:"""d_model = query.size(-1)  # 得到词嵌入的维度,取 query 的最后一维大小scores = torch.matmul(query,key.transpose(-2,-1)) / math.sqrt(d_model)     # 按照注意力公式,将 query 和 key 的转置相乘,这里是将 key 的最后两个维度进行转置,再除以缩放系数,得到注意力得分张量 scoresif mask is not None:scores = torch.masked_fill(scores,mask == 0,-1e9)  # 使用 tensor 的 mask_fill 方法,将掩码张量和 scores 张量中每一个位置进行一一比较,如果掩码张量处为 0 ,则使用 -1e9 替换# scores = scores.masked_fill(mask == 0,-1e9)p_attn = softmax(scores, dim = -1) # 对 scores 的最后一维进行 softmax 操作,使用 F.softmax 方法,第一个参数是 softmax 对象,第二个参数是最后一个维度,得到注意力矩阵print('scores.shape ',scores.shape)if dropout is not None:p_attn = dropout(p_attn)return torch.matmul(p_attn,value),p_attn  # 返回注意力表示
class MultiHeadAttention(nn.Module):def __init__(self, head, embedding_dim , dropout=0.1):""":param head: 代表几个头的参数:param embedding_dim: 词向量维度:param dropout: 置零比率"""super(MultiHeadAttention, self).__init__()assert embedding_dim % head == 0     # 确认一下多头的数量可以整除词嵌入的维度 embedding_dimself.d_k = embedding_dim // head  # 每个头获得词向量的维度self.head = headself.linears = nn.ModuleList([copy.deepcopy(nn.Linear(embedding_dim, embedding_dim)) for _ in range(4)])   # 深层拷贝4个线性层,每一个层都是独立的,保证内存地址是独立的,分别是 Q、K、V以及最终的输出线性层self.attn = None   # 初始化注意力张量self.dropout = nn.Dropout(p=dropout)def forward(self, query, key, value, mask=None):""":param query: 查询query [batch size, sentence length, d_model]:param key: 待查询key [batch size, sentence length, d_model]:param value: 待查询value [batch size, sentence length, d_model]:param mask: 计算相似度得分时的掩码(设置哪些输入不计算到score中)[batch size, 1, sentence length]:return:"""if mask is not None:mask = mask.unsqueeze(1) # 将掩码张量进行维度扩充,代表多头中的第 n 个头batch_size = query.size(0)query, key, value = [l(x).view(batch_size, -1, self.head, self.d_k).transpose(1, 2) for l, x in zip(self.linears, (query, key, value))]  # 将1、2维度进行调换,目的是让句子长度维度和词向量维度靠近,这样注意力机制才能找到词义与句子之间的关系# 将每个头传递到注意力层x, self.attn = attention(query, key, value, mask=mask,dropout=self.dropout)# 得到每个头的计算结果是 4 维的张量,需要形状的转换# 前面已经将1,2两个维度进行转置了,所以这里要重新转置回来# 前面已经经历了transpose,所以要使用contiguous()方法,不然无法使用 view 方法x = x.transpose(1, 2).contiguous() \.view(batch_size, -1, self.head * self.d_k)return self.linears[-1](x)  # 在最后一个线性层中进行处理,得到最终的多头注意力结构输出
import torch.nn as nn
import torch
class PositionFeedForward(nn.Module):def __init__(self, d_model, d_ff, dropout=0.1):""":param d_model: 输入维度,词嵌入的维度:param d_ff: 第一个的输出连接第二个的输入:param dropout: 置零比率"""super(PositionFeedForward, self).__init__()self.linear1 = nn.Linear(d_model, d_ff)self.linear2 = nn.Linear(d_ff, d_model)self.dropout = nn.Dropout(p = dropout)def forward(self, x):return self.linear2(self.dropout(torch.relu(self.linear1(x))))
# 实例化参数
d_model = 512
dropout = 0.1
max_len = 60  # 句子最大长度
#-----------------------------------------词嵌入层
# 输入 x 是 Embedding层输出的张量,形状为 2 * 4 * 512
x = Variable(torch.LongTensor([[100,2,42,508],[491,998,1,221]]))
emb = Embeddings(1000,512)   # 嵌入层
embr = emb(x)
#-----------------------------------------位置编码
pe = PositionalEncoding(d_model, dropout,max_len)  # 位置编码
pe_result = pe(embr)
#-----------------------------------------多头注意力机制
head = 8
embedding_dim = 512
dropout = 0.2
query = key = value = pe_result
mask = torch.zeros(2,4,4)
mha = MultiHeadAttention(head,embedding_dim,dropout)
mha_result = mha(query,key,value,mask)
#-----------------------------------------FFN
d_model = 512
d_ff = 62
dropout = 0.2
ff = PositionFeedForward(d_model,d_ff,dropout)
x = mha_result
ff_result = ff(x)
print(ff_result)
print(ff_result.shape)

scores.shape  torch.Size([2, 8, 4, 4])
tensor([[[ 0.4427, -1.4212, -0.3038,  ..., -0.0478, -0.9644,  1.0128],
         [-0.1761, -1.7056,  0.2831,  ..., -0.2049, -1.4458,  0.6031],
         [-1.0499, -0.6499, -0.6668,  ..., -0.4854, -1.1480,  0.2326],
         [-0.4020, -1.5800,  0.0879,  ...,  0.4629, -0.1393,  0.6648]],

        [[ 0.9657, -1.0271, -0.8961,  ..., -2.3648, -0.1479, -0.0632],
         [ 1.1462, -0.6441, -0.6195,  ..., -1.3566, -0.8735, -0.9672],
         [ 0.7842, -0.7735, -0.1261,  ..., -0.8950, -0.1457, -0.5313],
         [ 0.2848, -0.3641,  0.3231,  ..., -3.7226,  1.0617, -0.1215]]],
       grad_fn=<AddBackward0>)
torch.Size([2, 4, 512])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/100758.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MVCC和BufferPool缓存机制

文章目录 1. MVCC多版本并发控制机制2. BufferPool缓存机制 1. MVCC多版本并发控制机制 Mysql可以在可重复读隔离级别下可以保证事务较高的隔离性&#xff0c;这个隔离性是由MVCC机制来保证的&#xff0c;对一行数据的读和写两个操作默认是不会通过加锁互斥来保证隔离性&#…

机器学习与模式识别作业----决策树属性划分计算

文章目录 1.决策树划分原理1.1.特征选择1--信息增益1.2.特征选择2--信息增益比1.3.特征选择3--基尼系数 2.决策树属性划分计算题2.1.信息增益计算2.2.1.属性1的信息增益计算2.2.2.属性2的信息增益计算2.2.3.属性信息增益比较 2.2.信息增益比计算2.3.基尼系数计算 1.决策树划分原…

设计模式 - 解释器模式

目录 一. 前言 二. 实现 三. 优缺点 一. 前言 解释器模式&#xff08;Interpreter Pattern&#xff09;指给定一门语言&#xff0c;定义它的文法的一种表示&#xff0c;并定义一个解释器&#xff0c;该解释器使用该表示来解释语言中的句子&#xff0c;属于行为型设计模式。是…

AIGC|利用大语言模型实现智能私域问答助手

随着ChatGPT的爆火&#xff0c;最近大家开始关注到大语言模型&#xff08;LLM&#xff09;这个领域。像雨后春笋一样&#xff0c;国内外涌现出了很多LLM。作为开发者&#xff0c;我们通常会关注LLM各自擅长的领域和能力&#xff0c;然后思考如何利用它们的能力来解决某个场景或…

Table ‘mysql.proc‘ doesn‘t exist

使用workbench 同步model 報錯 "Table ‘mysql.proc‘ doesn‘t exist" 爲什麽會出現這個錯誤&#xff1f; 原因&#xff1a;误删了mysql数据库 解决办法如下&#xff1a; 1、在服务列表里找到mysql&#xff0c;停止服务 2、把mysql文件夹下的data文件夹备份&…

GaussDB向量数据库为盘古大模型再添助力

在今年7月7日的华为开发者大会2023(Cloud)期间,华为云盘古大模型3.0正式发布。目前盘古大模型已在政务、金融、制造、医药研发、气象等诸多行业发挥巨大价值。此次华为云发布的GaussDB向量数据库,具备一站式部署、全栈自主创新优势,不仅如此,它的ANN算法在行业排名第一,…

华为、小鹏大定爆单,智驾苦尽甘来,车主终于愿意买单

‍作者|德新 编辑|王博 国庆假期结束&#xff0c;车圈的最大热点事件&#xff0c;当属问界M7卖爆&#xff0c;上市不到一个月时间内&#xff0c;狂揽5万张大定订单。 在华为手机强势回归&#xff0c;改款问界M7大热的高光之下&#xff0c;还有一个重要趋势值得关注&#xff1…

[MySQL]基础篇

文章目录 1. MySQL基本使用1.1 MySQL的启动和登录1.1.1 MySQL的启动1.1.2 MySQL的客户端连接 1.2 数据模型 2. SQL2.1 SQL类型2.1.1 数值类型2.1.2 字符串类型2.1.3 日期类型 2.2 DDL2.2.1 数据库操作2.2.2 表操作 - 查询2.2.3 表操作 - 创建表2.2.4 表操作 - 修改 2.3 DML2.3.…

C++语言实现网络爬虫详细代码

当然&#xff01;下面是一个用C语言实现的基本网络爬虫的详细代码示例&#xff1a; #include <iostream> #include <string> #include <curl/curl.h> size_t writeCallback(void* contents, size_t size, size_t nmemb, std::string* output) {size_t totalS…

linux系统配置Samba实现与Windows系统的文件共享

1.linux系统下载安装Samba sudo apt install samba 2.在linux文件系统中创建一个共享目录(通常在用户目录下面创建一个名为share的目录) mkdir share 3.修改samba配置文件 sudo vim /etc/samba/smb.conf 添加配置信息(path share路径,需要修改) ,保存修改 [Share]comm…

HarmonyOS/OpenHarmony原生应用-ArkTS万能卡片组件Stack

堆叠容器&#xff0c;子组件按照顺序依次入栈&#xff0c;后一个子组件覆盖前一个子组件。该组件从API Version 7开始支持。可以包含子组件。 一、接口 Stack(value?: { alignContent?: Alignment }) 从API version 9开始&#xff0c;该接口支持在ArkTS卡片中使用。 二、…

交流回馈老化测试负载的应用

交流回馈老化测试负载的应用非常重要&#xff0c;老化测试是一种对产品进行长时间运行和负载测试的方法&#xff0c;旨在模拟产品在实际使用中的长期稳定性和可靠性。在老化测试过程中&#xff0c;负载是指对产品施加的工作负荷&#xff0c;可以是CPU、内存、硬盘等资源的使用情…

Bias and Fairness in Large Language Models: A Survey

本文是LLM系列文章&#xff0c;针对《Bias and Fairness in Large Language Models: A Survey》的翻译。 大型语言模型中的偏见与公平性研究 摘要1 引言2 LLM偏见与公平的形式化3 偏见评价指标的分类4 偏见评价数据集的分类5 缓解偏见的技术分类6 开放问题和挑战7 结论 摘要 …

docker 部署lnmp

目录 1、部署nginx\ 1.1、vim Dockerfile 1.2、 1.3、vim nginx.conf 2、部署mysql&#xff08;容器IP 为 172.18.0.20&#xff09; 2.1、vim Dockerfile 2.2、vim my.cnf 2.3、 3、部署php&#xff08;容器IP 为 172.18.0.30&#xff09; 3.1、 vim Dockerfile 3.2、…

Docker 的数据管理

目录 绪论 1&#xff0e;数据卷 2&#xff0e;数据卷容器 2.1 端口映射 2.2 容器互联&#xff08;使用centos镜像&#xff09; 2.3 Docker 镜像的创建 2.4 镜像加载原理 2.5 为什么Docker里的centos的大小才200M&#xff1f; 3.Dockerfile 3.1 Docker 镜像结构的分层…

基于FPGA的视频接口之千兆网口(五应用)

简介 相信网络上对于FPGA驱动网口的开发板、博客、论坛数不胜数,为何博主需要重新手敲一遍呢,而不是做一个文抄君呢!因为目前博主感觉网络上描述的多为应用层上的开发,非从底层开始说明,本博主的思虑还是按照老规矩,按照硬件、底层、应用等关系,使用三~四篇文章,来详细…

常见的Web安全漏洞(2021年9月的OWASP TOP 10)

聊Web安全漏洞&#xff0c;就不得不提到OWASP TOP10。开放式Web应用程序安全项目&#xff08;OpenWeb Application Security Project&#xff0c;OWASP&#xff09;是一个开源的、非营利的组织&#xff0c;主要提供有关Web应用程序的实际可行、公正透明、有社会效益的信息&…

【办公自动化】在Excel中按条件筛选数据并存入新的表2.0(文末送书)

&#x1f935;‍♂️ 个人主页&#xff1a;艾派森的个人主页 ✍&#x1f3fb;作者简介&#xff1a;Python学习者 &#x1f40b; 希望大家多多支持&#xff0c;我们一起进步&#xff01;&#x1f604; 如果文章对你有帮助的话&#xff0c; 欢迎评论 &#x1f4ac;点赞&#x1f4…

mysql面试题29:大表查询的优化方案

该文章专注于面试&#xff0c;面试只要回答关键点即可&#xff0c;不需要对框架有非常深入的回答&#xff0c;如果你想应付面试&#xff0c;是足够了&#xff0c;抓住关键点 面试官&#xff1a;说一下大表查询的优化方案 以下是几种常见的大表优化方案&#xff1a; 分区&…

数据治理的核心是什么?_光点科技

数据治理是当今数字化时代中企业管理的关键组成部分。在信息爆炸的时代&#xff0c;企业积累了大量的数据&#xff0c;这些数据不仅是企业宝贵的资产&#xff0c;也是推动业务决策和创新的重要驱动力。数据治理的核心在于建立有效的框架和流程&#xff0c;以确保数据的质量、安…