【新手适用】手把手教你从零开始实现一个基于Pytorch的卷积神经网络CNN二: 如何训练模型,内附详细损失、准确率、均值计算

手把手教你从零开始实现一个基于Pytorch的卷积神经网络CNN(新手适用)一: model.py:创建模块-CSDN博客

 从零开始实现一个基于Pytorch的卷积神经网络 - 知乎

目录

1 设备device定义

2 训练模型定义

3 开始训练

3.1 step、batchsize和数据集中图片数的关系

3.2 关于类型转换

4 保存模型

5 训练过程可视化

5.1 损失

5.1.1 损失的计算

 5.1.2 平均损失的计算

5.2 准确率

5.2.1 准确率计算

5.2.2 平均准确率计算

 6. train.py完整代码

7. 训练结果


设备device定义

通过torch.device()来指定使用的设备device,然后通过.to()方法将模型和数据放到指定的设备上,这样我们就可以通过定义device来指定是在cpu还是显卡上进行训练了,而且在多显卡的情况下也可以指定使用其中的某一张显卡进行训练。

torch.cuda.is_available()可以判断本设备是否支持CUDA,如果支持就返回True,不支持就返回False。这个函数可以让其自动判断是否支持CUDA加速并自动选择设备。

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

训练模型定义

  1. 初始化和导入模型
  2. 定义超参数、数据集和DataLoader
  3. 定义损失函数loss function和优化器optimizer
  4. 启用梯度:torch.set_grad_enabled(True)

代码如下,其中的具体解释可看知乎原文。

# 训练数据
import torch
import torchvision
import torch.nn as nn
import torch.utils.data as Data# 1. 导入模型文件并且定义
from model import LeNet
model = LeNet()
# 2. 定义参数:轮数,批次和学习率
Epoch = 5
batch_size = 64
lr = 0.001
# 3. 获取训练数据
train_data = torchvision.datasets.MNIST(root='./data',train=True,transform=torchvision.transforms.ToTensor(),download=False)
#定义train data的数据集
train_loader = Data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True)
# 4. 定义损失函数、优化器
loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# 5. 梯度计算
torch.set_grad_enabled(True)
#启用Batch Normalization层和Dropout层
model.train()

model.train()方法:该方法用于启用Batch Normalization层和Dropout层。虽然模型中并没有这两层,但是我们不妨将其加上,并作为一个习惯,以免在真正需要时忘记。

3 开始训练

1) 获得DataLoader中的数据x和标签y

2) 将优化器的梯度清零

3) 将数据送入模型中获得预测的结果y_pred

4) 将标签和预测结果送入损失函数获得损失

5) 将损失值反向传播

6) 使用优化器对模型的参数进行更新

# 训练
for epoch in range(Epoch):for step,data in enumerate(train_data):# 取出data中的数据和标签x,y=data# 优化器梯度清零optimizer.zero_grad()# 计算预测值y_pred = model(x.to(device,torch.float))# 计算损失loss = loss_function(y_pred, y.to(device, torch.long))# 梯度更新loss.backward()# 优化器更新模型参数optimizer.step()

3.1 step、batchsize和数据集中图片数的关系

一个 epoch 表示将训练数据集中的所有样本都用于训练一次。步数(steps)表示在一个 epoch 中所执行的批次数量。

  • step的计算方法:将总样本数除以批次大小。

我们在前面定义的batchsize是64且丢弃最后一批,手写数字数据集中有6万张图片,60000/64=937.5,故每个epoch中有step=937。

3.2 关于类型转换

# 计算预测值y_pred = model(x.to(device,torch.float))# 计算损失loss = loss_function(y_pred, y.to(device, torch.long))

两行中都涉及到了类型转换。在深度学习中,通常希望输入数据和模型参数的数据类型是一致的,这样可以避免类型不匹配的错误,并且可以更有效地利用硬件加速器(如 GPU)进行计算。

  1. 第一行把y_pred输入数据转换为 torch.float 类型的目的是确保模型接收到的数据类型与模型参数的数据类型匹配。在很多情况下,神经网络模型的参数通常是浮点数类型(float),因此将输入数据转换为 torch.float 类型可以确保与模型参数的类型匹配。
  2. 在深度学习中,通常使用整数类型(如 torch.long)来表示类别标签或离散值。许多损失函数(例如交叉熵损失函数)在计算损失时需要模型输出的预测值和真实标签值具有相同的数据类型。

