AIRNet模型使用与代码分析(All-In-One Image Restoration Network)

AIRNet提出了一种较为简易的pipeline,以单一网络结构应对多种任务需求(不同类型,不同程度)。但在效果上看,ALL-In-One是不如One-By-One的,且本文方法的亮点是batch内选择patch进行对比学习。在与sota对比上,仅是Denoise任务精度占优,在Derain与Dehaze任务上,效果不如One-By-One的MPRNet方法。本博客对AIRNet的关键结构实现,loss实现,data_patch实现进行深入分析,并对模型进行推理使用。

其论文的详细可以阅读:https://blog.csdn.net/a486259/article/details/139559389?spm=1001.2014.3001.5501

项目地址:https://blog.csdn.net/a486259/article/details/139559389?spm=1001.2014.3001.5501

项目依赖:torch、mmcv-full
安装mmcv-full时,需要注意torch所对应的cuda版本,要与系统中的cuda版本一致。

1、模型结构

AirNet的网络结构如下所示,输入图像x交由CBDE提取到嵌入空间z,z与x输入到DGRN模块的DGG block中逐步优化,最终输出预测结果。
在这里插入图片描述
模型代码在net\model.py

from torch import nnfrom net.encoder import CBDE
from net.DGRN import DGRNclass AirNet(nn.Module):def __init__(self, opt):super(AirNet, self).__init__()# Encoderself.E = CBDE(opt)  #编码特征值# Restorerself.R = DGRN(opt) #特征解码def forward(self, x_query, x_key):if self.training:fea, logits, labels, inter = self.E(x_query, x_key)restored = self.R(x_query, inter)return restored, logits, labelselse:fea, inter = self.E(x_query, x_query)restored = self.R(x_query, inter)return restored

1.1 CBDE模块

CBDE模块的功能是在模块内进行对比学习,核心是MoCo. Moco论文地址:https://arxiv.org/pdf/1911.05722

class CBDE(nn.Module):def __init__(self, opt):super(CBDE, self).__init__()dim = 256# Encoderself.E = MoCo(base_encoder=ResEncoder, dim=dim, K=opt.batch_size * dim)def forward(self, x_query, x_key):if self.training:# degradation-aware represenetion learningfea, logits, labels, inter = self.E(x_query, x_key)return fea, logits, labels, interelse:# degradation-aware represenetion learningfea, inter = self.E(x_query, x_query)return fea, inter

ResEncoder所对应的网络结构如下所示
在这里插入图片描述

在AIRNet中的CBDE模块里的MoCo模块的关键代码如下,其在内部自行完成了正负样本的分配,最终输出logits, labels用于计算对比损失的loss。但其所优化的模块实际上是ResEncoder。MoCo模块只是在训练阶段起作用,在推理阶段是不起作用的。

class MoCo(nn.Module):def forward(self, im_q, im_k):"""Input:im_q: a batch of query imagesim_k: a batch of key imagesOutput:logits, targets"""if self.training:# compute query featuresembedding, q, inter = self.encoder_q(im_q)  # queries: NxCq = nn.functional.normalize(q, dim=1)# compute key featureswith torch.no_grad():  # no gradient to keysself._momentum_update_key_encoder()  # update the key encoder_, k, _ = self.encoder_k(im_k)  # keys: NxCk = nn.functional.normalize(k, dim=1)# compute logits# Einstein sum is more intuitive# positive logits: Nx1l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)# negative logits: NxKl_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])# logits: Nx(1+K)logits = torch.cat([l_pos, l_neg], dim=1)# apply temperaturelogits /= self.T# labels: positive key indicatorslabels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()# dequeue and enqueueself._dequeue_and_enqueue(k)return embedding, logits, labels, interelse:embedding, _, inter = self.encoder_q(im_q)return embedding, inter

1.2 DGRN模块

DGRN模块的实现代码如下所示,可以看到核心是DGG模块,其不断迭代优化输入图像。

class DGRN(nn.Module):def __init__(self, opt, conv=default_conv):super(DGRN, self).__init__()self.n_groups = 5n_blocks = 5n_feats = 64kernel_size = 3# head modulemodules_head = [conv(3, n_feats, kernel_size)]self.head = nn.Sequential(*modules_head)# bodymodules_body = [DGG(default_conv, n_feats, kernel_size, n_blocks) \for _ in range(self.n_groups)]modules_body.append(conv(n_feats, n_feats, kernel_size))self.body = nn.Sequential(*modules_body)# tailmodules_tail = [conv(n_feats, 3, kernel_size)]self.tail = nn.Sequential(*modules_tail)def forward(self, x, inter):# headx = self.head(x)# bodyres = xfor i in range(self.n_groups):res = self.body[i](res, inter)res = self.body[-1](res)res = res + x# tailx = self.tail(res)return x

