pytorch 实现 Restormer 主要模块(多头通道自注意力机制和门控制结构)

        前面的博文读论文:Restormer: Efficient Transformer for High-Resolution Image Restoration 介绍了 Restormer 网络结构的网络技术特点,本文用 pytorch 实现其中的主要网络结构模块。

1. MDTA(Multi-Dconv Head Transposed Attention:多头注意力机制

## Multi-DConv Head Transposed Self-Attention (MDTA)
class Attention(nn.Module):def __init__(self, dim, num_heads, bias):super(Attention, self).__init__()self.num_heads = num_heads  # 注意力头的个数self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))  # 可学习系数# 1*1 升维self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)# 3*3 分组卷积self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)# 1*1 卷积self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)def forward(self, x):b,c,h,w = x.shape  # 输入的结构 batch 数,通道数和高宽qkv = self.qkv_dwconv(self.qkv(x))q,k,v = qkv.chunk(3, dim=1)  #  第 1 个维度方向切分成 3 块# 改变 q, k, v 的结构为 b head c (h w),将每个二维 plane 展平q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)q = torch.nn.functional.normalize(q, dim=-1)  # C 维度标准化,这里的 C 与通道维度略有不同k = torch.nn.functional.normalize(k, dim=-1)attn = (q @ k.transpose(-2, -1)) * self.temperature # @ 是矩阵乘attn = attn.softmax(dim=-1)out = (attn @ v)  # 注意力图(严格来说不算图)# 将展平后的注意力图恢复out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)# 真正的注意力图out = self.project_out(out)return out

2. GDFN( Gated-Dconv Feed-Forward Network) 

## Gated-Dconv Feed-Forward Network (GDFN)
class FeedForward(nn.Module):def __init__(self, dim, ffn_expansion_factor, bias):super(FeedForward, self).__init__()# 隐藏层特征维度等于输入维度乘以扩张因子hidden_features = int(dim*ffn_expansion_factor)# 1*1 升维self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)# 3*3 分组卷积self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)# 1*1 降维self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)def forward(self, x):x = self.project_in(x)x1, x2 = self.dwconv(x).chunk(2, dim=1)  # 第 1 个维度方向切分成 2 块x = F.gelu(x1) * x2  # gelu 相当于 relu+dropoutx = self.project_out(x)return x

3. TransformerBlock

## 就是标准的 Transformer 架构
class TransformerBlock(nn.Module):def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type):super(TransformerBlock, self).__init__()self.norm1 = LayerNorm(dim, LayerNorm_type)  # 层标准化self.attn = Attention(dim, num_heads, bias)  # 自注意力self.norm2 = LayerNorm(dim, LayerNorm_type)  # 层表转化self.ffn = FeedForward(dim, ffn_expansion_factor, bias)  # FFNdef forward(self, x):x = x + self.attn(self.norm1(x))  # 残差x = x + self.ffn(self.norm2(x))  # 残差return x

4. 测试样例

model = Restormer()
print(model)  # 打印网络结构x = torch.randn((1, 3, 512, 512))  #随机生成输入图像
x = model(x)  # 送入网络
print(x.shape) # 打印网络输入的图像结构

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/536813.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

选择免费的SSL证书,还是付费的?

作为一个互联网文章作者,我会根据具体的使用场景和需求来选择SSL证书。通常情况下,如果是用于个人网站或者小型项目,我会倾向于选择免费的SSL证书,比如 JoySSL提供的免费证书。这样可以在不增加额外费用的情况下为网站提供安全的加…

静态HTTP与CDN:如何优化内容分发

大家好,今天我们来聊聊静态HTTP和CDN这对“黄金搭档”。没错,就是那个让你的网站内容像闪电一样传遍全球的CDN! 首先,我们来了解一下静态HTTP。它就像是那个老实可靠的邮差,每次都按时按点地把你的内容送到用户手中。…

第二十一章博客

计算机应用实现了多台计算机间的互联,使得它们彼此之间能够进行数据交流。网络应用程序就是在已连接的不同计算机上运行的程序,这些程序借助于网络协议,相互之间可以交换数据。编写网络应用程序前,首先必须明确所要使用的网络协议…

Node.js中处理特殊字符的文件名,安全稳妥的方案

在Node.js中,通过path模块提供的basename方法,我们可以轻松地从文件路径中提取文件名。然而,这个方法在处理特殊字符时存在一些问题,因为它会对这些字符进行转义,导致在不同操作系统上的兼容性问题。在这篇文章中&…

C++ boost planner_cond_.wait(lock) 报错1225

