pytorch学习笔记(十二)

 以下代码是以CIFAR10这个10分类的图片数据集训练过程的完整的代码。

训练部分

train.py主要包含以下几个部件:

  • 准备训练、测试数据集
  • 用DateLoader加载两个数据集,要设置好batchsize
  • 创建网络模型(具体模型在model.py中)
  • 设置损失函数
  • 设置优化器,其中要包含优化的参数和学习率
  • 初始化一些参数,如训练测试的次数、以及训练的轮数epoch
  • 以训练轮数为循环进入训练
  • 从训练数据中加载数据,将数据(模型的输出和目标(标签))送进损失函数中计算损失
  • 梯度清零,并且反向传播损失函数,用优化器进行参数更新,并累计训练步数。
  • 在保证不调优的情况下看正确率(with)
    从测试集中拿数据,一样的讨论算损失,但是要算正确率
  • 用tensorboard可是话训练的结果

关于imgs, targets =data这句代码中的targets解释

  1. imgs (Images): 这个变量通常包含一批图像数据。在计算机视觉任务中,这些图像是模型的输入,可以是任何形式的视觉数据,比如照片、视频帧、医学影像等。在训练过程中,这些图像通过神经网络进行前向传播以生成预测结果。

  2. targets (Targets): 这个变量包含与 imgs 中每个图像对应的标签或目标。标签的具体形式取决于执行的任务:

    • 分类任务中,targets 可能是类别标签,例如识别图像中的对象(猫、狗、汽车等)。
    • 对象检测任务中,targets 可能包括对象的边界框(bounding boxes)和类别。
    • 语义分割任务中,targets 可能是每个像素的类别标签。
    • 回归任务中,targets 可能是一些连续值,如在面部关键点检测中的坐标点。

在训练过程中targets用于损失函数(交叉熵损失、均方误差等),这是模型学习并优化其参数的基础。损失函数衡量了模型预测和真实目标之间的差异,训练目标是最小化差异。

关于optimizer.step()的解释

在机器学习中,这玩意是个关键操作,就是用来根性模型参数的。

优化器和梯度下降,常用的优化算法(SGD、Adam、RMSprop等)来调整网络参数(如权重和偏差),以最小化损失函数。这个过程被称为梯度下降。

训练过程中的步骤:

  • Forward Pass:输入数据进行前向计算,生成预测。
  • 计算损失函数,比较网络的预测和真实计算损失
  • 反向传播:通过反向传播损失,计算每个参数梯度 loss.backward()来完成。
  • 更新参数optimizer.step()被调用来更新网络的参数。根据计算出的梯度和定义的优化算法,它会调整参数以减小损失。

注意: 

optimizer.step()根据优化器预定义的规则和计算出的梯度来更新模型参数。在调用它之后,会执行optimizer.zero_grad(),以便下一次迭代时从干净的状态开始。

import torch.nn
import torchvision
from torch.utils.tensorboard import SummaryWriterfrom model import *
from torch.utils.data import DataLoader
from torch import nn#准备数据集
train_data = torchvision.datasets.CIFAR10(root="../data", train=True, transform=torchvision.transforms.ToTensor(),download=True)
#测试数据集
test_data = torchvision.datasets.CIFAR10(root="../data", train=False, transform=torchvision.transforms.ToTensor(),download=True)
#length 长度
train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度:{}".format(train_data_size))
print("测试数据集的长度:{}".format(test_data_size))#利用DataLoader 来加载数据集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)#创建网络模型
tudui = Tudui()#损失函数
loss_fn = nn.CrossEntropyLoss()#优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(tudui.parameters(), lr=learning_rate)#设置训练网络的一些参数
#记录训练的次数
total_train_step = 0
#测试的次数
total_test_step = 0
#训练的轮数
epoch = 10#添加 tensorboard
writer = SummaryWriter("../logs_train")for i in range(epoch):print("-------------第{}轮训练开始-------------".format(i+1))#训练步骤开始#并不需要把网络设置成训练状态才能进行训练tudui.train()for data in train_dataloader:imgs, targets =dataoutputs = tudui(imgs)loss = loss_fn(outputs, targets)#梯度清零#优化器优化模型optimizer.zero_grad()loss.backward()optimizer.step()total_train_step = total_train_step + 1#避免无用信息覆盖if total_train_step % 100 == 0:print("训练次数: {},loss: {}".format(total_train_step, loss.item()))writer.add_scalar("train_loss", loss.item(), total_train_step)#测试步骤#也不是需要把网络设置成eval状态才能进行网络的一个测试tudui.eval()total_test_loss = 0#看正确率total_accuracy = 0#在with里面的代码没有了梯度,保证不会进行调优with torch.no_grad():for data in test_dataloader:imgs, targets =dataoutputs = tudui(imgs)#一部分数据在网络模型上的损失loss = loss_fn(outputs, targets)total_test_loss = total_test_loss + lossaccuracy = (outputs.argmax(1) == targets).sum()total_accuracy = total_accuracy + accuracyprint("整体测试集上的Loss:{}".format(total_test_loss))print("整体测试的正确率:{}".format(total_accuracy/test_data_size))writer.add_scalar("train_loss", loss.item(), total_test_step)writer.add_scalar("test_accuracy", total_accuracy/test_data_size, total_test_step)#测试的步骤+1否则图画不出来total_train_step = total_test_step + 1torch.save(tudui, "tudui_{}.pth".format(i))print("模型已保存")
writer.close()

