注意力机制简介

为了减少计算复杂度,通过借鉴生物神经网络的一些机制,我们引入了局部连接、权重共享以及汇聚操作来简化神经网络结构。神经网络中可以存储的信息量称为网络容量。一般来讲,利用一组神经元来存储信息的容量和神经元的数量以及网络的复杂度成正比。如果要存储越多的信息,神经元数量就要越多或者网络要越复杂,进而导致神经网络的参数成倍的增加。我们人脑的生物神经网络同样存在着容量问题,人脑中的工作记忆大概只有几秒钟的时间,类似于循环神经网络中的隐状态。人脑在有限的资源下,并不能同时处理这些过载的信息。大脑就是通过两个重要机制:注意力和记忆机制,来解决信息过载问题的。

基于显著性注意力

聚焦式注意力

注意力一般分为两种:一种是自上而下的有意识的注意力,称为聚焦式(focus)注意力,如上图,我们如果主观上要去看一本书,那在当下的场景中,可能会直接去搜寻书籍;另一种是自下而上的无意识的注意力,称为基于显著性的注意力,如上图,如果我们没有任何目的,随意的一看,可能最容易吸引我们注意的就是红色的茶杯。

在神经网络中,我们可以把最大池化(max pooling),门控等机制看作是自下而上的基于显著性的注意力机制,因为这些操作并没有主动去搜索信息。而我们这里讨论的注意力机制可以看作是一种自上而下的聚焦式注意力机制。用X=[x1,x2,...,xn]表示n个输入信息,为了节省计算资源,不需要将所有的n个输入信息都输入到神经网络进行计算,只需要从X中选择一些和任务相关的信息输入给神经网络。注意力机制的计算可以分为两步:一是在所有输入信息上计算注意力分布,二是根据注意力分布来计算输入信息的加权平均。

注意力机制其实可以理解为求解相似度。这点网上有一个视频讲解的很好。

假设现在有一个从腰围到体重的映射,我们成腰围为key,体重为value。对应效果如下:

key:51——》value:40

key:56——》value:43

key:58——》value:48

那么,如果现在有一个query=57,value该怎么求?

一个最自然的想法就是,57是56和58的平均数,所以对应的value也应该是43和48的平均数,f(q)=(v2+v3)/2,这里因为57距离56和58非常近,我们会非常“注意”它们,所以我们分给56和58的注意力权重为0.5,相当于我们假设这里面存在一个映射,也就是函数f(q),假设用α(q,ki)来表示q与对应ki的注意力权重,那么value的预测值展开来就是f(q)=α(q,k1)v1 + α(q,k2)v2 + α(q,k3)v3,这里我们认为α(q,k1)=0,因为距离较远,所以我们没有考虑,而α(q,k2)和α(q,k3)都为0.5,这就得到了我的结果。

不过这种算法没有考虑到其它数据可能带来的影响,实际的情况可能远比求平均数复杂。所以,更一般的,我们应该来计算注意力权重α(q,ki),关键是怎么计算。一般来说,我们会设置一个注意力打分函数,然后对注意力打分函数的结果来进行softmax操作,得到注意力权重,用公式表示就是:

这里面的a(q,ki)就是注意力打分函数,一般有加性模型、点积模型、缩放点击模型和双线性模型,他们的公式如下:

实际应用中,我们一般选择缩放点击模型来作为打分函数。其中W,U,v为可学习的参数,d为输入信息的维度。注意力分布αi可以解释为在给定任务相关的查询q时,第i个信息受关注的程度。

如果q,k,v是多维的也是一样的,我们可以用矩阵来表示,并以Q,K,V来命名查询和键值对,采用缩放点积模型,公式如下:

当Q,K,V是同一个矩阵时,这就是自注意力机制了,Q=K=V=X,写成公式就是:

这里面有三个可以训练的参数矩阵WQ,WK,WV,写成公式就是:

这就是自注意力机制的公式。

李沐的《动手学深度学习》中有个例子比较好的说明了注意力机制作用,就是Nadaraya-Waston回归,下面我们也来实现一下:

给定成对的输入输出数据集{(x1,y1),...,(xn,yn)},学习一个函数f来预测任意新输入x的输出y_hat = f(x)

假设真实函数是:

我们给数据加一个随机扰动,生成训练数据:

