深度学习pytorch实战4---猴逗病识别·

>- **🍨 本文为[🔗365天深度学习训练营](https://mp.weixin.qq.com/s/0dvHCaOoFnW8SCp3JpzKxg) 中的学习记录博客**
>- **🍖 原作者:[K同学啊](https://mtyjkh.blog.csdn.net/)**

引言

1.复习上周并反思

K同学针对大家近期的学习状态,写了一篇文章,希望可以让大家有所思考icon-default.png?t=N7T8https://www.yuque.com/mingtian-fkmxf/yb7bg5/lrsbsdx6q4prla0f

 K同学推荐正确的实验方式icon-default.png?t=N7T8https://blog.csdn.net/Banana0840/article/details/137651851

本机环境欠佳,下次争取在云gpu上训练,并积极探索,实现目标准确率

天气识别icon-default.png?t=N7T8http://t.csdnimg.cn/sEYUf

2.摆正心态

数据集icon-default.png?t=N7T8http://xn--https-bl8js66z7n7i//pan.baidu.com/s/1N2otldhzYZSqyh-E69IdjA?pwd=6666%20%20%E6%8F%90%E5%8F%96%E7%A0%81%EF%BC%9A6666%20%20--%E6%9D%A5%E8%87%AA%E7%99%BE%E5%BA%A6%E7%BD%91%E7%9B%98%E8%B6%85%E7%BA%A7%E4%BC%9A%E5%91%98V3%E7%9A%84%E5%88%86%E4%BA%AB

3.本机环境

同上期,

4.学习目标

一、前期准备

1.设置GPU

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision
from torchvision import transforms, datasetsimport os,PIL,pathlibdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")device

2.导入数据

import os,PIL,random,pathlibdata_dir = './4-data/'
data_dir = pathlib.Path(data_dir)data_paths = list(data_dir.glob('*'))
classeNames = [str(path).split("\\")[1] for path in data_paths]
classeNames

3.图像预处理

 图片大小处理

total_datadir = './4-data/'# 关于transforms.Compose的更多介绍可以参考:https://blog.csdn.net/qq_38251616/article/details/124878863
train_transforms = transforms.Compose([transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸transforms.ToTensor(),          # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间transforms.Normalize(           # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])total_data = datasets.ImageFolder(total_datadir,transform=train_transforms)
total_data

total_data.class_to_idx{'Monkeypox': 0, 'Others': 1}
total_data.class_to_idx是一个存储了数据集类别和对应索引的字典。在PyTorch的ImageFolder数据加载器中,根据数据集文件夹的组织结构,每个文件夹代表一个类别,class_to_idx字典将每个类别名称映射为一个数字索引。
具体来说,如果数据集文件夹包含两个子文件夹,比如Monkeypox和Others,class_to_idx字典将返回类似以下的映射关系:{'Monkeypox': 0, 'Others': 1}

划分数据集(4:1)

train_size = int(0.8 * len(total_data))
test_size  = len(total_data) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size])
train_dataset, test_dataset

 数据集分类

train_size,test_size

 使用dataloder便于之后使用

batch_size = 32train_dl = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=1)
test_dl = torch.utils.data.DataLoader(test_dataset,batch_size=batch_size,shuffle=True,num_workers=1)
for X, y in test_dl:print("Shape of X [N, C, H, W]: ", X.shape)print("Shape of y: ", y.shape, y.dtype)breakShape of X [N, C, H, W]:  torch.Size([32, 3, 224, 224])
Shape of y:  torch.Size([32]) torch.int64

二、构建简单的CNN网络

import torch.nn.functional as Fclass Network_bn(nn.Module):def __init__(self):super(Network_bn, self).__init__()"""nn.Conv2d()函数:第一个参数(in_channels)是输入的channel数量第二个参数(out_channels)是输出的channel数量第三个参数(kernel_size)是卷积核大小第四个参数(stride)是步长,默认为1第五个参数(padding)是填充大小,默认为0"""self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=5, stride=1, padding=0)self.bn1 = nn.BatchNorm2d(12)self.conv2 = nn.Conv2d(in_channels=12, out_channels=12, kernel_size=5, stride=1, padding=0)self.bn2 = nn.BatchNorm2d(12)self.pool = nn.MaxPool2d(2,2)self.conv4 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=5, stride=1, padding=0)self.bn4 = nn.BatchNorm2d(24)self.conv5 = nn.Conv2d(in_channels=24, out_channels=24, kernel_size=5, stride=1, padding=0)self.bn5 = nn.BatchNorm2d(24)self.fc1 = nn.Linear(24*50*50, len(classeNames))def forward(self, x):x = F.relu(self.bn1(self.conv1(x)))      x = F.relu(self.bn2(self.conv2(x)))     x = self.pool(x)                        x = F.relu(self.bn4(self.conv4(x)))     x = F.relu(self.bn5(self.conv5(x)))  x = self.pool(x)                        x = x.view(-1, 24*50*50)x = self.fc1(x)return xdevice = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))model = Network_bn().to(device)
model

 输出结果

