李沐64_注意力机制——自学笔记

注意力机制

1.卷积、全连接和池化层都只考虑不随意线索

2.注意力机制则显示的考虑随意线索
(1)随意线索倍称之为查询(query)

(2)每个输入是一个值value,和不随意线索key的对

(3)通过注意力池化层来有偏向性的选择某些输入

总结

注意力机制中,通过query(随意线索)和key(不随意线索)来有偏向性的选择输入

代码实现:注意力汇聚:Nadaraya-Watson 核回归

!pip install d2l
import torch
from torch import nn
from d2l import torch as d2l

生成数据集

在这里生成了50个训练样本和50个测试样本。

n_train = 50  # 训练样本数
x_train, _ = torch.sort(torch.rand(n_train) * 5)   # 排序后的训练样本def f(x):return 2 * torch.sin(x) + x**0.8y_train = f(x_train) + torch.normal(0.0, 0.5, (n_train,))  # 训练样本的输出
x_test = torch.arange(0, 5, 0.1)  # 测试样本
y_truth = f(x_test)  # 测试样本的真实输出
n_test = len(x_test)  # 测试样本数
n_test
50

下面的函数将绘制所有的训练样本(样本由圆圈表示), 不带噪声项的真实数据生成函数f
(标记为“Truth”), 以及学习得到的预测函数(标记为“Pred”)。

def plot_kernel_reg(y_hat):d2l.plot(x_test, [y_truth, y_hat], 'x', 'y', legend=['Truth', 'Pred'],xlim=[0, 5], ylim=[-1, 5])d2l.plt.plot(x_train, y_train, 'o', alpha=0.5);

平均汇聚

y_hat = torch.repeat_interleave(y_train.mean(), n_test)
plot_kernel_reg(y_hat)

在这里插入图片描述

非参数注意力汇聚

# X_repeat的形状:(n_test,n_train),
# 每一行都包含着相同的测试输入(例如:同样的查询)
X_repeat = x_test.repeat_interleave(n_train).reshape((-1, n_train))
# x_train包含着键。attention_weights的形状:(n_test,n_train),
# 每一行都包含着要在给定的每个查询的值(y_train)之间分配的注意力权重
attention_weights = nn.functional.softmax(-(X_repeat - x_train)**2 / 2, dim=1)
# y_hat的每个元素都是值的加权平均值,其中的权重是注意力权重
y_hat = torch.matmul(attention_weights, y_train)
plot_kernel_reg(y_hat)

在这里插入图片描述

现在来观察注意力的权重。 这里测试数据的输入相当于查询,而训练数据的输入相当于键。 因为两个输入都是经过排序的,因此由观察可知“查询-键”对越接近, 注意力汇聚的注意力权重就越高。

d2l.show_heatmaps(attention_weights.unsqueeze(0).unsqueeze(0),xlabel='Sorted training inputs',ylabel='Sorted testing inputs')

在这里插入图片描述

批量矩阵乘法

X = torch.ones((2, 1, 4))
Y = torch.ones((2, 4, 6))
torch.bmm(X, Y).shape
torch.Size([2, 1, 6])

在注意力机制的背景中,我们可以使用小批量矩阵乘法来计算小批量数据中的加权平均值。

weights = torch.ones((2, 10)) * 0.1
values = torch.arange(20.0).reshape((2, 10))
torch.bmm(weights.unsqueeze(1), values.unsqueeze(-1))
tensor([[[ 4.5000]],[[14.5000]]])

使用小批量矩阵乘法, 定义Nadaraya-Watson核回归的带参数版本为:

class NWKernelRegression(nn.Module):def __init__(self, **kwargs):super().__init__(**kwargs)self.w = nn.Parameter(torch.rand((1,), requires_grad=True))def forward(self, queries, keys, values):# queries和attention_weights的形状为(查询个数,“键-值”对个数)queries = queries.repeat_interleave(keys.shape[1]).reshape((-1, keys.shape[1]))self.attention_weights = nn.functional.softmax(-((queries - keys) * self.w)**2 / 2, dim=1)# values的形状为(查询个数,“键-值”对个数)return torch.bmm(self.attention_weights.unsqueeze(1),values.unsqueeze(-1)).reshape(-1)

将训练数据集变换为键和值用于训练注意力模型。

# X_tile的形状:(n_train,n_train),每一行都包含着相同的训练输入
X_tile = x_train.repeat((n_train, 1))
# Y_tile的形状:(n_train,n_train),每一行都包含着相同的训练输出
Y_tile = y_train.repeat((n_train, 1))
# keys的形状:('n_train','n_train'-1)
keys = X_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))
# values的形状:('n_train','n_train'-1)
values = Y_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))

