深度学习之超分辨率算法——FRCNN

– 对之前SRCNN算法的改进

    1. 输出层采用转置卷积层放大尺寸,这样可以直接将低分辨率图片输入模型中,解决了输入尺度问题。
    2. 改变特征维数,使用更小的卷积核和使用更多的映射层。卷积核更小,加入了更多的激活层。
    3. 共享其中的映射层,如果需要训练不同上采样倍率的模型,只需要修改最后的反卷积层大小,就可以训练出不同尺寸的图片。
  • 模型实现
  • 在这里插入图片描述
import math
from torch import nnclass FSRCNN(nn.Module):def __init__(self, scale_factor, num_channels=1, d=56, s=12, m=4):super(FSRCNN, self).__init__()self.first_part = nn.Sequential(nn.Conv2d(num_channels, d, kernel_size=5, padding=5//2),nn.PReLU(d))# 添加入多个激活层和小卷积核self.mid_part = [nn.Conv2d(d, s, kernel_size=1), nn.PReLU(s)]for _ in range(m):self.mid_part.extend([nn.Conv2d(s, s, kernel_size=3, padding=3//2), nn.PReLU(s)])self.mid_part.extend([nn.Conv2d(s, d, kernel_size=1), nn.PReLU(d)])self.mid_part = nn.Sequential(*self.mid_part)# 最后输出self.last_part = nn.ConvTranspose2d(d, num_channels, kernel_size=9, stride=scale_factor, padding=9//2,output_padding=scale_factor-1)self._initialize_weights()def _initialize_weights(self):# 初始化for m in self.first_part:if isinstance(m, nn.Conv2d):nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))nn.init.zeros_(m.bias.data)for m in self.mid_part:if isinstance(m, nn.Conv2d):nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))nn.init.zeros_(m.bias.data)nn.init.normal_(self.last_part.weight.data, mean=0.0, std=0.001)nn.init.zeros_(self.last_part.bias.data)def forward(self, x):x = self.first_part(x)x = self.mid_part(x)x = self.last_part(x)return x

以上代码中,如起初所说,将SRCNN中给的输出修改为转置卷积,并且在中间添加了多个11卷积核和多个线性激活层。且应用了权重初始化,解决协变量偏移问题。
备注:1
1卷积核虽然在通道的像素层面上,针对一个像素进行卷积,貌似没有什么作用,但是卷积神经网络的特性,我们在利用多个卷积核对特征图进行扫描时,单个卷积核扫描后的为sum©,那么就是尽管在像素层面上无用,但是在通道层面上进行了融合,并且进一步加深了层数,使网络层数增加,网络能力增强。

  • 上代码
  • train.py

训练脚本

import argparse
import os
import copyimport torch
from torch import nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdmfrom models import FSRCNN
from datasets import TrainDataset, EvalDataset
from utils import AverageMeter, calc_psnrif __name__ == '__main__':parser = argparse.ArgumentParser()# 训练文件parser.add_argument('--train-file', type=str,help="the dir of train data",default="./Train/91-image_x4.h5")# 测试集文件parser.add_argument('--eval-file', type=str,help="thr dir of test data ",default="./Test/Set5_x4.h5")# 输出的文件夹parser.add_argument('--outputs-dir',help="the output dir", type=str,default="./outputs")parser.add_argument('--weights-file', type=str)parser.add_argument('--scale', type=int, default=2)parser.add_argument('--lr', type=float, default=1e-3)parser.add_argument('--batch-size', type=int, default=16)parser.add_argument('--num-epochs', type=int, default=20)parser.add_argument('--num-workers', type=int, default=8)parser.add_argument('--seed', type=int, default=123)args = parser.parse_args()args.outputs_dir = os.path.join(args.outputs_dir, 'x{}'.format(args.scale))if not os.path.exists(args.outputs_dir):os.makedirs(args.outputs_dir)cudnn.benchmark = Truedevice = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')torch.manual_seed(args.seed)model = FSRCNN(scale_factor=args.scale).to(device)criterion = nn.MSELoss()optimizer = optim.Adam([{'params': model.first_part.parameters()},{'params': model.mid_part.parameters()},{'params': model.last_part.parameters(), 'lr': args.lr * 0.1}], lr=args.lr)train_dataset = TrainDataset(args.train_file)train_dataloader = DataLoader(dataset=train_dataset,batch_size=args.batch_size,shuffle=True,num_workers=args.num_workers,pin_memory=True)eval_dataset = EvalDataset(args.eval_file)eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1)best_weights = copy.deepcopy(model.state_dict())best_epoch = 0best_psnr = 0.0for epoch in range(args.num_epochs):model.train()epoch_losses = AverageMeter()with tqdm(total=(len(train_dataset) - len(train_dataset) % args.batch_size), ncols=80) as t:t.set_description('epoch: {}/{}'.format(epoch, args.num_epochs - 1))for data in train_dataloader:inputs, labels = datainputs = inputs.to(device)labels = labels.to(device)preds = model(inputs)loss = criterion(preds, labels)epoch_losses.update(loss.item(), len(inputs))optimizer.zero_grad()loss.backward()optimizer.step()t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))t.update(len(inputs))torch.save(model.state_dict(), os.path.join(args.outputs_dir, 'epoch_{}.pth'.format(epoch)))model.eval()epoch_psnr = AverageMeter()for data in eval_dataloader:inputs, labels = datainputs = inputs.to(device)labels = labels.to(device)with torch.no_grad():preds = model(inputs).clamp(0.0, 1.0)epoch_psnr.update(calc_psnr(preds, labels), len(inputs))print('eval psnr: {:.2f}'.format(epoch_psnr.avg))if epoch_psnr.avg > best_psnr:best_epoch = epochbest_psnr = epoch_psnr.avgbest_weights = copy.deepcopy(model.state_dict())print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))torch.save(best_weights, os.path.join(args.outputs_dir, 'best.pth'))

test.py 测试脚本

import argparseimport torch
import torch.backends.cudnn as cudnn
import numpy as np
import PIL.Image as pil_imagefrom models import FSRCNN
from utils import convert_ycbcr_to_rgb, preprocess, calc_psnrif __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--weights-file', type=str, required=True)parser.add_argument('--image-file', type=str, required=True)parser.add_argument('--scale', type=int, default=3)args = parser.parse_args()cudnn.benchmark = Truedevice = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')model = FSRCNN(scale_factor=args.scale).to(device)state_dict = model.state_dict()for n, p in torch.load(args.weights_file, map_location=lambda storage, loc: storage).items():if n in state_dict.keys():state_dict[n].copy_(p)else:raise KeyError(n)model.eval()image = pil_image.open(args.image_file).convert('RGB')image_width = (image.width // args.scale) * args.scaleimage_height = (image.height // args.scale) * args.scalehr = image.resize((image_width, image_height), resample=pil_image.BICUBIC)lr = hr.resize((hr.width // args.scale, hr.height // args.scale), resample=pil_image.BICUBIC)bicubic = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC)bicubic.save(args.image_file.replace('.', '_bicubic_x{}.'.format(args.scale)))lr, _ = preprocess(lr, device)hr, _ = preprocess(hr, device)_, ycbcr = preprocess(bicubic, device)with torch.no_grad():preds = model(lr).clamp(0.0, 1.0)psnr = calc_psnr(hr, preds)print('PSNR: {:.2f}'.format(psnr))preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)output = np.array([preds, ycbcr[..., 1], ycbcr[..., 2]]).transpose([1, 2, 0])output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)output = pil_image.fromarray(output)# 保存图片output.save(args.image_file.replace('.', '_fsrcnn_x{}.'.format(args.scale)))

datasets.py

数据集的读取

import h5py
import numpy as np
from torch.utils.data import Datasetclass TrainDataset(Dataset):def __init__(self, h5_file):super(TrainDataset, self).__init__()self.h5_file = h5_filedef __getitem__(self, idx):with h5py.File(self.h5_file, 'r') as f:return np.expand_dims(f['lr'][idx] / 255., 0), np.expand_dims(f['hr'][idx] / 255., 0)def __len__(self):with h5py.File(self.h5_file, 'r') as f:return len(f['lr'])class EvalDataset(Dataset):def __init__(self, h5_file):super(EvalDataset, self).__init__()self.h5_file = h5_filedef __getitem__(self, idx):with h5py.File(self.h5_file, 'r') as f:return np.expand_dims(f['lr'][str(idx)][:, :] / 255., 0), np.expand_dims(f['hr'][str(idx)][:, :] / 255., 0)def __len__(self):with h5py.File(self.h5_file, 'r') as f:return len(f['lr'])

工具文件utils.py

  • 主要用来测试psnr指数,图片的格式转换(悄悄说一句,opencv有直接实现~~~)
import torch
import numpy as npdef calc_patch_size(func):def wrapper(args):if args.scale == 2:args.patch_size = 10elif args.scale == 3:args.patch_size = 7elif args.scale == 4:args.patch_size = 6else:raise Exception('Scale Error', args.scale)return func(args)return wrapperdef convert_rgb_to_y(img, dim_order='hwc'):if dim_order == 'hwc':return 16. + (64.738 * img[..., 0] + 129.057 * img[..., 1] + 25.064 * img[..., 2]) / 256.else:return 16. + (64.738 * img[0] + 129.057 * img[1] + 25.064 * img[2]) / 256.def convert_rgb_to_ycbcr(img, dim_order='hwc'):if dim_order == 'hwc':y = 16. + (64.738 * img[..., 0] + 129.057 * img[..., 1] + 25.064 * img[..., 2]) / 256.cb = 128. + (-37.945 * img[..., 0] - 74.494 * img[..., 1] + 112.439 * img[..., 2]) / 256.cr = 128. + (112.439 * img[..., 0] - 94.154 * img[..., 1] - 18.285 * img[..., 2]) / 256.else:y = 16. + (64.738 * img[0] + 129.057 * img[1] + 25.064 * img[2]) / 256.cb = 128. + (-37.945 * img[0] - 74.494 * img[1] + 112.439 * img[2]) / 256.cr = 128. + (112.439 * img[0] - 94.154 * img[1] - 18.285 * img[2]) / 256.return np.array([y, cb, cr]).transpose([1, 2, 0])def convert_ycbcr_to_rgb(img, dim_order='hwc'):if dim_order == 'hwc':r = 298.082 * img[..., 0] / 256. + 408.583 * img[..., 2] / 256. - 222.921g = 298.082 * img[..., 0] / 256. - 100.291 * img[..., 1] / 256. - 208.120 * img[..., 2] / 256. + 135.576b = 298.082 * img[..., 0] / 256. + 516.412 * img[..., 1] / 256. - 276.836else:r = 298.082 * img[0] / 256. + 408.583 * img[2] / 256. - 222.921g = 298.082 * img[0] / 256. - 100.291 * img[1] / 256. - 208.120 * img[2] / 256. + 135.576b = 298.082 * img[0] / 256. + 516.412 * img[1] / 256. - 276.836return np.array([r, g, b]).transpose([1, 2, 0])def preprocess(img, device):img = np.array(img).astype(np.float32)ycbcr = convert_rgb_to_ycbcr(img)x = ycbcr[..., 0]x /= 255.x = torch.from_numpy(x).to(device)x = x.unsqueeze(0).unsqueeze(0)return x, ycbcrdef calc_psnr(img1, img2):return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2))class AverageMeter(object):def __init__(self):self.reset()def reset(self):self.val = 0self.avg = 0self.sum = 0self.count = 0def update(self, val, n=1):self.val = valself.sum += val * nself.count += nself.avg = self.sum / self.count