Using cuda deviceNetwork_bn((conv1): Conv2d(3, 12, kernel_size=(5, 5), stride=(1, 1))(bn1): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(12, 12, kernel_size=(5, 5), stride=(1, 1))(bn2): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(conv4): Conv2d(12, 24, kernel_size=(5, 5), stride=(1, 1))(bn4): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv5): Conv2d(24, 24, kernel_size=(5, 5), stride=(1, 1))(bn5): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(fc1): Linear(in_features=60000, out_features=2, bias=True)
)

三、训练模型

1.设置超参数

loss_fn    = nn.CrossEntropyLoss() # 创建损失函数
learn_rate = 1e-4 # 学习率
opt        = torch.optim.SGD(model.parameters(),lr=learn_rate)

2.编写训练函数

# 训练循环
def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)  # 训练集的大小,一共60000张图片num_batches = len(dataloader)   # 批次数目,1875(60000/32)train_loss, train_acc = 0, 0  # 初始化训练损失和正确率for X, y in dataloader:  # 获取图片及其标签X, y = X.to(device), y.to(device)# 计算预测误差pred = model(X)          # 网络输出loss = loss_fn(pred, y)  # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失# 反向传播optimizer.zero_grad()  # grad属性归零loss.backward()        # 反向传播optimizer.step()       # 每一步自动更新# 记录acc与losstrain_acc  += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc  /= sizetrain_loss /= num_batchesreturn train_acc, train_loss

3.编写测试函数

def test (dataloader, model, loss_fn):size        = len(dataloader.dataset)  # 测试集的大小,一共10000张图片num_batches = len(dataloader)          # 批次数目,313(10000/32=312.5,向上取整)test_loss, test_acc = 0, 0# 当不进行训练时,停止梯度更新,节省计算内存消耗with torch.no_grad():for imgs, target in dataloader:imgs, target = imgs.to(device), target.to(device)# 计算losstarget_pred = model(imgs)loss        = loss_fn(target_pred, target)test_loss += loss.item()test_acc  += (target_pred.argmax(1) == target).type(torch.float).sum().item()test_acc  /= sizetest_loss /= num_batchesreturn test_acc, test_loss

4.正式训练

epochs     = 20
train_loss = []
train_acc  = []
test_loss  = []
test_acc   = []for epoch in range(epochs):model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)model.eval()epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f}')print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))
print('Done')

 

四、结果可视化

import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 100        #分辨率epochs_range = range(epochs)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

五、保存模型并预测

# 模型保存
PATH = './model.pth'  # 保存的参数文件名
torch.save(model.state_dict(), PATH)# 将参数加载到model当中
model.load_state_dict(torch.load(PATH, map_location=device))
输出
<All keys matched successfully>

 六指定图片进行预测·

补充知识:torch.squeeze() & torch.unsqueeze() 详解

from PIL import Image classes = list(total_data.class_to_idx)def predict_one_image(image_path, model, transform, classes):test_img = Image.open(image_path).convert('RGB')# plt.imshow(test_img)  # 展示预测的图片test_img = transform(test_img)img = test_img.to(device).unsqueeze(0)model.eval()output = model(img)_,pred = torch.max(output,1)pred_class = classes[pred]print(f'预测结果是:{pred_class}')
# 预测训练集中的某张照片
predict_one_image(image_path='./4-data/Monkeypox/M01_01_00.jpg', model=model, transform=train_transforms, classes=classes)

七,总结

1,之前还在想字典文件是怎么获取的,这次清楚了,原来是此

