卷积神经网络实现天气图像分类 - P3

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍦 参考文章:Pytorch实战 | 第P3周:彩色图片识别:天气识别
  • 🍖 原作者:K同学啊 | 接辅导、项目定制
  • 🚀 文章来源:K同学的学习圈子

目录

  • 环境
  • 步骤
    • 环境设置
    • 数据准备
    • 模型设计
    • 模型训练
    • 结果展示
  • 总结与心得体会


环境

  • 系统: Linux
  • 语言: Python3.8.10
  • 深度学习框架: Pytorch2.0.0+cu118

步骤

环境设置

首先是包引用

import torch # pytorch主包
import torch.nn as nn # 模型相关的包,创建一个别名少打点字
import torch.optim as optim # 优化器包,创建一个别名
import torch.nn.functional as F # 可以直接调用的函数,一般用来调用里面在的激活函数from torch.utils.data import DataLoader, random_split # 数据迭代包装器,数据集切分
from torchvision import datasets, transforms # 图像类数据集和图像转换操作函数import matplotlib.pyplot as plt # 图表库
from torchinfo import summary # 打印模型结构

查询当前环境的GPU是否可用

print(torch.cuda.is_available())

GPU可用情况
创建一个全局的设备对象,用于使各类数据处于相同的设备中

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 当GPU不可用时,使用CPU# 如果是Mac系统可以多增加一个if条件,启用mps
if torch.backends.mps.is_available():device = torch.device('mps')

数据准备

这次的天气图像是由K同学提供的,我提前下载下来放在了当前目录下的data文件夹中
加载文件夹中的图像数据集,要求文件夹按照不同的分类并列存储,一个简要的文件树为

datacloudyrainshinesunrise

使用torchvisio.datasets中的方法加载自定义图像数据集,可以免除一些文章中推荐的自己创建Dataset,个人感觉十分方便,而且这种文件的存储结构也兼容keras框架。

首先我们使用原生的PythonAPI来遍历一下文件夹,收集一下分类信息

import pathlibdata_lib = pathlib.Path('data')
class_names = [f.parts[-1] for f in data_lib.glob('*')] # 将data下级文件夹作为分类名
print(class_names)

打印分类信息
在所有的图片中随机选择几个文件打印一下信息。

import numpy as np
from PIL import Image
import randomimage_list = list(data_lib.glob('*/*'))
for _ in range(10):print(np.array(Image.open(random.choice(image_list))).shape)

打印图像信息
通过打印图像信息,发现图像的大小并不一致,需要在创建数据集时对图像进行缩放到统一的大小。

transform = transforms.Compose([transforms.Resize([224, 224]), # 将图像都缩放到224x224transforms.ToTensor(), # 将图像转换成pytorch tensor对象
]) # 定义一个全局的transform, 用于对齐训练验证以及测试数据

接下来就可以正式从文件夹中加载数据集了

dataset = datasets.ImageFolder('data', transform=tranform)

现在把整文件夹下的所有文件加载为了一个数据集,需要根据一定的比例划分为训练和验证集,方便模型的评估

train_size = int(len(dataset) *0.8) # 80% 训练集 20% 验证集
eval_size = len(dataset) - train_sizetrain_dataset, eval_dataset = random_split(dataset, [train_size, eval_size])

创建完数据集,打印一下数据集中的图像

plt.figure(figsize=(20, 4))
for i in range(20):image, label = train_dataset[i]plt.subplot(2, 10, i+1)plt.imshow(image.permute(1,2,0)) # pytorch的tensor格式为N,C,H,W,在imshow展示需要将格式变成H,W,C格式,使用permute切换一下plt.axis('off')plt.title(class_names[label])

预览数据集
最后用DataLoader包装一下数据集,方便遍历

batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
eval_loader = DataLoader(eval_loader, batch_size=batch_size)

模型设计

使用一个带有BatchNorm的卷积神经网络来处理分类问题

