学习pytorch18 pytorch完整的模型训练流程

pytorch完整的模型训练流程

  • 1. 流程
    • 1. 整理训练数据 使用CIFAR10数据集
    • 2. 搭建网络结构
    • 3. 构建损失函数
    • 4. 使用优化器
    • 5. 训练模型
    • 6. 测试数据 计算模型预测正确率
    • 7. 保存模型
  • 2. 代码
    • 1. model.py
    • 2. train.py
  • 3. 结果
    • tensorboard结果
      • 以下图片 颜色较浅的线是真实计算的值,颜色较深的线是做了平滑处理的值
      • 训练loss
      • 测试loss
      • 测试集正确率
  • 4. 需要注意的细节

1. 流程

1. 整理训练数据 使用CIFAR10数据集

train_data = torchvision.datasets.CIFAR10(root='./dataset', train=True, transform=torchvision.transforms.ToTensor(),download=True)

2. 搭建网络结构

在这里插入图片描述
model.py

3. 构建损失函数

loss_fn = nn.CrossEntropyLoss()

4. 使用优化器

learing_rate = 1e-2 # 0.01
optimizer = torch.optim.SGD(net.parameters(), lr=learing_rate)

5. 训练模型

output = net(imgs)    # 数据输入模型
loss = loss_fn(output, targets)  # 损失函数计算损失 看计算的输出和真实的标签误差是多少
# 优化器开始优化模型  1.梯度清零  2.反向传播  3.参数优化
optimizer.zero_grad()  # 利用优化器把梯度清零 全部设置为0
loss.backward()        # 设置计算的损失值的钩子,调用损失的反向传播,计算每个参数结点的参数
optimizer.step()       # 调用优化器的step()方法 对其中的参数进行优化  

6. 测试数据 计算模型预测正确率

output = net(imags)
# 计算测试集的正确率
preds = (output.argmax(1)==targets).sum()
accuracy += preds 
rate = accuracy/len(test_data)

调用模型输出tensor 数据类型的 argmax方法, argmax或获取一行或者一列数值中最大数值的下标位置,argmax(0) 是从列的维度取一列数值的最大值的下标,argmax(1) 是从行的维度取一行数值的最大值的下标
output.argmax(1)==targets 会输出如下图最后一行 [false, ture], 对应位置相同则为true,对应位置不同则为false;
调用sum()方法,计算求和,false值为0,true值为1.
最后计算得出测试集整体正确率: rate = accuracy/len(test_data)
在这里插入图片描述

7. 保存模型

torch.save(net, './net_epoch{}.pth'.format(i))

2. 代码

1. model.py

