PyTorch入门学习(十七):完整的模型训练套路

目录

一、构建神经网络

二、数据准备

三、损失函数和优化器

四、训练模型

五、保存模型


一、构建神经网络

首先,需要构建一个神经网络模型。在示例代码中,构建了一个名为Tudui的卷积神经网络(CNN)模型。这个模型包括卷积层、池化层和全连接层,用于处理图像分类任务。

class Tudui(nn.Module):def __init__(self):super(Tudui, self).__init()self.mode1 = nn.Sequential(nn.Conv2d(3, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64*4*4, 64),nn.Linear(64, 10))def forward(self, x):x = self.mode1(x)return x

二、数据准备

训练深度学习模型需要数据集。在示例中,使用CIFAR-10数据集作为示例数据。数据集的准备包括下载、预处理和分割成训练集和测试集。

import torch
import torchvision
from torch.utils.data import DataLoader# 准备数据集
train_data = torchvision.datasets.CIFAR10(root="D:\\Python_Project\\pytorch\\dataset2", train=True, transform=torchvision.transforms.ToTensor(), download=True)
test_data = torchvision.datasets.CIFAR10(root="D:\\Python_Project\\pytorch\\dataset2", train=False, transform=torchvision.transforms.ToTensor(), download=True)# 创建数据加载器
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)train_data_size = len(train_data)
test_data_size = len(test_data)

三、损失函数和优化器

在训练中,需要定义损失函数和优化器。损失函数用于度量模型的输出与真实标签之间的差距,而优化器用于更新模型的参数以减小损失。

loss_fn = nn.CrossEntropyLoss()
learning_rate = 1e-2
optimizer = torch.optim.SGD(tudui.parameters(), lr=learning_rate)

四、训练模型

模型训练分为多轮迭代,每轮包括训练和测试步骤。在训练步骤中,通过反向传播算法更新模型参数,以最小化损失函数。在测试步骤中,用测试集验证模型性能。

for epoch in range(10):  # 训练的轮数tudui.train()for data in train_dataloader:imgs, targets = dataoutputs = tudui(imgs)loss = loss_fn(outputs, targets)optimizer.zero_grad()loss.backward()optimizer.step()tudui.eval()total_test_loss = 0total_accuracy = 0with torch.no_grad():for data in test_dataloader:imgs, targets = dataoutputs = tudui(imgs)loss = loss_fn(outputs, targets)total_test_loss += loss.item()accuracy = (outputs.argmax(1) == targets).sum()total_accuracy += accuracyprint("整体测试集上的Loss:{}".format(total_test_loss))print("整体测试集上的正确率:{}".format(total_accuracy / test_data_size))

五、保存模型

最后,可以保存训练好的模型,以备后续使用。示例代码展示了两种保存模型的方式,包括保存整个模型和仅保存模型参数。

# 保存方式一
torch.save(tudui, "tudui_{}.pth".format(epoch))
# 保存方式二(官方推荐)
# torch.save(tudui.state_dict(), 'tudui_{}.pth'.format(epoch))

完整代码如下:

import torch
from torch import nn# 搭建神经网络
class Tudui(nn.Module):def __init__(self):super(Tudui,self).__init__()self.mode1 = nn.Sequential(nn.Conv2d(3,32,5,1,2),nn.MaxPool2d(2),nn.Conv2d(32,32,5,1,2),nn.MaxPool2d(2),nn.Conv2d(32,64,5,1,2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64*4*4,64),nn.Linear(64,10))def forward(self, x):x = self.mode1(x)return xif __name__ == '__main__':tudui = Tudui()input = torch.ones((64,3,32,32))output = tudui(input)print(output.shape)
import torch
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from P27_model import *
import time# 准备数据集
train_data = torchvision.datasets.CIFAR10(root="D:\\Python_Project\\pytorch\\dataset2",train=True,transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10(root="D:\\Python_Project\\pytorch\\dataset2",train=False,transform=torchvision.transforms.ToTensor(),download=True)# length 长度
train_data_size = len(train_data)
test_data_size = len(test_data)# 如果train_data_size=10,训练数据集的长度为:10
print("训练数据集的长度为:{}".format(train_data_size))
print("测试数据集的长度为:{}".format(test_data_size))# 利用DataLoader 来加载数据集
train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)# 创建网络模型
tudui = Tudui()# 损失函数
loss_fn = nn.CrossEntropyLoss()# 优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(tudui.parameters(),lr=learning_rate)# 记录训练的次数
total_train_step = 0
# 记录测试的次数
total_test_step = 0# 训练的轮数
epoch = 10# 添加tensorboard
writer = SummaryWriter("logs_train")
# 添加开始时间
strat_time = time.time()for i in range(epoch):print("----------第{}轮训练开始----------".format(i+1))# 训练步骤开始tudui.train()  # 这两个层,只对一部分层起作用,比如 dropout层;如果有这些特殊的层,才需要调用这个语句for data in train_dataloader:imgs, targets = dataoutputs = tudui(imgs)loss = loss_fn(outputs, targets)# 优化器优化模型optimizer.zero_grad() # 优化器,梯度清零loss.backward()optimizer.step()total_train_step = total_train_step + 1if total_train_step % 100 == 0:end_time = time.time()  # 结束时间print(end_time - strat_time)print("训练次数:{}, Loss:{}".format(total_train_step, loss.item()))   # 这里用到的 item()方法,有说法的,其实加不加都行,就是输出的形式不一样而已writer.add_scalar("train_loss", loss.item(),total_train_step)# 每训练完一轮,进行测试,在测试集上测试,以测试集的损失或者正确率,来评估有没有训练好,测试时,就不要调优了,就是以当前的模型,进行测试,所以不用再使用梯度(with no_grad 那句)# 测试步骤开始tudui.eval()  # 这两个层,只对一部分层起作用,比如 dropout层;如果有这些特殊的层,才需要调用这个语句total_test_loss = 0total_accuracy = 0with torch.no_grad():     # 这样后面就没有梯度了,  测试的过程中,不需要更新参数,所以不需要梯度?for data in test_dataloader: # 在测试集中,选取数据imgs, targets = dataoutputs = tudui(imgs)   # 分类的问题,是可以这样的,用一个output进行绘制loss = loss_fn(outputs, targets)total_test_loss = total_test_loss + loss.item()     # 为了查看总体数据上的 loss,创建的 total_test_loss,初始值是0accuracy = (outputs.argmax(1) == targets).sum()  # 正确率,这是分类问题中,特有的一种,评价指标,语义分割之类的,不一定非要有这个东西,这里是存疑的,再看。total_accuracy = total_accuracy + accuracyprint("整体测试集上的Loss:{}".format(total_test_loss))print("整体测试集上的正确率:{}".format(total_accuracy / test_data_size))   # 即便是输出了上一行的 loss,也不能很好的表现出效果。# 在分类问题上比较特有,通常使用正确率来表示优劣。因为其他问题,可以可视化地显示在tensorboard中。writer.add_scalar("test_loss", total_test_loss, total_test_step)writer.add_scalar("test_accuracy", total_accuracy / test_data_size, total_test_step)total_test_step = total_test_step + 1# print(total_test_step)# 保存方式一,其实后缀都可以自己取,习惯用 .pth。torch.save(tudui, "tudui_{}.pth".format(i))# 保存方式2(官方推荐)# torch.save(model.state_dict(), pth_dir + '/model_{}.pth'.format(i)print("模型已保存")writer.close()

参考资料:

视频教程:PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/135066.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Redis系列之常见数据类型应用场景

文章目录 String简单介绍常见命令应用场景 Hash简单介绍常见命令应用场景 List简单介绍常见命令应用场景 Set简单介绍常见命令应用场景 Sorted Set(Zset)简单介绍常见命令应用场景 Bitmap简单介绍常见命令应用场景 附录 Redis支持多种数据类型,比如String、hash、li…

Midway.js打通WebSocket前后端监听通道

您好, 如果喜欢我的文章或者想上岸大厂,可以关注公众号「量子前端」,将不定期关注推送前端好文、分享就业资料秘籍,也希望有机会一对一帮助你实现梦想 前言 WebSocket协议允许客户端和服务端持久化连接,这种可以持续…

MAC设备(M1)环境下编译安装openCV for Java

最近发现一个需求,可以用openCV来实现,碰巧又新买了mac笔记本,就打算利用业余时间安装下openCV。这里将主要步骤记录下,希望能帮助有需要的人。 1、准备编译环境 #查询编译opencv相关依赖 brew info opencv查询结果如下图所示&a…

docker容器中运行jar 出现invalid or corrupt jarfile

1,背景: 在本地java开发完毕之后,想要打包成docker镜像,方便安装。由于本地没有docker环境,也懒得装了。有一台测试的linux机器可以使用,所以先在本地打包生成xxx.jar,然后拷贝到有docker环境的…

openGauss学习笔记-117 openGauss 数据库管理-设置数据库审计-查看审计结果

文章目录 openGauss学习笔记-117 openGauss 数据库管理-设置数据库审计-查看审计结果117.1 前提条件117.2 背景信息117.3 操作步骤 openGauss学习笔记-117 openGauss 数据库管理-设置数据库审计-查看审计结果 117.1 前提条件 审计功能总开关已开启。需要审计的审计项开关已开…

语言模型AI——聊聊GPT使用情形与影响

GPT的出现象征着人工智能自然语言处理技术的一次巨大飞跃。从编程助手到写作利器,它的身影在各个行业中越来越常见。百度【文心一言】、CSDN【C知道】等基于GPT的产品相继推出,让我们看到了其广泛的应用前景。然而,随着GPT的普及,…

upload-labs-1

文章目录 Pass-01 Pass-01 先上传一个正常的图片&#xff0c;查看返回结果&#xff0c;结果中带有文件上传路径&#xff0c;可以进行利用&#xff1a; 上传一个恶意的webshell&#xff0c;里面写入一句话木马&#xff1a; <?php eval($_POST[cmd]); echo "hello&quo…

【漏洞复现】Apache_HTTP_2.4.50_路径穿越漏洞(CVE-2021-42013)

