Transformer教程之多头自注意力机制

大家好,今天我们要聊一聊Transformer中的一个核心组件——多头自注意力机制。无论你是AI领域的新手,还是深度学习的老鸟,这篇文章都会帮助你更深入地理解这个关键概念。我们会从基础开始,逐步深入,最终让你对多头自注意力机制有一个全面的认识。

什么是多头自注意力机制?

在讨论多头自注意力机制之前,我们首先需要理解什么是注意力机制。注意力机制最早在机器翻译中得到应用,它的核心思想是:在处理某个词语时,模型不应该只关注固定窗口内的词,而应该能够动态地根据当前处理的词,选择最相关的上下文信息。

注意力机制

注意力机制可以用一个简单的公式来表示:

Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V

这里, Q Q Q(Query), K K K(Key), V V V(Value)是输入的向量。这个公式表示我们对 K K K V V V进行线性变换,然后计算 Q Q Q K K K的点积,经过softmax归一化后得到注意力权重,再与 V V V相乘,得到最终的输出。

多头自注意力机制

多头自注意力机制是注意力机制的一个扩展,它通过将输入分成多个“头”(head),让模型在不同的子空间中独立计算注意力,这样可以捕捉到更多层次的特征。

具体来说,多头自注意力机制的过程如下:

  1. 将输入向量 Q Q Q K K K V V V分别线性变换成多个头,每个头的维度减小,通常是 d / h d/h d/h,其中 d d d是输入向量的维度, h h h是头的数量。
  2. 每个头独立地计算注意力机制。
  3. 将所有头的输出拼接起来,经过线性变换,得到最终的输出。

公式上表示为:

MultiHead ( Q , K , V ) = Concat ( head 1 , head 2 , … , head h ) W O \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \text{head}_2, \ldots, \text{head}_h)W^O MultiHead(Q,K,V)=Concat(head1,head2,,headh)WO

其中,每个头的计算过程为:

head i = Attention ( Q W i Q , K W i K , V W i V ) \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) headi=Attention(QWiQ,KWiK,VWiV)

这里, W i Q W_i^Q WiQ W i K W_i^K WiK W i V W_i^V WiV是每个头对应的线性变换矩阵, W O W^O WO是拼接后进行线性变换的矩阵。

为什么要使用多头自注意力?

那么,为什么我们需要多头自注意力机制呢?简单来说,多头自注意力机制有以下几个优点:

  1. 并行计算:每个头可以并行计算,提高了计算效率。
  2. 多样性:不同的头可以关注输入的不同部分,捕捉到更多层次的特征。
  3. 稳定性:多头机制可以使模型更稳定,因为它能够从多个角度看待输入,避免单一注意力机制可能出现的偏差。

多头自注意力机制的实现

接下来,我们来看一下多头自注意力机制的具体实现。我们将以PyTorch为例,逐步实现多头自注意力机制。

准备工作

首先,我们需要导入必要的库:

import torch
import torch.nn as nn
import torch.nn.functional as F

定义线性变换层

我们需要为 Q Q Q K K K V V V分别定义线性变换层:

class MultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads):super(MultiHeadAttention, self).__init__()self.num_heads = num_headsself.d_model = d_modelself.depth = d_model // num_headsself.wq = nn.Linear(d_model, d_model)self.wk = nn.Linear(d_model, d_model)self.wv = nn.Linear(d_model, d_model)self.dense = nn.Linear(d_model, d_model)def split_heads(self, x, batch_size):x = x.view(batch_size, -1, self.num_heads, self.depth)return x.permute(0, 2, 1, 3)

计算注意力

接下来,我们定义一个函数来计算注意力:

def scaled_dot_product_attention(q, k, v, mask=None):matmul_qk = torch.matmul(q, k.transpose(-2, -1))dk = k.size()[-1]scaled_attention_logits = matmul_qk / torch.sqrt(torch.tensor(dk, dtype=torch.float32))if mask is not None:scaled_attention_logits += (mask * -1e9)attention_weights = F.softmax(scaled_attention_logits, dim=-1)output = torch.matmul(attention_weights, v)return output, attention_weights

组合在一起

最后,我们将这些部分组合在一起,实现多头自注意力机制:

class MultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads):super(MultiHeadAttention, self).__init__()self.num_heads = num_headsself.d_model = d_modelself.depth = d_model // num_headsself.wq = nn.Linear(d_model, d_model)self.wk = nn.Linear(d_model, d_model)self.wv = nn.Linear(d_model, d_model)self.dense = nn.Linear(d_model, d_model)def split_heads(self, x, batch_size):x = x.view(batch_size, -1, self.num_heads, self.depth)return x.permute(0, 2, 1, 3)def forward(self, q, k, v, mask=None):batch_size = q.size(0)q = self.wq(q)k = self.wk(k)v = self.wv(v)q = self.split_heads(q, batch_size)k = self.split_heads(k, batch_size)v = self.split_heads(v, batch_size)scaled_attention, _ = scaled_dot_product_attention(q, k, v, mask)scaled_attention = scaled_attention.permute(0, 2, 1, 3).contiguous()concat_attention = scaled_attention.view(batch_size, -1, self.d_model)output = self.dense(concat_attention)return output

