边写代码边学习之TF Attention

1. Attention 背景介绍

通常注意力机制可以使得网络资源集中到某几个需要关注的部分上,和有选择性的弱化对网络结果不重要的部分。网络的注意力机制来源于人的视觉注意力,因为人的精力有限,不能注意到所有的细节,而是有选择性的弱化不需要的或不重要的部分和集中到重要的部分上面,通过这种方式感知外部世界后,指挥人的眼睛移动[19]。注意力机制网络(Attention Mechanism)在一个网络任务中通过关注部分有用方式从而很大的降低了网络的复杂性。2014年,Mnih等人提出了视觉注意力机制的循环神经网络模型(Recurrent models of visual attention),发现加入注意力机制后的循环神经网络能够更好识别出图像中重要的部分,不仅使模型具有了高并发的网络计算能力而且还提高模型识别图像准确率[21]。2015年,Bahdanau D等人首次在自然语言处理领域加入了把注意力机制并且把它用到机器翻译和对齐任务,与传统神经网络模型相比,在准确率上有很大幅度的提升[22]。随后同年,Luong等人在Bahdanau的论文中的注意力机制基础上提出了两种注意力机制即全局注意力模型(Global attention model)和局部注意力模型(Local attention model),进一步提升了seq2seq这种自然语言处理任务的效果[23]。这两种注意力机制主要区别是全局注意力模型会利用query中每个token的向量值,而局部注意力模型仅仅利用其中某一部分token的向量值。2016年,Yin W等人提出基于卷机的注意力机制(attention based CNN, ABCNN),首次在卷积神经网络模型中引入了注意力机制对句子进行建模。在答案选择数据集(Answer Selection, AS)实验发现,分类的准确率有了很大的提升[24]。2017年,Vaswani等人通过单纯利用自注意力机制(Self-Attention)和全连接网络来训练语言翻译模型,与原有的神经网络模型相比,准确率更高[10]。

[21] Mnih V. Heess V. Graves A. Recurrent models of visual attention[C]//Proceedings of Advances in neural information processing systems 2014: 2204-2212

[24] Yin W , H Schütze,  Xiang B , et al. ABCNN: Attention-Based Convolutional Neural Network for Modeling Sentence Pairs[J]. Computer Science, 2015.

[23] Luong M T, Pham H, Manning C D. Effective approaches to attention-based neural machine translation[C]//Proceedings of the 2015 conference on empirical methods in natural language processing. Lisbon, Portugal: Association for Computational Linguistics, 2015: 1412–1421.

[22] Bahdanau D ,  Cho K ,  Bengio Y . Neural Machine Translation by Jointly Learning to Align and Translate[J]. Computer Science, 2014.

[10] Vaswani A, Shazeer N, Parmar N, et al. Attention is all you need[C]//Advances in neural information processing systems. 2017: 5998-6008.

2. 什么是Attention

注意力机制(Attention Mechanism)是机器学习和人工智能领域中的一个重要概念,用于模拟人类视觉或听觉等感知过程中的关注机制。注意力机制的目标是让模型能够在处理信息时,更加关注与任务相关的部分,忽略与任务无关的信息。这种机制最初是受到人类大脑对信息处理的启发而提出的。

注意力机制的基本原理如下:

  1. 输入信息:首先,注意力机制接收输入信息,这可以是序列数据、图像、语音等。

  2. 查询、键和值:对于每个输入,注意力机制引入了三个部分:查询(query)、键(key)、值(value)。这些部分通常是通过神经网络学习得到的。查询用于表示要关注的内容,键用于表示输入信息中的特征,值则是与每个键相关的信息。

  3. 权重分配:注意力机制根据查询和键之间的关系来计算权重,这些权重决定了每个值在最终输出中的贡献程度。通常使用某种形式的相似度度量(如点积、缩放点积等)来计算权重。

  4. 加权求和:将计算得到的权重与对应的值相乘,然后将它们加权求和,得到最终的输出。这个输出通常包含了模型在处理输入信息时关注的部分。

  5. 重复:上述过程通常会被重复多次,以便模型可以在不同的上下文中动态地调整注意力。

注意力机制的核心思想是让模型能够自动地确定在处理输入信息时要关注哪些部分,从而提高了模型在各种任务中的性能。它在自然语言处理、计算机视觉和语音处理等领域都有广泛的应用,如在机器翻译中的Transformer模型、图像分割中的U-Net模型以及语音识别中的Listen, Attend and Spell(LAS)模型等。

