什么是注意力机制?注意力机制的计算规则

我们观察事物时,之所以能够快速判断一种事物(当然允许判断是错误的),是因为我们大脑能够很快把注意力放在事物最具有辨识度的部分从而作出判断,而并非是从头到尾的观察一遍事物后,才能有判断结果,正是基于这样的理论,就产生了注意力机制。

什么是注意力计算规则:

它需要三个指定的输入Q(query),K(key),V(value),然后通过计算公式得到注意力的结果,这个结果代表query在key和value作用下的注意力表示.当输入的Q=K=V时,称作自注意力计算规则。

常见的注意力计算规则:

|| ·将Q,K进行纵轴拼接,做一次线性变化,再使用softmax处理获得结果最后与V做张量乘法。

在这里插入图片描述

|| ·将Q,K进行纵轴拼接,做一次线性变化后再使用tanh函数激活,然后再进行内部求和,最后使用softmax处理获得结果再与V做张量乘法.

在这里插入图片描述

|| ·将Q与K的转置做点积运算,然后除以一个缩放系数再使用softmax处理获得结果最后与V做张量乘法。

在这里插入图片描述

说明:当注意力权重矩阵和V都是三维张量且第一维代表为batch条数时, 则做bmm运算.bmm是一种特殊的张量乘法运算。

bmm运算演示:

# 如果参数1形状是(b × n × m), 参数2形状是(b × m × p), 则输出为(b × n × p)
>>> input = torch.randn(10, 3, 4)
>>> mat2 = torch.randn(10, 4, 5)
>>> res = torch.bmm(input, mat2)
>>> res.size()
torch.Size([10, 3, 5])

注意力机制的作用

在解码器端的注意力机制:能够根据模型目标有效的聚焦编码器的输出结果,当其作为解码器的输入时提升效果,改善以往编码器输出是单一定长张量,无法存储过多信息的情况。

在编码器端的注意力机制:主要解决表征问题,相当于特征提取过程,得到输入的注意力表示。般使用自注意力(self-attention)。

注意力机制实现步骤

第一步:根据注意力计算规则,对Q,K,V进行相应的计算

第二步:根据第一步采用的计算方法,如果是拼接方法,则需要将Q与第二步的计算结果再进行拼接,如果是转置点积一般是自注意力,Q与V相同,则不需要进行与Q的拼接

第三步:最后为了使整个attention机制按照指定尺寸输出,使用线性层作用在第二步的结果上做个线性变换,得到最终对Q的注意力表示

常见注意力机制的代码分析:

import torch
import torch.nn as nn
import torch.nn.functional as Fclass Attn(nn.Module):def __init__(self, query_size, key_size, value_size1, value_size2, output_size):"""初始化函数中的参数有5个, query_size代表query的最后一维大小key_size代表key的最后一维大小, value_size1代表value的导数第二维大小, value = (1, value_size1, value_size2)value_size2代表value的倒数第一维大小, output_size输出的最后一维大小"""super(Attn, self).__init__()# 将以下参数传入类中self.query_size = query_sizeself.key_size = key_sizeself.value_size1 = value_size1self.value_size2 = value_size2self.output_size = output_size# 初始化注意力机制实现第一步中需要的线性层.self.attn = nn.Linear(self.query_size + self.key_size, value_size1)# 初始化注意力机制实现第三步中需要的线性层.self.attn_combine = nn.Linear(self.query_size + value_size2, output_size)def forward(self, Q, K, V):"""forward函数的输入参数有三个, 分别是Q, K, V, 根据模型训练常识, 输入给Attion机制的张量一般情况都是三维张量, 因此这里也假设Q, K, V都是三维张量"""# 第一步, 按照计算规则进行计算, # 我们采用常见的第一种计算规则# 将Q,K进行纵轴拼接, 做一次线性变化, 最后使用softmax处理获得结果attn_weights = F.softmax(self.attn(torch.cat((Q[0], K[0]), 1)), dim=1)# 然后进行第一步的后半部分, 将得到的权重矩阵与V做矩阵乘法计算, # 当二者都是三维张量且第一维代表为batch条数时, 则做bmm运算attn_applied = torch.bmm(attn_weights.unsqueeze(0), V)# 之后进行第二步, 通过取[0]是用来降维, 根据第一步采用的计算方法, # 需要将Q与第一步的计算结果再进行拼接output = torch.cat((Q[0], attn_applied[0]), 1)# 最后是第三步, 使用线性层作用在第三步的结果上做一个线性变换并扩展维度,得到输出# 因为要保证输出也是三维张量, 因此使用unsqueeze(0)扩展维度output = self.attn_combine(output).unsqueeze(0)return output, attn_weights

调用:

query_size = 32
key_size = 32
value_size1 = 32
value_size2 = 64
output_size = 64
attn = Attn(query_size, key_size, value_size1, value_size2, output_size)
Q = torch.randn(1,1,32)
K = torch.randn(1,1,32)
V = torch.randn(1,32,64)
out = attn(Q, K ,V)
print(out[0])
print(out[1])

