卷积神经网络(CNN)天气识别

文章目录

  • 前期工作
    • 1. 设置GPU(如果使用的是CPU可以忽略这步)
      • 我的环境:
    • 2. 导入数据
    • 3. 查看数据
  • 二、数据预处理
    • 1. 加载数据
    • 2. 可视化数据
    • 3. 再次检查数据
    • 4. 配置数据集
  • 三、构建CNN网络
  • 四、编译
  • 五、训练模型
  • 六、模型评估

前期工作

1. 设置GPU(如果使用的是CPU可以忽略这步)

我的环境:

  • 语言环境:Python3.6.5
  • 编译器:jupyter notebook
  • 深度学习环境:TensorFlow2.4.1
import tensorflow as tfgpus = tf.config.list_physical_devices("GPU")if gpus:gpu0 = gpus[0]                                        #如果有多个GPU,仅使用第0个GPUtf.config.experimental.set_memory_growth(gpu0, True)  #设置GPU显存用量按需使用tf.config.set_visible_devices([gpu0],"GPU")

2. 导入数据

import matplotlib.pyplot as plt
import os,PIL# 设置随机种子尽可能使结果可以重现
import numpy as np
np.random.seed(1)# 设置随机种子尽可能使结果可以重现
import tensorflow as tf
tf.random.set_seed(1)from tensorflow import keras
from tensorflow.keras import layers,modelsimport pathlib
data_dir = "weather_photos/"
data_dir = pathlib.Path(data_dir)

3. 查看数据

数据集一共分为cloudyrainshinesunrise四类,分别存放于weather_photos文件夹中以各自名字命名的子文件夹中。

image_count = len(list(data_dir.glob('*/*.jpg')))print("图片总数为:",image_count)
roses = list(data_dir.glob('sunrise/*.jpg'))
PIL.Image.open(str(roses[0]))

在这里插入图片描述

二、数据预处理

1. 加载数据

使用image_dataset_from_directory方法将磁盘中的数据加载到tf.data.Dataset

batch_size = 32
img_height = 180
img_width = 180
train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="training",seed=123,image_size=(img_height, img_width),batch_size=batch_size)
Found 1125 files belonging to 4 classes.
Using 900 files for training.
val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="validation",seed=123,image_size=(img_height, img_width),batch_size=batch_size)
Found 1125 files belonging to 4 classes.
Using 225 files for validation.

我们可以通过class_names输出数据集的标签。标签将按字母顺序对应于目录名称。

class_names = train_ds.class_names
print(class_names)
['cloudy', 'rain', 'shine', 'sunrise']

2. 可视化数据

plt.figure(figsize=(20, 10))for images, labels in train_ds.take(1):for i in range(20):ax = plt.subplot(5, 10, i + 1)plt.imshow(images[i].numpy().astype("uint8"))plt.title(class_names[labels[i]])plt.axis("off")

在这里插入图片描述

3. 再次检查数据

for image_batch, labels_batch in train_ds:print(image_batch.shape)print(labels_batch.shape)break
(32, 180, 180, 3)
(32,)
  • Image_batch是形状的张量(32,180,180,3)。这是一批形状180x180x3的32张图片(最后一维指的是彩色通道RGB)。
  • Label_batch是形状(32,)的张量,这些标签对应32张图片

4. 配置数据集

AUTOTUNE = tf.data.AUTOTUNEtrain_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

三、构建CNN网络

卷积神经网络(CNN)的输入是张量 (Tensor) 形式的 (image_height, image_width, color_channels),包含了图像高度、宽度及颜色信息。不需要输入batch size。color_channels 为 (R,G,B) 分别对应 RGB 的三个颜色通道(color channel)。在此示例中,我们的 CNN 输入,fashion_mnist 数据集中的图片,形状是 (28, 28, 1)即灰度图像。我们需要在声明第一层时将形状赋值给参数input_shape

num_classes = 4model = models.Sequential([layers.experimental.preprocessing.Rescaling(1./255, input_shape=(img_height, img_width, 3)),layers.Conv2D(16, (3, 3), activation='relu', input_shape=(img_height, img_width, 3)), # 卷积层1,卷积核3*3  layers.AveragePooling2D((2, 2)),               # 池化层1,2*2采样layers.Conv2D(32, (3, 3), activation='relu'),  # 卷积层2,卷积核3*3layers.AveragePooling2D((2, 2)),               # 池化层2,2*2采样layers.Conv2D(64, (3, 3), activation='relu'),  # 卷积层3,卷积核3*3layers.Dropout(0.3),  layers.Flatten(),                       # Flatten层,连接卷积层与全连接层layers.Dense(128, activation='relu'),   # 全连接层,特征进一步提取layers.Dense(num_classes)               # 输出层,输出预期结果
])model.summary()  # 打印网络结构
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
rescaling (Rescaling)        (None, 180, 180, 3)       0         
_________________________________________________________________
conv2d (Conv2D)              (None, 178, 178, 16)      448       
_________________________________________________________________
average_pooling2d (AveragePo (None, 89, 89, 16)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 87, 87, 32)        4640      
_________________________________________________________________
average_pooling2d_1 (Average (None, 43, 43, 32)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 41, 41, 64)        18496     
_________________________________________________________________
dropout (Dropout)            (None, 41, 41, 64)        0         
_________________________________________________________________
flatten (Flatten)            (None, 107584)            0         
_________________________________________________________________
dense (Dense)                (None, 128)               13770880  
_________________________________________________________________
dense_1 (Dense)              (None, 5)                 645       
=================================================================
Total params: 13,795,109
Trainable params: 13,795,109
Non-trainable params: 0
_________________________________________________________________

