CNN实现卫星图像分类(tensorflow)

使用的数据集卫星图像有两类,airplane和lake,每个类别样本量各700张,大小为256*256,RGB三通道彩色卫星影像。搭建深度卷积神经网络,实现卫星影像二分类。
数据链接百度网盘地址,提取码: cq47

1、查看tensorflow版本

import tensorflow as tfprint('Tensorflow Version:{}'.format(tf.__version__))
print(tf.config.list_physical_devices())

在这里插入图片描述

2、加载并显示训练数据

从文件夹中获取所有数据路径

import glob
import randomall_image_path = glob.glob('./data/air_lake_dataset/*/*.jpg')  # glob相比于pathlib更简洁
random.shuffle(all_image_path)

读取并处理图像

def load_and_preprocess_image(path):img_raw = tf.io.read_file(path)img_tensor = tf.image.decode_jpeg(img_raw,channels=3)img_tensor = tf.image.resize(img_tensor,[256,256])img_tensor = tf.cast(img_tensor,tf.float32)img_tensor = img_tensor/255return img_tensor

处理标签

label_to_index = {'airplane':0,'lake':1}
index_to_label = dict((v,k) for k,v in label_to_index.items())
labels = [label_to_index.get(img.split('/')[3]) for img in all_image_path]

显示卫星影像

import matplotlib.pyplot as pltdef plot_images_lables(all_image_path,labels,start_idx,num=5):fig = plt.gcf()fig.set_size_inches(12,14)images = [load_and_preprocess_image(img_path) for img_path in all_image_path[start_idx:start_idx+5]]for i in range(num):ax = plt.subplot(1,num,1+i)ax.imshow(images[i])title = 'label=' + index_to_label.get(labels[start_idx+i])ax.set_title(title,fontsize=10)ax.set_xticks([])ax.set_yticks([])plt.show()plot_images_lables(all_image_path,labels,0,5)

在这里插入图片描述

4、使用tf.data.Dataset制作训练/测试数据

制作 Dataset

img_ds = tf.data.Dataset.from_tensor_slices(all_image_path)
img_ds = img_ds.map(load_and_preprocess_image)
label_ds = tf.data.Dataset.from_tensor_slices(labels)
img_label_ds = tf.data.Dataset.zip((img_ds,label_ds))

训练集、测试集划分

test_count = int(len(labels)*0.2) 
train_count = len(labels) - test_counttrain_ds = img_label_ds.skip(test_count)
test_ds = img_label_ds.take(test_count)

分批次加载数据

BATCH_SIZE = 16
train_ds = train_ds.repeat().shuffle(100).batch(BATCH_SIZE)
test_ds = test_ds.repeat().batch(BATCH_SIZE)

5、CNN模型构建

from keras.layers import Input,Dense,Dropout
from keras.layers import Conv2D,MaxPool2D,GlobalAvgPool2Dmodel = tf.keras.Sequential([Input(shape=(256,256,3)),Conv2D(filters=64,kernel_size=(3,3),activation='relu',padding='same'),  # 增加filter个数,增加模型拟合能力Conv2D(filters=64,kernel_size=(3,3),activation='relu',padding='same'),MaxPool2D(),  # 默认2*2. 池化层扩大视野Dropout(0.2),  # 防止过拟合Conv2D(filters=128,kernel_size=(3,3),activation='relu',padding='same'),Conv2D(filters=128,kernel_size=(3,3),activation='relu',padding='same'),MaxPool2D(),Dropout(0.2),Conv2D(filters=256,kernel_size=(3,3),activation='relu',padding='same'),Conv2D(filters=256,kernel_size=(3,3),activation='relu',padding='same'),MaxPool2D(),Dropout(0.2),Conv2D(filters=512,kernel_size=(3,3),activation='relu',padding='same'),Conv2D(filters=512,kernel_size=(3,3),activation='relu',padding='same'),GlobalAvgPool2D(),  # 全局平均池化Dense(1024,activation='relu'),Dense(256,activation='relu'),Dense(1,activation='sigmoid') 
])model.summary()

在这里插入图片描述

6、模型编译与训练