total_data.class_to_idx是一个存储了数据集类别和对应索引的字典。在PyTorch的ImageFolder数据加载器中,根据数据集文件夹的组织结构,每个文件夹代表一个类别,class_to_idx字典将每个类别名称映射为一个数字索引。
具体来说,如果数据集文件夹包含两个子文件夹,比如Monkeypox和Others,class_to_idx字典将返回类似以下的映射关系:{'Monkeypox': 0, 'Others': 1}

2指定单张图片预测

接收图像路径和类别文件等参数,对图像进行读取和预处理,然后通过模型得到预测结果,并将其转换为可读的类别名称,最后打印出预测结果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/3656.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

python爬虫 - 爬取html中的script数据(36kr.com新闻信息)

文章目录 1. 分析页面内容数据格式2. 使用re.findall方法&#xff0c;爬取新闻3. 使用re.search 方法&#xff0c;爬取新闻 1. 分析页面内容数据格式 打开 https://36kr.com/ 按F12&#xff08;或 在网页上右键 --> 检查&#xff08;Inspect&#xff09;&#xff09; 找…

c++初阶——类和对象(中)

大家好&#xff0c;我是小锋&#xff0c;我们今天继续来学习类和对象。 类的6个默认成员函数 我们想一想如果一个类什么都没有那它就是一个空类&#xff0c;但是空类真的什么都没有吗&#xff1f; 其实并不是&#xff0c;任何类在什么都不写时&#xff0c;编译器会自动生成以…

电脑提示丢失iutils.dll怎么办?一分钟教你搞定dll丢失问题

在计算机世界中&#xff0c;DLL&#xff08;Dynamic Link Library&#xff0c;动态链接库&#xff09;文件扮演着至关重要的角色&#xff0c;它们如同乐高积木中的基础模块&#xff0c;不同程序通过调用这些模块来实现各种功能。其中&#xff0c;iutils.dll就是这样一款不可或缺…

transformer 最简单学习3, 训练文本数据输入的形式

1、输入数据中&#xff0c;源数据和目标数据的定义 def get_batch(source,i):用于获取每个批数据合理大小的源数据和目标数据参数source 是通过batchfy 得到的划分batch个 ,的所有数据&#xff0c;并且转置列表示i第几个batchbptt 15 #超参数&#xff0c;一次输入多少个ba…

聚类分析字符串数组

聚类分析字符串数组 对多个字符串进行聚类分析旨在根据它们之间的相似度将这些字符串划分成若干个类别&#xff0c;使得同一类别内的字符串彼此相似度高&#xff0c;而不同类别间的字符串相似度低 小结 数据要清洗。清洗的足够准确&#xff0c;可能不需要用聚类分析了数据要…

统一所有 LLM API:支持预算与速率限制 | 开源日报 No.229

BerriAI/litellm Stars: 6.7k License: NOASSERTION litellm 是一个使用 OpenAI 格式调用所有 LLM API 的工具。它支持 Bedrock、Azure、OpenAI、Cohere、Anthropic 等 100 多种 LLMs&#xff0c;提供企业级代理服务器和稳定版本 v1.30.2。 主要功能和优势包括&#xff1a; 将…

javaEE初阶——多线程(八)——常见的锁策略 以及 CAS机制

T04BF &#x1f44b;专栏: 算法|JAVA|MySQL|C语言 &#x1faf5; 小比特 大梦想 此篇文章与大家分享分治算法关于多线程进阶的章节——关于常见的锁策略以及CAS机制 如果有不足的或者错误的请您指出! 目录 多线程进阶1.常见的锁策略1.1乐观锁和悲观锁1.2重量级锁 和 轻量级锁1.…

【大数据】分布式数据库HBase

目录 1.概述 1.1.前言 1.2.数据模型 1.3.列式存储的优势 2.实现原理 2.1.region 2.2.LSM树 2.3.完整读写过程 2.4.master的作用 1.概述 1.1.前言 本文式作者大数据系列专栏中的一篇文章&#xff0c;按照专栏来阅读&#xff0c;循序渐进能更好的理解&#xff0c;专栏…

JS实现对用户名、密码进行正则表达式判断,按钮绑定多个事件,网页跳转

目标&#xff1a;使用JS实现对用户名和密码进行正则表达式判断&#xff0c;用户名和密码正确时&#xff0c;进行网页跳转。 用户名、密码的正则表达式检验 HTML代码&#xff1a; <button type"submit" id"login-btn" /*onclick"login();alidate…

精益思想赋能数字化转型:落地策略与实践路径

