深度学习pytorch实战第P3周--实现天气识别

>- **🍨 本文为[🔗365天深度学习训练营](https://mp.weixin.qq.com/s/0dvHCaOoFnW8SCp3JpzKxg) 中的学习记录博客**
>- **🍖 原作者:[K同学啊](https://mtyjkh.blog.csdn.net/)**

引言

1.复习上周

深度学习pytorch实战-第P2周-彩色图片识别icon-default.png?t=N7T8http://t.csdnimg.cn/f5l6F对于上周的学习,数据集是下载的

2.摆正心态

正如K同学所说,学到第三周左右就会有点感觉了,还真是这样,引领我入门,激发了我的兴趣,同时感谢同济子豪兄。

3.本机环境

见上文

4.学习目标

扎扎实实学好习,走好每一步。

一、前期准备

1.设置GPU

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision
from torchvision import transforms, datasetsimport os,PIL,pathlib,randomdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")device

2.导入数据

data_dir = './weather_photos/'
print(data_dir)
data_dir = pathlib.Path(data_dir)data_paths = list(data_dir.glob('*'))
classeNames = [str(path).split("\\")[1] for path in data_paths]
classeNames
  • 第一步:使用pathlib.Path()函数将字符串类型的文件夹路径转换为pathlib.Path对象。
  • 第二步:使用glob()方法获取data_dir路径下的所有文件路径,并以列表形式存储在data_paths中。
  • 第三步:通过split()函数对data_paths中的每个文件路径执行分割操作,获得各个文件所属的类别名称,并存储在classeNames
  • 第四步:打印classeNames列表,显示每个文件所属的类别名称。

3.数据可视化

import matplotlib.pyplot as plt
from PIL import Image# 指定图像文件夹路径
image_folder = './weather_photos/cloudy/'# 获取文件夹中的所有图像文件
image_files = [f for f in os.listdir(image_folder) if f.endswith((".jpg", ".png", ".jpeg"))]# 创建Matplotlib图像
fig, axes = plt.subplots(3, 8, figsize=(16, 6))# 使用列表推导式加载和显示图像
for ax, img_file in zip(axes.flat, image_files):img_path = os.path.join(image_folder, img_file)img = Image.open(img_path)ax.imshow(img)ax.axis('off')# 显示图像
plt.tight_layout()
plt.show()

 

 图片大小处理

total_datadir = './data/'# 关于transforms.Compose的更多介绍可以参考:https://blog.csdn.net/qq_38251616/article/details/124878863
train_transforms = transforms.Compose([transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸transforms.ToTensor(),          # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间transforms.Normalize(           # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])total_data = datasets.ImageFolder(total_datadir,transform=train_transforms)
total_data

划分数据集(4:1)

train_size = int(0.8 * len(total_data))
test_size  = len(total_data) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size])
train_dataset, test_dataset

 

batch_size = 32train_dl = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=1)
test_dl = torch.utils.data.DataLoader(test_dataset,batch_size=batch_size,shuffle=True,num_workers=1)
for X, y in test_dl:print("Shape of X [N, C, H, W]: ", X.shape)print("Shape of y: ", y.shape, y.dtype)break

 

二、构建简单的CNN网络

对于一般的CNN网络来说,都是由特征提取网络和分类网络构成,其中特征提取网络用于提取图片的特征,分类网络用于将图片进行分类

参数详解见上文。

import torch.nn.functional as Fclass Network_bn(nn.Module):def __init__(self):super(Network_bn, self).__init__()"""nn.Conv2d()函数:第一个参数(in_channels)是输入的channel数量第二个参数(out_channels)是输出的channel数量第三个参数(kernel_size)是卷积核大小第四个参数(stride)是步长,默认为1第五个参数(padding)是填充大小,默认为0"""self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=5, stride=1, padding=0)self.bn1 = nn.BatchNorm2d(12)self.conv2 = nn.Conv2d(in_channels=12, out_channels=12, kernel_size=5, stride=1, padding=0)self.bn2 = nn.BatchNorm2d(12)self.pool1 = nn.MaxPool2d(2,2)self.conv4 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=5, stride=1, padding=0)self.bn4 = nn.BatchNorm2d(24)self.conv5 = nn.Conv2d(in_channels=24, out_channels=24, kernel_size=5, stride=1, padding=0)self.bn5 = nn.BatchNorm2d(24)self.pool2 = nn.MaxPool2d(2,2)self.fc1 = nn.Linear(24*50*50, len(classeNames))def forward(self, x):x = F.relu(self.bn1(self.conv1(x)))      x = F.relu(self.bn2(self.conv2(x)))     x = self.pool1(x)                        x = F.relu(self.bn4(self.conv4(x)))     x = F.relu(self.bn5(self.conv5(x)))  x = self.pool2(x)                        x = x.view(-1, 24*50*50)x = self.fc1(x)return xdevice = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))model = Network_bn().to(device)
model

