CNN实现卫星图像分类(tensorflow)

使用的数据集卫星图像有两类,airplane和lake,每个类别样本量各700张,大小为256*256,RGB三通道彩色卫星影像。搭建深度卷积神经网络,实现卫星影像二分类。
数据链接百度网盘地址,提取码: cq47

1、查看tensorflow版本

import tensorflow as tfprint('Tensorflow Version:{}'.format(tf.__version__))
print(tf.config.list_physical_devices())

在这里插入图片描述

2、加载并显示训练数据

从文件夹中获取所有数据路径

import glob
import randomall_image_path = glob.glob('./data/air_lake_dataset/*/*.jpg')  # glob相比于pathlib更简洁
random.shuffle(all_image_path)

读取并处理图像

def load_and_preprocess_image(path):img_raw = tf.io.read_file(path)img_tensor = tf.image.decode_jpeg(img_raw,channels=3)img_tensor = tf.image.resize(img_tensor,[256,256])img_tensor = tf.cast(img_tensor,tf.float32)img_tensor = img_tensor/255return img_tensor

处理标签

label_to_index = {'airplane':0,'lake':1}
index_to_label = dict((v,k) for k,v in label_to_index.items())
labels = [label_to_index.get(img.split('/')[3]) for img in all_image_path]

显示卫星影像

import matplotlib.pyplot as pltdef plot_images_lables(all_image_path,labels,start_idx,num=5):fig = plt.gcf()fig.set_size_inches(12,14)images = [load_and_preprocess_image(img_path) for img_path in all_image_path[start_idx:start_idx+5]]for i in range(num):ax = plt.subplot(1,num,1+i)ax.imshow(images[i])title = 'label=' + index_to_label.get(labels[start_idx+i])ax.set_title(title,fontsize=10)ax.set_xticks([])ax.set_yticks([])plt.show()plot_images_lables(all_image_path,labels,0,5)

在这里插入图片描述

4、使用tf.data.Dataset制作训练/测试数据

制作 Dataset

img_ds = tf.data.Dataset.from_tensor_slices(all_image_path)
img_ds = img_ds.map(load_and_preprocess_image)
label_ds = tf.data.Dataset.from_tensor_slices(labels)
img_label_ds = tf.data.Dataset.zip((img_ds,label_ds))

训练集、测试集划分

test_count = int(len(labels)*0.2) 
train_count = len(labels) - test_counttrain_ds = img_label_ds.skip(test_count)
test_ds = img_label_ds.take(test_count)

分批次加载数据

BATCH_SIZE = 16
train_ds = train_ds.repeat().shuffle(100).batch(BATCH_SIZE)
test_ds = test_ds.repeat().batch(BATCH_SIZE)

5、CNN模型构建

from keras.layers import Input,Dense,Dropout
from keras.layers import Conv2D,MaxPool2D,GlobalAvgPool2Dmodel = tf.keras.Sequential([Input(shape=(256,256,3)),Conv2D(filters=64,kernel_size=(3,3),activation='relu',padding='same'),  # 增加filter个数,增加模型拟合能力Conv2D(filters=64,kernel_size=(3,3),activation='relu',padding='same'),MaxPool2D(),  # 默认2*2. 池化层扩大视野Dropout(0.2),  # 防止过拟合Conv2D(filters=128,kernel_size=(3,3),activation='relu',padding='same'),Conv2D(filters=128,kernel_size=(3,3),activation='relu',padding='same'),MaxPool2D(),Dropout(0.2),Conv2D(filters=256,kernel_size=(3,3),activation='relu',padding='same'),Conv2D(filters=256,kernel_size=(3,3),activation='relu',padding='same'),MaxPool2D(),Dropout(0.2),Conv2D(filters=512,kernel_size=(3,3),activation='relu',padding='same'),Conv2D(filters=512,kernel_size=(3,3),activation='relu',padding='same'),GlobalAvgPool2D(),  # 全局平均池化Dense(1024,activation='relu'),Dense(256,activation='relu'),Dense(1,activation='sigmoid') 
])model.summary()

在这里插入图片描述

6、模型编译与训练

