002 self-attention自注意力

目录

一、环境

二、self-attention原理

三、完整代码


一、环境

本文使用环境为:

  • Windows10
  • Python 3.9.17
  • torch 1.13.1+cu117
  • torchvision 0.14.1+cu117

二、self-attention原理

自注意力(Self-Attention)操作是基于 Transformer 的机器翻译模型的基本操作,在源语言的编
码和目标语言的生成中频繁地被使用以建模源语言、目标语言任意两个单词之间的依赖关系。给
定由单词语义嵌入及其位置编码叠加得到的输入表示 {xi ∈ Rd},为了实现对上下文语义依赖的建模,进一步引入在自注意力机制中涉及到的三个元素:查询 qi(Query),键 ki(Key),值 vi (Value)。在编码输入序列中每一个单词的表示的过程中,这三个元素用于计算上下文单词所对应的权重得分。直观地说,这些权重反映了在编码当前单词的表示时,对于上下文不同部分所需要的关注程度。具体来说,如图所示,通过三个线性变换 WQ,WK ,WV 将输入序列中的每一个单词表示 xi 转换为其对应的 qi,ki ,vi  向量。

为了得到编码单词 xi 时所需要关注的上下文信息,通过位置 i 查询向量与其他位置的键向量做点积得到匹配分数 qi · k1, qi · k2, ..., qi · kt。为了防止过大的匹配分数在后续 Softmax 计算过程中导致的梯度爆炸以及收敛效率差的问题,这些得分会除放缩因子 √d 以稳定优化。放缩后的得分经过 Softmax 归一化为概率之后,与其他位置的值向量相乘来聚合希望关注的上下文信息,并最小化不相关信息的干扰。上述计算过程可以被形式化地表述如下:

其中 Q  , K  ,V  分别表示输入序列中的不同单词的 q, k, v 向量拼接组成的矩阵,L 表示序列长度,Z 表示自注意力操作的输出。为了进一步增强自注意力机制聚合上下文信息的能力,提出了多头自注意力(Multi-head Attention)的机制,以关注上下文的不同侧面。具体来说,上下文中每一个单词的表示 xi 经过多组线性 {WQ*WK*WV } 映射到不同的表示子空间中。公式会在不同的子空间中分别计算并得到不同的上下文相关的单词序列表示{Zj}。最终,线性变换 WO 用于综合不同子空间中的上下文表示并形成自注意力层最终的输出 xi 。

三、完整代码

import torch.nn as nn
import torch
import math
import torch.nn.functional as Fclass MultiHeadAttention(nn.Module):def __init__(self, heads, d_model, dropout = 0.1):super().__init__()self.d_model = d_modelself.d_k = d_model // heads # 512 / 8 self.h = headsself.q_linear = nn.Linear(d_model, d_model)self.v_linear = nn.Linear(d_model, d_model)self.k_linear = nn.Linear(d_model, d_model)self.dropout = nn.Dropout(dropout)self.out = nn.Linear(d_model, d_model)def attention(self, q, k, v, d_k, mask=None, dropout=None):scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) # self-attention公式# 掩盖掉那些为了填补长度增加的单元,使其通过 softmax 计算后为 0if mask is not None:mask = mask.unsqueeze(1)scores = scores.masked_fill(mask == 0, -1e9)scores = F.softmax(scores, dim=-1) # self-attention公式if dropout is not None:scores = dropout(scores)output = torch.matmul(scores, v) # self-attention公式return outputdef forward(self, q, k, v, mask=None):bs = q.size(0) # 进行线性操作划分为成 h 个头k = self.k_linear(k).view(bs, -1, self.h, self.d_k)q = self.q_linear(q).view(bs, -1, self.h, self.d_k)v = self.v_linear(v).view(bs, -1, self.h, self.d_k)# 矩阵转置k = k.transpose(1,2) q = q.transpose(1,2) v = v.transpose(1,2) # 计算 attentionscores = self.attention(q, k, v, self.d_k, mask, self.dropout)# 连接多个头并输入到最后的线性层concat = scores.transpose(1,2).contiguous().view(bs, -1, self.d_model)output = self.out(concat)return output# 准备q、k、v张量
d_model = 512
num_heads = 8
batch_size = 32
seq_len = 64q = torch.randn(batch_size, seq_len, d_model) # 64 x 512
k = torch.randn(batch_size, seq_len, d_model) # 64 x 512
v = torch.randn(batch_size, seq_len, d_model) # 64 x 512sa = MultiHeadAttention(heads = num_heads, d_model=d_model)
print(sa(q, k, v).shape) # torch.Size([32, 64, 512])
print('')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/208826.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【XILINX】记录ISE/Vivado使用过程中遇到的一些warning及解决方案