先跑他个几十轮~
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/64586.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

小程序UI自动化测试实践:Minium+PageObject !

小程序架构上分为渲染层和逻辑层,尽管各平台的运行环境十分相似,但是还是有些许的区别(如下图),比如说JavaScript 语法和 API 支持不一致,WXSS 渲染表现也有不同,所以不论是手工测试&#xff0c…

堆的深度剖析及使用

目录 1.堆的创建1.1初始化1.2销毁 2.堆的使用2.1数据插入2.2堆顶元素2.3数据删除 3.堆的难点3.1向上调整3.1.1视频分析向上调整3.1.2 代码分析 3.2向下调整3.2.1视频分析向下调整3.2.2代码分析 1.堆的创建 堆的物理储存方式其实是一个数组,而逻辑储存方式其实是一个…

Hu矩原理 | cv2中基于Hu矩计算图像轮廓相似度差异的函数cv2.matchShapes【小白记笔记】

Hu 矩(Hu Moments) 是一种用于描述轮廓形状的 不变特征。它基于图像的矩提取,经过数学变换得到 7 个不变矩,这些不变矩在图像 平移、旋转和缩放等几何变换下保持不变,适合用来衡量轮廓或形状的相似度差异。 1、图像矩…

计算无人机俯拍图像的地面采样距离(GSD)矩阵

