【深度学习】四种天气分类 模版函数 从0到1手敲版本

引入该引入的库

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision
import torch.optim as optim
%matplotlib inline
import os
import shutil
import glob
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

注意:os.environ[“KMP_DUPLICATE_LIB_OK”]=“TRUE” 必须要引入否则用plt出错

数据集整理

img_dir = r"F:\播放器\1、pytorch全套入门与实战项目\课程资料\参考代码和部分数据集\参考代码\参考代码\29-42节参考代码和数据集\四种天气图片数据集\dataset2"
base_dir = r"./dataset/4weather"img_list = glob.glob(img_dir+"/*.*")
test_dir = "test"
train_dir = "train"
species = ["cloudy","rain","shine","sunrise"]
for idx,img_path in enumerate(img_list):_,img_name = os.path.split(img_path)if idx%5==0:for specie in species:if img_path.find(specie) > -1:dst_dir = os.path.join(test_dir,specie)os.makedirs(dst_dir,exist_ok=True)dst_path = os.path.join(dst_dir,img_name)else:for specie in species:if img_path.find(specie) > -1:dst_dir = os.path.join(train_dir,specie)os.makedirs(dst_dir,exist_ok=True)dst_path = os.path.join(dst_dir,img_name)shutil.copy(img_path,dst_path)

生成测试和训练的文件夹,
目录结构如下:
在这里插入图片描述
rain 下面就是图片了
在这里插入图片描述

构建ds和dl

from torchvision import transforms
transform = transforms.Compose([transforms.Resize((96,96)),transforms.ToTensor(),transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])])
train_ds=torchvision.datasets.ImageFolder(train_dir,transform)
test_ds = torchvision.datasets.ImageFolder(train_dir,transform)

在这里插入图片描述
在这里插入图片描述
一张图片效果,这是rain图片 这里需要转换维度,把channel放到最后。同时把数据拉到0-1之间,原本std 和mean 【0.5,0,5】数据在-0.5~0.5之间
在这里插入图片描述
类的映射
在这里插入图片描述

plt.figure(figsize=(12, 8))
for i, (img, label) in enumerate(zip(imgs[:6], labels[:6])):img = (img.permute(1, 2, 0).numpy() + 1)/2plt.subplot(2, 3, i+1)plt.title(id_to_class.get(label.item()))plt.imshow(img)

这个方法要学会
在这里插入图片描述

定义网络

class Net(nn.Module):def __init__(self) -> None:super().__init__()self.conv1 = nn.Conv2d(3,16,3)self.conv2 = nn.Conv2d(16,32,3)self.conv3 = nn.Conv2d(32,64,3)self.pool = nn.MaxPool2d(2,2)self.dropout = nn.Dropout(0.3)self.fc1 = nn.Linear(64*10*10,1024)self.fc2 = nn.Linear(1024,4)def forward(self,x):x = F.relu(self.conv1(x))x = self.pool(x)x = F.relu(self.conv2(x))x = self.pool(x)x = F.relu(self.conv3(x))x = self.pool(x)x = self.dropout(x)# print(x.size()) 这里是可以计算出来的,需要掌握计算方法x = x.view(-1,64*10*10)x = F.relu(self.fc1(x))x = self.dropout(x)return self.fc2(x)
model = Net()        
preds = model(imgs)
preds.shape, preds

在这里插入图片描述
定义损失函数和优化函数:

loss_fn = nn.CrossEntropyLoss()
optim = torch.optim.Adam(model.parameters(),lr=0.001)

定义网络

def fit(epoch, model, trainloader, testloader):correct = 0total = 0running_loss = 0for x, y in trainloader:if torch.cuda.is_available():x, y = x.to('cuda'), y.to('cuda')y_pred = model(x)loss = loss_fn(y_pred, y)optim.zero_grad()loss.backward()optim.step()with torch.no_grad():y_pred = torch.argmax(y_pred, dim=1)correct += (y_pred == y).sum().item()total += y.size(0)running_loss += loss.item()epoch_loss = running_loss / len(trainloader.dataset)epoch_acc = correct / totaltest_correct = 0test_total = 0test_running_loss = 0 with torch.no_grad():for x, y in testloader:if torch.cuda.is_available():x, y = x.to('cuda'), y.to('cuda')y_pred = model(x)loss = loss_fn(y_pred, y)y_pred = torch.argmax(y_pred, dim=1)test_correct += (y_pred == y).sum().item()test_total += y.size(0)test_running_loss += loss.item()epoch_test_loss = test_running_loss / len(testloader.dataset)epoch_test_acc = test_correct / test_totalprint('epoch: ', epoch, 'loss: ', round(epoch_loss, 3),'accuracy:', round(epoch_acc, 3),'test_loss: ', round(epoch_test_loss, 3),'test_accuracy:', round(epoch_test_acc, 3))return epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc

训练:

epochs = 30
train_loss = []
train_acc = []
test_loss = []
test_acc = []for epoch in range(epochs):epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc = fit(epoch,model,train_dl,test_dl)train_loss.append(epoch_loss)train_acc.append(epoch_acc)test_loss.append(epoch_test_loss)test_acc.append(epoch_test_acc)
epoch:  0 loss:  0.043 accuracy: 0.714 test_loss:  0.029 test_accuracy: 0.809
epoch:  1 loss:  0.03 accuracy: 0.807 test_loss:  0.023 test_accuracy: 0.867
epoch:  2 loss:  0.024 accuracy: 0.857 test_loss:  0.018 test_accuracy: 0.888
epoch:  3 loss:  0.021 accuracy: 0.869 test_loss:  0.017 test_accuracy: 0.894
epoch:  4 loss:  0.018 accuracy: 0.886 test_loss:  0.014 test_accuracy: 0.921
epoch:  5 loss:  0.017 accuracy: 0.897 test_loss:  0.022 test_accuracy: 0.869
epoch:  6 loss:  0.013 accuracy: 0.923 test_loss:  0.008 test_accuracy: 0.944
epoch:  7 loss:  0.009 accuracy: 0.947 test_loss:  0.011 test_accuracy: 0.924
epoch:  8 loss:  0.006 accuracy: 0.966 test_loss:  0.004 test_accuracy: 0.988
epoch:  9 loss:  0.004 accuracy: 0.979 test_loss:  0.002 test_accuracy: 0.998
epoch:  10 loss:  0.004 accuracy: 0.979 test_loss:  0.005 test_accuracy: 0.966

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
比较重要的点,
1.分类的数据集布局要记住
2.图片经过conv2 多次后的值要会算 todo
3.图片展示的方法要会

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/765970.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

iOS应用审核问题解决方案及优化方法 ✨

摘要 本文将针对iOS应用提交审核时可能遇到的问题,如“你必须在Xcode中添加com.apple.developer.game-center密钥”,以及突然间提交送审报错情况进行探讨。通过大量查询资料和尝试,结合案例分析,提供了解决方案和优化方法&#x…

【模糊逻辑】Type-1 Fuzzy Systems-2

【模糊逻辑】Type-1 Fuzzy Systems 3.4.3 模糊化及其推理的影响3.4.3.1 Singleton Fuzzifier例3.5例3.6 3.4.3.2 Non-Singleton Fuzzifier例3.7 Non-Singleton Fuzzifier 量化求解 Firing Level 3.5 对规则触发(Fired-Rule)的输出集进行组合3.5.1Mamdani…

【linux】CentOS查看系统信息

一、查看版本号 在CentOS中,可以通过多种方法来查看版本号。以下是几种常用的方法: 使用cat命令查看/etc/centos-release文件: CentOS的版本信息存储在/etc/centos-release文件中。可以使用cat命令来显示该文件的内容,从而获得C…

力扣hot100:153. 寻找旋转排序数组中的最小值(二分的理解)

由力扣hot100:33. 搜索旋转排序数组(二分的理解)-CSDN博客,我们知道二分实际上就是找到一个策略将区间“均分”。对于旋转数组问题,在任何位置分开两个区间,如果原区间不是顺序的,分开后必然有一…

BRAM底层原理详细解释(1)

目录 一、原语 二、端口简述 2.1 端口简介 2.2 SDP端口映射 三、端口信号含义补充说明 3.1 字节写使能(Byte-Write Enable)- WEA and WEBWE: 3.2 地址总线—ADDRARDADDR and ADDRBWRADDR 3.3 数据总线—DIADI, DIPADIP, DIBDI, and D…

【c++初阶】C++入门(下)

✅✅✅✅✅✅✅✅✅✅✅✅✅✅✅✅ ✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨ 🌿🌿🌿🌿🌿🌿🌿🌿🌿🌿🌿🌿🌿🌿&#x1…

AI元年,这5款AI写作能为你提供帮助

自从人工智能技术的迅猛发展以来,AI在各个领域都取得了巨大的进步。其中,AI写作工具成为越来越多人关注的焦点。在这个AI元年,小编想向大家分享5款可能对你有帮助的AI写作工具,如果你也想找AI写作相关的工具,那么来看看…

【数据结构基础】之八大排序(C语言实现)

【数据结构基础】之八大排序(C语言实现) 🐧 冒泡排序♈️ 冒泡排序原理及代码实现♈️ 稳定性分析 🐧 选择排序♈️ 选择排序原理及代码实现♈️ 稳定性分析 🐧 插入排序♈️ 插入排序的原理及代码实现♈️ 稳定性分析 &#x1f4…

