Transformer中的自注意力是怎么实现的?

在Transformer模型中,自注意力(Self-Attention)是核心组件,用于捕捉输入序列中不同位置之间的关系。自注意力机制通过计算每个标记与其他所有标记之间的注意力权重,然后根据这些权重对输入序列进行加权求和,从而生成新的表示。下面是实现自注意力机制的代码及其详细说明。

自注意力机制的实现

1. 计算注意力得分(Scaled Dot-Product Attention)

自注意力机制的基本步骤包括以下几个部分:

  1. 线性变换:将输入序列通过三个不同的线性变换层,得到查询(Query)、键(Key)和值(Value)矩阵。
  2. 计算注意力得分:通过点积计算查询与键的相似度,再除以一个缩放因子(通常是键的维度的平方根),以稳定梯度。
  3. 应用掩码:在计算注意力得分后,应用掩码(如果有),避免未来信息泄露(用于解码器中的自注意力)。
  4. 计算注意力权重:通过softmax函数将注意力得分转换为概率分布。
  5. 加权求和:使用注意力权重对值进行加权求和,得到新的表示。
2. 多头注意力机制(Multi-Head Attention)

为了捕捉不同子空间的特征,Transformer使用多头注意力机制。通过将查询、键和值分割成多个头,每个头独立地计算注意力,然后将所有头的输出连接起来,并通过一个线性层进行组合。

自注意力机制代码实现

import torch
import torch.nn as nn
import torch.nn.functional as F# Scaled Dot-Product Attention
def scaled_dot_product_attention(query, key, value, mask=None):d_k = query.size(-1)scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))print(f"Scores shape: {scores.shape}")  # (batch_size, num_heads, seq_length, seq_length)if mask is not None:scores = scores.masked_fill(mask == 0, float('-inf'))attention_weights = F.softmax(scores, dim=-1)print(f"Attention weights shape: {attention_weights.shape}")  # (batch_size, num_heads, seq_length, seq_length)output = torch.matmul(attention_weights, value)print(f"Output shape after attention: {output.shape}")  # (batch_size, num_heads, seq_length, d_k)return output, attention_weights# Multi-Head Attention
class MultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads):super(MultiHeadAttention, self).__init__()assert d_model % num_heads == 0self.d_model = d_modelself.num_heads = num_headsself.d_k = d_model // num_headsself.linear_query = nn.Linear(d_model, d_model)self.linear_key = nn.Linear(d_model, d_model)self.linear_value = nn.Linear(d_model, d_model)self.linear_out = nn.Linear(d_model, d_model)def forward(self, query, key, value, mask=None):batch_size = query.size(0)# Linear projectionsquery = self.linear_query(query)key = self.linear_key(key)value = self.linear_value(value)print(f"Query shape after linear: {query.shape}")  # (batch_size, seq_length, d_model)print(f"Key shape after linear: {key.shape}")      # (batch_size, seq_length, d_model)print(f"Value shape after linear: {value.shape}")  # (batch_size, seq_length, d_model)# Split into num_headsquery = query.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)key = key.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)value = value.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)print(f"Query shape after split: {query.shape}")   # (batch_size, num_heads, seq_length, d_k)print(f"Key shape after split: {key.shape}")       # (batch_size, num_heads, seq_length, d_k)print(f"Value shape after split: {value.shape}")   # (batch_size, num_heads, seq_length, d_k)# Apply scaled dot-product attentionx, attention_weights = scaled_dot_product_attention(query, key, value, mask)# Concatenate headsx = x.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)print(f"Output shape after concatenation: {x.shape}")  # (batch_size, seq_length, d_model)# Final linear layerx = self.linear_out(x)print(f"Output shape after final linear: {x.shape}")   # (batch_size, seq_length, d_model)return x, attention_weights# 示例用法
d_model = 512
num_heads = 8
batch_size = 64
seq_length = 10# 假设输入是随机生成的张量
query = torch.rand(batch_size, seq_length, d_model)
key = torch.rand(batch_size, seq_length, d_model)
value = torch.rand(batch_size, seq_length, d_model)# 创建多头注意力层
mha = MultiHeadAttention(d_model, num_heads)
output, attention_weights = mha(query, key, value)print("最终输出形状:", output.shape)  # 最终输出形状: (batch_size, seq_length, d_model)
print("注意力权重形状:", attention_weights.shape)  # 注意力权重形状: (batch_size, num_heads, seq_length, seq_length)

