深入浅出 Multi-Head Attention:原理 + 例子 + PyTorch 实现

本文带你一步步理解 Transformer 中最核心的模块:多头注意力机制(Multi-Head Attention)。从原理到实现,配图 + 举例 + PyTorch 代码,一次性说清楚!


什么是 Multi-Head Attention?

简单说,多头注意力就是一种让模型在多个角度“看”一个序列的机制。

在自然语言中,一个词的含义往往依赖于上下文,比如:

“我把苹果给了她”

模型在处理“苹果”时,需要关注“我”“她”“给了”等词,多头注意力就是这样一种机制——从多个角度理解上下文关系。


Self-Attention 是什么?为什么还要多头?

在讲“多头”之前,咱们先回顾一下基础的 Self-Attention

Self-Attention(自注意力)机制的目标是:

让每个词都能“关注”整个句子里的其他词,融合上下文。

它的核心步骤是:

  1. 对每个词生成 Query、Key、Value 向量

  2. 用 Query 和所有 Key 做点积,算出每个词对其他词的关注度(打分)

  3. 用 Softmax 得到权重,对 Value 加权平均,生成当前词的新表示

这样做的好处是:词的语义表示不再是孤立的,而是上下文相关的。


Self-Attention vs Multi-Head Attention

但问题是——单头 Self-Attention 视角有限。就像一个老师只能从一种角度讲课。

于是,Multi-Head Attention 应运而生

特性Self-Attention(单头)Multi-Head Attention(多头)
输入映射矩阵一组 Q/K/V 线性变换多组 Q/K/V,每个头一组
学习角度单一视角多角度并行理解
表达能力有限更丰富、强大
结构简单并行多个头 + 合并输出

一句话总结:

Multi-Head Attention = 多个不同“视角”的 Self-Attention 并行处理 + 合并结果


 多头注意力:8个脑袋一起思考!

多头 = 多个“单头注意力”并行处理!

每个头使用不同的线性变换矩阵,所以能从不同视角处理数据:

  • 第1个头可能专注短依赖(like 动词和主语)

  • 第2个头可能专注实体关系(我 vs 她)

  • 第3个头可能关注时间顺序(“给了”前后)

  • ……共用同一个输入,学习到不同特征!

多头的步骤:

  1. 将输入向量(如512维)拆成多个头(比如8个,每个64维)

  2. 每个头独立进行 attention

  3. 所有头的输出拼接

  4. 再过一次线性变换,融合成最终输出


 PyTorch 实现(简洁版)

我们来看下 PyTorch 中的简化实现:

import torch
import torch.nn as nn
import copydef clones(module, N):return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])def attention(query, key, value, mask=None, dropout=None):d_k = query.size(-1)scores = torch.matmul(query, key.transpose(-2, -1)) / (d_k ** 0.5)if mask is not None:scores = scores.masked_fill(mask == 0, -1e9)p_attn = torch.softmax(scores, dim=-1)if dropout:p_attn = dropout(p_attn)return torch.matmul(p_attn, value), p_attnclass MultiHeadedAttention(nn.Module):def __init__(self, h, d_model, dropout=0.1):super().__init__()assert d_model % h == 0self.d_k = d_model // hself.h = hself.linears = clones(nn.Linear(d_model, d_model), 4)self.dropout = nn.Dropout(dropout)def forward(self, query, key, value, mask=None):if mask is not None:mask = mask.unsqueeze(1)nbatches = query.size(0)query, key, value = [lin(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)for lin, x in zip(self.linears, (query, key, value))]x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)return self.linears[-1](x)

举个例子:多头在实际模型中的作用

假设输入是句子:

"The animal didn't cross the street because it was too tired."

多头注意力的不同头可能会:

  • 🧠 头1:关注“animal”和“it”之间的指代关系;

  • 📐 头2:识别“because”和“tired”之间的因果联系;

  • 📚 头3:注意句子的结构层次……

所以说,多头注意力本质上是一个“并行注意力专家系统”!


 总结

