深度学习pytorch实战4---猴逗病识别·

>- **🍨 本文为[🔗365天深度学习训练营](https://mp.weixin.qq.com/s/0dvHCaOoFnW8SCp3JpzKxg) 中的学习记录博客**
>- **🍖 原作者:[K同学啊](https://mtyjkh.blog.csdn.net/)**

引言

1.复习上周并反思

K同学针对大家近期的学习状态,写了一篇文章,希望可以让大家有所思考icon-default.png?t=N7T8https://www.yuque.com/mingtian-fkmxf/yb7bg5/lrsbsdx6q4prla0f

 K同学推荐正确的实验方式icon-default.png?t=N7T8https://blog.csdn.net/Banana0840/article/details/137651851

本机环境欠佳,下次争取在云gpu上训练,并积极探索,实现目标准确率

天气识别icon-default.png?t=N7T8http://t.csdnimg.cn/sEYUf

2.摆正心态

数据集icon-default.png?t=N7T8http://xn--https-bl8js66z7n7i//pan.baidu.com/s/1N2otldhzYZSqyh-E69IdjA?pwd=6666%20%20%E6%8F%90%E5%8F%96%E7%A0%81%EF%BC%9A6666%20%20--%E6%9D%A5%E8%87%AA%E7%99%BE%E5%BA%A6%E7%BD%91%E7%9B%98%E8%B6%85%E7%BA%A7%E4%BC%9A%E5%91%98V3%E7%9A%84%E5%88%86%E4%BA%AB

3.本机环境

同上期,

4.学习目标

一、前期准备

1.设置GPU

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision
from torchvision import transforms, datasetsimport os,PIL,pathlibdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")device

2.导入数据

import os,PIL,random,pathlibdata_dir = './4-data/'
data_dir = pathlib.Path(data_dir)data_paths = list(data_dir.glob('*'))
classeNames = [str(path).split("\\")[1] for path in data_paths]
classeNames

3.图像预处理

 图片大小处理

total_datadir = './4-data/'# 关于transforms.Compose的更多介绍可以参考:https://blog.csdn.net/qq_38251616/article/details/124878863
train_transforms = transforms.Compose([transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸transforms.ToTensor(),          # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间transforms.Normalize(           # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])total_data = datasets.ImageFolder(total_datadir,transform=train_transforms)
total_data

total_data.class_to_idx{'Monkeypox': 0, 'Others': 1}
total_data.class_to_idx是一个存储了数据集类别和对应索引的字典。在PyTorch的ImageFolder数据加载器中,根据数据集文件夹的组织结构,每个文件夹代表一个类别,class_to_idx字典将每个类别名称映射为一个数字索引。
具体来说,如果数据集文件夹包含两个子文件夹,比如Monkeypox和Others,class_to_idx字典将返回类似以下的映射关系:{'Monkeypox': 0, 'Others': 1}

划分数据集(4:1)

train_size = int(0.8 * len(total_data))
test_size  = len(total_data) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size])
train_dataset, test_dataset

 数据集分类

train_size,test_size

 使用dataloder便于之后使用

batch_size = 32train_dl = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=1)
test_dl = torch.utils.data.DataLoader(test_dataset,batch_size=batch_size,shuffle=True,num_workers=1)
for X, y in test_dl:print("Shape of X [N, C, H, W]: ", X.shape)print("Shape of y: ", y.shape, y.dtype)breakShape of X [N, C, H, W]:  torch.Size([32, 3, 224, 224])
Shape of y:  torch.Size([32]) torch.int64

二、构建简单的CNN网络

import torch.nn.functional as Fclass Network_bn(nn.Module):def __init__(self):super(Network_bn, self).__init__()"""nn.Conv2d()函数:第一个参数(in_channels)是输入的channel数量第二个参数(out_channels)是输出的channel数量第三个参数(kernel_size)是卷积核大小第四个参数(stride)是步长,默认为1第五个参数(padding)是填充大小,默认为0"""self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=5, stride=1, padding=0)self.bn1 = nn.BatchNorm2d(12)self.conv2 = nn.Conv2d(in_channels=12, out_channels=12, kernel_size=5, stride=1, padding=0)self.bn2 = nn.BatchNorm2d(12)self.pool = nn.MaxPool2d(2,2)self.conv4 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=5, stride=1, padding=0)self.bn4 = nn.BatchNorm2d(24)self.conv5 = nn.Conv2d(in_channels=24, out_channels=24, kernel_size=5, stride=1, padding=0)self.bn5 = nn.BatchNorm2d(24)self.fc1 = nn.Linear(24*50*50, len(classeNames))def forward(self, x):x = F.relu(self.bn1(self.conv1(x)))      x = F.relu(self.bn2(self.conv2(x)))     x = self.pool(x)                        x = F.relu(self.bn4(self.conv4(x)))     x = F.relu(self.bn5(self.conv5(x)))  x = self.pool(x)                        x = x.view(-1, 24*50*50)x = self.fc1(x)return xdevice = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))model = Network_bn().to(device)
model

 输出结果

