SSF-CNN:空间光谱融合的卷积光谱图像超分网络

SSF-CNN: SPATIAL AND SPECTRAL FUSION WITH CNN FOR HYPERSPECTRAL IMAGE SUPER-RESOLUTION

文章目录

  • SSF-CNN: SPATIAL AND SPECTRAL FUSION WITH CNN FOR HYPERSPECTRAL IMAGE SUPER-RESOLUTION
    • 简介
    • 解决问题
    • 网络框架
    • 代码实现
    • 训练部分
    • 运行结果

简介

​ 本文提出了一种利用空间和光谱进行高光谱融合图像超分辨率的新型CNN架构,首先是对高光谱图像进行双三次插值,使其空间分辨率大小和多光谱一致,然后进行concat操作。使用类似于SRCNN的网络框架对融合超分的图像进行优化,最后输出高分辨率高光谱超分图像。

​ 对于PDCon,也就是引入了部分密集连接,将输入concat到每一个卷积层后面。
Hyperspectral-Image-Super-Resolution-Benchmark——光谱图像超分基准-CSDN博客
Paper: IEEE
Code:https://github.com/miraclefan777/SSFCNN

2023-11-25_16-06-09

解决问题

  1. 传统方法通过基于优化的方法恢复 HR-HS 图像的质量在很大程度上取决于预定义的约束。此外,由于约束项数量较多,优化过程通常涉及较高的计算成本。
  2. 执行HSI SR的一个直接想法是直接应用这样的网络来放大LR-HS图像的空间维度或HR-RGB图像的光谱维度,我们称之为Spatial-CNN和Spectral-CNN,这两种单图像方法忽略了两种图像特有的信息互补优势。

网络框架

  1. 原始的SRCNN是将图片映射到Ycbcr空间,并只使用其中的 Y 分量作为输入来预测 HR Y 图像,该论文则是将图片的通道信息以及空间信息整个进行输入
  2. 原始SRCNN卷积核大小第1,2修改为3*3,增加上下文信息,同时为了避免高维数据(padding为same,保持和原有特征图大小一致)

代码实现

class SSFCNNnet(nn.Module):def __init__(self, num_spectral=31, scale_factor=8, pdconv=False):super(SSFCNNnet, self).__init__()self.scale_factor = scale_factorself.pdconv = pdconvself.Upsample = nn.Upsample(mode='bicubic', scale_factor=self.scale_factor)self.conv1 = nn.Conv2d(num_spectral + 3, 64, kernel_size=3, padding="same")if pdconv:self.conv2 = nn.Conv2d(64 + 3, 32, kernel_size=3, padding="same")self.conv3 = nn.Conv2d(32 + 3, num_spectral, kernel_size=5, padding="same")else:self.conv2 = nn.Conv2d(64, 32, kernel_size=3, padding="same")self.conv3 = nn.Conv2d(32, num_spectral, kernel_size=5, padding="same")self.relu = nn.ReLU(inplace=True)def forward(self, lr_hs, hr_ms):""":param lr_hs:LR-HSI低分辨率的高光谱图像:param hr_ms:高分辨率的多光谱图像:return:"""# 对LR-HSI低分辨率图像进行上采样,让其分辨率更高lr_hs_up = self.Upsample(lr_hs)# 将上采样后的LR-HSI低分辨率图像与高分辨率的多光谱图像进行拼接x = torch.cat((lr_hs_up, hr_ms), dim=1)x = self.relu(self.conv1(x))if self.pdconv:x = torch.cat((x, hr_ms), dim=1)x = self.relu(self.conv2(x))x = torch.cat((x, hr_ms), dim=1)else:x = self.relu(self.conv2(x))out = self.conv3(x)return out

如果需要使用密集连接,只需要在初始化网络模型时,传参pdconv=True

训练部分

未提供自定义dataset类,根据自己的dateset进行参数的修改即可。

