PyTorch 实现图像版多头注意力(Multi-Head Attention)和自注意力(Self-Attention)

本文提供一个适用于图像输入的多头注意力机制(Multi-Head Attention)PyTorch 实现,适用于 ViT、MAE 等视觉 Transformer 中的注意力计算。


模块说明

  • 输入支持图像格式 (B, C, H, W)
  • 内部转换为序列 (B, N, C),其中 N = H * W
  • 多头注意力计算:查询(Q)、键(K)、值(V)使用线性层投影
  • 结果 reshape 回原图维度 (B, C, H, W)

多头注意力机制代码(适用于图像输入)

import torch
import torch.nn as nnclass ImageMultiHeadAttention(nn.Module):def __init__(self, embed_dim, num_heads):super(ImageMultiHeadAttention, self).__init__()assert embed_dim % num_heads == 0, "embed_dim 必须能被 num_heads 整除"self.embed_dim = embed_dimself.num_heads = num_headsself.head_dim = embed_dim // num_heads# Q, K, V 的线性映射self.q_proj = nn.Linear(embed_dim, embed_dim)self.k_proj = nn.Linear(embed_dim, embed_dim)self.v_proj = nn.Linear(embed_dim, embed_dim)# 输出映射层self.out_proj = nn.Linear(embed_dim, embed_dim)self.scale = self.head_dim ** 0.5def forward(self, x):# 输入 x: (B, C, H, W),需要 reshape 为 (B, N, C)B, C, H, W = x.shapex = x.view(B, C, H * W).permute(0, 2, 1)  # (B, N, C)Q = self.q_proj(x)K = self.k_proj(x)V = self.v_proj(x)# 拆成多头 (B, num_heads, N, head_dim)Q = Q.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)K = K.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)V = V.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)# 注意力分数计算attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scaleattn_probs = torch.softmax(attn_scores, dim=-1)attn_out = torch.matmul(attn_probs, V)# 合并多头attn_out = attn_out.transpose(1, 2).contiguous().view(B, H * W, self.embed_dim)# 输出映射out = self.out_proj(attn_out)# 恢复回原图维度 (B, C, H, W)out = out.permute(0, 2, 1).view(B, C, H, W)return out# 测试示例
# 假设输入是一张 14x14 的特征图(类似 patch embedding 后)
img = torch.randn(4, 64, 14, 14)  # (B, C, H, W)mha = ImageMultiHeadAttention(embed_dim=64, num_heads=8)
out = mha(img)print(out.shape)  # 输出应为 (4, 64, 14, 14)

PyTorch 实现自注意力机制(Self-Attention)

本节补充自注意力机制(Self-Attention)的核心代码实现,适用于 ViT 等模型中 patch token 的注意力操作。

自注意力机制代码(Self-Attention)

import torch
import torch.nn as nnclass SelfAttention(nn.Module):def __init__(self, embed_dim):super(SelfAttention, self).__init__()self.embed_dim = embed_dimself.qkv_proj = nn.Linear(embed_dim, embed_dim * 3)self.out_proj = nn.Linear(embed_dim, embed_dim)self.scale = embed_dim ** 0.5def forward(self, x):# 输入 x: (B, N, C)B, N, C = x.shape# 一次性生成 Q, K, Vqkv = self.qkv_proj(x)  # (B, N, 3C)Q, K, V = torch.chunk(qkv, chunks=3, dim=-1)  # 各自为 (B, N, C)# 计算注意力分数attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale  # (B, N, N)attn_probs = torch.softmax(attn_scores, dim=-1)# 得到注意力加权输出attn_out = torch.matmul(attn_probs, V)  # (B, N, C)# 映射回原维度out = self.out_proj(attn_out)  # (B, N, C)return out#  测试示例
# 假设输入为 196 个 patch,每个 patch 的嵌入维度为 64
x = torch.randn(2, 196, 64)  # (B, N, C)attn = SelfAttention(embed_dim=64)
out = attn(x)print(out.shape)  # 输出应为 (2, 196, 64)

📎 拓展说明
• 本实现为单头自注意力机制
• 可用于 NLP 中的序列特征或 ViT 图像 patch 序列
• 若需改为多头注意力,只需将 embed_dim 拆成 num_heads × head_dim 并分别计算后合并


PyTorch 实现图像输入的自注意力机制(Self-Attention)

本节介绍一种适用于图像输入 (B, C, H, W) 的自注意力机制实现,适合卷积神经网络与 Transformer 的融合模块,如 Self-Attention ConvNet、BAM、CBAM、ViT 前层等。

自注意力机制(图像维度)代码

