卷积神经网络(CNN)天气识别

文章目录

  • 前期工作
    • 1. 设置GPU(如果使用的是CPU可以忽略这步)
      • 我的环境:
    • 2. 导入数据
    • 3. 查看数据
  • 二、数据预处理
    • 1. 加载数据
    • 2. 可视化数据
    • 3. 再次检查数据
    • 4. 配置数据集
  • 三、构建CNN网络
  • 四、编译
  • 五、训练模型
  • 六、模型评估

前期工作

1. 设置GPU(如果使用的是CPU可以忽略这步)

我的环境:

  • 语言环境:Python3.6.5
  • 编译器:jupyter notebook
  • 深度学习环境:TensorFlow2.4.1
import tensorflow as tfgpus = tf.config.list_physical_devices("GPU")if gpus:gpu0 = gpus[0]                                        #如果有多个GPU,仅使用第0个GPUtf.config.experimental.set_memory_growth(gpu0, True)  #设置GPU显存用量按需使用tf.config.set_visible_devices([gpu0],"GPU")

2. 导入数据

import matplotlib.pyplot as plt
import os,PIL# 设置随机种子尽可能使结果可以重现
import numpy as np
np.random.seed(1)# 设置随机种子尽可能使结果可以重现
import tensorflow as tf
tf.random.set_seed(1)from tensorflow import keras
from tensorflow.keras import layers,modelsimport pathlib
data_dir = "weather_photos/"
data_dir = pathlib.Path(data_dir)

3. 查看数据

数据集一共分为cloudyrainshinesunrise四类,分别存放于weather_photos文件夹中以各自名字命名的子文件夹中。

image_count = len(list(data_dir.glob('*/*.jpg')))print("图片总数为:",image_count)
roses = list(data_dir.glob('sunrise/*.jpg'))
PIL.Image.open(str(roses[0]))

在这里插入图片描述

二、数据预处理

1. 加载数据

使用image_dataset_from_directory方法将磁盘中的数据加载到tf.data.Dataset

batch_size = 32
img_height = 180
img_width = 180
train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="training",seed=123,image_size=(img_height, img_width),batch_size=batch_size)
Found 1125 files belonging to 4 classes.
Using 900 files for training.
val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="validation",seed=123,image_size=(img_height, img_width),batch_size=batch_size)
Found 1125 files belonging to 4 classes.
Using 225 files for validation.

我们可以通过class_names输出数据集的标签。标签将按字母顺序对应于目录名称。

class_names = train_ds.class_names
print(class_names)
['cloudy', 'rain', 'shine', 'sunrise']

2. 可视化数据

plt.figure(figsize=(20, 10))for images, labels in train_ds.take(1):for i in range(20):ax = plt.subplot(5, 10, i + 1)plt.imshow(images[i].numpy().astype("uint8"))plt.title(class_names[labels[i]])plt.axis("off")

在这里插入图片描述

3. 再次检查数据

for image_batch, labels_batch in train_ds:print(image_batch.shape)print(labels_batch.shape)break
(32, 180, 180, 3)
(32,)
  • Image_batch是形状的张量(32,180,180,3)。这是一批形状180x180x3的32张图片(最后一维指的是彩色通道RGB)。
  • Label_batch是形状(32,)的张量,这些标签对应32张图片

4. 配置数据集

AUTOTUNE = tf.data.AUTOTUNEtrain_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

三、构建CNN网络

卷积神经网络(CNN)的输入是张量 (Tensor) 形式的 (image_height, image_width, color_channels),包含了图像高度、宽度及颜色信息。不需要输入batch size。color_channels 为 (R,G,B) 分别对应 RGB 的三个颜色通道(color channel)。在此示例中,我们的 CNN 输入,fashion_mnist 数据集中的图片,形状是 (28, 28, 1)即灰度图像。我们需要在声明第一层时将形状赋值给参数input_shape

num_classes = 4model = models.Sequential([layers.experimental.preprocessing.Rescaling(1./255, input_shape=(img_height, img_width, 3)),layers.Conv2D(16, (3, 3), activation='relu', input_shape=(img_height, img_width, 3)), # 卷积层1,卷积核3*3  layers.AveragePooling2D((2, 2)),               # 池化层1,2*2采样layers.Conv2D(32, (3, 3), activation='relu'),  # 卷积层2,卷积核3*3layers.AveragePooling2D((2, 2)),               # 池化层2,2*2采样layers.Conv2D(64, (3, 3), activation='relu'),  # 卷积层3,卷积核3*3layers.Dropout(0.3),  layers.Flatten(),                       # Flatten层,连接卷积层与全连接层layers.Dense(128, activation='relu'),   # 全连接层,特征进一步提取layers.Dense(num_classes)               # 输出层,输出预期结果
])model.summary()  # 打印网络结构
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
rescaling (Rescaling)        (None, 180, 180, 3)       0         
_________________________________________________________________
conv2d (Conv2D)              (None, 178, 178, 16)      448       
_________________________________________________________________
average_pooling2d (AveragePo (None, 89, 89, 16)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 87, 87, 32)        4640      
_________________________________________________________________
average_pooling2d_1 (Average (None, 43, 43, 32)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 41, 41, 64)        18496     
_________________________________________________________________
dropout (Dropout)            (None, 41, 41, 64)        0         
_________________________________________________________________
flatten (Flatten)            (None, 107584)            0         
_________________________________________________________________
dense (Dense)                (None, 128)               13770880  
_________________________________________________________________
dense_1 (Dense)              (None, 5)                 645       
=================================================================
Total params: 13,795,109
Trainable params: 13,795,109
Non-trainable params: 0
_________________________________________________________________

