动手学深度学习10.5. 多头注意力-笔记练习(PyTorch)

本节课程地址:多头注意力代码_哔哩哔哩_bilibili

本节教材地址:10.5. 多头注意力 — 动手学深度学习 2.0.0 documentation

本节开源代码:...>d2l-zh>pytorch>chapter_multilayer-perceptrons>multihead-attention.ipynb


多头注意力

在实践中,当给定相同的查询、键和值的集合时, 我们希望模型可以基于相同的注意力机制学习到不同的行为, 然后将不同的行为作为知识组合起来, 捕获序列内各种范围的依赖关系 (例如,短距离依赖和长距离依赖关系)。 因此,允许注意力机制组合使用查询、键和值的不同 子空间表示(representation subspaces)可能是有益的。

为此,与其只使用单独一个注意力汇聚, 我们可以用独立学习得到的 h 组不同的 线性投影(linear projections)来变换查询、键和值。 然后,这 h 组变换后的查询、键和值将并行地送到注意力汇聚中。 最后,将这 h 个注意力汇聚的输出拼接在一起, 并且通过另一个可以学习的线性投影进行变换, 以产生最终输出。 这种设计被称为多头注意力(multihead attention) (="https://zh.d2l.ai/chapter_references/zreferences.html#id174">Vaswaniet al., 2017)。 对于 h 个注意力汇聚输出,每一个注意力汇聚都被称作一个(head)。 图10.5.1 展示了使用全连接层来实现可学习的线性变换的多头注意力。

模型

在实现多头注意力之前,让我们用数学语言将这个模型形式化地描述出来。 给定查询 \mathbf{q} \in \mathbb{R}^{d_q} 、 键 \mathbf{k} \in \mathbb{R}^{d_k} 和 值 \mathbf{v} \in \mathbb{R}^{d_v} , 每个注意力头 \mathbf{h}_i( i = 1, \ldots, h )的计算方法为:

\mathbf{h}_i = f(\mathbf W_i^{(q)}\mathbf q, \mathbf W_i^{(k)}\mathbf k,\mathbf W_i^{(v)}\mathbf v) \in \mathbb R^{p_v},

其中,可学习的参数包括 \mathbf W_i^{(q)}\in\mathbb R^{p_q\times d_q} 、 \mathbf W_i^{(k)}\in\mathbb R^{p_k\times d_k} 和 \mathbf W_i^{(v)}\in\mathbb R^{p_v\times d_v} , 以及代表注意力汇聚的函数 f 。 f 可以是 10.3节 中的 加性注意力和缩放点积注意力。 多头注意力的输出需要经过另一个线性转换, 它对应着 h 个头连结后的结果,因此其可学习参数是 \mathbf W_o\in\mathbb R^{p_o\times h p_v} :

\mathbf W_o \begin{bmatrix}\mathbf h_1\\\vdots\\\mathbf h_h\end{bmatrix} \in \mathbb{R}^{p_o}.

基于这种设计,每个头都可能会关注输入的不同部分, 可以表示比简单加权平均值更复杂的函数。

import math
import torch
from torch import nn
from d2l import torch as d2l

实现

在实现过程中通常[选择缩放点积注意力作为每一个注意力头]。 为了避免计算代价和参数代价的大幅增长, 我们设定 p_q = p_k = p_v = p_o / h 。 值得注意的是,如果将查询、键和值的线性变换的输出数量设置为 p_q h = p_k h = p_v h = p_o , 则可以并行计算 h 个头。 在下面的实现中, p_o 是通过参数num_hiddens指定的。