三、训练模型

1.设置超参数

loss_fn    = nn.CrossEntropyLoss() # 创建损失函数
learn_rate = 1e-4 # 学习率
opt        = torch.optim.SGD(model.parameters(),lr=learn_rate)

2.编写训练函数

# 训练循环
def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)  # 训练集的大小,一共60000张图片num_batches = len(dataloader)   # 批次数目,1875(60000/32)train_loss, train_acc = 0, 0  # 初始化训练损失和正确率for X, y in dataloader:  # 获取图片及其标签X, y = X.to(device), y.to(device)# 计算预测误差pred = model(X)          # 网络输出loss = loss_fn(pred, y)  # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失# 反向传播optimizer.zero_grad()  # grad属性归零loss.backward()        # 反向传播optimizer.step()       # 每一步自动更新# 记录acc与losstrain_acc  += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc  /= sizetrain_loss /= num_batchesreturn train_acc, train_loss

 

3.编写测试函数

def test (dataloader, model, loss_fn):size        = len(dataloader.dataset)  # 测试集的大小,一共10000张图片num_batches = len(dataloader)          # 批次数目,313(10000/32=312.5,向上取整)test_loss, test_acc = 0, 0# 当不进行训练时,停止梯度更新,节省计算内存消耗with torch.no_grad():for imgs, target in dataloader:imgs, target = imgs.to(device), target.to(device)# 计算losstarget_pred = model(imgs)loss        = loss_fn(target_pred, target)test_loss += loss.item()test_acc  += (target_pred.argmax(1) == target).type(torch.float).sum().item()test_acc  /= sizetest_loss /= num_batchesreturn test_acc, test_loss

4.正式训练

epochs     = 20
train_loss = []
train_acc  = []
test_loss  = []
test_acc   = []for epoch in range(epochs):model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)model.eval()epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f}')print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))
print('Done')

 也不是伦茨越多越好

四、结果可视化

import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 100        #分辨率epochs_range = range(epochs)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

五、保存模型并预测