在这里插入图片描述
DGG模块的结构示意如下所示
在这里插入图片描述
DGG代码实现如下所示,DGG模块内嵌DGB模块,DGB模块内嵌DGM模块,DGM模块内嵌SFT_layer模块与DCN_layer(可变性卷积)
在这里插入图片描述

2、loss实现

AIRNet中提到的loss如下所示,其中Lrec是L1 loss,Lcl是Moco模块实现的对比损失。
在这里插入图片描述
AIRNet的loss实现代码在train.py中,CE loss是针对CBDE(Moco模块)的输出进行计算,l1 loss是针对修复图像与清晰图片。

    # Network Constructionnet = AirNet(opt).cuda()net.train()# Optimizer and Lossoptimizer = optim.Adam(net.parameters(), lr=opt.lr)CE = nn.CrossEntropyLoss().cuda()l1 = nn.L1Loss().cuda()# Start trainingprint('Start training...')for epoch in range(opt.epochs):for ([clean_name, de_id], degrad_patch_1, degrad_patch_2, clean_patch_1, clean_patch_2) in tqdm(trainloader):degrad_patch_1, degrad_patch_2 = degrad_patch_1.cuda(), degrad_patch_2.cuda()clean_patch_1, clean_patch_2 = clean_patch_1.cuda(), clean_patch_2.cuda()optimizer.zero_grad()if epoch < opt.epochs_encoder:_, output, target, _ = net.E(x_query=degrad_patch_1, x_key=degrad_patch_2)contrast_loss = CE(output, target)loss = contrast_losselse:restored, output, target = net(x_query=degrad_patch_1, x_key=degrad_patch_2)contrast_loss = CE(output, target)l1_loss = l1(restored, clean_patch_1)loss = l1_loss + 0.1 * contrast_loss# backwardloss.backward()optimizer.step()

这里可以看出来,AIRNet首先是训练CBDE模块,最后才训练CBDE模块+DGRN模块。

3、TrainDataset

TrainDataset的实现代码在utils\dataset_utils.py中,首先找到__getitem__函数进行分析。以下代码为关键部分,删除了大部分在逻辑上重复的部分。TrainDataset一共支持5种数据类型,‘denoise_15’: 0, ‘denoise_25’: 1, ‘denoise_50’: 2,是不需要图像对的(在代码里面自动对图像添加噪声);‘derain’: 3, ‘dehaze’: 4是需要图像对进行训练的。

class TrainDataset(Dataset):def __init__(self, args):super(TrainDataset, self).__init__()self.args = argsself.rs_ids = []self.hazy_ids = []self.D = Degradation(args)self.de_temp = 0self.de_type = self.args.de_typeself.de_dict = {'denoise_15': 0, 'denoise_25': 1, 'denoise_50': 2, 'derain': 3, 'dehaze': 4}self._init_ids()self.crop_transform = Compose([ToPILImage(),RandomCrop(args.patch_size),])self.toTensor = ToTensor()def __getitem__(self, _):de_id = self.de_dict[self.de_type[self.de_temp]]if de_id < 3:if de_id == 0:clean_id = self.s15_ids[self.s15_counter]self.s15_counter = (self.s15_counter + 1) % self.num_cleanif self.s15_counter == 0:random.shuffle(self.s15_ids)# clean_id = random.randint(0, len(self.clean_ids) - 1)clean_img = crop_img(np.array(Image.open(clean_id).convert('RGB')), base=16)clean_patch_1, clean_patch_2 = self.crop_transform(clean_img), self.crop_transform(clean_img)clean_patch_1, clean_patch_2 = np.array(clean_patch_1), np.array(clean_patch_2)# clean_name = self.clean_ids[clean_id].split("/")[-1].split('.')[0]clean_name = clean_id.split("/")[-1].split('.')[0]clean_patch_1, clean_patch_2 = random_augmentation(clean_patch_1, clean_patch_2)degrad_patch_1, degrad_patch_2 = self.D.degrade(clean_patch_1, clean_patch_2, de_id)clean_patch_1, clean_patch_2 = self.toTensor(clean_patch_1), self.toTensor(clean_patch_2)degrad_patch_1, degrad_patch_2 = self.toTensor(degrad_patch_1), self.toTensor(degrad_patch_2)self.de_temp = (self.de_temp + 1) % len(self.de_type)if self.de_temp == 0:random.shuffle(self.de_type)return [clean_name, de_id], degrad_patch_1, degrad_patch_2, clean_patch_1, clean_patch_2

