超分辨率(2)--基于EDSR网络实现图像超分辨率重建

目录

一.项目介绍

二.项目流程详解

2.1.构建网络模型

2.2.数据集处理

2.3.训练模块

2.4.测试模块

三.测试网络


一.项目介绍

EDSR全称Enhanced Deep Residual Networks,是SRResnet的升级版,其对网络结构进行了优化(去除了BN层),省下来的空间可以用于提升模型的size来增强表现力。

为什么要去除BN层:

Batch Norm是深度学习中非常重要的技术,不仅可以使训练更深的网络变容易,加速收敛,还有一定正则化的效果,可以防止模型过拟合。

但对于图像超分辨率来说,网络输出的图像在色彩、对比度、亮度上要求和输入一致,改变的仅仅是分辨率和一些细节,而Batch Norm,对图像来说类似于一种对比度的拉伸,任何图像经过Batch Norm后,其色彩的分布都会被归一化,也就是说,它破坏了图像原本的对比度信息,所以Batch Norm的加入反而影响了网络输出的质量。

网络结构及对比:

移除BN层后,模型更加轻量,BN层所消耗的存储空间等同于上一层CNN层所消耗的,作者指出相比于SRResNet,EDSR去掉BN层之后节约了40%的存储资源。

同时在BN腾出来的空间下插入更多的类似于残差块等CNN-based子网络来增加模型的表现力。

论文地址:

[1707.02921] Enhanced Deep Residual Networks for Single Image Super-Resolution (arxiv.org)icon-default.png?t=N7T8https://arxiv.org/abs/1707.02921源码地址:

developer0hye/EDAR: PyTorch implementation of Deep Convolution Networks based on EDSR for Compression(Jpeg) Artifacts Reduction (github.com)icon-default.png?t=N7T8https://github.com/developer0hye/EDAR

二.项目流程详解

2.1.构建网络模型