4 保存模型

把模型的定义和参数全保存在一个文件中,后面可以直接使用训练好的权重文件。

torch.save(model, './LeNet.pkl')

训练过程可视化

  • 查看训练过程中的损失和准确率等等过程参数。

可以在每隔一定的step后输出当前损失和准确率的平均值。MNIST的训练集共有六万张图像,而我们的batch_size是64且丢弃最后一批,因此在每个Epoch中有937个step,实际训练59968张图像。可以每迭代100次后输出当前Epoch的损失和准确率的平均值,并输出当前处在哪一次Epochstep

5.1 损失

5.1.1 损失的计算

对每次计算产生的损失进行相加,把结果放在running_loss中,因此,我们需要在反向传播后添加一个累加操作。由于loss在我们之前定义的设备上,因此我们需要获得loss的值,然后将其传回cpu并转换为float类型,即:

 # 计算损失loss = loss_function(y_pred, y.to(device, torch.long))# 梯度更新loss.backward()# 累加梯度running_loss += float(loss.data.cpu())

 5.1.2 平均损失的计算

如果当前step是100次,则计算一次平均损失。由于step从0开始,所以需要+1.

loss_avg = running_loss / (step + 1)

5.2 准确率

5.2.1 准确率计算

acc = 预测正确的数目 / 总数目

y_pred是一个二维的张量,其形状为[batch_size, num_classes],在这边channel是10,即十个数字。如果我们将batch中的任意一行提取出来就获得了一个10维的向量,向量里的每个数代表与其下标所对应的标签的相关性,相关性越大则代表越有可能是这个数字。

因此,我们需要获得这个向量中最大数的下标,在pytorch中,我们可以用.argmax(dim)方法实现,输入维度dim,即可返回这个维度下最大值的下标,即pred = y_pred.argmax(dim=1)。在此基础上,我们就可以计算其预测正确的数量了;

先获取pred的值,即模型预测的图片的类别,然后传回cpu,用==筛选模型预测的标签和图片标注的标签相等的个数,然后使用.sum()相加统计预测正确的个数

acc保存的即是模型预测正确的个数。

acc += (pred.data.cpu()==y.data).sum()

5.2.2 平均准确率计算

接下来先对steps进行统计,设置每轮数中的每100个steps计算一次平均损失、准确率等。

平均损失值 = 损失值 / steps

平均准确率 = 预测正确的个数 / 总个数 

 # 判断该step是不是该epoch中的第100步if step%100==99:# 平均损失 = 损失值/stepsloss_avg = running_loss / (step+1)# 平均准确率 = 预测正确的数量/总个数acc_avg = float(acc / ((step + 1) * batch_size))# 输出print('Epoch', epoch + 1, ',step', step + 1, '| Loss_avg: %.4f' % loss_avg, '|Acc_avg:%.4f' % acc_avg)

 6. train.py完整代码

