pytorch-损失函数-分类和回归区别

torch.nn 库和 torch.nn.functional库的区别

  1. torch.nn库:这个库提供了许多预定义的层,如全连接层(Linear)、卷积层(Conv2d)等,以及一些损失函数(如MSELoss、CrossEntropyLoss等)。这些层都是类,它们都继承自nn.Module,因此可以很方便地集成到自定义的模型中。torch.nn库中的层都有自己的权重和偏置,这些参数可以通过优化器进行更新。

    1. 当你需要的操作包含可学习的参数(例如权重和偏置)时,通常使用torch.nn库更为方便。例如,对于卷积层(Conv2d)、全连接层(Linear)等,由于它们包含可学习的参数,因此通常使用torch.nn库中的类。这些类会自动管理参数的创建和更新。

      例如:

    2. import torch.nn as nnconv = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1)
      fc = nn.Linear(in_features=1024, out_features=10)
      
  2. torch.nn.functional库:这个库提供了一些函数,如激活函数(如relu、sigmoid等)、池化函数(如max_pool2d、avg_pool2d等)以及一些损失函数(如cross_entropy、mse_loss等)。这些函数更加灵活,但使用它们需要手动管理权重和偏置。

    1. 对于没有可学习参数的操作,例如ReLU激活函数、池化操作、dropout等,你可以选择使用torch.nn.functional库,因为这些操作不需要额外的参数。

    2. import torch.nn.functional as Fx = F.relu(x)
      x = F.max_pool2d(x, kernel_size=2)
      x = F.dropout(x, p=0.5, training=self.training)
      
  3. 对于损失函数,torch.nn库和torch.nn.functional库都提供了实现,你可以根据需要选择。如果你需要的损失函数有可学习的参数(例如nn.BCEWithLogitsLoss中的pos_weight),那么应该使用torch.nn库。如果你的损失函数没有可学习的参数,那么你可以选择使用torch.nn.functional库,这样可以避免创建不必要的对象。

    例如:

  4. import torch.nn as nn
    import torch.nn.functional as F# 使用nn库
    loss_fn = nn.CrossEntropyLoss()
    loss = loss_fn(prediction, target)# 使用functional库
    loss = F.cross_entropy(prediction, target)
    

torch.nn 库和 torch.nn.functional库损失函数的对应关系

以下是一些常见的损失函数在torch.nn和torch.nn.functional中的对应关系:

  1. 交叉熵损失:
    1. torch.nn.CrossEntropyLoss
    2. torch.nn.functional.cross_entropy
  2. 负对数似然损失:
    1. torch.nn.NLLLoss
    2. torch.nn.functional.nll_loss
  3. 均方误差损失:
    1. torch.nn.MSELoss
    2. torch.nn.functional.mse_loss
  4. 平均绝对误差损失:
    1. torch.nn.L1Loss
    2. torch.nn.functional.l1_loss

分类和回归损失函数的区别

  1. 分类问题:分类问题的目标是预测输入数据的类别。对于这类问题,常用的损失函数有交叉熵损失(Cross Entropy Loss)和负对数似然损失(Negative Log Likelihood Loss)。这些损失函数都是基于预测的概率分布和真实的概率分布之间的差异来计算损失的。
    1. nn.CrossEntropyLoss:这是用于分类问题的损失函数。它期望的输入是一个形状为(batch_size, num_classes)的张量,其中每个元素是对应类别的原始分数(通常是最后一个全连接层的输出),以及一个形状为(batch_size,)的张量,其中每个元素是真实的类别标签。
    2. nn.NLLLoss:这也是用于分类问题的损失函数。它期望的输入是一个形状为(batch_size, num_classes)的张量,其中每个元素是对应类别的对数概率(通常是log_softmax的输出),以及一个形状为(batch_size,)的张量,其中每个元素是真实的类别标签。
  2. 回归问题:回归问题的目标是预测一个连续的值。对于这类问题,常用的损失函数有均方误差损失(Mean Squared Error Loss)和平均绝对误差损失(Mean Absolute Error Loss)。这些损失函数都是基于预测值和真实值之间的差异来计算损失的。
    1. nn.MSELoss:这是用于回归问题的损失函数。它期望的输入是两个形状相同的张量,一个是预测值,一个是真实值。这两个张量的形状可以是任意的,只要它们相同即可。
    2. nn.L1Loss:这也是用于回归问题的损失函数。它期望的输入是两个形状相同的张量,一个是预测值,一个是真实值。这两个张量的形状可以是任意的,只要它们相同即可。
