pytorch学习笔记(十二)

 以下代码是以CIFAR10这个10分类的图片数据集训练过程的完整的代码。

训练部分

train.py主要包含以下几个部件:

  • 准备训练、测试数据集
  • 用DateLoader加载两个数据集,要设置好batchsize
  • 创建网络模型(具体模型在model.py中)
  • 设置损失函数
  • 设置优化器,其中要包含优化的参数和学习率
  • 初始化一些参数,如训练测试的次数、以及训练的轮数epoch
  • 以训练轮数为循环进入训练
  • 从训练数据中加载数据,将数据(模型的输出和目标(标签))送进损失函数中计算损失
  • 梯度清零,并且反向传播损失函数,用优化器进行参数更新,并累计训练步数。
  • 在保证不调优的情况下看正确率(with)
    从测试集中拿数据,一样的讨论算损失,但是要算正确率
  • 用tensorboard可是话训练的结果

关于imgs, targets =data这句代码中的targets解释

  1. imgs (Images): 这个变量通常包含一批图像数据。在计算机视觉任务中,这些图像是模型的输入,可以是任何形式的视觉数据,比如照片、视频帧、医学影像等。在训练过程中,这些图像通过神经网络进行前向传播以生成预测结果。

  2. targets (Targets): 这个变量包含与 imgs 中每个图像对应的标签或目标。标签的具体形式取决于执行的任务:

    • 分类任务中,targets 可能是类别标签,例如识别图像中的对象(猫、狗、汽车等)。
    • 对象检测任务中,targets 可能包括对象的边界框(bounding boxes)和类别。
    • 语义分割任务中,targets 可能是每个像素的类别标签。
    • 回归任务中,targets 可能是一些连续值,如在面部关键点检测中的坐标点。

在训练过程中targets用于损失函数(交叉熵损失、均方误差等),这是模型学习并优化其参数的基础。损失函数衡量了模型预测和真实目标之间的差异,训练目标是最小化差异。

关于optimizer.step()的解释

在机器学习中,这玩意是个关键操作,就是用来根性模型参数的。

优化器和梯度下降,常用的优化算法(SGD、Adam、RMSprop等)来调整网络参数(如权重和偏差),以最小化损失函数。这个过程被称为梯度下降。

训练过程中的步骤:

  • Forward Pass:输入数据进行前向计算,生成预测。
  • 计算损失函数,比较网络的预测和真实计算损失
  • 反向传播:通过反向传播损失,计算每个参数梯度 loss.backward()来完成。
  • 更新参数optimizer.step()被调用来更新网络的参数。根据计算出的梯度和定义的优化算法,它会调整参数以减小损失。

注意: 

optimizer.step()根据优化器预定义的规则和计算出的梯度来更新模型参数。在调用它之后,会执行optimizer.zero_grad(),以便下一次迭代时从干净的状态开始。

import torch.nn
import torchvision
from torch.utils.tensorboard import SummaryWriterfrom model import *
from torch.utils.data import DataLoader
from torch import nn#准备数据集
train_data = torchvision.datasets.CIFAR10(root="../data", train=True, transform=torchvision.transforms.ToTensor(),download=True)
#测试数据集
test_data = torchvision.datasets.CIFAR10(root="../data", train=False, transform=torchvision.transforms.ToTensor(),download=True)
#length 长度
train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度:{}".format(train_data_size))
print("测试数据集的长度:{}".format(test_data_size))#利用DataLoader 来加载数据集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)#创建网络模型
tudui = Tudui()#损失函数
loss_fn = nn.CrossEntropyLoss()#优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(tudui.parameters(), lr=learning_rate)#设置训练网络的一些参数
#记录训练的次数
total_train_step = 0
#测试的次数
total_test_step = 0
#训练的轮数
epoch = 10#添加 tensorboard
writer = SummaryWriter("../logs_train")for i in range(epoch):print("-------------第{}轮训练开始-------------".format(i+1))#训练步骤开始#并不需要把网络设置成训练状态才能进行训练tudui.train()for data in train_dataloader:imgs, targets =dataoutputs = tudui(imgs)loss = loss_fn(outputs, targets)#梯度清零#优化器优化模型optimizer.zero_grad()loss.backward()optimizer.step()total_train_step = total_train_step + 1#避免无用信息覆盖if total_train_step % 100 == 0:print("训练次数: {},loss: {}".format(total_train_step, loss.item()))writer.add_scalar("train_loss", loss.item(), total_train_step)#测试步骤#也不是需要把网络设置成eval状态才能进行网络的一个测试tudui.eval()total_test_loss = 0#看正确率total_accuracy = 0#在with里面的代码没有了梯度,保证不会进行调优with torch.no_grad():for data in test_dataloader:imgs, targets =dataoutputs = tudui(imgs)#一部分数据在网络模型上的损失loss = loss_fn(outputs, targets)total_test_loss = total_test_loss + lossaccuracy = (outputs.argmax(1) == targets).sum()total_accuracy = total_accuracy + accuracyprint("整体测试集上的Loss:{}".format(total_test_loss))print("整体测试的正确率:{}".format(total_accuracy/test_data_size))writer.add_scalar("train_loss", loss.item(), total_test_step)writer.add_scalar("test_accuracy", total_accuracy/test_data_size, total_test_step)#测试的步骤+1否则图画不出来total_train_step = total_test_step + 1torch.save(tudui, "tudui_{}.pth".format(i))print("模型已保存")
writer.close()