n_train = 50 # 训练样本数
x_train, _ = torch.sort(torch.rand(n_train)*5) # 排序后的训练样本,生成50个[0,1)的数据再乘以5,等于生成50个[0,5)之间的数据# 定义一个真实函数
def f(x):return 2*torch.sin(x)+x**0.8y_train = f(x_train) + torch.normal(0.0,0.5,(n_train,)) # 训练样本的输出,加上了一个扰动
x_test = torch.arange(0,5,0.1) # 测试样本
y_truth = f(x_test) # 测试样本的真实输出

定义一个画图函数:

def plot_kernel_reg(y_hat):plt.plot(x_test, y_truth)plt.plot(x_test, y_hat)plt.xlabel('x')plt.ylabel('y')plt.legend(['Truth', 'Pred'])plt.plot(x_train, y_train, 'o', alpha=0.5)plt.show()

1.直接取训练数据的平均值来预测:

y_hat = torch.repeat_interleave(y_train.mean(), n_test)
plot_kernel_reg(y_hat)
plt.show()

结果肯定是不理想的,就是一条直线。

2.非参数注意力:

非参数注意力其实就是不需要训练的注意力计算,直接训练数据和测试数据直接的差距来计算注意力。

#每⼀⾏都包含着相同的测试输⼊(例如:同样的查询)
x_repeat =x_test.repeat_interleave(n_train).reshape((-1,n_train))
print(x_repeat)
print(x_repeat.shape)# 输出
tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],[0.1000, 0.1000, 0.1000,  ..., 0.1000, 0.1000, 0.1000],[0.2000, 0.2000, 0.2000,  ..., 0.2000, 0.2000, 0.2000],...,[4.7000, 4.7000, 4.7000,  ..., 4.7000, 4.7000, 4.7000],[4.8000, 4.8000, 4.8000,  ..., 4.8000, 4.8000, 4.8000],[4.9000, 4.9000, 4.9000,  ..., 4.9000, 4.9000, 4.9000]])
torch.Size([50, 50])

x_repeat就是把x_test复制了50份,从一个50个数据的1维数组变成了50X50的矩阵。把x_repeat看成是query,x_train看成是key,y_train看出是value,计算注意力权重,根据注意力权重来计算输出。

#x_train包含着键。attention_weights的形状:(n_test,n_train),
#每⼀⾏都包含着要在给定的每个查询的值(y_train)之间分配的注意⼒权重
attention_weights = nn.functional.softmax(-(x_repeat-x_train)**2 / 2, dim=1)
#y_hat的每个元素都是值的加权平均值,其中的权重是注意力权重
y_hat = torch.matmul(attention_weights, y_train)
plot_kernel_reg(y_hat)

3.带参数的注意力机制

这就是正常的注意力机制了。

# X_tile的形状:(n_train,n_train),每⼀⾏都包含着相同的训练输⼊
x_tile = x_train.repeat((n_train,1))
# Y_tile的形状:(n_train,n_train),每⼀⾏都包含着相同的训练输出
y_tile = y_train.repeat((n_train,1))
# keys的形状:('n_train','n_train'-1)
keys = x_tile[(1-torch.eye(n_train)).type(torch.bool)].reshape((n_train,-1))
# values的形状:('n_train','n_train'-1)
values = y_tile[(1-torch.eye(n_train)).type(torch.bool)].reshape((n_train,-1))

这里有个小技巧,用来生成keys和values。

print((1-torch.eye(n_train)).type(torch.bool).shape) # 50X50的矩阵,包含True和False,每一行有49个True和1个False
x_tile[(1-torch.eye(n_train)).type(torch.bool)].shape # 用这个值作为索引后,就得到50行49列的数据,因为每一行都有一个值是False# 输出:
torch.Size([50, 50])
torch.Size([2450])

这里torch.eye(n_train)生成的是50X50的单位矩阵,1-torch.eye(n_train)生成的是50X50的矩阵:

tensor([[0., 1., 1.,  ..., 1., 1., 1.],[1., 0., 1.,  ..., 1., 1., 1.],[1., 1., 0.,  ..., 1., 1., 1.],...,[1., 1., 1.,  ..., 0., 1., 1.],[1., 1., 1.,  ..., 1., 0., 1.],[1., 1., 1.,  ..., 1., 1., 0.]])

(1-torch.eye(n_train)).type(torch.bool)生成的是50X50的布尔矩阵:

tensor([[False,  True,  True,  ...,  True,  True,  True],[ True, False,  True,  ...,  True,  True,  True],[ True,  True, False,  ...,  True,  True,  True],...,[ True,  True,  True,  ..., False,  True,  True],[ True,  True,  True,  ...,  True, False,  True],[ True,  True,  True,  ...,  True,  True, False]])