举例说明

nn.MSELoss()

输入:预测值和目标值,它们的形状应该是相同的。例如,如果你有一个批量大小为batch_size的数据,每个数据有n个特征,那么预测值和目标值的形状都应该是(batch_size, n)。

输出:一个标量,表示计算得到的均方误差损失。

例如:

import torch
import torch.nn as nn# 假设我们有一个批量大小为3的数据,每个数据有2个特征
prediction = torch.randn(3, 2)
target = torch.randn(3, 2)loss_fn = nn.MSELoss()
loss = loss_fn(prediction, target)print(loss)  # 输出一个标量,表示计算得到的均方误差损失

F.cross_entropy()

输入:预测值和目标值。预测值的形状应该是(batch_size, num_classes),表示对每个类别的预测概率;目标值的形状应该是(batch_size,),表示每个数据的真实类别标签。

输出:一个标量,表示计算得到的交叉熵损失。

例如:

import torch
import torch.nn.functional as F# 假设我们有一个批量大小为3的数据,有4个类别
prediction = torch.randn(3, 4)
target = torch.tensor([1, 0, 3])  # 真实的类别标签loss = F.cross_entropy(prediction, target)print(loss)  # 输出一个标量,表示计算得到的交叉熵损失

多分类中CrossEntropyLoss() 和NLLLoss()的区别

  1. CrossEntropyLoss():它的输入是模型对每个类别的原始分数(通常是最后一个全连接层的输出),并且这些分数没有经过任何归一化处理。CrossEntropyLoss()内部会对这些分数进行log_softmax操作,然后计算交叉熵损失。
  2. NLLLoss():它的输入是模型对每个类别的对数概率,这些对数概率通常是通过对模型的原始输出进行log_softmax操作得到的。NLLLoss()会直接计算负对数似然损失。

CrossEntropyLoss() = softmax + log + NLLLoss() = log_softmax + NLLLoss()

二分类中BCELoss和BCEWithLogitsLoss的区别

BCELoss()和BCEWithLogitsLoss()都是PyTorch中常用的损失函数,主要用于二分类问题。但是它们的输入和处理方式有所不同。

  1. BCELoss():它的输入是模型对每个类别的概率,这些概率通常是通过对模型的原始输出进行sigmoid操作得到的。BCELoss()会直接计算二元交叉熵损失。
  2. BCEWithLogitsLoss():它的输入是模型对每个类别的原始分数(通常是最后一个全连接层的输出),并且这些分数没有经过任何归一化处理。BCEWithLogitsLoss()内部会对这些分数进行sigmoid操作,然后计算二元交叉熵损失。

总的来说,BCELoss()和BCEWithLogitsLoss()的主要区别在于它们的输入:BCELoss()期望的输入是模型的概率输出,而BCEWithLogitsLoss()期望的输入是模型的原始输出。在实际使用中,你可以根据自己的需求和模型的输出来选择使用哪一个损失函数。

另外,BCEWithLogitsLoss()在内部进行sigmoid和loss计算可以提高数值稳定性,因此在实际使用中,如果模型的输出是原始分数,推荐使用BCEWithLogitsLoss()。

回归损失函数中的reduction函数详解

它的完整定义是torch.nn.MSELoss(size_average=None, reduce=None, reduction='mean')。