上面是一个训练过程,下面介绍一下训练准确率怎么得来的。

假设有一个2分类的模型

Model(2分类)

#下面是得分

Outputs = [[0.2,0.3],[0.1,0.4]]

#通过Argmax 变成

Preds = [1]

                [1]

Inputs target=[0][1]

Preds==inputs target

#上面的这个式子返回的就是T or F

#加起来就是分类正确的个数了。

[false,true].sum()=1

                                       

这边注意一下output.argmax(x)的方向,x是0或是1,0的方向是竖着来的,1的方向是横着来的。

import torch
outputs = torch.tensor([[0.1,0.2],[0.3,0.4]])
print(outputs.argmax(1))
preds = outputs.argmax(1)
targets = torch.tensor([0,1])
print((preds == targets).sum())

-----------------------------------------------------未完待续1------------------------------------------------------------- 

 训练的一些细节:

如果有Dropout和BatchNorm等一些特殊层,需要

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/660736.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深入了解C++:底层编译原理

进程的虚拟空间划分 任何编程语言,都会产生两样东西,指令和数据。 .exe程序运行的时候会从磁盘被加载到内存中,但是不能直接加载到物理内存中。Linux会给当前进程分配一块空间,比如x86 32位linux环境下会给进程分配2^32(4G)大小…

vue3页面跳转产生白屏,刷新后能正常展示的解决方案

可以依次检查以下问题: 1.是否在根组件标签最外层包含了个最大的div盒子包裹内容。 2.看看是否在template标签下面直接有注释,如果有需要把注释写到div里面。(即根标签下不要直接有注释) 3.在router-view 中给路由添加key标识。 …

(附源码)ssm 招聘信息管理系统-计算机毕设 78049

ssm 招聘客户管理系统 摘 要 由于数据库和数据仓库技术的快速发展,招聘客户管理系统建设越来越向模块化、智能化、自我服务和管理科学化的方向发展。招聘客户系统对处理对象和服务对象,自身的系统结构,处理能力,都将适应技术发展的…

脚本工具 mktemp 和 install

1.创建临时文件 mktemp 1.1 介绍 mktemp 命令用于创建并显示临时文件,可避免冲突 使用mktemp命令时,它会根据指定的模板在临时目录(默认为/tmp)中创建一个唯一的临时文件或目录,并返回该文件或目录的完整路径。临时…

在线UI设计工具有哪些?这5个包你满意

随着 UI 设计行业的蓬勃发展,越来越多的设计师进入 UI 设计,选择一款方便的 UI 设计工具尤为重要!除了传统的 UI 设计工具,在线 UI 设计工具也受到越来越多设计师的青睐。这种不受时间、地点和计算机配置限制的工作方法真的很刺激…

Python处理图片生成天际线(2024.1.29)

1、天际线简介 天际线(SkyLine)顾名思义就是天空与地面的边界线,人站在不同的高度,会看到不同的景色和地平线,天空与地面建筑物分离的标记线,不得不说,每天抬头仰望天空,相信大家都可…

屈子祠镇黑鱼岭,不可移动文物预防性保护系统守遗珍

一、何止秦俑 秦陵苑囿青铜水禽等文物集中展出 文物保护,尤其是不可移动文物的保护,一直都是文化遗产的重要环节。湖南省汨罗市屈子祠镇双楚村黑鱼岭墓地,作为长江中游地区的重大考古发现,其商朝晚期的历史背景赋予其不可估量的历…

响应式Web开发项目教程(HTML5+CSS3+Bootstrap)第2版 例5-6 绘制几何图形

代码 <!doctype html> <html> <head> <meta charset"utf-8"> <title>绘制几何图形</title> </head><body><canvas id"canvas" width"250" height"150" style"border: 1px b…

QUIC with CUBIC or BBR