每一步的形状解释

  1. Linear Projections

    • Query, Key, Value分别经过线性变换。
    • 形状:[batch_size, seq_length, d_model]
  2. Split into Heads

    • 将Query, Key, Value分割成多个头。
    • 形状:[batch_size, num_heads, seq_length, d_k],其中d_k = d_model // num_heads
  3. Scaled Dot-Product Attention

    • 计算注意力得分(Scores)。
    • 形状:[batch_size, num_heads, seq_length, seq_length]
    • 计算注意力权重(Attention Weights)。
    • 形状:[batch_size, num_heads, seq_length, seq_length]
    • 使用注意力权重对Value进行加权求和。
    • 形状:[batch_size, num_heads, seq_length, d_k]
  4. Concatenate Heads

    • 将所有头的输出连接起来。
    • 形状:[batch_size, seq_length, d_model]
  5. Final Linear Layer

    • 通过一个线性层将连接的输出转换为最终的输出。
    • 形状:[batch_size, seq_length, d_model]

通过这种方式,我们可以清楚地看到每一步变换后的张量形状,理解自注意力和多头注意力机制的具体实现细节。

代码说明

  • scaled_dot_product_attention:实现了缩放点积注意力机制,计算查询和键的点积,应用掩码,计算softmax,然后使用权重对值进行加权求和。
  • MultiHeadAttention:实现了多头注意力机制,包括线性变换、分割、缩放点积注意力和最后的线性变换。

多头注意力机制的细节

  • 线性变换:将输入序列通过线性层转换为查询、键和值的矩阵。
  • 分割头:将查询、键和值的矩阵分割为多个头,每个头的维度是[batch_size, num_heads, seq_length, d_k]。
  • 缩放点积注意力:对每个头分别计算缩放点积注意力。
  • 连接头:将所有头的输出连接起来,得到[batch_size, seq_length, d_model]的张量。
  • 线性变换:通过一个线性层将连接的输出转换为最终的输出。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/873136.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JVM:MAT内存泄漏检测原理

文章目录 一、介绍 一、介绍 MAT提供了称为支配树(Dominator Tree)的对象图。支配树展示的是对象实例间的支配关系。在对象引用图中,所有指向对象B的路径都经过对象A,则认为对象A支配对象B。 支配树中对象本身占用的空间称之为…

Spire.PDF for .NET【文档操作】演示:如何在 C# 中切换 PDF 层的可见性

我们已经演示了如何使用 Spire.PDF在 C# 中向 PDF 文件添加多个图层以及在 PDF 中删除图层。我们还可以在 Spire.PDF 的帮助下在创建新页面图层时切换 PDF 图层的可见性。在本节中,我们将演示如何在 C# 中切换新 PDF 文档中图层的可见性。 Spire.PDF for .NET 是一…

【LabVIEW作业篇 - 1】:中途停止for和while循环