#@save
class MultiHeadAttention(nn.Module):"""多头注意力"""def __init__(self, key_size, query_size, value_size, num_hiddens,num_heads, dropout, bias=False, **kwargs):super(MultiHeadAttention, self).__init__(**kwargs)self.num_heads = num_headsself.attention = d2l.DotProductAttention(dropout)self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)def forward(self, queries, keys, values, valid_lens):# queries,keys,values的形状:# (batch_size,查询或者“键-值”对的个数,num_hiddens)# valid_lens 的形状:# (batch_size,)或(batch_size,查询的个数)# 经过变换后,输出的queries,keys,values 的形状:# (batch_size*num_heads,查询或者“键-值”对的个数,# num_hiddens/num_heads)queries = transpose_qkv(self.W_q(queries), self.num_heads)keys = transpose_qkv(self.W_k(keys), self.num_heads)values = transpose_qkv(self.W_v(values), self.num_heads)if valid_lens is not None:# 在轴0,将第一项(标量或者矢量)复制num_heads次,# 然后如此复制第二项,然后诸如此类。valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)# output的形状:(batch_size*num_heads,查询的个数,# num_hiddens/num_heads)output = self.attention(queries, keys, values, valid_lens)# output_concat的形状:(batch_size,查询的个数,num_hiddens)output_concat = transpose_output(output, self.num_heads)return self.W_o(output_concat)

为了能够[使多个头并行计算], 上面的MultiHeadAttention类将使用下面定义的两个转置函数。 具体来说,transpose_output函数反转了transpose_qkv函数的操作。

#@save
def transpose_qkv(X, num_heads):"""为了多注意力头的并行计算而变换形状"""# 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)# 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,# num_hiddens/num_heads)X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)# 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数,# num_hiddens/num_heads)X = X.permute(0, 2, 1, 3)# 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数,# num_hiddens/num_heads)return X.reshape(-1, X.shape[2], X.shape[3])#@save
def transpose_output(X, num_heads):"""逆转transpose_qkv函数的操作"""# 输入X的形状:(batch_size*num_heads,查询或者“键-值”对的个数,num_hiddens/num_heads)# 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数,num_hiddens/num_heads)X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])# 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,num_hiddens/num_heads)X = X.permute(0, 2, 1, 3)# 最终输出的形状:((batch_size,查询或者“键-值”对的个数,num_hiddens)return X.reshape(X.shape[0], X.shape[1], -1)

下面使用键和值相同的小例子来[测试]我们编写的MultiHeadAttention类。 多头注意力输出的形状是(batch_sizenum_queriesnum_hiddens)。

num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,num_hiddens, num_heads, 0.5)
attention.eval()
MultiHeadAttention((attention): DotProductAttention((dropout): Dropout(p=0.5, inplace=False))(W_q): Linear(in_features=100, out_features=100, bias=False)(W_k): Linear(in_features=100, out_features=100, bias=False)(W_v): Linear(in_features=100, out_features=100, bias=False)(W_o): Linear(in_features=100, out_features=100, bias=False)
)
batch_size, num_queries = 2, 4
num_kvpairs, valid_lens =  6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
attention(X, Y, Y, valid_lens).shape
torch.Size([2, 4, 100])

小结

  • 多头注意力融合了来自于多个注意力汇聚的不同知识,这些知识的不同来源于相同的查询、键和值的不同的子空间表示。
  • 基于适当的张量操作,可以实现多头注意力的并行计算。

练习

  1. 分别可视化这个实验中的多个头的注意力权重。

解:
代码如下:

attention.attention.attention_weights.shape
# (batch_size*num_heads,查询的个数,“键-值”对的个数)

输出结果:
torch.Size([10, 4, 6])

d2l.show_heatmaps(attention.attention.attention_weights.reshape((2,5,4,6)), xlabel='Key positions', ylabel='Query positions', titles=['Head %d' % i for i in range(1, 6)],figsize=(8, 3.5))

输出结果:

2. 假设有一个完成训练的基于多头注意力的模型,现在希望修剪最不重要的注意力头以提高预测速度。如何设计实验来衡量注意力头的重要性呢?

解:
首先定义评判注意力头重要性的指标,比如预测速度等;
然后采用单一变量法,修剪某一个头或某几个头的组合,重新训练模型,并在验证集上评估重要性指标的变化; 最后根据重要性指标的变化,判断最不重要的一个或几个注意力头,并修剪。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/888366.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

故障诊断 | Transformer-LSTM组合模型的故障诊断(Matlab)