这样,我们就完成了多头自注意力机制的实现。

多头自注意力机制在Transformer中的作用

在Transformer模型中,多头自注意力机制主要用于编码器和解码器的构建。编码器中的每一层都包含一个多头自注意力机制和一个前馈神经网络,而解码器则包含一个用于自身的多头自注意力机制和一个用于编码器-解码器交互的多头自注意力机制。

编码器中的多头自注意力

在编码器中,多头自注意力机制帮助模型捕捉输入序列中不同位置之间的关系,从而更好地理解上下文信息。每个编码器层中的多头自注意力机制能够独立地关注不同的上下文特征,然后将这些特征综合起来,生成更具代表性的编码。

解码器中的多头自注意力

在解码器中,多头自注意力机制不仅用于理解自身的序列信息,还用于理解编码器生成的编码信息。解码器中的多头自注意力机制分为两部分:一部分用于关注解码器自身的序列信息,另一部分用于关注编码器生成的序列信息。这种设计使得解码器能够更好地将输入序列的信息与当前生成的序列信息结合起来,提高生成的准确性和连贯性。

总结

多头自注意力机制是Transformer模型中的一个核心组件,通过并行计算和多样性捕捉,可以更高效、更全面地理解输入数据的特征。在实际应用中,多头自注意力机制已经证明了其强大的能力,不仅在自然语言处理领域取得了巨大的成功,还在计算机视觉等其他领域展现出了广泛的应用前景。

希望通过这篇文章,大家能够对多头自注意力机制有一个更清晰的认识。如果你有任何问题或者想进一步探讨的内容,欢迎在评论区留言,我们一起交流学习!

原文链接:Transformer教程之多头自注意力机制

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/37726.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

软考《信息系统运行管理员》-1.3信息系统运维的发展

1.3信息系统运维的发展 我国信息系统运维的发展总体现状 呈现三个“二八现象” 从时间周期看(开发流程)从信息系统效益看(消息体现为“用好”)从资金投入看(重开发,轻服务) 信息系统运维的发…

Codeforces Beta Round 32 (Div. 2, Codeforces format) D. Constellation 题解 枚举

Constellation 题目描述 A star map in Berland is a checked field n m nm nm squares. In each square there is or there is not a star. The favorite constellation of all Berland’s astronomers is the constellation of the Cross. This constellation can be for…

JAVA高级进阶13单元测试、反射、注解

第十三天、单元测试、反射、注解 单元测试 介绍 单元测试 就是针对最小的功能单元(方法),编写测试代码对其进行正确性测试 咱们之前是如何进行单元测试的? 有啥问题 ? 只能在main方法编写测试代码,去调用其他方法进行测试。 …

页面开发感想

页面开发 1、 前端预览 2、一些思路 2.1、首页自定义element-plus的走马灯 :deep(.el-carousel__arrow){border-radius: 0%;height: 10vh; }需要使用:deep(标签)才能修改样式 或者 ::v-deep 标签 2.2、整体设计思路 <template><div class"card" style&…

【ChatBI】text2sql-不需要访问数据表-超轻量Python库Vanna快速上手,对接oneapi

oneapi 准备 首先确保你有oneapi &#xff0c;然后申请 kimi的api 需要去Moonshot AI - 开放平台 然后添加一个api key 然后打开oneapi的渠道界面&#xff0c;添加kimi。 然后点击 测试&#xff0c; 如果能生成响应时间&#xff0c;就是配置正确。 然后创建令牌 http:…

Vllm Offline 启动

Vllm Offline 启动 Vllm Offline 启动&#xff0c;设置环境变量&#xff0c; TRANSFORMERS_OFFLINE1reference: https://github.com/vllm-project/vllm/discussions/1405

Linux shell编程学习笔记60:touch命令

0 前言 在csdn技能树Linux入门的练习题中&#xff0c;touch是最常见的一条命令。这次我们就来研究它的用法。 1 touch命令的功能、格式和选项说明 我们可以使用touch --help命令查看touch命令的帮助信息。 [purpleendurer bash ~ ]touch --help Usage: touch [OPTION]... …

MATLAB-NGO-CNN-SVM,基于NGO苍鹰优化算法优化卷积神经网络CNN结合支持向量机SVM数据分类(多特征输入多分类)

NGO-CNN-SVM&#xff0c;基于NGO苍鹰优化算法优化卷积神经网络CNN结合支持向量机SVM数据分类(多特征输入多分类) 1.数据均为Excel数据&#xff0c;直接替换数据就可以运行程序。 2.所有程序都经过验证&#xff0c;保证程序可以运行。 3.具有良好的编程习惯&#xff0c;程序均…

【Android面试八股文】Activity A跳转B,B跳转C,A不能直接跳转到C,A如何传递消息给C?

