多头注意力机制(Multi-Head Attention)

文章目录

      • 多头注意力机制的作用
      • 多头注意力机制的工作原理
      • 为什么使用多头注意力机制?
      • 代码示例

多头注意力机制(Multi-Head Attention)是Transformer架构中的一个核心组件。它在机器翻译、自然语言处理(NLP)等领域取得了显著的成功。多头注意力机制的引入是为了增强模型的能力,使其能够从不同的角度关注输入序列的不同部分,从而捕捉更多层次的信息。

多头注意力机制的作用

在多头注意力机制中,the number of heads 参数指的是“头”的数量,即注意力机制的独立并行子层的数量。每个头独立地执行注意力机制(Self-Attention 或 Attention),然后将这些头的输出连接起来,再通过线性变换得到最终的输出。

多头注意力机制的工作原理

以下是多头注意力机制的详细步骤和解释:

  1. 线性变换

    • 对输入进行线性变换,生成多个查询(Query)、键(Key)和值(Value)。
    • 每个头都有独立的线性变换,这意味着不同的头可以学到不同的特征。

    假设输入的维度是 ( d m o d e l d_{model} dmodel),头的数量是 ( h h h),每个头的维度是 ( d k = d m o d e l / h d_k = d_{model} / h dk=dmodel/h)。

    对于输入 ( X \mathbf{X} X),我们有:

    Q i = X W i Q , K i = X W i K , V i = X W i V \mathbf{Q}_i = \mathbf{X} \mathbf{W}_i^Q, \quad \mathbf{K}_i = \mathbf{X} \mathbf{W}_i^K, \quad \mathbf{V}_i = \mathbf{X} \mathbf{W}_i^V Qi=XWiQ,Ki=XWiK,Vi=XWiV

    其中 ( i i i) 表示第 (i) 个头,( W i Q , W i K , W i V \mathbf{W}_i^Q, \mathbf{W}_i^K, \mathbf{W}_i^V WiQ,WiK,WiV) 是线性变换矩阵。

  2. 计算注意力

    • 每个头独立地计算注意力(例如,使用缩放点积注意力机制)。

    缩放点积注意力的公式为:
    [
    \text{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{softmax}\left(\frac{\mathbf{Q} \mathbf{K}^T}{\sqrt{d_k}}\right) \mathbf{V}
    ]

  3. 连接(Concatenation)

    • 将所有头的输出连接起来,形成一个新的矩阵。

    如果有 (h) 个头,每个头的输出维度是 (d_k),则连接后的维度为 (h \times d_k = d_{model})。

  4. 线性变换

    • 将连接后的矩阵通过一个线性变换,得到最终的输出。

    [
    \text{MultiHead}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{Concat}(\text{head}_1, \text{head}_2, \ldots, \text{head}_h) \mathbf{W}^O
    ]
    其中,(\mathbf{W}^O) 是输出的线性变换矩阵。

为什么使用多头注意力机制?

  1. 多样性:不同的头可以关注输入的不同部分,捕捉到更多样化的特征和模式。
  2. 稳定性:多个头的存在使得模型在学习过程中更加稳定和鲁棒。
  3. 增强模型能力:通过并行地执行多个注意力机制,模型能够更好地捕捉长程依赖关系和复杂的结构信息。

代码示例

以下是一个简单的 PyTorch 示例,展示多头注意力机制的实现:

import torch
import torch.nn as nn
import torch.nn.functional as Fclass MultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads):super(MultiHeadAttention, self).__init__()self.num_heads = num_headsself.d_model = d_modelself.d_k = d_model // num_headsassert d_model % num_heads == 0, "d_model must be divisible by num_heads"self.query_linear = nn.Linear(d_model, d_model)self.key_linear = nn.Linear(d_model, d_model)self.value_linear = nn.Linear(d_model, d_model)self.out_linear = nn.Linear(d_model, d_model)def forward(self, query, key, value):batch_size = query.size(0)# Linear projectionsquery = self.query_linear(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)key = self.key_linear(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)value = self.value_linear(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)# Scaled dot-product attentionscores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k, dtype=torch.float32))attention = F.softmax(scores, dim=-1)output = torch.matmul(attention, value)# Concat and linear projectionoutput = output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)output = self.out_linear(output)return output# Example usage
d_model = 512
num_heads = 8
batch_size = 64
sequence_length = 10mha = MultiHeadAttention(d_model, num_heads)
query = torch.randn(batch_size, sequence_length, d_model)
key = torch.randn(batch_size, sequence_length, d_model)
value = torch.randn(batch_size, sequence_length, d_model)output = mha(query, key, value)
print(output.shape)  # Expected output: (64, 10, 512)

