PyTorch学习笔记(十五)——完整的模型训练套路

以 CIFAR10 数据集为例,分类问题(10分类)

 

model.py

import torch
from torch import nn# 搭建神经网络
class MyNN(nn.Module):def __init__(self):super(MyNN, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64 * 4 * 4, 64),nn.Linear(64, 10))def forward(self, x):x = self.model(x)return xif __name__ == '__main__':# 验证网络的正确性mynn = MyNN()input = torch.ones(64,3,32,32)output = mynn(input)print(output)

运行结果:torch.Size([64,10]) 

返回64行数据,每一行数据有10个数据,代表每一张图片在10个类别中的概率

train.py

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
# model.py必须和train.py在同一个文件夹下
from model import *# 准备数据集(CIFAR10 数据集是PIL Image,要转换为tensor数据类型)
train_data = torchvision.datasets.CIFAR10(root="../datasets",train=True,transform=torchvision.transforms.ToTensor(),download=False)
test_data = torchvision.datasets.CIFAR10(root="../datasets",train=False,transform=torchvision.transforms.ToTensor(),download=False)# 获得数据集的长度
train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度为:{}".format(train_data_size))
print("测试数据集的长度为:{}".format(test_data_size))# 利用dataloader来加载数据集
train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)# 创建网络模型
mynn = MyNN()
# 损失函数
loss_function = nn.CrossEntropyLoss()
# 优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(mynn.parameters(), lr=learning_rate) # SGD 随机梯度下降# 设置训练网络的一些参数
total_train_step = 0 # 记录训练次数
total_test_step = 0 # 记录测试次数
epoch = 10 # 训练的轮数# 添加tensorboard
writer = SummaryWriter("../logs_train")for i in range(epoch):print("----------第{}轮训练开始----------".format(i+1))# 训练步骤开始mynn.train()for data in train_dataloader:imgs,targets = dataoutputs = mynn(imgs)loss = loss_function(outputs, targets)# 优化器优化模型optimizer.zero_grad() # 首先要梯度清零loss.backward() # 反向传播得到每一个参数节点的梯度optimizer.step() # 对参数进行优化total_train_step += 1# 训练步骤逢百才打印记录if total_train_step % 100 == 0:print("训练次数:{},loss:{}".format(total_train_step, loss.item()))writer.add_scalar("train_loss",loss.item(),total_train_step)# 测试步骤开始mynn.eval()total_test_loss = 0total_accuracy = 0# 无梯度,不进行调优with torch.no_grad():for data in test_dataloader:imgs,targets = dataoutputs = mynn(imgs)loss = loss_function(outputs, targets)total_test_loss += loss# 即便得到整体测试集上的 loss,也不能很好说明在测试集上的表现效果# 在分类问题中可以用正确率表示accuracy = (outputs.argmax(1) == targets).sum()total_accuracy += accuracyprint("整体测试集上的loss:{}".format(total_test_loss))print("整体测试集上的正确率:{}".format(total_accuracy/test_data_size))writer.add_scalar("test_loss",total_test_loss,total_test_step)writer.add_scalar("test_accuracy",total_accuracy/test_data_size,total_test_step)total_test_step += 1# 保存每一轮训练的模型torch.save(mynn,"mynn_{}.pth".format(i))# torch.save(mynn.state_dict(),"mynn_{}.pth".format(i))print("模型已保存")writer.close()

 

 关于正确率的计算:

方式1:

import torchoutputs = torch.tensor([[0.1,0.2],[0.3,0.4]])
target = torch.tensor([0,1])predict = outputs.argmax(1)
print(predict)print(predict == target)
print((predict == target).sum())

 方式2:

import torchoutputs = torch.tensor([[0.1,0.2],[0.3,0.4]])
target = torch.tensor([0,1])predict = torch.max(outputs, dim=1)[1]
print(predict)print(torch.eq(predict,target))
print(torch.eq(predict,target).sum())
print(torch.eq(predict,target).sum().item())

关于mynn.train()和mynn.eval():

这两句不写网络依然可以运行,它们的作用是:

 

 

这个案例没有 Dropout 层或 BatchNorm 层,所以有没有这两行都无所谓。但如果有这些特殊层,一定要调用。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/45117.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

VMware 虚拟机三种网络模式详解

文章目录 前言桥接模式(Bridged)桥接模式特点: 仅主机模式 (Host-only)仅主机模式 (Host-only)特点: NAT网络地址转换模式(NAT)网络地址转换模式(NAT 模式)特点: 前言 很多同学在初次接触虚拟机的时候对 VMware 产品的三种网络模式不是很理解,本文就 VMware 的三种网络模式进行…

【数据结构OJ题】合并两个有序链表

原题链接:https://leetcode.cn/problems/merge-two-sorted-lists/description/ 目录 1. 题目描述 2. 思路分析 3. 代码实现 1. 题目描述 2. 思路分析 可以先创建一个空链表,然后依次从两个有序链表中选取最小的进行尾插操作。(有点类似双…

Java-抽象类和接口(下)

接口使用实例 给对象数组排序 两个学生对象的大小关系怎么确定? 需要我们额外指定. 这里需要用到Comparable 接口 在Comparable 接口内部有一个compareTo 的方法,我们需要实现它 在下图中,我们需要将o强制转换为Student 之后调用Arrays.sort(array)即…

rabbitmq容器启动后修改连接密码

