【Pytorch神经网络理论篇】 09 神经网络模块中的损失函数

同学你好!本文章于2021年末编写,获得广泛的好评!

故在2022年末对本系列进行填充与更新,欢迎大家订阅最新的专栏,获取基于Pytorch1.10版本的理论代码(2023版)实现,

Pytorch深度学习·理论篇(2023版)目录地址为:

CSDN独家 | 全网首发 | Pytorch深度学习·理论篇(2023版)目录本专栏将通过系统的深度学习实例,从可解释性的角度对深度学习的原理进行讲解与分析,通过将深度学习知识与Pytorch的高效结合,帮助各位新入门的读者理解深度学习各个模板之间的关系,这些均是在Pytorch上实现的,可以有效的结合当前各位研究生的研究方向,设计人工智能的各个领域,是经过一年时间打磨的精品专栏!https://v9999.blog.csdn.net/article/details/127587345欢迎大家订阅(2023版)理论篇

以下为2021版原文~~~~

 

1 训练模型的步骤与方法

  • 将样本书记输入到模型中计算出正向的结果
  • 计算模型结果与样本目标数值之间的差值(也称为损失值loss)
  • 根据损失值,使用链式反向求导的方法,依次计算出模型中每个参数/权重的梯度
  • 使用优化器中的策略对模型中的参数进行更新

2 神经网络模块中的损失函数

2.1 损失函数定义

损失函数主要用来计算“输出值”与“输入值”之间的差距,即误差,反向传播中依靠损失函数找到最优的权重。

2.2 L1损失函数/最小绝对值偏差(LAD)/最小绝对值误差(LAE)

L1损失函数用于最小化误差,该误差是真实值和预测值之间的所有绝对差之和。

2.2.1 代码实现==>以类的形式进行封装,需要对其实例化后再使用

import torch
### pre:模型的输出值
### label:模型的目标值
loss = torch.nn.L1Loss()[pre,label]

2.3 L2损失函数

L2损失函数用于最小化误差,该误差是真实值和预测值之间所有平方差的总和

2.4 均值平方差损失(MSE)

均值平方差损失(MSE)主要针对的是回归问题,主要表达预测值域真实值之间的差异

2.4.1 MSE的公式表述

 这里的n表示n个样本。ylabel与ypred的取值范围一般为0-1。

2.4.2 注释

  • MSE的值越小,表明模型越好
  • 在神经网络的计算中,预测值与真实值要控制在相同的数据分布中
  • 假设预测值输入Sigmoid激活函数后其取值范围为0到1之间,则真实值的取值范围也应该取到0到1之间

2.4.3 代码实现==>以类的形式进行封装,需要对其实例化后再使用

import torch
### pre:模型的输出值
### label:模型的目标值
loss = torch.nn.MSELoss()(pre,label)

2.5 交叉熵损失函数(Cross Entropy)

2.5.1 交叉熵损失函数简介

交叉熵损失函数可以用来学习模型分布与训练分布之间的差异,一般用作分类问题,数学含义为预测输入样本属于某一类别的概率。

2.5.2 公式介绍

 y^为真实分类y的概率值

2.5.3 代码实现==>以类的形式进行封装,需要对其实例化后再使用

import torch
### pre:模型的输出值
### label:模型的目标值
loss = torch.nn.CrossEntropyLoss()(pre,label)

2.5.4 图像理解

接下来,我们从图形的角度,分析交叉熵函数,加深大家的理解。首先,还是写出单个样本的交叉熵损失函数:


我们知道,当 y = 1 时:

 这时候,L 与预测输出的关系如下图所示:

看了 L 的图形,简单明了!横坐标是预测输出,纵坐标是交叉熵损失函数 L。显然,预测输出越接近真实样本标签 1,损失函数 L 越小;预测输出越接近 0,L 越大。因此,函数的变化趋势完全符合实际需要的情况。

当 y = 0 时:

这时候,L 与预测输出的关系如下图所示:

同样,预测输出越接近真实样本标签 0,损失函数 L 越小;预测函数越接近 1,L 越大。函数的变化趋势也完全符合实际需要的情况。

从上面两种图,可以帮助我们对交叉熵损失函数有更直观的理解。无论真实样本标签 y 是 0 还是 1,L 都表征了预测输出与 y 的差距。

从图形中我们可以发现:预测输出与 y 差得越多,L 的值越大,也就是说对当前模型的 “ 惩罚 ” 越大,而且是非线性增大,是一种类似指数增长的级别。这是由 log 函数本身的特性所决定的。这样的好处是模型会倾向于让预测输出更接近真实样本标签 y。

2.6 其他损失函数

2.6.1 SmoothL1Loss

SmoothL1Loss:平滑版的L1损失函数。此损失函数对于异常点的敏感性不如MSE-Loss。在某些情况下(如Fast R-CNN模型中),它可以防止梯度“爆炸”。这个损失函数也称为Huber loss。

