PyTorch损失函数(二)

损失函数

5、nn.L1Loss

nn.L1Loss是一个用于计算输入和目标之间差异的损失函数,它计算输入和目标之间的绝对值差异。

主要参数:

  • reduction:计算模式,可以是nonesummean
    • none:逐个元素计算损失,返回一个与输入大小相同的张量。
    • sum:将所有元素的损失求和,返回一个标量值。
    • mean:计算所有元素的加权平均损失,返回一个标量值。

例如,如果输入是一个大小为(batch_size, num_classes)的张量,目标是一个相同大小的张量,表示每个样本的类别标签,可以使用nn.L1Loss计算它们之间的绝对值差异。

以下是一个示例代码:

import torch
import torch.nn as nn# 创建一个 L1Loss 的实例,设置 reduction 参数为 'mean'
loss_fn = nn.L1Loss(reduction='mean')# 创建输入张量和目标张量
input = torch.tensor([2.0, 4.0, 6.0])
target = torch.tensor([1.0, 3.0, 5.0])# 使用 L1Loss 计算输入和目标之间的损失
loss = loss_fn(input, target)# 打印损失值
print(loss)

输出结果为:

tensor(1.)

这表示输入和目标之间的平均绝对值差异为1.0。

在这里插入图片描述


6、nn.MSELoss

nn.MSELoss是一个用于计算输入和目标之间差异的损失函数,它计算输入和目标之间的平方差异。
主要参数:

  • reduction:计算模式,可以是nonesummean
    • none:逐个元素计算损失,返回一个与输入大小相同的张量。
    • sum:将所有元素的损失求和,返回一个标量值。
    • mean:计算所有元素的加权平均损失,返回一个标量值。

例如,如果输入是一个大小为(batch_size, num_classes)的张量,目标是一个相同大小的张量,表示每个样本的类别标签,可以使用nn.MSELoss计算它们之间的平方差异。

以下是一个示例代码:

import torch
import torch.nn as nn# 创建一个 MSELoss 的实例,设置 reduction 参数为 'mean'
loss_fn = nn.MSELoss(reduction='mean')# 创建输入张量和目标张量
input = torch.tensor([2.0, 4.0, 6.0])
target = torch.tensor([1.0, 3.0, 5.0])# 使用 MSELoss 计算输入和目标之间的损失
loss = loss_fn(input, target)# 打印损失值
print(loss)

输出结果为:

tensor(1.)

这表示输入和目标之间的平均平方差异为1.0。


7、 SmoothL1Loss

nn.SmoothL1Loss是一个用于计算输入和目标之间差异的平滑L1损失函数。
主要参数:

  • reduction:计算模式,可以是nonesummean
    • none:逐个元素计算损失,返回一个与输入大小相同的张量。
    • sum:将所有元素的损失求和,返回一个标量值。
    • mean:计算所有元素的加权平均损失,返回一个标量值。

平滑L1损失函数在输入和目标之间的差异较小时,使用平方函数计算损失,而在差异较大时,使用绝对值函数计算损失。这使得损失函数对离群值更加鲁棒。

以下是一个示例代码:

import torch
import torch.nn as nn# 创建一个 SmoothL1Loss 的实例,设置 reduction 参数为 'mean'
loss_fn = nn.SmoothL1Loss(reduction='mean')# 创建输入张量和目标张量
input = torch.tensor([2.0, 4.0, 6.0])
target = torch.tensor([1.0, 3.0, 5.0])# 使用 SmoothL1Loss 计算输入和目标之间的损失
loss = loss_fn(input, target)# 打印损失值
print(loss)

输出结果为:

tensor(0.5000)

这表示输入和目标之间的平均平滑L1损失为0.5。

在这里插入图片描述
在这里插入图片描述


8、nn.PoissonNLLLoss

nn.PoissonNLLLoss是一个用于计算泊松分布的负对数似然损失函数。