四、编译

  • 在准备对模型进行训练之前,还需要再对其进行一些设置。以下内容是在模型的编译步骤中添加的:
    • 损失函数(loss):用于衡量模型在训练期间的准确率。
    • 优化器(optimizer):决定模型如何根据其看到的数据和自身的损失函数进行更新。
    • 指标(metrics):用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率。
# 设置优化器
opt = tf.keras.optimizers.Adam(learning_rate=0.001)
model.compile(optimizer=opt,loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])

五、训练模型

epochs = 10
history = model.fit(train_ds,validation_data=val_ds,epochs=epochs
)
Epoch 1/10
29/29 [==============================] - 6s 58ms/step - loss: 1.5865 - accuracy: 0.4463 - val_loss: 0.5837 - val_accuracy: 0.7689
Epoch 2/10
29/29 [==============================] - 0s 12ms/step - loss: 0.5289 - accuracy: 0.8295 - val_loss: 0.5405 - val_accuracy: 0.8133
Epoch 3/10
29/29 [==============================] - 0s 12ms/step - loss: 0.2930 - accuracy: 0.8967 - val_loss: 0.5364 - val_accuracy: 0.8000
Epoch 4/10
29/29 [==============================] - 0s 12ms/step - loss: 0.2742 - accuracy: 0.9074 - val_loss: 0.4034 - val_accuracy: 0.8267
Epoch 5/10
29/29 [==============================] - 0s 11ms/step - loss: 0.1952 - accuracy: 0.9383 - val_loss: 0.3874 - val_accuracy: 0.8844
Epoch 6/10
29/29 [==============================] - 0s 11ms/step - loss: 0.1592 - accuracy: 0.9468 - val_loss: 0.3680 - val_accuracy: 0.8756
Epoch 7/10
29/29 [==============================] - 0s 12ms/step - loss: 0.0836 - accuracy: 0.9755 - val_loss: 0.3429 - val_accuracy: 0.8756
Epoch 8/10
29/29 [==============================] - 0s 12ms/step - loss: 0.0943 - accuracy: 0.9692 - val_loss: 0.3836 - val_accuracy: 0.9067
Epoch 9/10
29/29 [==============================] - 0s 12ms/step - loss: 0.0344 - accuracy: 0.9909 - val_loss: 0.3578 - val_accuracy: 0.9067
Epoch 10/10
29/29 [==============================] - 0s 11ms/step - loss: 0.0950 - accuracy: 0.9708 - val_loss: 0.4710 - val_accuracy: 0.8356

六、模型评估

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']loss = history.history['loss']
val_loss = history.history['val_loss']epochs_range = range(epochs)plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/148279.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

EEPROM与Flash的区别

EEPROM与Flash的区别 EEPROMEEPROM内部功能框图实现写入数据内部结构存储管在充电或放电状态下有着不同的阈值电压 问题点EEPROM是如何失效的呢?为何EEPROM不能做大呢? ------------------------------------------------------------------------------…

Java多线程(3)

Java多线程(3) 深入剖析Java线程的生命周期,探秘JVM的线程状态! 线程的生命周期 Java 线程的生命周期主要包括五个阶段:新建、就绪、运行、阻塞和销毁。 **新建(New):**线程对象通过 new 关键字创建&…

tamarin运行

首先我们找到安装tamarin的文件位置,找到以后进入该文件夹下 ubuntuubuntu:~$ sudo find / -name tamarin-prover /home/linuxbrew/.linuxbrew/var/homebrew/linked/tamarin-prover /home/linuxbrew/.linuxbrew/Cellar/tamarin-prover /home/linuxbrew/.linuxbrew/…

mac下vue-cli从2.9.6升级到最新版本

由于mac之前安装了 vue 2.9.6 的版本,现在想升级到最新版本,用官方给的命令: npm uninstall vue-cli -g 发现不行。 1、究其原因:从vue-cli 3.0版本开始原来的npm install -g vue-cli 安装的都是旧版,最高到2.9.6。安…

基于Netty实现的简单聊天服务组件

目录 基于Netty实现的简单聊天服务组件效果展示技术选型:功能分析聊天服务基础设施配置(基于Netty)定义组件基础的配置(ChatProperties)定义聊天服务类(ChatServer)定义聊天服务配置初始化类&am…

后端接口错误总结

今天后端错误总结: 1.ConditionalOnExpression(“${spring.kafka.exclusive-group.enable:false}”) 这个标签负责加载Bean,因此这个位置必须打开,如果这个标签不打开就会报错 问题解决:这里的配置在application.yml文件中 kaf…

Linux Docker图形化工具Portainer如何进行远程访问?