输出效果:

tensor([[[ 0.4477, -0.0500, -0.2277, -0.3168, -8.4096, -0.5982, 0.1548,
-8.8771, -8.0951. 8.1833. 8.3128. 8.1260, 8.4420. 8.8495.
-0.7774, -0.0995, 0.2629, 0.4957, 1.0922, 0.1428, 0.3024.
-0.2646, -0.0265, 0.0632, 0.3951, 0.1583, 0.1130, 0.5500,
-0.1887, -0.2816, -0.3800, -0.5741, 0.1342, 0.0244, -0.2217,
0.1544, 0.1865, -0.2019, 0.4090, -0.4762, 0.3677, -0.2553,
-0.5199, 0.2290, -0.4407, 0.0663, -8.0182, -8.2168, 0.0913,
-0.2340, 0.1924, -0.3687, 0.1508, 0.3618, -0.0113, 0.2864.
-0.1929, -0.6821, 0.0951, 0.1335, 0.3560, -0.3215
,0.6461,
0.1532]]],grad_fn=<UnsqueezeBackward0>)
tensor([[0.0395, 0.0342, 0.0200, 0.0471, 0.0177, 0.0209, 0.0244, 0.0465, 0.0346,
0.0378, 0.0282, 0.0214, 0.0135, 0.0419, 0.0926, 0.0123, 0.0177, 0.0187,
0.0166, 0.8225, 0.0234, 0.0284, 0.0151, 0.0239, 0.0132, 0.0439, 0.0507,
0.0419, 8.0352, 8.0392, 8.0546, 0.0224]], grad_fn=<SoftmaxBackward>)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/18357.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于 FFlogs API 快速实现的 logs 颜色查询小爬虫

文章目录 找到接口解析响应需要平均颜色和过本次数&#xff1f; 找到接口 首先试了一下爬虫&#xff0c;发现和wow一样官网上有暴露的 API&#xff0c;链接在&#xff1a;FFlogs v1 API 文档链接 通过查询官方提供的 API 接口得知&#xff1a; user_name 角色名字 api_key …

数据结构—栈

栈 栈的概念及结构栈的实现 栈的概念及结构 栈&#xff1a;一种特殊的线性表&#xff0c;其只允许在固定的一端进行插入和删除元素操作。进行数据插入和删除操作的一端称为栈顶&#xff0c;另一端称为栈底。栈中的数据元素遵守后进先出LIFO&#xff08;Last In First Out&…

小夜灯的体势红外传感器 > 红外知识学习

红外是电磁辐射谱中的一部分&#xff0c;它位于可见光谱的红色边缘之外&#xff0c;具有较长的波长。可见光谱是人眼能够感知的电磁辐射范围&#xff0c;而红外光的波长较长&#xff0c;人眼无法感知。 生命光的范围是6~14um 红外光的波长范围一般约为0.7um~1000um&#xff08;…

IP地址转换函数

#include<string.h> #include<arpa/inet.h> #include<stdio.h>int main(void){char ip[]"1.2.3.4";//字符串struct sockaddr_in server_addr;inet_pton(AF_INET,ip,&server_addr.sin_addr.s_addr);//字符串 to 网络字节序printf("s_addr …

Kubernetes 概述

1、K8S 是什么&#xff1f; K8S 的全称为 Kubernetes (K12345678S) 作用 用于自动部署、扩展和管理“容器化&#xff08;containerized&#xff09;应用程序”的开源系统。 可以理解成 K8S 是负责自动化运维管理多个容器化程序&#xff08;比如 Docker&#xff09;的集群&#…

iOS——Block签名

首先来看block结构体对象Block_layout&#xff08;等同于clang编译出来的__Block_byref_a_0&#xff09; #define BLOCK_DESCRIPTOR_1 1 struct Block_descriptor_1 {uintptr_t reserved;uintptr_t size; };#define BLOCK_DESCRIPTOR_2 1 struct Block_descriptor_2 {// requi…

【vue】组件使用教训

组件使用 报错组件找不到 These dependencies were not found: 遇见的问题 在使用vue的时候&#xff0c;做了一个统计图的功能&#xff0c;引入了chart。 但是在运行项目的时候&#xff0c;直接报错启动不起来&#xff0c;报错内容是 告诉我依赖找不到&#xff0c;然后还试…

微信小程序iconfont真机渲染失败

解决方法&#xff1a; 1.将下载的.woff文件在transfonter转为base64&#xff0c; 2.打开网站&#xff0c;导入文件&#xff0c;开启base64按钮&#xff0c;下载转换后的文件 3. 在下载解压后的文件夹中找到stylesheet.css&#xff0c;并复制其中的base64 4. 修改index.wxss文…

从零开始学Docker(二):启动第一个Docker容器