拥塞控制 拥塞控制算法是 TCP/QUIC 协议的一个基础部分&#xff0c;多年来经过一个个版本的迭代&#xff08;如 Tahoe、Reno、Vegas 等&#xff09;&#xff0c;拥塞控制算法得到了持续的提升。由于篇幅有限&#xff0c;本文就目前比较流行的两种拥塞控制算法&#xff08;CUBI…

SVM支持向量机

1.基本概念 支持向量机&#xff08;Support Vector Machine&#xff0c;SVM&#xff09;是一种有监督学习方法&#xff0c;主要用于分类和回归分析。它的基本思想是在特征空间中找到一个超平面&#xff0c;能够将不同类别的样本分开&#xff0c;并且使得离这个超平面最近的样本…

sql注入,布尔盲注和时间盲注,无回显

布尔盲注 通过order by分组可以看到&#xff0c;如果正确会i显示you are in&#xff0c;错误则无任何提示&#xff0c;由此可以判断出&#xff0c;目前只显示对错&#xff0c;此外前端不会显示任何数据 也就是说&#xff0c;目前结果只有两种&#xff0c;在这种只有两种变量的…

Uniapp登录页面获取头像、昵称的最新方法的简单使用

前言 写小程序写到登录页面的时候&#xff0c;发现官方文档中原来的wx.getUserInfo和wx.getUserProfile不太能用了&#xff0c;学习了相对比较新的方法&#xff0c;这种方法的文档链接如下&#xff1a; https://developers.weixin.qq.com/miniprogram/dev/framework/open-abil…

上位机是什么?与下位机是什么关系

在工业自动化领域中&#xff0c;上位机是一项关键而引人注目的技术。许多人对上位机的概念感到好奇&#xff0c;想要深入了解其在工业智能中的作用。那么&#xff0c;上位机究竟是什么呢&#xff1f; 首先&#xff0c;上位机是一种用于工业控制系统的软件应用&#xff0c;通常…

配置支持 OpenAPI 的 ASP.NET Core 应用

写在前面 Swagger 是一个规范和完整的框架&#xff0c;用于生成、描述、调用和可视化 RESTful 风格的 Web 服务。 本文记录如何配置基于Swagger 的 ASP.NET Core 应用程序的 OpenAPI 规范。 需要从NuGet 安装 Swashbuckle.AspNetCore 包 代码实现 var builder WebApplicati…

STM32G4单片机

单片机的基本结构 CPU就是中央处理器&#xff0c;是单片机的内核 时钟电路&#xff0c;时钟源是给整个电路提供时序 其他的外设、中断以及存储器都是通过系统总线与CPU进行连接 RAM相当于电脑的内存条&#xff0c;随机存储器&#xff0c;掉电会丢失 ROM相当于电脑的硬盘&am…

美籍华裔力学和数学家林家翘

林家翘&#xff08;1916年7月7日—2013年1月13日&#xff09;&#xff0c;出生于北京&#xff0c;男&#xff0c;原籍福建福州&#xff0c;力学和数学家&#xff0c;美国艺术与科学院院士、美国国家科学院院士、中国科学院外籍院士&#xff0c;麻省理工学院荣誉退休教授 [1]。 …

初始MySQL

一 SQL的基本概述 基本概述 ▶SQL全称: Structured Query Language&#xff0c;是结构化查询语言&#xff0c;用于访问和处理数据库的标准的计算机语言。SQL语言1974年由Boyce和Chamberlin提出&#xff0c;并首先在IBM公司研制的关系数据库系统SystemR上实现。 ▶美国国家标…

【Linux】Linux基本指令

目录 1.ls指令 2.cd指令 3.touch指令 4.mkdir指令 5.rmdir指令和rm指令 5.1rmdir指令 5.2rm指令 6.man指令 7.cp指令 8.mv指令 9.cat指令 10.more指令 && less指令 10.1more指令 10.2less指令 11.head指令 && tail指令 11.1head指令 11.2tai…

服务网格(Service Mesh)流行工具

在这篇博客中&#xff0c;我们将介绍微服务的最佳服务网格工具列表&#xff0c;这些工具提供安全性、金丝雀部署、遥测、负载均衡等。 用于部署和操作微服务的服务网格工具的数量不断增加。在这篇文章中&#xff0c;我们将探讨您应该用来构建自己的服务网格架构的顶级服务网格…

【Eclipse平台】2 Eclipse Workbench工作台介绍

Eclipse Workbench工作台介绍 本文介绍Eclipse工作台Workbench。 当工作台启动时&#xff0c;首先看到的是一个对话框&#xff0c;该对话框允许我们选择工作区的位置。工作区是存储工作的目录。现在&#xff0c;只需单击“确定”即可选择默认位置。 选择工作区位置后&#x…