前言 XILINX/AMD是大家常用的FPGA,但是在使用其开发工具ISE/Vivado时免不了会遇到很多warning,(大家是不是发现程序越大warning越多?),并且还有很多warning根据消除不了,看着特心烦? 我这里汇总一些我遇到的…

http和https区别

http和https区别 HTTP(Hypertext Transfer Protocol)和HTTPS(Hypertext Transfer Protocol Secure)是用于在网络上传输数据的两种协议。它们之间的主要区别在于安全性和数据传输方式: 安全性:HTTP是明文传…

华清远见嵌入式学习——QT——作业2

作业要求&#xff1a; 代码运行效果图&#xff1a; 登录失败 和 最小化 和 取消登录 登录成功 和 X号退出 代码&#xff1a; ①&#xff1a;头文件 #ifndef LOGIN_H #define LOGIN_H#include <QMainWindow> #include <QLineEdit> //行编辑器类 #include…

如何在centos8上配置一个ca证书颁发机构并且颁发一个自签名证书【超详细!!!】

在CentOS 8上配置CA证书颁发机构并颁发自签名证书的步骤如下&#xff1a; 1. 安装OpenSSL sudo dnf install openssl 2. 创建CA证书目录 sudo mkdir /etc/pki/CA/ sudo chmod 0700 /etc/pki/CA/ 3. 创建CA证书数据库 sudo touch /etc/pki/CA/index.txt sudo echo 1000 >…

Java Spring + SpringMVC + MyBatis(SSM)期末作业项目

本系统是一个图书管理系统&#xff0c;比较适合当作期末作业主要技术栈如下&#xff1a; - 数据库&#xff1a;MySQL - 开发工具&#xff1a;IDEA - 数据连接池&#xff1a;Druid - Web容器&#xff1a;Apache Tomcat - 项目管理工具&#xff1a;Maven - 版本控制工具&#xf…

探索人工智能领域——每日20个名词详解【day12】

目录 前言 正文 总结 &#x1f308;嗨&#xff01;我是Filotimo__&#x1f308;。很高兴与大家相识&#xff0c;希望我的博客能对你有所帮助。 &#x1f4a1;本文由Filotimo__✍️原创&#xff0c;首发于CSDN&#x1f4da;。 &#x1f4e3;如需转载&#xff0c;请事先与我联系以…

学习JVM

java虚拟机 流程&#xff1a;helloworld.java----(javac编译)----helloworld.class-------(java运行)——JVM——机器码JVM功能 *解释和运行 *内存管理 *即时编译&#xff08;跨平台-慢一点&#xff09;jit &#xff08;反复用到的代码 解释保存再内存里面&#xff09;…

进程、线程、线程池状态

线程几种状态和状态转换 进程主要写明三种基本状态&#xff1a; 线程池的几种状态&#xff1a;

STM32的BKP与RTC简介

芯片的供电引脚 引脚表橙色的是芯片的供电引脚&#xff0c;其中VSS/VDD是芯片内部数字部分的供电&#xff0c;VSSA/VDDA是芯片内部模拟部分的供电&#xff0c;这4组以VDD开头的供电都是系统的主电源&#xff0c;正常使用时&#xff0c;全部都要接3.3V的电源上&#xff0c;VBAT是…

Leetcode2477. 到达首都的最少油耗

Every day a Leetcode 题目来源&#xff1a;2477. 到达首都的最少油耗 解法1&#xff1a;贪心 深度优先搜索 题目等价于给出了一棵以节点 0 为根结点的树&#xff0c;并且初始树上的每一个节点上都有一个人&#xff0c;现在所有人都需要通过「车子」向结点 0 移动。 对于…