总的来说,注意力机制可以帮助模型更好地理解和利用输入信息,提高了模型的表现和泛化能力。

3. Why Attention

由于LSTM和GRU只在一定程度上改进了循环神经网络的长句子依赖问题,并且信息的记忆能力也不是很强和计算能力有限。如果模型要记住很多信息,不得不设计的更复杂,为了解决这些问题,注意力机制出现了,它即能从大量信息中选择重要的信息来缓解神经网络模型的复杂度,而且能高效的并行运算。注意力机制的计算是一个匹配的过程,即通过一个查询(Query)向量到键(Key)和值(Value)对数据对来映射输出值.

注意力的计算一般有三个阶段。第一阶段是计算查询向量Q和每个输入的K的相关性或相似度,得到注意力权重系数S_i :

S_i=f(Q,K_i)

第二阶段是使用SoftMax函数对第一阶段得出的权重系数进行尺度缩放,即把它归一化为概率分布 ai ,分子是把神经元的当前输出映射到(0,+∞),分母是所有输出结果值的总和,公式如下:

a _i=softmax (S_i ) = e^{S_i }/(\sum e^{S_j})

第三阶段:将第二阶段得出的权重与value值加权求和,得到最终需要的Attention数值:

Attention(Q,K,V)=\sum a_i V_i

4. TF attention api 介绍

Attention class

tf.keras.layers.Attention(use_scale=False, score_mode="dot", **kwargs)

Dot-product attention layer, a.k.a. Luong-style attention.

Inputs are query tensor of shape [batch_size, Tq, dim]value tensor of shape [batch_size, Tv, dim] and key tensor of shape [batch_size, Tv, dim]. The calculation follows the steps:

  1. Calculate scores with shape [batch_size, Tq, Tv] as a query-key dot product: scores = tf.matmul(query, key, transpose_b=True).
  2. Use scores to calculate a distribution with shape [batch_size, Tq, Tv]distribution = tf.nn.softmax(scores).
  3. Use distribution to create a linear combination of value with shape [batch_size, Tq, dim]return tf.matmul(distribution, value).

5. 实验代码

5.1.  验证并理解TF attention方法,只输入query和value矩阵。

def softmax(t):s_value = np.exp(t) / np.sum(np.exp(t), axis=-1, keepdims=True)# print('softmax value: ', s_value)return s_valuedef numpy_attention(inputs,mask=None,training=None,return_attention_scores=False,use_causal_mask=False):query = inputs[0]value = inputs[1]key = inputs[2] if len(inputs) > 2 else valuescore = np.matmul(query, key.transpose())attention_score_np = softmax(score)result = np.matmul(attention_score_np, value)print('attention score in numpy =', attention_score_np)print('result in numpy = ', result)def verify_logic_in_attention_with_query_value():query_data = np.array([[1, 0.0, 1],[2, 3, 1]])value_data = np.array([[2, 1.0, 1],[1, 4, 2 ]])print(query_data.shape)numpy_attention([query_data, value_data], return_attention_scores=True)print("=============following is keras attention output================")attention_layer= tf.keras.layers.Attention()result, attention_scores = attention_layer([query_data, value_data], return_attention_scores=True)print('attention_scores = ', attention_scores)print('result=', result);
if __name__ == '__main__':verify_logic_in_attention_with_query_value()

运行结果

(2, 3)
attention score in numpy = [[5.0000000e-01 5.0000000e-01][3.3535013e-04 9.9966465e-01]]
result in numpy =  [[1.5        2.5        1.5       ][1.00033535 3.99899395 1.99966465]]
=============following is keras attention output================
attention_scores =  tf.Tensor(
[[5.0000000e-01 5.0000000e-01][3.3535014e-04 9.9966466e-01]], shape=(2, 2), dtype=float32)
result= tf.Tensor(
[[1.5       2.5       1.5      ][1.0003353 3.998994  1.9996647]], shape=(2, 3), dtype=float32)

5.2.  验证并理解TF attention方法,输入query, key, value矩阵。

def verify_logic_in_attention_with_query_key_value():query_data = np.array([[1, 0.0, 1],[2, 3, 1]])value_data = np.array([[2, 1.0, 1],[1, 4, 2 ]])key_data = np.array([[1, 2.0, 2], [3, 1, 0.1]])print(query_data.shape)numpy_attention([query_data, value_data, key_data], return_attention_scores=True)print("=============following is keras attention output================")attention_layer= tf.keras.layers.Attention()result, attention_scores = attention_layer([query_data, value_data, key_data], return_attention_scores=True)print(attention_layer.get_weights())print('attention_scores = ', attention_scores)print('result=', result);
if __name__ == '__main__':verify_logic_in_attention_with_query_key_value()