# 1.保存模型# torch.save(model, 'model.pth') # 保存整个模型
torch.save(model.state_dict(), 'model_state_dict.pth') # 仅保存状态字典# 2. 加载模型 or 新建模型加载状态字典# model2 = torch.load('model.pth') 
# model2 = model2.to(device) # 理论上在哪里保存模型,加载模型也会优先在哪里,但是指定一下确保不会出错model2 = Network_bn().to(device) # 重新定义模型
model2.load_state_dict(torch.load('model_state_dict.pth')) # 加载状态字典到模型# 3.图片预处理
from PIL import Image
import torchvision.transforms as transforms# 输入图片预处理
def preprocess_image(image_path):image = Image.open(image_path)transform = transforms.Compose([transforms.Resize((224, 224)),  # 假设使用的是224x224的输入transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),])image = transform(image).unsqueeze(0)  # 增加一个批次维度return image# 4.预测函数(指定路径)
def predict(image_path, model):model.eval()  # 将模型设置为评估模式with torch.no_grad():  # 关闭梯度计算image = preprocess_image(image_path)image = image.to(device)  # 确保图片在正确的设备上outputs = model(image)_, predicted = torch.max(outputs, 1)  # 获取最可能的预测类别return predicted.item()# 5.预测并输出结果
image_path = "./weather_photos/shine/shine101.jpg"  # 替换为你的图片路径
prediction = predict(image_path, model)
class_names = ["cloudy", "rain", "shine", "sunrise"]  # Replace with your class labels
predicted_label = class_names[prediction]
print("Predicted class:", predicted_label)

 

Predicted class: shine
# 选取dataloader中的一个图像进行判断
import numpy as np
# 选取图像
imgs,labels = next(iter(train_dl))
image,label = imgs[0],labels[0]# 选取指定图像并展示
# 调整维度为 [224, 224, 3]
image_to_show = image.numpy().transpose((1, 2, 0))# 归一化
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
image_to_show = std * image_to_show + mean
image_to_show = np.clip(image_to_show, 0, 1)# 显示图像
plt.imshow(image_to_show)
plt.show()# 将图像转移到模型所在的设备上(如果使用GPU)
image = image.to(device)# 预测
with torch.no_grad():output = model(image.unsqueeze(0))  # 添加批次维度# 输出预测结果
_, predicted = torch.max(output, 1)
class_names = ["cloudy", "rain", "shine", "sunrise"]  # Replace with your class labels
predicted_label = class_names[predicted]
print(f"Predicted: {predicted.item()}, Actual: {label.item()}")

 如何保存模型可以看这篇icon-default.png?t=N7T8http://t.csdnimg.cn/yCkr7

六,总结

构建数据集中

此次数据集不是直接从网上下载的,而是保存在本地的,那么我们构建数据集的代码就有3步了,第一步就是从我们的本地获取数据集,使用pathlib将路径转化为pathlib.path对象,使用glob方法获取路径下的所有文件路径并保存到datapath,通过spilt分割路径并保存在classname中,打印classname列表。这是在确保路径文件夹正确,且分类正确。接下来可以使用matplotlib.pyplot和PIL来获取图片,并可视化,过程中使用列表推导式加载和显示图像,第二步划分数据集,4:1比例划分并保存为dataset对象,第三步定义batch_size大小,使用dataloder加载器来管理数据,

构建cnn网络中

主要是特征提取网络和分类网络的构建,首先进行一个网络初始化init,定义网络结构,卷积层1,bn层1,(nn.BatchNorm2d(12) 是 PyTorch 中的一个批量归一化层。批量归一化用于加速神经网络的训练过程,并提高模型的泛化能力。在卷积神经网络中),卷积2,bn2,池化1,卷积4,bn层4,卷积5,bn层5,池化层2,全连接层1,然后定义进行前向传播结构,并在每一层(卷积+bn)进行relu函数返回给x,然后再池化,再relu,最后拉平,使用view函数,并传给fc层进行分类。

训练模型中

设置超参数,loss一般用交叉熵,学习率一般0.0001,opt采用sgd或者adam

编写训练函数中

编写训练函数时候,要使用size知道dataloaer加载的dataset大小长度,每batchnum就是dataloder的长度,定义训练集损失和准确率,获取图片和标签从dataloder中,计算误差,进行啊反向传播更新参数,每一步自动更新,记录acc和loss,最后计算train的acc和loss

编写测试函数同理

同理,再取图像之前加个当不进行训练时,停止梯度更新就行

来到了正式训练

定义epoch轮次,训练和特使的acc和loss定义,便于后续存储,轮次循环,model.train和model.eval,然后把每一批次的准确率loss都存在列表里,最后输出

最后可视化

定义plt窗口,可视化列表train_acc等等

保存模型

参上,现在我还有点不懂,慢慢学吧

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/810120.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

一文了解HTTPS的加密原理

HTTPS是一种安全的网络通信协议,用于在互联网上提供端到端的加密通信,确保数据在客户端(如Web浏览器)与服务器之间传输时的机密性、完整性和身份验证。HTTPS的加密原理主要基于SSL/TLS协议,以下详细阐述其工作过程&…

常见程序故障排查及程序配置

文章目录 故障排查基础关机/重启/注销系统信息和性能查看磁盘和分区⽤户和⽤户组⽹络和进程管理常⻅系统服务命令⽂件和⽬录操作⽂件查看和处理打包和解压RPM包管理命令YUM包管理命令DPKG包管理命令APT软件⼯具 分析工具JDK自带分析工具jpsjstatjinfojmapjhatjstackjcmd GUI分析…

QT:QMainWindow、ui界面、资源文件的添加、信号和槽

1.练习:使用手动连接,将登录框中的取消按钮使用qt4版本的连接到自定义的槽函数中,在自定义的槽函数中调用关闭函数 #include "widget.h" #include "ui_widget.h"Widget::Widget(QWidget *parent): QWidget(parent), ui(…

第6章 6.3.1 正则表达式的语法(MATLAB入门课程)

讲解视频:可以在bilibili搜索《MATLAB教程新手入门篇——数学建模清风主讲》。​ MATLAB教程新手入门篇(数学建模清风主讲,适合零基础同学观看)_哔哩哔哩_bilibili 正则表达式可以由一般的字符、转义字符、元字符、限定符等元素组…

算法题解记录8+++爬楼梯(百日筑基)

题目描述: 假设你正在爬楼梯。需要 n 阶你才能到达楼顶。 每次你可以爬 1 或 2 个台阶。你有多少种不同的方法可以爬到楼顶呢? 示例 1: 输入:n 2 输出:2 解释:有两种方法可以爬到楼顶。 1. 1 阶…

KVM虚拟机

文章目录 QEMU-KVM介绍虚拟网卡流程网卡透访流程 QEMU-KVM介绍 QEMU ● QEMU是一个主机上的VMM (Virtual machine monitor), 通过动态二进制模拟CPU,并提供一系列的硬件模型,使Guest OS能够与Host硬件交互。 ● QEMU的代码中有完整的虚拟机实现&#xf…

【C++】1.从C语言转向C++

目录 一.对C的认识 二.C的关键字 三.命名空间 3.1命名空间的定义 3.2命名空间的使用 四.C的输入与输出 五.缺省参数 5.1全缺省参数 5.2半缺省参数 六.函数重载 七.引用 7.1引用的特性 7.2引用和指针的区别 八.内联函数 九.auto关键字(C1…

WEB漏洞——XXE

文章目录 前言一、XXE简述及XML基础XXE简述XML基础xml简介文档格式xml树结构xml其它xml语法1、格式良好的xml2、编写第一段XML代码DTD介绍内部文档声明(即DTD在XML源文件中)外部文档声明(DTD位于XML源文件的外部)XML文档构建模块Elements(元素)数量词的用法Attributes(属…

CISA :恶意软件分析平台Malware Next-Gen全新升级

本周三,美国网络安全和基础设施安全局(CISA)发布了新版恶意软件分析平台Malware Next-Gen,现在公众可以提交任意恶意软件样本供 CISA 分析。 据悉,Malware Next-Gen 可用于检查恶意软件样本中是否存在可疑项目。它最初…

数据生成 | Matlab实现基于SNN浅层神经网络的数据生成

数据生成 | Matlab实现基于SNN浅层神经网络的数据生成 目录 数据生成 | Matlab实现基于SNN浅层神经网络的数据生成生成效果基本描述模型描述程序设计参考资料 生成效果 基本描述 1.Matlab实现基于SNN浅层神经网络的数据生成,运行环境Matlab2021b及以上; …

在windows中anaconda中安装fasttext (whl 文件安装)

Anaconda安装第三方包(whl文件) windows 安装fasttext 一直不成功,python 版本3.8 网上教程都是 https://www.lfd.uci.edu/~gohlke/pythonlibs/#fasttext 下载然后安装,但是这个网站里我没找到哈哈哈。。。 然后就是成功方案&am…

Pygame教程10:在背景图片上,添加一个雪花特效

------------★Pygame系列教程★------------ Pygame经典游戏:贪吃蛇 Pygame教程01:初识pygame游戏模块 Pygame教程02:图片的加载缩放旋转显示操作 Pygame教程03:文本显示字体加载transform方法 Pygame教程04:dra…

移植 amd blas 到 cuda 生态

1,下载源码 GitHub - ROCm/rocBLAS: Next generation BLAS implementation for ROCm platform $ git clone --recursive https://github.com/ROCm/rocBLAS.git 2, 编译 2.1 不带Tensile的编译 如果是在conda环境中,需要deactive conda 环境…

信息学奥赛一本通T1457-Power Strings【KMP】

信息学奥赛一本通T1457-Power Strings - C语言网 (dotcpp.com) #include <iostream> #include <algorithm> #include <cstring> using namespace std; const int N1e6100; char str[N]; int nex[N]; int res0; signed main() {while(scanf("%s",st…

[2024]最新激活Navicat教程附激活码

PS&#xff1a;在开始前&#xff0c;建议先断开本地网络&#xff01;&#xff01;&#xff01;建议先断开本地网络&#xff01;&#xff01;&#xff01;建议先断开本地网络&#xff01;&#xff01;&#xff01; 1 安装 1.1 点击下一步 1.2 许可证选择“我同意”&#xff0c…

【云计算】云网络产品体系概述

云网络产品体系概述 在介绍云网络产品体系前&#xff0c;先介绍几个与云计算相关的基础概念。 阿里云在基础设施层面分为 地域 和 可用区 两层&#xff0c;关系如下图所示。在一个地域内有多个可用区&#xff0c;每个地域完全独立&#xff0c;每个可用区完全隔离&#xff0c;同…

ViT:拉开Trasnformer在图像领域正式挑战CNN的序幕 | ICLR 2021

论文直接将纯Trasnformer应用于图像识别&#xff0c;是Trasnformer在图像领域正式挑战CNN的开山之作。这种简单的可扩展结构在与大型数据集的预训练相结合时&#xff0c;效果出奇的好。在许多图像分类数据集上都符合或超过了SOTA&#xff0c;同时预训练的成本也相对较低   来源…

安装 Kali NetHunter (完整版、精简版、非root版)、实战指南、ARM设备武器化指南

From&#xff1a;https://www.kali.org/docs/nethunter/ NetHunter 实战指南&#xff1a;https://www.vuln.cn/6430 乌云 存档&#xff1a;https://www.vuln.cn/wooyundrops 1、Kali NetHunter Kali NetHunter 简介 Net&#xff08;网络&#xff09;&#xff0c;hunter&#x…

今天讲讲MYSQL数据库事务怎么实现的!

目录 什么是数据库事务 Mysql如何保证原子性 Mysql如何保证持久性 MySQL怎么保证隔离性 事务隔离级别 脏读的解决 不可重复读的解决 幻读的解决 MVCC实现 Read View 那么RC、RR级别下的InnoDB快照读有什么不同&#xff1f; 什么是数据库事务 数据库事务是指一组数据…

Geotrust SSL证书:安全加密与身份验证的坚实后盾

Geotrust&#xff0c;作为全球知名的数字证书颁发机构&#xff0c;以其高度的安全性、可靠性和广泛的兼容性&#xff0c;在网络安全领域中占据了举足轻重的地位。尤其在SSL&#xff08;Secure Sockets Layer&#xff09;证书服务上&#xff0c;Geotrust更是凭借其卓越的技术实力…