多头注意力(Multi‑Head Attention)

1. 多头注意力(Multi‑Head Attention)原理

设输入序列表示为矩阵 X ∈ R B × L × d model X\in\mathbb{R}^{B\times L\times d_{\text{model}}} XRB×L×dmodel,其中

  • B B B:批大小(batch size),
  • L L L:序列长度(sequence length),
  • d model d_{\text{model}} dmodel:模型隐层维度(model dimension)。

多头注意力基于对缩放点乘注意力的并行化扩展,引入了 h h h 个注意力头(heads),每个头在不同子空间中学习不同的表示。

1.1 线性映射与切分

我们首先为每个头定义三组可学习权重:
W i Q ∈ R d model × d k , W i K ∈ R d model × d k , W i V ∈ R d model × d v , i = 1 , … , h W_i^Q \in \mathbb{R}^{d_{\text{model}}\times d_k},\quad W_i^K \in \mathbb{R}^{d_{\text{model}}\times d_k},\quad W_i^V \in \mathbb{R}^{d_{\text{model}}\times d_v}, \quad i=1,\dots,h WiQRdmodel×dk,WiKRdmodel×dk,WiVRdmodel×dv,i=1,,h
其中

  • h h h:头数(number of heads),
  • d k d_k dk:每个头中 Query/Key 的维度(key/query dimension),
  • d v d_v dv:每个头中 Value 的维度(value dimension),
  • 通常 d model = h × d k d_{\text{model}}=h\times d_k dmodel=h×dk 且取 d v = d k d_v = d_k dv=dk

对输入 X X X 进行投影,得到第 i i i 个头的查询、键、值:
Q i = X W i Q , K i = X W i K , V i = X W i V Q_i = X\,W_i^Q,\quad K_i = X\,W_i^K,\quad V_i = X\,W_i^V Qi=XWiQ,Ki=XWiK,Vi=XWiV
其中

  • Q i ∈ R B × L × d k Q_i \in \mathbb{R}^{B\times L\times d_k} QiRB×L×dk
  • K i ∈ R B × L × d k K_i \in \mathbb{R}^{B\times L\times d_k} KiRB×L×dk
  • V i ∈ R B × L × d v V_i \in \mathbb{R}^{B\times L\times d_v} ViRB×L×dv

1.2 缩放点乘注意力

对第 i i i 个头,分别对所有位置做点积注意力:

  1. 打分矩阵
    S c o r e i = Q i K i ⊤ ∈ R B × L × L \mathrm{Score}_i = Q_i\,K_i^\top \quad\in\mathbb{R}^{B\times L\times L} Scorei=QiKiRB×L×L
  2. 缩放
    S c o r e ~ i = S c o r e i d k \widetilde{\mathrm{Score}}_i = \frac{\mathrm{Score}_i}{\sqrt{d_k}} Score i=dk Scorei
  3. Softmax 归一化
    A i = s o f t m a x ( S c o r e ~ i ) ∈ R B × L × L A_i = \mathrm{softmax}\bigl(\widetilde{\mathrm{Score}}_i\bigr) \quad\in\mathbb{R}^{B\times L\times L} Ai=softmax(Score i)RB×L×L
  4. 加权求和
    h e a d i = A i V i ∈ R B × L × d v \mathrm{head}_i = A_i\,V_i \quad\in\mathbb{R}^{B\times L\times d_v} headi=AiViRB×L×dv

1.3 拼接与线性变换

将所有头的输出在最后一维拼接,再做一次线性投影:
C o n c a t = [ h e a d 1 , … , h e a d h ] ∈ R B × L × ( h d v ) \mathrm{Concat} = \bigl[\mathrm{head}_1,\dots,\mathrm{head}_h\bigr] \quad\in\mathbb{R}^{B\times L\times (h\,d_v)} Concat=[head1,,headh]RB×L×(hdv)
定义输出权重矩阵
W O ∈ R ( h d v ) × d model W^O\in\mathbb{R}^{(h\,d_v)\times d_{\text{model}}} WOR(hdv)×dmodel
最终输出:
M u l t i H e a d ( X ) = C o n c a t W O ∈ R B × L × d model \mathrm{MultiHead}(X) = \mathrm{Concat}\;W^O \quad\in\mathbb{R}^{B\times L\times d_{\text{model}}} MultiHead(X)=ConcatWORB×L×dmodel


