[高光谱]PyTorch使用CNN对高光谱图像进行分类

项目原地址:

Hyperspectral-Classificationicon-default.png?t=N6B9https://github.com/eecn/Hyperspectral-ClassificationDataLoader讲解:

[高光谱]使用PyTorch的dataloader加载高光谱数据icon-default.png?t=N6B9https://blog.csdn.net/weixin_37878740/article/details/130929358

一、模型加载

        在原始项目中,提供了14种模型可供选择,从最简单的SVM到3D-CNN,这里以2D-CNN为例,在原项目中需要将model属性设置为:sharma。

         模型通过一个get_model(.)函数获得,该函数一共四个返回(model, optimizer, loss, hyperparams;分别为:模型,迭代器,损失函数,超参数),输入为模型类别。

        进入函数内部,找到对应的函数体如下:

elif name == 'sharma':kwargs.setdefault('batch_size', 60)        #batch_szieepoch = kwargs.setdefault('epoch', 30)     #迭代数lr = kwargs.setdefault('lr', 0.05)         #学习率center_pixel = True                        #是否开启中心像素模型# We assume patch_size = 64kwargs.setdefault('patch_size', 64)        #patch_szie,即图像块大小model = SharmaEtAl(n_bands, n_classes, patch_size=kwargs['patch_size'])  #模型本体optimizer = optim.SGD(model.parameters(), lr=lr, weight_decay=0.0005)    #迭代器criterion = nn.CrossEntropyLoss(weight=kwargs['weights'])            #交叉熵损失函数kwargs.setdefault('scheduler', optim.lr_scheduler.MultiStepLR(optimizer, milestones=[epoch // 2, (5 * epoch) // 6], gamma=0.1))

        这里设置了一部分超参数,同时设置了patch_size为64(此概念可以参见dataloader篇),采用的损失函数为常见的交叉熵损失函数,而模型本体则是使用SharmaEtAl(.)进行加载。

二、模型本体

        跳转至SharmaEtAl(nn.Module),其继承自nn.model,输入参数3个,分别为:输入通道数、分类数、图块尺寸。

def __init__(self, input_channels, n_classes, patch_size=64):

  该网络的结构如图,模型中里面包含3个卷积、2个bn、2个池化和2个全连接,如下:

# 卷积层1
self.conv1 = nn.Conv3d(1, 96, (input_channels, 6, 6), stride=(1,2,2))
self.conv1_bn = nn.BatchNorm3d(96)
self.pool1 = nn.MaxPool3d((1, 2, 2))
# 卷积层2
self.conv2 = nn.Conv3d(1, 256, (96, 3, 3), stride=(1,2,2))
self.conv2_bn = nn.BatchNorm3d(256)
self.pool2 = nn.MaxPool3d((1, 2, 2))
# 卷积层3
self.conv3 = nn.Conv3d(1, 512, (256, 3, 3), stride=(1,1,1))# 展平函数
self.features_size = self._get_final_flattened_size()# 由两个全连接组成的分类器
self.fc1 = nn.Linear(self.features_size, 1024)
self.dropout = nn.Dropout(p=0.5)
self.fc2 = nn.Linear(1024, n_classes)

        其中的展平函数_get_final_flattened_size(.),并不实际参与前向传递,仅计算转换后的通道数。

    def _get_final_flattened_size(self):with torch.no_grad():x = torch.zeros((1, 1, self.input_channels,self.patch_size, self.patch_size))x = F.relu(self.conv1_bn(self.conv1(x)))x = self.pool1(x)print(x.size())b, t, c, w, h = x.size()x = x.view(b, 1, t*c, w, h) x = F.relu(self.conv2_bn(self.conv2(x)))x = self.pool2(x)print(x.size())b, t, c, w, h = x.size()x = x.view(b, 1, t*c, w, h) x = F.relu(self.conv3(x))print(x.size())_, t, c, w, h = x.size()return t * c * w * h

        实际的前向传递如下:

    def forward(self, x):# 卷积块1x = F.relu(self.conv1_bn(self.conv1(x)))x = self.pool1(x)# 获取tensor尺寸b, t, c, w, h = x.size()# 调整tensor尺寸x = x.view(b, 1, t*c, w, h) # 卷积块2x = F.relu(self.conv2_bn(self.conv2(x)))x = self.pool2(x)# 获取tensor尺寸b, t, c, w, h = x.size()# 调整tensor尺寸x = x.view(b, 1, t*c, w, h) # 卷积块3x = F.relu(self.conv3(x))# 调整tensor尺寸x = x.view(-1, self.features_size)# 分类器x = self.fc1(x)x = self.dropout(x)x = self.fc2(x)return x