class Network(nn.Module):def __init__(self, num_classes):super().__init__()self.conv1 = nn.Conv2d(3, 12, kernel_size=5, strides=1)self.conv2 = nn.Conv2d(12, 12, kernel_size=5, strides=1)self.conv3 = nn.Conv2d(12, 24, kernel_size=5, strides=1)self.conv4 = nn.Conv2d(24, 24, kernel_size=5, strides=1)self.maxpool = nn.MaxPool2d(2)self.bn1 = nn.BatchNorm2d(12)self.bn2 = nn.BatchNorm2d(12)self.bn3 = nn.BatchNorm2d(24)self.bn4 = nn.BatchNorm2d(24)# 224 [-> 220 -> 216 -> 108] [-> 104 -> 100 -> 50]self.fc1 = nn.Linear(50*50*24, num_classes)def forward(self, x):x = F.relu(self.bn1(self.conv1(x)))x = F.relu(self.bn2(self.conv2(x)))x = self.maxpool(x)x = F.relu(self.bn3(self.conv3(x)))x = F.relu(self.bn4(self.conv4(x)))x = self.maxpool(x)x = x.view(x.size(0), -1)x = self.fc1(x)return xmodel = Network(len(class_names)).to(device) # 别忘了把定义的模型拉入共享中
summary(model, input_size=(32, 3, 224, 224))

模型结构

模型训练

首先定义一下每个epoch内训练和评估的逻辑

def train(train_loader, model, loss_fn, optimizer):train_size = len(train_loader.dataset)num_batches = len(train_loader)train_loss, train_acc = 0, 0for x, y in train_loader:x, y = x.to(device), y.to(device)preds = model(x)loss = loss_fn(preds, y)optimizer.zero_grad()loss.backward()optimizer.step()train_loss += loss.item()train_acc += (preds.argmax(1) == y).type(torch.float).sum().item()train_loss /= num_batchestrain_acc /= train_sizereturn train_loss, train_accdef eval(eval_loader, model, loss_fn):eval_size = len(eval_loader.dataset)num_batches = len(eval_loader)eval_loss, eval_acc = 0, 0for x, y in eval_loader:x, y = x.to(device), y.to(device)preds = model(x)loss = loss_fn(preds, y)eval_loss += loss.item()eval_acc += (preds.argmax(1) == y).type(torch.float).sum().item()eval_loss /= num_batcheseval_acc /= eval_sizereturn eval_loss, eval_acc

然后编写代码进行训练

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
epochs = 10train_loss, train_acc = [], []
eval_loss, eval_acc =[], []
for epoch in range(epochs):model.train()epoch_train_loss, epoch_train_acc = train(train_loader, model, loss_fn, optimizer)model.eval()model.no_grad():epoch_eval_loss, epoch_eval_acc = test(eval_loader, model, loss_fn)

结果展示

训练结果
基于训练和测试数据展示结果

range_epochs = range(len(train_loss))
plt.figure(figsize=(12, 4))
plt.subplot(1,2,1)
plt.plot(range_epochs, train_loss, label='train loss')
plt.plot(range_epochs, eval_loss, label='validation loss')
plt.legend(loc='upper right')
plt.title('Loss')plt.subplot(1,2,2)
plt.plot(range_epochs, train_acc, label='train accuracy')
plt.plot(range_epochs, eval_acc, label='validation accuracy')
plt.legend(loc='lower right')
plt.title('Accuracy')

训练历史图表

总结与心得体会

通过对训练过程的观察,训练过程中的数据波动很大,并且验证集上的最好正确率只有82%。
目前行业都流行小卷积核,于是我把卷积核调整为了3x3,并且每次卷积后我都进行池化操作,直到通道数为64,由于天气识别时,背景信息也比较重要,高层的卷积操作后我使用平均池化代替低层使用的最大池化,加大了全连接层的Dropout惩罚比重,用来抑制过拟合问题。最后的模型如下:

class Network(nn.Module):def __init__(self, num_classes):super().__init__()self.conv1 = nn.Conv2d(3, 16, kernel_size=3)self.conv2 = nn.Conv2d(16, 32, kernel_size=3)self.conv3 = nn.Conv2d(32, 64, kernel_size=3)self.conv4 = nn.Conv2d(64, 64, kernel_size=3)self.bn1 = nn.BatchNorm2d(16)self.bn2 = nn.BatchNorm2d(32)self.bn3 = nn.BatchNorm2d(64)self.bn4 = nn.BatchNorm2d(64)self.maxpool = nn.MaxPool2d(2)self.avgpool = nn.AvgPool2d(2)self.dropout = nn.Dropout(0.5)# 224 -> 222-> 111 -> 109 -> 54 -> 52 -> 50 -> 25self.fc1 = nn.Linear(25*25*64, 128)self.fc2 = nn.Linear(128, num_classes)def forward(self, x):x = F.relu(self.bn1(self.conv1(x)))x = self.maxpool(x)x = F.relu(self.bn2(self.conv2(x)))x = self.avgpool(x)x = F.relu(self.bn3(self.conv3(x)))x = F.relu(self.bn4(self.conv4(x)))x = self.avgpool(x)x = x.view(x.size(0), -1)x = self.dropout(x)x = F.relu(self.fc1(x))x = self.dropout(x)x = self.fc2(x)return x