项目解释
目的提升模型表达能力,从多个角度理解输入
核心机制将向量分头 → 每头独立 attention → 合并输出
技术关键view, transpose, matmul, softmax, 拼接线性层

推荐学习路径

  • 🔹 理解 Self-Attention 的点积公式

  • 🔹 搞懂 view, transpose 等张量操作

  • 🔹 看 Transformer 整体结构,关注每层作用

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/901902.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

常用 Git 命令详解

Git 是一个强大的版本控制工具,广泛用于软件开发和团队协作中。掌握 Git 命令可以帮助开发者更高效地管理代码版本和项目进度。本文将介绍一些常用的 Git 命令,并提供示例以帮助你更好地理解和应用这些命令。 目录 常用命令 git clonegit stashgit pul…

NO.96十六届蓝桥杯备战|图论基础-多源最短路|Floyd|Clear And Present Danger|灾后重建|无向图的最小环问题(C++)

多源最短路:即图中每对顶点间的最短路径 floyd算法本质是动态规划,⽤来求任意两个结点之间的最短路,也称插点法。通过不断在两点之间加⼊新的点,来更新最短路。 适⽤于任何图,不管有向⽆向,边权正负&…

电流模式控制学习

电流模式控制 电流模式控制(CMC)是开关电源中广泛使用的一种控制策略,其核心思想是通过内环电流反馈和外环电压反馈共同调节占空比。相比电压模式控制,CMC具有更快的动态响应和更好的稳定性,但也存在一些固有缺点。 …

MATLAB 控制系统设计与仿真 - 36