import argparse
from calculate_metrics import Loss_SAM, Loss_RMSE, Loss_PSNR
from models.SSFCNNnet import SSFCNNnet
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
from train_dataloader import CAVEHSIDATAprocess
from utils import create_F, fspecial,AverageMeter
import os
import copy
import torch
import torch.nn as nnif __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--model', type=str, default="SSFCNNnet")parser.add_argument('--train-file', type=str, required=True)parser.add_argument('--eval-file', type=str, required=True)parser.add_argument('--outputs-dir', type=str, required=True)parser.add_argument('--scale', type=int, default=2)parser.add_argument('--lr', type=float, default=1e-4)parser.add_argument('--batch-size', type=int, default=32)parser.add_argument('--num-workers', type=int, default=0)parser.add_argument('--num-epochs', type=int, default=400)parser.add_argument('--seed', type=int, default=123)args = parser.parse_args()assert args.model in ['SSFCNNnet', 'PDcon_SSF']outputs_dir = os.path.join(args.outputs_dir, '{}'.format(args.model))if not os.path.exists(outputs_dir):os.makedirs(outputs_dir)device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')torch.manual_seed(args.seed)# 训练参数# loss_func = nn.L1Loss(reduction='mean').cuda()criterion = nn.MSELoss()#################数据集处理#################R = create_F()PSF = fspecial('gaussian', 8, 3)downsample_factor = 8training_size = 64stride = 32stride1 = 32train_dataset = CAVEHSIDATAprocess(args.train_file, R, training_size, stride, downsample_factor, PSF, 20)train_dataloader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True)eval_dataset = CAVEHSIDATAprocess(args.eval_file, R, training_size, stride, downsample_factor, PSF, 12)eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1)#################数据集处理################## 模型if args.model == 'SSFCNNnet':model = SSFCNNnet().cuda()else:model = SSFCNNnet(pdconv=True).cuda()best_weights = copy.deepcopy(model.state_dict())best_epoch = 0best_psnr = 0.0# 模型初始化for m in model.modules():if isinstance(m, (nn.Conv2d, nn.Linear)):nn.init.xavier_uniform_(m.weight)elif isinstance(m, nn.LayerNorm):nn.init.constant_(m.bias, 0)nn.init.constant_(m.weight, 1.0)optimizer = torch.optim.Adam([{'params': model.conv1.parameters()},{'params': model.conv2.parameters()},{'params': model.conv3.parameters(), 'lr': args.lr * 0.1}], lr=args.lr)start_epoch = 0for epoch in range(start_epoch, args.num_epochs):model.train()epoch_losses = AverageMeter()with tqdm(total=(len(train_dataset) - len(train_dataset) % args.batch_size)) as t:t.set_description('epoch:{}/{}'.format(epoch, args.num_epochs - 1))for data in train_dataloader:label, lr_hs, hr_ms = datalabel = label.to(device)lr_hs = lr_hs.to(device)hr_ms = hr_ms.to(device)lr = optimizer.param_groups[0]['lr']pred = model(hr_ms, lr_hs)loss = criterion(pred, label)epoch_losses.update(loss.item(), len(label))optimizer.zero_grad()loss.backward()optimizer.step()t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg), lr='{0:1.8f}'.format(lr))t.update(len(label))# torch.save(model.state_dict(), os.path.join(outputs_dir, 'epoch_{}.pth'.format(epoch)))if epoch % 5 == 0:model.eval()val_loss = AverageMeter()SAM = Loss_SAM()RMSE = Loss_RMSE()PSNR = Loss_PSNR()sam = AverageMeter()rmse = AverageMeter()psnr = AverageMeter()for data in eval_dataloader:label, lr_hs, hr_ms = datalr_hs = lr_hs.to(device)hr_ms = hr_ms.to(device)label = label.cpu().numpy()with torch.no_grad():preds = model(hr_ms, lr_hs).cpu().numpy()sam.update(SAM(preds, label), len(label))rmse.update(RMSE(preds, label), len(label))psnr.update(PSNR(preds, label), len(label))if psnr.avg > best_psnr:best_epoch = epochbest_psnr = psnr.avgbest_weights = copy.deepcopy(model.state_dict())print('eval psnr: {:.2f}  RMSE: {:.2f}  SAM: {:.2f} '.format(psnr.avg, rmse.avg, sam.avg))

运行结果

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/172405.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MYSQL 及 SQL 注入