感谢互联网提供分享知识与智慧&#xff0c;在法治的社会里&#xff0c;请遵守有关法律法规 文章目录 1.1、漏洞描述1.2、漏洞等级1.3、影响版本1.4、漏洞复现1、基础环境2、漏洞扫描3、漏洞验证方式一 curl方式二 bp抓捕 1.5、修复建议 说明内容漏洞编号CVE-2021-42013漏洞名称…

图像对比方法介绍及实现

图像对比方法介绍及实现 1.引言 图像对比是在计算机视觉和图像处理中常见的任务之一。它可以用于识别重复图片、图像搜索、图像相似性比较等应用场景。实现图片对比方法的方法有多种&#xff0c;根据不同的需求和图片类型&#xff0c;可以选择适合的实现方案。如果对于简单的…

【Node.js入门】1.1Node.js 简介

Node.js入门之—1.1Node.js 简介 文章目录 Node.js入门之—1.1Node.js 简介什么是 Node.js错误说法 Node.js 的特点跨平台三方类库自带http服务器非阻塞I/O事件驱动单线程 Node.js 的应用场合适合用Node.js的场合不适合用Node.js的场合弥补Node.js不足的解决方案 什么是 Node.j…

Java连接数据库并查询表中的全部数据

1、导入相关jar包 这里创建简单的maven项目&#xff0c;我们导入相关的jar包 相关依赖&#xff1a; <dependency><groupId>mysql</groupId><artifactId>mysql-connector-java</artifactId><version>5.1.47</version></dependenc…

React进阶之路(一)-- JSX基础、组件基础

文章目录 React介绍React开发环境搭建项目目录说明以及相关调整 JSX基础JSX介绍JSX中使用js表达式JSX列表渲染JSX条件渲染JSX样式处理JSX注意事项 组件基础组件的概念函数组件类组件事件绑定如何绑定事件获取事件对象传递额外参数 组件状态状态不可变表单处理受控表单组件非受控…

menuTreeRef.value?.getCheckedKeys(true) as string[]

问: menuTreeRef.value?.getCheckedKeys(true) as string[]的as string[]什么意思? 回答: 举个例子:

Hive从入门到大牛【Hive 学习笔记】

文章目录 什么是HiveHive的数据存储Hive的系统架构MetastoreHive VS Mysql数据库 VS 数据仓库 Hive安装部署Hive的使用方式命令行方式JDBC方式 Set命令的使用Hive的日志配置Hive中数据库的操作Hive中表的操作 Hive中的数据类型基本数据类型复合数据类型ArrayMapStructStruct和M…

【Leetcode】【每日一题】【简单】2609. 最长平衡子字符串

力扣&#xff08;LeetCode&#xff09;官网 - 全球极客挚爱的技术成长平台备战技术面试&#xff1f;力扣提供海量技术面试资源&#xff0c;帮助你高效提升编程技能&#xff0c;轻松拿下世界 IT 名企 Dream Offer。https://leetcode.cn/problems/find-the-longest-balanced-subs…

shell的for循环

列表for循环 列表for循环的语法结构如下: for variablein list #每一次循环&#xff0c;依次把列表list 中的一个值赋给循环变量 do #循环体开始的标志commands #循环变量每取一次值&#xff0c;循环体就执行一遍commands done #循环结束的标志&#xff0c;返回循环顶…

PCA9535模块移植

在虚拟机环境里面找到内核文件 更改需要的信息 比如内核设备名称与设备树的名称是否一样 如有需要添加的应用程序 也需要添加进去 根据实际情况来 更改设备名称 还有注意的 比如中断号 根据硬件信息本次中断号为32 所以所有的设备树文件中断号都改为32 现在准备编写驱动文…

TensorFlow学习笔记--(1)张量的随机生成

张量的生成 如何判断一个张量的维数&#xff1a;看张量的中括号有几层 0 1 2 &#xff1a;零维数列 [2 4 6] : 一维向量 [ [1 2 3] [4 5 6] ] : 二维数组 两行三列 第一行数据为 1 2 3 第二行数据为 4 5 6 以此类推 n维张量有n层中括号 tf.zeros(%指定一个张量的维数%) 生成一…

合肥工业大学数据库实验报告

✅作者简介:CSDN内容合伙人、信息安全专业在校大学生🏆 🔥系列专栏 :hfut实验课设 📃新人博主 :欢迎点赞收藏关注,会回访! 💬舞台再大,你不上台,永远是个观众。平台再好,你不参与,永远是局外人。能力再大,你不行动,只能看别人成功!没有人会关心你付出过多少…

如何将 ONLYOFFICE 文档 7.5 与 Odoo 进行集成

在本教程中&#xff0c;我们将了解如何使用集成应用实现 ONLYOFFICE 文档与 Odoo 之间的连接。 ONLYOFFICE 文档是什么 ONLYOFFICE 文档是一款全面的在线办公工具&#xff0c;提供了文本文档、电子表格和演示文稿的查看和编辑功能。它高度兼容微软 Office 格式&#xff0c;包括…