然后增大训练的epochs为30,学习率降低为1e-4

optimizer = optim.Adam(model.parameters(), lr=1e-4)
epochs = 30

训练结果如下
训练过程
可以看到,验证集上的正确率最高达到了95%以上
训练过程图示

在数据集中随机选取一个图像进行预测展示

image_path = random.choice(image_list)
image_input = transform(Image.open(image_path))
image_input = image_input.unsqueeze(0).to(device)
model.eval()
pred = model(image_input)plt.figure(figsize=(5, 5))
plt.imshow(image_input.cpu().squeeze(0).permute(1,2,0))
plt.axis('off')
plt.title(class_names[pred.argmax(1)])

结果如下
预测结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/52658.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【LeetCode75】第三十五题 统计二叉树中好节点的数目

目录 题目: 示例: 分析: 代码: 题目: 示例: 分析: 给我们一棵二叉树,让我们统计这棵二叉树中好节点的数目。 那么什么是好节点,题目中给出定义,从根节点…

1782. 统计点对的数目

给你一个无向图,无向图由整数 n ,表示图中节点的数目,和 edges 组成,其中 edges[i] [ui, vi] 表示 ui 和 vi 之间有一条无向边。同时给你一个代表查询的整数数组 queries 。 第 j 个查询的答案是满足如下条件的点对 (a, b) 的数…

实现高效消息传递:使用RabbitMQ构建可复用的企业级消息系统

文章目录 前言1.安装erlang 语言2.安装rabbitMQ3. 内网穿透3.1 安装cpolar内网穿透(支持一键自动安装脚本)3.2 创建HTTP隧道 4. 公网远程连接5.固定公网TCP地址5.1 保留一个固定的公网TCP端口地址5.2 配置固定公网TCP端口地址 前言 RabbitMQ是一个在 AMQP(高级消息队列协议)基…

【Linux】动态库和静态库

动态库和静态库 软链接硬链接硬链接要注意 自定义实现一个静态库(.a)解决、使用方法静态库的内部加载过程 自定义实现一个动态库(.so)动态库加载过程 静态库和动态库的特点 软链接 命令:ln -s 源文件名 目标文件名 软链接是独立连接文件的,他…

Tomcat运行后localhost:8080访问自己编写的网页

主要是注意项目结构,home.html放在src/resources/templates下的home.html下,application.properties可以不做任何配置。还有就是关于web包的位置,作者一开始将web包与tabtab包平行,访问8080出现了此类报错: Whitelabel…

c++ qt--页面布局(第五部分)

c qt–页面布局(第五部分) 一.页面布局 在设计页面的左侧一栏的组件中我们可以看到进行页面布局的一些组件 布局组件的使用 1.水平布局 使用:将别的组件拖到水平布局的组件中即可,可以选择是在哪个位置 2.垂直布局 使用&…

【业务功能篇81】微服务SpringCloud-ElasticSearch-Kibanan-docke安装-入门实战

ElasticSearch 一、ElasticSearch概述 1.ElasticSearch介绍 ES 是一个开源的高扩展的分布式全文搜索引擎,是整个Elastic Stack技术栈的核心。它可以近乎实时的存储,检索数据;本身扩展性很好,可以扩展到上百台服务器,…

GB28181设备接入侧如何对接外部编码后音视频数据并实现预览播放

技术背景 我们在对接GB28181设备接入模块的时候,遇到这样的技术诉求,好多开发者期望能提供编码后(H.264/H.265、AAC/PCMA)数据对接,确保外部采集设备,比如无人机类似回调过来的数据,直接通过模…

Vue中使用element-plus中的el-dialog定义弹窗-内部样式修改-v-model实现-demo

效果图 实现代码 <template><el-dialog class"no-code-dialog" v-model"isShow" title"没有收到验证码&#xff1f;"><div class"nocode-body"><div class"tips">请尝试一下操作</div><d…

C语言易错点整理

