Pytorch从零开始实战03

Pytorch从零开始实战——天气识别

本系列来源于365天深度学习训练营

原作者K同学

文章目录

  • Pytorch从零开始实战——天气识别
    • 环境准备
    • 数据集
    • 模型选择
    • 模型训练
    • 数据可视化
    • 总结

环境准备

本文基于Jupyter notebook,使用Python3.8,Pytorch2.0.1+cu118,torchvision0.15.2,需读者自行配置好环境且有一些深度学习理论基础。
第一步,导入常用包。

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision
import torch.nn.functional as F
import torchvision.transforms as transforms
import random
import time
import numpy as np
import pandas as pd
import datetime
import gc
import pathlib
import os
import PIL
os.environ['KMP_DUPLICATE_LIB_OK']='True'  # 用于避免jupyter环境突然关闭
torch.backends.cudnn.benchmark=True  # 用于加速GPU运算的代码

创建设备对象。

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device # device(type='cuda')

设置随机数种子

torch.manual_seed(428)
torch.cuda.manual_seed(428)
torch.cuda.manual_seed_all(428)
random.seed(428)
np.random.seed(428)

数据集

本次实验使用的天气图片数据集,共有1127张天气图片,分别存在’cloudy’, ‘sunrise’, ‘shine’, 'rain’四个文件夹中,其中文件夹名代表类别。数据集获取可联系K同学。
导入数据集。
根据自己数据集存放的路径,转换为pathlib.Path对象,然后获取路径下的所有文件路径,使用字符串分割函数获取文件名,也就是类别名。

data_dir = './data/weather_photos'
data_dir = pathlib.Path(data_dir) # 转换为pathlib.Path对象data_paths = list(data_dir.glob('*')) # 获取data_dir路径下的所有文件路径
data_paths # data/weather_photos/xxxx
classNames = [str(path).split("/")[2] for path in data_paths]
classNames # ['cloudy', 'sunrise', 'shine', 'rain']

对数据集进行预处理。调整到相同的尺寸,转换为张量对象,并进行标准化处理。使用torchvision.datasets.ImageFolder函数读取数据集,并且使用文件名当做数据集的标签。