主要参数:

  • log_input:指定输入是否为对数形式。如果设置为True,则假设输入已经是对数形式,计算损失时会使用不同的公式。如果设置为False,则假设输入是线性形式,计算损失时会使用另一种公式。
  • full:指定是否计算所有样本的损失。如果设置为True,则计算所有样本的损失并返回一个标量值。如果设置为False,则计算每个样本的损失并返回一个与输入大小相同的张量。
  • eps:用于避免输入为零时取对数出现NaN的修正项。

以下是一个示例代码:

import torch
import torch.nn as nn# 创建一个 PoissonNLLLoss 的实例,设置 log_input 参数为 False,full 参数为 False,eps 参数为 1e-8
loss_fn = nn.PoissonNLLLoss(log_input=False, full=False, eps=1e-8)# 创建输入张量和目标张量
input = torch.tensor([2.0, 4.0, 6.0])
target = torch.tensor([1.0, 3.0, 5.0])# 使用 PoissonNLLLoss 计算输入和目标之间的损失
loss = loss_fn(input, target)# 打印损失值
print(loss)

输出结果为:

tensor([1.1269, 1.1269, 1.1269])

这表示每个样本的泊松分布的负对数似然损失。
在这里插入图片描述


9、nn.KLDivLoss

nn.KLDivLoss是一个用于计算KL散度(相对熵)的损失函数。

主要参数:

  • reduction:指定损失的归约方式。有以下选项:
    • 'none':逐个元素计算损失。
    • 'sum':将所有元素的损失求和,返回一个标量。
    • 'mean':计算所有元素的加权平均,返回一个标量。
    • 'batchmean':对每个batch的损失进行求和,然后再对所有batch求平均,返回一个标量。

需要注意的是,在使用nn.KLDivLoss之前,输入需要先计算log-probabilities,可以通过nn.log_softmax()来实现。

以下是一个示例代码:

import torch
import torch.nn as nn
# 创建一个 KLDivLoss 的实例,设置 reduction 参数为 'mean'
loss_fn = nn.KLDivLoss(reduction='mean')
# 创建输入张量和目标张量,并计算对数概率
input = torch.log_softmax(torch.tensor([2.0, 4.0, 6.0]), dim=0)
target = torch.tensor([0.3, 0.5, 0.2])
# 使用 KLDivLoss 计算输入和目标之间的损失
loss = loss_fn(input, target)
# 打印损失值
print(loss)

输出结果为:

tensor(0.5974)

这表示输入和目标之间的KL散度损失为0.5974。

在这里插入图片描述


10、nn.MarginRankingLoss是一个用于计算排序任务中两个向量之间相似度的损失函数。

主要参数:

  • margin:边界值,用于衡量两个向量之间的差异。如果两个向量之间的差异小于边界值,则认为它们是相似的;如果差异大于边界值,则认为它们是不相似的。
  • reduction:计算模式,用于指定损失的归约方式。有以下选项:
    • 'none':不进行归约,返回一个与输入大小相同的损失矩阵。
    • 'sum':将所有元素的损失求和,返回一个标量。
    • 'mean':计算所有元素的平均值,返回一个标量。

y = 1时, 希望x1比x2大,当x1>x2时,不产生loss
y = -1时,希望x2比x1大,当x2>x1时,不产生loss

以下是一个示例代码:

import torch
import torch.nn as nn
# 创建一个 MarginRankingLoss 的实例,设置 margin 参数为 0.0,reduction 参数为 'mean'
loss_fn = nn.MarginRankingLoss(margin=0.0, reduction='mean')
# 创建输入张量和目标张量
input1 = torch.tensor([1.0, 2.0, 3.0])
input2 = torch.tensor([2.0, 3.0, 4.0])
target = torch.tensor([1.0, -1.0, 1.0])
# 使用 MarginRankingLoss 计算输入和目标之间的损失
loss = loss_fn(input1, input2, target)
# 打印损失值
print(loss)

输出结果为:

tensor(0.)

这表示输入向量input1和input2之间的相似度损失为0.0。

在这里插入图片描述


11、nn.MultiLabelMarginLoss