Using cuda deviceNetwork_bn((conv1): Conv2d(3, 12, kernel_size=(5, 5), stride=(1, 1))(bn1): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(12, 12, kernel_size=(5, 5), stride=(1, 1))(bn2): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(conv4): Conv2d(12, 24, kernel_size=(5, 5), stride=(1, 1))(bn4): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv5): Conv2d(24, 24, kernel_size=(5, 5), stride=(1, 1))(bn5): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(fc1): Linear(in_features=60000, out_features=2, bias=True)
)

三、训练模型

1.设置超参数

loss_fn    = nn.CrossEntropyLoss() # 创建损失函数
learn_rate = 1e-4 # 学习率
opt        = torch.optim.SGD(model.parameters(),lr=learn_rate)

2.编写训练函数

# 训练循环
def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)  # 训练集的大小,一共60000张图片num_batches = len(dataloader)   # 批次数目,1875(60000/32)train_loss, train_acc = 0, 0  # 初始化训练损失和正确率for X, y in dataloader:  # 获取图片及其标签X, y = X.to(device), y.to(device)# 计算预测误差pred = model(X)          # 网络输出loss = loss_fn(pred, y)  # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失# 反向传播optimizer.zero_grad()  # grad属性归零loss.backward()        # 反向传播optimizer.step()       # 每一步自动更新# 记录acc与losstrain_acc  += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc  /= sizetrain_loss /= num_batchesreturn train_acc, train_loss

3.编写测试函数

def test (dataloader, model, loss_fn):size        = len(dataloader.dataset)  # 测试集的大小,一共10000张图片num_batches = len(dataloader)          # 批次数目,313(10000/32=312.5,向上取整)test_loss, test_acc = 0, 0# 当不进行训练时,停止梯度更新,节省计算内存消耗with torch.no_grad():for imgs, target in dataloader:imgs, target = imgs.to(device), target.to(device)# 计算losstarget_pred = model(imgs)loss        = loss_fn(target_pred, target)test_loss += loss.item()test_acc  += (target_pred.argmax(1) == target).type(torch.float).sum().item()test_acc  /= sizetest_loss /= num_batchesreturn test_acc, test_loss

4.正式训练

epochs     = 20
train_loss = []
train_acc  = []
test_loss  = []
test_acc   = []for epoch in range(epochs):model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)model.eval()epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f}')print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))
print('Done')

 

四、结果可视化

import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 100        #分辨率epochs_range = range(epochs)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

五、保存模型并预测

# 模型保存
PATH = './model.pth'  # 保存的参数文件名
torch.save(model.state_dict(), PATH)# 将参数加载到model当中
model.load_state_dict(torch.load(PATH, map_location=device))
输出
<All keys matched successfully>

 六指定图片进行预测·

补充知识:torch.squeeze() & torch.unsqueeze() 详解

from PIL import Image classes = list(total_data.class_to_idx)def predict_one_image(image_path, model, transform, classes):test_img = Image.open(image_path).convert('RGB')# plt.imshow(test_img)  # 展示预测的图片test_img = transform(test_img)img = test_img.to(device).unsqueeze(0)model.eval()output = model(img)_,pred = torch.max(output,1)pred_class = classes[pred]print(f'预测结果是:{pred_class}')
# 预测训练集中的某张照片
predict_one_image(image_path='./4-data/Monkeypox/M01_01_00.jpg', model=model, transform=train_transforms, classes=classes)

七,总结

1,之前还在想字典文件是怎么获取的,这次清楚了,原来是此