# 训练数据
import torch
import torchvision
import torch.nn as nn
import torch.utils.data as Data
# 导入模型文件并且定义
from model import LeNetmodel = LeNet()
# 定义参数:轮数,批次和学习率
Epoch = 5
batch_size = 64
lr = 0.001# 获取训练数据
train_data = torchvision.datasets.MNIST(root='./data',train=True,transform=torchvision.transforms.ToTensor(),download=False)
#定义train data的数据集
train_loader = Data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True)
# 定义损失函数、优化器
loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
#梯度计算
torch.set_grad_enabled(True)
#启用Batch Normalization层和Dropout层
model.train()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)# 训练
for epoch in range(Epoch):# 定义损失值和准确率running_loss = 0.0acc = 0.0for step,data in enumerate(train_loader):# 取出data中的数据和标签,x为数据y为标签x,y=data# 优化器梯度清零optimizer.zero_grad()# 计算预测值y_pred = model(x.to(device,torch.float))# 计算损失loss = loss_function(y_pred, y.to(device, torch.long))# 梯度更新loss.backward()# 累加梯度running_loss += float(loss.data.cpu())# 取出预测的最大值pred = y_pred.argmax(dim=1)# 统计预测正确的个数acc += (pred.data.cpu()==y.data).sum()# 优化器更新模型参数optimizer.step()# 判断该step是不是该epoch中的第100步if step%100==99:# 平均损失 = 损失值/stepsloss_avg = running_loss / (step+1)# 平均准确率 = 预测正确的数量/总个数acc_avg = float(acc / ((step + 1) * batch_size))# 输出print('Epoch', epoch + 1, ',step', step + 1, '| Loss_avg: %.4f' % loss_avg, '|Acc_avg:%.4f' % acc_avg)# 保存模型
torch.save(model, './LeNet.pkl')

7. 训练结果

可以看到loss呈下降趋势,acc提升。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/787058.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

4.机器学习-十大算法之一线性回归算法(LinearRegression)案例讲解

机器学习-十大算法之一线性回归算法案例讲解 一摘要二个人简介三什么是线性回归四LinearRegression使用方法五糖尿病数据线性回归预测1.数据说明2.导包3.导入数据4.脱敏处理5.抽取训练数据和预测数据6.创建模型7.预测8.线性回归评估指标9.研究每个特征和标记结果之间的关系.来分…

职场成功的关键:提升软实力,成就非凡事业

在竞争激烈的职场中,专业技能固然重要,但软实力同样不可或缺。要想在职场中脱颖而出,实现事业上的成功,我们需要在提升软实力上下功夫。本文将探讨职场软实力的内涵及其在职场成功中的作用,并提供一些建议,…

解决Quartus与modelsim联合仿真问题:# Error loading design解决,是tb文件中没加:`timescale 1ns/1ns

解决Quartus与modelsim联合仿真问题:# Error loading design解决,是tb文件中没加:timescale 1,一直走下来,在modelsim中出现了下面问题2,rtl文件、tb文件2.1,rtl代码2.2,tb测试2.3&a…

java Web实现用户登录功能