1.如下程序段 boost unique_lock doesn’t own the mutex: Operation not permitted 问题: 其中makePlan是一个线程。这里的unlock导致错误这个报错 boost unique_lock doesn’t own the mutex: Operation not permitted bool navigation::makePlan(){ //cv::named…

MySQL中如何快速定位占用CPU过高的SQL

作为DBA工作中都会遇到过数据库服务器CPU飙升的场景,我们该如何快速定位问题?又该如何快速找到具体是哪个SQL引发的CPU异常呢?下面我们说两个方法。聊聊MySQL中如何快速定位占用CPU过高的SQL。 技术人人都可以磨炼,但处理问题的思…

华为OD机试 - 多段线数据压缩(Java JS Python C)

在线OJ刷题 题目详情 - 多段线数据压缩 - Hydro 题目描述 下图中,每个方块代表一个像素,每个像素用其行号和列号表示。 为简化处理,多线段的走向只能是水平、竖直、斜向45度。 上图中的多线段可以用下面的坐标串表示:(2,8),(3,7),(3,6),(3,5),(4,4),(5,3),(6,2),(7,3),(…

042、序列模型

之——从时序中获取信息 目录 之——从时序中获取信息 杂谈 正文 1.建模 2.方案A-马尔科夫假设 3.方案B-潜变量模型 4.简单实现 杂谈 很多连续的数据都是有前后的时间相关性的,并不是每一个单独的数据是随机出现的。在时序中会蕴含一些空间结构的变化信息、…

【数据科学】一文彻底理清数据、数据类型、数据结构的概念

一、什么是数据? 入门数据学科,首先第一步要认识数据什么,可能大多数人都无法对数据做一个准确的定义,在我们印象中,提到数据首先头脑浮现的是数据表格,是一堆堆数字,那么数据就是数字吗&#x…

SpringBoot 2.0 中默认 HikariCP 数据库连接池原理解析

作为后台服务开发,在日常工作中我们天天都在跟数据库打交道,一直在进行各种CRUD操作,都会使用到数据库连接池。按照发展历程,业界知名的数据库连接池有以下几种:c3p0、DBCP、Tomcat JDBC Connection Pool、Druid 等&am…

阿里云服务器记录

阿里云服务器记录 CentOS 8.4 64位 SCC版 CentOS 7.9 64位 SCC版 CentOS 7.9 64位 CentOS 7.9 64位 UEFI版 Alibaba Cloud Linux Anolis OS CentOS Windows Server Ubuntu Debian Fedora OpenSUSE Rocky Linux CentOS Stream AlmaLinux 阿里云服务器有个scc版,这个…

Flask+Mysql项目docker-compose部署(Pythondocker-compose详细步骤)

一、前言 环境: Linux、docker、docker-compose、python(Flask)、Mysql 简介: 简单使用Flask框架写的查询Mysql数据接口,使用docker部署,shell脚本启动 优势: 采用docker方式部署更加便于维护,更加简单快…

如何在Go中使用模板

引言 您是否需要以格式良好的输出、文本报告或HTML页面呈现一些数据?你可以使用Go模板来做到这一点。任何Go程序都可以使用text/template或html/template包(两者都包含在Go标准库中)来整齐地显示数据。 这两个包都允许你编写文本模板并将数据传递给它们,以按你喜欢的格式呈…

“C语言“——scanf()、getchar() 、putchar()、之间的关系

scanf函数说明 scanf函数是对来自于标准输入流的输入数据作格式转换,并将转换结果保存至format后面的实参所指向的对象。 而const char*format 指向的字符串为格式控制字符串,它指定了可输入的字符串以及赋值时转换方法。 简单来说给一个打印格式(输入…

【并发编程篇】源码分析,手动创建线程池

文章目录 🛸前言🌹Executors的三大方法 🍔简述线程池🎆手动创建线程池⭐源码分析✨代码实现,手动创建线程池🎈CallerRunsPolicy()🎈AbortPolicy()🎈DiscardPolicy()🎈Dis…

LNPMariadb数据库分离|web服务器集群

LNP&Mariadb数据库分离|web服务器集群 网站架构演变单机版LNMP独立数据库服务器web服务器集群与Session保持 LNP与数据库分离1. 准备一台独立的服务器,安装数据库软件包2. 将之前的LNMP网站中的数据库迁移到新的数据库服务器3. 修改wordpress网站配置…

2023.12.24 关于 Redis 中 String 类型内部编码 及 应用场景

目录 String 类型内部编码 3 种内部编码方式 String 类型应用场景 Cache 缓存 键名命名规则 计数(Counter) 共享会话(Session ) 手机验证码 总结 String 类型内部编码 3 种内部编码方式 int:用来表示 64 位 —…

vue3菜单权限管理实现

前提 你的菜单是根据路由动态生成的,具体可以参考这篇博客对el-menu组件进行递归封装(根据路由配置动态生成) 描述 首先将路由分为常量路由constantRoute(所有用户都有的路由)和异步路由asyncRoute(需要动…

Gradle 插件

自定义Gradle插件 - 简书

小天使的小难题:新生儿疝气的关注与温馨呵护

引言: 新生儿疝气是一种在出生后可能出现的常见情况,虽然通常不会造成长期影响,但对于家长而言,了解如何正确应对新生儿疝气是至关重要的。本文将深入探讨新生儿疝气的原因、症状,以及家长在面对这一问题时应该采取的…