Pytorch深度学习-----完整神经网络模型训练套路

系列文章目录

PyTorch深度学习——Anaconda和PyTorch安装
Pytorch深度学习-----数据模块Dataset类
Pytorch深度学习------TensorBoard的使用
Pytorch深度学习------Torchvision中Transforms的使用(ToTensor,Normalize,Resize ,Compose,RandomCrop)
Pytorch深度学习------torchvision中dataset数据集的使用(CIFAR10)
Pytorch深度学习-----DataLoader的用法
Pytorch深度学习-----神经网络的基本骨架-nn.Module的使用
Pytorch深度学习-----神经网络的卷积操作
Pytorch深度学习-----神经网络之卷积层用法详解
Pytorch深度学习-----神经网络之池化层用法详解及其最大池化的使用
Pytorch深度学习-----神经网络之非线性激活的使用(ReLu、Sigmoid)
Pytorch深度学习-----神经网络之线性层用法
Pytorch深度学习-----神经网络之Sequential的详细使用及实战详解
Pytorch深度学习-----损失函数(L1Loss、MSELoss、CrossEntropyLoss)
Pytorch深度学习-----优化器详解(SGD、Adam、RMSprop)
Pytorch深度学习-----现有网络模型的使用及修改(VGG16模型)
Pytorch深度学习-----神经网络模型的保存与加载(VGG16模型)


文章目录

  • 系列文章目录
  • 一、完整神经网络训练一般步骤
    • 1.数据集加载步骤
    • 2.模型创建步骤
    • 3.损失函数和优化器定义步骤
    • 4.训练循环步骤
    • 5.测试循环步骤
    • 6.训练和测试过程的记录和输出步骤
    • 7.结束训练步骤
  • 二、代码演示
  • 三、对上面代码进一步总结


一、完整神经网络训练一般步骤

1.数据集加载步骤

  • 使用适当的库加载数据集,例如torchvision、TensorFlow的tf.data等。
  • 将数据集分为训练集和测试集,并进行必要的预处理,如归一化、数据增强等。

2.模型创建步骤

  • 创建机器学习模型,可以是深度神经网络、传统机器学习模型或其它模型类型。
  • 定义模型架构,包括输入层、隐藏层和输出层的结构、激活函数、损失函数等。

3.损失函数和优化器定义步骤

  • 定义适当的损失函数来计算模型预测结果于真实标签之间的差异。
  • 选择适当的优化器算法来更新模型参数,如随机梯度下降(SGD)、Adam等。

4.训练循环步骤

  • 从训练集中获取一批样本数据,并将其输入模型进行前向传播。
  • 计算损失函数,并根据损失函数进行反向传播和参数更新。
  • 重复以上步骤,直到达到预定的训练次数或达到收敛条件。

5.测试循环步骤

  • 从测试集中获取一批样本数据,并将其输入模型进行前向传播。
  • 计算损失函数或评估指标,用于评估模型在测试集上的性能。

6.训练和测试过程的记录和输出步骤

  • 使用适当的工具或库记录训练过程中的损失值、准确率、评估指标等。
  • 可以使用TensorBoard、matplotlib、CSV文件等方式记录和可视化训练和测试结果。

7.结束训练步骤

  • 根据训练结束条件、例如达到预定的训练次数或收敛条件,结束训练。
  • 可以保存模型参数或整个模型,以便日后部署和使用。

二、代码演示

创建model.py代码如下

import torch
from torch import nn# 搭建神经网络
class Lgl(nn.Module):def __init__(self):super(Lgl, self).__init__()self.model = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(kernel_size=2),nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(kernel_size=2),nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(kernel_size=2),nn.Flatten(),nn.Linear(in_features=64*4*4, out_features=64),nn.Linear(in_features=64, out_features=10))def forward(self, x):x = self.model(x)return x

上述模型的原图如下所示
在这里插入图片描述

trains.py文件开始对模型按步骤进行训练代码如下