效果一览 文章概述 故障诊断 | Transformer-LSTM组合模型的故障诊断(Matlab) 源码设计 %% 初始化 clear close all clc disp(此程序务必用2023b及其以上版本的MATLAB!否则会报错!) warning off %

用html+jq实现元素的拖动效果——js基础积累

用htmljq实现元素的拖动效果 效果图如下&#xff1a; 将【item10】拖动到【item1】前面 直接上代码&#xff1a; html部分 <ul id"sortableList"><li id"item1" class"w1" draggable"true">Item 1</li><li …

单片机学习笔记 12. 定时/计数器_定时

更多单片机学习笔记&#xff1a;单片机学习笔记 1. 点亮一个LED灯单片机学习笔记 2. LED灯闪烁单片机学习笔记 3. LED灯流水灯单片机学习笔记 4. 蜂鸣器滴~滴~滴~单片机学习笔记 5. 数码管静态显示单片机学习笔记 6. 数码管动态显示单片机学习笔记 7. 独立键盘单片机学习笔记 8…

【后端开发】Go语言编程实践,Goroutines和Channels,基于共享变量的并发,反射与底层编程

【后端开发】Go语言编程实践&#xff0c;Goroutines和Channels&#xff0c;基于共享变量的并发&#xff0c;反射与底层编程 【后端开发】Go语言高级编程&#xff0c;CGO、Go汇编语言、RPC实现、Web框架实现、分布式系统 文章目录 1、并发基础, Goroutines和Channels2、基于共享…

人机交互革命,为智能座舱市场激战注入一针「催化剂」

从AIGC到AGI赋能&#xff0c;智能座舱人机交互体验迎来新范式。 不断训练、迭代的大模型&#xff0c;为智能座舱带来了更全面的感知能力、更准确的认知理解&#xff0c;以及更丰富的交互模态&#xff0c;显著提升了其智能化水平。 “AI大模型的快速应用与迭代&#xff0c;推动…

计算机视觉硬件知识点整理六:工业相机选型

文章目录 前言一、工业数字相机的分类二、相机的主要参数三、工业数字摄像机主要接口类型四、选择工业相机的考量因素六、实例分析 前言 随着科技的不断进步&#xff0c;工业自动化领域正经历着前所未有的变革。作为工业自动化的重要组成部分&#xff0c;工业相机在工业检测、…

如何使用brew安装phpredis扩展?

如何使用brew安装phpredis扩展&#xff1f; phpredis扩展是一个用于PHP语言的Redis客户端扩展&#xff0c;它提供了一组PHP函数&#xff0c;用于与Redis服务器进行交互。 1、cd到php某一版本的bin下 /usr/local/opt/php8.1/bin 2、下载 phpredis git clone https://githu…

硬件看门狗工作原理

硬件看门狗是什么&#xff1f; 硬件看门狗&#xff08;Hardware Watchdog&#xff09;是一种用于监控系统运行状态的硬件设备或电路。它的主要功能是检测系统是否正常运行&#xff0c;并在系统出现故障或无响应时自动重启或采取其他恢复措施。 工作原理与引脚 硬件看门狗一般…

神经网络中的优化方法(一)