1、进入容器 docker exec -it rabbitmq bash 2、查看当前用户列表 rabbitmqctl list_users 3、修改密码 rabbitmqctl change_password [username] ‘[NewPassword]’ 4、修改后退出容器 ctrlpq 5、退出容器后即可生效,不需要重启容器

redis实战-缓存数据解决缓存与数据库数据一致性

缓存的定义 缓存(Cache),就是数据交换的缓冲区,俗称的缓存就是缓冲区内的数据,一般从数据库中获取,存储于本地代码。防止过高的数据访问猛冲系统,导致其操作线程无法及时处理信息而瘫痪,这在实际开发中对企业讲,对产品口碑,用户评价都是致命的;所以企业非常重视缓存…

Hi-TRS:骨架点视频序列的层级式建模及层级式自监督学习

论文题目:Hierarchically Self-Supervised Transformer for Human Skeleton Representation Learning 论文下载地址:https://www.ecva.net/papers/eccv_2022/papers_ECCV/papers/136860181.pdf 代码地址:https://github.com/yuxiaochen1103…

Linux -- 进阶 利用大文件来增加分区 自动挂载大文件

情景引入 : 比如, 你的硬盘 分了三个区,但是,现在就是要求要分第四个区, 你一看硬盘没有剩余空 间了,分不出第四个区了,除非你再添加 一块儿 新硬盘。 那就可以使用我们介绍的这种方法 &…

并发-并发挑战及底层实现原理笔记

并发编程挑战 上下文切换 cpu通过给每个线程分配cpu时间片实现多线程执行,时间片是cpu分配给各个线程的时间,cpu通过不断切换线程执行。线程有创建和上下文切换的开销。减少上下文切换的方方法 – 无锁并发编程,eg:将数据的id按…

记录hutool http通过代理模式proxy访问外面的链接

效果: 代码: public class TestMain {public static void main(String[] args){HttpRequest httpRequest HttpRequest.get("https://www.youtube.com").timeout(30000);httpRequest.setProxy(new Proxy(Proxy.Type.HTTP,new InetSocketAddre…

Laravel 框架模型的定义 模型的增删改 批量赋值和软删除 ⑧

作者 : SYFStrive 博客首页 : HomePage 📜: THINK PHP 📌:个人社区(欢迎大佬们加入) 👉:社区链接🔗 📌:觉得文章不错可以点点关注 &#x1f44…

chatgpt汇编,两个关于chatgpt的问题

1.用C写一个高级的可自主学习的AI语言模型 编写一个高级的可自主学习的AI语言模型是一个复杂且庞大的工程,需要深入的专业知识和大量的时间精力。下面是一个简化的示例,演示如何使用C实现一个基本的可自主学习的AI语言模型。 首先,你需要定…

ffmpeg将图片压缩成视频过程中的问题记录

问题描述:服务器集群上在使用如下命令对img_path文件夹下的图片(图片名1.jpg……300.jpg)进行批量合成视频过程中,ffmpeg创建完视频文件后,写入过程中,自动kill,捕获异常为空。问题原因:集群pods的内存不够…

python读取word/pdf文档,指定文字内容和图片

读编号转文件夹目录然后放图片进去那个 一 先将word转为PDF pdf 读起来比较方便, 按页码读取文件: import pdfplumber from PIL import Image import cv2 import numpy as np import re import os import logging import iodef create_folder(folder_name):if not…

django sqlite3操作和manage.py功能介绍

参考链接:https://www.cnblogs.com/csd97/p/8432715.html manage.py 常用命令_python manage.py_追逐&梦想的博客-CSDN博客 python django操作sqlite3_django sqlite_浪子仙迹的博客-CSDN博客

linux 搭建 nexus maven私服

目录 环境: 下载 访问百度网盘链接 官网下载 部署 : 进入目录,创建文件夹,进入文件夹 将安装包放入nexus文件夹,并解压​编辑 启动 nexus,并查看状态.​编辑 更改 nexus 端口为7020,并重新启动,访问虚拟机7020…

SpringBoot + Vue 前后端分离项目 微人事(九)

职位管理后端接口设计 在controller包里面新建system包,再在system包里面新建basic包,再在basic包里面创建PositionController类,在定义PositionController类的接口的时候,一定要与数据库的menu中的url地址到一致,不然…

JavaScript(JavaEE初阶系列13)

目录 前言: 1.初识JavaScript 2.JavaScript的书写形式 2.1行内式 2.2内嵌式 2.3外部式 2.4注释 2.5输入输出 3.语法 3.1变量的使用 3.2基本数据类型 3.3运算符 3.4条件语句 3.5循环语句 3.6数组 3.7函数 3.8对象 3.8.1 对象的创建 4.案例演示 4…

【hive】hive修复分区或修复表 以及msck命令的使用

【hive】hive修复分区或修复表 以及msck命令的使用 文章目录 【hive】hive修复分区或修复表 以及msck命令的使用问题原因:解决方法:msck命令解析:例子: 问题原因: 之前hive里有数据,后面存储元数据信息的MySQL数据库坏…

rocketBot使用/Rpc调用监控

9 RocketBot使用 这里可以获取到比较详细的地方。可以通过追踪id的方式进行查询。只支持精准查询。[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-FvGvUo6l-1692459587743)(C:\Users\15870\AppData\Roaming\Typora\typora-user-images\image-202308…

Linux 系统编程拾遗

Linux 系统编程拾遗 进程的创建 进程的创建 fork()、exit()、wait()以及execve()的简介 创建新进程:fork()