Pytorch完整的模型训练套路

Pytorch完整的模型训练套路

文章目录

  • Pytorch完整的模型训练套路
  • 以CIFAR10为例实践

  1. 数据集加载步骤

使用适当的库加载数据集,例如torchvision、TensorFlow的tf.data等。
将数据集分为训练集和测试集,并进行必要的预处理,如归一化、数据增强等。

  1. 模型创建步骤

创建机器学习模型,可以是深度神经网络、传统机器学习模型或其它模型类型。
定义模型架构,包括输入层、隐藏层和输出层的结构、激活函数、损失函数等。

  1. 损失函数和优化器定义步骤

定义适当的损失函数来计算模型预测结果于真实标签之间的差异。
选择适当的优化器算法来更新模型参数,如随机梯度下降(SGD)、Adam等。

  1. 训练循环步骤

从训练集中获取一批样本数据,并将其输入模型进行前向传播。
计算损失函数,并根据损失函数进行反向传播和参数更新。
重复以上步骤,直到达到预定的训练次数或达到收敛条件。

  1. 测试循环步骤

从测试集中获取一批样本数据,并将其输入模型进行前向传播。
计算损失函数或评估指标,用于评估模型在测试集上的性能。

  1. 训练和测试过程的记录和输出步骤

使用适当的工具或库记录训练过程中的损失值、准确率、评估指标等。

  1. 结束训练步骤

根据训练结束条件、例如达到预定的训练次数或收敛条件,结束训练。可以保存模型参数或整个模型,以便日后部署和使用。

以CIFAR10为例实践

并利用tensorboard可视化

import torch
import torchvision
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
'''数据集加载'''
train_data = torchvision.datasets.CIFAR10(root='dataset',train=True,transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10(root='dataset',train=False,transform=torchvision.transforms.ToTensor(),download=True)# 训练数据集的长度
train_data_size = len(train_data)
print(f"训练数据集的长度为:{train_data_size}")
# 测试数据集的长度
test_data_size = len(test_data)
print(f"测试数据集的长度:{test_data_size}")
#利用DataLoader加载数据集
train_dataloader = DataLoader(test_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)
Files already downloaded and verified
Files already downloaded and verified
训练数据集的长度为:50000
测试数据集的长度:10000

‘’‘创建模型’‘’

以上篇文章《Pytorch损失函数、反向传播和优化器、Sequential使用》中的BS()为例

在这里插入图片描述

'''创建模型'''
class BS(nn.Module):def __init__(self):super().__init__()self.model = nn.Sequential(nn.Conv2d(in_channels=3,out_channels=32,kernel_size=5,stride=1,padding=2),  #stride和padding计算得到nn.MaxPool2d(kernel_size=2),nn.Conv2d(in_channels=32,out_channels=32,kernel_size=5,stride=1,padding=2),nn.MaxPool2d(kernel_size=2),nn.Conv2d(in_channels=32,out_channels=64,kernel_size=5,padding=2),nn.MaxPool2d(kernel_size=2),nn.Flatten(),  #in_features变为64*4*4=1024nn.Linear(in_features=1024, out_features=64),nn.Linear(in_features=64, out_features=10),)def forward(self,x):x = self.model(x)return xbs = BS()
print(bs)
BS((model): Sequential((0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))(1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))(3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(4): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))(5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(6): Flatten(start_dim=1, end_dim=-1)(7): Linear(in_features=1024, out_features=64, bias=True)(8): Linear(in_features=64, out_features=10, bias=True))
)

一般来说,会将网络单独存放在一个model.py文件当中,然后利用from model import * 进行导入