文章目录 一、纯JSP方式实现用户登录功能(一)实现思路1、创建Web项目2、创建登录页面3、创建登录处理页面4、创建登录成功页面5、创建登录失败页面6、编辑项目首页 (三)测试结果 二、JSPServlet方式实现用户登录功能(一…

校园通勤车可视化系统的设计与实现

1.需求分析: 校园通勤车可视化系统的设计与实现,不用管什么可视化,就是一个小程序就是可以知道校园车的路线,然后往简单了弄就可以。 校园通勤车可视化系统的设计与实现,不用管什么可视化,就是一个小程序…

【C/C++】C++学籍信息管理系统(源码+报告)【独一无二】

👉博__主👈:米码收割机 👉技__能👈:C/Python语言 👉公众号👈:测试开发自动化【获取源码商业合作】 👉荣__誉👈:阿里云博客专家博主、5…

【threejs】较大物体或shape的贴图较小问题处理方法

问题 有的场景内相对体型差距过大的物体(如山地 海洋等)由于尺寸问题,加载贴图过于小,同时shader也无法完全展示,如图 我们可以获取物体的uv,进行缩放使得贴图可以完全展开 如果uv是乱的 可以用xyz坐标最…

【Redis】MISCONF Redis is configured to save RDB snapshots报错解决方案

【Redis】MISCONF Redis is configured to save RDB snapshots报错解决方案 大家好 我是寸铁👊 总结了一篇【Redis】MISCONF Redis is configured to save RDB snapshots报错解决方案✨ 喜欢的小伙伴可以点点关注 💝 前言 今天在登录redis时&#xff0c…

matlab函数化简和函数极限

文章目录 化简求函数极限泰勒公式泰勒公式求解 化简 simplify 函数是MATLAB中符号计算工具箱提供的一个函数,用于简化数学表达式。它可以根据预定义的简化规则,对给定的数学表达式进行简化和转化。 以下是simplify 函数的一些常用用法: 简…

[蓝桥杯 2022 省 B] 李白打酒加强版

题目链接 [蓝桥杯 2022 省 B] 李白打酒加强版 题目描述 话说大诗人李白,一生好饮。幸好他从不开车。 一天,他提着酒壶,从家里出来,酒壶中有酒 2 2 2 斗。他边走边唱: 无事街上走,提壶去打酒。 逢店加一倍…

python_绘图_多条折线图绘制_显示与隐藏

1. 需求 给定一个二维数组 100行, 5列, 每一列绘制一条折线, 横轴为行索引, 纵轴为对应位置的值, 绘制在一个子图里面, 使用python plot, 使用随机颜色进行区别添加显示和隐藏按钮, 可以对每条折线进行显示和隐藏 2. 代码 import numpy as np import matplotlib.pyplot as p…

为什么说FMEA是最主要的可靠性设计工具?——FMEA软件

免费试用FMEA软件-免费版-SunFMEA FMEA,即故障模式与影响分析(Failure Modes and Effects Analysis),是一种预防性的质量工具,广泛应用于各种工程领域,特别是在产品设计和制造过程中。它通过对产品或过程中…

工具_git提交时忽略某些文件或者目录,git提交排除某些文件或目录

git 提交时如果想忽略某些文件或者目录: 1.在根目录下创建 .gitignore 文件 2.在该文件中直接添加内容,如: 忽略.mdb、.sln、.sln,.config 文件,不忽视 .txt 文件 *.mdb *.ldb *.sln .config !.txt 忽略Debug目录及文件&#…

4月2日 qt密码生成小程序(可选择生成密码的格式),基于Python框架下的pyqt6

4月2日 密码生成小程序 代码展示: import stringfrom PyQt6.QtWidgets import (QApplication, QDialog,QMessageBox ) from untitled import Ui_PasswordGender import sys import string # py模块含有字符 import randomclass MyPasswordGenerate(Ui_Password…

快速入门Linux,Linux岗位有哪些?(一)

文章目录 Linux与Linux运维操作系统?操作系统图解 认识LinuxLinux受欢迎的原因什么是Linux运维Linux运维岗位Linux运维岗位职责Linux运维架构师岗位职责Linux运维职业发展路线计算机硬件分类运维人员的三大核心职责 运维人员工作(服务器)什么…

python爬虫----了解爬虫(十一天)

🎈🎈作者主页: 喔的嘛呀🎈🎈 🎈🎈所属专栏:python爬虫学习🎈🎈 ✨✨谢谢大家捧场,祝屏幕前的小伙伴们每天都有好运相伴左右,一定要天天…

6000000IOPS!FASS×kunpeng920全新突破

实测数据详见下文 网络环境 前端和后端网均采用100GE网络,管理网采用1Gbps以太网。 前端网和后端网通过不同网段隔离,与管理网物理隔离。 软硬件配置 存储端配置: 客户端配置: 软件配置: 存储集群配置: …

【MapBox】实现实时飞行轨迹功能

之前写了一篇MapBox添加带箭头的轨迹线,现在在这个基础之上实现获取到无人机的推送点位数据实时飞行的功能 首先创建实例,将无人机的图标加载在地图上 const MAP_UAV_FLIGHT_ING (values, layerKey 无人机飞行) > {ClearUAVMap();const map GET_…

认证的无线网络安全

​在今天,大多数计算机网络都是无线网络,更易管理,移动性更强,而且成本也更低。而另一方面,无线网络也更容易受到攻击,因为企业无法将网络的物理范围控制在办公楼内,大街或停车场上的任何人都有…

postgresql数据库扩展之fdw

1.介绍 PostgreSQL中的Foreign Data Wrapper(FDW)是一个强大的功能,它允许你访问和操作存储在外部源中的数据,就好像它是PostgreSQL数据库内的一个表一样。这意味着你可以直接从PostgreSQL查询和联接不同数据库和系统中的数据。F…