深度学习之超分辨率算法——FRCNN

– 对之前SRCNN算法的改进

    1. 输出层采用转置卷积层放大尺寸,这样可以直接将低分辨率图片输入模型中,解决了输入尺度问题。
    2. 改变特征维数,使用更小的卷积核和使用更多的映射层。卷积核更小,加入了更多的激活层。
    3. 共享其中的映射层,如果需要训练不同上采样倍率的模型,只需要修改最后的反卷积层大小,就可以训练出不同尺寸的图片。
  • 模型实现
  • 在这里插入图片描述
import math
from torch import nnclass FSRCNN(nn.Module):def __init__(self, scale_factor, num_channels=1, d=56, s=12, m=4):super(FSRCNN, self).__init__()self.first_part = nn.Sequential(nn.Conv2d(num_channels, d, kernel_size=5, padding=5//2),nn.PReLU(d))# 添加入多个激活层和小卷积核self.mid_part = [nn.Conv2d(d, s, kernel_size=1), nn.PReLU(s)]for _ in range(m):self.mid_part.extend([nn.Conv2d(s, s, kernel_size=3, padding=3//2), nn.PReLU(s)])self.mid_part.extend([nn.Conv2d(s, d, kernel_size=1), nn.PReLU(d)])self.mid_part = nn.Sequential(*self.mid_part)# 最后输出self.last_part = nn.ConvTranspose2d(d, num_channels, kernel_size=9, stride=scale_factor, padding=9//2,output_padding=scale_factor-1)self._initialize_weights()def _initialize_weights(self):# 初始化for m in self.first_part:if isinstance(m, nn.Conv2d):nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))nn.init.zeros_(m.bias.data)for m in self.mid_part:if isinstance(m, nn.Conv2d):nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))nn.init.zeros_(m.bias.data)nn.init.normal_(self.last_part.weight.data, mean=0.0, std=0.001)nn.init.zeros_(self.last_part.bias.data)def forward(self, x):x = self.first_part(x)x = self.mid_part(x)x = self.last_part(x)return x

以上代码中,如起初所说,将SRCNN中给的输出修改为转置卷积,并且在中间添加了多个11卷积核和多个线性激活层。且应用了权重初始化,解决协变量偏移问题。
备注:1
1卷积核虽然在通道的像素层面上,针对一个像素进行卷积,貌似没有什么作用,但是卷积神经网络的特性,我们在利用多个卷积核对特征图进行扫描时,单个卷积核扫描后的为sum©,那么就是尽管在像素层面上无用,但是在通道层面上进行了融合,并且进一步加深了层数,使网络层数增加,网络能力增强。

  • 上代码
  • train.py

训练脚本

import argparse
import os
import copyimport torch
from torch import nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdmfrom models import FSRCNN
from datasets import TrainDataset, EvalDataset
from utils import AverageMeter, calc_psnrif __name__ == '__main__':parser = argparse.ArgumentParser()# 训练文件parser.add_argument('--train-file', type=str,help="the dir of train data",default="./Train/91-image_x4.h5")# 测试集文件parser.add_argument('--eval-file', type=str,help="thr dir of test data ",default="./Test/Set5_x4.h5")# 输出的文件夹parser.add_argument('--outputs-dir',help="the output dir", type=str,default="./outputs")parser.add_argument('--weights-file', type=str)parser.add_argument('--scale', type=int, default=2)parser.add_argument('--lr', type=float, default=1e-3)parser.add_argument('--batch-size', type=int, default=16)parser.add_argument('--num-epochs', type=int, default=20)parser.add_argument('--num-workers', type=int, default=8)parser.add_argument('--seed', type=int, default=123)args = parser.parse_args()args.outputs_dir = os.path.join(args.outputs_dir, 'x{}'.format(args.scale))if not os.path.exists(args.outputs_dir):os.makedirs(args.outputs_dir)cudnn.benchmark = Truedevice = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')torch.manual_seed(args.seed)model = FSRCNN(scale_factor=args.scale).to(device)criterion = nn.MSELoss()optimizer = optim.Adam([{'params': model.first_part.parameters()},{'params': model.mid_part.parameters()},{'params': model.last_part.parameters(), 'lr': args.lr * 0.1}], lr=args.lr)train_dataset = TrainDataset(args.train_file)train_dataloader = DataLoader(dataset=train_dataset,batch_size=args.batch_size,shuffle=True,num_workers=args.num_workers,pin_memory=True)eval_dataset = EvalDataset(args.eval_file)eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1)best_weights = copy.deepcopy(model.state_dict())best_epoch = 0best_psnr = 0.0for epoch in range(args.num_epochs):model.train()epoch_losses = AverageMeter()with tqdm(total=(len(train_dataset) - len(train_dataset) % args.batch_size), ncols=80) as t:t.set_description('epoch: {}/{}'.format(epoch, args.num_epochs - 1))for data in train_dataloader:inputs, labels = datainputs = inputs.to(device)labels = labels.to(device)preds = model(inputs)loss = criterion(preds, labels)epoch_losses.update(loss.item(), len(inputs))optimizer.zero_grad()loss.backward()optimizer.step()t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))t.update(len(inputs))torch.save(model.state_dict(), os.path.join(args.outputs_dir, 'epoch_{}.pth'.format(epoch)))model.eval()epoch_psnr = AverageMeter()for data in eval_dataloader:inputs, labels = datainputs = inputs.to(device)labels = labels.to(device)with torch.no_grad():preds = model(inputs).clamp(0.0, 1.0)epoch_psnr.update(calc_psnr(preds, labels), len(inputs))print('eval psnr: {:.2f}'.format(epoch_psnr.avg))if epoch_psnr.avg > best_psnr:best_epoch = epochbest_psnr = epoch_psnr.avgbest_weights = copy.deepcopy(model.state_dict())print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))torch.save(best_weights, os.path.join(args.outputs_dir, 'best.pth'))

test.py 测试脚本

import argparseimport torch
import torch.backends.cudnn as cudnn
import numpy as np
import PIL.Image as pil_imagefrom models import FSRCNN
from utils import convert_ycbcr_to_rgb, preprocess, calc_psnrif __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--weights-file', type=str, required=True)parser.add_argument('--image-file', type=str, required=True)parser.add_argument('--scale', type=int, default=3)args = parser.parse_args()cudnn.benchmark = Truedevice = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')model = FSRCNN(scale_factor=args.scale).to(device)state_dict = model.state_dict()for n, p in torch.load(args.weights_file, map_location=lambda storage, loc: storage).items():if n in state_dict.keys():state_dict[n].copy_(p)else:raise KeyError(n)model.eval()image = pil_image.open(args.image_file).convert('RGB')image_width = (image.width // args.scale) * args.scaleimage_height = (image.height // args.scale) * args.scalehr = image.resize((image_width, image_height), resample=pil_image.BICUBIC)lr = hr.resize((hr.width // args.scale, hr.height // args.scale), resample=pil_image.BICUBIC)bicubic = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC)bicubic.save(args.image_file.replace('.', '_bicubic_x{}.'.format(args.scale)))lr, _ = preprocess(lr, device)hr, _ = preprocess(hr, device)_, ycbcr = preprocess(bicubic, device)with torch.no_grad():preds = model(lr).clamp(0.0, 1.0)psnr = calc_psnr(hr, preds)print('PSNR: {:.2f}'.format(psnr))preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)output = np.array([preds, ycbcr[..., 1], ycbcr[..., 2]]).transpose([1, 2, 0])output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)output = pil_image.fromarray(output)# 保存图片output.save(args.image_file.replace('.', '_fsrcnn_x{}.'.format(args.scale)))

datasets.py

数据集的读取

import h5py
import numpy as np
from torch.utils.data import Datasetclass TrainDataset(Dataset):def __init__(self, h5_file):super(TrainDataset, self).__init__()self.h5_file = h5_filedef __getitem__(self, idx):with h5py.File(self.h5_file, 'r') as f:return np.expand_dims(f['lr'][idx] / 255., 0), np.expand_dims(f['hr'][idx] / 255., 0)def __len__(self):with h5py.File(self.h5_file, 'r') as f:return len(f['lr'])class EvalDataset(Dataset):def __init__(self, h5_file):super(EvalDataset, self).__init__()self.h5_file = h5_filedef __getitem__(self, idx):with h5py.File(self.h5_file, 'r') as f:return np.expand_dims(f['lr'][str(idx)][:, :] / 255., 0), np.expand_dims(f['hr'][str(idx)][:, :] / 255., 0)def __len__(self):with h5py.File(self.h5_file, 'r') as f:return len(f['lr'])

工具文件utils.py

  • 主要用来测试psnr指数,图片的格式转换(悄悄说一句,opencv有直接实现~~~)
import torch
import numpy as npdef calc_patch_size(func):def wrapper(args):if args.scale == 2:args.patch_size = 10elif args.scale == 3:args.patch_size = 7elif args.scale == 4:args.patch_size = 6else:raise Exception('Scale Error', args.scale)return func(args)return wrapperdef convert_rgb_to_y(img, dim_order='hwc'):if dim_order == 'hwc':return 16. + (64.738 * img[..., 0] + 129.057 * img[..., 1] + 25.064 * img[..., 2]) / 256.else:return 16. + (64.738 * img[0] + 129.057 * img[1] + 25.064 * img[2]) / 256.def convert_rgb_to_ycbcr(img, dim_order='hwc'):if dim_order == 'hwc':y = 16. + (64.738 * img[..., 0] + 129.057 * img[..., 1] + 25.064 * img[..., 2]) / 256.cb = 128. + (-37.945 * img[..., 0] - 74.494 * img[..., 1] + 112.439 * img[..., 2]) / 256.cr = 128. + (112.439 * img[..., 0] - 94.154 * img[..., 1] - 18.285 * img[..., 2]) / 256.else:y = 16. + (64.738 * img[0] + 129.057 * img[1] + 25.064 * img[2]) / 256.cb = 128. + (-37.945 * img[0] - 74.494 * img[1] + 112.439 * img[2]) / 256.cr = 128. + (112.439 * img[0] - 94.154 * img[1] - 18.285 * img[2]) / 256.return np.array([y, cb, cr]).transpose([1, 2, 0])def convert_ycbcr_to_rgb(img, dim_order='hwc'):if dim_order == 'hwc':r = 298.082 * img[..., 0] / 256. + 408.583 * img[..., 2] / 256. - 222.921g = 298.082 * img[..., 0] / 256. - 100.291 * img[..., 1] / 256. - 208.120 * img[..., 2] / 256. + 135.576b = 298.082 * img[..., 0] / 256. + 516.412 * img[..., 1] / 256. - 276.836else:r = 298.082 * img[0] / 256. + 408.583 * img[2] / 256. - 222.921g = 298.082 * img[0] / 256. - 100.291 * img[1] / 256. - 208.120 * img[2] / 256. + 135.576b = 298.082 * img[0] / 256. + 516.412 * img[1] / 256. - 276.836return np.array([r, g, b]).transpose([1, 2, 0])def preprocess(img, device):img = np.array(img).astype(np.float32)ycbcr = convert_rgb_to_ycbcr(img)x = ycbcr[..., 0]x /= 255.x = torch.from_numpy(x).to(device)x = x.unsqueeze(0).unsqueeze(0)return x, ycbcrdef calc_psnr(img1, img2):return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2))class AverageMeter(object):def __init__(self):self.reset()def reset(self):self.val = 0self.avg = 0self.sum = 0self.count = 0def update(self, val, n=1):self.val = valself.sum += val * nself.count += nself.avg = self.sum / self.count