在这个示例中:

  • d_model 是输入和输出的特征维度。
  • num_heads 是头的数量。
  • d_k 是每个头的维度。
  • 输入 querykeyvalue 的形状为 (batch_size, sequence_length, d_model)
  • 输出的形状为 (batch_size, sequence_length, d_model)

多头注意力机制通过将注意力机制并行化,并应用多个独立的注意力头,从而增强了模型的表示能力和学习能力。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/32811.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何打造稳定、好用的 Android LayoutInspector?

速度极慢,遇到复杂的布局经常超时 某些情况无法选中指定的 View 本文将围绕 LayoutInspector 的痛点,分析问题并修复,最终将 LayoutInspector 变成一个稳定、好用的插件。 二、加速 Dump View Hierarchy 2.1 问题描述 开发复杂业务的同学…

Spring Boot + WebSocket 实现 IM 即时通讯

文章目录 1. 项目环境准备2. 配置WebSocket3. 创建消息处理器4. 创建消息类5. 创建前端页面6. 启动应用并测试7. 分析与扩展结论 🎉欢迎来到SpringBoot框架学习专栏~ ☆* o(≧▽≦)o *☆嗨~我是IT陈寒🍹✨博客主页:IT陈寒的博客🎈…

Go语言之基础入门

网站:http://hardyfish.top/ 免费书籍分享: 资料链接:https://url81.ctfile.com/d/57345181-61545511-81795b?p3899 访问密码:3899 免费专栏分享: MySQL是怎样运行的从根儿上理解MySQL 课程链接:https:/…

nn.Embedding 根据索引生成的向量有权重吗

import torch import torch.nn as nn 假设有一个大小为 10x3 的 Embedding 层,其中有 10 个单词,每个单词用一个长度为 3 的向量表示 num_words 10 embedding_dim 3 创建 Embedding 层 embedding_layer nn.Embedding(num_words, embedding_dim) p…

LeetCode 算法:翻转二叉树 c++

原题链接🔗:翻转二叉树 难度:简单⭐️ 题目 给你一棵二叉树的根节点 root ,翻转这棵二叉树,并返回其根节点。 示例 1: 输入:root [4,2,7,1,3,6,9] 输出:[4,7,2,9,6,3,1] 示例 …

【Python】已解决:安装python-Levenshtein包时遇到的subprocess-exited-with-error问题

文章目录 一、分析问题背景二、可能出错的原因三、错误代码示例四、正确代码示例及解决方案五、注意事项 已解决:安装python-Levenshtein包时遇到的subprocess-exited-with-error问题 一、分析问题背景 在安装python-Levenshtein这个Python包时,有时会…

godot所有2D节点介绍

五十个2D节点介绍 2D节点介绍 前言一、Node2D二、sprite2D三、AnimatedSprite2D四、Camera2D五、PhysicsBody2D六、 RigidBody2D七、CharacterBody2D八、StaticBody2D九、joint2D十、DampedSpringJoint2D十一、GrooveJoint2D十二、PinJoint2D十三、Area2D十四、AnimatableBody2…

HTML(21)——CSS精灵

CSS精灵,也叫CSS Sprites,是一种网页图片应用处理方式。把网页中一些背景图片整合到一张图片的文件中,再background-position精确定位出背景图片的位置。 优点:减少服务器被请求的次数,减轻服务器的压力,提高页面加载…

智能优化算法改进策略之局部搜索算子(三)—二次插值法

