【NeurIPS 2023】PromptIR: Prompting for All-in-One Blind Image Restoration

PromptIR: Prompting for All-in-One Blind Image Restoration, NeurIPS 2023

论文:https://arxiv.org/abs/2306.13090

代码:https://github.com/va1shn9v/promptir

解读:即插即用系列 | PromptIR:MBZUAI提出一种基于Prompt的全能图像恢复网络 - 知乎 (zhihu.com)

摘要

图像恢复是从其受损版本中恢复高质量清晰图像的过程。deep-learning方法显著提升了图像恢复性能,然而,它们在不同类型和级别的退化上的泛化能力有限。这限制了它们在实际应用中的使用,因为需要针对每种具体的退化进行单独训练模型,并了解输入图像的退化类型才能应用相应的模型。本文介绍了一种基于提示的学习方法,称为PromptIR,用于全能图像恢复,可以有效地从各种类型和级别的退化中恢复图像。具体而言,本文方法使用提示来编码退化特定信息,并动态引导恢复网络。 PromptIR提供了一个通用且高效的插件模块,只需少量轻量级提示即可用于恢复各种类型和级别的受损图像,无需事先了解图像中存在的损坏信息。

动机

图像恢复过程中,会由于各种客观因素或限制(相机设备、环境条件)出现各种退化现象。deep-learning方法虽然有用,但在特定的退化类型和程度之外缺乏泛化性。因此,迫切需要开发一种能够有效恢复各种类型和程度的退化图像的一体化方法。

AirNet,采用对比学习解决一体化恢复任务,但需要训练额外的编码器,会增加训练负担。

为此,本文提出了一种基于提示学习的方法来执行一体化图像恢复。该方法利用提示(一组可调参数),用于编码关于各种图像退化类型的重要区分信息。通过将提示与主恢复网络的特征表示相互作用,动态地增强表示,以获得具有退化特定知识的适应性,这种适应性使网络能够通过动态调整其行为有效地恢复图像。

该图显示了在PromptIR和AirNet中使用的退化嵌入的tSNE图。不同的颜色表示不同的退化类型。每个任务的嵌入更好地聚集在一起,显示了提示标记学习具有区分性的退化上下文的有效性,从而有助于恢复过程。

 

所提的Prompt 方法

本文PromptIR,提出了一个即插即用的提示模块,它隐式地预测与退化条件相关的提示,以引导具有未知退化的输入图像的恢复过程。来自提示的指导被注入到网络的多个解码阶段,其中包含少量可学习参数。

 

贡献

  • 提出了一种基于提示的一体化恢复框架PromptIR,它仅依赖于输入图像来恢复清晰图像,而不需要任何关于图像中存在的退化的先验知识。
  • 本文prompt block 是一个可轻松集成到任何现有恢复网络中的插件模块。它由提示生成模块(PGM)和提示交互模块(PIM)组成。提示块的目标是生成与输入条件相关的提示(通过PGM),这些提示具有有用的上下文信息,以指导恢复网络(通过PIM)有效地消除输入图像中的破坏。
  • 本文实验证明了PromptIR的动态适应行为,在包括图像降噪、去雨和去雾在内的各种图像恢复任务上实现了最先进的性能。

方法 PromptIR

“一体式”图像恢复的目标是,学习单个模型M,以从退化的图像中恢复干净图像,该图像已使用退化方式D退化,而没有关于D的先验信息。虽然该模型对退化方式“不可见”,可以通过提供关于退化类型的隐含上下文信息来增强其在恢复干净图像方面的性能。基于提示学习的图像恢复框架PromptIR,用于在恢复干净图像的同时用退化类型的相关知识补充模型。关键元素是 提示块prompt block。

PromptIR使用提示块来生成可学习的提示参数,并在恢复过程中利用这些提示来指导模型。框架通过逐级编码器-解码器将特征逐步转换为深层特征,并在解码器中引入提示块来辅助恢复过程。提示块在解码器的每个级别中连接,隐式地为输入特征提供关于退化类型的信息,以实现引导恢复。总体来说,PromptIR框架通过逐级编码和解码以及引入提示块的方式实现图像恢复任务。 