三、训练与测试

        主函数中,训练和测试结构如下:

        try:train(model, optimizer, loss, train_loader, hyperparams['epoch'],scheduler=hyperparams['scheduler'], device=hyperparams['device'],supervision=hyperparams['supervision'], val_loader=val_loader,display=viz)except KeyboardInterrupt:# Allow the user to stop the trainingpassprobabilities = test(model, img, hyperparams)prediction = np.argmax(probabilities, axis=-1)

        训练被封装在train(.)函数中,测试封装在test(.)函数中,下面逐一来看。

        首先是train函数,这里省去外围部分,仅看核心的循环控制段。

# 外循环控制,用于控制轮次(epoch)
for e in tqdm(range(1, epoch + 1), desc="Training the network"):# 进入训练模式net.train()avg_loss = 0.# 从dataloader中取出图像(data)和标签(target)for batch_idx, (data, target) in tqdm(enumerate(data_loader), total=len(data_loader)):# 如果是GPU模式则需要转换为cuda格式data, target = data.to(device), target.to(device)#---实际的训练部分---## 冻结梯度optimizer.zero_grad()# 训练模式(监督训练/半监督训练)if supervision == 'full':# 前向传递output = net(data)#target = target - 1# 交叉熵损失函数loss = criterion(output, target)elif supervision == 'semi':outs = net(data)output, rec = outs#target = target - 1loss = criterion[0](output, target) + net.aux_loss_weight * criterion[1](rec, data)#---实际的训练部分---## 损失函数反向传递loss.backward()# 迭代器步进optimizer.step()# 记录损失函数avg_loss += loss.item()losses[iter_] = loss.item()mean_losses[iter_] = np.mean(losses[max(0, iter_ - 100):iter_ + 1])iter_ += 1del(data, target, loss, output)

        接下来是test函数,与train不同的是,其参数为:model, img, hyperparams。其中img,是一整张高光谱图像,而不是由DataSet块采样后的图像块。故其结构也与train大不相同。

        在进行测试的时候,需要一个滑动窗口(sliding_window)函数将其进行切块以满足图像输入的要求。同时还需要一个grouper函数将其组装为batch送入神经网络中。所以我们可以看到循环控制的最外层实际上就是上面两个函数来组成的。

    # 图像切块iterations = count_sliding_window(img, **kwargs) // batch_sizefor batch in tqdm(grouper(batch_size, sliding_window(img, **kwargs)),total=(iterations),desc="Inference on the image"):#  锁定梯度with torch.no_grad():#  逐像素模式if patch_size == 1:data = [b[0][0, 0] for b in batch]data = np.copy(data)data = torch.from_numpy(data)# 其他模式else:data = [b[0] for b in batch]data = np.copy(data)data = data.transpose(0, 3, 1, 2)data = torch.from_numpy(data)data = data.unsqueeze(1)indices = [b[1:] for b in batch]# 类型转换data = data.to(device)# 前向传递output = net(data)if isinstance(output, tuple):output = output[0]output = output.to('cpu')if patch_size == 1 or center_pixel:output = output.numpy()else:output = np.transpose(output.numpy(), (0, 2, 3, 1))for (x, y, w, h), out in zip(indices, output):# 将得到的像素平装回原尺寸if center_pixel:probs[x + w // 2, y + h // 2] += outelse:probs[x:x + w, y:y + h] += outreturn probs

        这个函数会使用上述的两个函数,将图像切割成可以放入神经网络的尺寸并逐个进行前向传递,最后将得到的所有像素的结果按照原来的尺寸组成一个结果矩阵返回。

        最后,这个结果由一个argmax函数得到其概率最大的预测结果:

prediction = np.argmax(probabilities, axis=-1)

四、结果计算

        在完成上述步骤后,由metrics(.)函数计算最终的模型结果:

run_results = metrics(prediction, test_gt, ignored_labels=hyperparams['ignored_labels'], n_classes=N_CLASSES)

        其函数体如下:

def metrics(prediction, target, ignored_labels=[], n_classes=None):"""Compute and print metrics (accuracy, confusion matrix and F1 scores).Args:prediction: list of predicted labelstarget: list of target labelsignored_labels (optional): list of labels to ignore, e.g. 0 for undefn_classes (optional): number of classes, max(target) by defaultReturns:accuracy, F1 score by class, confusion matrix"""ignored_mask = np.zeros(target.shape[:2], dtype=np.bool)for l in ignored_labels:ignored_mask[target == l] = Trueignored_mask = ~ignored_mask#target = target[ignored_mask] -1target = target[ignored_mask]prediction = prediction[ignored_mask]results = {}n_classes = np.max(target) + 1 if n_classes is None else n_classescm = confusion_matrix(target,prediction,labels=range(n_classes))results["Confusion matrix"] = cm# Compute global accuracytotal = np.sum(cm)accuracy = sum([cm[x][x] for x in range(len(cm))])accuracy *= 100 / float(total)results["Accuracy"] = accuracy# Compute F1 scoreF1scores = np.zeros(len(cm))for i in range(len(cm)):try:F1 = 2. * cm[i, i] / (np.sum(cm[i, :]) + np.sum(cm[:, i]))except ZeroDivisionError:F1 = 0.F1scores[i] = F1results["F1 scores"] = F1scores# Compute kappa coefficientpa = np.trace(cm) / float(total)pe = np.sum(np.sum(cm, axis=0) * np.sum(cm, axis=1)) / \float(total * total)kappa = (pa - pe) / (1 - pe)results["Kappa"] = kappareturn results

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/41572.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用JMeter创建数据库测试

好吧!我一直觉得我不聪明,所以,我用最详细,最明了的方式来书写这个文章。我相信,我能明白的,你们一定能明白。 我的环境:MySQL:mysql-essential-5.1.51-win32 jdbc驱动:…

mysql 03.查询(重点)

先准备测试数据,代码如下: -- 创建数据库 DROP DATABASE IF EXISTS mydb; CREATE DATABASE mydb; USE mydb;-- 创建student表 CREATE TABLE student (sid CHAR(6),sname VARCHAR(50),age INT,gender VARCHAR(50) DEFAULT male );-- 向student表插入数据…

PHP 公交公司充电桩管理系统mysql数据库web结构apache计算机软件工程网页wamp

一、源码特点 PHP 公交公司充电桩管理系统是一套完善的web设计系统,对理解php编程开发语言有帮助,系统具有完整的源代码和数据库,系统主要采用B/S模式开发。 源码下载 https://download.csdn.net/download/qq_41221322/88220946 论文下…

苹果手机批量删除联系人的2个方法,请查收!

【想要清理通讯录里的“僵尸号”,但是突然发现手机不能批量删除。一个一个删除太麻烦了,有什么办法可以一次性多删几个人吗?】 小编想问问果粉们平时都是怎么删除联系人的?特别是要删除多个联系人的时候,大家还是选择…

matlab保存图片