1、原理介绍 多项式是逼近函数的一种常用工具。在寻求函数极小点的区间(即寻查区间)上,我们可以利用在若干点处的函数值来构成低次插值多项式,用它作为求极小点的函数的近似表达式,并用这个多项式的极小点作为原函数极…

Java --- 面试题

一、Redis应用场景 1.1、缓存 热点数据(高频查询,但不经常修改和删除的数据)首选redis作为缓存,性能优秀。 案例:如仓储业务中的商品信息,用户从redis的查询商品信息,没有在去数据库中查询。 1.2、分布式锁 在多线程环境下,对共享资源访问的线程问题,需要通过锁的…

快速业务建模

一句话故事 培训学院进行新季度招生工作,出计划后教务处审批,教学秘书下发计划,班主任手机名单审核后完成计划 用户故事 角色 时间线 动作为动名词 业务建模 多次建模,模型是否能完成业务

高考填报志愿(选专业),怎样找准自己的兴趣?

在很多的高考报考指南中,第一要点,都会建议我们根据自己的兴趣来选择自己的专业。很多人虽然是依据这条规则,选择了自己大学的专业。却依然在学习的过程中发现,好像自己对这个专业并不是那么的有兴趣。 甚至对专业学习深入了解之…

构建健壮的Java应用:错误处理与日志管理

构建健壮的Java应用:错误处理与日志管理 大家好,我是免费搭建查券返利机器人省钱赚佣金就用微赚淘客系统3.0的小编,也是冬天不穿秋裤,天冷也要风度的程序猿! 在Java应用程序的开发过程中,错误处理和日志管…

docker部署ClamAV集成java和python实现文件病毒扫描

介绍 官方文档:https://docs.clamav.net/manual/Signatures/DatabaseInfo.html ClamAV 是一个开源的反病毒引擎,它由多个模块组成,负责不同的任务处理。以下是 ClamAV 的主要模块和它们的功能: clamd:clamd 是 Clam…

java通过 notify和 wait 实现线程间的通信

你好,我是 shengjk1,多年大厂经验,努力构建 通俗易懂的、好玩的编程语言教程。 欢迎关注!你会有如下收益: 了解大厂经验拥有和大厂相匹配的技术等 希望看什么,评论或者私信告诉我! 文章目录 一…

【专业英语 复习】第2章 The Internet, the Web, and Electronic Commerce

1. 单选题 (1分) "Wiki" comes from the Hawaiian word for ________.____ A fast B social C small D changeable 正确答案:A 翻译:Wiki来源于夏威夷语中的________。 2. 单选题 (1分) This type of e-commerce often resembles the elec…

WHAT - 高性能和内存安全的 Rust(一)

目录 一、介绍1.1 示例代码1.2 关键特性内存安全零成本抽象:高效性能示例代码:使用迭代器的零成本抽象示例代码:泛型和单态化总结 并发编程:防止数据竞争Rust 并发编程示例Rust 的所有权系统防止数据竞争总结 丰富的类型系统包管理…

2024.06.11校招 实习 内推 面经

绿*泡*泡VX: neituijunsir 交流*裙 ,内推/实习/校招汇总表格 1、校招 | 美团2025届北斗计划正式启动(内推) 校招 | 美团2025届北斗计划正式启动(内推) 2、实习 | 沃尔沃汽车 Open Day & 实习招聘 …

医学记录 --- 腋下异味

逻辑图地址 症状 病因 汗液分泌旺盛:由于天气炎热、活动出汗、肥胖等因素导致汗液分泌旺盛,可引起腋下有异味表现。在这种情况下,建议保持身体清洁,特别是在炎热和潮湿的环境下。可以使用抗菌洗液、喷雾或霜剂来帮助减少细菌滋…

(done) 关于 GNU/Linux API setenv 的实验

写一个下面的代码来验证 #include <stdlib.h> #include <stdio.h> #include <unistd.h> #include <sys/types.h>int main() {// 设置环境变量 MY_VAR 的值为 "hello_world"if (setenv("MY_VAR", "hello_world", 1) ! 0…