上面是一个训练过程,下面介绍一下训练准确率怎么得来的。

假设有一个2分类的模型

Model(2分类)

#下面是得分

Outputs = [[0.2,0.3],[0.1,0.4]]

#通过Argmax 变成

Preds = [1]

                [1]

Inputs target=[0][1]

Preds==inputs target

#上面的这个式子返回的就是T or F

#加起来就是分类正确的个数了。

[false,true].sum()=1

                                       

这边注意一下output.argmax(x)的方向,x是0或是1,0的方向是竖着来的,1的方向是横着来的。

import torch
outputs = torch.tensor([[0.1,0.2],[0.3,0.4]])
print(outputs.argmax(1))
preds = outputs.argmax(1)
targets = torch.tensor([0,1])
print((preds == targets).sum())

-----------------------------------------------------未完待续1------------------------------------------------------------- 

 训练的一些细节:

如果有Dropout和BatchNorm等一些特殊层,需要

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/660736.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深入了解C++:底层编译原理

进程的虚拟空间划分 任何编程语言,都会产生两样东西,指令和数据。 .exe程序运行的时候会从磁盘被加载到内存中,但是不能直接加载到物理内存中。Linux会给当前进程分配一块空间,比如x86 32位linux环境下会给进程分配2^32(4G)大小…

vue3页面跳转产生白屏,刷新后能正常展示的解决方案

可以依次检查以下问题: 1.是否在根组件标签最外层包含了个最大的div盒子包裹内容。 2.看看是否在template标签下面直接有注释,如果有需要把注释写到div里面。(即根标签下不要直接有注释) 3.在router-view 中给路由添加key标识。 …

c# cass10 获取宗地内所有封闭线段的面积

获取面积的主要流程如下: 获取当前AutoCAD应用中的活动文档、数据库和编辑器对象。创建一个选择过滤器,限制用户只能选择"宗地"图层上的LWPOLYLINE对象作为外部边界。提示用户根据上述规则进行实体选择,并获取选择结果。检查用户是…

前端下载导出文件流,excel/word/pdf/zip等

** 一、导入导出接口增加responseType:‘blob’ ** axios({url: 接口,method: post,data:{},responseType: blob });二、导出方法封装 //data 文件流 //fileName 文件名称 /* mineType 文件类型例如:* 下载 Excel : "application/vnd.m…

(附源码)ssm 招聘信息管理系统-计算机毕设 78049

ssm 招聘客户管理系统 摘 要 由于数据库和数据仓库技术的快速发展,招聘客户管理系统建设越来越向模块化、智能化、自我服务和管理科学化的方向发展。招聘客户系统对处理对象和服务对象,自身的系统结构,处理能力,都将适应技术发展的…

逃避自由是所有成长的前提

《逃避自由》是埃里希弗洛姆的一部重要作品,首次出版于1941年。该书主要探讨了现代人在面对自由和独立时所表现出的逃避倾向。在这部作品中,弗洛姆分析了自由的心理学基础,并论证了自由不仅仅是一个政治和经济的概念,更是一个深刻…

脚本工具 mktemp 和 install

1.创建临时文件 mktemp 1.1 介绍 mktemp 命令用于创建并显示临时文件,可避免冲突 使用mktemp命令时,它会根据指定的模板在临时目录(默认为/tmp)中创建一个唯一的临时文件或目录,并返回该文件或目录的完整路径。临时…

2023年12月CCF-GESP编程能力等级认证Python编程一级真题解析

一、单选题(共15题,共30分) 第1题 某公司新出了一款无人驾驶的小汽车,通过声控智能驾驶系统,乘客只要告诉汽车目的地,车子就能自动选择一条优化路线,告诉乘客后驶达那里。请问下面哪项不是驾驶系统完成选路所必须的。( )(2023年12月py一级) A:麦克风 B:扬声器 C…

