动手学深度学习10.5. 多头注意力-笔记练习(PyTorch)

本节课程地址:多头注意力代码_哔哩哔哩_bilibili

本节教材地址:10.5. 多头注意力 — 动手学深度学习 2.0.0 documentation

本节开源代码:...>d2l-zh>pytorch>chapter_multilayer-perceptrons>multihead-attention.ipynb


多头注意力

在实践中,当给定相同的查询、键和值的集合时, 我们希望模型可以基于相同的注意力机制学习到不同的行为, 然后将不同的行为作为知识组合起来, 捕获序列内各种范围的依赖关系 (例如,短距离依赖和长距离依赖关系)。 因此,允许注意力机制组合使用查询、键和值的不同 子空间表示(representation subspaces)可能是有益的。

为此,与其只使用单独一个注意力汇聚, 我们可以用独立学习得到的 h 组不同的 线性投影(linear projections)来变换查询、键和值。 然后,这 h 组变换后的查询、键和值将并行地送到注意力汇聚中。 最后,将这 h 个注意力汇聚的输出拼接在一起, 并且通过另一个可以学习的线性投影进行变换, 以产生最终输出。 这种设计被称为多头注意力(multihead attention) (="https://zh.d2l.ai/chapter_references/zreferences.html#id174">Vaswaniet al., 2017)。 对于 h 个注意力汇聚输出,每一个注意力汇聚都被称作一个(head)。 图10.5.1 展示了使用全连接层来实现可学习的线性变换的多头注意力。

模型

在实现多头注意力之前,让我们用数学语言将这个模型形式化地描述出来。 给定查询 \mathbf{q} \in \mathbb{R}^{d_q} 、 键 \mathbf{k} \in \mathbb{R}^{d_k} 和 值 \mathbf{v} \in \mathbb{R}^{d_v} , 每个注意力头 \mathbf{h}_i( i = 1, \ldots, h )的计算方法为:

\mathbf{h}_i = f(\mathbf W_i^{(q)}\mathbf q, \mathbf W_i^{(k)}\mathbf k,\mathbf W_i^{(v)}\mathbf v) \in \mathbb R^{p_v},

其中,可学习的参数包括 \mathbf W_i^{(q)}\in\mathbb R^{p_q\times d_q} 、 \mathbf W_i^{(k)}\in\mathbb R^{p_k\times d_k} 和 \mathbf W_i^{(v)}\in\mathbb R^{p_v\times d_v} , 以及代表注意力汇聚的函数 f 。 f 可以是 10.3节 中的 加性注意力和缩放点积注意力。 多头注意力的输出需要经过另一个线性转换, 它对应着 h 个头连结后的结果,因此其可学习参数是 \mathbf W_o\in\mathbb R^{p_o\times h p_v} :

\mathbf W_o \begin{bmatrix}\mathbf h_1\\\vdots\\\mathbf h_h\end{bmatrix} \in \mathbb{R}^{p_o}.

基于这种设计,每个头都可能会关注输入的不同部分, 可以表示比简单加权平均值更复杂的函数。

import math
import torch
from torch import nn
from d2l import torch as d2l

实现

在实现过程中通常[选择缩放点积注意力作为每一个注意力头]。 为了避免计算代价和参数代价的大幅增长, 我们设定 p_q = p_k = p_v = p_o / h 。 值得注意的是,如果将查询、键和值的线性变换的输出数量设置为 p_q h = p_k h = p_v h = p_o , 则可以并行计算 h 个头。 在下面的实现中, p_o 是通过参数num_hiddens指定的。