import torch
import torch.nn as nn
import torch.nn.functional as Fclass ImageSelfAttention(nn.Module):def __init__(self, in_channels):super(ImageSelfAttention, self).__init__()self.in_channels = in_channelsself.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)self.key_conv   = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)self.gamma = nn.Parameter(torch.zeros(1))  # 可学习缩放因子def forward(self, x):# 输入 x: (B, C, H, W)B, C, H, W = x.size()# 生成 Q, K, Vproj_query = self.query_conv(x).view(B, -1, H * W).permute(0, 2, 1)  # (B, N, C//8)proj_key   = self.key_conv(x).view(B, -1, H * W)                      # (B, C//8, N)proj_value = self.value_conv(x).view(B, -1, H * W)                    # (B, C, N)# 注意力矩阵:Q * K^Tenergy = torch.bmm(proj_query, proj_key)         # (B, N, N)attention = F.softmax(energy, dim=-1)             # (B, N, N)# 加权求和 Vout = torch.bmm(proj_value, attention.permute(0, 2, 1))  # (B, C, N)out = out.view(B, C, H, W)# 残差连接 + 缩放因子out = self.gamma * out + xreturn out#测试用例
x = torch.randn(2, 64, 32, 32)  # 输入一张图像:B=2, C=64, H=W=32
self_attn = ImageSelfAttention(in_channels=64)
out = self_attn(x)print(out.shape)  # 输出形状应为 (2, 64, 32, 32)

• 本模块基于图像 (B, C, H, W) 进行自注意力计算
• 使用卷积进行 Q/K/V 提取,保持局部感知力
• gamma 是可学习缩放因子,用于残差连接控制注意力贡献度


自注意力中**缩放因子(scale factor)的处理,在序列维度(如 ViT)和图片维度(如 Self-Attention Conv)**中有点不一样。下面我们来详细解释一下原因,并对两种写法做一个统一和对比分析

两种缩放因子的区别
  1. 序列维度的缩放因子
scale = head_dim ** 0.5  # 或者 embed_dim ** 0.5
attn = (Q @ K.T) / scale

• 来源:Transformer 原始论文(Attention is All You Need)
• 原因:在高维向量内积中,为了避免 dot product 的结果数值过大导致梯度不稳定,需要除以 sqrt(d_k)
• 使用场景:多头注意力机制,输入是 (B, N, C),应用在 NLP、ViT 等序列结构

  1. 图片维度(C, H, W)的注意力机制中没有缩放,或者使用 softmax 平衡
attn = softmax(Q @ K.T)   # 无 scale,或者手动调节

• 来源:Non-local Net、Self-Attention Conv、BAM 等 CNN + Attention 融合方法
• 原因:Q 和 K 都通过 1x1 conv 压缩成 C//8 或更小的维度,内积的值本身不会太大;同时图像 attention 主要用 softmax 控制权重范围
• 缩放因子的控制通常用 γ(gamma)作为残差通道缩放,不是 QK 内部的数值缩放


💬 如果你觉得这篇整理有帮助,欢迎点赞收藏!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/900364.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

每日一题(小白)字符串娱乐篇16

分析题意可以了解到本题要求在一串字符串中找到所有组合起来排序递增的字符串。我们可以默认所有字符在字符串中的上升序列是1,从第一个字符开始找,如果后面的字符大于前面的字符就说明这是一个上序列那么后面字符所在的数组加一,如果连接不上…

Ubuntu 22 Linux上部署DeepSeek R1保姆式操作详解(Xinference方式)

一、安装步骤 1.基础环境安装 安装显卡驱动、cuda,根据自己硬件情况查找相应编号,本篇不介绍这部分内容,只给出参考指令,详情请读者自行查阅互联网其它参考资料。 sudo apt install nvidia-utils-565-server sudo apt install…

Immutable.js 完全指南:不可变数据的艺术与实践

引言 在现代前端开发中,状态管理是一个核心挑战。随着应用复杂度增加,如何高效、安全地管理应用状态变得至关重要。Immutable.js 是 Facebook 推出的一个 JavaScript 库,它提供了持久化不可变数据结构,可以帮助开发者更好地管理应…

字符串数据类型的基本运算

任务描述 本关任务:从后台输入任意三个字符串,求最大的字符串。 相关知识 字符串本身是存放在一块连续的内存空间中,并以’\0’作为字符串的结束标记。 字符指针变量本身是一个变量,用于存放字符串的第 1 个字符的地址。 字符数…

Ubuntu 22.04 一键部署openManus

openManus 前言 OpenManus-RL,这是一个专注于基于强化学习(RL,例如 GRPO)的方法来优化大语言模型(LLM)智能体的开源项目,由来自UIUC 和 OpenManus 的研究人员合作开发。 前提要求 安装deepseek docker方式安装 ,windows 方式安装,Linux安装方式

PDF 转图片,一行代码搞定!批量支持已上线!

大家好,我是程序员晚枫。今天我要给大家带来一个超实用的功能——popdf 现在支持 PDF 转图片了,而且还能批量操作!是不是很激动?别急,我来手把手教你玩转这个功能。 1. 一行代码搞定单文件转换 popdf 的核心就是简单暴…

《比特城的机密邮件:加密、签名与防篡改的守护之战》

点击下面图片带您领略全新的嵌入式学习路线 🔥爆款热榜 88万阅读 1.6万收藏 第一章:风暴前的密令 比特城的议会大厅内,首席长老艾德文握着一卷足有半人高的羊皮纸,眉头紧锁。纸上是即将颁布的《新纪元法典》——这份文件不仅内…

8.用户管理专栏主页面开发