'''定义损失函数和优化器'''
# 使用交叉熵损失函数
loss_fn = nn.CrossEntropyLoss()  
# 定义优化器
learning_rate = 1e-2  #学习率0.01
optimizer = torch.optim.SGD(bs.parameters(), lr=learning_rate)
"""
训练循环步骤
"""
# 开始设置训练神经网络的一些参数
# 记录训练的次数
total_train_step = 0
# 记录测试的次数
total_test_step = 0
# 训练的轮数
epoch = 10writer = SummaryWriter(".logs") #Tensorboard可视化
for i in range(epoch):print("----第{}轮训练开始----".format(i))#bs.train() # bs.train()#有batchnorm、dropout层需要调用。官方文档见torch.nn.Module'''训练步骤开始'''for data in train_dataloader:imgs, targets = dataoutputs = bs(imgs)loss = loss_fn(outputs, targets)optimizer.zero_grad() # 首先要梯度清零loss.backward() #得到梯度optimizer.step() #进行优化total_train_step = total_train_step + 1if total_train_step % 100 == 0:print("训练次数:{}, loss:{}".format(total_train_step,loss.item()))writer.add_scalar("train_loss", loss.item(),total_train_step)'''测试步骤开始'''#bs.eval() # bs.train()#有batchnorm、dropout层需要调用。官方文档见torch.nn.Moduletotal_test_loss = 0#total_accuracytotal_accuracy = 0with torch.no_grad():#torch.no_grad()是一个上下文管理器,用来禁止梯度的计算,通常用来网络推断中,它可以减少计算内存的使用量。for imgs, targets in test_dataloader:outputs = bs(imgs)loss = loss_fn(outputs, targets)total_test_loss = total_test_loss + loss.item() #.item()取出数字accuracy = (outputs.argmax(1) == targets).sum()total_accuracy += accuracy"""测试过程的记录和输出"""print("整体测试集上损失函数loss:{}".format(total_test_loss))print("整体测试集上正确率:{}".format(total_accuracy/test_data_size))writer.add_scalar("test_loss", total_test_loss, total_test_step)writer.add_scalar('test_accuracy',total_accuracy/test_data_size)total_test_step = total_test_step + 1torch.save(bs, "test_{}.pth".format(i))print("模型已保存")
"""
结束训练步骤
"""
writer.close()

利用tensoraboard显示:

tensorboar --logdir logs

在这里插入图片描述

补充.item()

  1. .item()
import torch
a = torch.tensor(5)
print(a)
print(a.item())
tensor(5)
5
  1. model.train()和model.eval()
    官方网址见:torch.nn.Module(*args, **kwargs)
    在这里插入图片描述
    在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/153943.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深搜回溯剪枝-全排列

LCR 083. 全排列 - 力扣(LeetCode) 根据题意,要根据给定的整数数组,穷举出所有可能的排列,从直观的角度上来看,可以使用多层 for 循环来解决,但如果是数组长度太大的时候,这种方式不…

配置Java环境变量不生效的解决办法

问题: 直接更换Java_HOME的JDK安装路径后,竟然环境变量不生效,在cmd窗口输入java -version或者javac -version后报错???这是为什么呢? 问题剖析: 在使用安装版本的JDK程序时&#…

DataFunSummit:2023年数据基础架构峰会-核心PPT资料下载

一、峰会简介 正如From、Join、排序等是SQL的基本算子,存储与计算是也是数据架构中数据生产与消费的基本算子,对于数据架构之下的技术栈层级,我们可将其定义为数据基础架构。 数据存储技术在适应大数据时代的规模需求基础之上,持…

【23真题】难!985难度前五名!

今天分享的是23年中山大学884的信号与系统试题及解析。 本套试卷难度分析:22年中山大学884考研真题,我也发布过,若有需要,戳这里自取!22年并不是很难,今年难度突然大幅度提升!原因不明。23年平均分为100分…

新手教师如何迅速成长

对于许多新手教师来说,迈出教学的第一步可能会感到非常困难。不过,通过一些关键的策略和技巧,还是可以快速提升教学能力的,我将为大家提供一些实用的建议,帮助各位在教育领域迅速成长。 深入了解学科知识 作为一名老师…

数据库编程sqlite3库安装及使用

数据库编程 数据库的概念 数据库是“按照数据结构来组织、存储和管理数据的仓库”。是一个长期存储在计算机内的、有组织的、可共享的、统一管理的大量数据的集合。 数据库是存放数据的仓库。它的存储空间很大,可以存放百万条、千万条、上亿条数据。但是数据库并不是…

postgresql:记录表膨胀引起的io问题的处理

文章目录 1. io异常2.查看profile报告2.1 生成事发时间段的pgprofile2.2 查看报告 3.检查table是否膨胀4.执行vacuum full5.总结 1. io异常 iostat -x 1 20 Device r/s w/s rkB/s wkB/s rrqm/s wrqm/s %rrqm %wrqm r_await w_await aqu-sz rareq…

老师怎么才能让学生听话

在教育学生的过程中,如何让他们听话并且尊重师长,是一个老师需要深入思考的问题。这不仅涉及到学生的学习进步,还关系到他们的人格形成。以下是一些方法和策略,帮助教师更好地引导学生,使他们更愿意听从教导。 建立信任…

轻量封装WebGPU渲染系统示例<36>- 广告板(Billboard)(WGSL源码)