2.6.2 NLLLoss

NLLLoss:负对数似然损失函数,在分类任务中经常使用。

2.6.3 NLLLoss22d

NLLLoss22d:计算图片的负对数似然损失函数,即对每个像素计算NLLLoss。

2.6.4 KLDivLoss

KLDivLoss:计算KL散度损失函数。

2.6.5 BCELoss

BCELoss:计算真实标签与预测值之间的二进制交叉熵。

2.6.6 BCEWithLogitsLoss

BCEWithLogitsLoss:带有Sigmoid激活函数层的BCELoss,即计算target与Sigmoid(output)之间的二进制交叉熵。

2.6.7 MarginRankingLoss

MarginRankingLoss:按照一个特定的方法计算损失。计算给定输入x、x(一维张量)和对应的标y(一维张量,取值为-1或1)之间的损失值。如果y=1,那么第一个输入的值应该大于第二个输入的值;如果y=-1,则相反。

2.6.8 HingeEmbeddingLoss

HingeEmbeddingLoss:用来测量两个输入是否相似,使用L1距离。计算给定一个输入x(二维张量)和对应的标签y(一维张量,取值为-1或1)之间的损失值。

2.6.9 MultiLabelMarginLoss

MultiLabelMarginLoss:计算多标签分类的基于间隔的损失函数(hinge loss)。计算给定一个输入x(二维张量)和对应的标签y(二维张量)之间的损失值。其中,y表示最小批次中样本类别的索引。

2.6.10 SoftMarginLoss

SoftMarginLoss:用来优化二分类的逻辑损失。计算给定一个输入x(二维张量)和对应的标签y(一维张量,取值为-1或1)之间的损失值。

2.6.11 MultiLabelSoftMarginLoss

MultiLabelSoftMarginLoss:基于输入x(二维张量)和目标y(二维张量)的最大交叉熵,优化多标签分类(one-versus-al)的损失。

2.6.12 CosineEmbeddingLoss

CosineEmbeddingLoss:使用余弦距离测量两个输入是否相似,一般用于学习非线性embedding或者半监督学习。

2.6.13 MultiMarginLoss

MultiMarginLoss:用来计算多分类任务的hinge loss。输入是x(二维张量)和y(一维张量)。其中y代表类别的索引。

2.7 汇总

用输入标签数据的类型来选取损失函数

如果蝓入是无界的实数值,那么损失函数使用平方差

如果输入标签是位矢量(分类标识),那么使用交叉熵会更适合。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/469406.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

jquery hover事件中 fadeIn和fadeOut 效果不能及时停止

