DataLoader与Dataset

一、人民币二分类在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

二、DataLoader 与 Dataset

DataLoader

torch.utils.data.DataLoader

功能:构建可迭代的数据装载器
(只标注了较为重要的)
• dataset: Dataset类,决定数据从哪读取及如何读取
• batchsize : 批大小
• num_works: 是否多进程读取数据
• shuffle: 每个epoch是否乱序
• drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据

DataLoader(dataset,batch_size=1,shuffle=False,sampler=None,batch_sampler=None,num_workers=0,collate_fn=None,pin_memory=False,drop_last=False,timeout=0,worker_init_fn=None,multiprocessing_context=None
)
  • Epoch: 所有训练样本都已输入到模型中,称为一个Epoch
  • Iteration:一批样本输入到模型中,称之为一个Iteration
  • Batchsize:批大小,决定一个Epoch有多少个Iteration

样本总数:80, Batchsize:8
1 Epoch = 10 Iteration

样本总数:87, Batchsize:8
1 Epoch = 10 Iteration ? drop_last = True
1 Epoch = 11 Iteration ? drop_last = False

根据给定的样本总数和批大小,可以计算出一个Epoch中的Iteration数量。

  1. 样本总数为80,批大小为8:
    • 一个Epoch中的Iteration数量 = 样本总数 / 批大小 = 80 / 8 = 10
  2. 样本总数为87,批大小为8,且设置drop_last = True
    • 一个Epoch中的Iteration数量 = 样本总数 // 批大小 = 87 // 8 = 10
  3. 样本总数为87,批大小为8,且设置drop_last = False
    • 一个Epoch中的Iteration数量 = (样本总数 + 批大小 - 1) // 批大小 = (87 + 8 - 1) // 8 = 11

在第3种情况下,由于样本总数无法被批大小整除,因此在最后一个Epoch中会有一个额外的Iteration来处理剩余的样本。

Dataset

torch.utils.data.Dataset

功能:Dataset抽象类,所有自定义的Dataset需要继承它,并且复写__getitem__()

getitem :接收一个索引,返回一个样本

class Dataset(object):def __getitem__(self, index):raise NotImplementedErrordef __add__(self, other):return ConcatDataset([self, other])

上述代码定义了一个名为Dataset的类,该类是一个抽象基类。它包含了两个特殊方法:

  1. __getitem__(self, index)方法:这是一个抽象方法,需要在子类中实现。它用于根据给定的索引index返回对应的数据样本。在这里,抛出了NotImplementedError异常,表示子类必须覆盖这个方法来提供具体的实现。
  2. __add__(self, other)方法:这是一个特殊方法,用于实现对象的加法操作。在这里,它返回一个ConcatDataset对象,该对象将当前的self和另一个other数据集合并在一起。__add__方法的返回值是一个ConcatDataset对象,表示将当前数据集和另一个数据集进行连接。ConcatDataset是PyTorch中的一个类,用于将多个数据集连接在一起,以便在训练过程中一起使用。

四、模型训练

