李沐65_注意力分数——自学笔记

Additive Attention

等价于将key和value合并起来后放入到一个隐藏大小为h输出大小为1的单隐藏层

总结

1.注意力分数是query和key的相似度,注意力权重是分数的softmax结果

2.两种常见的分数计算:

(1)将query和key合并起来进入一个单输出单隐藏层的MLP

(2)直接将query和key做内积

注意力打分函数

!pip install d2l
import math
import torch
from torch import nn
from d2l import torch as d2l

masked_softmax函数 实现了这样的掩蔽softmax操作(masked softmax operation), 其中任何超出有效长度的位置都被掩蔽并置为0。

def masked_softmax(X, valid_lens):"""通过在最后一个轴上掩蔽元素来执行softmax操作"""# X:3D张量,valid_lens:1D或2D张量if valid_lens is None:return nn.functional.softmax(X, dim=-1)else:shape = X.shapeif valid_lens.dim() == 1:valid_lens = torch.repeat_interleave(valid_lens, shape[1])else:valid_lens = valid_lens.reshape(-1)# 最后一轴上被掩蔽的元素使用一个非常大的负值替换,从而其softmax输出为0X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens,value=-1e6)return nn.functional.softmax(X.reshape(shape), dim=-1)

考虑由两个2X4
矩阵表示的样本, 这两个样本的有效长度分别为2
和3
。 经过掩蔽softmax操作,超出有效长度的值都被掩蔽为0。

masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3]))
tensor([[[0.3505, 0.6495, 0.0000, 0.0000],[0.5069, 0.4931, 0.0000, 0.0000]],[[0.2469, 0.4668, 0.2863, 0.0000],[0.2865, 0.3008, 0.4127, 0.0000]]])

同样,也可以使用二维张量,为矩阵样本中的每一行指定有效长度。

masked_softmax(torch.rand(2, 2, 4), torch.tensor([[1, 3], [2, 4]]))
tensor([[[1.0000, 0.0000, 0.0000, 0.0000],[0.2102, 0.3264, 0.4634, 0.0000]],[[0.4785, 0.5215, 0.0000, 0.0000],[0.1783, 0.1803, 0.3615, 0.2800]]])

additive attention

class AdditiveAttention(nn.Module):"""加性注意力"""def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):super(AdditiveAttention, self).__init__(**kwargs)self.W_k = nn.Linear(key_size, num_hiddens, bias=False)self.W_q = nn.Linear(query_size, num_hiddens, bias=False)self.w_v = nn.Linear(num_hiddens, 1, bias=False)self.dropout = nn.Dropout(dropout)def forward(self, queries, keys, values, valid_lens):queries, keys = self.W_q(queries), self.W_k(keys)# 在维度扩展后,# queries的形状:(batch_size,查询的个数,1,num_hidden)# key的形状:(batch_size,1,“键-值”对的个数,num_hiddens)# 使用广播方式进行求和features = queries.unsqueeze(2) + keys.unsqueeze(1)features = torch.tanh(features)# self.w_v仅有一个输出,因此从形状中移除最后那个维度。# scores的形状:(batch_size,查询的个数,“键-值”对的个数)scores = self.w_v(features).squeeze(-1)self.attention_weights = masked_softmax(scores, valid_lens)# values的形状:(batch_size,“键-值”对的个数,值的维度)return torch.bmm(self.dropout(self.attention_weights), values)

用一个小例子来演示上面的AdditiveAttention类, 其中查询、键和值的形状为(批量大小,步数或词元序列长度,特征大小), 实际输出为(2,1,20)
、(2,10,2)
和(2,10,4)
。 注意力汇聚输出的形状为(批量大小,查询的步数,值的维度)。

queries, keys = torch.normal(0, 1, (2, 1, 20)), torch.ones((2, 10, 2))
# values的小批量,两个值矩阵是相同的
values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(2, 1, 1)
valid_lens = torch.tensor([2, 6])attention = AdditiveAttention(key_size=2, query_size=20, num_hiddens=8,dropout=0.1)
attention.eval()
attention(queries, keys, values, valid_lens)
tensor([[[ 2.0000,  3.0000,  4.0000,  5.0000]],[[10.0000, 11.0000, 12.0000, 13.0000]]], grad_fn=<BmmBackward0>)

管加性注意力包含了可学习的参数,但由于本例子中每个键都是相同的, 所以注意力权重是均匀的,由指定的有效长度决定

d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),xlabel='Keys', ylabel='Queries')

在这里插入图片描述

缩放点积注意力

(scaled dot-product attention)评分函数