文章目录 前言什么是sql注入防止SQL注入Like语句中的注入后言 前言 hello world欢迎来到前端的新世界 😜当前文章系列专栏:Mysql 🐱‍👓博主在前端领域还有很多知识和技术需要掌握,正在不断努力填补技术短板。(如果出现…

Adversarial Attack on Graph Structured Data(2018 PMLR)

Adversarial Attack on Graph Structured Data----《图结构数据的对抗攻击》 摘要 基于图结构的深度学习已经在各种应用中显示出令人兴奋的结果。然而,与图像或文本对抗攻击和防御的大量研究工作相比,此类模型的鲁棒性却很少受到关注。在本文中&#xf…

⑩⑧【MySQL】InnoDB架构、事务原理、MVCC多版本并发控制

个人简介:Java领域新星创作者;阿里云技术博主、星级博主、专家博主;正在Java学习的路上摸爬滚打,记录学习的过程~ 个人主页:.29.的博客 学习社区:进去逛一逛~ InnoDB存储引擎 ⑩⑧【MySQL】详解InnoDB存储引…

Day42力扣打卡

打卡记录 统计子串中的唯一字符(找规律) 链接 大佬的题解 class Solution:def uniqueLetterString(self, s: str) -> int:ans total 0last0, last1 {}, {}for i, c in enumerate(s):total i - 2 * last0.get(c, -1) last1.get(c, -1)ans tot…

【数据中台】开源项目(2)-Dbus数据总线

1 背景 企业中大量业务数据保存在各个业务系统数据库中,过去通常的同步数据的方法有很多种,比如: 各个数据使用方在业务低峰期各种抽取所需数据(缺点是存在重复抽取而且数据不一致) 由统一的数仓平台通过sqoop到各个…

网络爬虫(Python:Requests、Beautiful Soup笔记)

网络爬虫(Python:Requests、Beautiful Soup笔记) 网络协议简要介绍一。OSI参考模型二、TCP/IP参考模型对应关系TCP/IP各层实现的协议应用层传输层网络层 HTTP协议HTTP请求HTTP响应HTTP状态码 Requests(Python)Requests…

【Vue】@keyup.enter @v-model.trim的用法

目录 keyup.enter v-model.trim 情景一: 情景二: keyup.enter 作用:监听键盘回车事件 上一篇内容: 记事本 https://blog.csdn.net/m0_67930426/article/details/134630834?spm1001.2014.3001.5502 这里有个添加任务的功能&…

ESP32控制数码管实现数字叠加案例

经过了几个小时的接线和代码实现终于搞定了代码,贴出来大家参考下 import machine import time# 定义4个Led的引脚 led1 machine.Pin(5,machine.Pin.OUT) led2 machine.Pin(18,machine.Pin.OUT) led3 machine.Pin(19,machine.Pin.OUT) led4 machine.Pin(21,mac…

手摸手vue2+Element-ui整合Axios

后端WebAPI准备 跨域问题 为了保证浏览器的安全,不同源的客户端脚本在没有明确授权的情况下,不能读写对方资源,称为同源策略,同源策略是浏览器安全的基石 同源策略( Sameoriginpolicy)是一种约定,它是浏览器最核心也最基本的安全功能 所谓同源(即指在同一个域)就是两个页面具…

Idea常用的快捷键

快捷键 快速生成main()方法:psvm,回车 快速生成输出语句:sout,回车 ctrlz撤回,ctrlshiftz取消撤回 ctrlr替换 CtrlAltspace(内容提示,代码补全等) ctrl句号。最小化方法,恢复最小化方法。 …

【数据中台】开源项目(2)-Moonbox计算服务平台

Moonbox是一个DVtaaS(Data Virtualization as a Service)平台解决方案。 Moonbox基于数据虚拟化设计思想,致力于提供批量计算服务解决方案。Moonbox负责屏蔽底层数据源的物理和使用细节,为用户带来虚拟数据库般使用体验&#xff0…

Tabular特征选择基准

学术实验中的表格基准通常是一小组精心选择的特征。相比之下,工业界数据科学家通常会收集尽可能多的特征到他们的数据集中,甚至从现有的特征中设计新的特征。为了防止在后续的下游建模中过拟合,数据科学家通常使用自动特征选择方法来获得特征子集。Tabular特征选择的现有基准…

JavaFX开发调用AWT创建系统托盘MenuItem菜单中文乱码

打开系统托盘MenuItem只能显示英文字符和中文显示方框 解决办法: 打开Edit Configurations… 选择Mofidy options 勾选Add VM options 在VM optios中填入以下代码 -Dfile.encodingGBK

【MySQL | TCP】宝塔面板结合内网穿透实现公网远程访问

文章目录 前言1.Mysql服务安装2.创建数据库3.安装cpolar3.2 创建HTTP隧道4.远程连接5.固定TCP地址5.1 保留一个固定的公网TCP端口地址5.2 配置固定公网TCP端口地址 前言 宝塔面板的简易操作性,使得运维难度降低,简化了Linux命令行进行繁琐的配置&#x…

Oracle的安装及使用流程

Oracle的安装及使用流程 1.Win10安装Oracle10g 1.1 安装与测试 安装版本: OracleXEUniv10.2.1015.exe 步骤参考:oracleXe下载与安装 安装完成后测试是否正常 # 输入命令连接oracle conn sys as sysdba; # 无密码,直接按回车 # 测试连接的s…

我的第一次SACC之旅

今年有很多第一次,第一次作为“游客”参加DTCC(中国数据库大会),第一次作为讲师参与ACDU中国行(成都站),第一次参加OB年度发布会(包含DBA老友会),而这次是第一…

leetcode面试经典150题——32 串联所有单词的子串(中等+困难)

题目: 串联所有单词的子串(1中等) 描述: 给定两个字符串 s 和 p,找到 s 中所有 p 的 异位词 的子串,返回这些子串的起始索引。不考虑答案输出的顺序。 异位词 指由相同字母重排列形成的字符串(包括相同的字符串&…

【涂鸦T2-U】1、开发环境搭建

前言 本章介绍T2-U的开发环境搭建流程,以及一些遇到的问题。 一、资料 试用网址: 【新品体验】涂鸦 T2-U 开发板免费试用 涂鸦官网文档: 涂鸦 T2-U 开发板 T2-U 模组规格书 T2-U 开发板 淘宝(资料较全): 涂鸦智能 TuyaOS开发…

【C语言】字母转换大小写的三种方法

🦄个人主页:修修修也 🎏所属专栏:C语言 ⚙️操作环境:Visual Studio 2022 目录 方法一:库函数法 1.小写转换大写:toupper()函数 2.大写转换小写:tolower()函数 方法二:自定义函数加减32法 1.小写转换大…

Redis报错:JedisConnectionException: Could not get a resource from the pool

1、问题描述: redis.clients.jedis.exceptions.JedisConnectionException: Could not get a resource from the pool 2、简要分析: redis.clients.util.Pool.getResource会从JedisPool实例池中返回一个可用的redis连接。分析源码可知JedisPool 继承了 r…