先跑他个几十轮~
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/64586.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

小程序UI自动化测试实践:Minium+PageObject !

小程序架构上分为渲染层和逻辑层,尽管各平台的运行环境十分相似,但是还是有些许的区别(如下图),比如说JavaScript 语法和 API 支持不一致,WXSS 渲染表现也有不同,所以不论是手工测试&#xff0c…

堆的深度剖析及使用

目录 1.堆的创建1.1初始化1.2销毁 2.堆的使用2.1数据插入2.2堆顶元素2.3数据删除 3.堆的难点3.1向上调整3.1.1视频分析向上调整3.1.2 代码分析 3.2向下调整3.2.1视频分析向下调整3.2.2代码分析 1.堆的创建 堆的物理储存方式其实是一个数组,而逻辑储存方式其实是一个…

Hu矩原理 | cv2中基于Hu矩计算图像轮廓相似度差异的函数cv2.matchShapes【小白记笔记】

Hu 矩(Hu Moments) 是一种用于描述轮廓形状的 不变特征。它基于图像的矩提取,经过数学变换得到 7 个不变矩,这些不变矩在图像 平移、旋转和缩放等几何变换下保持不变,适合用来衡量轮廓或形状的相似度差异。 1、图像矩…

计算无人机俯拍图像的地面采样距离(GSD)矩阵