可以看出TrainDataset返回的数据有:degrad_patch_1, degrad_patch_2, clean_patch_1, clean_patch_2。

3.1 clean_patch分析

通过以下代码可以看出 clean_patch_1, clean_patch_2是来自于同一个图片,然后基于crop_transform变化,变成了2个对象

            clean_img = crop_img(np.array(Image.open(clean_id).convert('RGB')), base=16)clean_patch_1, clean_patch_2 = self.crop_transform(clean_img), self.crop_transform(clean_img)# clean_name = self.clean_ids[clean_id].split("/")[-1].split('.')[0]clean_name = clean_id.split("/")[-1].split('.')[0]clean_patch_1, clean_patch_2 = random_augmentation(clean_patch_1, clean_patch_2)

crop_transform的定义如下,可见是随机进行crop

crop_transform = Compose([ToPILImage(),RandomCrop(args.patch_size),])

random_augmentation的实现代码如下,可以看到只是随机对图像进行翻转或旋转,其目的是尽可能使随机crop得到clean_patch_1, clean_patch_2差异更大,避免裁剪出高度相似的patch。

def random_augmentation(*args):out = []flag_aug = random.randint(1, 7)for data in args:out.append(data_augmentation(data, flag_aug).copy())return out
def data_augmentation(image, mode):if mode == 0:# originalout = image.numpy()elif mode == 1:# flip up and downout = np.flipud(image)elif mode == 2:# rotate counterwise 90 degreeout = np.rot90(image)elif mode == 3:# rotate 90 degree and flip up and downout = np.rot90(image)out = np.flipud(out)elif mode == 4:# rotate 180 degreeout = np.rot90(image, k=2)elif mode == 5:# rotate 180 degree and flipout = np.rot90(image, k=2)out = np.flipud(out)elif mode == 6:# rotate 270 degreeout = np.rot90(image, k=3)elif mode == 7:# rotate 270 degree and flipout = np.rot90(image, k=3)out = np.flipud(out)else:raise Exception('Invalid choice of image transformation')return out

3.2 degrad_patch分析

degrad_patch来自于clean_patch,可以看到是通过D.degrade进行转换的。

degrad_patch_1, degrad_patch_2 = self.D.degrade(clean_patch_1, clean_patch_2, de_id)

D.degrade相关的代码如下,可以看到只是对图像添加噪声。难怪AIRNet在图像去噪上效果最好。

class Degradation(object):def __init__(self, args):super(Degradation, self).__init__()self.args = argsself.toTensor = ToTensor()self.crop_transform = Compose([ToPILImage(),RandomCrop(args.patch_size),])def _add_gaussian_noise(self, clean_patch, sigma):# noise = torch.randn(*(clean_patch.shape))# clean_patch = self.toTensor(clean_patch)noise = np.random.randn(*clean_patch.shape)noisy_patch = np.clip(clean_patch + noise * sigma, 0, 255).astype(np.uint8)# noisy_patch = torch.clamp(clean_patch + noise * sigma, 0, 255).type(torch.int32)return noisy_patch, clean_patchdef _degrade_by_type(self, clean_patch, degrade_type):if degrade_type == 0:# denoise sigma=15degraded_patch, clean_patch = self._add_gaussian_noise(clean_patch, sigma=15)elif degrade_type == 1:# denoise sigma=25degraded_patch, clean_patch = self._add_gaussian_noise(clean_patch, sigma=25)elif degrade_type == 2:# denoise sigma=50degraded_patch, clean_patch = self._add_gaussian_noise(clean_patch, sigma=50)return degraded_patch, clean_patchdef degrade(self, clean_patch_1, clean_patch_2, degrade_type=None):if degrade_type == None:degrade_type = random.randint(0, 3)else:degrade_type = degrade_typedegrad_patch_1, _ = self._degrade_by_type(clean_patch_1, degrade_type)degrad_patch_2, _ = self._degrade_by_type(clean_patch_2, degrade_type)return degrad_patch_1, degrad_patch_2