nn.MultiLabelMarginLoss是一个用于多标签分类任务的边界损失函数。

  1. 功能:多标签边界损失函数
  2. 举例:四分类任务,样本x属于0类和3类,
  3. 标签:[0, 3, -1, -1] , 不是[1, 0, 0, 1]

主要参数:

  • reduction:计算模式,用于指定损失的归约方式。有以下选项:
    • 'none':不进行归约,返回一个与输入大小相同的损失矩阵。
    • 'sum':将所有元素的损失求和,返回一个标量。
    • 'mean':计算所有元素的平均值,返回一个标量。

以下是一个示例代码:

import torch
import torch.nn as nn
# 创建一个 MultiLabelMarginLoss 的实例,设置 reduction 参数为 'mean'
loss_fn = nn.MultiLabelMarginLoss(reduction='mean')
# 创建输入张量和目标张量
input = torch.tensor([[1.0, 2.0, 3.0, 4.0]])
target = torch.tensor([[0, 3, -1, -1]])
# 使用 MultiLabelMarginLoss 计算输入和目标之间的损失
loss = loss_fn(input, target)
# 打印损失值
print(loss)

输出结果为:

tensor(0.)

这表示输入和目标之间的多标签边界损失为0.0。在这个示例中,样本x属于0类和3类,与标签[0, 3, -1, -1]一致。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/628568.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MongoDB 启动提示错误code=killed, signal=ABRT

1.停止MongoDB sudo systemctl stop mongod 2.检查数据损坏 sudo mongod --repair --dbpath /var/lib/mongodb 3.赋权限 chown -R mongodb:mongodb /var/lib/mongodb chown mongodb:mongodb /tmp/mongodb-27017.sock 如果不赋权限,启动的时候则会提示 4.启动Mo…

静态路由添加404页面

静态路由添加404页面 引入404页面路由代码: {path: *,name: 404,component: () > import(/views/page404)}404页面样式

meter报OOM错误,如何解决?

根据在之前的压测过程碰到的问题,今天稍微总结总结,以后方便自己查找。 一、单台Mac进行压测时候,压测客户端Jmeter启动超过2000个线程,Jmeter报OOM错误,如何解决? 解答:单台Mac配置内存为8G&…

小红书青年文化洞察:新“旷野文学”兴起,用户回归人间清醒?

社交媒体的“议程设置”能够影响用户的关注焦点,乃至影响舆论风向,但是以UGC生态为主的小红书,受众手中的话语权影响力变大,用户能识别、参与,甚至抵抗议程设置。 用户越来越清醒,不再是“电视喂什么&…

MySQL的单表查询