引言 在无人机遥感、测绘和精细农业等领域,地面采样距离(Ground Sampling Distance,简称 GSD)是一个非常重要的指标。GSD 是指图像中每个像素在地面上实际代表的物理距离,通常以米或厘米为单位。GSD 决定了图像的空间…

浅谈怎样系统的准备前端面试

前言 创业梦碎,回归现实,7 月底毅然裸辞,苦战两个月,拿到了美团和字节跳动的 offer,这算是从业以来第一次真正意义的面试,遇到蛮多问题,比如一开始具体的面试过程我都不懂,基本一直是…

深度学习-74-大语言模型LLM之基于API与llama.cpp启动的模型进行交互

文章目录 1 大模型量化方法1.1 GPTQ(后训练量化)1.2 GGUF(支持CPU)1.3 AWQ(后训练量化)2 llama.cpp2.1 功能2.1.1 Chat(聊天)2.1.2 Completion(补全)2.2 运行开源LLM2.2.1 下载安装llama.cpp2.2.2 下载gguf格式的模型2.2.3 运行大模型3 API访问3.1 调用补全3.2 调用聊天3.3 提取…

sql server 字符集和排序

英文: Latin1_General_CI_AS 中文:Chinese_PRC_CI_AS 影响字符存储,解释用户存在单字节字符类型(char,varchar等)里面的数据 字符排序规则(是否区分大小写等) 中国的用户一定要注意…