引言 在无人机遥感、测绘和精细农业等领域,地面采样距离(Ground Sampling Distance,简称 GSD)是一个非常重要的指标。GSD 是指图像中每个像素在地面上实际代表的物理距离,通常以米或厘米为单位。GSD 决定了图像的空间…

浅谈怎样系统的准备前端面试

前言 创业梦碎,回归现实,7 月底毅然裸辞,苦战两个月,拿到了美团和字节跳动的 offer,这算是从业以来第一次真正意义的面试,遇到蛮多问题,比如一开始具体的面试过程我都不懂,基本一直是…

2009 ~ 2019 年 408【数据结构】大题解析

2009 年 讲解视频推荐:【BOK408真题讲解-2009年(催更就退网版)】 1. 图的应用(10’) 带权图(权值非负, 表示边连接的两顶点间的距离)的最短路径问题是找出从初始顶点到目标顶点之间…

时空AI赋能低空智能科技创新

随着人工智能技术的不断进步,时空人工智能(Spatio-Temporal AI,简称时空AI)正在逐渐成为推动低空经济发展的新引擎。时空AI结合了地理空间智能、城市空间智能和时空大数据智能,为低空智能科技创新提供了强大的数据支持…

Python读取Excel批量写入到PPT生成词卡

一、问题的提出 有网友想把Excel表中的三列数据,分别是:单词、音标和释义分别写入到PPT当中,每一张PPT写一个单词的内容。这种批量操作是python的强项,尤其是在办公领域,它能较好地解放双手,读取Excel表后…