文章目录 前言1. 部署Portainer2. 本地访问Portainer3. Linux 安装cpolar4. 配置Portainer 公网访问地址5. 公网远程访问Portainer6. 固定Portainer公网地址 前言 Portainer 是一个轻量级的容器管理工具,可以通过 Web 界面对 Docker 容器进行管理和监控。它提供了可…

Flutter最新稳定版3.16 新特性介绍

Flutter 3.16 默认采用 Material 3 主题,Android 平台预览 Impeller,DevTools 扩展等等 欢迎回到每季度一次的 Flutter 稳定版本发布,这次是 Flutter 3.16。这个版本将 Material 3 设为新的默认主题,为 Android 带来 Impeller 预览…

SpringBoot使用DevTools实现后端热部署

📑前言 本文主要SpringBoot通过DevTools实现热部署的文章,如果有什么需要改进的地方还请大佬指出⛺️ 🎬作者简介:大家好,我是青衿🥇 ☁️博客首页:CSDN主页放风讲故事 🌄每日一句&…

Windows使用ssh远程连接(虚拟机)Linux(Ubuntu)的方法

步骤 1.Windows下载一个SSH客户端软件 要使用SSH连接,当然得先有一个好用的客户端软件才方便。 我这里使用的是WindTerm,一个开源免费的SSH连接工具,用什么软件不是重点。 这里默认你已经生成过SSH的密钥了,如果没有&#xff0c…

C语言 字符函数汇总,模拟实现各字符函数(炒鸡详细)

目录 求字符串长度 strlen 示例 模拟实现strlen 长度不受限制的字符串函数 strcpy 示例 模拟实现strcpy strcat 模拟实现strcat strcmp 示例 模拟实现strcmp 长度受限制的字符串函数介绍 strncpy 示例 模拟实现strncpy strncat 示例 模拟实现strncat s…

Spring Boot 中使用 ResourceLoader 加载资源的完整示例

ResourceLoader 是 Spring 框架中用于加载资源的接口。它定义了一系列用于获取资源的方法,可以处理各种资源,包括类路径资源、文件系统资源、URL 资源等。 以下是 ResourceLoader 接口的主要方法: Resource getResource(String location)&am…

【Hello Go】Go语言异常处理

Go语言异常处理 异常处理error接口panicrecover延时调用错误问题 异常处理 error接口 Go语言引入了一个关于错误处理的标准模式 它是Go语言内建的接口类型 它的定义如下 type error interface {Error() string }Go语言的标准库代码包errors为用户提供了以下方法 package e…

人工智能轨道交通行业周刊-第65期(2023.10.30-11.19)

本期关键词:高铁自主创新、智慧城轨、调车司机、大模型垂直应用、大模型幻觉 1 整理涉及公众号名单 1.1 行业类 RT轨道交通人民铁道世界轨道交通资讯网铁路信号技术交流北京铁路轨道交通网上榜铁路视点ITS World轨道交通联盟VSTR铁路与城市轨道交通RailMetro轨道…

Kafka快速入门

文章目录 Kafka快速入门1、相关概念介绍前言1.1 基本介绍1.2 常见消息队列的比较1.3 Kafka常见相关概念介绍 2、安装Kafka3、初体验前期准备编码测试配置介绍 bug记录 Kafka快速入门 1、相关概念介绍 前言 在当今信息爆炸的时代,实时数据处理已经成为许多应用程序和…

汽车虚拟仿真视频数据理解--CLIP模型原理

CLIP模型原理 CLIP的全称是Contrastive Language-Image Pre-Training,中文是对比语言-图像预训练,是一个预训练模型,简称为CLIP。该模型是 OpenAI 在 2021 年发布的,最初用于匹配图像和文本的预训练神经网络模型,这个任…

【Ubuntu】设置永不息屏与安装 dconf-editor

方式一、GUI界面进行设置 No LSB modules are available. Distributor ID: Ubuntu Description: Ubuntu 20.04.6 LTS Release: 20.04 Codename: focal打开 Ubuntu 桌面环境的设置菜单。你可以通过点击屏幕右上角的系统菜单,然后选择设置。在设置菜单中,…

警惕.360勒索病毒,您需要知道的预防和恢复方法。

引言: 网络威胁的演变无常,.360勒索病毒作为一种新兴的勒索软件,以其狡猾性备受关注。本文将深入介绍.360勒索病毒的特点,提供解决方案以恢复被其加密的数据,并分享一系列强化网络安全的预防措施。如果您在面对被勒索…

Vue中实现div的任意移动

前言 在系统应用中,像图片,流程预览及打印预览等情况,当前视窗无法全部显示要预览的全部内容,设置左右和上下滚动条后,如果用鼠标拖动滚动条,又不太便利,如何用鼠标随意的移动呢? …

wpf devexpress自定义编辑器

打开前一个例子 步骤1-自定义FirstName和LastName编辑器字段 如果运行程序,会通知编辑器是空。对于例子,这两个未命名编辑器在第一个LayoutItem(Name)。和最终用户有一个访客左右编辑器查阅到First Name和Last Name字段,分别。如果你看到Go…