class DotProductAttention(nn.Module):"""缩放点积注意力"""def __init__(self, dropout, **kwargs):super(DotProductAttention, self).__init__(**kwargs)self.dropout = nn.Dropout(dropout)# queries的形状:(batch_size,查询的个数,d)# keys的形状:(batch_size,“键-值”对的个数,d)# values的形状:(batch_size,“键-值”对的个数,值的维度)# valid_lens的形状:(batch_size,)或者(batch_size,查询的个数)def forward(self, queries, keys, values, valid_lens=None):d = queries.shape[-1]# 设置transpose_b=True为了交换keys的最后两个维度scores = torch.bmm(queries, keys.transpose(1,2)) / math.sqrt(d)self.attention_weights = masked_softmax(scores, valid_lens)return torch.bmm(self.dropout(self.attention_weights), values)

为了演示上述的DotProductAttention类, 我们使用与先前加性注意力例子中相同的键、值和有效长度。 对于点积操作,我们令查询的特征维度与键的特征维度大小相同。

queries = torch.normal(0, 1, (2, 1, 2))
attention = DotProductAttention(dropout=0.5)
attention.eval()
attention(queries, keys, values, valid_lens)
tensor([[[ 2.0000,  3.0000,  4.0000,  5.0000]],[[10.0000, 11.0000, 12.0000, 13.0000]]])

与加性注意力演示相同,由于键包含的是相同的元素, 而这些元素无法通过任何查询进行区分,因此获得了均匀的注意力权重。

d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),xlabel='Keys', ylabel='Queries')

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/2658.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

模块三:二分——852.山脉数组的峰顶索引

文章目录 题目描述算法原理解法一&#xff1a;暴力查找解法二&#xff1a;二分查找 代码实现暴力查找二分——C二分——Java 题目描述 题目链接&#xff1a;852.山脉数组的峰顶索引 算法原理 解法一&#xff1a;暴力查找 峰顶&#xff1a;比左右区间都大 遍历整个数组&…

谷歌搜索SEO优化需要做什么?

最基本的要求&#xff0c;网站基础要优化好&#xff0c;让你的网站更加友好地服务于用户和搜索引擎&#xff0c;首先你要保证你的网站也适配手机端&#xff0c;现在手机端&#xff0c;如果你的网站在手机上打开慢&#xff0c;或者没有适配手机端&#xff0c;让用户用手机看着电…

笔记:VMware之性能优化

目标&#xff1a;通过调整VMware设置&#xff0c;提高VMware中虚拟机性能 版本&#xff1a;16.2.2 build-19200509 一、首选项 针对所有虚拟机设置&#xff0c;对所有虚拟机都有效 1.1 设置路径&#xff1a;主页->编辑->首选项->更新 软件更新&#xff0c;取消“…

线程池嵌套导致的死锁问题

1、背景 有一个报告功能&#xff0c;报告需要生成1个word&#xff0c;6个excel附件&#xff0c;总共7个文件&#xff0c;需要记录报告生成进度&#xff0c;进度字段jd初始化是0&#xff0c;每个文件生成成功进度加1&#xff0c;生成失败就把生成状态置为失败。 更新进度语句&…

Win11系统变量打不开解决方法

Q&#xff1a; 下图所框选部分&#xff0c;变为灰色&#xff0c;点击不了 A: 1.可能是用户权限过低&#xff0c;升为管理员身份 按win R 调出运行&#xff0c;输入netplwiz 或 control userpasswords2效果都一样分别有两个组User和Administarations选中你的用户对应的组 …

3A开关降压型单节充电管理芯片CS5308D

CS5308D是一款30V耐压&#xff0c;单节锂电池或锂离子聚合物电池的降压型充电管理IC。集成功率MOS&#xff0c;芯片采用同步开关架构&#xff0c;使其在应用时仅需极少的外围器件&#xff0c;可有效减少整体方案尺寸&#xff0c;降低BOM成本。具有最大3A的充电电流能力&#xf…

华为云实验 -- 对云硬盘数据盘进行备份

文章目录 备份Linux系统备份1.购买Linux操作系统的ESC(云服务器)2.挂载数据盘--初始化--分区--格式化2.1.点击"远程登录"a.查看/dev/vdb数据盘b.新建主分区/dev/vdb1 2.2.查看新建分区大小,分区格式信息a.确定之前的分区操作是否正确b.确认完成后&#xff0c;将分区结…

Rust腐蚀服务器搭建架设教程ubuntu系统

Rust腐蚀服务器搭建架设教程ubuntu系统 大家好我是艾西一个做服务器租用的网络架构师。Rust腐蚀游戏对于服务器的配置有一定的要求很多小伙伴就思考用linux系统搭建的话占用会不会小一点&#xff0c;有一定电脑基础的小伙伴都知道Linux系统和windows系统相比较linux因为是面板…

小程序变更主体公证怎么做?

小程序迁移变更主体有什么作用&#xff1f;好多朋友都想做小程序迁移变更主体&#xff0c;但是又不太清楚具体有啥用&#xff0c;今天我就来详细说说。首先&#xff0c;小程序迁移变更主体最重要的作用就是可以修改主体。比如你的小程序原来是 A 公司的&#xff0c;现在 A 公司…