# -*- coding: utf-8 -*-
"""
# @file name  : train_lenet.py
# @author     : siuserjy
# @date       : 2024-01-03 20:50:38
# @brief      : 人民币分类模型训练
"""
import os# 获取当前文件的目录路径
BASE_DIR = os.path.dirname(os.path.abspath(__file__))# 导入必要的库和模块
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from matplotlib import pyplot as plt# 定义lenet.py和common_tools.py文件的路径并检查文件是否存在
path_lenet = os.path.abspath(os.path.join(BASE_DIR, "..", "..", "model", "lenet.py"))
path_tools = os.path.abspath(os.path.join(BASE_DIR, "..", "..", "tools", "common_tools.py"))
assert os.path.exists(path_lenet), "{}不存在,请将lenet.py文件放到 {}".format(path_lenet, os.path.dirname(path_lenet))
assert os.path.exists(path_tools), "{}不存在,请将common_tools.py文件放到 {}".format(path_tools, os.path.dirname(path_tools))# 将自定义模块所在的目录添加到Python路径中
import sys
hello_pytorch_DIR = os.path.abspath(os.path.dirname(__file__) + os.path.sep + ".." + os.path.sep + "..")
sys.path.append(hello_pytorch_DIR)# 从自定义模块导入所需内容
from model.lenet import LeNet
from tools.my_dataset import RMBDataset
from tools.common_tools import set_seed# 设置随机种子
set_seed()# 定义人民币数据集的标签
rmb_label = {"1": 0, "100": 1}# 设置训练参数
MAX_EPOCH = 10
BATCH_SIZE = 16
LR = 0.01
log_interval = 10
val_interval = 1# ============================ step 1/5 数据 ============================# 设置数据集路径
split_dir = os.path.abspath(os.path.join(BASE_DIR, "..", "..", "data", "rmb_split"))
if not os.path.exists(split_dir):raise Exception(r"数据 {} 不存在, 回到lesson-06\1_split_dataset.py生成数据".format(split_dir))# 设置训练集和验证集路径
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")# 设置图像的均值和标准差
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]# 设置训练集的数据预处理
train_transform = transforms.Compose([transforms.Resize((32, 32)),  # 将图像大小调整为32x32transforms.RandomCrop(32, padding=4),  # 随机裁剪32x32大小的图像transforms.ToTensor(),  # 将图像转换为Tensor格式transforms.Normalize(norm_mean, norm_std),  # 标准化图像
])# 设置验证集的数据预处理
valid_transform = transforms.Compose([transforms.Resize((32, 32)),  # 将图像大小调整为32x32transforms.ToTensor(),  # 将图像转换为Tensor格式transforms.Normalize(norm_mean, norm_std),  # 标准化图像
])# 构建训练集和验证集的数据集实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)# 构建训练集和验证集的DataLoader
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)# ============================ step 2/5 模型 ============================
# 构建LeNet模型实例
net = LeNet(classes=2)
net.initialize_weights()# ============================ step 3/5 损失函数 ============================
# 设置损失函数
criterion = nn.CrossEntropyLoss()# ============================ step 4/5 优化器 ============================
# 设置优化器
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)# 设置学习率下降策略
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)# ============================ step 5/5 训练 ============================
train_curve = list()  # 记录训练集的loss值
valid_curve = list()  # 记录验证集的loss值for epoch in range(MAX_EPOCH):  # 迭代训练多个epochloss_mean = 0.  # 记录每个epoch的平均loss值correct = 0.  # 记录分类正确的样本数量total = 0.  # 记录总样本数量net.train()  # 将模型设置为训练模式for i, data in enumerate(train_loader):  # 遍历训练集数据# forwardinputs, labels = data  # 获取输入数据和标签outputs = net(inputs)  # 将输入数据输入模型,得到输出结果# backwardoptimizer.zero_grad()  # 将模型参数的梯度置零loss = criterion(outputs, labels)  # 计算损失值loss.backward()  # 反向传播,计算梯度# update weightsoptimizer.step()  # 更新模型参数# 统计分类情况_, predicted = torch.max(outputs.data, 1)  # 获取预测结果total += labels.size(0)  # 累计总样本数量correct += (predicted == labels).squeeze().sum().numpy()  # 累计分类正确的样本数量# 打印训练信息loss_mean += loss.item()  # 累计每个batch的loss值train_curve.append(loss.item())  # 将每个batch的loss值记录下来if (i+1) % log_interval == 0:  # 每隔一定的batch数打印一次训练信息loss_mean = loss_mean / log_interval  # 计算平均loss值print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))loss_mean = 0.  # 重置loss_meanscheduler.step()  # 更新学习率# validate the modelif (epoch+1) % val_interval == 0:  # 每隔一定的epoch数进行一次验证correct_val = 0.  # 记录验证集分类正确的样本数量total_val = 0.  # 记录验证集总样本数量loss_val = 0.  # 记录验证集的loss值net.eval()  # 将模型设置为评估模式with torch.no_grad():  # 不计算梯度for j, data in enumerate(valid_loader):  # 遍历验证集数据inputs, labels = data  # 获取输入数据和标签outputs = net(inputs)  # 将输入数据输入模型,得到输出结果loss = criterion(outputs, labels)  # 计算损失值_, predicted = torch.max(outputs.data, 1)  # 获取预测结果total_val += labels.size(0)  # 累计验证集总样本数量correct_val += (predicted == labels).squeeze().sum().numpy()  # 累计验证集分类正确的样本数量loss_val += loss.item()  # 累计验证集的loss值loss_val_epoch = loss_val / len(valid_loader)  # 计算验证集每个epoch的平均loss值valid_curve.append(loss_val_epoch)  # 将验证集每个epoch的平均loss值记录下来print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val_epoch, correct_val / total_val))# 绘制训练曲线和验证曲线
train_x = range(len(train_curve))  # 训练曲线的x轴
train_y = train_curve  # 训练曲线的y轴train_iters = len(train_loader)  # 训练集的迭代次数
valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval - 1  # 验证曲线的x轴,将epoch转换为iteration
valid_y = valid_curve  # 验证曲线的y轴plt.plot(train_x, train_y, label='Train')  # 绘制训练曲线
plt.plot(valid_x, valid_y, label='Valid')  # 绘制验证曲线plt.legend(loc='upper right')  # 设置图例位置
plt.ylabel('loss value')  # 设置y轴标签
plt.xlabel('Iteration')  # 设置x轴标签
plt.show()  # 显示图像# ============================ inference ============================# 设置基本路径
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
test_dir = os.path.join(BASE_DIR, "test_data")# 创建测试数据集
test_data = RMBDataset(data_dir=test_dir, transform=valid_transform)# 创建验证数据加载器
valid_loader = DataLoader(dataset=test_data, batch_size=1)# 遍历验证数据集
for i, data in enumerate(valid_loader):# 前向传播inputs, labels = dataoutputs = net(inputs)_, predicted = torch.max(outputs.data, 1)# 判断预测结果是1元还是100元rmb = 1 if predicted.numpy()[0] == 0 else 100# 打印模型获得的金额print("模型获得{}元".format(rmb))

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/594519.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