【docker】列出与特定镜像名相关的镜像

目录 1. 说明2. 列出所有镜像3. 使用镜像名过滤4. 列出特定标签的镜像5. 结合多个过滤条件6. 使用 JSON 格式和 jq 工具 1. 说明 1.在 Docker 中,如果你想列出与特定镜像名相关的镜像,可以使用 docker images 命令并结合过滤选项(如 --filte…

Elasticsearch 实战应用:开启数据搜索与分析新征程

在当今信息爆炸的时代,高效的数据搜索与分析能力成为众多企业和开发者追求的目标。Elasticsearch 作为一款强大的分布式搜索和分析引擎,正逐渐成为数据处理领域的核心工具之一。在我们的教学过程中,旨在让学生深入理解并熟练掌握 Elasticsear…

Navicat 17 功能简介 | SQL 美化

SQL美化 本期,我们将深入挖掘 Navicat 的实用的SQL代码美化功能。你只需简单地点击“SQL 美化”按钮,即可轻松完成 SQL 的格式化。 随着 17 版本的发布,Navicat 也带来了众多的新特性,包括兼容更多数据库、全新的模型设计、可视化…

2009 ~ 2019 年 408【数据结构】大题解析

2009 年 讲解视频推荐:【BOK408真题讲解-2009年(催更就退网版)】 1. 图的应用(10’) 带权图(权值非负, 表示边连接的两顶点间的距离)的最短路径问题是找出从初始顶点到目标顶点之间…

时空AI赋能低空智能科技创新

随着人工智能技术的不断进步,时空人工智能(Spatio-Temporal AI,简称时空AI)正在逐渐成为推动低空经济发展的新引擎。时空AI结合了地理空间智能、城市空间智能和时空大数据智能,为低空智能科技创新提供了强大的数据支持…

SamOut 任意长度推理空间不变

项目地址 import numpy as np import pandas as pd import torch from tqdm import tqdmfrom infer_model import SamOutdef load_model_and_voc(device"cpu"):voc pd.read_pickle("total_voc.pkl")net SamOut(len(voc["voc"]), 1024 512, 64…

17.springcloud_openfeign之扩展组件一

文章目录 一、前言二、默认约定配置FeignAutoConfigurationCachingCapabilityFeignCachingInvocationHandlerFactoryFeignJacksonConfiguration熔断器配置FeignCircuitBreakerTargeterFeignCircuitBreaker.Builder FeignClientsConfigurationCircuitBreakerFactory 总结 一、前…

Python读取Excel批量写入到PPT生成词卡

一、问题的提出 有网友想把Excel表中的三列数据,分别是:单词、音标和释义分别写入到PPT当中,每一张PPT写一个单词的内容。这种批量操作是python的强项,尤其是在办公领域,它能较好地解放双手,读取Excel表后…

Proteus(8.15)仿真下载安装过程(附详细安装过程图)

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 目录 前言 一、Proteus是什么? 二、下载链接 三、下安装步骤 1.解压,有键管理员运行 2.点击Next,进行下一步 3.勾选I accept…&#…

防止私接小路由器

电脑获取到IP地址不是DHCP服务器的IP地址段,导致整个公司网络瘫痪,这些故障现象通常80%原因是私接小路由器导致的,以下防止私接小路由器措施。 一、交换机配置DHCP Sooping DHCP snooping是一种DHCP安全特性,用于防止非法设备获…

动态导出word文件支持转pdf

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、功能说明二、使用步骤1.controller2.工具类 DocumentUtil 导出样式 前言 提示:这里可以添加本文要记录的大概内容: 例如&#xff…

紧固件设计之——开槽六角头防脱出杆螺栓仿真APP

按照产品形态分类,紧固件通常包括以下12类:螺栓、螺柱、螺钉、螺母、自攻螺钉、木螺钉、垫圈、挡圈、销、铆钉、焊钉、组合件与连接副,是一类用于连接和固定各种构件和零部件的重要机械零件,可确保机械装置或设备结构的牢固和稳定…

【Python装饰器】编写一个装饰器,并将其放到适当的位置,目的是让代码 1 秒钟打印一个结果

import timedef fib():back1, back2 0, 1def func():nonlocal back1, back2back1, back2 back2, back1 back2print(back1, end )return funcdef get_fib(n):f fib()for i in range(n):f()n int(input("请输入需要获取的斐波那契数:"))get_fib(n) imp…