2. PyTorch 代码实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import mathclass MultiHeadAttention(nn.Module):def __init__(self, d_model: int, h: int):"""d_model: 模型维度 d_modelh: 注意力头数 h"""super().__init__()assert d_model % h == 0, "d_model 必须能被 h 整除"self.d_model = d_model      # d_modelself.h = h                  # hself.d_k = d_model // h     # d_k = d_model / hself.d_v = self.d_k         # d_v 通常等于 d_k# 投影矩阵 W_i^Q, W_i^K, W_i^V,实际上合并为一个大矩阵后在 forward 再切分self.w_q = nn.Linear(d_model, d_model)  # 同时生成 h 个 Q 投影self.w_k = nn.Linear(d_model, d_model)  # 同时生成 h 个 K 投影self.w_v = nn.Linear(d_model, d_model)  # 同时生成 h 个 V 投影# 输出线性变换 W^Oself.w_o = nn.Linear(d_model, d_model)def forward(self, X: torch.Tensor, mask: torch.Tensor = None):"""X: 输入张量,形状 (B, L, d_model)mask: 可选掩码,形状 (B, 1, L, L) 或 (B, L, L)"""B, L, _ = X.size()# 1. 线性映射到 Q, K, V,然后切分 h 头#    先得到 (B, L, h*d_k),再 view/transpose 为 (B, h, L, d_k)Q = self.w_q(X).view(B, L, self.h, self.d_k).transpose(1, 2)K = self.w_k(X).view(B, L, self.h, self.d_k).transpose(1, 2)V = self.w_v(X).view(B, L, self.h, self.d_k).transpose(1, 2)# 此时 Q, K, V 形状均为 (B, h, L, d_k)# 2. 计算点积注意力#    scores = Q @ K^T  -> (B, h, L, L)scores = torch.matmul(Q, K.transpose(-2, -1))  #    缩放:除以 sqrt(d_k)scores = scores / math.sqrt(self.d_k)#    可选掩码:将被屏蔽位置设为 -inf if mask is not None:scores = scores.masked_fill(mask == 0, float('-inf'))#    Softmax 归一化 -> (B, h, L, L)A = F.softmax(scores, dim=-1)#    加权求和 -> head_i 形状 (B, h, L, d_k)heads = torch.matmul(A, V)# 3. 拼接 h 个头:transpose 回 (B, L, h, d_k) 再 reshapeconcat = heads.transpose(1, 2).contiguous().view(B, L, self.h * self.d_k)#    concat 形状 (B, L, h*d_k) == (B, L, d_model)# 4. 最后一次线性变换 W^Ooutput = self.w_o(concat)  # -> (B, L, d_model)return output, A  # 返回输出及注意力权重 A

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/903070.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

系列位置效应——AI与思维模型【80】

一、定义 系列位置效应思维模型是指在一系列事物或信息的呈现过程中,人们对于处于系列开头和结尾部分的项目的记忆效果优于中间部分项目的现象。具体而言,开头部分的记忆优势被称为首因效应,结尾部分的记忆优势被称为近因效应。这种效应反映…

MyBatis XML 配置完整示例(含所有核心配置项)

MyBatis XML 配置完整示例&#xff08;含所有核心配置项&#xff09; 1. 完整 mybatis-config.xml 配置文件 <?xml version"1.0" encoding"UTF-8" ?> <!DOCTYPE configurationPUBLIC "-//mybatis.org//DTD Config 3.0//EN""htt…

电商数据中台架构:淘宝 API 实时采集与多源数据融合技术拆解

引言 在当今竞争激烈的电商领域&#xff0c;数据已成为企业决策和业务发展的核心驱动力。电商数据中台能够整合和管理企业内外部的各种数据&#xff0c;为业务提供有力支持。其中&#xff0c;淘宝 API 实时采集与多源数据融合技术是数据中台架构中的关键部分。本文将深入探讨这…

ubuntu22.04部署Snipe-IT

文章目录 参考链接一、写在前二、安装操作系统三、安装 PHP四、下载 Snipe-IT五、安装依赖六、安装数据库并创建用户七、安装 Snipe-IT八、安装 Nginx九、Web 继续安装 Snipe-IT补充&#xff1a;20250427补充&#xff1a; 最后 参考链接 How to Install Snipe-IT on Ubuntu 22…

图论---Bellman-Ford算法

适用场景&#xff1a;有边数限制 ->&#xff08;有负环也就没影响了&#xff09;&#xff0c;存在负权边&#xff0c;O( n * m )&#xff1b; 有负权回路时有的点距离会是负无穷&#xff0c;因此最短路存在的话就说明没有负权回路。 从1号点经过不超过k条边到每个点的距离…

A. Ideal Generator

