遥感图像应用:在低分辨率图像上实现洪水损害检测(迁移学习)

本文是上一篇关于“在低分辨率图像上实现洪水损害检测”的博客的延申。

代码来源:https://github.com/weining20000/Flooding-Damage-Detection-from-Post-Hurricane-Satellite-Imagery-Based-on-CNN/tree/master

数据储存地址:https://github.com/JeffereyWu/FloodDamageDetection/tree/main

目标:利用迁移学习训练两个预训练的CNN模型(VGGResnet),自动化识别一个区域是否存在洪水损害。

运行环境:Google Colab

1. 导入库

# Pytoch
import torch
from torchvision import datasets, models
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torch.nn as nn
from torch_lr_finder import LRFinder# Data science tools
import numpy as np
import pandas as pd
import os
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrixfrom PIL import Image# Visualizations
import matplotlib.pyplot as plt
import seaborn as sns

2. 迁移学习知识点

  • 对于卷积神经网络(CNN)等模型,通常包括一些卷积层和池化层,这些层的权重用于提取图像的特征。当这些层的参数被冻结时,这些权重将保持不变,不会在训练过程中进行更新。这意味着模型会继续使用预训练模型的特征提取能力
  • 如果模型还包含其他的预训练层,例如预训练的全连接层,这些层的权重也将被冻结,不会更新。
  • 通常,当使用预训练模型进行微调时,会替换模型的最后一层或几层,以适应新的任务。新添加的自定义分类器层的权重将被训练和更新,以适应特定的分类任务。

3. 加载和配置预训练的深度学习模型

#Load pre-trained model
def get_pretrained_model(model_name):"""获取预训练模型的函数。参数:model_name: 要加载的预训练模型的名称(例如,'vgg16' 或 'resnet50')返回:MODEL: 加载并配置好的预训练模型"""if model_name == 'vgg16':model = models.vgg16(pretrained=True)# 将模型的参数(权重)冻结,不进行微调。这意味着这些参数在训练过程中不会更新for param in model.parameters():param.requires_grad = Falsen_inputs = model.classifier[6].in_features # 获取模型分类器最后一层的输入特征数n_classes = 2# 替换模型的分类器部分,添加自定义的分类器model.classifier[6] = nn.Sequential(nn.Linear(n_inputs, 256), nn.ReLU(), nn.Dropout(0.2),nn.Linear(256, n_classes))elif model_name == 'resnet50':model = models.resnet50(pretrained=True)for param in model.parameters():param.requires_grad = False# 获取模型最后一层全连接层的输入特征数n_inputs = model.fc.in_featuresn_classes = 2model.fc = nn.Sequential(nn.Linear(n_inputs, 256), nn.ReLU(), nn.Dropout(0.2),nn.Linear(256, n_classes))# Move to GPUMODEL = model.to(device)return MODEL # 返回加载和配置好的预训练模型

注意,这里vgg16的classifier结构原本为:
Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace=True)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace=True)
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)
以上代码替换了最后一层的classifier,改为:
Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace=True)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace=True)
(5): Dropout(p=0.5, inplace=False)
(6): Sequential(
(0): Linear(in_features=4096, out_features=256, bias=True)
(1): ReLU()
(2): Dropout(p=0.2, inplace=False)
(3): Linear(in_features=256, out_features=2, bias=True)
)
)

注意,这里resnet50的fc结构原本为:
Linear(in_features=2048, out_features=1000, bias=True)
以上代码替换了最后一层fc,改为:
Sequential(
(0): Linear(in_features=2048, out_features=256, bias=True)
(1): ReLU()
(2): Dropout(p=0.2, inplace=False)
(3): Linear(in_features=256, out_features=2, bias=True)
)

4. 建立模型

# VGG 16
model_vgg = get_pretrained_model('vgg16') # 包含加载和配置好的 VGG16 模型
criterion_vgg = nn.CrossEntropyLoss()
optimizer_vgg = torch.optim.Adam(model_vgg.parameters(), lr=0.00002)# ResNet 50
model_resnet50 = get_pretrained_model('resnet50') # 包含加载和配置好的 ResNet50 模型
criterion_resnet50 = nn.CrossEntropyLoss() 
optimizer_resnet50 = torch.optim.Adam(model_resnet50.parameters(), lr=0.001)

5. 定义计算准确率的函数