工作中的小记录

1、在element的el-dialog中上传附件后在另一个el-form-item下的input输入框中获取该附件名 使用v-model无法双向绑定。使用this.$set this.$set(this.formData,"属性名","属性值")2、后端传来文件地址,点击直接下载 getCaseId(this.$route.quer…

datawhale 大模型学习 第十一章-大模型法律篇

简介 新技术与法律关系:大型语言模型(LLM)的出现引发了对现有法律适用性的探讨,尤其是在版权、隐私和公平使用等方面。互联网法律挑战:互联网的匿名性和无国界特性对法律的管辖权提出了挑战。法律与道德区分&#xff…

在线UI设计工具有哪些?这5个包你满意

随着 UI 设计行业的蓬勃发展,越来越多的设计师进入 UI 设计,选择一款方便的 UI 设计工具尤为重要!除了传统的 UI 设计工具,在线 UI 设计工具也受到越来越多设计师的青睐。这种不受时间、地点和计算机配置限制的工作方法真的很刺激…

Python处理图片生成天际线(2024.1.29)

1、天际线简介 天际线(SkyLine)顾名思义就是天空与地面的边界线,人站在不同的高度,会看到不同的景色和地平线,天空与地面建筑物分离的标记线,不得不说,每天抬头仰望天空,相信大家都可…

屈子祠镇黑鱼岭,不可移动文物预防性保护系统守遗珍

一、何止秦俑 秦陵苑囿青铜水禽等文物集中展出 文物保护,尤其是不可移动文物的保护,一直都是文化遗产的重要环节。湖南省汨罗市屈子祠镇双楚村黑鱼岭墓地,作为长江中游地区的重大考古发现,其商朝晚期的历史背景赋予其不可估量的历…

c# 语音播报

在C#中进行语音播报通常需要使用.NET Framework中的某个语音库或服务。一个常见的选择是使用System.Speech.Synthesis命名空间中的SpeechSynthesizer类,该类提供了文本到语音的转换功能。 以下是一个简单的示例,演示如何在C#中使用SpeechSynthesizer进行…

响应式Web开发项目教程(HTML5+CSS3+Bootstrap)第2版 例5-6 绘制几何图形

代码 <!doctype html> <html> <head> <meta charset"utf-8"> <title>绘制几何图形</title> </head><body><canvas id"canvas" width"250" height"150" style"border: 1px b…

QUIC with CUBIC or BBR

拥塞控制 拥塞控制算法是 TCP/QUIC 协议的一个基础部分&#xff0c;多年来经过一个个版本的迭代&#xff08;如 Tahoe、Reno、Vegas 等&#xff09;&#xff0c;拥塞控制算法得到了持续的提升。由于篇幅有限&#xff0c;本文就目前比较流行的两种拥塞控制算法&#xff08;CUBI…

SVM支持向量机

1.基本概念 支持向量机&#xff08;Support Vector Machine&#xff0c;SVM&#xff09;是一种有监督学习方法&#xff0c;主要用于分类和回归分析。它的基本思想是在特征空间中找到一个超平面&#xff0c;能够将不同类别的样本分开&#xff0c;并且使得离这个超平面最近的样本…

Linux 链接 GitHub 出现 Connection timed out

问题 安装GIT并完成公钥验证&#xff1a;Linux 系统拉取 Github项目 [rootxxx devtools]# ssh -T gitgithub.com ssh: connect to host github.com port 22: Connection timed out解决方案 进入在存放公钥私钥id_rsa.pub文件里&#xff0c;新建/修改config文本 [rootxxx my…

Java 异常处理下篇:11 个异常处理最佳实践

文章目录 前言最佳实践早抛出&#xff0c;晚捕获原则只捕获实际可处理的异常不要忽略捕捉的异常抛出具体的检查性异常正确包装自定义异常中的异常记录或抛出异常&#xff0c;但不要同时执行finally 中永远不要抛出异常或返回值避免使用异常进行流程控制使用模板方法处理重复的 …

算法训练day24回溯算法理论基础77组合

今日学习链接 https://programmercarl.com/%E5%9B%9E%E6%BA%AF%E7%AE%97%E6%B3%95%E7%90%86%E8%AE%BA%E5%9F%BA%E7%A1%80.html#%E7%90%86%E8%AE%BA%E5%9F%BA%E7%A1%80 https://programmercarl.com/0077.%E7%BB%84%E5%90%88.html#%E5%89%AA%E6%9E%9D%E4%BC%98%E5%8C%96 回溯算…