STM32G030F6P6TR 芯片TSSOP20 MCU单片机微控制器芯片

STM32G030F6P6TR 在物联网&#xff08;IoT&#xff09;设备中的典型应用案例包括但不限于以下几个方面&#xff1a; 1. 环境监测系统&#xff1a; 使用传感器来监测温度、湿度、气压等环境因素&#xff0c;并通过无线通信模块将数据发送到中央服务器或云端平台进行分析和监控。…

探索RadSystems:低代码开发的新选择(三)

系列文章目录 探索RadSystems&#xff1a;低代码开发的新选择&#xff08;一&#xff09;&#x1f6aa; 探索RadSystems&#xff1a;低代码开发的新选择&#xff08;二&#xff09;&#x1f6aa; 文章目录 系列文章目录前言一、RadSystems Studio是什么&#xff1f;二、操作日…

机器学习理论基础—神经网络算法公式学习

机器学习理论基础—神经网络公式学习 M-P神经元 M-P神经元&#xff08;一个用来模拟生物行为的数学模型&#xff09;&#xff1a;接收n个输入(通常是来自其他神经 元)&#xff0c;并给各个输入赋予权重计算加权和&#xff0c;然后和自身特有的阈值进行比较 (作减法&#xff0…

​「Python绘图」绘制小猪佩奇

python 绘制小猪佩奇 一、预期结果 二、核心代码 import turtle print("开始绘制小猪佩奇") pen turtle.Turtle() pen.pensize(4) #pen.hideturtle()pen.speed(1000)pen.color("#ff9bc0","pink") pen.setheading(-30) pen.pu() pen.goto(-100,…

LLM学习笔记-2

在未标记数据上进行预训练 本章概要 在上节的笔记中&#xff0c;因为训练出的效果&#xff0c;并不是特别理想&#xff0c;在本节中&#xff0c;会用数据进行训练&#xff0c;使得模型更加的好&#xff1b; 计算文本生成损失 inputs torch.tensor([[16833, 3626, 6100],…

ARP 攻击神器:ARP Spoof 保姆级教程

一、介绍 arpspoof是一种网络工具&#xff0c;用于进行ARP欺骗攻击。它允许攻击者伪造网络设备的MAC地址&#xff0c;以欺骗其他设备&#xff0c;并截获其通信。arpspoof工具通常用于网络渗透测试和安全评估&#xff0c;以测试网络的安全性和漏洞。 以下是arpspoof工具的一些…

TensorRT plugins and ONNX parser编译

https://github.com/NVIDIA/TensorRT是TensorRT plugins and ONNX parser&#xff0c;并不包含TensorRT的nvinfer库&#xff08;libinfer.so、nvinfer.dll&#xff09;&#xff0c;此部分并未开源&#xff0c;只能使用官方支持的平台、环境https://developer.nvidia.com/tensor…

科技改变视听4K 120HZ高刷新率的投影、电视、电影终有用武之地

早在1888年&#xff0c;法国生理学家埃蒂安朱尔马莱就发明了一套盒式摄像机&#xff0c;能以120帧/s的速度在一条纸膜上曝光照片&#xff0c;但是当时没有相匹配的放映设备。而马莱的另一套拍摄设备是60帧/s的规格&#xff0c;并且图像质量非常好。 受此启发&#xff0c;雷诺的…

【软件测试基础】黑盒测试(知识点 + 习题 + 答案)

《 软件测试基础持续更新中》 对于黑盒测试这一章&#xff0c;等价类划分、边界值测试、决策表、场景法&#xff0c;这四种是最容易出大题的&#xff0c;其他几种考察频率很低。下述的一些例题只是经典例题&#xff0c;掌握方法后&#xff0c;还要多加练习&#xff01; 目录 3…

极快!宝藏EI,2-4周录用,接受范围广!

本周投稿推荐 SSCI • 2/4区经管类&#xff0c;2.5-3.0&#xff08;录用率99%&#xff09; SCIE&#xff08;CCF推荐&#xff09; • 计算机类&#xff0c;2.0-3.0&#xff08;最快18天录用&#xff09; SCIE&#xff08;CCF-C类&#xff09; • IEEE旗下&#xff0c;1/2…

短信视频提取批量工具,免COOKIE,博主视频下载抓取,爬虫

痛点&#xff1a;关于看了好多市面的软件&#xff0c;必须要先登录自己的Dy号才能 然后找到自己的COOKIE 放入软件才可以继续搜索&#xff0c;并且无法避免长时间使用 会导致无法正常显示页面的问题。 有没有一种方法 直接可以使用软件&#xff0c;不用设置的COOKIE的方法呢 …