x_tile[(1-torch.eye(n_train)).type(torch.bool)],代表的是根据bool矩阵取出tensor中对应位置元素,如果是False的元素就不取,取出对应位置为True的元素。

所以keys和values的形状就是[50,49]。下面定义Nadaraya-Waston模型:

class NWKernelRegression(nn.Module):def __init__(self,**kwargs):super().__init__(**kwargs)self.w = nn.Parameter(torch.rand((1,), requires_grad=True)) # 初始化权重参数def forward(self, queries, keys, values):# queries和attenion_weights的形状为(查询个数,“键值对”个数)queries = queries.repeat_interleave(keys.shape[1]).reshape((-1, keys.shape[1]))self.attention_weights = nn.functional.softmax(-((queries-keys)*self.w)**2/2,dim=1)return torch.bmm(self.attention_weights.unsqueeze(1), values.unsqueeze(-1)).reshape(-1)

模型训练:

net = NWKernelRegression()
loss = nn.MSELoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(), lr=0.5)for epoch in range(5):trainer.zero_grad()l = loss(net(x_train, keys, values), y_train)l.sum().backward()trainer.step()print(f'epoch {epoch+1}, loss {float(l.sum()):.6f}')
# 输出:
epoch 1, loss 27.291159
epoch 2, loss 6.903591
epoch 3, loss 6.815420
epoch 4, loss 6.729923
epoch 5, loss 6.646971

可以看到,预测结果已经越来越接近真实数据了,只是曲线还不是很平滑。其实我还尝试了一下用缩放点积模型来构建模型,但是效果并不好,下一篇来看一下多头注意力和自注意力机制。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/30026.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

表面声波滤波器——工艺 (5)

制作工艺流程 声表面波器件制作采用半导体集成电路的平面工艺,首先在压电衬底上通过光刻、镀膜、剥离或刻蚀等工艺制备出叉指换能器,然后经过划片、粘片、压丝、封焊等后续封装工艺得到最后的器件。 整个工艺过程中需要操作使用各种机台 清洗机光刻机涂胶显影台全…

专业和学校到底怎么选,兴趣和知名度到底哪个重要?

前言 2024高考已经落下帷幕,再过不久就到了激动人心的查分和填报志愿的时刻,在那天到来,小伙伴们就要根据自己的分数选取院校和专业,接下来我就以参加22年(破防年)河南高考的大二生来讲述一下我自己对于如何选取院校和专业的看法以…

香港电讯高可用网络助力企业变革金融计算

客户背景 客户是一家金融行业知名的量化私募对冲基金公司,专注于股票、期权、期货、债券等主要投资市场,在量化私募管理深耕多年,目前资管规模已达数百亿级,在国内多个城市均设有办公地点。 客户需求 由于客户业务倚重量化技术…

从“野人饭”走红,探索品牌户外化营销趋势丨小红书内容分析

wildeat,户外是人的天性的回归 近来,“wildeat(户外野吃)”的风潮在小红书逐渐兴起。越来越多的人选择到户外吃一顿,做一次“野人”,主打一个只要氛围到了,就地开饭,不愁吃什么&…

屏蔽房是做什么用的?为什么需要定期检测?

屏蔽房对于不了解的人来说,可能光看名字不知道是做什么的,但是对于一些企业或者机构,却是再熟悉不过的了。和名字一样,屏蔽房是对空间内的信号以及一些外界环境条件进行隔绝,在一些有特殊要求的企业机构中,…

睿治数据治理平台焕新升级,推出全新建模与调度平台

在数据治理的浩瀚征途中,企业常常面临着数据冗余如同连绵山峦,使得关键信息的获取变得困难重重;在数据检索的海洋中,有时迷失方向,消耗大量时间精力,严重影响了运营效率;特别是在处理大规模数据…

assertJ-db 科普

前言 今日我们看看 java 大名鼎鼎的 assertj 是怎么做断言的 数据库断言 在实际的测试中我们总是跟业务打交道的。跟业务打交道一般很难避免验证数据库中的东西。尤其在接口测试中,一个常见的例子是你测试一个下单的接口。 接口返回可能就是成功过或者失败。你无…

安装 Fedora CoreOS 操作系统