model.compile(optimizer=tf.keras.optimizers.Adam(0.0001),loss=tf.keras.losses.BinaryCrossentropy(from_logits=False),  # 已经使用sigmoid激活过了metrics=['acc'])steps_per_epoch = train_count//BATCH_SIZE
val_step = test_count//BATCH_SIZEH = model.fit(train_ds,epochs=10,steps_per_epoch=steps_per_epoch,validation_data=test_ds,validation_steps=val_step,verbose=1)

在这里插入图片描述

7、模型评估

import matplotlib.pyplot as pltfig = plt.gcf()
fig.set_size_inches(12,4)
plt.subplot(1,2,1)
plt.plot(H.epoch, H.history['loss'], label='loss')
plt.plot(H.epoch, H.history['val_loss'], label='val_loss')
plt.legend()
plt.title('loss')plt.subplot(1,2,2)
plt.plot(H.epoch, H.history['acc'], label='acc')
plt.plot(H.epoch, H.history['val_acc'], label='val_acc')
plt.legend()
plt.title('acc')
plt.show()

在这里插入图片描述

8、模型预测

def pred_img(img_path):img = load_and_preprocess_image(img_path)img = tf.expand_dims(img, axis=0)pred = model.predict(img)pred = index_to_label.get((pred>0.5).astype('int')[0][0])return predimg_path = './data/air_lake_dataset/airplane/airplane_240.jpg'
pred = pred_img(img_path)
img_tensor = load_and_preprocess_image(img_path)
plt.imshow(img_tensor)
title = 'label=' + img_path.split('/')[3].strip() + ', pred=' + pred
plt.title(title)
plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/832131.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

CentOS常用命令有哪些?

目录 一、CentOS常用命令有哪些? 二、不熟悉命令怎么办? 场景一:如果是文件操作,可以使用FileZilla工具来完成 场景二:安装CentOS桌面 一、CentOS常用命令有哪些? CentOS 系统中有许多常用命令及其用法…

leetcode尊享面试100题(549二叉树最长连续序列||,python)

题目不长,就是分析时间太久了。 思路使用dfs深度遍历,先想好这个函数返回什么,题目给出路径可以是子-父-子的路径,那么1-2-3可以,3-2-1也可以,那么考虑dfs返回两个值,对于当前节点node来说&…

JavaScript —— APIs(五)

一、Window对象 1. BOM(浏览器对象模型) 2. 定时器-延时函数 ①、定义 ②、定时器比较 ③、【案例】 3. JS执行机制 4. location对象 注意:hash应用 不点击页面刷新号,点击刷新按钮也可以实现页面刷新 【案例】 5. navig…

电机控制系列模块解析(16)—— 电流环

一、FOC为什么使用串联控制器 在此说明,串联形式(内外环形式,速度环和电流环控制器串联)并不是必须的,但是对于线性控制系统来说,电机属于非线性控制对象,早期工程师们为了处理电机的非线性&am…

【ARM】ARM寄存器和异常处理

1.指令的执行过程 (1)一条指令的执行分为三个阶段 1.取址: CPU将PC寄存器中的地址发送给内存,内存将其地址中对应的指令返回 到CPU中的指令寄存器(IR) 2.译码: 译码器对IR中的指令…

神经网络中的算法优化(皮毛讲解)

抛砖引玉 在深度学习中,优化算法是训练神经网络时至关重要的一部分。 优化算法的目标是最小化(或最大化)一个损失函数,通常通过调整神经网络的参数来实现。 这个过程可以通过梯度下降法来完成,其中梯度指的是损失函数…

Grafana:云原生时代的数据可视化与监控王者

🐇明明跟你说过:个人主页 🏅个人专栏:《Grafana:让数据说话的魔术师》 🏅 🔖行路有良友,便是天堂🔖 目录 一、引言 1、Grafana简介 2、Grafana的重要性与影响力 …

全方位了解 Meta Llama 3

本文将为您提供 Llama 3 的全面概览,从其架构、性能到未来的发展方向,让您一文了解这一革命性大语言模型的所有要点。 Meta Llama 发展历程 Llama 1 Llama 是由 Meta(FaceBook) AI 发布的一个开源项目,允许商用,影响力巨大。Lla…

力扣每日一题111:二叉树的最小深度

题目 简单 给定一个二叉树,找出其最小深度。 最小深度是从根节点到最近叶子节点的最短路径上的节点数量。 说明:叶子节点是指没有子节点的节点。 示例 1: 输入:root [3,9,20,null,null,15,7] 输出:2示例 2&#x…

C语言:文件操作(上)

片头 嗨!小伙伴们,今天我们来学习新的知识----文件操作,准备好了吗?我要开始咯! 目录 1. 为什么使用文件? 2. 什么是文件? 3. 二进制文件和文本文件? 4. 文件的打开和关闭 5. 文件顺序读写…

启发式算法解魔方——python

未完待续,填坑ing…… 魔方操作的表示——辛马斯特标记 辛马斯特标记(Singmaster Notation)是一种用于描述魔方和类似拼图的转动操作的标记系统。它以大卫辛马斯特(David Singmaster)的名字命名,辛马斯特…

C 认识指针

目录 一、取地址操作符(&) 二、解引用操作符(*) 三、指针变量 1、 指针变量的大小 2、 指针变量类型的意义 2.1 指针的解引用 2.2 指针 - 整数 2.3 调试解决疑惑 认识指针,指针比较害羞内敛,我们…

单调栈-java

本次主要通过数组模拟单调栈来解决问题。 目录 一、单调栈☀ 二、算法思路☀ 1.暴力做法🌙 2.优化做法🌙 3.单调递增栈和单调递减栈🌙 三、代码如下☀ 1.代码如下(示例):🌙 2.读入数据&a…

Ubuntu MATE系统下WPS显示错位

系统:Ubuntu MATE 22.04和24.04,在显示器设置200%放大的情况下,显示错位。 显示器配置: WPS显示错位: 这个问题当前没有找到好的解决方式。 因为4K显示屏设置4K分辨率,图标,字体太小&#xff…

prometheus搭建

1.prometheus下载 下载地址:Download | Prometheus 请下载LTS稳定版本 本次prometheus搭建使用prometheus-2.37.1.linux-amd64.tar.gz版本 2.上传prometheus-2.37.1.linux-amd64.tar.gz至服务器/opt目录 CentOS7.9 使用命令rz -byE上传 3.解压缩prometheus-2.37.1.linux…

【C++之map的应用】

C学习笔记---021 C之map的应用1、map的简单介绍1.1、基本概念1.2、map基本特性 2、map的基本操作2.1、插入元素2.2、访问元素2.3、删除元素2.4、遍历map2.5、检查元素是否存在2.6、获取map的大小2.7、清空map2.8、基本样例 3、map的基础模拟实现4、测试用例4.1、插入和遍历4.2、…

Flutter笔记:Widgets Easier组件库(11)- 使用提示吐丝

Flutter笔记 Widgets Easier组件库(11)使用提示吐丝 - 文章信息 - Author: 李俊才 (jcLee95) Visit me at CSDN: https://jclee95.blog.csdn.netMy WebSite:http://thispage.tech/Email: 291148484163.com. Shenzhen ChinaAddress of this …

【多模态】29、OCRBench | 为大型多模态模型提供一个 OCR 任务测评基准

文章目录 一、背景二、实验2.1 测评标准和结果2.1.1 文本识别 Text Recognition2.1.2 场景文本中心的视觉问答 Scene Text-Centric VQA2.1.3 文档导向的视觉问答 Document-Oriented VQA2.1.4 关键信息提取 Key Information Extraction2.1.5 手写数学公式识别 Handwritten Mathe…

Ubuntu安装配置网络

参考 https://blog.csdn.net/qq_59633155/article/details/131252293https://blog.csdn.net/qq_59633155/article/details/131252293 Ubuntu配置网络 1,查看网络是否连接 终端输入 ping baidu.com 如若成功则如下图所示 如未能成功,则继续按下面步骤…

解决HTTP 403 Forbidden错误:禁止访问目录索引问题的解决方法

解决HTTP 403 Forbidden错误:禁止访问目录索引问题的解决方法 过去有人曾对我说,“一个人爱上小溪,是因为没有见过大海。”而如今我终于可以说,“我已见过银河,但我仍只爱你一颗星。” 在Web开发和服务器管理中&#x…