结果

(2, 3)
attention score in numpy = [[0.47502081 0.52497919][0.7109495  0.2890505 ]]
result in numpy =  [[1.47502081 2.57493756 1.52497919][1.7109495  1.86715149 1.2890505 ]]
=============following is keras attention output================
[]
attention_scores =  tf.Tensor(
[[0.47502086 0.52497923][0.7109495  0.28905058]], shape=(2, 2), dtype=float32)
result= tf.Tensor(
[[1.4750209 2.5749378 1.5249794][1.7109495 1.8671517 1.2890506]], shape=(2, 3), dtype=float32)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/71167.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

前端Vue自定义得分构成水平柱形图组件 可用于系统专业门类得分评估分析

引入Vue自定义得分构成水平柱形图组件:cc-horBarChart 随着技术的发展,传统的开发方式使得系统的复杂度越来越高,一个小小的改动或小功能的增加可能会导致整体逻辑的修改,造成牵一发而动全身的情况。为了解决这个问题&#xff0c…

当AI遇到IoT:开启智能生活的无限可能

文章目录 1. AI和IoT的融合1.1 什么是人工智能(AI)?1.2 什么是物联网(IoT)?1.3 AI和IoT的融合 2. 智能家居2.1 智能家居安全2.2 智能家居自动化 3. 医疗保健3.1 远程监护3.2 个性化医疗 4. 智能交通4.1 交通…

c高级 day2

写一个1.sh脚本,将以下内容放到脚本中:在家目录下创建目录文件,dir 在dir下创建dir1和dir2 把当前目录下的所有文件拷贝到dir1中,把当前目录下的所有脚本文件拷贝到dir2中把dir2打包并压缩为dir2.tar.xz 再把dir2.tar.xz移动到dir1中解压d…

华为云云服务器评测|华为云耀云L搭建zerotier服务测试

0. 环境 - Win10 - 云耀云L服务器 1. 安装docker 检查yum源,本EulerOS的源在这里: cd /etc/yum.repos.d 更新源 yum makecache 安装 yum install -y docker-engine 运行测试 docker run hello-world 2. 运行docker镜像 默认配…

软件架构设计(四) 基于服务的架构(SOA)

前面我们了解到了层次架构中表示层的架构分层,分为了MVC,MVP,MVVM等架构风格,下面我们了解一下SOA架构与微服务架构。 什么是服务? 服务是一种为了满足某项业务需求的操作,规则等的逻辑组合,它包含了一系列有序活动的交互,为实现用户目标提供支持。 SOA的起源 前面…