用户管理专栏主页面开发 写在前面用户权限控制用户列表接口设计主页面开发前端account/Index.vuelangs/zh.jsstore.js 后端Paginator概述基本用法代码示例属性与方法 urls.pyviews.py 运行效果 总结 欢迎加入Gerapy二次开发教程专栏! 本专栏专为新手开发者精心策划了…

http://noi.openjudge.cn/_2.5基本算法之搜索_1804:小游戏

文章目录 题目深搜代码宽搜代码深搜数据演示图总结 题目 1804:小游戏 总时间限制: 1000ms 内存限制: 65536kB 描述 一天早上,你起床的时候想:“我编程序这么牛,为什么不能靠这个赚点小钱呢?”因此你决定编写一个小游戏。 游戏在一…

发生梯度消失, 梯度爆炸问题的原因,怎么解决?

目录 一、梯度消失的原因 二、梯度爆炸的原因 三、共同的结构性原因 四、解决办法 五、补充知识 一、梯度消失的原因 梯度消失指的是在反向传播过程中,梯度随着层数的增加指数级减小(趋近于0),导致浅层网络的权重几乎无法更新…

【USRP】srsRAN 开源 4G 软件无线电套件

srsRAN 是SRS开发的开源 4G 软件无线电套件。 srsRAN套件包括: srsUE - 具有原型 5G 功能的全栈 SDR 4G UE 应用程序srsENB - 全栈 SDR 4G eNodeB 应用程序srsEPC——具有 MME、HSS 和 S/P-GW 的轻量级 4G 核心网络实现 安装系统 Ubuntu 20.04 USRP B210 sudo …

ChatGPT 4:解锁AI文案、绘画与视频创作新纪元

文章目录 一、ChatGPT 4的技术革新二、AI文案创作:精准生成与个性化定制三、AI绘画艺术:从文字到图像的神奇转化四、AI视频制作:自动化剪辑与创意实现五、知识库与ChatGPT 4的深度融合六、全新的变革和机遇《ChatGPT 4 应用详解:A…

在js中数组相关用法讲解

数组 uniqueArray 简单数组去重 /*** 简单数组去重* param arr* returns*/ export const uniqueArray <T>(arr: T[]) > [...new Set(arr)];const arr1 [1,1,1,1 2, 3];uniqueArray(arr); // [1,2,3]uniqueArrayByKey 根据 key 数组去重 /*** 根据key数组去重* …

RT-Thread ulog 日志组件深度分析

一、ulog 组件核心功能解析 轻量化与实时性 • 资源占用&#xff1a;ulog 核心代码仅需 ROM<1KB&#xff0c;RAM<0.2KB&#xff0c;支持在资源受限的MCU&#xff08;如STM32F103&#xff09;中运行。 • 异步/同步模式&#xff1a;默认采用异步环形缓冲区&#xff08;rt_…

T113s3远程部署Qt应用(dropbear)

T113-S3 是一款先进的应用处理器&#xff0c;专为汽车和工业控制市场而设计。 它集成了双核CortexTM-A7 CPU和单核HiFi4 DSP&#xff0c;提供高效的计算能力。 T113-S3 支持 H.265、H.264、MPEG-1/2/4、JPEG、VC1 等全格式解码。 独立的硬件编码器可以编码为 JPEG 或 MJPEG。 集…

12.青龙面板自动化我的生活

安装 docker方式 docker run -dit \ -v /root/ql:/ql/data \ -p 5700:5700 \ -e ENABLE_HANGUPtrue \ -e ENABLE_WEB_PANELtrue \ --name qinglong \ --hostname qinglong \ --restart always \ whyour/qinglongk8s方式 https://truecharts.org/charts/stable/qinglong/ he…

Maven 远程仓库推送方法

步骤 1&#xff1a;配置 pom.xml 中的远程仓库地址 在项目的 pom.xml 文件中添加 distributionManagement 配置&#xff0c;指定远程仓库的 URL。 xml 复制 <project>...<distributionManagement><!-- 快照版本仓库 --><snapshotRepository><id…

Spring Boot 日志 配置 SLF4J 和 Logback

文章目录 一、前言二、案例一&#xff1a;初识日志三、案例二&#xff1a;使用Lombok输出日志四、案例三&#xff1a;配置Logback 一、前言 在开发 Java 应用时&#xff0c;日志记录是不可或缺的一部分。日志可以记录应用的运行状态、错误信息和调试信息&#xff0c;帮助开发者…

JS API 事件监听

焦点事件案例&#xff1a;搜索框激活下拉菜单 事件对象 事件对象存储事件触发时的相关信息 可以判断用户按键&#xff0c;点击元素等内容 如何获取 事件绑定的回调函数中的第一个形参就是事件对象 一般命名为e,event 事件对象常用属性 type类型 click mouseenter client…

DDD与MVC扩展能力对比

一、架构设计理念的差异二、扩展性差异的具体表现三、DDD扩展性优势的深层原因四、MVC扩展性不足的典型场景五、总结&#xff1a;架构的本质与选择六、例子1&#xff09;场景描述2&#xff09;MVC实现示例&#xff08;三层架构&#xff09;3&#xff09;DDD实现示例&#xff08…