深度学习之超分辨率算法——SRGAN

  • 更新版本

  • 实现了生成对抗网络在超分辨率上的使用

  • 更新了损失函数,增加先验函数
    在这里插入图片描述

  • SRresnet实现

import torch
import torchvision
from torch import nnclass ConvBlock(nn.Module):def __init__(self, kernel_size=3, stride=1, n_inchannels=64):super(ConvBlock, self).__init__()self.sequential = nn.Sequential(nn.Conv2d(in_channels=n_inchannels, out_channels=n_inchannels, kernel_size=(kernel_size, kernel_size),stride=(stride, stride), bias=False, padding=(1, 1)),nn.BatchNorm2d(n_inchannels),nn.PReLU(),nn.Conv2d(in_channels=n_inchannels, out_channels=n_inchannels, kernel_size=(kernel_size, kernel_size),stride=(stride, stride), bias=False, padding=(1, 1)),nn.BatchNorm2d(n_inchannels),nn.PReLU(),)def forward(self, x):redisious = xout = self.sequential(x)return redisious + outclass Head_Conv(nn.Module):def __init__(self):super(Head_Conv, self).__init__()self.sequential = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(9, 9), stride=(1, 1), padding=(9 // 2, 9 // 2)),nn.PReLU(),)def forward(self, x):return self.sequential(x)class PixelShuffle(nn.Module):def __init__(self, n_channels=64, upscale_factor=2):super(PixelShuffle, self).__init__()self.sequential = nn.Sequential(nn.Conv2d(in_channels=n_channels, out_channels=n_channels * (upscale_factor ** 2), kernel_size=(3, 3),stride=(1, 1), padding=(3 // 2, 3 // 2)),nn.BatchNorm2d(n_channels * (upscale_factor ** 2)),nn.PixelShuffle(upscale_factor=upscale_factor))def forward(self, x):return self.sequential(x)class Hidden_block(nn.Module):def __init__(self):super(Hidden_block, self).__init__()self.sequential = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=(3 // 2, 3 // 2)),nn.BatchNorm2d(64),)def forward(self, x):return self.sequential(x)class TailConv(nn.Module):def __init__(self):super(TailConv, self).__init__()self.sequential = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=3, kernel_size=(9, 9), stride=(1, 1), padding=(9 // 2, 9 // 2)),nn.Tanh(),)def forward(self, x):return self.sequential(x)class SRResNet(nn.Module):def __init__(self, n_blocks=16):super(SRResNet, self).__init__()self.head = Head_Conv()self.resnet = list()for _ in range(n_blocks):self.resnet.append(ConvBlock(kernel_size=3, stride=1, n_inchannels=64))self.resnet = nn.Sequential(*self.resnet)self.hidden = Hidden_block()self.pixelShuufe = []for _ in range(2):self.pixelShuufe.append(PixelShuffle(n_channels=64, upscale_factor=2))self.pixelShuufe = nn.Sequential(*self.pixelShuufe)self.tail_conv = TailConv()def forward(self, x):head_out = self.head(x)resnet_out = self.resnet(head_out)out = head_out + resnet_outresult = self.pixelShuufe(out)out = self.tail_conv(result)return out

class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.model = SRResNet()def forward(self, x):''':param x:lr_img:return: '''return self.model(x)class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.hidden = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=(3 // 2, 3 // 2)),nn.LeakyReLU(),nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),nn.BatchNorm2d(64),nn.LeakyReLU(),nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),nn.BatchNorm2d(128),nn.LeakyReLU(),nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),nn.BatchNorm2d(128),nn.LeakyReLU(),nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),nn.BatchNorm2d(256),nn.LeakyReLU(),nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),nn.BatchNorm2d(256),nn.LeakyReLU(),nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),nn.BatchNorm2d(512),nn.LeakyReLU(),nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),nn.BatchNorm2d(512),nn.LeakyReLU(),nn.AdaptiveAvgPool2d((6, 6)))self.out_layer = nn.Sequential(nn.Linear(512 * 6 * 6, 1024),nn.LeakyReLU(negative_slope=0.2, inplace=True),nn.Linear(1024, 1),nn.Sigmoid())def forward(self, x):result = self.hidden(x)# print(result.shape)result = result.reshape(result.shape[0], -1)out = self.out_layer(result)return out

SRGAN模型的生成器与判别器的实现


class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.model = SRResNet()def forward(self, x):''':param x:lr_img:return: '''return self.model(x)class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.hidden = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=(3 // 2, 3 // 2)),nn.LeakyReLU(),nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),nn.BatchNorm2d(64),nn.LeakyReLU(),nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),nn.BatchNorm2d(128),nn.LeakyReLU(),nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),nn.BatchNorm2d(128),nn.LeakyReLU(),nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),nn.BatchNorm2d(256),nn.LeakyReLU(),nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),nn.BatchNorm2d(256),nn.LeakyReLU(),nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),nn.BatchNorm2d(512),nn.LeakyReLU(),nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),nn.BatchNorm2d(512),nn.LeakyReLU(),nn.AdaptiveAvgPool2d((6, 6)))self.out_layer = nn.Sequential(nn.Linear(512 * 6 * 6, 1024),nn.LeakyReLU(negative_slope=0.2, inplace=True),nn.Linear(1024, 1),nn.Sigmoid())def forward(self, x):result = self.hidden(x)# print(result.shape)result = result.reshape(result.shape[0], -1)out = self.out_layer(result)return out```
- 针对VGG19 的层数截取
```python
class TruncatedVGG19(nn.Module):"""truncated VGG19网络,用于计算VGG特征空间的MSE损失"""def __init__(self, i, j):""":参数 i: 第 i 个池化层:参数 j: 第 j 个卷积层"""super(TruncatedVGG19, self).__init__()# 加载预训练的VGG模型vgg19 = torchvision.models.vgg19(pretrained=True)print(vgg19)maxpool_counter = 0conv_count = 0truncate_at = 0# 迭代搜索for layer in vgg19.features.children():truncate_at += 1# 统计if isinstance(layer, nn.Conv2d):conv_count += 1if isinstance(layer, nn.MaxPool2d):maxpool_counter += 1conv_counter = 0# 截断位置在第(i-1)个池化层之后(第 i 个池化层之前)的第 j 个卷积层if maxpool_counter == i - 1 and conv_count == j:break# 检查是否满足条件assert maxpool_counter == i - 1 and conv_count == j, "当前 i=%d 、 j=%d 不满足 VGG19 模型结构" % (i, j)# 截取网络self.truncated_vgg19 = nn.Sequential(*list(vgg19.features.children())[:truncate_at + 1])def forward(self, input):output = self.truncated_vgg19(input)  # (N, channels, _w,h)return output
```

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/64899.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

未来将要被淘汰的编程语言

COBOL - 这是一种非常古老的语言,主要用于大型企业系统和政府机构。随着老一代IT工作人员的退休,COBOL程序员变得越来越少。Fortran - 最初用于科学和工程计算,Fortran在特定领域仍然有其应用,但随着更现代的语言(如Py…

路由器做WPAD、VPN、透明代理中之间一个

本文章将采用家中TP-Link路由器 路由器进行配置DNS DNS理解知识本文DNS描述参考:网络安全基础知识&中间件简单介绍_计算机网络中间件-CSDN博客 TP LINK未知的错误,错误编号:-22025 TP-LINK 认证界面地址:https://realnam…

MacOS M3源代码编译Qt6.8.1

编译时间过长,如果不想自己编译,可以通过如果网盘进行下载: 链接: https://pan.baidu.com/s/17lvF5jQ-vR6vE-KEchzrVA?pwdts26 提取码: ts26 在macOS上编译Qt 6需要一些前置步骤和工具。以下是编译Qt 6的基本步骤: 安装Xcode和…

CentOS HTTPS自签证书访问失败问题的排查与解决全流程

sudo cp harbor.crt /usr/local/share/ca-certificates/sudo yum install -y ca-certificatessudo update-ca-trust force-enablesudo update-ca-trust extract 但是访问 https://172.16.20.20 仍然报错 * About to connect() to 172.16.20.20 port 443 (#0) * Trying 172.16.2…

PostgreSQL数据库访问限制详解

pg_hba.conf 文件是 PostgreSQL 数据库系统中非常重要的一个配置文件,它用于定义哪些用户(或客户端)可以连接到 PostgreSQL 数据库服务器,以及他们可以使用哪些认证方法进行连接。 pg_hba.conf 的名称来源于 "Host-Based Aut…

Tool之Excalidraw:Excalidraw(开源的虚拟手绘风格白板)的简介、安装和使用方法、艾米莉应用之详细攻略

Tool之Excalidraw:Excalidraw(开源的虚拟手绘风格白板)的简介、安装和使用方法、艾米莉应用之详细攻略 目录 Excalidraw 简介 1、Excalidraw 的主要特点: Excalidraw 安装和使用方法 1、Excalidraw的安装 T1、使用 npm 安装: T2、使用 …

【蓝桥杯选拔赛真题96】Scratch风车旋转 第十五届蓝桥杯scratch图形化编程 少儿编程创意编程选拔赛真题解析

目录 scratch风车旋转 一、题目要求 编程实现 二、案例分析 1、角色分析 2、背景分析 3、前期准备 三、解题思路 1、思路分析 2、详细过程 四、程序编写 五、考点分析 六、推荐资料 1、入门基础 2、蓝桥杯比赛 3、考级资料 4、视频课程 5、python资料 scratc…

奇怪问题| Chrome 访问csdn 创作中心的时候报错: 服务超时,请稍后重试

Chrome 访问csdn 创作中心的时候报错: 服务超时,请稍后重试用无痕浏览器可以正常访问 关闭代理无效清缓存和Cookies无效。考虑无痕浏览器模式下插件不生效,尝试把chrome 插件也禁用,发现有效,是该扩展程序的缘故

B2HGraphicBufferProducer和H2BGraphicBufferProducer

在 Android 的图形系统中,B2HGraphicBufferProducer 和 BnGraphicBufferProducer 是基于 Binder 机制的两个重要组件,它们负责图形缓冲区的生产接口。二者关系可以理解为 桥接和实现分离,以下是详细说明: 1. B2HGraphicBufferProd…

sentinel学习笔记7-熔断降级

本文属于sentinel学习笔记系列。网上看到吴就业老师的专栏,写的好值得推荐,我整理的有所删减,推荐看原文。 https://blog.csdn.net/baidu_28523317/category_10400605.html 限流需要我们根据不同的硬件条件做好压测,不好准确评估…

记录--uniapp 安卓端实现录音功能,保存为amr/mp3文件

🧑‍💻 写在开头 点赞 收藏 学会🤣🤣🤣 功能实现需要用到MediaRecorder、navigator.mediaDevices.getUserMedia、Blob等API,uniapp App端不支持,需要借助renderjs来实现 实现逻辑 通过naviga…

步进电机位置速度双环控制实现

步进电机位置速度双环控制实现 野火stm32电机教学 提高部分-第11讲 步进电机位置速度双环控制实现(1)_哔哩哔哩_bilibili PID模型 位置环作为外环,速度环作为内环。设定目标位置和实际转轴位置的位置偏差,经过位置PID获得位置期望,然后讲位置期望(位置变化反映了转轴的速…

myexcel的使用

参考: (1)api文档:https://www.bookstack.cn/read/MyExcel-2.x/624d8ce73162300b.md (2)源代码: https://github.com/liaochong/myexcel/issues 我: (1)m…

MySQL 8.0:explain analyze 分析 SQL 执行过程

介绍 MySQL 8.0.16 引入一个实验特性:explain formattree ,树状的输出执行过程,以及预估成本和预估返 回行数。在 MySQL 8.0.18 又引入了 EXPLAIN ANALYZE,在 formattree 基础上,使用时,会执行 SQL &#…

事务、管道

目录 事务 相关命令 悲观锁 乐观锁 管道 实例 Pipeline与原生批量命令对比 Pipeline与事物对比 使用Pipeline注意事项 事务 相关命令 命令描述discard取消事务,放弃执行事务块内的所有命令exec执行所有事务块内的事务(所有命令依次执行&#x…

list的常用操作

list的介绍 list是序列容器,它允许在常数范围O(1)进行插入和删除在这段序列的任意位置,并且可以双向遍历 它是弥补vector容器的缺点,与vector有互补的韵味, 这里我们可以将其进行与vector进行对比 vect…

机器人角度参考方式

机器人的角度可以根据需求和系统设计来决定。通常情况下,机器人角度(如航向角或偏航角)有两种常见的参考方式: 参考开机时的 0:这是最常见的方式,机器人在开机时会将当前的方向作为 0(即参考方向…

Go语言封装Cron定时任务

Go语言封装Cron定时任务 介绍目标项目背景代码分析代码实现主要功能 Cron表达式解析例子 使用示例总结 介绍 在现代应用中,定时任务是非常常见的需求,无论是用于定时清理数据、定时发送邮件,还是定时执行系统维护任务。Go语言作为一门现代编…

3.4 stm32系列:定时器(PWM、定时中断)

一、定时器概述 1.1 软件定时原理 使用纯软件(CPU死等)的方式实现定时(延时)功能; 不精准的延迟: /* 微秒级延迟函数* 不精准* stm32存在压出栈过程需要消耗时间* 存在流水线,执行时间不确定…

28、论文阅读:基于像素分布重映射和多先验Retinex变分模型的水下图像增强

A Pixel Distribution Remapping and Multi-Prior Retinex Variational Model for Underwater Image Enhancement 摘要介绍相关工作基于模型的水下图像增强方法:无模型水下图像增强方法:基于深度学习的水下图像增强方法: 论文方法概述像素分布…