学习pytorch18 pytorch完整的模型训练流程

pytorch完整的模型训练流程

  • 1. 流程
    • 1. 整理训练数据 使用CIFAR10数据集
    • 2. 搭建网络结构
    • 3. 构建损失函数
    • 4. 使用优化器
    • 5. 训练模型
    • 6. 测试数据 计算模型预测正确率
    • 7. 保存模型
  • 2. 代码
    • 1. model.py
    • 2. train.py
  • 3. 结果
    • tensorboard结果
      • 以下图片 颜色较浅的线是真实计算的值,颜色较深的线是做了平滑处理的值
      • 训练loss
      • 测试loss
      • 测试集正确率
  • 4. 需要注意的细节

1. 流程

1. 整理训练数据 使用CIFAR10数据集

train_data = torchvision.datasets.CIFAR10(root='./dataset', train=True, transform=torchvision.transforms.ToTensor(),download=True)

2. 搭建网络结构

在这里插入图片描述
model.py

3. 构建损失函数

loss_fn = nn.CrossEntropyLoss()

4. 使用优化器

learing_rate = 1e-2 # 0.01
optimizer = torch.optim.SGD(net.parameters(), lr=learing_rate)

5. 训练模型

output = net(imgs)    # 数据输入模型
loss = loss_fn(output, targets)  # 损失函数计算损失 看计算的输出和真实的标签误差是多少
# 优化器开始优化模型  1.梯度清零  2.反向传播  3.参数优化
optimizer.zero_grad()  # 利用优化器把梯度清零 全部设置为0
loss.backward()        # 设置计算的损失值的钩子,调用损失的反向传播,计算每个参数结点的参数
optimizer.step()       # 调用优化器的step()方法 对其中的参数进行优化  

6. 测试数据 计算模型预测正确率

output = net(imags)
# 计算测试集的正确率
preds = (output.argmax(1)==targets).sum()
accuracy += preds 
rate = accuracy/len(test_data)

调用模型输出tensor 数据类型的 argmax方法, argmax或获取一行或者一列数值中最大数值的下标位置,argmax(0) 是从列的维度取一列数值的最大值的下标,argmax(1) 是从行的维度取一行数值的最大值的下标
output.argmax(1)==targets 会输出如下图最后一行 [false, ture], 对应位置相同则为true,对应位置不同则为false;
调用sum()方法,计算求和,false值为0,true值为1.
最后计算得出测试集整体正确率: rate = accuracy/len(test_data)
在这里插入图片描述

7. 保存模型

torch.save(net, './net_epoch{}.pth'.format(i))

2. 代码

1. model.py