四、编译

  • 在准备对模型进行训练之前,还需要再对其进行一些设置。以下内容是在模型的编译步骤中添加的:
    • 损失函数(loss):用于衡量模型在训练期间的准确率。
    • 优化器(optimizer):决定模型如何根据其看到的数据和自身的损失函数进行更新。
    • 指标(metrics):用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率。
# 设置优化器
opt = tf.keras.optimizers.Adam(learning_rate=0.001)
model.compile(optimizer=opt,loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])

五、训练模型

epochs = 10
history = model.fit(train_ds,validation_data=val_ds,epochs=epochs
)
Epoch 1/10
29/29 [==============================] - 6s 58ms/step - loss: 1.5865 - accuracy: 0.4463 - val_loss: 0.5837 - val_accuracy: 0.7689
Epoch 2/10
29/29 [==============================] - 0s 12ms/step - loss: 0.5289 - accuracy: 0.8295 - val_loss: 0.5405 - val_accuracy: 0.8133
Epoch 3/10
29/29 [==============================] - 0s 12ms/step - loss: 0.2930 - accuracy: 0.8967 - val_loss: 0.5364 - val_accuracy: 0.8000
Epoch 4/10
29/29 [==============================] - 0s 12ms/step - loss: 0.2742 - accuracy: 0.9074 - val_loss: 0.4034 - val_accuracy: 0.8267
Epoch 5/10
29/29 [==============================] - 0s 11ms/step - loss: 0.1952 - accuracy: 0.9383 - val_loss: 0.3874 - val_accuracy: 0.8844
Epoch 6/10
29/29 [==============================] - 0s 11ms/step - loss: 0.1592 - accuracy: 0.9468 - val_loss: 0.3680 - val_accuracy: 0.8756
Epoch 7/10
29/29 [==============================] - 0s 12ms/step - loss: 0.0836 - accuracy: 0.9755 - val_loss: 0.3429 - val_accuracy: 0.8756
Epoch 8/10
29/29 [==============================] - 0s 12ms/step - loss: 0.0943 - accuracy: 0.9692 - val_loss: 0.3836 - val_accuracy: 0.9067
Epoch 9/10
29/29 [==============================] - 0s 12ms/step - loss: 0.0344 - accuracy: 0.9909 - val_loss: 0.3578 - val_accuracy: 0.9067
Epoch 10/10
29/29 [==============================] - 0s 11ms/step - loss: 0.0950 - accuracy: 0.9708 - val_loss: 0.4710 - val_accuracy: 0.8356

六、模型评估

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']loss = history.history['loss']
val_loss = history.history['val_loss']epochs_range = range(epochs)plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/148279.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

11.10 校招 实习 内推 面经

绿*泡*泡: neituijunsir 交流裙 ,内推/实习/校招汇总表格 1、校招丨海康机器人2024校招3大类岗位补录(内推) 校招丨海康机器人2024校招3大类岗位补录(内推) 2、校招&实习 | 华为数据存储研究部招聘…

SELinux零知识学习十九、SELinux策略语言之类型强制(4)

接前一篇文章:SELinux零知识学习十八、SELinux策略语言之类型强制(3) 二、SELinux策略语言之类型强制 2. 类型、属性和别名 (4)别名 别名是引用类型时的一个备选的名字,能够使用类型名的地方就可以使用别…

EEPROM与Flash的区别

EEPROM与Flash的区别 EEPROMEEPROM内部功能框图实现写入数据内部结构存储管在充电或放电状态下有着不同的阈值电压 问题点EEPROM是如何失效的呢?为何EEPROM不能做大呢? ------------------------------------------------------------------------------…

HTTP请求详解