下面是这些参数的解释:

  1. size_average(已弃用):如果设置为True,损失函数会对每个小批量的损失取平均值。如果设置为False,损失函数会对每个小批量的损失求和。默认值是True。这个参数已经被弃用,推荐使用reduction参数。
  2. reduce(已弃用):如果设置为True,损失函数会返回一个标量值,即所有输入元素的损失的平均值或总和(取决于size_average参数)。如果设置为False,损失函数会返回一个损失值的向量,每个元素对应一个输入数据点的损失。默认值是True。这个参数已经被弃用,推荐使用reduction参数。
  3. reduction:指定如何减少损失。可以是'none'(不减少,返回一个损失值的向量),'mean'(取平均,返回所有输入元素的损失的平均值)或'sum'(求和,返回所有输入元素的损失的总和)。默认值是'mean'。

nn.MSELoss()函数的输入是两个张量,分别代表预测值和目标值。它们必须有相同的形状。函数的输出是一个标量值,表示损失。

nn.SmoothL1Loss相比于nn.MSELoss损失函数的优点

  1. nn.MSELoss(均方误差损失)对于回归问题非常有效,但它对于异常值(outliers)非常敏感,因为它会将每个误差的平方进行求和。这意味着,即使只有一个样本的预测值与真实值相差很大,也会导致整体损失值显著增加。
  2. 而nn.SmoothL1Loss(平滑L1损失)则在处理异常值时更为鲁棒。它结合了L1损失和L2损失的优点:当预测值与真实值的差距较大时,它的行为类似于L1损失(即绝对值损失),对异常值不敏感;而当预测值与真实值接近时,它的行为类似于L2损失(即均方误差损失),可以更精细地优化模型。

因此,nn.SmoothL1Loss的一个主要优点是它可以在处理异常值和进行精细优化之间找到一个平衡,这在某些任务中可能是非常有用的。

nn.SmoothL1Loss是通过一个特定的数学公式来实现这个优点的。这个公式如下:

SmoothL1Loss(x, y) = 0.5 * (x - y)^2, if abs(x - y) < 1= abs(x - y) - 0.5, otherwise

这个公式的含义是,当预测值和真实值的差距小于1时,使用平方误差损失(即L2损失);当差距大于或等于1时,使用绝对值误差损失(即L1损失)。

可以看到,当差距较小的时候,SmoothL1Loss的行为类似于nn.MSELoss,它会对这些小的误差进行精细优化。而当差距较大的时候,SmoothL1Loss的行为类似于L1损失,它不会对这些大的误差进行过度惩罚,从而提高了对异常值的鲁棒性。

这就是nn.SmoothL1Loss如何在处理异常值和进行精细优化之间找到平衡的。

nn.HuberLoss的作用

nn.HuberLoss也被称为Huber损失,是一种结合了均方误差损失(Mean Squared Error,MSE)和平均绝对误差损失(Mean Absolute Error,MAE)的损失函数。它在处理回归问题时,尤其是存在异常值(outliers)的情况下,表现出较好的性能。

Huber损失的计算公式如下:

HuberLoss(x, y) = 0.5 * (x - y)^2, if abs(x - y) < delta= delta * abs(x - y) - 0.5 * delta^2, otherwise

这个公式的含义是,当预测值和真实值的差距小于一个阈值delta时,使用平方误差损失(即MSE);当差距大于或等于delta时,使用线性误差损失(即MAE)。

与nn.SmoothL1Loss类似,nn.HuberLoss在处理异常值和进行精细优化之间找到了一个平衡。当预测误差较小的时候,它的行为类似于MSE,可以对这些小的误差进行精细优化;而当预测误差较大的时候,它的行为类似于MAE,不会对这些大的误差进行过度惩罚,从而提高了对异常值的鲁棒性。

另外,nn.HuberLoss的一个优点是它的梯度在整个定义域内都是有界的,这使得模型在训练过程中更稳定。

参考自:

pytorch中常用的损失函数用法说明 | w3cschool笔记

pytorch教程 (四)- 损失函数_pytorch对比损失-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/130627.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

nginx负载配置