首发日期 2024-06-16, 以下为原文内容: 有一台吃灰几年的 e5-26v3 古老机器, 最近翻出来用一下. 首先从安装操作系统开始. 目录 1 FCOS 简介2 安装过程 2.1 下载 iso 镜像文件并制作安装 U 盘2.2 编写安装配置文件2.3 编译安装配置文件2.4 从 U 盘启动并安装 3 SSH 连接并测试…

直播平台美颜技术分析:视频美颜SDK功能实现原理

本篇文章,笔者将深入分析视频美颜SDK的功能实现原理,探讨其在直播平台中的应用。 一、视频美颜技术概述 通过这些功能,用户可以在直播过程中呈现更加理想的自己,从而提高观众的观看体验和互动积极性。 二、视频美颜SDK的功能 1…

不“卷”低价,品牌如何让客户愿意“留下”?

天猫取消预售制度,满减力度大于往年,京东直接将“又便宜又好”定为大促主题。今年的618,离不开两大关键词:拼低价 和 回归用户。 价格“内卷”,消费者可以花更少的钱买到商品,但对商家来说,意味…

6月19日(周三)A股行情总结:A股震荡收跌,恒生科技指数大涨3%,10年期国债期货转涨续创新高

内容提要 车路云概念延续昨日涨势,华铭智能20CM 3连板。贵金属及PEEK材料概念全日走强;港股有色金属及能源股走强,紫金矿业涨超3%,中石油涨超3%。国债期货午后全线转涨,10年期主力合约涨0.05%报104.925元,…

6.17继承

面向对象的特征:封装,继承,多态 使用背景:比如说在动物类底下可以有带毛的动物,带毛的动物符合所有的动物的特征,只是在这个基础上再继续添加一些特征 命名:原有类型称为“基类”或“父类”&a…

计算机网络(1) OSI七层模型与TCP/IP四层模型

一.OSI七层模型 OSI 七层模型是国际标准化组织ISO提出的一个网络分层模型,它的目的是使各种不同的计算机和网络在世界范围内按照相同的标准框架实现互联。OSI 模型把网络通信的工作分为 7 层,从下到上分别是物理层、数据链路层、网络层、传输层、会话层、…

类加载器、反射、注解

1、类加载器 1.1类加载器作用 负责将.class文件(存储的物理文件)加载在到内存中。 1.2类加载的过程 类加载时机 创建类的实例(对象)调用类的类方法访问类或者接口的类变量,或者为该类变量赋值使用反射方式来强制创建…

Java学习 (一) 环境安装及入门程序

一、安装java环境 1、获取软件包 https://www.oracle.com/java/technologies/downloads/ .exe 文件一路装过去就行,最好别装c盘 ,我这里演示的时候是云主机只有C盘 2、配置环境变量 我的电脑--右键属性--高级系统设置--环境变量 在环境变量中添加如下配…

数字孪生涉及到的9大技术栈,都是难啃骨头呀。

数字孪生涉及到多个技术栈,包括但不限于以下几个方面: 数据采集和传感器技术: 数字孪生需要实时获取物理世界的数据,因此需要使用各种传感器技术(如温度传感器、压力传感器、运动传感器等)来采集数据&…

【猫狗分类】Pytorch VGG16 实现猫狗分类1-数据清洗+制作标签文件

Pytorch 猫狗分类 用Pytorch框架,实现分类问题,好像是学习了一些基础知识后的一个小项目阶段,通过这个分类问题,可以知道整个pytorch的工作流程是什么,会了一个分类,那就可以解决其他的分类问题&#xff0…

第6章 设备驱动程序(3)

目录 6.5 块设备操作 6.5.1 块设备的表示 6.5.2 数据结构 6.5.3 向系统添加磁盘和分区 6.5.4 打开块设备文件 本专栏文章将有70篇左右,欢迎关注,查看后续文章。 6.5 块设备操作 特点: 随机访问任意位置。 固定块大小的传输。 块设备在内…

手机网站制作软件是哪些

手机网站制作软件是一种用于设计、开发和创建适用于移动设备的网站的软件工具。随着移动互联网时代的到来,越来越多的用户开始使用手机浏览网页和进行在线交流,因此,手机网站制作软件也逐渐成为了市场上的热门工具。 1. Adobe Dreamweaver&am…

天翼云8080、80端口用不了的问题

天翼云8080、80端口用不了的问题 前言:前段时间天翼云搞了活动,原来公司用的华为云老板说太贵了也快到期了,就换了天翼云的服务器。 排查: 安全组开放 80 8080 防火墙查看 没有问题 nginx nacos dcoker等停了 查看监听端口 发现…