def acc_vgg(x, y, return_labels=False):with torch.no_grad(): # 禁止梯度计算,因为在准确率计算中不需要梯度信息logits = model_vgg(x)pred_labels = np.argmax(logits.cpu().numpy(), axis=1)if return_labels:return pred_labelselse:return 100*accuracy_score(y.cpu().numpy(), pred_labels)def acc_resnet50(x, y, return_labels=False):with torch.no_grad():logits = model_resnet50(x)pred_labels = np.argmax(logits.cpu().numpy(), axis=1)if return_labels:return pred_labelselse:return 100*accuracy_score(y.cpu().numpy(), pred_labels)

6. 定义一个用于训练深度学习模型的函数

def train(model, criterion, optimizer, acc, xtrain, ytrain, xval, yval, save_file_name, n_epochs, BATCH_SIZE):"""训练深度学习模型的函数。参数:model: 要训练的深度学习模型criterion: 损失函数optimizer: 优化器acc: 准确率计算函数xtrain: 训练数据ytrain: 训练标签xval: 验证数据yval: 验证标签save_file_name: 保存训练后模型权重的文件名n_epochs: 训练的总轮数(epochs)BATCH_SIZE: 每个批次的样本数量返回:训练完成的模型和训练历史记录"""history1 = []# Number of epochs already trained (if using loaded in model weights)try:print(f'Model has been trained for: {model.epochs} epochs.\n')except:model.epochs = 0print(f'Starting Training from Scratch.\n')# Main loopfor epoch in range(n_epochs):# keep track of training and validation loss each epochtrain_loss = 0.0val_loss = 0.0train_acc = 0val_acc = 0# Set to trainingmodel.train()#Training loopfor batch in range(len(xtrain)//BATCH_SIZE):idx = slice(batch * BATCH_SIZE, (batch+1)*BATCH_SIZE)# Clear gradientsoptimizer.zero_grad()# Predicted outputsoutput = model(xtrain[idx])# Loss and BP of gradientsloss = criterion(output, ytrain[idx])loss.backward()# Update the parametersoptimizer.step()# Track train losstrain_loss += loss.item()train_acc = acc(xtrain, ytrain)# After training loops ends, start validation# set to evaluation modemodel.eval()# Don't need to keep track of gradientswith torch.no_grad():# Evaluation loop# F.P.y_val_pred = model(xval)# Validation lossloss = criterion(y_val_pred, yval)val_loss = loss.item()val_acc = acc(xval, yval)history1.append([train_loss / BATCH_SIZE, val_loss, train_acc, val_acc])torch.save(model.state_dict(), save_file_name) # 保存模型权重torch.cuda.empty_cache()# Print training and validation resultsprint("Epoch {} | Train Loss: {:.5f} | Train Acc: {:.2f} | Valid Loss: {:.5f} | Valid Acc: {:.2f} |".format(epoch, train_loss / BATCH_SIZE, acc(xtrain, ytrain), val_loss, acc(xval, yval)))# Format historyhistory = pd.DataFrame(history1, columns=['train_loss', 'val_loss', 'train_acc', 'val_acc'])return model, history

7. 开始训练

N_EPOCHS = 30model_vgg, history_vgg = train(model_vgg,criterion_vgg,optimizer_vgg,acc_vgg,x_train,y_train,x_val,y_val,save_file_name = 'model_vgg.pt',n_epochs = N_EPOCHS,BATCH_SIZE = 3)model_resnet50, history_resnet50 = train(model_resnet50,criterion_resnet50,optimizer_resnet50,acc_resnet50,x_train,y_train,x_val,y_val,save_file_name = 'model_resnet50.pt',n_epochs = N_EPOCHS,BATCH_SIZE = 4)

8. 绘画VGG训练和验证准确率的曲线图

plt.figure() # 创建一个新的绘图窗口
vgg_train_acc = history_vgg['train_acc']
vgg_val_acc = history_vgg['val_acc']
vgg_epoch = range(0, len(vgg_train_acc), 1) # 创建一个包含训练轮次(epochs)的范围对象
plot1, = plt.plot(vgg_epoch, vgg_train_acc, linestyle = "solid", color = "skyblue")
plot2, = plt.plot(vgg_epoch, vgg_val_acc, linestyle = "dashed", color = "orange")
plt.legend([plot1, plot2], ['training acc', 'validation acc']) # 添加图例,以标识图中的两条曲线
plt.xlabel('Epoch')
plt.ylabel('Average Accuracy per Batch')
plt.title('Model VGG-16: Training and Validation Accuracy', pad = 20)
plt.savefig('VGG16-Acc-Plot.png')