#@save
class MultiHeadAttention(nn.Module):"""多头注意力"""def __init__(self, key_size, query_size, value_size, num_hiddens,num_heads, dropout, bias=False, **kwargs):super(MultiHeadAttention, self).__init__(**kwargs)self.num_heads = num_headsself.attention = d2l.DotProductAttention(dropout)self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)def forward(self, queries, keys, values, valid_lens):# queries,keys,values的形状:# (batch_size,查询或者“键-值”对的个数,num_hiddens)# valid_lens 的形状:# (batch_size,)或(batch_size,查询的个数)# 经过变换后,输出的queries,keys,values 的形状:# (batch_size*num_heads,查询或者“键-值”对的个数,# num_hiddens/num_heads)queries = transpose_qkv(self.W_q(queries), self.num_heads)keys = transpose_qkv(self.W_k(keys), self.num_heads)values = transpose_qkv(self.W_v(values), self.num_heads)if valid_lens is not None:# 在轴0,将第一项(标量或者矢量)复制num_heads次,# 然后如此复制第二项,然后诸如此类。valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)# output的形状:(batch_size*num_heads,查询的个数,# num_hiddens/num_heads)output = self.attention(queries, keys, values, valid_lens)# output_concat的形状:(batch_size,查询的个数,num_hiddens)output_concat = transpose_output(output, self.num_heads)return self.W_o(output_concat)

为了能够[使多个头并行计算], 上面的MultiHeadAttention类将使用下面定义的两个转置函数。 具体来说,transpose_output函数反转了transpose_qkv函数的操作。

#@save
def transpose_qkv(X, num_heads):"""为了多注意力头的并行计算而变换形状"""# 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)# 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,# num_hiddens/num_heads)X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)# 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数,# num_hiddens/num_heads)X = X.permute(0, 2, 1, 3)# 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数,# num_hiddens/num_heads)return X.reshape(-1, X.shape[2], X.shape[3])#@save
def transpose_output(X, num_heads):"""逆转transpose_qkv函数的操作"""# 输入X的形状:(batch_size*num_heads,查询或者“键-值”对的个数,num_hiddens/num_heads)# 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数,num_hiddens/num_heads)X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])# 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,num_hiddens/num_heads)X = X.permute(0, 2, 1, 3)# 最终输出的形状:((batch_size,查询或者“键-值”对的个数,num_hiddens)return X.reshape(X.shape[0], X.shape[1], -1)

下面使用键和值相同的小例子来[测试]我们编写的MultiHeadAttention类。 多头注意力输出的形状是(batch_sizenum_queriesnum_hiddens)。

num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,num_hiddens, num_heads, 0.5)
attention.eval()
MultiHeadAttention((attention): DotProductAttention((dropout): Dropout(p=0.5, inplace=False))(W_q): Linear(in_features=100, out_features=100, bias=False)(W_k): Linear(in_features=100, out_features=100, bias=False)(W_v): Linear(in_features=100, out_features=100, bias=False)(W_o): Linear(in_features=100, out_features=100, bias=False)
)
batch_size, num_queries = 2, 4
num_kvpairs, valid_lens =  6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
attention(X, Y, Y, valid_lens).shape
torch.Size([2, 4, 100])

小结

  • 多头注意力融合了来自于多个注意力汇聚的不同知识,这些知识的不同来源于相同的查询、键和值的不同的子空间表示。
  • 基于适当的张量操作,可以实现多头注意力的并行计算。

练习

  1. 分别可视化这个实验中的多个头的注意力权重。

解:
代码如下:

attention.attention.attention_weights.shape
# (batch_size*num_heads,查询的个数,“键-值”对的个数)

输出结果:
torch.Size([10, 4, 6])

d2l.show_heatmaps(attention.attention.attention_weights.reshape((2,5,4,6)), xlabel='Key positions', ylabel='Query positions', titles=['Head %d' % i for i in range(1, 6)],figsize=(8, 3.5))

输出结果:

2. 假设有一个完成训练的基于多头注意力的模型,现在希望修剪最不重要的注意力头以提高预测速度。如何设计实验来衡量注意力头的重要性呢?

解:
首先定义评判注意力头重要性的指标,比如预测速度等;
然后采用单一变量法,修剪某一个头或某几个头的组合,重新训练模型,并在验证集上评估重要性指标的变化; 最后根据重要性指标的变化,判断最不重要的一个或几个注意力头,并修剪。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/888366.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

故障诊断 | Transformer-LSTM组合模型的故障诊断(Matlab)