文章目录 for循环while循环如何使用帮助 for循环 在程序框图中,创建for循环结构,选择for循环,鼠标右键-条件接线端,即出现像while循环中的小红圆心,其作用与while循环相同。 运行结果如下。(若随机数>…

分布式搜索引擎ES-Elasticsearch进阶

1.head与postman基于索引的操作 引入概念: 集群健康: green 所有的主分片和副本分片都正常运行。你的集群是100%可用 yellow 所有的主分片都正常运行,但不是所有的副本分片都正常运行。 red 有主分片没能正常运行。 查询es集群健康状态&…

ExoPlayer架构详解与源码分析(15)——Renderer

系列文章目录 ExoPlayer架构详解与源码分析(1)——前言 ExoPlayer架构详解与源码分析(2)——Player ExoPlayer架构详解与源码分析(3)——Timeline ExoPlayer架构详解与源码分析(4)—…

Pytest+selenium UI自动化测试实战实例

🍅 视频学习:文末有免费的配套视频可观看 🍅 点击文末小卡片,免费获取软件测试全套资料,资料在手,涨薪更快 今天来说说pytest吧,经过几周的时间学习,有收获也有疑惑,总之…

【Python数据分析】数据分析三剑客:NumPy、SciPy、Matplotlib中常用操作汇总

文章目录 NumPy常见操作汇总SciPy常见操作汇总Matplotlib常见操作汇总官方文档链接NumPy常见操作汇总 在Python的NumPy库中,有许多常用的知识点,这里列出了一些核心功能和常见操作: 类别函数或特性描述基础操作np.array创建数组np.shape获取数组形状np.dtype查看数组数据类…

适用于618/7xx芯片平台 AT开发 远程FOTA升级指南教程

简介 AT版本的远程升级主要是对AT固件版本进行升级,实际方式为通过合宙官方IOT平台升级或者使用自己搭建的服务器进行升级服务。 该文档教程流程适用于 618/716S/718P 芯片平台的Cat.1模块 合宙IOT平台配置 升级日志 —— 如何查看 升级日志 —— 响应码列表 响应…

网络安全相关竞赛比赛

赛事日历(包含全国所有网络安全竞赛) https://datacon.qianxin.com/competition/competitions https://www.ichunqiu.com/competition/all 全国网络安全竞赛 名称链接全国大学生信息安全竞赛http://www.ciscn.cn/信息安全与对抗技术竞赛(In…

PCB(印制电路板)制造涉及的常规设备

印制电路板(PCB)的制造涉及多种设备和工艺。从设计、制作原型到批量生产,每个阶段都需要不同的专业设备。以下是一些在PCB制造过程中常见的设备: 1. 计算机辅助设计(CAD)软件: - 用于设计PC…

深度解析 containerd 中的 CNI 插件

CNI(Container Network Interface)插件是独立的可执行文件,遵循 CNI 规范。Kubernetes 通过 kubelet 调用这些插件来创建和管理容器的网络接口。CNI 插件的主要职责包括网络接口的创建和删除、IP 地址的分配和回收、以及相关网络资源的配置和…

C语言 分割链表

题目来源: 代码部分,参考官方题解的写法: // 思路: 就是把原始链表,拆分为2部分,最后再拼接一下。struct ListNode* partition(struct ListNode* head, int x) {struct ListNode* small malloc(sizeof(struct ListNode));struct ListNode*…

Memcached介绍与使用

引言 本文是笔者对Memcached这个高性能分布式缓存组件的实践案例,Memcached是一种高性能的分布式内存对象缓存系统,用于减轻数据库负载,加速动态Web应用,提高网站访问速度。它通过在内存中缓存数据和对象来减少读取数据库的次数&…

IAR嵌入式开发解决方案已全面支持芯科集成CX3288系列车规RISC-V MCU,共同推动汽车高品质应用的安全开发

中国上海,2024年7月16日 — 全球领先的嵌入式系统开发软件解决方案供应商IAR与芯科集成电路(以下简称“芯科集成”)联合宣布,最新版本IAR Embedded Workbench for RISC-V 3.30.2功能安全版已全面支持芯科集成CX3288系列车规RISC-V…

解决vue多层弹框时存在遮挡问题

本文给大家介绍vue多层弹框时存在遮挡问题,解决思路首先想到的是找到对应的遮挡层的css标签,然后修改z-index值,但是本思路只能解决首次问题,再次打开还会存在相同的问题,故该思路错误,下面给大家带来一种正…

Vue3 快速入门 (四) : 使用路由实现页面跳转

1. 本文环境 Vue版本 : 3.4.29Node.js版本 : v20.15.0系统 : Windows11 64位IDE : VsCode 2. vue 路由 Vue Router 是 Vue 官方的客户端路由解决方案。 客户端路由的作用是在单页应用 (SPA) 中将浏览器的 URL 和用户看到的内容绑定起来。当用户在应用中浏览不同页面时&#…

中间件的理解

内容来源于学习网站整理。【一看就会】什么是前端开发的中间件?_哔哩哔哩_bilibili 每日八股文~白话说mq,消息中间件_哔哩哔哩_bilibili 例如: 1)两个人打电话,中间的通信网络就是中间件。 2)菜鸟驿站&…

npm、pnpm、yarn使用淘宝镜像

文章目录 npmpnpm安装方法Windows其它 设置镜像 yarn npm # 查询当前使用的镜像源 npm get registry# 设置为淘宝镜像源 npm config set registry https://registry.npmmirror.com/# 还原为官方镜像源 npm config set registry https://registry.npmjs.org/pnpm 安装方法 Wi…

MyBatis中的优点和缺点?

优点: 管 1.基于 SQL语句编程,相当灵活,不会对应用程序或者数据库的现有设计造成任何影响,SQL单独写,解除 sq!与程序代码的耦合,便于统一。 2.与JDBC 相比,减少了 50%以上的代码量,消除了 JDB…

流式数据库 |RisingWave 的架构、容错、数据持久化

在上一篇文章中,已经为大家分享了 RisingWave 相关核心概念和术语。本文将在此基础上为大家介绍 RisingWave 的架构、容错以及数据持久化。 1. 架构 RisingWave 的架构如下图所示。它由三个主要部分组成:Meta 节点、Compute 节点和 Compactor 节点。 …