在这里插入图片描述

9. 绘画VGG训练和验证损失的曲线图

plt.figure()
vgg_train_loss = history_vgg['train_loss']
vgg_val_loss = history_vgg['val_loss']
vgg_epoch = range(0, len(vgg_train_loss), 1)
plot3, = plt.plot(vgg_epoch, vgg_train_loss, linestyle = "solid", color = "skyblue")
plot4, = plt.plot(vgg_epoch, vgg_val_loss, linestyle = "dashed", color = "orange")
plt.legend([plot3, plot4], ['training loss', 'validation loss'])
plt.xlabel('Epoch')
plt.ylabel('Average Loss per Batch')
plt.title('Model VGG-16: Training and Validation Loss', pad = 20)
plt.savefig('VGG16-Loss-Plot.png')

在这里插入图片描述

10. 绘画Resnet训练和验证准确率的曲线图

# Training Reseults: Resnet50
plt.figure()
resnet50_train_acc = history_resnet50['train_acc']
resnet50_val_acc = history_resnet50['val_acc']
resnet50_epoch = range(0, len(resnet50_train_acc), 1)
plot5, = plt.plot(resnet50_epoch, resnet50_train_acc, linestyle = "solid", color = "skyblue")
plot6, = plt.plot(resnet50_epoch, resnet50_val_acc, linestyle = "dashed", color = "orange")
plt.legend([plot5, plot6], ['training acc', 'validation acc'])
plt.xlabel('Epoch')
plt.ylabel('Average Accuracy per Batch')
plt.title('Model Resnet50: Training and Validation Accuracy', pad = 20)
plt.savefig('Resnet50-Acc-Plot.png')

在这里插入图片描述

11. 绘画Resnet训练和验证损失的曲线图

plt.figure()
resnet50_train_loss = history_resnet50['train_loss']
resnet50_val_loss = history_resnet50['val_loss']
resnet50_epoch = range(0, len(resnet50_train_loss), 1)
plot7, = plt.plot(resnet50_epoch, resnet50_train_loss, linestyle = "solid", color = "skyblue")
plot8, = plt.plot(resnet50_epoch, resnet50_val_loss, linestyle = "dashed", color = "orange")
plt.legend([plot7, plot8], ['training loss', 'validation loss'])
plt.xlabel('Epoch')
plt.ylabel('Average Loss per Batch')
plt.title('Model Resnet50: Training and Validation Loss', pad = 20)
plt.savefig('Resnet50-Loss-Plot.png')

在这里插入图片描述

12. 绘画验证损失的比较图

plt.figure()
df_valid_loss = pd.DataFrame({'Epoch': range(0, N_EPOCHS, 1),'valid_loss_vgg': history_vgg['val_loss'],'valid_loss_resnet50':history_resnet50['val_loss']})
plota, = plt.plot('Epoch', 'valid_loss_vgg', data=df_valid_loss, linestyle = '--', color = 'skyblue')
plotb, = plt.plot('Epoch', 'valid_loss_resnet50', data=df_valid_loss, color = 'orange')
plt.xlabel('Epoch')
plt.ylabel('Average Validation Loss per Batch')
plt.title('Validation Loss Comparison', pad = 20)
plt.legend([plota, plotb], ['VGG16', 'Resnet50'])
plt.savefig('Result_Comparison.png')

在这里插入图片描述

13. 定义并执行预测函数

def predict(mymodel, model_name_pt, loader):model = mymodelmodel.load_state_dict(torch.load(model_name_pt))model.to(device)model.eval()y_actual_np = []y_pred_np = []for idx, data in enumerate(loader):test_x, test_label = data[0], data[1]test_x = test_x.to(device)y_actual_np.extend(test_label.cpu().numpy().tolist())with torch.no_grad():y_pred_logits = model(test_x)pred_labels = np.argmax(y_pred_logits.cpu().numpy(), axis=1)print("Predicting ---->", pred_labels)y_pred_np.extend(pred_labels.tolist())return y_actual_np, y_pred_npy_actual_vgg, y_predict_vgg = predict(model_vgg, "model_vgg.pt", test_loader)
y_actual_resnet50, y_predict_resnet50 = predict(model_resnet50, "model_resnet50.pt", test_loader)

14. 计算VGG的准确性和混淆矩阵