Windows Server 系统各版本及授权说明(附下载地址

本文为Windows Server系统各版本差异对比及授权说明。 会对相关目前仍主流使用的相关Windows Server系统版本和相关授权进行对比和功能说明。 WindowsServer2012 R2 Windows Server 2012 R2授权方式是按照物理CPU数量进行授权,比如物理服务器CPU插槽数量2&#xff…

部署Django报错-requires SQLite 3.8.3 or higher

记一次CentOS7部署Django项目时的报错 问题出现 在部署测试环境时,有需要用到一个python的后端服务,要部署到测试环境中去 心想这不是so easy吗,把本地调试时使用的python版本及Django版本在服务器上对应下载好,然后直接执行命…

1065 A+B and C (64bit)

题&#xff1a;点我 题目大意&#xff1a; 这题虽然看着像签到&#xff0c;然鹅签不过去。 因为我最初写的沙雕代码是&#xff1a; #include<iostream> #include<cstdio> using namespace std; int main(void) {int t;scanf("%d", &t);for (int i …

【云计算网络安全】解析DDoS攻击:工作原理、识别和防御策略 | 文末送书

文章目录 一、前言二、什么是 DDoS 攻击&#xff1f;三、DDoS 攻击的工作原理四、如何识别 DDoS 攻击五、常见的 DDoS 攻击有哪几类&#xff1f;5.1 应用程序层攻击5.1.1 攻击目标5.1.2 应用程序层攻击示例5.1.3 HTTP 洪水 5.2 协议攻击5.2.1 攻击目标5.2.2 协议攻击示例5.2.3 …

IDEA中Run/Debug Configurations添加VM options和Program arguments

1. 现象描述 我在我的IDEA当中打开配置模板后&#xff0c;发现没有VM options和Program arguments&#xff0c;也就是虚拟机选项和程序实参这两项&#xff0c;导致我不能配置系统属性参数和命令行参数&#xff01;&#xff01;&#xff01;&#xff01;&#xff01;&#xff0…

最强的AI视频去码图片修复模型:CodeFormer

目录 1 CodeFormer介绍 1.1 CodeFormer解决的问题 1.2 人脸复原的挑战 1.3 方法动机 1.4 模型实现 1.5 实验结果 2 CodeFormer部署与运行 2.1 conda环境安装 2.2 运行环境构建 2.3 模型下载 2.4 运行 2.4.1 人脸复原 ​编辑​编辑 2.4.2 全图片增强 2.4.3 人脸颜色…

Android逆向学习(二)vscode进行双开与图标修改

Android逆向学习&#xff08;二&#xff09;vscode进行双开与图标修改 写在前面 这其实应该还是吾爱的第一个作业&#xff0c;但是写完上一个博客的时候已经比较晚了&#xff0c;如果继续敲机械键盘吵到室友&#xff0c;我怕我看不到明天的太阳&#xff0c;所以我决定分成两篇…

类ChatGPT大模型LLaMA及其微调模型

1.LLaMA LLaMA的模型架构:RMSNorm/SwiGLU/RoPE/Transfor mer/1-1.4T tokens 1.1对transformer子层的输入归一化 对每个transformer子层的输入使用RMSNorm进行归一化&#xff0c;计算如下&#xff1a; 1.2使用SwiGLU替换ReLU 【Relu激活函数】Relu(x) max(0,x) 。 【GLU激…

Unity ProBuilder(自己创建斜面、拐角)

目录 基础操作 下载 打开面板 新增对象 材质保存 1.斜面实例 2.拐角实例 3.切割实例 4.单独面赋值 基础操作 下载 打开面板 新增对象 选中想创建的块体后&#xff0c;在编辑器见面拉出块体 材质保存 打开材质编辑器后&#xff0c;将材质赋值&#xff0c;之后&am…

【开发】视频云存储/安防监控/AI分析/视频AI智能分析网关:垃圾满溢算法

随着我国科技的发展和城市化进程加快&#xff0c;大家对于生活环境以及空气质量更加重视&#xff0c;要求越来越严格。城市街道垃圾以及生活区垃圾满溢已经成为城市之痛。乱扔垃圾&#xff0c;垃圾不入桶这些行为已经严重影响到了城市的美化问题。特别是炎热的夏日和雨水季节&a…

在iPhone上构建自定义数据采集完整指南

在iPhone上构建自定义数据采集工具可以帮助我们更好地满足特定需求&#xff0c;提高数据采集的灵活性和准确性。本文将为您提供一份完整的指南和示例代码&#xff0c;教您如何在iPhone上构建自定义数据采集工具。 自定义数据采集工具的核心组件 a、数据模型 数据模型是数据采…

开开心心带你学习MySQL数据库之第六篇上

​ &#x1f4ae; &#x1f4ae;&#x1f4ae; 只要路是对的&#xff0c;就不害怕遥远! &#x1f4ae; &#x1f4ae;&#x1f4ae; &#x1f386;&#x1f386;&#x1f386;窗台是风景&#xff0c;笔下有前途&#xff0c;低头是题海&#xff0c;抬头是未来&#x1f386;&…

【BI看板】Superset时间过滤控件二次开发

有没有人发觉Superset时间过滤组件非常高级&#xff0c;&#x1f61f;但又有点复杂&#xff0c;没有选择时间区间的快捷方式。 Superset的时间过滤控件可以通过在代码中进行二次开发来进行定制。以下是一些可能有用的提示&#xff1a; 查找源代码&#xff1a;可以在Superset的源…

Redis之bigkey问题解读

目录 什么是bigkey&#xff1f; bigkey引发的问题 如何查找bigkey redis-cli --bigkeys MEMORY USAGE bigKey如何删除 渐进式删除 unlink bigKey生产调优 什么是bigkey&#xff1f; bigkey简单来说就是存储本身的key值空间太大&#xff0c;或者hash&#xff0c;list&…

意向客户的信息获取到底是怎样的,快来get一下

客户信息获取技术真的可以为企业提供精准客源吗&#xff1f;这个渠道到底安不安全&#xff0c;技术到底成不成熟&#xff1f;效果到底如何&#xff1f;下面简单的和大家分析一下。 客户信息获取技术是怎样的 手机采集引流方面&#xff0c;上量不精准&#xff0c;精准不上量的说…