鲁棒工具箱定义了个新的对象类ureal,可以定义在某个区间内可变的变量。 函数的调用格式为: p ureal(name,nominalvalue) % name为变量名,nominalValue为标称值,默认变化值为/-1 p ureal(name,nominalvalue,PlusMinus,plusminus) p ureal(name,nomin…

LeetCode -- Flora -- edit 2025-04-17

1.最长连续序列 128. 最长连续序列 给定一个未排序的整数数组 nums ,找出数字连续的最长序列(不要求序列元素在原数组中连续)的长度。 请你设计并实现时间复杂度为 O(n) 的算法解决此问题。 示例 1: 输入:nums [1…

Sql刷题日志(day3)

一、笔试 1、min(date_time):求最早日期 2、mysql中distinct不能与order by 连用,可以用group by去重 二、面试 1、SQL中如何利用replace函数统计给定重复字段在字符串中的出现次数 (length(all_string)-length(all_string,目标字符串,))/length(ta…

解决 Spring Boot 多数据源环境下事务管理器冲突问题(非Neo4j请求标记了 @Transactional 尝试启动Neo4j的事务管理器)

0. 写在前面 到底遇到了什么问题? 简洁版: 在 Oracle 与 Neo4j 共存的多数据源项目中,一个仅涉及 Oracle 操作的请求,却因为 Neo4j 连接失败而报错。根本原因是 Spring 的默认事务管理器错误地指向了 Neo4j,导致不相…

理解和实现RESTful API的最佳实践

理解和实现RESTful API的最佳实践 在当今数字化时代,APIs已成为软件开发的核心组件,而RESTful API以其简洁、灵活和可扩展性成为最流行的API设计风格。本文将深入探讨RESTful API的概念、特点和实施指南,帮助开发者构建高效、可靠的Web服务。…

大语言模型微调技术与实践:从原理到应用

大语言模型微调技术与实践:从原理到应用 摘要:随着大语言模型(LLM)技术的迅猛发展,预训练语言模型在各种自然语言处理任务中展现出强大的能力。然而,将这些通用的预训练模型直接应用于特定领域或任务时&am…

遨游科普:三防平板除了三防特性?还能实现什么功能?

在工业4.0浪潮席卷全球的今天,电子设备的功能边界正经历着革命性突破。三防平板电脑作为"危、急、特"场景的智能终端代表,其价值早已超越防水、防尘、防摔的基础防护属性。遨游通讯通过系统级技术创新,将三防平板打造为集通信中枢、…

前端实战:基于 Vue 与 QRCode 库实现动态二维码合成与下载功能

在现代 Web 应用开发中,二维码的应用越来越广泛,从电子票务到信息传递,它都扮演着重要角色。本文将分享如何在 Vue 项目中,结合QRCode库实现动态二维码的生成、与背景图合成以及图片下载功能,打造一个完整且实用的二维…

HAL详解

一、直通式HAL 这里使用一个案例来介绍直通式HAL,选择MTK的NFC HIDL 1.0为例,因为比较简单,代码量也比较小,其源码路径:vendor/hardware/interfaces/nfc/1.0/ 1、NFC HAL的定义 1)NFC HAL数据类型 通常定…

Vue自定义指令-防抖节流

Vue2版本 // 防抖 // <el-button v-debounce"[reset,click,300]" ></el-button> // <el-button v-debounce"[reset]" ></el-button> Vue.directive(debounce, { inserted: function (el, binding) { let [fn, event "cl…

AI知识补全(十六):A2A - 谷歌开源的agent通信协议是什么?

名人说&#xff1a;一笑出门去&#xff0c;千里落花风。——辛弃疾《水调歌头我饮不须劝》 创作者&#xff1a;Code_流苏(CSDN)&#xff08;一个喜欢古诗词和编程的Coder&#x1f60a;&#xff09; 上一篇&#xff1a;AI知识补全&#xff08;十五&#xff09;&#xff1a;AI可解…

【机器人创新创业应需明确产品定位与方向指南】

机器人领域的创新创业, 需要对公司和产品的定位和生态进行深入思考, 明确其定位与发展目标, 明确产品在是为G、为B还是为C进行服务。 本文引用地址&#xff1a;https://www.eepw.com.cn/article/202504/469401.htm 超前的、探索性的创新技术一般是面向G端, 而不是面向B端或者C…

网安加·百家讲坛 | 刘志诚:AI安全风险与未来展望

作者简介&#xff1a;刘志诚&#xff0c;乐信集团信息安全中心总监、OWASP广东区域负责人、网安加社区特聘专家。专注于企业数字化过程中网络空间安全风险治理&#xff0c;对大数据、人工智能、区块链等新技术在金融风险治理领域的应用&#xff0c;以及新技术带来的技术风险治理…

TOA与AOA联合定位的高精度算法,三维、4个基站的情况,MATLAB例程,附完整代码

本代码实现了三维空间内目标的高精度定位,结合到达角(AOA) 和到达时间(TOA) 两种测量方法,通过4个基站的协同观测,利用最小二乘法解算目标位置。代码支持噪声模拟、误差分析及三维可视化,适用于无人机导航、室内定位等场景。订阅专栏后可获得完整代码 文章目录 运行结果…

2025MathorcupC题 音频文件的高质量读写与去噪优化 保姆级教程讲解|模型讲解

2025Mathorcup数学建模挑战赛&#xff08;妈妈杯&#xff09;C题保姆级分析完整思路代码数据教学 C题&#xff1a;音频文件的高质量读写与去噪优化 随着数字媒体技术的迅速发展&#xff0c;音频处理成为信息时代的关键技术之一。在日常生活中&#xff0c;从录音设备捕捉的原始…

Deno Dep:颠覆传统的模块化未来

一、重新定义依赖管理&#xff1a;Deno Dep 的革新哲学 Deno Dep&#xff08;原Deno包管理器&#xff09;彻底重构了JavaScript/TypeScript的依赖管理方式&#xff0c;其核心突破体现在&#xff1a; 1. 浏览器优先的模块化&#xff08;URL-Centric Modules&#xff09; // 直…

欧拉系统升级openssh 9.7p1

开发的系统准备上线&#xff0c;甲方对欧拉服务器进行了扫描&#xff0c;发现openssh版本为8.2p1&#xff0c;存在漏洞&#xff0c;因此需要升级openssh至9.7p1。欧拉系统版本为20.03 SP3。 1、下载openssh 9.7p1 https://www.openssh.com/releasenotes.html&#xff0c; 将下…