深度学习之超分辨率算法——SRGAN

  • 更新版本

  • 实现了生成对抗网络在超分辨率上的使用

  • 更新了损失函数,增加先验函数
    在这里插入图片描述

  • SRresnet实现

import torch
import torchvision
from torch import nnclass ConvBlock(nn.Module):def __init__(self, kernel_size=3, stride=1, n_inchannels=64):super(ConvBlock, self).__init__()self.sequential = nn.Sequential(nn.Conv2d(in_channels=n_inchannels, out_channels=n_inchannels, kernel_size=(kernel_size, kernel_size),stride=(stride, stride), bias=False, padding=(1, 1)),nn.BatchNorm2d(n_inchannels),nn.PReLU(),nn.Conv2d(in_channels=n_inchannels, out_channels=n_inchannels, kernel_size=(kernel_size, kernel_size),stride=(stride, stride), bias=False, padding=(1, 1)),nn.BatchNorm2d(n_inchannels),nn.PReLU(),)def forward(self, x):redisious = xout = self.sequential(x)return redisious + outclass Head_Conv(nn.Module):def __init__(self):super(Head_Conv, self).__init__()self.sequential = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(9, 9), stride=(1, 1), padding=(9 // 2, 9 // 2)),nn.PReLU(),)def forward(self, x):return self.sequential(x)class PixelShuffle(nn.Module):def __init__(self, n_channels=64, upscale_factor=2):super(PixelShuffle, self).__init__()self.sequential = nn.Sequential(nn.Conv2d(in_channels=n_channels, out_channels=n_channels * (upscale_factor ** 2), kernel_size=(3, 3),stride=(1, 1), padding=(3 // 2, 3 // 2)),nn.BatchNorm2d(n_channels * (upscale_factor ** 2)),nn.PixelShuffle(upscale_factor=upscale_factor))def forward(self, x):return self.sequential(x)class Hidden_block(nn.Module):def __init__(self):super(Hidden_block, self).__init__()self.sequential = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=(3 // 2, 3 // 2)),nn.BatchNorm2d(64),)def forward(self, x):return self.sequential(x)class TailConv(nn.Module):def __init__(self):super(TailConv, self).__init__()self.sequential = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=3, kernel_size=(9, 9), stride=(1, 1), padding=(9 // 2, 9 // 2)),nn.Tanh(),)def forward(self, x):return self.sequential(x)class SRResNet(nn.Module):def __init__(self, n_blocks=16):super(SRResNet, self).__init__()self.head = Head_Conv()self.resnet = list()for _ in range(n_blocks):self.resnet.append(ConvBlock(kernel_size=3, stride=1, n_inchannels=64))self.resnet = nn.Sequential(*self.resnet)self.hidden = Hidden_block()self.pixelShuufe = []for _ in range(2):self.pixelShuufe.append(PixelShuffle(n_channels=64, upscale_factor=2))self.pixelShuufe = nn.Sequential(*self.pixelShuufe)self.tail_conv = TailConv()def forward(self, x):head_out = self.head(x)resnet_out = self.resnet(head_out)out = head_out + resnet_outresult = self.pixelShuufe(out)out = self.tail_conv(result)return out

class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.model = SRResNet()def forward(self, x):''':param x:lr_img:return: '''return self.model(x)class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.hidden = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=(3 // 2, 3 // 2)),nn.LeakyReLU(),nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),nn.BatchNorm2d(64),nn.LeakyReLU(),nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),nn.BatchNorm2d(128),nn.LeakyReLU(),nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),nn.BatchNorm2d(128),nn.LeakyReLU(),nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),nn.BatchNorm2d(256),nn.LeakyReLU(),nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),nn.BatchNorm2d(256),nn.LeakyReLU(),nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),nn.BatchNorm2d(512),nn.LeakyReLU(),nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),nn.BatchNorm2d(512),nn.LeakyReLU(),nn.AdaptiveAvgPool2d((6, 6)))self.out_layer = nn.Sequential(nn.Linear(512 * 6 * 6, 1024),nn.LeakyReLU(negative_slope=0.2, inplace=True),nn.Linear(1024, 1),nn.Sigmoid())def forward(self, x):result = self.hidden(x)# print(result.shape)result = result.reshape(result.shape[0], -1)out = self.out_layer(result)return out

SRGAN模型的生成器与判别器的实现


class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.model = SRResNet()def forward(self, x):''':param x:lr_img:return: '''return self.model(x)class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.hidden = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=(3 // 2, 3 // 2)),nn.LeakyReLU(),nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),nn.BatchNorm2d(64),nn.LeakyReLU(),nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),nn.BatchNorm2d(128),nn.LeakyReLU(),nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),nn.BatchNorm2d(128),nn.LeakyReLU(),nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),nn.BatchNorm2d(256),nn.LeakyReLU(),nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),nn.BatchNorm2d(256),nn.LeakyReLU(),nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),nn.BatchNorm2d(512),nn.LeakyReLU(),nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),nn.BatchNorm2d(512),nn.LeakyReLU(),nn.AdaptiveAvgPool2d((6, 6)))self.out_layer = nn.Sequential(nn.Linear(512 * 6 * 6, 1024),nn.LeakyReLU(negative_slope=0.2, inplace=True),nn.Linear(1024, 1),nn.Sigmoid())def forward(self, x):result = self.hidden(x)# print(result.shape)result = result.reshape(result.shape[0], -1)out = self.out_layer(result)return out```
- 针对VGG19 的层数截取
```python
class TruncatedVGG19(nn.Module):"""truncated VGG19网络,用于计算VGG特征空间的MSE损失"""def __init__(self, i, j):""":参数 i: 第 i 个池化层:参数 j: 第 j 个卷积层"""super(TruncatedVGG19, self).__init__()# 加载预训练的VGG模型vgg19 = torchvision.models.vgg19(pretrained=True)print(vgg19)maxpool_counter = 0conv_count = 0truncate_at = 0# 迭代搜索for layer in vgg19.features.children():truncate_at += 1# 统计if isinstance(layer, nn.Conv2d):conv_count += 1if isinstance(layer, nn.MaxPool2d):maxpool_counter += 1conv_counter = 0# 截断位置在第(i-1)个池化层之后(第 i 个池化层之前)的第 j 个卷积层if maxpool_counter == i - 1 and conv_count == j:break# 检查是否满足条件assert maxpool_counter == i - 1 and conv_count == j, "当前 i=%d 、 j=%d 不满足 VGG19 模型结构" % (i, j)# 截取网络self.truncated_vgg19 = nn.Sequential(*list(vgg19.features.children())[:truncate_at + 1])def forward(self, input):output = self.truncated_vgg19(input)  # (N, channels, _w,h)return output
```

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/64899.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

路由器做WPAD、VPN、透明代理中之间一个

本文章将采用家中TP-Link路由器 路由器进行配置DNS DNS理解知识本文DNS描述参考:网络安全基础知识&中间件简单介绍_计算机网络中间件-CSDN博客 TP LINK未知的错误,错误编号:-22025 TP-LINK 认证界面地址:https://realnam…

CentOS HTTPS自签证书访问失败问题的排查与解决全流程

sudo cp harbor.crt /usr/local/share/ca-certificates/sudo yum install -y ca-certificatessudo update-ca-trust force-enablesudo update-ca-trust extract 但是访问 https://172.16.20.20 仍然报错 * About to connect() to 172.16.20.20 port 443 (#0) * Trying 172.16.2…

Tool之Excalidraw:Excalidraw(开源的虚拟手绘风格白板)的简介、安装和使用方法、艾米莉应用之详细攻略

Tool之Excalidraw:Excalidraw(开源的虚拟手绘风格白板)的简介、安装和使用方法、艾米莉应用之详细攻略 目录 Excalidraw 简介 1、Excalidraw 的主要特点: Excalidraw 安装和使用方法 1、Excalidraw的安装 T1、使用 npm 安装: T2、使用 …

【蓝桥杯选拔赛真题96】Scratch风车旋转 第十五届蓝桥杯scratch图形化编程 少儿编程创意编程选拔赛真题解析

目录 scratch风车旋转 一、题目要求 编程实现 二、案例分析 1、角色分析 2、背景分析 3、前期准备 三、解题思路 1、思路分析 2、详细过程 四、程序编写 五、考点分析 六、推荐资料 1、入门基础 2、蓝桥杯比赛 3、考级资料 4、视频课程 5、python资料 scratc…

奇怪问题| Chrome 访问csdn 创作中心的时候报错: 服务超时,请稍后重试

Chrome 访问csdn 创作中心的时候报错: 服务超时,请稍后重试用无痕浏览器可以正常访问 关闭代理无效清缓存和Cookies无效。考虑无痕浏览器模式下插件不生效,尝试把chrome 插件也禁用,发现有效,是该扩展程序的缘故

sentinel学习笔记7-熔断降级

本文属于sentinel学习笔记系列。网上看到吴就业老师的专栏,写的好值得推荐,我整理的有所删减,推荐看原文。 https://blog.csdn.net/baidu_28523317/category_10400605.html 限流需要我们根据不同的硬件条件做好压测,不好准确评估…

记录--uniapp 安卓端实现录音功能,保存为amr/mp3文件

🧑‍💻 写在开头 点赞 收藏 学会🤣🤣🤣 功能实现需要用到MediaRecorder、navigator.mediaDevices.getUserMedia、Blob等API,uniapp App端不支持,需要借助renderjs来实现 实现逻辑 通过naviga…

步进电机位置速度双环控制实现

步进电机位置速度双环控制实现 野火stm32电机教学 提高部分-第11讲 步进电机位置速度双环控制实现(1)_哔哩哔哩_bilibili PID模型 位置环作为外环,速度环作为内环。设定目标位置和实际转轴位置的位置偏差,经过位置PID获得位置期望,然后讲位置期望(位置变化反映了转轴的速…

MySQL 8.0:explain analyze 分析 SQL 执行过程

介绍 MySQL 8.0.16 引入一个实验特性:explain formattree ,树状的输出执行过程,以及预估成本和预估返 回行数。在 MySQL 8.0.18 又引入了 EXPLAIN ANALYZE,在 formattree 基础上,使用时,会执行 SQL &#…

事务、管道

目录 事务 相关命令 悲观锁 乐观锁 管道 实例 Pipeline与原生批量命令对比 Pipeline与事物对比 使用Pipeline注意事项 事务 相关命令 命令描述discard取消事务,放弃执行事务块内的所有命令exec执行所有事务块内的事务(所有命令依次执行&#x…

list的常用操作

list的介绍 list是序列容器,它允许在常数范围O(1)进行插入和删除在这段序列的任意位置,并且可以双向遍历 它是弥补vector容器的缺点,与vector有互补的韵味, 这里我们可以将其进行与vector进行对比 vect…

3.4 stm32系列:定时器(PWM、定时中断)

一、定时器概述 1.1 软件定时原理 使用纯软件(CPU死等)的方式实现定时(延时)功能; 不精准的延迟: /* 微秒级延迟函数* 不精准* stm32存在压出栈过程需要消耗时间* 存在流水线,执行时间不确定…

28、论文阅读:基于像素分布重映射和多先验Retinex变分模型的水下图像增强

A Pixel Distribution Remapping and Multi-Prior Retinex Variational Model for Underwater Image Enhancement 摘要介绍相关工作基于模型的水下图像增强方法:无模型水下图像增强方法:基于深度学习的水下图像增强方法: 论文方法概述像素分布…

【路径规划】原理及实现

路径规划(Path Planning)是指在给定地图、起始点和目标点的情况下,确定应该采取的最佳路径。常见的路径规划算法包括A* 算法、Dijkstra 算法、RRT(Rapidly-exploring Random Tree)等。 目录 一.A* 1.算法原理 2.实…

java web springboot

0. 引言 SpringBoot对Spring的改善和优化,它基于约定优于配置的思想,提供了大量的默认配置和实现 使用SpringBoot之后,程序员只需按照它规定的方式去进行程序代码的开发即可,而无需再去编写一堆复杂的配置 SpringBoot的主要功能…

实验四 综合数据流处理-Storm (单机和集群配置部分)

1.前期准备 (1)把docker和docker-compose给下载好 参考:基于docker-compose来搭建zookeeper集群-CSDN博客(注意对于这篇文章下面配置zookeeper的内容,可以直接跳过,因为我们只需要看最上面下载docker-com…

前端开发 之 12个鼠标交互特效下【附完整源码】

前端开发 之 12个鼠标交互特效下【附完整源码】 文章目录 前端开发 之 12个鼠标交互特效下【附完整源码】七:粒子烟花绽放特效1.效果展示2.HTML完整代码 八:彩球释放特效1.效果展示2.HTML完整代码 九:雨滴掉落特效1.效果展示2.HTML完整代码 十…

Java设计模式 —— 【结构型模式】外观模式详解

文章目录 概述结构案例实现优缺点 概述 外观模式又名门面模式,是一种通过为多个复杂的子系统提供一个一致的接口,而使这些子系统更加容易被访问的模式。该模式对外有一个统一接口,外部应用程序不用关心内部子系统的具体的细节,这…

基于Springboot + vue实现的汽车资讯网站

🥂(❁◡❁)您的点赞👍➕评论📝➕收藏⭐是作者创作的最大动力🤞 💖📕🎉🔥 支持我:点赞👍收藏⭐️留言📝欢迎留言讨论 🔥🔥&…

Html:点击图标链接发起QQ临时会话

我们在做前端开发的时候&#xff0c;会遇到用户需要点击一个图标可以发起QQ临时会话&#xff0c;这样不用添加好友也能沟通的&#xff0c;那我们就来看看如何实现这个功能&#xff1a; <a href"http://wpa.qq.com/msgrd?v3&uin你的QQ号码&siteqq&menuyes…