仅作为记录,大佬请跳过。 文章目录 用界面中的“另存为”用saveas 用界面中的“另存为” 即可。 参考 感谢大佬博主文章:传送门 用saveas 必须在编辑器中的plot之后用saveas(也就是不能在命令行中单独使用——比如在编辑器中plot&#xf…

基于平台的城市排水泵站管理系统设计

安科瑞 耿敏花 近年来我国城市内涝灾害频发,造成人员伤亡以及经济损失严重,严重威胁着城市的安全。数据显示,2015-2018年我国平均每年受淹或发生内涝城市的数量约占我国城市数量的1/5;人民生命财产也损失严重,据统计&a…

基于YOLOv5n/s/m不同参数量级模型开发构建茶叶嫩芽检测识别模型,使用pruning剪枝技术来对模型进行轻量化处理,探索不同剪枝水平下模型性能影响【续】

这里主要是前一篇博文的后续内容,简单回顾一下:本文选取了n/s/m三款不同量级的模型来依次构建训练模型,所有的参数保持同样的设置,之后探索在不同剪枝处理操作下的性能影响。 在上一篇博文中保持30的剪枝程度得到的效果还是比较理…

C++ 学习系列3 -- 函数压栈与出栈

在C中,函数压栈(函数调用)和出栈(函数返回)是函数调用过程中的两个关键步骤。下面将逐步解释这两个过程: 一 函数压栈与出栈过程简介 函数压栈(函数调用)的过程如下: …

2020年3月全国计算机等级考试真题(C语言二级)

2020年3月全国计算机等级考试真题(C语言二级) 第1题 有以下程序 void fun1 (char*p) { char*q; qp; while(*q!\0) { (*Q); q; } } main() { char a[]{"Program"},*p; p&a[3]; fun1(p); print…

新能源电动车充电桩控制主板安全特点

新能源电动车充电桩控制主板安全特点 你是否曾经担心过充电桩的安全问题?充电桩主板又是什么样的呢?今天我们就来聊聊这个话题。 充电桩主板采用双重安全防护系统,包括防水、防护、防尘等,确保充电桩安全、可靠。不仅如此,充电桩主板采用先…

简单的洗牌算法

目录 前言 问题 代码展现及分析 poker类 game类 Text类 前言 洗牌算法为ArrayList具体使用的典例,可以很好的让我们快速熟系ArrayList的用法。如果你对ArrayList还不太了解除,推荐先看本博主的ArrayList的详解。 ArrayList的详解_WHabcwu的博客-CSD…

mysql mysql 容器 忽略大小写配置

首先能够连接上mysql,然后输入下面这个命令查看mysql是否忽略大小写 show global variables like %lower_case%; lower_case_table_names 0:不忽略大小写 lower_case_table_names 1:忽略大小写 mysql安装分为两种(根据自己的my…

FPGA芯片IO口上下拉电阻的使用

FPGA芯片IO口上下拉电阻的使用 为什么要设置上下拉电阻一、如何设置下拉电阻二、如何设置上拉电阻为什么要设置上下拉电阻 这里以高云FPGA的GW1N-UV2QN48C6/I5来举例,这个芯片的上电默认初始化阶段,引脚是弱上来模式,且模式固定不能通过软件的配置来改变。如下图所示: 上…

centos 7.x 单用户模式

最近碰到 centos 7.9 一些参数设置错误无法启动系统的情况,研究后可以使用单用户模式进入系统进行恢复操作。 进入启动界面,按 e ro 替换为 rw init/sysroot/bin/sh 替换前 替换后 Ctrl-x 进行重启进入单用户模式 执行 chroot /sysroot 可以查看日…

java练习4.快速查找

题目: 数组 arr[6,1,3,7,9,8,5,4,2],用快速排序进行升序排序. import java.util.Random;public class recursionDemo {public static void main(String[] args) {/*快速排序:* 第一轮:以0索引为基准数,确定基准数在数组正确的位置,* 比基准数小的放到左边,比基准数大的放在右边…

Scada和lloT有什么区别?

人们经常混淆SCADA(监督控制和数据采集)和IIoT(工业物联网)。虽然SCADA系统已经存在多年,但IIoT是一种相对较新的技术,由于其能够收集和分析来自各种设备的大量数据而越来越受欢迎。SCADA和IIoT都用于提高工…

【学习笔记之vue】These dependencies were not found:

These dependencies were not found:方案一 全部安装一遍 我们先浅试一个axios >> npm install axios 安装完报错就没有axios了,验证咱们的想法没有问题,实行! ok

性能分析之MySQL慢查询日志分析(慢查询日志)

一、背景 MySQL的慢查询日志是MySQL提供的一种日志记录,他用来记录在MySQL中响应的时间超过阈值的语句,具体指运行时间超过long_query_time(默认是10秒)值的SQL,会被记录到慢查询日志中。 慢查询日志一般用于性能分析时开启,收集慢SQL然后通过explain进行全面分析,一…

使用PDF文件入侵任何操作系统

提示:我们8月28号开学,所以我得快点更新了,不能拖了😥 文章目录 前言一、打开终端总结 前言 PDF文件被广泛应用于共享信息,电子邮件,网站或文档或存储系统的真实链接 它可以用于恶意软件的载体。 不要问我什么意思&am…

在项目中如何解除idea和Git的绑定

在项目中如何解除idea和Git的绑定 1、点击File--->Settings...(CtrlAltS)--->Version Control--->Directory Mappings--->点击取消Git的注册根路径: 2、回到idea界面就没有Git了: 3、给这个项目初始化 这样就可以重新绑定远程仓库了&#x…