单表查询的素材: 一、单表查询 素材: 表名:worker-- 表中字段均为中文,比如 部门号 工资 职工号 参加工作 等 CREATE TABLE worker ( 部门号 int(11) NOT NULL, 职工号 int(11) NOT NULL, 工作时间 date NOT NULL, 工资 floa…

在linux环境下安装lnmp

lnmp官网:https://lnmp.org 一:lnmp安装 参考:https://lnmp.org/install.html 1:下载lnmp安装包 wget https://soft.lnmp.com/lnmp/lnmp2.0.tar.gz -O lnmp2.0.tar.gz 2:解压lnmp安装包 tar zxf lnmp2.0.tar.gz …

01章【JAVA开发入门】

计算机基本概念 计算机组成原理 计算机组装 计算机:电子计算机,俗称电脑。是一种能够按照程序运行,自动、高速处理海量数据的现代化智能电子设备。由硬件和软件所组成,没有安装任何软件的计算机称为裸机。常见的形式有台式计算机、…

leetcode-344. 反转字符串、9. 回文数

题目1: 解题方法 直接用reverse()即可 代码: class Solution(object):def reverseString(self, s):""":type s: List[str]:rtype: None Do not return anything, modify s in-place instead."""return s.reverse()如果不…

动态规划day09(打家劫舍,树形dp)

目录 198.打家劫舍 看到题目的第一想法 看到代码随想录之后的想法 自己实现过程中遇到的困难 213.打家劫舍II 看到题目的第一想法 看到代码随想录之后的想法 自己实现过程中遇到的困难 337.打家劫舍 III(树形dp) 看到题目的第一想法 看到代码随想录之后的想法 自己实…

PLC绝对定位指令DDRVA往复运动(三菱FX系列简单状态机编程)

有关状态机的具体介绍,专栏有很多文章,大家可以通过下面的链接查看: https://rxxw-control.blog.csdn.net/article/details/125488089https://rxxw-control.blog.csdn.net/article/details/125488089三菱FX系列回原功能块介绍 https://rxxw-control.blog.csdn.net/article…

【MATLAB】 HANTS滤波算法

有意向获取代码,请转文末观看代码获取方式~ 1 基本定义 HANTS滤波算法是一种时间序列谐波分析方法,它综合了平滑和滤波两种方法,能够充分利用遥感图像存在时间性和空间性的特点,将其空间上的分布规律和时间上的变化规律联系起来…

怎么把一个网站地址生成二维码?扫码跳网站页面

怎么把一个网站地址生成二维码?现在经常会发现扫描日常生活中的一些二维码会跳转到一个对应的网站页面,那么这种类型的二维码是如何生成的呢?如果大家也想要将网址生成二维码图片使用,那么最简单快捷的方法就是找合适的二维码生成…

单片机中的PWM(脉宽调制)的工作原理以及它在电机控制中的应用。

目录 工作原理 在电机控制中的应用 脉宽调制(PWM)是一种在单片机中常用的控制技术,它通过调整信号的脉冲宽度来控制输出信号的平均电平。PWM常用于模拟输出一个可调电平的数字信号,用于控制电机速度、亮度、电压等。 工作原理 …

【Maven笔记3】Maven基础入门案例

本篇通过一个最基础的入门案例,熟悉一下maven最基础的使用方法。 编写POM maven项目的核心是pom.xml文件,pom定义了项目的基本信息,用于描述项目如何构建,声明项目依赖等等。 这里我们新建一个maven-demo-hello项目,…

【用队列实现栈】【用栈实现队列】Leetcode 232 225

【用队列实现栈】【用栈实现队列】Leetcode 232 225 队列的相关操作栈的相关操作用队列实现栈用栈实现队列 ---------------🎈🎈题目链接 用队列实现栈🎈🎈------------------- ---------------🎈🎈题目链…

TCP 的三次握手和四次挥手

Java 面试题 TCP 三次握手 第一次握手:客户端向服务端发送SYN包。报文中标志位SYN1,序列号seqx(x为随机整数)。此时客户端进入了 SYN_SEND 同步已发送状态。 第二次握手:服务端回复客户端SYNACK包。报文中标志位SYN1&…

宿舍管理系统的设计与实现:基于Spring Boot、Java、Vue.js和MySQL的完整解决方案

✍✍计算机编程指导师 ⭐⭐个人介绍:自己非常喜欢研究技术问题!专业做Java、Python、微信小程序、安卓、大数据、爬虫、Golang、大屏等实战项目。 ⛽⛽实战项目:有源码或者技术上的问题欢迎在评论区一起讨论交流! ⚡⚡ Java实战 |…

文理导航期刊投稿方式

《文理导航》杂志系国家新闻出版总署批准,内蒙古自治区文旅厅主管,内蒙古自治区北方文化研究院主办的,面向大中专院校、中小学教育的专业性教育刊物,阅读对象是关心教育事业发展的大中专院校、职业教育、中小学教育的专家、教研员…

Flask框架小程序后端分离开发学习笔记《1》网络知识

Flask框架小程序后端分离开发学习笔记《1》网络知识 Flask是使用python的后端,由于小程序需要后端开发,遂学习一下后端开发。 一、网址组成介绍 协议:http,https (https是加密的http)主机:g.cn zhihu.com之类的网址…

嵌入式-Stm32-江科大基于寄存器点亮LED灯

文章目录 前言:一:搭建基于寄存器控制LED的工程二:用江科大的STM32板子实现基于寄存器点亮LED灯三:用非江科大stm32板子实现基于寄存器点亮LED灯道友:一星陨落,黯淡不了星空灿烂;一花凋零&#…