total_data.class_to_idx是一个存储了数据集类别和对应索引的字典。在PyTorch的ImageFolder数据加载器中,根据数据集文件夹的组织结构,每个文件夹代表一个类别,class_to_idx字典将每个类别名称映射为一个数字索引。
具体来说,如果数据集文件夹包含两个子文件夹,比如Monkeypox和Others,class_to_idx字典将返回类似以下的映射关系:{'Monkeypox': 0, 'Others': 1}

2指定单张图片预测

接收图像路径和类别文件等参数,对图像进行读取和预处理,然后通过模型得到预测结果,并将其转换为可读的类别名称,最后打印出预测结果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/3656.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

STM32 float浮点数转换成四个字节

float浮点数转换成四个字节 在C或C中&#xff0c;联合体&#xff08;union&#xff09;是一种特殊的数据结构&#xff0c;它允许在相同的内存位置存储不同的数据类型。联合体中的所有成员共享同一块内存区域&#xff0c;这意味着同一时间内&#xff0c;联合体只能保存其中一个…

python爬虫 - 爬取html中的script数据(36kr.com新闻信息)

文章目录 1. 分析页面内容数据格式2. 使用re.findall方法&#xff0c;爬取新闻3. 使用re.search 方法&#xff0c;爬取新闻 1. 分析页面内容数据格式 打开 https://36kr.com/ 按F12&#xff08;或 在网页上右键 --> 检查&#xff08;Inspect&#xff09;&#xff09; 找…

c++初阶——类和对象(中)

大家好&#xff0c;我是小锋&#xff0c;我们今天继续来学习类和对象。 类的6个默认成员函数 我们想一想如果一个类什么都没有那它就是一个空类&#xff0c;但是空类真的什么都没有吗&#xff1f; 其实并不是&#xff0c;任何类在什么都不写时&#xff0c;编译器会自动生成以…

电脑提示丢失iutils.dll怎么办?一分钟教你搞定dll丢失问题

在计算机世界中&#xff0c;DLL&#xff08;Dynamic Link Library&#xff0c;动态链接库&#xff09;文件扮演着至关重要的角色&#xff0c;它们如同乐高积木中的基础模块&#xff0c;不同程序通过调用这些模块来实现各种功能。其中&#xff0c;iutils.dll就是这样一款不可或缺…

transformer 最简单学习3, 训练文本数据输入的形式

1、输入数据中&#xff0c;源数据和目标数据的定义 def get_batch(source,i):用于获取每个批数据合理大小的源数据和目标数据参数source 是通过batchfy 得到的划分batch个 ,的所有数据&#xff0c;并且转置列表示i第几个batchbptt 15 #超参数&#xff0c;一次输入多少个ba…

Java 异常

异常概念 Java异常是程序在执行过程中发生的错误事件&#xff0c;它打断了正常的流程控制流。 在Java中&#xff0c;异常被定义为一种特殊的对象&#xff0c;它们可以被抛出、捕获和处理。以下是关于Java异常的一些详细信息&#xff1a; 异常分类&#xff1a;Java中的异常分…

AI大模型评测问题集合

文章目录 编程能力测试工作场景价值观测试生活场景数学能力测试中文能力测试编程能力测试 第一题:Google 面试题,Python 题目 给定一组不同的整数数组,给出一个算法对这些整数进行随机排序,使每个重 排序方法的可能性相等。换句话说,给定一副牌,你要如何洗牌才能确保牌的…

vue2 mixin的用法

在 Vue 2 中&#xff0c;mixin 是一种分发 Vue 组件中可复用功能的非常灵活的方式。一个 mixin 对象可以包含任意组件选项。当组件使用 mixin 对象时&#xff0c;所有 mixin 对象的选项将被“混合”进入该组件本身的选项。 下面是如何在 Vue 2 中使用 mixin 的基本步骤&#x…

聚类分析字符串数组

聚类分析字符串数组 对多个字符串进行聚类分析旨在根据它们之间的相似度将这些字符串划分成若干个类别&#xff0c;使得同一类别内的字符串彼此相似度高&#xff0c;而不同类别间的字符串相似度低 小结 数据要清洗。清洗的足够准确&#xff0c;可能不需要用聚类分析了数据要…

C++ //练习 13.34 编写本节所描述的Message。

C Primer&#xff08;第5版&#xff09; 练习 13.34 练习 13.34 编写本节所描述的Message。 环境&#xff1a;Linux Ubuntu&#xff08;云服务器&#xff09; 工具&#xff1a;vim 代码块 /*************************************************************************>…