目录 1. 与纯优化的区别1.1 经验风险最小化1.2 代理损失函数1.3 批量算法和小批量算法 2. 神经网络中优化的挑战2.1 病态2.2 局部极小值2.3 高原、鞍点和其他平坦区域2.4 悬崖和梯度爆炸2.5 长期依赖2.6 非精确梯度2.7 局部和全局结构间的弱对应 3. 基本算法3.1 随机梯度下降(小…

海康gige工业相机无驱动取像突破(c#实现,最后更新,你也可以移植到linux下去用)

买了3个海康的相机&#xff0c;最初测试成功的是500万相机。 然后写了一个通用版&#xff0c;害怕有问题&#xff0c;又买了600万的相机&#xff0c;测试果然不及格&#xff0c;花了九牛二虎之力找到一个小问题&#xff0c;就这个 if (changdu > 1000)&#xff1b; 最后又…

Linux -初识 与基础指令1

博客主页&#xff1a;【夜泉_ly】 本文专栏&#xff1a;【Linux】 欢迎点赞&#x1f44d;收藏⭐关注❤️ 文章目录 &#x1f4da; 前言&#x1f5a5;️ 初识&#x1f510; 登录 root用户&#x1f465; 两种用户➕ 添加用户&#x1f9d1;‍&#x1f4bb; 登录 普通用户⚙️ 常见…

Elasticsearch在liunx 中单机部署

下载配置 1、下载 官网下载地址 2、上传解压 tar -zxvf elasticsearch-XXX.tar.gz 3、新建组和用户 &#xff08;elasticsearch 默认不允许root账户&#xff09; #创建组 es groupadd es #新建用户 useradd ryzhang -g es 4、更改文件夹的用户权限 chown -R ryzhang …

Refit 使用详解

Git官网&#xff1a;https://github.com/reactiveui/refit Refit 是一个针对 .NET 应用程序的 REST API 客户端库&#xff0c;它通过接口定义 API 调用&#xff0c;从而简化与 RESTful 服务的交互。其核心理念是利用声明性编程的方式来创建 HttpClient 客户端&#xff0c;使得…

《山海经》:北山

《山海经》&#xff1a;北山 北山一经单狐山求如山&#xff08;水马&#xff1a;形状与马相似&#xff0c;滑鱼&#xff1a;背部红色&#xff09;带山&#xff08;䑏疏&#xff1a;似马&#xff0c;一只角&#xff0c;鵸鵌&#xff1a;状乌鸦五彩斑斓&#xff0c;儵鱼&#xff…

使用gemini-1.5-pro-002做视频检测

使用Google Cloud Video Intelligence API做视频检测最大的缺陷是无法自定义规则&#xff0c;若使用gemini-1.5-pro-002多模拟模型&#xff0c;则可以自定义检测的规则&#xff0c;具有更好的灵活性。 安装SDK pip install --upgrade google-cloud-aiplatform gcloud auth ap…

动态规划——子序列问题

文章目录 目录 文章目录 前言 一、动态规划思路简介 二、具体实现 1. 字符串问题 1.1 最长公共字符串 1.2.0 最长回文子串 1.2.1 最长回文子序列 2.编辑距离问题 2.1 判断子序列 2.2 编辑距离 总结 前言 上文提到动态规划的背包问题&#xff0c;本文继续介绍动态…

Ubuntu24.04配置DINO-Tracker

一、引言 记录 Ubuntu 配置的第一个代码过程 二、更改conda虚拟环境的默认安装路径 鉴于不久前由于磁盘空间不足引发的重装系统的惨痛经历&#xff0c;在新系统装好后当然要先更改虚拟环境的默认安装路径。 输入指令&#xff1a; conda info可能因为我原本就没有把 Anacod…

vulnhub靶场【哈利波特】三部曲之Aragog

前言 使用virtual box虚拟机 靶机&#xff1a;Aragog : 192.168.1.101 攻击&#xff1a;kali : 192.168.1.16 主机发现 使用arp-scan -l扫描&#xff0c;在同一虚拟网卡下 信息收集 使用nmap扫描 发现22端口SSH服务&#xff0c;openssh 80端口HTTP服务&#xff0c;Apach…

顶刊算法 | 鱼鹰算法OOA-BiTCN-BiGRU-Attention多输入单输出回归预测(Maltab)

顶刊算法 | 鱼鹰算法OOA-BiTCN-BiGRU-Attention多输入单输出回归预测&#xff08;Maltab&#xff09; 目录 顶刊算法 | 鱼鹰算法OOA-BiTCN-BiGRU-Attention多输入单输出回归预测&#xff08;Maltab&#xff09;效果一览基本介绍程序设计参考资料 效果一览 基本介绍 1.Matlab实…

getchar()

getchar():从计算机终端&#xff08;一般是键盘&#xff09;输入一个字符 1、getchar返回的是字符的ASCII码值&#xff08;整数&#xff09;。 2、getchar在读取结束或者失败的时候&#xff0c;会返回EOF 输入密码并确认&#xff1a; scanf读取\n之前的内容即12345678 回车符…