HTTP请求格式 请求报文通常包含以下部分: 请求行(Request Line): 包括请求方法、请求的URL和协议版本。 示例:GET /index.html HTTP/1.1 请求头(Request Headers): 包含了一系列的键值对,用来描述客户端请求的相关信息,比如Accept(告诉服务器客户端能够处理的MIME类型…

Java多线程(3)

Java多线程(3) 深入剖析Java线程的生命周期,探秘JVM的线程状态! 线程的生命周期 Java 线程的生命周期主要包括五个阶段:新建、就绪、运行、阻塞和销毁。 **新建(New):**线程对象通过 new 关键字创建&…

tamarin运行

首先我们找到安装tamarin的文件位置,找到以后进入该文件夹下 ubuntuubuntu:~$ sudo find / -name tamarin-prover /home/linuxbrew/.linuxbrew/var/homebrew/linked/tamarin-prover /home/linuxbrew/.linuxbrew/Cellar/tamarin-prover /home/linuxbrew/.linuxbrew/…

mac下vue-cli从2.9.6升级到最新版本

由于mac之前安装了 vue 2.9.6 的版本,现在想升级到最新版本,用官方给的命令: npm uninstall vue-cli -g 发现不行。 1、究其原因:从vue-cli 3.0版本开始原来的npm install -g vue-cli 安装的都是旧版,最高到2.9.6。安…

基于Netty实现的简单聊天服务组件

目录 基于Netty实现的简单聊天服务组件效果展示技术选型:功能分析聊天服务基础设施配置(基于Netty)定义组件基础的配置(ChatProperties)定义聊天服务类(ChatServer)定义聊天服务配置初始化类&am…

后端接口错误总结

今天后端错误总结: 1.ConditionalOnExpression(“${spring.kafka.exclusive-group.enable:false}”) 这个标签负责加载Bean,因此这个位置必须打开,如果这个标签不打开就会报错 问题解决:这里的配置在application.yml文件中 kaf…

数据结构之双向带头循环链表函数功能实现与详细解析

个人主页:点我进入主页 专栏分类:C语言初阶 C语言程序设计————KTV C语言小游戏 C语言进阶 C语言刷题 数据结构初阶 欢迎大家点赞,评论,收藏。 一起努力,一起奔赴大厂。 目录 1.前言 2.带头双…

Linux Docker图形化工具Portainer如何进行远程访问?

文章目录 前言1. 部署Portainer2. 本地访问Portainer3. Linux 安装cpolar4. 配置Portainer 公网访问地址5. 公网远程访问Portainer6. 固定Portainer公网地址 前言 Portainer 是一个轻量级的容器管理工具,可以通过 Web 界面对 Docker 容器进行管理和监控。它提供了可…

Flutter最新稳定版3.16 新特性介绍

Flutter 3.16 默认采用 Material 3 主题,Android 平台预览 Impeller,DevTools 扩展等等 欢迎回到每季度一次的 Flutter 稳定版本发布,这次是 Flutter 3.16。这个版本将 Material 3 设为新的默认主题,为 Android 带来 Impeller 预览…

SpringBoot使用DevTools实现后端热部署

📑前言 本文主要SpringBoot通过DevTools实现热部署的文章,如果有什么需要改进的地方还请大佬指出⛺️ 🎬作者简介:大家好,我是青衿🥇 ☁️博客首页:CSDN主页放风讲故事 🌄每日一句&…

Windows使用ssh远程连接(虚拟机)Linux(Ubuntu)的方法

步骤 1.Windows下载一个SSH客户端软件 要使用SSH连接,当然得先有一个好用的客户端软件才方便。 我这里使用的是WindTerm,一个开源免费的SSH连接工具,用什么软件不是重点。 这里默认你已经生成过SSH的密钥了,如果没有&#xff0c…

C语言 字符函数汇总,模拟实现各字符函数(炒鸡详细)

目录 求字符串长度 strlen 示例 模拟实现strlen 长度不受限制的字符串函数 strcpy 示例 模拟实现strcpy strcat 模拟实现strcat strcmp 示例 模拟实现strcmp 长度受限制的字符串函数介绍 strncpy 示例 模拟实现strncpy strncat 示例 模拟实现strncat s…

前端js常用代码段总结

持续更新中… 以下内容仅供参考。如有错误,欢迎指正! 判断一个对象是否拥有某个属性 场景介绍 1、项目中后端返回的字段,有些时候存在有些时候不存在,前端的逻辑需要依靠这个字段 方法总结 Reflect.has() 静态方法 Reflect.has…

Spring Boot 中使用 ResourceLoader 加载资源的完整示例

ResourceLoader 是 Spring 框架中用于加载资源的接口。它定义了一系列用于获取资源的方法,可以处理各种资源,包括类路径资源、文件系统资源、URL 资源等。 以下是 ResourceLoader 接口的主要方法: Resource getResource(String location)&am…

【Hello Go】Go语言异常处理

Go语言异常处理 异常处理error接口panicrecover延时调用错误问题 异常处理 error接口 Go语言引入了一个关于错误处理的标准模式 它是Go语言内建的接口类型 它的定义如下 type error interface {Error() string }Go语言的标准库代码包errors为用户提供了以下方法 package e…

人工智能轨道交通行业周刊-第65期(2023.10.30-11.19)

本期关键词:高铁自主创新、智慧城轨、调车司机、大模型垂直应用、大模型幻觉 1 整理涉及公众号名单 1.1 行业类 RT轨道交通人民铁道世界轨道交通资讯网铁路信号技术交流北京铁路轨道交通网上榜铁路视点ITS World轨道交通联盟VSTR铁路与城市轨道交通RailMetro轨道…

Kafka快速入门

文章目录 Kafka快速入门1、相关概念介绍前言1.1 基本介绍1.2 常见消息队列的比较1.3 Kafka常见相关概念介绍 2、安装Kafka3、初体验前期准备编码测试配置介绍 bug记录 Kafka快速入门 1、相关概念介绍 前言 在当今信息爆炸的时代,实时数据处理已经成为许多应用程序和…