import torch
from torch import nn# 2. 搭建模型网络结构--神经网络
class Cifar10Net(nn.Module):def __init__(self):super(Cifar10Net, self).__init__()self.net = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(kernel_size=2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(kernel_size=2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(kernel_size=2),nn.Flatten(),nn.Linear(64*4*4, 64),nn.Linear(64, 10))def forward(self, x):x = self.net(x)return xif __name__ == '__main__':net = Cifar10Net()input = torch.ones((64, 3, 32, 32))output = net(input)print(output.shape)

2. train.py

import torch
import torchvision
from torch import nn
from torch.utils.tensorboard import SummaryWriterfrom p24_model import *# 1. 准备数据集
# 训练数据
from torch.utils.data import DataLoadertrain_data = torchvision.datasets.CIFAR10(root='./dataset', train=True, transform=torchvision.transforms.ToTensor(),download=True)
# 测试数据
test_data = torchvision.datasets.CIFAR10(root='./dataset', train=False, transform=torchvision.transforms.ToTensor(),download=True)# 查看数据大小--size
print("训练数据集大小:", len(train_data))
print("测试数据集大小:", len(test_data))
# 利用DataLoader来加载数据集
train_loader = DataLoader(dataset=train_data, batch_size=64)
test_loader = DataLoader(dataset=test_data, batch_size=64)# 2. 导入模型结构 创建模型
net = Cifar10Net()# 3. 创建损失函数  分类问题--交叉熵
loss_fn = nn.CrossEntropyLoss()# 4. 创建优化器
# learing_rate = 0.01
# 1e-2 = 1 * 10^(-2) = 0.01
learing_rate = 1e-2
print(learing_rate)
optimizer = torch.optim.SGD(net.parameters(), lr=learing_rate)# 设置训练网络的一些参数
epoch = 10   # 记录训练的轮数
total_train_step = 0  # 记录训练的次数
total_test_step = 0   # 记录测试的次数# 利用tensorboard显示训练loss趋势
writer = SummaryWriter('./train_logs')for i in range(epoch):# 训练步骤开始net.train()  # 可以加可以不加  只有当模型结构有 Dropout BatchNorml层才会起作用for data in train_loader:imgs, targets = data  # 获取数据output = net(imgs)    # 数据输入模型loss = loss_fn(output, targets)  # 损失函数计算损失 看计算的输出和真实的标签误差是多少# 优化器开始优化模型  1.梯度清零  2.反向传播  3.参数优化optimizer.zero_grad()  # 利用优化器把梯度清零 全部设置为0loss.backward()        # 设置计算的损失值,调用损失的反向传播,计算每个参数结点的参数optimizer.step()       # 调用优化器的step()方法 对其中的参数进行优化# 优化一次 认为训练了一次total_train_step += 1if total_train_step % 100 == 0:print('训练次数: {}   loss: {}'.format(total_train_step, loss))# 直接打印loss是tensor数据类型,打印loss.item()是打印的int或float真实数值, 真实数值方便做数据可视化【损失可视化】# print('训练次数: {}   loss: {}'.format(total_train_step, loss.item()))writer.add_scalar('train-loss', loss.item(), global_step=total_train_step)# 利用现有模型做模型测试# 测试步骤开始total_test_loss = 0accuracy = 0net.eval()  # 可以加可以不加  只有当模型结构有 Dropout BatchNorml层才会起作用with torch.no_grad():for data in test_loader:imags, targets = dataoutput = net(imags)loss = loss_fn(output, targets)total_test_loss += loss.item()# 计算测试集的正确率preds = (output.argmax(1)==targets).sum()accuracy += preds# writer.add_scalar('test-loss', total_test_loss, global_step=i+1)writer.add_scalar('test-loss', total_test_loss, global_step=total_test_step)writer.add_scalar('test-accracy', accuracy/len(test_data), total_test_step)total_test_step += 1print("---------test loss: {}--------------".format(total_test_loss))print("---------test accuracy: {}--------------".format(accuracy))# 保存每一个epoch训练得到的模型torch.save(net, './net_epoch{}.pth'.format(i))writer.close()

3. 结果

训练数据集大小: 50000
测试数据集大小: 10000
0.01
训练次数: 100   loss: 2.2905373573303223
训练次数: 200   loss: 2.2878968715667725
训练次数: 300   loss: 2.258394718170166
训练次数: 400   loss: 2.1968581676483154
训练次数: 500   loss: 2.0476632118225098
训练次数: 600   loss: 2.002145767211914
训练次数: 700   loss: 2.016021728515625
---------test loss: 316.382279753685--------------
训练次数: 800   loss: 1.8957302570343018
训练次数: 900   loss: 1.8659226894378662
训练次数: 1000   loss: 1.9004186391830444
训练次数: 1100   loss: 1.9708642959594727
......

tensorboard结果

安装tensorboard运行环境

pip install tensorboard
pip install opencv-python
pip install six
tensorboard --logdir=train_logs

以下图片 颜色较浅的线是真实计算的值,颜色较深的线是做了平滑处理的值

训练loss

在这里插入图片描述

测试loss

在这里插入图片描述

测试集正确率

在这里插入图片描述

4. 需要注意的细节

https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module

所有网络层继承于torch.nn.Module, net.train() net.eval() 在模型训练或测试之初 可以加可以不加 只有当模型结构有 Dropout BatchNorml层才会起作用,当模型有这两个网络层的时候,两个代码需要加上。
在这里插入图片描述

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/208738.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

国产化软件突围!怿星科技eStation产品荣获2023铃轩奖“前瞻优秀奖”

11月11日,2023中国汽车供应链峰会暨第八届铃轩奖颁奖典礼在江苏省昆山市举行。怿星科技凭借eStation产品,荣获2023铃轩奖“前瞻智能座舱类优秀奖”,怿星CEO潘凯受邀出席铃轩奖晚会并代表领奖。 2023铃轩奖“前瞻智能座舱类优秀奖” 铃轩奖&a…

el-table 跨页多选

步骤一 在<el-table>中:row-key"getRowKeys"和selection-change"handleSelectionChange" 在<el-table-column>中type"selection"那列&#xff0c;添加:reserve-selection"true" <el-table:data"tableData"r…

队列排序:给定序列a,每次操作将a[1]移动到 从右往左第一个严格小于a[1]的元素的下一个位置,求能否使序列有序,若可以,求最少操作次数

题目 思路&#xff1a; 赛时代码&#xff08;先求右起最长有序区间长度&#xff0c;再求左边最小值是否小于等于右边有序区间左端点的数&#xff09; #include<bits/stdc.h> using namespace std; #define int long long const int maxn 1e6 5; int a[maxn]; int n; …

阿里云磁盘在线扩容

我们从阿里云的控制面板中给硬盘扩容后结果发现我们的磁盘空间并没有改变 注意&#xff1a;本次操作是针对CentOS 7的 &#xfeff;#使用df -h并没有发现我们的磁盘空间增加 #使用fdisk -l发现确实还有部分空间 运行df -h命令查看云盘分区大小。 以下示例返回分区&#xf…

python3安装redis

#!/usr/bin/python3import os import platform import argparse import shutil# 自定义变量 default_system "ubuntu" default_redis_version "6.2.6" default_install_path "/usr/local/redis" default_local_package_dir os.path.dirname(…

eve-ng镜像模拟设备-信息安全管理与评估-2023国赛

eve-ng镜像模拟设备-信息安全管理与评估-2023国赛 author&#xff1a;leadlife data&#xff1a;2023/12/4 mains&#xff1a;EVE-ng 模拟器 - 信息安全管理与评估模拟环境部署 references&#xff1a; EVE-ng 官网&#xff1a;https://www.eve-ng.net/EVE-ng 中文网&#xff1…

嵌入版python作为便携计算器(安装及配置ipython)

今天用别的电脑调试C&#xff0c;需要计算反三角函数时发现没有趁手工具&#xff0c;忽然想用python作为便携计算器放在U盘&#xff0c;遂想到嵌入版python 懒得自己配可以直接下载&#xff0c;使用方法见第4节 1&#xff0c;下载embeddable python&#xff08;嵌入版python&…

React中传入props.children后, 为什么会导致组件的重新渲染?

传入props.children后, 为什么会导致组件的重新渲染&#xff1f; 问题描述 在 react 中, 我想要对组件的渲染进行优化, 遇到了一个非常意思的问题, 当我向一个组件中传入了 props.children 之后, 每次父组件重新渲染都会导致这个组件的重新渲染; 它看起来的表现就像是被memo包…

【1day】​万户协同办公平台 convertFile 任意文件读取漏洞学习

注:该文章来自作者日常学习笔记,请勿利用文章内的相关技术从事非法测试,如因此产生的一切不良后果与作者无关。 目录 一、漏洞描述 二、影响版本 三、资产测绘 四、漏洞复现

图的邻接链表储存

喷了一节课 。。。。。。。、。 #include<stdio.h> #include<stdlib.h> #define MAXNUM 20 //每一个顶点的节点结构&#xff08;单链表&#xff09; typedef struct ANode{ int adjvex;//顶点指向的位置 struct ArcNode *next;//指向下一个顶点 …

C++ 内存分区模型

目录 程序运行前 代码区 全局区 程序运行后 new 在堆区开辟数据 delete释放堆区数据 堆区开辟数组 内存分区模型 栈&#xff08;Stack&#xff09; 堆&#xff08;Heap&#xff09; 全局/静态存储区&#xff08;Global/Static Storage&#xff09; 常量存储区&am…

力扣230. 二叉搜索树中第K小的元素

深度优先搜索 思路&#xff1a; 二叉搜索树的特性&#xff0c;通过中序遍历得到有序序列&#xff0c;则遍历到第K个节点的时候即为结果&#xff1b;使用栈通过深度优先遍历进行中序遍历&#xff1a; 先将节点和左子节点压栈&#xff1b;然后栈顶上就是“最左”叶子节点&#x…

Linux DAC权限的简单应用

Linux的DAC&#xff08;Discretionary Access Control&#xff09;权限模型是一种常见的访问控制机制&#xff0c;它用于管理文件和目录的访问权限。作为一名经验丰富的Linux系统安全工程师&#xff0c;我会尽可能以简单明了的方式向计算机小白介绍Linux DAC权限模型。 在Linu…

jenkins中“Jenkins Plot Plugin”的使用方法,比较两个excel的数据差异

Jenkins Plot Plugin是Jenkins的一个插件&#xff0c;它可以用于生成图表和报表&#xff0c;以便更好地理解和分析构建和测试数据。下面是使用Jenkins Plot Plugin比较两个Excel数据差异的步骤&#xff1a; 1.安装Jenkins Plot Plugin&#xff1a;在Jenkins的插件管理页面搜索…

使用 Axios 进行网络请求的全面指南

使用 Axios 进行网络请求的全面指南 本文将向您介绍如何使用 Axios 进行网络请求。通过分步指南和示例代码&#xff0c;您将学习如何使用 Axios 库在前端应用程序中发送 GET、POST、PUT 和 DELETE 请求&#xff0c;并处理响应数据和错误。 准备工作 在开始之前&#xff0c;请…

电子学会C/C++编程等级考试2021年09月(五级)真题解析

C/C++等级考试(1~8级)全部真题・点这里 第1题:抓牛 农夫知道一头牛的位置,想要抓住它。农夫和牛都位于数轴上,农夫起始位于点N(0<=N<=100000),牛位于点K(0<=K<=100000)。农夫有两种移动方式: 1、从X移动到X-1或X+1,每次移动花费一分钟 2、从X移动到2*X,每…

ubuntu18.04安装opencv-4.5.5+opencv_contrib-4.5.5

一、安装opencv依赖 sudo apt-get install build-essential sudo apt-get install cmake git libgtk2.0-dev pkg-config libavcodec-dev libavformat-dev libswscale-dev sudo apt-get install python-dev python-numpy libtbb2 libtbb-dev libjpeg-dev libpng-dev libtiff-d…

Navicat 技术指引 | 适用于 GaussDB 分布式的自动运行功能

Navicat Premium&#xff08;16.3.3 Windows 版或以上&#xff09;正式支持 GaussDB 分布式数据库。GaussDB 分布式模式更适合对系统可用性和数据处理能力要求较高的场景。Navicat 工具不仅提供可视化数据查看和编辑功能&#xff0c;还提供强大的高阶功能&#xff08;如模型、结…

「Python编程基础」第7章:字符串操作

文章目录 一、回顾二、新手容易踩坑的引号三、转义字符四、多行字符串写法五、注释六、字符串索引和切片七、字符串的in 和 not in八、字符串拼接九、转换大小写十、合并字符串join()十一、分割字符串split()十二、字符串替换 replace()十三、字符串内容判断方法十四、字符串内…

读文章摘录

20%的时间可以做点业余项目。有个叫克莱舍基的人&#xff0c;写了一本书&#xff0c;书名叫《认知盈余-网络时代的创造与繁荣》&#xff0c;他有个观点&#xff0c;闲暇时间给人机会创造有价值的东西。 很重要的一点是选合适的人&#xff0c;把他们引入团队。何谓合适的人&…