效果一览 文章概述 故障诊断 | Transformer-LSTM组合模型的故障诊断(Matlab) 源码设计 %% 初始化 clear close all clc disp(此程序务必用2023b及其以上版本的MATLAB!否则会报错!) warning off %

用html+jq实现元素的拖动效果——js基础积累

用htmljq实现元素的拖动效果 效果图如下&#xff1a; 将【item10】拖动到【item1】前面 直接上代码&#xff1a; html部分 <ul id"sortableList"><li id"item1" class"w1" draggable"true">Item 1</li><li …

单片机学习笔记 12. 定时/计数器_定时

更多单片机学习笔记&#xff1a;单片机学习笔记 1. 点亮一个LED灯单片机学习笔记 2. LED灯闪烁单片机学习笔记 3. LED灯流水灯单片机学习笔记 4. 蜂鸣器滴~滴~滴~单片机学习笔记 5. 数码管静态显示单片机学习笔记 6. 数码管动态显示单片机学习笔记 7. 独立键盘单片机学习笔记 8…

计算机视觉硬件知识点整理六:工业相机选型

文章目录 前言一、工业数字相机的分类二、相机的主要参数三、工业数字摄像机主要接口类型四、选择工业相机的考量因素六、实例分析 前言 随着科技的不断进步&#xff0c;工业自动化领域正经历着前所未有的变革。作为工业自动化的重要组成部分&#xff0c;工业相机在工业检测、…

如何使用brew安装phpredis扩展?

如何使用brew安装phpredis扩展&#xff1f; phpredis扩展是一个用于PHP语言的Redis客户端扩展&#xff0c;它提供了一组PHP函数&#xff0c;用于与Redis服务器进行交互。 1、cd到php某一版本的bin下 /usr/local/opt/php8.1/bin 2、下载 phpredis git clone https://githu…

硬件看门狗工作原理

硬件看门狗是什么&#xff1f; 硬件看门狗&#xff08;Hardware Watchdog&#xff09;是一种用于监控系统运行状态的硬件设备或电路。它的主要功能是检测系统是否正常运行&#xff0c;并在系统出现故障或无响应时自动重启或采取其他恢复措施。 工作原理与引脚 硬件看门狗一般…

Linux -初识 与基础指令1

博客主页&#xff1a;【夜泉_ly】 本文专栏&#xff1a;【Linux】 欢迎点赞&#x1f44d;收藏⭐关注❤️ 文章目录 &#x1f4da; 前言&#x1f5a5;️ 初识&#x1f510; 登录 root用户&#x1f465; 两种用户➕ 添加用户&#x1f9d1;‍&#x1f4bb; 登录 普通用户⚙️ 常见…

Elasticsearch在liunx 中单机部署

下载配置 1、下载 官网下载地址 2、上传解压 tar -zxvf elasticsearch-XXX.tar.gz 3、新建组和用户 &#xff08;elasticsearch 默认不允许root账户&#xff09; #创建组 es groupadd es #新建用户 useradd ryzhang -g es 4、更改文件夹的用户权限 chown -R ryzhang …

Refit 使用详解

Git官网&#xff1a;https://github.com/reactiveui/refit Refit 是一个针对 .NET 应用程序的 REST API 客户端库&#xff0c;它通过接口定义 API 调用&#xff0c;从而简化与 RESTful 服务的交互。其核心理念是利用声明性编程的方式来创建 HttpClient 客户端&#xff0c;使得…

Ubuntu24.04配置DINO-Tracker

一、引言 记录 Ubuntu 配置的第一个代码过程 二、更改conda虚拟环境的默认安装路径 鉴于不久前由于磁盘空间不足引发的重装系统的惨痛经历&#xff0c;在新系统装好后当然要先更改虚拟环境的默认安装路径。 输入指令&#xff1a; conda info可能因为我原本就没有把 Anacod…

vulnhub靶场【哈利波特】三部曲之Aragog