文章目录 1. 使用Intent传递消息2. 使用全局单例类(Singleton)3. 使用静态变量4. 使用Application全局静态变量5. 使用 Android系统剪切板(Clipboard)6. 本地化存储方式6.1 使用SharedPreferences6.2 使用File文件存储方式传递消息6.3 使用SQLite数据库方式传递消息7. 使用广…

【Spring Boot】Java 的数据库连接模板:JDBCTemplate

Java 的数据库连接模板&#xff1a;JDBCTemplate 1.JDBCTemplate 初识1.1 JDBC1.2 JDBCTemplate 2.JDBCTemplate 实现数据的增加、删除、修改和查询2.1 配置基础依赖2.2 新建实体类2.3 操作数据2.3.1 创建数据表2.3.2 添加数据2.3.3 查询数据2.3.4 查询所有记录2.3.5 修改数据2…

【ai】tx2 nx:ubuntu18.04 yolov4-triton-tensorrt 成功部署server 运行

isarsoft / yolov4-triton-tensorrt运行发现插件未注册? 【ai】tx2 nx: jetson Triton Inference Server 部署YOLOv4 【ai】tx2 nx: jetson Triton Inference Server 运行YOLOv4 对main 进行了重新构建 【ai】tx2 nx :ubuntu查找NvInfer.h 路径及哪个包、查找符号【ai】tx2…

深度学习实战81-基于大模型的Chatlaw法律问答中的知识图谱融合思路,数据集说明、以及知识图谱对ChatLaw的影响介绍

大家好,我是微学AI,今天给大家介绍一下深度学习实战81-基于大模型的Chatlaw法律问答中的知识图谱融合思路,数据集说明、以及知识图谱对ChatLaw的影响介绍。基于大模型的Chatlaw法律问答系统融合了知识图谱,以提高法律咨询服务的可靠性和准确性。Chatlaw通过结合知识图谱与人…

AES加密算法及AES-CMAC原理白话版系统解析

本文框架 前言1. AES加密理论1.1 不同AES算法区别1.2 加密过程介绍1.2.1 加密模式和填充方案选择1.2.2 密钥扩展1.2.3分组处理1.2.4多轮加密1.2.4.1字节替换1.2.4.2行移位1.2.4.3列混淆1.2.4.4轮密钥加1.3 加密模式1.3.1ECB模式1.3.2CBC模式1.3.3CTR模式1.3.4CFB模式1.3.5 OFB模…

redis 单节点数据如何平滑迁移到集群中

目的 如何把一个redis单节点的数据迁移到 redis集群中 方案&#xff1a; 使用命令redis-cli --cluster import 导入数据至集群 --cluster-from <arg>--cluster-from-user <arg> 数据源用户--cluster-from-pass <arg> 数据源密码--cluster-from-askpass--c…

css_22_过渡动画

一.过渡 transition-property 作用&#xff1a;定义哪个属性需要过渡。结构&#xff1a; transition-property: all; 常用值&#xff1a; 1.none&#xff1a;不过渡任何属性。 2.all&#xff1a;过渡所有能过渡的属性。 3&#xff0e;具体某个属性名&#xff0c;例如&#xf…

驾校预约小程序系统的设计

管理员账户功能包括&#xff1a;系统首页&#xff0c;个人中心&#xff0c;学员管理&#xff0c;教练管理&#xff0c;驾校信息管理&#xff0c;驾校车辆管理&#xff0c;教练预约管理&#xff0c;考试信息管理 微信端账号功能包括&#xff1a;系统首页&#xff0c;驾校信息&a…

Java基础——五、继承

五、继承 简要 1、说明 继承(Inheritance)是面向对象编程(OOP)的一个核心概念&#xff0c;它允许一个类(子类)继承另一个类(父类)的属性和方法&#xff0c;从而实现代码重用和结构化组织。通过继承&#xff0c;子类可以扩展父类的功能或者对父类的方法进行重写。 父类(超类…

基于docker安装redis服务

Redis是我们在项目中经常需要使用的缓存数据库&#xff0c;安装redis的方式也有很多&#xff0c;本文主要是给大家讲解如何基于docker进行redis服务的安装&#xff0c;主要介绍&#xff0c;如何拉取redis镜像、如何挂载redis的数据以及使用redis的配置文件和开启认证等功能&…

steam社区载入失败、加载不出来、打不开?

随着steam夏季大促的到来&#xff0c;最近steam在线用户越来越多了&#xff0c;很多玩家在自己喜欢的游戏社区里看最新的玩法、攻略和玩家的游戏心得。不过有不少玩家表示有时候会打不开游戏社区或是社区加载失败等问题。根据大家遇到的问题&#xff0c;这里总结了几种解决方法…

构建现代医疗:互联网医院系统源码与电子处方小程序开发教学

本篇文章&#xff0c;笔者将探讨互联网医院系统的源码结构和电子处方小程序的开发&#xff0c;帮助读者更好地理解和掌握这些前沿技术。 一、互联网医院系统源码结构 互联网医院系统通常由多个模块组成&#xff0c;每个模块负责不同的功能。以下是一个典型的互联网医院系统的主…