$(".nav ul li").hover(function () {var id $(this).attr("id");$(".nav dl").each(function (index, domEle) {if ($(domEle).attr("id") id) {$(domEle).fadeIn();}else {$(domEle).stop().fadeOut();//在这里加入.stop() 以阻止…

【Pytorch神经网络理论篇】 10 优化器模块+退化学习率

同学你好!本文章于2021年末编写,获得广泛的好评! 故在2022年末对本系列进行填充与更新,欢迎大家订阅最新的专栏,获取基于Pytorch1.10版本的理论代码(2023版)实现, Pytorch深度学习理论篇(2023版)目录地址…

Android fb0 截屏实现

问题:我们有几个项目,在项目1和项目2上实现截屏是没有问题的,但是在项目3上实现截屏是不行的 原因:分辨率差异引起的问题,分辨率长宽一定要是32的整数倍 Dear customer, Sorry for the late reply due to annual leave. I dont think this issue relates with thediff…

Python trino执行hive insert overwrite不生效的问题

使用python的trino包执行insert overwrite,但是overwrite却没有生效的问题 根据trino的官网介绍的insert overwrite的开启方式,开启hive的insert overwrite会话,使当前会话的insert into语句支持insert overwrite,也即支持插入数…

HAProxy负载均衡原理及企业级实例部署haproxy集群

HAProxy是一种高效、可靠、免费的高可用及负载均衡解决方案,非常适合于高负载站点的七层数据请求。客户端通过HAProxy代理服务器获得站点页面,而代理服务器收到客户请求后根据负载均衡的规则将请求数据转发给后端真实服务器。 同一客户端访问服务器&…

【Pytorch神经网络实战案例】07 预测泰坦尼克号上生存的乘客

1 样本处理 1.1 载入样本代码---Titanic forecast.py(第1部分) import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from scipy import stats import pandas as pd import matplotlib.pyplot as plt import os o…

ubuntu下 安装 adb

1、把adb tool工具考到你要安装的目录夏目 <

基于sanic的服务使用celery完成动态修改定时任务

首先声明一下 考虑到celery目前和asyncio的不兼容性&#xff0c;协程任务需要转换为非异步的普通方法才能被当做task加入定时&#xff0c;并且celery和asyncio使用可能会带来预想不到的问题&#xff0c;在celery官方第二次承诺的6.0版本融合asyncio之前&#xff0c;需要慎重考虑…

shell 中的ifeq

libs_for_gcc -lgnunormal_libs foo: $(objects)ifeq ($(CC),gcc)$(CC) -o foo $(objects) $(libs_for_gcc)else$(CC) -o foo $(objects) $(normal_libs)endif 可见&#xff0c;在上面示例的这个规则中&#xff0c;目标“foo”可以根据变量“$(CC)”值来选取不同的函数库来编…

第一篇unity

在网上找的学习资料&#xff0c;做了点简单的效果。 半成品 http://files.cnblogs.com/files/buzhidaojiaoshenme/unity.rar 第二个游戏&#xff0c;方向键和“W”&#xff0c;”S“键移动方块&#xff0c;碰撞到最右边的方块过关。 http://files.cnblogs.com/files/buzhidaoji…

报错:OMP: Error #15: Initializing libomp.dylib, but found libiomp5.dylib already initialized.

问题描述&#xff1a; OMP: Error #15: Initializing libiomp5.dylib, but found libiomp5.dylib already initialized. OMP: Hint This means that multiple copies of the OpenMP runtime have been linked into the program. That is dangerous, since it can degrade perf…

Pyscript,使用Python编写前端脚本

介绍 Anaconda的CEO Peter Wang在前两个月的时候发布了Pyscript&#xff0c;实现了在HTML支持Python的使用&#xff0c;整个引用过程甚至不需要安装任何环境&#xff0c;只需要使用link和script标签即可引用实现Python在HTML中运行的功能&#xff0c;在HTML中也可以运行和使用…

如何把应用程序app编译进android系统

转载&#xff1a;http://ywxiao66.blog.163.com/blog/static/175482055201152710441106/------------------------------------------------------------------把常用的应用程序编译到img文件中&#xff0c;就成了系统的一部分&#xff0c;用户不必自己安装&#xff0c;当然也卸…

【Pytorch神经网络实战案例】08 识别黑白图中的服装图案(Fashion-MNIST)

1 Fashion-MNIST简介 FashionMNIST 是一个替代 MNIST 手写数字集 的图像数据集。 它是由 Zalando&#xff08;一家德国的时尚科技公司&#xff09;旗下的研究部门提供。其涵盖了来自 10 种类别的共 7 万个不同商品的正面图片。 FashionMNIST 的大小、格式和训练集/测试集划分与…

PHP list的赋值

List右边的赋值对象是一个以数值为索引的数组&#xff0c;左边的变量的位置和赋值对象的键值一一对应&#xff0c;有些位置的变量可以省略不写。非末尾的被赋值变量省略时&#xff0c;分隔的逗号不能省略。左边变量被赋值的顺序是从右到左的。 1 list($a, ,$b,$c[],$c[]) [1,2…

Pyscript,创建一个能执行crud操作的网页应用

目录 实现一个添加邀请客人名单的功能 循序渐进&#xff0c;逐步实现&#xff1a; 输入客人名称&#xff0c;按下enter键添加客人名单点击客人名单在名单上添加或者取消添加删除线&#xff0c;表示已经检查客人到场或未到场 checkbox&#xff0c;点击客人名单或者点击checkb…

爬虫实战学习笔记_1 爬虫基础+HTTP原理

1 爬虫简介 网络爬虫&#xff08;又被称作网络蜘蛛、网络机器人&#xff0c;在某些社区中也经常被称为网页追逐者)可以按照指定的规则&#xff08;网络爬虫的算法&#xff09;自动浏览或抓取网络中的信息。 1.1 Web网页存在方式 表层网页指的是不需要提交表单&#xff0c;使…

LeetCode | HouseCode 算法题

题目&#xff1a; You are a professional robber planning to rob houses along a street. Each house has a certain amount of money stashed, the only constraint stopping you from robbing each of them is that adjacent houses have security system connected and it…

爬虫实战学习笔记_2 网络请求urllib模块+设置请求头+Cookie+模拟登陆

1 urllib模块 1.1 urllib模块简介 Python3中将urib与urllib2模块的功能组合&#xff0c;并且命名为urllib。Python3中的urllib模块中包含多个功能的子模块&#xff0c;具体内容如下。 urllib.request&#xff1a;用于实现基本HTTP请求的模块。urlb.error&#xff1a;异常处理…

Python解决多个进程服务重复运行定时任务的问题

记录多实例服务定时任务出现运行多次的问题 问题&#xff1a;web项目运行多个实例时&#xff0c;定时任务会被执行多次的问题 举例来说 我使用库APScheduler排定了一个定时任务taskA在每天的晚上9点需要执行一次&#xff0c;我的web服务使用分布式运行了8个实例&#xff0c;于…