机器学习异常值处理 逻辑汇总一

一 清除数据中恒定不变值 如果某个数据长时间不变,默认异常,清除掉该部分数据: # 使用 shift 和 cumsum 来创建一个分组键,每次值改变都会增加组号 g (df[沉淀池3号进水流量] ! df[沉淀池3号进水流量].shift()).cumsum()# 使用…

案例253:基于微信小程序的懂球短视频管理系统

文末获取源码 开发语言:Java 框架:SpringBoot JDK版本:JDK1.8 数据库:mysql 5.7 开发软件:eclipse/myeclipse/idea Maven包:Maven3.5.4 小程序框架:uniapp 小程序开发软件:HBuilder …

GC8549 大电流,双通道 12V,短地短电源保护等功能 可替代ONSEMI的LV8548/LV8549

GC8549 可以工作在 3.8~12V 的电源电压上,每 通道能提供高达 1.5A 持续输出电流或者 2.5A 峰值 电流,睡眠模式下功耗小于 1uA。具有 PWM(IN/EN)输入接口,与行业标 准器件兼容,并具有过温保护,欠压保护&…

【计算机网络】网络层

文章目录 网络层提供的服务虚电路数据报服务虚电路与数据报服务比较 虚拟互连网络IP地址IP层次结构IP地址分类特殊地址子网掩码 子网划分变长子网划分超网合并网络规律 IP地址与MAC地址ARP协议ARP欺骗的应用 数据包数据包首部 路由ICMP协议RIP动态路由协议OSPF协议BGP协议 VPNN…

进程等待(wait和wait函数)【Linux】

进程等待 wait和wait函数【Linux】 进程等待的概念进程等待的必要性进程等待的方法wait函数waitpid函数 非阻塞等待和阻塞等待的对比阻塞等待:非阻塞等待 进程等待的概念 进程等待就是通过 wait/waitpid的方式,让父进程对子进程进行等待子进程退出并且将…

【信号处理:小波包转换(WPT)/小波包分解(WPD) 】

【信号处理:小波包转换(WPT)/小波包分解(WPD) 】 小波包变换简介WPT/WPD的基础知识WPT/WPD的主要特点The Wavelet Packet Transform 小波包变换前向小波数据包变换最佳基础和成本函数数学中波纹的最佳基础其他成本函数…

酷狗高级Java面试真题

今年IT寒冬,大厂都裁员或者准备裁员,作为开猿节流主要目标之一,我们更应该时刻保持竞争力。为了抱团取暖,林老师开通了《知识星球》,并邀请我阿里、快手、腾讯等的朋友加入,分享八股文、项目经验、管理经验…

洛谷普及组P1044栈,题目讲解(无数论基础,纯打表找规律)

[NOIP2003 普及组] 栈 - 洛谷 我先写了个打表的代码&#xff0c;写了一个小时&#xff0c;o(╥﹏╥)o只能说我真不擅长dfs。 int n; std::unordered_map<std::string, int>map; void dfs(std::vector<int>&a, int step,std::stack<int>p, std::string …

Nginx中include配置文件,方便管理多域名

目录 1.加上include配置 2.配置 server 记录 一个网站对应一个server 记录&#xff0c;这样管理起来相对麻烦。我们可以将每个网站记录单独拆分出来即可&#xff0c;这就需要用到 nginx 中的 conf.d 文件 1.加上include配置 先进入到 nginx.conf 文件&#xff0c;然后将所有…