Nginx是一款高性能的Web服务器&#xff0c;同时也是一款高效的反向代理和负载均衡工具。在高并发的情况下&#xff0c;使用Nginx进行负载均衡可以提高网站的并发处理能力&#xff0c;保证网站的稳定性和可用性。下面是一些关于Nginx负载均衡的基础知识和配置方法。 Nginx负载均…

算法——多数相和

三数 15. 三数之和 - 力扣&#xff08;LeetCode&#xff09; 所以代码实现应该是 vector<vector<int>> threeSum(vector<int>& nums) {int n nums.size();sort(nums.begin(), nums.end()); // 对数组进行排序&#xff0c;以便后续操作vector<vector…

【微信小程序】自定义组件(二)

自定义组件 纯数据字段1、什么是纯数据字段2、使用规则 组件的生命周期1、组件全部的生命周期函数2、组件主要的生命周期函数3、lifetimes节点 组件所在页面的生命周期1、什么是组件所在页面的生命周期2、 pageLifetimes节点3、生成随机的颜色值 纯数据字段 1、什么是纯数据字…

快速了解推荐引擎检索技术

目录 一、推荐引擎和其检索技术 二、推荐引擎的整体架构和工作过程 &#xff08;一&#xff09;用户画像 &#xff08;二&#xff09;文章画像 &#xff08;三&#xff09;推荐算法召回 三、基于内容的召回 &#xff08;一&#xff09;召回算法 &#xff08;二&#xf…

C#高级--IO详解

零、文章目录 IO详解 1、IO是什么 &#xff08;1&#xff09;IO是什么 IO是输入/输出的缩写&#xff0c;即Input/Output。在计算机领域&#xff0c;IO通常指数据在内部存储器和外部存储器或其他周边设备之间的输入和输出。输入和输出是信息处理系统&#xff08;例如计算器&…

分享者 - 携程旅游创作者搬砖项目图文教程

大家好&#xff01;携程这个出行旅游平台相信大家都不陌生吧。 每天都有大量的旅客在里面浏览攻略&#xff0c;寻找灵感和旅游建议。 那么&#xff0c;我们的项目就是把一些优质的小红书平台上的旅游攻略或作品&#xff0c;经过处理后搬运到携程平台上发布。 这个项目如何操作呢…

Portraiture4.1.2最新中文汉化版

提起PS后期修图人像美白磨皮&#xff0c;大家会想到各种磨皮工具&#xff0c;其中Portraiture这款磨皮效率超高&#xff0c;是99%摄影师的必备插件&#xff0c;一秒磨皮&#xff0c;无卡顿&#xff0c;效果好&#xff01;人像摄影师人均一款&#xff0c;磨皮质感非常好&#xf…

Java 正则表达式重复匹配篇

重复匹配 * 可以匹配任意个字符&#xff0c;包括0个字符。 可以匹配至少一个字符。? 可以匹配0个或一个字符。{n} 可以精确指定 n 个字符。{n,m} 可以精确匹配 n-m 个字符。你可以是 0 。 匹配任意个字符 匹配 D 开头&#xff0c;后面是任意数字的字符&#xff0c; String …

独创改进 | RT-DETR 引入双向级联特征融合结构 RepBi-PAN | 附手绘结构图原图

本专栏内容均为博主独家全网首发,未经授权,任何形式的复制、转载、洗稿或传播行为均属违法侵权行为,一经发现将采取法律手段维护合法权益。我们对所有未经授权传播行为保留追究责任的权利。请尊重原创,支持创作者的努力,共同维护网络知识产权。 文章目录 YOLOv6贡献RepBi-…

实习记录--(海量数据如何判重?)--每天都要保持学习状态和专注的状态啊!!!---你的未来值得你去奋斗

海量数据如何判重&#xff1f; 判断一个值是否存在&#xff1f;解决方法&#xff1a; 1.使用哈希表&#xff1a; 可以将数据进行哈希操作&#xff0c;将数据存储在相应的桶中。 查询时&#xff0c;根据哈希值定位到对应的桶&#xff0c;然后在桶内进行查找。这种方法的时间复…

一站式解决方案:体验亚马逊轻量服务器/VPS的顶级服务与灵活性