total_dir = './data/weather_photos'
train_transforms = transforms.Compose([transforms.Resize([224, 224]), # 调整相同的尺寸transforms.ToTensor(),transforms.Normalize(          # 标准化处理-->转换为标准正太分布mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])
total_data = torchvision.datasets.ImageFolder(total_dir, transform=train_transforms) # 通过total_dir下的子文件夹当做标签
total_data

我们根据8:2划分训练集和测试集。

# 划分数据集
train_size = int(0.8 * len(total_data))
test_size = len(total_data) - train_size
train_ds, test_ds = torch.utils.data.random_split(total_data, [train_size, test_size])
len(train_ds), len(test_ds) # (901, 226)

又是前面几篇出现的函数,随机查看五张图片。

def plotsample(data):fig, axs = plt.subplots(1, 5, figsize=(10, 10)) #建立子图for i in range(5):num = random.randint(0, len(data) - 1) #首先选取随机数,随机选取五次#抽取数据中对应的图像对象,make_grid函数可将任意格式的图像的通道数升为3,而不改变图像原始的数据#而展示图像用的imshow函数最常见的输入格式也是3通道npimg = torchvision.utils.make_grid(data[num][0]).numpy()nplabel = data[num][1] #提取标签 #将图像由(3, weight, height)转化为(weight, height, 3),并放入imshow函数中读取axs[i].imshow(np.transpose(npimg, (1, 2, 0))) axs[i].set_title(nplabel) #给每个子图加上标签axs[i].axis("off") #消除每个子图的坐标轴plotsample(train_ds)

在这里插入图片描述
使用DataLoder将它按照batch_size批量划分,并将数据集顺序打乱。

batch_size = 32
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=batch_size, shuffle=True)
for X, y in test_dl:print(X.shape) # 32, 3, 224, 224print(y) # 3 0 2 0 3 2 0 0 2 1....break

模型选择

本文使用卷积神经网络,大致流程是卷积->卷积->池化->卷积->卷积->池化->线性层,并进行数据归一化处理,本文选用的卷积核大小为5 * 5。

class Model(nn.Module):def __init__(self):super(Model, self).__init__()self.conv1 = nn.Conv2d(3, 12, kernel_size=5, stride=1, padding=0)self.bn1 = nn.BatchNorm2d(12)self.conv2 = nn.Conv2d(12, 12, kernel_size=5, stride=1, padding=0)self.bn2 = nn.BatchNorm2d(12)self.pool2 = nn.MaxPool2d(2, 2)self.conv3 = nn.Conv2d(12, 24, kernel_size=5, stride=1, padding=0)self.bn3 = nn.BatchNorm2d(24)self.conv4 = nn.Conv2d(24, 24, kernel_size=5, stride=1, padding=0)self.bn4 = nn.BatchNorm2d(24)self.pool4 = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(24 * 50 * 50, len(classNames))def forward(self, x):x = F.relu(self.bn1(self.conv1(x)))x = F.relu(self.bn2(self.conv2(x)))x = self.pool2(x)x = F.relu(self.bn3(self.conv3(x)))x = F.relu(self.bn4(self.conv4(x)))x = self.pool4(x)x = x.view(-1, 24 * 50 * 50)x = self.fc1(x)return x

请添加图片描述
使用summary展示模型架构。

from torchsummary import summary
# 将模型转移到GPU中
model = Model().to(device)
summary(model, input_size=(3, 224, 224))

请添加图片描述

模型训练

定义超参数,本次选择的学习率为0.0001,经实验,最初设置为0.01效果并不是很好。

loss_fn = nn.CrossEntropyLoss()
learn_rate = 0.0001
opt = torch.optim.SGD(model.parameters(), lr=learn_rate)

训练函数。

def train(dataloader, model, loss_fn, opt):size = len(dataloader.dataset)num_batches = len(dataloader)train_acc, train_loss = 0, 0for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)opt.zero_grad()loss.backward()opt.step()train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc /= sizetrain_loss /= num_batchesreturn train_acc, train_loss

测试函数。

def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)test_acc, test_loss = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)test_acc += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss += loss.item()test_acc /= sizetest_loss /= num_batchesreturn test_acc, test_loss

开始训练,训练20轮,在测试集准确率达到94.7%,还是很不错的。

import time
epochs = 20
train_loss = []
train_acc = []
test_loss = []
test_acc = []T1 = time.time()for epoch in range(epochs):model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)model.eval() # 确保模型不会进行训练操作epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)print("epoch:%d, train_acc:%.1f%%, train_loss:%.3f, test_acc:%.1f%%, test_loss:%.3f"% (epoch + 1, epoch_train_acc * 100, epoch_train_loss, epoch_test_acc * 100, epoch_test_loss))
print("Done")
T2 = time.time()
print('程序运行时间:%s毫秒' % ((T2 - T1)*1000))

请添加图片描述

数据可视化

使用matplotlib进行训练数据、测试数据的可视化。

import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 100        #分辨率epochs_range = range(epochs)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

请添加图片描述

总结

经过几次实验,发现三个问题:
1.经过实验,将学习率从0.01改为0.0001,模型效果会好很多。
2.有的时候每轮epoch准确率一直为百分之20多,可能是模型陷入局部最小值或鞍点,所以后续可以引入提前停止。
3.无脑的增加层数并不会使模型效果变好。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/78742.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux 修改SSH的显示样式,修改终端shell显示的样式,美观更改

要修改SSH的显示样式,您可以使用自定义的PS1(提示字符串1)变量来更改命令行提示符的外观。在您的情况下,您想要的格式似乎包括日期和时间,以及当前目录。以下是一个示例PS1设置,可以实现您所描述的样式&…