CentOS7搭建Elasticsearch与Kibana服务

1.部署单点es 1.1.创建网络 因为我们还需要部署kibana容器&#xff0c;因此需要让es和kibana容器互联。这里先创建一个网络&#xff1a; docker network create es-net 1.2拉取elasticsearch镜像 docker pull elasticsearch:7.11.1 1.3.运行 运行docker命令&#xff0c;部…

paddle v4 hubserving 部署

环境准备&#xff1a;https://github.com/PaddlePaddle/PaddleOCR/tree/release/2.7/deploy/hubserving#24-%E5%90%AF%E5%8A%A8%E6%9C%8D%E5%8A%A1 服务器启动命令 hub serving start -c deploy/hubserving/ocr_system/config.json客户端请求 python tools/test_hubserving.…

Java ORM 框架 Mybatis详解

&#x1f4d6; 内容 Mybatis 的前身就是 iBatis &#xff0c;是一款优秀的持久层框架&#xff0c;它支持自定义 SQL、存储过程以及高级映射。本文以一个 Mybatis 完整示例为切入点&#xff0c;结合 Mybatis 底层源码分析&#xff0c;图文并茂的讲解 Mybatis 的核心工作机制。 …

面向对象编程(高级)

面向对象编程&#xff08;高级&#xff09; 1、类变量和类方法 &#xff08;1&#xff09; 概念 类变量&#xff0c;也称为静态变量&#xff0c;是指在类级别声明的变量。它们与特定类相关联&#xff0c;而不是与类的实例&#xff08;对象&#xff09;相关联。每个类变量只有…

JavaSE语法之十五:异常(超全!!!)

文章目录 一、异常的概念与体系1. 异常的概念2. 异常的体系结构3. 异常的分类 二、异常的处理方式1. 防御式编程&#xff08;1&#xff09;LBYL 事前防御型&#xff08;2&#xff09;EAFP 时候认错型 2. 异常的抛出3. 异常的捕获&#xff08;1&#xff09;异常声明的 throws&am…

工作流入门这篇就够了!

总概 定义&#xff1a;工作流是在计算机支持下业务流程的自动或半自动化&#xff0c;其通过对流程进行描述以及按一定规则执行以完成相应工作。 应用&#xff1a;随着计算机技术的发展以及工业生产、办公自动化等领域的需求不断提升&#xff0c;面向事务审批、材料提交、业务…

OpenCV中实现图像旋转的方法

OpenCV中实现图像旋转的方法 函数&#xff1a;cv2.flip() 功能&#xff1a;水平或者垂直翻转 格式&#xff1a;dst cv2.flip(src,flipCode[,dst]) 参数说明&#xff1a; src&#xff1a;输入图像 dst&#xff1a;和原图像具有相同大小、类型的目标图像。 flipCode&#…

【Hotspot源码】揭秘Java线程创建过程中的各种细节

近期准备给大家分享专题系列文章&#xff0c;聚焦Java多线程机制。会从hotspot源码角度&#xff0c;给大家揭秘平时学习多线程那些从来没有想过的问题&#xff0c;或者存在疑虑却又无法证明的理论。 今天是系列文章首篇&#xff0c;咱们来谈谈Java线程创建的一些细节问题&#…

Vue 中的 ref 与 reactive:让你的应用更具响应性(中)

&#x1f90d; 前端开发工程师&#xff08;主业&#xff09;、技术博主&#xff08;副业&#xff09;、已过CET6 &#x1f368; 阿珊和她的猫_CSDN个人主页 &#x1f560; 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 &#x1f35a; 蓝桥云课签约作者、已在蓝桥云…

Lumerical Script------for语句

Lumerical------for语句 正文正文 关于 Lumerical 中 for 语句的用法这里不做过多说明了,仅仅做一个记录,具体用法如下: 通常我们用的比较多的形式是第一种步长值为 1 的情况。对于其他步长值的情况,我们可以使用第二种用法。对于 while 的类似使用方法可以使用第三种。 …

用Audio2Face驱动UE - MetaHuman

新的一年咯&#xff0c;很久没发博客了&#xff0c;就发两篇最近的研究吧。 开始之前说句话&#xff0c;别轻易保存任何内容&#xff0c;尤其是程序员不要轻易Ctrl S 在UE中配置Audio2Face 先检查自身电脑配置看是否满足&#xff0c;按最小配置再带个UE可能会随时崩&#x…