机器学习笔记-01

一…AI&#xff08;人工智能&#xff09; 二.机器学习–是人工智能实现的途径 三.深度学习–是机器学习的一个方法 1.机器学习能做什么&#xff1a; 1.1 传统预测 1.2 图像识别 1.3 自然语言处理&#xff08;nlp&#xff09; 2.数据集包含&#xff1a;特征值 目标值 3.机器学…

python绘制三维图

在Python中&#xff0c;我们可以使用matplotlib库中的mplot3d工具包来绘制三维图。下面是一个简单的例子&#xff0c;绘制了一个三维的散点图和一个三维曲面图&#xff1a; 首先&#xff0c;确保已经安装了matplotlib库。如果没有&#xff0c;可以通过pip进行安装&#xff1a;…

C#中的Task:异步编程的瑞士军刀

在现代软件开发中&#xff0c;异步编程已经成为处理I/O密集型任务和网络操作的重要手段。C#中的Task是.NET Framework 4.0引入的一个并发编程的抽象&#xff0c;它在后续的.NET Core和.NET 5中得到了进一步的发展和完善。Task代表了一个异步操作&#xff0c;可以等待它的完成&a…

统一所有 LLM API:支持预算与速率限制 | 开源日报 No.229

BerriAI/litellm Stars: 6.7k License: NOASSERTION litellm 是一个使用 OpenAI 格式调用所有 LLM API 的工具。它支持 Bedrock、Azure、OpenAI、Cohere、Anthropic 等 100 多种 LLMs&#xff0c;提供企业级代理服务器和稳定版本 v1.30.2。 主要功能和优势包括&#xff1a; 将…

javaEE初阶——多线程(八)——常见的锁策略 以及 CAS机制

T04BF &#x1f44b;专栏: 算法|JAVA|MySQL|C语言 &#x1faf5; 小比特 大梦想 此篇文章与大家分享分治算法关于多线程进阶的章节——关于常见的锁策略以及CAS机制 如果有不足的或者错误的请您指出! 目录 多线程进阶1.常见的锁策略1.1乐观锁和悲观锁1.2重量级锁 和 轻量级锁1.…

【大数据】分布式数据库HBase

目录 1.概述 1.1.前言 1.2.数据模型 1.3.列式存储的优势 2.实现原理 2.1.region 2.2.LSM树 2.3.完整读写过程 2.4.master的作用 1.概述 1.1.前言 本文式作者大数据系列专栏中的一篇文章&#xff0c;按照专栏来阅读&#xff0c;循序渐进能更好的理解&#xff0c;专栏…

JS实现对用户名、密码进行正则表达式判断,按钮绑定多个事件,网页跳转

目标&#xff1a;使用JS实现对用户名和密码进行正则表达式判断&#xff0c;用户名和密码正确时&#xff0c;进行网页跳转。 用户名、密码的正则表达式检验 HTML代码&#xff1a; <button type"submit" id"login-btn" /*onclick"login();alidate…

探索企业微信助手工具:强化沟通协作,助力高效办公

随着企业信息化建设的深入发展&#xff0c;企业微信助手工具作为一种集成化、智能化的办公辅助工具&#xff0c;正逐渐受到企业的青睐。企业微信助手不仅能够帮助企业提高工作效率&#xff0c;还能增强沟通协作能力&#xff0c;为企业发展注入新的活力。本文将简要介绍企业微信…

精益思想赋能数字化转型:落地策略与实践路径

当下&#xff0c;数字化转型已不再是选择题&#xff0c;而是关乎企业生存与发展的必答题。然而&#xff0c;转型过程中如何确保效率、降低成本并快速实现价值创造&#xff0c;成为了摆在众多企业面前的难题。精益思想作为一种追求精益求精、持续改进的管理思维&#xff0c;为数…

2024最新版JavaScript逆向爬虫教程-------基础篇之面向对象

目录 一、概念二、对象的创建和操作2.1 JavaScript创建对象的方式2.2 对象属性操作的控制2.3 理解JavaScript创建对象2.3.1 工厂模式2.3.2 构造函数2.3.3 原型构造函数 三、继承3.1 通过原型链实现继承3.2 借用构造函数实现继承3.3 寄生组合式继承3.3.1 对象的原型式继承3.3.2 …