4、推理演示

项目中默认包含了All.pth,要单独任务的模型可以到预训练模型下载地址: Google Drive and Baidu Netdisk (password: cr7d). 下载模型放到 ckpt/ 目录下

打开demo.py,将 subprocess.check_output(['mkdir', '-p', opt.output_path]) 替换为os.makedirs(opt.output_path,exist_ok=True),避免在window上报错,具体修改如下所示
在这里插入图片描述

demo.py默认从test\demo目录下读取图片进行测试,可见原始图像如下
在这里插入图片描述
代码运行后的输出结果默认保存在 output\demo目录下,可见对于去雨,去雾,去噪声效果都比较好。
在这里插入图片描述
模型推理时间如下所示,可以看到对一张320, 480的图片,要0.54s
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/25512.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JAVA-学习-2

一、类 1、类的定义 把相似的对象划分了一个类。 类指的就是一种模板&#xff0c;定义了一种特定类型的所有对象的属性和行为 在一个.java的问题件中&#xff0c;可以有多个class&#xff0c;但是智能有一个class是用public的class。被声明的public的class&#xff0c;必须和文…

Pytorch 实现目标检测一(Pytorch 23)

一 目标检测和边界框 在图像分类任务中&#xff0c;我们假设图像中只有一个主要物体对象&#xff0c;我们只关注如何识别其类别。然而&#xff0c;很多时候图像里有多个我们感兴趣的目标&#xff0c;我们不仅想知 道它们的类别&#xff0c;还想得到它们在图像中的具体位置。在…

【前端】响应式布局笔记——自适应布局

自适应布局 自适应布局是不同设备对应不同的html(局部自适应)&#xff0c;通过判断设备的类型或控制局部的变化。 1、获取设备是移动端还是pc端 // 获取设备的信息let userAgent navigator.userAgent.toLowerCase();// 使用正则表达式来判断类型let device /ipad|iphone|m…

数据并非都是正态分布:三种常见的统计分布及其应用

你有没有过这样的经历&#xff1f;使用一款减肥app&#xff0c;通过它的图表来监控自己的体重变化&#xff0c;并预测何时能达到理想体重。这款app预测我需要八年时间才能恢复到大学时的体重&#xff0c;这种不切实际的预测是因为应用使用了简单的线性模型来进行体重预测。这个…

服务部署:.NET项目使用Docker构建镜像与部署

前提条件 安装Docker&#xff1a;确保你的Linux系统上已经安装了Docker。如果没有&#xff0c;请参考官方文档进行安装。 步骤一&#xff1a;准备项目文件 将你的.NET项目从Windows系统复制到Linux系统。你可以使用Git、SCP等工具来完成这个操作。如何是使用virtualbox虚拟电…

国产操作系统上给virtualbox中win7虚拟机安装增强工具 _ 统信 _ 麒麟 _ 中科方德

原文链接&#xff1a;国产操作系统上给virtualbox中win7虚拟机安装增强工具 | 统信 | 麒麟 | 中科方德 Hello&#xff0c;大家好啊&#xff01;今天给大家带来一篇在国产操作系统上给win7虚拟机安装virtualbox增强工具的文章。VirtualBox增强工具&#xff08;Guest Additions&a…

Liunx环境下redis主从集群搭建(保姆级教学)02

Redis在linux下的主从集群配置 本次演示使用三个节点实例一个主节点&#xff0c;两个从节点&#xff1a;7000端口&#xff08;主&#xff09;&#xff0c;7001端口&#xff08;从&#xff09;&#xff0c;7002端口&#xff08;从&#xff09;&#xff1b; 主节点负责写数据&a…

Rust-02-变量与可变性

在Rust中&#xff0c;变量和可变性是两个重要的概念。 变量&#xff1a;变量是用于存储数据的标识符。在Rust中&#xff0c;变量需要声明其类型&#xff0c;例如&#xff1a; let x: i32 5; // 声明一个名为x的变量&#xff0c;类型为i32&#xff08;整数&#xff09;&#…

安装MySQL Sample Database

本文安装的示例数据库为官方的Employees Sample Database。 操作过程参考其安装部分。 在安装前&#xff0c;MySQL已安装完成&#xff0c;环境为Linux。 克隆github项目&#xff1a; $ git clone https://github.com/datacharmer/test_db.git Cloning into test_db... remo…

