PyTorch学习笔记(十五)——完整的模型训练套路

以 CIFAR10 数据集为例,分类问题(10分类)

 

model.py

import torch
from torch import nn# 搭建神经网络
class MyNN(nn.Module):def __init__(self):super(MyNN, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64 * 4 * 4, 64),nn.Linear(64, 10))def forward(self, x):x = self.model(x)return xif __name__ == '__main__':# 验证网络的正确性mynn = MyNN()input = torch.ones(64,3,32,32)output = mynn(input)print(output)

运行结果:torch.Size([64,10]) 

返回64行数据,每一行数据有10个数据,代表每一张图片在10个类别中的概率

train.py

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
# model.py必须和train.py在同一个文件夹下
from model import *# 准备数据集(CIFAR10 数据集是PIL Image,要转换为tensor数据类型)
train_data = torchvision.datasets.CIFAR10(root="../datasets",train=True,transform=torchvision.transforms.ToTensor(),download=False)
test_data = torchvision.datasets.CIFAR10(root="../datasets",train=False,transform=torchvision.transforms.ToTensor(),download=False)# 获得数据集的长度
train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度为:{}".format(train_data_size))
print("测试数据集的长度为:{}".format(test_data_size))# 利用dataloader来加载数据集
train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)# 创建网络模型
mynn = MyNN()
# 损失函数
loss_function = nn.CrossEntropyLoss()
# 优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(mynn.parameters(), lr=learning_rate) # SGD 随机梯度下降# 设置训练网络的一些参数
total_train_step = 0 # 记录训练次数
total_test_step = 0 # 记录测试次数
epoch = 10 # 训练的轮数# 添加tensorboard
writer = SummaryWriter("../logs_train")for i in range(epoch):print("----------第{}轮训练开始----------".format(i+1))# 训练步骤开始mynn.train()for data in train_dataloader:imgs,targets = dataoutputs = mynn(imgs)loss = loss_function(outputs, targets)# 优化器优化模型optimizer.zero_grad() # 首先要梯度清零loss.backward() # 反向传播得到每一个参数节点的梯度optimizer.step() # 对参数进行优化total_train_step += 1# 训练步骤逢百才打印记录if total_train_step % 100 == 0:print("训练次数:{},loss:{}".format(total_train_step, loss.item()))writer.add_scalar("train_loss",loss.item(),total_train_step)# 测试步骤开始mynn.eval()total_test_loss = 0total_accuracy = 0# 无梯度,不进行调优with torch.no_grad():for data in test_dataloader:imgs,targets = dataoutputs = mynn(imgs)loss = loss_function(outputs, targets)total_test_loss += loss# 即便得到整体测试集上的 loss,也不能很好说明在测试集上的表现效果# 在分类问题中可以用正确率表示accuracy = (outputs.argmax(1) == targets).sum()total_accuracy += accuracyprint("整体测试集上的loss:{}".format(total_test_loss))print("整体测试集上的正确率:{}".format(total_accuracy/test_data_size))writer.add_scalar("test_loss",total_test_loss,total_test_step)writer.add_scalar("test_accuracy",total_accuracy/test_data_size,total_test_step)total_test_step += 1# 保存每一轮训练的模型torch.save(mynn,"mynn_{}.pth".format(i))# torch.save(mynn.state_dict(),"mynn_{}.pth".format(i))print("模型已保存")writer.close()

 

 关于正确率的计算:

方式1:

import torchoutputs = torch.tensor([[0.1,0.2],[0.3,0.4]])
target = torch.tensor([0,1])predict = outputs.argmax(1)
print(predict)print(predict == target)
print((predict == target).sum())

 方式2:

import torchoutputs = torch.tensor([[0.1,0.2],[0.3,0.4]])
target = torch.tensor([0,1])predict = torch.max(outputs, dim=1)[1]
print(predict)print(torch.eq(predict,target))
print(torch.eq(predict,target).sum())
print(torch.eq(predict,target).sum().item())

关于mynn.train()和mynn.eval():

这两句不写网络依然可以运行,它们的作用是:

 

 

这个案例没有 Dropout 层或 BatchNorm 层,所以有没有这两行都无所谓。但如果有这些特殊层,一定要调用。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/45117.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

VMware 虚拟机三种网络模式详解

文章目录 前言桥接模式(Bridged)桥接模式特点: 仅主机模式 (Host-only)仅主机模式 (Host-only)特点: NAT网络地址转换模式(NAT)网络地址转换模式(NAT 模式)特点: 前言 很多同学在初次接触虚拟机的时候对 VMware 产品的三种网络模式不是很理解,本文就 VMware 的三种网络模式进行…

【数据结构OJ题】合并两个有序链表

原题链接:https://leetcode.cn/problems/merge-two-sorted-lists/description/ 目录 1. 题目描述 2. 思路分析 3. 代码实现 1. 题目描述 2. 思路分析 可以先创建一个空链表,然后依次从两个有序链表中选取最小的进行尾插操作。(有点类似双…

Java-抽象类和接口(下)

接口使用实例 给对象数组排序 两个学生对象的大小关系怎么确定? 需要我们额外指定. 这里需要用到Comparable 接口 在Comparable 接口内部有一个compareTo 的方法,我们需要实现它 在下图中,我们需要将o强制转换为Student 之后调用Arrays.sort(array)即…

rabbitmq容器启动后修改连接密码