训练带参数的注意力汇聚模型时,使用平方损失函数和随机梯度下降。

net = NWKernelRegression()
loss = nn.MSELoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(), lr=0.5)
animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[1, 5])for epoch in range(5):trainer.zero_grad()l = loss(net(x_train, keys, values), y_train)l.sum().backward()trainer.step()print(f'epoch {epoch + 1}, loss {float(l.sum()):.6f}')animator.add(epoch + 1, float(l.sum()))

在这里插入图片描述

如下所示,训练完带参数的注意力汇聚模型后可以发现: 在尝试拟合带噪声的训练数据时, 预测结果绘制的线不如之前非参数模型的平滑。

# keys的形状:(n_test,n_train),每一行包含着相同的训练输入(例如,相同的键)
keys = x_train.repeat((n_test, 1))
# value的形状:(n_test,n_train)
values = y_train.repeat((n_test, 1))
y_hat = net(x_test, keys, values).unsqueeze(1).detach()
plot_kernel_reg(y_hat)

在这里插入图片描述

与非参数的注意力汇聚模型相比, 带参数的模型加入可学习的参数后, 曲线在注意力权重较大的区域变得更不平滑。

d2l.show_heatmaps(net.attention_weights.unsqueeze(0).unsqueeze(0),xlabel='Sorted training inputs',ylabel='Sorted testing inputs')

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/4653.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Unity动画系统】详解Root Motion动画在Unity中的应用(二)

Root Motion遇到Blend Tree 如果Root Motion动画片段的速度是1.8,那么阈值就要设置为1.8,那么在代码中的参数就可以直接反映出Root Motion的最终移动速度。 Compute Thresholds:根据Root Motion中某些数值自动计算这里的阈值。 Velocity X/…

Meilisearch 快速入门(Windows 环境) 搜索引擎 语义搜索

Meilisearch 快速入门(Windows 环境)# 简介# Meilisearch 是一个基于 rust 开发的,快速的、完全开源的轻量级搜索引擎。它的数据存储基于磁盘与内存映射,不受 RAM 限制。在一定数量级下,搜索速度不逊于 Elasticsearch。 下载# 官方服务端包下载地址:github.com/meili…

对于button按钮引发的bug

主要原因就是今天在给button按钮添加一个点击事件的时候,并没有声明button的type类型,就一直发生点击按钮但事件并不触发的问题。 触发这种问题的原因就是: 按钮默认的 type 类型是 "submit",而不是 "button"。当你不显式…

【前端】VUE项目创建

在所需文件夹中打开cmd命令行窗口,输入vue ui 进入web可视化界面选择创建新项目 根据需求依次完成下列选择,下列是参考配置,完成后点击创建项目即可 最终显示完成

(学习日记)2024.05.10:UCOSIII第六十四节:常用的结构体(os.h文件)第三部分

之前的章节都是针对某个或某些知识点进行的专项讲解,重点在功能和代码解释。 回到最初开始学μC/OS-III系统时,当时就定下了一个目标,不仅要读懂,还要读透,改造成更适合中国宝宝体质的使用方式。在学完野火的教程后&a…

从OpenJDK源码看JAVA虚拟机的创建过程

这里写目录标题 关于Java跨平台能力的理解Java Virtual Machine是怎么创建的。1. Java Launcher2. JLI_Launch 入口3. JVM-Init4. 开启新线程并继续5. 调用JavaMain6. 初始化Java虚拟机,并执行Main方法java.c中的InitializeJVM 方法 7. JNI_CreateJavaVM8. 虚拟机创…

WPS的JS宏如何设置Word文档的表格的单元格文字重新编号

希望对Word文档中的表格进行统一处理,表格内的编号,有时候会出现紊乱,下一个表格的编号承接了上一个表格的编号,实际需要重新编号。 当表格比较多时,手动更改非常麻烦,而且更改一遍并不能完成,…

使用 XHbuilder 编辑器 uniapp开发 app 中使用手机本相机可直接拍摄照片进行上传,也可以选择相册进行上传

学习目标: 使用 XHbuilder 编辑器 uniapp开发 app 中使用手机本相机可直接拍摄照片进行上传,也可以选择相册进行上传 学习内容: 相关内容 上传图片上传时调用的相关方法配置的相关模块需要配置的相关权限 知识小结: 总结&#…

最大连续1的个数 ||| ---- 滑动窗口