前言 使用virtual box虚拟机 靶机&#xff1a;Aragog : 192.168.1.101 攻击&#xff1a;kali : 192.168.1.16 主机发现 使用arp-scan -l扫描&#xff0c;在同一虚拟网卡下 信息收集 使用nmap扫描 发现22端口SSH服务&#xff0c;openssh 80端口HTTP服务&#xff0c;Apach…

顶刊算法 | 鱼鹰算法OOA-BiTCN-BiGRU-Attention多输入单输出回归预测(Maltab)

顶刊算法 | 鱼鹰算法OOA-BiTCN-BiGRU-Attention多输入单输出回归预测&#xff08;Maltab&#xff09; 目录 顶刊算法 | 鱼鹰算法OOA-BiTCN-BiGRU-Attention多输入单输出回归预测&#xff08;Maltab&#xff09;效果一览基本介绍程序设计参考资料 效果一览 基本介绍 1.Matlab实…

getchar()

getchar():从计算机终端&#xff08;一般是键盘&#xff09;输入一个字符 1、getchar返回的是字符的ASCII码值&#xff08;整数&#xff09;。 2、getchar在读取结束或者失败的时候&#xff0c;会返回EOF 输入密码并确认&#xff1a; scanf读取\n之前的内容即12345678 回车符…

动态规划-----路径问题

动态规划-----路径问题 下降最小路径和1&#xff1a;状态表示2&#xff1a;状态转移方程3 初始化4 填表顺序5 返回值6 代码实现 总结&#xff1a; 下降最小路径和 1&#xff1a;状态表示 假设&#xff1a;用dp[i][j]表示&#xff1a;到达[i,j]的最小路径 2&#xff1a;状态转…

实现PDF文档加密,访问需要密码

01. 背景 今天下午老板神秘兮兮的来问我&#xff0c;能不能做个文档加密功能&#xff0c;就是那种用户下载打开需要密码才能打开的那种效果。boss都发话了&#xff0c;那必须可以。 需求&#xff1a;将 pdf 文档经过加密处理&#xff0c;客户下载pdf文档&#xff0c;打开文档需…

HarmonyOS Next 模拟器安装与探索

HarmonyOS 5 也发布了有一段时间了&#xff0c;不知道大家实际使用的时候有没有发现一些惊喜。当然随着HarmonyOS 5的更新也带来了很多新特性&#xff0c;尤其是 HarmonyOS Next 模拟器。今天&#xff0c;我们就来探索一下这个模拟器&#xff0c;看看它能给我们的开发过程带来什…

深入探索进程间通信:System V IPC的机制与应用

目录 1、System V概述 2.共享内存&#xff08;shm&#xff09; 2.1 shmget — 创建共享内存 2.1.2 ftok&#xff08;为shmmat创建key值&#xff09; 2.1.3 为什么一块共享内存的标志信息需要用户来传递 2.2 shmat — 进程挂接共享内存 2.3 shmdt — 断开共享内存连接 2.4…

Rust : 生成日历管理markdown文件的小工具

需求&#xff1a; 拟生成以下markdown管理小工具&#xff0c;这也是我日常工作日程表。 可以输入任意时间段&#xff0c;运行后就可以生成以上的markdown文件。 一、toml [package] name "rust-workfile" version "0.1.0" edition "2021"[d…

mean,median,mode,var,std,min,max函数

剩余的函数都放在这篇里面吧 m e a n mean mean函数可以求平均值 a a a为向量时&#xff0c; m e a n ( a ) mean(a) mean(a)求向量中元素的平均值 a a a为矩阵时&#xff0c; m e a n ( a , 1 ) mean(a,1) mean(a,1)求矩阵中各列元素的平均值&#xff1b; m e a n ( a , 2 )…

Android studio 签名加固后的apk文件

Android studio打包时&#xff0c;可以选择签名类型v1和v2&#xff0c;但是在经过加固后&#xff0c;签名就不在了&#xff0c;或者只有v1签名&#xff0c;这样是不安全的。 操作流程&#xff1a; 1、Android studio 对项目进行打包&#xff0c;生成有签名的apk文件&#xff…