Proteus(8.15)仿真下载安装过程(附详细安装过程图)

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 目录 前言 一、Proteus是什么? 二、下载链接 三、下安装步骤 1.解压,有键管理员运行 2.点击Next,进行下一步 3.勾选I accept…&#…

动态导出word文件支持转pdf

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、功能说明二、使用步骤1.controller2.工具类 DocumentUtil 导出样式 前言 提示:这里可以添加本文要记录的大概内容: 例如&#xff…

紧固件设计之——开槽六角头防脱出杆螺栓仿真APP

按照产品形态分类,紧固件通常包括以下12类:螺栓、螺柱、螺钉、螺母、自攻螺钉、木螺钉、垫圈、挡圈、销、铆钉、焊钉、组合件与连接副,是一类用于连接和固定各种构件和零部件的重要机械零件,可确保机械装置或设备结构的牢固和稳定…

mysql中与并发相关的问题?

今天我们来聊聊 MySQL 中与并发相关的一些问题。作为一名资深 Python 开发工程师,我觉得这些问题不仅关乎数据库的稳定性和数据的一致性,更与我们的代码实现和业务逻辑密切相关。 尤其是在高并发环境下,如何保证数据的一致性,如何…

使用k6进行kafka负载测试

1.安装环境 kafka环境 参考Docker搭建kafka环境-CSDN博客 xk6-kafka环境 ./xk6 build --with github.com/mostafa/xk6-kafkalatest 查看安装情况 2.编写脚本 test_kafka.js // Either import the module object import * as kafka from "k6/x/kafka";// Or in…

[机器学习]XGBoost(3)——确定树的结构

XGBoost的目标函数详见[机器学习]XGBoost(2)——目标函数(公式详解) 确定树的结构 之前在关于目标函数的计算中,均假设树的结构是确定的,但实际上,当划分条件不同时,叶子节点包含的…

springboot444新冠物资管理系统的设计与实现(论文+源码)_kaic

摘 要 传统办法管理信息首先需要花费的时间比较多,其次数据出错率比较高,而且对错误的数据进行更改也比较困难,最后,检索数据费事费力。因此,在计算机上安装新冠物资管理系统软件来发挥其高效地信息处理的作用&#x…

Javascript-web API-day02

文章目录 01-事件监听02-点击关闭广告03-随机点名案例04-鼠标经过或离开事件05-可点击的轮播图06-小米搜索框07-键盘类型事件08-键盘事件-发布评论案例09-focus选择器10-评论回车发布11-事件对象12-trim方法13-环境对象14-回调函数15-tab栏切换 01-事件监听 <!DOCTYPE html…

使用xjar 对Spring-Boot JAR 包加密运行

1 Xjar 介绍 Spring Boot JAR 安全加密运行工具&#xff0c;同时支持的原生JAR。 基于对JAR包内资源的加密以及拓展ClassLoader来构建的一套程序加密启动&#xff0c;动态解密运行的方案&#xff0c;避免源码泄露或反编译。 功能特性 无需侵入代码&#xff0c;只需要把编译好的…

深度学习的下一站:解锁人工智能的新边界

引言&#xff1a;新边界的呼唤 深度学习的诞生&#xff0c;犹如人工智能领域的一次革命&#xff0c;激发了语音助手、自动驾驶、智能医疗等前沿技术的飞速发展。然而&#xff0c;面对现实世界的复杂性&#xff0c;现有的深度学习模型仍然存在数据依赖、可解释性差、环境适应力不…

基于DockerCompose搭建Redis主从哨兵模式

linux目录结构 内网配置 哨兵配置文件如下&#xff0c;创建3个哨兵配置文件 # sentinel26379.conf sentinel26380.conf sentinel26381.conf 内容如下 protected-mode no sentinel monitor mymaster redis-master 6379 2 sentinel down-after-milliseconds mymaster 60000 s…

upload-labs靶场1-19关

第 1 关&#xff08;删除前端js校验&#xff09; 点击第一关&#xff0c;我们可以看到页面上传区可以上传一个图片&#xff0c;我们要上传一个 webshell&#xff0c;这里我们上传一句话木马的 php 点击上传 显示文件不支持上传&#xff0c;这时我们查看源码 查看代码后发现&am…