【西瓜书】6.支持向量机

目录&#xff1a; 1.分类问题SVM 1.1.线性可分 1.2.非线性可分——核函数 2.回归问题SVR 3.软间隔——松弛变量 3.1.分类问题&#xff1a;0/1损失函数、hinge损失、指数损失、对率损失 3.2.回归问题&#xff1a;不敏感损失函数、平方 4.正则化

计算机组成原理之指令格式

1、指令的定义 零地址指令&#xff1a; 1、不需要操作数&#xff0c;如空操作、停机、关中断等指令。 2、堆栈计算机&#xff0c;两个操作数隐藏在栈顶和此栈顶&#xff0c;取两个操作数&#xff0c;并运算的结果后重新压回栈顶。 一地址指令&#xff1a; 二、三地址指令 四…

配置免密登录秘钥报错

移除秘钥&#xff0c;执行 ssh-keygen -R cdh2即可 参考&#xff1a;ECDSA主机密钥已更改,您已请求严格检查。 - 简书

Python爬虫入门与登录验证自动化思路

1、pytyon爬虫 1.1、爬虫简介 Python爬虫是使用Python编写的程序&#xff0c;可以自动访问网页并提取其中的信息。爬虫可以模拟浏览器的行为&#xff0c;自动点击链接、填写表单、进行登录等操作&#xff0c;从而获取网页中的数据。 使用Python编写爬虫的好处是&#xff0c;…

【数据结构】十二、八种常用的排序算法讲解及代码分享

目录 一、插入排序 1)算法思想 2&#xff09;代码 二、希尔排序 1&#xff09;算法思想 2&#xff09;代码 三、选择排序 1&#xff09;算法思想 2&#xff09;代码 四、堆排序 1&#xff09;什么是最大堆 2&#xff09;如何创建最大堆 3&#xff09;算法思想 4&a…

C# Excel操作类EPPlus

摘要 EPPlus 是一个流行的用于操作 Excel 文件的开源库&#xff0c;适用于 C# 和 .NET 环境。它提供了丰富的功能&#xff0c;能够轻松地读取、写入和格式化 Excel 文件&#xff0c;使得在 C# 中进行 Excel 文件处理变得更加简单和高效。EPPlus 不需要安装 Microsoft Office 或…

知乎网站只让知乎用户看文章,普通人看不了

知乎默认不显示全部文章&#xff0c;需要点击展开阅读全文 然而点击后却要登录&#xff0c;这意味着普通人看不了博主写的文章&#xff0c;只有成为知乎用户才有权力查看文章。我想这不是知乎创作者希望的情况&#xff0c;他们写文章肯定是希望所有人都能看到。 这个网站篡改…

应用商店如何检测在架应用内容是否违规?

&#x1f3c6;本文收录于「Bug调优」专栏&#xff0c;主要记录项目实战过程中的Bug之前因后果及提供真实有效的解决方案&#xff0c;希望能够助你一臂之力&#xff0c;帮你早日登顶实现财富自由&#x1f680;&#xff1b;同时&#xff0c;欢迎大家关注&&收藏&&…

线性代数|机器学习-P11方程Ax=b求解研究

文章目录 1. 变量数和约束条件数大小分类2. 最小二乘法和Gram-schmidt变换2.1 Gram-schmidt变换2.2 最小二乘法2.2.1 损失函数-Lasso 和regression2.2.2 损失函数-Lasso2.2.3 损失函数-regression2.2.4 Regression岭回归-矩阵验证2.2.5 Regression岭回归-导数验证 3. 迭代和随机…

【数据结构】队列——循环队列(详解)

目录 0 循环队列 1 特定条件下循环队列队/空队满判断条件 1.1 队列为空的条件 1.2 队列为满的条件 2 循环队列的实现 3 示例 4 注意事项 0 循环队列 循环队列&#xff08;Circular Queue&#xff09;是队列的一种实现方式&#xff0c;它通过将队列存储空间的最后一…

MySQL之多表查询—行子查询

一、引言 上篇博客学习了列子查询。 接下来学习子查询中的第三种——行子查询。 行子查询 1、概念 子查询返回的结果是一行&#xff08;当然可以是多列)&#xff0c;这种子查询称为行子查询。 2、常用的操作符 、 <> (不等于) 、IN 、NOT IN 接下来通过一个需求去演示和…