Prompt Block

提示块 prompt block,由两个模块组成:提示生成模块(PGM)和提示交互模块(PIM)。

  • 提示生成模块使用输入特征Fl和提示组件生成与输入条件相关的提示P。
  • 提示交互模块通过Transformer块使用生成的提示动态调整输入特征。提示与解码器特征在多个级别交互,以丰富特定于退化的上下文信息。 

给N个prompt-components P_c \in R^{N*\hat{H}*\hat{W}*\hat{C}}和 input feature F_1\in R^{\hat{H}*\hat{W}*\hat{C}}, prompt block 可表示为:

 

 其中,PGM表示提示生成模块,PIM表示提示交互模块。

Prompt Generation Module (PGM)

提示组件 P_c是一组可学习的参数,与输入特征交互,嵌入了退化信息。一种直接的特征-提示交互方法是直接使用学习到的提示来校准特征,可能会产生次优结果,因为它对输入内容是无知的。本文提出了提示生成模块(PGM),它从输入特征中动态预测基于注意力的权重,并将这些权重应用于提示组件,生成与输入条件相关的提示P。此外,PGM创建了一个共享空间,促进了提示组件之间的相关知识共享。PGM公式表达为:

Prompt Interaction Module (PIM)

PIM的主要目标是实现输入特征F_1和提示P之间的交互,以实现有指导的恢复过程。

在PIM中,沿着通道维度将生成的提示与输入特征进行拼接。接下来将拼接后的表示通过一个Transformer块进行处理,该块利用提示中编码的退化信息来转换输入特征。

​​本文的主要贡献是提示块,它是一个插件模块,与具体的架构无关。PromptIR框架中,使用了现有的Transformer块。Transformer块由两个顺序连接的子模块组成:多转置卷积头转置注意力(MDTA)和门控转置卷积前馈网络(GDFN)。MDTA在通道而不是空间维度上应用自注意操作,并具有线性复杂度。GDFN的目标是以可控的方式转换特征,即抑制信息较少的特征,只允许有用的特征在网络中传播。PIM的整体过程为:

实验

主要实验与可视化样例 

全能恢复设置下的比较:

去雾、去雨、去噪实验比较: 

 

去雾、去雨、去噪可视化比较:

消融实验 

关键代码

PGM

# https://github.com/va1shn9v/PromptIR/blob/main/net/model.py##---------- Prompt Gen Module -----------------------
class PromptGenBlock(nn.Module):def __init__(self,prompt_dim=128,prompt_len=5,prompt_size = 96,lin_dim = 192):super(PromptGenBlock,self).__init__()self.prompt_param = nn.Parameter(torch.rand(1,prompt_len,prompt_dim,prompt_size,prompt_size))self.linear_layer = nn.Linear(lin_dim,prompt_len)self.conv3x3 = nn.Conv2d(prompt_dim,prompt_dim,kernel_size=3,stride=1,padding=1,bias=False)def forward(self,x):B,C,H,W = x.shapeemb = x.mean(dim=(-2,-1))prompt_weights = F.softmax(self.linear_layer(emb),dim=1)prompt = prompt_weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * self.prompt_param.unsqueeze(0).repeat(B,1,1,1,1,1).squeeze(1)prompt = torch.sum(prompt,dim=1)prompt = F.interpolate(prompt,(H,W),mode="bilinear")prompt = self.conv3x3(prompt)return prompt

PromptIR