time limit per test 1 second memory limit per test 256 megabytes We call an array aa, consisting of kk positive integers, palindromic if [a1,a2,…,ak][ak,ak−1,…,a1][a1,a2,…,ak][ak,ak−1,…,a1]. For example, the arrays [1,2,1][1,2,1] and [5,1,1,5][5,…

[详细无套路]MDI Jade6.5安装包下载安装教程

目录 1. 软件包获取 2. 下载安装 3. 启动 4. 问题记录 写在前面: 垂死病中惊坐起,JAVA博主居然开始更博客了~ 最近忙项目了, 没啥更新的动力,见谅~见谅~. 这次博主的化工友友突然让帮安装JADE6.5软件,本来以为不就一个软件,直接拿捏. 不料竟然翻了个小车, 反被拿捏了. 既…

Serverless 在云原生后端的实践与演化:从函数到平台的革新

📝个人主页🌹:慌ZHANG-CSDN博客 🌹🌹期待您的关注 🌹🌹 一、引言:从服务器到“无服务器”的后端演变 在传统后端开发中,我们需要为服务配置并维护服务器资源,无论是物理机、虚拟机还是容器化服务,都需要: 管理系统运行环境 监控负载与扩缩容 保证高可用与安…

【专题三】二分查找(2)

&#x1f4dd;前言说明&#xff1a; 本专栏主要记录本人的基础算法学习以及LeetCode刷题记录&#xff0c;按专题划分每题主要记录&#xff1a;&#xff08;1&#xff09;本人解法 本人屎山代码&#xff1b;&#xff08;2&#xff09;优质解法 优质代码&#xff1b;&#xff…

MySQL 详解之函数:数据处理与计算的利器

在 MySQL 中,函数可以接受零个或多个输入参数,并返回一个值。这些函数可以在 SELECT 语句的字段列表、WHERE 子句、HAVING 子句、ORDER BY 子句以及 UPDATE 和 INSERT 语句中使用。合理利用函数,可以简化 SQL 语句,提高开发效率。 MySQL 提供了大量的内置函数 (Built-in F…

探索具身智能协作机器人:技术、应用与未来

具身智能协作机器人&#xff1a;概念与特点 具身智能协作机器人&#xff0c;简单来说&#xff0c;就是将人工智能技术与机器人实体相结合&#xff0c;使其能够在与人类共享的空间中进行安全、高效协作的智能设备。它打破了传统机器人只能在预设环境中执行固定任务的局限&#…

基于物联网的园林防火监测系统

标题:基于物联网的园林防火监测系统 内容:1.摘要 随着全球气候变化和人类活动影响&#xff0c;园林火灾发生频率呈上升趋势&#xff0c;给生态环境和人类生命财产造成巨大损失。为有效预防和应对园林火灾&#xff0c;本文提出基于物联网的园林防火监测系统。该系统综合运用传感…

JAVA多线程(8.0)

目录 线程池 为什么使用线程池 线程池的使用 工厂类Executors&#xff08;工厂模式&#xff09; submit 实现一个线程池 线程池 为什么使用线程池 在前面我们都是通过new Thread() 来创建线程的&#xff0c;虽然在java中对线程的创建、中断、销毁、等值等功能提供了支持…

用go从零构建写一个RPC(仿gRPC,tRPC)--- 版本1

希望借助手写这个go的中间件项目&#xff0c;能够理解go语言的特性以及用go写中间件的优势之处&#xff0c;同时也是为了更好的使用和优化公司用到的trpc&#xff0c;并且作者之前也使用过grpc并有一定的兴趣&#xff0c;所以打算从0构建一个rpc系统&#xff0c;对于生产环境已…

【学习笔记】Stata

一、Stata简介 Stata 是一种用于数据分析、数据管理和图形生成的统计软件包&#xff0c;广泛应用于经济学、社会学、政治科学等社会科学领域。 二、Stata基础语法 2.1 数据管理 Stata 支持多种数据格式的导入&#xff0c;包括 Excel、CSV、文本文件等。 从 Excel 文件导入…

Redis数据结构SDS,IntSet,Dict

目录 1.字符串&#xff1a;SDS 1.1.为什么叫做动态字符串 2.IntSet 2.1.inset如何保存大于当前编码的最大数字&#xff1f; 3.Dict 3.1Dict的扩容 3.2Dict的收缩 3.3.rehash 1.字符串&#xff1a;SDS SDS的底层是C语言编写的构建的一种简单动态字符串 简称SDS&#xff…

Maven的聚合工程与继承

目录 一、为什么需要使用Maven工程 二、聚合工程的结构 三、聚合工程实现步骤 四、父工程统一管理版本 五、编译打包 大家好&#xff0c;我是jstart千语。想着平时开发项目似乎都是用maven来管理的&#xff0c;并且大多都是聚合工程。而且在maven的聚合工程中&#xff0c…

前端职业发展:如何规划前端工程师的成长路径?

前端职业发展:如何规划前端工程师的成长路径? 大家好,我是全栈老李。今天咱们聊聊前端工程师的职业发展路径,这个话题看似简单,实则暗藏玄机。就像打游戏升级一样,你得知道下一关是什么,才能提前准备装备和技能点。 前端之路 一般我们从一个新手到大神,普遍需要经过…

【星海出品】分布式存储数据库etcd

etcd 数据库由 CoreOS 公司创建。 https://github.com/etcd-io/etcd api信息 https://etcd.io/docs/v3.5/dev-guide/api_reference_v3/ etcdctl --help etcd 最初由 CoreOS 公司开发&#xff0c;作为其核心项目之一。 CoreOS 成立于 2013 年&#xff0c;专注于容器化技术&#…

2025新版修复蛇年运势测试风水起名系统源码

2025新版修复蛇年运势测试风水起名系统源码 通过网盘分享的文件&#xff1a;2025xbfsysweb.rar 链接: https://pan.baidu.com/s/1r1MOkJJJMj9s9nQX_GzI3Q 提取码: 9weh 备用下载地址&#xff1a;http://pan.1234f.com:5212/s/JK1uw