model.compile(optimizer=tf.keras.optimizers.Adam(0.0001),loss=tf.keras.losses.BinaryCrossentropy(from_logits=False),  # 已经使用sigmoid激活过了metrics=['acc'])steps_per_epoch = train_count//BATCH_SIZE
val_step = test_count//BATCH_SIZEH = model.fit(train_ds,epochs=10,steps_per_epoch=steps_per_epoch,validation_data=test_ds,validation_steps=val_step,verbose=1)

在这里插入图片描述

7、模型评估

import matplotlib.pyplot as pltfig = plt.gcf()
fig.set_size_inches(12,4)
plt.subplot(1,2,1)
plt.plot(H.epoch, H.history['loss'], label='loss')
plt.plot(H.epoch, H.history['val_loss'], label='val_loss')
plt.legend()
plt.title('loss')plt.subplot(1,2,2)
plt.plot(H.epoch, H.history['acc'], label='acc')
plt.plot(H.epoch, H.history['val_acc'], label='val_acc')
plt.legend()
plt.title('acc')
plt.show()

在这里插入图片描述

8、模型预测

def pred_img(img_path):img = load_and_preprocess_image(img_path)img = tf.expand_dims(img, axis=0)pred = model.predict(img)pred = index_to_label.get((pred>0.5).astype('int')[0][0])return predimg_path = './data/air_lake_dataset/airplane/airplane_240.jpg'
pred = pred_img(img_path)
img_tensor = load_and_preprocess_image(img_path)
plt.imshow(img_tensor)
title = 'label=' + img_path.split('/')[3].strip() + ', pred=' + pred
plt.title(title)
plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/832131.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

CentOS常用命令有哪些?

目录 一、CentOS常用命令有哪些? 二、不熟悉命令怎么办? 场景一:如果是文件操作,可以使用FileZilla工具来完成 场景二:安装CentOS桌面 一、CentOS常用命令有哪些? CentOS 系统中有许多常用命令及其用法…

ApacheCordova 12 +Vs 2022 项目搭建教程_开发环境搭建教程

一、安装 cordova cli 并使用命令创建项目 npm install –g cordova 详细参考: Apache Cordova开发环境搭建(二)VS Code_天马3798-CSDN博客_cordova vscode 二、 Vs 2022 Android 开发搭建+调试 .Net MAUI 搭建Android 开发环境-CSDN博客 三、配置 JDK 环境变量、配置…

leetcode尊享面试100题(549二叉树最长连续序列||,python)

题目不长,就是分析时间太久了。 思路使用dfs深度遍历,先想好这个函数返回什么,题目给出路径可以是子-父-子的路径,那么1-2-3可以,3-2-1也可以,那么考虑dfs返回两个值,对于当前节点node来说&…

JavaScript —— APIs(五)

一、Window对象 1. BOM(浏览器对象模型) 2. 定时器-延时函数 ①、定义 ②、定时器比较 ③、【案例】 3. JS执行机制 4. location对象 注意:hash应用 不点击页面刷新号,点击刷新按钮也可以实现页面刷新 【案例】 5. navig…

电机控制系列模块解析(16)—— 电流环

一、FOC为什么使用串联控制器 在此说明,串联形式(内外环形式,速度环和电流环控制器串联)并不是必须的,但是对于线性控制系统来说,电机属于非线性控制对象,早期工程师们为了处理电机的非线性&am…

【ARM】ARM寄存器和异常处理

1.指令的执行过程 (1)一条指令的执行分为三个阶段 1.取址: CPU将PC寄存器中的地址发送给内存,内存将其地址中对应的指令返回 到CPU中的指令寄存器(IR) 2.译码: 译码器对IR中的指令…

一则不知从何谈起的故事

我觉得我一直很矛盾很迷茫 来到这个世界 这个美丽的世界二十年载了吧 也就是你常说的这个美丽的世界 它真的很美丽不是吗 花花草草 还有小猫小狗 依稀的记得 是23年的上半年 我不知道怎么的 心理出现了很大的问题 当时差点从学校旁边去中港的那天大桥上跳了下去 一开始也…

神经网络中的算法优化(皮毛讲解)

抛砖引玉 在深度学习中,优化算法是训练神经网络时至关重要的一部分。 优化算法的目标是最小化(或最大化)一个损失函数,通常通过调整神经网络的参数来实现。 这个过程可以通过梯度下降法来完成,其中梯度指的是损失函数…

Grafana:云原生时代的数据可视化与监控王者

🐇明明跟你说过:个人主页 🏅个人专栏:《Grafana:让数据说话的魔术师》 🏅 🔖行路有良友,便是天堂🔖 目录 一、引言 1、Grafana简介 2、Grafana的重要性与影响力 …

objdump命令解析

一、名称 objdump-显示目标文件的工具 二、简介 objdump [-a|--archive-headers] [-b bfdname|--targetbfdname] [-C|--demangle[style] ] [-d|--disassemble[symbol]] [-D|--disassemble-all] …

全方位了解 Meta Llama 3

本文将为您提供 Llama 3 的全面概览,从其架构、性能到未来的发展方向,让您一文了解这一革命性大语言模型的所有要点。 Meta Llama 发展历程 Llama 1 Llama 是由 Meta(FaceBook) AI 发布的一个开源项目,允许商用,影响力巨大。Lla…

Terraform数据类型

概括地说,Terraform的数据类型分为两种:原始类型,复杂类型。 原始类型 原始类型包含3个:string,number,bool。 string:表示一组Unicode字符,例如:”hello”number&…

力扣每日一题111:二叉树的最小深度

题目 简单 给定一个二叉树,找出其最小深度。 最小深度是从根节点到最近叶子节点的最短路径上的节点数量。 说明:叶子节点是指没有子节点的节点。 示例 1: 输入:root [3,9,20,null,null,15,7] 输出:2示例 2&#x…

C语言:文件操作(上)

片头 嗨!小伙伴们,今天我们来学习新的知识----文件操作,准备好了吗?我要开始咯! 目录 1. 为什么使用文件? 2. 什么是文件? 3. 二进制文件和文本文件? 4. 文件的打开和关闭 5. 文件顺序读写…

启发式算法解魔方——python

未完待续,填坑ing…… 魔方操作的表示——辛马斯特标记 辛马斯特标记(Singmaster Notation)是一种用于描述魔方和类似拼图的转动操作的标记系统。它以大卫辛马斯特(David Singmaster)的名字命名,辛马斯特…

C 认识指针

目录 一、取地址操作符(&) 二、解引用操作符(*) 三、指针变量 1、 指针变量的大小 2、 指针变量类型的意义 2.1 指针的解引用 2.2 指针 - 整数 2.3 调试解决疑惑 认识指针,指针比较害羞内敛,我们…

单调栈-java

本次主要通过数组模拟单调栈来解决问题。 目录 一、单调栈☀ 二、算法思路☀ 1.暴力做法🌙 2.优化做法🌙 3.单调递增栈和单调递减栈🌙 三、代码如下☀ 1.代码如下(示例):🌙 2.读入数据&a…

thinkphp5 配合阿里直播实现直播功能流程

要为你提供一个更详细的教程来结合ThinkPHP 5和阿里直播SDK实现直播功能,需要涵盖的内容相对较多。不过,我可以为你提供一个大致的、更详细的步骤指南,供你参考和扩展: 1. 准备工作 a. 注册阿里云账号 前往阿里云官网注册账号&…

Ubuntu MATE系统下WPS显示错位

系统:Ubuntu MATE 22.04和24.04,在显示器设置200%放大的情况下,显示错位。 显示器配置: WPS显示错位: 这个问题当前没有找到好的解决方式。 因为4K显示屏设置4K分辨率,图标,字体太小&#xff…

LeetCode题目100:递归、迭代、dfs使用栈多种算法图解相同的树

题目描述 给定两个二叉树的根节点 p 和 q,编写一个函数来检测这两棵树是否相同。如果两棵树在结构上相同,并且节点具有相同的值,则认为它们是相同的。 示例 示例 1 输入:p [1,2,3], q [1,2,3] 输出:True 示例 2…