Pytorch学习 day10(L1Loss、MSELoss、交叉熵Loss、反向传播)

Loss

  • loss的作用如下:
    • 计算实际输出和真实值之间的差距
    • 为我们更新模型提供一定的依据(反向传播)

L1Loss

  • 绝对值损失函数:在每一个batch_size内,求每个输入x和标签y的差的绝对值,最后返回他们平均值
    在这里插入图片描述

MSELoss

  • 均方损失函数:在每一个batch_size内,求每个输入x和标签y的差的平方,最后返回他们的平均值
    在这里插入图片描述

交叉熵Loss

  • 当我们在处理分类问题时,经常使用交叉熵损失函数。
    • 交叉熵能够衡量同一个随机变量中的两个不同概率分布的差异程度,在机器学习中就表示为真实概率分布与预测概率分布之间的差异。交叉熵的值越小,模型预测效果就越好。
    • 交叉熵在分类问题中常常与softmax是标配,softmax将输出的结果进行处理,使其多个分类的预测值和为1,再通过交叉熵来计算损失。
  • 由于以下内容需要理解Softmax函数和交叉熵损失函数,所以先回顾一遍:
  • Softmax函数:
    • 首先,分类任务的目标是通过比较每个类别的概率大小来判断预测的结果。但是,我们不能选择未规范化的线性输出作为我们的预测。原因有两点。
1. 线性输出的总和不一定为1
2. 线性输出可能有负值
  • 因此我们采用Softmax规范手段来保证输出的非负、和为1,公式和举例如下:
    • 左侧为Softmax函数公式,右侧的o为线性输出,y为Softmax规范后的输出
      在这里插入图片描述
  • 交叉熵损失函数:
    • 下图为交叉熵损失函数公式,P(x)为真实概率分布,q(x)为预测概率分布:
      在这里插入图片描述
  • 我们将Softmax规范后的输出代入交叉熵损失函数中,可得:
    • 在训练中,我们已知该样本的类别,那么在该样本的真实概率分布中,只有该类别为1,其他都为0。
    • 在计算机中的log,默认都是ln。
      请添加图片描述
  • 这就是Pytorch官网中的交叉熵损失函数公式:
    在这里插入图片描述
  • 注意:给此公式的交叉熵损失函数传入的input,不需要进行规范化,即不需要进行Softmax变换
  • 我们仍然使用该类的对象函数来调用forward方法,而forward方法需要满足以下条件:
    • input:第一位为batch_size,第二位为输入的class数量
    • target:只有一位,为batch_size
      在这里插入图片描述
  • 代码如下:
import torchx = torch.tensor([0.1, 0.2, 0.3])
print(x.shape)  # torch.Size([3])
print(x)    # tensor([0.1000, 0.2000, 0.3000])
y = torch.tensor([1])   
x = torch.reshape(x, (1,3)) # 由于交叉熵损失函数的forward方法要求输入是二维,且第一位是batch_size,第二位是class的数量
print(x.shape)  # torch.Size([1, 3])
print(x)    # tensor([[0.1000, 0.2000, 0.3000]])
loss_cross = torch.nn.CrossEntropyLoss()    # 交叉熵损失函数
result_loss = loss_cross(x, y)  # 计算交叉熵损失
print(result_loss)  # 输出结果:
# torch.Size([3])
# tensor([0.1000, 0.2000, 0.3000])
# torch.Size([1, 3])
# tensor([[0.1000, 0.2000, 0.3000]])
# tensor(1.1019)
  • 计算器的输出结果如下:
    • 代码中的log默认为ln
      在这里插入图片描述

反向传播

  • 当输入不变时,我们要想让总loss最小,就是要找到一组最小的w、b序列,这时我们可以采用一种系统的方法:梯度下降方法
    • 那么找w、b序列,就转换为求学习率和loss对w、b的偏导数,形象化的表示如下:
      在这里插入图片描述
    • 梯度下降的公式如下:
      在这里插入图片描述
  • 这其中:学习率是我们手动设定的,偏导数则是模型自动计算的。
  • 由于每一个节点都需要计算偏导数,如果我们采用正向传播计算,那么针对每一个节点,我们都需要正向计算到结尾一次,而反向传播,只需要我们从头正向计算到结尾一次,之后根据节点位置,进行反向偏导数相乘即可,流程图如下:
    在这里插入图片描述
  • 在模型代码中,偏导数用grad(梯度)表示,在模型的训练过程中,通过反向传播来计算每个网络层节点的对应梯度,并通过某种算法(优化器)不断更新节点的参数,最终达到loss最小的一个结果,代码如下:
import torch
import torchvision
from torch import nntest_dataset = torchvision.datasets.CIFAR10(root='Dataset', train=False, download=True, transform=torchvision.transforms.ToTensor())
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)class Tudui(nn.Module):def __init__(self):super().__init__()self.module1 = nn.Sequential(nn.Conv2d(3, 32, 5, 1, 2),nn.MaxPool2d(2, 2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(2, 2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2, 2),nn.Flatten(),nn.Linear(1024, 64),nn.Linear(64, 10))def forward(self, input):output = self.module1(input)return outputloss = nn.CrossEntropyLoss()    # 交叉熵损失函数
tudui = Tudui()for data in test_loader:inputs, targets = dataoutputs = tudui(inputs)result_loss = loss(outputs, targets)    # 计算lossresult_loss.backward()  # 反向传播,注意需要使用计算后的lossa=1 # 用于调试,设置断点break
  • 结果如下:
    在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/746696.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Kafka配置SASL_PLAINTEXT权限。常用操作命令,创建用户,topic授权

查看已经创建的topic ./bin/kafka-topics.sh --bootstrap-server localhost:9092 --list 创建topic 创建分区和副本数为1的topic ./bin/kafka-topics.sh --create --bootstrap-server localhost:9092 --topic acltest --partitions 1 --replication-factor 1 创建kafka用户 …

HTML静态网页成品作业(HTML+CSS+JS)——迪士尼公主介绍(6个页面)

🎉不定期分享源码,关注不丢失哦 文章目录 一、作品介绍二、作品演示三、代码目录四、网站代码HTML部分代码 五、源码获取 一、作品介绍 🏷️本套采用HTMLCSS,使用Javacsript代码,共有6个页面。 二、作品演示 三、代码…

基于单片机的酒精浓度测试仪

摘 要 现如今,人们对生活的态度和生活方式变得不同,,不仅私家车成为了人们最普遍的交通工具,大多数人都有自己的私家车,而且人们对酒精的消耗量也越来越大,这些就导致酒后驾车行为越来越普遍,酒后驾车意外越来越频繁&…

【深度学习笔记】10_10 束搜索beam-search

注:本文为《动手学深度学习》开源内容,部分标注了个人理解,仅为个人学习记录,无抄袭搬运意图 10.10 束搜索 上一节介绍了如何训练输入和输出均为不定长序列的编码器—解码器。本节我们介绍如何使用编码器—解码器来预测不定长的序…

3dmax2020模型显示黑白不稳定---模大狮模型网

如果在3ds Max 2020中显示的模型出现黑白不稳定的情况,可能有几个常见原因和解决方法: 显卡驱动问题: 首先检查你的显卡驱动程序是否是最新版本。过时或不兼容的显卡驱动可能导致显示问题。建议更新到最新的显卡驱动程序,并确保其…

YOLOv9(3):YOLOv9损失(Loss)计算

1. 写在前面 YOLOv9的Loss计算与YOLOv8如出一辙,仅存在略微的差异。多说一句,数据的预处理和导入方式都是一样的。因此如果你已经对YOLOv8了解的比较透彻,那么对于YOLOv9你也只是需要多关注网络结构就可以。 YOLOv9本身也是Anchor-Free的&a…

编译esp32s3的ncnn,并运行mnist 手写数字识别

东哥科技,专注科技研发,wx交流:dg_i688 我的项目代码 https://github.com/cdmstrong/ncnn_on_esp32s3 下载ncnn git clone https://github.com/Tencent/ncnn.git安装idf 环境 这里直接按官网的可执行文件来就好了,直接安装完…

[mysql必备面试题]-mysql索引(B+ Tree )

一 B Tree 原理 1. 数据结构 B Tree 指的是 Balance Tree,也就是平衡树。平衡树是一颗查找树,并且所有叶子节点位于同一层。 B Tree 是基于 B Tree 和叶子节点顺序访问指针进行实现,它具有 B Tree 的平衡性,并且通过顺序访问指针…

【python中处理日期和时间二】扩展内容datetime模块-pytz模块-dateutil模块

扩展内容:日期和时间 datetime模块;pytz模块;dateutil模块 一、 datetime模块 查看datetime模块函数: >>> import datetime >>> dir(datetime) [MAXYEAR, MINYEAR, UTC, __all__, __builtins__, __cached__…

2024最新国内外低代码平台大全

博主猫头虎的技术世界 🌟 欢迎来到猫头虎的博客 — 探索技术的无限可能! 专栏链接: 🔗 精选专栏: 《面试题大全》 — 面试准备的宝典!《IDEA开发秘籍》 — 提升你的IDEA技能!《100天精通鸿蒙》 …

python面向对象的三大特性:封装,继承,多态

1、面向对象有哪些特性 三种:封装性、继承性、多态性 2、Python中的封装 在Python代码中,封装有两层含义: ① 把现实世界中的主体中的属性和方法书写到类的里面的操作即为封装 ② 封装可以为属性和方法添加为私有权限,不能直…

数据泄露态势(2024年2月)

监控说明:以下数据由零零信安0.zone安全开源情报系统提供,该系统监控范围包括约10万个明网、深网、暗网、匿名社交社群威胁源。在进行抽样事件分析时,涉及到我国的数据不会选取任何政府、安全与公共事务的事件进行分析。如遇到影响较大的伪造…

基于YOLOv8/YOLOv7/YOLOv6/YOLOv5的安全帽检测系统(深度学习模型+UI界面代码+训练数据集)

摘要:开发先进的安全帽识别系统对提升工作场所的安全性至关重要。本文详细介绍了使用深度学习技术创建此类系统的方法,并分享了完整的实现代码。系统采用了强大的YOLOv8算法,并对其与YOLOv7、YOLOv6、YOLOv5的性能进行了详细比较,…

Windows主机多网卡访问内外网

1:在实际生产环境有可能需要某台机器既能访问公司的内部网络也要能够访问外网。 2:首先机器要有两块网卡根据实际情况分别设置内外网的IP地址,掩码,网关,DNS等信息。设置完成时会出现下面的提示。 3:打开命…

空间计算综合指南

空间计算(spatial computing)是指使人类能够在三维空间中与计算机交互的一组技术。 该保护伞下的技术包括增强现实(AR)和虚拟现实(VR)。 这本综合指南将介绍有关空间计算所需了解的一切。 你将了解 AR、VR…

漏洞复现-红帆OA系列

漏洞复现-红帆OA GetWorkUnit.asmx存在SQL注入iOffice ioDesktopData存在SQL注入list接口存在SQL注入漏洞ioffice wssrtfile sql注入任意⽤户登录(2个)后台多处⽂件上传(7个)后台密码修改(1个)⽂件读取(2个)SQL注⼊(15个)红帆OA任意文件上传漏洞红帆HF Office系统SQL…

QComboBox相关的qss学习

QT有关QCobobox控件的样式设置(圆角、下拉框,向上展开、可编辑、内部布局等)_qcombobox样式-CSDN博客 原始图: 红色边框: QComboBox{ border:2px solid rgb(255, 85, 0); } 绿色背景: QComboBox{ border…

备战蓝桥杯Day27 - 省赛真题-2023

题目描述 大佬代码 import os import sysdef find(n):k 0for num in range(12345678,98765433):str1 ["2","0","2","3"]for x in str(num) :if x in str1:if str1[0] x:str1.pop(0)if len(str1) ! 0:k1print(k)print(85959030) 详…

C语言指针与数组(不适合初学者版):一篇文章带你深入了解指针与数组!

🎈个人主页:JAMES别扣了 💕在校大学生一枚。对IT有着极其浓厚的兴趣 ✨系列专栏目前为C语言初阶、后续会更新c语言的学习方法以及c题目分享. 😍希望我的文章对大家有着不一样的帮助,欢迎大家关注我,我也会回…

大模型笔记:吴恩达 ChatGPT Prompt Engineering for Developers(1) prompt的基本原则和策略

1 intro 基础大模型 VS 用指令tune 过的大模型 基础大模型 只会对prompt的文本进行续写 所以当你向模型发问的时候,它往往会像复读机一样续写几个问题这是因为在它见过的语料库文本(通常大多来自互联网)中,通常会连续列举出N个问…