题目链接 题目: 分析: 题目中说可以将最多k个0翻转成1, 如果我们真的这样算就会十分麻烦, 所以我们可以换一种思路: 找到一个最长的子数组, 最多有k个0解法一: 暴力解法: 找到所有的最多有k个0的子字符串, 返回最长的解法二: 找到最长的子数组, 我们可以想到"滑动窗口算…

【win10相关】更新后出现未连接到互联网的问题及解决

问题背景 在win10更新完系统之后,第二天电脑开机后,发现无法上网,尝试打开百度,但是出现以下图片: 经过检查,发现手机是可以上网的,说明网络本身并没有问题,对防火墙进行了一些设置…

C++/BOOST filesystem fs::directory_iterator一个滑稽的错误

错误来源于,用 fs::directory_iterator iter(folderPath), end; 然后for循环 for (; iter ! iter_end; iter) {} 最开始没问题,后来说加个进度条,统计一下所有文件数量,用了std::distance, int totalFiles std::…

XYCTF2024 部分w

RE 1. 聪明的信使 基础爆破 #include<stdio.h> #include<string.h> int main() {char enc[] "oujp{H0d_TwXf_Lahyc0_14_e3ah_Rvy0acwc!}";char flag[41] {0};int i, j;for (i 0; i < strlen(enc); i){for (j 33; j < 127; j){if ((j < 9…

Skill Check: Fundamentals of Large Language Models

Skill Check: Fundamentals of Large Language Models 完结&#xff01;

Vue项目中引入高德地图步骤详解,附示例代码

vue中如何使用高德地图&#xff0c;下面为您详解。 步骤一&#xff1a;安装高德地图的JavaScript API 在Vue项目的根目录下打开终端&#xff0c;执行以下命令安装高德地图的JavaScript API&#xff1a; npm install amap/amap-jsapi-loader --save 步骤二&#xff1a;创建地…

什么?你还不懂文件系统和软硬链接?

文章目录 序言认识磁盘磁盘在系统中的管理熟悉磁盘各个分区 软硬链接软链接硬链接 序言 首先熟悉一下一些专有名词(了解即可,但必须有一个概念认识) 固态:SSD,笔记本中常装的,台式机中也可以装,常见的对应接口M.2和SATA接口 磁盘:90年代常用的数据存储设备,或是现在企业级数据…

IPv4 NAT(含Cisco配置)

IPv4 NAT&#xff08;含Cisco配置&#xff09; IPv4私有空间地址 类RFC 1918 内部地址范围前缀A10.0.0.0 - 10.255.255.25510.0.0.0/8B172.16.0.0 - 172.31.255.255172.16.0.0/12C192.168.0.0 - 192.168.255.255192.168.0.0/16 这些私有地址可在企业或站点内使用&#xff0c…

从零开始的软件测试学习之旅(四)web项目工作流程介绍

WEB手工项目 项目介绍项目技术分析项目学习准备工作如何快速熟悉项目举例熟悉TPshop项目 总体系统项目介绍项目与数据库测试流程什么是软件需求需求评审 测试计划测试方案测试计划和测试方案的区别 项目介绍 满足经典三层架构:前端 后端 数据库 前端:运行在用户端的浏览器和客…

同仁堂医养拟赴港上市,养老产业的盈利难题有了答案?

提及银发经济&#xff0c;大众可能最先想到的就是养老产业&#xff0c;在市场需求推动下&#xff0c;这一细分赛道的增长已势不可挡。单从入局者的积极性就可以把握到赛道前景之广阔。 天眼查专业版数据显示&#xff0c;截至目前&#xff0c;我国拥有养老相关企业36.2万家&…

线上办理离婚快速离婚,无需双方见面异地可办

现在离婚有两种方式 一种是协议离婚&#xff0c;双方都同意的情况下&#xff0c;可以去民政局协议离婚&#xff0c;有30天冷静期&#xff0c;冷静期过后需要双方再次去民政局办理离婚手续。 另一种是诉讼离婚&#xff0c;一方不同意离婚&#xff0c;可以选择诉讼离婚。可以全…

【PPT设计】颜色对比、渐变填充、简化框线、放大镜效果、渐变形状配图、线条的使用

目录 图表颜色对比、渐变填充、简化框线放大镜效果渐变形状配图 线条的使用区分标题与说明信息区分标题与正文,区分不同含义的内容**聚焦****引导****注解****装饰** 图表 颜色对比、渐变填充、简化框线 小米汽车正式亮相&#xff01;你们都在讨论价格&#xff0c;我全程只关…