前言&#xff1a; 本文涵盖了博主在平常写C语言题目时经常犯的一些错误&#xff0c;在这里帮大家整理出来&#xff0c;一些易错点会帮大家标识出来&#xff0c;希望大家看完这篇文章后有所得&#xff0c;引以为戒~ 一、 题目&#xff1a; 解答&#xff1a; 首先在这个程序中…

Unity ProBuilder SetUVs 不起作用

ProBuilder SetUVs 不起作用 &#x1f41f; 需要设置face.manulUV true public static void Set01UV(this ProBuilderMesh mesh){foreach (var face in mesh.faces){face.manualUV true;//设置为手动uv}var vertices mesh.GetVertices().Select(v > v.position).ToArray(…

mq与mqtt的关系

文章目录 mqtt 与 mq的区别mqtt 与 mq的详细区别传统消息队列RocketMQ和微消息队列MQTT对比&#xff1a;MQ与RPC的区别 mqtt 与 mq的区别 mqtt&#xff1a;一种通信协议&#xff0c;规范 MQ&#xff1a;一种通信通道&#xff08;方式&#xff09;&#xff0c;也叫消息队列 MQ…

openCV实战-系列教程5:边缘检测(Canny边缘检测/高斯滤波器/Sobel算子/非极大值抑制/线性插值法/梯度方向/双阈值检测 )、原理解析、源码解读

打印一个图片可以做出一个函数&#xff1a; def cv_show(img,name):cv2.imshow(name,img)cv2.waitKey()cv2.destroyAllWindows() 1、Canny边缘检测流程 Canny是一个科学家在1986年写了一篇论文&#xff0c;所以用自己的名字来命名这个检测算法&#xff0c;Canny边缘检测算法…

C#-Tolewer和ToUpper的使用

目录 简介: 好处:​ 过程: 总结&#xff1a; 简介: 字符串是不可变的&#xff0c;所以这些函数都不会直接改变字符串的内容&#xff0c;而是把修改后的字符串的值通过函数返回值的形式返回。 ToLower和ToUpper是字符串处理函数&#xff0c;用于将字符中的英文字母转换为小…

使用DPO微调Llama2

简介 基于人类反馈的强化学习 (Reinforcement Learning from Human Feedback&#xff0c;RLHF) 事实上已成为 GPT-4 或 Claude 等 LLM 训练的最后一步&#xff0c;它可以确保语言模型的输出符合人类在闲聊或安全性等方面的期望。然而&#xff0c;它也给 NLP 引入了一些 RL 相关…

linux篇---使用systemctl start xxx启动自己的程序|开机启动|守护进程

linux篇---使用systemctl start xxx启动自己的程序|开机启动|守护进程 1、创建服务2、修改权限3、启动服务4、测试 机器&#xff1a;Nvidia Jetson Xavier系统&#xff1a;ubuntu 18.04 最近在使用symfony的console组件&#xff0c;需要执行一个后台的php进程&#xff0c;并且…

C++ Day5

目录 一、静态成员 1.1 概念 1.2 格式 1.3 银行账户实例 二、类的继承 2.1 目的 2.2 概念 2.3 格式 2.4 继承方式 2.5 继承中的特殊成员函数 2.5.1 构造函数 2.5.2析构函数 2.5.3 拷贝构造函数 2.5.4拷贝赋值函数 总结&#xff1a; 三、多继承 3.1 概念 3.2 格…

大数据领域都有什么发展方向

近年来越来越多的人选择大数据行业&#xff0c;大数据行业前景不错薪资待遇好&#xff0c;各大名企对于大数据人才需求不断上涨。 大数据从业领域很宽广&#xff0c;不管是科技领域还是食品产业&#xff0c;零售业等都是需要大数据人才进行大数据的处理&#xff0c;以提供更好…

【python】Leetcode(primer-set)

文章目录 78. 子集&#xff08;集合的所有子集&#xff09;90. 子集 II&#xff08;集合的所有子集&#xff09;792. 匹配子序列的单词数&#xff08;判断是否为子集&#xff09;500. 键盘行&#xff08;集合的交集&#xff09;409. 最长回文串&#xff08;set&#xff09; 更多…

测试先行:探索测试驱动开发的深层价值

引言 在软件开发的世界中,如何确保代码的质量和可维护性始终是一个核心议题。测试驱动开发(TDD)为此提供了一个答案。与传统的开发方法相比,TDD鼓励开发者从用户的角度出发,先定义期望的结果,再进行实际的开发。这种方法不仅可以确保代码满足预期的需求,还可以在整个开…