1、进入容器 docker exec -it rabbitmq bash 2、查看当前用户列表 rabbitmqctl list_users 3、修改密码 rabbitmqctl change_password [username] ‘[NewPassword]’ 4、修改后退出容器 ctrlpq 5、退出容器后即可生效,不需要重启容器

redis实战-缓存数据解决缓存与数据库数据一致性

缓存的定义 缓存(Cache),就是数据交换的缓冲区,俗称的缓存就是缓冲区内的数据,一般从数据库中获取,存储于本地代码。防止过高的数据访问猛冲系统,导致其操作线程无法及时处理信息而瘫痪,这在实际开发中对企业讲,对产品口碑,用户评价都是致命的;所以企业非常重视缓存…

Hi-TRS:骨架点视频序列的层级式建模及层级式自监督学习

论文题目:Hierarchically Self-Supervised Transformer for Human Skeleton Representation Learning 论文下载地址:https://www.ecva.net/papers/eccv_2022/papers_ECCV/papers/136860181.pdf 代码地址:https://github.com/yuxiaochen1103…

Linux -- 进阶 利用大文件来增加分区 自动挂载大文件

情景引入 : 比如, 你的硬盘 分了三个区,但是,现在就是要求要分第四个区, 你一看硬盘没有剩余空 间了,分不出第四个区了,除非你再添加 一块儿 新硬盘。 那就可以使用我们介绍的这种方法 &…

并发-并发挑战及底层实现原理笔记

并发编程挑战 上下文切换 cpu通过给每个线程分配cpu时间片实现多线程执行,时间片是cpu分配给各个线程的时间,cpu通过不断切换线程执行。线程有创建和上下文切换的开销。减少上下文切换的方方法 – 无锁并发编程,eg:将数据的id按…

记录hutool http通过代理模式proxy访问外面的链接

效果: 代码: public class TestMain {public static void main(String[] args){HttpRequest httpRequest HttpRequest.get("https://www.youtube.com").timeout(30000);httpRequest.setProxy(new Proxy(Proxy.Type.HTTP,new InetSocketAddre…

Laravel 框架模型的定义 模型的增删改 批量赋值和软删除 ⑧

作者 : SYFStrive 博客首页 : HomePage 📜: THINK PHP 📌:个人社区(欢迎大佬们加入) 👉:社区链接🔗 📌:觉得文章不错可以点点关注 &#x1f44…

linux 搭建 nexus maven私服

目录 环境: 下载 访问百度网盘链接 官网下载 部署 : 进入目录,创建文件夹,进入文件夹 将安装包放入nexus文件夹,并解压​编辑 启动 nexus,并查看状态.​编辑 更改 nexus 端口为7020,并重新启动,访问虚拟机7020…

SpringBoot + Vue 前后端分离项目 微人事(九)

职位管理后端接口设计 在controller包里面新建system包,再在system包里面新建basic包,再在basic包里面创建PositionController类,在定义PositionController类的接口的时候,一定要与数据库的menu中的url地址到一致,不然…

JavaScript(JavaEE初阶系列13)

目录 前言: 1.初识JavaScript 2.JavaScript的书写形式 2.1行内式 2.2内嵌式 2.3外部式 2.4注释 2.5输入输出 3.语法 3.1变量的使用 3.2基本数据类型 3.3运算符 3.4条件语句 3.5循环语句 3.6数组 3.7函数 3.8对象 3.8.1 对象的创建 4.案例演示 4…

Linux 系统编程拾遗

Linux 系统编程拾遗 进程的创建 进程的创建 fork()、exit()、wait()以及execve()的简介 创建新进程:fork()

【ARM v8】如何在ARM上实现x86的rdtsc()函数

博主未授权任何人或组织机构转载博主任何原创文章,感谢各位对原创的支持! 博主链接 本人就职于国际知名终端厂商,负责modem芯片研发。 在5G早期负责终端数据业务层、核心网相关的开发工作,目前牵头6G算力网络技术标准研究。 博客…

LeetCode 热题 100(五):54. 螺旋矩阵、234. 回文链表、21. 合并两个有序链表

题目一: 54. 螺旋矩阵https://leetcode.cn/problems/spiral-matrix/ 题目要求: 思路:一定要先找好边界。如下图 ,上边界是1234,右边界是8、12,下边界是9、10、11,左边界是5,所以可…

滑块验证码-接口返回base64数据

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言所需包图片示例使用方法提示前言 滑动验证码在实际爬虫开发过程中会遇到很多,不同网站返回的数据也是千奇百怪。这里分享一种接口返回base64格式的情况以及处理方式 所需包 opencv-python、…

vue3 路由缓存问题

目录 解决问题的思路: 解决问题的方案: 1、给roter-view添加key(破坏复用机制,强制销毁重建) 2、使用beforeRouteUpdate导航钩子 3、使用watch监听路由 vue3路由缓存:当用户从/users/johnny导航到/use…

Linux网络编程:Socket套接字编程(Server服务器 Client客户端)

文章目录: 一:定义和流程分析 1.定义 2.流程分析 3.网络字节序 二:相关函数 IP地址转换函数inet_pton inet_ntop(本地字节序 网络字节序) socket函数(创建一个套接字) bind函数(给socket绑定一个服务器地址结…

Git概述

目录 一、什么是Git 二、什么是版本控制系统 三、Git和SVN对比 SVN集中式 SVN优缺点 Git分布式 Git优缺点 四、Git工作流程 四个工作区域 工作流程 五、Git下载与安装 一、什么是Git 很多人都知道,林纳斯托瓦兹在1991年创建了开源的Linux,从…