原理不再赘述&#xff0c;请见wgsl shader实现。 当前示例源码github地址: https://github.com/vilyLei/voxwebgpu/blob/feature/rendering/src/voxgpu/sample/BillboardEntityTest.ts 当前示例运行效果: WGSL顶点shader: group(0) binding(0) var<uniform> objMat :…

【机器学习】对比学习(contrastive learning)

对比学习是一种机器学习技术&#xff0c;算法学习区分相似和不相似的数据点。对比学习的目标是学习数据的表示&#xff0c;以捕捉不同数据点之间的基本结构和关系。 在对比学习中&#xff0c;算法被训练最大化相似数据点之间的相似度&#xff0c;并最小化不相似数据点之间的相似…

U-boot(三):start.S

本文主要探讨x210的uboot的start.S文件,也是uboot启动的第一阶段。 头文件 config.h config.h x210_sd.h,由mkconfig脚本生成,包含了开发板的配置宏 rootkaxi-virtual-machine:~/qt_x210v3s_160307/uboot/include# cat config.h /* Automatically generate…

el-date-picker ie模式下 初始化未赋值;未清空

el-date-picker ie模式下 初始化未赋值;未清空 给 dete-picker 加key属性 eg:

接口自动化测试实战:JMeter+Ant+Jenkins+钉钉机器人群通知完美结合

前言 一、本地JAVA环境安装配置,安装JAVA8和JAVA17 二、安装和配置Jmeter 三、安装和配置ant 四、jmeter + ant配置 五、jenkins安装和配置持续构建项目 六、jenkins配置流程 前言 搭建jmeter+ant+jenkins环境有些前提条件,那就是要先配置好java环境,本地java环境…

redis的高可用

redis-cli -h 192.168.233.10 -p 6379 redis的数据类型的增删改查 redis的高可用在集群当中有一个非常重要的指标&#xff0c;提供正常服务的时间的百分比(365天) 99.9% redis的高可用含义更加广泛&#xff0c;正常服务是指标之一&#xff0c;数据容量的扩展&#xff0c;数据…

2023亚太杯数学建模思路 - 案例:异常检测

文章目录 赛题思路一、简介 -- 关于异常检测异常检测监督学习 二、异常检测算法2. 箱线图分析3. 基于距离/密度4. 基于划分思想 建模资料 赛题思路 &#xff08;赛题出来以后第一时间在CSDN分享&#xff09; https://blog.csdn.net/dc_sinor?typeblog 一、简介 – 关于异常…

给新手教师的成长建议

随着教育的不断发展和进步&#xff0c;越来越多的新人加入到教师这个行列中来。从学生到教师&#xff0c;这是一个华丽的转身&#xff0c;需要我们不断地学习和成长。作为一名新手老师&#xff0c;如何才能快速成长呢&#xff1f;以下是一名老师教师给的几点建议&#xff1a; 一…

人工智能对我们的生活影响有多大

随着科技的飞速发展&#xff0c;人工智能已经渗透到我们生活的方方面面&#xff0c;并且越来越受到人们的关注。从智能语音助手到自动驾驶汽车&#xff0c;从智能家居系统到医疗诊断&#xff0c;人工智能技术正在改变着我们的生活方式。那么&#xff0c;人工智能对我们的生活影…

使用 RAFT 的光流:第 1 部分

一、说明 在这篇文章中&#xff0c;我们将了解一种旗舰的光流深度学习方法&#xff0c;该方法获得了 2020 年 ECCV 最佳论文奖&#xff0c;并被引用超过 1000 次。它也是KITTI基准测试中许多性能最佳的模型的基础。该模型称为 RAFT&#xff1a;Recurrent All-Pairs Field Trans…

微信表情太大怎么缩小?一分钟教会你!

在微信的较早版本中&#xff0c;单个表情的最大体积限制为500KB&#xff0c;而在后续版本中&#xff0c;这一限制已经放宽。目前&#xff0c;微信允许上传的单个表情最大体积为2MB。所以&#xff0c;我们只需要把图片或者GIF缩小到2MB即可&#xff0c;下面就向大家介绍三种实用…

如何给面试官解释什么是分布式和集群?

分布式&#xff08;distributed&#xff09; 是指在多台不同的服务器中部署不同的服务模块&#xff0c;通过远程调用协同工作&#xff0c;对外提供服务。 集群&#xff08;cluster&#xff09; 是指在多台不同的服务器中部署相同应用或服务模块&#xff0c;构成一个集群&#…