【搭建私人图床】本地PHP搭建简单Imagewheel云图床,在外远程访问

文章目录 1.前言2. Imagewheel网站搭建2.1. Imagewheel下载和安装2.2. Imagewheel网页测试2.3.cpolar的安装和注册 3.本地网页发布3.1.Cpolar临时数据隧道3.2.Cpolar稳定隧道(云端设置)3.3.Cpolar稳定隧道(本地设置) 4.公网访问测…

【Spring面试】三、Bean的配置、线程安全、自动装配

文章目录 Q1、什么是Spring Bean?和对象有什么区别Q2、配置Bean有哪几种方式?Q3、Spring支持的Bean有哪几种作用域?Q4、单例Bean的优势是什么?Q5、Spring的Bean是线程安全的吗?Q6、Spring如何处理线程并发问题&#xf…

【已解决】您所使用的密钥ak有问题,不支持jsapi服务,可以访问该网址了解如何获取有效密钥。

您所使用的密钥ak有问题,不支持jsapi服务,可以访问该网址了解如何获取有效密钥。详情查看:http://lbsyun.baidu.com/apiconsole/key#。 问题 百度密钥过期 思路 注册成为开发者 如果还没注册百度地图api账号的,点击以后就进入…

【深度学习】 Python 和 NumPy 系列教程(廿二):Matplotlib详解:2、3d绘图类型(8)3D饼图(3D Pie Chart)

一、前言 Python是一种高级编程语言,由Guido van Rossum于1991年创建。它以简洁、易读的语法而闻名,并且具有强大的功能和广泛的应用领域。Python具有丰富的标准库和第三方库,可以用于开发各种类型的应用程序,包括Web开发、数据分…

WebRTC 源码 编译 iOS端

1. 获取依赖工具 首先,确保你已经安装了以下工具: GitDepot ToolsXcode(确保已安装命令行工具) 2. 下载 depot_tools 使用 git 克隆 depot_tools 并将其添加到你的 PATH 中: /path/to/depot_tools 替换为自己的路径…

unity C#客户端与服务器程序

客户端和服务器公共的脚本 OSC.cs // This is version 1.01(2015.05.27) // Tested in Unity 4 // Most of the code is based on a library for the Make Controller Kit1/* using UnityEngine; using System; using System.Collections; using System.Threading; using Syst…

Furion api npm web vue混合开发

Furion api npm web vue混合开发 Furion-api项目获取swagger.json文件复制json制作ts包删除非.ts文件上传到npm获取npm包引用 Furion-api项目获取swagger.json文件 使用所有接口合并的配置文件 复制json制作ts包 https://editor.swagger.io 得到 typescript-axios-clien…

怎么科学管理固定资产呢

在当今的商业环境中,固定资产的管理是企业成功的关键因素之一。然而,传统的固定资产管理方法往往过于繁琐,缺乏创新,导致资源的浪费和效率的低下。因此,我们需要一种新的、更加科学的方法来管理我们的固定资产。本文将…

C++多线程的用法(包含线程池小项目)

一些小tips: 编译命令如下&#xff1a; g 7.thread_pool.cpp -lpthread 查看运行时间&#xff1a; time ./a.out 获得本进程的进程id&#xff1a; this_thread::get_id() 需要引入的库函数有&#xff1a; #include<thread> // 引入线程库 #include<mutex> //…

Ui自动化测试上传文件方法都在这里了 ~

前言 实施UI自动化测试的时候&#xff0c;经常会遇见上传文件的操作&#xff0c;那么对于上传文件你知道几种方法呢&#xff1f;今天我们就总结一下几种常用的上传文件的方法&#xff0c;并分析一下每个方法的优点和缺点以及哪种方法效率&#xff0c;稳定性更高 被测HTML代码…

睿趣科技:抖音开店前期需要准备什么