文章目录 一、什么是轻量级服务器/VPS 二、服务器创建步骤 三、服务器连接客户端(私钥登录) 四、使用服务器搭建博客网站 五、个人浅解及总结 一、什么是轻量级服务器/VPS 亚马逊推出的轻量级服务器/VPS&#xff1a;是一种基于云计算技术的虚拟服务器解决方案。它允许用户…

0005Java安卓程序设计-ssm基于Android的网店系统

文章目录 **摘要**目录系统设计开发环境 编程技术交流、源码分享、模板分享、网课教程 &#x1f427;裙&#xff1a;776871563 摘要 随着Internet的发展&#xff0c;人们的日常生活已经离不开网络。未来人们的生活与工作将变得越来越数字化&#xff0c;网络化和电子化。网上管…

Spring Boot 3 整合 xxl-job 实现分布式定时任务调度,结合 Docker 容器化部署(图文指南)

目录 前言初始化数据库Docker 部署 xxl-job下载镜像创建容器并运行访问调度中心 SpringBoot 整合 xxl-jobpom.xmlapplication.ymlXxlJobConfig.java执行器注册查看 定时任务测试添加测试任务配置定时任务测试结果 结语附录xxl-job 官方文档xxl-job 源码测试项目源码 前言 xxl-…

代码随想录算法训练营第四十三天丨 动态规划part06

518.零钱兑换II 思路 这是一道典型的背包问题&#xff0c;一看到钱币数量不限&#xff0c;就知道这是一个完全背包。 对完全背包还不了解的同学&#xff0c;可以看这篇&#xff1a;动态规划&#xff1a;关于完全背包&#xff0c;你该了解这些&#xff01;(opens new window)…

Spring Boot spring.factories的原理

文章目录 1. spring.factories 用法2. spring.factories 实现原理3. spring.factories 用于解决什么问题&#xff1f; 3.1 业务场景思考及 starter 机制引入3.2 Spring Boot starter 机制 4. 小结 近期看到业务代码里用到 spring.factories 的配置&#xff0c;觉得场景不合适…

Java基础篇 | 多线程详解

✅作者简介&#xff1a;大家好&#xff0c;我是Leo&#xff0c;热爱Java后端开发者&#xff0c;一个想要与大家共同进步的男人&#x1f609;&#x1f609; &#x1f34e;个人主页&#xff1a;Leo的博客 &#x1f49e;当前专栏&#xff1a; Java从入门到精通 ✨特色专栏&#xf…

RMI初探

接口 import java.rmi.Remote; import java.rmi.RemoteException;public interface IFoo extends Remote {String say(String name) throws RemoteException; }import java.rmi.Remote; import java.rmi.RemoteException;public interface IBar extends Remote {String buy(Str…

【Nginx38】Nginx学习:SSL模块(二)错误状态码、变量及宝塔配置分析

Nginx学习&#xff1a;SSL模块&#xff08;二&#xff09;错误状态码、变量及宝塔配置分析 继续我们的 SSL 模块的学习。上回其实我们已经搭建起了一个 HTTPS 服务器了&#xff0c;只用了三个配置&#xff0c;其中一个是 listen 的参数&#xff0c;另外两个是指定密钥文件的地址…

overleaf里插入中文语句

作业要求是需要插入中文 我直接插入中文生成pdf会报错&#xff1a; 解决办法&#xff1a; overleaf官网里提供了教程&#xff1a;https://www.overleaf.com/learn/latex/Chinese 使用XeLaTeX或者LuaLaTeX进行编译是支持UTF-8编码。所以改变编译器的步骤如下&#xff1a; 点击…

3,感兴趣区域ROI

1&#xff0c;简介 ROI&#xff0c;感兴趣区域&#xff08;region of interest)&#xff0c;截取图像 2&#xff0c;获取方法 方法1&#xff1a;使用Rect cv::Mat srccv::imread("*.bmp");//读取原图 cv::Mat matROI src(cv::Rect(100,200,50,100));//截取原图&…