当下&#xff0c;数字化转型已不再是选择题&#xff0c;而是关乎企业生存与发展的必答题。然而&#xff0c;转型过程中如何确保效率、降低成本并快速实现价值创造&#xff0c;成为了摆在众多企业面前的难题。精益思想作为一种追求精益求精、持续改进的管理思维&#xff0c;为数…

2024最新版JavaScript逆向爬虫教程-------基础篇之面向对象

目录 一、概念二、对象的创建和操作2.1 JavaScript创建对象的方式2.2 对象属性操作的控制2.3 理解JavaScript创建对象2.3.1 工厂模式2.3.2 构造函数2.3.3 原型构造函数 三、继承3.1 通过原型链实现继承3.2 借用构造函数实现继承3.3 寄生组合式继承3.3.1 对象的原型式继承3.3.2 …

stm32HAL库-GPIO

一 什么是 GPIO: GPIO(general porpose intput output), 通用输入输出端口 . 二 我们先认识芯片控制 GPIO 输出控制。 2.1LED 硬件原理如图&#xff1a; 当电流从这根电线流通&#xff0c; LED 亮。当电流不通过这根电线&#xff0c; LED 灭。 上面 PF** &#xff0c;芯片电…

MySQL面试——聚簇/非聚簇索引

存储引擎是针对表结构&#xff0c;不是数据库 引擎层&#xff1a;对数据层以何种方式进行组织 update&#xff1a;加索引&#xff1a;行级锁&#xff1b;不加索引&#xff1a;表级锁

固态继电器:推进可再生能源系统

随着可再生能源系统的发展&#xff0c;太阳能系统日益成为现代能源解决方案的先锋。在这种背景下&#xff0c;固态继电器&#xff08;SSR&#xff09;&#xff0c;特别是光耦固态继电器的利用变得日益突出。本文旨在深入探讨SSR在可再生能源系统中的多方位应用&#xff0c;重点…

【学习笔记】Python 使用 matplotlib 画图

文章目录 安装中文显示折线图、点线图柱状图、堆积柱状图坐标轴断点参考资料 本文将介绍如何使用 Python 的 matplotlib 库画图&#xff0c;记录一些常用的画图 demo 代码 安装 # 建议先切换到虚拟环境中 pip install matplotlib中文显示 新版的 matplotlib 已经支持字体回退…

SD-WAN:灵活、低成本、便于管理

近年来&#xff0c;SD-WAN&#xff08;软件定义广域网&#xff09;技术成为企业网络领域的新趋势&#xff0c;其带来的变革性影响备受瞩目。凭借出色的灵活性、高效的可管理性以及显著的成本优势&#xff0c;SD-WAN技术为企业网络注入了新的活力。 首先&#xff0c;SD-WAN技术的…

如何利用diskpart命令界面在win10/win11上解除U盘写保护

背景 在把U盘作为系统盘装了一次后&#xff0c;惊讶的发现自己U盘的一个1M的小卷被写保护了。不能格式化&#xff0c;不能删除文件&#xff0c;在给用户拷文件的时候&#xff0c;小卷还会提示病毒告警&#xff0c;非常的尴尬&#xff0c;因此展开了研究。 失败的尝试 尝试了网…

58、回溯-组合总和

思路&#xff1a; 数组内的每一个元素都可以无线使用只要最后可以拼接成target就可以。那么如何限制呢&#xff1f; &#xff08;target-已经拼接的和 &#xff09;/当前元素 就是你可以利用的数量。代码如下&#xff1a; class Solution {public static List<List<I…

触发器的基本概念及分类

目录 触发器的基本概念 作用对象 触发事件 触发条件 触发时间 触发级别或者触发频率 触发器的分类 DML 触发器 INSTEAD OF 触发器 系统触发器 Oracle从入门到总裁:​​​​​​https://blog.csdn.net/weixin_67859959/article/details/135209645 触发器的基本概念 …

2024年电商视频号夏令营(第四期)零基础带你玩转微信视频号

教学内容&#xff1a; 下 载 地 址&#xff1a; laoa1.cn/1821.html 1.剪辑软件整套实例教程0基本一小时懂得视频编辑 1.上课前必看 1.如何获实拍视频的原创素材 2.怎样运送视频水印&#xff0c;提取图片文案脚本 2.如何发布爆款短视频 2.微信视频号基本功能解读 2.直播的时…