从阻抗匹配看拥塞控制

先来理解阻抗匹配&#xff0c;但我不按传统方式解释&#xff0c;因为传统方案你要先理解如何定义阻抗&#xff0c;然后再学习什么是输入阻抗和输出阻抗&#xff0c;最后再看如何让它们匹配&#xff0c;而让它们匹配的目标仅仅是信号不反射&#xff0c;以最大能效被负载接收。 …

面试宝典之自我介绍

听人劝、吃饱饭,奉劝各位小伙伴,不要订阅该文所属专栏。 如需要项目实战或者是体系化资源,文末名片加V! 作者:哈哥撩编程,工作十余年, 从事过全栈研发、产品经理等工作,目前在公司担任研发部门CTO。荣誉:2022年度博客之星Top4、2023年度超级个体得主、谷歌与亚马逊开发…

Amazon CodeWhisperer 开箱初体验

文章作者&#xff1a;Coder9527 科技的进步日新月异&#xff0c;正当人工智能发展如火如荼的时候&#xff0c;各大厂商在“解放”码农的道路上不断创造出各种 Coding 利器&#xff0c;今天在下就带大家开箱体验一个 Coding 利器&#xff1a; Amazon CodeWhisperer。 亚马逊云科…

99基于matlab的小波分解和小波能量熵函数

基于matlab的小波分解和小波能量熵函数&#xff0c;通过GUI界面导入西储大学轴承故障数据&#xff0c;以可视化的图对结果进行展现。数据可更换自己的&#xff0c;程序已调通&#xff0c;可直接运行。 99小波分解和小波能量熵函数 (xiaohongshu.com)https://www.xiaohongshu.co…

【LeetCode每日一题合集】2023.11.27-2023.12.3 (⭐)

文章目录 907. 子数组的最小值之和&#xff08;单调栈贡献法&#xff09;1670. 设计前中后队列⭐&#xff08;设计数据结构&#xff09;解法1——双向链表解法2——两个双端队列 2336. 无限集中的最小数字解法1——维护最小变量mn 和 哈希表维护已经去掉的数字解法2——维护原本…

二分查找|前缀和|滑动窗口|2302:统计得分小于 K 的子数组数目

作者推荐 贪心算法LeetCode2071:你可以安排的最多任务数目 本文涉及的基础知识点 二分查找算法合集 题目 一个数组的 分数 定义为数组之和 乘以 数组的长度。 比方说&#xff0c;[1, 2, 3, 4, 5] 的分数为 (1 2 3 4 5) * 5 75 。 给你一个正整数数组 nums 和一个整数…

response应用及重定向和request转发

请求和转发&#xff1a; response说明一、response文件下载二、response验证码实现1.前置知识&#xff1a;2.具体实现&#xff1a;3.知识总结 三、response重定向四、request转发五、重定向和转发的区别 response说明 response是指HttpServletResponse,该响应有很多的应用&…

JavaScript 一些少见多怪的玩意

$$() [].forEach.call($$("*"), function (a) {a.style.outline "1px solid #" (~~(Math.random() * (1 << 24))).toString(16);}); 直接复制到控制台&#xff0c;页面效果就是页面中不同的HTML结构被不同颜色的框圈着。 原理&#xff1a; $$函数…

力扣面试150题 | 轮转数组

力扣面试150题 &#xff5c; 轮转数组 题目描述解题思路代码实现 题目描述 189.轮转数组 给定一个整数数组 nums&#xff0c;将数组中的元素向右轮转 k 个位置&#xff0c;其中 k 是非负数。 示例 1: 输入: nums [1,2,3,4,5,6,7], k 3 输出: [5,6,7,1,2,3,4] 解释: 向右轮…

Kafka在微服务架构中的应用:实现高效通信与数据流动

微服务架构的兴起带来了分布式系统的复杂性&#xff0c;而Kafka作为一款强大的分布式消息系统&#xff0c;为微服务之间的通信和数据流动提供了理想的解决方案。本文将深入探讨Kafka在微服务架构中的应用&#xff0c;并通过丰富的示例代码&#xff0c;帮助大家更全面地理解和应…