抖音作为全球最受欢迎的短视频平台之一&#xff0c;已经成为了许多年轻人的创业和赚钱的机会。如果你计划在抖音上开店&#xff0c;那么在正式开业之前&#xff0c;有一些重要的准备工作是必不可少的。下面就是抖音开店前期需要准备的关键步骤和注意事项。 确定你的目标和产品&…

Matlab图像处理-三原色

三原色 根据详细的实验结果&#xff0c;人眼中负责颜色感知的细胞中约有65%对红光敏感&#xff0c;33%对绿光敏感&#xff0c;只有2%对蓝光敏感。正是人眼的这些吸收特性决定了所看到的彩色是一般所谓的原色红&#xff08;R&#xff09;、绿&#xff08;G&#xff09;和蓝&…

动态渲染 echarts 饼图(vue 2 + axios + Springboot)

目录 前言1. 项目搭建1.1. 前端1.2. 后端 2. 后端数据渲染前端2.1 补充1&#xff1a;在 vue 中使用 axios2.2. 补充2&#xff1a;Springboot 处理跨域问题2.3. 修改前端代码2.3.1 修改饼图样式2.3.2 调用后台数据渲染饼图2.3.3 改造成内外两个圈 前言 因为上文中提到的需求就是…

内网隧道代理技术(二十五)之 ICMP隧道反弹SHELL

ICMP隧道反弹SHELL ICMP隧道原理 由于ICMP报文自身可以携带数据,而且ICMP报文是由系统内核处理的,不占用任何端口,因此具有很高的隐蔽性。把数据隐藏在ICMP数据包包头的data字段中,建立隐蔽通道,可以实现绕过防火墙和入侵检测系统的阻拦。 ICMP隧道有以下的优点: ICMP…

腾讯云4核8G服务器选CVM还是轻量比较好?价格对比

腾讯云4核8G云服务器可以选择轻量应用服务器或CVM云服务器标准型S5实例&#xff0c;轻量4核8G12M服务器446元一年&#xff0c;CVM S5云服务器935元一年&#xff0c;相对于云服务器CVM&#xff0c;轻量应用服务器性价比更高&#xff0c;轻量服务器CPU和CVM有区别吗&#xff1f;性…

博客系统(升级(Spring))(四)(完)基本功能(阅读,修改,添加,删除文章)(附带项目)

博客系统 (三&#xff09; 博客系统博客主页前端后端个人博客前端后端显示个人文章删除文章 修改文章前端后端提取文章修改文章 显示正文内容前端后端文章阅读量功能 添加文章前端后端 如何使用Redis项目地点&#xff1a; 博客系统 博客系统是干什么的&#xff1f; CSDN就是一…

数字化转型对企业有哪些优势?

数字化转型为企业提供了众多优势&#xff0c;帮助他们在日益数字化的世界中保持竞争力、敏捷性和响应能力。以下是一些主要优势&#xff1a; 1.提高效率和生产力&#xff1a; 重复性任务和流程的自动化可以减少人为错误&#xff0c;并使员工能够专注于更具战略性的任务。简化…

Apache Linki 1.3.1+DataSphereStudio+正常启动+微服务+端口号

我使用的是一键部署容器化版本&#xff0c;官方文章 默认会启动6个 Linkis 微服务&#xff0c;其中下图linkis-cg-engineconn服务为运行任务才会启动,一共七个 LINKIS-CG-ENGINECONN:38681 LINKIS-CG-ENGINECONNMANAGER:9102 引擎管理服务 LINKIS-CG-ENTRANCE:9104 计算治理入…

Vue开发小注意点

改bug 更改了配置项啥的&#xff0c;保存刷新发现没变&#xff0c;那就重启项目&#xff01;&#xff01;&#xff01;&#xff01; binding.value 和 e.target.value binding.value Day5 指令的值 e.target.value Day4 表单组件封装 binding.value 和 e.target.valu…