import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter# 加载自己搭建的神经网络
from model import *
"""
1.数据集加载
"""
# 准备训练数据集
train_data = torchvision.datasets.CIFAR10(root="dataset", train=True, transform=torchvision.transforms.ToTensor(), download=True)
# 准备测试数据集
test_data = torchvision.datasets.CIFAR10(root="dataset", train=False, transform=torchvision.transforms.ToTensor(), download=True)
# 训练数据集的长度
train_data_sise = len(train_data)
print("训练数据集的长度为:{}".format(train_data_sise))
# 测试数据集的长度
test_data_sise = len(test_data)
print("测试数据集的长度:".format(test_data_sise))
# 加载数据集
train_dataloader = DataLoader(test_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)
"""
2.模型的创建,这里直接from model import * 故下面直接调用
"""
# 实例化网络模型
lgl = Lgl()
"""
3.损失函数和优化器
"""
# 定义交叉熵损失函数
loss_fn = nn.CrossEntropyLoss()
# 进行优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(lgl.parameters(), lr=learning_rate)
"""
4.训练循环步骤
4.1 为训练做的参数准备工作
"""
# 开始设置训练神经网络的一些参数
# 记录训练的次数
total_train_step = 0
# 记录测试的次数
total_test_step = 0
# 记录是第几轮训练
epoch = 10
# 添加Tensorboard
writer = SummaryWriter("logs")
for i in range(epoch):print("----第{}轮训练开始----".format(i))"""4.2 训练循环"""# 训练步骤for data in train_dataloader:imgs, targets = dataoutputs = lgl(imgs)loos_result = loss_fn(outputs, targets)# 优化器优化模型# 将上一轮的梯度清零optimizer.zero_grad()# 借助梯度进行反向传播loos_result.backward()optimizer.step()total_train_step = total_train_step + 1if total_train_step % 100 == 0:print("训练次数:{}, loss:{}".format(total_train_step, loos_result.item()))writer.add_scalar("train_loos", loos_result.item(), total_train_step)"""5.测试循环"""# 测试步骤开始total_test_loos = 0with torch.no_grad():for imgs, targets in test_dataloader:outputs = lgl(imgs)loos_result = loss_fn(outputs, targets)total_test_loos = total_test_loos + loos_result.item()"""6.测试过程的记录和输出"""print("整体测试集上损失函数loos:{}".format(total_test_loos))writer.add_scalar("test_loos", total_test_loos, total_test_step)total_test_step = total_test_step + 1torch.save(lgl, "test_{}.pth".format(i))print("模型已保存")
"""
7.结束训练步骤
"""
writer.close()

运行结果

训练数据集的长度为:50000
测试数据集的长度:
----0轮训练开始----
训练次数:100, loss:2.2938857078552246
整体测试集上损失函数loos:359.07741928100586
模型已保存
----1轮训练开始----
训练次数:200, loss:2.2591800689697266
训练次数:300, loss:2.263899087905884
整体测试集上损失函数loos:351.34613394737244
模型已保存
----2轮训练开始----
训练次数:400, loss:2.175294876098633
整体测试集上损失函数loos:340.2291133403778
模型已保存
----3轮训练开始----
训练次数:500, loss:2.096158981323242
训练次数:600, loss:1.9759657382965088
整体测试集上损失函数loos:344.92591202259064
模型已保存
----4轮训练开始----
训练次数:700, loss:2.043778896331787
整体测试集上损失函数loos:333.33667516708374
模型已保存
----5轮训练开始----
训练次数:800, loss:1.9719760417938232
训练次数:900, loss:1.8361881971359253
整体测试集上损失函数loos:318.2255847454071
模型已保存
----6轮训练开始----
训练次数:1000, loss:1.832183599472046
整体测试集上损失函数loos:303.4973853826523
模型已保存
----7轮训练开始----
训练次数:1100, loss:1.8691924810409546
训练次数:1200, loss:2.0134520530700684
整体测试集上损失函数loos:292.21254682540894
模型已保存
----8轮训练开始----
训练次数:1300, loss:1.7631018161773682
训练次数:1400, loss:1.6039265394210815
整体测试集上损失函数loos:283.98761427402496
模型已保存
----9轮训练开始----
训练次数:1500, loss:1.7172112464904785
整体测试集上损失函数loos:276.9621036052704
模型已保存