def default_conv(in_channels, out_channels, kernel_size, bias=True):return nn.Conv2d(in_channels, out_channels, kernel_size,padding=(kernel_size//2), bias=bias)class MeanShift(nn.Conv2d):def __init__(self, rgb_mean, rgb_std, sign=-1):super(MeanShift, self).__init__(3, 3, kernel_size=1)std = torch.Tensor(rgb_std)self.weight.data = torch.eye(3).view(3, 3, 1, 1)self.weight.data.div_(std.view(3, 1, 1, 1))self.bias.data = sign * torch.Tensor(rgb_mean)self.bias.data.div_(std)self.requires_grad = Falseclass ResBlock(nn.Module):def __init__(self, conv, n_feat, kernel_size,bias=True, act=nn.ReLU(True)):super(ResBlock, self).__init__()m = []for i in range(2):m.append(conv(n_feat, n_feat, kernel_size, bias=bias))if i == 0: m.append(act)# m是设置好的conv层# 设置网络内部层次结构为bodyself.body = nn.Sequential(*m)def forward(self, x):# 获取当前的结果res = self.body(x)# 当前得到的网络和最初的网络融合res += xreturn res

class EDAR(nn.Module):def __init__(self, conv=common.default_conv):super(EDAR, self).__init__()# 参数设置n_resblock = 8  # resnet长度n_feats = 64kernel_size = 3  # 卷积核大小#DIV 2K meanrgb_mean = (0.4488, 0.4371, 0.4040)rgb_std = (1.0, 1.0, 1.0)self.sub_mean = common.MeanShift(rgb_mean, rgb_std)# define head module# 经过卷积,特征图数由3->n_featsm_head = [conv(3, n_feats, kernel_size)]# define body module# Residual Block设置m_body = [common.ResBlock(conv, n_feats, kernel_size) for _ in range(n_resblock)]m_body.append(conv(n_feats, n_feats, kernel_size))# define tail module# 经过卷积,特征图数由n_feats->3m_tail = [conv(n_feats, 3, kernel_size)]self.add_mean = common.MeanShift(rgb_mean, rgb_std, 1)# 设置网络的三个层次self.head = nn.Sequential(*m_head)self.body = nn.Sequential(*m_body)self.tail = nn.Sequential(*m_tail)

前向传播过程:

    def forward(self, x):x = self.sub_mean(x)x = self.head(x)res = self.body(x)res += xx = self.tail(res)x = self.add_mean(x)# 将输入input张量每个元素的范围限制到区间 [min,max],返回结果到一个新张量。# 及输出一个新张量值x,并限制他的值在0~1之间return torch.clamp(x,0.0,1.0)

2.2.数据集处理

import os
import io
import random
from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = Trueclass Dataset(object):def __init__(self, images_dir, patch_size=48, jpeg_quality=40, transforms=None):self.images = os.walk(images_dir).__next__()[2]self.images_path = []for img_file in self.images:if img_file.endswith((".ppm")):try:#print(os.path.join(images_dir, img_file))label = Image.open(os.path.join(images_dir, img_file))self.images_path.append(os.path.join(images_dir, img_file))except:print(f"Image {os.path.join(images_dir, img_file)} didn't get loaded")self.patch_size = patch_sizeself.jpeg_quality = jpeg_qualityself.transforms = transformsself.random_rotate = [0, 90, 180, 270]def __getitem__(self, idx):label = Image.open(self.images_path[idx]).convert('RGB')label = label.rotate(self.random_rotate[random.randrange(0,4)])# randomly crop patch from training setcrop_x = random.randint(0, label.width - self.patch_size)crop_y = random.randint(0, label.height - self.patch_size)# 使用crop函数对图片进行裁剪label = label.crop((crop_x, crop_y, crop_x + self.patch_size, crop_y + self.patch_size))# additive jpeg noisebuffer = io.BytesIO()label.save(buffer, format='jpeg', quality=random.randrange(self.jpeg_quality+1))input = Image.open(buffer).convert('RGB')if self.transforms is not None:input = self.transforms(input)label = self.transforms(label)#print("Image transformed")return input, labeldef __len__(self):return len(self.images_path)

2.3.训练模块

import argparse
import osfrom dataset import Dataset
from edar import EDARimport torch
from torch import nn
from torch.utils.data.dataloader import DataLoader
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torchvision import transforms
from torchvision.models.vgg import vgg16from utils import AverageMeter
from tqdm import tqdmif __name__ == '__main__':'''It enables benchmark mode in cudnn.benchmark mode is good whenever your input sizes for your network do not vary. This way, cudnn will look for the optimal set of algorithms for that particular configuration (which takes some time). This usually leads to faster runtime.But if your input sizes changes at each iteration, then cudnn will benchmark every time a new size appears, possibly leading to worse runtime performances.'''cudnn.benchmark = Truedevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 参数设置parser = argparse.ArgumentParser()# required为true的参数则是必须要设置的参数parser.add_argument('--images_dir', type=str, required=True)parser.add_argument('--outputs_dir', type=str, required=True)parser.add_argument('--jpeg_quality', type=int, default=40)parser.add_argument('--patch_size', type=int, default=48)parser.add_argument('--batch_size', type=int, default=16)parser.add_argument('--num_epochs', type=int, default=400)parser.add_argument('--lr', type=float, default=1e-4)parser.add_argument('--threads', type=int, default=1)parser.add_argument('--seed', type=int, default=123)parser.add_argument('--resume', default='', type=str, metavar='PATH',help='path to latest checkpoint (default: none)')opt = parser.parse_args()# 如果输出文件夹不存在,则自动创建一个文件夹if not os.path.exists(opt.outputs_dir):os.makedirs(opt.outputs_dir)torch.manual_seed(opt.seed)transforms_train = transforms.Compose([transforms.ToTensor()])# 模型设置model = EDAR().to(device)print("Model loaded")if opt.resume:if os.path.isfile(opt.resume):state_dict = model.state_dict()for n, p in torch.load(opt.resume, map_location=lambda storage, loc: storage).items():if n in state_dict.keys():state_dict[n].copy_(p)else:raise KeyError(n)# 损失函数设置criterion = nn.L1Loss()# 优化器设置optimizer = optim.Adam(model.parameters(), lr=opt.lr)print("Data processing started")# 数据集设置dataset = Dataset(opt.images_dir, opt.patch_size, opt.jpeg_quality,transforms=transforms_train)dataloader = DataLoader(dataset=dataset,batch_size=opt.batch_size,shuffle=True,num_workers=opt.threads,pin_memory=True,drop_last=True)print("Data loading completed")#vgg = vgg16(pretrained=True).cuda()#loss_network = nn.Sequential(*list(vgg.features)[:31]).eval()
#     for param in loss_network.parameters():
#         param.requires_grad = False# 开始训练for epoch in range(opt.num_epochs):epoch_losses = AverageMeter()print("Length of the dataset is", len(dataset))with tqdm(total=(len(dataset) - len(dataset) % opt.batch_size)) as _tqdm:_tqdm.set_description('epoch: {}/{}'.format(epoch + 1, opt.num_epochs))# 按照dataloader的格式取出datafor data in dataloader:inputs, labels = datainputs = inputs.to(device)labels = labels.to(device)#print(inputs.size(), labels.size())outs = model(inputs)# 损失值计算,参数是预测值和实际值loss = criterion(outs, labels)#perception_loss = criterion(loss_network(outs), loss_network(labels))#loss = loss + perception_loss*0.06epoch_losses.update(loss.item(), len(inputs))# 梯度清零optimizer.zero_grad()# 反向传播loss.backward()# 更新参数optimizer.step()_tqdm.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))_tqdm.update(len(inputs))torch.save(model.state_dict(), os.path.join(opt.outputs_dir, '{}_epoch_{}.pth'.format("EDAR_", epoch)))

2.4.测试模块

import argparse
import os
import io
import torch
import torch.backends.cudnn as cudnn
from torchvision import transforms
import PIL.Image as pil_image
import globfrom edar import EDARcudnn.benchmark = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")if __name__ == '__main__':# 参数设置parser = argparse.ArgumentParser()parser.add_argument('--weights_path', type=str, required=True)parser.add_argument('--image_path', type=str, required=True)parser.add_argument('--outputs_dir', type=str, required=True)parser.add_argument('--jpeg_quality', type=int, default=40)parser.add_argument('--input_dir', type=str, required=False)opt, unknown = parser.parse_known_args()model = EDAR()state_dict = model.state_dict()# 参数获取for n, p in torch.load(opt.weights_path, map_location=lambda storage, loc: storage).items():if n in state_dict.keys():state_dict[n].copy_(p)else:raise KeyError(n)model = model.to(device)print(device)model.eval()if opt.input_dir:filenames = [os.path.join(opt.input_dir, file) for file in os.listdir(opt.input_dir) if file.endswith(("ppm", "jpeg", "png", "jpg"))]print(filenames)else:filenames = opt.image_pathif not os.path.exists(opt.outputs_dir):os.makedirs(opt.outputs_dir)# 处理单个测试图片时使用:filename = filenamesprint("file is", filename)input = pil_image.open(filename).convert('RGB')print("Input size:", input.size)print("file is", filename)input = pil_image.open(filename).convert('RGB')print("Input size:", input.size)#buffer = io.BytesIO()#input.save(buffer, format='jpeg', quality=opt.jpeg_quality)#input = pil_image.open(buffer)#input.save(os.path.join(opt.outputs_dir, '{}_jpeg_q{}.png'.format(filename, opt.jpeg_quality)))input = transforms.ToTensor()(input).unsqueeze(0).to(device)output_path = os.path.join(opt.outputs_dir, '{}-{}.jpeg'.format(filename.split("/")[-1].split(".")[0], "edar"))if not os.path.exists(output_path):with torch.no_grad():pred = model(input)[-1]pred = pred.mul_(255.0).clamp_(0.0, 255.0).squeeze(0).permute(1, 2, 0).byte().cpu().numpy()output = pil_image.fromarray(pred, mode='RGB')print("Output size", output.size)print("Output dir is", opt.outputs_dir)output.save(output_path)#print(os.path.join(opt.outputs_dir, '{}_{}.png'.format(filename, "EDAR")))#print("Output saved")'''处理多个测试图片时使用:for filename in filenames:print("file is", filename)input = pil_image.open(filename).convert('RGB')print("Input size:", input.size)# buffer = io.BytesIO()# input.save(buffer, format='jpeg', quality=opt.jpeg_quality)# input = pil_image.open(buffer)# input.save(os.path.join(opt.outputs_dir, '{}_jpeg_q{}.png'.format(filename, opt.jpeg_quality)))input = transforms.ToTensor()(input).unsqueeze(0).to(device)output_path = os.path.join(opt.outputs_dir, '{}-{}.jpeg'.format(filename.split("/")[-1].split(".")[0], "edar"))if not os.path.exists(output_path):with torch.no_grad():pred = model(input)[-1]pred = pred.mul_(255.0).clamp_(0.0, 255.0).squeeze(0).permute(1, 2, 0).byte().cpu().numpy()output = pil_image.fromarray(pred, mode='RGB')print("Output size", output.size)print("Output dir is", opt.outputs_dir)output.save(output_path)# print(os.path.join(opt.outputs_dir, '{}_{}.png'.format(filename, "EDAR")))# print("Output saved")'''

三.测试网络

参数设置:

输入图片:

输出图片:

输入图片:

输出图片:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/746353.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

避免阻塞主线程 —— Web Worker 示例项目

前期回顾 迄今为止易用 —— 的 “盲水印“ 实现方案-CSDN博客https://blog.csdn.net/m0_57904695/article/details/136720192?spm1001.2014.3001.5501 目录 CSDN 彩色之外 📝 前言 🚩 技术栈 🛠️ 功能 🤖 如何运行 ♻️ …

《JAVA与模式》之工厂方法模式

系列文章目录 文章目录 系列文章目录前言一、工厂方法模式二、工厂方法模式的活动序列图三、工厂方法模式和简单工厂模式前言 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站,这篇文章男女通用,看懂了就去分享给你的码…

【个人记录】CentOS7安装MySQL 5.7和libmysqlclient.so.20

记录 之前使用MariaDB 发现使用的libmysqlclient.so是18版本的,一些程序需要20版本的库,查了一下需要安装5.7以上版本的才有libmysqlclient.so.20,这里简单记录一下怎么安装。 安装MySQL 5.7 Yum源 yum install -y https://repo.mysql.com…

如何用SMU数字源表测试apd管的暗电流

01 APD工作原理 APD雪崩光电二极管的工作原理是基于光电效应和雪崩效应,当光子被吸收时,会产生电子空穴对,空穴向P区移动,电子向N区移动,由于电场的作用,电子与空穴相遇时会产生二次电子,形成雪…

串行通信——IIC总结

一.什么是IIC? IIC(Inter-Integrated Circuit)也称I2C,中文叫集成电路总线。是一个多主从的串行总线,由飞利浦公司发明的通讯总线,属于半双工同步传输类总线,仅由两条线就能完成多机通讯&#…

【解读】区块链和分布式记账技术标准体系建设指南

大家好,这里是苏泽。一个从业Java后端的区块链技术爱好者。 今天带大家来解读这份三部门印发的行业建设指南《区块链和分布式记账技术标准体系建设指南》 原文件可查看P020240112840724196854.pdf (www.gov.cn) 以下是个人解读,如有纰漏请指正&#xff…

【系统架构师】-第16章-嵌入式系统架构设计理论与实践

1、嵌入式系统发展 第一阶段:单片微型计算机 (SCM) 阶段,即单片机时代,五操作系统 第二阶段:微控制器 (MUC) 阶段,有简单操作系统 第三阶段:片上系统 (SoC),兼容各种微处理器 第四阶段&…

软件测试 —— 测试用例设计报告

写出测试网站的测试用例,测试网站具体内容可看团购网站系统需求说明书1.2.doc 一、流程1:注册→登录 图1:注册->登录流程图 1、 使用场景设计法设计测试用例 1) 找出基本流和备选流 基本流注册用户-成功登录系统备选流1注册…

Jenkins cron定时构建触发器

from: https://www.jenkins.io/doc/book/pipeline/syntax/#cron-syntax 以下内容为根据Jenkins官方文档cron表达式部分翻译过来,使用机翻加个人理解补充内容,包括举例。 目录 介绍举例:设置方法方法一:方法二&#xf…

3.2_1 虚拟内存的基本概念

3.2_1 虚拟内存的基本概念 虚拟存储技术也是存储空间扩充的一种技术,它比交换、覆盖技术要更先进一些。 (一)传统存储管理方式的特征、缺点 对于这种传统的存储管理方案,很多暂时用不到的数据也会长期占用内存,导致内存…

【数据结构和算法初阶(C语言)】栈的概念和实现(后进先出---后来者居上的神奇线性结构带来的惊喜体验)

目录 1.栈 1.1栈的概念及结构 2.栈的实现 3.栈结构对数据的处理方式 3.1对栈进行初始化 3.2 从栈顶添加元素 3.3 打印栈元素 3.4移除栈顶元素 3.5获取栈顶元素 3.6获取栈中的有效个数 3.7 判断链表是否为空 3.9 销毁栈空间 4.结语及整个源码 1.栈 1.1栈的概念及结构 栈&am…

高压辊磨机(辊压机)在矿物加工领域应用广泛 目前本土企业处于向高端转型阶段

高压辊磨机(辊压机)在矿物加工领域应用广泛 目前本土企业处于向高端转型阶段 高压辊磨机又称为辊压机、挤压磨,是基于料层粉碎原理设计的一种干式辊磨设备。高压辊磨机结构形式多样,但原理基本相似,主要由机架、高压工…

浅谈C++绑定器bind1st、bind2nd和函数对象function

今天我们先来谈谈C 标准库里面的绑定器bind1st,bind2nd 和函数对象function C 绑定器和函数对象 一、绑定器二、函数对象 一、绑定器 虽然在C11标准中这两个绑定函数已经被弃用,但仍然值得我们深入思考其底层原理。从字面上理解,“绑定” 这…

Explain

Explain EXPLAIN是MySQL提供的一种用于分析SQL查询执行计划的工具,通过它我们可以深入了解数据库如何执行一条SQL语句,以及优化器在选择索引、访问表和排序数据等方面的决策。 我整理了一份思维导图方便更好查看各个参数的意义,红色表示比较…

RabbitMq踩坑记录

1、连接报错:Broker not available; cannot force queue declarations during start: java.io.IOException 2.1、原因:端口不对 2.2、解决方案: 检查你的连接配置,很可能是你的yml里面的端口配置的是15672,更改为5672即…

css超出部分显示省略号

目录 前言 一、CSS单行实现 二、CSS多行实现(CSS3出的,兼容性需要注意) 三、微信小程序超过2行出现省略号实现 四、JavaScript脚本实现 前言 CSS文本溢出就显示省略号,就是在样式中指定了盒子的宽度与高度,有可能出现某些内…

LLM - 大语言模型(LLM) 概述

欢迎关注我的CSDN:https://spike.blog.csdn.net/ 本文地址:https://blog.csdn.net/caroline_wendy/article/details/136617643 大语言模型(LLM, Large Language Model)的发展和应用是一个非常广泛的领域,涉及从早期的统计模型到现代基于深度学…

【AI+CAD】(二)LLM和VLM生成结构化数据结构(PPT/CAD/DXF)

当前LLM和VLM在PPT生成任务上已经小有成效,如ChatPPT。 @TOC 1. PPT-LLM LLM根据用户的instruction生成规范的绘制ppt的API语句:即使是最强的GPT-4 + CoT也只能达到20-30%的内容准确度。 LLM输入:User_instruction(当前+过去)、PPT_content、PPT_reader_API。其中 PPT_rea…

面试经典150题——随机链表的复制

​前两天断更了两天有点事情🤗 1. 题目描述 2. 题目分析与解析 2.1 思路一 开始还是没什么思路,没思路那就先把题目解决不管方法的好坏。如果不考虑复杂度,该怎么解决? 可以有这样的一种思路: 首先复制链表的所有节…

【python绘图】turle 绘图基本案例

文章目录 0. 基础知识1. 蟒蛇绘制2. 正方形绘制3. 六边形绘制4. 叠边形绘制5. 风轮绘制 0. 基础知识 资料来自中国mooc北京理工大学python课程 1. 蟒蛇绘制 import turtle turtle.setup(650, 350, 200, 200) turtle.penup() turtle.fd(-250) turtle.pendown() turtle.pen…