# https://github.com/va1shn9v/PromptIR/blob/main/net/model.pyclass PromptIR(nn.Module):def __init__(self, inp_channels=3, out_channels=3, dim = 48,num_blocks = [4,6,6,8], num_refinement_blocks = 4,heads = [1,2,4,8],ffn_expansion_factor = 2.66,bias = False,LayerNorm_type = 'WithBias',   ## Other option 'BiasFree'decoder = False,):super(PromptIR, self).__init__()self.patch_embed = OverlapPatchEmbed(inp_channels, dim)self.decoder = decoderif self.decoder:self.prompt1 = PromptGenBlock(prompt_dim=64,prompt_len=5,prompt_size = 64,lin_dim = 96)self.prompt2 = PromptGenBlock(prompt_dim=128,prompt_len=5,prompt_size = 32,lin_dim = 192)self.prompt3 = PromptGenBlock(prompt_dim=320,prompt_len=5,prompt_size = 16,lin_dim = 384)self.chnl_reduce1 = nn.Conv2d(64,64,kernel_size=1,bias=bias)self.chnl_reduce2 = nn.Conv2d(128,128,kernel_size=1,bias=bias)self.chnl_reduce3 = nn.Conv2d(320,256,kernel_size=1,bias=bias)self.reduce_noise_channel_1 = nn.Conv2d(dim + 64,dim,kernel_size=1,bias=bias)self.encoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])self.down1_2 = Downsample(dim) ## From Level 1 to Level 2self.reduce_noise_channel_2 = nn.Conv2d(int(dim*2**1) + 128,int(dim*2**1),kernel_size=1,bias=bias)self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3self.reduce_noise_channel_3 = nn.Conv2d(int(dim*2**2) + 256,int(dim*2**2),kernel_size=1,bias=bias)self.encoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])self.down3_4 = Downsample(int(dim*2**2)) ## From Level 3 to Level 4self.latent = nn.Sequential(*[TransformerBlock(dim=int(dim*2**3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[3])])self.up4_3 = Upsample(int(dim*2**2)) ## From Level 4 to Level 3self.reduce_chan_level3 = nn.Conv2d(int(dim*2**1)+192, int(dim*2**2), kernel_size=1, bias=bias)self.noise_level3 = TransformerBlock(dim=int(dim*2**2) + 512, num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type)self.reduce_noise_level3 = nn.Conv2d(int(dim*2**2)+512,int(dim*2**2),kernel_size=1,bias=bias)self.decoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias)self.noise_level2 = TransformerBlock(dim=int(dim*2**1) + 224, num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type)self.reduce_noise_level2 = nn.Conv2d(int(dim*2**1)+224,int(dim*2**2),kernel_size=1,bias=bias)self.decoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])self.up2_1 = Upsample(int(dim*2**1))  ## From Level 2 to Level 1  (NO 1x1 conv to reduce channels)self.noise_level1 = TransformerBlock(dim=int(dim*2**1)+64, num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type)self.reduce_noise_level1 = nn.Conv2d(int(dim*2**1)+64,int(dim*2**1),kernel_size=1,bias=bias)self.decoder_level1 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])self.refinement = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)])self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)def forward(self, inp_img,noise_emb = None):inp_enc_level1 = self.patch_embed(inp_img)out_enc_level1 = self.encoder_level1(inp_enc_level1)inp_enc_level2 = self.down1_2(out_enc_level1)out_enc_level2 = self.encoder_level2(inp_enc_level2)inp_enc_level3 = self.down2_3(out_enc_level2)out_enc_level3 = self.encoder_level3(inp_enc_level3) inp_enc_level4 = self.down3_4(out_enc_level3)        latent = self.latent(inp_enc_level4)if self.decoder:dec3_param = self.prompt3(latent)latent = torch.cat([latent, dec3_param], 1)latent = self.noise_level3(latent)latent = self.reduce_noise_level3(latent)inp_dec_level3 = self.up4_3(latent)inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1)inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3)out_dec_level3 = self.decoder_level3(inp_dec_level3) if self.decoder:dec2_param = self.prompt2(out_dec_level3)out_dec_level3 = torch.cat([out_dec_level3, dec2_param], 1)out_dec_level3 = self.noise_level2(out_dec_level3)out_dec_level3 = self.reduce_noise_level2(out_dec_level3)inp_dec_level2 = self.up3_2(out_dec_level3)inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1)inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2)out_dec_level2 = self.decoder_level2(inp_dec_level2)if self.decoder:dec1_param = self.prompt1(out_dec_level2)out_dec_level2 = torch.cat([out_dec_level2, dec1_param], 1)out_dec_level2 = self.noise_level1(out_dec_level2)out_dec_level2 = self.reduce_noise_level1(out_dec_level2)inp_dec_level1 = self.up2_1(out_dec_level2)inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1)out_dec_level1 = self.decoder_level1(inp_dec_level1)out_dec_level1 = self.refinement(out_dec_level1)out_dec_level1 = self.output(out_dec_level1) + inp_imgreturn out_dec_level1

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/190698.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