# VGG-16 Accuracy
print("=====================================================")
acc_rate_vgg = 100*accuracy_score(y_actual_vgg, y_predict_vgg)
print("The Accuracy rate for the VGG-16 model is: ", acc_rate_vgg)
# Confusion matrix for model-VGG-16
print("The Confusion Matrix for VGG-16 is as below:")
print(confusion_matrix(y_actual_vgg, y_predict_vgg))

输出为:

The Accuracy rate for the VGG-16 model is: 88.16666666666667
The Confusion Matrix for VGG-16 is as below:
[[7106 894]
[ 171 829]]

15. 计算Resnet的准确性和混淆矩阵

print("=====================================================")
acc_rate_resnet50 = 100*accuracy_score(y_actual_resnet50, y_predict_resnet50)
print("The Accuracy rate for the Resnet50 model is: ", acc_rate_resnet50)
# Confusion matrix for model Resnet50
print("The Confusion Matrix for Resnet50 is as below:")
print(confusion_matrix(y_actual_resnet50, y_predict_resnet50))

输出为:

The Accuracy rate for the Resnet50 model is: 85.35555555555555
The Confusion Matrix for Resnet50 is as below:
[[6843 1157]
[ 161 839]]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/75061.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

CSS宽度问题

一、魔法 为 DOM 设置宽度有哪些方式呢?最常用的是配置width属性,width属性在配置时,也有多种方式: widthmin-widthmax-width 通常当配置了 width 时,不会再配置min-width max-width,如果将这三者混合使…

MySql 变量

1.系统变量 1.1 系统变量分类 变量由系统定义,不是用户定义,属于 服务器 层面。系统变量分为全局系统变量(需要添加 global 关键字)以及会话系统变量(需要添加 session 关键字),有时也把全局系…

Web安全——Web安全漏洞与利用上篇(仅供学习)

SQL注入 一、SQL 注入漏洞1、与 mysql 注入的相关知识2、SQL 注入原理3、判断是否存在注入回显是指页面有数据信息返回id 1 and 114、三种 sql 注释符5、注入流程6、SQL 注入分类7、接受请求类型区分8、注入数据类型的区分9、SQL 注入常规利用思路:10、手工注入常规…

MySQL的权限管理与远程访问

MySQL的权限管理 1、授予权限 授权命令: grant 权限1,权限2,…权限n on 数据库名称.表名称 to 用户名用户地址 identified by ‘连接口令’; 该权限如果发现没有该用户,则会直接新建一个用户。 比如 grant select,insert,delete,drop on atguigudb.…

驱动开发,stm32mp157a开发板的led灯控制实验

1.实验目的 编写LED灯的驱动,在应用程序中编写控制LED灯亮灭的代码逻辑实现LED灯功能的控制; 2.LED灯相关寄存器分析 LED1->PE10 LED1亮灭: RCC寄存器[4]->1 0X50000A28 GPIOE_MODER[21:20]->01 (输出) 0X50006000 GPIOE_ODR[10]-&g…

文件操作(个人学习笔记黑马学习)

C中对文件操作需要包含头文件<fstream > 文件类型分为两种: 1.文本文件&#xff1a;文件以文本的ASCII码形式存储在计算机中 2.二进制文件&#xff1a;文件以文本的二进制形式存储在计算机中&#xff0c;用户一般不能直接读懂它们 操作文件的三大类: 1.ofstream: 写操作 …

SpringBoot项目启动时预加载

SpringBoot项目启动时预加载 Spring Boot是一种流行的Java开发框架&#xff0c;它提供了许多方便的功能来简化应用程序的开发和部署。其中一个常见的需求是在Spring Boot应用程序启动时预加载一些数据或执行一些初始化操作。 1. CommandLineRunner 和 ApplicationRunner Spri…

Nginx 配置错误导致漏洞

文章目录 Nginx 配置错误导致漏洞1. 环境启动2. CRLF注入漏洞2.1 漏洞描述2.2 漏洞原理2.3 漏洞利用2.4 修复建议 3. 目录穿越漏洞3.1 漏洞描述3.2 漏洞原理3.3 漏洞利用3.4 修复建议 4. add_header被覆盖4.1 漏洞描述4.2 漏洞原理4.3 漏洞利用4.4 修复建议 Nginx 配置错误导致…

YOLO的基本原理详解

YOLO介绍 YOLO是一种新的目标检测方法。以前的目标检测方法通过重新利用分类器来执行检测。与先前的方案不同&#xff0c;将目标检测看作回归问题从空间上定位边界框&#xff08;bounding box&#xff09;并预测该框的类别概率。使用单个神经网络&#xff0c;在一次评估中直接…