tensorboard中显示
在这里插入图片描述

三、对上面代码进一步总结

数据集加载步骤:

使用torchvision库加载CIFAR10数据集。
将训练集和测试集分别存放在train_data和test_data中。

模型创建步骤:

引用model.py文件,在其中创建名为"Lgl"的模型。

损失函数和优化器定义步骤:

定义损失函数为交叉熵损失(nn.CrossEntropyLoss)。
定义优化器为随机梯度下降(SGD)优化器,并将模型参数传递给优化器。

训练循环步骤:

从训练数据(train_dataloader)中迭代获取一个批次的图像和目标标签。
执行模型的前向传播,计算损失,执行反向传播,更新模型参数。
记录训练过程中的损失值。
每100个训练步骤后,打印当前的训练次数和损失值。

测试循环步骤:

使用torch.no_grad()上下文环境。
从测试数据(test_dataloader)中迭代获取一个批次的图像和目标标签。
执行模型的前向传播和损失计算,并累加测试集上的损失值。

损失记录和输出步骤:

使用SummaryWriter创建一个TensorBoard的日志写入器。
将训练过程中的损失值写入TensorBoard文件中。
在整个测试集上打印损失值。

结束训练步骤:

关闭TensorBoard写入器。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/33886.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

websocket知识点

http协议 http协议特点: 无状态协议每个请求是独立的单双工通信,且服务器无法主动给客户端发信息http协议受浏览器同源策略影响 http实现双向通信方法: 轮询长轮询iframe流sse EventSource websocket协议 websocket协议: 全双工协议支持跨域支持多…

自动测试框架airtest应用一:将XX读书书籍保存为PDF

一、Airtest的简介 Airtest是网易出品的一款基于图像识别和poco控件识别的一款UI自动化测试工具。Airtest的框架是网易团队自己开发的一个图像识别框架,这个框架的祖宗就是一种新颖的图形脚本语言Sikuli。Sikuli这个框架的原理是这样的,计算机用户不需要…

asp.net core webapi如何执行周期性任务

使用Api执行周期性任务 第一种,无图形化界面1.新建类,继承IJob,在实现的方法种书写需要周期性执行的事件。2.编写方法类,定义事件执行方式3.在启动方法中,进行设置,.net 6中在program.cs的Main方法中&#…

旅卦-火山旅

前言:人生就像一趟旅行,为谋生奔波也是旅,旅是人生的常态,我们看一下易经里的旅卦,分析下卦辞和爻辞以及自己的理解。 目录 卦辞 爻辞 总结 卦辞 旅:小亨,旅贞吉。 卦序:穷大者…

java获取到heapdump文件后,如何快速分析?

简介 在之前的OOM问题复盘之后,本周,又一Java服务出现了内存问题,这次问题不严重,只会触发堆内存占用高报警,没有触发OOM,但好在之前的复盘中总结了dump脚本,会在堆占用高时自动执行jstack与jm…

560. 和为 K 的子数组

思路 本题的主要思路为创建一个哈希表记录每个0~i的和,在遍历这个数组的时候查询有没有sum-k的值在哈希表中,如果有,说明有个位置到当前位置的和为k。   有可能不止一个,哈希表负责记录有几个sum-k,将和记录下来。这…

【ArcGIS Pro二次开发】(60):按图层导出布局

在使用布局导图时,会遇到如下问题: 为了切换图层和导图方便,一般情况下,会把相关图层做成图层组。 在导图的时候,如果想要按照图层组进行分开导图,如上图,想导出【现状图、规划图、管控边界】3…

UNIX网络编程——TCP协议API 基础demo服务器代码

目录 一.TCP客户端API 1.创建套接字 2.connect连接服务器​编辑 3.send发送信息 4.recv接受信息 5.close 二.TCP服务器API 1.socket创建tcp套接字(监听套接字) 2.bind给服务器套接字绑定port,ip地址信息 3.listen监听并创建连接队列 4.accept提取客户端的连接 5.send,r…

包管理工具详解npm 、 yarn 、 cnpm 、 npx 、 pnpm(2023)