(附源码)基于Spring Boot和Vue的智能订餐与外卖系统设计与实现

1. 引言 这部分通常包含了研究背景、研究意义、国内外研究现状、本文研究内容以及论文结构安排。 研究背景:介绍当前外卖市场的快速发展,以及智能订餐系统对改善人们生活的影响。研究意义:强调这类系统在现代生活中的作用和开发的创新点。国…

Kubernetes一文上手【手把手系列】

目录 Kubernetes前言部署方式的演变 K8S概述K8S架构Master节点1. API Server2. Etcd3. Controller Manager4. Scheduler Node节点1. kubelet2. kube-proxy3. 容器运行时 组件与插件1. Kubernetes DNS2. Dashboard3. Heapster4. Ingress Controller K8S核心概念PodSerivceNamesp…

CodeSys创建自定义的html5控件

文章目录 背景创建html5control.xml文件控件界面以及逻辑的实现使用的资源安装自定义的html5控件库 背景 查看官方的资料:https://content.helpme-codesys.com/en/CODESYS%20Visualization/_visu_html5_dev.html 官方的例子:https://forge.codesys.com/…

使用 PyOpenGL 进行 2D 图形渲染总结

一、说明 OpenGL是一个广泛使用的开放式跨平台实时 3D 图形库,开发于二十多年前。它提供了一个低级API,允许开发人员以统一的方式访问图形硬件。在开发需要硬件加速且需要在不同平台上运行的复杂 2D 或 3D 应用程序时,它是首选平台。它可以在…

liunx centos7 下通过yum删除安装已经安装的php

执行下面命令查看php相关的包 rpm -qa | grep php 只需要卸载几个名为common的包即可,其他同版本依赖会被全部删除,删除php71w-common,71w版本的依赖包全部会被删除。 查看php包的命令 rpm -qa | grep php 或 yum list installed | gre…

unity编辑器扩展高级用法

在PropertyDrawer中,您不能使用来自GUILayout或EditorGUILayout的自动布局API,而只能使用来自GUI和EditorGUI的绝对Rect API始终传递相应的起始位置和维度。 你需要 计算显示嵌套内容所需的总高度将此高度添加到public override float GetPropertyHeig…

实用工具推荐:适用于 TypeScript 网络爬取的常用爬虫框架与库

随着互联网的迅猛发展,网络爬虫在信息收集、数据分析等领域扮演着重要角色。而在当前的技术环境下,使用TypeScript编写网络爬虫程序成为越来越流行的选择。TypeScript作为JavaScript的超集,通过类型检查和面向对象的特性,提高了代…

Linux :环境基础开发工具

目录: 1. Linux 软件包管理器 yum 1. 什么是软件包 2. 查看软件包 3. 如何安装软件 4. 如何卸载软件 2. Linux开发工具 1. Linux编辑器-vim的基本概念 2. vim使用 3. vim的基本操作 4. vim正常模式命令集 5. vim末行模式命令集 6. 简单vim配置 3. Linux编译器-gcc/…

常用相似度计算方法总总结

一、欧几里得相似度 1、欧几里得相似度 公式如下所示: 2、自定义代码实现 import numpy as np def EuclideanDistance(x, y):import numpy as npx np.array(x)y np.array(y)return np.sqrt(np.sum(np.square(x-y)))# 示例数据 # 用户1 的A B C D E商品数据 [3.3…

知识管理软件那么多,怎么挑选才适合初创企业?

对于初创企业来说,资源有限,效率显得尤其重要。此时,一个强大的知识管理软件就显得必不可少。它不仅利于信息的录入、查找和共享,还可以帮助团队更好的组织和协作,提高工作效率。那么,在众多的知识管理软件…

SQL-Labs靶场“34-35”关通关教程

君衍. 一、34关 POST单引号宽字节注入1、源码分析2、联合查询注入3、updatexml报错注入4、floor报错注入 二、35关 GET数字型报错注入1、源码分析2、联合查询注入3、updatexml报错注入4、floor报错注入 SQL-Labs靶场通关教程: SQL注入第一课 SQL注入思路基础 SQL无列…

第 6 章 ROS-xacro练习(自学二刷笔记)

重要参考: 课程链接:https://www.bilibili.com/video/BV1Ci4y1L7ZZ 讲义链接:Introduction Autolabor-ROS机器人入门课程《ROS理论与实践》零基础教程 6.4.3 Xacro_完整使用流程示例 需求描述: 使用 Xacro 优化 URDF 版的小车底盘模型实现 结果演示: 1.编写 X…