import torch
from torch import nn# 2. 搭建模型网络结构--神经网络
class Cifar10Net(nn.Module):def __init__(self):super(Cifar10Net, self).__init__()self.net = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(kernel_size=2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(kernel_size=2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(kernel_size=2),nn.Flatten(),nn.Linear(64*4*4, 64),nn.Linear(64, 10))def forward(self, x):x = self.net(x)return xif __name__ == '__main__':net = Cifar10Net()input = torch.ones((64, 3, 32, 32))output = net(input)print(output.shape)

2. train.py

import torch
import torchvision
from torch import nn
from torch.utils.tensorboard import SummaryWriterfrom p24_model import *# 1. 准备数据集
# 训练数据
from torch.utils.data import DataLoadertrain_data = torchvision.datasets.CIFAR10(root='./dataset', train=True, transform=torchvision.transforms.ToTensor(),download=True)
# 测试数据
test_data = torchvision.datasets.CIFAR10(root='./dataset', train=False, transform=torchvision.transforms.ToTensor(),download=True)# 查看数据大小--size
print("训练数据集大小:", len(train_data))
print("测试数据集大小:", len(test_data))
# 利用DataLoader来加载数据集
train_loader = DataLoader(dataset=train_data, batch_size=64)
test_loader = DataLoader(dataset=test_data, batch_size=64)# 2. 导入模型结构 创建模型
net = Cifar10Net()# 3. 创建损失函数  分类问题--交叉熵
loss_fn = nn.CrossEntropyLoss()# 4. 创建优化器
# learing_rate = 0.01
# 1e-2 = 1 * 10^(-2) = 0.01
learing_rate = 1e-2
print(learing_rate)
optimizer = torch.optim.SGD(net.parameters(), lr=learing_rate)# 设置训练网络的一些参数
epoch = 10   # 记录训练的轮数
total_train_step = 0  # 记录训练的次数
total_test_step = 0   # 记录测试的次数# 利用tensorboard显示训练loss趋势
writer = SummaryWriter('./train_logs')for i in range(epoch):# 训练步骤开始net.train()  # 可以加可以不加  只有当模型结构有 Dropout BatchNorml层才会起作用for data in train_loader:imgs, targets = data  # 获取数据output = net(imgs)    # 数据输入模型loss = loss_fn(output, targets)  # 损失函数计算损失 看计算的输出和真实的标签误差是多少# 优化器开始优化模型  1.梯度清零  2.反向传播  3.参数优化optimizer.zero_grad()  # 利用优化器把梯度清零 全部设置为0loss.backward()        # 设置计算的损失值,调用损失的反向传播,计算每个参数结点的参数optimizer.step()       # 调用优化器的step()方法 对其中的参数进行优化# 优化一次 认为训练了一次total_train_step += 1if total_train_step % 100 == 0:print('训练次数: {}   loss: {}'.format(total_train_step, loss))# 直接打印loss是tensor数据类型,打印loss.item()是打印的int或float真实数值, 真实数值方便做数据可视化【损失可视化】# print('训练次数: {}   loss: {}'.format(total_train_step, loss.item()))writer.add_scalar('train-loss', loss.item(), global_step=total_train_step)# 利用现有模型做模型测试# 测试步骤开始total_test_loss = 0accuracy = 0net.eval()  # 可以加可以不加  只有当模型结构有 Dropout BatchNorml层才会起作用with torch.no_grad():for data in test_loader:imags, targets = dataoutput = net(imags)loss = loss_fn(output, targets)total_test_loss += loss.item()# 计算测试集的正确率preds = (output.argmax(1)==targets).sum()accuracy += preds# writer.add_scalar('test-loss', total_test_loss, global_step=i+1)writer.add_scalar('test-loss', total_test_loss, global_step=total_test_step)writer.add_scalar('test-accracy', accuracy/len(test_data), total_test_step)total_test_step += 1print("---------test loss: {}--------------".format(total_test_loss))print("---------test accuracy: {}--------------".format(accuracy))# 保存每一个epoch训练得到的模型torch.save(net, './net_epoch{}.pth'.format(i))writer.close()

3. 结果

训练数据集大小: 50000
测试数据集大小: 10000
0.01
训练次数: 100   loss: 2.2905373573303223
训练次数: 200   loss: 2.2878968715667725
训练次数: 300   loss: 2.258394718170166
训练次数: 400   loss: 2.1968581676483154
训练次数: 500   loss: 2.0476632118225098
训练次数: 600   loss: 2.002145767211914
训练次数: 700   loss: 2.016021728515625
---------test loss: 316.382279753685--------------
训练次数: 800   loss: 1.8957302570343018
训练次数: 900   loss: 1.8659226894378662
训练次数: 1000   loss: 1.9004186391830444
训练次数: 1100   loss: 1.9708642959594727
......

tensorboard结果

安装tensorboard运行环境

pip install tensorboard
pip install opencv-python
pip install six
tensorboard --logdir=train_logs

以下图片 颜色较浅的线是真实计算的值,颜色较深的线是做了平滑处理的值

训练loss

在这里插入图片描述

测试loss

在这里插入图片描述

测试集正确率

在这里插入图片描述

4. 需要注意的细节

https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module

所有网络层继承于torch.nn.Module, net.train() net.eval() 在模型训练或测试之初 可以加可以不加 只有当模型结构有 Dropout BatchNorml层才会起作用,当模型有这两个网络层的时候,两个代码需要加上。
在这里插入图片描述

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/208738.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

国产化软件突围!怿星科技eStation产品荣获2023铃轩奖“前瞻优秀奖”

11月11日,2023中国汽车供应链峰会暨第八届铃轩奖颁奖典礼在江苏省昆山市举行。怿星科技凭借eStation产品,荣获2023铃轩奖“前瞻智能座舱类优秀奖”,怿星CEO潘凯受邀出席铃轩奖晚会并代表领奖。 2023铃轩奖“前瞻智能座舱类优秀奖” 铃轩奖&a…

el-table 跨页多选

步骤一 在<el-table>中:row-key"getRowKeys"和selection-change"handleSelectionChange" 在<el-table-column>中type"selection"那列&#xff0c;添加:reserve-selection"true" <el-table:data"tableData"r…

队列排序:给定序列a,每次操作将a[1]移动到 从右往左第一个严格小于a[1]的元素的下一个位置,求能否使序列有序,若可以,求最少操作次数

题目 思路&#xff1a; 赛时代码&#xff08;先求右起最长有序区间长度&#xff0c;再求左边最小值是否小于等于右边有序区间左端点的数&#xff09; #include<bits/stdc.h> using namespace std; #define int long long const int maxn 1e6 5; int a[maxn]; int n; …

阿里云磁盘在线扩容

我们从阿里云的控制面板中给硬盘扩容后结果发现我们的磁盘空间并没有改变 注意&#xff1a;本次操作是针对CentOS 7的 &#xfeff;#使用df -h并没有发现我们的磁盘空间增加 #使用fdisk -l发现确实还有部分空间 运行df -h命令查看云盘分区大小。 以下示例返回分区&#xf…

eve-ng镜像模拟设备-信息安全管理与评估-2023国赛

eve-ng镜像模拟设备-信息安全管理与评估-2023国赛 author&#xff1a;leadlife data&#xff1a;2023/12/4 mains&#xff1a;EVE-ng 模拟器 - 信息安全管理与评估模拟环境部署 references&#xff1a; EVE-ng 官网&#xff1a;https://www.eve-ng.net/EVE-ng 中文网&#xff1…

嵌入版python作为便携计算器(安装及配置ipython)

今天用别的电脑调试C&#xff0c;需要计算反三角函数时发现没有趁手工具&#xff0c;忽然想用python作为便携计算器放在U盘&#xff0c;遂想到嵌入版python 懒得自己配可以直接下载&#xff0c;使用方法见第4节 1&#xff0c;下载embeddable python&#xff08;嵌入版python&…

图的邻接链表储存

喷了一节课 。。。。。。。、。 #include<stdio.h> #include<stdlib.h> #define MAXNUM 20 //每一个顶点的节点结构&#xff08;单链表&#xff09; typedef struct ANode{ int adjvex;//顶点指向的位置 struct ArcNode *next;//指向下一个顶点 …

C++ 内存分区模型

目录 程序运行前 代码区 全局区 程序运行后 new 在堆区开辟数据 delete释放堆区数据 堆区开辟数组 内存分区模型 栈&#xff08;Stack&#xff09; 堆&#xff08;Heap&#xff09; 全局/静态存储区&#xff08;Global/Static Storage&#xff09; 常量存储区&am…

使用 Axios 进行网络请求的全面指南

使用 Axios 进行网络请求的全面指南 本文将向您介绍如何使用 Axios 进行网络请求。通过分步指南和示例代码&#xff0c;您将学习如何使用 Axios 库在前端应用程序中发送 GET、POST、PUT 和 DELETE 请求&#xff0c;并处理响应数据和错误。 准备工作 在开始之前&#xff0c;请…

电子学会C/C++编程等级考试2021年09月(五级)真题解析

C/C++等级考试(1~8级)全部真题・点这里 第1题:抓牛 农夫知道一头牛的位置,想要抓住它。农夫和牛都位于数轴上,农夫起始位于点N(0<=N<=100000),牛位于点K(0<=K<=100000)。农夫有两种移动方式: 1、从X移动到X-1或X+1,每次移动花费一分钟 2、从X移动到2*X,每…

Navicat 技术指引 | 适用于 GaussDB 分布式的自动运行功能

Navicat Premium&#xff08;16.3.3 Windows 版或以上&#xff09;正式支持 GaussDB 分布式数据库。GaussDB 分布式模式更适合对系统可用性和数据处理能力要求较高的场景。Navicat 工具不仅提供可视化数据查看和编辑功能&#xff0c;还提供强大的高阶功能&#xff08;如模型、结…

「Python编程基础」第7章:字符串操作

文章目录 一、回顾二、新手容易踩坑的引号三、转义字符四、多行字符串写法五、注释六、字符串索引和切片七、字符串的in 和 not in八、字符串拼接九、转换大小写十、合并字符串join()十一、分割字符串split()十二、字符串替换 replace()十三、字符串内容判断方法十四、字符串内…

Qt使用Cryptopp生成HMAC-MD5

近期项目中HTTPS通讯中&#xff0c;token需要使用HMAC-MD5算法生成&#xff0c;往上找了一些资料后&#xff0c;仍不能满足自身需求&#xff0c;故次一记。 前期准备&#xff1a; ①下载Cryptopp库&#xff08;我下载的是8.8.0 Release版本&#xff09;&#xff1a;Crypto Li…

基于ResNet模型的908种超大规模中草药图像识别系统

中草药药材图像识别相关的实践在前文中已有对应的实践了&#xff0c;感兴趣的话可以自行移步阅读即可&#xff1a; 《python基于轻量级GhostNet模型开发构建23种常见中草药图像识别系统》 《基于轻量级MnasNet模型开发构建40种常见中草药图像识别系统》 在上一篇文章中&…

RocketMQ-RocketMQ高性能核心原理(流程图)

1.NamesrvStartup 2.BrokerStartup 3. DefualtMQProducer 4.DefaultMQPushConsumer

maven工程的pom.xml文件中增加了依赖,但偶尔没有下载到本地仓库

maven工程pom.xml文件中的个别依赖没有下载到本地maven仓库。以前没有遇到这种情况&#xff0c;今天就遇到了这个问题&#xff0c;把解决过程记录下来。 我在eclipse中编辑maven工程的pom.xml文件&#xff0c;增加对mybatis的依赖&#xff0c;但保存文件后&#xff0c;依赖的j…

Java--1v1双向通信-控制台版

文章目录 前言客户端服务器端输出线程端End 前言 TCP&#xff08;Transmission Control Protocol&#xff09;是一种面向连接的、可靠的网络传输协议&#xff0c;它提供了端到端的数据传输和可靠性保证。 本程序就是基于tcp协议编写而成的。 利用 TCP 协议进行通信的两个应用…

HarmonyOS(鸿蒙操作系统)与Android系统 各自特点 架构对比 各自优势

综合对比 HarmonyOS&#xff08;鸿蒙操作系统&#xff09;是由华为开发的操作系统&#xff0c;旨在跨多种设备和平台使用。HarmonyOS的架构与谷歌开发的广泛使用的Android操作系统有显著不同。以下是两者之间的一些主要比较点&#xff1a; 设计理念和使用案例&#xff1a; Harm…

go语言 grpc 拦截器

文章目录 拦截器服务端拦截器一元拦截器流拦截器 客户端拦截器一元拦截器流拦截 多个拦截器 代码仓库 拦截器 gRPC拦截器&#xff08;interceptor&#xff09;是一种函数&#xff0c;它可以在gRPC调用之前和之后执行一些逻辑&#xff0c;例如认证、授权、日志记录、监控和统计…

docker学习(四、修改容器创建新的镜像推送到云上)

镜像是只读的&#xff0c;容器是可编辑的。Docker镜像是分层的&#xff0c;支持通过扩展镜像&#xff0c;创建新的镜像。 学到这里感觉docker跟git很想~~ 通过docker commit将修改的容器做成新的镜像 # 将容器做成新的镜像 docker commit -m"提交备注" -a"作…