1、包管理工具npm (1)包管理工具npm: Node Package Manager,也就是Node包管理器;但是目前已经不仅仅是Node包管理器了,在前端项目中我们也在使用它来管理依赖的包;比如vue、vue-router、vuex、…

【Spring中MySQL连接错误】Cannot load driver class: com.mysql.cj.jdbc.Driver

Caused by: Failed to instantiate [com.zaxxer.hikari.HikariDataSource]: Factory method ‘dataSource’ threw exception; nested exception is java.lang.IllegalStateException: Cannot load driver class: com.mysql.cj.jdbc.Driver Caused by: java.lang.IllegalState…

干货满满的Python知识,学会这些你也能成为大牛

目录 1. 爬取网站数据 2. 数据清洗与处理 3. 数据可视化 4. 机器学习模型训练 5. 深度学习模型训练 6. 总结 1. 爬取网站数据 在我们的Python中呢,使用爬虫可以轻松地获取网站的数据。可以使用urllib、requests、BeautifulSoup等库进行数据爬取和处理。以下是…

(kubernetes)k8s常用资源管理

目录 k8s常用资源管理 1、创建一个pod 1)创建yuml文件 2)创建容器 3)查看所有pod创建运行状态 4)查看指定pod资源 5)查看pod运行的详细信息 6)验证运行的pod 2、pod管理 1)删除pod 2…

STM32 F103C8T6学习笔记1:开发环境与原理图的熟悉

作为一名大学生,学习单片机有一段时间了,也接触过嵌入式ARM的开发,但从未使用以及接触过STM32C8T6大开发使用,于是从今日开始,将学习使用它~ 本文介绍STM32C8T6最小系统开发环境搭建注意问题,STM32C8T6单片…

【Docker晋升记】No.2 --- Docker工具安装使用、命令行选项及构建、共享和运行容器化应用程序

文章目录 前言🌟一、Docker工具安装🌟二、Docker命令行选项🌏2.1.docker run命令选项:🌏2.2.docker build命令选项:🌏2.3.docker images命令选项:🌏2.4.docker ps命令选项…

20.5 HTML 媒体

1. video视频标签 video视频标签: 是HTML中用于在网页上嵌入视频的元素.常用的视频标签属性: - src属性: 指定视频文件的URL地址. - controls属性: 用于显示视频播放控件(如播放按钮, 进度条等), 使用户能够控制视频的播放. - width和height: 指定视频的宽度和高度. - autopla…

【Unity实战系列】Unity的下载安装以及汉化教程

君兮_的个人主页 即使走的再远,也勿忘启程时的初心 C/C 游戏开发 Hello,米娜桑们,这里是君兮_,怎么说呢,其实这才是我以后真正想写想做的东西,虽然才刚开始,但好歹,我总算是启程了。今天要分享…

使用RecyclerView构建灵活的列表界面

使用RecyclerView构建灵活的列表界面 1. 引言 在现代移动应用中,列表界面是最常见的用户界面之一,它能够展示大量的数据,让用户可以浏览和操作。无论是社交媒体的动态流、商品展示、新闻列表还是任务清单,列表界面都扮演着不可或…

第一百二十四天学习记录:C++提高:STL-deque容器(上)(黑马教学视频)

deque容器 deque容器基本概念 功能: 双端数组,可以对头端进行插入删除操作 deque与vector区别 vector对于头部的插入删除效率低,数据量越大,效率越低 deque相对而言,对头部的插入删除速度比vector快 vector访问元素的…

LeetCode150道面试经典题--同构字符串(简单)

1.题目 给定两个字符串 s 和 t ,判断它们是否是同构的。如果 s 中的字符可以按某种映射关系替换得到 t ,那么这两个字符串是同构的。每个出现的字符都应当映射到另一个字符,同时不改变字符的顺序。不同字符不能映射到同一个字符上&#xff0c…

ppt怎么压缩?试试这样压缩文件

当PPT文件体积过大时,打开的速度就会很慢,演示的时候刘程度也会受到影响,其次,现在很多平台对于上传的文件是有大小限制的,比如超过100M的文件就无法上传、发送等等,那么,怎么才能压缩PPT文件呢…