C#winform导出DataGridView数据到Excel表

前提&#xff1a;NuGet安装EPPlus&#xff0c;选择合适的能兼容当前.net framwork的版本 主要代码&#xff1a; private void btn_export_Click(object sender, EventArgs e) {SaveFileDialog saveFileDialog new SaveFileDialog();saveFileDialog.Filter "Excel Files…

Mysql基于成本选择索引

本篇文章介绍mysql基于成本选择索引的行为&#xff0c;解释为什么有时候明明可以走索引&#xff0c;但mysql却没有走索引的原因 mysql索引失效的场景大致有几种 不符合最左前缀原则在索引列上使用函数或隐式类型转换使用like查询&#xff0c;如 %xxx回表代价太大索引列区分度过…

ElMessageBox.prompt 点击确认校验成功后关闭

ElMessageBox.prompt(, 验证取货码, {inputPattern: /^.{1,20}$/,inputErrorMessage: 请输入取货码,inputPlaceholder: 请输入取货码,beforeClose: (action, instance, done) > {if (action confirm) {if (instance.inputValue) {let flag false;if (flag) {done()} else …

pandas(四十三)Pandas实现复杂Excel的转置合并

一、Pandas实现复杂Excel的转置合并 读取并筛选第一张表 df1 pd.read_excel("第一个表.xlsx") df1# 删除无用列 df1 df1[[股票代码, 高数, 实际2]].copy() df1df1.dtypes股票代码 int64 高数 float64 实际2 int64 dtype: object读取并处理第二张表…

jmeter 数据库连接配置 JDBC Connection Configuration

jmeter 从数据库获取变量信息 官方文档参考&#xff1a; [jmeter安装路径]/printable_docs/usermanual/component_reference.html#JDBC_Connection_Configuration 引入数据库连接&#xff1a; 将MySQLjar包存放至jemter指定目录&#xff08;/apache-jmeter-3.3/lib&#xff09…

buuctf web 前5题

目录 一、[极客大挑战 2019]EasySQL 总结&#xff1a; 二、[极客大挑战 2019]Havefun 总结&#xff1a; 三、[HCTF 2018]WarmUp 总论&#xff1a; 四、[ACTF2020 新生赛]Include 总结&#xff1a; 五、[ACTF2020 新生赛]Exec 总结&#xff1a; 一、[极客大挑战 2019]…

有哪些适合初学者的编程语言?

C语言 那为什么我还要教你C语言呢&#xff1f;因为我想要让你成为一个更好、更强大的程序员。如果你要变得更好&#xff0c;C语言是一个极佳的选择&#xff0c;其原因有二。首先&#xff0c;C语言缺乏任何现代的安全功能&#xff0c;这意味着你必须更为警惕&#xff0c;时刻了…

Json“牵手”易贝商品详情数据方法,易贝商品详情API接口,易贝API申请指南

易贝是一个可让全球民众在网上买卖物品的线上拍卖及购物网站&#xff0c;易贝&#xff08;EBAY&#xff09;于1995年9月4日由Pierre Omidyar以Auctionweb的名称创立于加利福尼亚州圣荷塞。人们可以在易贝上通过网络出售商品。2014年2月20日&#xff0c;易贝宣布收购3D虚拟试衣公…

SpringMVC的简介及工作流程

一.简介 Spring MVC是一个基于Java的开发框架&#xff0c;用于构建灵活且功能强大的Web应用程序。它是Spring Framework的一部分&#xff0c;提供了一种模型-视图-控制器&#xff08;Model-View-Controller&#xff0c;MVC&#xff09;的设计模式&#xff0c;用于组织和管理Web…

MATLAB中isoutlier函数用法

目录 语法 说明 示例 检测向量中的离群值 使用均值检测方法 使用移窗检测法 检测矩阵中的离群值 可视化离群值阈值 isoutlier函数的功能是查找数据中的离群值 语法 TF isoutlier(A) TF isoutlier(A,method) TF isoutlier(A,"percentiles",threshold) TF…

controller接口上带@PreAuthorize的注解如何访问 (postman请求示例)

1. 访问接口 /*** 查询时段列表*/RateLimiter(time 10,count 10)ApiOperation("查询时段列表")PreAuthorize("ss.hasPermi(ls/sy:time:list)")GetMapping("/list")public TableDataInfo list(LsTime lsTime){startPage();List<LsTime> l…