宿主机环境&#xff1a;RockyLinux 9 这个章节不小心搞成命令学习了&#xff0c;后面在整理成原理吧 Docker生命周期 拉取并启动Nginx容器 # 查找镜像 例如&#xff1a;nginx [root192 ~]# docker search nginx 我们可以看到&#xff0c;第一个时官方认证构建的nginx # 拉…

如何能够高效实现表格中的分权限编辑功能

摘要&#xff1a;本文由葡萄城技术团队于CSDN原创并首发。转载请注明出处&#xff1a;葡萄城官网&#xff0c;葡萄城为开发者提供专业的开发工具、解决方案和服务&#xff0c;赋能开发者。 在表格类填报需求中&#xff0c;根据当前登录用户的不同等级&#xff0c;能填报的区域会…

Delphi 开发不一样的窗体标题栏:TTitleBarPanel

目录 TTitleBarPanel 的使用 TTitleBarPanel 的使用进阶 一、设置标题栏高度、颜色 二、个性化标题栏的关闭等按键 我们在用Delphi开发程序的时候&#xff0c;窗体的标题栏一般都是标准的windows标题栏&#xff0c;上面包括&#xff1a;程序图标、标题、最小化、最大化、关闭…

Python爬虫时遇到连接超时解决方案

在进行Python爬虫任务时&#xff0c;经常会遇到连接超时&#xff08;TimeoutError&#xff09;错误。连接超时意味着爬虫无法在规定的时间内建立与目标服务器的连接&#xff0c;导致请求失败。为了帮助您解决这个常见的问题&#xff0c;本文将提供一些解决办法&#xff0c;并提…

用合成数据训练托盘检测模型【机器学习】

想象一下&#xff0c;你是一名机器人或机器学习 (ML) 工程师&#xff0c;负责开发一个模型来检测托盘&#xff0c;以便叉车可以操纵它们。 ‌你熟悉传统的深度学习流程&#xff0c;已经整理了手动标注的数据集&#xff0c;并且已经训练了成功的模型。 推荐&#xff1a;用 NSDT设…

【LeetCode】88. 合并两个有序数组

这道题我总共想了三种解法。 1.将nums2中的元素依次放入nums1有效元素的后面&#xff0c;再总体进行排序。 import java.util.*; class Solution {public void merge(int[] nums1, int m, int[] nums2, int n) {int j 0;for(int i m;i<mn;i){nums1[i] nums2[j];j;}Arrays…

搭建网站 --- 快速WordPress个人博客并内网穿透发布到互联网

文章目录 快速WordPress个人博客并内网穿透发布到互联网 快速WordPress个人博客并内网穿透发布到互联网 我们能够通过cpolar完整的搭建起一个属于自己的网站&#xff0c;并且通过cpolar建立的数据隧道&#xff0c;从而让我们存放在本地电脑上的网站&#xff0c;能够为公众互联…

vue3单选选择全部传all,否则可以多选

<el-form-item label"发布范围-单位选择"><el-radio-group v-model"formData.unitRadio" change"getUnit"><el-radio label"ALL" click.prevent"radioChange(ALL)">全部</el-radio><el-radio la…

java 企业工程管理系统软件源码+Spring Cloud + Spring Boot +二次开发+ MybatisPlus + Redis em

&#xfeff; 工程项目管理软件&#xff08;工程项目管理系统&#xff09;对建设工程项目管理组织建设、项目策划决策、规划设计、施工建设到竣工交付、总结评估、运维运营&#xff0c;全过程、全方位的对项目进行综合管理 工程项目各模块及其功能点清单 一、系统管理 1、数据…

【C++】STL——list的模拟实现、构造函数、迭代器类的实现、运算符重载、增删查改

文章目录 1.模拟实现list1.1构造函数1.2迭代器类的实现1.3运算符重载1.4增删查改 1.模拟实现list list使用文章 1.1构造函数 析构函数 在定义了一个类模板list时。我们让该类模板包含了一个内部结构体_list_node&#xff0c;用于表示链表的节点。该结构体包含了指向前一个节点…

深度学习入门 ---- 张量(Tensor)

文章目录 张量张量在深度学习领域的定义张量的基本属性使用PyTorch安装PyTorch查看安装版本 创建张量常用函数四种创建张量的方式和区别 四则运算 张量 张量在深度学习领域的定义 张量&#xff08;tensor&#xff09;是多维数组&#xff0c;目的是把向量、矩阵推向更高的维度。…

uniapp 微信小程序:v-model双向绑定问题(自定义 props 名无效)

uniapp 微信小程序&#xff1a;v-model双向绑定问题&#xff08;自定义 props 名无效&#xff09; 前言问题双向绑定示例使用 v-model使用 v-bind v-on使用 sync 修饰符 参考资料 前言 VUE中父子组件传递数据的基本套路&#xff1a; 父传子 props子传父 this.$emit(事件名, …