文件操作--IO

目录 ♫什么是文件 ♫文件路径 ♫文件类型 ♫文件的管理 ♪File的构造方法 ♪File的常用方法 ♫文件的内容操作 ♪InputStream ♪OutputStream ♪字符流读写文件 ♫Scanner与流对象 ♫什么是文件 文件在计算机里可以指“狭义”的文件(指硬盘上的文件和目录&…

c语言详解牛顿迭代法以及求解倒数和平方根

Newtons iteration method 是在实数域和复数域利用切线不断逼近方程根的一种求高次曲线方程的方法,区别于梯度下降法,它是二阶导,收敛速度比较快,对于非凸函数,牛顿法容易受到鞍点或者最大值点的吸引。由于牛顿迭代法是…

产品学习之路(一)

在做好开发的同时,还需要熟悉产品业务逻辑,不能为了功能而做功能,要从产品经理的角度去看待每个需求和客户痛点所在,这样针对产品设计出来的东西自己也有发言权; 目前作为一名前端开发人员,也在自学产品知识…

xxl-job 分布式任务调度框架

文章目录 分布式任务调度XXL-Job 简介XXL-Job 环境搭建XXL-Job (源码说明)配置部署调度中心docker安装 Bean模式任务(方法形式)-入门案例任务详解任务详解-执行器任务详解-基础配置任务详解-调度配置任务详解-基础配置任务详解-阻塞处理策略任务详解-路由策略 路由策略路由策略…

Redis数据结构之压缩列表

压缩列表是Redis为节约内存而开发的,是由一系列特殊编码的连续内存块组成的顺序型数据结构。一个压缩列表可以包含任意多个节点,每个节点可以保存一个字节数组或者整数值。 压缩列表构成 zlbytes: 记录整个压缩列表占用的内存字节数,对压缩列…

LD_PRELOAD劫持、ngixn临时文件、无需临时文件rce

LD_PRELOAD劫持 <1> LD_PRELOAD简介 LD_PRELOAD 是linux下的一个环境变量。用于动态链接库的加载&#xff0c;在动态链接库的过程中他的优先级是最高的。类似于 .user.ini 中的 auto_prepend_file&#xff0c;那么我们就可以在自己定义的动态链接库中装入恶意函数。 也…

Java数据结构之《折半查找》题目

一、前言&#xff1a; 这是怀化学院的&#xff1a;Java数据结构中的一道难度中等的一道编程题(此方法为博主自己研究&#xff0c;问题基本解决&#xff0c;若有bug欢迎下方评论提出意见&#xff0c;我会第一时间改进代码&#xff0c;谢谢&#xff01;) 后面其他编程题只要我写完…

自定义函数中的(int*a,int*b)与(int*a,int n)

事实上第一种更安全&#xff0c;不会因越界发生占位&#xff0c;从而导致错误。

C++的类和对象(一)

目录 1、面向过程和面向对象初认识 2、为什么要有类 3、类的定义 类的两种定义方式 4、类的访问限定符 5、类的作用域 5.1 为什么要有作用域&#xff1f; 5.2类作用域 6、类的实例化 6.1类的实例化的定义 6.2类的实例化的实现 6.3经典面试题 7、类对象 7.1类对…

计算机体系结构补充篇----静态超标量流水线及循环展开(一)

本文仅供学习&#xff0c;不作任何商业用途&#xff0c;严禁转载。部分资料取自----计算机系统结构教程(第二版)张晨曦等。部分资料来自----国科大计算机体系结构课程PPT–张科、刘珂、高婉玲 计算机体系结构----静态超标量流水线及循环展开&#xff08;一&#xff09; 摘要静…

FreeRTOS第2天:

1. 二值信号量简介&#xff08;386.11&#xff09; 什么是信号量&#xff1f; 信号量&#xff08;Semaphore&#xff09;&#xff0c;是在多任务环境下使用的一种机制&#xff0c;是可以用来保证两个或多个关键代码段不被并 发调用。信号量这个名字&#xff0c;我们可以把它拆…

【蓝桥杯软件赛 零基础备赛20周】第5周——高精度大数运算与队列

文章目录 1. 数组的应用–高精度大数运算1.1 Java和Python计算大数1.2 C/C高精度计算大数1.2.1 高精度加法1.2.2 高精度减法 2. 队列2.1 手写队列2.1.1 C/C手写队列2.1.2 Java手写队列2.1.3 Python手写队列 2.2 C STL队列queue2.3 Java队列Queue2.4 Python队列Queue和deque2.5 …

边缘数据中心和5G的融合彻底改变数据传输和物联网

伴随着数字化时代的飞速发展&#xff0c;边缘数据中心和5G技术的联袂崛起&#xff0c;正深刻塑造着人们对数据的创造、传输和处理方式。据Gartner公司的预测&#xff0c;到2025年&#xff0c;企业数据的三分之二将在边缘计算设施中涌现&#xff0c;而非传统的集中式数据中心。这…

134. 加油站(贪心算法)

根据题解 这道题使用贪心算法&#xff0c;找到当前可解决问题的状态即可 「贪心算法」的问题需要满足的条件&#xff1a; 最优子结构&#xff1a;规模较大的问题的解由规模较小的子问题的解组成&#xff0c;规模较大的问题的解只由其中一个规模较小的子问题的解决定&#xff…

分享88个节日PPT,总有一款适合您

分享88个节日PPT&#xff0c;总有一款适合您 88个节日PPT下载链接&#xff1a;https://pan.baidu.com/s/1mfLrdlB9Y1jqz2vkVIwBNA?pwd6666 提取码&#xff1a;6666 Python采集代码下载链接&#xff1a;采集代码.zip - 蓝奏云 学习知识费力气&#xff0c;收集整理更不易…

Markdown学习测试

这里写自定义目录标题 欢迎使用Markdown编辑器新的改变功能快捷键合理的创建标题&#xff0c;有助于目录的生成如何改变文本的样式插入链接与图片如何插入一段漂亮的代码片生成一个适合你的列表创建一个表格设定内容居中、居左、居右SmartyPants 创建一个自定义列表如何创建一个…

国产Type-C接口逻辑协议芯片:Type-C显示器芯片方案

产品介绍 双Type-C盲插选型&#xff1a; LDR6282 PD3.0认证协议芯片&#xff0c;USB-IF TID号&#xff1a;212 支持iic&#xff0c;USB转UART&#xff0c;CC升级方式&#xff0c;多年市场验证&#xff0c;显示器市场出货量&#xff0c;显示器大厂采用兼容性NO.1。采用QFN32 5…

系列十五、SpringBoot的启动原理分析

一、概述 所谓SpringBoot的启动原理&#xff0c;翻译成大白话就是"当我们在主启动类上运行run方法时&#xff0c;SpringBoot底层到底做了什么事情&#xff0c;能够帮助我们启动一个Spring的web应用"&#xff0c;上边用大白话解释了一下什么是SpringBoot的启动原理&am…

vue 解决响应大数据表格渲染崩溃问题

如果可以实现记得点赞分享&#xff0c;谢谢老铁&#xff5e; 1.场景描述 发起请求获取上万条数据&#xff0c;进行表格渲染&#xff0c;使浏览器卡顿&#xff0c;导致网页崩溃。 2.分析原因 1.大量数据加载&#xff0c;过多操作Dom&#xff0c;消耗性能。 2.表格中包含其他…